Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ grpo:
enabled: true
overlong_buffer_length: 2048
max_response_length: 14336
skip_reference_policy_logprobs_calculation: true
loss_fn:
reference_policy_kl_penalty: 0.0
ratio_clip_max: 0.28
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults: ../../grpo_math_1B.yaml
grpo:
max_num_steps: 500
loss_fn:
force_on_policy_ratio: true
checkpointing:
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1
policy:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults: ../../grpo_math_1B.yaml
grpo:
max_num_steps: 500
loss_fn:
force_on_policy_ratio: true
checkpointing:
enabled: false
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
defaults: ../../grpo_math_1B.yaml
grpo:
max_num_steps: 500
loss_fn:
force_on_policy_ratio: true
checkpointing:
enabled: false
checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ grpo:
enabled: true
overlong_buffer_length: 512
max_response_length: 1024
skip_reference_policy_logprobs_calculation: true
loss_fn:
reference_policy_kl_penalty: 0.0 # Corresponds to actor_rollout_ref.actor.kl_loss_coef
ratio_clip_max: 0.28 # clip_ratio_high
Expand Down
77 changes: 54 additions & 23 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1739,11 +1739,26 @@ def grpo_train(

metrics_logging_data["content"] = flat_messages["content"]

memory_tracker.snapshot_start_of_stage("Computing logprobs", dir())
print("▶ Preparing for logprob inference...", flush=True)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()
force_on_policy_ratio = master_config["loss_fn"].get(
"force_on_policy_ratio", False
)
skip_prev_logprobs = force_on_policy_ratio
skip_reference_policy_logprobs = master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation", False
)
if skip_prev_logprobs:
print(
"Skipping prev_logprobs computation due to force_on_policy_ratio=True"
)
train_data["prev_logprobs"] = torch.zeros_like(
train_data["generation_logprobs"]
)
if not (skip_prev_logprobs and skip_reference_policy_logprobs):
print("▶ Preparing for logprob inference...", flush=True)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

Comment on lines +1742 to 1760
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Zero-filled prev_logprobs will produce misleading logs and diagnostics.

When force_on_policy_ratio=True, train_data["prev_logprobs"] is filled with zeros. The loss function correctly overrides this internally with curr_logprobs.detach(), but downstream code still reads the raw zeros:

  • Line 1886: log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() — logs zeros to JSONL.
  • Lines 1897–1909: The token_mult_prob_error visualization plots train_data["prev_logprobs"] (zeros) against generation_logprobs, producing a nonsensical chart.

Consider either (a) skipping these log entries when force_on_policy_ratio is on, or (b) back-filling train_data["prev_logprobs"] with the actual curr_logprobs returned from train_results (if available).

🤖 Prompt for AI Agents
In `@nemo_rl/algorithms/grpo.py` around lines 1579 - 1597, The code currently
zero-fills train_data["prev_logprobs"] when force_on_policy_ratio is True which
leads to misleading logs and plots (token_mult_prob_error); change the handling
so that when master_config["loss_fn"].get("force_on_policy_ratio", False) is
True you either (A) avoid emitting prev_logprobs into log_data and skip plotting
token_mult_prob_error, or (B) back-fill train_data["prev_logprobs"] with the
actual on-policy probabilities returned by the training step (e.g., use
train_results["curr_logprobs"] / .detach() if present) before any
logging/visualization; update the code paths around prev_logprobs,
train_results, log_data["prev_logprobs"], and token_mult_prob_error to implement
one of these behaviors.

memory_tracker.snapshot_start_of_stage("Computing logprobs", dir())
print("▶ Computing logprobs...", flush=True)
with timer.time("policy_and_reference_logprobs"):
# Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers.
Expand All @@ -1756,13 +1771,12 @@ def grpo_train(
**extra_multimodal_data,
}
)
train_data["prev_logprobs"] = policy.get_logprobs(
logprob_data, timer=timer
)["logprobs"]
if not skip_prev_logprobs:
train_data["prev_logprobs"] = policy.get_logprobs(
logprob_data, timer=timer
)["logprobs"]

