@@ -645,16 +645,13 @@ class GrammarCompilerNoCache {
645
645
std::unordered_map<int32_t , DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
646
646
};
647
647
648
- CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar (Grammar grammar ) {
648
+ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar (Grammar grammar_unoptimized ) {
649
649
using GrammarExprType = Grammar::Impl::GrammarExprType;
650
650
651
651
auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();
652
652
653
- compiled_grammar_impl->grammar = grammar ;
653
+ compiled_grammar_impl->grammar = GrammarOptimizer::Apply (grammar_unoptimized) ;
654
654
compiled_grammar_impl->tokenizer_info = tokenizer_info_;
655
- grammar->allow_empty_rule_ids = AllowEmptyRuleAnalyzer::Apply (compiled_grammar_impl->grammar );
656
- RepetitionNormalizer::Apply (&compiled_grammar_impl->grammar );
657
- GrammarFSMBuilder::Apply (&compiled_grammar_impl->grammar );
658
655
if (tokenizer_info_.GetVocabSize () == 0 ) {
659
656
return CompiledGrammar (compiled_grammar_impl);
660
657
}
@@ -711,15 +708,14 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
711
708
// not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
712
709
std::optional<ThreadPool> thread_pool;
713
710
std::optional<std::mutex> adaptive_token_mask_cache_mutex;
714
-
715
711
if (max_threads_ > 1 ) {
716
712
thread_pool.emplace (max_threads_);
717
713
adaptive_token_mask_cache_mutex.emplace ();
718
714
}
719
715
720
716
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
721
717
auto grammar_matcher = GrammarMatcherForTokenMaskCache (
722
- grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
718
+ compiled_grammar_impl-> grammar , state, tag_dispatch_rule_id_to_second_slicing_bitset, false
723
719
);
724
720
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask (
725
721
tokenizer_info_.GetVocabSize (),
@@ -746,12 +742,13 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
746
742
}
747
743
};
748
744
749
- auto root_rule_id = grammar->GetRootRuleId ();
745
+ auto root_rule_id = compiled_grammar_impl-> grammar ->GetRootRuleId ();
750
746
751
- for (int32_t rule_id = 0 ; rule_id < static_cast <int >(grammar->NumRules ()); ++rule_id) {
752
- auto rule = grammar->GetRule (rule_id);
753
- auto rule_body = grammar->GetGrammarExpr (rule.body_expr_id );
754
- const auto & rule_fsm = grammar->per_rule_fsms [rule_id];
747
+ for (int32_t rule_id = 0 ; rule_id < static_cast <int >(compiled_grammar_impl->grammar ->NumRules ());
748
+ ++rule_id) {
749
+ auto rule = compiled_grammar_impl->grammar ->GetRule (rule_id);
750
+ auto rule_body = compiled_grammar_impl->grammar ->GetGrammarExpr (rule.body_expr_id );
751
+ const auto & rule_fsm = compiled_grammar_impl->grammar ->per_rule_fsms [rule_id];
755
752
if (rule_fsm.has_value ()) {
756
753
auto cur_stack_element =
757
754
ParserState (rule_id, rule.body_expr_id , 0 , ParserState::kNoPrevInputPos , 0 );
@@ -768,15 +765,15 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
768
765
}
769
766
XGRAMMAR_DCHECK (rule_body.type == GrammarExprType::kChoices );
770
767
for (auto sequence_id : rule_body) {
771
- const auto & sequence = grammar->GetGrammarExpr (sequence_id);
768
+ const auto & sequence = compiled_grammar_impl-> grammar ->GetGrammarExpr (sequence_id);
772
769
if (sequence.type == GrammarExprType::kEmptyStr ) {
773
770
continue ;
774
771
}
775
772
XGRAMMAR_DCHECK (sequence.type == GrammarExprType::kSequence );
776
773
auto state = ParserState (rule_id, sequence_id, 0 , ParserState::kNoPrevInputPos , 0 );
777
774
for (int element_id = 0 ; element_id < sequence.size (); ++element_id) {
778
775
state.element_id = element_id;
779
- auto element = grammar->GetGrammarExpr (sequence[element_id]);
776
+ auto element = compiled_grammar_impl-> grammar ->GetGrammarExpr (sequence[element_id]);
780
777
if (element.type == GrammarExprType::kRuleRef || element.type == GrammarExprType::kRepeat ) {
781
778
continue ;
782
779
}
0 commit comments