Skip to content

Commit ee680ff

Browse files
committed
[Serving][Grammar] Refactor GrammarStateMatcher and support LLaMA-3
This PR refactors GrammarStateMatcher and support the LLaMA-3 tokenizer. Common tokenizers, including Phi-2, Gemma, LLaMA-2, etc. are also supported. The performance is optimized for LLaMA-3 tokenizer since its token table has size 128k, much larger than LLaMA-2 tokenizer. These changes are introduced to the grammar library: These changes are introduced to the grammar library: 1. Introduce ByteString rule expression and simplify CharacterClass and CharacterClassStar 2. Refactor BNFGrammarVisitor and BNFGrammarMutator for visiting and mutating grammar rules 3. Now GrammarStateMatcherBase, the internally impl of the GrammarStateMatcher, accepts char by char, instead of codepoint by codepoint. So it supports any valid UTF-8 string, even if the token is not a complete codepoint. 4. Support lookahead assertion for rules to specify the rule must be followed by a sequence. This can eliminate some uncertain tokens in preprocessing. Minor changes: 1. Introduce template hash function HashCombine 2. Update the UTF8 encoding handling functions Performance: 1. For JSON, finding mask requires <30us on 5900X with single thread. The uncertain tokens is <30 in most cases. 2. For JSON schema, finding mask requires <30us on 5900X with single thread. The uncertain tokens is <30 in most cases.
1 parent 679d3a8 commit ee680ff

27 files changed

+1684
-1024
lines changed

cpp/serve/engine.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class EngineImpl : public Engine {
122122
}
123123
n->token_table_ =
124124
Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method);
125-
n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_);
125+
n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_);
126126
// - Create the logit processor and sampler, and
127127
// the DraftTokenWorkspaceManager for speculative decoding.
128128
int max_num_tokens = engine_config->max_num_sequence;
@@ -499,9 +499,9 @@ class EngineImpl : public Engine {
499499
if (response_format.type != "json_object") {
500500
return std::nullopt;
501501
} else if (!response_format.schema) {
502-
return grammar_init_context_storage_->GetInitContextForJSON();
502+
return grammar_init_context_cache_->GetInitContextForJSON();
503503
} else {
504-
return grammar_init_context_storage_->GetInitContextForJSONSchema(
504+
return grammar_init_context_cache_->GetInitContextForJSONSchema(
505505
response_format.schema.value());
506506
}
507507
}
@@ -513,7 +513,7 @@ class EngineImpl : public Engine {
513513
Tokenizer tokenizer_;
514514
std::vector<std::string> token_table_;
515515
// Helper to get the grammar init context for requests.
516-
GrammarInitContextStorage grammar_init_context_storage_;
516+
GrammarInitContextCache grammar_init_context_cache_;
517517
// Models
518518
Array<Model> models_;
519519
// Device that the models run on.

cpp/serve/grammar/grammar.cc

Lines changed: 78 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55

66
#include "grammar.h"
77

8+
#include "grammar_functor.h"
89
#include "grammar_parser.h"
910
#include "grammar_serializer.h"
10-
#include "grammar_simplifier.h"
1111
#include "json_schema_converter.h"
1212

1313
namespace mlc {
@@ -21,18 +21,28 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) {
2121
return os;
2222
}
2323

24-
BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule,
25-
bool normalize, bool simplify) {
24+
BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string,
25+
const std::string& main_rule) {
2626
auto grammar = EBNFParser::Parse(ebnf_string, main_rule);
27-
if (normalize) {
28-
grammar = NestedRuleUnwrapper(grammar).Apply();
29-
}
27+
// Normalize the grammar by default
28+
grammar = BNFGrammarNormalizer().Apply(grammar);
3029
return grammar;
3130
}
3231

3332
TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString")
34-
.set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) {
35-
return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify);
33+
.set_body_typed([](String ebnf_string, String main_rule) {
34+
return BNFGrammar::FromEBNFString(ebnf_string, main_rule);
35+
});
36+
37+
// Parse the EBNF string but not normalize it
38+
BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string,
39+
const std::string& main_rule) {
40+
return EBNFParser::Parse(ebnf_string, main_rule);
41+
}
42+
43+
TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarDebugFromEBNFStringNoNormalize")
44+
.set_body_typed([](String ebnf_string, String main_rule) {
45+
return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule);
3646
});
3747

3848
BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) {
@@ -69,79 +79,90 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args,
6979
*rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]);
7080
});
7181

