diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 616c463d9c..9b9cf81fe7 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -122,7 +122,7 @@ class EngineImpl : public Engine { } n->token_table_ = Tokenizer::PostProcessTokenTable(n->tokenizer_->TokenTable(), token_table_postproc_method); - n->grammar_init_context_storage_ = GrammarInitContextStorage(n->token_table_); + n->grammar_init_context_cache_ = GrammarInitContextCache(n->token_table_); // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -499,9 +499,9 @@ class EngineImpl : public Engine { if (response_format.type != "json_object") { return std::nullopt; } else if (!response_format.schema) { - return grammar_init_context_storage_->GetInitContextForJSON(); + return grammar_init_context_cache_->GetInitContextForJSON(); } else { - return grammar_init_context_storage_->GetInitContextForJSONSchema( + return grammar_init_context_cache_->GetInitContextForJSONSchema( response_format.schema.value()); } } @@ -513,7 +513,7 @@ class EngineImpl : public Engine { Tokenizer tokenizer_; std::vector token_table_; // Helper to get the grammar init context for requests. - GrammarInitContextStorage grammar_init_context_storage_; + GrammarInitContextCache grammar_init_context_cache_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/grammar/grammar.cc b/cpp/serve/grammar/grammar.cc index c8d760538c..2f0d7f565f 100644 --- a/cpp/serve/grammar/grammar.cc +++ b/cpp/serve/grammar/grammar.cc @@ -5,9 +5,9 @@ #include "grammar.h" +#include "grammar_functor.h" #include "grammar_parser.h" #include "grammar_serializer.h" -#include "grammar_simplifier.h" #include "json_schema_converter.h" namespace mlc { @@ -21,18 +21,28 @@ std::ostream& operator<<(std::ostream& os, const BNFGrammar& grammar) { return os; } -BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, const std::string& main_rule, - bool normalize, bool simplify) { +BNFGrammar BNFGrammar::FromEBNFString(const std::string& ebnf_string, + const std::string& main_rule) { auto grammar = EBNFParser::Parse(ebnf_string, main_rule); - if (normalize) { - grammar = NestedRuleUnwrapper(grammar).Apply(); - } + // Normalize the grammar by default + grammar = BNFGrammarNormalizer().Apply(grammar); return grammar; } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromEBNFString") - .set_body_typed([](String ebnf_string, String main_rule, bool normalize, bool simplify) { - return BNFGrammar::FromEBNFString(ebnf_string, main_rule, normalize, simplify); + .set_body_typed([](String ebnf_string, String main_rule) { + return BNFGrammar::FromEBNFString(ebnf_string, main_rule); + }); + +// Parse the EBNF string but not normalize it +BNFGrammar DebugFromEBNFStringNoNormalize(const std::string& ebnf_string, + const std::string& main_rule) { + return EBNFParser::Parse(ebnf_string, main_rule); +} + +TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarDebugFromEBNFStringNoNormalize") + .set_body_typed([](String ebnf_string, String main_rule) { + return DebugFromEBNFStringNoNormalize(ebnf_string, main_rule); }); BNFGrammar BNFGrammar::FromJSON(const std::string& json_string) { @@ -69,79 +79,90 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarFromSchema").set_body([](TVMArgs args, *rv = BNFGrammar::FromSchema(args[0], indent, separators, args[3]); }); +// Optimized json grammar for the speed of the grammar state matcher const std::string kJSONGrammarString = R"( main ::= ( - "{" ws members_or_embrace | - "[" ws elements_or_embrace + "{" [ \n\t]* members_and_embrace | + "[" [ \n\t]* elements_or_embrace ) -value ::= ( - "{" ws members_or_embrace | - "[" ws elements_or_embrace | - "\"" characters "\"" | - [0-9] fraction exponent | - [1-9] digits fraction exponent | +value_non_str ::= ( + "{" [ \n\t]* members_and_embrace | + "[" [ \n\t]* elements_or_embrace | + "0" fraction exponent | + [1-9] [0-9]* fraction exponent | "-" [0-9] fraction exponent | - "-" [1-9] digits fraction exponent | + "-" [1-9] [0-9]* fraction exponent | "true" | "false" | "null" -) -members_or_embrace ::= ( - "\"" characters "\"" ws ":" ws value members_rest ws "}" | - "}" -) -members ::= "\"" characters "\"" ws ":" ws value members_rest -members_rest ::= ( - "" | - "," ws "\"" characters "\"" ws ":" ws value members_rest | - " " ws "," ws "\"" characters "\"" ws ":" ws value members_rest | - "\n" ws "," ws "\"" characters "\"" ws ":" ws value members_rest | - "\t" ws "," ws "\"" characters "\"" ws ":" ws value members_rest -) +) (= [ \n\t,}\]]) +members_and_embrace ::= ("\"" characters_and_colon [ \n\t]* members_suffix | "}") (= [ \n\t,}\]]) +members_suffix ::= ( + value_non_str [ \n\t]* member_suffix_suffix | + "\"" characters_and_embrace | + "\"" characters_and_comma [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix +) (= [ \n\t,}\]]) +member_suffix_suffix ::= ( + "}" | + "," [ \n\t]* "\"" characters_and_colon [ \n\t]* members_suffix +) (= [ \n\t,}\]]) elements_or_embrace ::= ( - "{" ws members_or_embrace elements_rest ws "]" | - "[" ws elements_or_embrace elements_rest ws "]" | - "\"" characters "\"" elements_rest ws "]" | - [0-9] fraction exponent elements_rest ws "]" | - [1-9] digits fraction exponent elements_rest ws "]" | - "-" [0-9] fraction exponent elements_rest ws "]" | - "-" [1-9] digits fraction exponent elements_rest ws "]" | - "true" elements_rest ws "]" | - "false" elements_rest ws "]" | - "null" elements_rest ws "]" | + "{" [ \n\t]* members_and_embrace elements_rest [ \n\t]* "]" | + "[" [ \n\t]* elements_or_embrace elements_rest [ \n\t]* "]" | + "\"" characters_item elements_rest [ \n\t]* "]" | + "0" fraction exponent elements_rest [ \n\t]* "]" | + [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | + "-" "0" fraction exponent elements_rest [ \n\t]* "]" | + "-" [1-9] [0-9]* fraction exponent elements_rest [ \n\t]* "]" | + "true" elements_rest [ \n\t]* "]" | + "false" elements_rest [ \n\t]* "]" | + "null" elements_rest [ \n\t]* "]" | "]" ) elements ::= ( - "{" ws members_or_embrace elements_rest | - "[" ws elements_or_embrace elements_rest | - "\"" characters "\"" elements_rest | - [0-9] fraction exponent elements_rest | - [1-9] digits fraction exponent elements_rest | + "{" [ \n\t]* members_and_embrace elements_rest | + "[" [ \n\t]* elements_or_embrace elements_rest | + "\"" characters_item elements_rest | + "0" fraction exponent elements_rest | + [1-9] [0-9]* fraction exponent elements_rest | "-" [0-9] fraction exponent elements_rest | - "-" [1-9] digits fraction exponent elements_rest | + "-" [1-9] [0-9]* fraction exponent elements_rest | "true" elements_rest | "false" elements_rest | "null" elements_rest ) elements_rest ::= ( "" | - "," ws elements | - " " ws "," ws elements | - "\n" ws "," ws elements | - "\t" ws "," ws elements + [ \n\t]* "," [ \n\t]* elements ) -characters ::= "" | [^"\\\r\n] characters | "\\" escape characters +characters_and_colon ::= ( + "\"" [ \n\t]* ":" | + [^"\\\x00-\x1F] characters_and_colon | + "\\" escape characters_and_colon +) (=[ \n\t]* [\"{[0-9tfn-]) +characters_and_comma ::= ( + "\"" [ \n\t]* "," | + [^"\\\x00-\x1F] characters_and_comma | + "\\" escape characters_and_comma +) (=[ \n\t]* "\"") +characters_and_embrace ::= ( + "\"" [ \n\t]* "}" | + [^"\\\x00-\x1F] characters_and_embrace | + "\\" escape characters_and_embrace +) (=[ \n\t]* [},]) +characters_item ::= ( + "\"" | + [^"\\\x00-\x1F] characters_item | + "\\" escape characters_item +) (= [ \n\t]* [,\]]) escape ::= ["\\/bfnrt] | "u" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] -digits ::= [0-9] | [0-9] digits -fraction ::= "" | "." digits -exponent ::= "" | "e" sign digits | "E" sign digits +fraction ::= "" | "." [0-9] [0-9]* +exponent ::= "" | "e" sign [0-9] [0-9]* | "E" sign [0-9] [0-9]* sign ::= "" | "+" | "-" -ws ::= [ \n\t]* )"; BNFGrammar BNFGrammar::GetGrammarOfJSON() { - static const BNFGrammar grammar = - BNFGrammar::FromEBNFString(kJSONGrammarString, "main", true, false); + static const BNFGrammar grammar = BNFGrammar::FromEBNFString(kJSONGrammarString, "main"); return grammar; } diff --git a/cpp/serve/grammar/grammar.h b/cpp/serve/grammar/grammar.h index ba15e58af3..b7922301cb 100644 --- a/cpp/serve/grammar/grammar.h +++ b/cpp/serve/grammar/grammar.h @@ -44,16 +44,15 @@ using namespace tvm::runtime; * #### Types of RuleExprs * Every RuleExpr is represented by a type as well as a variable-length array containing its data. * RuleExpr has several types: + * - Byte string: a string of bytes (0~255). Supports UTF-8 strings. * - Character class: a range of characters (each character is a unicode codepoint), e.g. [a-z], - * [ac-z]. - * A single character is represented by a character class with the same lower and upper bound. - * A string is represented by a sequence of character classes. - * - Negated character class: all characters that are not in the range, e.g. [^a-z], [^ac-z] + * [ac-z]. Can be negated: [^a-z], [^ac-z]. Now only ascii chars is allowed in [], but this + * expression can accept/reject unicode chars. + * - Character class star: a star quantifier of a character class. e.g. [a-z]*, [^a-z]*. * - EmptyStr: an empty string, i.e. "" * - Rule reference: a reference to another rule * - Sequence: a sequence of rule_exprs, e.g. ("a" "b"). These rule_exprs are concatenated together. * - Choices: a choice of rule_exprs, e.g. ("a" "b") | "c". Each rule_expr can be matched. - * - Character class star: special support for a repetition of a character class. e.g. [a-z]* * * #### Storage of RuleExprs * Each type of RuleExpr has a different data format. For the format of each type of RuleExpr, see @@ -76,6 +75,9 @@ class BNFGrammarNode : public Object { std::string name; /*! \brief The RuleExpr id of the body of the rule. */ int32_t body_expr_id; + /*! \brief The id of the associated lookahead assertion expr. For now it must be a id of a + * sequence RuleExpr. -1 if not exists. */ + int32_t lookahead_assertion_id = -1; }; /*! \brief Get the number of rules. */ @@ -86,6 +88,8 @@ class BNFGrammarNode : public Object { << "rule_id " << rule_id << " is out of bound"; return rules_[rule_id]; } + /*! \brief Get the main rule id of the grammar. */ + int32_t GetMainRuleId() const { return main_rule_id_; } /*! \brief Get the main rule of the grammar. */ const Rule& GetMainRule() const { DCHECK(main_rule_id_ >= 0 && main_rule_id_ < static_cast(rules_.size())) @@ -95,10 +99,11 @@ class BNFGrammarNode : public Object { /*! \brief The type of the rule expr. */ enum class RuleExprType : int32_t { - // data format: [lower0, upper0, lower1, upper1, ...] + // data format: [byte0, byte1, ...] + kByteString, + // data format: [is_negative, lower0, upper0, lower1, upper1, ...] kCharacterClass, - // data format: [lower0, upper0, lower1, upper1, ...] - kNegCharacterClass, + kCharacterClassStar, // data format: [] kEmptyStr, // data format: [rule_id] @@ -107,8 +112,6 @@ class BNFGrammarNode : public Object { kSequence, // data format: [rule_expr_id0, rule_expr_id1, ...] kChoices, - // data format: [rule_expr_id] - kCharacterClassStar, }; /*! \brief The object representing a rule expr. */ @@ -154,8 +157,8 @@ class BNFGrammarNode : public Object { std::vector rules_; /*! \brief The data of all rule_exprs. */ std::vector rule_expr_data_; - /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id corresponds the - * index of this vector. */ + /*! \brief The start index of every rule_expr in rule_expr_data_. rule_expr_id is the index + * to the elements in this vector. */ std::vector rule_expr_indptr_; /*! \brief The id of the main rule. */ int32_t main_rule_id_ = -1; @@ -168,25 +171,13 @@ class BNFGrammarNode : public Object { class BNFGrammar : public ObjectRef { public: /*! - * \brief Construct a BNF grammar with a EBNF-formatted string. Will parse the string and - * transform it into BNF AST. + * \brief Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + * (simplified) by default. * \param ebnf_string The EBNF-formatted string. * \param main_rule The name of the main rule. - * \param normalize Whether to normalize the grammar. Default: true. Only set to false for the - * purpose of testing. - * - * \note In The normalized form of a BNF grammar, every rule is in the form: - * `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character - * class or a rule reference. And if the rule can be empty, the first choice will be an empty - * string. - * \param simplify Whether to simplify the grammar to make matching more efficient. Default: true. - * Not implemented yet. */ static BNFGrammar FromEBNFString(const std::string& ebnf_string, - const std::string& main_rule = "main", bool normalize = true, - bool simplify = true); + const std::string& main_rule = "main"); /*! * \brief Construct a BNF grammar from the dumped JSON string. diff --git a/cpp/serve/grammar/grammar_builder.h b/cpp/serve/grammar/grammar_builder.h index 0854cc9789..7987a67f98 100644 --- a/cpp/serve/grammar/grammar_builder.h +++ b/cpp/serve/grammar/grammar_builder.h @@ -56,6 +56,16 @@ class BNFGrammarBuilder { return static_cast(grammar_->rule_expr_indptr_.size()) - 1; } + /*! + * \brief Add a RuleExpr for string stored in bytes. + * \param bytes A vector of int32_t, each representing a byte (0~255) in the string. + * The string is stored in int32 vector to match the storage format of the grammar. + */ + int32_t AddByteString(const std::vector& bytes) { + return AddRuleExpr( + {RuleExprType::kByteString, bytes.data(), static_cast(bytes.size())}); + } + /*! * \brief One element of a character class, containing a lower and a upper bound. Both bounds are * inclusive. @@ -66,19 +76,39 @@ class BNFGrammarBuilder { }; /*! - * \brief Add a RuleExpr for character class. + * \brief Add a RuleExpr for a character class. * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. - * \param is_neg_range Whether the character class is negated. + * \param is_negative Whether the character class is negated. */ int32_t AddCharacterClass(const std::vector& elements, - bool is_neg_range = false) { + bool is_negative = false) { std::vector data; + data.reserve(1 + elements.size() * 2); + data.push_back(static_cast(is_negative)); for (const auto& range : elements) { data.push_back(range.lower); data.push_back(range.upper); } - auto type = is_neg_range ? RuleExprType::kNegCharacterClass : RuleExprType::kCharacterClass; - return AddRuleExpr({type, data.data(), static_cast(data.size())}); + return AddRuleExpr( + {RuleExprType::kCharacterClass, data.data(), static_cast(data.size())}); + } + + /*! + * \brief Add a RuleExpr for a star quantifier of a character class. + * \param elements A vector of CharacterClassElement, each containing a lower and a upper bound. + * \param is_negative Whether the character class is negated. + */ + int32_t AddCharacterClassStar(const std::vector& elements, + bool is_negative = false) { + std::vector data; + data.reserve(1 + elements.size() * 2); + data.push_back(static_cast(is_negative)); + for (const auto& range : elements) { + data.push_back(range.lower); + data.push_back(range.upper); + } + return AddRuleExpr( + {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); } /*! \brief Add a RuleExpr for empty string.*/ @@ -93,23 +123,14 @@ class BNFGrammarBuilder { /*! \brief Add a RuleExpr for RuleExpr sequence.*/ int32_t AddSequence(const std::vector& elements) { - std::vector data; - data.insert(data.end(), elements.begin(), elements.end()); - return AddRuleExpr({RuleExprType::kSequence, data.data(), static_cast(data.size())}); + return AddRuleExpr( + {RuleExprType::kSequence, elements.data(), static_cast(elements.size())}); } /*! \brief Add a RuleExpr for RuleExpr choices.*/ int32_t AddChoices(const std::vector& choices) { - std::vector data; - data.insert(data.end(), choices.begin(), choices.end()); - return AddRuleExpr({RuleExprType::kChoices, data.data(), static_cast(data.size())}); - } - - int32_t AddCharacterClassStar(int32_t element) { - std::vector data; - data.push_back(element); return AddRuleExpr( - {RuleExprType::kCharacterClassStar, data.data(), static_cast(data.size())}); + {RuleExprType::kChoices, choices.data(), static_cast(choices.size())}); } size_t NumRuleExprs() const { return grammar_->NumRuleExprs(); } @@ -154,7 +175,7 @@ class BNFGrammarBuilder { * rule body of a rule inserted by BNFGrammarBuilder::AddEmptyRule. */ void UpdateRuleBody(int32_t rule_id, int32_t body_expr_id) { - CHECK(rule_id < static_cast(grammar_->rules_.size())) + CHECK(rule_id >= 0 && rule_id < static_cast(grammar_->rules_.size())) << "Rule id " << rule_id << " is out of range."; grammar_->rules_[rule_id].body_expr_id = body_expr_id; } @@ -169,6 +190,28 @@ class BNFGrammarBuilder { UpdateRuleBody(rule_id, body_expr_id); } + /*! + * \brief Add a lookahead assertion to a rule referred by the given rule_id. The lookahead + * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. + */ + void AddLookaheadAssertion(int32_t rule_id, int32_t lookahead_assertion_id) { + CHECK(rule_id < static_cast(grammar_->rules_.size())) + << "Rule id " << rule_id << " is out of range."; + CHECK(grammar_->rules_[rule_id].lookahead_assertion_id == -1) + << "Rule " << rule_id << " already has a lookahead assertion."; + grammar_->rules_[rule_id].lookahead_assertion_id = lookahead_assertion_id; + } + + /*! + * \brief Add a lookahead assertion to a rule referred by the given name. The lookahead + * assertion should be a sequence RuleExpr id. An id of -1 means no lookahead assertion. + */ + void AddLookaheadAssertion(std::string rule_name, int32_t lookahead_assertion_id) { + int32_t rule_id = GetRuleId(rule_name); + CHECK(rule_id != -1) << "Rule " << rule_name << " is not found."; + AddLookaheadAssertion(rule_id, lookahead_assertion_id); + } + /*! * \brief Find a name for a new rule starting with the given name hint. Some integer suffix (_1, * _2, ...) may be added to avoid name conflict. diff --git a/cpp/serve/grammar/grammar_simplifier.cc b/cpp/serve/grammar/grammar_functor.cc similarity index 54% rename from cpp/serve/grammar/grammar_simplifier.cc rename to cpp/serve/grammar/grammar_functor.cc index 109b5d85e1..ae4e108233 100644 --- a/cpp/serve/grammar/grammar_simplifier.cc +++ b/cpp/serve/grammar/grammar_functor.cc @@ -1,56 +1,101 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/grammar/grammar_simplifier.cc + * \file serve/grammar/grammar_functor.cc */ -#include "grammar_simplifier.h" +#include "grammar_functor.h" + +#include "../../support/encoding.h" namespace mlc { namespace llm { namespace serve { /*! - * \brief Eliminates single-element sequence or choice nodes in the grammar. - * \example The sequence `(a)` or the choice `(a)` will be replaced by `a` in a rule. - * \example The rule `A ::= ((b) (((d))))` will be replaced by `A ::= (b d)`. + * \brief Eliminates single-element sequence or choice or character class in the grammar. + * \example `A ::= choices("a")` --> `A ::= "a"` (the body is a string) + * \example `A ::= sequence("a")` --> `A ::= "a"` (the body is a string) + * \example `A ::= [a-a]` --> `A ::= "a"` (the body is a string) */ -class SingleElementSequenceOrChoiceEliminator : public BNFGrammarMutator { +class SingleElementExprEliminator : public BNFGrammarMutator { public: using BNFGrammarMutator::Apply; using BNFGrammarMutator::BNFGrammarMutator; private: - int32_t VisitSequence(const RuleExpr& rule_expr) { + // Keep the sequence expr in lookahead assertion + int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { + if (lookahead_assertion_id == -1) { + return -1; + } + auto rule_expr = grammar_->GetRuleExpr(lookahead_assertion_id); + CHECK(rule_expr.type == RuleExprType::kSequence); + + std::vector sequence_ids; + for (int32_t i : rule_expr) { + sequence_ids.push_back(VisitExpr(i)); + } + return builder_.AddSequence(sequence_ids); + } + + int32_t VisitSequence(const RuleExpr& rule_expr) final { std::vector sequence_ids; for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + sequence_ids.push_back(VisitExpr(i)); } if (sequence_ids.size() == 1) { return sequence_ids[0]; - } else { - return builder_.AddSequence(sequence_ids); } + return builder_.AddSequence(sequence_ids); } - int32_t VisitChoices(const RuleExpr& rule_expr) { + int32_t VisitChoices(const RuleExpr& rule_expr) final { std::vector choice_ids; for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + choice_ids.push_back(VisitExpr(i)); } if (choice_ids.size() == 1) { return choice_ids[0]; - } else { - return builder_.AddChoices(choice_ids); } + return builder_.AddChoices(choice_ids); + } + + int32_t VisitCharacterClass(const RuleExpr& rule_expr) final { + if (rule_expr.data_len == 3 && rule_expr[0] == 0 && rule_expr[1] == rule_expr[2]) { + std::string str = PrintAsUTF8(rule_expr[1]); + std::vector bytes; + bytes.reserve(str.size()); + for (char c : str) { + bytes.push_back(static_cast(c)); + } + return builder_.AddByteString(bytes); + } + return builder_.AddRuleExpr(rule_expr); } }; -class NestedRuleUnwrapperImpl : public BNFGrammarMutator { +/*! + * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in + * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. + * + * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class + * or a rule reference. And if the rule can be empty, the first choice will be an empty string. + * + * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice + * containing a sequence of three elements. The empty string is removed. + * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by + * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three + * choices is a sequence containing a single element. + * \example The rule `A ::= (a | (b (c | d)))` will be replaced by + * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested + * choices. + */ +class NestedRuleUnwrapper : public BNFGrammarMutator { public: using BNFGrammarMutator::BNFGrammarMutator; - BNFGrammar Apply() final { - grammar_ = SingleElementSequenceOrChoiceEliminator(grammar_).Apply(); + BNFGrammar Apply(const BNFGrammar& grammar) final { + Init(grammar); for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { builder_.AddEmptyRule(grammar_->GetRule(i).name); } @@ -60,11 +105,20 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { cur_rule_name_ = rule.name; auto new_body_expr_id = VisitRuleBody(rule_expr); builder_.UpdateRuleBody(i, new_body_expr_id); + builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } return builder_.Get(grammar_->GetMainRule().name); } private: + int32_t VisitLookaheadAssertion(int32_t lookahead_assertion_id) final { + if (lookahead_assertion_id == -1) { + return -1; + } + auto assertion_expr = grammar_->GetRuleExpr(lookahead_assertion_id); + return builder_.AddSequence(VisitSequence_(assertion_expr)); + } + /*! \brief Visit a RuleExpr as a rule body. */ int32_t VisitRuleBody(const RuleExpr& rule_expr) { switch (rule_expr.type) { @@ -74,12 +128,11 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { return builder_.AddChoices(VisitChoices_(rule_expr)); case RuleExprType::kEmptyStr: return builder_.AddChoices({builder_.AddEmptyStr()}); + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: + case RuleExprType::kCharacterClassStar: case RuleExprType::kRuleRef: return builder_.AddChoices({builder_.AddSequence({builder_.AddRuleExpr(rule_expr)})}); - case RuleExprType::kCharacterClassStar: - return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } @@ -104,14 +157,12 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { case RuleExprType::kEmptyStr: found_empty = true; break; + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: + case RuleExprType::kCharacterClassStar: case RuleExprType::kRuleRef: VisitElementInChoices(choice_expr, &new_choice_ids); break; - case RuleExprType::kCharacterClassStar: - VisitCharacterClassStarInChoices(choice_expr, &new_choice_ids); - break; default: LOG(FATAL) << "Unexpected choice type: " << static_cast(choice_expr.type); } @@ -154,16 +205,6 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { new_choice_ids->push_back(builder_.AddSequence({sub_expr_id})); } - /*! \brief Visit a character class star RuleExpr that is one of a list of choices. */ - void VisitCharacterClassStarInChoices(const RuleExpr& rule_expr, - std::vector* new_choice_ids) { - auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); - auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); - auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); - auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); - new_choice_ids->push_back(builder_.AddSequence({new_rule_ref_id})); - } - /*! * \brief Visit a RuleExpr containing a sequence. * \returns A list of new sequence RuleExpr ids. @@ -171,26 +212,24 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { std::vector VisitSequence_(const RuleExpr& rule_expr) { std::vector new_sequence_ids; for (auto i : rule_expr) { - auto seq_expr = grammar_->GetRuleExpr(i); - switch (seq_expr.type) { + auto element_expr = grammar_->GetRuleExpr(i); + switch (element_expr.type) { case RuleExprType::kSequence: - VisitSequenceInSequence(seq_expr, &new_sequence_ids); + VisitSequenceInSequence(element_expr, &new_sequence_ids); break; case RuleExprType::kChoices: - VisitChoiceInSequence(seq_expr, &new_sequence_ids); + VisitChoiceInSequence(element_expr, &new_sequence_ids); break; case RuleExprType::kEmptyStr: break; + case RuleExprType::kByteString: case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: - case RuleExprType::kRuleRef: - VisitElementInSequence(seq_expr, &new_sequence_ids); - break; case RuleExprType::kCharacterClassStar: - VisitCharacterClassStarInSequence(seq_expr, &new_sequence_ids); + case RuleExprType::kRuleRef: + VisitElementInSequence(element_expr, &new_sequence_ids); break; default: - LOG(FATAL) << "Unexpected sequence type: " << static_cast(seq_expr.type); + LOG(FATAL) << "Unexpected sequence type: " << static_cast(element_expr.type); } } return new_sequence_ids; @@ -223,22 +262,58 @@ class NestedRuleUnwrapperImpl : public BNFGrammarMutator { void VisitElementInSequence(const RuleExpr& rule_expr, std::vector* new_sequence_ids) { new_sequence_ids->push_back(builder_.AddRuleExpr(rule_expr)); } +}; - /*! \brief Visit a character class star RuleExpr that is in a sequence. */ - void VisitCharacterClassStarInSequence(const RuleExpr& rule_expr, - std::vector* new_sequence_ids) { - auto sub_expr_id = builder_.AddRuleExpr(grammar_->GetRuleExpr(rule_expr[0])); - auto new_star_id = builder_.AddCharacterClassStar(sub_expr_id); - auto new_rule_id = builder_.AddRuleWithHint(cur_rule_name_ + "_star", new_star_id); - auto new_rule_ref_id = builder_.AddRuleRef(new_rule_id); - new_sequence_ids->push_back(new_rule_ref_id); - } +class ByteStringFuser : public BNFGrammarMutator { + public: + using BNFGrammarMutator::Apply; + using BNFGrammarMutator::BNFGrammarMutator; - /*! \brief The name of the current rule being visited. */ - std::string cur_rule_name_; + private: + /*! + * \brief Visit a RuleExpr containing a sequence. + * \returns A list of new sequence RuleExpr ids. + */ + int32_t VisitSequence(const RuleExpr& rule_expr) final { + std::vector new_sequence_ids; + std::vector cur_byte_string; + for (auto i : rule_expr) { + auto element_expr = grammar_->GetRuleExpr(i); + if (element_expr.type == RuleExprType::kByteString) { + cur_byte_string.insert(cur_byte_string.end(), element_expr.begin(), element_expr.end()); + continue; + } else { + if (!cur_byte_string.empty()) { + new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); + cur_byte_string.clear(); + } + new_sequence_ids.push_back(builder_.AddRuleExpr(element_expr)); + } + } + if (!cur_byte_string.empty()) { + new_sequence_ids.push_back(builder_.AddByteString(cur_byte_string)); + } + return builder_.AddSequence(new_sequence_ids); + } }; -BNFGrammar NestedRuleUnwrapper::Apply() { return NestedRuleUnwrapperImpl(grammar_).Apply(); } +// Return the list of all normalizers in the class. The normalizers are applied one by one. +std::vector> BNFGrammarNormalizer::GetNormalizerList() { + std::vector> normalizer_mutators; + normalizer_mutators.emplace_back(std::make_unique()); + normalizer_mutators.emplace_back(std::make_unique()); + normalizer_mutators.emplace_back(std::make_unique()); + return normalizer_mutators; +} + +BNFGrammar BNFGrammarNormalizer::Apply(const BNFGrammar& grammar) { + std::vector> normalizer_mutators = GetNormalizerList(); + grammar_ = grammar; + for (auto& mutator : normalizer_mutators) { + grammar_ = mutator->Apply(grammar_); + } + return grammar_; +} } // namespace serve } // namespace llm diff --git a/cpp/serve/grammar/grammar_simplifier.h b/cpp/serve/grammar/grammar_functor.h similarity index 58% rename from cpp/serve/grammar/grammar_simplifier.h rename to cpp/serve/grammar/grammar_functor.h index 50f3804387..123700778e 100644 --- a/cpp/serve/grammar/grammar_simplifier.h +++ b/cpp/serve/grammar/grammar_functor.h @@ -1,11 +1,11 @@ /*! * Copyright (c) 2023 by Contributors - * \file serve/grammar/grammar_simplifier.h + * \file serve/grammar/grammar_functor.h * \brief The header for the simplification of the BNF AST. */ -#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ -#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#ifndef MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ +#define MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ #include #include @@ -27,29 +27,44 @@ namespace serve { * are void (for visitor) and BNFGrammar (for mutator). */ template -class BNFGrammarMutator { +class BNFGrammarFunctor { public: /*! * \brief Constructor. * \param grammar The grammar to visit or mutate. */ - explicit BNFGrammarMutator(const BNFGrammar& grammar) : grammar_(grammar) {} + explicit BNFGrammarFunctor() {} /*! * \brief Apply the transformation to the grammar, or visit the grammar. * \return The transformed grammar, or the visiting result, or void. - * \note Should be called only once after the mutator is constructed. */ - virtual ReturnType Apply() { - if constexpr (std::is_same::value && std::is_same::value) { + virtual ReturnType Apply(const BNFGrammar& grammar) { + Init(grammar); + if constexpr (std::is_same::value) { for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { auto rule = grammar_->GetRule(i); - auto rule_expr = grammar_->GetRuleExpr(rule.body_expr_id); - auto new_body_expr_id = VisitExpr(rule_expr); - builder_.AddRule(rule.name, new_body_expr_id); + cur_rule_name_ = rule.name; + VisitExpr(rule.body_expr_id); + VisitLookaheadAssertion(rule.lookahead_assertion_id); + } + } else if constexpr (std::is_same::value && + std::is_same::value) { + // First add empty rules to ensure the new rule ids the same as the old ones, then update + // the rule bodies + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + builder_.AddEmptyRule(grammar_->GetRule(i).name); + } + for (int i = 0; i < static_cast(grammar_->NumRules()); ++i) { + auto rule = grammar_->GetRule(i); + cur_rule_name_ = rule.name; + auto new_body_expr_id = VisitExpr(rule.body_expr_id); + builder_.UpdateRuleBody(i, new_body_expr_id); + // Handle lookahead assertion + builder_.AddLookaheadAssertion(i, VisitLookaheadAssertion(rule.lookahead_assertion_id)); } return builder_.Get(grammar_->GetMainRule().name); - } else if constexpr (!std::is_same::value) { + } else { return ReturnType(); } } @@ -59,6 +74,25 @@ class BNFGrammarMutator { using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + /*! \brief Initialize the functor. Should be called at the beginning of Apply(). */ + virtual void Init(const BNFGrammar& grammar) { + grammar_ = grammar; + builder_ = BNFGrammarBuilder(); + } + + /*! \brief Visit a lookahead assertion expr referred by id. */ + virtual T VisitLookaheadAssertion(int32_t lookahead_assertion_id) { + if (lookahead_assertion_id == -1) { + return -1; + } + return VisitExpr(lookahead_assertion_id); + } + + /*! \brief Visit a RuleExpr by id. */ + virtual T VisitExpr(int32_t old_rule_expr_id) { + return VisitExpr(grammar_->GetRuleExpr(old_rule_expr_id)); + } + /*! \brief Visit a RuleExpr. Dispatch to the corresponding Visit function. */ virtual T VisitExpr(const RuleExpr& rule_expr) { switch (rule_expr.type) { @@ -68,47 +102,48 @@ class BNFGrammarMutator { return VisitChoices(rule_expr); case RuleExprType::kEmptyStr: return VisitEmptyStr(rule_expr); + case RuleExprType::kByteString: + return VisitByteString(rule_expr); case RuleExprType::kCharacterClass: - case RuleExprType::kNegCharacterClass: return VisitCharacterClass(rule_expr); - case RuleExprType::kRuleRef: - return VisitRuleRef(rule_expr); case RuleExprType::kCharacterClassStar: return VisitCharacterClassStar(rule_expr); + case RuleExprType::kRuleRef: + return VisitRuleRef(rule_expr); default: LOG(FATAL) << "Unexpected sequence type: " << static_cast(rule_expr.type); } } - /*! \brief Visit a sequence RuleExpr. */ - virtual T VisitSequence(const RuleExpr& rule_expr) { + /*! \brief Visit a choices RuleExpr. */ + virtual T VisitChoices(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { for (auto i : rule_expr) { - VisitExpr(grammar_->GetRuleExpr(i)); + VisitExpr(i); } } else if constexpr (std::is_same::value) { - std::vector sequence_ids; + std::vector choice_ids; for (int32_t i : rule_expr) { - sequence_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + choice_ids.push_back(VisitExpr(i)); } - return builder_.AddSequence(sequence_ids); + return builder_.AddChoices(choice_ids); } else { return T(); } } - /*! \brief Visit a choices RuleExpr. */ - virtual T VisitChoices(const RuleExpr& rule_expr) { + /*! \brief Visit a sequence RuleExpr. */ + virtual T VisitSequence(const RuleExpr& rule_expr) { if constexpr (std::is_same::value) { for (auto i : rule_expr) { - VisitExpr(grammar_->GetRuleExpr(i)); + VisitExpr(i); } } else if constexpr (std::is_same::value) { - std::vector choice_ids; + std::vector sequence_ids; for (int32_t i : rule_expr) { - choice_ids.push_back(VisitExpr(grammar_->GetRuleExpr(i))); + sequence_ids.push_back(VisitExpr(i)); } - return builder_.AddChoices(choice_ids); + return builder_.AddSequence(sequence_ids); } else { return T(); } @@ -128,23 +163,18 @@ class BNFGrammarMutator { /*! \brief Visit an empty string RuleExpr. */ virtual T VisitEmptyStr(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a character class RuleExpr. */ + virtual T VisitByteString(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a character class RuleExpr. */ virtual T VisitCharacterClass(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a star quantifier RuleExpr. */ + virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } + /*! \brief Visit a rule reference RuleExpr. */ virtual T VisitRuleRef(const RuleExpr& rule_expr) { return VisitElement(rule_expr); } - /*! \brief Visit a star quantifier RuleExpr. */ - virtual T VisitCharacterClassStar(const RuleExpr& rule_expr) { - if constexpr (std::is_same::value) { - VisitExpr(grammar_->GetRuleExpr(rule_expr[0])); - } else if constexpr (std::is_same::value) { - return builder_.AddCharacterClassStar(VisitExpr(grammar_->GetRuleExpr(rule_expr[0]))); - } else { - return T(); - } - } - /*! \brief The grammar to visit or mutate. */ BNFGrammar grammar_; /*! @@ -152,33 +182,38 @@ class BNFGrammarMutator { * can be used to build a new grammar in subclasses. */ BNFGrammarBuilder builder_; + /*! \brief The name of the current rule being visited. */ + std::string cur_rule_name_; }; /*! - * \brief Unwrap the rules containing nested expressions. After unwrapping, each rule will be in - * the form: `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - * - * I.e. a list of choices, each choice is a sequence of elements. Elements can be a character class - * or a rule reference. And if the rule can be empty, the first choice will be an empty string. - * - * \example The rule `A ::= ((a) (((b)) (c)) "")` will be replaced by `A ::= ((a b c))`. One choice - * containing a sequence of three elements. The empty string is removed. - * \example The rule `A ::= (a | (b | (c | "")))` will be replaced by - * `A ::= ("" | (a) | (b) | (c))`. The first choice is an empty string, and each of the other three - * choices is a sequence containing a single element. - * \example The rule `A ::= (a | (b (c | d)))` will be replaced by - * `A ::= ((a) | (b B)), B ::= ((c) | (d))`. A new rule B is created to represent the nested - * choices. + * \brief Visitor of BNFGrammar. + * \tparam ReturnType The return type of the Apply() function. Denotes the collected information. */ -class NestedRuleUnwrapper : public BNFGrammarMutator { +template +using BNFGrammarVisitor = BNFGrammarFunctor; + +/*! + * \brief Mutator of BNFGrammar. The Apply() function returns the updated grammar. + */ +using BNFGrammarMutator = BNFGrammarFunctor; + +/*! + * \brief Normalize a BNFGrammar: expand the nested rules, combine consequent sequences and strings, + * etc. + */ +class BNFGrammarNormalizer : public BNFGrammarMutator { public: using BNFGrammarMutator::BNFGrammarMutator; - BNFGrammar Apply() final; + BNFGrammar Apply(const BNFGrammar& grammar) final; + + private: + std::vector> GetNormalizerList(); }; } // namespace serve } // namespace llm } // namespace mlc -#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_SIMPLIFIER_H_ +#endif // MLC_LLM_SERVE_GRAMMAR_GRAMMAR_FUNCTOR_H_ diff --git a/cpp/serve/grammar/grammar_parser.cc b/cpp/serve/grammar/grammar_parser.cc index a4eda4e395..2799ee4ba9 100644 --- a/cpp/serve/grammar/grammar_parser.cc +++ b/cpp/serve/grammar/grammar_parser.cc @@ -29,6 +29,7 @@ class EBNFParserImpl { int32_t ParseRuleRef(); int32_t ParseElement(); int32_t ParseQuantifier(); + int32_t ParseLookaheadAssertion(); int32_t ParseSequence(); int32_t ParseChoices(); Rule ParseRule(); @@ -157,10 +158,10 @@ int32_t EBNFParserImpl::ParseCharacterClass() { } auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_, kCustomEscapeMap); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + if (codepoint == CharHandlingError::kInvalidUTF8) { ThrowParseError("Invalid UTF8 sequence"); } - if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { + if (codepoint == CharHandlingError::kInvalidEscape) { ThrowParseError("Invalid escape sequence"); } Consume(new_cur - cur_); @@ -189,26 +190,37 @@ int32_t EBNFParserImpl::ParseCharacterClass() { // parse a c style string with utf8 support int32_t EBNFParserImpl::ParseString() { - std::vector character_classes; + std::vector codepoints; while (Peek() && Peek() != '\"') { if (Peek() == '\r' || Peek() == '\n') { ThrowParseError("There should be no newline character in a string literal"); } auto [codepoint, new_cur] = ParseNextUTF8OrEscaped(cur_); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + if (codepoint == CharHandlingError::kInvalidUTF8) { ThrowParseError("Invalid utf8 sequence"); } - if (codepoint == static_cast(CharHandlingError::kInvalidEscape)) { + if (codepoint == CharHandlingError::kInvalidEscape) { ThrowParseError("Invalid escape sequence"); } Consume(new_cur - cur_); - character_classes.push_back(builder_.AddCharacterClass({{codepoint, codepoint}})); + codepoints.push_back(codepoint); } - if (character_classes.empty()) { + if (codepoints.empty()) { return builder_.AddEmptyStr(); } - return builder_.AddSequence(character_classes); + + // convert codepoints to string + std::string str; + for (auto codepoint : codepoints) { + str += PrintAsUTF8(codepoint); + } + // convert str to int32_t vector + std::vector bytes; + for (auto c : str) { + bytes.push_back(static_cast(c)); + } + return builder_.AddByteString(bytes); } int32_t EBNFParserImpl::ParseRuleRef() { @@ -264,9 +276,11 @@ int32_t EBNFParserImpl::ParseElement() { } int32_t EBNFParserImpl::HandleStarQuantifier(int32_t rule_expr_id) { - if (builder_.GetRuleExpr(rule_expr_id).type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { + BNFGrammarNode::RuleExpr rule_expr = builder_.GetRuleExpr(rule_expr_id); + if (rule_expr.type == BNFGrammarBuilder::RuleExprType::kCharacterClass) { // We have special handling for character class star, e.g. [a-z]* - return builder_.AddCharacterClassStar(rule_expr_id); + rule_expr.type = BNFGrammarBuilder::RuleExprType::kCharacterClassStar; + return builder_.AddRuleExpr(rule_expr); } else { // For other star quantifiers, we transform it into a rule: // a* --> rule ::= a rule | "" @@ -327,12 +341,11 @@ int32_t EBNFParserImpl::ParseQuantifier() { int32_t EBNFParserImpl::ParseSequence() { std::vector elements; - elements.push_back(ParseQuantifier()); - ConsumeSpace(in_parentheses_); - while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r') { + do { elements.push_back(ParseQuantifier()); ConsumeSpace(in_parentheses_); - } + } while (Peek() && Peek() != '|' && Peek() != ')' && Peek() != '\n' && Peek() != '\r' && + (Peek() != '(' || Peek(1) != '=')); return builder_.AddSequence(elements); } @@ -350,6 +363,24 @@ int32_t EBNFParserImpl::ParseChoices() { return builder_.AddChoices(choices); } +int32_t EBNFParserImpl::ParseLookaheadAssertion() { + if (Peek() != '(' || Peek(1) != '=') { + return -1; + } + Consume(2); + auto prev_in_parentheses = in_parentheses_; + in_parentheses_ = true; + ConsumeSpace(in_parentheses_); + auto result = ParseSequence(); + ConsumeSpace(in_parentheses_); + if (Peek() != ')') { + ThrowParseError("Expect )"); + } + Consume(); + in_parentheses_ = prev_in_parentheses; + return result; +} + EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { std::string name = ParseName(); cur_rule_name_ = name; @@ -359,7 +390,10 @@ EBNFParserImpl::Rule EBNFParserImpl::ParseRule() { } Consume(3); ConsumeSpace(); - return {name, ParseChoices()}; + auto body_id = ParseChoices(); + ConsumeSpace(); + auto lookahead_id = ParseLookaheadAssertion(); + return {name, body_id, lookahead_id}; } void EBNFParserImpl::BuildRuleNameToId() { @@ -399,8 +433,14 @@ BNFGrammar EBNFParserImpl::DoParse(std::string ebnf_string, std::string main_rul ResetStringIterator(ebnf_string.c_str()); ConsumeSpace(); while (Peek()) { + // Throw error when there are multiple lookahead assertions + if (Peek() == '(' && Peek(1) == '=') { + ThrowParseError("Unexpected lookahead assertion"); + } auto new_rule = ParseRule(); builder_.UpdateRuleBody(new_rule.name, new_rule.body_expr_id); + // Update the lookahead assertion + builder_.AddLookaheadAssertion(new_rule.name, new_rule.lookahead_assertion_id); ConsumeSpace(); } diff --git a/cpp/serve/grammar/grammar_parser.h b/cpp/serve/grammar/grammar_parser.h index 4d10e8eb0d..94ac3d4ce1 100644 --- a/cpp/serve/grammar/grammar_parser.h +++ b/cpp/serve/grammar/grammar_parser.h @@ -23,7 +23,7 @@ using namespace tvm::runtime; * \details This function accepts the EBNF notation defined in the W3C XML Specification * (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following * changes: - * - Using # as comment mark instead of /**\/ + * - Using # as comment mark instead of C-style comments * - Accept C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 * - Rule A-B (match A and not match B) is not supported yet * diff --git a/cpp/serve/grammar/grammar_serializer.cc b/cpp/serve/grammar/grammar_serializer.cc index c3c2c88baa..5176b9f102 100644 --- a/cpp/serve/grammar/grammar_serializer.cc +++ b/cpp/serve/grammar/grammar_serializer.cc @@ -18,7 +18,11 @@ namespace serve { using namespace tvm::runtime; std::string BNFGrammarPrinter::PrintRule(const Rule& rule) { - return rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); + std::string res = rule.name + " ::= " + PrintRuleExpr(rule.body_expr_id); + if (rule.lookahead_assertion_id != -1) { + res += " (=" + PrintRuleExpr(rule.lookahead_assertion_id) + ")"; + } + return res; } std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { @@ -28,10 +32,12 @@ std::string BNFGrammarPrinter::PrintRule(int32_t rule_id) { std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { std::string result; switch (rule_expr.type) { + case RuleExprType::kByteString: + return PrintByteString(rule_expr); case RuleExprType::kCharacterClass: return PrintCharacterClass(rule_expr); - case RuleExprType::kNegCharacterClass: - return PrintCharacterClass(rule_expr); + case RuleExprType::kCharacterClassStar: + return PrintCharacterClassStar(rule_expr); case RuleExprType::kEmptyStr: return PrintEmptyStr(rule_expr); case RuleExprType::kRuleRef: @@ -40,8 +46,6 @@ std::string BNFGrammarPrinter::PrintRuleExpr(const RuleExpr& rule_expr) { return PrintSequence(rule_expr); case RuleExprType::kChoices: return PrintChoices(rule_expr); - case RuleExprType::kCharacterClassStar: - return PrintCharacterClassStar(rule_expr); default: LOG(FATAL) << "Unexpected RuleExpr type: " << static_cast(rule_expr.type); } @@ -51,14 +55,29 @@ std::string BNFGrammarPrinter::PrintRuleExpr(int32_t rule_expr_id) { return PrintRuleExpr(grammar_->GetRuleExpr(rule_expr_id)); } +std::string BNFGrammarPrinter::PrintByteString(const RuleExpr& rule_expr) { + std::string internal_str; + internal_str.reserve(rule_expr.data_len); + for (int i = 0; i < rule_expr.data_len; ++i) { + internal_str += static_cast(rule_expr[i]); + } + auto codepoints = ParseUTF8(internal_str.c_str(), UTF8ErrorPolicy::kReturnByte); + std::string result; + for (auto codepoint : codepoints) { + result += PrintAsEscaped(codepoint); + } + return "\"" + result + "\""; +} + std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { static const std::unordered_map kCustomEscapeMap = {{'-', "\\-"}, {']', "\\]"}}; std::string result = "["; - if (rule_expr.type == RuleExprType::kNegCharacterClass) { + bool is_negative = static_cast(rule_expr[0]); + if (is_negative) { result += "^"; } - for (auto i = 0; i < rule_expr.data_len; i += 2) { + for (auto i = 1; i < rule_expr.data_len; i += 2) { result += PrintAsEscaped(rule_expr[i], kCustomEscapeMap); if (rule_expr[i] == rule_expr[i + 1]) { continue; @@ -70,6 +89,10 @@ std::string BNFGrammarPrinter::PrintCharacterClass(const RuleExpr& rule_expr) { return result; } +std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { + return PrintCharacterClass(rule_expr) + "*"; +} + std::string BNFGrammarPrinter::PrintEmptyStr(const RuleExpr& rule_expr) { return "\"\""; } std::string BNFGrammarPrinter::PrintRuleRef(const RuleExpr& rule_expr) { @@ -103,10 +126,6 @@ std::string BNFGrammarPrinter::PrintChoices(const RuleExpr& rule_expr) { return result; } -std::string BNFGrammarPrinter::PrintCharacterClassStar(const RuleExpr& rule_expr) { - return PrintRuleExpr(rule_expr[0]) + "*"; -} - std::string BNFGrammarPrinter::ToString() { std::string result; auto num_rules = grammar_->NumRules(); @@ -121,7 +140,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToString").set_body_typed([](const BNFG }); std::string BNFGrammarJSONSerializer::ToString() { - picojson::object grammar_json; + picojson::object grammar_json_obj; picojson::array rules_json; for (const auto& rule : grammar_->rules_) { @@ -130,20 +149,21 @@ std::string BNFGrammarJSONSerializer::ToString() { rule_json["body_expr_id"] = picojson::value(static_cast(rule.body_expr_id)); rules_json.push_back(picojson::value(rule_json)); } - grammar_json["rules"] = picojson::value(rules_json); + grammar_json_obj["rules"] = picojson::value(rules_json); picojson::array rule_expr_data_json; for (const auto& data : grammar_->rule_expr_data_) { rule_expr_data_json.push_back(picojson::value(static_cast(data))); } - grammar_json["rule_expr_data"] = picojson::value(rule_expr_data_json); + grammar_json_obj["rule_expr_data"] = picojson::value(rule_expr_data_json); picojson::array rule_expr_indptr_json; for (const auto& index_ptr : grammar_->rule_expr_indptr_) { rule_expr_indptr_json.push_back(picojson::value(static_cast(index_ptr))); } - grammar_json["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); + grammar_json_obj["rule_expr_indptr"] = picojson::value(rule_expr_indptr_json); - return picojson::value(grammar_json).serialize(prettify_); + auto grammar_json = picojson::value(grammar_json_obj); + return grammar_json.serialize(prettify_); } TVM_REGISTER_GLOBAL("mlc.serve.BNFGrammarToJSON") diff --git a/cpp/serve/grammar/grammar_serializer.h b/cpp/serve/grammar/grammar_serializer.h index 4ad5c2103b..f0837d9638 100644 --- a/cpp/serve/grammar/grammar_serializer.h +++ b/cpp/serve/grammar/grammar_serializer.h @@ -62,8 +62,12 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { std::string PrintRuleExpr(int32_t rule_expr_id); private: + /*! \brief Print a RuleExpr for byte string. */ + std::string PrintByteString(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for character class. */ std::string PrintCharacterClass(const RuleExpr& rule_expr); + /*! \brief Print a RuleExpr for a star quantifier of a character class. */ + std::string PrintCharacterClassStar(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for empty string. */ std::string PrintEmptyStr(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for rule reference. */ @@ -72,8 +76,6 @@ class BNFGrammarPrinter : public BNFGrammarSerializer { std::string PrintSequence(const RuleExpr& rule_expr); /*! \brief Print a RuleExpr for rule_expr choices. */ std::string PrintChoices(const RuleExpr& rule_expr); - /*! \brief Print a RuleExpr for star quantifier. */ - std::string PrintCharacterClassStar(const RuleExpr& rule_expr); }; /*! diff --git a/cpp/serve/grammar/grammar_state_matcher.cc b/cpp/serve/grammar/grammar_state_matcher.cc index 451127e746..e6e68f376f 100644 --- a/cpp/serve/grammar/grammar_state_matcher.cc +++ b/cpp/serve/grammar/grammar_state_matcher.cc @@ -2,6 +2,7 @@ * Copyright (c) 2023 by Contributors * \file serve/grammar/grammar_state_matcher.cc */ +// #define TVM_LOG_DEBUG 1 #include "grammar_state_matcher.h" #include @@ -123,13 +124,15 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm private: using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + using SaveType = CatagorizedTokens::SaveType; public: GrammarStateMatcherNodeImpl(std::shared_ptr init_ctx, int max_rollback_steps = 0) : GrammarStateMatcherBase(init_ctx->grammar), init_ctx_(init_ctx), - max_rollback_steps_(max_rollback_steps) {} + max_rollback_steps_(max_rollback_steps), + tmp_accepted_bitset_(init_ctx_->vocab_size) {} bool AcceptToken(int32_t token_id) final; @@ -143,8 +146,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm void ResetState() final { stack_tops_history_.Reset(); - token_size_history_.clear(); - InitStackState(); + token_length_history.clear(); + PushInitialState(kInvalidRulePosition, true); } private: @@ -160,14 +163,8 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm const std::vector& uncertain_tokens_bitset); /*! \brief Set the acceptable next token in next_token_bitmask. */ - void SetTokenBitmask(DLTensor* next_token_bitmask, std::vector& accepted_indices, - std::vector& rejected_indices, bool can_reach_end); - - /*! \brief Check if a token is a stop token. */ - bool IsStopToken(int32_t token_id) const { - return std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), - token_id) != init_ctx_->stop_token_ids.end(); - } + void SetTokenBitmask(DLTensor* next_token_bitmask, const DynamicBitset& accepted_bitset, + const std::vector& rejected_indices, bool can_reach_end); /*! * \brief Accept the stop token and terminates the matcher. @@ -180,14 +177,12 @@ class GrammarStateMatcherNodeImpl : public GrammarStateMatcherNode, public Gramm std::shared_ptr init_ctx_; int max_rollback_steps_; - std::deque token_size_history_; + std::deque token_length_history; // Temporary data for FindNextTokenBitmask. They are stored here to avoid repeated allocation. - std::vector tmp_accepted_indices_; + DynamicBitset tmp_accepted_bitset_; std::vector tmp_rejected_indices_; - std::vector tmp_accepted_indices_delta_; std::vector tmp_rejected_indices_delta_; - std::vector tmp_uncertain_tokens_bitset_; }; bool GrammarStateMatcherNodeImpl::AcceptStopToken() { @@ -204,23 +199,31 @@ bool GrammarStateMatcherNodeImpl::AcceptToken(int32_t token_id) { "accept another token id " << token_id; + CHECK(token_id >= 0 && token_id < init_ctx_->vocab_size) + << "Invalid token id " << token_id << " for GrammarStateMatcher"; + // Handle the stop token - if (IsStopToken(token_id)) { + if (std::find(init_ctx_->stop_token_ids.begin(), init_ctx_->stop_token_ids.end(), token_id) != + init_ctx_->stop_token_ids.end()) { return AcceptStopToken(); } - CHECK(init_ctx_->id_to_token_codepoints.count(token_id) > 0) - << "Token id " << token_id << " is not supported in generation"; - const auto& token = init_ctx_->id_to_token_codepoints[token_id].token; - for (auto codepoint : token) { - if (!AcceptCodepoint(codepoint, false)) { + if (init_ctx_->special_token_ids.count(token_id) > 0) { + LOG(FATAL) + << "Token id " << token_id << ": " << init_ctx_->token_table[token_id] + << " is regarded as a special token, and cannot be accepted by the GrammarStateMatcher"; + } + + const auto& token = init_ctx_->token_table[token_id]; + for (auto char_value : token) { + if (!AcceptChar(char_value, false)) { return false; } } - token_size_history_.push_back(token.size()); - if (token_size_history_.size() > max_rollback_steps_) { - DiscardEarliestCodepoints(token_size_history_.front()); - token_size_history_.pop_front(); + token_length_history.push_back(token.size()); + if (token_length_history.size() > max_rollback_steps_) { + DiscardEarliestChars(token_length_history.front()); + token_length_history.pop_front(); } return true; } @@ -229,7 +232,7 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm CHECK(!IsTerminated()) << "GrammarStateMatcher has terminated after accepting the stop token, but is trying to " "find the next token mask"; - const auto& sorted_token_codepoints = init_ctx_->sorted_token_codepoints; + const auto& sorted_token_table = init_ctx_->sorted_token_table; const auto& catagorized_tokens_for_grammar = init_ctx_->catagorized_tokens_for_grammar; const auto& latest_stack_tops = stack_tops_history_.GetLatest(); @@ -238,113 +241,132 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm // The final accepted token set is the union of the accepted token sets of all stacks. // The final rejected token set is the intersection of the rejected token sets of all stacks. - // Note these indices store the indices in sorted_token_codepoints, instead of the token ids. - tmp_accepted_indices_.clear(); + // Note these indices store the indices in sorted_token_table, instead of the token ids. + tmp_accepted_bitset_.Reset(); // {-1} means the universal set, i.e. all tokens initially tmp_rejected_indices_.assign({-1}); + // std::chrono::microseconds time_unc(0); + // std::chrono::microseconds time_idx(0); + int check_cnt = 0; + for (auto top : latest_stack_tops) { - // Step 1. Find the current catagorized_tokens auto cur_rule_position = tree_[top]; - auto current_sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); - if (cur_rule_position.parent_id == RulePosition::kNoParent && - cur_rule_position.element_id == current_sequence.size()) { + if (tree_.IsEndPosition(cur_rule_position)) { continue; } - const auto& catagorized_tokens = catagorized_tokens_for_grammar.at( - {cur_rule_position.sequence_id, cur_rule_position.element_id}); + const auto& catagorized_tokens = catagorized_tokens_for_grammar.at(cur_rule_position); + + // auto start = std::chrono::high_resolution_clock::now(); // For each stack, we will check every uncertain token and put them into the accepted or // rejected list. - // If the accepted tokens are saved, it means it is likely to be smaller than the rejected - // tokens, so we will just find the accepted tokens, and vice versa. - bool is_find_accept_mode = - catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kAccepted; - - // If uncertain tokens are saved, we will iterate over the uncertain tokens. - // Otherwise, we will iterate over all_tokens - accepted_tokens - rejected_tokens. - bool is_uncertain_saved = - catagorized_tokens.not_saved_index != CatagorizedTokens::NotSavedIndex::kUncertain; // Step 2. Update the accepted tokens in accepted_indices_delta, or the rejected tokens in // rejected_indices_delta. - // Examine only the current one stack - stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - - const std::vector* prev_token = nullptr; - int prev_matched_size = 0; + // If the accepted tokens are saved, it means it is likely to be smaller than the rejected + // tokens, so we will just find the accepted tokens, and vice versa. - tmp_accepted_indices_delta_.clear(); tmp_rejected_indices_delta_.clear(); - if (!is_uncertain_saved) { - // unc_tokens = all_tokens - accepted_tokens - rejected_tokens - tmp_uncertain_tokens_bitset_.assign(sorted_token_codepoints.size(), true); - for (auto idx : catagorized_tokens.accepted_indices) { - tmp_uncertain_tokens_bitset_[idx] = false; - } - for (auto idx : catagorized_tokens.rejected_indices) { - tmp_uncertain_tokens_bitset_[idx] = false; - } - } + // Examine only the current one stack + stack_tops_history_.PushHistory({tree_.NewNode(cur_rule_position)}); - int iterator_uncertain = -1; + const std::string* prev_token = nullptr; + int prev_matched_size = 0; - while (true) { - // Step 2.1. Find the current token. - auto idx = - GetNextUncertainToken(is_uncertain_saved, &iterator_uncertain, - catagorized_tokens.uncertain_indices, tmp_uncertain_tokens_bitset_); - if (idx == -1) { - break; - } - const auto& cur_token = sorted_token_codepoints[idx].token; + // std::cout << tree_.PrintNode(top) << std::endl; + + // std::cout << "Accepted count: " << catagorized_tokens.accepted_indices.size() + // << ", rejected count: " << catagorized_tokens.rejected_indices.size() + // << ", uncertain count: " << catagorized_tokens.uncertain_indices.size() + // << ", save type: " << static_cast(catagorized_tokens.save_type) << std::endl; + + // if (catagorized_tokens.accepted_indices.size() < 200) { + // std::cout << "Accpeted: "; + // for (int i = 0; i < catagorized_tokens.accepted_indices.size(); ++i) { + // std::cout << "<" + // << PrintAsEscaped( + // sorted_token_table[catagorized_tokens.accepted_indices[i]].second) + // << "> "; + // } + // std::cout << "\n"; + // } + + // if (catagorized_tokens.uncertain_indices.size() > 100) { + // std::cout << "Uncertain: "; + // for (int i = 0; i < catagorized_tokens.uncertain_indices.size(); ++i) { + // std::cout << "<" + // << PrintAsEscaped( + // sorted_token_table[catagorized_tokens.uncertain_indices[i]].second) + // << "> "; + // } + // std::cout << "\n"; + // } + + for (auto cur_token_idx : catagorized_tokens.uncertain_indices) { + const auto& cur_token = sorted_token_table[cur_token_idx].second; + bool accepted = true; - // Step 2.2. Find the longest common prefix with the accepted part of the previous token. + // Step 2.1. Find the longest common prefix with the accepted part of the previous token. // We can reuse the previous matched size to avoid unnecessary matching. - int prev_useful_size = 0; if (prev_token) { - prev_useful_size = std::min(prev_matched_size, static_cast(cur_token.size())); - for (int j = 0; j < prev_useful_size; ++j) { - if (cur_token[j] != (*prev_token)[j]) { - prev_useful_size = j; - break; - } + int lcp_len = std::mismatch(cur_token.begin(), cur_token.end(), prev_token->begin(), + prev_token->end()) + .first - + cur_token.begin(); + if (lcp_len > prev_matched_size) { + accepted = false; + } else if (lcp_len < prev_matched_size) { + RollbackChars(prev_matched_size - lcp_len); } - RollbackCodepoints(prev_matched_size - prev_useful_size); + prev_matched_size = std::min(prev_matched_size, lcp_len); } - // Step 2.3. Find if the current token is accepted or rejected. - bool accepted = true; - prev_matched_size = prev_useful_size; - - for (int j = prev_useful_size; j < cur_token.size(); ++j) { - if (!AcceptCodepoint(cur_token[j], false)) { - accepted = false; - break; + // Step 2.2. Find if the current token is accepted or rejected. + if (accepted) { + for (int j = prev_matched_size; j < cur_token.size(); ++j) { + ++check_cnt; + if (!AcceptChar(cur_token[j], false)) { + accepted = false; + break; + } + prev_matched_size = j + 1; } - prev_matched_size = j + 1; } - // Step 2.4. Push the result to the delta list. - if (accepted && is_find_accept_mode) { - tmp_accepted_indices_delta_.push_back(idx); - } else if (!accepted && !is_find_accept_mode) { - tmp_rejected_indices_delta_.push_back(idx); + // Step 2.3. Push the result to the delta list. + if (catagorized_tokens.save_type == SaveType::kAcceptedBitset || + catagorized_tokens.save_type == SaveType::kAccepted) { + if (accepted) { + tmp_accepted_bitset_.Set(sorted_token_table[cur_token_idx].first, true); + } + } else { + if (!accepted) { + tmp_rejected_indices_delta_.push_back(cur_token_idx); + } } prev_token = &cur_token; } - RollbackCodepoints(prev_matched_size + 1); + RollbackChars(prev_matched_size + 1); + + // auto end = std::chrono::high_resolution_clock::now(); + + // time_unc += std::chrono::duration_cast(end - start); + + // start = std::chrono::high_resolution_clock::now(); // Step 3. Update the accepted_indices and rejected_indices - if (is_find_accept_mode) { - // accepted_indices += catagorized_tokens.accepted_indices + accepted_indices_delta - IntsetUnion(&tmp_accepted_indices_delta_, catagorized_tokens.accepted_indices); - IntsetUnion(&tmp_accepted_indices_, tmp_accepted_indices_delta_); + if (catagorized_tokens.save_type == SaveType::kAcceptedBitset) { + tmp_accepted_bitset_ |= catagorized_tokens.accepted_bitset; + } else if (catagorized_tokens.save_type == SaveType::kAccepted) { + for (auto idx : catagorized_tokens.accepted_indices) { + tmp_accepted_bitset_.Set(sorted_token_table[idx].first, true); + } } else { // rejected_indices = Intersect( // rejected_indices, @@ -352,72 +374,81 @@ void GrammarStateMatcherNodeImpl::FindNextTokenBitmask(DLTensor* next_token_bitm IntsetUnion(&tmp_rejected_indices_delta_, catagorized_tokens.rejected_indices); IntsetIntersection(&tmp_rejected_indices_, tmp_rejected_indices_delta_); } + // end = std::chrono::high_resolution_clock::now(); + // time_idx += std::chrono::duration_cast(end - start); } // Finally update the rejected_ids bitset + // auto start = std::chrono::high_resolution_clock::now(); bool can_reach_end = CanReachEnd(); - SetTokenBitmask(next_token_bitmask, tmp_accepted_indices_, tmp_rejected_indices_, can_reach_end); + SetTokenBitmask(next_token_bitmask, tmp_accepted_bitset_, tmp_rejected_indices_, can_reach_end); + // auto end = std::chrono::high_resolution_clock::now(); + // time_idx += std::chrono::duration_cast(end - start); + // std::cout << "Time for uncertain: " << time_unc.count() + // << "us, time for index: " << time_idx.count() << "us" << std::endl; + // std::cout << "Check cnt " << check_cnt << std::endl; } void GrammarStateMatcherNodeImpl::Rollback(int num_tokens) { - CHECK(num_tokens <= token_size_history_.size()) + CHECK(num_tokens <= token_length_history.size()) << "Intended to rollback " << num_tokens << " tokens, but only the last " - << token_size_history_.size() << " steps of history are saved"; + << token_length_history.size() << " steps of history are saved"; while (num_tokens > 0) { - int steps = token_size_history_.back(); - RollbackCodepoints(steps); - token_size_history_.pop_back(); + int steps = token_length_history.back(); + RollbackChars(steps); + token_length_history.pop_back(); --num_tokens; } } void GrammarStateMatcherNodeImpl::SetTokenBitmask(DLTensor* next_token_bitmask, - std::vector& accepted_indices, - std::vector& rejected_indices, + const DynamicBitset& accepted_bitset, + const std::vector& rejected_indices, bool can_reach_end) { - // accepted_ids = Union(accepted_indices, all_tokens - rejected_indices) - // rejected_ids = Intersect(all_tokens - accepted_indices, rejected_indices) + // next_token_bitmask = set(all accepted tokens) = + // 1. all_tokens - (rejected_ids / accepted_ids) + // (when rejected_ids != {-1}, i.e. rejected_ids is not the universal set) + // 2. accepted_ids + // (otherwise, when rejected_ids is the universal set) CHECK(next_token_bitmask->dtype.code == kDLUInt && next_token_bitmask->dtype.bits == 32 && next_token_bitmask->data && next_token_bitmask->ndim == 1 && next_token_bitmask->shape) << "The provied bitmask's shape or dtype is not valid."; + CHECK(next_token_bitmask->shape[0] >= DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size)) + << "The provided bitmask is not large enough to store the token set. The length should be " + << DynamicBitset::CalculateBufferSize(init_ctx_->vocab_size) << " at least"; - BitsetManager next_token_bitset(reinterpret_cast(next_token_bitmask->data), - next_token_bitmask->shape[0], init_ctx_->vocab_size); + DynamicBitset next_token_bitset(init_ctx_->vocab_size, + reinterpret_cast(next_token_bitmask->data)); + const auto& sorted_token_table = init_ctx_->sorted_token_table; if (rejected_indices.size() == 1 && rejected_indices[0] == -1) { // If rejected_indices is the universal set, the final accepted token set is just // accepted_indices - next_token_bitset.Reset(false); - for (int idx : accepted_indices) { - next_token_bitset.Set(init_ctx_->sorted_token_codepoints[idx].id, true); - } + next_token_bitset = accepted_bitset; if (can_reach_end) { // add end tokens - for (int idx : init_ctx_->stop_token_ids) { - next_token_bitset.Set(idx, true); + for (int id : init_ctx_->stop_token_ids) { + next_token_bitset.Set(id, true); } } } else { // Otherwise, the final rejected token set is (rejected_indices \ accepted_indices) - next_token_bitset.Reset(true); + next_token_bitset.Set(); - auto it_acc = accepted_indices.begin(); for (auto i : rejected_indices) { - while (it_acc != accepted_indices.end() && *it_acc < i) { - ++it_acc; - } - if (it_acc == accepted_indices.end() || *it_acc != i) { - next_token_bitset.Set(init_ctx_->sorted_token_codepoints[i].id, false); + auto id = sorted_token_table[i].first; + if (!accepted_bitset[id]) { + next_token_bitset.Set(id, false); } } - for (int idx : init_ctx_->special_token_ids) { - next_token_bitset.Set(idx, false); + for (int id : init_ctx_->special_token_ids) { + next_token_bitset.Set(id, false); } if (!can_reach_end) { - for (int idx : init_ctx_->stop_token_ids) { - next_token_bitset.Set(idx, false); + for (int id : init_ctx_->stop_token_ids) { + next_token_bitset.Set(id, false); } } } @@ -452,16 +483,24 @@ GrammarStateMatcher::GrammarStateMatcher(std::shared_ptr tokenizer, int max_rollback_steps) { + .set_body_typed([](BNFGrammar grammar, Optional tokenizer, int max_rollback_steps, + String token_table_postproc_method) { auto preproc_start = std::chrono::high_resolution_clock::now(); - auto init_ctx = GrammarStateMatcher::CreateInitContext( - grammar, tokenizer ? tokenizer.value()->TokenTable() : std::vector()); + std::shared_ptr init_ctx; + if (tokenizer) { + auto token_table = Tokenizer::PostProcessTokenTable(tokenizer.value()->TokenTable(), + token_table_postproc_method); + init_ctx = GrammarStateMatcher::CreateInitContext(grammar, token_table); + } else { + init_ctx = GrammarStateMatcher::CreateInitContext(grammar, {}); + } + auto preproc_end = std::chrono::high_resolution_clock::now(); - std::cerr << "Preprocess takes " + LOG(INFO) << "GrammarStateMatcher preprocess takes " << std::chrono::duration_cast(preproc_end - preproc_start) .count() - << "us" << std::endl; + << "us"; return GrammarStateMatcher(init_ctx, max_rollback_steps); }); #endif @@ -479,11 +518,11 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFromTokenTable") *rv = GrammarStateMatcher(init_ctx, max_rollback_steps); }); -TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptCodepoint") - .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint) { +TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugAcceptChar") + .set_body_typed([](GrammarStateMatcher matcher, int32_t codepoint, bool verbose) { auto mutable_node = const_cast(matcher.as()); - return mutable_node->AcceptCodepoint(codepoint); + return mutable_node->AcceptChar(codepoint, verbose); }); TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherAcceptToken") @@ -507,32 +546,43 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherResetState") /*! \brief Check if a matcher can accept the complete string, and then reach the end of the * grammar. Does not change the state of the GrammarStateMatcher. For test purpose. */ -bool MatchCompleteString(GrammarStateMatcher matcher, String str) { +bool MatchCompleteString(GrammarStateMatcher matcher, String str, bool verbose) { auto mutable_node = const_cast(matcher.as()); - auto codepoints = ParseUTF8(str.c_str()); int accepted_cnt = 0; - for (auto codepoint : codepoints) { - if (!mutable_node->AcceptCodepoint(codepoint, false)) { - mutable_node->RollbackCodepoints(accepted_cnt); + for (auto char_value : str.operator std::string()) { + if (!mutable_node->AcceptChar(char_value, verbose)) { + if (verbose) { + LOG(INFO) << "Matching failed after accepting " << accepted_cnt << " characters"; + } + mutable_node->RollbackChars(accepted_cnt); return false; } ++accepted_cnt; } auto accepted = mutable_node->CanReachEnd(); - mutable_node->RollbackCodepoints(accepted_cnt); + if (verbose) { + if (accepted) { + LOG(INFO) << "Matching succeed after accepting " << accepted_cnt << " characters"; + } else { + LOG(INFO) << "Matching failed due to the end state not reached after all " << accepted_cnt + << " characters are accepted"; + } + } + mutable_node->RollbackChars(accepted_cnt); return accepted; } TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherDebugMatchCompleteString") - .set_body_typed([](GrammarStateMatcher matcher, String str) { - return MatchCompleteString(matcher, str); + .set_body_typed([](GrammarStateMatcher matcher, String str, bool verbose) { + return MatchCompleteString(matcher, str, verbose); }); /*! \brief Print the accepted and rejected tokens stored in the bitset. For debug purposes. */ -void PrintAcceptedRejectedTokens( +std::string PrintAcceptedRejectedTokens( const std::shared_ptr& init_ctx, - const BitsetManager& bitset, int threshold = 500) { + const DynamicBitset& bitset, int threshold = 300) { + std::stringstream ss; auto vocab_size = init_ctx->vocab_size; std::vector accepted_ids; std::vector rejected_ids; @@ -544,42 +594,27 @@ void PrintAcceptedRejectedTokens( } } - if (accepted_ids.size() < threshold) { - std::cerr << "Accepted: "; - for (auto id : accepted_ids) { - std::cerr << "<"; - auto token = init_ctx->token_table[id]; - if (token.size() == 1 && (static_cast(token[0]) >= 128 || token[0] == 0)) { - // First cast to unsigned, then cast to int - std::cerr << static_cast(static_cast(token[0])); - } else { - auto codepoints = ParseUTF8(token.c_str()); - for (auto c : codepoints) { - std::cerr << PrintAsEscaped(c); - } - } - std::cerr << "> "; - } - std::cerr << "\n"; + ss << "Accepted: "; + auto end_it = + accepted_ids.size() > threshold ? accepted_ids.begin() + threshold : accepted_ids.end(); + for (auto it = accepted_ids.begin(); it != end_it; ++it) { + ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; + } + if (accepted_ids.size() > threshold) { + ss << "..."; } + ss << "\n"; - if (rejected_ids.size() < threshold) { - std::cerr << "Rejected: "; - for (auto id : rejected_ids) { - std::cerr << "<"; - auto token = init_ctx->token_table[id]; - if (token.size() == 1 && ((unsigned char)token[0] >= 128 || token[0] == 0)) { - std::cerr << (int)(unsigned char)token[0]; - } else { - auto codepoints = ParseUTF8(token.c_str()); - for (auto c : codepoints) { - std::cerr << PrintAsEscaped(c); - } - } - std::cerr << "> "; - } - std::cerr << "\n"; + ss << "Rejected: "; + end_it = rejected_ids.size() > threshold ? rejected_ids.begin() + threshold : rejected_ids.end(); + for (auto it = rejected_ids.begin(); it != end_it; ++it) { + ss << "<" << PrintAsEscaped(init_ctx->token_table[*it]) << "> "; + } + if (rejected_ids.size() > threshold) { + ss << "..."; } + ss << "\n"; + return ss.str(); } /*! @@ -591,7 +626,7 @@ void PrintAcceptedRejectedTokens( IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = false) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); auto ndarray = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(ndarray.operator->()); @@ -605,7 +640,7 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals end = std::chrono::high_resolution_clock::now(); } - auto bitset = BitsetManager(reinterpret_cast(dltensor->data), bitset_size, vocab_size); + auto bitset = DynamicBitset(vocab_size, reinterpret_cast(dltensor->data)); std::vector rejected_ids; for (int i = 0; i < vocab_size; i++) { if (bitset[i] == 0) { @@ -614,10 +649,10 @@ IntTuple FindNextRejectedTokens(GrammarStateMatcher matcher, bool verbose = fals } if (verbose) { - std::cerr << "FindNextTokenBitmask takes " + LOG(INFO) << "FindNextTokenBitmask takes " << std::chrono::duration_cast(end - start).count() << "us" << ", found accepted: " << vocab_size - rejected_ids.size() - << ", rejected: " << rejected_ids.size() << std::endl; + << ", rejected: " << rejected_ids.size(); } auto ret = IntTuple(rejected_ids); @@ -634,7 +669,7 @@ TVM_REGISTER_GLOBAL("mlc.serve.GrammarStateMatcherFindNextRejectedTokens") NDArray FindNextTokenBitmaskAsNDArray(GrammarStateMatcher matcher) { auto init_ctx = matcher.as()->init_ctx_; auto vocab_size = init_ctx->vocab_size; - auto bitset_size = BitsetManager::CalculateBufferSize(vocab_size); + auto bitset_size = DynamicBitset::CalculateBufferSize(vocab_size); auto bitmask = NDArray::Empty(ShapeTuple{static_cast(bitset_size)}, DLDataType{kDLUInt, 32, 1}, DLDevice{kDLCPU, 0}); auto dltensor = const_cast(bitmask.operator->()); diff --git a/cpp/serve/grammar/grammar_state_matcher.h b/cpp/serve/grammar/grammar_state_matcher.h index eceaa75d07..eedf7a1989 100644 --- a/cpp/serve/grammar/grammar_state_matcher.h +++ b/cpp/serve/grammar/grammar_state_matcher.h @@ -130,14 +130,13 @@ class GrammarStateMatcher : public ObjectRef { }; /*! - * \brief Helper class to get the grammar state init context for grammars or schemas. This class - * maintains cache internally, so the same grammar or schema will not be preprocessed multiple - * times. + * \brief A cache to get the grammar state init context for grammar or schema. This class avoids + * redundant preprocessing of the grammar or schema when constructing a GrammarStateInitContext. * \note This class is associated with a token table when constructed. The token table is used to * create every grammar state init context. If multiple toke tables are used to create init * contexts, an instance of this class for each token table should be created. */ -class GrammarInitContextStorageNode : public Object { +class GrammarInitContextCacheNode : public Object { public: /*! \brief Get the init context for pure JSON. */ virtual std::shared_ptr GetInitContextForJSON() = 0; @@ -147,25 +146,25 @@ class GrammarInitContextStorageNode : public Object { const std::string& schema) = 0; /*! \brief Clear the interal cache of init contexts. */ - virtual void ClearCache() = 0; + virtual void Clear() = 0; - static constexpr const char* _type_key = "mlc.serve.GrammarInitContextStorageNode"; + static constexpr const char* _type_key = "mlc.serve.GrammarInitContextCacheNode"; static constexpr const bool _type_has_method_sequal_reduce = false; static constexpr const bool _type_has_method_shash_reduce = false; - TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextStorageNode, Object); + TVM_DECLARE_BASE_OBJECT_INFO(GrammarInitContextCacheNode, Object); }; -class GrammarInitContextStorage : public ObjectRef { +class GrammarInitContextCache : public ObjectRef { public: /*! - * \brief Construct a GrammarInitContextStorage with a token table. This class will always create + * \brief Construct a GrammarInitContextCache with a token table. This class will always create * grammar state init contexts with this token table. * \param token_table The token table that the grammar will use. */ - GrammarInitContextStorage(const std::vector& token_table); + GrammarInitContextCache(const std::vector& token_table); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextStorage, ObjectRef, - GrammarInitContextStorageNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(GrammarInitContextCache, ObjectRef, + GrammarInitContextCacheNode); }; } // namespace serve diff --git a/cpp/serve/grammar/grammar_state_matcher_base.h b/cpp/serve/grammar/grammar_state_matcher_base.h index 5b774d33a4..1241e7307a 100644 --- a/cpp/serve/grammar/grammar_state_matcher_base.h +++ b/cpp/serve/grammar/grammar_state_matcher_base.h @@ -32,95 +32,172 @@ class GrammarStateMatcherBase { * \param grammar The grammar to match. * \param init_rule_position The initial rule position. If not specified, the main rule will be * used. + * \param expand_init_rule_position Whether to expand the initial rule position to all possible + * locations. See ExpandRulePosition. */ - GrammarStateMatcherBase(const BNFGrammar& grammar, RulePosition init_rule_position = {}) + GrammarStateMatcherBase(const BNFGrammar& grammar, + RulePosition init_rule_position = kInvalidRulePosition, + bool expand_init_rule_position = true) : grammar_(grammar), tree_(grammar), stack_tops_history_(&tree_) { - InitStackState(init_rule_position); + PushInitialState(init_rule_position, expand_init_rule_position); } - /*! \brief Accept one codepoint. */ - bool AcceptCodepoint(TCodepoint codepoint, bool verbose = false); + /*! \brief Accept one character. */ + bool AcceptChar(uint8_t char_value, bool verbose = false); /*! \brief Check if the end of the main rule is reached. If so, the stop token can be accepted. */ bool CanReachEnd() const; - /*! \brief Rollback the matcher to a previous state. */ - void RollbackCodepoints(int rollback_codepoint_cnt); + /*! \brief Rollback the matcher to a previous state by the number of characters. */ + void RollbackChars(int rollback_cnt); - /*! \brief Discard the earliest history. */ - void DiscardEarliestCodepoints(int discard_codepoint_cnt); + /*! \brief Discard the earliest history by the number of characters. */ + void DiscardEarliestChars(int discard_cnt); /*! \brief Print the stack state. */ std::string PrintStackState(int steps_behind_latest = 0) const; protected: - // Init the stack state according to the given rule position. - // If init_rule_position is {}, init the stack with the main rule. - void InitStackState(RulePosition init_rule_position = {}); + // Push an initial stack state according to the given rule position. + // If init_rule_position is kInvalidRulePosition, init the stack with the main rule. + void PushInitialState(RulePosition init_rule_position, bool expand_init_rule_position); - // Update the char_class_star_id field of the given rule_position, if it refers to a character - // class star rule. - void UpdateCharClassStarId(RulePosition* rule_position) const; + // Check if the character is accepted by the current rule position. + bool CheckIfAccepted(const RulePosition& rule_position, uint8_t char_value) const; /*! * \brief Find the next position in the rule. If the next position is at the end of the rule, - * the result depends on the consider_parent parameter: - * - false: kInvalidRulePosition will be returned. - * - true: the next position of the parent rule will be returned. If the current rule is the root - * rule, the RulePosition will be returned as is to indicate the end of the grammar. + * and consider_parent is true, will iteratively find the next position in the parent rule. * \param rule_position The current position. - * \param consider_parent Whether to consider the parent position if the current position is at - * the end of the rule. + * \param consider_parent Whether to consider the parent position if the current position is + * at the end of the rule. + * \returns (success, next_rule_position), indicating if the iteration is successful and the + * next rule position. */ - RulePosition IterateToNextPosition(const RulePosition& rule_position, bool consider_parent) const; + std::pair GetNextPositionInSequence(const RulePosition& rule_position, + bool consider_parent) const; + + // Return the updated rule position after accepting the char + RulePosition UpdatePositionWithChar(const RulePosition& rule_position, uint8_t char_value) const; /*! - * \brief Expand the given rule position (may be a RuleRef element) s.t. every new position is a - * CharacterClass or refers to a CharacterClassStar rule. Push all new positions into - * new_stack_tops. - * \details This method will start from cur_rule_position and continuously iterate to the next - * position as long as the current position can be empty (e.g. the current position is a - * reference to an rule that can be empty, or to a character class star rule). If the current - * position can not be empty, stop expanding. All positions collected will be pushed into - * new_stack_tops. + * \brief Expand the given rule position to all possible positions approachable in the grammar. + * The expanded positions must refers to an element (CharacterClass or CharacterClassStar or + * ByteString) in a rule. Push all new positions into new_stack_tops. + * \example + * A ::= "a" B [a-z]* "c" + * B ::= "b" | "" * - * If the end of the current rule is reached: - * - If is_outmost_level is true, we can go to the next position in the parent rule. - * - Otherwise, stop iteration. + * Input position: (rule=A, position=B) + * Approachable positions: (rule=B, position="b"), (rule=A, position=[a-z]*), + * (rule=A, position="c"), since B and [a-z]* can be empty. * \param cur_rule_position The current rule position. * \param new_stack_tops The vector to store the new stack tops. - * \param is_outmost_level Whether the current position is the outmost level of the rule. - * \param first_id_if_inserted Being not -1 means the first node is already inserted. This is the - * id of the first node. This is used to avoid inserting the same node twice. - * \return Whether the end of the rule can be reached. Used as the condition of recursion. + * \param consider_parent Whether consider expanding the elements in the parent rule. Useful for + * inner recursion. + * \param first_id_if_inserted An optimization. When cur_rule_position is already inserted to + * the state tree, pass its id to avoid inserting it again. -1 (ignore it) by default. + * \return Whether the end of the rule can be reached. Useful for inner recursion. */ bool ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool is_outmost_level, int32_t first_id_if_inserted = -1); + bool consider_parent = true, int32_t first_id_if_inserted = -1); + // The matched grammar. BNFGrammar grammar_; + // The tree storing all states RulePositionTree tree_; + // The tracked history of stack tops (each stack top refers to a node in the tree). + // We store the stack tops in different steps in the history to support rollback. StackTopsHistory stack_tops_history_; - // Temporary data for AcceptCodepoint. + // Temporary data for AcceptChar. std::vector tmp_new_stack_tops_; }; /*! \brief Check the codepoint is contained in the character class. */ -inline bool CharacterClassContains(const BNFGrammarNode::RuleExpr& rule_expr, - TCodepoint codepoint) { - DCHECK(rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass || - rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass); - for (int i = 0; i < rule_expr.size(); i += 2) { - if (rule_expr.data[i] <= codepoint && codepoint <= rule_expr.data[i + 1]) { - return rule_expr.type == BNFGrammarNode::RuleExprType::kCharacterClass; +inline bool GrammarStateMatcherBase::CheckIfAccepted(const RulePosition& rule_position, + uint8_t char_value) const { + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + if (current_element.type == RuleExprType::kCharacterClass || + current_element.type == RuleExprType::kCharacterClassStar) { + if (rule_position.left_utf8_bytes > 0) { + return (char_value & 0xC0) == 0x80; + } + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + if (!accepted) { + return false; + } + bool is_negative = static_cast(current_element[0]); + if (num_bytes > 1) { + return is_negative; + } + for (int i = 1; i < current_element.size(); i += 2) { + if (current_element[i] <= char_value && char_value <= current_element[i + 1]) { + return !is_negative; + } + } + return is_negative; + } else if (current_element.type == RuleExprType::kByteString) { + return current_element[rule_position.element_in_string] == char_value; + } else { + LOG(FATAL) << "Unexpected RuleExprType in CheckIfAccepted: " + << static_cast(current_element.type); + } +} + +inline RulePosition GrammarStateMatcherBase::UpdatePositionWithChar( + const RulePosition& rule_position, uint8_t char_value) const { + auto current_sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + auto current_element = grammar_->GetRuleExpr(current_sequence[rule_position.element_id]); + RulePosition new_rule_position = rule_position; + switch (current_element.type) { + case RuleExprType::kCharacterClass: { + if (rule_position.left_utf8_bytes > 1) { + new_rule_position.left_utf8_bytes -= 1; + return new_rule_position; + } else if (rule_position.left_utf8_bytes == 1) { + return GetNextPositionInSequence(rule_position, true).second; + } + // If no left utf8 bytes, check the first byte to find the left bytes needed. + DCHECK(rule_position.left_utf8_bytes == 0); + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + DCHECK(accepted); + if (num_bytes > 1) { + new_rule_position.left_utf8_bytes = num_bytes - 1; + return new_rule_position; + } + return GetNextPositionInSequence(rule_position, true).second; + } + case RuleExprType::kCharacterClassStar: { + if (rule_position.left_utf8_bytes >= 1) { + new_rule_position.left_utf8_bytes -= 1; + } else { + DCHECK(rule_position.left_utf8_bytes == 0); + auto [accepted, num_bytes, codepoint] = HandleUTF8FirstByte(char_value); + DCHECK(accepted); + new_rule_position.left_utf8_bytes = num_bytes - 1; + } + return new_rule_position; + } + case RuleExprType::kByteString: { + if (rule_position.element_in_string + 1 < current_element.size()) { + new_rule_position.element_in_string += 1; + return new_rule_position; + } + return GetNextPositionInSequence(rule_position, true).second; } + default: + LOG(FATAL) << "Unexpected RuleExprType in UpdatePositionWithChar: " + << static_cast(current_element.type); } - return rule_expr.type == BNFGrammarNode::RuleExprType::kNegCharacterClass; } -inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool verbose) { +inline bool GrammarStateMatcherBase::AcceptChar(uint8_t char_value, bool verbose) { if (verbose) { - std::cout << "Stack before accepting: " << PrintStackState() << std::endl; + LOG(INFO) << "Matching char: " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\""; + LOG(INFO) << "Previous stack: " << PrintStackState(); } const auto& prev_stack_tops = stack_tops_history_.GetLatest(); @@ -135,37 +212,31 @@ inline bool GrammarStateMatcherBase::AcceptCodepoint(TCodepoint codepoint, bool continue; } - auto current_char_class = - cur_rule_position.char_class_star_id != -1 - ? grammar_->GetRuleExpr(cur_rule_position.char_class_star_id) - : grammar_->GetRuleExpr(current_sequence[cur_rule_position.element_id]); - DCHECK(current_char_class.type == RuleExprType::kCharacterClass || - current_char_class.type == RuleExprType::kNegCharacterClass); - auto ok = CharacterClassContains(current_char_class, codepoint); - if (!ok) { + auto accepted = CheckIfAccepted(cur_rule_position, char_value); + if (!accepted) { continue; } - if (cur_rule_position.char_class_star_id == -1) { - auto next_rule_position = IterateToNextPosition(cur_rule_position, true); - DCHECK(next_rule_position != kInvalidRulePosition); - ExpandRulePosition(next_rule_position, &tmp_new_stack_tops_, true); + auto new_rule_position = UpdatePositionWithChar(cur_rule_position, char_value); + + if (new_rule_position == cur_rule_position) { + ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true, prev_top); } else { - ExpandRulePosition(cur_rule_position, &tmp_new_stack_tops_, true, prev_top); + ExpandRulePosition(new_rule_position, &tmp_new_stack_tops_, true); } } if (tmp_new_stack_tops_.empty()) { if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Rejected" - << std::endl; + LOG(INFO) << "Character " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\" Rejected"; } return false; } stack_tops_history_.PushHistory(tmp_new_stack_tops_); if (verbose) { - std::cout << "Codepoint: " << codepoint << " \"" << PrintAsEscaped(codepoint) << "\" Accepted" - << std::endl; - std::cout << "Stack after accepting: " << PrintStackState() << std::endl; + LOG(INFO) << "Character: " << static_cast(char_value) << " \"" + << PrintAsEscaped(char_value) << "\" Accepted"; + LOG(INFO) << "New stack after acceptance: " << PrintStackState(); } #if TVM_LOG_DEBUG stack_tops_history_.CheckWellFormed(); @@ -179,80 +250,92 @@ inline bool GrammarStateMatcherBase::CanReachEnd() const { [&](int32_t id) { return tree_.IsEndPosition(tree_[id]); }); } -inline void GrammarStateMatcherBase::RollbackCodepoints(int rollback_codepoint_cnt) { - stack_tops_history_.Rollback(rollback_codepoint_cnt); +inline void GrammarStateMatcherBase::RollbackChars(int rollback_cnt) { + stack_tops_history_.Rollback(rollback_cnt); } -inline void GrammarStateMatcherBase::DiscardEarliestCodepoints(int discard_codepoint_cnt) { - stack_tops_history_.DiscardEarliest(discard_codepoint_cnt); +inline void GrammarStateMatcherBase::DiscardEarliestChars(int discard_cnt) { + stack_tops_history_.DiscardEarliest(discard_cnt); } inline std::string GrammarStateMatcherBase::PrintStackState(int steps_behind_latest) const { return stack_tops_history_.PrintHistory(steps_behind_latest); } -inline void GrammarStateMatcherBase::InitStackState(RulePosition init_rule_position) { +inline void GrammarStateMatcherBase::PushInitialState(RulePosition init_rule_position, + bool expand_init_rule_position) { if (init_rule_position == kInvalidRulePosition) { // Initialize the stack with the main rule. auto main_rule = grammar_->GetMainRule(); auto main_rule_body = grammar_->GetRuleExpr(main_rule.body_expr_id); - std::vector new_stack_tops; + std::vector stack_tops; for (auto i : main_rule_body) { auto init_rule_position = RulePosition(0, i, 0, RulePosition::kNoParent); - UpdateCharClassStarId(&init_rule_position); - ExpandRulePosition(init_rule_position, &new_stack_tops, true); + if (expand_init_rule_position) { + ExpandRulePosition(init_rule_position, &stack_tops, true); + } else { + stack_tops.push_back(tree_.NewNode(init_rule_position)); + } } - stack_tops_history_.PushHistory(new_stack_tops); + stack_tops_history_.PushHistory(stack_tops); } else { - stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); - } -} - -inline void GrammarStateMatcherBase::UpdateCharClassStarId(RulePosition* rule_position) const { - auto rule_expr = grammar_->GetRuleExpr(rule_position->sequence_id); - auto element = grammar_->GetRuleExpr(rule_expr[rule_position->element_id]); - if (element.type == RuleExprType::kRuleRef) { - auto sub_rule_body = grammar_->GetRuleExpr(grammar_->GetRule(element[0]).body_expr_id); - if (sub_rule_body.type == RuleExprType::kCharacterClassStar) { - rule_position->char_class_star_id = sub_rule_body[0]; + if (expand_init_rule_position) { + std::vector stack_tops; + ExpandRulePosition(init_rule_position, &stack_tops, true); + stack_tops_history_.PushHistory(stack_tops); + } else { + stack_tops_history_.PushHistory({tree_.NewNode(init_rule_position)}); } } } -inline RulePosition GrammarStateMatcherBase::IterateToNextPosition( +inline std::pair GrammarStateMatcherBase::GetNextPositionInSequence( const RulePosition& rule_position, bool consider_parent) const { - auto next_position = RulePosition(rule_position.rule_id, rule_position.sequence_id, - rule_position.element_id + 1, rule_position.parent_id); - auto rule_expr = grammar_->GetRuleExpr(rule_position.sequence_id); - auto current_sequence_length = rule_expr.size(); - DCHECK(next_position.element_id <= current_sequence_length); - - if (next_position.element_id < current_sequence_length) { - // Update char_class_star_id if the position refers to a character class star rule. - UpdateCharClassStarId(&next_position); - return next_position; + auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + + auto next_position = rule_position; + next_position.element_id += 1; + next_position.element_in_string = 0; + next_position.left_utf8_bytes = 0; + + DCHECK(next_position.element_id <= sequence.size()); + + if (next_position.element_id < sequence.size()) { + return {true, next_position}; } if (!consider_parent) { - return kInvalidRulePosition; + return {false, kInvalidRulePosition}; } - if (next_position.parent_id == RulePosition::kNoParent) { - return next_position; - } else { - auto parent_rule_position = tree_[next_position.parent_id]; - return IterateToNextPosition(parent_rule_position, true); + // Find the next position in the parent rule + while (next_position.parent_id != RulePosition::kNoParent) { + next_position = tree_[next_position.parent_id]; + next_position.element_id += 1; + DCHECK(next_position.element_in_string == 0); + DCHECK(next_position.left_utf8_bytes == 0); + + sequence = grammar_->GetRuleExpr(next_position.sequence_id); + DCHECK(next_position.element_id <= sequence.size()); + + if (next_position.element_id < sequence.size()) { + break; + } } + + return {true, next_position}; } inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_position, std::vector* new_stack_tops, - bool is_outmost_level, + bool consider_parent, int32_t first_id_if_inserted) { bool is_first = false; + bool is_iteration_successful = true; - for (; cur_rule_position != kInvalidRulePosition; - cur_rule_position = IterateToNextPosition(cur_rule_position, is_outmost_level)) { + for (; is_iteration_successful; + std::tie(is_iteration_successful, cur_rule_position) = + GetNextPositionInSequence(cur_rule_position, consider_parent)) { // Insert the node to the tree, if not inserted before. int32_t new_node_id; if (is_first && first_id_if_inserted != -1) { @@ -263,7 +346,7 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po is_first = false; // Case 1. The current position points to the end of the grammar. - if (is_outmost_level) { + if (consider_parent) { if (tree_.IsEndPosition(cur_rule_position)) { new_stack_tops->push_back(new_node_id); return true; @@ -272,42 +355,39 @@ inline bool GrammarStateMatcherBase::ExpandRulePosition(RulePosition cur_rule_po DCHECK(!tree_.IsEndPosition(cur_rule_position)); } - // Case 2. The current position refers to a character class star rule. It can be empty. - if (cur_rule_position.char_class_star_id != -1) { - new_stack_tops->push_back(new_node_id); - continue; - } - - // Case 3. Character class: cannot be empty. auto sequence = grammar_->GetRuleExpr(cur_rule_position.sequence_id); auto element = grammar_->GetRuleExpr(sequence[cur_rule_position.element_id]); - if (element.type == RuleExprType::kCharacterClass || - element.type == RuleExprType::kNegCharacterClass) { - new_stack_tops->push_back(new_node_id); - return false; - } - - // Case 4. The current position refers to a normal rule, i.e. a rule of choices of sequences. - DCHECK(element.type == RuleExprType::kRuleRef); - auto sub_rule_id = element[0]; - auto sub_rule = grammar_->GetRule(sub_rule_id); - auto sub_rule_body = grammar_->GetRuleExpr(sub_rule.body_expr_id); - DCHECK(sub_rule_body.type == RuleExprType::kChoices); - - bool contain_empty = false; - - for (auto sequence_id : sub_rule_body) { - auto sequence = grammar_->GetRuleExpr(sequence_id); - if (sequence.type == RuleExprType::kEmptyStr) { - contain_empty = true; - continue; + bool can_be_empty = false; + + if (element.type == RuleExprType::kRuleRef) { + // Case 2. The current position refers to another rule. + auto ref_rule = grammar_->GetRule(element[0]); + auto ref_rule_body = grammar_->GetRuleExpr(ref_rule.body_expr_id); + DCHECK(ref_rule_body.type == RuleExprType::kChoices); + + for (auto sequence_id : ref_rule_body) { + auto ref_rule_sequence = grammar_->GetRuleExpr(sequence_id); + if (ref_rule_sequence.type == RuleExprType::kEmptyStr) { + can_be_empty = true; + continue; + } + auto ref_rule_position = RulePosition(element[0], sequence_id, 0, new_node_id); + // Find the positions in every choice of the referred rule + can_be_empty |= ExpandRulePosition(ref_rule_position, new_stack_tops, false); } - auto sub_rule_position = RulePosition(sub_rule_id, sequence_id, 0, new_node_id); - UpdateCharClassStarId(&sub_rule_position); - contain_empty |= ExpandRulePosition(sub_rule_position, new_stack_tops, false); + } else if (element.type == RuleExprType::kCharacterClass || + element.type == RuleExprType::kByteString) { + // Case 3. Character class or byte string. cannot be empty. + new_stack_tops->push_back(new_node_id); + can_be_empty = false; + } else { + DCHECK(element.type == RuleExprType::kCharacterClassStar); + // Case 4. Character class star. Might be empty. + new_stack_tops->push_back(new_node_id); + can_be_empty = cur_rule_position.left_utf8_bytes == 0; } - if (!contain_empty) { + if (!can_be_empty) { return false; } } diff --git a/cpp/serve/grammar/grammar_state_matcher_preproc.h b/cpp/serve/grammar/grammar_state_matcher_preproc.h index f63eee2c5c..dc9fb9646e 100644 --- a/cpp/serve/grammar/grammar_state_matcher_preproc.h +++ b/cpp/serve/grammar/grammar_state_matcher_preproc.h @@ -9,6 +9,7 @@ #include #include "../../support/encoding.h" +#include "../../support/utils.h" #include "grammar.h" #include "grammar_state_matcher_base.h" @@ -18,34 +19,47 @@ namespace serve { using namespace tvm::runtime; -/*! \brief A token and its id. */ -struct TokenAndId { - std::vector token; - int32_t id; - /*! \brief Compare tokens by their unicode codepoint sequence. */ - bool operator<(const TokenAndId& other) const; -}; - /*! - * \brief Preprocessed information, for a given specific rule and position, divides the token set + * \brief Preprocessed information, for a given specific RulePosition, divides the token set * into three categories: accepted, rejected, and uncertain. - * \note Since the union of these three sets is the whole token set, we only need to store the - * smaller two sets. The unsaved set is specified by not_saved_index. - * \note These indices are the indices of sorted_token_codepoints in the GrammarStateInitContext + * Accepted: tokens that can be determined by the current RulePosition to be acceptable + * Rejected: tokens that can be determined by the current RulePosition to be unacceptable + * Uncertain: tokens that need the state of the parent RulePositions to determine if acceptable + * + * \note uncertain indices are stored directly. Accepted / rejected indices have three ways to + * store to reduce memory and computation usage. See SaveType. + * \note These indices are the indices of sorted_token_table in the GrammarStateInitContext * object, instead of the token ids. That helps the matching process. */ struct CatagorizedTokens { + enum class SaveType { + // Only store all accepted token indices. Then rejected indices = all_indices - accepted_indices + // - uncertain_indices. This is useful when |accepted_indices| < |rejected_indices|. + kAccepted = 0, + // Only store all accepted token indices. Then accepted indices = all_indices - rejected_indices + // - uncertain_indices. This is useful when |accepted_indices| > |rejected_indices|. + kRejected = 1, + // Store all accepted token indices in a bitset. This is useful when both |accepted_indices| and + // |rejected_indices| are large. + kAcceptedBitset = 2 + }; + SaveType save_type; + + static constexpr int USE_BITSET_THRESHOLD = 200; + std::vector accepted_indices; std::vector rejected_indices; + DynamicBitset accepted_bitset; + std::vector uncertain_indices; - enum class NotSavedIndex { kAccepted = 0, kRejected = 1, kUncertain = 2 }; - NotSavedIndex not_saved_index; CatagorizedTokens() = default; - CatagorizedTokens(std::vector&& accepted_indices, - std::vector&& rejected_indices, - std::vector&& uncertain_indices); + CatagorizedTokens(int vocab_size, + const std::vector>& sorted_token_table, + const std::vector& accepted_indices, + const std::vector& rejected_indices, + const std::vector& uncertain_indices); }; /*! @@ -57,189 +71,227 @@ class GrammarStateInitContext { public: /******************* Information about the tokenizer *******************/ - /*! \brief The token table. Now only used for debug purpose. */ - std::vector token_table; - /*! \brief The vocabulary size of the tokenizer. */ + /*! \brief The vocabulary size of the tokenizer. Special tokens are included. */ size_t vocab_size; - /*! \brief All tokens represented by the id and codepoints of each. The tokens are sorted by - * codepoint values to reuse the common prefix during matching. */ - std::vector sorted_token_codepoints; - /*! \brief The mapping from token id to token represented by codepoints. Only contains - * non-special and non-stop tokens. */ - std::unordered_map id_to_token_codepoints; - /*! \brief The stop tokens. They can be accepted iff GramamrMatcher can reach the end of the - * grammar. */ + /*! \brief The token table. Special tokens are included. */ + std::vector token_table; + /*! \brief All (id, token) pairs sorted in lexicographic order. This sorting is done to + * maximize prefix reuse during matching. Special tokens and stop tokens are not included. */ + std::vector> sorted_token_table; + /*! \brief The stop tokens. When the GrammarStateMatcher can reach the end of the= grammar, + * stop tokens can be accepted. */ std::vector stop_token_ids; - /*! \brief The special tokens. Currently we will ignore these tokens during grammar-guided - * matching. */ - std::vector special_token_ids; + /*! \brief The special tokens. These tokens are ignored (masked out) during the grammar-guided + * generation. */ + std::unordered_set special_token_ids; /******************* Information about the grammar *******************/ + /*! \brief The grammar for the GrammarStateMatcher. */ BNFGrammar grammar; /******************* Grammar-specific tokenizer information *******************/ - /*! \brief A sequence id and its position. */ - struct SequenceIdAndPosition { - int32_t sequence_id; - int32_t element_id; - bool operator==(const SequenceIdAndPosition& other) const { - return sequence_id == other.sequence_id && element_id == other.element_id; + struct RulePositionEqual { + std::size_t operator()(const RulePosition& lhs, const RulePosition& rhs) const noexcept { + return lhs.sequence_id == rhs.sequence_id && lhs.element_id == rhs.element_id && + lhs.left_utf8_bytes == rhs.left_utf8_bytes && + lhs.element_in_string == rhs.element_in_string; } }; - /*! \brief Hash function for SequenceIdAndPosition. */ - struct SequenceIdAndPositionHash { - std::size_t operator()(const SequenceIdAndPosition& k) const { - return std::hash()(k.sequence_id) ^ (std::hash()(k.element_id) << 1); + struct RulePositionHash { + std::size_t operator()(const RulePosition& rule_position) const noexcept { + return HashCombine(rule_position.sequence_id, rule_position.element_id, + rule_position.left_utf8_bytes, rule_position.element_in_string); } }; - /*! \brief Mapping from sequence id and its position to the catagorized tokens. */ - std::unordered_map + /*! \brief Mapping from RulePositions to the catagorized tokens. */ + std::unordered_map catagorized_tokens_for_grammar; }; -/* \brief The concrete implementation of GrammarStateMatcherNode. */ +/*! \brief The concrete implementation of GrammarStateMatcherNode. */ class GrammarStateMatcherForInitContext : public GrammarStateMatcherBase { public: + // Do not expand the initial rule position: we want to find the accepted/rejected tokens + // that exactly start from the initial rule position. GrammarStateMatcherForInitContext(const BNFGrammar& grammar, RulePosition init_rule_position) - : GrammarStateMatcherBase(grammar, init_rule_position) {} - - CatagorizedTokens GetCatagorizedTokens(const std::vector& sorted_token_codepoints, - bool is_main_rule); + : GrammarStateMatcherBase(grammar, init_rule_position, false), + init_rule_id(init_rule_position.rule_id) {} + + /*! + * \brief Get the catagorized tokens for the given RulePosition. + * \param consider_parent_rule Whether to consider the parent rule. If false, there will be + * no uncertain tokens. Useful for the main rule. + */ + CatagorizedTokens GetCatagorizedTokens( + int vocab_size, const std::vector>& sorted_token_table, + bool consider_parent_rule); private: using RuleExpr = BNFGrammarNode::RuleExpr; using RuleExprType = BNFGrammarNode::RuleExprType; + /*! \brief Check if a token can pass the lookahead assertion. */ + bool IsTokenPassLookaheadAssertion(const std::string& token, + const std::vector& can_reach_end_stack); + + // The id of the initial rule. + int32_t init_rule_id; + // Temporary data for GetCatagorizedTokens. std::vector tmp_accepted_indices_; std::vector tmp_rejected_indices_; std::vector tmp_uncertain_indices_; - std::vector tmp_can_see_end_stack_; + std::vector tmp_can_reach_end_stack_; + std::vector tmp_can_reach_end_prefix_or_stack_; }; -inline bool TokenAndId::operator<(const TokenAndId& other) const { - for (size_t i = 0; i < token.size(); ++i) { - if (i >= other.token.size()) { - return false; - } - if (token[i] < other.token[i]) { - return true; - } else if (token[i] > other.token[i]) { - return false; +inline CatagorizedTokens::CatagorizedTokens( + int vocab_size, const std::vector>& sorted_token_table, + const std::vector& accepted_indices, const std::vector& rejected_indices, + const std::vector& uncertain_indices) { + auto size_acc = accepted_indices.size(); + auto size_rej = rejected_indices.size(); + + save_type = size_acc >= USE_BITSET_THRESHOLD && size_rej >= USE_BITSET_THRESHOLD + ? SaveType::kAcceptedBitset + : size_acc < size_rej ? SaveType::kAccepted + : SaveType::kRejected; + + if (save_type == SaveType::kAcceptedBitset) { + accepted_bitset = DynamicBitset(vocab_size); + for (auto idx : accepted_indices) { + accepted_bitset.Set(sorted_token_table[idx].first, true); } + } else if (save_type == SaveType::kAccepted) { + this->accepted_indices = accepted_indices; + } else { + this->rejected_indices = rejected_indices; } - return token.size() < other.token.size(); + + this->uncertain_indices = uncertain_indices; } -inline CatagorizedTokens::CatagorizedTokens(std::vector&& accepted_indices, - std::vector&& rejected_indices, - std::vector&& uncertain_indices) { - auto size_acc = accepted_indices.size(); - auto size_rej = rejected_indices.size(); - auto size_unc = uncertain_indices.size(); - not_saved_index = - (size_acc >= size_rej && size_acc >= size_unc) - ? NotSavedIndex::kAccepted - : (size_rej >= size_unc ? NotSavedIndex::kRejected : NotSavedIndex::kUncertain); - - if (not_saved_index != NotSavedIndex::kAccepted) { - this->accepted_indices = std::move(accepted_indices); +bool GrammarStateMatcherForInitContext::IsTokenPassLookaheadAssertion( + const std::string& token, const std::vector& can_reach_end_stack) { + auto lookahead_assertion_id = grammar_->GetRule(init_rule_id).lookahead_assertion_id; + if (lookahead_assertion_id == -1) { + return true; } - if (not_saved_index != NotSavedIndex::kRejected) { - this->rejected_indices = std::move(rejected_indices); - } - if (not_saved_index != NotSavedIndex::kUncertain) { - this->uncertain_indices = std::move(uncertain_indices); + auto lookahead_rule_position = RulePosition(-1, lookahead_assertion_id, 0); + PushInitialState(lookahead_rule_position, true); + int token_len = token.size(); + + // Find all positions that can come to and end. Then check if the suffix from that position + // can be accepted by the lookahead assertion. + for (int i = static_cast(can_reach_end_stack.size()); i >= 0; --i) { + if (!can_reach_end_stack[i]) { + continue; + } + int last_accept_pos = i - 1; + for (int pos = i; pos < token_len; ++pos) { + if (!AcceptChar(token[pos])) { + break; + } + last_accept_pos = pos; + // Case 1. The whole rule is finished. + if (CanReachEnd()) { + // accepted chars: pos - i + 1 + // we need to rollback the pushed initial state as well + RollbackChars(pos - i + 2); + return true; + } + } + // Case 2. The whole token is accepted + if (last_accept_pos == token_len - 1) { + RollbackChars(last_accept_pos - i + 2); + return true; + } + // Case 3. The token is not accepted. Check the next position. + RollbackChars(last_accept_pos - i + 1); } + + RollbackChars(1); + return false; } inline CatagorizedTokens GrammarStateMatcherForInitContext::GetCatagorizedTokens( - const std::vector& sorted_token_codepoints, bool is_main_rule) { - // Support the current stack contains only one stack with one RulePosition. - // Iterate over all tokens. Split them into three categories: - // - accepted_indices: If a token is accepted by current rule - // - rejected_indices: If a token is rejected by current rule - // - uncertain_indices: If a prefix of a token is accepted by current rule and comes to the end - // of the rule. - - // Note many tokens may contain the same prefix, so we will avoid unnecessary matching - + int vocab_size, const std::vector>& sorted_token_table, + bool consider_parent_rule) { tmp_accepted_indices_.clear(); tmp_rejected_indices_.clear(); tmp_uncertain_indices_.clear(); + // For every character in the current token, stores whether it is possible to reach the end of - // the rule when matching until this character. Useful for rollback. - tmp_can_see_end_stack_.assign({CanReachEnd()}); + // the rule when matching until this character. Store it in a stack for later rollback. + tmp_can_reach_end_stack_.assign({CanReachEnd()}); + tmp_can_reach_end_prefix_or_stack_.assign({tmp_can_reach_end_stack_.back()}); int prev_matched_size = 0; - for (int i = 0; i < static_cast(sorted_token_codepoints.size()); ++i) { - const auto& token = sorted_token_codepoints[i].token; - const auto* prev_token = i > 0 ? &sorted_token_codepoints[i - 1].token : nullptr; - - // Find the longest common prefix with the accepted part of the previous token. - auto prev_useful_size = 0; - if (prev_token) { - prev_useful_size = std::min(prev_matched_size, static_cast(token.size())); - for (int j = 0; j < prev_useful_size; ++j) { - if (token[j] != (*prev_token)[j]) { - prev_useful_size = j; - break; - } - } - RollbackCodepoints(prev_matched_size - prev_useful_size); - tmp_can_see_end_stack_.erase( - tmp_can_see_end_stack_.end() - (prev_matched_size - prev_useful_size), - tmp_can_see_end_stack_.end()); - } + for (int i = 0; i < static_cast(sorted_token_table.size()); ++i) { + const auto& token = sorted_token_table[i].second; - // Find if the current token is accepted or rejected or uncertain. bool accepted = true; - bool can_see_end = tmp_can_see_end_stack_.back(); - prev_matched_size = prev_useful_size; - for (int j = prev_useful_size; j < token.size(); ++j) { - if (!AcceptCodepoint(token[j], false)) { + + // Many tokens may contain the same prefix, so we will avoid unnecessary matching + // by finding the longest common prefix with the previous token. + if (i > 0) { + const auto& prev_token = sorted_token_table[i - 1].second; + int lcp_len = + std::mismatch(token.begin(), token.end(), prev_token.begin(), prev_token.end()).first - + token.begin(); + if (lcp_len > prev_matched_size) { + // Case 1. The common prefix is rejected by the matcher in the last token. Reject directly. accepted = false; - break; + } else if (lcp_len < prev_matched_size) { + // Case 2. The common prefix is shorter than the previous matched size. Rollback + // the non-common part. + RollbackChars(prev_matched_size - lcp_len); + tmp_can_reach_end_stack_.erase( + tmp_can_reach_end_stack_.end() - (prev_matched_size - lcp_len), + tmp_can_reach_end_stack_.end()); + tmp_can_reach_end_prefix_or_stack_.erase( + tmp_can_reach_end_prefix_or_stack_.end() - (prev_matched_size - lcp_len), + tmp_can_reach_end_prefix_or_stack_.end()); } - if (CanReachEnd()) { - can_see_end = true; + prev_matched_size = std::min(prev_matched_size, lcp_len); + } + + if (accepted) { + // Accept the rest chars one by one + for (int j = prev_matched_size; j < token.size(); ++j) { + if (!AcceptChar(token[j], false)) { + accepted = false; + break; + } + tmp_can_reach_end_stack_.push_back(CanReachEnd()); + tmp_can_reach_end_prefix_or_stack_.push_back(tmp_can_reach_end_stack_.back() || + tmp_can_reach_end_prefix_or_stack_.back()); + prev_matched_size = j + 1; } - tmp_can_see_end_stack_.push_back(can_see_end); - prev_matched_size = j + 1; } + + bool can_reach_end = tmp_can_reach_end_prefix_or_stack_.back(); + if (accepted) { tmp_accepted_indices_.push_back(i); - } else if (can_see_end && !is_main_rule) { - // If the current rule is the main rule, there will be no uncertain indices since we will - // never consider its parent rule. Unaccepted tokens are just rejected. + } else if (can_reach_end && consider_parent_rule && + IsTokenPassLookaheadAssertion(token, tmp_can_reach_end_stack_)) { + // 1. If the current rule is the main rule (consider_parent_rule=false), there are no + // uncertain tokens. Not accepted tokens are just rejected. + // 2. If a token cannot pass the lookahead assertion, it is rejected. tmp_uncertain_indices_.push_back(i); } else { tmp_rejected_indices_.push_back(i); } } - RollbackCodepoints(prev_matched_size); - return CatagorizedTokens(std::move(tmp_accepted_indices_), std::move(tmp_rejected_indices_), - std::move(tmp_uncertain_indices_)); -} - -inline std::string ReplaceUnderscoreWithSpace(const std::string& str, - const std::string& kSpecialUnderscore) { - std::string res; - size_t pos = 0; - while (pos < str.size()) { - size_t found = str.find(kSpecialUnderscore, pos); - if (found == std::string::npos) { - res += str.substr(pos); - break; - } - res += str.substr(pos, found - pos) + " "; - pos = found + kSpecialUnderscore.size(); - } - return res; + // Rollback the last matched part + RollbackChars(prev_matched_size); + return CatagorizedTokens(vocab_size, sorted_token_table, tmp_accepted_indices_, + tmp_rejected_indices_, tmp_uncertain_indices_); } inline std::shared_ptr GrammarStateMatcher::CreateInitContext( @@ -248,87 +300,94 @@ inline std::shared_ptr GrammarStateMatcher::CreateInitC auto ptr = std::make_shared(); ptr->grammar = grammar; - ptr->token_table = token_table; ptr->vocab_size = token_table.size(); + ptr->token_table = token_table; if (ptr->vocab_size == 0) { return ptr; } for (int i = 0; i < token_table.size(); ++i) { - auto token = token_table[i]; - if (token == "" || token == "" || token == "") { - ptr->special_token_ids.push_back(i); - } else if (token == "") { + const auto& token = token_table[i]; + // LLaMA2: + // LLaMA3: <|end_of_text|>, <|eot_id|> + // Phi-2: <|endoftext|> + // Gemma: , + if (token == "" || token == "<|end_of_text|>" || token == "<|eot_id|>" || + token == "<|endoftext|>" || token == "" || token == "") { ptr->stop_token_ids.push_back(i); - } else if (token.size() == 1 && - (static_cast(token[0]) >= 128 || token[0] == 0)) { - // Currently we consider all tokens with one character that >= 128 as special tokens, - // and will ignore generating them during grammar-guided generation. - ptr->special_token_ids.push_back(i); + } else if ((token[0] == '<' && token[token.size() - 1] == '>' && token.size() >= 3) || + token == "[@BOS@]") { + // gemma treats [@BOS@] as a special token + ptr->special_token_ids.insert(i); } else { - // First replace the special underscore with space. - auto codepoints = ParseUTF8(token.c_str()); - DCHECK(!codepoints.empty() && - codepoints[0] != static_cast(CharHandlingError::kInvalidUtf8)) - << "Invalid token: " << token; - ptr->sorted_token_codepoints.push_back({codepoints, i}); - ptr->id_to_token_codepoints[i] = {codepoints, i}; + ptr->sorted_token_table.push_back({i, token}); } } - std::sort(ptr->sorted_token_codepoints.begin(), ptr->sorted_token_codepoints.end()); + + auto f_compare_token = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(ptr->sorted_token_table.begin(), ptr->sorted_token_table.end(), f_compare_token); // Find the corresponding catagorized tokens for: - // 1. All character elements in the grammar - // 2. All RuleRef elements that refers to a rule containing a CharacterClassStar RuleExpr. - for (int i = 0; i < static_cast(grammar->NumRules()); ++i) { - auto rule = grammar->GetRule(i); - auto rule_expr = grammar->GetRuleExpr(rule.body_expr_id); - // Skip CharacterClassStar since we just handle it at the reference element during matching. - if (rule_expr.type == RuleExprType::kCharacterClassStar) { - continue; - } - DCHECK(rule_expr.type == RuleExprType::kChoices); - for (auto sequence_id : rule_expr) { - auto sequence_expr = grammar->GetRuleExpr(sequence_id); - if (sequence_expr.type == RuleExprType::kEmptyStr) { + // 1. All character class or character class star (with last_utf8_bytes=0, 1, 2, 3) + // 2. All byte strings (with element_in_string=0, 1, 2, ...) + auto main_rule_id = grammar->GetMainRuleId(); + for (int rule_id = 0; rule_id < static_cast(grammar->NumRules()); ++rule_id) { + auto rule = grammar->GetRule(rule_id); + auto rule_body = grammar->GetRuleExpr(rule.body_expr_id); + DCHECK(rule_body.type == RuleExprType::kChoices); + for (auto sequence_id : rule_body) { + auto sequence = grammar->GetRuleExpr(sequence_id); + if (sequence.type == RuleExprType::kEmptyStr) { continue; } - DCHECK(sequence_expr.type == RuleExprType::kSequence); - for (int element_id = 0; element_id < sequence_expr.size(); ++element_id) { - auto element_expr = grammar->GetRuleExpr(sequence_expr[element_id]); - auto cur_rule_position = RulePosition{i, sequence_id, element_id}; - if (element_expr.type == RuleExprType::kRuleRef) { - auto ref_rule = grammar->GetRule(element_expr[0]); - auto ref_rule_expr = grammar->GetRuleExpr(ref_rule.body_expr_id); - if (ref_rule_expr.type == RuleExprType::kChoices) { - continue; - } else { - // Reference to a CharacterClassStar of a character class. - cur_rule_position.char_class_star_id = ref_rule_expr[0]; - } + DCHECK(sequence.type == RuleExprType::kSequence); + for (int element_id = 0; element_id < sequence.size(); ++element_id) { + auto element = grammar->GetRuleExpr(sequence[element_id]); + if (element.type == RuleExprType::kRuleRef) { + continue; } - auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, cur_rule_position); - auto cur_catagorized_tokens_for_grammar = - grammar_state_matcher.GetCatagorizedTokens(ptr->sorted_token_codepoints, i == 0); - ptr->catagorized_tokens_for_grammar[{sequence_id, element_id}] = - cur_catagorized_tokens_for_grammar; + auto add_catagorized_tokens = [&](const RulePosition& rule_position) { + auto grammar_state_matcher = GrammarStateMatcherForInitContext(grammar, rule_position); + auto cur_catagorized_tokens_for_grammar = grammar_state_matcher.GetCatagorizedTokens( + ptr->vocab_size, ptr->sorted_token_table, rule_id != main_rule_id); + ptr->catagorized_tokens_for_grammar[rule_position] = cur_catagorized_tokens_for_grammar; + }; + + auto cur_rule_position = RulePosition(rule_id, sequence_id, element_id); + if (element.type == RuleExprType::kByteString) { + for (int idx = 0; idx < element.size(); ++idx) { + cur_rule_position.element_in_string = idx; + add_catagorized_tokens(cur_rule_position); + } + } else { + DCHECK(element.type == RuleExprType::kCharacterClassStar || + element.type == RuleExprType::kCharacterClass); + for (int left_utf8_bytes = 0; left_utf8_bytes <= 3; ++left_utf8_bytes) { + cur_rule_position.left_utf8_bytes = left_utf8_bytes; + add_catagorized_tokens(cur_rule_position); + } + } } } } return ptr; } -class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { +class GrammarInitContextCacheImpl : public GrammarInitContextCacheNode { public: - GrammarInitContextStorageImpl(const std::vector& token_table); + GrammarInitContextCacheImpl(const std::vector& token_table); - std::shared_ptr GetInitContextForJSONSchema(const std::string& schema); + std::shared_ptr GetInitContextForJSONSchema( + const std::string& schema) final; - std::shared_ptr GetInitContextForJSON(); + std::shared_ptr GetInitContextForJSON() final; - void ClearCache(); + void Clear() final; private: /*! \brief The token table associated with this storage class. */ @@ -340,7 +399,7 @@ class GrammarInitContextStorageImpl : public GrammarInitContextStorageNode { std::shared_ptr init_ctx_for_json_; }; -inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( +inline GrammarInitContextCacheImpl::GrammarInitContextCacheImpl( const std::vector& token_table) : token_table_(token_table) { init_ctx_for_json_ = @@ -348,7 +407,7 @@ inline GrammarInitContextStorageImpl::GrammarInitContextStorageImpl( } inline std::shared_ptr -GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& schema) { +GrammarInitContextCacheImpl::GetInitContextForJSONSchema(const std::string& schema) { auto it = init_ctx_for_schema_cache_.find(schema); if (it != init_ctx_for_schema_cache_.end()) { return it->second; @@ -360,14 +419,14 @@ GrammarInitContextStorageImpl::GetInitContextForJSONSchema(const std::string& sc } inline std::shared_ptr -GrammarInitContextStorageImpl::GetInitContextForJSON() { +GrammarInitContextCacheImpl::GetInitContextForJSON() { return init_ctx_for_json_; } -inline void GrammarInitContextStorageImpl::ClearCache() { init_ctx_for_schema_cache_.clear(); } +inline void GrammarInitContextCacheImpl::Clear() { init_ctx_for_schema_cache_.clear(); } -GrammarInitContextStorage::GrammarInitContextStorage(const std::vector& token_table) - : ObjectRef(make_object(token_table)) {} +GrammarInitContextCache::GrammarInitContextCache(const std::vector& token_table) + : ObjectRef(make_object(token_table)) {} } // namespace serve } // namespace llm diff --git a/cpp/serve/grammar/grammar_state_matcher_state.h b/cpp/serve/grammar/grammar_state_matcher_state.h index 47f3e11c7b..1b8a34074f 100644 --- a/cpp/serve/grammar/grammar_state_matcher_state.h +++ b/cpp/serve/grammar/grammar_state_matcher_state.h @@ -20,18 +20,20 @@ using namespace tvm::runtime; /*! \brief Specifies a position in a rule. */ struct RulePosition { - /*! \brief The rule's id. */ + /*! \brief The rule's id. Used for debug purposes. */ int32_t rule_id = -1; /*! \brief Which choice in this rule is selected. */ int32_t sequence_id = -1; - /*! \brief Which element of the choice sequence is being visited. */ + /*! \brief Which element of the choice sequence is to be visited. */ int32_t element_id = -1; - /*! - * \brief If the element refers to another rule, and the body of another rule is a - * CharacterClassStar RuleExpr, this field will be set to the id of the character class. - * This is for the special support of CharacterClassStar. - */ - int32_t char_class_star_id = -1; + + /*! \brief The number of left utf8 bytes in the current element. Used when the element is + * a character class or a character class star. */ + int32_t left_utf8_bytes = 0; + /*! \brief The next position to match in the current byte string. Used when the element is + * a byte string. */ + int32_t element_in_string = 0; + /*! \brief The id of the parent node in the RulePositionTree. */ int32_t parent_id = -1; /*! \brief The reference count of this RulePosition. If reduces to zero, the node will be @@ -43,24 +45,21 @@ struct RulePosition { constexpr RulePosition() = default; constexpr RulePosition(int32_t rule_id, int32_t sequence_id, int32_t element_id, - int32_t parent_id = kNoParent, int32_t char_class_star_id = -1) - : rule_id(rule_id), - sequence_id(sequence_id), - element_id(element_id), - char_class_star_id(char_class_star_id), - parent_id(parent_id) {} + int32_t parent_id = kNoParent) + : rule_id(rule_id), sequence_id(sequence_id), element_id(element_id), parent_id(parent_id) {} + + // The position is invalid when sequence_id is -1. + bool IsInvalid() const { return sequence_id == -1; } bool operator==(const RulePosition& other) const { return rule_id == other.rule_id && sequence_id == other.sequence_id && - element_id == other.element_id && char_class_star_id == other.char_class_star_id && - parent_id == other.parent_id; + element_id == other.element_id && parent_id == other.parent_id && + left_utf8_bytes == other.left_utf8_bytes && element_in_string == other.element_in_string; } - - bool operator!=(const RulePosition& other) const { return !(*this == other); } }; /*! \brief A special value for invalid RulePosition. */ -inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1, -1); +inline constexpr RulePosition kInvalidRulePosition(-1, -1, -1, -1); /*! \brief A buffer to manage all RulePositions. */ class RulePositionBuffer { @@ -76,7 +75,7 @@ class RulePositionBuffer { id = buffer_.size() - 1; } else { id = free_nodes_.back(); - DCHECK(buffer_[id] == kInvalidRulePosition); + DCHECK(buffer_[id].IsInvalid()); free_nodes_.pop_back(); } rule_position.reference_count = 0; @@ -86,7 +85,7 @@ class RulePositionBuffer { /*! \brief Free the RulePosition with the given id. */ void Free(int32_t id) { - DCHECK(buffer_[id] != kInvalidRulePosition); + DCHECK(!buffer_[id].IsInvalid()); buffer_[id] = kInvalidRulePosition; free_nodes_.push_back(id); } @@ -102,11 +101,13 @@ class RulePositionBuffer { /*! \brief Get the RulePosition with the given id. */ RulePosition& operator[](int32_t id) { - DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + DCHECK(id >= 0 && id < static_cast(buffer_.size())); + DCHECK(!buffer_[id].IsInvalid()); return buffer_[id]; } const RulePosition& operator[](int32_t id) const { - DCHECK(id < static_cast(buffer_.size()) && buffer_[id] != kInvalidRulePosition); + DCHECK(id >= 0 && id < static_cast(buffer_.size())); + DCHECK(!buffer_[id].IsInvalid()); return buffer_[id]; } @@ -145,7 +146,7 @@ class RulePositionTree { auto id = node_buffer_.Allocate(rule_position); if (rule_position.parent_id != RulePosition::kNoParent) { DCHECK(rule_position.parent_id < static_cast(node_buffer_.Capacity()) && - node_buffer_[rule_position.parent_id] != kInvalidRulePosition); + !node_buffer_[rule_position.parent_id].IsInvalid()); node_buffer_[rule_position.parent_id].reference_count++; } return id; @@ -183,7 +184,7 @@ class RulePositionTree { /*! \brief Get the RulePosition with the given id. */ const RulePosition& operator[](int32_t id) const { DCHECK(id != RulePosition::kNoParent); - DCHECK(node_buffer_[id] != kInvalidRulePosition); + DCHECK(!node_buffer_[id].IsInvalid()); return node_buffer_[id]; } @@ -331,15 +332,26 @@ inline std::string RulePositionTree::PrintNode(int32_t id) const { inline std::string RulePositionTree::PrintNode(const RulePosition& rule_position) const { std::stringstream ss; - ss << "RulePosition: rule " << rule_position.rule_id << ": " - << grammar_->GetRule(rule_position.rule_id).name; + ss << "RulePosition: rule " << rule_position.rule_id; + if (rule_position.rule_id != -1) { + ss << ": " << grammar_->GetRule(rule_position.rule_id).name; + } ss << ", sequence " << rule_position.sequence_id << ": " << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.sequence_id); ss << ", element id: " << rule_position.element_id; - if (rule_position.char_class_star_id != -1) { - ss << ", char class " << rule_position.char_class_star_id << ": " - << BNFGrammarPrinter(grammar_).PrintRuleExpr(rule_position.char_class_star_id) << "*"; + + auto sequence = grammar_->GetRuleExpr(rule_position.sequence_id); + if (rule_position.element_id < static_cast(sequence.size())) { + auto element = grammar_->GetRuleExpr(sequence[rule_position.element_id]); + if (element.type == BNFGrammarNode::RuleExprType::kByteString) { + ss << ", element in string: " << rule_position.element_in_string; + } else { + DCHECK(element.type == BNFGrammarNode::RuleExprType::kCharacterClass || + element.type == BNFGrammarNode::RuleExprType::kCharacterClassStar); + ss << ", left utf8 bytes: " << rule_position.left_utf8_bytes; + } } + ss << ", parent id: " << rule_position.parent_id << ", ref count: " << rule_position.reference_count; return ss.str(); @@ -370,7 +382,7 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid std::queue visit_queue; for (auto id : outside_pointers) { CHECK(id >= 0 && id < buffer_size); - CHECK(buffer[id] != kInvalidRulePosition); + CHECK(!buffer[id].IsInvalid()); new_reference_counter[id]++; if (visited[id] == false) { visited[id] = true; @@ -383,7 +395,7 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid const auto& rule_position = buffer[cur_id]; if (rule_position.parent_id != RulePosition::kNoParent) { CHECK(rule_position.parent_id >= 0 && rule_position.parent_id < buffer_size); - CHECK(buffer[rule_position.parent_id] != kInvalidRulePosition); + CHECK(!buffer[rule_position.parent_id].IsInvalid()); new_reference_counter[rule_position.parent_id]++; if (visited[rule_position.parent_id] == false) { visited[rule_position.parent_id] = true; @@ -394,11 +406,11 @@ inline void RulePositionTree::CheckWellFormed(const std::vector& outsid for (int i = 0; i < static_cast(buffer.size()); ++i) { if (free_nodes_set.count(i)) { - CHECK(buffer[i] == kInvalidRulePosition); + CHECK(buffer[i].IsInvalid()); CHECK(visited[i] == false); } else { CHECK(visited[i] == true); - CHECK(buffer[i] != kInvalidRulePosition); + CHECK(!buffer[i].IsInvalid()); CHECK(new_reference_counter[i] == buffer[i].reference_count) << "Reference counters unmatch for node #" << i << ": Updated " << new_reference_counter[i] << ", Original " << buffer[i].reference_count; diff --git a/cpp/serve/grammar/json_schema_converter.cc b/cpp/serve/grammar/json_schema_converter.cc index 83be710cf5..e0c465ba9e 100644 --- a/cpp/serve/grammar/json_schema_converter.cc +++ b/cpp/serve/grammar/json_schema_converter.cc @@ -385,9 +385,9 @@ void JSONSchemaToEBNFConverter::AddBasicRules() { void JSONSchemaToEBNFConverter::AddHelperRules() { rules_.push_back(std::make_pair( kBasicEscape, "[\"\\\\/bfnrt] | \"u\" [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9] [A-Fa-f0-9]")); - rules_.push_back(std::make_pair(kBasicStringSub, "\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + - " | \"\\\\\" " + kBasicEscape + " " + - kBasicStringSub)); + rules_.push_back(std::make_pair( + kBasicStringSub, "(\"\\\"\" | [^\"\\\\\\r\\n] " + kBasicStringSub + " | \"\\\\\" " + + kBasicEscape + " " + kBasicStringSub + ") (= [ \\n\\t]* [,}\\]:])")); } void JSONSchemaToEBNFConverter::CreateBasicRule(const picojson::value& schema, @@ -648,7 +648,7 @@ std::string JSONSchemaToEBNFConverter::VisitString(const picojson::object& schem "pattern", "format", }); - return "[\"] " + kBasicStringSub + " [\"]"; + return "[\"] " + kBasicStringSub; } std::string JSONSchemaToEBNFConverter::VisitBoolean(const picojson::object& schema, diff --git a/cpp/serve/grammar/support.h b/cpp/serve/grammar/support.h index fb9002dbac..c8b3f34344 100644 --- a/cpp/serve/grammar/support.h +++ b/cpp/serve/grammar/support.h @@ -8,30 +8,72 @@ #include +#include #include #include +#include namespace mlc { namespace llm { namespace serve { -/*! \brief Manages a segment of externally provided memory and use it as a bitset. */ -class BitsetManager { +/*! \brief A bitset with runtime specified length. It manages memory internally or the memory + * provided externally with enough size. */ +class DynamicBitset { public: - BitsetManager(uint32_t* data, int buffer_size, int element_cnt) - : data_(data), buffer_size_(buffer_size), element_cnt_(element_cnt) { - DCHECK(buffer_size >= CalculateBufferSize(element_cnt)); + static int CalculateBufferSize(int element_size) { return (element_size + 31) / 32; } + + DynamicBitset() : size_(0), buffer_size_(0), data_(nullptr), is_internal_(true) {} + + DynamicBitset(int size, uint32_t* data = nullptr) + : size_(size), buffer_size_(CalculateBufferSize(size)) { + if (data == nullptr) { + internal_buffer_.resize(buffer_size_, 0); + data_ = internal_buffer_.data(); + is_internal_ = true; + } else { + data_ = data; + is_internal_ = false; + } } - static int CalculateBufferSize(int element_cnt) { return (element_cnt + 31) / 32; } + DynamicBitset& operator=(const DynamicBitset& other) { + DCHECK(is_internal_ || size_ >= other.size_) << "Expanding bitset size is not allowed when the " + "memory of the bitset is externally managed"; + size_ = other.size_; + buffer_size_ = other.buffer_size_; + if (is_internal_) { + internal_buffer_.reserve(buffer_size_); + data_ = internal_buffer_.data(); + } + if (data_ != other.data_) { + std::memcpy(data_, other.data_, buffer_size_ * sizeof(uint32_t)); + } + return *this; + } + + DynamicBitset& operator=(DynamicBitset&& other) { + size_ = other.size_; + buffer_size_ = other.buffer_size_; + is_internal_ = other.is_internal_; + if (is_internal_) { + internal_buffer_ = std::move(other.internal_buffer_); + data_ = internal_buffer_.data(); + } else { + data_ = other.data_; + } + return *this; + } bool operator[](int index) const { - DCHECK(index >= 0 && index < element_cnt_); + DCHECK(data_ && index >= 0 && index < size_); return (data_[index / 32] >> (index % 32)) & 1; } + int Size() const { return size_; } + void Set(int index, bool value) { - DCHECK(index >= 0 && index < element_cnt_); + DCHECK(data_ && index >= 0 && index < size_); if (value) { data_[index / 32] |= 1 << (index % 32); } else { @@ -39,14 +81,30 @@ class BitsetManager { } } - void Reset(bool value) { std::memset(data_, value ? 0xFF : 0, buffer_size_ * sizeof(uint32_t)); } + void Set() { + DCHECK(data_); + std::memset(data_, 0xFF, buffer_size_ * sizeof(uint32_t)); + } + + void Reset() { + DCHECK(data_); + std::memset(data_, 0, buffer_size_ * sizeof(uint32_t)); + } - int GetElementCnt() const { return element_cnt_; } + DynamicBitset& operator|=(const DynamicBitset& other) { + DCHECK(buffer_size_ <= other.buffer_size_); + for (int i = 0; i < buffer_size_; ++i) { + data_[i] |= other.data_[i]; + } + return *this; + } private: - uint32_t* const data_; - const int buffer_size_; - const int element_cnt_; + int size_; + int buffer_size_; + uint32_t* data_; + std::vector internal_buffer_; + bool is_internal_; }; /*! diff --git a/cpp/support/encoding.cc b/cpp/support/encoding.cc index d9420bbbd5..9f33f98a7e 100644 --- a/cpp/support/encoding.cc +++ b/cpp/support/encoding.cc @@ -36,14 +36,15 @@ std::string PrintAsUTF8(TCodepoint codepoint) { return utf8; } -std::string PrintAsEscaped(TCodepoint codepoint, - const std::unordered_map& custom_escape_map) { +std::string PrintAsEscaped( + TCodepoint codepoint, + const std::unordered_map& additional_escape_map) { static const std::unordered_map kCodepointToEscape = { {'\'', "\\\'"}, {'\"', "\\\""}, {'\?', "\\\?"}, {'\\', "\\\\"}, {'\a', "\\a"}, {'\b', "\\b"}, {'\f', "\\f"}, {'\n', "\\n"}, {'\r', "\\r"}, {'\t', "\\t"}, {'\v', "\\v"}, {'\0', "\\0"}, {'\x1B', "\\e"}}; - if (auto it = custom_escape_map.find(codepoint); it != custom_escape_map.end()) { + if (auto it = additional_escape_map.find(codepoint); it != additional_escape_map.end()) { return it->second; } @@ -56,14 +57,24 @@ std::string PrintAsEscaped(TCodepoint codepoint, } // convert codepoint to hex - int width = codepoint <= 0xFFFF ? 4 : 8; + char prefix = codepoint <= 0xFF ? 'x' : codepoint <= 0xFFFF ? 'u' : 'U'; + int width = codepoint <= 0xFF ? 2 : codepoint <= 0xFFFF ? 4 : 8; std::stringstream ss; ss << std::setfill('0') << std::setw(width) << std::hex << codepoint; auto hex = ss.str(); - return codepoint <= 0xFFFF ? "\\u" + hex : "\\U" + hex; + return std::string("\\") + prefix + hex; } -std::pair ParseNextUTF8(const char* utf8) { +std::string PrintAsEscaped(std::string raw_str) { + std::string res; + auto codepoints = ParseUTF8(raw_str.c_str(), UTF8ErrorPolicy::kReturnByte); + for (auto c : codepoints) { + res += PrintAsEscaped(c); + } + return res; +} + +std::tuple HandleUTF8FirstByte(uint8_t byte) { static const std::array kFirstByteMask = {0x00, 0x7F, 0x1F, 0x0F, 0x07}; // clang-format off static const std::array kUtf8Bytes = { @@ -85,30 +96,44 @@ std::pair ParseNextUTF8(const char* utf8) { 4, 4, 4, 4, 4, 4, 4, 4, -1, -1, -1, -1, -1, -1, -1, -1, }; // clang-format on + auto num_bytes = kUtf8Bytes[static_cast(byte)]; + if (num_bytes == -1) { + return {false, 0, 0}; + } + return {true, num_bytes, byte & kFirstByteMask[num_bytes]}; +} - auto bytes = kUtf8Bytes[static_cast(utf8[0])]; - if (bytes == -1) { - // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), utf8}; +std::pair ParseNextUTF8(const char* utf8, UTF8ErrorPolicy error_policy) { + auto [accepted, num_bytes, res] = HandleUTF8FirstByte(utf8[0]); + if (accepted) { + for (int i = 1; i < num_bytes; ++i) { + if (utf8[i] == 0 || (static_cast(utf8[i]) & 0xC0) != 0x80) { + // invalid utf8 + accepted = false; + break; + } + res = (res << 6) | (static_cast(utf8[i]) & 0x3F); + } } - TCodepoint res = static_cast(utf8[0]) & kFirstByteMask[bytes]; - for (int i = 1; i < bytes; ++i) { - if (utf8[i] == 0 || (static_cast(utf8[i]) & 0xC0) != 0x80) { - // invalid utf8 - return {static_cast(CharHandlingError::kInvalidUtf8), 0}; + if (!accepted) { + // invalid utf8 + if (error_policy == UTF8ErrorPolicy::kReturnInvalid) { + return {CharHandlingError::kInvalidUTF8, utf8}; + } else { + return {static_cast(utf8[0]), utf8 + 1}; } - res = (res << 6) | (static_cast(utf8[i]) & 0x3F); } - return {res, utf8 + bytes}; + + return {res, utf8 + num_bytes}; } -std::vector ParseUTF8(const char* utf8) { +std::vector ParseUTF8(const char* utf8, UTF8ErrorPolicy error_policy) { std::vector codepoints; while (*utf8 != 0) { TCodepoint codepoint; - std::tie(codepoint, utf8) = ParseNextUTF8(utf8); - if (codepoint == static_cast(CharHandlingError::kInvalidUtf8)) { + std::tie(codepoint, utf8) = ParseNextUTF8(utf8, error_policy); + if (codepoint == CharHandlingError::kInvalidUTF8) { return {codepoint}; } codepoints.push_back(codepoint); @@ -129,17 +154,17 @@ inline int HexCharToInt(char c) { } std::pair ParseNextUTF8OrEscaped( - const char* utf8, const std::unordered_map& custom_escape_map) { + const char* utf8, const std::unordered_map& additional_escape_map) { static const std::unordered_map kEscapeToCodepoint = { {"\\\'", '\''}, {"\\\"", '\"'}, {"\\\?", '\?'}, {"\\\\", '\\'}, {"\\a", '\a'}, {"\\b", '\b'}, {"\\f", '\f'}, {"\\n", '\n'}, {"\\r", '\r'}, {"\\t", '\t'}, {"\\v", '\v'}, {"\\0", '\0'}, {"\\e", '\x1B'}}; if (utf8[0] != '\\') { - return ParseNextUTF8(utf8); + return ParseNextUTF8(utf8, UTF8ErrorPolicy::kReturnInvalid); } auto escape_sequence = std::string(utf8, 2); - if (auto it = custom_escape_map.find(escape_sequence); it != custom_escape_map.end()) { + if (auto it = additional_escape_map.find(escape_sequence); it != additional_escape_map.end()) { return {it->second, utf8 + 2}; } if (auto it = kEscapeToCodepoint.find(escape_sequence); it != kEscapeToCodepoint.end()) { @@ -159,7 +184,7 @@ std::pair ParseNextUTF8OrEscaped( ++len; } if (len == 0) { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } return {codepoint, utf8 + len + 2}; } else if (utf8[1] == 'u' || utf8[1] == 'U') { @@ -170,13 +195,13 @@ std::pair ParseNextUTF8OrEscaped( for (int i = 0; i < len; ++i) { auto digit = HexCharToInt(utf8[i + 2]); if (digit == -1) { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } codepoint = codepoint * 16 + digit; } return {codepoint, utf8 + len + 2}; } else { - return {static_cast(CharHandlingError::kInvalidEscape), utf8}; + return {CharHandlingError::kInvalidEscape, utf8}; } } diff --git a/cpp/support/encoding.h b/cpp/support/encoding.h index 790040e97e..0b18c43b0d 100644 --- a/cpp/support/encoding.h +++ b/cpp/support/encoding.h @@ -17,59 +17,89 @@ namespace llm { using TCodepoint = int32_t; /*! - * \brief Convert a codepoint to a UTF-8 string. + * \brief Handle the utf-8 first byte. + * \returns (is_valid, total_number_of_bytes, initial_codepoint). + */ +std::tuple HandleUTF8FirstByte(uint8_t byte); + +/*! + * \brief Print a codepoint to a UTF-8 string. * \param codepoint The codepoint. * \return The UTF-8 string. */ std::string PrintAsUTF8(TCodepoint codepoint); /*! - * \brief Convert a codepoint to a printable string. If the codepoint is not printable, it will be + * \brief Print a codepoint to a escaped string. If the codepoint is not printable, it will be * escaped. By default the function support escape sequences in C ("\n", "\t", "\u0123"). User can - * specify more escape sequences using custom_escape_map. + * specify more escape sequences using additional_escape_map. * \param codepoint The codepoint. - * \param custom_escape_map A map from codepoint to escape sequence. If the codepoint is in the map, - * it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. - * \return The printable string. + * \param additional_escape_map A map from codepoint to escape sequence. If the codepoint is in the + * map, it will be escaped using the corresponding escape sequence. e.g. {{'-', "\\-"}}. \return The + * printable string. */ std::string PrintAsEscaped( TCodepoint codepoint, - const std::unordered_map& custom_escape_map = {}); + const std::unordered_map& additional_escape_map = {}); + +/*! + * \brief Print the given string to a escaped string that can be printed. + * \return The escaped string. + */ +std::string PrintAsEscaped(std::string raw_str); /*! * \brief Represents an error when handling characters. Will be returned as a special TCodepoint * value. */ -enum class CharHandlingError : TCodepoint { +enum CharHandlingError : TCodepoint { /*! \brief The UTF-8 string is invalid. */ - kInvalidUtf8 = -10, + kInvalidUTF8 = -10, /*! \brief The escape sequence is invalid. */ kInvalidEscape = -11, }; /*! - * \brief Convert a UTF-8 string to a codepoint. + * \brief The method to handle invalid UTF-8 sequence. + */ +enum class UTF8ErrorPolicy { + /*! \brief Return an error codepoint when an error is encountered. */ + kReturnInvalid, + /*! \brief Skip the error and continue parsing. */ + kReturnByte, +}; + +/*! + * \brief Parse the first codepoint in a UTF-8 string. * \param utf8 The UTF-8 string. - * \return The codepoint and the number of bytes consumed. If the UTF-8 string is invalid, the - * function returns (CharHandlingError::kInvalidUtf8, 0). + * \return The codepoint and new pointer. If the UTF-8 string is invalid, and the error policy is + * kReturnInvalid, the function returns (CharHandlingError::kInvalidUTF8, input char pointer). */ -std::pair ParseNextUTF8(const char* utf8); +std::pair ParseNextUTF8( + const char* utf8, UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid); -std::vector ParseUTF8(const char* utf8); +/*! + * \brief Parse all codepoints in a UTF-8 string. + * \param utf8 The UTF-8 string. + * \return All codepoints. If the UTF-8 string is invalid, and the error policy is + * kReturnInvalid, the function returns {CharHandlingError::kInvalidUTF8}. + */ +std::vector ParseUTF8(const char* utf8, + UTF8ErrorPolicy error_policy = UTF8ErrorPolicy::kReturnInvalid); /*! - * \brief Convert a UTF-8 string or an escape sequence to a codepoint. By default the function - * supports escape sequences in C ("\n", "\t", "\u0123"). User can specify more escape sequences - * using custom_escape_map. + * \brief Parse the first codepoint from a UTF-8 string. Also checks escape sequences and converts + * the escaped char to its original value. * \param utf8 The UTF-8 string or the escape sequence. - * \param custom_escape_map A map from escape sequence to codepoint. If the escape sequence is in - * the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. - * \return The codepoint and the number of bytes consumed. If the UTF-8 string or the escape - * sequence is invalid, the function returns - * (CharHandlingError::kInvalidUtf8 or CharHandlingError::kInvalidEscape, 0). + * \param additional_escape_map A map from escape sequence to codepoint. If the escape sequence is + * in the map, it will be converted to the corresponding codepoint. e.g. {{"\\-", '-'}}. + * \return The codepoint and the new pointer. If the UTF-8 string or the escape sequence is + * invalid, and the error policy is kReturnInvalid, the function returns + * (CharHandlingError::kInvalidUTF8, input char pointer). */ std::pair ParseNextUTF8OrEscaped( - const char* utf8, const std::unordered_map& custom_escape_map = {}); + const char* utf8, + const std::unordered_map& additional_escape_map = {}); } // namespace llm } // namespace mlc diff --git a/cpp/support/utils.h b/cpp/support/utils.h index 6c53e35715..2789654a88 100644 --- a/cpp/support/utils.h +++ b/cpp/support/utils.h @@ -37,5 +37,23 @@ inline bool StartsWith(const std::string& str, const char* prefix) { return prefix[n] == '\0'; } +/*! + * \brief Hash and combine value into seed. + * \ref https://www.boost.org/doc/libs/1_84_0/boost/intrusive/detail/hash_combine.hpp + */ +inline void HashCombineBinary(uint32_t& seed, uint32_t value) { + seed ^= value + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +/*! + * \brief Find the hash sum of several uint32_t args. + */ +template +uint32_t HashCombine(Args... args) { + uint32_t seed = 0; + (..., HashCombineBinary(seed, args)); + return seed; +} + } // namespace llm } // namespace mlc diff --git a/cpp/tokenizers.cc b/cpp/tokenizers.cc index 6fe9217520..cc1c172697 100644 --- a/cpp/tokenizers.cc +++ b/cpp/tokenizers.cc @@ -152,7 +152,8 @@ inline std::string ByteLevelDecoder(const std::string& token) { }; // clang-format on - auto unicode_codepoints = ParseUTF8(token.c_str()); + auto unicode_codepoints = ParseUTF8(token.c_str(), UTF8ErrorPolicy::kReturnInvalid); + ICHECK(unicode_codepoints.size() != 1 || unicode_codepoints[0] != kInvalidUTF8); std::string decoded; for (auto unicode_codepoint : unicode_codepoints) { diff --git a/python/mlc_llm/serve/grammar.py b/python/mlc_llm/serve/grammar.py index cf491884c2..8b5b7d9649 100644 --- a/python/mlc_llm/serve/grammar.py +++ b/python/mlc_llm/serve/grammar.py @@ -1,6 +1,6 @@ """Classes handling the grammar guided generation of MLC LLM serving""" -from typing import List, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import tvm import tvm._ffi @@ -22,19 +22,20 @@ class BNFGrammar(Object): def from_ebnf_string( ebnf_string: str, main_rule: str = "main", - normalize: bool = True, - simplify: bool = True, ) -> "BNFGrammar": - r"""Parse a BNF grammar from a string in BNF/EBNF format. - - This method accepts the EBNF notation from the W3C XML Specification - (https://www.w3.org/TR/xml/#sec-notation), which is a popular standard, with the following - changes: - - Using # as comment mark instead of /**/ - - Using C-style unicode escape sequence \u01AB, \U000001AB, \xAB instead of #x0123 - - Do not support A-B (match A and not match B) yet - - See tests/python/serve/json.ebnf for an example. + r"""Construct a BNF grammar with a EBNF-formatted string. The grammar will be normalized + (simplified) by default. + + EBNF grammar: see https://www.w3.org/TR/xml/#sec-notation. Note: + 1. Use # as the comment mark + 2. Use C-style unicode escape sequence \u01AB, \U000001AB, \xAB + 3. A-B (match A and not match B) is not supported yet + 4. Lookahead assertion can be added at the end of a rule to speed up matching. E.g. + ``` + main ::= "ab" a [a-z] + a ::= "cd" (=[a-z]) + ``` + The assertion (=[a-z]) means a must be followed by [a-z]. Parameters ---------- @@ -44,28 +45,13 @@ def from_ebnf_string( main_rule : str The name of the main rule. Default: "main". - normalize : bool - Whether to normalize the grammar. Default: true. Only set to false for the purpose of - testing. - - In The normalized form of a BNF grammar, every rule is in the form: - `rule_name ::= ("" | (element1_1 element1_2 ...) | (element2_1 element2_2 ...) | ...)`. - - I.e. a list of choices, each choice is a sequence of elements. Elements can be a - character class or a rule reference. And if the rule can be empty, the first choice - will be an empty string. - - simplify : bool - Whether to simplify the grammar to make matching more efficient. Default: true. Not - implemented yet. - Returns ------- grammar : BNFGrammar The parsed BNF grammar. """ return _ffi_api.BNFGrammarFromEBNFString( # type: ignore # pylint: disable=no-member - ebnf_string, main_rule, normalize, simplify + ebnf_string, main_rule ) def to_string(self) -> str: @@ -167,6 +153,31 @@ def get_grammar_of_json() -> "BNFGrammar": """ return _ffi_api.BNFGrammarGetGrammarOfJSON() # type: ignore # pylint: disable=no-member + @staticmethod + def debug_from_ebnf_string_no_normalize( + ebnf_string: str, + main_rule: str = "main", + ) -> "BNFGrammar": + r"""Construct a BNF grammar with a EBNF-formatted string, but not normalize it. + For test purposes. + + Parameters + ---------- + ebnf_string : str + The grammar string. + + main_rule : str + The name of the main rule. Default: "main". + + Returns + ------- + grammar : BNFGrammar + The parsed BNF grammar. + """ + return _ffi_api.BNFGrammarDebugFromEBNFStringNoNormalize( # type: ignore # pylint: disable=no-member + ebnf_string, main_rule + ) + @staticmethod def debug_json_schema_to_ebnf( schema: str, @@ -235,6 +246,11 @@ class GrammarStateMatcher(Object): max_rollback_steps : int The maximum number of steps to rollback when backtracking. Default: 0. + + token_table_postproc_method : Literal["byte_fallback", "byte_level"] + A helper parameter for the tokenizer. Only useful when the tokenizer is specified. + The method to postprocess the token table. For LLaMA and LLaMA-2 tokenizer, use + "byte_fallback"; for LLaMA-3 tokenizer, use "byte_level". Default: "byte_fallback". """ def __init__( @@ -242,6 +258,7 @@ def __init__( grammar: BNFGrammar, tokenizer: Union[None, Tokenizer, List[str]] = None, max_rollback_steps: int = 0, + token_table_postproc_method: Literal["byte_fallback", "byte_level"] = "byte_fallback", ): if isinstance(tokenizer, list): self.__init_handle_by_constructor__( @@ -256,6 +273,7 @@ def __init__( grammar, tokenizer, max_rollback_steps, + token_table_postproc_method, ) def accept_token(self, token_id: int) -> bool: @@ -346,7 +364,7 @@ def is_terminated(self) -> bool: """ return _ffi_api.GrammarStateMatcherIsTerminated(self) # type: ignore # pylint: disable=no-member - def debug_accept_char(self, codepoint: int) -> bool: + def debug_accept_char(self, codepoint: int, verbose: bool = False) -> bool: """Accept one unicode codepoint to the current state. For test purposes. Parameters @@ -354,11 +372,11 @@ def debug_accept_char(self, codepoint: int) -> bool: codepoint : int The unicode codepoint of the character to be accepted. """ - return _ffi_api.GrammarStateMatcherDebugAcceptCodepoint( # type: ignore # pylint: disable=no-member - self, codepoint + return _ffi_api.GrammarStateMatcherDebugAcceptChar( # type: ignore # pylint: disable=no-member + self, codepoint, verbose ) - def debug_match_complete_string(self, string: str) -> bool: + def debug_match_complete_string(self, string: str, verbose: bool = False) -> bool: """Check if the matcher can accept the complete string, and then reach the end of the grammar. Does not change the state of the GrammarStateMatcher. For test purposes. @@ -367,4 +385,4 @@ def debug_match_complete_string(self, string: str) -> bool: string : str The string to be matched. """ - return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string) # type: ignore # pylint: disable=no-member + return _ffi_api.GrammarStateMatcherDebugMatchCompleteString(self, string, verbose) # type: ignore # pylint: disable=no-member diff --git a/tests/python/serve/test_grammar_parser.py b/tests/python/serve/test_grammar_parser.py index 10eacdf9b9..5e335e15c7 100644 --- a/tests/python/serve/test_grammar_parser.py +++ b/tests/python/serve/test_grammar_parser.py @@ -1,4 +1,5 @@ # pylint: disable=missing-module-docstring,missing-function-docstring +import json import os import pytest @@ -14,11 +15,13 @@ def test_bnf_simple(): c ::= "c" """ expected = """main ::= ((b c)) -b ::= (([b])) -c ::= (([c])) +b ::= (("b")) +c ::= (("c")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) + print(expected) assert after == expected @@ -32,11 +35,11 @@ def test_ebnf(): b ::= ((b_1)) c ::= ((c_1)) d ::= ((d_1)) -b_1 ::= ("" | ([a] [b] b_1)) +b_1 ::= ("" | ("ab" b_1)) c_1 ::= (([acep-z] c_1) | ([acep-z])) -d_1 ::= ("" | ([d])) +d_1 ::= ("" | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -49,18 +52,33 @@ def test_star_quantifier(): e ::= [e]* [f]* | [g]* """ expected = """main ::= ((b c d)) -b ::= [b]* +b ::= (([b]*)) c ::= ((c_1)) d ::= ((d_1)) -e ::= ((e_star e_star_1) | (e_star_2)) -c_1 ::= ("" | ([b] c_1)) +e ::= (([e]* [f]*) | ([g]*)) +c_1 ::= ("" | ("b" c_1)) d_1 ::= ("" | (d_1_choice d_1)) -e_star ::= [e]* -e_star_1 ::= [f]* -e_star_2 ::= [g]* -d_1_choice ::= (([b] [c] [d]) | ([p] [q])) +d_1_choice ::= (("bcd") | ("pq")) +""" + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") + after = bnf_grammar.to_string() + assert after == expected + + +def test_lookahead_assertion(): + before = """main ::= ((b c d)) +b ::= (("abc" [a-z])) (=("abc")) +c ::= (("a") | ("b")) (=([a-z] "b")) +d ::= (("ac") | ("b" d_choice)) (=("abc")) +d_choice ::= (("e") | ("d")) +""" + expected = """main ::= ((b c d)) +b ::= (("abc" [a-z])) (=("abc")) +c ::= (("a") | ("b")) (=([a-z] "b")) +d ::= (("ac") | ("b" d_choice)) (=("abc")) +d_choice ::= (("e") | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -68,14 +86,14 @@ def test_star_quantifier(): def test_char(): before = r"""main ::= [a-z] [A-z] "\u0234" "\U00000345\xff" [-A-Z] [--] [^a] rest rest ::= [a-zA-Z0-9-] [\u0234-\U00000345] [测-试] [\--\]] rest1 -rest1 ::= "\?\"\'测试あc" "👀" "" +rest1 ::= "\?\"\'测试あc" "👀" "" [a-a] [b-b] """ - expected = r"""main ::= (([a-z] [A-z] ([\u0234]) ([\u0345] [\u00ff]) [\-A-Z] [\-\-] [^a] rest)) + expected = r"""main ::= (([a-z] [A-z] "\u0234\u0345\u00ff" [\-A-Z] [\-\-] [^a] rest)) rest ::= (([a-zA-Z0-9\-] [\u0234-\u0345] [\u6d4b-\u8bd5] [\--\]] rest1)) -rest1 ::= ((([\?] [\"] [\'] [\u6d4b] [\u8bd5] [\u3042] [c]) ([\U0001f440]) "")) +rest1 ::= (("\?\"\'\u6d4b\u8bd5\u3042c\U0001f440ab")) """ # Disable unwrap_nesting_rules to expose the result before unwrapping. - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", False, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -88,9 +106,9 @@ def test_space(): "f" | "g" """ - expected = """main ::= (([a] [b] [c] [d] [e]) | ([f]) | ([g])) + expected = """main ::= (("abcde") | ("f") | ("g")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -98,10 +116,10 @@ def test_space(): def test_nest(): before = """main::= "a" ("b" | "c" "d") | (("e" "f")) """ - expected = """main ::= (([a] main_choice) | ([e] [f])) -main_choice ::= (([b]) | ([c] [d])) + expected = """main ::= (("a" main_choice) | ("ef")) +main_choice ::= (("b") | ("cd")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() assert after == expected @@ -115,15 +133,16 @@ def test_flatten(): empty_test ::= "d" | (("" | "" "") "" | "a" "") | ("" ("" | "")) "" "" """ expected = """main ::= ((or_test sequence_test nested_test empty_test)) -or_test ::= ("" | ([a]) | ([b]) | ([d] [e]) | (or_test) | ([^a-z])) -sequence_test ::= (([a] [a] [b] sequence_test_choice [d] [e] sequence_test)) -nested_test ::= (([a] [b] [c] [d]) | ([a]) | ([b]) | ([c]) | (nested_rest)) -nested_rest ::= (([a]) | ([b] [c]) | ([d]) | ([e] [f]) | ([g])) -empty_test ::= ("" | ([d]) | ([a])) -sequence_test_choice ::= (([c]) | ([d])) +or_test ::= ("" | ("a") | ("b") | ("de") | (or_test) | ([^a-z])) +sequence_test ::= (("aab" sequence_test_choice "de" sequence_test)) +nested_test ::= (("abcd") | ("a") | ("b") | ("c") | (nested_rest)) +nested_rest ::= (("a") | ("bc") | ("d") | ("ef") | ("g")) +empty_test ::= ("" | ("d") | ("a")) +sequence_test_choice ::= (("c") | ("d")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) assert after == expected @@ -135,51 +154,53 @@ def test_json(): before = file.read() expected = r"""main ::= ((element)) -value ::= ((object) | (array) | (string) | (number) | ([t] [r] [u] [e]) | ([f] [a] [l] [s] [e]) | ([n] [u] [l] [l])) -object ::= (([{] ws [}]) | ([{] members [}])) -members ::= ((member) | (member [,] members)) -member ::= ((ws string ws [:] element)) -array ::= (([[] ws [\]]) | ([[] elements [\]])) -elements ::= ((element) | (element [,] elements)) +value ::= ((object) | (array) | (string) | (number) | ("true") | ("false") | ("null")) +object ::= (("{" ws "}") | ("{" members "}")) +members ::= ((member) | (member "," members)) +member ::= ((ws string ws ":" element)) +array ::= (("[" ws "]") | ("[" elements "]")) +elements ::= ((element) | (element "," elements)) element ::= ((ws value ws)) -string ::= (([\"] characters [\"])) +string ::= (("\"" characters "\"")) characters ::= ("" | (character characters)) -character ::= (([^\"\\]) | ([\\] escape)) -escape ::= (([\"]) | ([\\]) | ([/]) | ([b]) | ([f]) | ([n]) | ([r]) | ([t]) | ([u] hex hex hex hex)) +character ::= (([^\"\\]) | ("\\" escape)) +escape ::= (("\"") | ("\\") | ("/") | ("b") | ("f") | ("n") | ("r") | ("t") | ("u" hex hex hex hex)) hex ::= (([A-Fa-f0-9])) number ::= ((integer fraction exponent)) -integer ::= ((digit) | (onenine digits) | ([\-] digit) | ([\-] onenine digits)) +integer ::= ((digit) | (onenine digits) | ("-" digit) | ("-" onenine digits)) digits ::= ((digit) | (digit digits)) digit ::= (([0-9])) onenine ::= (([1-9])) -fraction ::= ("" | ([.] digits)) +fraction ::= ("" | ("." digits)) exponent ::= ("" | (exponent_choice exponent_choice_1 digits)) -ws ::= ("" | ([ ] ws) | ([\n] ws) | ([\r] ws) | ([\t] ws)) -exponent_choice ::= (([e]) | ([E])) -exponent_choice_1 ::= ("" | ([+]) | ([\-])) +ws ::= ("" | (" " ws) | ("\n" ws) | ("\r" ws) | ("\t" ws)) +exponent_choice ::= (("e") | ("E")) +exponent_choice_1 ::= ("" | ("+") | ("-")) """ - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") after = bnf_grammar.to_string() + print(after) assert after == expected def test_to_string_roundtrip(): """Checks the printed result can be parsed, and the parsing-printing process is idempotent.""" - before = r"""main ::= (b c) | (b main) -b ::= b_1 d -c ::= c_1 -d ::= d_1 -b_1 ::= ([b] b_1) | "" -c_1 ::= (c_2 c_1) | c_2 -c_2 ::= [acep-z] -d_1 ::= [d] | "" + before = r"""main ::= ((b c) | (b main)) +b ::= ((b_1 d)) +c ::= ((c_1)) +d ::= ((d_1)) +b_1 ::= ("" | ("b" b_1)) +c_1 ::= ((c_2 c_1) | (c_2)) (=("abc" [a-z])) +c_2 ::= (([acep-z])) +d_1 ::= ("" | ("d")) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") output_string_1 = bnf_grammar_1.to_string() - bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main", True, False) + bnf_grammar_2 = BNFGrammar.from_ebnf_string(output_string_1, "main") output_string_2 = bnf_grammar_2.to_string() + assert before == output_string_1 assert output_string_1 == output_string_2 @@ -245,34 +266,50 @@ def test_error(): ): BNFGrammar.from_ebnf_string('a ::= "a"') + with pytest.raises( + TVMError, + match="TVMError: EBNF parse error at line 1, column 21: Unexpected lookahead assertion", + ): + BNFGrammar.from_ebnf_string('main ::= "a" (="a") (="b")') + def test_to_json(): before = """main ::= b c | b main b ::= "bcd" c ::= [a-z] """ - expected = ( - '{"rule_expr_indptr":[0,3,6,10,13,16,20,24,28,32,36,41,44,48,51],"rule_expr_data"' - ":[3,1,1,3,1,2,4,2,0,1,3,1,1,3,1,0,4,2,3,4,5,2,2,5,0,2,98,98,0,2,99,99,0,2,100,100," - '4,3,7,8,9,5,1,10,0,2,97,122,4,1,12,5,1,13],"rules":[{"body_expr_id":6,"name":"main"},' - '{"body_expr_id":11,"name":"b"},{"body_expr_id":14,"name":"c"}]}' - ) - bnf_grammar = BNFGrammar.from_ebnf_string(before, "main", True, False) - after = bnf_grammar.to_json(False) - assert after == expected + expected_obj = { + "rules": [ + {"body_expr_id": 6, "name": "main"}, + {"body_expr_id": 9, "name": "b"}, + {"body_expr_id": 12, "name": "c"}, + ], + "rule_expr_indptr": [0, 3, 6, 10, 13, 16, 20, 24, 29, 32, 35, 40, 43], + "rule_expr_data": [ + # fmt: off + 4,1,1,4,1,2,5,2,0,1,4,1,1,4,1,0,5,2,3,4,6,2,2,5,0,3,98,99, + 100,5,1,7,6,1,8,1,3,0,97,122,5,1,10,6,1,11 + # fmt: on + ], + } + bnf_grammar = BNFGrammar.from_ebnf_string(before, "main") + print(bnf_grammar) + after_str = bnf_grammar.to_json(False) + after_obj = json.loads(after_str) + assert after_obj == expected_obj def test_to_json_roundtrip(): before = r"""main ::= ((b c) | (b main)) -b ::= ((b_1 d)) +b ::= ((b_1 d [a]*)) c ::= ((c_1)) d ::= ((d_1)) -b_1 ::= ("" | ([b] b_1)) +b_1 ::= ("" | ("b" b_1)) c_1 ::= ((c_2 c_1) | (c_2)) c_2 ::= (([acep-z])) -d_1 ::= ("" | ([d])) +d_1 ::= ("" | ("d")) """ - bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main", True, False) + bnf_grammar_1 = BNFGrammar.from_ebnf_string(before, "main") output_json_1 = bnf_grammar_1.to_json(False) bnf_grammar_2 = BNFGrammar.from_json(output_json_1) output_json_2 = bnf_grammar_2.to_json(False) diff --git a/tests/python/serve/test_grammar_state_matcher_custom.py b/tests/python/serve/test_grammar_state_matcher_custom.py index 6fc48705d1..6ad6294d77 100644 --- a/tests/python/serve/test_grammar_state_matcher_custom.py +++ b/tests/python/serve/test_grammar_state_matcher_custom.py @@ -40,6 +40,20 @@ def json_grammar(): return get_json_grammar() +def test_simple(): + grammar_str = """main ::= rule1 rule2 +rule1 ::= (rule2 | rule3) "a" +rule2 ::= "b" +rule3 ::= "c" +""" + + grammar = BNFGrammar.from_ebnf_string(grammar_str) + matcher = GrammarStateMatcher(grammar) + assert matcher.debug_match_complete_string("bab") + assert not matcher.debug_match_complete_string("abb") + assert matcher.debug_match_complete_string("cab") + + (json_input_accepted,) = tvm.testing.parameters( ('{"name": "John"}',), ('{ "name" : "John" }',), @@ -241,8 +255,8 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 # fmt: on ], ), @@ -258,15 +272,15 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): }""", [ # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 + 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, + 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, + 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, + 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, + 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, + 31846, 265, 265, 265, 265, 31974, 31974, 31999 # fmt: on ], ), @@ -395,5 +409,6 @@ class MainModel(BaseModel): if __name__ == "__main__": # Run a benchmark to show the performance before running tests test_find_next_rejected_tokens(get_json_grammar(), '{"id": 1,"name": "Example"}') + test_find_next_rejected_tokens_schema() tvm.testing.main() diff --git a/tests/python/serve/test_grammar_state_matcher_json.py b/tests/python/serve/test_grammar_state_matcher_json.py index fc0f79a041..51737e1435 100644 --- a/tests/python/serve/test_grammar_state_matcher_json.py +++ b/tests/python/serve/test_grammar_state_matcher_json.py @@ -2,7 +2,7 @@ # pylint: disable=redefined-outer-name,unbalanced-tuple-unpacking """This test uses the optimized JSON grammar provided by the grammar library.""" import sys -from typing import List, Optional +from typing import List, Literal, Optional import pytest import tvm @@ -213,19 +213,40 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): assert GrammarStateMatcher(json_grammar).debug_match_complete_string(json_input_pressure) -(input_find_rejected_tokens, expected_rejected_sizes) = tvm.testing.parameters( +( + tokenizer_path, + input_find_rejected_tokens, + expected_rejected_sizes, + token_table_postproc_method, +) = tvm.testing.parameters( ( # short test + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", '{"id": 1,"name": "Example"}', [ # fmt: off - 31989, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 299, 299, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 292, 292, 292, 292, 292, 31974, 31999 + 31989, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 272, 272, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 265, 265, 265, 265, 265, 31974, 31999 # fmt: on ], + "byte_fallback", + ), + ( + # short test + "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", + '{"id": 1,"name": "Example哈哈"}', + [ + # fmt: off + 128235, 127497, 5002, 5002, 5002, 127849, 126399, 126399, 126760, 127499, 5002, 5002, + 5002, 5002, 5002, 127849, 126399, 126399, 4952, 4952, 4952, 4952, 4952, 4952, 4952, + 4952, 128066, 128111, 4952, 128066, 128111, 4952, 127873, 128254 + # fmt: on + ], + "byte_level", ), ( # long test + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", """{ "id": 1, "na": "ex", @@ -236,40 +257,51 @@ def test_json_pressure(json_grammar: BNFGrammar, json_input_pressure): }""", [ # fmt: off - 31989, 31912, 31912, 299, 299, 299, 31973, 31846, 31846, 31948, 31915, 31915, 299, 299, - 299, 31973, 31846, 31846, 292, 292, 292, 31974, 31915, 31915, 299, 299, 299, 31973, - 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 299, 299, 31973, 31846, 31846, - 31840, 291, 291, 291, 31969, 31846, 31846, 291, 291, 291, 31969, 31974, 31915, 31915, - 299, 299, 299, 31973, 31846, 31846, 31908, 299, 299, 299, 299, 31973, 31846, 31846, - 31906, 299, 299, 299, 299, 31973, 31846, 31846, 291, 291, 291, 31968, 31970, 31915, - 31915, 299, 299, 299, 299, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, - 31846, 31846, 31943, 31970, 31974, 31915, 31915, 299, 299, 299, 299, 31973, 31846, - 31846, 292, 292, 292, 292, 31974, 31974, 31999 + 31989, 31912, 31912, 272, 272, 272, 31973, 31846, 31846, 31948, 31915, 31915, 272, 272, + 272, 31973, 31846, 31846, 265, 265, 265, 31974, 31915, 31915, 272, 272, 272, 31973, + 31846, 31846, 31997, 31997, 31998, 31974, 31915, 31915, 272, 272, 31973, 31846, 31846, + 31840, 264, 264, 264, 31969, 31846, 31846, 264, 264, 264, 31969, 31974, 31915, 31915, + 272, 272, 272, 31973, 31846, 31846, 31908, 272, 272, 272, 272, 31973, 31846, 31846, + 31906, 272, 272, 272, 272, 31973, 31846, 31846, 264, 264, 264, 31968, 31970, 31915, + 31915, 272, 272, 272, 272, 31973, 31846, 31846, 31840, 31943, 31846, 31846, 31943, + 31846, 31846, 31943, 31970, 31974, 31915, 31915, 272, 272, 272, 272, 31973, 31846, + 31846, 265, 265, 265, 265, 31974, 31974, 31999 # fmt: on ], + "byte_fallback", ), ) def test_find_next_rejected_tokens( json_grammar: BNFGrammar, + tokenizer_path: str, input_find_rejected_tokens: str, - expected_rejected_sizes: Optional[List[int]] = None, + expected_rejected_sizes: Optional[List[int]], + token_table_postproc_method: Literal["byte_fallback", "byte_level"], ): - tokenizer_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" tokenizer = Tokenizer(tokenizer_path) - grammar_state_matcher = GrammarStateMatcher(json_grammar, tokenizer) + grammar_state_matcher = GrammarStateMatcher( + json_grammar, tokenizer, token_table_postproc_method=token_table_postproc_method + ) + input_bytes = input_find_rejected_tokens.encode("utf-8") + rejected_sizes = [] - real_sizes = [] - for c in input_find_rejected_tokens: + for i, c in enumerate(input_bytes): rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) - print("Accepting char:", c, file=sys.stderr) - assert grammar_state_matcher.debug_accept_char(ord(c)) + rejected_sizes.append(len(rejected_token_ids)) + if expected_rejected_sizes is not None: + assert rejected_sizes[-1] == expected_rejected_sizes[i], ( + len(rejected_token_ids), + expected_rejected_sizes[i], + ) + print("Accepting char:", c, bytes([c]), file=sys.stderr) + assert grammar_state_matcher.debug_accept_char(c) + rejected_token_ids = grammar_state_matcher.find_next_rejected_tokens(True) - real_sizes.append(len(rejected_token_ids)) + rejected_sizes.append(len(rejected_token_ids)) if expected_rejected_sizes is not None: - assert real_sizes == expected_rejected_sizes + assert rejected_sizes[-1] == expected_rejected_sizes[-1] def test_token_based_operations(json_grammar: BNFGrammar): @@ -305,7 +337,7 @@ def test_token_based_operations(json_grammar: BNFGrammar): accepted = list(set(range(len(token_table))) - set(rejected)) accepted_tokens = [token_table[i] for i in accepted] result.append(accepted_tokens) - assert id in accepted + assert id in accepted, token_table[id] assert grammar_state_matcher.accept_token(id) rejected = grammar_state_matcher.find_next_rejected_tokens() @@ -407,6 +439,20 @@ def test_termination(json_grammar: BNFGrammar): if __name__ == "__main__": # Run a benchmark to show the performance before running tests - test_find_next_rejected_tokens(BNFGrammar.get_grammar_of_json(), '{"id": 1,"name": "Example"}') + test_find_next_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + "dist/Llama-2-7b-chat-hf-q4f16_1-MLC", + '{"id": 1,"name": "Example"}', + None, + "byte_fallback", + ) + + test_find_next_rejected_tokens( + BNFGrammar.get_grammar_of_json(), + "dist/Meta-Llama-3-8B-Instruct-q4f16_1-MLC", + '{"id": 1,"name": "Example哈哈"}', + None, + "byte_level", + ) tvm.testing.main() diff --git a/tests/python/serve/test_serve_engine_grammar.py b/tests/python/serve/test_serve_engine_grammar.py index 2b3ce29c7f..8bd86a25a1 100644 --- a/tests/python/serve/test_serve_engine_grammar.py +++ b/tests/python/serve/test_serve_engine_grammar.py @@ -13,7 +13,7 @@ prompts_list = [ "Generate a JSON string containing 20 objects:", - "Generate a JSON containing a list:", + "Generate a JSON containing a non-empty list:", "Generate a JSON with 5 elements:", ] model_path = "dist/Llama-2-7b-chat-hf-q4f16_1-MLC" diff --git a/web/emcc/mlc_wasm_runtime.cc b/web/emcc/mlc_wasm_runtime.cc index b9a7f55bfa..6ba914ee9f 100644 --- a/web/emcc/mlc_wasm_runtime.cc +++ b/web/emcc/mlc_wasm_runtime.cc @@ -36,9 +36,9 @@ // Grammar related #include "serve/grammar/grammar.cc" +#include "serve/grammar/grammar_functor.cc" #include "serve/grammar/grammar_parser.cc" #include "serve/grammar/grammar_serializer.cc" -#include "serve/grammar/grammar_simplifier.cc" #include "serve/grammar/grammar_state_matcher.cc" #include "serve/grammar/json_schema_converter.cc" #include "support/encoding.cc"