Skip to content

Commit bdee539

Browse files
[Fix] Fix the optimization for TagDispatch and SubGrammarAdder. (#471)
This PR fixes a concurrency problem in the optimization for TagDispatch, and fixes a bug in `SubGrammarAdder`. --------- Signed-off-by: Yuchuan <[email protected]>
1 parent 65f45f7 commit bdee539

File tree

4 files changed

+185
-44
lines changed

4 files changed

+185
-44
lines changed

cpp/grammar_compiler.cc

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -636,13 +636,20 @@ class GrammarCompilerNoCache {
636636
private:
637637
/*! \brief The main logic. Compile the grammar with multi-threading. */
638638
CompiledGrammar MultiThreadCompileGrammar(Grammar grammar);
639+
/*! \brief Optimization for TagDispatch.
640+
* \param compiled_grammar_impl the compiled_grammar to be optimized.
641+
* \param tag_dispatch_rule_id_to_second_slicing_bitset Return value. Mapping from the rule_id to
642+
* the definite accepted token mask.
643+
*/
644+
void TagDispatchOptimization(
645+
std::shared_ptr<CompiledGrammar::Impl> compiled_grammar_impl,
646+
std::unordered_map<int32_t, DynamicBitset>* tag_dispatch_rule_id_to_second_slicing_bitset
647+
);
639648

640649
/*! \brief The vocabulary associated with this storage class. */
641650
const TokenizerInfo tokenizer_info_;
642651
/*! \brief The maximum number of threads to use. */
643652
const int max_threads_;
644-
/*! \brief Mapping from the rule_id to the definite accepted token mask. */
645-
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
646653
};
647654

648655
CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar grammar_unoptimized) {
@@ -655,48 +662,8 @@ CompiledGrammar GrammarCompilerNoCache::MultiThreadCompileGrammar(Grammar gramma
655662
if (tokenizer_info_.GetVocabSize() == 0) {
656663
return CompiledGrammar(compiled_grammar_impl);
657664
}
658-
659-
// Optimization for TagDispatch: Precompute the definitely accepted tokens.
660-
for (int i = 0; i < compiled_grammar_impl->grammar->NumRules(); i++) {
661-
const auto& rule = compiled_grammar_impl->grammar->GetRule(i);
662-
const auto& rule_body = compiled_grammar_impl->grammar->GetGrammarExpr(rule.body_expr_id);
663-
if (rule_body.type != GrammarExprType::kTagDispatch) {
664-
continue;
665-
}
666-
XGRAMMAR_DCHECK(rule_body.type == GrammarExprType::kTagDispatch);
667-
Grammar::Impl::TagDispatch tag_dispatch =
668-
compiled_grammar_impl->grammar->GetTagDispatch(rule.body_expr_id);
669-
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
670-
DynamicBitset definite_accepted_tokens_since_second_char(sorted_decoded_vocab.size());
671-
for (int i = 0; i < static_cast<int32_t>(sorted_decoded_vocab.size()); i++) {
672-
bool definite_accept_since_second_char = true;
673-
const auto& token = sorted_decoded_vocab[i].second;
674-
if (token.empty()) {
675-
definite_accepted_tokens_since_second_char.Set(i);
676-
continue;
677-
}
678-
679-
// Check if the token contains any tag or stop string after the first character.
680-
for (const auto& tag : tag_dispatch.tag_rule_pairs) {
681-
if (token.find(tag.first, 1) != std::string::npos) {
682-
definite_accept_since_second_char = false;
683-
break;
684-
}
685-
}
686-
for (const auto& stop_str : tag_dispatch.stop_str) {
687-
if (token.find(stop_str, 1) != std::string::npos) {
688-
definite_accept_since_second_char = false;
689-
break;
690-
}
691-
}
692-
693-
// If the token can be definitely accepted since the second character, set the bit.
694-
if (definite_accept_since_second_char) {
695-
definite_accepted_tokens_since_second_char.Set(i);
696-
}
697-
}
698-
tag_dispatch_rule_id_to_second_slicing_bitset[i] = definite_accepted_tokens_since_second_char;
699-
}
665+
std::unordered_map<int32_t, DynamicBitset> tag_dispatch_rule_id_to_second_slicing_bitset;
666+
TagDispatchOptimization(compiled_grammar_impl, &tag_dispatch_rule_id_to_second_slicing_bitset);
700667
// Step 3. Compute the adaptive token mask cache
701668
// The token mask cache is computed for these positions in the grammar:
702669
// 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3)
@@ -842,6 +809,57 @@ CompiledGrammar GrammarCompilerNoCache::CompileGrammar(
842809
return MultiThreadCompileGrammar(Grammar::FromEBNF(ebnf_str, root_rule_name));
843810
}
844811

812+
void GrammarCompilerNoCache::TagDispatchOptimization(
813+
std::shared_ptr<CompiledGrammar::Impl> compiled_grammar_impl,
814+
std::unordered_map<int32_t, DynamicBitset>* tag_dispatch_rule_id_to_second_slicing_bitset
815+
) {
816+
using GrammarExprType = Grammar::Impl::GrammarExprType;
817+
tag_dispatch_rule_id_to_second_slicing_bitset->clear();
818+
819+
// Optimization for TagDispatch: Precompute the definitely accepted tokens.
820+
for (int i = 0; i < compiled_grammar_impl->grammar->NumRules(); i++) {
821+
const auto& rule = compiled_grammar_impl->grammar->GetRule(i);
822+
const auto& rule_body = compiled_grammar_impl->grammar->GetGrammarExpr(rule.body_expr_id);
823+
if (rule_body.type != GrammarExprType::kTagDispatch) {
824+
continue;
825+
}
826+
XGRAMMAR_DCHECK(rule_body.type == GrammarExprType::kTagDispatch);
827+
Grammar::Impl::TagDispatch tag_dispatch =
828+
compiled_grammar_impl->GetGrammar()->GetTagDispatch(rule.body_expr_id);
829+
const auto& sorted_decoded_vocab = tokenizer_info_.GetSortedDecodedVocab();
830+
DynamicBitset definite_accepted_tokens_since_second_char(sorted_decoded_vocab.size());
831+
for (int i = 0; i < static_cast<int32_t>(sorted_decoded_vocab.size()); i++) {
832+
bool definite_accept_since_second_char = true;
833+
const auto& token = sorted_decoded_vocab[i].second;
834+
if (token.empty()) {
835+
definite_accepted_tokens_since_second_char.Set(i);
836+
continue;
837+
}
838+
839+
// Check if the token contains any tag or stop string after the first character.
840+
for (const auto& tag : tag_dispatch.tag_rule_pairs) {
841+
if (token.find(tag.first, 1) != std::string::npos) {
842+
definite_accept_since_second_char = false;
843+
break;
844+
}
845+
}
846+
for (const auto& stop_str : tag_dispatch.stop_str) {
847+
if (token.find(stop_str, 1) != std::string::npos) {
848+
definite_accept_since_second_char = false;
849+
break;
850+
}
851+
}
852+
853+
// If the token can be definitely accepted since the second character, set the bit.
854+
if (definite_accept_since_second_char) {
855+
definite_accepted_tokens_since_second_char.Set(i);
856+
}
857+
}
858+
(*tag_dispatch_rule_id_to_second_slicing_bitset)[i] =
859+
definite_accepted_tokens_since_second_char;
860+
}
861+
}
862+
845863
/******************* GrammarCompiler::Impl *******************/
846864

847865
/*!

cpp/grammar_functor.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,18 @@ class SubGrammarAdderImpl : public GrammarMutator {
7373
);
7474
}
7575

76+
int32_t VisitTagDispatch(const GrammarExpr& grammar_expr) final {
77+
Grammar::Impl::TagDispatch old_tag_dispatch = base_grammar_->GetTagDispatch(grammar_expr);
78+
Grammar::Impl::TagDispatch new_tag_dispatch;
79+
new_tag_dispatch.stop_eos = old_tag_dispatch.stop_eos;
80+
for (const auto& [tag, rule_id] : old_tag_dispatch.tag_rule_pairs) {
81+
new_tag_dispatch.tag_rule_pairs.emplace_back(tag, new_rule_ids_names[rule_id].first);
82+
}
83+
new_tag_dispatch.stop_str = old_tag_dispatch.stop_str;
84+
new_tag_dispatch.loop_after_dispatch = old_tag_dispatch.loop_after_dispatch;
85+
return builder_->AddTagDispatch(new_tag_dispatch);
86+
}
87+
7688
std::vector<std::pair<int32_t, std::string>> new_rule_ids_names;
7789
};
7890

tests/python/test_grammar_matcher_structural_tag.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import sys
3+
import threading
34
import time
45
from typing import List
56

@@ -340,5 +341,36 @@ def test_utf8_structural_tag_begin_end():
340341
_ = compiler.compile_structural_tag(structures, triggers)
341342

342343

344+
@pytest.mark.hf_token_required
345+
def test_pressure_structural_tag():
346+
model = "meta-llama/Llama-3.1-8B-Instruct"
347+
tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True, trust_remote_code=True)
348+
tokenizer_info = xgr.TokenizerInfo.from_huggingface(tokenizer)
349+
compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=1)
350+
threads = []
351+
start = "start"
352+
schema = {"type": "object", "properties": {"arg": {"type": "string"}}}
353+
end = "end"
354+
355+
def worker(idx: int):
356+
tag = xgr.StructuralTagItem(begin=start, schema=schema, end=end)
357+
triggers = [start]
358+
stag_grammar = xgr.Grammar.from_structural_tag([tag], triggers)
359+
start_grammar = xgr.Grammar.from_ebnf("root ::= [a-z] root | [a-z]")
360+
grammar = start_grammar
361+
for _ in range(idx):
362+
grammar = grammar.concat(grammar, start_grammar)
363+
final_grammar = xgr.Grammar.concat(grammar, stag_grammar)
364+
_ = compiler.compile_grammar(final_grammar)
365+
366+
for i in range(128):
367+
t = threading.Thread(target=worker, args=(i,))
368+
threads.append(t)
369+
t.start()
370+
371+
for t in threads:
372+
t.join()
373+
374+
343375
if __name__ == "__main__":
344376
pytest.main(sys.argv)

tests/python/test_grammar_union_concat.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,84 @@ def test_grammar_concat():
8383
assert str(concat_grammar) == expected
8484

8585

86+
def test_grammar_union_with_stag():
87+
expected_grammar_union = r"""root ::= ((root_1_1) | (root_2))
88+
basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]))
89+
basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:]))
90+
basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object))
91+
basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*))
92+
basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6))
93+
basic_string ::= (("\"" basic_string_sub))
94+
basic_boolean ::= (("true") | ("false"))
95+
basic_null ::= (("null"))
96+
basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]"))
97+
basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}"))
98+
root_1 ::= (("{" [ \n\t]* "\"arg\"" [ \n\t]* ":" [ \n\t]* basic_string [ \n\t]* "}") | ("{" [ \n\t]* "}"))
99+
basic_integer_1 ::= ("" | ("-"))
100+
basic_number_1 ::= ("" | ("-"))
101+
basic_number_2 ::= (([0-9] basic_number_2) | ([0-9]))
102+
basic_number_3 ::= ("" | ("." basic_number_2))
103+
basic_number_4 ::= ("" | ([+\-]))
104+
basic_number_5 ::= (([0-9] basic_number_5) | ([0-9]))
105+
basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5))
106+
basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1))
107+
basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1))
108+
basic_number_7 ::= (("0") | ([1-9] [0-9]*))
109+
triggered_tags_group ::= (("" root_1 "end"))
110+
triggered_tags ::= TagDispatch(
111+
("start", triggered_tags_group),
112+
stop_eos=true,
113+
stop_str=(),
114+
loop_after_dispatch=true
115+
)
116+
root_1_1 ::= ((triggered_tags))
117+
root_2 ::= (([a-z] root_2) | ([a-z]))
118+
"""
119+
120+
expected_grammar_concat = r"""root ::= ((root_1_1 root_2))
121+
basic_escape ::= (([\"\\/bfnrt]) | ("u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]))
122+
basic_string_sub ::= (("\"") | ([^\0-\x1f\"\\\r\n] basic_string_sub) | ("\\" basic_escape basic_string_sub)) (=([ \n\t]* [,}\]:]))
123+
basic_any ::= ((basic_number) | (basic_string) | (basic_boolean) | (basic_null) | (basic_array) | (basic_object))
124+
basic_integer ::= (("0") | (basic_integer_1 [1-9] [0-9]*))
125+
basic_number ::= ((basic_number_1 basic_number_7 basic_number_3 basic_number_6))
126+
basic_string ::= (("\"" basic_string_sub))
127+
basic_boolean ::= (("true") | ("false"))
128+
basic_null ::= (("null"))
129+
basic_array ::= (("[" [ \n\t]* basic_any basic_array_1 [ \n\t]* "]") | ("[" [ \n\t]* "]"))
130+
basic_object ::= (("{" [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1 [ \n\t]* "}") | ("{" [ \n\t]* "}"))
131+
root_1 ::= (("{" [ \n\t]* "\"arg\"" [ \n\t]* ":" [ \n\t]* basic_string [ \n\t]* "}") | ("{" [ \n\t]* "}"))
132+
basic_integer_1 ::= ("" | ("-"))
133+
basic_number_1 ::= ("" | ("-"))
134+
basic_number_2 ::= (([0-9] basic_number_2) | ([0-9]))
135+
basic_number_3 ::= ("" | ("." basic_number_2))
136+
basic_number_4 ::= ("" | ([+\-]))
137+
basic_number_5 ::= (([0-9] basic_number_5) | ([0-9]))
138+
basic_number_6 ::= ("" | ([eE] basic_number_4 basic_number_5))
139+
basic_array_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_any basic_array_1))
140+
basic_object_1 ::= ("" | ([ \n\t]* "," [ \n\t]* basic_string [ \n\t]* ":" [ \n\t]* basic_any basic_object_1))
141+
basic_number_7 ::= (("0") | ([1-9] [0-9]*))
142+
triggered_tags_group ::= (("" root_1 "end"))
143+
triggered_tags ::= TagDispatch(
144+
("start", triggered_tags_group),
145+
stop_eos=true,
146+
stop_str=(),
147+
loop_after_dispatch=true
148+
)
149+
root_1_1 ::= ((triggered_tags))
150+
root_2 ::= (([a-z] root_2) | ([a-z]))
151+
"""
152+
start = "start"
153+
schema = {"type": "object", "properties": {"arg": {"type": "string"}}}
154+
end = "end"
155+
tag = xgr.StructuralTagItem(begin=start, schema=schema, end=end)
156+
triggers = [start]
157+
stag_grammar = xgr.Grammar.from_structural_tag([tag], triggers)
158+
start_grammar = xgr.Grammar.from_ebnf("root ::= [a-z] root | [a-z]")
159+
grammar_union = xgr.Grammar.union(stag_grammar, start_grammar)
160+
assert str(grammar_union) == expected_grammar_union
161+
grammar_concat = xgr.Grammar.concat(stag_grammar, start_grammar)
162+
assert str(grammar_concat) == expected_grammar_concat
163+
164+
86165
if __name__ == "__main__":
87166
pytest.main(sys.argv)

0 commit comments

Comments
 (0)