Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions tests/test_vllm_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ def test_generate(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
Expand All @@ -186,6 +194,14 @@ def test_chat(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Expand Down Expand Up @@ -395,6 +411,14 @@ def test_generate(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
Expand All @@ -415,6 +439,14 @@ def test_chat(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Expand Down Expand Up @@ -545,6 +577,14 @@ def test_generate(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
Expand All @@ -565,6 +605,14 @@ def test_chat(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Expand Down Expand Up @@ -699,6 +747,14 @@ def test_generate(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_generate_with_logprobs_none(self):
outputs = self.client.generate(["Hello, AI!"], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat(self):
messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]]
outputs = self.client.chat(messages)
Expand All @@ -719,6 +775,14 @@ def test_chat(self):
for seq in completion_ids:
assert all(isinstance(tok, int) for tok in seq)

def test_chat_with_logprobs_none(self):
outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None)

assert isinstance(outputs["prompt_ids"], list)
assert isinstance(outputs["completion_ids"], list)
assert outputs["logprobs"] is None
assert outputs["logprob_token_ids"] is None

def test_chat_with_tools(self):
def multiply(a: int, b: int) -> int:
"""
Expand Down
16 changes: 10 additions & 6 deletions trl/generation/vllm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,12 @@ class VLLMClient:
>>> client.generate(["Hello, AI!", "Tell me a joke"])
{'prompt_ids': [[9707, 11, 15235, 0],
[40451, 752, 264, 21646]],
'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733],
[911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]],
'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963],
[-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]}
'completion_ids': [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025],
[911, 98072, 2142, 624, 45, 51426, 2142, 374, 279, 16396, 429, 4302, 702, 36988, 7290, 476]],
'logprobs': [[[-1.6612], [-0.0081], [-1.5189], [-0.0123], [-1.2045], [-0.6227], [-2.9791], [-2.8387], [-0.1267], [-0.0366], [-2.6528], [-0.3197], [-0.0001], [-1.8174], [-0.0251], [-1.473]],
[[-0.018], [-10.7331], [-0.1605], [-0.891], [-3.7945], [-0.0127], [-0.3073], [-1.1648], [-1.8025], [-0.409], [-0.0256], [-1.6127], [-2.2935], [-4.1785], [-0.6531], [-0.2629]]],
'logprob_token_ids': [[[2980], [498], [1492], [752], [448], [264], [13027], [8645], [30], [358], [2776], [4460], [311], [3270], [264], [2025]],
[[911], [98072], [2142], [624], [45], [51426], [2142], [374], [279], [16396], [429], [4302], [702], [36988], [7290], [476]]]}

>>> from transformers import AutoModelForCausalLM

Expand Down Expand Up @@ -239,7 +241,8 @@ def generate(
Maximum number of tokens to generate for each prompt.
logprobs (`int` or `None`, *optional*, defaults to `0`):
Number of top logprobs to return per token. When 0, only the sampled token's logprob is returned. When
N>0, returns the top-N logprobs sorted by descending probability.
N>0, returns up to N+1 logprobs sorted by descending probability, because vLLM always includes the
sampled token's logprob (which may fall outside the top-N).
structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Expand Down Expand Up @@ -336,7 +339,8 @@ def chat(
Maximum number of tokens to generate for each message list.
logprobs (`int` or `None`, *optional*, defaults to `0`):
Number of top logprobs to return per token. When 0, only the sampled token's logprob is returned. When
N>0, returns the top-N logprobs sorted by descending probability.
N>0, returns up to N+1 logprobs sorted by descending probability, because vLLM always includes the
sampled token's logprob (which may fall outside the top-N).
structured_outputs_regex (`str`, *optional*):
Regular expression to guide the decoding process.
generation_kwargs (`dict`, *optional*):
Expand Down
18 changes: 10 additions & 8 deletions trl/scripts/vllm_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,8 +506,8 @@ class GenerateRequest(BaseModel):
class GenerateResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
logprobs: list[list[list[float | None]]] | None
logprob_token_ids: list[list[list[int]]] | None

@app.post("/generate/", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
Expand All @@ -533,8 +533,9 @@ async def generate(request: GenerateRequest):
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `logprobs` (`int`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0,
only the sampled token's logprob is returned. When N>0, returns the top-N logprobs sorted by
descending probability.
only the sampled token's logprob is returned. When N>0, returns up to N+1 logprobs sorted by
descending probability, because vLLM always includes the sampled token's logprob (which may fall
outside the top-N).
- `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided,
the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
Expand Down Expand Up @@ -675,8 +676,8 @@ class ChatRequest(BaseModel):
class ChatResponse(BaseModel):
prompt_ids: list[list[int]]
completion_ids: list[list[int]]
logprobs: list[list[list[float]]]
logprob_token_ids: list[list[list[int]]]
logprobs: list[list[list[float | None]]] | None
logprob_token_ids: list[list[list[int]]] | None

@app.post("/chat/", response_model=ChatResponse)
async def chat(request: ChatRequest):
Expand All @@ -700,8 +701,9 @@ async def chat(request: ChatRequest):
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
completion.
- `logprobs` (`int`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0,
only the sampled token's logprob is returned. When N>0, returns the top-N logprobs sorted by
descending probability.
only the sampled token's logprob is returned. When N>0, returns up to N+1 logprobs sorted by
descending probability, because vLLM always includes the sampled token's logprob (which may fall
outside the top-N).
- `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided,
the model will only generate tokens that match this regex pattern.
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
Expand Down
9 changes: 5 additions & 4 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,11 @@ class GRPOTrainer(_BaseTrainer):
rollout_func (`RolloutFunc`, *optional*):
Function to use for generating completions. It receives the list of prompts allocated to the current
process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and
`"logprobs"` fields. Any other fields are forwarded to the reward functions. The function receives the raw
per-process prompt slice with no duplication; it is responsible for returning the correct number of
completions per prompt (see `num_generations` / `num_generations_eval` on the trainer). This feature is
experimental and may change or be removed at any time without prior notice.
`"logprobs"` fields, and can optionally return `"logprob_token_ids"` (same shape as `"logprobs"`). Any
other fields are forwarded to the reward functions. The function receives the raw per-process prompt slice
with no duplication; it is responsible for returning the correct number of completions per prompt (see
`num_generations` / `num_generations_eval` on the trainer). This feature is experimental and may change or
be removed at any time without prior notice.
environment_factory (`EnvironmentFactory`, *optional*):
A callable that creates and returns an environment instance. The environment class should define methods
that can be invoked as tools during generation. Each method should comply with the same requirements as the
Expand Down
Loading