Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 79 additions & 18 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,34 @@
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
from typing import Callable, DefaultDict, Dict, List, Optional, Union

import numpy as np
import torch
from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase


class LogitsInfo:

def __init__(self, fsm: FSM):
self.fsm: FSM = fsm
self.mask: Optional[torch.Tensor] = None
self.allowed_tokens: Dict[int, torch.Tensor] = {}
self.cache_hash: Dict[int, int] = {}


class BaseLogitsProcessor:

def __init__(self):
# Child class should use initialize in their init.
self.fsm: FSM
def __init__(self, fsm: FSM):
self.fsm = fsm

def init_state(self):
"""Initialize the FSM states."""
"""Initialize the FSM states"""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
self.info: LogitsInfo = LogitsInfo(self.fsm)

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
Expand All @@ -47,18 +57,58 @@ def __call__(self, input_ids: List[int],
else:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[seq_id] = self.info.fsm.next_state(
self.fsm_state[last_seq_id], last_token)

allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
state = self.fsm_state[seq_id]
allowed_tokens = self.info.fsm.allowed_token_ids(state)

allowed_tokens_hash = hash(tuple(allowed_tokens))
cacheEntry = (self.info.allowed_tokens[allowed_tokens_hash]
if allowed_tokens_hash in self.info.cache_hash else None)

if cacheEntry is None:
# Cache miss, calculate allowed tokens and cache them
np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32)
allowed_tokens_tensor = torch.from_numpy(
np_allowed_tokens).pin_memory()

if allowed_tokens_tensor.device != scores.device:
allowed_tokens_tensor = allowed_tokens_tensor.to(
scores.device, dtype=torch.int64, non_blocking=True)
else:
allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)

self.info.allowed_tokens[
allowed_tokens_hash] = allowed_tokens_tensor
self.info.cache_hash[allowed_tokens_hash] = allowed_tokens_hash

else:
allowed_tokens_tensor = self.info.allowed_tokens[
allowed_tokens_hash]

if self.info.mask is None:
self.info.mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
else:
self.info.mask.fill_(-math.inf)

self.info.mask.index_fill_(0, allowed_tokens_tensor, 0)
scores.add_(self.info.mask)

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
return scores

def state_reset_required(self) -> bool:
"""Determine if a state reset is required for this processor.

Returns
-------
bool
Indicates whether a reset is required. Default is False.
"""
return False


class RegexLogitsProcessor(BaseLogitsProcessor):

Expand All @@ -75,14 +125,17 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
"""
tokenizer = _adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
super().__init__(fsm=fsm)


class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None]):
def __init__(
self,
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None],
):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
Expand Down Expand Up @@ -128,12 +181,20 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
"""
tokenizer = _adapt_tokenizer(tokenizer)
fsm = CFGFSM(cfg, tokenizer)
self.fsm = fsm
self._previous_fsm = None
super().__init__(fsm=fsm)

def init_state(self):
"""Initialize state with a CFGFSM copy."""
super().init_state()
self.fsm = self.fsm.copy()
self.fsm = self.info.fsm.copy()

def state_reset_required(self) -> bool:
requiresReset = (self._previous_fsm is None or self.info.fsm.regex_fsm
is not self.info._previous_fsm)

self._previous_fsm = self.info.fsm.regex_fsm
return requiresReset


@lru_cache
Expand Down