From 4a9e16a60960cc946f97d184c2a1822ae65edc8a Mon Sep 17 00:00:00 2001 From: Matt Psaltis Date: Fri, 24 May 2024 01:17:40 +1000 Subject: [PATCH 1/3] Adds outlines performance improvement --- .../outlines_logits_processors.py | 54 ++++++++++++++----- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index a131c6a1b92b..6ba7ac8dfc77 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -18,12 +18,13 @@ import math from collections import defaultdict from functools import lru_cache -from typing import Callable, DefaultDict, Dict, List, Union +from typing import Callable, DefaultDict, Dict, List, Optional, Union -import torch +import numpy as np from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel +from torch import Tensor, from_numpy, full_like, int64 from transformers import PreTrainedTokenizerBase @@ -32,13 +33,14 @@ class BaseLogitsProcessor: def __init__(self): # Child class should use initialize in their init. self.fsm: FSM + self.mask: Optional[Tensor] = None + self.allowed_tokens_cache: Dict[int, Tensor] = {} def init_state(self): - """Initialize the FSM states.""" + """Initialize the FSM states""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) - def __call__(self, input_ids: List[int], - scores: torch.Tensor) -> torch.Tensor: + def __call__(self, input_ids: List[int], scores: Tensor) -> Tensor: """Use the FSM to bias the logits before sampling the next token.""" seq_id = hash(tuple(input_ids)) @@ -50,13 +52,34 @@ def __call__(self, input_ids: List[int], self.fsm_state[seq_id] = self.fsm.next_state( self.fsm_state[last_seq_id], last_token) - allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + state = self.fsm_state[seq_id] + + # Retrieve allowed tokens from cache using the current state + if state not in self.allowed_tokens_cache: + # Cache miss, calculate allowed tokens and cache them + allowed_tokens = self.fsm.allowed_token_ids(state) + np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32) + allowed_tokens_tensor = from_numpy(np_allowed_tokens) + + if allowed_tokens_tensor.device != scores.device: + allowed_tokens_tensor = allowed_tokens_tensor.to( + scores.device, dtype=int64, non_blocking=True) + else: + allowed_tokens_tensor = allowed_tokens_tensor.to(int64) + + self.allowed_tokens_cache[state] = allowed_tokens_tensor + + else: + allowed_tokens_tensor = self.allowed_tokens_cache[state] + + if self.mask is None: + self.mask = full_like(scores, -math.inf) + else: + self.mask.fill_(-math.inf) + + self.mask.index_fill_(0, allowed_tokens_tensor, 0) + scores.add_(self.mask) - mask = torch.full((scores.shape[-1], ), - -math.inf, - device=scores.device) - mask[allowed_tokens] = 0 - scores.add_(mask) return scores @@ -80,9 +103,12 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema: Union[str, Dict, BaseModel], - tokenizer: PreTrainedTokenizerBase, - whitespace_pattern: Union[str, None]): + def __init__( + self, + schema: Union[str, Dict, BaseModel], + tokenizer: PreTrainedTokenizerBase, + whitespace_pattern: Union[str, None], + ): """Compile the FSM that drives the JSON-guided generation. Parameters From 29d0b2de2fda2d7d5582000b78bf25cb39640f38 Mon Sep 17 00:00:00 2001 From: Matt Psaltis Date: Sat, 25 May 2024 15:16:31 +1000 Subject: [PATCH 2/3] Resolving feedback --- .../outlines_logits_processors.py | 35 ++++++++++++------- 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 6ba7ac8dfc77..b82917bec041 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,10 +21,10 @@ from typing import Callable, DefaultDict, Dict, List, Optional, Union import numpy as np +import torch from outlines.fsm.fsm import CFGFSM, FSM, RegexFSM from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel -from torch import Tensor, from_numpy, full_like, int64 from transformers import PreTrainedTokenizerBase @@ -33,14 +33,15 @@ class BaseLogitsProcessor: def __init__(self): # Child class should use initialize in their init. self.fsm: FSM - self.mask: Optional[Tensor] = None - self.allowed_tokens_cache: Dict[int, Tensor] = {} + self.mask: Optional[torch.Tensor] = None + self.allowed_tokens_cache: Dict[tuple, torch.Tensor] = {} def init_state(self): """Initialize the FSM states""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) - def __call__(self, input_ids: List[int], scores: Tensor) -> Tensor: + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" seq_id = hash(tuple(input_ids)) @@ -53,27 +54,33 @@ def __call__(self, input_ids: List[int], scores: Tensor) -> Tensor: self.fsm_state[last_seq_id], last_token) state = self.fsm_state[seq_id] + allowed_tokens = self.fsm.allowed_token_ids(state) + allowed_tokens_key = tuple(allowed_tokens) # Retrieve allowed tokens from cache using the current state - if state not in self.allowed_tokens_cache: + if allowed_tokens_key not in self.allowed_tokens_cache: # Cache miss, calculate allowed tokens and cache them - allowed_tokens = self.fsm.allowed_token_ids(state) np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32) - allowed_tokens_tensor = from_numpy(np_allowed_tokens) + allowed_tokens_tensor = torch.from_numpy( + np_allowed_tokens).pin_memory() if allowed_tokens_tensor.device != scores.device: allowed_tokens_tensor = allowed_tokens_tensor.to( - scores.device, dtype=int64, non_blocking=True) + scores.device, dtype=torch.int64, non_blocking=True) else: - allowed_tokens_tensor = allowed_tokens_tensor.to(int64) + allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64) - self.allowed_tokens_cache[state] = allowed_tokens_tensor + self.allowed_tokens_cache[ + allowed_tokens_key] = allowed_tokens_tensor else: - allowed_tokens_tensor = self.allowed_tokens_cache[state] + allowed_tokens_tensor = self.allowed_tokens_cache[ + allowed_tokens_key] if self.mask is None: - self.mask = full_like(scores, -math.inf) + self.mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) else: self.mask.fill_(-math.inf) @@ -99,6 +106,8 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm + self.mask: Optional[torch.Tensor] = None + self.allowed_tokens_cache: Dict[tuple, torch.Tensor] = {} class JSONLogitsProcessor(RegexLogitsProcessor): @@ -155,6 +164,8 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) self.fsm = fsm + self.mask: Optional[torch.Tensor] = None + self.allowed_tokens_cache: Dict[tuple, torch.Tensor] = {} def init_state(self): """Initialize state with a CFGFSM copy.""" From 69a2c31082eefb1a098db4d60c22168aac03a740 Mon Sep 17 00:00:00 2001 From: Matt Psaltis Date: Sun, 26 May 2024 23:53:30 +1000 Subject: [PATCH 3/3] Fixes --- .../outlines_logits_processors.py | 80 ++++++++++++------- 1 file changed, 52 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index b82917bec041..8e00f71b6a4b 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -28,17 +28,24 @@ from transformers import PreTrainedTokenizerBase -class BaseLogitsProcessor: +class LogitsInfo: - def __init__(self): - # Child class should use initialize in their init. - self.fsm: FSM + def __init__(self, fsm: FSM): + self.fsm: FSM = fsm self.mask: Optional[torch.Tensor] = None - self.allowed_tokens_cache: Dict[tuple, torch.Tensor] = {} + self.allowed_tokens: Dict[int, torch.Tensor] = {} + self.cache_hash: Dict[int, int] = {} + + +class BaseLogitsProcessor: + + def __init__(self, fsm: FSM): + self.fsm = fsm def init_state(self): """Initialize the FSM states""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) + self.info: LogitsInfo = LogitsInfo(self.fsm) def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: @@ -50,15 +57,17 @@ def __call__(self, input_ids: List[int], else: last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) - self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[seq_id] = self.info.fsm.next_state( self.fsm_state[last_seq_id], last_token) state = self.fsm_state[seq_id] - allowed_tokens = self.fsm.allowed_token_ids(state) - allowed_tokens_key = tuple(allowed_tokens) + allowed_tokens = self.info.fsm.allowed_token_ids(state) - # Retrieve allowed tokens from cache using the current state - if allowed_tokens_key not in self.allowed_tokens_cache: + allowed_tokens_hash = hash(tuple(allowed_tokens)) + cacheEntry = (self.info.allowed_tokens[allowed_tokens_hash] + if allowed_tokens_hash in self.info.cache_hash else None) + + if cacheEntry is None: # Cache miss, calculate allowed tokens and cache them np_allowed_tokens = np.array(allowed_tokens, dtype=np.int32) allowed_tokens_tensor = torch.from_numpy( @@ -70,25 +79,36 @@ def __call__(self, input_ids: List[int], else: allowed_tokens_tensor = allowed_tokens_tensor.to(torch.int64) - self.allowed_tokens_cache[ - allowed_tokens_key] = allowed_tokens_tensor + self.info.allowed_tokens[ + allowed_tokens_hash] = allowed_tokens_tensor + self.info.cache_hash[allowed_tokens_hash] = allowed_tokens_hash else: - allowed_tokens_tensor = self.allowed_tokens_cache[ - allowed_tokens_key] + allowed_tokens_tensor = self.info.allowed_tokens[ + allowed_tokens_hash] - if self.mask is None: - self.mask = torch.full((scores.shape[-1], ), - -math.inf, - device=scores.device) + if self.info.mask is None: + self.info.mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) else: - self.mask.fill_(-math.inf) + self.info.mask.fill_(-math.inf) - self.mask.index_fill_(0, allowed_tokens_tensor, 0) - scores.add_(self.mask) + self.info.mask.index_fill_(0, allowed_tokens_tensor, 0) + scores.add_(self.info.mask) return scores + def state_reset_required(self) -> bool: + """Determine if a state reset is required for this processor. + + Returns + ------- + bool + Indicates whether a reset is required. Default is False. + """ + return False + class RegexLogitsProcessor(BaseLogitsProcessor): @@ -105,9 +125,7 @@ def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase): """ tokenizer = _adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) - self.fsm = fsm - self.mask: Optional[torch.Tensor] = None - self.allowed_tokens_cache: Dict[tuple, torch.Tensor] = {} + super().__init__(fsm=fsm) class JSONLogitsProcessor(RegexLogitsProcessor): @@ -163,14 +181,20 @@ def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase): """ tokenizer = _adapt_tokenizer(tokenizer) fsm = CFGFSM(cfg, tokenizer) - self.fsm = fsm - self.mask: Optional[torch.Tensor] = None - self.allowed_tokens_cache: Dict[tuple, torch.Tensor] = {} + self._previous_fsm = None + super().__init__(fsm=fsm) def init_state(self): """Initialize state with a CFGFSM copy.""" super().init_state() - self.fsm = self.fsm.copy() + self.fsm = self.info.fsm.copy() + + def state_reset_required(self) -> bool: + requiresReset = (self._previous_fsm is None or self.info.fsm.regex_fsm + is not self.info._previous_fsm) + + self._previous_fsm = self.info.fsm.regex_fsm + return requiresReset @lru_cache