Skip to content

Commit ced69c3

Browse files
[Feature] Add a new expression to represent repetition to speed up. (#368)
Currently, if we try to compile a grammar with a expression like `[a-z]{2, 20000}`, it will cost a lot of time to compile the grammar. In this PR, we add a new type of expression `kRepeat` in the parser, which can significantly improve the efficiency in such cases. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 8d68545 commit ced69c3

15 files changed

+382
-148
lines changed

cpp/earley_parser.cc

Lines changed: 90 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
#include "earley_parser.h"
77

8+
#include <algorithm>
89
#include <cassert>
910
#include <cctype>
1011
#include <cstdint>
@@ -55,18 +56,55 @@ void EarleyParser::Complete(const ParserState& state, const GrammarExpr& grammar
5556
const auto& parent_state = parent_state_iter->second;
5657
const auto& parent_expr = grammar_->GetGrammarExpr(parent_state.sequence_id);
5758
if (parent_state.rule_id == -1 || !grammar_->per_rule_fsms[parent_state.rule_id].has_value()) {
59+
const auto& element_expr = grammar_->GetGrammarExpr(parent_expr[parent_state.element_id]);
5860
// The new rule is not referenced by a fsm.
5961
XGRAMMAR_DCHECK(
60-
grammar_->GetGrammarExpr(parent_expr[parent_state.element_id]).type ==
61-
GrammarExprType::kRuleRef
62+
element_expr.type == GrammarExprType::kRuleRef ||
63+
element_expr.type == GrammarExprType::kRepeat
6264
);
63-
Enqueue(ParserState{
64-
parent_state.rule_id,
65-
parent_state.sequence_id,
66-
parent_state.element_id + 1,
67-
parent_state.rule_start_pos,
68-
0
69-
});
65+
if (element_expr.type == GrammarExprType::kRuleRef) {
66+
Enqueue(ParserState{
67+
parent_state.rule_id,
68+
parent_state.sequence_id,
69+
parent_state.element_id + 1,
70+
parent_state.rule_start_pos,
71+
0
72+
});
73+
continue;
74+
}
75+
XGRAMMAR_DCHECK(element_expr.type == GrammarExprType::kRepeat);
76+
if (state.rule_start_pos ==
77+
static_cast<int32_t>(rule_id_to_completeable_states_.size() - 1) &&
78+
std::binary_search(
79+
grammar_->allow_empty_rule_ids.begin(),
80+
grammar_->allow_empty_rule_ids.end(),
81+
element_expr[0]
82+
)) {
83+
// It means that the subrule of the repeat is empty, and we have already detected it.
84+
// We shouldn't add it into the queue.
85+
continue;
86+
}
87+
// The parent state is a repeat, we need to increase the repeat count.
88+
auto new_state = parent_state;
89+
const int32_t& min_repeat_count = element_expr[1];
90+
const int32_t& max_repeat_count = element_expr[2];
91+
new_state.repeat_count++;
92+
// The repeat rule can be completed, and we advance the state. Don't forget to
93+
// reset the repeat count.
94+
if (new_state.repeat_count >= min_repeat_count) {
95+
Enqueue(ParserState{
96+
parent_state.rule_id,
97+
parent_state.sequence_id,
98+
parent_state.element_id + 1,
99+
parent_state.rule_start_pos,
100+
0
101+
});
102+
}
103+
// If the repeat count is less than the max repeat count, we can continue to
104+
// visit the repeat state for another round.
105+
if (new_state.repeat_count < max_repeat_count) {
106+
Enqueue(new_state);
107+
}
70108
continue;
71109
}
72110
// If the rule is referenced by a fsm, we need to advance the fsm.
@@ -105,16 +143,37 @@ std::pair</* scanable */ bool, /* completable */ bool> EarleyParser::Predict(
105143
return std::make_pair(false, true);
106144
}
107145
const auto& element_expr = grammar_->GetGrammarExpr(grammar_expr[state.element_id]);
108-
if (element_expr.type == GrammarExprType::kRuleRef) {
109-
ExpandNextRuleRefElement(state, grammar_expr, &element_expr);
110-
return std::make_pair(false, false);
111-
}
112-
if (element_expr.type == GrammarExprType::kCharacterClassStar && state.sub_element_id == 0) {
113-
Enqueue(
114-
ParserState{state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0}
115-
);
146+
switch (element_expr.type) {
147+
case GrammarExprType::kRuleRef: {
148+
ExpandNextRuleRefElement(state, grammar_expr, &element_expr);
149+
return std::make_pair(false, false);
150+
}
151+
case GrammarExprType::kCharacterClassStar: {
152+
if (state.sub_element_id == 0) {
153+
Enqueue(ParserState{
154+
state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0
155+
});
156+
}
157+
return std::make_pair(true, false);
158+
}
159+
case GrammarExprType::kRepeat: {
160+
const int32_t& min_repeat_count = element_expr[1];
161+
const int32_t& max_repeat_count = element_expr[2];
162+
// If the current repeat count is less than the max repeat count,
163+
// we can expand the next rule reference element.
164+
XGRAMMAR_DCHECK(state.repeat_count <= max_repeat_count);
165+
ExpandNextRuleRefElement(state, grammar_expr, &element_expr);
166+
if (state.repeat_count >= min_repeat_count) {
167+
Enqueue(ParserState{
168+
state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0
169+
});
170+
}
171+
return std::make_pair(false, false);
172+
}
173+
default: {
174+
return std::make_pair(true, false);
175+
}
116176
}
117-
return std::make_pair(true, false);
118177
}
119178

120179
void EarleyParser::Scan(const ParserState& state, const uint8_t ch) {
@@ -165,7 +224,6 @@ bool EarleyParser::Advance(const uint8_t ch) {
165224
tmp_states_to_be_added_.clear();
166225
tmp_accept_stop_token_ = false;
167226
const auto& latest_states = scanable_state_history_[scanable_state_history_.size() - 1];
168-
169227
// Scan all the scanable states.
170228
for (const auto& state : latest_states) {
171229
Scan(state, ch);
@@ -322,14 +380,18 @@ void EarleyParser::ExpandNextRuleRefElement(
322380
}
323381
} else {
324382
XGRAMMAR_DCHECK(grammar_expr.type == GrammarExprType::kSequence);
325-
XGRAMMAR_DCHECK(sub_grammar_expr->type == GrammarExprType::kRuleRef);
383+
XGRAMMAR_DCHECK(
384+
sub_grammar_expr->type == GrammarExprType::kRuleRef ||
385+
sub_grammar_expr->type == GrammarExprType::kRepeat
386+
);
326387
ref_rule_ids.push_back((*sub_grammar_expr)[0]);
327388
}
328389
for (const auto& ref_rule_id : ref_rule_ids) {
329390
{ // Add the reference rule to map.
330391
if ((state.element_id != grammar_expr.size() - 1) ||
331392
state.rule_start_pos == ParserState::kNoPrevInputPos ||
332-
(state.rule_id != -1 && grammar_->per_rule_fsms[state.rule_id].has_value())) {
393+
(state.rule_id != -1 && grammar_->per_rule_fsms[state.rule_id].has_value()) ||
394+
sub_grammar_expr->type == GrammarExprType::kRepeat) {
333395
// It's not the right recursion, or it's the root rule.
334396
auto& states_map = rule_id_to_completeable_states_.back();
335397
states_map.insert({ref_rule_id, state});
@@ -374,11 +436,11 @@ void EarleyParser::ExpandNextRuleRefElement(
374436

375437
// Check if the reference rule is already visited.
376438
if (IsStateVisitedInQueue({ref_rule_id, -1, -1, -1, -1})) {
377-
if (std::find(
439+
if (std::binary_search(
378440
grammar_->allow_empty_rule_ids.begin(),
379441
grammar_->allow_empty_rule_ids.end(),
380442
ref_rule_id
381-
) != grammar_->allow_empty_rule_ids.end()) {
443+
)) {
382444
if (state.rule_id != -1 && grammar_->per_rule_fsms[state.rule_id].has_value()) {
383445
const auto& current_fsm = grammar_->per_rule_fsms[state.rule_id].value();
384446
for (const auto& edge : current_fsm->GetEdges(state.element_id)) {
@@ -391,9 +453,11 @@ void EarleyParser::ExpandNextRuleRefElement(
391453
continue;
392454
}
393455
XGRAMMAR_DCHECK(grammar_expr.type == GrammarExprType::kSequence);
394-
Enqueue(ParserState{
395-
state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0
396-
});
456+
if (sub_grammar_expr->type == GrammarExprType::kRuleRef) {
457+
Enqueue(ParserState{
458+
state.rule_id, state.sequence_id, state.element_id + 1, state.rule_start_pos, 0
459+
});
460+
}
397461
}
398462
continue;
399463
}

cpp/earley_parser.h

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,19 @@ struct ParserState {
4343
constexpr ParserState() = default;
4444

4545
constexpr ParserState(
46-
int32_t rule_id,
47-
int32_t sequence_id,
48-
int32_t element_id,
49-
int32_t rule_start_pos,
50-
int32_t sub_element_id
46+
const int32_t& rule_id,
47+
const int32_t& sequence_id,
48+
const int32_t& element_id,
49+
const int32_t& rule_start_pos,
50+
const int32_t& sub_element_id,
51+
const int32_t& repeat_count = 0
5152
)
5253
: rule_id(rule_id),
5354
sequence_id(sequence_id),
5455
element_id(element_id),
5556
rule_start_pos(rule_start_pos),
56-
sub_element_id(sub_element_id) {}
57-
58-
constexpr ParserState(const ParserState&) = default;
59-
constexpr ParserState(ParserState&&) = default;
60-
ParserState& operator=(const ParserState&) = default;
61-
ParserState& operator=(ParserState&&) = default;
57+
sub_element_id(sub_element_id),
58+
repeat_count(repeat_count) {}
6259

6360
/*!
6461
* \brief A sequence_id value of kUnexpandedRuleStartSequenceId means a rule hasn't been
@@ -92,6 +89,9 @@ struct ParserState {
9289
/*! \brief The id of the sub element in the current selement of the sequence. */
9390
int32_t sub_element_id = 0;
9491

92+
/*! \brief The number of times the element is repeated. It will be used in kRepeat.*/
93+
int32_t repeat_count = 0;
94+
9595
/*! \brief The element is invalid when sequence_id is -1. */
9696
bool IsInvalid() const { return sequence_id == -1; }
9797

@@ -107,13 +107,15 @@ struct ParserState {
107107
if (sequence_id != other.sequence_id) return sequence_id < other.sequence_id;
108108
if (element_id != other.element_id) return element_id < other.element_id;
109109
if (rule_start_pos != other.rule_start_pos) return rule_start_pos < other.rule_start_pos;
110-
return sub_element_id < other.sub_element_id;
110+
if (sub_element_id != other.sub_element_id) return sub_element_id < other.sub_element_id;
111+
return repeat_count < other.repeat_count;
111112
}
112113

113114
friend std::ostream& operator<<(std::ostream& os, const ParserState& state) {
114115
os << "ParserState(rule_id=" << state.rule_id << ", sequence_id=" << state.sequence_id
115116
<< ", element_id=" << state.element_id << ", rule_start_pos=" << state.rule_start_pos
116-
<< ", sub_element_id=" << state.sub_element_id << ")";
117+
<< ", sub_element_id=" << state.sub_element_id << ", repeat_count=" << state.repeat_count
118+
<< ")";
117119
return os;
118120
}
119121

@@ -132,7 +134,8 @@ XGRAMMAR_MEMBER_ARRAY(
132134
&ParserState::sequence_id,
133135
&ParserState::element_id,
134136
&ParserState::rule_start_pos,
135-
&ParserState::sub_element_id
137+
&ParserState::sub_element_id,
138+
&ParserState::repeat_count
136139
);
137140

138141
/*!
@@ -170,7 +173,8 @@ class StateHashForParsing {
170173
state.sequence_id,
171174
state.element_id,
172175
state.rule_start_pos,
173-
state.sub_element_id
176+
state.sub_element_id,
177+
state.repeat_count
174178
);
175179
}
176180
};

cpp/grammar_builder.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,12 @@ class GrammarBuilder {
199199
);
200200
}
201201

202+
int32_t AddRepeat(const int32_t ref_rule_id, int32_t min_repeat_count, int32_t max_repeat_count) {
203+
std::vector<int32_t> data({ref_rule_id, min_repeat_count, max_repeat_count});
204+
return AddGrammarExpr({GrammarExprType::kRepeat, data.data(), static_cast<int32_t>(data.size())}
205+
);
206+
}
207+
202208
/*! \brief Get the number of grammar_exprs. */
203209
int32_t NumGrammarExprs() const { return grammar_->NumGrammarExprs(); }
204210

0 commit comments

Comments
 (0)