Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 187 additions & 1 deletion tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pytest
from packaging.version import Version
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
from transformers.testing_utils import torch_device

from trl.generation.vllm_client import VLLMClient
Expand All @@ -31,6 +31,7 @@
kill_process,
require_3_accelerators,
require_torch_multi_accelerator,
require_vision,
require_vllm,
)

Expand Down Expand Up @@ -207,6 +208,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -411,6 +437,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -536,6 +587,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -665,6 +741,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -774,3 +875,88 @@ def teardown_class(cls):
# vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to
# kill the server process and its children explicitly.
kill_process(cls.server_process)


@pytest.mark.slow
@require_vllm
@require_vision
class TestVLLMClientServerVLM(TrlTestCase):
model_id = "Qwen/Qwen2.5-VL-3B-Instruct"

@classmethod
def setup_class(cls):
# Start the server process
cls.server_process = subprocess.Popen(
["trl", "vllm-serve", "--model", cls.model_id], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)

# Initialize the client (no communicator needed for generation-only tests)
cls.client = VLLMClient(connection_timeout=240, host="localhost")

def test_generate_with_token_ids_and_image(self):
from PIL import Image

processor = AutoProcessor.from_pretrained(self.model_id)
image1 = Image.new("RGB", (64, 64), color="red")
image2 = Image.new("RGB", (64, 64), color="blue")
image3 = Image.new("RGB", (64, 64), color="green")
messages = [
[
{
"role": "user",
"content": [
{"type": "image", "image": image1},
{"type": "image", "image": image2},
{"type": "text", "text": "What are the differences between these two images?"},
],
}
],
[
{
"role": "user",
"content": [
{"type": "image", "image": image3},
{"type": "text", "text": "What is the color of this image?"},
],
}
],
]
prompt_token_ids = processor.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
outputs = self.client.generate(prompt_token_ids, images=[[image1, image2], [image3]], max_tokens=64)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

assert len(prompt_ids) == 2
assert len(completion_ids) == 2
assert all(isinstance(tok, int) for tok in prompt_ids[0])
assert all(isinstance(tok, int) for tok in completion_ids[0])

def test_generate_with_token_ids_mixed_images(self):
"""Test a batch where one prompt has an image and the other does not."""
from PIL import Image

processor = AutoProcessor.from_pretrained(self.model_id)
image = Image.new("RGB", (64, 64), color="red")
messages = [
[{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}]}],
[{"role": "user", "content": [{"type": "text", "text": "What is 1+1?"}]}],
]
prompt_token_ids = processor.apply_chat_template(
conversation=messages, tokenize=True, add_generation_prompt=True
)
outputs = self.client.generate(prompt_token_ids, images=[[image], None], max_tokens=64)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

assert len(prompt_ids) == 2
assert len(completion_ids) == 2
assert all(isinstance(tok, int) for tok in prompt_ids[0])
assert all(isinstance(tok, int) for tok in prompt_ids[1])
assert all(isinstance(tok, int) for tok in completion_ids[0])
assert all(isinstance(tok, int) for tok in completion_ids[1])

@classmethod
def teardown_class(cls):
kill_process(cls.server_process)
19 changes: 12 additions & 7 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):

