Skip to content

Commit 9a15539

Browse files
[Refac] Refactor the pipeline. (#428)
This is a rebased version of #383. This PR refactors the pipeline. The functors in `grammar_functor` are divided into three types: - `grammar_normalizer`: When a grammar is constructed, they should be called immediately. - `grammar_optimizer`: When a grammar is going to be compiled, they will be called. - `grammar_constructor`: They are used to construct a new grammar, like constructing the union of two grammars. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent c487a83 commit 9a15539

14 files changed

+617
-382
lines changed

cpp/earley_parser.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ EarleyParser::EarleyParser(
241241
const Grammar& grammar, const ParserState& init_state, const bool need_expand
242242
)
243243
: grammar_(grammar) {
244+
if (!grammar->optimized) {
245+
XGRAMMAR_LOG(FATAL) << "The grammar is not optimized. Please optimize the grammar before using "
246+
"the Earley parser.";
247+
}
244248
// Check if the initial state is valid. If invalid, then we choose the root state as default.
245249
ParserState init = init_state;
246250
if (init_state.IsInvalid()) {

cpp/grammar_compiler.cc

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -645,16 +645,13 @@ class GrammarCompilerNoCache {
645645
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
646646
};
647647

648-
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar) {
648+
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) {
649649
using GrammarExprType = Grammar::Impl::GrammarExprType;
650650

651651
auto compiled_grammar_impl = std::make_shared<CompiledGrammar::Impl>();
652652

653-
compiled_grammar_impl->grammar = grammar;
653+
compiled_grammar_impl->grammar = GrammarOptimizer::Apply(grammar_unoptimized);
654654
compiled_grammar_impl->tokenizer_info = tokenizer_info_;
655-
grammar->allow_empty_rule_ids = AllowEmptyRuleAnalyzer::Apply(compiled_grammar_impl->grammar);
656-
RepetitionNormalizer::Apply(&compiled_grammar_impl->grammar);
657-
GrammarFSMBuilder::Apply(&compiled_grammar_impl->grammar);
658655
if (tokenizer_info_.GetVocabSize() == 0) {
659656
return CompiledGrammar(compiled_grammar_impl);
660657
}
@@ -711,15 +708,14 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
711708
// not need ThreadPool or std::mutex, which throws error in runtime in WebAssembly.
712709
std::optional<ThreadPool> thread_pool;
713710
std::optional<std::mutex> adaptive_token_mask_cache_mutex;
714-
715711
if (max_threads_ > 1) {
716712
thread_pool.emplace(max_threads_);
717713
adaptive_token_mask_cache_mutex.emplace();
718714
}
719715

720716
auto add_adaptive_token_mask = [&](const ParserState& state, bool is_root_rule) {
721717
auto grammar_matcher = GrammarMatcherForTokenMaskCache(
722-
grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
718+
compiled_grammar_impl->grammar, state, tag_dispatch_rule_id_to_second_slicing_bitset, false
723719
);
724720
auto cur_adaptive_token_mask_cache = grammar_matcher.GetAdaptiveTokenMask(
725721
tokenizer_info_.GetVocabSize(),
@@ -746,12 +742,13 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
746742
}
747743
};
748744

749-
auto root_rule_id = grammar->GetRootRuleId();
745+
auto root_rule_id = compiled_grammar_impl->grammar->GetRootRuleId();
750746

751-
for (int32_t rule_id = 0; rule_id < static_cast<int>(grammar->NumRules()); ++rule_id) {
752-
auto rule = grammar->GetRule(rule_id);
753-
auto rule_body = grammar->GetGrammarExpr(rule.body_expr_id);
754-
const auto& rule_fsm = grammar->per_rule_fsms[rule_id];
747+
for (int32_t rule_id = 0; rule_id < static_cast<int>(compiled_grammar_impl->grammar->NumRules());
748+
++rule_id) {
749+
auto rule = compiled_grammar_impl->grammar->GetRule(rule_id);
750+
auto rule_body = compiled_grammar_impl->grammar->GetGrammarExpr(rule.body_expr_id);
751+
const auto& rule_fsm = compiled_grammar_impl->grammar->per_rule_fsms[rule_id];
755752
if (rule_fsm.has_value()) {
756753
auto cur_stack_element =
757754
ParserState(rule_id, rule.body_expr_id, 0, ParserState::kNoPrevInputPos, 0);
@@ -768,15 +765,15 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
768765
}
769766
XGRAMMAR_DCHECK(rule_body.type == GrammarExprType::kChoices);
770767
for (auto sequence_id : rule_body) {
771-
const auto& sequence = grammar->GetGrammarExpr(sequence_id);
768+
const auto& sequence = compiled_grammar_impl->grammar->GetGrammarExpr(sequence_id);
772769
if (sequence.type == GrammarExprType::kEmptyStr) {
773770
continue;
774771
}
775772
XGRAMMAR_DCHECK(sequence.type == GrammarExprType::kSequence);
776773
auto state = ParserState(rule_id, sequence_id, 0, ParserState::kNoPrevInputPos, 0);
777774
for (int element_id = 0; element_id < sequence.size(); ++element_id) {
778775
state.element_id = element_id;
779-
auto element = grammar->GetGrammarExpr(sequence[element_id]);
776+
auto element = compiled_grammar_impl->grammar->GetGrammarExpr(sequence[element_id]);
780777
if (element.type == GrammarExprType::kRuleRef || element.type == GrammarExprType::kRepeat) {
781778
continue;
782779
}

0 commit comments

Comments
 (0)