5
5
6
6
#include " earley_parser.h"
7
7
8
+ #include < algorithm>
8
9
#include < cassert>
9
10
#include < cctype>
10
11
#include < cstdint>
@@ -55,18 +56,55 @@ void EarleyParser::Complete(const ParserState& state, const GrammarExpr& grammar
55
56
const auto & parent_state = parent_state_iter->second ;
56
57
const auto & parent_expr = grammar_->GetGrammarExpr (parent_state.sequence_id );
57
58
if (parent_state.rule_id == -1 || !grammar_->per_rule_fsms [parent_state.rule_id ].has_value ()) {
59
+ const auto & element_expr = grammar_->GetGrammarExpr (parent_expr[parent_state.element_id ]);
58
60
// The new rule is not referenced by a fsm.
59
61
XGRAMMAR_DCHECK (
60
- grammar_-> GetGrammarExpr (parent_expr[parent_state. element_id ]). type ==
61
- GrammarExprType::kRuleRef
62
+ element_expr. type == GrammarExprType:: kRuleRef ||
63
+ element_expr. type == GrammarExprType::kRepeat
62
64
);
63
- Enqueue (ParserState{
64
- parent_state.rule_id ,
65
- parent_state.sequence_id ,
66
- parent_state.element_id + 1 ,
67
- parent_state.rule_start_pos ,
68
- 0
69
- });
65
+ if (element_expr.type == GrammarExprType::kRuleRef ) {
66
+ Enqueue (ParserState{
67
+ parent_state.rule_id ,
68
+ parent_state.sequence_id ,
69
+ parent_state.element_id + 1 ,
70
+ parent_state.rule_start_pos ,
71
+ 0
72
+ });
73
+ continue ;
74
+ }
75
+ XGRAMMAR_DCHECK (element_expr.type == GrammarExprType::kRepeat );
76
+ if (state.rule_start_pos ==
77
+ static_cast <int32_t >(rule_id_to_completeable_states_.size () - 1 ) &&
78
+ std::binary_search (
79
+ grammar_->allow_empty_rule_ids .begin (),
80
+ grammar_->allow_empty_rule_ids .end (),
81
+ element_expr[0 ]
82
+ )) {
83
+ // It means that the subrule of the repeat is empty, and we have already detected it.
84
+ // We shouldn't add it into the queue.
85
+ continue ;
86
+ }
87
+ // The parent state is a repeat, we need to increase the repeat count.
88
+ auto new_state = parent_state;
89
+ const int32_t & min_repeat_count = element_expr[1 ];
90
+ const int32_t & max_repeat_count = element_expr[2 ];
91
+ new_state.repeat_count ++;
92
+ // The repeat rule can be completed, and we advance the state. Don't forget to
93
+ // reset the repeat count.
94
+ if (new_state.repeat_count >= min_repeat_count) {
95
+ Enqueue (ParserState{
96
+ parent_state.rule_id ,
97
+ parent_state.sequence_id ,
98
+ parent_state.element_id + 1 ,
99
+ parent_state.rule_start_pos ,
100
+ 0
101
+ });
102
+ }
103
+ // If the repeat count is less than the max repeat count, we can continue to
104
+ // visit the repeat state for another round.
105
+ if (new_state.repeat_count < max_repeat_count) {
106
+ Enqueue (new_state);
107
+ }
70
108
continue ;
71
109
}
72
110
// If the rule is referenced by a fsm, we need to advance the fsm.
@@ -105,16 +143,37 @@ std::pair</* scanable */ bool, /* completable */ bool> EarleyParser::Predict(
105
143
return std::make_pair (false , true );
106
144
}
107
145
const auto & element_expr = grammar_->GetGrammarExpr (grammar_expr[state.element_id ]);
108
- if (element_expr.type == GrammarExprType::kRuleRef ) {
109
- ExpandNextRuleRefElement (state, grammar_expr, &element_expr);
110
- return std::make_pair (false , false );
111
- }
112
- if (element_expr.type == GrammarExprType::kCharacterClassStar && state.sub_element_id == 0 ) {
113
- Enqueue (
114
- ParserState{state.rule_id , state.sequence_id , state.element_id + 1 , state.rule_start_pos , 0 }
115
- );
146
+ switch (element_expr.type ) {
147
+ case GrammarExprType::kRuleRef : {
148
+ ExpandNextRuleRefElement (state, grammar_expr, &element_expr);
149
+ return std::make_pair (false , false );
150
+ }
151
+ case GrammarExprType::kCharacterClassStar : {
152
+ if (state.sub_element_id == 0 ) {
153
+ Enqueue (ParserState{
154
+ state.rule_id , state.sequence_id , state.element_id + 1 , state.rule_start_pos , 0
155
+ });
156
+ }
157
+ return std::make_pair (true , false );
158
+ }
159
+ case GrammarExprType::kRepeat : {
160
+ const int32_t & min_repeat_count = element_expr[1 ];
161
+ const int32_t & max_repeat_count = element_expr[2 ];
162
+ // If the current repeat count is less than the max repeat count,
163
+ // we can expand the next rule reference element.
164
+ XGRAMMAR_DCHECK (state.repeat_count <= max_repeat_count);
165
+ ExpandNextRuleRefElement (state, grammar_expr, &element_expr);
166
+ if (state.repeat_count >= min_repeat_count) {
167
+ Enqueue (ParserState{
168
+ state.rule_id , state.sequence_id , state.element_id + 1 , state.rule_start_pos , 0
169
+ });
170
+ }
171
+ return std::make_pair (false , false );
172
+ }
173
+ default : {
174
+ return std::make_pair (true , false );
175
+ }
116
176
}
117
- return std::make_pair (true , false );
118
177
}
119
178
120
179
void EarleyParser::Scan (const ParserState& state, const uint8_t ch) {
@@ -165,7 +224,6 @@ bool EarleyParser::Advance(const uint8_t ch) {
165
224
tmp_states_to_be_added_.clear ();
166
225
tmp_accept_stop_token_ = false ;
167
226
const auto & latest_states = scanable_state_history_[scanable_state_history_.size () - 1 ];
168
-
169
227
// Scan all the scanable states.
170
228
for (const auto & state : latest_states) {
171
229
Scan (state, ch);
@@ -322,14 +380,18 @@ void EarleyParser::ExpandNextRuleRefElement(
322
380
}
323
381
} else {
324
382
XGRAMMAR_DCHECK (grammar_expr.type == GrammarExprType::kSequence );
325
- XGRAMMAR_DCHECK (sub_grammar_expr->type == GrammarExprType::kRuleRef );
383
+ XGRAMMAR_DCHECK (
384
+ sub_grammar_expr->type == GrammarExprType::kRuleRef ||
385
+ sub_grammar_expr->type == GrammarExprType::kRepeat
386
+ );
326
387
ref_rule_ids.push_back ((*sub_grammar_expr)[0 ]);
327
388
}
328
389
for (const auto & ref_rule_id : ref_rule_ids) {
329
390
{ // Add the reference rule to map.
330
391
if ((state.element_id != grammar_expr.size () - 1 ) ||
331
392
state.rule_start_pos == ParserState::kNoPrevInputPos ||
332
- (state.rule_id != -1 && grammar_->per_rule_fsms [state.rule_id ].has_value ())) {
393
+ (state.rule_id != -1 && grammar_->per_rule_fsms [state.rule_id ].has_value ()) ||
394
+ sub_grammar_expr->type == GrammarExprType::kRepeat ) {
333
395
// It's not the right recursion, or it's the root rule.
334
396
auto & states_map = rule_id_to_completeable_states_.back ();
335
397
states_map.insert ({ref_rule_id, state});
@@ -374,11 +436,11 @@ void EarleyParser::ExpandNextRuleRefElement(
374
436
375
437
// Check if the reference rule is already visited.
376
438
if (IsStateVisitedInQueue ({ref_rule_id, -1 , -1 , -1 , -1 })) {
377
- if (std::find (
439
+ if (std::binary_search (
378
440
grammar_->allow_empty_rule_ids .begin (),
379
441
grammar_->allow_empty_rule_ids .end (),
380
442
ref_rule_id
381
- ) != grammar_-> allow_empty_rule_ids . end () ) {
443
+ )) {
382
444
if (state.rule_id != -1 && grammar_->per_rule_fsms [state.rule_id ].has_value ()) {
383
445
const auto & current_fsm = grammar_->per_rule_fsms [state.rule_id ].value ();
384
446
for (const auto & edge : current_fsm->GetEdges (state.element_id )) {
@@ -391,9 +453,11 @@ void EarleyParser::ExpandNextRuleRefElement(
391
453
continue ;
392
454
}
393
455
XGRAMMAR_DCHECK (grammar_expr.type == GrammarExprType::kSequence );
394
- Enqueue (ParserState{
395
- state.rule_id , state.sequence_id , state.element_id + 1 , state.rule_start_pos , 0
396
- });
456
+ if (sub_grammar_expr->type == GrammarExprType::kRuleRef ) {
457
+ Enqueue (ParserState{
458
+ state.rule_id , state.sequence_id , state.element_id + 1 , state.rule_start_pos , 0
459
+ });
460
+ }
397
461
}
398
462
continue ;
399
463
}
0 commit comments