if not master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation"
):
if not skip_reference_policy_logprobs:
train_data["reference_policy_logprobs"] = (
policy.get_reference_policy_logprobs(
logprob_data,
Expand Down Expand Up @@ -2789,22 +2803,39 @@ def async_grpo_train(
train_data.to("cpu")

# Training phase (same as sync version)
print("▶ Preparing for logprob inference...")
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()
force_on_policy_ratio = master_config["loss_fn"].get(
"force_on_policy_ratio", False
)
skip_prev_logprobs = force_on_policy_ratio
skip_reference_policy_logprobs = master_config["grpo"].get(
"skip_reference_policy_logprobs_calculation", False
)
if skip_prev_logprobs:
print(
"Skipping prev_logprobs computation due to force_on_policy_ratio=True"
)
train_data["prev_logprobs"] = torch.zeros_like(
train_data["generation_logprobs"]
)
if not (skip_prev_logprobs and skip_reference_policy_logprobs):
print("▶ Preparing for logprob inference...")
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

print("▶ Computing logprobs...")
with timer.time("policy_and_reference_logprobs"):
fprop_logprobs = policy.get_logprobs(
train_data,
timer=timer,
)["logprobs"]
reference_logprobs = policy.get_reference_policy_logprobs(
train_data,
timer=timer,
)["reference_logprobs"]
train_data["prev_logprobs"] = fprop_logprobs
train_data["reference_policy_logprobs"] = reference_logprobs
if not skip_prev_logprobs:
fprop_logprobs = policy.get_logprobs(
train_data,
timer=timer,
)["logprobs"]
train_data["prev_logprobs"] = fprop_logprobs
if not skip_reference_policy_logprobs:
reference_logprobs = policy.get_reference_policy_logprobs(
train_data,
timer=timer,
)["reference_logprobs"]
train_data["reference_policy_logprobs"] = reference_logprobs

(
max_seq_mult_prob_error,
Expand Down
43 changes: 42 additions & 1 deletion nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,55 @@ def __call__(
token_mask = data["token_mask"][:, 1:]
sample_mask = data["sample_mask"]
advantages = data["advantages"][:, 1:]
prev_logprobs = data["prev_logprobs"][:, 1:]
# Skip loading prev_logprobs when force_on_policy_ratio=True (will use curr_logprobs instead)
prev_logprobs = (
None if self.force_on_policy_ratio else data["prev_logprobs"][:, 1:]
)
generation_logprobs = data["generation_logprobs"][:, 1:]
if self.reference_policy_kl_penalty != 0:
reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:]
curr_logprobs_unfiltered = data.get(
"curr_logprobs_unfiltered", curr_logprobs
)

next_token_logits = next_token_logits.to(torch.float32)

if vocab_parallel_group is not None:
assert vocab_parallel_rank is not None, (
"vocab_parallel_rank must be provided when vocab_parallel_group is provided"
)
curr_logprobs = from_parallel_logits_to_logprobs(
next_token_logits,
data["input_ids"],
vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
)
else:
next_token_logits_wo_last = next_token_logits[
:, :-1
] # Remove last position's logits
next_token_logprobs = torch.nn.functional.log_softmax(
next_token_logits_wo_last, dim=-1
)
next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token
curr_logprobs = next_token_logprobs.gather(
dim=-1, index=next_tokens.unsqueeze(-1)
).squeeze(-1)

# For truly on-policy training, use curr_logprobs as prev_logprobs
# This avoids computing prev_logprobs upstream
if self.force_on_policy_ratio:
prev_logprobs = curr_logprobs.detach()

mask = token_mask * sample_mask.unsqueeze(-1)

# token_mult_prob_error
Expand Down
Loading