From ea25352c98ef0d6d009ba7b027375327cb47b26f Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Thu, 6 Jun 2024 11:16:29 -0300 Subject: [PATCH 1/4] Add factories for logits_processors This allows vllm to support stateful LPs that must be unique for each sequence. Signed-off-by: Max de Bayser --- tests/entrypoints/test_guided_processors.py | 6 +-- vllm/engine/async_llm_engine.py | 10 ++++- vllm/engine/llm_engine.py | 16 ++++--- vllm/entrypoints/openai/serving_chat.py | 4 +- vllm/entrypoints/openai/serving_completion.py | 10 ++--- .../guided_decoding/__init__.py | 8 ++-- .../lm_format_enforcer_decoding.py | 20 ++++++--- .../guided_decoding/outlines_decoding.py | 44 ++++++++++++++----- .../model_executor/layers/logits_processor.py | 12 ++--- vllm/sampling_params.py | 41 ++++++++++++++++- vllm/sequence.py | 3 +- 11 files changed, 129 insertions(+), 45 deletions(-) diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index fb32a9d155bc..4e3893b00c12 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -6,7 +6,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor_factory) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) @@ -89,7 +89,7 @@ async def test_guided_logits_processor_black_box(backend: str): regex_request = CompletionRequest(model='test', prompt=token_ids, guided_regex=TEST_REGEX) - regex_lp = await get_guided_decoding_logits_processor( + regex_lp = await get_guided_decoding_logits_processor_factory( backend, regex_request, tokenizer) assert regex_lp is not None tensor = torch.rand(32000) @@ -103,7 +103,7 @@ async def test_guided_logits_processor_black_box(backend: str): json_request = CompletionRequest(model='test', prompt=token_ids, guided_json=TEST_SCHEMA) - json_lp = await get_guided_decoding_logits_processor( + json_lp = await get_guided_decoding_logits_processor_factory( backend, json_request, tokenizer) assert json_lp is not None tensor = torch.rand(32000) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index aa1f07b5bdc2..64fee7a0c351 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -295,7 +295,7 @@ async def add_request_async( processed_inputs = await self.process_model_inputs_async( request_id=request_id, inputs=inputs, lora_request=lora_request) - self._add_processed_request( + seq_group = self._create_sequence_group( request_id=request_id, processed_inputs=processed_inputs, params=params, @@ -303,6 +303,14 @@ async def add_request_async( lora_request=lora_request, ) + if isinstance(params, SamplingParams): + for seq in seq_group.get_seqs(): + seq.data.logits_processors = \ + await params.get_logits_processors_async() + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + async def check_health_async(self) -> None: self.model_executor.check_health() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index cb5893e707c8..7cf985ea288b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -425,14 +425,14 @@ def _get_eos_token_id( return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id - def _add_processed_request( + def _create_sequence_group( self, request_id: str, processed_inputs: LLMInputs, params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], - ) -> None: + ) -> SequenceGroup: # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -462,8 +462,7 @@ def _add_processed_request( raise ValueError( "Either SamplingParams or PoolingParams must be provided.") - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + return seq_group def process_model_inputs( self, @@ -547,7 +546,7 @@ def add_request( inputs=inputs, lora_request=lora_request) - self._add_processed_request( + seq_group = self._create_sequence_group( request_id=request_id, processed_inputs=processed_inputs, params=params, @@ -555,6 +554,13 @@ def add_request( lora_request=lora_request, ) + if isinstance(params, SamplingParams): + for seq in seq_group.get_seqs(): + seq.data.logits_processors = params.get_logits_processors() + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + def _create_sequence_group_with_sampling( self, request_id: str, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index afd87f49c1c4..9dbfd75d4850 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -23,7 +23,7 @@ OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor_factory) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.utils import random_uuid @@ -172,7 +172,7 @@ async def create_chat_completion( guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( - await get_guided_decoding_logits_processor( + await get_guided_decoding_logits_processor_factory( guided_decoding_backend, request, await self.engine.get_tokenizer())) if guided_decode_logits_processor: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 572878b5527d..75cc29215055 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -21,7 +21,7 @@ OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) + get_guided_decoding_logits_processor_factory) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.utils import merge_async_iterators, random_uuid @@ -99,15 +99,15 @@ async def create_completion(self, request: CompletionRequest, decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend - guided_decode_logit_processor = ( - await get_guided_decoding_logits_processor( + guided_decode_logits_processor = ( + await get_guided_decoding_logits_processor_factory( guided_decoding_backend, request, await self.engine.get_tokenizer())) - if guided_decode_logit_processor is not None: + if guided_decode_logits_processor is not None: if sampling_params.logits_processors is None: sampling_params.logits_processors = [] sampling_params.logits_processors.append( - guided_decode_logit_processor) + guided_decode_logits_processor) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index 50aa3ec379f4..f4b6c6e580db 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -7,13 +7,13 @@ get_lm_format_enforcer_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_decoding import ( get_outlines_guided_decoding_logits_processor) -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import LogitsProcessorFactory -async def get_guided_decoding_logits_processor( +async def get_guided_decoding_logits_processor_factory( guided_decoding_backend: str, request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Optional[LogitsProcessor]: + tokenizer) -> Optional[LogitsProcessorFactory]: request = _adapt_request_for_tool_use(request) if guided_decoding_backend == 'outlines': @@ -48,4 +48,4 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, tool = tools[tool_name] request.guided_json = tool.parameters - return request + return request \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index d0a5ca5592f9..71f15ed7ce47 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -14,12 +14,23 @@ CompletionRequest) from vllm.model_executor.guided_decoding.outlines_decoding import ( get_outlines_guided_decoding_logits_processor) -from vllm.sampling_params import LogitsProcessor +from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory + + +class LMFormatDecodingLogitsProcessorFactory(LogitsProcessorFactory): + + def __init__(self, tokenizer_data, character_level_parser): + self.tokenizer_data = tokenizer_data + self.character_level_parser = character_level_parser + + def get_processor(self) -> LogitsProcessor: + return build_vllm_logits_processor(self.tokenizer_data, + self.character_level_parser) async def get_lm_format_enforcer_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer) -> Optional[LogitsProcessor]: + tokenizer) -> Optional[LMFormatDecodingLogitsProcessorFactory]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -49,9 +60,8 @@ async def get_lm_format_enforcer_guided_decoding_logits_processor( else: return None - logits_processor = build_vllm_logits_processor(tokenizer_data, - character_level_parser) - return logits_processor + return LMFormatDecodingLogitsProcessorFactory(tokenizer_data, + character_level_parser) def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 721f7e0530cb..3b0d49480074 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -3,7 +3,7 @@ from enum import Enum from json import dumps as json_dumps from re import escape as regex_escape -from typing import Tuple, Union +from typing import Optional, Tuple, Union from pydantic import BaseModel from transformers import PreTrainedTokenizerBase @@ -12,6 +12,7 @@ CompletionRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import LogitsProcessor, LogitsProcessorFactory class GuidedDecodingMode(Enum): @@ -50,12 +51,37 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm +class OutlinesDecodingLogitsProcessorFactory(LogitsProcessorFactory): + + def __init__(self, guide: str, tokenizer: PreTrainedTokenizerBase, + mode: GuidedDecodingMode, whitespace_pattern: Union[str, + None]): + self.guide = guide + self.tokenizer = tokenizer + self.mode = mode + self.whitespace_pattern = whitespace_pattern + + def get_processor(self) -> LogitsProcessor: + return _get_logits_processor(self.guide, self.tokenizer, self.mode, + self.whitespace_pattern) + + async def get_processor_async(self) -> LogitsProcessor: + global global_thread_pool + if global_thread_pool is None: + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=2) + loop = asyncio.get_running_loop() + + return await loop.run_in_executor(global_thread_pool, + _get_logits_processor, + self.guide, self.tokenizer, + self.mode, self.whitespace_pattern) + async def get_outlines_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer: PreTrainedTokenizerBase -) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, - None]: +) -> Optional[OutlinesDecodingLogitsProcessorFactory]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -66,15 +92,9 @@ async def get_outlines_guided_decoding_logits_processor( guide, mode = _get_guide_and_mode(request) if not guide or not mode: return None - - if global_thread_pool is None: - global_thread_pool = concurrent.futures.ThreadPoolExecutor( - max_workers=2) - loop = asyncio.get_running_loop() - - return await loop.run_in_executor(global_thread_pool, - _get_logits_processor, guide, tokenizer, - mode, request.guided_whitespace_pattern) + + return OutlinesDecodingLogitsProcessorFactory( + guide, tokenizer, mode, request.guided_whitespace_pattern) def _get_guide_and_mode( diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 7eee599473a1..d2e22b97832d 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -95,13 +95,13 @@ def _apply_logits_processors( logits_processed = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids - sampling_params = seq_group.sampling_params - logits_processors = sampling_params.logits_processors - if logits_processors: - found_logits_processors = True - for seq_id, logits_row_idx in zip(seq_ids, - seq_group.sample_indices): + for seq_id, logits_row_idx in zip(seq_ids, seq_group.sample_indices): + logits_processors = seq_group.seq_data[seq_id].logits_processors + + if logits_processors: + found_logits_processors = True + logits_row = logits[logits_row_idx] past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 9d8a361353e2..a82a46e42b92 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,5 +1,6 @@ """Sampling parameters for text generation.""" import copy +from abc import ABC, abstractmethod from enum import IntEnum from functools import cached_property from typing import Any, Callable, Dict, List, Optional, Union @@ -28,6 +29,29 @@ class SamplingType(IntEnum): to sample from.""" +class LogitsProcessorFactory(ABC): + """Factory for logits processors + + A factory allows the implementation of stateful logits processors by + making sure that each sequence gets it's own instance. + While a stateless LogitsProcessor callable can be shared between multiple + sequences, processors that have internal state that depends on the sequence + seen so far inherently can't be shared. + + For logits processors that have expensive initializations, it is + recommended to override the get_processor_async method to build the + object asynchronously, for example using a thread pool, to prevent the + evento pool from being blocked. + """ + + @abstractmethod + def get_processor(self) -> LogitsProcessor: + ... + + async def get_processor_async(self) -> LogitsProcessor: + return self.get_processor() + + class SamplingParams: """Sampling parameters for text generation. @@ -132,7 +156,8 @@ def __init__( detokenize: bool = True, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = None, + logits_processors: Optional[List[Union[ + LogitsProcessor, LogitsProcessorFactory]]] = None, truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None, ) -> None: self.n = n @@ -316,6 +341,20 @@ def clone(self) -> "SamplingParams": } return copy.deepcopy(self, memo=logit_processor_refs) + def get_logits_processors(self): + return [] if not self.logits_processors else [ + lp.get_processor() + if isinstance(lp, LogitsProcessorFactory) else lp + for lp in self.logits_processors + ] + + async def get_logits_processors_async(self): + return [] if not self.logits_processors else [ + await lp.get_processor_async() if isinstance( + lp, LogitsProcessorFactory) else lp + for lp in self.logits_processors + ] + def __repr__(self) -> str: return ( f"SamplingParams(n={self.n}, " diff --git a/vllm/sequence.py b/vllm/sequence.py index 2f27bf33b166..28deff29c06e 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -11,7 +11,7 @@ from vllm.inputs import LLMInputs from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams if TYPE_CHECKING: from vllm.multimodal import MultiModalData @@ -129,6 +129,7 @@ def __init__( # The number of tokens that are computed (that run against the model). self._num_computed_tokens = 0 self._stage: SequenceStage = SequenceStage.PREFILL + self.logits_processors: List[LogitsProcessor] = [] def append_token_id(self, token_id: int, logprob: float) -> None: self.output_token_ids.append(token_id) From 60fd27a519b663cef86b0cec6585beae38c75006 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 10 Jun 2024 11:32:18 -0300 Subject: [PATCH 2/4] fix unit tests Signed-off-by: Max de Bayser --- tests/entrypoints/test_guided_processors.py | 10 ++++++---- tests/test_logits_processor.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index 4e3893b00c12..4c56a649edbe 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -89,9 +89,10 @@ async def test_guided_logits_processor_black_box(backend: str): regex_request = CompletionRequest(model='test', prompt=token_ids, guided_regex=TEST_REGEX) - regex_lp = await get_guided_decoding_logits_processor_factory( + regex_lpf = await get_guided_decoding_logits_processor_factory( backend, regex_request, tokenizer) - assert regex_lp is not None + assert regex_lpf is not None + regex_lp = await regex_lpf.get_processor_async() tensor = torch.rand(32000) original_tensor = torch.clone(tensor) tensor = regex_lp(token_ids, tensor) @@ -103,9 +104,10 @@ async def test_guided_logits_processor_black_box(backend: str): json_request = CompletionRequest(model='test', prompt=token_ids, guided_json=TEST_SCHEMA) - json_lp = await get_guided_decoding_logits_processor_factory( + json_lpf = await get_guided_decoding_logits_processor_factory( backend, json_request, tokenizer) - assert json_lp is not None + assert json_lpf is not None + json_lp = await json_lpf.get_processor_async() tensor = torch.rand(32000) original_tensor = torch.clone(tensor) tensor = json_lp(token_ids, tensor) diff --git a/tests/test_logits_processor.py b/tests/test_logits_processor.py index 4ee980505a3a..321f2d8fc947 100644 --- a/tests/test_logits_processor.py +++ b/tests/test_logits_processor.py @@ -65,11 +65,13 @@ def pick_ith(token_ids, logits): seq_group_metadata_list = [] seq_lens = [] for i in range(batch_size): + seq_data = {0: SequenceData([1, 2, 3])} + seq_data[0].logits_processors = [pick_ith] seq_group_metadata_list.append( SequenceGroupMetadata( request_id=f"test_{i}", is_prompt=True, - seq_data={0: SequenceData([1, 2, 3])}, + seq_data=seq_data, sampling_params=SamplingParams(temperature=0, logits_processors=[pick_ith]), block_tables={0: [1]}, From 501554998cbb522af0a52dc65c5ac34af7a54398 Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Mon, 10 Jun 2024 12:18:48 -0300 Subject: [PATCH 3/4] Try to fix yapf error Signed-off-by: Max de Bayser --- .../guided_decoding/outlines_decoding.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index 3b0d49480074..55d635db7f8b 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -51,6 +51,7 @@ class GuidedDecodingMode(Enum): global_thread_pool = None # used for generating logits processor fsm + class OutlinesDecodingLogitsProcessorFactory(LogitsProcessorFactory): def __init__(self, guide: str, tokenizer: PreTrainedTokenizerBase, @@ -63,7 +64,7 @@ def __init__(self, guide: str, tokenizer: PreTrainedTokenizerBase, def get_processor(self) -> LogitsProcessor: return _get_logits_processor(self.guide, self.tokenizer, self.mode, - self.whitespace_pattern) + self.whitespace_pattern) async def get_processor_async(self) -> LogitsProcessor: global global_thread_pool @@ -73,9 +74,9 @@ async def get_processor_async(self) -> LogitsProcessor: loop = asyncio.get_running_loop() return await loop.run_in_executor(global_thread_pool, - _get_logits_processor, - self.guide, self.tokenizer, - self.mode, self.whitespace_pattern) + _get_logits_processor, self.guide, + self.tokenizer, self.mode, + self.whitespace_pattern) async def get_outlines_guided_decoding_logits_processor( @@ -92,7 +93,7 @@ async def get_outlines_guided_decoding_logits_processor( guide, mode = _get_guide_and_mode(request) if not guide or not mode: return None - + return OutlinesDecodingLogitsProcessorFactory( guide, tokenizer, mode, request.guided_whitespace_pattern) From b01b4fec6eb2e708834c0325789ee6f286b36b9d Mon Sep 17 00:00:00 2001 From: Max de Bayser Date: Thu, 4 Jul 2024 17:09:40 -0300 Subject: [PATCH 4/4] fix merge problem Signed-off-by: Max de Bayser --- vllm/engine/async_llm_engine.py | 9 +++++++-- vllm/engine/llm_engine.py | 1 - 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 87b6492c8e04..5d2df34404d6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -317,8 +317,13 @@ async def add_request_async( seq.data.logits_processors = \ await params.get_logits_processors_async() - # Add the sequence group to the scheduler. - self.scheduler.add_seq_group(seq_group) + # Add the sequence group to the scheduler with least unfinished seqs. + costs = [ + scheduler.get_num_unfinished_seq_groups() + for scheduler in self.scheduler + ] + min_cost_scheduler = self.scheduler[costs.index(min(costs))] + min_cost_scheduler.add_seq_group(seq_group) async def check_health_async(self) -> None: if self.tokenizer: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dc0fd176b076..1cc639c282aa 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -521,7 +521,6 @@ def _create_sequence_group( return seq_group - def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop()