diff --git a/tests/entrypoints/openai/test_guided_processors.py b/tests/entrypoints/openai/test_guided_processors.py index 85cb4d52200c..ee35efa687f2 100644 --- a/tests/entrypoints/openai/test_guided_processors.py +++ b/tests/entrypoints/openai/test_guided_processors.py @@ -6,7 +6,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor_factory) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) @@ -47,9 +47,10 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, regex_request = CompletionRequest(model='test', prompt=token_ids, guided_regex=sample_regex) - regex_lp = await get_guided_decoding_logits_processor( + regex_lpf = await get_guided_decoding_logits_processor_factory( backend, regex_request, tokenizer) - assert regex_lp is not None + assert regex_lpf is not None + regex_lp = await regex_lpf.get_processor_async() tensor = torch.rand(32000) original_tensor = torch.clone(tensor) tensor = regex_lp(token_ids, tensor) @@ -62,9 +63,10 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, json_request = CompletionRequest(model='test', prompt=token_ids, guided_json=sample_json_schema) - json_lp = await get_guided_decoding_logits_processor( + json_lpf = await get_guided_decoding_logits_processor_factory( backend, json_request, tokenizer) - assert json_lp is not None + assert json_lpf is not None + json_lp = await json_lpf.get_processor_async() tensor = torch.rand(32000) original_tensor = torch.clone(tensor) tensor = json_lp(token_ids, tensor) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 8ee2d78190cd..0df357f664ca 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -65,11 +65,13 @@ def pick_ith(token_ids, logits): seq_group_metadata_list = [] seq_lens = [] for i in range(batch_size): + seq_data = {0: SequenceData([1, 2, 3])} + seq_data[0].logits_processors = [pick_ith] seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data=seq_data, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a28b20fcbbcd..d1504fc32131 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -458,7 +458,7 @@ async def add_request_async( prompt_adapter_request=prompt_adapter_request, ) - self._add_processed_request( + seq_group = self._create_sequence_group( request_id=request_id, processed_inputs=processed_inputs, params=params, @@ -468,6 +468,19 @@ async def add_request_async( trace_headers=trace_headers, ) + if isinstance(params, SamplingParams): + for seq in seq_group.get_seqs(): + seq.data.logits_processors = \ + await params.get_logits_processors_async() + + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + async def check_health_async(self) -> None: if self.tokenizer: self.tokenizer.check_health() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 979555eb6a05..f1c654bc37cb 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -570,7 +570,7 @@ def _get_decoder_start_token_id(self) -> Optional[int]: return dec_start_token_id - def _add_processed_request( + def _create_sequence_group( self, request_id: str, processed_inputs: Union[LLMInputs, EncoderDecoderLLMInputs], @@ -579,7 +579,7 @@ def _add_processed_request( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Mapping[str, str]] = None, - ) -> None: + ) -> SequenceGroup: # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -622,13 +622,7 @@ def _add_processed_request( raise ValueError( "Either SamplingParams or PoolingParams must be provided.") - # Add the sequence group to the scheduler with least unfinished seqs. - costs = [ - scheduler.get_num_unfinished_seq_groups() - for scheduler in self.scheduler - ] - min_cost_scheduler = self.scheduler[costs.index(min(costs))] - min_cost_scheduler.add_seq_group(seq_group) + return seq_group def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() @@ -1026,7 +1020,7 @@ def add_request( prompt_adapter_request=prompt_adapter_request, ) - self._add_processed_request( + seq_group = self._create_sequence_group( request_id=request_id, processed_inputs=processed_inputs, params=params, @@ -1036,6 +1030,18 @@ def add_request( trace_headers=trace_headers, ) + if isinstance(params, SamplingParams): + for seq in seq_group.get_seqs(): + seq.data.logits_processors = params.get_logits_processors() + + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) + def _create_sequence_group_with_sampling( self, request_id: str, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7da3002b283f..08e51efd8131 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -12,7 +12,8 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import (LogitsProcessor, LogitsProcessorFactory, + SamplingParams) from vllm.utils import random_uuid # torch is mocked during docs generation, @@ -232,10 +233,10 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor], - default_max_tokens: int) -> SamplingParams: + def to_sampling_params(self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[Union[ + LogitsProcessor, LogitsProcessorFactory]], + default_max_tokens: int) -> SamplingParams: max_tokens = self.max_tokens if max_tokens is None: max_tokens = default_max_tokens diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2167b967b14b..fc39e1f4ae2e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -128,7 +128,8 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: guided_decode_logits_processor = ( - await self._guided_decode_logits_processor(request, tokenizer)) + await self._guided_decode_logits_processor_factory( + request, tokenizer)) prompt_inputs = self._tokenize_prompt_input( request, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index f4c91ce04684..5404569576a2 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -96,7 +96,8 @@ async def create_completion(self, request: CompletionRequest, lora_request) guided_decode_logits_processor = ( - await self._guided_decode_logits_processor(request, tokenizer)) + await self._guided_decode_logits_processor_factory( + request, tokenizer)) prompts = list( self._tokenize_prompt_input_or_inputs( request, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 8d8b5ea4bdf5..c8404db9b339 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -26,10 +26,10 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor_factory) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import LogitsProcessor, SamplingParams +from vllm.sampling_params import LogitsProcessorFactory, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer_group import AnyTokenizer @@ -152,13 +152,13 @@ def create_streaming_error_response( }) return json_str - async def _guided_decode_logits_processor( + async def _guided_decode_logits_processor_factory( self, request: Union[ChatCompletionRequest, CompletionRequest], - tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: + tokenizer: AnyTokenizer) -> Optional[LogitsProcessorFactory]: decoding_config = await self.async_engine_client.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend - return await get_guided_decoding_logits_processor( + return await get_guided_decoding_logits_processor_factory( guided_decoding_backend, request, tokenizer) async def _check_model( diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 4a2476dd6314..d0e9b293610a 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -8,13 +8,13 @@ from vllm.model_executor.guided_decoding.outlines_decoding import ( get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory -async def get_guided_decoding_logits_processor( +async def get_guided_decoding_logits_processor_factory( guided_decoding_backend: str, request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Optional[LogitsProcessor]: + tokenizer) -> Optional[LogitsProcessorFactory]: request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': @@ -70,4 +70,4 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, tool = tools[tool_name] request.guided_json = tool.parameters - return request + return request \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index b2188c9cbc2b..b368ff5cf516 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -17,12 +17,23 @@ from vllm.model_executor.guided_decoding.outlines_decoding import ( get_local_outlines_guided_decoding_logits_processor, get_outlines_guided_decoding_logits_processor) -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory + + +class LMFormatDecodingLogitsProcessorFactory(LogitsProcessorFactory): + + def __init__(self, tokenizer_data, character_level_parser): + self.tokenizer_data = tokenizer_data + self.character_level_parser = character_level_parser + + def get_processor(self) -> LogitsProcessor: + return build_vllm_logits_processor(self.tokenizer_data, + self.character_level_parser) async def get_lm_format_enforcer_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Optional[LogitsProcessor]: + tokenizer) -> Optional[LMFormatDecodingLogitsProcessorFactory]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -52,9 +63,8 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( else: return None - logits_processor = build_vllm_logits_processor(tokenizer_data, - character_level_parser) - return logits_processor + return LMFormatDecodingLogitsProcessorFactory(tokenizer_data, + character_level_parser) def get_local_lm_format_enforcer_guided_decoding_logits_processor( diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bc62224dabec..490f834c4c9f 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -3,7 +3,7 @@ from enum import Enum from json import dumps as json_dumps from re import escape as regex_escape -from typing import Tuple, Union +from typing import Optional, Tuple, Union from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -14,6 +14,7 @@ GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory class GuidedDecodingMode(Enum): @@ -53,11 +54,37 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm +class OutlinesDecodingLogitsProcessorFactory(LogitsProcessorFactory): + + def __init__(self, guide: str, tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, whitespace_pattern: Union[str, + None]): + self.guide = guide + self.tokenizer = tokenizer + self.mode = mode + self.whitespace_pattern = whitespace_pattern + + def get_processor(self) -> LogitsProcessor: + return _get_logits_processor(self.guide, self.tokenizer, self.mode, + self.whitespace_pattern) + + async def get_processor_async(self) -> LogitsProcessor: + global global_thread_pool + if global_thread_pool is None: + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=2) + loop = asyncio.get_running_loop() + + return await loop.run_in_executor(global_thread_pool, + _get_logits_processor, self.guide, + self.tokenizer, self.mode, + self.whitespace_pattern) + + async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: +) -> Optional[OutlinesDecodingLogitsProcessorFactory]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -69,14 +96,8 @@ async def get_outlines_guided_decoding_logits_processor( if not guide or not mode: return None - if global_thread_pool is None: - global_thread_pool = concurrent.futures.ThreadPoolExecutor( - max_workers=2) - loop = asyncio.get_running_loop() - - return await loop.run_in_executor(global_thread_pool, - _get_logits_processor, guide, tokenizer, - mode, request.guided_whitespace_pattern) + return OutlinesDecodingLogitsProcessorFactory( + guide, tokenizer, mode, request.guided_whitespace_pattern) def get_local_outlines_guided_decoding_logits_processor( diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 1d5b6fad2e16..13ad7c8d50b5 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -121,13 +121,13 @@ def _apply_logits_processors( logits_processed = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + logits_processors = seq_group.seq_data[seq_id].logits_processors + + if logits_processors: + found_logits_processors = True + logits_row = logits[logits_row_idx] past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 04250c682cd2..5ff9417e56d0 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,5 +1,6 @@ """Sampling parameters for text generation.""" import copy +from abc import ABC, abstractmethod from enum import IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union @@ -33,6 +34,29 @@ class SamplingType(IntEnum): to sample from.""" +class LogitsProcessorFactory(ABC): + """Factory for logits processors + + A factory allows the implementation of stateful logits processors by + making sure that each sequence gets it's own instance. + While a stateless LogitsProcessor callable can be shared between multiple + sequences, processors that have internal state that depends on the sequence + seen so far inherently can't be shared. + + For logits processors that have expensive initializations, it is + recommended to override the get_processor_async method to build the + object asynchronously, for example using a thread pool, to prevent the + evento pool from being blocked. + """ + + @abstractmethod + def get_processor(self) -> LogitsProcessor: + ... + + async def get_processor_async(self) -> LogitsProcessor: + return self.get_processor() + + class SamplingParams: """Sampling parameters for text generation. @@ -138,7 +162,8 @@ def __init__( detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = None, + logits_processors: Optional[List[Union[ + LogitsProcessor, LogitsProcessorFactory]]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n @@ -344,6 +369,20 @@ def clone(self) -> "SamplingParams": } return copy.deepcopy(self, memo=logit_processor_refs) + def get_logits_processors(self): + return [] if not self.logits_processors else [ + lp.get_processor() + if isinstance(lp, LogitsProcessorFactory) else lp + for lp in self.logits_processors + ] + + async def get_logits_processors_async(self): + return [] if not self.logits_processors else [ + await lp.get_processor_async() if isinstance( + lp, LogitsProcessorFactory) else lp + for lp in self.logits_processors + ] + def __repr__(self) -> str: return ( f"SamplingParams(n={self.n}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index b83e345235cd..4b983a6ea8f0 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -15,7 +15,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams if TYPE_CHECKING: from vllm.inputs import LLMInputs @@ -140,6 +140,7 @@ def __init__( # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL + self.logits_processors: List[LogitsProcessor] = [] self._update_cached_all_tokens()