diff --git a/python/sglang/srt/constrained/base_grammar_backend.py b/python/sglang/srt/constrained/base_grammar_backend.py index dda3fab4f7b3..9b662756457e 100644 --- a/python/sglang/srt/constrained/base_grammar_backend.py +++ b/python/sglang/srt/constrained/base_grammar_backend.py @@ -125,7 +125,7 @@ class CacheEntry: class BaseGrammarBackend: def __init__(self): self.executor = ThreadPoolExecutor() - self.cache: Dict[Tuple[str, str], CacheEntry] = {} + self.cache: Dict[Tuple[str, str, bool], CacheEntry] = {} def _not_supported(self, key_type: str, key_string: str) -> None: logger.warning(f"Skip unsupported {key_type=}, {key_string=}") @@ -150,9 +150,11 @@ def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]: def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]: return self._not_supported("structural_tag", key_string) - def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]: + def _init_value_dispatch( + self, key: Tuple[str, str, bool] + ) -> Optional[BaseGrammarObject]: s = time.perf_counter() - key_type, key_string = key + key_type, key_string, may_can_reasoning = key if key_type == "json": grammar = self.dispatch_json(key_string) elif key_type == "regex": @@ -172,16 +174,14 @@ def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObje grammar.grammar_stats.compilation_time = time.perf_counter() - s return grammar - def get_cached_or_future_value( - self, key: Tuple[str, str] - ) -> Optional[BaseGrammarObject]: + def get_cached_or_future_value(self, key: Tuple[str, str, bool]): value = self.cache.get(key) if value: return value.copy(), True value = self.executor.submit(self._init_value_dispatch, key) return value, False - def set_cache(self, key: Tuple[str, str], value: BaseGrammarObject): + def set_cache(self, key: Tuple[str, str, bool], value: BaseGrammarObject): self.cache[key] = value def reset(self): diff --git a/python/sglang/srt/constrained/reasoner_grammar_backend.py b/python/sglang/srt/constrained/reasoner_grammar_backend.py index 57fd55a3bf98..28e17d07fee9 100644 --- a/python/sglang/srt/constrained/reasoner_grammar_backend.py +++ b/python/sglang/srt/constrained/reasoner_grammar_backend.py @@ -25,11 +25,11 @@ class ReasonerGrammarObject(BaseGrammarObject): - def __init__(self, grammar: BaseGrammarObject, think_end_id): + def __init__(self, grammar: BaseGrammarObject, think_end_id, may_can_reasoning): super().__init__() self.grammar = grammar self.think_end_id = think_end_id - self.is_in_reasoning = True + self.is_in_reasoning = may_can_reasoning def accept_token(self, token: int): if token == self.think_end_id: @@ -38,11 +38,17 @@ def accept_token(self, token: int): if not self.is_in_reasoning and token != self.think_end_id: self.grammar.accept_token(token) + def rollback(self, k: int): + self.grammar.rollback(k) + def allocate_vocab_mask( self, vocab_size: int, batch_size: int, device ) -> torch.Tensor: return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device) + def is_terminated(self): + return self.grammar.is_terminated() + def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None: if not self.is_in_reasoning: self.grammar.fill_vocab_mask(vocab_mask, idx) @@ -55,7 +61,9 @@ def apply_vocab_mask(self): return self.grammar.apply_vocab_mask def copy(self) -> BaseGrammarObject: - return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id) + return ReasonerGrammarObject( + self.grammar.copy(), self.think_end_id, self.is_in_reasoning + ) @property def finished(self): @@ -90,4 +98,5 @@ def _init_value_dispatch(self, key: Tuple[str, str]) -> Optional[BaseGrammarObje # avoid wrapping invalid grammar, so that the scheduler can detect it if ret is None or ret is INVALID_GRAMMAR_OBJ: return ret - return ReasonerGrammarObject(ret, self.think_end_id) + may_can_reasoning = key[2] + return ReasonerGrammarObject(ret, self.think_end_id, may_can_reasoning) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 2938511112a0..b9fca41af50f 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1318,6 +1318,11 @@ def handle_generate_request( elif req.sampling_params.structural_tag: key = ("structural_tag", req.sampling_params.structural_tag) + may_can_reasoning = not getattr( + self.tokenizer, "think_end_id", None + ) in getattr(req, "origin_input_ids", []) + key = (key[0], key[1], may_can_reasoning) + value, cache_hit = self.grammar_backend.get_cached_or_future_value(key) req.grammar = value