From f10285e1f956afbcd5cdc835e20de972ced58492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Mar 2026 19:10:52 +0000 Subject: [PATCH 1/7] support prompts or token IDs in VLLMClient and update API request handling --- trl/generation/vllm_client.py | 60 ++++++++++++++++------------ trl/generation/vllm_generation.py | 5 ++- trl/scripts/vllm_serve.py | 65 +++++++++++++++++++++---------- 3 files changed, 84 insertions(+), 46 deletions(-) diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 411317e9057..d709a36a1a4 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -201,8 +201,9 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): def generate( self, - prompts: list[str], + prompts: list[str] | None = None, images: list | None = None, + prompt_token_ids: list[list[int]] | None = None, n: int = 1, repetition_penalty: float = 1.0, temperature: float = 1.0, @@ -216,13 +217,17 @@ def generate( generation_kwargs: dict | None = None, ) -> dict[str, list[list[int]]]: """ - Generates model completions for the provided prompts. + Generates model completions for the provided prompts or token IDs. + + Either `prompts` or `prompt_token_ids` must be provided, but not both. Args: - prompts (`list[str]`): + prompts (`list[str]`, *optional*): List of text prompts for which the model will generate completions. images (`list[PIL.Image]`, *optional*): - List of PIL Images to send along with the prompts. + List of PIL Images to send along with the prompts. Only valid when `prompts` is provided. + prompt_token_ids (`list[list[int]]` or `None`, *optional*): + List of tokenized prompts (list of list of token IDs) for which the model will generate completions. n (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. repetition_penalty (`float`, *optional*, defaults to `1.0`): @@ -263,29 +268,36 @@ def generate( - `logprob_token_ids` (`list[list[list[int]]]`): Token IDs corresponding to each logprob, same shape as `logprobs`. """ + if prompt_token_ids is not None and prompts is not None: + raise ValueError("Only one of 'prompts' or 'prompt_token_ids' can be provided, not both.") + if prompt_token_ids is None and prompts is None: + raise ValueError("Either 'prompts' or 'prompt_token_ids' must be provided.") + url = f"{self.base_url}/generate/" - # Convert PIL images to base64 strings - images = [pil_to_base64(img) for img in images] if images else None + # Build the payload with whichever input mode is provided + payload = { + "n": n, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "max_tokens": max_tokens, + "logprobs": logprobs, + "truncate_prompt_tokens": truncate_prompt_tokens, + "structured_outputs_regex": structured_outputs_regex, + "generation_kwargs": generation_kwargs or {}, + } + + if prompt_token_ids is not None: + payload["prompt_token_ids"] = prompt_token_ids + else: + payload["prompts"] = prompts + if images is not None: + payload["images"] = [pil_to_base64(image) if image is not None else None for image in images] - response = self.session.post( - url, - json={ - "prompts": prompts, - "images": images, - "n": n, - "repetition_penalty": repetition_penalty, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "min_p": min_p, - "max_tokens": max_tokens, - "logprobs": logprobs, - "truncate_prompt_tokens": truncate_prompt_tokens, - "structured_outputs_regex": structured_outputs_regex, - "generation_kwargs": generation_kwargs or {}, - }, - ) + response = self.session.post(url, json=payload) if response.status_code == 200: json_response = response.json() return { diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 8ba6fc2a857..2eb0a928107 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -627,7 +627,10 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte chat_template=chat_template, ) else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"] + output = self.vllm_client.generate( + prompt_token_ids=ordered_set_of_prompt_ids, **sampling_params + ) # Extract required fields and collect any extra fields for reward functions required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"} extra_fields = {k: v for k, v in output.items() if k not in required_keys} diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index fed2af86bf6..af3b3ed9020 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -495,8 +495,9 @@ async def get_world_size(): return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} class GenerateRequest(BaseModel): - prompts: list[str] + prompts: list[str] | None = None images: list[str] | None = None + prompt_token_ids: list[list[int]] | None = None n: int = 1 repetition_penalty: float = 1.0 temperature: float = 1.0 @@ -518,13 +519,19 @@ class GenerateResponse(BaseModel): @app.post("/generate/", response_model=GenerateResponse) async def generate(request: GenerateRequest): """ - Generates completions for the provided prompts. + Generates completions for the provided prompts or token IDs. + + Accepts either `prompts` (text strings, optionally with `images`) or `prompt_token_ids` (pre-tokenized). + Exactly one of `prompts` or `prompt_token_ids` must be provided. Args: request (`GenerateRequest`): - - `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. + - `prompts` (list of `str`, *optional*): A list of prompts (text strings) for the model to generate + completions. - `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process along with prompts. + - `prompt_token_ids` (list of list of `int`, *optional*): A list of tokenized prompts for the model to + generate completions. - `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. - `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during generation. @@ -558,11 +565,16 @@ async def generate(request: GenerateRequest): - `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each logprob, same shape as `logprobs`. - Example request: + Example request (text prompts): ```json {"prompts": ["Hello world", "What is AI?"]} ``` + Example request (token IDs): + ```json + {"prompt_token_ids": [[101, 102], [201, 202]]} + ``` + Example response: ```json { @@ -573,14 +585,10 @@ async def generate(request: GenerateRequest): } ``` """ - request.images = request.images or [None] * len(request.prompts) - - prompts = [] - for prompt, image in zip(request.prompts, request.images, strict=True): - row = {"prompt": prompt} - if image is not None: - row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} - prompts.append(row) + if request.prompt_token_ids is not None and request.prompts is not None: + raise ValueError("Only one of 'prompts' or 'prompt_token_ids' can be provided, not both.") + if request.prompt_token_ids is None and request.prompts is None: + raise ValueError("Either 'prompts' or 'prompt_token_ids' must be provided.") generation_kwargs = { "n": request.n, @@ -627,24 +635,39 @@ async def generate(request: GenerateRequest): generation_kwargs[structured_outputs_key] = structured_outputs sampling_params = SamplingParams(**generation_kwargs) - # Evenly distribute prompts across DP ranks - chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) - - # Send the prompts to each worker - for connection, prompts in zip(connections, chunked_prompts, strict=True): + if request.prompt_token_ids is not None: + # Token IDs path: pass pre-tokenized prompts directly to vLLM workers + chunked_inputs = chunk_list(request.prompt_token_ids, script_args.data_parallel_size) + input_key = "prompt_token_ids" + placeholder = [[0]] + else: + # Text prompts path: build prompt dicts with optional images + request.images = request.images or [None] * len(request.prompts) + prompts = [] + for prompt, image in zip(request.prompts, request.images, strict=True): + row = {"prompt": prompt} + if image is not None: + row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} + prompts.append(row) + chunked_inputs = chunk_list(prompts, script_args.data_parallel_size) + input_key = "prompts" + placeholder = [""] + + # Send inputs to each worker + for connection, chunk in zip(connections, chunked_inputs, strict=True): # When the number of prompts is less than data_parallel_size, some workers will receive empty prompts. # However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply # with vLLM's requirement, and we later ignore the result. - if not prompts: - prompts = [""] - kwargs = {"prompts": prompts, "sampling_params": sampling_params} + if not chunk: + chunk = placeholder + kwargs = {input_key: chunk, "sampling_params": sampling_params} connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) # Receive results all_outputs = [connection.recv() for connection in connections] # Handle empty prompts (see above) - all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts, strict=True) if prompts] + all_outputs = [output for output, chunk in zip(all_outputs, chunked_inputs, strict=True) if chunk] # Flatten and combine all results all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list From 7d2bb6727b61082c0f3ff72a82be0cc4c6fb5298 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Mar 2026 19:12:14 +0000 Subject: [PATCH 2/7] test --- tests/test_vllm_client_server.py | 67 ++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 7c14af14e9b..d8335378328 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -207,6 +207,31 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs + assert prompt_ids == prompt_token_ids + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -411,6 +436,20 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + assert prompt_ids == prompt_token_ids + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -536,6 +575,20 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + assert prompt_ids == prompt_token_ids + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -665,6 +718,20 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + assert prompt_ids == prompt_token_ids + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ From 3b356ac4b25de8a3b652ed3343eb83c6e629338b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Mar 2026 19:20:14 +0000 Subject: [PATCH 3/7] consistency --- tests/test_vllm_client_server.py | 33 ++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index d8335378328..ec96854992d 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -444,12 +444,23 @@ def test_generate_with_token_ids(self): prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] + # Check that the outputs are lists assert isinstance(prompt_ids, list) assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts assert len(prompt_ids) == len(prompts) assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs assert prompt_ids == prompt_token_ids + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -583,12 +594,23 @@ def test_generate_with_token_ids(self): prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] + # Check that the outputs are lists assert isinstance(prompt_ids, list) assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts assert len(prompt_ids) == len(prompts) assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs assert prompt_ids == prompt_token_ids + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -726,12 +748,23 @@ def test_generate_with_token_ids(self): prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] + # Check that the outputs are lists assert isinstance(prompt_ids, list) assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts assert len(prompt_ids) == len(prompts) assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs assert prompt_ids == prompt_token_ids + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ From 82c4508f934bed685b852f78b4d114b96d145d2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Mar 2026 19:53:20 +0000 Subject: [PATCH 4/7] fix --- tests/test_vllm_client_server.py | 8 ++-- trl/generation/vllm_client.py | 62 +++++++++++++------------------ trl/generation/vllm_generation.py | 4 +- trl/scripts/vllm_serve.py | 22 +++++------ 4 files changed, 39 insertions(+), 57 deletions(-) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index ec96854992d..fe4eb184127 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -211,7 +211,7 @@ def test_generate_with_token_ids(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id) prompts = ["Hello, AI!", "Tell me a joke"] prompt_token_ids = tokenizer(prompts)["input_ids"] - outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + outputs = self.client.generate(prompt_token_ids) prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] @@ -440,7 +440,7 @@ def test_generate_with_token_ids(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id) prompts = ["Hello, AI!", "Tell me a joke"] prompt_token_ids = tokenizer(prompts)["input_ids"] - outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + outputs = self.client.generate(prompt_token_ids) prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] @@ -590,7 +590,7 @@ def test_generate_with_token_ids(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id) prompts = ["Hello, AI!", "Tell me a joke"] prompt_token_ids = tokenizer(prompts)["input_ids"] - outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + outputs = self.client.generate(prompt_token_ids) prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] @@ -744,7 +744,7 @@ def test_generate_with_token_ids(self): tokenizer = AutoTokenizer.from_pretrained(self.model_id) prompts = ["Hello, AI!", "Tell me a joke"] prompt_token_ids = tokenizer(prompts)["input_ids"] - outputs = self.client.generate(prompt_token_ids=prompt_token_ids) + outputs = self.client.generate(prompt_token_ids) prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index d709a36a1a4..c1c1d962d8c 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -201,9 +201,8 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): def generate( self, - prompts: list[str] | None = None, + prompts: list[str] | list[list[int]], images: list | None = None, - prompt_token_ids: list[list[int]] | None = None, n: int = 1, repetition_penalty: float = 1.0, temperature: float = 1.0, @@ -217,17 +216,13 @@ def generate( generation_kwargs: dict | None = None, ) -> dict[str, list[list[int]]]: """ - Generates model completions for the provided prompts or token IDs. - - Either `prompts` or `prompt_token_ids` must be provided, but not both. + Generates model completions for the provided prompts. Args: - prompts (`list[str]`, *optional*): - List of text prompts for which the model will generate completions. + prompts (`list[str]` or `list[list[int]]`): + List of text prompts or list of token ID lists for which the model will generate completions. images (`list[PIL.Image]`, *optional*): - List of PIL Images to send along with the prompts. Only valid when `prompts` is provided. - prompt_token_ids (`list[list[int]]` or `None`, *optional*): - List of tokenized prompts (list of list of token IDs) for which the model will generate completions. + List of PIL Images to send along with the prompts. Only valid when `prompts` is a list of strings. n (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. repetition_penalty (`float`, *optional*, defaults to `1.0`): @@ -268,36 +263,29 @@ def generate( - `logprob_token_ids` (`list[list[list[int]]]`): Token IDs corresponding to each logprob, same shape as `logprobs`. """ - if prompt_token_ids is not None and prompts is not None: - raise ValueError("Only one of 'prompts' or 'prompt_token_ids' can be provided, not both.") - if prompt_token_ids is None and prompts is None: - raise ValueError("Either 'prompts' or 'prompt_token_ids' must be provided.") - url = f"{self.base_url}/generate/" - # Build the payload with whichever input mode is provided - payload = { - "n": n, - "repetition_penalty": repetition_penalty, - "temperature": temperature, - "top_p": top_p, - "top_k": top_k, - "min_p": min_p, - "max_tokens": max_tokens, - "logprobs": logprobs, - "truncate_prompt_tokens": truncate_prompt_tokens, - "structured_outputs_regex": structured_outputs_regex, - "generation_kwargs": generation_kwargs or {}, - } - - if prompt_token_ids is not None: - payload["prompt_token_ids"] = prompt_token_ids - else: - payload["prompts"] = prompts - if images is not None: - payload["images"] = [pil_to_base64(image) if image is not None else None for image in images] + # Convert PIL images to base64 strings + images = [pil_to_base64(img) for img in images] if images else None - response = self.session.post(url, json=payload) + response = self.session.post( + url, + json={ + "prompts": prompts, + "images": images, + "n": n, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "max_tokens": max_tokens, + "logprobs": logprobs, + "truncate_prompt_tokens": truncate_prompt_tokens, + "structured_outputs_regex": structured_outputs_regex, + "generation_kwargs": generation_kwargs or {}, + }, + ) if response.status_code == 200: json_response = response.json() return { diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 2eb0a928107..1c4065bfc85 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -628,9 +628,7 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte ) else: ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"] - output = self.vllm_client.generate( - prompt_token_ids=ordered_set_of_prompt_ids, **sampling_params - ) + output = self.vllm_client.generate(prompts=ordered_set_of_prompt_ids, **sampling_params) # Extract required fields and collect any extra fields for reward functions required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"} extra_fields = {k: v for k, v in output.items() if k not in required_keys} diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index af3b3ed9020..6bae0472fa0 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -495,9 +495,8 @@ async def get_world_size(): return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} class GenerateRequest(BaseModel): - prompts: list[str] | None = None + prompts: list[str] | list[list[int]] images: list[str] | None = None - prompt_token_ids: list[list[int]] | None = None n: int = 1 repetition_penalty: float = 1.0 temperature: float = 1.0 @@ -585,10 +584,8 @@ async def generate(request: GenerateRequest): } ``` """ - if request.prompt_token_ids is not None and request.prompts is not None: - raise ValueError("Only one of 'prompts' or 'prompt_token_ids' can be provided, not both.") - if request.prompt_token_ids is None and request.prompts is None: - raise ValueError("Either 'prompts' or 'prompt_token_ids' must be provided.") + # Detect whether prompts are text strings or pre-tokenized token ID lists + is_token_ids = request.prompts and isinstance(request.prompts[0], list) generation_kwargs = { "n": request.n, @@ -635,11 +632,11 @@ async def generate(request: GenerateRequest): generation_kwargs[structured_outputs_key] = structured_outputs sampling_params = SamplingParams(**generation_kwargs) - if request.prompt_token_ids is not None: - # Token IDs path: pass pre-tokenized prompts directly to vLLM workers - chunked_inputs = chunk_list(request.prompt_token_ids, script_args.data_parallel_size) - input_key = "prompt_token_ids" - placeholder = [[0]] + if is_token_ids: + # Token IDs path: wrap each list of token IDs as a TokensPrompt dict for vLLM + prompts = [{"prompt_token_ids": ids} for ids in request.prompts] + chunked_inputs = chunk_list(prompts, script_args.data_parallel_size) + placeholder = [{"prompt_token_ids": [0]}] else: # Text prompts path: build prompt dicts with optional images request.images = request.images or [None] * len(request.prompts) @@ -650,7 +647,6 @@ async def generate(request: GenerateRequest): row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} prompts.append(row) chunked_inputs = chunk_list(prompts, script_args.data_parallel_size) - input_key = "prompts" placeholder = [""] # Send inputs to each worker @@ -660,7 +656,7 @@ async def generate(request: GenerateRequest): # with vLLM's requirement, and we later ignore the result. if not chunk: chunk = placeholder - kwargs = {input_key: chunk, "sampling_params": sampling_params} + kwargs = {"prompts": chunk, "sampling_params": sampling_params} connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) # Receive results From 3ea2fcff50fe54ed8e092afd4890f76cfa1fbf8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Mar 2026 20:05:48 +0000 Subject: [PATCH 5/7] another fix --- trl/scripts/vllm_serve.py | 67 +++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 38 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 6bae0472fa0..bb69cf301ac 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -518,19 +518,14 @@ class GenerateResponse(BaseModel): @app.post("/generate/", response_model=GenerateResponse) async def generate(request: GenerateRequest): """ - Generates completions for the provided prompts or token IDs. - - Accepts either `prompts` (text strings, optionally with `images`) or `prompt_token_ids` (pre-tokenized). - Exactly one of `prompts` or `prompt_token_ids` must be provided. + Generates completions for the provided prompts. Args: request (`GenerateRequest`): - - `prompts` (list of `str`, *optional*): A list of prompts (text strings) for the model to generate - completions. + - `prompts` (list of `str` or list of list of `int`): A list of prompts. It accepts either text strings + or pre-tokenized token ID lists. When text strings are provided, `images` can optionally be included. - `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process along with prompts. - - `prompt_token_ids` (list of list of `int`, *optional*): A list of tokenized prompts for the model to - generate completions. - `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. - `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during generation. @@ -571,21 +566,31 @@ async def generate(request: GenerateRequest): Example request (token IDs): ```json - {"prompt_token_ids": [[101, 102], [201, 202]]} - ``` + {"prompts": [[101, 102], [201, 202]]} Example response: ```json { - "prompt_ids": [[101, 102], [201, 202]], - "completion_ids": [[103, 104, 105], [203, 204, 205]], - "logprobs": [[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]], - "logprob_token_ids": [[[103], [104], [105]], [[203], [204], [205]]] + "prompt_ids": [[101, 102], [201, 202]], "completion_ids": [[103, 104, 105], [203, 204, 205]], "logprobs": + [[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]], "logprob_token_ids": [[[103], [104], [105]], [[203], + [204], [205]]] } ``` """ - # Detect whether prompts are text strings or pre-tokenized token ID lists - is_token_ids = request.prompts and isinstance(request.prompts[0], list) + # Build vLLM-compatible prompt inputs + if request.prompts and isinstance(request.prompts[0], list): + # Token IDs path: wrap each list of token IDs as a TokensPrompt dict for vLLM + prompts = [{"prompt_token_ids": ids} for ids in request.prompts] + else: + # Text prompts path: build prompt dicts with optional images + request.images = request.images or [None] * len(request.prompts) + + prompts = [] + for prompt, image in zip(request.prompts, request.images, strict=True): + row = {"prompt": prompt} + if image is not None: + row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} + prompts.append(row) generation_kwargs = { "n": request.n, @@ -632,38 +637,24 @@ async def generate(request: GenerateRequest): generation_kwargs[structured_outputs_key] = structured_outputs sampling_params = SamplingParams(**generation_kwargs) - if is_token_ids: - # Token IDs path: wrap each list of token IDs as a TokensPrompt dict for vLLM - prompts = [{"prompt_token_ids": ids} for ids in request.prompts] - chunked_inputs = chunk_list(prompts, script_args.data_parallel_size) - placeholder = [{"prompt_token_ids": [0]}] - else: - # Text prompts path: build prompt dicts with optional images - request.images = request.images or [None] * len(request.prompts) - prompts = [] - for prompt, image in zip(request.prompts, request.images, strict=True): - row = {"prompt": prompt} - if image is not None: - row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} - prompts.append(row) - chunked_inputs = chunk_list(prompts, script_args.data_parallel_size) - placeholder = [""] + # Evenly distribute prompts across DP ranks + chunked_prompts = chunk_list(prompts, script_args.data_parallel_size) - # Send inputs to each worker - for connection, chunk in zip(connections, chunked_inputs, strict=True): + # Send the prompts to each worker + for connection, prompts in zip(connections, chunked_prompts, strict=True): # When the number of prompts is less than data_parallel_size, some workers will receive empty prompts. # However, vLLM requires that we always send at least one prompt. So we send a placeholder prompt to comply # with vLLM's requirement, and we later ignore the result. - if not chunk: - chunk = placeholder - kwargs = {"prompts": chunk, "sampling_params": sampling_params} + if not prompts: + prompts = [""] + kwargs = {"prompts": prompts, "sampling_params": sampling_params} connection.send({"type": "call", "method": "generate", "kwargs": kwargs}) # Receive results all_outputs = [connection.recv() for connection in connections] # Handle empty prompts (see above) - all_outputs = [output for output, chunk in zip(all_outputs, chunked_inputs, strict=True) if chunk] + all_outputs = [output for output, prompts in zip(all_outputs, chunked_prompts, strict=True) if prompts] # Flatten and combine all results all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list From 445f4ba764369eabf78654cf43fdc0dcf6492467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 5 Mar 2026 20:25:33 +0000 Subject: [PATCH 6/7] fix docstring --- trl/scripts/vllm_serve.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index bb69cf301ac..6fe6b175e44 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -567,6 +567,7 @@ async def generate(request: GenerateRequest): Example request (token IDs): ```json {"prompts": [[101, 102], [201, 202]]} + ``` Example response: ```json From f033e63ee356ba67ef69b514fd42a5976b83b7b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 9 Mar 2026 17:23:14 +0000 Subject: [PATCH 7/7] revert doc modif --- trl/scripts/vllm_serve.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 3261c5d2fc5..77335b8cc13 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -567,9 +567,10 @@ async def generate(request: GenerateRequest): Example response: ```json { - "prompt_ids": [[101, 102], [201, 202]], "completion_ids": [[103, 104, 105], [203, 204, 205]], "logprobs": - [[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]], "logprob_token_ids": [[[103], [104], [105]], [[203], - [204], [205]]] + "prompt_ids": [[101, 102], [201, 202]], + "completion_ids": [[103, 104, 105], [203, 204, 205]], + "logprobs": [[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]], + "logprob_token_ids": [[[103], [104], [105]], [[203], [204], [205]]] } ``` """