diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index e3464dae6d6..611f9a2f781 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -152,65 +152,6 @@ def test_reset_prefix_cache(self): # Test resetting the prefix cache self.client.reset_prefix_cache() - def test_chat_completions_endpoint(self): - data = self.client.chat_completions( - messages=[{"role": "user", "content": "Say hello"}], - max_tokens=32, - ) - - assert "id" in data - assert "choices" in data - assert "usage" in data - assert len(data["choices"]) > 0 - assert data["choices"][0]["message"]["role"] == "assistant" - assert data["choices"][0]["finish_reason"] in ["stop", "length", "tool_calls"] - - def test_chat_completions_with_tools(self): - tools = [ - { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get weather information for a location", - "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}, - }, - } - ] - data = self.client.chat_completions( - messages=[{"role": "user", "content": "What's the weather in San Francisco?"}], - tools=tools, - max_tokens=100, - ) - - assert "choices" in data - assert len(data["choices"]) > 0 - assert "message" in data["choices"][0] - - def test_chat_completions_with_params(self): - data = self.client.chat_completions( - messages=[{"role": "user", "content": "Tell me a joke"}], - n=2, - temperature=0.8, - top_p=0.9, - max_tokens=32, - ) - - assert len(data["choices"]) == 2 - - for i, choice in enumerate(data["choices"]): - assert choice["index"] == i, f"Expected choice at position {i} to have index {i}, got {choice['index']}" - assert "message" in choice - assert choice["message"]["role"] == "assistant" - - def test_tokenize_endpoint(self): - data = self.client.tokenize(messages=[{"role": "user", "content": "Hello, how are you?"}]) - - assert "tokens" in data - assert "model" in data - assert isinstance(data["tokens"], list) - assert len(data["tokens"]) > 0 - assert all(isinstance(tok, int) for tok in data["tokens"]) - @pytest.mark.xfail(reason="Importing `bitsandbytes` causes issues, see vllm-project/vllm#32793") def test_logprobs_match_with_non_default_sampling(self): prompts = ["Hello, AI!", "Tell me a joke"] diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index d66916177d7..51597382603 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -514,82 +514,6 @@ def reset_prefix_cache(self): if response.status_code != 200: raise Exception(f"Request failed: {response.status_code}, {response.text}") - def chat_completions( - self, - messages: list[dict], - model: str | None = None, - temperature: float = 1.0, - top_p: float = 1.0, - max_tokens: int | None = None, - n: int = 1, - tools: list[dict] | None = None, - **kwargs, - ) -> dict: - """ - OpenAI-compatible chat completions endpoint. - - Args: - messages (`list[dict]`): - List of messages in OpenAI format with "role" and "content" keys. - model (`str`, *optional*): - Model name to use. - temperature (`float`, *optional*, defaults to `1.0`): - Temperature for sampling. - top_p (`float`, *optional*, defaults to `1.0`): - Top-p sampling parameter. - max_tokens (`int`, *optional*): - Maximum number of tokens to generate. - n (`int`, *optional*, defaults to `1`): - Number of completions to generate. - tools (`list[dict]`, *optional*): - List of tool definitions for tool calling. - **kwargs: - Additional parameters to pass to the endpoint. - - Returns: - `dict`: - OpenAI-compatible response with "choices", "usage", etc. - """ - url = f"{self.base_url}/v1/chat/completions" - response = self.session.post( - url, - json={ - "messages": messages, - "model": model, - "temperature": temperature, - "top_p": top_p, - "max_tokens": max_tokens, - "n": n, - "tools": tools, - **kwargs, - }, - ) - if response.status_code == 200: - return response.json() - else: - raise Exception(f"Request failed: {response.status_code}, {response.text}") - - def tokenize(self, messages: list[dict], tools: list[dict] | None = None) -> dict: - """ - Tokenize messages to get token IDs. - - Args: - messages (`list[dict]`): - List of messages to tokenize. - tools (`list[dict]`, *optional*): - List of tool definitions. - - Returns: - `dict`: - Dictionary with "tokens" (list of token IDs) and "model" keys. - """ - url = f"{self.base_url}/tokenize" - response = self.session.post(url, json={"messages": messages, "tools": tools}) - if response.status_code == 200: - return response.json() - else: - raise Exception(f"Request failed: {response.status_code}, {response.text}") - def close_communicator(self): """ Closes the weight update group and cleans up the communication group. diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 286d42a0361..9ae419c8d43 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -14,12 +14,8 @@ import argparse import base64 -import json import logging import os -import re -import time -import uuid from collections.abc import Sequence from contextlib import asynccontextmanager from dataclasses import dataclass, field @@ -31,7 +27,7 @@ import torch import torch.distributed.distributed_c10d as c10d from packaging.version import Version -from transformers import AutoTokenizer, is_torch_xpu_available, is_vision_available +from transformers import is_torch_xpu_available, is_vision_available from trl import TrlParser from trl.generation.vllm_generation import sanitize_logprob @@ -389,23 +385,7 @@ def llm_worker( method_name = command["method"] args, kwargs = command.get("args", ()), command.get("kwargs", {}) method = getattr(llm, method_name) - - try: - result = method(*args, **kwargs) - except ValueError as e: - error_msg = str(e) - if "longer than the maximum model length" in error_msg or "context length" in error_msg: - logger.error(f"[Worker] Context length exceeded: {error_msg}") - if method_name in ["generate", "chat"]: - result = [] - else: - raise - else: - raise - except Exception as e: - logger.error(f"[Worker] Unexpected error in {method_name}: {e}") - raise - + result = method(*args, **kwargs) if command["type"] == "call": connection.send(result) elif command["type"] == "shutdown": @@ -432,61 +412,6 @@ def chunk_list(lst: list, n: int) -> list[list]: return [lst[i * k + min(i, r) : (i + 1) * k + min(i + 1, r)] for i in range(n)] -def _replace_prefix_tokens( - tokenizer, - model_prefix_token_ids: list[int], - template_prefix_token_ids: list[int], - template_token_ids: list[int], -) -> list[int]: - """ - This function is for fixing up the chat template-tokenized messages history to match the model output tokenization - up to the last assistant turn, in order to preserve the monotonic tokens property for optimized multi-turn - training. - - RL training frameworks train models on token IDs, but the OpenAI compatible server communicates in what is - basically de-tokenized text. When multiple model calls are made to the OpenAI compatible server in a single - trajectory, model generations in previous model calls may be re-tokenized to something that is different than what - was generated. This is not too big of an issue (that we know of) at inference time, but the log probs the model - produces are different enough for the differently re-tokenized generation result that it causes the training to be - off policy. Off policy isn't necessarily a bad thing in isolation, but this source of off-policyness may cause - unexpected issues if not properly accounted for. It also mis-aligns the token ID sequences across model calls, - which is strange during training. - - There are real cases where the model output string _does not match_ the chat template tokenization of the parsed - model output. A concrete example is inconsistent whitespace tokens around tool call special tokens. - - Based on NeMo RL's _replace_prefix_tokens: - https://github.com/NVIDIA-NeMo/RL/blob/748b9caff4e6d672b8a98a10b6e612d028cfc96b/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 - """ - if not model_prefix_token_ids: - return template_token_ids - - eos_token_id = tokenizer.eos_token_id - if eos_token_id is None: - logger.warning("Tokenizer has no EOS token ID, cannot apply _replace_prefix_tokens") - return template_token_ids - - model_cut_end = len(model_prefix_token_ids) - if model_prefix_token_ids and model_prefix_token_ids[-1] == eos_token_id: - model_cut_end -= 1 - - # We take everything starting with the EOS token ID. - template_cut_start = -1 - for pos in reversed(range(len(template_prefix_token_ids))): - if template_token_ids[pos] == eos_token_id: - template_cut_start = pos - break - - # This should never be the case, but - if template_cut_start < 0: - logger.warning("No EOS token found in template prefix, cannot apply _replace_prefix_tokens") - return template_token_ids - - result = model_prefix_token_ids[:model_cut_end] + template_token_ids[template_cut_start:] - - return result - - def main(script_args: ScriptArguments): if not is_fastapi_available(): raise ImportError( @@ -519,11 +444,6 @@ def main(script_args: ScriptArguments): @asynccontextmanager async def lifespan(app: FastAPI): - logger.info(f"Loading tokenizer for {script_args.model}...") - app.state.tokenizer = AutoTokenizer.from_pretrained( - script_args.model, trust_remote_code=script_args.trust_remote_code - ) - # Wait for all workers to send "ready" ready_connections = set() while len(ready_connections) < script_args.data_parallel_size: @@ -736,7 +656,6 @@ class ChatRequest(BaseModel): structured_outputs_regex: str | None = None generation_kwargs: dict = field(default_factory=dict) chat_template_kwargs: dict = field(default_factory=dict) - tools: list[dict] | None = None class ChatResponse(BaseModel): prompt_ids: list[list[int]] @@ -865,9 +784,7 @@ async def chat(request: ChatRequest): "messages": messages, "sampling_params": sampling_params, "chat_template_kwargs": request.chat_template_kwargs, - "tools": request.tools if request.tools else None, } - connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) # Receive results @@ -970,337 +887,8 @@ async def close_communicator(): connection.send({"type": "fire_and_forget", "method": "collective_rpc", "kwargs": kwargs}) return {"message": "Request received, closing communicator"} - class ChatCompletionRequest(BaseModel): - messages: list[dict] - model: str | None = None - temperature: float = 1.0 - top_p: float = 1.0 - max_completion_tokens: int | None = None - max_tokens: int | None = None - n: int = 1 - stop: str | list[str] | None = None - presence_penalty: float = 0.0 - frequency_penalty: float = 0.0 - logprobs: bool = False - top_logprobs: int | None = None - tools: list[dict] | None = None - tool_choice: str | dict = "auto" - parallel_tool_calls: bool = True - - @app.post("/v1/chat/completions") - async def chat_completions(request: ChatCompletionRequest): - completion_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" - created_at = int(time.time()) - - messages = [] - for msg in request.messages: - role = msg.get("role", "") - if role not in ["system", "user", "assistant", "tool"]: - logger.warning(f"Unknown message role: {role}") - messages.append(msg) - - max_tokens = request.max_completion_tokens or request.max_tokens or 512 - - sampling_kwargs = { - "n": request.n, - "temperature": request.temperature, - "top_p": request.top_p, - "max_tokens": max_tokens, - "presence_penalty": request.presence_penalty, - "frequency_penalty": request.frequency_penalty, - "stop": request.stop, - } - - if request.logprobs or request.top_logprobs: - sampling_kwargs["logprobs"] = request.top_logprobs if request.top_logprobs else 1 - - sampling_params = SamplingParams(**sampling_kwargs) - - chat_template_kwargs = {} - if request.tool_choice and request.tool_choice != "auto": - chat_template_kwargs["tool_choice"] = request.tool_choice - - has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) - - if has_prefix_token_ids: - # do on policy token id correction and call generate instead of chat - # see https://docs.nvidia.com/nemo/gym/latest/contribute/rl-framework-integration/openai-compatible-http-server-on-policy-correction.html - # and https://github.com/NVIDIA-NeMo/RL/blob/main/nemo_rl/models/generation/vllm/vllm_worker_async.py#L40 - tokenizer = app.state.tokenizer - - # preprocess full conversation - connections[0].send( - { - "type": "call", - "method": "preprocess_chat", - "kwargs": { - "messages": [messages], - "chat_template_kwargs": chat_template_kwargs, - "tools": request.tools, - "add_generation_prompt": True, - }, - } - ) - template_prompts = connections[0].recv() - template_prompt = template_prompts[0] - - # extract model prefix tokens from last assistant message - model_prefix_tokens = None - last_assistant_idx = None - for i in reversed(range(len(messages))): - if messages[i].get("role") == "assistant": - last_assistant_idx = i - if "prompt_token_ids" in messages[i]: - model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( - "generation_token_ids", [] - ) - break - - if model_prefix_tokens and last_assistant_idx is not None: - messages_to_last_assistant = messages[: last_assistant_idx + 1] - connections[0].send( - { - "type": "call", - "method": "preprocess_chat", - "kwargs": { - "messages": [messages_to_last_assistant], - "chat_template_kwargs": chat_template_kwargs, - "tools": request.tools, - "add_generation_prompt": False, - }, - } - ) - template_prefix_prompts = connections[0].recv() - template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] - - corrected_token_ids = _replace_prefix_tokens( - tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] - ) - - else: - corrected_token_ids = template_prompt["prompt_token_ids"] - - corrected_prompt = {"prompt_token_ids": corrected_token_ids} - chunked_prompts = chunk_list([corrected_prompt], script_args.data_parallel_size) - - for connection, prompts in zip(connections, chunked_prompts, strict=True): - if not prompts: - prompts = [{"prompt_token_ids": [tokenizer.eos_token_id]}] - connection.send( - { - "type": "call", - "method": "generate", - "kwargs": {"prompts": prompts, "sampling_params": sampling_params}, - } - ) - else: - # no prefix token IDs, use chat() - chunked_messages = chunk_list([messages], script_args.data_parallel_size) - - for connection, message_chunk in zip(connections, chunked_messages, strict=True): - if not message_chunk: - message_chunk = [[{"role": "user", "content": ""}]] - kwargs = { - "messages": message_chunk, - "sampling_params": sampling_params, - "tools": request.tools, - "chat_template_kwargs": chat_template_kwargs, - } - connection.send({"type": "call", "method": "chat", "kwargs": kwargs}) - - all_outputs = [connection.recv() for connection in connections] - if has_prefix_token_ids: - all_outputs = [ - output for output, prompt_chunk in zip(all_outputs, chunked_prompts, strict=True) if prompt_chunk - ] - else: - all_outputs = [ - output for output, msg_chunk in zip(all_outputs, chunked_messages, strict=True) if msg_chunk - ] - all_outputs = list(chain.from_iterable(all_outputs)) - - if not all_outputs: - return { - "id": completion_id, - "object": "chat.completion", - "created": created_at, - "model": request.model or script_args.model, - "choices": [ - { - "index": 0, - "message": {"role": "assistant", "content": ""}, - "finish_reason": "length", - "logprobs": None, - } - ], - "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, - } - - choices = [] - total_input_tokens = 0 - total_output_tokens = 0 - - idx = 0 - for output in all_outputs: - total_input_tokens += len(output.prompt_token_ids) - - for gen_output in output.outputs: - total_output_tokens += len(gen_output.token_ids) - text = gen_output.text if hasattr(gen_output, "text") else "" - - tool_calls = None - finish_reason = gen_output.finish_reason if hasattr(gen_output, "finish_reason") else "stop" - - # Manual XML-json tool call parsing - if request.tools and text: - pattern = r"(.*?)" - matches = re.findall(pattern, text, re.DOTALL) - if matches: - tool_calls = [] - for match in matches: - try: - data = json.loads(match.strip()) - tool_calls.append( - { - "id": f"call_{uuid.uuid4().hex[:24]}", - "type": "function", - "function": { - "name": data.get("name", ""), - "arguments": json.dumps(data.get("arguments", {})), - }, - } - ) - except json.JSONDecodeError: - continue - if tool_calls: - finish_reason = "tool_calls" - text = re.sub(pattern, "", text, flags=re.DOTALL).strip() - - if not request.parallel_tool_calls and tool_calls and len(tool_calls) > 1: - tool_calls = [tool_calls[0]] - - logprobs_data = None - if request.logprobs and hasattr(gen_output, "logprobs") and gen_output.logprobs: - logprobs_data = { - "content": [ - { - "token": str(token_id), - "logprob": float(list(logprob_dict.values())[0].logprob) if logprob_dict else 0.0, - "bytes": None, - "top_logprobs": [], - } - for token_id, logprob_dict in zip(gen_output.token_ids, gen_output.logprobs, strict=False) - ] - } - - choices.append( - { - "index": idx, - "message": { - "role": "assistant", - "content": text if not tool_calls else None, - "tool_calls": tool_calls, - }, - "logprobs": logprobs_data, - "finish_reason": finish_reason, - } - ) - idx += 1 - - return { - "id": completion_id, - "object": "chat.completion", - "created": created_at, - "model": request.model or script_args.model, - "choices": choices, - "usage": { - "prompt_tokens": total_input_tokens, - "completion_tokens": total_output_tokens, - "total_tokens": total_input_tokens + total_output_tokens, - }, - } - - class TokenizeRequest(BaseModel): - model: str | None = None - messages: list[dict] - tools: list[dict] | None = None - - @app.post("/tokenize") - async def tokenize(request: TokenizeRequest): - messages = request.messages - - has_prefix_token_ids = any(msg.get("role") == "assistant" and "prompt_token_ids" in msg for msg in messages) - - kwargs = { - "messages": [messages], - "tools": request.tools, - "add_generation_prompt": True, - "chat_template_kwargs": {}, - } - - connections[0].send({"type": "call", "method": "preprocess_chat", "kwargs": kwargs}) - preprocessed_prompts = connections[0].recv() - - if preprocessed_prompts and len(preprocessed_prompts) > 1: - logger.warning( - "More than one tokenized message returned from preprocess_chat inside tokenize, double check results!" - ) - - if not preprocessed_prompts or len(preprocessed_prompts) == 0: - return {"tokens": [], "model": request.model or script_args.model} - - template_prompt = preprocessed_prompts[0] - result_tokens = template_prompt["prompt_token_ids"] - - if has_prefix_token_ids: - tokenizer = app.state.tokenizer - - # Extract model prefix tokens from last assistant message - model_prefix_tokens = None - last_assistant_idx = None - for i in reversed(range(len(messages))): - if messages[i].get("role") == "assistant": - last_assistant_idx = i - if "prompt_token_ids" in messages[i]: - model_prefix_tokens = messages[i]["prompt_token_ids"] + messages[i].get( - "generation_token_ids", [] - ) - break - - if model_prefix_tokens and last_assistant_idx is not None: - # Preprocess up to last assistant - messages_to_last_assistant = messages[: last_assistant_idx + 1] - connections[0].send( - { - "type": "call", - "method": "preprocess_chat", - "kwargs": { - "messages": [messages_to_last_assistant], - "tools": request.tools, - "add_generation_prompt": False, - "chat_template_kwargs": {}, - }, - } - ) - template_prefix_prompts = connections[0].recv() - template_prefix_token_ids = template_prefix_prompts[0]["prompt_token_ids"] - - result_tokens = _replace_prefix_tokens( - tokenizer, model_prefix_tokens, template_prefix_token_ids, template_prompt["prompt_token_ids"] - ) - - return {"tokens": result_tokens, "model": request.model or script_args.model} - # Start the server - uvicorn.run( - app, - host=script_args.host, - port=script_args.port, - log_level=script_args.log_level, - limit_concurrency=256, - backlog=4096, - timeout_keep_alive=600, - ) + uvicorn.run(app, host=script_args.host, port=script_args.port, log_level=script_args.log_level) def make_parser(subparsers: argparse._SubParsersAction | None = None):