From cd3b3bbd15b3d958527e2bdcda6e210eb9ea6494 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 8 Feb 2024 02:29:43 -0800 Subject: [PATCH 01/33] first pass for JSON and regex --- vllm/entrypoints/openai/protocol.py | 2 ++ vllm/entrypoints/openai/serving_chat.py | 6 ++++ vllm/entrypoints/openai/serving_completion.py | 28 +++++++++++++++++++ vllm/model_executor/layers/sampler.py | 6 +++- 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fc15b7833ecf..96cec7c6bfdd 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -66,6 +66,7 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None + extra_body: dict = None # for structured generation # Additional parameters supported by vLLM best_of: Optional[int] = None top_k: Optional[int] = -1 @@ -122,6 +123,7 @@ class CompletionRequest(BaseModel): best_of: Optional[int] = None logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None + extra_body: dict = None # for structured generation # Additional parameters supported by vLLM top_k: Optional[int] = -1 ignore_eos: Optional[bool] = False diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a9e4c355560b..c9db1a4154af 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,6 +12,7 @@ UsageInfo) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing +from .serving_completion import get_struct_gen_logits_processor logger = init_logger(__name__) @@ -64,6 +65,11 @@ async def create_chat_completion( token_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) sampling_params = request.to_sampling_params() + if request.extra_body: # check for structured generation + sampling_params.logits_processors = \ + get_struct_gen_logits_processor( + request.extra_body, + self.engine.engine.tokenizer.tokenizer) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 191142d222ea..af3e26a516d1 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -2,6 +2,7 @@ import time from fastapi import Request from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple +from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -16,6 +17,7 @@ ) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.llm import LLM logger = init_logger(__name__) @@ -247,6 +249,27 @@ async def consumer(): return consumer() +def get_struct_gen_logits_processor(extra_body, tokenizer): + def dummy_llm(): + return LLM(model="dummy", tokenizer=tokenizer) + + if "json" in extra_body: + from pydantic import BaseModel + assert type(extra_body["json"]) in (str, dict, BaseModel), "JSON schema error" + return [JSONLogitsProcessor(extra_body["json"], dummy_llm())] + elif "regex" in extra_body: + assert type(extra_body["regex"]) is str, "Regex must be string" + return [RegexLogitsProcessor(extra_body["regex"], dummy_llm())] + elif "grammar" in extra_body: + # TODO + pass + elif "choice" in extra_body: + # TODO + pass + else: + return None + + class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model: str): @@ -284,6 +307,11 @@ async def create_completion(self, request: CompletionRequest, generators = [] try: sampling_params = request.to_sampling_params() + if request.extra_body: # check for structured generation + sampling_params.logits_processors = \ + get_struct_gen_logits_processor( + request.extra_body, + self.engine.engine.tokenizer.tokenizer) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bc86a916b5bb..934572f1ef9c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,5 +1,6 @@ """A layer that samples the next tokens from the model's outputs.""" from typing import Dict, List, Optional, Tuple +from inspect import signature as fn_signature import torch import torch.nn as nn @@ -153,7 +154,10 @@ def _apply_logits_processors( logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) + if len(fn_signature(logits_processor).parameters) == 3: + logits_row = logits_processor(seq_id, token_ids, logits_row) + else: # params len == 2 + logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: From e589bd02a27b0899b8845076118faefb40b641ab Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 8 Feb 2024 12:39:21 -0800 Subject: [PATCH 02/33] tiny refactor --- vllm/model_executor/layers/sampler.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 934572f1ef9c..429cfb9af3f7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -150,13 +150,14 @@ def _apply_logits_processors( logits_processors = sampling_params.logits_processors if logits_processors: found_logits_processors = True + logits_processor_argc = [len(fn_signature(fn).parameters) for fn in logits_processors] for seq_id in seq_ids: logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - if len(fn_signature(logits_processor).parameters) == 3: + for i, logits_processor in enumerate(logits_processors): + if logits_processor_argc[i] == 3: logits_row = logits_processor(seq_id, token_ids, logits_row) - else: # params len == 2 + else: # args len == 2 logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 From 5f55e6a09acd6b9b676540ef80f06ee568fde23a Mon Sep 17 00:00:00 2001 From: br3no Date: Thu, 8 Feb 2024 16:53:11 +0100 Subject: [PATCH 03/33] Added support for guided decoding in `api_server` by integrating _outlines_ (https://github.com/outlines-dev/outlines). --- requirements-guided-decoding.txt | 1 + tests/samplers/test_sampler.py | 3 +- vllm/entrypoints/api_server.py | 37 +++++++++++++++++ vllm/model_executor/guided_decoding.py | 56 ++++++++++++++++++++++++++ vllm/sampling_params.py | 7 ++-- 5 files changed, 100 insertions(+), 4 deletions(-) create mode 100644 requirements-guided-decoding.txt create mode 100644 vllm/model_executor/guided_decoding.py diff --git a/requirements-guided-decoding.txt b/requirements-guided-decoding.txt new file mode 100644 index 000000000000..04ef481381d0 --- /dev/null +++ b/requirements-guided-decoding.txt @@ -0,0 +1 @@ +outlines == 0.0.27 \ No newline at end of file diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index d34f32d03fee..1523ca2be7fb 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -230,7 +230,8 @@ def test_sampler_logits_processors(seed: int, device: str): # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] - def pick_ith(token_ids, logits): + # Since this processor is stateless, the seq_id is not used + def pick_ith(token_ids, logits, seq_id): logits[len(token_ids)] = float("inf") return logits diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f7b8d258fae4..f20b7b1fd339 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -10,6 +10,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid +from vllm.model_executor.guided_decoding import GuidedDecodingEngine, GuidedDecodingMode, get_logits_processor TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() @@ -28,6 +29,8 @@ async def generate(request: Request) -> Response: The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. + - schema: the JSON schema to use for the generation (if regex is not provided). + - regex: the regex to use for the generation (if schema is not provided). - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). """ @@ -35,6 +38,11 @@ async def generate(request: Request) -> Response: prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) + + if args.guided_decoding_engine is not None: + # Add logits processors if guided decoding is requested + _setup_guided_decoding(request_dict) + sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -72,6 +80,28 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) +def _setup_guided_decoding(request_dict): + json_schema = request_dict.pop("schema", None) + regex_string = request_dict.pop("regex", None) + + if json_schema is not None or regex_string is not None: + assert json_schema is None or regex_string is None, \ + "Only one of 'schema' and 'regex' can be provided." + + guided_decoding_engine = GuidedDecodingEngine( + args.guided_decoding_engine) + mode = GuidedDecodingMode( + "schema" if json_schema is not None else "regex") + logits_processors = [ + get_logits_processor(json_schema or regex_string, mode, + guided_decoding_engine, engine.engine) + ] + if request_dict.get("logits_processors") is None: + request_dict["logits_processors"] = logits_processors + else: + request_dict["logits_processors"].extend(logits_processors) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) @@ -83,6 +113,13 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--guided-decoding-engine", + type=str, + default=None, + help= + "What engine for guided decoding to use. Currently only `oulines` is supported." + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py new file mode 100644 index 000000000000..84845b0e241a --- /dev/null +++ b/vllm/model_executor/guided_decoding.py @@ -0,0 +1,56 @@ +from enum import Enum +import time +from typing import List, Union +try: + from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor +except ImportError: + raise ValueError("Please install 'outlines' (pip install outlines) to use guided generation.") +import torch + +from vllm.engine.llm_engine import LLMEngine +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import LogitsProcessor + +class GuidedDecodingEngine(Enum): + OUTLINES = "outlines" + +class GuidedDecodingMode(Enum): + REGEX = "regex" + JSON_SCHEMA = "schema" + +class OutlinesJSONLogitsProcessor(JSONLogitsProcessor): + + def __init__(self, json_schema: dict, llm: LLM): + super().__init__(json_schema, llm) + + def __call__( + self, + input_ids: List[int], + scores: torch.Tensor, + seq_id: int, + ) -> torch.Tensor: + return super().__call__(seq_id, input_ids, scores) + + +class OulinesRegexLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, regex: str, llm: LLM): + super().__init__(regex, llm) + + def __call__( + self, + input_ids: List[int], + scores: torch.Tensor, + seq_id: int, + ) -> torch.Tensor: + return super().__call__(seq_id, input_ids, scores) + + +def get_logits_processor(specification: Union[str, dict], mode: GuidedDecodingMode, engine: GuidedDecodingEngine, llm_engine: LLMEngine): + if engine == GuidedDecodingEngine.OUTLINES: + if mode == GuidedDecodingMode.JSON_SCHEMA: + return OutlinesJSONLogitsProcessor(specification, llm_engine) + elif mode == GuidedDecodingMode.REGEX: + return OulinesRegexLogitsProcessor(specification, llm_engine) + else: + raise ValueError(f"Unknown mode: {mode}") \ No newline at end of file diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index bb7d0002c910..a3c577dbd96e 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -14,10 +14,11 @@ class SamplingType(IntEnum): BEAM = 2 -LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +LogitsProcessor = Callable[[List[int], torch.Tensor, int], torch.Tensor] """LogitsProcessor is a function that takes a list of previously generated -tokens and a tensor of the logits for the next token, and returns a modified -tensor of logits to sample from.""" +tokens, a tensor of the logits for the next token and an integer sequence id, +and returns a modified tensor of logits to sample from. The sequence id is used +to distinguish different generations, in case the processor is stateful.""" class SamplingParams: From 3a051cf1fa1d6fc6f1c1138c3458e8a43728b3f9 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 8 Feb 2024 17:39:26 -0800 Subject: [PATCH 04/33] refactor/combine breno's PR with mine --- tests/samplers/test_sampler.py | 2 +- vllm/entrypoints/api_server.py | 58 ++++++++----------- vllm/entrypoints/openai/protocol.py | 6 +- vllm/entrypoints/openai/serving_chat.py | 20 +++++-- vllm/entrypoints/openai/serving_completion.py | 43 +++++--------- vllm/model_executor/guided_decoding.py | 57 +++++------------- vllm/sampling_params.py | 10 ++-- 7 files changed, 80 insertions(+), 116 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 1523ca2be7fb..44cd9ef2a07a 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -231,7 +231,7 @@ def test_sampler_logits_processors(seed: int, device: str): # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] # Since this processor is stateless, the seq_id is not used - def pick_ith(token_ids, logits, seq_id): + def pick_ith(token_ids, logits): logits[len(token_ids)] = float("inf") return logits diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index f20b7b1fd339..1827db24916f 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -10,7 +10,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from vllm.model_executor.guided_decoding import GuidedDecodingEngine, GuidedDecodingMode, get_logits_processor +from vllm.vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() @@ -29,20 +29,18 @@ async def generate(request: Request) -> Response: The request should be a JSON object with the following fields: - prompt: the prompt to use for the generation. - - schema: the JSON schema to use for the generation (if regex is not provided). - - regex: the regex to use for the generation (if schema is not provided). - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). + - guided decoding (structured generation) can be specified with: + - guided_json: JSON schema (str, dict, Pydantic BaseModel). + - guided_regex: a regex string. """ request_dict = await request.json() prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) - if args.guided_decoding_engine is not None: - # Add logits processors if guided decoding is requested - _setup_guided_decoding(request_dict) - + _setup_guided_decoding(request_dict) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -81,25 +79,25 @@ async def stream_results() -> AsyncGenerator[bytes, None]: def _setup_guided_decoding(request_dict): - json_schema = request_dict.pop("schema", None) - regex_string = request_dict.pop("regex", None) - - if json_schema is not None or regex_string is not None: - assert json_schema is None or regex_string is None, \ - "Only one of 'schema' and 'regex' can be provided." - - guided_decoding_engine = GuidedDecodingEngine( - args.guided_decoding_engine) - mode = GuidedDecodingMode( - "schema" if json_schema is not None else "regex") - logits_processors = [ - get_logits_processor(json_schema or regex_string, mode, - guided_decoding_engine, engine.engine) - ] - if request_dict.get("logits_processors") is None: - request_dict["logits_processors"] = logits_processors - else: - request_dict["logits_processors"].extend(logits_processors) + guided_json = request_dict.pop("guided_json", None) + guided_regex = request_dict.pop("guided_regex", None) + + # if both json and regex exist, use the json + if guided_json: + logits_processors = get_guided_decoding_logits_processor( + guided_json, GuidedDecodingMode("json"), + engine.engine.tokenizer.tokenizer) + elif guided_regex: + logits_processors = get_guided_decoding_logits_processor( + guided_regex, GuidedDecodingMode("regex"), + engine.engine.tokenizer.tokenizer) + else: + return + + if request_dict.get("logits_processors") is None: + request_dict["logits_processors"] = logits_processors + else: + request_dict["logits_processors"].extend(logits_processors) if __name__ == "__main__": @@ -113,13 +111,7 @@ def _setup_guided_decoding(request_dict): type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") - parser.add_argument( - "--guided-decoding-engine", - type=str, - default=None, - help= - "What engine for guided decoding to use. Currently only `oulines` is supported." - ) + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 96cec7c6bfdd..2edd8804f658 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -66,7 +66,6 @@ class ChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None - extra_body: dict = None # for structured generation # Additional parameters supported by vLLM best_of: Optional[int] = None top_k: Optional[int] = -1 @@ -81,6 +80,8 @@ class ChatCompletionRequest(BaseModel): min_p: Optional[float] = 0.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 + guided_json: Optional[dict] = None + guided_regex: Optional[str] = None def to_sampling_params(self) -> SamplingParams: return SamplingParams( @@ -123,7 +124,6 @@ class CompletionRequest(BaseModel): best_of: Optional[int] = None logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None - extra_body: dict = None # for structured generation # Additional parameters supported by vLLM top_k: Optional[int] = -1 ignore_eos: Optional[bool] = False @@ -135,6 +135,8 @@ class CompletionRequest(BaseModel): min_p: Optional[float] = 0.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 + guided_json: Optional[Union[str, dict, BaseModel]] = None + guided_regex: Optional[str] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index c9db1a4154af..579080c4d8af 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,7 +12,7 @@ UsageInfo) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing -from .serving_completion import get_struct_gen_logits_processor +from vllm.vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -65,11 +65,7 @@ async def create_chat_completion( token_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) sampling_params = request.to_sampling_params() - if request.extra_body: # check for structured generation - sampling_params.logits_processors = \ - get_struct_gen_logits_processor( - request.extra_body, - self.engine.engine.tokenizer.tokenizer) + sampling_params.logits_processors = self._get_struct_gen_logits_processor(request) except ValueError as e: return self.create_error_response(str(e)) @@ -269,3 +265,15 @@ def _load_chat_template(self, chat_template): else: logger.warning( "No chat template provided. Chat API will not work.") + + def _get_guided_decoding_logits_processor(self, request: ChatCompletionRequest): + if request.guided_json: + return get_guided_decoding_logits_processor( + request.guided_json, GuidedDecodingMode("json"), + self.engine.engine.tokenizer.tokenizer) + elif request.guided_regex: + return get_guided_decoding_logits_processor( + request.guided_regex, GuidedDecodingMode("regex"), + self.engine.engine.tokenizer.tokenizer) + else: + return None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index af3e26a516d1..4a28966521a2 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -2,7 +2,6 @@ import time from fastapi import Request from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple -from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor from vllm.logger import init_logger from vllm.utils import random_uuid from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -17,7 +16,7 @@ ) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.llm import LLM +from vllm.vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -249,27 +248,6 @@ async def consumer(): return consumer() -def get_struct_gen_logits_processor(extra_body, tokenizer): - def dummy_llm(): - return LLM(model="dummy", tokenizer=tokenizer) - - if "json" in extra_body: - from pydantic import BaseModel - assert type(extra_body["json"]) in (str, dict, BaseModel), "JSON schema error" - return [JSONLogitsProcessor(extra_body["json"], dummy_llm())] - elif "regex" in extra_body: - assert type(extra_body["regex"]) is str, "Regex must be string" - return [RegexLogitsProcessor(extra_body["regex"], dummy_llm())] - elif "grammar" in extra_body: - # TODO - pass - elif "choice" in extra_body: - # TODO - pass - else: - return None - - class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model: str): @@ -307,11 +285,7 @@ async def create_completion(self, request: CompletionRequest, generators = [] try: sampling_params = request.to_sampling_params() - if request.extra_body: # check for structured generation - sampling_params.logits_processors = \ - get_struct_gen_logits_processor( - request.extra_body, - self.engine.engine.tokenizer.tokenizer) + sampling_params.logits_processors = self._get_guided_decoding_logits_processor(request) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): @@ -375,3 +349,16 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: return fake_stream_generator() return response + + def _get_guided_decoding_logits_processor(self, request: CompletionRequest): + # should this go inside CompletionRequest.to_sampling_params() instead? + if request.guided_json: + return get_guided_decoding_logits_processor( + request.guided_json, GuidedDecodingMode("json"), + self.engine.engine.tokenizer.tokenizer) + elif request.guided_regex: + return get_guided_decoding_logits_processor( + request.guided_regex, GuidedDecodingMode("regex"), + self.engine.engine.tokenizer.tokenizer) + else: + return None \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 84845b0e241a..19ab62772e84 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,56 +1,29 @@ from enum import Enum -import time -from typing import List, Union +from typing import Union +from pydantic import BaseModel try: from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor except ImportError: raise ValueError("Please install 'outlines' (pip install outlines) to use guided generation.") -import torch -from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.llm import LLM -from vllm.sampling_params import LogitsProcessor -class GuidedDecodingEngine(Enum): - OUTLINES = "outlines" class GuidedDecodingMode(Enum): + JSON = "json" REGEX = "regex" - JSON_SCHEMA = "schema" + # TODO: add grammar, choice -class OutlinesJSONLogitsProcessor(JSONLogitsProcessor): - def __init__(self, json_schema: dict, llm: LLM): - super().__init__(json_schema, llm) +def get_guided_decoding_logits_processor(guided_spec: Union[str, dict, BaseModel], mode: GuidedDecodingMode, tokenizer): + def dummy_llm(): + return LLM(model="dummy", tokenizer=tokenizer) - def __call__( - self, - input_ids: List[int], - scores: torch.Tensor, - seq_id: int, - ) -> torch.Tensor: - return super().__call__(seq_id, input_ids, scores) - - -class OulinesRegexLogitsProcessor(RegexLogitsProcessor): - - def __init__(self, regex: str, llm: LLM): - super().__init__(regex, llm) - - def __call__( - self, - input_ids: List[int], - scores: torch.Tensor, - seq_id: int, - ) -> torch.Tensor: - return super().__call__(seq_id, input_ids, scores) - - -def get_logits_processor(specification: Union[str, dict], mode: GuidedDecodingMode, engine: GuidedDecodingEngine, llm_engine: LLMEngine): - if engine == GuidedDecodingEngine.OUTLINES: - if mode == GuidedDecodingMode.JSON_SCHEMA: - return OutlinesJSONLogitsProcessor(specification, llm_engine) - elif mode == GuidedDecodingMode.REGEX: - return OulinesRegexLogitsProcessor(specification, llm_engine) - else: - raise ValueError(f"Unknown mode: {mode}") \ No newline at end of file + if mode == GuidedDecodingMode.JSON: + assert type(guided_spec) in (str, dict, BaseModel), "JSON schema error" + return [JSONLogitsProcessor(guided_spec, dummy_llm())] + elif mode == GuidedDecodingMode.REGEX: + assert type(guided_spec) is str, "Regex must be string" + return [RegexLogitsProcessor(guided_spec, dummy_llm())] + else: + return None \ No newline at end of file diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index a3c577dbd96e..bfd05a644580 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -14,11 +14,13 @@ class SamplingType(IntEnum): BEAM = 2 -LogitsProcessor = Callable[[List[int], torch.Tensor, int], torch.Tensor] +LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], + Callable[[int, List[int], torch.Tensor], torch.Tensor]] """LogitsProcessor is a function that takes a list of previously generated -tokens, a tensor of the logits for the next token and an integer sequence id, -and returns a modified tensor of logits to sample from. The sequence id is used -to distinguish different generations, in case the processor is stateful.""" +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from. Some processors may also take in an integer +sequence id, which is used to distinguish different generations, in case the +processor is stateful (such as for guided decoding).""" class SamplingParams: From 54217ba85505d698d9887d27c3b5f8b3514a9ccc Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 8 Feb 2024 17:42:48 -0800 Subject: [PATCH 05/33] fix type check --- vllm/model_executor/guided_decoding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 19ab62772e84..2f7f7945d6d1 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -20,10 +20,10 @@ def dummy_llm(): return LLM(model="dummy", tokenizer=tokenizer) if mode == GuidedDecodingMode.JSON: - assert type(guided_spec) in (str, dict, BaseModel), "JSON schema error" + assert isinstance(guided_spec, (str, dict, BaseModel)), "JSON schema error" return [JSONLogitsProcessor(guided_spec, dummy_llm())] elif mode == GuidedDecodingMode.REGEX: - assert type(guided_spec) is str, "Regex must be string" + assert isinstance(guided_spec, str), "Regex must be string" return [RegexLogitsProcessor(guided_spec, dummy_llm())] else: return None \ No newline at end of file From c9c6f4fc61baaf74bfbc2ba13910247d47de7b71 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 8 Feb 2024 17:45:14 -0800 Subject: [PATCH 06/33] fix try-except --- vllm/model_executor/guided_decoding.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 2f7f7945d6d1..4778ed63cb66 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -3,8 +3,10 @@ from pydantic import BaseModel try: from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor -except ImportError: - raise ValueError("Please install 'outlines' (pip install outlines) to use guided generation.") +except ImportError as e: + raise ValueError( + "Please install 'outlines' (pip install outlines) to use guided generation." + ) from e from vllm.entrypoints.llm import LLM From b82dedb00103f57371238e9c84c1cc8a0a9af412 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Fri, 9 Feb 2024 13:50:52 -0800 Subject: [PATCH 07/33] fix import bug --- vllm/entrypoints/api_server.py | 2 +- vllm/entrypoints/openai/serving_chat.py | 4 ++-- vllm/entrypoints/openai/serving_completion.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 1827db24916f..7abaf00373eb 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -10,7 +10,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from vllm.vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor +from vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 579080c4d8af..e82ea33b45ff 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,7 +12,7 @@ UsageInfo) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor +from vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -65,7 +65,7 @@ async def create_chat_completion( token_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) sampling_params = request.to_sampling_params() - sampling_params.logits_processors = self._get_struct_gen_logits_processor(request) + sampling_params.logits_processors = self._get_guided_decoding_logits_processor(request) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 4a28966521a2..fe3a709dc8fc 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -16,7 +16,7 @@ ) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor +from vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor logger = init_logger(__name__) From ba92cb234f483c28b29dd4621151bc7fd7ebd097 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Fri, 9 Feb 2024 13:59:32 -0800 Subject: [PATCH 08/33] add outlines v0.0.27 requirement --- requirements-guided-decoding.txt | 1 - requirements.txt | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) delete mode 100644 requirements-guided-decoding.txt diff --git a/requirements-guided-decoding.txt b/requirements-guided-decoding.txt deleted file mode 100644 index 04ef481381d0..000000000000 --- a/requirements-guided-decoding.txt +++ /dev/null @@ -1 +0,0 @@ -outlines == 0.0.27 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 5684b2c29634..b8d99009847e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,3 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server. aioprometheus[starlette] pynvml == 11.5.0 triton >= 2.1.0 +outlines == 0.0.27 From da2f5b8fa49f0155e93ba9a03a86406448662e0f Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Sat, 10 Feb 2024 15:48:16 -0800 Subject: [PATCH 09/33] fix dummy_llm --- vllm/model_executor/guided_decoding.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 4778ed63cb66..9a341d864d30 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,5 +1,6 @@ from enum import Enum from typing import Union +from types import SimpleNamespace from pydantic import BaseModel try: from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor @@ -8,8 +9,6 @@ "Please install 'outlines' (pip install outlines) to use guided generation." ) from e -from vllm.entrypoints.llm import LLM - class GuidedDecodingMode(Enum): JSON = "json" @@ -19,7 +18,11 @@ class GuidedDecodingMode(Enum): def get_guided_decoding_logits_processor(guided_spec: Union[str, dict, BaseModel], mode: GuidedDecodingMode, tokenizer): def dummy_llm(): - return LLM(model="dummy", tokenizer=tokenizer) + x = SimpleNamespace() + y = SimpleNamespace() + x.tokenizer = tokenizer + y.tokenizer = x + return y if mode == GuidedDecodingMode.JSON: assert isinstance(guided_spec, (str, dict, BaseModel)), "JSON schema error" From b090c18f852d7da3d150603383704db437e1eee1 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Sat, 10 Feb 2024 17:39:42 -0800 Subject: [PATCH 10/33] start adding tests --- .../test_openai_server_guided_decoding.py | 163 ++++++++++++++++++ vllm/model_executor/guided_decoding.py | 6 +- 2 files changed, 167 insertions(+), 2 deletions(-) create mode 100644 tests/entrypoints/test_openai_server_guided_decoding.py diff --git a/tests/entrypoints/test_openai_server_guided_decoding.py b/tests/entrypoints/test_openai_server_guided_decoding.py new file mode 100644 index 000000000000..5e77047f02df --- /dev/null +++ b/tests/entrypoints/test_openai_server_guided_decoding.py @@ -0,0 +1,163 @@ +import os +import subprocess +import time + +import sys +import pytest +import requests +import ray # using Ray for overall ease of process management, parallel requests, and debugging. +import openai # use the official client for correctness check + +import json +import jsonschema + +MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here + +TEST_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "position": { + "type": "string" + } + }, + "required": [ + "company", + "position" + ] + } + } + }, + "required": [ + "name", + "age", + "skills", + "work history" + ] +} + +pytestmark = pytest.mark.asyncio + + +@ray.remote(num_gpus=1) +class ServerRunner: + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + "--dtype", + "bfloat16", # use half precision for speed and memory savings in CI environment + "--max-model-len", + "8192", + "--enforce-eager", + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +async def test_guided_json_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", + temperature=0.0, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None + output_json = json.loads(completion.choices[0].text) + jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) + + +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): + with pytest.raises(TypeError): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + temperature=0.0, + extra_body=dict( + guided_json=42 + ) + ) + + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 9a341d864d30..2e884d86d270 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -25,10 +25,12 @@ def dummy_llm(): return y if mode == GuidedDecodingMode.JSON: - assert isinstance(guided_spec, (str, dict, BaseModel)), "JSON schema error" + if not isinstance(guided_spec, (str, dict, BaseModel)): + raise TypeError("JSON schema must be str, dict, or BaseModel") return [JSONLogitsProcessor(guided_spec, dummy_llm())] elif mode == GuidedDecodingMode.REGEX: - assert isinstance(guided_spec, str), "Regex must be string" + if not isinstance(guided_spec, str): + raise TypeError("Regex must be string") return [RegexLogitsProcessor(guided_spec, dummy_llm())] else: return None \ No newline at end of file From 9093c5e6bbad4774ad3bd181c5f6505550028dcd Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Sun, 11 Feb 2024 02:08:13 -0800 Subject: [PATCH 11/33] add more tests --- .../test_openai_server_guided_decoding.py | 132 +++++++++++++++++- vllm/entrypoints/openai/protocol.py | 2 +- 2 files changed, 126 insertions(+), 8 deletions(-) diff --git a/tests/entrypoints/test_openai_server_guided_decoding.py b/tests/entrypoints/test_openai_server_guided_decoding.py index 5e77047f02df..609c3663362e 100644 --- a/tests/entrypoints/test_openai_server_guided_decoding.py +++ b/tests/entrypoints/test_openai_server_guided_decoding.py @@ -10,6 +10,7 @@ import json import jsonschema +import re MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here @@ -60,6 +61,10 @@ "work history" ] } +# NOTE: outlines' underlying regex library (interegular) doesn't support +# ^ or $ or \b, kinda annoying +TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ + "(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" pytestmark = pytest.mark.asyncio @@ -105,7 +110,7 @@ def __del__(self): @pytest.fixture(scope="session") def server(): - ray.init() + ray.init(ignore_reinit_error=True) server_runner = ServerRunner.remote([ "--model", MODEL_NAME, @@ -133,21 +138,125 @@ async def test_guided_json_completion(server, client: openai.AsyncOpenAI): completion = await client.completions.create( model=MODEL_NAME, prompt=f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", - temperature=0.0, + n=3, + temperature=1.0, + max_tokens=500, extra_body=dict( guided_json=TEST_SCHEMA ) ) assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 1 - assert completion.choices[0].text is not None - output_json = json.loads(completion.choices[0].text) - jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) + assert completion.choices is not None and len(completion.choices) == 3 + for i in range(3): + assert completion.choices[i].text is not None + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) + + +async def test_guided_json_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "Give an example JSON for an employee profile that " + \ + f"fits this schema: {TEST_SCHEMA}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=500, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + message = chat_completion.choices[0].message + assert message.content is not None + json1 = json.loads(message.content) + jsonschema.validate(instance=json1, schema=TEST_SCHEMA) + + messages.append({"role": "assistant", "content": message.content}) + messages.append({ + "role": "user", + "content": "Give me another one with a different name and age" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=500, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + message = chat_completion.choices[0].message + assert message.content is not None + json2 = json.loads(message.content) + jsonschema.validate(instance=json2, schema=TEST_SCHEMA) + assert json1["name"] != json2["name"] + assert json1["age"] != json2["age"] + + +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict( + guided_regex=TEST_REGEX + ) + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 3 + for i in range(3): + assert completion.choices[i].text is not None + assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None + + +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": f"Give an example IP address with this regex: {TEST_REGEX}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=20, + extra_body=dict( + guided_regex=TEST_REGEX + ) + ) + ip1 = chat_completion.choices[0].message.content + assert ip1 is not None + assert re.fullmatch(TEST_REGEX, ip1) is not None + + messages.append({"role": "assistant", "content": ip1}) + messages.append({ + "role": "user", + "content": "Give me a different one" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=20, + extra_body=dict( + guided_regex=TEST_REGEX + ) + ) + ip2 = chat_completion.choices[0].message.content + assert ip2 is not None + assert re.fullmatch(TEST_REGEX, ip2) is not None + assert ip1 != ip2 async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): - with pytest.raises(TypeError): + with pytest.raises(Exception): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", @@ -157,6 +266,15 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): ) ) + with pytest.raises(Exception): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex: True", + temperature=0.0, + extra_body=dict( + guided_regex=True + ) + ) if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2edd8804f658..4212904d03f7 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -80,7 +80,7 @@ class ChatCompletionRequest(BaseModel): min_p: Optional[float] = 0.0 include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 - guided_json: Optional[dict] = None + guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None def to_sampling_params(self) -> SamplingParams: From 1efd64d6be42ed4a818340c41dd5cf4c79afe543 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Mon, 12 Feb 2024 19:03:44 -0800 Subject: [PATCH 12/33] fix pytest fixtures scope --- tests/entrypoints/test_openai_server.py | 4 ++-- tests/entrypoints/test_openai_server_guided_decoding.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 54522f0a99fa..f1c758bc8e9a 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -53,7 +53,7 @@ def __del__(self): self.proc.terminate() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(): ray.init() server_runner = ServerRunner.remote([ @@ -70,7 +70,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", diff --git a/tests/entrypoints/test_openai_server_guided_decoding.py b/tests/entrypoints/test_openai_server_guided_decoding.py index 609c3663362e..d6b5f6973fc2 100644 --- a/tests/entrypoints/test_openai_server_guided_decoding.py +++ b/tests/entrypoints/test_openai_server_guided_decoding.py @@ -108,9 +108,9 @@ def __del__(self): self.proc.terminate() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def server(): - ray.init(ignore_reinit_error=True) + ray.init() server_runner = ServerRunner.remote([ "--model", MODEL_NAME, @@ -125,7 +125,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="session") +@pytest.fixture(scope="module") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", From 736ca318f1ad50ff3a886d697fddc5d23573b75a Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Tue, 13 Feb 2024 14:10:44 -0800 Subject: [PATCH 13/33] remove guided decoding from vllm api server --- .../test_openai_server_guided_decoding.py | 2 +- vllm/entrypoints/api_server.py | 27 ------------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/tests/entrypoints/test_openai_server_guided_decoding.py b/tests/entrypoints/test_openai_server_guided_decoding.py index d6b5f6973fc2..561a4d2839e7 100644 --- a/tests/entrypoints/test_openai_server_guided_decoding.py +++ b/tests/entrypoints/test_openai_server_guided_decoding.py @@ -64,7 +64,7 @@ # NOTE: outlines' underlying regex library (interegular) doesn't support # ^ or $ or \b, kinda annoying TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ - "(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" pytestmark = pytest.mark.asyncio diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 7abaf00373eb..4e3e9ff7d746 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -10,7 +10,6 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor TIMEOUT_KEEP_ALIVE = 5 # seconds. app = FastAPI() @@ -31,16 +30,12 @@ async def generate(request: Request) -> Response: - prompt: the prompt to use for the generation. - stream: whether to stream the results or not. - other fields: the sampling parameters (See `SamplingParams` for details). - - guided decoding (structured generation) can be specified with: - - guided_json: JSON schema (str, dict, Pydantic BaseModel). - - guided_regex: a regex string. """ request_dict = await request.json() prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) - _setup_guided_decoding(request_dict) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -78,28 +73,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return JSONResponse(ret) -def _setup_guided_decoding(request_dict): - guided_json = request_dict.pop("guided_json", None) - guided_regex = request_dict.pop("guided_regex", None) - - # if both json and regex exist, use the json - if guided_json: - logits_processors = get_guided_decoding_logits_processor( - guided_json, GuidedDecodingMode("json"), - engine.engine.tokenizer.tokenizer) - elif guided_regex: - logits_processors = get_guided_decoding_logits_processor( - guided_regex, GuidedDecodingMode("regex"), - engine.engine.tokenizer.tokenizer) - else: - return - - if request_dict.get("logits_processors") is None: - request_dict["logits_processors"] = logits_processors - else: - request_dict["logits_processors"].extend(logits_processors) - - if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) From 4d1b04941bebd1d58f3644cbfb9377a69a27e0a2 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Tue, 13 Feb 2024 23:21:05 -0800 Subject: [PATCH 14/33] refactor + add guided_choice --- requirements.txt | 2 +- tests/entrypoints/test_openai_server.py | 262 +++++++++++++++- .../test_openai_server_guided_decoding.py | 281 ------------------ vllm/engine/async_llm_engine.py | 3 + vllm/entrypoints/openai/protocol.py | 2 + vllm/entrypoints/openai/serving_chat.py | 18 +- vllm/entrypoints/openai/serving_completion.py | 19 +- vllm/model_executor/guided_decoding.py | 46 ++- 8 files changed, 307 insertions(+), 326 deletions(-) delete mode 100644 tests/entrypoints/test_openai_server_guided_decoding.py diff --git a/requirements.txt b/requirements.txt index b8d99009847e..66468210c641 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,4 @@ pydantic >= 2.0 # Required for OpenAI server. aioprometheus[starlette] pynvml == 11.5.0 triton >= 2.1.0 -outlines == 0.0.27 +outlines >= 0.0.27 diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index f1c758bc8e9a..25b0413a3dff 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -8,9 +8,69 @@ import ray # using Ray for overall ease of process management, parallel requests, and debugging. import openai # use the official client for correctness check +# imports for guided decoding tests +import json +import jsonschema +import re + MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here +TEST_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "position": { + "type": "string" + } + }, + "required": [ + "company", + "position" + ] + } + } + }, + "required": [ + "name", + "age", + "skills", + "work history" + ] +} + +# NOTE: outlines' underlying regex library (interegular) doesn't support +# ^ or $ or \b, kinda annoying +TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + +TEST_CHOICE = ["Python", "Java", "JavaScript", "C++", "C#", + "PHP", "TypeScript", "Ruby", "Swift", "Kotlin"] + pytestmark = pytest.mark.asyncio @@ -53,7 +113,7 @@ def __del__(self): self.proc.terminate() -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def server(): ray.init() server_runner = ServerRunner.remote([ @@ -70,7 +130,7 @@ def server(): ray.shutdown() -@pytest.fixture(scope="module") +@pytest.fixture(scope="session") def client(): client = openai.AsyncOpenAI( base_url="http://localhost:8000/v1", @@ -250,5 +310,203 @@ async def test_batch_completions(server, client: openai.AsyncOpenAI): assert texts[0] == texts[1] +async def test_guided_json_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", + n=3, + temperature=1.0, + max_tokens=500, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 3 + for i in range(3): + assert completion.choices[i].text is not None + output_json = json.loads(completion.choices[i].text) + jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) + + +async def test_guided_json_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "Give an example JSON for an employee profile that " + \ + f"fits this schema: {TEST_SCHEMA}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=500, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + message = chat_completion.choices[0].message + assert message.content is not None + json1 = json.loads(message.content) + jsonschema.validate(instance=json1, schema=TEST_SCHEMA) + + messages.append({"role": "assistant", "content": message.content}) + messages.append({ + "role": "user", + "content": "Give me another one with a different name and age" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=500, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + message = chat_completion.choices[0].message + assert message.content is not None + json2 = json.loads(message.content) + jsonschema.validate(instance=json2, schema=TEST_SCHEMA) + assert json1["name"] != json2["name"] + assert json1["age"] != json2["age"] + + +async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", + n=3, + temperature=1.0, + max_tokens=20, + extra_body=dict( + guided_regex=TEST_REGEX + ) + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 3 + for i in range(3): + assert completion.choices[i].text is not None + assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None + + +async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": f"Give an example IP address with this regex: {TEST_REGEX}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=20, + extra_body=dict( + guided_regex=TEST_REGEX + ) + ) + ip1 = chat_completion.choices[0].message.content + assert ip1 is not None + assert re.fullmatch(TEST_REGEX, ip1) is not None + + messages.append({"role": "assistant", "content": ip1}) + messages.append({ + "role": "user", + "content": "Give me a different one" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=20, + extra_body=dict( + guided_regex=TEST_REGEX + ) + ) + ip2 = chat_completion.choices[0].message.content + assert ip2 is not None + assert re.fullmatch(TEST_REGEX, ip2) is not None + assert ip1 != ip2 + + +async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt="The best language for type-safe systems programming is ", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict( + guided_choice=TEST_CHOICE + ) + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 2 + for i in range(2): + assert completion.choices[i].text in TEST_CHOICE + + +async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "The best language for type-safe systems programming is " + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + extra_body=dict( + guided_choice=TEST_CHOICE + ) + ) + choice1 = chat_completion.choices[0].message.content + assert choice1 in TEST_CHOICE + + messages.append({"role": "assistant", "content": choice1}) + messages.append({ + "role": "user", + "content": "I disagree, pick another one" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + extra_body=dict( + guided_choice=TEST_CHOICE + ) + ) + choice2 = chat_completion.choices[0].message.content + assert choice2 in TEST_CHOICE + assert choice1 != choice2 + + +async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): + with pytest.raises(Exception): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON that fits this schema: 42", + temperature=0.0, + extra_body=dict( + guided_json=42 + ) + ) + + with pytest.raises(Exception): + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example string that fits this regex: True", + temperature=0.0, + extra_body=dict( + guided_regex=True + ) + ) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/entrypoints/test_openai_server_guided_decoding.py b/tests/entrypoints/test_openai_server_guided_decoding.py deleted file mode 100644 index 561a4d2839e7..000000000000 --- a/tests/entrypoints/test_openai_server_guided_decoding.py +++ /dev/null @@ -1,281 +0,0 @@ -import os -import subprocess -import time - -import sys -import pytest -import requests -import ray # using Ray for overall ease of process management, parallel requests, and debugging. -import openai # use the official client for correctness check - -import json -import jsonschema -import re - -MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds -MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" # any model with a chat template should work here - -TEST_SCHEMA = { - "type": "object", - "properties": { - "name": { - "type": "string" - }, - "age": { - "type": "integer" - }, - "skills": { - "type": "array", - "items": { - "type": "string", - "maxLength": 10 - }, - "minItems": 3 - }, - "work history": { - "type": "array", - "items": { - "type": "object", - "properties": { - "company": { - "type": "string" - }, - "duration": { - "type": "string" - }, - "position": { - "type": "string" - } - }, - "required": [ - "company", - "position" - ] - } - } - }, - "required": [ - "name", - "age", - "skills", - "work history" - ] -} -# NOTE: outlines' underlying regex library (interegular) doesn't support -# ^ or $ or \b, kinda annoying -TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ - r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" - -pytestmark = pytest.mark.asyncio - - -@ray.remote(num_gpus=1) -class ServerRunner: - - def __init__(self, args): - env = os.environ.copy() - env["PYTHONUNBUFFERED"] = "1" - self.proc = subprocess.Popen( - ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, - env=env, - stdout=sys.stdout, - stderr=sys.stderr, - ) - self._wait_for_server() - - def ready(self): - return True - - def _wait_for_server(self): - # run health check - start = time.time() - while True: - try: - if requests.get( - "http://localhost:8000/health").status_code == 200: - break - except Exception as err: - if self.proc.poll() is not None: - raise RuntimeError("Server exited unexpectedly.") from err - - time.sleep(0.5) - if time.time() - start > MAX_SERVER_START_WAIT_S: - raise RuntimeError( - "Server failed to start in time.") from err - - def __del__(self): - if hasattr(self, "proc"): - self.proc.terminate() - - -@pytest.fixture(scope="module") -def server(): - ray.init() - server_runner = ServerRunner.remote([ - "--model", - MODEL_NAME, - "--dtype", - "bfloat16", # use half precision for speed and memory savings in CI environment - "--max-model-len", - "8192", - "--enforce-eager", - ]) - ray.get(server_runner.ready.remote()) - yield server_runner - ray.shutdown() - - -@pytest.fixture(scope="module") -def client(): - client = openai.AsyncOpenAI( - base_url="http://localhost:8000/v1", - api_key="token-abc123", - ) - yield client - - -async def test_guided_json_completion(server, client: openai.AsyncOpenAI): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", - n=3, - temperature=1.0, - max_tokens=500, - extra_body=dict( - guided_json=TEST_SCHEMA - ) - ) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 3 - for i in range(3): - assert completion.choices[i].text is not None - output_json = json.loads(completion.choices[i].text) - jsonschema.validate(instance=output_json, schema=TEST_SCHEMA) - - -async def test_guided_json_chat(server, client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": "Give an example JSON for an employee profile that " + \ - f"fits this schema: {TEST_SCHEMA}" - }] - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=500, - extra_body=dict( - guided_json=TEST_SCHEMA - ) - ) - message = chat_completion.choices[0].message - assert message.content is not None - json1 = json.loads(message.content) - jsonschema.validate(instance=json1, schema=TEST_SCHEMA) - - messages.append({"role": "assistant", "content": message.content}) - messages.append({ - "role": "user", - "content": "Give me another one with a different name and age" - }) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=500, - extra_body=dict( - guided_json=TEST_SCHEMA - ) - ) - message = chat_completion.choices[0].message - assert message.content is not None - json2 = json.loads(message.content) - jsonschema.validate(instance=json2, schema=TEST_SCHEMA) - assert json1["name"] != json2["name"] - assert json1["age"] != json2["age"] - - -async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give an example IPv4 address with this regex: {TEST_REGEX}", - n=3, - temperature=1.0, - max_tokens=20, - extra_body=dict( - guided_regex=TEST_REGEX - ) - ) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 3 - for i in range(3): - assert completion.choices[i].text is not None - assert re.fullmatch(TEST_REGEX, completion.choices[i].text) is not None - - -async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": f"Give an example IP address with this regex: {TEST_REGEX}" - }] - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=20, - extra_body=dict( - guided_regex=TEST_REGEX - ) - ) - ip1 = chat_completion.choices[0].message.content - assert ip1 is not None - assert re.fullmatch(TEST_REGEX, ip1) is not None - - messages.append({"role": "assistant", "content": ip1}) - messages.append({ - "role": "user", - "content": "Give me a different one" - }) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=20, - extra_body=dict( - guided_regex=TEST_REGEX - ) - ) - ip2 = chat_completion.choices[0].message.content - assert ip2 is not None - assert re.fullmatch(TEST_REGEX, ip2) is not None - assert ip1 != ip2 - - -async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): - with pytest.raises(Exception): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example JSON that fits this schema: 42", - temperature=0.0, - extra_body=dict( - guided_json=42 - ) - ) - - with pytest.raises(Exception): - _ = await client.completions.create( - model=MODEL_NAME, - prompt="Give an example string that fits this regex: True", - temperature=0.0, - extra_body=dict( - guided_regex=True - ) - ) - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7cba65460277..daa6419cdad3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -333,6 +333,9 @@ def is_running(self) -> bool: return (self.background_loop is not None and not self.background_loop.done()) + def get_tokenizer(self): + return self.engine.tokenizer.tokenizer + def start_background_loop(self) -> None: """Start the background loop.""" if self.is_running: diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 4212904d03f7..3a117bba0d22 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -82,6 +82,7 @@ class ChatCompletionRequest(BaseModel): length_penalty: Optional[float] = 1.0 guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None + guided_choice: Optional[List[Union[str, int, float, bool]]] = None def to_sampling_params(self) -> SamplingParams: return SamplingParams( @@ -137,6 +138,7 @@ class CompletionRequest(BaseModel): length_penalty: Optional[float] = 1.0 guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None + guided_choice: Optional[List[Union[str, int, float, bool]]] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index e82ea33b45ff..6e5b55bc45e2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -12,7 +12,7 @@ UsageInfo) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -65,7 +65,9 @@ async def create_chat_completion( token_ids = self._validate_prompt_and_tokenize(request, prompt=prompt) sampling_params = request.to_sampling_params() - sampling_params.logits_processors = self._get_guided_decoding_logits_processor(request) + sampling_params.logits_processors = \ + get_guided_decoding_logits_processor( + request, self.engine.get_tokenizer()) except ValueError as e: return self.create_error_response(str(e)) @@ -265,15 +267,3 @@ def _load_chat_template(self, chat_template): else: logger.warning( "No chat template provided. Chat API will not work.") - - def _get_guided_decoding_logits_processor(self, request: ChatCompletionRequest): - if request.guided_json: - return get_guided_decoding_logits_processor( - request.guided_json, GuidedDecodingMode("json"), - self.engine.engine.tokenizer.tokenizer) - elif request.guided_regex: - return get_guided_decoding_logits_processor( - request.guided_regex, GuidedDecodingMode("regex"), - self.engine.engine.tokenizer.tokenizer) - else: - return None diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index fe3a709dc8fc..accd236ae8cc 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -16,7 +16,7 @@ ) from vllm.outputs import RequestOutput from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.model_executor.guided_decoding import GuidedDecodingMode, get_guided_decoding_logits_processor +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor logger = init_logger(__name__) @@ -285,7 +285,9 @@ async def create_completion(self, request: CompletionRequest, generators = [] try: sampling_params = request.to_sampling_params() - sampling_params.logits_processors = self._get_guided_decoding_logits_processor(request) + sampling_params.logits_processors = \ + get_guided_decoding_logits_processor( + request, self.engine.get_tokenizer()) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): @@ -349,16 +351,3 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: return fake_stream_generator() return response - - def _get_guided_decoding_logits_processor(self, request: CompletionRequest): - # should this go inside CompletionRequest.to_sampling_params() instead? - if request.guided_json: - return get_guided_decoding_logits_processor( - request.guided_json, GuidedDecodingMode("json"), - self.engine.engine.tokenizer.tokenizer) - elif request.guided_regex: - return get_guided_decoding_logits_processor( - request.guided_regex, GuidedDecodingMode("regex"), - self.engine.engine.tokenizer.tokenizer) - else: - return None \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 2e884d86d270..726748f16b3e 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -2,6 +2,8 @@ from typing import Union from types import SimpleNamespace from pydantic import BaseModel + +from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest try: from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor except ImportError as e: @@ -10,27 +12,45 @@ ) from e -class GuidedDecodingMode(Enum): - JSON = "json" - REGEX = "regex" - # TODO: add grammar, choice - +def get_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer + ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: -def get_guided_decoding_logits_processor(guided_spec: Union[str, dict, BaseModel], mode: GuidedDecodingMode, tokenizer): def dummy_llm(): + # outlines' logit processor takes in a vllm.LLM object + # to grab the LLM's tokenizer x = SimpleNamespace() y = SimpleNamespace() x.tokenizer = tokenizer y.tokenizer = x return y - if mode == GuidedDecodingMode.JSON: - if not isinstance(guided_spec, (str, dict, BaseModel)): + if request.guided_json: + if not isinstance(request.guided_json, (str, dict, BaseModel)): raise TypeError("JSON schema must be str, dict, or BaseModel") - return [JSONLogitsProcessor(guided_spec, dummy_llm())] - elif mode == GuidedDecodingMode.REGEX: - if not isinstance(guided_spec, str): + return [JSONLogitsProcessor(request.guided_json, dummy_llm())] + elif request.guided_regex: + if not isinstance(request.guided_regex, str): raise TypeError("Regex must be string") - return [RegexLogitsProcessor(guided_spec, dummy_llm())] + return [RegexLogitsProcessor(request.guided_regex, dummy_llm())] + elif request.guided_choice: + if not isinstance(request.guided_choice, list): + raise TypeError("Choices must be a list") + # create regex from choices + choices = [str_with_escape(choice) for choice in request.guided_choice] + choices_regex = "(" + "|".join(choices) + ")" + return [RegexLogitsProcessor(choices_regex, dummy_llm())] else: - return None \ No newline at end of file + return None + + +def str_with_escape(e: Union[str, int, float, bool]): + s = str(e) + a = [] + regex_reserved = set(".()[]{}|*+?^$-\\") + for ch in s: + if ch in regex_reserved: + a.append("\\") + a.append(ch) + return "".join(a) \ No newline at end of file From a46684eaef4c5e4169bb891df028cfbc5079c516 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 14 Feb 2024 00:28:50 -0800 Subject: [PATCH 15/33] add caching for logit processors --- vllm/model_executor/guided_decoding.py | 65 ++++++++++++++++++++------ 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 726748f16b3e..a75a0f448b3b 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,4 +1,6 @@ from enum import Enum +from functools import lru_cache +from json import dumps as json_dumps from typing import Union from types import SimpleNamespace from pydantic import BaseModel @@ -8,7 +10,7 @@ from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor except ImportError as e: raise ValueError( - "Please install 'outlines' (pip install outlines) to use guided generation." + "Please install 'outlines' (pip install outlines) to use guided decoding." ) from e @@ -16,35 +18,68 @@ def get_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: - - def dummy_llm(): - # outlines' logit processor takes in a vllm.LLM object - # to grab the LLM's tokenizer - x = SimpleNamespace() - y = SimpleNamespace() - x.tokenizer = tokenizer - y.tokenizer = x - return y + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (json/regex, tokenizer). + """ if request.guided_json: if not isinstance(request.guided_json, (str, dict, BaseModel)): raise TypeError("JSON schema must be str, dict, or BaseModel") - return [JSONLogitsProcessor(request.guided_json, dummy_llm())] + + json = request.guided_json + if isinstance(request.guided_json, dict): + # turn dict into hashable string + json = json_dumps(request.guided_json, sort_keys=True) + elif isinstance(request.guided_json, BaseModel): + # use pydantic signature so that different model classes + # with the same fields will get hashed the same + json = str(request.guided_json.__signature__) + + return get_cached_logit_processor(json, tokenizer, True) + elif request.guided_regex: if not isinstance(request.guided_regex, str): raise TypeError("Regex must be string") - return [RegexLogitsProcessor(request.guided_regex, dummy_llm())] + + return get_cached_logit_processor( + request.guided_regex, tokenizer, False) + elif request.guided_choice: if not isinstance(request.guided_choice, list): raise TypeError("Choices must be a list") - # create regex from choices + + # choice just uses regex choices = [str_with_escape(choice) for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" - return [RegexLogitsProcessor(choices_regex, dummy_llm())] + + return get_cached_logit_processor(choices_regex, tokenizer, False) + else: return None +@lru_cache +def get_cached_logit_processor(guide: str, tokenizer, is_json: bool): + # guide is guaranteed hashable (see above function) + # tokenizer should be hashable right?? + + def dummy_llm(): + # outlines' logit processor takes in a vllm.LLM object + # to grab the LLM's tokenizer, may break in future + x = SimpleNamespace() + y = SimpleNamespace() + x.tokenizer = tokenizer + y.tokenizer = x + return y + + if is_json: + return [JSONLogitsProcessor(guide, dummy_llm())] + else: + return [RegexLogitsProcessor(guide, dummy_llm())] + + def str_with_escape(e: Union[str, int, float, bool]): s = str(e) a = [] @@ -53,4 +88,4 @@ def str_with_escape(e: Union[str, int, float, bool]): if ch in regex_reserved: a.append("\\") a.append(ch) - return "".join(a) \ No newline at end of file + return "".join(a) From 058fce65c3caf5e585b248b4c98b3da8e5f42ac2 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 14 Feb 2024 01:15:01 -0800 Subject: [PATCH 16/33] use re.escape --- vllm/model_executor/guided_decoding.py | 28 +++++++++----------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index a75a0f448b3b..e85754774991 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,6 +1,7 @@ from enum import Enum from functools import lru_cache from json import dumps as json_dumps +from re import escape as regex_escape from typing import Union from types import SimpleNamespace from pydantic import BaseModel @@ -37,31 +38,31 @@ def get_guided_decoding_logits_processor( # with the same fields will get hashed the same json = str(request.guided_json.__signature__) - return get_cached_logit_processor(json, tokenizer, True) + return [get_cached_logits_processor(json, tokenizer, True)] elif request.guided_regex: if not isinstance(request.guided_regex, str): raise TypeError("Regex must be string") - return get_cached_logit_processor( - request.guided_regex, tokenizer, False) + return [get_cached_logits_processor( + request.guided_regex, tokenizer, False)] elif request.guided_choice: if not isinstance(request.guided_choice, list): raise TypeError("Choices must be a list") # choice just uses regex - choices = [str_with_escape(choice) for choice in request.guided_choice] + choices = [regex_escape(choice) for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" - return get_cached_logit_processor(choices_regex, tokenizer, False) + return [get_cached_logits_processor(choices_regex, tokenizer, False)] else: return None @lru_cache -def get_cached_logit_processor(guide: str, tokenizer, is_json: bool): +def get_cached_logits_processor(guide: str, tokenizer, is_json: bool): # guide is guaranteed hashable (see above function) # tokenizer should be hashable right?? @@ -75,17 +76,6 @@ def dummy_llm(): return y if is_json: - return [JSONLogitsProcessor(guide, dummy_llm())] + return JSONLogitsProcessor(guide, dummy_llm()) else: - return [RegexLogitsProcessor(guide, dummy_llm())] - - -def str_with_escape(e: Union[str, int, float, bool]): - s = str(e) - a = [] - regex_reserved = set(".()[]{}|*+?^$-\\") - for ch in s: - if ch in regex_reserved: - a.append("\\") - a.append(ch) - return "".join(a) + return RegexLogitsProcessor(guide, dummy_llm()) From dc601c76cd5285643ff46bf93f16805c1824e05c Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 14 Feb 2024 13:28:33 -0800 Subject: [PATCH 17/33] revert logits processor 2 vs 3 arg fix --- vllm/model_executor/guided_decoding.py | 2 +- vllm/model_executor/layers/sampler.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index e85754774991..c0b8d4351700 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -52,7 +52,7 @@ def get_guided_decoding_logits_processor( raise TypeError("Choices must be a list") # choice just uses regex - choices = [regex_escape(choice) for choice in request.guided_choice] + choices = [regex_escape(str(choice)) for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" return [get_cached_logits_processor(choices_regex, tokenizer, False)] diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 429cfb9af3f7..bc86a916b5bb 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,6 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" from typing import Dict, List, Optional, Tuple -from inspect import signature as fn_signature import torch import torch.nn as nn @@ -150,15 +149,11 @@ def _apply_logits_processors( logits_processors = sampling_params.logits_processors if logits_processors: found_logits_processors = True - logits_processor_argc = [len(fn_signature(fn).parameters) for fn in logits_processors] for seq_id in seq_ids: logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for i, logits_processor in enumerate(logits_processors): - if logits_processor_argc[i] == 3: - logits_row = logits_processor(seq_id, token_ids, logits_row) - else: # args len == 2 - logits_row = logits_processor(token_ids, logits_row) + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: From 09d2a9cbadf9cbbe772de9e1eed96e85aa197bda Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 14 Feb 2024 14:09:26 -0800 Subject: [PATCH 18/33] copy on cache hit --- vllm/model_executor/guided_decoding.py | 24 +++++++++++++++++------- vllm/model_executor/layers/sampler.py | 9 +++++++-- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index c0b8d4351700..d921903e1670 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,4 +1,5 @@ -from enum import Enum +from collections import defaultdict +from copy import copy from functools import lru_cache from json import dumps as json_dumps from re import escape as regex_escape @@ -22,9 +23,11 @@ def get_guided_decoding_logits_processor( """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. - We cache logit processors by (json/regex, tokenizer). + We cache logit processors by (json/regex, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying RegexFSM. """ + logits_processor = None if request.guided_json: if not isinstance(request.guided_json, (str, dict, BaseModel)): raise TypeError("JSON schema must be str, dict, or BaseModel") @@ -38,25 +41,32 @@ def get_guided_decoding_logits_processor( # with the same fields will get hashed the same json = str(request.guided_json.__signature__) - return [get_cached_logits_processor(json, tokenizer, True)] + logits_processor = copy(get_cached_logits_processor( + json, tokenizer, True)) elif request.guided_regex: if not isinstance(request.guided_regex, str): raise TypeError("Regex must be string") - return [get_cached_logits_processor( - request.guided_regex, tokenizer, False)] + logits_processor = copy(get_cached_logits_processor( + request.guided_regex, tokenizer, False)) elif request.guided_choice: if not isinstance(request.guided_choice, list): raise TypeError("Choices must be a list") # choice just uses regex - choices = [regex_escape(str(choice)) for choice in request.guided_choice] + choices = [regex_escape(str(choice)) + for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" - return [get_cached_logits_processor(choices_regex, tokenizer, False)] + logits_processor = copy(get_cached_logits_processor( + choices_regex, tokenizer, False)) + if logits_processor: + # reset logits processor's internal state + logits_processor.fsm_state = defaultdict(int) + return [logits_processor] else: return None diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index bc86a916b5bb..429cfb9af3f7 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,5 +1,6 @@ """A layer that samples the next tokens from the model's outputs.""" from typing import Dict, List, Optional, Tuple +from inspect import signature as fn_signature import torch import torch.nn as nn @@ -149,11 +150,15 @@ def _apply_logits_processors( logits_processors = sampling_params.logits_processors if logits_processors: found_logits_processors = True + logits_processor_argc = [len(fn_signature(fn).parameters) for fn in logits_processors] for seq_id in seq_ids: logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for logits_processor in logits_processors: - logits_row = logits_processor(token_ids, logits_row) + for i, logits_processor in enumerate(logits_processors): + if logits_processor_argc[i] == 3: + logits_row = logits_processor(seq_id, token_ids, logits_row) + else: # args len == 2 + logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: From d774cf65e3a0f870e026e85fa11d65d02e1de75d Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 15 Feb 2024 15:11:13 -0800 Subject: [PATCH 19/33] add separate thread for creating logits processor --- vllm/model_executor/guided_decoding.py | 121 ++++++++++++++++--------- 1 file changed, 76 insertions(+), 45 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index d921903e1670..ef0c4144423c 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,5 +1,7 @@ from collections import defaultdict +import concurrent.futures from copy import copy +from enum import Enum from functools import lru_cache from json import dumps as json_dumps from re import escape as regex_escape @@ -16,6 +18,13 @@ ) from e +class GuidedDecodingMode(Enum): + JSON = "json" + REGEX = "regex" + CHOICE = "choice" + GRAMMAR = "grammar" + + def get_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer @@ -26,56 +35,74 @@ def get_guided_decoding_logits_processor( We cache logit processors by (json/regex, tokenizer), and on cache hit we make a shallow copy to reuse the same underlying RegexFSM. """ + guide_count = sum([ + request.guided_json is not None, + request.guided_regex is not None, + request.guided_choice is not None + ]) + if guide_count == 0: + return None + elif guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice')." + ) - logits_processor = None - if request.guided_json: - if not isinstance(request.guided_json, (str, dict, BaseModel)): - raise TypeError("JSON schema must be str, dict, or BaseModel") - - json = request.guided_json - if isinstance(request.guided_json, dict): - # turn dict into hashable string - json = json_dumps(request.guided_json, sort_keys=True) - elif isinstance(request.guided_json, BaseModel): - # use pydantic signature so that different model classes - # with the same fields will get hashed the same - json = str(request.guided_json.__signature__) + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: + futures = [None] # weird workaround, there should be better semantics + if request.guided_json: + if not isinstance(request.guided_json, (str, dict, BaseModel)): + raise TypeError("JSON schema must be str, dict, or BaseModel") + + json = request.guided_json + if isinstance(json, dict): + # turn dict into hashable string + json = json_dumps(json, sort_keys=True) + elif isinstance(json, BaseModel): + # use pydantic signature so that different model classes + # with the same fields will get hashed the same + json = str(json.__signature__) - logits_processor = copy(get_cached_logits_processor( - json, tokenizer, True)) - - elif request.guided_regex: - if not isinstance(request.guided_regex, str): - raise TypeError("Regex must be string") + futures[0] = executor.submit( + get_cached_logits_processor, + json, tokenizer, GuidedDecodingMode.JSON + ) - logits_processor = copy(get_cached_logits_processor( - request.guided_regex, tokenizer, False)) - - elif request.guided_choice: - if not isinstance(request.guided_choice, list): - raise TypeError("Choices must be a list") + elif request.guided_regex: + if not isinstance(request.guided_regex, str): + raise TypeError("Regex must be string") + + futures[0] = executor.submit( + get_cached_logits_processor, + request.guided_regex, tokenizer, GuidedDecodingMode.REGEX + ) - # choice just uses regex - choices = [regex_escape(str(choice)) - for choice in request.guided_choice] - choices_regex = "(" + "|".join(choices) + ")" + elif request.guided_choice: + if not isinstance(request.guided_choice, list): + raise TypeError("Choices must be a list") + + # choice just uses regex + choices = [regex_escape(str(choice)) + for choice in request.guided_choice] + choices_regex = "(" + "|".join(choices) + ")" - logits_processor = copy(get_cached_logits_processor( - choices_regex, tokenizer, False)) - - if logits_processor: - # reset logits processor's internal state - logits_processor.fsm_state = defaultdict(int) - return [logits_processor] - else: - return None + futures[0] = executor.submit( + get_cached_logits_processor, + choices_regex, tokenizer, GuidedDecodingMode.CHOICE + ) + + future = next(concurrent.futures.as_completed(futures)) + try: + logits_processor = copy(future.result()) + # reset logits processor's internal state + logits_processor.fsm_state = defaultdict(int) + return [logits_processor] + except Exception as e: + print("Thread failed to get guided logits processor:", e) -@lru_cache -def get_cached_logits_processor(guide: str, tokenizer, is_json: bool): - # guide is guaranteed hashable (see above function) - # tokenizer should be hashable right?? - +@lru_cache(maxsize=32) +def get_cached_logits_processor(guide: str, tokenizer, mode: GuidedDecodingMode): def dummy_llm(): # outlines' logit processor takes in a vllm.LLM object # to grab the LLM's tokenizer, may break in future @@ -85,7 +112,11 @@ def dummy_llm(): y.tokenizer = x return y - if is_json: + if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, dummy_llm()) - else: + elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, dummy_llm()) + elif mode == GuidedDecodingMode.GRAMMAR: + pass + else: + raise RuntimeError(f"Unknown guided decoding mode {mode}") \ No newline at end of file From 782b1da694ebbb36db7b90ddca7f3c5a8399c630 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 15 Feb 2024 15:34:06 -0800 Subject: [PATCH 20/33] add simple cache test --- tests/entrypoints/test_openai_server.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 25b0413a3dff..68fea5b823d0 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -508,5 +508,25 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): ) +async def test_guided_decoding_cache_performance(server, client: openai.AsyncOpenAI): + N = 10 + times = [] + for i in range(N): + start_t = time.time() + _ = await client.completions.create( + model=MODEL_NAME, + prompt="Give an example JSON for an employee profile " + f"that fits this schema: {TEST_SCHEMA}", + max_tokens=500, + extra_body=dict( + guided_json=TEST_SCHEMA + ) + ) + times.append(time.time() - start_t) + + for i in range(N): + print(f"Request #{i}, time: {times[i]}") + + if __name__ == "__main__": pytest.main([__file__]) From df3c7744d4d443fec005621f8d996255bd38ff42 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 15 Feb 2024 16:15:09 -0800 Subject: [PATCH 21/33] use asyncio --- vllm/entrypoints/openai/serving_chat.py | 2 +- vllm/entrypoints/openai/serving_completion.py | 2 +- vllm/model_executor/guided_decoding.py | 80 ++++++++++++++----- 3 files changed, 64 insertions(+), 20 deletions(-) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 6e5b55bc45e2..dc05dc1d83d7 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -66,7 +66,7 @@ async def create_chat_completion( prompt=prompt) sampling_params = request.to_sampling_params() sampling_params.logits_processors = \ - get_guided_decoding_logits_processor( + await get_guided_decoding_logits_processor( request, self.engine.get_tokenizer()) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index accd236ae8cc..b0b7b67c87da 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -286,7 +286,7 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() sampling_params.logits_processors = \ - get_guided_decoding_logits_processor( + await get_guided_decoding_logits_processor( request, self.engine.get_tokenizer()) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index ef0c4144423c..dce03d2f4522 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,3 +1,4 @@ +import asyncio from collections import defaultdict import concurrent.futures from copy import copy @@ -25,7 +26,7 @@ class GuidedDecodingMode(Enum): GRAMMAR = "grammar" -def get_guided_decoding_logits_processor( +async def get_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: @@ -33,7 +34,7 @@ def get_guided_decoding_logits_processor( Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. We cache logit processors by (json/regex, tokenizer), and on cache hit - we make a shallow copy to reuse the same underlying RegexFSM. + we make a shallow copy to reuse the same underlying FSM. """ guide_count = sum([ request.guided_json is not None, @@ -48,8 +49,55 @@ def get_guided_decoding_logits_processor( "('guided_json', 'guided_regex' or 'guided_choice')." ) - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: - futures = [None] # weird workaround, there should be better semantics + loop = asyncio.get_running_loop() + + # if request.guided_json: + # if not isinstance(request.guided_json, (str, dict, BaseModel)): + # raise TypeError("JSON schema must be str, dict, or BaseModel") + + # json = request.guided_json + # if isinstance(json, dict): + # # turn dict into hashable string + # json = json_dumps(json, sort_keys=True) + # elif isinstance(json, BaseModel): + # # use pydantic signature so that different model classes + # # with the same fields will get hashed the same + # json = str(json.__signature__) + + # result = await loop.run_in_executor( + # None, get_cached_logits_processor, + # json, tokenizer, GuidedDecodingMode.JSON + # ) + + # elif request.guided_regex: + # if not isinstance(request.guided_regex, str): + # raise TypeError("Regex must be string") + + # result = await loop.run_in_executor( + # get_cached_logits_processor, + # request.guided_regex, tokenizer, GuidedDecodingMode.REGEX + # ) + + # elif request.guided_choice: + # if not isinstance(request.guided_choice, list): + # raise TypeError("Choices must be a list") + + # # choice just uses regex + # choices = [regex_escape(str(choice)) + # for choice in request.guided_choice] + # choices_regex = "(" + "|".join(choices) + ")" + + # result = await loop.run_in_executor( + # get_cached_logits_processor, + # choices_regex, tokenizer, GuidedDecodingMode.CHOICE + # ) + + # logits_processor = copy(result) + # # reset logits processor's internal state + # logits_processor.fsm_state = defaultdict(int) + # return [logits_processor] + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: if request.guided_json: if not isinstance(request.guided_json, (str, dict, BaseModel)): raise TypeError("JSON schema must be str, dict, or BaseModel") @@ -63,8 +111,8 @@ def get_guided_decoding_logits_processor( # with the same fields will get hashed the same json = str(json.__signature__) - futures[0] = executor.submit( - get_cached_logits_processor, + result = await loop.run_in_executor( + pool, get_cached_logits_processor, json, tokenizer, GuidedDecodingMode.JSON ) @@ -72,8 +120,8 @@ def get_guided_decoding_logits_processor( if not isinstance(request.guided_regex, str): raise TypeError("Regex must be string") - futures[0] = executor.submit( - get_cached_logits_processor, + result = await loop.run_in_executor( + pool, get_cached_logits_processor, request.guided_regex, tokenizer, GuidedDecodingMode.REGEX ) @@ -86,19 +134,15 @@ def get_guided_decoding_logits_processor( for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" - futures[0] = executor.submit( - get_cached_logits_processor, + result = await loop.run_in_executor( + pool, get_cached_logits_processor, choices_regex, tokenizer, GuidedDecodingMode.CHOICE ) - future = next(concurrent.futures.as_completed(futures)) - try: - logits_processor = copy(future.result()) - # reset logits processor's internal state - logits_processor.fsm_state = defaultdict(int) - return [logits_processor] - except Exception as e: - print("Thread failed to get guided logits processor:", e) + logits_processor = copy(result) + # reset logits processor's internal state + logits_processor.fsm_state = defaultdict(int) + return [logits_processor] @lru_cache(maxsize=32) From c74f6bb2a92ba20782e6f295c56d67cc7797de69 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Sat, 17 Feb 2024 23:41:07 -0800 Subject: [PATCH 22/33] add grammar support --- tests/entrypoints/test_openai_server.py | 60 +++++++++ tests/samplers/test_sampler.py | 1 - vllm/entrypoints/openai/protocol.py | 2 + vllm/model_executor/guided_decoding.py | 162 +++++++++--------------- 4 files changed, 123 insertions(+), 102 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 68fea5b823d0..a44f993c8032 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,6 +71,13 @@ TEST_CHOICE = ["Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", "Swift", "Kotlin"] +TEST_GRAMMAR = """ +start: DECIMAL +DIGIT: "0".."9" +INT: DIGIT+ +DECIMAL: INT "." INT? | "." INT +""" + pytestmark = pytest.mark.asyncio @@ -486,6 +493,59 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): assert choice1 != choice2 +async def test_guided_grammar_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create( + model=MODEL_NAME, + prompt=f"Give me a value that matches this context-free grammar: {TEST_GRAMMAR}", + n=2, + temperature=1.0, + max_tokens=10, + extra_body=dict( + guided_grammar=TEST_GRAMMAR + ) + ) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 2 + for i in range(2): + _ = float(completion.choices[i].text) + + +async def test_guided_grammar_chat(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": f"Give me a value that matches this context-free grammar: {TEST_GRAMMAR}" + }] + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + extra_body=dict( + guided_grammar=TEST_GRAMMAR + ) + ) + val1 = float(chat_completion.choices[0].message.content) + + messages.append({"role": "assistant", "content": val1}) + messages.append({ + "role": "user", + "content": "Give me a different value" + }) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + extra_body=dict( + guided_grammar=TEST_GRAMMAR + ) + ) + val2 = float(chat_completion.choices[0].message.content) + assert val1 != val2 + + async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): with pytest.raises(Exception): _ = await client.completions.create( diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 44cd9ef2a07a..d34f32d03fee 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -230,7 +230,6 @@ def test_sampler_logits_processors(seed: int, device: str): # This sample logits processor gives infinite score to the i-th token, # where i is the length of the input sequence. # We therefore expect the output token sequence to be [0, 1, 2, ...] - # Since this processor is stateless, the seq_id is not used def pick_ith(token_ids, logits): logits[len(token_ids)] = float("inf") return logits diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3a117bba0d22..bd06f2e7a23a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -83,6 +83,7 @@ class ChatCompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[Union[str, int, float, bool]]] = None + guided_grammar: Optional[str] = None def to_sampling_params(self) -> SamplingParams: return SamplingParams( @@ -139,6 +140,7 @@ class CompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[Union[str, int, float, bool]]] = None + guided_grammar: Optional[str] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index dce03d2f4522..0b317891853a 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -6,13 +6,13 @@ from functools import lru_cache from json import dumps as json_dumps from re import escape as regex_escape -from typing import Union +from typing import Union, Tuple from types import SimpleNamespace from pydantic import BaseModel from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest try: - from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor + from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor except ImportError as e: raise ValueError( "Please install 'outlines' (pip install outlines) to use guided decoding." @@ -33,134 +33,94 @@ async def get_guided_decoding_logits_processor( """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. - We cache logit processors by (json/regex, tokenizer), and on cache hit + We cache logit processors by (guide, tokenizer), and on cache hit we make a shallow copy to reuse the same underlying FSM. """ + guide, mode = _get_guide_and_mode(request) + if not guide: + return None + loop = asyncio.get_running_loop() + + with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: + result = await loop.run_in_executor( + pool, get_cached_logits_processor, + guide, tokenizer, mode + ) + logits_processor = copy(result) + # reset logits processor's internal state + logits_processor.fsm_state = defaultdict(int) + return [logits_processor] + + +def _get_guide_and_mode( + request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Tuple[str, GuidedDecodingMode]: + # validate guided decoding parameters guide_count = sum([ request.guided_json is not None, request.guided_regex is not None, request.guided_choice is not None ]) if guide_count == 0: - return None + return None, None elif guide_count > 1: raise ValueError( "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice')." ) - - loop = asyncio.get_running_loop() - - # if request.guided_json: - # if not isinstance(request.guided_json, (str, dict, BaseModel)): - # raise TypeError("JSON schema must be str, dict, or BaseModel") - - # json = request.guided_json - # if isinstance(json, dict): - # # turn dict into hashable string - # json = json_dumps(json, sort_keys=True) - # elif isinstance(json, BaseModel): - # # use pydantic signature so that different model classes - # # with the same fields will get hashed the same - # json = str(json.__signature__) - - # result = await loop.run_in_executor( - # None, get_cached_logits_processor, - # json, tokenizer, GuidedDecodingMode.JSON - # ) - # elif request.guided_regex: - # if not isinstance(request.guided_regex, str): - # raise TypeError("Regex must be string") - - # result = await loop.run_in_executor( - # get_cached_logits_processor, - # request.guided_regex, tokenizer, GuidedDecodingMode.REGEX - # ) - - # elif request.guided_choice: - # if not isinstance(request.guided_choice, list): - # raise TypeError("Choices must be a list") - - # # choice just uses regex - # choices = [regex_escape(str(choice)) - # for choice in request.guided_choice] - # choices_regex = "(" + "|".join(choices) + ")" - - # result = await loop.run_in_executor( - # get_cached_logits_processor, - # choices_regex, tokenizer, GuidedDecodingMode.CHOICE - # ) - - # logits_processor = copy(result) - # # reset logits processor's internal state - # logits_processor.fsm_state = defaultdict(int) - # return [logits_processor] - - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - if request.guided_json: - if not isinstance(request.guided_json, (str, dict, BaseModel)): - raise TypeError("JSON schema must be str, dict, or BaseModel") - - json = request.guided_json - if isinstance(json, dict): - # turn dict into hashable string - json = json_dumps(json, sort_keys=True) - elif isinstance(json, BaseModel): - # use pydantic signature so that different model classes - # with the same fields will get hashed the same - json = str(json.__signature__) - - result = await loop.run_in_executor( - pool, get_cached_logits_processor, - json, tokenizer, GuidedDecodingMode.JSON - ) - - elif request.guided_regex: - if not isinstance(request.guided_regex, str): - raise TypeError("Regex must be string") + if request.guided_json: + if not isinstance(request.guided_json, (str, dict, BaseModel)): + raise TypeError("JSON schema must be str, dict, or BaseModel") - result = await loop.run_in_executor( - pool, get_cached_logits_processor, - request.guided_regex, tokenizer, GuidedDecodingMode.REGEX - ) + json = request.guided_json + if isinstance(json, dict): + # turn dict into hashable string + json = json_dumps(json, sort_keys=True) + elif isinstance(json, BaseModel): + # use pydantic signature so that different model classes + # with the same fields will get hashed the same + json = str(json.__signature__) + return json, GuidedDecodingMode.JSON - elif request.guided_choice: - if not isinstance(request.guided_choice, list): - raise TypeError("Choices must be a list") - - # choice just uses regex - choices = [regex_escape(str(choice)) - for choice in request.guided_choice] - choices_regex = "(" + "|".join(choices) + ")" - - result = await loop.run_in_executor( - pool, get_cached_logits_processor, - choices_regex, tokenizer, GuidedDecodingMode.CHOICE - ) + elif request.guided_regex: + if not isinstance(request.guided_regex, str): + raise TypeError("Regex must be string") + return request.guided_regex, GuidedDecodingMode.REGEX + + elif request.guided_choice: + if not isinstance(request.guided_choice, list): + raise TypeError("Choices must be a list") - logits_processor = copy(result) - # reset logits processor's internal state - logits_processor.fsm_state = defaultdict(int) - return [logits_processor] + # choice just uses regex + choices = [regex_escape(str(choice)) + for choice in request.guided_choice] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex, GuidedDecodingMode.CHOICE + elif request.guided_grammar: + if not isinstance(request.guided_grammar, str): + raise TypeError("Grammar must be string") + return request.guided_grammar, GuidedDecodingMode.GRAMMAR + @lru_cache(maxsize=32) def get_cached_logits_processor(guide: str, tokenizer, mode: GuidedDecodingMode): def dummy_llm(): - # outlines' logit processor takes in a vllm.LLM object + # outlines' logit processor takes i"n a LLMEngine object # to grab the LLM's tokenizer, may break in future + # NOTE: as of 2/17, outlines PR 541 gets this wrong" x = SimpleNamespace() - y = SimpleNamespace() + # y = SimpleNamespace() x.tokenizer = tokenizer - y.tokenizer = x - return y + # y.tokenizer = x + return x if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, dummy_llm()) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, dummy_llm()) elif mode == GuidedDecodingMode.GRAMMAR: - pass + return CFGLogitsProcessor(guide, dummy_llm()) else: - raise RuntimeError(f"Unknown guided decoding mode {mode}") \ No newline at end of file + raise RuntimeError(f"Unknown guided decoding mode {mode}") From cf8494dddc2e778824de19151d1f6e8f8ae55427 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 21 Feb 2024 14:52:44 -0800 Subject: [PATCH 23/33] remove grammar --- vllm/entrypoints/openai/protocol.py | 2 -- vllm/model_executor/guided_decoding.py | 10 +--------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index bd06f2e7a23a..3a117bba0d22 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -83,7 +83,6 @@ class ChatCompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[Union[str, int, float, bool]]] = None - guided_grammar: Optional[str] = None def to_sampling_params(self) -> SamplingParams: return SamplingParams( @@ -140,7 +139,6 @@ class CompletionRequest(BaseModel): guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None guided_choice: Optional[List[Union[str, int, float, bool]]] = None - guided_grammar: Optional[str] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 0b317891853a..82c058d33b91 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -12,7 +12,7 @@ from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest try: - from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor + from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor except ImportError as e: raise ValueError( "Please install 'outlines' (pip install outlines) to use guided decoding." @@ -23,7 +23,6 @@ class GuidedDecodingMode(Enum): JSON = "json" REGEX = "regex" CHOICE = "choice" - GRAMMAR = "grammar" async def get_guided_decoding_logits_processor( @@ -97,11 +96,6 @@ def _get_guide_and_mode( for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE - - elif request.guided_grammar: - if not isinstance(request.guided_grammar, str): - raise TypeError("Grammar must be string") - return request.guided_grammar, GuidedDecodingMode.GRAMMAR @lru_cache(maxsize=32) @@ -120,7 +114,5 @@ def dummy_llm(): return JSONLogitsProcessor(guide, dummy_llm()) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: return RegexLogitsProcessor(guide, dummy_llm()) - elif mode == GuidedDecodingMode.GRAMMAR: - return CFGLogitsProcessor(guide, dummy_llm()) else: raise RuntimeError(f"Unknown guided decoding mode {mode}") From c30fed02a7bcc9200a82e48b82eb9944601b389b Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 21 Feb 2024 15:22:00 -0800 Subject: [PATCH 24/33] copy outlines' logits processors code --- tests/entrypoints/test_openai_server.py | 60 -------- vllm/model_executor/guided_decoding.py | 13 +- .../guided_logits_processors.py | 136 ++++++++++++++++++ 3 files changed, 140 insertions(+), 69 deletions(-) create mode 100644 vllm/model_executor/guided_logits_processors.py diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index a44f993c8032..68fea5b823d0 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -71,13 +71,6 @@ TEST_CHOICE = ["Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", "Swift", "Kotlin"] -TEST_GRAMMAR = """ -start: DECIMAL -DIGIT: "0".."9" -INT: DIGIT+ -DECIMAL: INT "." INT? | "." INT -""" - pytestmark = pytest.mark.asyncio @@ -493,59 +486,6 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): assert choice1 != choice2 -async def test_guided_grammar_completion(server, client: openai.AsyncOpenAI): - completion = await client.completions.create( - model=MODEL_NAME, - prompt=f"Give me a value that matches this context-free grammar: {TEST_GRAMMAR}", - n=2, - temperature=1.0, - max_tokens=10, - extra_body=dict( - guided_grammar=TEST_GRAMMAR - ) - ) - - assert completion.id is not None - assert completion.choices is not None and len(completion.choices) == 2 - for i in range(2): - _ = float(completion.choices[i].text) - - -async def test_guided_grammar_chat(server, client: openai.AsyncOpenAI): - messages = [{ - "role": "system", - "content": "you are a helpful assistant" - }, { - "role": "user", - "content": f"Give me a value that matches this context-free grammar: {TEST_GRAMMAR}" - }] - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - extra_body=dict( - guided_grammar=TEST_GRAMMAR - ) - ) - val1 = float(chat_completion.choices[0].message.content) - - messages.append({"role": "assistant", "content": val1}) - messages.append({ - "role": "user", - "content": "Give me a different value" - }) - chat_completion = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - max_tokens=10, - extra_body=dict( - guided_grammar=TEST_GRAMMAR - ) - ) - val2 = float(chat_completion.choices[0].message.content) - assert val1 != val2 - - async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): with pytest.raises(Exception): _ = await client.completions.create( diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 82c058d33b91..c15f5664e7d0 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -11,12 +11,7 @@ from pydantic import BaseModel from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest -try: - from outlines.serve.vllm import JSONLogitsProcessor, RegexLogitsProcessor -except ImportError as e: - raise ValueError( - "Please install 'outlines' (pip install outlines) to use guided decoding." - ) from e +from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor class GuidedDecodingMode(Enum): @@ -105,10 +100,10 @@ def dummy_llm(): # to grab the LLM's tokenizer, may break in future # NOTE: as of 2/17, outlines PR 541 gets this wrong" x = SimpleNamespace() - # y = SimpleNamespace() + y = SimpleNamespace() x.tokenizer = tokenizer - # y.tokenizer = x - return x + y.tokenizer = x + return y if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, dummy_llm()) diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py new file mode 100644 index 000000000000..2fc5ef705f3c --- /dev/null +++ b/vllm/model_executor/guided_logits_processors.py @@ -0,0 +1,136 @@ +# Make vLLM compatible with Outlines' structured generation. +# +# _______________________________ +# / Don't want to self-host? \ +# \ Try .json at http://dottxt.co / +# ------------------------------- +# \ ^__^ +# \ (oo)\_______ +# (__)\ )\/\ +# ||----w | +# || || +# +# Copyright 2024- the Outlines developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import math +from collections import defaultdict +from typing import DefaultDict, Dict, List, Optional + +import torch +from pydantic import BaseModel + +try: + from outlines.fsm.fsm import RegexFSM + from outlines.fsm.json_schema import build_regex_from_schema +except ImportError as e: + raise ValueError( + "Please install 'outlines' (pip install outlines) to use guided decoding." + ) from e + + +class RegexLogitsProcessor: + def __init__(self, regex_string, llm): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + llm + An instance of `vllm.LLM` + + """ + tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer) + + fsm = RegexFSM(regex_string, tokenizer) + self.fsm = fsm + + def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + seq_id = hash(tuple(input_ids)) + + if len(input_ids) == 0: # Initialize the fsm states + self.fsm_state: DefaultDict[int, int] = defaultdict(int) + else: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[last_seq_id], last_token + ) + + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + + mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask[allowed_tokens] = 0 + biased_scores = scores + mask + + return biased_scores + + def adapt_tokenizer(self, tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +class JSONLogitsProcessor(RegexLogitsProcessor): + def __init__(self, schema: Dict, llm, whitespace_pattern: Optional[str] = None): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate + llm + An instance of `vllm.LLM` + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + + "a Pydantic object, a dictionary or a string that contains the JSON " + + "Schema specification" + ) + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, llm) \ No newline at end of file From 33dc0827c5a5ae2b13e49dad03d4855be8779089 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Wed, 21 Feb 2024 21:15:22 -0800 Subject: [PATCH 25/33] breno PR comments --- vllm/model_executor/guided_decoding.py | 7 +++---- vllm/model_executor/guided_logits_processors.py | 8 ++++++-- vllm/model_executor/layers/sampler.py | 7 +------ vllm/sampling_params.py | 7 ++----- 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index c15f5664e7d0..47d5d8a46d94 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -1,5 +1,4 @@ import asyncio -from collections import defaultdict import concurrent.futures from copy import copy from enum import Enum @@ -37,12 +36,12 @@ async def get_guided_decoding_logits_processor( with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: result = await loop.run_in_executor( - pool, get_cached_logits_processor, + pool, _get_cached_logits_processor, guide, tokenizer, mode ) logits_processor = copy(result) # reset logits processor's internal state - logits_processor.fsm_state = defaultdict(int) + logits_processor.init_state() return [logits_processor] @@ -94,7 +93,7 @@ def _get_guide_and_mode( @lru_cache(maxsize=32) -def get_cached_logits_processor(guide: str, tokenizer, mode: GuidedDecodingMode): +def _get_cached_logits_processor(guide: str, tokenizer, mode: GuidedDecodingMode): def dummy_llm(): # outlines' logit processor takes i"n a LLMEngine object # to grab the LLM's tokenizer, may break in future diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py index 2fc5ef705f3c..dd031cac8fdf 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_logits_processors.py @@ -57,13 +57,17 @@ def __init__(self, regex_string, llm): fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm + def init_state(self): + """Initialize the FSM states.""" + self.fsm_state: DefaultDict[int, int] = defaultdict(int) + def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" seq_id = hash(tuple(input_ids)) - if len(input_ids) == 0: # Initialize the fsm states - self.fsm_state: DefaultDict[int, int] = defaultdict(int) + if len(input_ids) == 0: + self.init_state() else: last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 429cfb9af3f7..744e8243aa57 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,6 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" from typing import Dict, List, Optional, Tuple -from inspect import signature as fn_signature import torch import torch.nn as nn @@ -150,15 +149,11 @@ def _apply_logits_processors( logits_processors = sampling_params.logits_processors if logits_processors: found_logits_processors = True - logits_processor_argc = [len(fn_signature(fn).parameters) for fn in logits_processors] for seq_id in seq_ids: logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids for i, logits_processor in enumerate(logits_processors): - if logits_processor_argc[i] == 3: - logits_row = logits_processor(seq_id, token_ids, logits_row) - else: # args len == 2 - logits_row = logits_processor(token_ids, logits_row) + logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 else: diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index bfd05a644580..bb7d0002c910 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -14,13 +14,10 @@ class SamplingType(IntEnum): BEAM = 2 -LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor], - Callable[[int, List[int], torch.Tensor], torch.Tensor]] +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] """LogitsProcessor is a function that takes a list of previously generated tokens and a tensor of the logits for the next token, and returns a modified -tensor of logits to sample from. Some processors may also take in an integer -sequence id, which is used to distinguish different generations, in case the -processor is stateful (such as for guided decoding).""" +tensor of logits to sample from.""" class SamplingParams: From 8699039faaa23d8780e7fa56441551de4575c98e Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Thu, 22 Feb 2024 15:00:59 -0800 Subject: [PATCH 26/33] resolve PR comments --- tests/entrypoints/test_openai_server.py | 35 ++++++-------- vllm/entrypoints/api_server.py | 2 - vllm/entrypoints/openai/protocol.py | 32 ++++++++++++- vllm/model_executor/guided_decoding.py | 43 +++++------------ .../guided_logits_processors.py | 46 +++++++------------ vllm/model_executor/layers/sampler.py | 2 +- 6 files changed, 75 insertions(+), 85 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 68fea5b823d0..43eb8c442a0f 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -487,45 +487,40 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): - with pytest.raises(Exception): + with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - temperature=0.0, extra_body=dict( guided_json=42 ) ) - with pytest.raises(Exception): - _ = await client.completions.create( + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "The best language for type-safe systems programming is " + }] + with pytest.raises(openai.BadRequestError): + _ = await client.chat.completions.create( model=MODEL_NAME, - prompt="Give an example string that fits this regex: True", - temperature=0.0, + messages=messages, extra_body=dict( - guided_regex=True + guided_regex={1: "Python", 2: "C++"} ) ) - -async def test_guided_decoding_cache_performance(server, client: openai.AsyncOpenAI): - N = 10 - times = [] - for i in range(N): - start_t = time.time() + with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, - prompt="Give an example JSON for an employee profile " - f"that fits this schema: {TEST_SCHEMA}", - max_tokens=500, + prompt="Give an example string that fits this regex", extra_body=dict( + guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA ) ) - times.append(time.time() - start_t) - - for i in range(N): - print(f"Request #{i}, time: {times[i]}") if __name__ == "__main__": diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 4e3e9ff7d746..f7b8d258fae4 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -35,7 +35,6 @@ async def generate(request: Request) -> Response: prompt = request_dict.pop("prompt") prefix_pos = request_dict.pop("prefix_pos", None) stream = request_dict.pop("stream", False) - sampling_params = SamplingParams(**request_dict) request_id = random_uuid() @@ -84,7 +83,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: type=str, default=None, help="FastAPI root_path when app is behind a path based routing proxy") - parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 3a117bba0d22..5534f05be342 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -3,7 +3,7 @@ import time from typing import Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from vllm.utils import random_uuid from vllm.sampling_params import SamplingParams @@ -105,6 +105,21 @@ def to_sampling_params(self) -> SamplingParams: include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice')." + ) + return data class CompletionRequest(BaseModel): @@ -165,6 +180,21 @@ def to_sampling_params(self): include_stop_str_in_output=self.include_stop_str_in_output, length_penalty=self.length_penalty, ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice')." + ) + return data class LogProbs(BaseModel): diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 47d5d8a46d94..dbe2ffad8996 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -39,29 +39,17 @@ async def get_guided_decoding_logits_processor( pool, _get_cached_logits_processor, guide, tokenizer, mode ) - logits_processor = copy(result) - # reset logits processor's internal state - logits_processor.init_state() - return [logits_processor] + + logits_processor = copy(result) + # reset logits processor's internal state + logits_processor.init_state() + return [logits_processor] def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest] ) -> Tuple[str, GuidedDecodingMode]: - # validate guided decoding parameters - guide_count = sum([ - request.guided_json is not None, - request.guided_regex is not None, - request.guided_choice is not None - ]) - if guide_count == 0: - return None, None - elif guide_count > 1: - raise ValueError( - "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice')." - ) - + if request.guided_json: if not isinstance(request.guided_json, (str, dict, BaseModel)): raise TypeError("JSON schema must be str, dict, or BaseModel") @@ -90,23 +78,16 @@ def _get_guide_and_mode( for choice in request.guided_choice] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE + + else: + return None, None @lru_cache(maxsize=32) def _get_cached_logits_processor(guide: str, tokenizer, mode: GuidedDecodingMode): - def dummy_llm(): - # outlines' logit processor takes i"n a LLMEngine object - # to grab the LLM's tokenizer, may break in future - # NOTE: as of 2/17, outlines PR 541 gets this wrong" - x = SimpleNamespace() - y = SimpleNamespace() - x.tokenizer = tokenizer - y.tokenizer = x - return y - if mode == GuidedDecodingMode.JSON: - return JSONLogitsProcessor(guide, dummy_llm()) + return JSONLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: - return RegexLogitsProcessor(guide, dummy_llm()) + return RegexLogitsProcessor(guide, tokenizer) else: - raise RuntimeError(f"Unknown guided decoding mode {mode}") + raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py index dd031cac8fdf..a7e0c32e7971 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_logits_processors.py @@ -1,15 +1,3 @@ -# Make vLLM compatible with Outlines' structured generation. -# -# _______________________________ -# / Don't want to self-host? \ -# \ Try .json at http://dottxt.co / -# ------------------------------- -# \ ^__^ -# \ (oo)\_______ -# (__)\ )\/\ -# ||----w | -# || || -# # Copyright 2024- the Outlines developers # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,34 +14,27 @@ import json import math from collections import defaultdict -from typing import DefaultDict, Dict, List, Optional +from typing import Union, DefaultDict, Dict, List, Optional import torch from pydantic import BaseModel - -try: - from outlines.fsm.fsm import RegexFSM - from outlines.fsm.json_schema import build_regex_from_schema -except ImportError as e: - raise ValueError( - "Please install 'outlines' (pip install outlines) to use guided decoding." - ) from e +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_schema class RegexLogitsProcessor: - def __init__(self, regex_string, llm): + def __init__(self, regex_string: str, tokenizer): """Compile the FSM that drives the regex-structured generation. Parameters ---------- regex_string A string that represents a regular expression - llm - An instance of `vllm.LLM` + tokenizer + The model's tokenizer """ - tokenizer = self.adapt_tokenizer(llm.tokenizer.tokenizer) - + tokenizer = self.adapt_tokenizer(tokenizer) fsm = RegexFSM(regex_string, tokenizer) self.fsm = fsm @@ -111,15 +92,20 @@ def convert_token_to_string(token: str) -> str: class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__(self, schema: Dict, llm, whitespace_pattern: Optional[str] = None): + def __init__( + self, + schema: Union[str, Dict, BaseModel], + tokenizer, + whitespace_pattern: Optional[str] = None + ): """Compile the FSM that drives the JSON-guided generation. Parameters ---------- schema A JSON schema that encodes the structure we want the model to generate - llm - An instance of `vllm.LLM` + tokenizer + The model's tokenizer whitespace_pattern Pattern to use for JSON syntactic whitespace (doesn't impact string literals) Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` @@ -137,4 +123,4 @@ def __init__(self, schema: Dict, llm, whitespace_pattern: Optional[str] = None): + "Schema specification" ) regex_string = build_regex_from_schema(schema_str, whitespace_pattern) - super().__init__(regex_string, llm) \ No newline at end of file + super().__init__(regex_string, tokenizer) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 744e8243aa57..bc86a916b5bb 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -152,7 +152,7 @@ def _apply_logits_processors( for seq_id in seq_ids: logits_row = logits[logits_row_idx] token_ids = sampling_metadata.seq_data[seq_id].output_token_ids - for i, logits_processor in enumerate(logits_processors): + for logits_processor in logits_processors: logits_row = logits_processor(token_ids, logits_row) logits[logits_row_idx] = logits_row logits_row_idx += 1 From 6bf277e2b066cc8bfe9151e98600f7d037e8eca6 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Fri, 23 Feb 2024 16:52:31 -0800 Subject: [PATCH 27/33] add unit test for logits processors --- tests/entrypoints/test_openai_server.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 43eb8c442a0f..0d4e37b518e6 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -523,5 +523,30 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): ) +def test_guided_logits_processors(): + """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" + from vllm.model_executor.guided_logits_processors import ( + RegexLogitsProcessor, + JSONLogitsProcessor + ) + from transformers import AutoTokenizer + from torch import rand + + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + regexLP = RegexLogitsProcessor(TEST_REGEX, tokenizer) + jsonLP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + tensor = rand(32000) + + regexLP.init_state() + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + assert regexLP(token_ids, tensor).shape == tensor.shape + + jsonLP.init_state() + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + assert jsonLP(token_ids, tensor).shape == tensor.shape + + if __name__ == "__main__": pytest.main([__file__]) From b45942ea93976160e9779a729f9e127151e1967c Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Tue, 27 Feb 2024 14:43:46 -0800 Subject: [PATCH 28/33] PR comments --- tests/entrypoints/test_guided_processors.py | 85 +++++++++++++++++++ tests/entrypoints/test_openai_server.py | 27 ------ vllm/model_executor/guided_decoding.py | 13 +-- .../guided_logits_processors.py | 6 +- 4 files changed, 97 insertions(+), 34 deletions(-) create mode 100644 tests/entrypoints/test_guided_processors.py diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py new file mode 100644 index 000000000000..b9da4b51fa72 --- /dev/null +++ b/tests/entrypoints/test_guided_processors.py @@ -0,0 +1,85 @@ +# This unit test should be moved to a new +# tests/test_guided_decoding directory. + +from transformers import AutoTokenizer +import torch + +from vllm.model_executor.guided_logits_processors import ( + RegexLogitsProcessor, + JSONLogitsProcessor +) + +TEST_SCHEMA = { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "string" + }, + "position": { + "type": "string" + } + }, + "required": [ + "company", + "position" + ] + } + } + }, + "required": [ + "name", + "age", + "skills", + "work history" + ] +} + +TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" + + +def test_guided_logits_processors(): + """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" + tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') + regex_LP = RegexLogitsProcessor(TEST_REGEX, tokenizer) + json_LP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) + + regex_LP.init_state() + token_ids = tokenizer.encode( + f"Give an example IPv4 address with this regex: {TEST_REGEX}") + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + regex_LP(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) + + json_LP.init_state() + token_ids = tokenizer.encode( + f"Give an employee profile that fits this schema: {TEST_SCHEMA}") + tensor = torch.rand(32000) + original_tensor = torch.clone(tensor) + json_LP(token_ids, tensor) + assert tensor.shape == original_tensor.shape + assert not torch.allclose(tensor, original_tensor) \ No newline at end of file diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 0d4e37b518e6..077fc6b37fef 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -63,8 +63,6 @@ ] } -# NOTE: outlines' underlying regex library (interegular) doesn't support -# ^ or $ or \b, kinda annoying TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" @@ -523,30 +521,5 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): ) -def test_guided_logits_processors(): - """Basic unit test for RegexLogitsProcessor and JSONLogitsProcessor.""" - from vllm.model_executor.guided_logits_processors import ( - RegexLogitsProcessor, - JSONLogitsProcessor - ) - from transformers import AutoTokenizer - from torch import rand - - tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') - regexLP = RegexLogitsProcessor(TEST_REGEX, tokenizer) - jsonLP = JSONLogitsProcessor(TEST_SCHEMA, tokenizer) - tensor = rand(32000) - - regexLP.init_state() - token_ids = tokenizer.encode( - f"Give an example IPv4 address with this regex: {TEST_REGEX}") - assert regexLP(token_ids, tensor).shape == tensor.shape - - jsonLP.init_state() - token_ids = tokenizer.encode( - f"Give an employee profile that fits this schema: {TEST_SCHEMA}") - assert jsonLP(token_ids, tensor).shape == tensor.shape - - if __name__ == "__main__": pytest.main([__file__]) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index dbe2ffad8996..a2b8debabad1 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -32,13 +32,16 @@ async def get_guided_decoding_logits_processor( guide, mode = _get_guide_and_mode(request) if not guide: return None + + global global_pool + if 'global_pool' not in globals(): + global_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) loop = asyncio.get_running_loop() - with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: - result = await loop.run_in_executor( - pool, _get_cached_logits_processor, - guide, tokenizer, mode - ) + result = await loop.run_in_executor( + global_pool, _get_cached_logits_processor, + guide, tokenizer, mode + ) logits_processor = copy(result) # reset logits processor's internal state diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py index a7e0c32e7971..1f561f931464 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_logits_processors.py @@ -1,4 +1,6 @@ # Copyright 2024- the Outlines developers +# This file is adapted from +# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,9 +62,9 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) mask[allowed_tokens] = 0 - biased_scores = scores + mask + scores.add_(mask) - return biased_scores + return scores def adapt_tokenizer(self, tokenizer): """Adapt vLLM's tokenizer to use to compile the FSM. From c176dab35aa657d42277f55b96d1d8030467afa3 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Tue, 27 Feb 2024 15:54:21 -0800 Subject: [PATCH 29/33] fix --- vllm/model_executor/guided_decoding.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index a2b8debabad1..bbca07030239 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -6,7 +6,6 @@ from json import dumps as json_dumps from re import escape as regex_escape from typing import Union, Tuple -from types import SimpleNamespace from pydantic import BaseModel from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest From bfbbce3b7111d3d5e859cf4693438b4c5c2e1d43 Mon Sep 17 00:00:00 2001 From: Felix Zhu Date: Tue, 27 Feb 2024 16:34:42 -0800 Subject: [PATCH 30/33] format with yapf and ruff --- tests/entrypoints/test_guided_processors.py | 24 ++-- tests/entrypoints/test_openai_server.py | 119 +++++++----------- vllm/entrypoints/openai/protocol.py | 10 +- vllm/model_executor/guided_decoding.py | 38 +++--- .../guided_logits_processors.py | 31 ++--- 5 files changed, 88 insertions(+), 134 deletions(-) diff --git a/tests/entrypoints/test_guided_processors.py b/tests/entrypoints/test_guided_processors.py index b9da4b51fa72..5b39269916f8 100644 --- a/tests/entrypoints/test_guided_processors.py +++ b/tests/entrypoints/test_guided_processors.py @@ -1,13 +1,11 @@ -# This unit test should be moved to a new +# This unit test should be moved to a new # tests/test_guided_decoding directory. from transformers import AutoTokenizer import torch -from vllm.model_executor.guided_logits_processors import ( - RegexLogitsProcessor, - JSONLogitsProcessor -) +from vllm.model_executor.guided_logits_processors import (RegexLogitsProcessor, + JSONLogitsProcessor) TEST_SCHEMA = { "type": "object", @@ -41,19 +39,11 @@ "type": "string" } }, - "required": [ - "company", - "position" - ] + "required": ["company", "position"] } } }, - "required": [ - "name", - "age", - "skills", - "work history" - ] + "required": ["name", "age", "skills", "work history"] } TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ @@ -74,7 +64,7 @@ def test_guided_logits_processors(): regex_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) - + json_LP.init_state() token_ids = tokenizer.encode( f"Give an employee profile that fits this schema: {TEST_SCHEMA}") @@ -82,4 +72,4 @@ def test_guided_logits_processors(): original_tensor = torch.clone(tensor) json_LP(token_ids, tensor) assert tensor.shape == original_tensor.shape - assert not torch.allclose(tensor, original_tensor) \ No newline at end of file + assert not torch.allclose(tensor, original_tensor) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index f4b68d0496dc..84a33961a6b1 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -52,26 +52,20 @@ "type": "string" } }, - "required": [ - "company", - "position" - ] + "required": ["company", "position"] } } }, - "required": [ - "name", - "age", - "skills", - "work history" - ] + "required": ["name", "age", "skills", "work history"] } TEST_REGEX = r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + \ r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)" -TEST_CHOICE = ["Python", "Java", "JavaScript", "C++", "C#", - "PHP", "TypeScript", "Ruby", "Swift", "Kotlin"] +TEST_CHOICE = [ + "Python", "Java", "JavaScript", "C++", "C#", "PHP", "TypeScript", "Ruby", + "Swift", "Kotlin" +] pytestmark = pytest.mark.asyncio @@ -419,14 +413,12 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): async def test_guided_json_completion(server, client: openai.AsyncOpenAI): completion = await client.completions.create( model=MODEL_NAME, - prompt=f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", + prompt= + f"Give an example JSON for an employee profile that fits this schema: {TEST_SCHEMA}", n=3, temperature=1.0, max_tokens=500, - extra_body=dict( - guided_json=TEST_SCHEMA - ) - ) + extra_body=dict(guided_json=TEST_SCHEMA)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -449,10 +441,7 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=500, - extra_body=dict( - guided_json=TEST_SCHEMA - ) - ) + extra_body=dict(guided_json=TEST_SCHEMA)) message = chat_completion.choices[0].message assert message.content is not None json1 = json.loads(message.content) @@ -460,17 +449,16 @@ async def test_guided_json_chat(server, client: openai.AsyncOpenAI): messages.append({"role": "assistant", "content": message.content}) messages.append({ - "role": "user", - "content": "Give me another one with a different name and age" + "role": + "user", + "content": + "Give me another one with a different name and age" }) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=500, - extra_body=dict( - guided_json=TEST_SCHEMA - ) - ) + extra_body=dict(guided_json=TEST_SCHEMA)) message = chat_completion.choices[0].message assert message.content is not None json2 = json.loads(message.content) @@ -486,10 +474,7 @@ async def test_guided_regex_completion(server, client: openai.AsyncOpenAI): n=3, temperature=1.0, max_tokens=20, - extra_body=dict( - guided_regex=TEST_REGEX - ) - ) + extra_body=dict(guided_regex=TEST_REGEX)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 3 @@ -503,34 +488,27 @@ async def test_guided_regex_chat(server, client: openai.AsyncOpenAI): "role": "system", "content": "you are a helpful assistant" }, { - "role": "user", - "content": f"Give an example IP address with this regex: {TEST_REGEX}" + "role": + "user", + "content": + f"Give an example IP address with this regex: {TEST_REGEX}" }] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict( - guided_regex=TEST_REGEX - ) - ) + extra_body=dict(guided_regex=TEST_REGEX)) ip1 = chat_completion.choices[0].message.content assert ip1 is not None assert re.fullmatch(TEST_REGEX, ip1) is not None messages.append({"role": "assistant", "content": ip1}) - messages.append({ - "role": "user", - "content": "Give me a different one" - }) + messages.append({"role": "user", "content": "Give me a different one"}) chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=20, - extra_body=dict( - guided_regex=TEST_REGEX - ) - ) + extra_body=dict(guided_regex=TEST_REGEX)) ip2 = chat_completion.choices[0].message.content assert ip2 is not None assert re.fullmatch(TEST_REGEX, ip2) is not None @@ -544,10 +522,7 @@ async def test_guided_choice_completion(server, client: openai.AsyncOpenAI): n=2, temperature=1.0, max_tokens=10, - extra_body=dict( - guided_choice=TEST_CHOICE - ) - ) + extra_body=dict(guided_choice=TEST_CHOICE)) assert completion.id is not None assert completion.choices is not None and len(completion.choices) == 2 @@ -560,17 +535,16 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): "role": "system", "content": "you are a helpful assistant" }, { - "role": "user", - "content": "The best language for type-safe systems programming is " + "role": + "user", + "content": + "The best language for type-safe systems programming is " }] chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict( - guided_choice=TEST_CHOICE - ) - ) + extra_body=dict(guided_choice=TEST_CHOICE)) choice1 = chat_completion.choices[0].message.content assert choice1 in TEST_CHOICE @@ -583,10 +557,7 @@ async def test_guided_choice_chat(server, client: openai.AsyncOpenAI): model=MODEL_NAME, messages=messages, max_tokens=10, - extra_body=dict( - guided_choice=TEST_CHOICE - ) - ) + extra_body=dict(guided_choice=TEST_CHOICE)) choice2 = chat_completion.choices[0].message.content assert choice2 in TEST_CHOICE assert choice1 != choice2 @@ -597,36 +568,30 @@ async def test_guided_decoding_type_error(server, client: openai.AsyncOpenAI): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example JSON that fits this schema: 42", - extra_body=dict( - guided_json=42 - ) - ) + extra_body=dict(guided_json=42)) messages = [{ "role": "system", "content": "you are a helpful assistant" }, { - "role": "user", - "content": "The best language for type-safe systems programming is " + "role": + "user", + "content": + "The best language for type-safe systems programming is " }] with pytest.raises(openai.BadRequestError): - _ = await client.chat.completions.create( - model=MODEL_NAME, - messages=messages, - extra_body=dict( - guided_regex={1: "Python", 2: "C++"} - ) - ) + _ = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + extra_body=dict(guided_regex={ + 1: "Python", + 2: "C++" + })) with pytest.raises(openai.BadRequestError): _ = await client.completions.create( model=MODEL_NAME, prompt="Give an example string that fits this regex", - extra_body=dict( - guided_regex=TEST_REGEX, - guided_json=TEST_SCHEMA - ) - ) + extra_body=dict(guided_regex=TEST_REGEX, guided_json=TEST_SCHEMA)) if __name__ == "__main__": diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 033d4938c889..d4e737837e76 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -133,7 +133,7 @@ def logit_bias_logits_processor( length_penalty=self.length_penalty, logits_processors=logits_processors, ) - + @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): @@ -145,8 +145,7 @@ def check_guided_decoding_count(cls, data): if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice')." - ) + "('guided_json', 'guided_regex' or 'guided_choice').") return data @@ -227,7 +226,7 @@ def logit_bias_logits_processor( length_penalty=self.length_penalty, logits_processors=logits_processors, ) - + @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): @@ -239,8 +238,7 @@ def check_guided_decoding_count(cls, data): if guide_count > 1: raise ValueError( "You can only use one kind of guided decoding " - "('guided_json', 'guided_regex' or 'guided_choice')." - ) + "('guided_json', 'guided_regex' or 'guided_choice').") return data diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index bbca07030239..6e18eff54e36 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -20,8 +20,7 @@ class GuidedDecodingMode(Enum): async def get_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], - tokenizer - ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: + tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: """ Given an OpenAI-compatible request, check for guided decoding parameters and get the necessary logits processor for the given guide. @@ -31,31 +30,30 @@ async def get_guided_decoding_logits_processor( guide, mode = _get_guide_and_mode(request) if not guide: return None - + global global_pool if 'global_pool' not in globals(): global_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) loop = asyncio.get_running_loop() - result = await loop.run_in_executor( - global_pool, _get_cached_logits_processor, - guide, tokenizer, mode - ) - + result = await loop.run_in_executor(global_pool, + _get_cached_logits_processor, guide, + tokenizer, mode) + logits_processor = copy(result) # reset logits processor's internal state logits_processor.init_state() return [logits_processor] - + def _get_guide_and_mode( - request: Union[CompletionRequest, ChatCompletionRequest] - ) -> Tuple[str, GuidedDecodingMode]: + request: Union[CompletionRequest, ChatCompletionRequest] +) -> Tuple[str, GuidedDecodingMode]: if request.guided_json: if not isinstance(request.guided_json, (str, dict, BaseModel)): raise TypeError("JSON schema must be str, dict, or BaseModel") - + json = request.guided_json if isinstance(json, dict): # turn dict into hashable string @@ -65,28 +63,30 @@ def _get_guide_and_mode( # with the same fields will get hashed the same json = str(json.__signature__) return json, GuidedDecodingMode.JSON - + elif request.guided_regex: if not isinstance(request.guided_regex, str): raise TypeError("Regex must be string") return request.guided_regex, GuidedDecodingMode.REGEX - + elif request.guided_choice: if not isinstance(request.guided_choice, list): raise TypeError("Choices must be a list") - + # choice just uses regex - choices = [regex_escape(str(choice)) - for choice in request.guided_choice] + choices = [ + regex_escape(str(choice)) for choice in request.guided_choice + ] choices_regex = "(" + "|".join(choices) + ")" return choices_regex, GuidedDecodingMode.CHOICE - + else: return None, None @lru_cache(maxsize=32) -def _get_cached_logits_processor(guide: str, tokenizer, mode: GuidedDecodingMode): +def _get_cached_logits_processor(guide: str, tokenizer, + mode: GuidedDecodingMode): if mode == GuidedDecodingMode.JSON: return JSONLogitsProcessor(guide, tokenizer) elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py index 1f561f931464..1b3e5e71a591 100644 --- a/vllm/model_executor/guided_logits_processors.py +++ b/vllm/model_executor/guided_logits_processors.py @@ -25,6 +25,7 @@ class RegexLogitsProcessor: + def __init__(self, regex_string: str, tokenizer): """Compile the FSM that drives the regex-structured generation. @@ -43,8 +44,9 @@ def __init__(self, regex_string: str, tokenizer): def init_state(self): """Initialize the FSM states.""" self.fsm_state: DefaultDict[int, int] = defaultdict(int) - - def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: + + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" seq_id = hash(tuple(input_ids)) @@ -55,12 +57,13 @@ def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: last_token = input_ids[-1] last_seq_id = hash(tuple(input_ids[:-1])) self.fsm_state[seq_id] = self.fsm.next_state( - self.fsm_state[last_seq_id], last_token - ) + self.fsm_state[last_seq_id], last_token) allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) - mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device) + mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) mask[allowed_tokens] = 0 scores.add_(mask) @@ -94,12 +97,11 @@ def convert_token_to_string(token: str) -> str: class JSONLogitsProcessor(RegexLogitsProcessor): - def __init__( - self, - schema: Union[str, Dict, BaseModel], - tokenizer, - whitespace_pattern: Optional[str] = None - ): + + def __init__(self, + schema: Union[str, Dict, BaseModel], + tokenizer, + whitespace_pattern: Optional[str] = None): """Compile the FSM that drives the JSON-guided generation. Parameters @@ -120,9 +122,8 @@ def __init__( schema_str = schema else: raise ValueError( - f"Cannot parse schema {schema}. The schema must be either " - + "a Pydantic object, a dictionary or a string that contains the JSON " - + "Schema specification" - ) + f"Cannot parse schema {schema}. The schema must be either " + + "a Pydantic object, a dictionary or a string that contains the JSON " + + "Schema specification") regex_string = build_regex_from_schema(schema_str, whitespace_pattern) super().__init__(regex_string, tokenizer) From 9655b6e857e77a336b8236d7ee50319fa1c0334f Mon Sep 17 00:00:00 2001 From: simon-mo Date: Tue, 27 Feb 2024 17:55:35 -0800 Subject: [PATCH 31/33] fix global pool --- vllm/model_executor/guided_decoding.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 6e18eff54e36..f8d3ba4cfd95 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -18,6 +18,9 @@ class GuidedDecodingMode(Enum): CHOICE = "choice" +global_thread_pool = None # uesd for generating logits processor fsm + + async def get_guided_decoding_logits_processor( request: Union[CompletionRequest, ChatCompletionRequest], tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: @@ -27,16 +30,17 @@ async def get_guided_decoding_logits_processor( We cache logit processors by (guide, tokenizer), and on cache hit we make a shallow copy to reuse the same underlying FSM. """ + global global_thread_pool guide, mode = _get_guide_and_mode(request) if not guide: return None - global global_pool - if 'global_pool' not in globals(): - global_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2) + if global_thread_pool is None: + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=2) loop = asyncio.get_running_loop() - result = await loop.run_in_executor(global_pool, + result = await loop.run_in_executor(global_thread_pool, _get_cached_logits_processor, guide, tokenizer, mode) From 0c3d4759ae96e53657f6dc92e77f7b18179858cd Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Wed, 28 Feb 2024 14:33:45 -0800 Subject: [PATCH 32/33] Apply suggestions from code review --- vllm/model_executor/guided_decoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index f8d3ba4cfd95..4ef735de3928 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -18,7 +18,7 @@ class GuidedDecodingMode(Enum): CHOICE = "choice" -global_thread_pool = None # uesd for generating logits processor fsm +global_thread_pool = None # used for generating logits processor fsm async def get_guided_decoding_logits_processor( From ce9c07a6ccdf8828e42c225e9bd33eef9fe6d7c0 Mon Sep 17 00:00:00 2001 From: simon-mo Date: Thu, 29 Feb 2024 20:05:46 +0000 Subject: [PATCH 33/33] some minor fix --- tests/entrypoints/test_openai_server.py | 1 + vllm/entrypoints/openai/protocol.py | 4 ++-- vllm/entrypoints/openai/serving_chat.py | 9 +++++++-- vllm/entrypoints/openai/serving_completion.py | 9 +++++++-- vllm/model_executor/guided_decoding.py | 2 +- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 84a33961a6b1..e426cf7eed72 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -377,6 +377,7 @@ async def test_logits_bias(server, client: openai.AsyncOpenAI): max_tokens=max_tokens, temperature=0.0, logit_bias={str(token_id): 100}, + seed=42, ) assert completion.choices[0].text is not None and len( completion.choices[0].text) >= 5 diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 227841669770..26499b8d7a66 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -88,7 +88,7 @@ class ChatCompletionRequest(BaseModel): length_penalty: Optional[float] = 1.0 guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None - guided_choice: Optional[List[Union[str, int, float, bool]]] = None + guided_choice: Optional[List[str]] = None def to_sampling_params(self) -> SamplingParams: if self.logprobs and not self.top_logprobs: @@ -182,7 +182,7 @@ class CompletionRequest(BaseModel): length_penalty: Optional[float] = 1.0 guided_json: Optional[Union[str, dict, BaseModel]] = None guided_regex: Optional[str] = None - guided_choice: Optional[List[Union[str, int, float, bool]]] = None + guided_choice: Optional[List[str]] = None def to_sampling_params(self): echo_without_generation = self.echo and self.max_tokens == 0 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 9080eb049f8f..f4ad0aa5a018 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,9 +63,14 @@ async def create_chat_completion( prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - sampling_params.logits_processors = \ + guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( - request, self.engine.get_tokenizer()) + request, self.engine.get_tokenizer())) + if guided_decode_logits_processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logits_processor) except ValueError as e: return self.create_error_response(str(e)) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index c476153f9f23..713e67793b29 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -287,9 +287,14 @@ async def create_completion(self, request: CompletionRequest, try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - sampling_params.logits_processors = \ + guided_decode_logit_processor = ( await get_guided_decoding_logits_processor( - request, self.engine.get_tokenizer()) + request, self.engine.get_tokenizer())) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logit_processor) prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py index 4ef735de3928..a8573f8bdc6c 100644 --- a/vllm/model_executor/guided_decoding.py +++ b/vllm/model_executor/guided_decoding.py @@ -47,7 +47,7 @@ async def get_guided_decoding_logits_processor( logits_processor = copy(result) # reset logits processor's internal state logits_processor.init_state() - return [logits_processor] + return logits_processor def _get_guide_and_mode(