diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index dda3fab4f7b3..18cb9ebe4af7 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -112,6 +112,9 @@ def jump_and_retokenize( """ raise NotImplementedError() + def accept_token_length(self) -> int: + raise NotImplementedError() + INVALID_GRAMMAR_OBJ = BaseGrammarObject() diff --git a/python/sglang/srt/constrained/reasoner_grammar_backend.py b/python/sglang/srt/constrained/reasoner_grammar_backend.py index 57fd55a3bf98..1614fcadcb74 100644 --- a/python/sglang/srt/constrained/reasoner_grammar_backend.py +++ b/python/sglang/srt/constrained/reasoner_grammar_backend.py @@ -25,11 +25,16 @@ class ReasonerGrammarObject(BaseGrammarObject): - def __init__(self, grammar: BaseGrammarObject, think_end_id): + def __init__( + self, + grammar: BaseGrammarObject, + think_end_id, + is_in_reasoning=True, + ): super().__init__() self.grammar = grammar self.think_end_id = think_end_id - self.is_in_reasoning = True + self.is_in_reasoning = is_in_reasoning def accept_token(self, token: int): if token == self.think_end_id: @@ -38,12 +43,27 @@ def accept_token(self, token: int): if not self.is_in_reasoning and token != self.think_end_id: self.grammar.accept_token(token) + def rollback(self, k: int): + if self.accept_token_length() > 0: + return self.grammar.rollback(k) + else: + # for spec decoding. + self.is_in_reasoning = True + + def is_terminated(self): + return self.grammar.is_terminated() + def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device) def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: + # in reasoning, do not mask, otherwise there will be no token to sample + if self.is_in_reasoning: + vocab_mask[idx].fill_(-1) + return + if not self.is_in_reasoning: self.grammar.fill_vocab_mask(vocab_mask, idx) @@ -55,7 +75,11 @@ def apply_vocab_mask(self): return self.grammar.apply_vocab_mask def copy(self) -> BaseGrammarObject: - return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id) + return ReasonerGrammarObject( + self.grammar.copy(), + self.think_end_id, + self.is_in_reasoning, + ) @property def finished(self): @@ -78,6 +102,9 @@ def jump_and_retokenize( old_output_ids, new_output_ids, next_state ) + def accept_token_length(self) -> int: + return self.grammar.accept_token_length() + class ReasonerGrammarBackend(BaseGrammarBackend): def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id): diff --git a/python/sglang/srt/constrained/xgrammar_backend.py b/python/sglang/srt/constrained/xgrammar_backend.py index 58ea764d6220..b0c62978acfd 100644 --- a/python/sglang/srt/constrained/xgrammar_backend.py +++ b/python/sglang/srt/constrained/xgrammar_backend.py @@ -158,6 +158,9 @@ def jump_and_retokenize( for i in range(k, len(new_output_ids)): assert self.matcher.accept_token(new_output_ids[i]) + def accept_token_length(self) -> int: + return len(self.accepted_tokens) + def __repr__(self): return f"XGrammarGrammar({self.key_string=}, {self.accepted_tokens=}, {self.current_token=})"