diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index fc15b7833ecf..ea879a019102 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -78,6 +78,8 @@ class ChatCompletionRequest(BaseModel): echo: Optional[bool] = False repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + prefix_pos: Optional[int] = None + prefix_stop: Optional[str] = None include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 @@ -131,6 +133,8 @@ class CompletionRequest(BaseModel): spaces_between_special_tokens: Optional[bool] = True repetition_penalty: Optional[float] = 1.0 min_p: Optional[float] = 0.0 + prefix_pos: Optional[int] = None + prefix_stop: Optional[str] = None include_stop_str_in_output: Optional[bool] = False length_penalty: Optional[float] = 1.0 diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index a9e4c355560b..c39acb5b8e1d 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -67,8 +67,16 @@ async def create_chat_completion( except ValueError as e: return self.create_error_response(str(e)) + prefix_pos = request.prefix_pos + if request.prefix_stop is not None and request.prefix_stop in prompt: + prefix_index = prompt.index(request.prefix_stop) + prefix_pos = len(self.tokenizer.encode(prompt[:prefix_index])) - 1 + prompt = prompt.replace(request.prefix_stop, '') + result_generator = self.engine.generate(prompt, sampling_params, - request_id, token_ids) + request_id, token_ids, + prefix_pos) + # Streaming response if request.stream: return self.chat_completion_stream_generator( diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8c9a7ad309ce..1b48543a72c1 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -286,11 +286,19 @@ async def create_completion(self, request: CompletionRequest, sampling_params = request.to_sampling_params() prompt_is_tokens, prompts = parse_prompt_format(request.prompt) + prefix_pos = request.prefix_pos for i, prompt in enumerate(prompts): if prompt_is_tokens: input_ids = self._validate_prompt_and_tokenize( request, prompt_ids=prompt) else: + # Parse prefix position by prompt and prefix_stop indicator + if request.prefix_stop is not None and request.prefix_stop in prompt: + prefix_index = prompt.index(request.prefix_stop) + prefix_pos = len( + self.tokenizer.encode(prompt[:prefix_index])) - 1 + prompt = prompt.replace(request.prefix_stop, '') + input_ids = self._validate_prompt_and_tokenize( request, prompt=prompt) @@ -298,7 +306,8 @@ async def create_completion(self, request: CompletionRequest, self.engine.generate(None, sampling_params, f"{request_id}-{i}", - prompt_token_ids=input_ids)) + prompt_token_ids=input_ids, + prefix_pos=prefix_pos)) except ValueError as e: return self.create_error_response(str(e))