diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 7c14af14e9b..fe4eb184127 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -207,6 +207,31 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs + assert prompt_ids == prompt_token_ids + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -411,6 +436,31 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs + assert prompt_ids == prompt_token_ids + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -536,6 +586,31 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs + assert prompt_ids == prompt_token_ids + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ @@ -665,6 +740,31 @@ def multiply(a: int, b: int) -> int: decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0]) assert "Multiplies two integers." in decoded_prompt + def test_generate_with_token_ids(self): + tokenizer = AutoTokenizer.from_pretrained(self.model_id) + prompts = ["Hello, AI!", "Tell me a joke"] + prompt_token_ids = tokenizer(prompts)["input_ids"] + outputs = self.client.generate(prompt_token_ids) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + + # Check that the outputs are lists + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) + + # Check that the number of sequences are equal to the number of prompts + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) + + # Check that prompt_ids match the input token IDs + assert prompt_ids == prompt_token_ids + + # Check that the sequences are lists of integers + for seq in prompt_ids: + assert all(isinstance(tok, int) for tok in seq) + for seq in completion_ids: + assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 411317e9057..c1c1d962d8c 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -201,7 +201,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0): def generate( self, - prompts: list[str], + prompts: list[str] | list[list[int]], images: list | None = None, n: int = 1, repetition_penalty: float = 1.0, @@ -219,10 +219,10 @@ def generate( Generates model completions for the provided prompts. Args: - prompts (`list[str]`): - List of text prompts for which the model will generate completions. + prompts (`list[str]` or `list[list[int]]`): + List of text prompts or list of token ID lists for which the model will generate completions. images (`list[PIL.Image]`, *optional*): - List of PIL Images to send along with the prompts. + List of PIL Images to send along with the prompts. Only valid when `prompts` is a list of strings. n (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. repetition_penalty (`float`, *optional*, defaults to `1.0`): diff --git a/trl/generation/vllm_generation.py b/trl/generation/vllm_generation.py index 8ba6fc2a857..1c4065bfc85 100644 --- a/trl/generation/vllm_generation.py +++ b/trl/generation/vllm_generation.py @@ -627,7 +627,8 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte chat_template=chat_template, ) else: - output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) + ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"] + output = self.vllm_client.generate(prompts=ordered_set_of_prompt_ids, **sampling_params) # Extract required fields and collect any extra fields for reward functions required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"} extra_fields = {k: v for k, v in output.items() if k not in required_keys} diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 5c5962eeb27..77335b8cc13 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -490,7 +490,7 @@ async def get_world_size(): return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size} class GenerateRequest(BaseModel): - prompts: list[str] + prompts: list[str] | list[list[int]] images: list[str] | None = None n: int = 1 repetition_penalty: float = 1.0 @@ -517,7 +517,8 @@ async def generate(request: GenerateRequest): Args: request (`GenerateRequest`): - - `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions. + - `prompts` (list of `str` or list of list of `int`): A list of prompts. It accepts either text strings + or pre-tokenized token ID lists. When text strings are provided, `images` can optionally be included. - `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process along with prompts. - `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt. @@ -553,11 +554,16 @@ async def generate(request: GenerateRequest): - `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each logprob, same shape as `logprobs`. - Example request: + Example request (text prompts): ```json {"prompts": ["Hello world", "What is AI?"]} ``` + Example request (token IDs): + ```json + {"prompts": [[101, 102], [201, 202]]} + ``` + Example response: ```json { @@ -568,14 +574,20 @@ async def generate(request: GenerateRequest): } ``` """ - request.images = request.images or [None] * len(request.prompts) - - prompts = [] - for prompt, image in zip(request.prompts, request.images, strict=True): - row = {"prompt": prompt} - if image is not None: - row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} - prompts.append(row) + # Build vLLM-compatible prompt inputs + if request.prompts and isinstance(request.prompts[0], list): + # Token IDs path: wrap each list of token IDs as a TokensPrompt dict for vLLM + prompts = [{"prompt_token_ids": ids} for ids in request.prompts] + else: + # Text prompts path: build prompt dicts with optional images + request.images = request.images or [None] * len(request.prompts) + + prompts = [] + for prompt, image in zip(request.prompts, request.images, strict=True): + row = {"prompt": prompt} + if image is not None: + row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))} + prompts.append(row) generation_kwargs = { "n": request.n,