-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Revert change in GRPO from NeMo-Gym Integration #4970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1556,13 +1556,7 @@ def _generate_and_score_completions( | |
| prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") | ||
| prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") | ||
| completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] | ||
|
|
||
| # Allow custom completion_mask from rollout_func for multi-turn training | ||
| if "completion_mask" in extra_fields: | ||
| completion_mask_list = extra_fields.pop("completion_mask") | ||
| completion_mask = [torch.tensor(m, device=device, dtype=torch.long) for m in completion_mask_list] | ||
| else: | ||
| completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] | ||
| completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] | ||
| completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") | ||
| completion_mask = pad(completion_mask, padding_value=0, padding_side="right") | ||
| if sampling_per_token_logps_list is not None: | ||
|
|
@@ -1584,10 +1578,7 @@ def _generate_and_score_completions( | |
|
|
||
| # Concatenate prompt_mask with completion_mask for logit computation | ||
| prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) | ||
|
|
||
| # attend to all non-padding tokens, but mask out user/tool result tokens in loss | ||
| completion_attention_mask = (completion_ids != self.pad_token_id).long() | ||
| attention_mask = torch.cat([prompt_mask, completion_attention_mask], dim=1) # (B, P+C) | ||
|
Comment on lines
-1589
to
-1590
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pad tokens should be masked, but the mask should not be inferred from token values, because a same id can be used for other tokens (usually EOS). In practice, this often leads to unintentionally masking the EOS token. For more details, see: https://huggingface.co/blog/qgallouedec/gotchas-in-tokenizer-behavior#5-when-pad_token-equals-eos_token |
||
| attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) | ||
|
|
||
| logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens | ||
| batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would allow to have a non-continuous attention mask in the forward pass. This would just break the causality (error out?). To mask some tokens in the loss, we should use the "env_mask" extra field.