From 99f9c1ee806174150da3a86f5dec52aebe636105 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 30 Jul 2024 17:25:49 +0000 Subject: [PATCH 1/4] Fix max_tokens --- tests/entrypoints/openai/test_serving_chat.py | 39 +++++++++++++++++++ .../e2e/test_multistep_correctness.py | 6 ++- vllm/entrypoints/openai/protocol.py | 18 ++++++--- vllm/entrypoints/openai/serving_chat.py | 25 +++++------- vllm/entrypoints/openai/serving_completion.py | 27 +++++-------- vllm/entrypoints/openai/serving_engine.py | 17 ++++++-- 6 files changed, 88 insertions(+), 44 deletions(-) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 464465494b71..168ba7ba888e 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -1,7 +1,12 @@ import asyncio +from contextlib import suppress from dataclasses import dataclass +from unittest.mock import MagicMock +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.transformers_utils.tokenizer import get_tokenizer MODEL_NAME = "openai-community/gpt2" CHAT_TEMPLATE = "Dummy chat template for testing {}" @@ -42,3 +47,37 @@ async def _async_serving_chat_init(): def test_async_serving_chat_init(): serving_completion = asyncio.run(_async_serving_chat_init()) assert serving_completion.chat_template == CHAT_TEMPLATE + + +def test_serving_chat_should_set_correct_max_tokens(): + mock_engine = MagicMock(spec=AsyncLLMEngine) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + + serving_chat = OpenAIServingChat(mock_engine, + MockModelConfig(), + served_model_names=[MODEL_NAME], + response_role="assistant", + chat_template=CHAT_TEMPLATE, + lora_modules=None, + prompt_adapters=None, + request_logger=None) + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + guided_decoding_backend="outlines", + ) + + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + # AsyncLLMEngine.generate(inputs, sampling_params, ...) + assert mock_engine.generate.call_args.args[1].max_tokens == 93 + + req.max_tokens = 10 + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + + assert mock_engine.generate.call_args.args[1].max_tokens == 10 diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 86cab7aba238..ad23c08a29aa 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -35,6 +35,7 @@ """ from itertools import cycle +from typing import Optional import pytest from transformers import AutoTokenizer @@ -72,14 +73,15 @@ ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [1, 32]) +@pytest.mark.parametrize("output_len", [32, None]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_e2e_with_detokenization(test_llm_generator, - batch_size: int): + batch_size: int, + output_len: Optional[int]): """Run generation with speculative decoding on a batch. Verify the engine generates the correct number of tokens (via ignore_eos=True), and that the detokenization matches HF transformers. """ - output_len = 32 temperature = 0.0 prompts = [ diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 205860aa8e72..8b60b7a70b12 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -11,7 +11,7 @@ from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.utils import random_uuid @@ -215,15 +215,18 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params - def to_sampling_params(self, - tokenizer: PreTrainedTokenizer) -> SamplingParams: + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor] + ) -> SamplingParams: # We now allow logprobs being true without top_logrobs. - logits_processors = get_logits_processors( logit_bias=self.logit_bias, allowed_token_ids=None, tokenizer=tokenizer, ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, @@ -395,7 +398,10 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params - def to_sampling_params(self, tokenizer: PreTrainedTokenizer): + def to_sampling_params( + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor] + ) -> SamplingParams: echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -403,6 +409,8 @@ def to_sampling_params(self, tokenizer: PreTrainedTokenizer): allowed_token_ids=self.allowed_token_ids, tokenizer=tokenizer, ) + if guided_decode_logits_processor: + logits_processors.append(guided_decode_logits_processor) return SamplingParams( n=self.n, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 01843930bf11..a6fe538dda74 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -25,8 +25,6 @@ PromptAdapterPath) from vllm.inputs import PromptInputs from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict from vllm.outputs import RequestOutput from vllm.sequence import Logprob @@ -134,28 +132,25 @@ async def create_chat_completion( request_id = f"chat-{random_uuid()}" try: - sampling_params = request.to_sampling_params(tokenizer) - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logits_processor: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logits_processor) + await self._guided_decode_logits_processor(request, tokenizer)) + print(request) + print("prompt", prompt) prompt_inputs = self._tokenize_prompt_input( request, tokenizer, prompt, - truncate_prompt_tokens=sampling_params.truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) + sampling_params = request.to_sampling_params( + tokenizer, guided_decode_logits_processor) + if sampling_params.max_tokens is None: + sampling_params.max_tokens = \ + self.max_model_len - len(prompt_inputs["prompt_token_ids"]) + self._log_inputs(request_id, prompt_inputs, params=sampling_params, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 854835279168..51de2f182655 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -24,8 +24,6 @@ OpenAIServing, PromptAdapterPath) from vllm.logger import init_logger -from vllm.model_executor.guided_decoding import ( - get_guided_decoding_logits_processor) from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -95,31 +93,24 @@ async def create_completion(self, request: CompletionRequest, tokenizer = await self.engine.get_tokenizer(lora_request) - sampling_params = request.to_sampling_params(tokenizer) - decoding_config = await self.engine.get_decoding_config() - guided_decoding_backend = request.guided_decoding_backend \ - or decoding_config.guided_decoding_backend - guided_decode_logit_processor = ( - await - get_guided_decoding_logits_processor(guided_decoding_backend, - request, tokenizer)) - if guided_decode_logit_processor is not None: - if sampling_params.logits_processors is None: - sampling_params.logits_processors = [] - sampling_params.logits_processors.append( - guided_decode_logit_processor) - + guided_decode_logits_processor = ( + await self._guided_decode_logits_processor(request, tokenizer)) prompts = list( self._tokenize_prompt_input_or_inputs( request, tokenizer, request.prompt, - truncate_prompt_tokens=sampling_params. - truncate_prompt_tokens, + truncate_prompt_tokens=request.truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, )) for i, prompt_inputs in enumerate(prompts): + sampling_params = request.to_sampling_params( + tokenizer, guided_decode_logits_processor) + if sampling_params.max_tokens is None: + sampling_params.max_tokens = self.max_model_len - \ + len(prompt_inputs["prompt_token_ids"]) + request_id_item = f"{request_id}-{i}" self._log_inputs(request_id_item, diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 321c9ac2c1d5..a9992f12c604 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -26,9 +26,11 @@ from vllm.inputs import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.guided_decoding import ( + get_guided_decoding_logits_processor) from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob logger = init_logger(__name__) @@ -152,6 +154,15 @@ def create_streaming_error_response( }) return json_str + async def _guided_decode_logits_processor( + self, request: Union[ChatCompletionRequest, CompletionRequest], + tokenizer: AnyTokenizer) -> Optional[LogitsProcessor]: + decoding_config = await self.engine.get_decoding_config() + guided_decoding_backend = request.guided_decoding_backend \ + or decoding_config.guided_decoding_backend + return await get_guided_decoding_logits_processor( + guided_decoding_backend, request, tokenizer) + async def _check_model( self, request: AnyRequest, @@ -256,9 +267,7 @@ def _validate_input( f"{self.max_model_len} tokens. However, you requested " f"{token_num} tokens in the messages, " f"Please reduce the length of the messages.") - request.max_tokens = self.max_model_len - token_num - - if token_num + request.max_tokens > self.max_model_len: + elif token_num + request.max_tokens > self.max_model_len: raise ValueError( f"This model's maximum context length is " f"{self.max_model_len} tokens. However, you requested " From 47c887c36ff1115eee5e0ca1f3408ac6038e83ea Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 30 Jul 2024 21:33:11 +0000 Subject: [PATCH 2/4] Revert --- tests/spec_decode/e2e/test_multistep_correctness.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index ad23c08a29aa..86cab7aba238 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -35,7 +35,6 @@ """ from itertools import cycle -from typing import Optional import pytest from transformers import AutoTokenizer @@ -73,15 +72,14 @@ ]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [1, 32]) -@pytest.mark.parametrize("output_len", [32, None]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_e2e_with_detokenization(test_llm_generator, - batch_size: int, - output_len: Optional[int]): + batch_size: int): """Run generation with speculative decoding on a batch. Verify the engine generates the correct number of tokens (via ignore_eos=True), and that the detokenization matches HF transformers. """ + output_len = 32 temperature = 0.0 prompts = [ From 68ada2df94c4c832c43cf87d76dad6bc990c2957 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Tue, 30 Jul 2024 21:34:38 +0000 Subject: [PATCH 3/4] cleanup --- vllm/entrypoints/openai/serving_chat.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a6fe538dda74..1dee798858a1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -135,8 +135,6 @@ async def create_chat_completion( guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) - print(request) - print("prompt", prompt) prompt_inputs = self._tokenize_prompt_input( request, tokenizer, From 5069fb70d675dd93ecc487bd7101fb77d16c16c4 Mon Sep 17 00:00:00 2001 From: Zifei Tong Date: Wed, 31 Jul 2024 01:50:11 +0000 Subject: [PATCH 4/4] Address comment --- vllm/entrypoints/openai/protocol.py | 24 ++++++++++++------- vllm/entrypoints/openai/serving_chat.py | 8 +++---- vllm/entrypoints/openai/serving_completion.py | 8 +++---- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 8b60b7a70b12..3b35ae1ebd70 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -216,9 +216,13 @@ class ChatCompletionRequest(OpenAIBaseModel): # doc: end-chat-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor] - ) -> SamplingParams: + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + # We now allow logprobs being true without top_logrobs. logits_processors = get_logits_processors( logit_bias=self.logit_bias, @@ -244,7 +248,7 @@ def to_sampling_params( logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens, + max_tokens=max_tokens, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, @@ -399,9 +403,13 @@ class CompletionRequest(OpenAIBaseModel): # doc: end-completion-extra-params def to_sampling_params( - self, tokenizer: PreTrainedTokenizer, - guided_decode_logits_processor: Optional[LogitsProcessor] - ) -> SamplingParams: + self, tokenizer: PreTrainedTokenizer, + guided_decode_logits_processor: Optional[LogitsProcessor], + default_max_tokens: int) -> SamplingParams: + max_tokens = self.max_tokens + if max_tokens is None: + max_tokens = default_max_tokens + echo_without_generation = self.echo and self.max_tokens == 0 logits_processors = get_logits_processors( @@ -427,7 +435,7 @@ def to_sampling_params( stop_token_ids=self.stop_token_ids, logprobs=self.logprobs, ignore_eos=self.ignore_eos, - max_tokens=self.max_tokens if not echo_without_generation else 1, + max_tokens=max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 1dee798858a1..c832cf2a24b5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -144,10 +144,10 @@ async def create_chat_completion( ) sampling_params = request.to_sampling_params( - tokenizer, guided_decode_logits_processor) - if sampling_params.max_tokens is None: - sampling_params.max_tokens = \ - self.max_model_len - len(prompt_inputs["prompt_token_ids"]) + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) self._log_inputs(request_id, prompt_inputs, diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 51de2f182655..7765c5903f34 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -106,10 +106,10 @@ async def create_completion(self, request: CompletionRequest, for i, prompt_inputs in enumerate(prompts): sampling_params = request.to_sampling_params( - tokenizer, guided_decode_logits_processor) - if sampling_params.max_tokens is None: - sampling_params.max_tokens = self.max_model_len - \ - len(prompt_inputs["prompt_token_ids"]) + tokenizer, + guided_decode_logits_processor, + default_max_tokens=self.max_model_len - + len(prompt_inputs["prompt_token_ids"])) request_id_item = f"{request_id}-{i}"