Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 89 additions & 8 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

import os
import subprocess
from types import SimpleNamespace

import pytest
from transformers import AutoModelForCausalLM
from transformers.testing_utils import torch_device

from trl.generation.vllm_client import VLLMClient
from trl.generation.vllm_generation import extract_logprobs
from trl.import_utils import is_vllm_available
from trl.scripts.vllm_serve import chunk_list

Expand Down Expand Up @@ -62,6 +64,59 @@ def test_any_dtype(self):
]


class TestExtractLogprobs(TrlTestCase):
def test_extract_logprobs_sorts_by_rank_and_replaces_nan(self):
all_outputs = [
SimpleNamespace(
outputs=[
SimpleNamespace(
logprobs=[
{
11: SimpleNamespace(rank=1, logprob=-0.2),
99: SimpleNamespace(rank=0, logprob=-0.1),
42: SimpleNamespace(rank=2, logprob=float("nan")),
},
{
5: SimpleNamespace(rank=0, logprob=-1.1),
},
]
)
]
),
SimpleNamespace(
outputs=[
SimpleNamespace(
logprobs=[
{
3: SimpleNamespace(rank=1, logprob=-0.5),
7: SimpleNamespace(rank=0, logprob=-0.4),
}
]
)
]
),
]

all_logprobs, all_token_ids = extract_logprobs(all_outputs)

assert all_token_ids == [
[[99, 11, 42], [5]],
[[7, 3]],
]
assert all_logprobs == [
[[-0.1, -0.2, None], [-1.1]],
[[-0.4, -0.5]],
]

def test_extract_logprobs_returns_none_token_ids_when_logprobs_missing(self):
all_outputs = [SimpleNamespace(outputs=[SimpleNamespace(logprobs=None)])]

all_logprobs, all_token_ids = extract_logprobs(all_outputs)

assert all_logprobs is None
assert all_token_ids is None


@pytest.mark.slow
@require_torch_multi_accelerator
@require_vllm
Expand Down Expand Up @@ -162,13 +217,15 @@ def test_logprobs_match_with_non_default_sampling(self):
top_p = 0.9
max_tokens = 8
seed = 1234
num_logprobs = 5

server_outputs = self.client.generate(
prompts,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_p=top_p,
max_tokens=max_tokens,
logprobs=num_logprobs,
generation_kwargs={"seed": seed},
)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
Expand All @@ -185,27 +242,51 @@ def test_logprobs_match_with_non_default_sampling(self):
repetition_penalty=repetition_penalty,
top_p=top_p,
max_tokens=max_tokens,
logprobs=0, # this is what's used in practice to get the logprobs of generated tokens
logprobs=num_logprobs,
seed=seed,
)
colocate_outputs = llm.generate(prompts, sampling_params=sampling_params, use_tqdm=False)
colocate_prompt_ids = [output.prompt_token_ids for output in colocate_outputs]
colocate_completion_ids = [
list(output.token_ids) for outputs in colocate_outputs for output in outputs.outputs
]
colocate_logprobs = [
[next(iter(logprob.values())).logprob for logprob in output.logprobs]
for outputs in colocate_outputs
for output in outputs.outputs
]
colocate_logprobs, colocate_logprob_token_ids = extract_logprobs(colocate_outputs)

# Generation correctness: prompt and completion IDs match between server and colocate
assert server_outputs["prompt_ids"] == colocate_prompt_ids
assert server_outputs["completion_ids"] == colocate_completion_ids

server_logprobs = server_outputs["logprobs"]
assert len(server_logprobs) == len(colocate_logprobs)
server_logprob_token_ids = server_outputs["logprob_token_ids"]

# Shape: both should be (num_sequences, seq_len, num_logprobs) with multiple logprobs per token
assert len(server_logprobs) == len(prompts)
assert len(server_logprob_token_ids) == len(prompts)
for seq_lps in server_logprobs:
for token_lps in seq_lps:
assert len(token_lps) > 1, "Expected multiple logprobs per token when logprobs > 0"

# Value correctness: server extraction matches colocate extraction via extract_logprobs
assert server_logprob_token_ids == colocate_logprob_token_ids
for server_seq, colocate_seq in zip(server_logprobs, colocate_logprobs, strict=True):
assert len(server_seq) == len(colocate_seq)
assert server_seq == pytest.approx(colocate_seq, rel=1e-6, abs=1e-6)
for server_token_lps, colocate_token_lps in zip(server_seq, colocate_seq, strict=True):
assert server_token_lps == pytest.approx(colocate_token_lps, rel=1e-6, abs=1e-6)

# Ordering: logprobs at each position should be sorted descending
for seq_lps in server_logprobs:
for token_lps in seq_lps:
assert token_lps == sorted(token_lps, reverse=True), "Logprobs should be sorted descending"

