diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 84c4bc2ef9a..820483a9d7e 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -166,6 +166,14 @@ def test_generate(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_logprobs_none(self): + outputs = self.client.generate(["Hello, AI!"], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat(self): messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] outputs = self.client.chat(messages) @@ -186,6 +194,14 @@ def test_chat(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_chat_with_logprobs_none(self): + outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat_with_tools(self): def multiply(a: int, b: int) -> int: """ @@ -395,6 +411,14 @@ def test_generate(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_logprobs_none(self): + outputs = self.client.generate(["Hello, AI!"], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat(self): messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] outputs = self.client.chat(messages) @@ -415,6 +439,14 @@ def test_chat(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_chat_with_logprobs_none(self): + outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat_with_tools(self): def multiply(a: int, b: int) -> int: """ @@ -545,6 +577,14 @@ def test_generate(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_logprobs_none(self): + outputs = self.client.generate(["Hello, AI!"], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat(self): messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] outputs = self.client.chat(messages) @@ -565,6 +605,14 @@ def test_chat(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_chat_with_logprobs_none(self): + outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat_with_tools(self): def multiply(a: int, b: int) -> int: """ @@ -699,6 +747,14 @@ def test_generate(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_generate_with_logprobs_none(self): + outputs = self.client.generate(["Hello, AI!"], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat(self): messages = [[{"role": "user", "content": "Hello, AI!"}], [{"role": "user", "content": "Tell me a joke"}]] outputs = self.client.chat(messages) @@ -719,6 +775,14 @@ def test_chat(self): for seq in completion_ids: assert all(isinstance(tok, int) for tok in seq) + def test_chat_with_logprobs_none(self): + outputs = self.client.chat([[{"role": "user", "content": "Hello, AI!"}]], logprobs=None) + + assert isinstance(outputs["prompt_ids"], list) + assert isinstance(outputs["completion_ids"], list) + assert outputs["logprobs"] is None + assert outputs["logprob_token_ids"] is None + def test_chat_with_tools(self): def multiply(a: int, b: int) -> int: """ diff --git a/trl/generation/vllm_client.py b/trl/generation/vllm_client.py index 818a1df3f18..d993b283b7f 100644 --- a/trl/generation/vllm_client.py +++ b/trl/generation/vllm_client.py @@ -95,10 +95,12 @@ class VLLMClient: >>> client.generate(["Hello, AI!", "Tell me a joke"]) {'prompt_ids': [[9707, 11, 15235, 0], [40451, 752, 264, 21646]], - 'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733], - [911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]], - 'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963], - [-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]} + 'completion_ids': [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], + [911, 98072, 2142, 624, 45, 51426, 2142, 374, 279, 16396, 429, 4302, 702, 36988, 7290, 476]], + 'logprobs': [[[-1.6612], [-0.0081], [-1.5189], [-0.0123], [-1.2045], [-0.6227], [-2.9791], [-2.8387], [-0.1267], [-0.0366], [-2.6528], [-0.3197], [-0.0001], [-1.8174], [-0.0251], [-1.473]], + [[-0.018], [-10.7331], [-0.1605], [-0.891], [-3.7945], [-0.0127], [-0.3073], [-1.1648], [-1.8025], [-0.409], [-0.0256], [-1.6127], [-2.2935], [-4.1785], [-0.6531], [-0.2629]]], + 'logprob_token_ids': [[[2980], [498], [1492], [752], [448], [264], [13027], [8645], [30], [358], [2776], [4460], [311], [3270], [264], [2025]], + [[911], [98072], [2142], [624], [45], [51426], [2142], [374], [279], [16396], [429], [4302], [702], [36988], [7290], [476]]]} >>> from transformers import AutoModelForCausalLM @@ -239,7 +241,8 @@ def generate( Maximum number of tokens to generate for each prompt. logprobs (`int` or `None`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0, only the sampled token's logprob is returned. When - N>0, returns the top-N logprobs sorted by descending probability. + N>0, returns up to N+1 logprobs sorted by descending probability, because vLLM always includes the + sampled token's logprob (which may fall outside the top-N). structured_outputs_regex (`str`, *optional*): Regular expression to guide the decoding process. generation_kwargs (`dict`, *optional*): @@ -336,7 +339,8 @@ def chat( Maximum number of tokens to generate for each message list. logprobs (`int` or `None`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0, only the sampled token's logprob is returned. When - N>0, returns the top-N logprobs sorted by descending probability. + N>0, returns up to N+1 logprobs sorted by descending probability, because vLLM always includes the + sampled token's logprob (which may fall outside the top-N). structured_outputs_regex (`str`, *optional*): Regular expression to guide the decoding process. generation_kwargs (`dict`, *optional*): diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 354ce14202b..481c5810f29 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -506,8 +506,8 @@ class GenerateRequest(BaseModel): class GenerateResponse(BaseModel): prompt_ids: list[list[int]] completion_ids: list[list[int]] - logprobs: list[list[list[float]]] - logprob_token_ids: list[list[list[int]]] + logprobs: list[list[list[float | None]]] | None + logprob_token_ids: list[list[list[int]]] | None @app.post("/generate/", response_model=GenerateResponse) async def generate(request: GenerateRequest): @@ -533,8 +533,9 @@ async def generate(request: GenerateRequest): - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion. - `logprobs` (`int`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0, - only the sampled token's logprob is returned. When N>0, returns the top-N logprobs sorted by - descending probability. + only the sampled token's logprob is returned. When N>0, returns up to N+1 logprobs sorted by + descending probability, because vLLM always includes the sampled token's logprob (which may fall + outside the top-N). - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM @@ -675,8 +676,8 @@ class ChatRequest(BaseModel): class ChatResponse(BaseModel): prompt_ids: list[list[int]] completion_ids: list[list[int]] - logprobs: list[list[list[float]]] - logprob_token_ids: list[list[list[int]]] + logprobs: list[list[list[float | None]]] | None + logprob_token_ids: list[list[list[int]]] | None @app.post("/chat/", response_model=ChatResponse) async def chat(request: ChatRequest): @@ -700,8 +701,9 @@ async def chat(request: ChatRequest): - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion. - `logprobs` (`int`, *optional*, defaults to `0`): Number of top logprobs to return per token. When 0, - only the sampled token's logprob is returned. When N>0, returns the top-N logprobs sorted by - descending probability. + only the sampled token's logprob is returned. When N>0, returns up to N+1 logprobs sorted by + descending probability, because vLLM always includes the sampled token's logprob (which may fall + outside the top-N). - `structured_outputs_regex` (`str`, *optional*): A regex pattern for structured outputs. If provided, the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e5b3d21e96f..13429829645 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -232,10 +232,11 @@ class GRPOTrainer(_BaseTrainer): rollout_func (`RolloutFunc`, *optional*): Function to use for generating completions. It receives the list of prompts allocated to the current process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and - `"logprobs"` fields. Any other fields are forwarded to the reward functions. The function receives the raw - per-process prompt slice with no duplication; it is responsible for returning the correct number of - completions per prompt (see `num_generations` / `num_generations_eval` on the trainer). This feature is - experimental and may change or be removed at any time without prior notice. + `"logprobs"` fields, and can optionally return `"logprob_token_ids"` (same shape as `"logprobs"`). Any + other fields are forwarded to the reward functions. The function receives the raw per-process prompt slice + with no duplication; it is responsible for returning the correct number of completions per prompt (see + `num_generations` / `num_generations_eval` on the trainer). This feature is experimental and may change or + be removed at any time without prior notice. environment_factory (`EnvironmentFactory`, *optional*): A callable that creates and returns an environment instance. The environment class should define methods that can be invoked as tools during generation. Each method should comply with the same requirements as the