-
Notifications
You must be signed in to change notification settings - Fork 653
Decouple IS Weights from Rejection Sampling in MIS #657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
141e5ef
7d7e4b1
4b43e68
1e353e7
be49013
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -427,8 +427,9 @@ def vanilla_tis_function( | |
| pg_loss: torch.Tensor, | ||
| train_log_probs: list[torch.Tensor], | ||
| rollout_log_probs: list[torch.Tensor], | ||
| loss_masks: list[torch.Tensor], | ||
| **kwargs: Any, | ||
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | ||
| ) -> Tuple[torch.Tensor, list[torch.Tensor], Dict[str, torch.Tensor]]: | ||
| rollout_log_probs = torch.cat(rollout_log_probs, dim=0) | ||
| old_log_probs = torch.cat(train_log_probs, dim=0) | ||
| tis = torch.exp(old_log_probs - rollout_log_probs) | ||
|
|
@@ -441,7 +442,7 @@ def vanilla_tis_function( | |
| "tis_abs": tis_abs.clone().detach(), | ||
| } | ||
| pg_loss = pg_loss * tis_weights | ||
| return pg_loss, metrics | ||
| return pg_loss, loss_masks, metrics | ||
|
|
||
| assert "rollout_log_probs" in batch, "rollout_log_probs must be provided for TIS" | ||
|
|
||
|
|
@@ -460,7 +461,13 @@ def vanilla_tis_function( | |
| tis_func = load_function(args.custom_tis_function_path) | ||
| else: | ||
| tis_func = vanilla_tis_function | ||
| pg_loss, tis_metrics = tis_func(**tis_kwargs) | ||
| pg_loss, modified_response_masks, tis_metrics = tis_func(**tis_kwargs) | ||
|
|
||
| # [decouple IS and rejection] Rebuild sum_of_sample_mean with modified_response_masks for denominator correction | ||
| # modified_response_masks will be sliced with cp in get_sum_of_sample_mean | ||
| sum_of_sample_mean = get_sum_of_sample_mean( | ||
| total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss | ||
| ) | ||
|
|
||
| pg_loss = sum_of_sample_mean(pg_loss) | ||
| pg_clipfrac = sum_of_sample_mean(pg_clipfrac) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If not sum_of_sample_mean = get_sum_of_sample_mean(
total_lengths, response_lengths, modified_response_masks, args.calculate_per_token_loss
)Right?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, if we don't use TIS then we do not update this sum_of_sample_mean function, which was originally created from loss_mask |
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shall we change the mis -> tis here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just seems "use_mis" is not used anywhere - we may delete it?