diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index a131c6a1b92b..8e00f71b6a4b 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -18,8 +18,9 @@ import math from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Union +import numpy as np import torch from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_schema @@ -27,15 +28,24 @@ from transformers import PreTrainedTokenizerBase +class LogitsInfo: + + def __init__(self, fsm: FSM): + self.fsm: FSM = fsm + self.mask: Optional[torch.Tensor] = None + self.allowed_tokens: Dict[int, torch.Tensor] = {} + self.cache_hash: Dict[int, int] = {} + + class BaseLogitsProcessor: - def __init__(self): - # Child class should use initialize in their init. - self.fsm: FSM + def __init__(self, fsm: FSM): + self.fsm = fsm def init_state(self): - """Initialize the FSM states.""" + """Initialize the FSM states""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) + self.info: LogitsInfo = LogitsInfo(self.fsm) def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: @@ -47,18 +57,58 @@ def __call__(self, input_ids: List[int], else: last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) - self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[seq_id] = self.info.fsm.next_state( self.fsm_state[last_seq_id], last_token) - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + state = self.fsm_state[seq_id] + allowed_tokens = self.info.fsm.allowed_token_ids(state) + + allowed_tokens_hash = hash(tuple(allowed_tokens)) + cacheEntry = (self.info.allowed_tokens[allowed_tokens_hash] + if allowed_tokens_hash in self.info.cache_hash else None) + + if cacheEntry is None: + # Cache miss, calculate allowed tokens and cache them + np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32) + allowed_tokens_tensor = torch.from_numpy( + np_allowed_tokens).pin_memory() + + if allowed_tokens_tensor.device != scores.device: + allowed_tokens_tensor = allowed_tokens_tensor.to( + scores.device, dtype=torch.int64, non_blocking=True) + else: + allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64) + + self.info.allowed_tokens[ + allowed_tokens_hash] = allowed_tokens_tensor + self.info.cache_hash[allowed_tokens_hash] = allowed_tokens_hash + + else: + allowed_tokens_tensor = self.info.allowed_tokens[ + allowed_tokens_hash] + + if self.info.mask is None: + self.info.mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) + else: + self.info.mask.fill_(-math.inf) + + self.info.mask.index_fill_(0, allowed_tokens_tensor, 0) + scores.add_(self.info.mask) - mask = torch.full((scores.shape[-1], ), - -math.inf, - device=scores.device) - mask[allowed_tokens] = 0 - scores.add_(mask) return scores + def state_reset_required(self) -> bool: + """Determine if a state reset is required for this processor. + + Returns + ------- + bool + Indicates whether a reset is required. Default is False. + """ + return False + class RegexLogitsProcessor(BaseLogitsProcessor): @@ -75,14 +125,17 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): """ tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) - self.fsm = fsm + super().__init__(fsm=fsm) class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema: Union[str, Dict, BaseModel], - tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Union[str, None]): + def __init__( + self, + schema: Union[str, Dict, BaseModel], + tokenizer: PreTrainedTokenizerBase, + whitespace_pattern: Union[str, None], + ): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -128,12 +181,20 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): """ tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) - self.fsm = fsm + self._previous_fsm = None + super().__init__(fsm=fsm) def init_state(self): """Initialize state with a CFGFSM copy.""" super().init_state() - self.fsm = self.fsm.copy() + self.fsm = self.info.fsm.copy() + + def state_reset_required(self) -> bool: + requiresReset = (self._previous_fsm is None or self.info.fsm.regex_fsm + is not self.info._previous_fsm) + + self._previous_fsm = self.info.fsm.regex_fsm + return requiresReset @lru_cache