@@ -636,13 +636,20 @@ class GrammarCompilerNoCache {
636636 private:
637637 /* ! \brief The main logic. Compile the grammar with multi-threading. */
638638 CompiledGrammar MultiThreadCompileGrammar (Grammar grammar);
639+ /* ! \brief Optimization for TagDispatch.
640+ * \param compiled_grammar_impl the compiled_grammar to be optimized.
641+ * \param tag_dispatch_rule_id_to_second_slicing_bitset Return value. Mapping from the rule_id to
642+ * the definite accepted token mask.
643+ */
644+ void TagDispatchOptimization (
645+ std::shared_ptr<CompiledGrammar::Impl> compiled_grammar_impl,
646+ std::unordered_map<int32_t , DynamicBitset>* tag_dispatch_rule_id_to_second_slicing_bitset
647+ );
639648
640649 /* ! \brief The vocabulary associated with this storage class. */
641650 const TokenizerInfo tokenizer_info_;
642651 /* ! \brief The maximum number of threads to use. */
643652 const int max_threads_;
644- /* ! \brief Mapping from the rule_id to the definite accepted token mask. */
645- std::unordered_map<int32_t , DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
646653};
647654
648655CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar (Grammar grammar_unoptimized) {
@@ -655,48 +662,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
655662 if (tokenizer_info_.GetVocabSize () == 0 ) {
656663 return CompiledGrammar (compiled_grammar_impl);
657664 }
658-
659- // Optimization for TagDispatch: Precompute the definitely accepted tokens.
660- for (int i = 0 ; i < compiled_grammar_impl->grammar ->NumRules (); i++) {
661- const auto & rule = compiled_grammar_impl->grammar ->GetRule (i);
662- const auto & rule_body = compiled_grammar_impl->grammar ->GetGrammarExpr (rule.body_expr_id );
663- if (rule_body.type != GrammarExprType::kTagDispatch ) {
664- continue ;
665- }
666- XGRAMMAR_DCHECK (rule_body.type == GrammarExprType::kTagDispatch );
667- Grammar::Impl::TagDispatch tag_dispatch =
668- compiled_grammar_impl->grammar ->GetTagDispatch (rule.body_expr_id );
669- const auto & sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab ();
670- DynamicBitset definite_accepted_tokens_since_second_char (sorted_decoded_vocab.size ());
671- for (int i = 0 ; i < static_cast <int32_t >(sorted_decoded_vocab.size ()); i++) {
672- bool definite_accept_since_second_char = true ;
673- const auto & token = sorted_decoded_vocab[i].second ;
674- if (token.empty ()) {
675- definite_accepted_tokens_since_second_char.Set (i);
676- continue ;
677- }
678-
679- // Check if the token contains any tag or stop string after the first character.
680- for (const auto & tag : tag_dispatch.tag_rule_pairs ) {
681- if (token.find (tag.first , 1 ) != std::string::npos) {
682- definite_accept_since_second_char = false ;
683- break ;
684- }
685- }
686- for (const auto & stop_str : tag_dispatch.stop_str ) {
687- if (token.find (stop_str, 1 ) != std::string::npos) {
688- definite_accept_since_second_char = false ;
689- break ;
690- }
691- }
692-
693- // If the token can be definitely accepted since the second character, set the bit.
694- if (definite_accept_since_second_char) {
695- definite_accepted_tokens_since_second_char.Set (i);
696- }
697- }
698- tag_dispatch_rule_id_to_second_slicing_bitset[i] = definite_accepted_tokens_since_second_char;
699- }
665+ std::unordered_map<int32_t , DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
666+ TagDispatchOptimization (compiled_grammar_impl, &tag_dispatch_rule_id_to_second_slicing_bitset);
700667 // Step 3. Compute the adaptive token mask cache
701668 // The token mask cache is computed for these positions in the grammar:
702669 // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -842,6 +809,57 @@ CompiledGrammar GrammarCompilerNoCache::CompileGrammar(
842809 return MultiThreadCompileGrammar (Grammar::FromEBNF (ebnf_str, root_rule_name));
843810}
844811
812+ void GrammarCompilerNoCache::TagDispatchOptimization (
813+ std::shared_ptr<CompiledGrammar::Impl> compiled_grammar_impl,
814+ std::unordered_map<int32_t , DynamicBitset>* tag_dispatch_rule_id_to_second_slicing_bitset
815+ ) {
816+ using GrammarExprType = Grammar::Impl::GrammarExprType;
817+ tag_dispatch_rule_id_to_second_slicing_bitset->clear ();
818+
819+ // Optimization for TagDispatch: Precompute the definitely accepted tokens.
820+ for (int i = 0 ; i < compiled_grammar_impl->grammar ->NumRules (); i++) {
821+ const auto & rule = compiled_grammar_impl->grammar ->GetRule (i);
822+ const auto & rule_body = compiled_grammar_impl->grammar ->GetGrammarExpr (rule.body_expr_id );
823+ if (rule_body.type != GrammarExprType::kTagDispatch ) {
824+ continue ;
825+ }
826+ XGRAMMAR_DCHECK (rule_body.type == GrammarExprType::kTagDispatch );
827+ Grammar::Impl::TagDispatch tag_dispatch =
828+ compiled_grammar_impl->GetGrammar ()->GetTagDispatch (rule.body_expr_id );
829+ const auto & sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab ();
830+ DynamicBitset definite_accepted_tokens_since_second_char (sorted_decoded_vocab.size ());
831+ for (int i = 0 ; i < static_cast <int32_t >(sorted_decoded_vocab.size ()); i++) {
832+ bool definite_accept_since_second_char = true ;
833+ const auto & token = sorted_decoded_vocab[i].second ;
834+ if (token.empty ()) {
835+ definite_accepted_tokens_since_second_char.Set (i);
836+ continue ;
837+ }
838+
839+ // Check if the token contains any tag or stop string after the first character.
840+ for (const auto & tag : tag_dispatch.tag_rule_pairs ) {
841+ if (token.find (tag.first , 1 ) != std::string::npos) {
842+ definite_accept_since_second_char = false ;
843+ break ;
844+ }
845+ }
846+ for (const auto & stop_str : tag_dispatch.stop_str ) {
847+ if (token.find (stop_str, 1 ) != std::string::npos) {
848+ definite_accept_since_second_char = false ;
849+ break ;
850+ }
851+ }
852+
853+ // If the token can be definitely accepted since the second character, set the bit.
854+ if (definite_accept_since_second_char) {
855+ definite_accepted_tokens_since_second_char.Set (i);
856+ }
857+ }
858+ (*tag_dispatch_rule_id_to_second_slicing_bitset)[i] =
859+ definite_accepted_tokens_since_second_char;
860+ }
861+ }
862+
845863/* ****************** GrammarCompiler::Impl *******************/
846864
847865/* !
0 commit comments