Skip to content

Commit 39b8aa7

Browse files
[Optim] Optimize the efficiency of TagDispatch expressions. (#431)
This PR optimized the efficiency of TagDispatch expression. At the initial state, only the tokens contain tags or stop strings are possible to be constrained. Otherwise, they are definitely accepted. This PR takes advantage of it, and brings more efficiency improvement. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 911a7a3 commit 39b8aa7

File tree

1 file changed

+128
-14
lines changed

1 file changed

+128
-14
lines changed

cpp/grammar_compiler.cc

Lines changed: 128 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include <cctype>
1111
#include <cstddef>
1212
#include <cstdint>
13+
#include <optional>
14+
#include <unordered_map>
1315
#include <utility>
1416
#include <variant>
1517
#include <vector>
@@ -19,10 +21,12 @@
1921
#include "fsm.h"
2022
#include "grammar_functor.h"
2123
#include "grammar_impl.h"
24+
#include "support/dynamic_bitset.h"
2225
#include "support/logging.h"
2326
#include "support/thread_pool.h"
2427
#include "support/thread_safe_cache.h"
2528
#include "support/utils.h"
29+
#include "xgrammar/grammar.h"
2630

2731
namespace xgrammar {
2832

@@ -32,11 +36,17 @@ namespace xgrammar {
3236
class GrammarMatcherForTokenMaskCache : public EarleyParser {
3337
public:
3438
GrammarMatcherForTokenMaskCache(
35-
const Grammar& grammar, const ParserState& init_state, const bool& need_expand = true
39+
const Grammar& grammar,
40+
const ParserState& init_state,
41+
const std::unordered_map<int32_t, DynamicBitset>&
42+
tag_dispatch_rule_id_to_second_slicing_bitset,
43+
const bool& need_expand = true
3644
)
3745
: EarleyParser(grammar, init_state),
3846
init_rule_id(init_state.rule_id),
39-
initial_state(init_state) {}
47+
initial_state(init_state),
48+
tag_dispatch_rule_id_to_second_slicing_bitset(tag_dispatch_rule_id_to_second_slicing_bitset
49+
) {}
4050
/*!
4151
* \brief Get the adaptive token mask for the given ParserState.
4252
* \param is_root_rule Whether to consider the parent rule. If false, there will be
@@ -87,6 +97,17 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
8797
// The initial state of the parser.
8898
ParserState initial_state;
8999

100+
/*!
101+
\brief This is a mapping from TagDispatch rule id to the bitset used for second slicing.
102+
\note If a rule is a TagDispatch rule, then there will be an AC automaton for its triggers.
103+
Which means that it can accept a lot of tokens. However, it will be slow to check a lot of
104+
tokens. The DynamicBitset here is used to do a second slicing: if a token's substr(1, n - 1)
105+
can be accepted by the start state of the AC automaton, then it will be True in the bitset.
106+
When we check a token, we first check if its first character can transit to the start state.
107+
If yes, then we check if it is in the bitset. If yes, then we accept it directly.
108+
*/
109+
const std::unordered_map<int32_t, DynamicBitset>& tag_dispatch_rule_id_to_second_slicing_bitset;
110+
90111
// Temporary data for GetAdaptiveTokenMask.
91112
std::vector<int32_t> tmp_accepted_indices_;
92113
std::vector<int32_t> tmp_rejected_indices_;
@@ -218,6 +239,27 @@ std::pair<bool, std::bitset<256>> GrammarMatcherForTokenMaskCache::GetSpeculativ
218239
const std::vector<std::pair<int32_t, std::string>>& sorted_decoded_vocab
219240
) {
220241
using GrammarExprType = Grammar::Impl::GrammarExprType;
242+
// If the initial rule is a tag dispatch, we will check if it can achieve its initial state.
243+
const auto& rule = grammar_->GetRule(init_rule_id);
244+
const auto& rule_body = grammar_->GetGrammarExpr(rule.body_expr_id);
245+
if (rule_body.type == GrammarExprType::kTagDispatch) {
246+
std::bitset<256> speculative_mask;
247+
XGRAMMAR_DCHECK(grammar_->per_rule_fsms[init_rule_id].has_value());
248+
const auto& fsm = grammar_->per_rule_fsms[init_rule_id].value();
249+
for (const auto& edge : fsm.GetFsm().GetEdges(initial_state.element_id)) {
250+
if (edge.target != fsm.GetStart()) {
251+
continue;
252+
}
253+
if (!edge.IsCharRange()) {
254+
continue;
255+
}
256+
for (int32_t ch = edge.min; ch <= edge.max; ++ch) {
257+
speculative_mask.set(ch);
258+
}
259+
}
260+
return {true, speculative_mask};
261+
}
262+
221263
// Check if the initial state is self-recursive-like. If the state is self-recursive-like,
222264
// and it covers a large part of the vocabulary, we will do speculative calculation in compiling.
223265
if (!grammar_->per_rule_fsms[init_rule_id].has_value()) {
@@ -317,6 +359,15 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
317359
int prev_matched_size = 0;
318360
int last_rejected_range = 0;
319361
const bool& is_exact_lookahead = grammar_->GetRule(init_rule_id).is_exact_lookahead;
362+
std::optional<const DynamicBitset*> definite_accepted_bitset = std::nullopt;
363+
const bool is_tag_dispatch_rule =
364+
grammar_->GetGrammarExpr(grammar_->GetRule(init_rule_id).body_expr_id).type ==
365+
Grammar::Impl::GrammarExprType::kTagDispatch;
366+
if (is_tag_dispatch_rule) {
367+
XGRAMMAR_DCHECK(tag_dispatch_rule_id_to_second_slicing_bitset.count(init_rule_id) > 0);
368+
definite_accepted_bitset = &tag_dispatch_rule_id_to_second_slicing_bitset.at(init_rule_id);
369+
}
370+
320371
const std::string* prev_token = nullptr;
321372
for (size_t interval_idx = 0; interval_idx < possible_intervals.size(); ++interval_idx) {
322373
const auto& interval = possible_intervals[interval_idx];
@@ -339,18 +390,35 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
339390
const auto& token = sorted_decoded_vocab[i].second;
340391
// This optimization is useful for simple self-recursive rules, like string content.
341392
if (speculative_calculation) {
342-
bool all_accepted = true;
343-
for (char ch : token) {
344-
// If the first character is not the ascii character or can't be accepted by the
345-
// first character mask, we need to check them in the parser.
346-
if (isascii(ch) == 0 || !speculative_mask[static_cast<uint8_t>(ch)]) {
347-
all_accepted = false;
348-
break;
393+
// Optimization for tag dispatch rules.
394+
if (definite_accepted_bitset.has_value()) {
395+
// If the token is empty, it must be accepted.
396+
if (token.empty()) {
397+
tmp_accepted_indices_.push_back(i);
398+
continue;
399+
}
400+
// If the token doesn't contain tags or stop strings since the second character, and it
401+
// will transit to the start state after consuming the first character, it must be
402+
// accepted.
403+
if (speculative_mask[static_cast<uint8_t>(token[0])] &&
404+
(*definite_accepted_bitset.value())[i]) {
405+
tmp_accepted_indices_.push_back(i);
406+
continue;
407+
}
408+
} else {
409+
bool all_accepted = true;
410+
for (char ch : token) {
411+
// If the first character is not the ascii character or can't be accepted by the
412+
// first character mask, we need to check them in the parser.
413+
if (isascii(ch) == 0 || !speculative_mask[static_cast<uint8_t>(ch)]) {
414+
all_accepted = false;
415+
break;
416+
}
417+
}
418+
if (all_accepted) {
419+
tmp_accepted_indices_.push_back(i);
420+
continue;
349421
}
350-
}
351-
if (all_accepted) {
352-
tmp_accepted_indices_.push_back(i);
353-
continue;
354422
}
355423
}
356424
// Many tokens may contain the same prefix, so we will avoid unnecessary matching
@@ -573,6 +641,8 @@ class GrammarCompilerNoCache {
573641
const TokenizerInfo tokenizer_info_;
574642
/*! \brief The maximum number of threads to use. */
575643
const int max_threads_;
644+
/*! \brief Mapping from the rule_id to the definite accepted token mask. */
645+
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
576646
};
577647

578648
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar) {
@@ -588,6 +658,48 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
588658
if (tokenizer_info_.GetVocabSize() == 0) {
589659
return CompiledGrammar(compiled_grammar_impl);
590660
}
661+
662+
// Optimization for TagDispatch: Precompute the definitely accepted tokens.
663+
for (int i = 0; i < compiled_grammar_impl->grammar->NumRules(); i++) {
664+
const auto& rule = compiled_grammar_impl->grammar->GetRule(i);
665+
const auto& rule_body = compiled_grammar_impl->grammar->GetGrammarExpr(rule.body_expr_id);
666+
if (rule_body.type != GrammarExprType::kTagDispatch) {
667+
continue;
668+
}
669+
XGRAMMAR_DCHECK(rule_body.type == GrammarExprType::kTagDispatch);
670+
Grammar::Impl::TagDispatch tag_dispatch =
671+
compiled_grammar_impl->grammar->GetTagDispatch(rule.body_expr_id);
672+
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
673+
DynamicBitset definite_accepted_tokens_since_second_char(sorted_decoded_vocab.size());
674+
for (int i = 0; i < static_cast<int32_t>(sorted_decoded_vocab.size()); i++) {
675+
bool definite_accept_since_second_char = true;
676+
const auto& token = sorted_decoded_vocab[i].second;
677+
if (token.empty()) {
678+
definite_accepted_tokens_since_second_char.Set(i);
679+
continue;
680+
}
681+
682+
// Check if the token contains any tag or stop string after the first character.
683+
for (const auto& tag : tag_dispatch.tag_rule_pairs) {
684+
if (token.find(tag.first, 1) != std::string::npos) {
685+
definite_accept_since_second_char = false;
686+
break;
687+
}
688+
}
689+
for (const auto& stop_str : tag_dispatch.stop_str) {
690+
if (token.find(stop_str, 1) != std::string::npos) {
691+
definite_accept_since_second_char = false;
692+
break;
693+
}
694+
}
695+
696+
// If the token can be definitely accepted since the second character, set the bit.
697+
if (definite_accept_since_second_char) {
698+
definite_accepted_tokens_since_second_char.Set(i);
699+
}
700+
}
701+
tag_dispatch_rule_id_to_second_slicing_bitset[i] = definite_accepted_tokens_since_second_char;
702+
}
591703
// Step 3. Compute the adaptive token mask cache
592704
// The token mask cache is computed for these positions in the grammar:
593705
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -606,7 +718,9 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
606718
}
607719

608720
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
609-
auto grammar_matcher = GrammarMatcherForTokenMaskCache(grammar, state, false);
721+
auto grammar_matcher = GrammarMatcherForTokenMaskCache(
722+
grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
723+
);
610724
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
611725
tokenizer_info_.GetVocabSize(),
612726
tokenizer_info_.GetSortedDecodedVocab(),

0 commit comments

Comments
 (0)