From cf3dbe2adccdb5a7b60d7eedf77d82271fa6536b Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 10 Apr 2026 13:14:22 +0200 Subject: [PATCH] Resctrict padding workaround to transformers 5.3.0 --- trl/experimental/dppo/dppo_trainer.py | 21 +++++++++++++-------- trl/trainer/grpo_trainer.py | 22 +++++++++++++--------- trl/trainer/rloo_trainer.py | 22 +++++++++++++--------- 3 files changed, 39 insertions(+), 26 deletions(-) diff --git a/trl/experimental/dppo/dppo_trainer.py b/trl/experimental/dppo/dppo_trainer.py index ffac0feb4ca..bb5e5933228 100644 --- a/trl/experimental/dppo/dppo_trainer.py +++ b/trl/experimental/dppo/dppo_trainer.py @@ -246,9 +246,10 @@ def _tokenize_prompts(self, prompts: list): images.append(prompt_images if prompt_images else None) images = images if has_images else None - # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors - # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. - # See: https://github.com/huggingface/transformers/issues/44514 + # Workaround for a bug in transformers 5.3.0 where some processors (e.g. Qwen2.5-VL) crash on + # batched unpadded input (transformers#44514). + # Fixed in transformers 5.4.0 (transformers#44563). + needs_padding_workaround = Version("5.3.0") <= Version(transformers.__version__) < Version("5.4.0") tokenized = self.processing_class.apply_chat_template( conversation=prompts, tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] @@ -256,13 +257,17 @@ def _tokenize_prompts(self, prompts: list): add_generation_prompt=True, tokenize=True, return_dict=True, - padding=True, + **({"padding": True} if needs_padding_workaround else {}), **self.chat_template_kwargs, ) - prompt_ids = [ - [tok for tok, mask in zip(ids, attention_mask, strict=True) if mask] - for ids, attention_mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) - ] + if needs_padding_workaround: + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + else: + prompt_ids = tokenized["input_ids"] multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} else: prompt_ids = self.processing_class(text=prompts)["input_ids"] diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7544ab08f28..948841dc9a2 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1289,9 +1289,10 @@ def _tokenize_prompts(self, prompts: list): images.append(prompt_images if prompt_images else None) images = images if has_images else None - # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors - # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. - # See: https://github.com/huggingface/transformers/issues/44514 + # Workaround for a bug in transformers 5.3.0 where some processors (e.g. Qwen2.5-VL) crash on + # batched unpadded input (transformers#44514). + # Fixed in transformers 5.4.0 (transformers#44563). + needs_padding_workaround = Version("5.3.0") <= Version(transformers.__version__) < Version("5.4.0") tokenized = self.processing_class.apply_chat_template( conversation=prompts, tools=self.tools or None, # `or None`: Llama bug: it renders tool boilerplate for tools=[] @@ -1299,14 +1300,17 @@ def _tokenize_prompts(self, prompts: list): add_generation_prompt=True, tokenize=True, return_dict=True, - padding=True, + **({"padding": True} if needs_padding_workaround else {}), **self.chat_template_kwargs, ) - # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists - prompt_ids = [ - [tok for tok, m in zip(ids, mask, strict=True) if m] - for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) - ] + if needs_padding_workaround: + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + else: + prompt_ids = tokenized["input_ids"] # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} else: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 6588e031f2c..6924332a336 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -919,22 +919,26 @@ def _tokenize_prompts(self, prompts: list): images.append(prompt_images if prompt_images else None) images = images if has_images else None - # We pass padding=True to work around a bug introduced in transformers 5.2.0 in some processors - # (e.g. Qwen2.5-VL) that crash on batched unpadded input. We then unpad input_ids using attention_mask. - # See: https://github.com/huggingface/transformers/issues/44514 + # Workaround for a bug in transformers 5.3.0 where some processors (e.g. Qwen2.5-VL) crash on + # batched unpadded input (transformers#44514). + # Fixed in transformers 5.4.0 (transformers#44563). + needs_padding_workaround = Version("5.3.0") <= Version(transformers.__version__) < Version("5.4.0") tokenized = self.processing_class.apply_chat_template( conversation=prompts, add_generation_prompt=True, tokenize=True, return_dict=True, - padding=True, + **({"padding": True} if needs_padding_workaround else {}), **self.chat_template_kwargs, ) - # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists - prompt_ids = [ - [tok for tok, m in zip(ids, mask, strict=True) if m] - for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) - ] + if needs_padding_workaround: + # Unpad input_ids: remove padding tokens using attention_mask to get per-sequence lists + prompt_ids = [ + [tok for tok, m in zip(ids, mask, strict=True) if m] + for ids, mask in zip(tokenized["input_ids"], tokenized["attention_mask"], strict=True) + ] + else: + prompt_ids = tokenized["input_ids"] # For VLMs, the processor returns extra multimodal fields (pixel_values, image_grid_thw, etc.) multimodal_fields = {k: v for k, v in tokenized.items() if k not in ("input_ids", "attention_mask")} else: