10
10
#include < cctype>
11
11
#include < cstddef>
12
12
#include < cstdint>
13
+ #include < optional>
14
+ #include < unordered_map>
13
15
#include < utility>
14
16
#include < variant>
15
17
#include < vector>
19
21
#include " fsm.h"
20
22
#include " grammar_functor.h"
21
23
#include " grammar_impl.h"
24
+ #include " support/dynamic_bitset.h"
22
25
#include " support/logging.h"
23
26
#include " support/thread_pool.h"
24
27
#include " support/thread_safe_cache.h"
25
28
#include " support/utils.h"
29
+ #include " xgrammar/grammar.h"
26
30
27
31
namespace xgrammar {
28
32
@@ -32,11 +36,17 @@ namespace xgrammar {
32
36
class GrammarMatcherForTokenMaskCache : public EarleyParser {
33
37
public:
34
38
GrammarMatcherForTokenMaskCache (
35
- const Grammar& grammar, const ParserState& init_state, const bool & need_expand = true
39
+ const Grammar& grammar,
40
+ const ParserState& init_state,
41
+ const std::unordered_map<int32_t , DynamicBitset>&
42
+ tag_dispatch_rule_id_to_second_slicing_bitset,
43
+ const bool & need_expand = true
36
44
)
37
45
: EarleyParser(grammar, init_state),
38
46
init_rule_id (init_state.rule_id),
39
- initial_state(init_state) {}
47
+ initial_state(init_state),
48
+ tag_dispatch_rule_id_to_second_slicing_bitset(tag_dispatch_rule_id_to_second_slicing_bitset
49
+ ) {}
40
50
/* !
41
51
* \brief Get the adaptive token mask for the given ParserState.
42
52
* \param is_root_rule Whether to consider the parent rule. If false, there will be
@@ -87,6 +97,17 @@ class GrammarMatcherForTokenMaskCache : public EarleyParser {
87
97
// The initial state of the parser.
88
98
ParserState initial_state;
89
99
100
+ /* !
101
+ \brief This is a mapping from TagDispatch rule id to the bitset used for second slicing.
102
+ \note If a rule is a TagDispatch rule, then there will be an AC automaton for its triggers.
103
+ Which means that it can accept a lot of tokens. However, it will be slow to check a lot of
104
+ tokens. The DynamicBitset here is used to do a second slicing: if a token's substr(1, n - 1)
105
+ can be accepted by the start state of the AC automaton, then it will be True in the bitset.
106
+ When we check a token, we first check if its first character can transit to the start state.
107
+ If yes, then we check if it is in the bitset. If yes, then we accept it directly.
108
+ */
109
+ const std::unordered_map<int32_t , DynamicBitset>& tag_dispatch_rule_id_to_second_slicing_bitset;
110
+
90
111
// Temporary data for GetAdaptiveTokenMask.
91
112
std::vector<int32_t > tmp_accepted_indices_;
92
113
std::vector<int32_t > tmp_rejected_indices_;
@@ -218,6 +239,27 @@ std::pair<bool, std::bitset<256>> GrammarMatcherForTokenMaskCache::GetSpeculativ
218
239
const std::vector<std::pair<int32_t , std::string>>& sorted_decoded_vocab
219
240
) {
220
241
using GrammarExprType = Grammar::Impl::GrammarExprType;
242
+ // If the initial rule is a tag dispatch, we will check if it can achieve its initial state.
243
+ const auto & rule = grammar_->GetRule (init_rule_id);
244
+ const auto & rule_body = grammar_->GetGrammarExpr (rule.body_expr_id );
245
+ if (rule_body.type == GrammarExprType::kTagDispatch ) {
246
+ std::bitset<256 > speculative_mask;
247
+ XGRAMMAR_DCHECK (grammar_->per_rule_fsms [init_rule_id].has_value ());
248
+ const auto & fsm = grammar_->per_rule_fsms [init_rule_id].value ();
249
+ for (const auto & edge : fsm.GetFsm ().GetEdges (initial_state.element_id )) {
250
+ if (edge.target != fsm.GetStart ()) {
251
+ continue ;
252
+ }
253
+ if (!edge.IsCharRange ()) {
254
+ continue ;
255
+ }
256
+ for (int32_t ch = edge.min ; ch <= edge.max ; ++ch) {
257
+ speculative_mask.set (ch);
258
+ }
259
+ }
260
+ return {true , speculative_mask};
261
+ }
262
+
221
263
// Check if the initial state is self-recursive-like. If the state is self-recursive-like,
222
264
// and it covers a large part of the vocabulary, we will do speculative calculation in compiling.
223
265
if (!grammar_->per_rule_fsms [init_rule_id].has_value ()) {
@@ -317,6 +359,15 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
317
359
int prev_matched_size = 0 ;
318
360
int last_rejected_range = 0 ;
319
361
const bool & is_exact_lookahead = grammar_->GetRule (init_rule_id).is_exact_lookahead ;
362
+ std::optional<const DynamicBitset*> definite_accepted_bitset = std::nullopt ;
363
+ const bool is_tag_dispatch_rule =
364
+ grammar_->GetGrammarExpr (grammar_->GetRule (init_rule_id).body_expr_id ).type ==
365
+ Grammar::Impl::GrammarExprType::kTagDispatch ;
366
+ if (is_tag_dispatch_rule) {
367
+ XGRAMMAR_DCHECK (tag_dispatch_rule_id_to_second_slicing_bitset.count (init_rule_id) > 0 );
368
+ definite_accepted_bitset = &tag_dispatch_rule_id_to_second_slicing_bitset.at (init_rule_id);
369
+ }
370
+
320
371
const std::string* prev_token = nullptr ;
321
372
for (size_t interval_idx = 0 ; interval_idx < possible_intervals.size (); ++interval_idx) {
322
373
const auto & interval = possible_intervals[interval_idx];
@@ -339,18 +390,35 @@ bool GrammarMatcherForTokenMaskCache::GetTokenMaskWithFirstCharacterCheck(
339
390
const auto & token = sorted_decoded_vocab[i].second ;
340
391
// This optimization is useful for simple self-recursive rules, like string content.
341
392
if (speculative_calculation) {
342
- bool all_accepted = true ;
343
- for (char ch : token) {
344
- // If the first character is not the ascii character or can't be accepted by the
345
- // first character mask, we need to check them in the parser.
346
- if (isascii (ch) == 0 || !speculative_mask[static_cast <uint8_t >(ch)]) {
347
- all_accepted = false ;
348
- break ;
393
+ // Optimization for tag dispatch rules.
394
+ if (definite_accepted_bitset.has_value ()) {
395
+ // If the token is empty, it must be accepted.
396
+ if (token.empty ()) {
397
+ tmp_accepted_indices_.push_back (i);
398
+ continue ;
399
+ }
400
+ // If the token doesn't contain tags or stop strings since the second character, and it
401
+ // will transit to the start state after consuming the first character, it must be
402
+ // accepted.
403
+ if (speculative_mask[static_cast <uint8_t >(token[0 ])] &&
404
+ (*definite_accepted_bitset.value ())[i]) {
405
+ tmp_accepted_indices_.push_back (i);
406
+ continue ;
407
+ }
408
+ } else {
409
+ bool all_accepted = true ;
410
+ for (char ch : token) {
411
+ // If the first character is not the ascii character or can't be accepted by the
412
+ // first character mask, we need to check them in the parser.
413
+ if (isascii (ch) == 0 || !speculative_mask[static_cast <uint8_t >(ch)]) {
414
+ all_accepted = false ;
415
+ break ;
416
+ }
417
+ }
418
+ if (all_accepted) {
419
+ tmp_accepted_indices_.push_back (i);
420
+ continue ;
349
421
}
350
- }
351
- if (all_accepted) {
352
- tmp_accepted_indices_.push_back (i);
353
- continue ;
354
422
}
355
423
}
356
424
// Many tokens may contain the same prefix, so we will avoid unnecessary matching
@@ -573,6 +641,8 @@ class GrammarCompilerNoCache {
573
641
const TokenizerInfo tokenizer_info_;
574
642
/* ! \brief The maximum number of threads to use. */
575
643
const int max_threads_;
644
+ /* ! \brief Mapping from the rule_id to the definite accepted token mask. */
645
+ std::unordered_map<int32_t , DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
576
646
};
577
647
578
648
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar (Grammar grammar) {
@@ -588,6 +658,48 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
588
658
if (tokenizer_info_.GetVocabSize () == 0 ) {
589
659
return CompiledGrammar (compiled_grammar_impl);
590
660
}
661
+
662
+ // Optimization for TagDispatch: Precompute the definitely accepted tokens.
663
+ for (int i = 0 ; i < compiled_grammar_impl->grammar ->NumRules (); i++) {
664
+ const auto & rule = compiled_grammar_impl->grammar ->GetRule (i);
665
+ const auto & rule_body = compiled_grammar_impl->grammar ->GetGrammarExpr (rule.body_expr_id );
666
+ if (rule_body.type != GrammarExprType::kTagDispatch ) {
667
+ continue ;
668
+ }
669
+ XGRAMMAR_DCHECK (rule_body.type == GrammarExprType::kTagDispatch );
670
+ Grammar::Impl::TagDispatch tag_dispatch =
671
+ compiled_grammar_impl->grammar ->GetTagDispatch (rule.body_expr_id );
672
+ const auto & sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab ();
673
+ DynamicBitset definite_accepted_tokens_since_second_char (sorted_decoded_vocab.size ());
674
+ for (int i = 0 ; i < static_cast <int32_t >(sorted_decoded_vocab.size ()); i++) {
675
+ bool definite_accept_since_second_char = true ;
676
+ const auto & token = sorted_decoded_vocab[i].second ;
677
+ if (token.empty ()) {
678
+ definite_accepted_tokens_since_second_char.Set (i);
679
+ continue ;
680
+ }
681
+
682
+ // Check if the token contains any tag or stop string after the first character.
683
+ for (const auto & tag : tag_dispatch.tag_rule_pairs ) {
684
+ if (token.find (tag.first , 1 ) != std::string::npos) {
685
+ definite_accept_since_second_char = false ;
686
+ break ;
687
+ }
688
+ }
689
+ for (const auto & stop_str : tag_dispatch.stop_str ) {
690
+ if (token.find (stop_str, 1 ) != std::string::npos) {
691
+ definite_accept_since_second_char = false ;
692
+ break ;
693
+ }
694
+ }
695
+
696
+ // If the token can be definitely accepted since the second character, set the bit.
697
+ if (definite_accept_since_second_char) {
698
+ definite_accepted_tokens_since_second_char.Set (i);
699
+ }
700
+ }
701
+ tag_dispatch_rule_id_to_second_slicing_bitset[i] = definite_accepted_tokens_since_second_char;
702
+ }
591
703
// Step 3. Compute the adaptive token mask cache
592
704
// The token mask cache is computed for these positions in the grammar:
593
705
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -606,7 +718,9 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
606
718
}
607
719
608
720
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
609
- auto grammar_matcher = GrammarMatcherForTokenMaskCache (grammar, state, false );
721
+ auto grammar_matcher = GrammarMatcherForTokenMaskCache (
722
+ grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
723
+ );
610
724
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask (
611
725
tokenizer_info_.GetVocabSize (),
612
726
tokenizer_info_.GetSortedDecodedVocab (),
0 commit comments