Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 60 additions & 28 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union
from typing import Any, Callable, DefaultDict, Dict, List, Optional, Union

import numpy as np
import torch
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from line_profiler import profile
from outlines.fsm.guide import CFGGuide, Generate, Guide, Instruction, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase
Expand All @@ -32,44 +34,70 @@ class BaseLogitsProcessor:
def __init__(self, guide: Guide):
self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
self.cache: Dict[Instruction, torch.Tensor] = {}
self.mask: Optional[torch.Tensor] = None

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
def init_state(self):
"""Initialize the FSM states."""

def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids))

if len(input_ids) > 0:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self._fsm_state[seq_id] = self._guide.get_next_state(
state=self._fsm_state[last_seq_id], token_id=last_token)
state=self._fsm_state[last_seq_id], token_id=last_token
)

instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])
state = self._fsm_state[seq_id]
instruction = self._guide.get_next_instruction(state=state)

if type(instruction) == Generate:
allowed_tokens = instruction.tokens
elif type(instruction) == Write:
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
raise TypeError(f"Unsupported instruction type {type(instruction)}")

# Retrieve allowed tokens from cache using the current state
cacheKey = instruction.id
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me that instruction has no attribute named id

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with outlines==0.0.46 there doesn't seem to be and id attribute. I'm testing with

cacheKey = hash(tuple(allowed_tokens))

instead

if cacheKey not in self.cache:
# Cache miss, calculate allowed tokens and cache them

np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32)
allowed_tokens_tensor = torch.from_numpy(np_allowed_tokens).pin_memory()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm doing some of my testing with the cpu backend. I don't know if there is a better way to test for the availability of pin_memory() but I'm running it with:

allowed_tokens_tensor = torch.from_numpy(np_allowed_tokens)
try:
    allowed_tokens_tensor = allowed_tokens_tensor.pin_memory()
except NotImplementedError:
    pass


if allowed_tokens_tensor.device != scores.device:
allowed_tokens_tensor = allowed_tokens_tensor.to(
scores.device, dtype=torch.int64, non_blocking=True
)
else:
allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64)

self.cache[cacheKey] = allowed_tokens_tensor

else:
# Cache hit
allowed_tokens_tensor = self.cache[cacheKey]

if self.mask is None:
self.mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
else:
self.mask.fill_(-math.inf)

self.mask.index_fill_(0, allowed_tokens_tensor, 0)
scores.add_(self.mask)
return scores


class RegexLogitsProcessor(BaseLogitsProcessor):

@classmethod
@lru_cache(maxsize=32)
def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide:
def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return RegexGuide(regex_string, tokenizer)

Expand All @@ -84,15 +112,18 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer

"""
super().__init__(
RegexLogitsProcessor._get_guide(regex_string, tokenizer))
tokenizer = _adapt_tokenizer(tokenizer)
super().__init__(RegexGuide(regex_string, tokenizer))


class JSONLogitsProcessor(RegexLogitsProcessor):

def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None]):
def __init__(
self,
schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Optional[str] = None,
):
"""Compile the FSM that drives the JSON-guided generation.

Parameters
Expand All @@ -118,7 +149,8 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either "
f"a Pydantic object, a dictionary or a string that contains "
f"the JSON Schema specification")
f"the JSON Schema specification"
)
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)

Expand All @@ -142,11 +174,11 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase):
The model's tokenizer

"""
super().__init__(CFGLogitsProcessor._get_guide(cfg, tokenizer))
self._guide = self._guide.copy()
tokenizer = _adapt_tokenizer(tokenizer)
super().__init__(CFGGuide(cfg, tokenizer))


@lru_cache(maxsize=32)
@lru_cache
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
"""Adapt vLLM's tokenizer to use to compile the FSM.

Expand Down Expand Up @@ -178,8 +210,8 @@ def convert_token_to_string(token: str) -> str:
return string

def change_decoder(
decoder: Callable[[List[int]],
str]) -> Callable[[List[int]], List[str]]:
decoder: Callable[[List[int]], str]
) -> Callable[[List[int]], List[str]]:
"""Sync vLLM's decoder with the outlines by returning list."""

def new_decoder(inp_tokens: List[int]) -> List[str]:
Expand Down