diff --git a/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml b/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml index 9035a3598c..6956e7daa2 100644 --- a/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml +++ b/examples/configs/recipes/llm/dapo-qwen2.5-7b.yaml @@ -15,6 +15,7 @@ grpo: enabled: true overlong_buffer_length: 2048 max_response_length: 14336 + skip_reference_policy_logprobs_calculation: true loss_fn: reference_policy_kl_penalty: 0.0 ratio_clip_max: 0.28 diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml index 6eb8ed4872..184672be28 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.yaml @@ -1,6 +1,8 @@ defaults: ../../grpo_math_1B.yaml grpo: max_num_steps: 500 +loss_fn: + force_on_policy_ratio: true checkpointing: checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1 policy: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml index 333a06d980..4bc30bdb2f 100755 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron.yaml @@ -1,6 +1,8 @@ defaults: ../../grpo_math_1B.yaml grpo: max_num_steps: 500 +loss_fn: + force_on_policy_ratio: true checkpointing: enabled: false checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml index bb641388d8..23b1726342 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-megatron_generation.yaml @@ -1,6 +1,8 @@ defaults: ../../grpo_math_1B.yaml grpo: max_num_steps: 500 +loss_fn: + force_on_policy_ratio: true checkpointing: enabled: false checkpoint_dir: results/grpo-llama3.2-1b-instruct-1n8g-megatron_generation diff --git a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml index 9c4edd2b30..989c97df85 100644 --- a/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml +++ b/examples/configs/recipes/llm/performance/dapo-deepseek-v3-64n8g.yaml @@ -21,6 +21,7 @@ grpo: enabled: true overlong_buffer_length: 512 max_response_length: 1024 + skip_reference_policy_logprobs_calculation: true loss_fn: reference_policy_kl_penalty: 0.0 # Corresponds to actor_rollout_ref.actor.kl_loss_coef ratio_clip_max: 0.28 # clip_ratio_high diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 005aea47a7..3d8ad019a4 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1739,11 +1739,26 @@ def grpo_train( metrics_logging_data["content"] = flat_messages["content"] - memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) - print("▶ Preparing for logprob inference...", flush=True) - with timer.time("logprob_inference_prep"): - policy.prepare_for_lp_inference() + force_on_policy_ratio = master_config["loss_fn"].get( + "force_on_policy_ratio", False + ) + skip_prev_logprobs = force_on_policy_ratio + skip_reference_policy_logprobs = master_config["grpo"].get( + "skip_reference_policy_logprobs_calculation", False + ) + if skip_prev_logprobs: + print( + "Skipping prev_logprobs computation due to force_on_policy_ratio=True" + ) + train_data["prev_logprobs"] = torch.zeros_like( + train_data["generation_logprobs"] + ) + if not (skip_prev_logprobs and skip_reference_policy_logprobs): + print("▶ Preparing for logprob inference...", flush=True) + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() + memory_tracker.snapshot_start_of_stage("Computing logprobs", dir()) print("▶ Computing logprobs...", flush=True) with timer.time("policy_and_reference_logprobs"): # Custom create this logprob_data so we avoid Ray comm overheads sending unused data to workers. @@ -1756,13 +1771,12 @@ def grpo_train( **extra_multimodal_data, } ) - train_data["prev_logprobs"] = policy.get_logprobs( - logprob_data, timer=timer - )["logprobs"] + if not skip_prev_logprobs: + train_data["prev_logprobs"] = policy.get_logprobs( + logprob_data, timer=timer + )["logprobs"] - if not master_config["grpo"].get( - "skip_reference_policy_logprobs_calculation" - ): + if not skip_reference_policy_logprobs: train_data["reference_policy_logprobs"] = ( policy.get_reference_policy_logprobs( logprob_data, @@ -2789,22 +2803,39 @@ def async_grpo_train( train_data.to("cpu") # Training phase (same as sync version) - print("▶ Preparing for logprob inference...") - with timer.time("logprob_inference_prep"): - policy.prepare_for_lp_inference() + force_on_policy_ratio = master_config["loss_fn"].get( + "force_on_policy_ratio", False + ) + skip_prev_logprobs = force_on_policy_ratio + skip_reference_policy_logprobs = master_config["grpo"].get( + "skip_reference_policy_logprobs_calculation", False + ) + if skip_prev_logprobs: + print( + "Skipping prev_logprobs computation due to force_on_policy_ratio=True" + ) + train_data["prev_logprobs"] = torch.zeros_like( + train_data["generation_logprobs"] + ) + if not (skip_prev_logprobs and skip_reference_policy_logprobs): + print("▶ Preparing for logprob inference...") + with timer.time("logprob_inference_prep"): + policy.prepare_for_lp_inference() print("▶ Computing logprobs...") with timer.time("policy_and_reference_logprobs"): - fprop_logprobs = policy.get_logprobs( - train_data, - timer=timer, - )["logprobs"] - reference_logprobs = policy.get_reference_policy_logprobs( - train_data, - timer=timer, - )["reference_logprobs"] - train_data["prev_logprobs"] = fprop_logprobs - train_data["reference_policy_logprobs"] = reference_logprobs + if not skip_prev_logprobs: + fprop_logprobs = policy.get_logprobs( + train_data, + timer=timer, + )["logprobs"] + train_data["prev_logprobs"] = fprop_logprobs + if not skip_reference_policy_logprobs: + reference_logprobs = policy.get_reference_policy_logprobs( + train_data, + timer=timer, + )["reference_logprobs"] + train_data["reference_policy_logprobs"] = reference_logprobs ( max_seq_mult_prob_error, diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index c72269eee1..6cc2f5c5ca 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -200,7 +200,10 @@ def __call__( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] advantages = data["advantages"][:, 1:] - prev_logprobs = data["prev_logprobs"][:, 1:] + # Skip loading prev_logprobs when force_on_policy_ratio=True (will use curr_logprobs instead) + prev_logprobs = ( + None if self.force_on_policy_ratio else data["prev_logprobs"][:, 1:] + ) generation_logprobs = data["generation_logprobs"][:, 1:] if self.reference_policy_kl_penalty != 0: reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] @@ -208,6 +211,44 @@ def __call__( "curr_logprobs_unfiltered", curr_logprobs ) + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + curr_logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + data["input_ids"], + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + curr_logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, data["input_ids"], seq_index=seq_index + ) + else: + next_token_logits_wo_last = next_token_logits[ + :, :-1 + ] # Remove last position's logits + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token + curr_logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + # For truly on-policy training, use curr_logprobs as prev_logprobs + # This avoids computing prev_logprobs upstream + if self.force_on_policy_ratio: + prev_logprobs = curr_logprobs.detach() + mask = token_mask * sample_mask.unsqueeze(-1) # token_mult_prob_error