-
Notifications
You must be signed in to change notification settings - Fork 438
feat: skip logprob and reference logprob computation under certain conditions #1891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
dd9498d
5a3c52e
48ecb0b
a0b58e8
4f0f1f4
8d17ba3
4a3d561
a8156b8
1e25d65
aac0f37
7f9746f
74bc942
85fae95
5cca419
6ee59ac
ed4c613
4baca40
6abdc51
edba691
71f0ff6
39b6202
84a737b
b3c52da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1576,11 +1576,26 @@ def grpo_train( | |
|
|
||
| metrics_logging_data["content"] = flat_messages["content"] | ||
|
|
||
| memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) | ||
| print("▶ Preparing for logprob inference...", flush=True) | ||
| with timer.time("logprob_inference_prep"): | ||
| policy.prepare_for_lp_inference() | ||
| force_on_policy_ratio = master_config["loss_fn"].get( | ||
| "force_on_policy_ratio", False | ||
| ) | ||
|
guyueh1 marked this conversation as resolved.
Outdated
|
||
| skip_prev_logprobs = force_on_policy_ratio | ||
| skip_reference_policy_logprobs = master_config["grpo"].get( | ||
| "skip_reference_policy_logprobs_calculation", False | ||
| ) | ||
| if skip_prev_logprobs: | ||
| print( | ||
| "Skipping prev_logprobs computation due to force_on_policy_ratio=True" | ||
| ) | ||
| train_data["prev_logprobs"] = torch.zeros_like( | ||
| train_data["generation_logprobs"] | ||
| ) | ||
|
guyueh1 marked this conversation as resolved.
Outdated
|
||
| if not (skip_prev_logprobs and skip_reference_policy_logprobs): | ||
| print("▶ Preparing for logprob inference...", flush=True) | ||
| with timer.time("logprob_inference_prep"): | ||
| policy.prepare_for_lp_inference() | ||
|
|
||
|
guyueh1 marked this conversation as resolved.
Outdated
|
||
| memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before, it was taking a snapshot before doing offloading. In this PR, it is taking the snapshot after the offloading. Is this an intentional change?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh no, i think this is not intended, i will fix it |
||
| print("▶ Computing logprobs...", flush=True) | ||
| with timer.time("policy_and_reference_logprobs"): | ||
|
Comment on lines
1770
to
1771
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: When both
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think it's ok to ignore |
||
| # Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers. | ||
|
|
@@ -1591,13 +1606,12 @@ def grpo_train( | |
| **extra_multimodal_data, | ||
| } | ||
|
guyueh1 marked this conversation as resolved.
|
||
| ) | ||
| train_data["prev_logprobs"] = policy.get_logprobs( | ||
| logprob_data, timer=timer | ||
| )["logprobs"] | ||
| if not skip_prev_logprobs: | ||
| train_data["prev_logprobs"] = policy.get_logprobs( | ||
| logprob_data, timer=timer | ||
| )["logprobs"] | ||
|
|
||
| if not master_config["grpo"].get( | ||
| "skip_reference_policy_logprobs_calculation" | ||
| ): | ||
| if not skip_reference_policy_logprobs: | ||
| train_data["reference_policy_logprobs"] = ( | ||
| policy.get_reference_policy_logprobs( | ||
| logprob_data, | ||
|
|
@@ -2584,22 +2598,39 @@ def async_grpo_train( | |
| train_data.to("cpu") | ||
|
|
||
| # Training phase (same as sync version) | ||
| print("▶ Preparing for logprob inference...") | ||
| with timer.time("logprob_inference_prep"): | ||
| policy.prepare_for_lp_inference() | ||
| force_on_policy_ratio = master_config["loss_fn"].get( | ||
| "force_on_policy_ratio", False | ||
| ) | ||
| skip_prev_logprobs = force_on_policy_ratio | ||
| skip_reference_policy_logprobs = master_config["grpo"].get( | ||
| "skip_reference_policy_logprobs_calculation", False | ||
| ) | ||
| if skip_prev_logprobs: | ||
| print( | ||
| "Skipping prev_logprobs computation due to force_on_policy_ratio=True" | ||
| ) | ||
| train_data["prev_logprobs"] = torch.zeros_like( | ||
| train_data["generation_logprobs"] | ||
| ) | ||
| if not (skip_prev_logprobs and skip_reference_policy_logprobs): | ||
| print("▶ Preparing for logprob inference...") | ||
| with timer.time("logprob_inference_prep"): | ||
| policy.prepare_for_lp_inference() | ||
|
|
||
| print("▶ Computing logprobs...") | ||
| with timer.time("policy_and_reference_logprobs"): | ||
| fprop_logprobs = policy.get_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["logprobs"] | ||
| reference_logprobs = policy.get_reference_policy_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["reference_logprobs"] | ||
| train_data["prev_logprobs"] = fprop_logprobs | ||
| train_data["reference_policy_logprobs"] = reference_logprobs | ||
| if not skip_prev_logprobs: | ||
| fprop_logprobs = policy.get_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["logprobs"] | ||
| train_data["prev_logprobs"] = fprop_logprobs | ||
| if not skip_reference_policy_logprobs: | ||
| reference_logprobs = policy.get_reference_policy_logprobs( | ||
| train_data, | ||
| timer=timer, | ||
| )["reference_logprobs"] | ||
| train_data["reference_policy_logprobs"] = reference_logprobs | ||
|
|
||
| # Compute advantages with adv_estimator using correct mask and logprobs | ||
| with timer.time("advantage_calculation"): | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.