def generate(
self,
prompts: list[str],
prompts: list[str] | list[list[int]],
images: list | None = None,
n: int = 1,
repetition_penalty: float = 1.0,
Expand All @@ -219,10 +219,11 @@ def generate(
Generates model completions for the provided prompts.

Args:
prompts (`list[str]`):
List of text prompts for which the model will generate completions.
images (`list[PIL.Image]`, *optional*):
List of PIL Images to send along with the prompts.
prompts (`list[str]` or `list[list[int]]`):
List of text prompts or list of token ID lists for which the model will generate completions.
images (`list[list[PIL.Image] | None]`, *optional*):
List of image lists for VLM support. Each element is a list of PIL images for the corresponding prompt,
or `None` if no images for that prompt.
n (`int`, *optional*, defaults to `1`):
Number of completions to generate for each prompt.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Expand Down Expand Up @@ -265,8 +266,12 @@ def generate(
"""
url = f"{self.base_url}/generate/"

# Convert PIL images to base64 strings
images = [pil_to_base64(img) for img in images] if images else None
# Convert PIL images to base64 strings. Each element is a list of images for the corresponding prompt,
# or None if no images for that prompt.
if images:
images = [
[pil_to_base64(img) for img in img_list] if img_list is not None else None for img_list in images
Comment thread
qgallouedec marked this conversation as resolved.
Comment thread
qgallouedec marked this conversation as resolved.
]
Comment thread
cursor[bot] marked this conversation as resolved.

response = self.session.post(
url,
Expand Down
3 changes: 2 additions & 1 deletion trl/generation/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
chat_template=chat_template,
)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"]
output = self.vllm_client.generate(prompts=ordered_set_of_prompt_ids, **sampling_params)
# Extract required fields and collect any extra fields for reward functions
required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
Expand Down
36 changes: 22 additions & 14 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,8 +495,8 @@ async def get_world_size():
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size}

class GenerateRequest(BaseModel):
prompts: list[str]
images: list[str] | None = None
prompts: list[str] | list[list[int]]
images: list[list[str] | None] | None = None
Comment thread
qgallouedec marked this conversation as resolved.
n: int = 1
repetition_penalty: float = 1.0
temperature: float = 1.0
Expand All @@ -522,9 +522,10 @@ async def generate(request: GenerateRequest):

Args:
request (`GenerateRequest`):
- `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions.
- `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process
along with prompts.
- `prompts` (list of `str` or list of list of `int`): A list of prompts. It accepts either text strings
or pre-tokenized token ID lists. When text strings are provided, `images` can optionally be included.
- `images` (list of list of `str` or `None`, *optional*): A list of image lists. Each element is a list
of base64-encoded images for the corresponding prompt, or `None` if no images for that prompt.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
- `repetition_penalty` (`float`, *optional*, defaults to `1.0`): Repetition penalty to apply during
generation.
Expand Down Expand Up @@ -558,28 +559,35 @@ async def generate(request: GenerateRequest):
- `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each logprob, same
shape as `logprobs`.

Example request:
Example request (text prompts):
```json
{"prompts": ["Hello world", "What is AI?"]}
```

Example request (token IDs):
```json
{"prompts": [[101, 102], [201, 202]]}
```

Example response:
```json
{
"prompt_ids": [[101, 102], [201, 202]],
"completion_ids": [[103, 104, 105], [203, 204, 205]],
"logprobs": [[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]],
"logprob_token_ids": [[[103], [104], [105]], [[203], [204], [205]]]
"prompt_ids": [[101, 102], [201, 202]], "completion_ids": [[103, 104, 105], [203, 204, 205]], "logprobs":
[[[-0.1], [-0.2], [-0.3]], [[-0.4], [-0.5], [-0.6]]], "logprob_token_ids": [[[103], [104], [105]], [[203],
[204], [205]]]
}
```
"""
# Build vLLM-compatible prompt inputs
is_token_ids = request.prompts and isinstance(request.prompts[0], list)
request.images = request.images or [None] * len(request.prompts)

prompts = []
for prompt, image in zip(request.prompts, request.images, strict=True):
row = {"prompt": prompt}
if image is not None:
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
for prompt, image_list in zip(request.prompts, request.images, strict=True):
row = {"prompt_token_ids": prompt} if is_token_ids else {"prompt": prompt}
if image_list is not None:
decoded_images = [Image.open(BytesIO(base64.b64decode(img))) for img in image_list]
row["multi_modal_data"] = {"image": decoded_images if len(decoded_images) > 1 else decoded_images[0]}
Comment thread
cursor[bot] marked this conversation as resolved.
Outdated
Comment thread
qgallouedec marked this conversation as resolved.
Outdated
prompts.append(row)

generation_kwargs = {
Expand Down
Loading