Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class EngineImpl : public Engine {
}
n->token_table_ =
Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method);
n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_);
n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_);
// - Create the logit processor and sampler, and
// the DraftTokenWorkspaceManager for speculative decoding.
int max_num_tokens = engine_config->max_num_sequence;
Expand Down Expand Up @@ -499,9 +499,9 @@ class EngineImpl : public Engine {
if (response_format.type != "json_object") {
return std::nullopt;
} else if (!response_format.schema) {
return grammar_init_context_storage_->GetInitContextForJSON();
return grammar_init_context_cache_->GetInitContextForJSON();
} else {
return grammar_init_context_storage_->GetInitContextForJSONSchema(
return grammar_init_context_cache_->GetInitContextForJSONSchema(
response_format.schema.value());
}
}
Expand All @@ -513,7 +513,7 @@ class EngineImpl : public Engine {
Tokenizer tokenizer_;
std::vector<std::string> token_table_;
// Helper to get the grammar init context for requests.
GrammarInitContextStorage grammar_init_context_storage_;
GrammarInitContextCache grammar_init_context_cache_;
// Models
Array<Model> models_;
// Device that the models run on.
Expand Down
135 changes: 78 additions & 57 deletions cpp/serve/grammar/grammar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

#include "grammar.h"

#include "grammar_functor.h"
#include "grammar_parser.h"
#include "grammar_serializer.h"
#include "grammar_simplifier.h"
#include "json_schema_converter.h"

namespace mlc {
Expand All @@ -21,18 +21,28 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) {
return os;
}

BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule,
bool normalize, bool simplify) {
BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string,
const std::string& main_rule) {
auto grammar = EBNFParser::Parse(ebnf_string, main_rule);
if (normalize) {
grammar = NestedRuleUnwrapper(grammar).Apply();
}
// Normalize the grammar by default
grammar = BNFGrammarNormalizer().Apply(grammar);
return grammar;
}

TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString")
.set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) {
return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify);
.set_body_typed([](String ebnf_string, String main_rule) {
return BNFGrammar::FromEBNFString(ebnf_string, main_rule);
});

// Parse the EBNF string but not normalize it
BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string,
const std::string& main_rule) {
return EBNFParser::Parse(ebnf_string, main_rule);
}

TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarDebugFromEBNFStringNoNormalize")
.set_body_typed([](String ebnf_string, String main_rule) {
return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule);
});

BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) {
Expand Down Expand Up @@ -69,79 +79,90 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args,
*rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]);
});

