Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -411,6 +436,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -536,6 +586,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down Expand Up @@ -665,6 +740,31 @@ def multiply(a: int, b: int) -> int:
decoded_prompt = tokenizer.decode(outputs["prompt_ids"][0])
assert "Multiplies two integers." in decoded_prompt

def test_generate_with_token_ids(self):
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
prompts = ["Hello, AI!", "Tell me a joke"]
prompt_token_ids = tokenizer(prompts)["input_ids"]
outputs = self.client.generate(prompt_token_ids)
prompt_ids = outputs["prompt_ids"]
completion_ids = outputs["completion_ids"]

# Check that the outputs are lists
assert isinstance(prompt_ids, list)
assert isinstance(completion_ids, list)

# Check that the number of sequences are equal to the number of prompts
assert len(prompt_ids) == len(prompts)
assert len(completion_ids) == len(prompts)

# Check that prompt_ids match the input token IDs
assert prompt_ids == prompt_token_ids

# Check that the sequences are lists of integers
for seq in prompt_ids:
assert all(isinstance(tok, int) for tok in seq)
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_params(self):
prompts = ["Hello, AI!", "Tell me a joke"]
completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[
Expand Down
8 changes: 4 additions & 4 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def check_server(self, total_timeout: float = 0.0, retry_interval: float = 2.0):

def generate(
self,
prompts: list[str],
prompts: list[str] | list[list[int]],
images: list | None = None,
n: int = 1,
repetition_penalty: float = 1.0,
Expand All @@ -219,10 +219,10 @@ def generate(
Generates model completions for the provided prompts.

Args:
prompts (`list[str]`):
List of text prompts for which the model will generate completions.
prompts (`list[str]` or `list[list[int]]`):
List of text prompts or list of token ID lists for which the model will generate completions.
images (`list[PIL.Image]`, *optional*):
List of PIL Images to send along with the prompts.
List of PIL Images to send along with the prompts. Only valid when `prompts` is a list of strings.
n (`int`, *optional*, defaults to `1`):
Number of completions to generate for each prompt.
repetition_penalty (`float`, *optional*, defaults to `1.0`):
Expand Down
3 changes: 2 additions & 1 deletion trl/generation/vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ def generate(self, prompts: list, num_generations: int, profiler: ProfilingConte
chat_template=chat_template,
)
else:
output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params)
ordered_set_of_prompt_ids = self.processing_class(text=ordered_set_of_prompts)["input_ids"]
output = self.vllm_client.generate(prompts=ordered_set_of_prompt_ids, **sampling_params)
# Extract required fields and collect any extra fields for reward functions
required_keys = {"prompt_ids", "completion_ids", "logprobs", "logprob_token_ids"}
extra_fields = {k: v for k, v in output.items() if k not in required_keys}
Expand Down
34 changes: 23 additions & 11 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ async def get_world_size():
return {"world_size": script_args.tensor_parallel_size * script_args.data_parallel_size}

class GenerateRequest(BaseModel):
prompts: list[str]
prompts: list[str] | list[list[int]]
Comment thread
AmineDiro marked this conversation as resolved.
images: list[str] | None = None
n: int = 1
repetition_penalty: float = 1.0
Expand All @@ -517,7 +517,8 @@ async def generate(request: GenerateRequest):

Args:
request (`GenerateRequest`):
- `prompts` (list of `str`): A list of prompts (text strings) for the model to generate completions.
- `prompts` (list of `str` or list of list of `int`): A list of prompts. It accepts either text strings
or pre-tokenized token ID lists. When text strings are provided, `images` can optionally be included.
- `images` (list of `str`, *optional*, default to `None`): A list of base64 encoded images to process
along with prompts.
- `n` (`int`, *optional*, defaults to `1`): Number of completions to generate for each prompt.
Expand Down Expand Up @@ -553,11 +554,16 @@ async def generate(request: GenerateRequest):
- `logprob_token_ids` (list of list of list of `int`): Token IDs corresponding to each logprob, same
shape as `logprobs`.

Example request:
Example request (text prompts):
```json
{"prompts": ["Hello world", "What is AI?"]}
```

Example request (token IDs):
```json
{"prompts": [[101, 102], [201, 202]]}
```

Example response:
Comment thread
qgallouedec marked this conversation as resolved.
```json
{
Expand All @@ -568,14 +574,20 @@ async def generate(request: GenerateRequest):
}
```
"""
request.images = request.images or [None] * len(request.prompts)

prompts = []
for prompt, image in zip(request.prompts, request.images, strict=True):
row = {"prompt": prompt}
if image is not None:
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
prompts.append(row)
# Build vLLM-compatible prompt inputs
if request.prompts and isinstance(request.prompts[0], list):
Comment thread
qgallouedec marked this conversation as resolved.
# Token IDs path: wrap each list of token IDs as a TokensPrompt dict for vLLM
prompts = [{"prompt_token_ids": ids} for ids in request.prompts]
else:
# Text prompts path: build prompt dicts with optional images
request.images = request.images or [None] * len(request.prompts)

prompts = []
for prompt, image in zip(request.prompts, request.images, strict=True):
row = {"prompt": prompt}
if image is not None:
row["multi_modal_data"] = {"image": Image.open(BytesIO(base64.b64decode(image)))}
prompts.append(row)

generation_kwargs = {
"n": request.n,
Expand Down
Loading