# Sampled token presence: the actual completion token should appear in the logprob token IDs
for seq_idx, (completion_seq, token_ids_seq) in enumerate(
zip(server_outputs["completion_ids"], server_logprob_token_ids, strict=True)
):
for pos, (sampled_id, lp_ids) in enumerate(zip(completion_seq, token_ids_seq, strict=True)):
assert sampled_id in lp_ids, (
f"Sampled token {sampled_id} not found in logprob token IDs {lp_ids} "
f"at sequence {seq_idx}, position {pos}"
)

@classmethod
def teardown_class(cls):
Expand Down
26 changes: 22 additions & 4 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def generate(
top_k: int = 0,
min_p: float = 0.0,
max_tokens: int = 16,
logprobs: int | None = 0,
truncate_prompt_tokens: int | None = None,
structured_outputs_regex: str | None = None,
generation_kwargs: dict | None = None,
Expand All @@ -235,6 +236,9 @@ def generate(
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each prompt.
logprobs (`int` or `None`, *optional*, defaults to `0`):
Number of top logprobs to return per token. When 0, only the sampled token's logprob is returned. When
N>0, returns the top-N logprobs sorted by descending probability.
truncate_prompt_tokens (`int`, *optional*):
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
Expand All @@ -252,8 +256,11 @@ def generate(
List of lists of token IDs representing the tokenized input prompts.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each prompt.
- `logprobs` (`list[list[float]]`):
List of lists of log probabilities for each generated token.
- `logprobs` (`list[list[list[float]]]`):
Comment on lines -255 to +259

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the implementation of the change from list[list[float]] to list[list[list[float]]]? down on line 293, this function is still just getting it out of the response with no modifications.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's on the server side

logprobs, logprob_token_ids = extract_logprobs(all_outputs)
return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
"logprob_token_ids": logprob_token_ids,
}

the client just passes it through

Per-token logprobs of shape (num_sequences, seq_len, num_logprobs), sorted by descending
probability.
- `logprob_token_ids` (`list[list[list[int]]]`):
Token IDs corresponding to each logprob, same shape as `logprobs`.
"""
url = f"{self.base_url}/generate/"

Expand All @@ -272,6 +279,7 @@ def generate(
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"logprobs": logprobs,
"truncate_prompt_tokens": truncate_prompt_tokens,
"structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
Expand All @@ -283,6 +291,7 @@ def generate(
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
"logprob_token_ids": json_response["logprob_token_ids"],
}
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
Expand All @@ -297,6 +306,7 @@ def chat(
top_k: int = 0,
min_p: float = 0.0,
max_tokens: int = 16,
logprobs: int | None = 0,
truncate_prompt_tokens: int | None = None,
structured_outputs_regex: str | None = None,
generation_kwargs: dict | None = None,
Expand Down Expand Up @@ -325,6 +335,9 @@ def chat(
Minimum probability for sampling.
max_tokens (`int`, *optional*, defaults to `16`):
Maximum number of tokens to generate for each message list.
logprobs (`int` or `None`, *optional*, defaults to `0`):
Number of top logprobs to return per token. When 0, only the sampled token's logprob is returned. When
N>0, returns the top-N logprobs sorted by descending probability.
truncate_prompt_tokens (`int`, *optional*):
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
Expand All @@ -349,8 +362,11 @@ def chat(
List of lists of token IDs representing the tokenized input messages.
- `completion_ids` (`list[list[int]]`):
List of lists of token IDs representing the model-generated completions for each message list.
- `logprobs` (`list[list[float]]`):
List of lists of log probabilities for each generated token.
- `logprobs` (`list[list[list[float]]]`):
Per-token logprobs of shape (num_sequences, seq_len, num_logprobs), sorted by descending
probability.
- `logprob_token_ids` (`list[list[list[int]]]`):
Token IDs corresponding to each logprob, same shape as `logprobs`.
"""
if tools:
raise NotImplementedError("Tool calling is not yet implemented in VLLMClient.chat().")
Expand Down Expand Up @@ -379,6 +395,7 @@ def chat(
"top_k": top_k,
"min_p": min_p,
"max_tokens": max_tokens,
"logprobs": logprobs,
"truncate_prompt_tokens": truncate_prompt_tokens,
"structured_outputs_regex": structured_outputs_regex,
"generation_kwargs": generation_kwargs or {},
Expand All @@ -391,6 +408,7 @@ def chat(
"prompt_ids": json_response["prompt_ids"],
"completion_ids": json_response["completion_ids"],
"logprobs": json_response["logprobs"],
"logprob_token_ids": json_response["logprob_token_ids"],
}
else:
raise Exception(f"Request failed: {response.status_code}, {response.text}")
Expand Down
Loading
Loading