82+
// Optimized json grammar for the speed of the grammar state matcher
7283
const std::string kJSONGrammarString = R"(
7384
main ::= (
74-
"{" ws members_or_embrace |
75-
"[" ws elements_or_embrace
85+
"{" [ \n\t]* members_and_embrace |
86+
"[" [ \n\t]* elements_or_embrace
7687
)
77-
value ::= (
78-
"{" ws members_or_embrace |
79-
"[" ws elements_or_embrace |
80-
"\"" characters "\"" |
81-
[0-9] fraction exponent |
82-
[1-9] digits fraction exponent |
88+
value_non_str ::= (
89+
"{" [ \n\t]* members_and_embrace |
90+
"[" [ \n\t]* elements_or_embrace |
91+
"0" fraction exponent |
92+
[1-9] [0-9]* fraction exponent |
8393
"-" [0-9] fraction exponent |
84-
"-" [1-9] digits fraction exponent |
94+
"-" [1-9] [0-9]* fraction exponent |
8595
"true" |
8696
"false" |
8797
"null"
88-
)
89-
members_or_embrace ::= (
90-
"\"" characters "\"" ws ":" ws value members_rest ws "}" |
91-
"}"
92-
)
93-
members ::= "\"" characters "\"" ws ":" ws value members_rest
94-
members_rest ::= (
95-
"" |
96-
"," ws "\"" characters "\"" ws ":" ws value members_rest |
97-
" " ws "," ws "\"" characters "\"" ws ":" ws value members_rest |
98-
"\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest |
99-
"\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest
100-
)
98+
) (= [ \n\t,}\]])
99+
members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]])
100+
members_suffix ::= (
101+
value_non_str [ \n\t]* member_suffix_suffix |
102+
"\"" characters_and_embrace |
103+
"\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix
104+
) (= [ \n\t,}\]])
105+
member_suffix_suffix ::= (
106+
"}" |
107+
"," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix
108+
) (= [ \n\t,}\]])
101109
elements_or_embrace ::= (
102-
"{" ws members_or_embrace elements_rest ws "]" |
103-
"[" ws elements_or_embrace elements_rest ws "]" |
104-
"\"" characters "\"" elements_rest ws "]" |
105-
[0-9] fraction exponent elements_rest ws "]" |
106-
[1-9] digits fraction exponent elements_rest ws "]" |
107-
"-" [0-9] fraction exponent elements_rest ws "]" |
108-
"-" [1-9] digits fraction exponent elements_rest ws "]" |
109-
"true" elements_rest ws "]" |
110-
"false" elements_rest ws "]" |
111-
"null" elements_rest ws "]" |
110+
"{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" |
111+
"[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" |
112+
"\"" characters_item elements_rest [ \n\t]* "]" |
113+
"0" fraction exponent elements_rest [ \n\t]* "]" |
114+
[1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" |
115+
"-" "0" fraction exponent elements_rest [ \n\t]* "]" |
116+
"-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" |
117+
"true" elements_rest [ \n\t]* "]" |
118+
"false" elements_rest [ \n\t]* "]" |
119+
"null" elements_rest [ \n\t]* "]" |
112120
"]"
113121
)
114122
elements ::= (
115-
"{" ws members_or_embrace elements_rest |
116-
"[" ws elements_or_embrace elements_rest |
117-
"\"" characters "\"" elements_rest |
118-
[0-9] fraction exponent elements_rest |
119-
[1-9] digits fraction exponent elements_rest |
123+
"{" [ \n\t]* members_and_embrace elements_rest |
124+
"[" [ \n\t]* elements_or_embrace elements_rest |
125+
"\"" characters_item elements_rest |
126+
"0" fraction exponent elements_rest |
127+
[1-9] [0-9]* fraction exponent elements_rest |
120128
"-" [0-9] fraction exponent elements_rest |
121-
"-" [1-9] digits fraction exponent elements_rest |
129+
"-" [1-9] [0-9]* fraction exponent elements_rest |
122130
"true" elements_rest |
123131
"false" elements_rest |
124132
"null" elements_rest
125133
)
126134
elements_rest ::= (
127135
"" |
128-
"," ws elements |
129-
" " ws "," ws elements |
130-
"\n" ws "," ws elements |
131-
"\t" ws "," ws elements
136+
[ \n\t]* "," [ \n\t]* elements
132137
)
133-
characters ::= "" | [^"\\\r\n] characters | "\\" escape characters
138+
characters_and_colon ::= (
139+
"\"" [ \n\t]* ":" |
140+
[^"\\\x00-\x1F] characters_and_colon |
141+
"\\" escape characters_and_colon
142+
) (=[ \n\t]* [\"{[0-9tfn-])
143+
characters_and_comma ::= (
144+
"\"" [ \n\t]* "," |
145+
[^"\\\x00-\x1F] characters_and_comma |
146+
"\\" escape characters_and_comma
147+
) (=[ \n\t]* "\"")
148+
characters_and_embrace ::= (
149+
"\"" [ \n\t]* "}" |
150+
[^"\\\x00-\x1F] characters_and_embrace |
151+
"\\" escape characters_and_embrace
152+
) (=[ \n\t]* [},])
153+
characters_item ::= (
154+
"\"" |
155+
[^"\\\x00-\x1F] characters_item |
156+
"\\" escape characters_item
157+
) (= [ \n\t]* [,\]])
134158
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
135-
digits ::= [0-9] | [0-9] digits
136-
fraction ::= "" | "." digits
137-
exponent ::= "" | "e" sign digits | "E" sign digits
159+
fraction ::= "" | "." [0-9] [0-9]*
160+
exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]*
138161
sign ::= "" | "+" | "-"
139-
ws ::= [ \n\t]*
140162
)";
141163

142164
BNFGrammar BNFGrammar::GetGrammarOfJSON() {
143-
static const BNFGrammar grammar =
144-
BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false);
165+
static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main");
145166
return grammar;
146167
}
147168

cpp/serve/grammar/grammar.h

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,15 @@ using namespace tvm::runtime;
4444
* #### Types of RuleExprs
4545
* Every RuleExpr is represented by a type as well as a variable-length array containing its data.
4646
* RuleExpr has several types:
47+
* - Byte string: a string of bytes (0~255). Supports UTF-8 strings.
4748
* - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z],
48-
* [ac-z].
49-
* A single character is represented by a character class with the same lower and upper bound.
50-
* A string is represented by a sequence of character classes.
51-
* - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z]
49+
* [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this
50+
* expression can accept/reject unicode chars.
51+
* - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*.
5252
* - EmptyStr: an empty string, i.e. ""
5353
* - Rule reference: a reference to another rule
5454
* - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together.
5555
* - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched.
56-
* - Character class star: special support for a repetition of a character class. e.g. [a-z]*
5756
*
5857
* #### Storage of RuleExprs
5958
* Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see
@@ -76,6 +75,9 @@ class BNFGrammarNode : public Object {
7675
std::string name;
7776
/*! \brief The RuleExpr id of the body of the rule. */
7877
int32_t body_expr_id;
78+
/*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a
79+
* sequence RuleExpr. -1 if not exists. */
80+
int32_t lookahead_assertion_id = -1;
7981
};
8082

8183
/*! \brief Get the number of rules. */
@@ -86,6 +88,8 @@ class BNFGrammarNode : public Object {
8688
<< "rule_id " << rule_id << " is out of bound";
8789
return rules_[rule_id];
8890
}
91+
/*! \brief Get the main rule id of the grammar. */
92+
int32_t GetMainRuleId() const { return main_rule_id_; }
8993
/*! \brief Get the main rule of the grammar. */
9094
const Rule& GetMainRule() const {
9195
DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast<int32_t>(rules_.size()))
@@ -95,10 +99,11 @@ class BNFGrammarNode : public Object {
9599

96100
/*! \brief The type of the rule expr. */
97101
enum class RuleExprType : int32_t {
98-
// data format: [lower0, upper0, lower1, upper1, ...]
102+
// data format: [byte0, byte1, ...]
103+
kByteString,
104+
// data format: [is_negative, lower0, upper0, lower1, upper1, ...]
99105
kCharacterClass,
100-
// data format: [lower0, upper0, lower1, upper1, ...]
101-
kNegCharacterClass,
106+
kCharacterClassStar,
102107
// data format: []
103108
kEmptyStr,
104109
// data format: [rule_id]
@@ -107,8 +112,6 @@ class BNFGrammarNode : public Object {
107112
kSequence,
108113
// data format: [rule_expr_id0, rule_expr_id1, ...]
109114
kChoices,
110-
// data format: [rule_expr_id]
111-
kCharacterClassStar,
112115
};
113116

114117
/*! \brief The object representing a rule expr. */
@@ -154,8 +157,8 @@ class BNFGrammarNode : public Object {
154157
std::vector<Rule> rules_;
155158
/*! \brief The data of all rule_exprs. */
156159
std::vector<int32_t> rule_expr_data_;
157-
/*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the
158-
* index of this vector. */
160+
/*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index
161+
* to the elements in this vector. */
159162
std::vector<int32_t> rule_expr_indptr_;
160163
/*! \brief The id of the main rule. */
161164
int32_t main_rule_id_ = -1;
@@ -168,25 +171,13 @@ class BNFGrammarNode : public Object {
168171
class BNFGrammar : public ObjectRef {
169172
public:
170173
/*!
171-
* \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and
172-
* transform it into BNF AST.
174+
* \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized
175+
* (simplified) by default.
173176
* \param ebnf_string The EBNF-formatted string.
174177
* \param main_rule The name of the main rule.
175-
* \param normalize Whether to normalize the grammar. Default: true. Only set to false for the
176-
* purpose of testing.
177-
*
178-
* \note In The normalized form of a BNF grammar, every rule is in the form:
179-
* `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`.
180-
*
181-
* I.e. a list of choices, each choice is a sequence of elements. Elements can be a character
182-
* class or a rule reference. And if the rule can be empty, the first choice will be an empty
183-
* string.
184-
* \param simplify Whether to simplify the grammar to make matching more efficient. Default: true.
185-
* Not implemented yet.
186178
*/
187179
static BNFGrammar FromEBNFString(const std::string& ebnf_string,
188-
const std::string& main_rule = "main", bool normalize = true,
189-
bool simplify = true);
180+
const std::string& main_rule = "main");
190181

191182
/*!
192183
* \brief Construct a BNF grammar from the dumped JSON string.

0 commit comments

Comments
 (0)