// Optimized json grammar for the speed of the grammar state matcher
const std::string kJSONGrammarString = R"(
main ::= (
"{" ws members_or_embrace |
"[" ws elements_or_embrace
"{" [ \n\t]* members_and_embrace |
"[" [ \n\t]* elements_or_embrace
)
value ::= (
"{" ws members_or_embrace |
"[" ws elements_or_embrace |
"\"" characters "\"" |
[0-9] fraction exponent |
[1-9] digits fraction exponent |
value_non_str ::= (
"{" [ \n\t]* members_and_embrace |
"[" [ \n\t]* elements_or_embrace |
"0" fraction exponent |
[1-9] [0-9]* fraction exponent |
"-" [0-9] fraction exponent |
"-" [1-9] digits fraction exponent |
"-" [1-9] [0-9]* fraction exponent |
"true" |
"false" |
"null"
)
members_or_embrace ::= (
"\"" characters "\"" ws ":" ws value members_rest ws "}" |
"}"
)
members ::= "\"" characters "\"" ws ":" ws value members_rest
members_rest ::= (
"" |
"," ws "\"" characters "\"" ws ":" ws value members_rest |
" " ws "," ws "\"" characters "\"" ws ":" ws value members_rest |
"\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest |
"\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest
)
) (= [ \n\t,}\]])
members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]])
members_suffix ::= (
value_non_str [ \n\t]* member_suffix_suffix |
"\"" characters_and_embrace |
"\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix
) (= [ \n\t,}\]])
member_suffix_suffix ::= (
"}" |
"," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix
) (= [ \n\t,}\]])
elements_or_embrace ::= (
"{" ws members_or_embrace elements_rest ws "]" |
"[" ws elements_or_embrace elements_rest ws "]" |
"\"" characters "\"" elements_rest ws "]" |
[0-9] fraction exponent elements_rest ws "]" |
[1-9] digits fraction exponent elements_rest ws "]" |
"-" [0-9] fraction exponent elements_rest ws "]" |
"-" [1-9] digits fraction exponent elements_rest ws "]" |
"true" elements_rest ws "]" |
"false" elements_rest ws "]" |
"null" elements_rest ws "]" |
"{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" |
"[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" |
"\"" characters_item elements_rest [ \n\t]* "]" |
"0" fraction exponent elements_rest [ \n\t]* "]" |
[1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" |
"-" "0" fraction exponent elements_rest [ \n\t]* "]" |
"-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" |
"true" elements_rest [ \n\t]* "]" |
"false" elements_rest [ \n\t]* "]" |
"null" elements_rest [ \n\t]* "]" |
"]"
)
elements ::= (
"{" ws members_or_embrace elements_rest |
"[" ws elements_or_embrace elements_rest |
"\"" characters "\"" elements_rest |
[0-9] fraction exponent elements_rest |
[1-9] digits fraction exponent elements_rest |
"{" [ \n\t]* members_and_embrace elements_rest |
"[" [ \n\t]* elements_or_embrace elements_rest |
"\"" characters_item elements_rest |
"0" fraction exponent elements_rest |
[1-9] [0-9]* fraction exponent elements_rest |
"-" [0-9] fraction exponent elements_rest |
"-" [1-9] digits fraction exponent elements_rest |
"-" [1-9] [0-9]* fraction exponent elements_rest |
"true" elements_rest |
"false" elements_rest |
"null" elements_rest
)
elements_rest ::= (
"" |
"," ws elements |
" " ws "," ws elements |
"\n" ws "," ws elements |
"\t" ws "," ws elements
[ \n\t]* "," [ \n\t]* elements
)
characters ::= "" | [^"\\\r\n] characters | "\\" escape characters
characters_and_colon ::= (
"\"" [ \n\t]* ":" |
[^"\\\x00-\x1F] characters_and_colon |
"\\" escape characters_and_colon
) (=[ \n\t]* [\"{[0-9tfn-])
characters_and_comma ::= (
"\"" [ \n\t]* "," |
[^"\\\x00-\x1F] characters_and_comma |
"\\" escape characters_and_comma
) (=[ \n\t]* "\"")
characters_and_embrace ::= (
"\"" [ \n\t]* "}" |
[^"\\\x00-\x1F] characters_and_embrace |
"\\" escape characters_and_embrace
) (=[ \n\t]* [},])
characters_item ::= (
"\"" |
[^"\\\x00-\x1F] characters_item |
"\\" escape characters_item
) (= [ \n\t]* [,\]])
escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]
digits ::= [0-9] | [0-9] digits
fraction ::= "" | "." digits
exponent ::= "" | "e" sign digits | "E" sign digits
fraction ::= "" | "." [0-9] [0-9]*
exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]*
sign ::= "" | "+" | "-"
ws ::= [ \n\t]*
)";

BNFGrammar BNFGrammar::GetGrammarOfJSON() {
static const BNFGrammar grammar =
BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false);
static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main");
return grammar;
}

Expand Down
45 changes: 18 additions & 27 deletions cpp/serve/grammar/grammar.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,15 @@ using namespace tvm::runtime;
* #### Types of RuleExprs
* Every RuleExpr is represented by a type as well as a variable-length array containing its data.
* RuleExpr has several types:
* - Byte string: a string of bytes (0~255). Supports UTF-8 strings.
* - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z],
* [ac-z].
* A single character is represented by a character class with the same lower and upper bound.
* A string is represented by a sequence of character classes.
* - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z]
* [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this
* expression can accept/reject unicode chars.
* - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*.
* - EmptyStr: an empty string, i.e. ""
* - Rule reference: a reference to another rule
* - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together.
* - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched.
* - Character class star: special support for a repetition of a character class. e.g. [a-z]*
*
* #### Storage of RuleExprs
* Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see
Expand All @@ -76,6 +75,9 @@ class BNFGrammarNode : public Object {
std::string name;
/*! \brief The RuleExpr id of the body of the rule. */
int32_t body_expr_id;
/*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a
* sequence RuleExpr. -1 if not exists. */
int32_t lookahead_assertion_id = -1;
};

/*! \brief Get the number of rules. */
Expand All @@ -86,6 +88,8 @@ class BNFGrammarNode : public Object {
<< "rule_id " << rule_id << " is out of bound";
return rules_[rule_id];
}
/*! \brief Get the main rule id of the grammar. */
int32_t GetMainRuleId() const { return main_rule_id_; }
/*! \brief Get the main rule of the grammar. */
const Rule& GetMainRule() const {
DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast<int32_t>(rules_.size()))
Expand All @@ -95,10 +99,11 @@ class BNFGrammarNode : public Object {

/*! \brief The type of the rule expr. */
enum class RuleExprType : int32_t {
// data format: [lower0, upper0, lower1, upper1, ...]
// data format: [byte0, byte1, ...]
kByteString,
// data format: [is_negative, lower0, upper0, lower1, upper1, ...]
kCharacterClass,
// data format: [lower0, upper0, lower1, upper1, ...]
kNegCharacterClass,
kCharacterClassStar,
// data format: []
kEmptyStr,
// data format: [rule_id]
Expand All @@ -107,8 +112,6 @@ class BNFGrammarNode : public Object {
kSequence,
// data format: [rule_expr_id0, rule_expr_id1, ...]
kChoices,
// data format: [rule_expr_id]
kCharacterClassStar,
};

/*! \brief The object representing a rule expr. */
Expand Down Expand Up @@ -154,8 +157,8 @@ class BNFGrammarNode : public Object {
std::vector<Rule> rules_;
/*! \brief The data of all rule_exprs. */
std::vector<int32_t> rule_expr_data_;
/*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the
* index of this vector. */
/*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index
* to the elements in this vector. */
std::vector<int32_t> rule_expr_indptr_;
/*! \brief The id of the main rule. */
int32_t main_rule_id_ = -1;
Expand All @@ -168,25 +171,13 @@ class BNFGrammarNode : public Object {
class BNFGrammar : public ObjectRef {
public:
/*!
* \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and
* transform it into BNF AST.
* \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized
* (simplified) by default.
* \param ebnf_string The EBNF-formatted string.
* \param main_rule The name of the main rule.
* \param normalize Whether to normalize the grammar. Default: true. Only set to false for the
* purpose of testing.
*
* \note In The normalized form of a BNF grammar, every rule is in the form:
* `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`.
*
* I.e. a list of choices, each choice is a sequence of elements. Elements can be a character
* class or a rule reference. And if the rule can be empty, the first choice will be an empty
* string.
* \param simplify Whether to simplify the grammar to make matching more efficient. Default: true.
* Not implemented yet.
*/
static BNFGrammar FromEBNFString(const std::string& ebnf_string,
const std::string& main_rule = "main", bool normalize = true,
bool simplify = true);
const std::string& main_rule = "main");

/*!
* \brief Construct a BNF grammar from the dumped JSON string.
Expand Down
Loading