diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 1618705ff298..0c194a2d513b 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -18,10 +18,12 @@ import math from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Union +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Union +import numpy as np import torch -from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write +from line_profiler import profile +from outlines.fsm.guide import CFGGuide, Generate, Guide, Instruction, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -32,9 +34,13 @@ class BaseLogitsProcessor: def __init__(self, guide: Guide): self._guide: Guide = guide self._fsm_state: DefaultDict[int, int] = defaultdict(int) + self.cache: Dict[Instruction, torch.Tensor] = {} + self.mask: Optional[torch.Tensor] = None - def __call__(self, input_ids: List[int], - scores: torch.Tensor) -> torch.Tensor: + def init_state(self): + """Initialize the FSM states.""" + + def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" seq_id = hash(tuple(input_ids)) @@ -42,10 +48,11 @@ def __call__(self, input_ids: List[int], last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) self._fsm_state[seq_id] = self._guide.get_next_state( - state=self._fsm_state[last_seq_id], token_id=last_token) + state=self._fsm_state[last_seq_id], token_id=last_token + ) - instruction = self._guide.get_next_instruction( - state=self._fsm_state[seq_id]) + state = self._fsm_state[seq_id] + instruction = self._guide.get_next_instruction(state=state) if type(instruction) == Generate: allowed_tokens = instruction.tokens @@ -53,14 +60,36 @@ def __call__(self, input_ids: List[int], # TODO: support fast forward tokens allowed_tokens = [instruction.tokens[0]] else: - raise TypeError( - f"Unsupported instruction type {type(instruction)}") - - mask = torch.full((scores.shape[-1], ), - -math.inf, - device=scores.device) - mask[allowed_tokens] = 0 - scores.add_(mask) + raise TypeError(f"Unsupported instruction type {type(instruction)}") + + # Retrieve allowed tokens from cache using the current state + cacheKey = instruction.id + if cacheKey not in self.cache: + # Cache miss, calculate allowed tokens and cache them + + np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32) + allowed_tokens_tensor = torch.from_numpy(np_allowed_tokens).pin_memory() + + if allowed_tokens_tensor.device != scores.device: + allowed_tokens_tensor = allowed_tokens_tensor.to( + scores.device, dtype=torch.int64, non_blocking=True + ) + else: + allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64) + + self.cache[cacheKey] = allowed_tokens_tensor + + else: + # Cache hit + allowed_tokens_tensor = self.cache[cacheKey] + + if self.mask is None: + self.mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + else: + self.mask.fill_(-math.inf) + + self.mask.index_fill_(0, allowed_tokens_tensor, 0) + scores.add_(self.mask) return scores @@ -68,8 +97,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): @classmethod @lru_cache(maxsize=32) - def _get_guide(cls, regex_string: str, - tokenizer: PreTrainedTokenizerBase) -> Guide: + def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer = _adapt_tokenizer(tokenizer) return RegexGuide(regex_string, tokenizer) @@ -84,15 +112,18 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - super().__init__( - RegexLogitsProcessor._get_guide(regex_string, tokenizer)) + tokenizer = _adapt_tokenizer(tokenizer) + super().__init__(RegexGuide(regex_string, tokenizer)) class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema: Union[str, Dict, BaseModel], - tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Union[str, None]): + def __init__( + self, + schema: Union[str, Dict, BaseModel], + tokenizer: PreTrainedTokenizerBase, + whitespace_pattern: Optional[str] = None, + ): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -118,7 +149,8 @@ def __init__(self, schema: Union[str, Dict, BaseModel], raise ValueError( f"Cannot parse schema {schema}. The schema must be either " f"a Pydantic object, a dictionary or a string that contains " - f"the JSON Schema specification") + f"the JSON Schema specification" + ) regex_string = build_regex_from_schema(schema_str, whitespace_pattern) super().__init__(regex_string, tokenizer) @@ -142,11 +174,11 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): The model's tokenizer """ - super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer)) - self._guide = self._guide.copy() + tokenizer = _adapt_tokenizer(tokenizer) + super().__init__(CFGGuide(cfg, tokenizer)) -@lru_cache(maxsize=32) +@lru_cache def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): """Adapt vLLM's tokenizer to use to compile the FSM. @@ -178,8 +210,8 @@ def convert_token_to_string(token: str) -> str: return string def change_decoder( - decoder: Callable[[List[int]], - str]) -> Callable[[List[int]], List[str]]: + decoder: Callable[[List[int]], str] + ) -> Callable[[List[int]], List[str]]: """Sync vLLM's decoder with the outlines by returning list.""" def new_decoder(inp_tokens: List[int]) -> List[str]: