Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ grpo:
enabled: true
overlong_buffer_length: 2048
max_response_length: 14336
skip_reference_policy_logprobs_calculation: true
loss_fn:
reference_policy_kl_penalty: 0.0
ratio_clip_max: 0.28
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults: ../../grpo_math_1B.yaml
grpo:
max_num_steps: 500
loss_fn:
force_on_policy_ratio: true
checkpointing:
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1
policy:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults: ../../grpo_math_1B.yaml
grpo:
max_num_steps: 500
loss_fn:
force_on_policy_ratio: true
checkpointing:
enabled: false
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults: ../../grpo_math_1B.yaml
grpo:
max_num_steps: 500
loss_fn:
force_on_policy_ratio: true
checkpointing:
enabled: false
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ grpo:
enabled: true
overlong_buffer_length: 512
max_response_length: 1024
skip_reference_policy_logprobs_calculation: true
loss_fn:
reference_policy_kl_penalty: 0.0 # Corresponds to actor_rollout_ref.actor.kl_loss_coef
ratio_clip_max: 0.28 # clip_ratio_high
Expand Down
77 changes: 54 additions & 23 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1576,11 +1576,26 @@ def grpo_train(

metrics_logging_data["content"] = flat_messages["content"]

memory_tracker.snapshot_start_of_stage("Computing logprobs", dir())
print("▶ Preparing for logprob inference...", flush=True)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()
force_on_policy_ratio = master_config["loss_fn"].get(
"force_on_policy_ratio", False
Comment thread
guyueh1 marked this conversation as resolved.
Outdated
)
Comment thread
guyueh1 marked this conversation as resolved.
Outdated
skip_prev_logprobs = force_on_policy_ratio
skip_reference_policy_logprobs = master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation", False
)
if skip_prev_logprobs:
print(
"Skipping prev_logprobs computation due to force_on_policy_ratio=True"
)
train_data["prev_logprobs"] = torch.zeros_like(
train_data["generation_logprobs"]
)
Comment thread
guyueh1 marked this conversation as resolved.
Outdated
if not (skip_prev_logprobs and skip_reference_policy_logprobs):
print("▶ Preparing for logprob inference...", flush=True)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

Comment thread
guyueh1 marked this conversation as resolved.
Outdated
memory_tracker.snapshot_start_of_stage("Computing logprobs", dir())

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before, it was taking a snapshot before doing offloading. In this PR, it is taking the snapshot after the offloading. Is this an intentional change?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh no, i think this is not intended, i will fix it

print("▶ Computing logprobs...", flush=True)
with timer.time("policy_and_reference_logprobs"):
Comment on lines 1770 to 1771

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: When both skip_prev_logprobs and skip_reference_policy_logprobs are true, this still prints "Computing logprobs..." and constructs logprob_data only to immediately delete it. Consider wrapping the entire block (including the print) in the skip check, or adjusting the log message.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it's ok to ignore

# Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers.
Expand All @@ -1591,13 +1606,12 @@ def grpo_train(
**extra_multimodal_data,
}
Comment thread
guyueh1 marked this conversation as resolved.
)
train_data["prev_logprobs"] = policy.get_logprobs(
logprob_data, timer=timer
)["logprobs"]
if not skip_prev_logprobs:
train_data["prev_logprobs"] = policy.get_logprobs(
logprob_data, timer=timer
)["logprobs"]

if not master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation"
):
if not skip_reference_policy_logprobs:
train_data["reference_policy_logprobs"] = (
policy.get_reference_policy_logprobs(
logprob_data,
Expand Down Expand Up @@ -2584,22 +2598,39 @@ def async_grpo_train(
train_data.to("cpu")

# Training phase (same as sync version)
print("▶ Preparing for logprob inference...")
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()
force_on_policy_ratio = master_config["loss_fn"].get(
"force_on_policy_ratio", False
)
skip_prev_logprobs = force_on_policy_ratio
skip_reference_policy_logprobs = master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation", False
)
if skip_prev_logprobs:
print(
"Skipping prev_logprobs computation due to force_on_policy_ratio=True"
)
train_data["prev_logprobs"] = torch.zeros_like(
train_data["generation_logprobs"]
)
if not (skip_prev_logprobs and skip_reference_policy_logprobs):
print("▶ Preparing for logprob inference...")
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

print("▶ Computing logprobs...")
with timer.time("policy_and_reference_logprobs"):
fprop_logprobs = policy.get_logprobs(
train_data,
timer=timer,
)["logprobs"]
reference_logprobs = policy.get_reference_policy_logprobs(
train_data,
timer=timer,
)["reference_logprobs"]
train_data["prev_logprobs"] = fprop_logprobs
train_data["reference_policy_logprobs"] = reference_logprobs
if not skip_prev_logprobs:
fprop_logprobs = policy.get_logprobs(
train_data,
timer=timer,
)["logprobs"]
train_data["prev_logprobs"] = fprop_logprobs
if not skip_reference_policy_logprobs:
reference_logprobs = policy.get_reference_policy_logprobs(
train_data,
timer=timer,
)["reference_logprobs"]
train_data["reference_policy_logprobs"] = reference_logprobs

# Compute advantages with adv_estimator using correct mask and logprobs
with timer.time("advantage_calculation"):
Expand Down
76 changes: 42 additions & 34 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,53 @@ def __call__(
token_mask = data["token_mask"][:, 1:]
sample_mask = data["sample_mask"]
advantages = data["advantages"][:, 1:]
prev_logprobs = data["prev_logprobs"][:, 1:]
# Skip loading prev_logprobs when force_on_policy_ratio=True (will use curr_logprobs instead)
prev_logprobs = (
None if self.force_on_policy_ratio else data["prev_logprobs"][:, 1:]
)
generation_logprobs = data["generation_logprobs"][:, 1:]
if self.reference_policy_kl_penalty != 0:
reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:]
seq_index = data.get("seq_index", None)

next_token_logits = next_token_logits.to(torch.float32)

if vocab_parallel_group is not None:
assert vocab_parallel_rank is not None, (
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
)
curr_logprobs = from_parallel_logits_to_logprobs(
next_token_logits,
data["input_ids"],
vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
)
else:
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits_wo_last, dim=-1
)
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
curr_logprobs = next_token_logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)

# For truly on-policy training, use curr_logprobs as prev_logprobs
# This avoids computing prev_logprobs upstream
if self.force_on_policy_ratio:
prev_logprobs = curr_logprobs.detach()

mask = token_mask * sample_mask.unsqueeze(-1)

# token_mult_prob_error
Expand Down Expand Up @@ -269,39 +310,6 @@ def __call__(
global_normalization_factor=global_valid_toks,
).item()

next_token_logits = next_token_logits.to(torch.float32)

if vocab_parallel_group is not None:
assert vocab_parallel_rank is not None, (
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
)
curr_logprobs = from_parallel_logits_to_logprobs(
next_token_logits,
data["input_ids"],
vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
)
else:
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits_wo_last, dim=-1
)
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
curr_logprobs = next_token_logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)

# Calculate KL regularization.
if self.reference_policy_kl_penalty != 0:
if self.use_on_policy_kl_approximation:
Expand Down
Loading