diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7e59f12680b..fc4e267bab8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -30,6 +30,7 @@ from typing import Any, Protocol import datasets +import numpy as np import pandas as pd import torch import torch.utils.data @@ -1214,6 +1215,47 @@ def _generate_single_turn(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" + # Tokenize prompts once, shared across all generation backends + if is_conversational({"prompt": prompts[0]}): + # Extract images from messages for VLM support + images = [] + has_images = False + for prompt in prompts: + prompt_images = [] + for message in prompt: + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image": + prompt_images.append(part["image"]) + has_images = True + images.append(prompt_images if prompt_images else None) + images = images if has_images else None + + # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors + # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. + # See: https://github.com/huggingface/transformers/issues/44514 + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + tools=self.tools, + chat_template=self.chat_template, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=True, + **self.chat_template_kwargs, + ) + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) + multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} + else: + prompt_ids = self.processing_class(text=prompts)["input_ids"] + images = None + multimodal_fields = {} + # Generate completions using either vLLM or regular generation if self.use_vllm: # Sync weights if training step changed @@ -1222,40 +1264,10 @@ def _generate_single_turn(self, prompts: list): self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - # Tokenize prompts and extract images (for VLM) before calling vLLM - if is_conversational({"prompt": prompts[0]}): - # Extract images from messages for VLM support - images = [] - has_images = False - for prompt in prompts: - prompt_images = [] - for message in prompt: - if isinstance(message["content"], list): - for part in message["content"]: - if part["type"] == "image": - prompt_images.append(part["image"]) - has_images = True - images.append(prompt_images if prompt_images else None) - images = images if has_images else None - - tokenized = self.processing_class.apply_chat_template( - conversation=prompts, - tools=self.tools, - chat_template=self.chat_template, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - **self.chat_template_kwargs, - ) - prompt_token_ids = tokenized["input_ids"] - else: - prompt_token_ids = self.processing_class(text=prompts)["input_ids"] - images = None - # Generate using vLLM with raw token IDs num_generations = self.num_generations if mode == "train" else self.num_generations_eval prompt_ids, completion_ids, logprobs, _, extra_fields = self.vllm_generation.generate( - prompts=prompt_token_ids, + prompts=prompt_ids, images=images, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate"), @@ -1264,19 +1276,6 @@ def _generate_single_turn(self, prompts: list): logprobs = [[lp[0] for lp in seq] for seq in logprobs] elif self.use_transformers_paged: - if is_conversational({"prompt": prompts[0]}): - processor_outputs = self.processing_class.apply_chat_template( - conversation=prompts, - tools=self.tools, - chat_template=self.chat_template, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - **self.chat_template_kwargs, - ) - else: - processor_outputs = self.processing_class(text=prompts) - with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -1295,33 +1294,28 @@ def _generate_single_turn(self, prompts: list): with torch.inference_mode(): # Continuous batching API expects 'inputs' arg only all_outputs = unwrapped_model.generate_batch( - processor_outputs["input_ids"], generation_config=self.generation_config, progress_bar=False + prompt_ids, generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = processor_outputs["input_ids"] logprobs = None # not used in this case extra_fields = {} # No extra fields for paged mode else: - # Regular generation path - if is_conversational({"prompt": prompts[0]}): - generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, - tools=self.tools, - chat_template=self.chat_template, - add_generation_prompt=True, - tokenize=True, - padding=True, - padding_side="left", - return_tensors="pt", - return_dict=True, - **self.chat_template_kwargs, - ) - else: - generate_inputs = self.processing_class( - text=prompts, padding=True, padding_side="left", return_tensors="pt" - ) + # Regular generation path: left-pad token IDs into tensors + prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] + padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") + generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} + # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) + for k, v in multimodal_fields.items(): + if isinstance(v, torch.Tensor): + generate_inputs[k] = v + elif isinstance(v, list) and v and isinstance(v[0], list): + # Per-token field (e.g., token_type_ids): left-pad like input_ids + generate_inputs[k] = pad([torch.tensor(x) for x in v], padding_value=0, padding_side="left") + else: + generate_inputs[k] = torch.tensor(np.array(v)) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1339,8 +1333,8 @@ def _generate_single_turn(self, prompts: list): **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids - prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) + prompt_ids_tensor, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] + prompt_length = prompt_ids_tensor.size(1) completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token @@ -1350,7 +1344,9 @@ def _generate_single_turn(self, prompts: list): sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() # Move tensors to CPU before per-sample to avoid many CUDA syncs/copies (costly at scale/contention). - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids.cpu(), prompt_mask.bool().cpu(), strict=True)] + prompt_ids = [ + p[m].tolist() for p, m in zip(prompt_ids_tensor.cpu(), prompt_mask.bool().cpu(), strict=True) + ] completion_ids = [ c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) ] diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index f7376e1f0d1..a8fbdd44e89 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -26,6 +26,7 @@ from typing import Any import datasets +import numpy as np import pandas as pd import torch import torch.utils.data @@ -888,6 +889,45 @@ def _generate_single_turn(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" + # Tokenize prompts once, shared across all generation backends + if is_conversational({"prompt": prompts[0]}): + # Extract images from messages for VLM support + images = [] + has_images = False + for prompt in prompts: + prompt_images = [] + for message in prompt: + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image": + prompt_images.append(part["image"]) + has_images = True + images.append(prompt_images if prompt_images else None) + images = images if has_images else None + + # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors + # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. + # See: https://github.com/huggingface/transformers/issues/44514 + tokenized = self.processing_class.apply_chat_template( + conversation=prompts, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + padding=True, + **self.chat_template_kwargs, + ) + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) + multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} + else: + prompt_ids = self.processing_class(text=prompts)["input_ids"] + images = None + multimodal_fields = {} + # Generate completions using either vLLM or regular generation if self.use_vllm: # Sync weights if training step changed @@ -896,56 +936,16 @@ def _generate_single_turn(self, prompts: list): self.vllm_generation.sync_weights() self._last_loaded_step = self.state.global_step - # Tokenize prompts and extract images (for VLM) before calling vLLM - if is_conversational({"prompt": prompts[0]}): - # Extract images from messages for VLM support - images = [] - has_images = False - for prompt in prompts: - prompt_images = [] - for message in prompt: - if isinstance(message["content"], list): - for part in message["content"]: - if part["type"] == "image": - prompt_images.append(part["image"]) - has_images = True - images.append(prompt_images if prompt_images else None) - images = images if has_images else None - - # RLOO does not support tools; omit tools/chat_template args - tokenized = self.processing_class.apply_chat_template( - conversation=prompts, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - **self.chat_template_kwargs, - ) - prompt_token_ids = tokenized["input_ids"] - else: - prompt_token_ids = self.processing_class(text=prompts)["input_ids"] - images = None - # Generate using vLLM (note: RLOO doesn't use logprobs from generation, so we ignore them) num_generations = self.num_generations if mode == "train" else self.num_generations_eval prompt_ids, completion_ids, _, _, _ = self.vllm_generation.generate( - prompts=prompt_token_ids, + prompts=prompt_ids, images=images, num_generations=num_generations, profiler=profiling_context(self, "vLLM.generate"), ) elif self.use_transformers_paged: - if is_conversational({"prompt": prompts[0]}): - processor_outputs = self.processing_class.apply_chat_template( - conversation=prompts, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - **self.chat_template_kwargs, - ) - else: - processor_outputs = self.processing_class(text=prompts) - with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -962,29 +962,26 @@ def _generate_single_turn(self, prompts: list): with torch.inference_mode(): # Continuous batching API expects 'inputs' arg only all_outputs = unwrapped_model.generate_batch( - processor_outputs["input_ids"], generation_config=self.generation_config, progress_bar=False + prompt_ids, generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = processor_outputs["input_ids"] else: - # Regular generation path - if is_conversational({"prompt": prompts[0]}): - generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, - add_generation_prompt=True, - tokenize=True, - padding=True, - padding_side="left", - return_tensors="pt", - return_dict=True, - **self.chat_template_kwargs, - ) - else: - generate_inputs = self.processing_class( - text=prompts, padding=True, padding_side="left", return_tensors="pt" - ) + # Regular generation path: left-pad token IDs into tensors + prompt_tensors = [torch.tensor(ids) for ids in prompt_ids] + padded_ids = pad(prompt_tensors, padding_value=self.pad_token_id, padding_side="left") + attention_mask = pad([torch.ones_like(t) for t in prompt_tensors], padding_value=0, padding_side="left") + generate_inputs = {"input_ids": padded_ids, "attention_mask": attention_mask} + # For VLMs, include multimodal fields as tensors (pixel_values, image_grid_thw, etc.) + for k, v in multimodal_fields.items(): + if isinstance(v, torch.Tensor): + generate_inputs[k] = v + elif isinstance(v, list) and v and isinstance(v[0], list): + # Per-token field (e.g., token_type_ids): left-pad like input_ids + generate_inputs[k] = pad([torch.tensor(x) for x in v], padding_value=0, padding_side="left") + else: + generate_inputs[k] = torch.tensor(np.array(v)) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1002,8 +999,8 @@ def _generate_single_turn(self, prompts: list): **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids - prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_length = prompt_ids.size(1) + prompt_ids_tensor, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] + prompt_length = prompt_ids_tensor.size(1) completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token @@ -1013,7 +1010,9 @@ def _generate_single_turn(self, prompts: list): sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() # Move tensors to CPU before per-sample to avoid many CUDA syncs/copies (costly at scale/contention). - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids.cpu(), prompt_mask.bool().cpu(), strict=True)] + prompt_ids = [ + p[m].tolist() for p, m in zip(prompt_ids_tensor.cpu(), prompt_mask.bool().cpu(), strict=True) + ] completion_ids = [ c[m].tolist() for c, m in zip(completion_ids.cpu(), completion_mask.bool().cpu(), strict=True) ]