Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1556,13 +1556,7 @@ def _generate_and_score_completions(
prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left")
prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left")
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list]

# Allow custom completion_mask from rollout_func for multi-turn training
if "completion_mask" in extra_fields:
completion_mask_list = extra_fields.pop("completion_mask")
completion_mask = [torch.tensor(m, device=device, dtype=torch.long) for m in completion_mask_list]
else:
completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
Comment on lines -1559 to -1565

@qgallouedec qgallouedec Feb 4, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would allow to have a non-continuous attention mask in the forward pass. This would just break the causality (error out?). To mask some tokens in the loss, we should use the "env_mask" extra field.

completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids]
completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right")
completion_mask = pad(completion_mask, padding_value=0, padding_side="right")
if sampling_per_token_logps_list is not None:
Expand All @@ -1584,10 +1578,7 @@ def _generate_and_score_completions(

# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)

# attend to all non-padding tokens, but mask out user/tool result tokens in loss
completion_attention_mask = (completion_ids != self.pad_token_id).long()
attention_mask = torch.cat([prompt_mask, completion_attention_mask], dim=1) # (B, P+C)
Comment on lines -1589 to -1590

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pad tokens should be masked, but the mask should not be inferred from token values, because a same id can be used for other tokens (usually EOS). In practice, this often leads to unintentionally masking the EOS token. For more details, see: https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#5-when-pad_token-equals-eos_token

attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)

logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
Expand Down
Loading