diff --git a/docs/guides/prorlv2.md b/docs/guides/prorlv2.md new file mode 100644 index 0000000000..c028e88764 --- /dev/null +++ b/docs/guides/prorlv2.md @@ -0,0 +1,205 @@ +# An In-Depth Walkthrough of ProRLv2 in NeMo RL + +This guide covers the ProRLv2 configuration pattern in NeMo RL, based on the example config [`examples/configs/prorlv2.yaml`](../../examples/configs/prorlv2.yaml). + +ProRLv2 (as used in this repo) is best thought of as **GRPO and a bundle of stability/efficiency techniques** commonly used for long-horizon RL fine-tuning + +- **DAPO dynamic sampling**: skip prompt-groups with zero reward variance +- **Decoupled (asymmetric) clipping**: `ratio_clip_max > ratio_clip_min` +- **Token-level policy gradient loss** +- **Importance sampling correction and TIS/CE-POP** (especially helpful for MoE/backend-mismatch scenarios) +- **Reinforce++: Decoupled local/global advantage normalization** (`reinforce_plus_plus`) +- **“Stop properly” penalty** for truncated responses + +This document focuses on ProRLv2-specific knobs and gotchas. For foundational concepts on GRPO (data, environments, generation backends, loss/metrics), see the [NeMo RL GRPO Guide](grpo.md). For the original DAPO motivation behind dynamic sampling/overlong shaping, see the [NeMo RL DAPO Guide](dapo.md). + +## Quickstart: Launch a ProRLv2 Run + +Use the example configuration [`examples/configs/prorlv2.yaml`](../../examples/configs/prorlv2.yaml): + +```bash +uv run examples/run_grpo_math.py --config examples/configs/prorlv2.yaml {overrides} +``` + +`prorlv2.yaml` inherits from [`examples/configs/grpo_math_1B.yaml`](../../examples/configs/grpo_math_1B.yaml) and only overrides a small set of fields under `grpo` and `loss_fn`, plus output directories. + +**Reminder**: Don’t forget to set your `HF_HOME`, `WANDB_API_KEY`, and `HF_DATASETS_CACHE` (if needed). You’ll need to do a `huggingface-cli login` as well for gated models. + +## DAPO: Dynamic Sampling + +Standard GRPO will train on all generated responses, even when a prompt’s `num_generations_per_prompt` responses all receive the same reward (no per-prompt learning signal). **Dynamic sampling** filters to keep only prompt-groups with diverse rewards (`std > 0`), and can accumulate across multiple generation batches until it reaches the target rollout batch size. + +- **Config**: enable with `grpo.use_dynamic_sampling: true` and tune: + - `grpo.batch_multiplier`: how many extra prompts to generate to compensate filtering + - `grpo.dynamic_sampling_max_gen_batches`: upper bound before raising an error +- **Implementation**: see `dynamic_sampling()` in [`nemo_rl/algorithms/grpo.py`](../../nemo_rl/algorithms/grpo.py). + +## Advantage Estimator: Reinforce++ + +The ProRLv2 recipe uses **Reinforce++** advantage estimation instead of the standard GRPO-style group baseline. + +Quick intuition: + +- Reinforce++ uses **decoupled local + global normalization**. +- Compared to GRPO-style **local-only normalization**, this decoupling can be **more stable** in longer runs (less sensitivity to per-batch scale/variance shifts). + +Computation (as implemented in this repo, with the ProRLv2 example defaults): + +```text +Defaults in examples/configs/prorlv2.yaml: + grpo.adv_estimator.minus_baseline = true + loss_fn.use_kl_in_reward = false + +Steps: + 1) Per prompt-group, compute mean reward, then subtract it: + a_i = r_i - mean_{j in same prompt} r_j + + 2) Global normalize across *all valid response tokens* in the batch: + A <- (A - mean(A)) / sqrt(max(var(A), 1e-8)) +``` + +```yaml +grpo: + adv_estimator: + name: "reinforce_plus_plus" + normalize_rewards: true + use_leave_one_out_baseline: false + minus_baseline: true +``` + +- **Config**: `grpo.adv_estimator.name: "reinforce_plus_plus"` +- **Implementation**: the training loop wires this via `ReinforcePlusPlusAdvantageEstimator` in [`nemo_rl/algorithms/grpo.py`](../../nemo_rl/algorithms/grpo.py). +- **Reference**: [REINFORCE++ paper](https://arxiv.org/abs/2501.03262) + +## Reward Shaping: “Stop properly” Penalty (Truncation Penalty) + +When a generation hits the max length without emitting EOS, many pipelines mark it as **truncated**. The “stop properly” penalty scales the reward for truncated samples: + +- `stop_properly_penalty_coef = 0.0`: truncated samples get **zero reward** +- `stop_properly_penalty_coef = 1.0`: **no penalty** (keep original rewards) +- Any value in \([0, 1]\) interpolates between the two. + +In the example config: + +```yaml +grpo: + reward_shaping: + enabled: true + stop_properly_penalty_coef: 0.0 +``` + +- **Implementation**: `apply_reward_shaping()` in [`nemo_rl/algorithms/reward_functions.py`](../../nemo_rl/algorithms/reward_functions.py). + +:::{important} +In the current implementation, if `stop_properly_penalty_coef` is set (not `null`), `apply_reward_shaping()` **returns early** after applying truncation scaling. That means you **cannot** apply DAPO "overlong reward shaping" in the same run unless you set `stop_properly_penalty_coef: null` and provide the DAPO overlong parameters (`overlong_buffer_length`, `overlong_buffer_penalty`, `max_response_length`). +::: + +## Loss: Decoupled (Asymmetric) Clipping + +ProRLv2 uses DAPO’s “decoupled clipping” idea by setting different lower/upper clip bounds: + +```yaml +loss_fn: + ratio_clip_min: 0.2 + ratio_clip_max: 0.27 +``` + +This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse. + +- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). + +## Loss: Token-level Policy Gradient + +ProRLv2 enables token-level loss: + +```yaml +loss_fn: + token_level_loss: true +``` + +This computes the policy gradient loss per token (under masking) instead of aggregating per sequence, which is often helpful for long CoT/variable-length rollouts. + +## Truncated Importance Sampling + +When training and generation backends differ (e.g., numerics, precision, MoE routing, or vLLM vs training framework), you may see a mismatch between: + +- `generation_logprobs` (logprobs under the generation backend that produced samples) +- `prev_logprobs` (logprobs under the training framework policy) + +NeMo RL supports **importance sampling correction**, and ProRLv2’s example config turns it on together with **truncated importance sampling**. + +Quick intuition: + +- This is mainly useful for **MoE/backend mismatch** cases, where the generation backend and the training policy can disagree on logprobs. +- We compute an importance weight from `prev_logprobs` (training policy) vs `generation_logprobs` (generator). **ICE-POP** drops outliers by zeroing weights outside \([min, max]\). +- In the common setup of **one policy update per rollout batch** (i.e., minibatch equals the per-step rollout batch; no PPO multi-epoch reuse), the PPO/GRPO likelihood ratio term is effectively **1.0** at update time, so the main stability issue is the MoE/backend-mismatch importance weights. +- “Online ICE-POP” here just means applying that ICE-POP filtering **during loss computation** on the current training batch. + +- **Reference**: [The Online IcePop Solution for MoE models](https://hijkzzz.notion.site/online-ice-pop) + +```yaml +loss_fn: + use_importance_sampling_correction: true + truncated_importance_sampling_ratio: 5.0 + truncated_importance_sampling_ratio_min: 0.5 + truncated_importance_sampling_type: "icepop" +``` + +- **`use_importance_sampling_correction`**: enable token-level importance weights (must be `true` for truncated IS) +- **`truncated_importance_sampling_ratio`**: upper bound (or upper threshold) +- **`truncated_importance_sampling_ratio_min`**: lower bound used by ICE-POP filtering +- **`truncated_importance_sampling_type`**: + - `"tis"`: clamp weights to `<= truncated_importance_sampling_ratio` + - `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers) + +- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). + +## Full Example Config (Annotated) + +The ProRLv2 example config is intentionally small and relies on defaults from `grpo_math_1B.yaml`. + +- **Example config**: [`examples/configs/prorlv2.yaml`](../../examples/configs/prorlv2.yaml) +- **Base defaults**: [`examples/configs/grpo_math_1B.yaml`](../../examples/configs/grpo_math_1B.yaml) + +## Practical Overrides + +A few common overrides when launching: + +```bash +uv run examples/run_grpo_math.py \ + --config examples/configs/prorlv2.yaml \ + policy.model_name="Qwen/Qwen2.5-1.5B" \ + logger.wandb_enabled=true \ + logger.wandb.project="prorlv2-dev" \ + checkpointing.checkpoint_dir="results/prorlv2" \ + logger.log_dir="logs/prorlv2" +``` + +If you want to enable DAPO overlong reward shaping instead of stop-properly: + +```bash +uv run examples/run_grpo_math.py \ + --config examples/configs/prorlv2.yaml \ + grpo.reward_shaping.stop_properly_penalty_coef=null \ + grpo.reward_shaping.overlong_buffer_length=4096 \ + grpo.reward_shaping.overlong_buffer_penalty=1.0 \ + grpo.reward_shaping.max_response_length=20480 +``` + +## What to Monitor + +In addition to task rewards/accuracy, a few stability signals are particularly useful with ProRLv2-style runs: + +- **Dynamic sampling efficiency**: if enabled, watch how often batches need multiple generation rounds (see `dapo.md` for detailed guidance). +- **Training–generation mismatch**: `token_mult_prob_error`, `gen_kl_error`, `policy_kl_error`, `js_divergence_error` are computed in `ClippedPGLossFn` (see the [GRPO metrics section](grpo.md#metrics)). +- **Truncation rate**: if high, either increase `policy.max_total_sequence_length`/`policy.generation.max_model_len` or relax truncation penalty (`stop_properly_penalty_coef`). + +## References + +- **ProRLv2 blog**: [Scaling LLM Reinforcement Learning with Prolonged Training using ProRL v2](https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/) +- **DAPO**: [Decoupled Clip and Dynamic Sampling Policy Optimization](https://arxiv.org/pdf/2503.14476) +- **GRPO**: [Group Relative Policy Optimization](https://arxiv.org/abs/2402.03300) +- **REINFORCE++**: [REINFORCE++](https://arxiv.org/abs/2501.03262) +- **DLER (stop properly penalty explanation)**: [DLER](https://arxiv.org/pdf/2510.15110) +- **[NeMo RL GRPO Guide](grpo.md)** +- **[NeMo RL DAPO Guide](dapo.md)** diff --git a/docs/index.md b/docs/index.md index e67dea73b7..40af886f92 100644 --- a/docs/index.md +++ b/docs/index.md @@ -209,6 +209,7 @@ adding-new-models.md guides/sft.md guides/dpo.md guides/dapo.md +guides/prorlv2.md guides/grpo.md guides/grpo-deepscaler.md guides/grpo-sliding-puzzle.md diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 90269726d7..69e25f8f4f 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -22,6 +22,15 @@ grpo: overlong_buffer_length: 128 overlong_buffer_penalty: 1 max_response_length: ${policy.max_total_sequence_length} + stop_properly_penalty_coef: null + + # Advantage Estimator Configuration + # Options: "grpo" (default) or "reinforce_plus_plus" + adv_estimator: + name: "grpo" # Use "reinforce_plus_plus" for Reinforce++ estimator + normalize_rewards: true + use_leave_one_out_baseline: false + minus_baseline: true # Reinforce++-baseline specific: subtract per-prompt mean baseline reward_scaling: enabled: false source_min: 0.0 @@ -52,9 +61,12 @@ loss_fn: # Set to true when async_grpo.enabled is true use_importance_sampling_correction: false truncated_importance_sampling_ratio: null + truncated_importance_sampling_ratio_min: null # Lower bound for ICE-POP + truncated_importance_sampling_type: tis # "tis" (clamp to max) or "icepop" (filter outside [min, max]) sequence_level_importance_ratios: false token_level_loss: true force_on_policy_ratio: false # Set to true to force ratio=1.0 (requires train_global_batch_size == num_prompts_per_step * num_generations_per_prompt) + use_kl_in_reward: false # Reinforce++: add KL penalty to reward instead of loss checkpointing: enabled: true diff --git a/examples/configs/prorlv2.yaml b/examples/configs/prorlv2.yaml new file mode 100644 index 0000000000..74545ddbf4 --- /dev/null +++ b/examples/configs/prorlv2.yaml @@ -0,0 +1,106 @@ +# ProRLv2 Algorithm Configuration +# +# This configuration implements ProRLv2 with TIS techniques: +# - Dynamic Sampling: Filter prompts with zero reward variance +# - Decoupled Clipping: Asymmetric ratio clipping (clip_max > clip_min) +# - Token-level Loss: Fine-grained policy gradient +# - Truncated Importance Sampling (TIS) / IcePop for MoE models +# - REINFORCE++: Decoupled local and global advantage normalization estimator +# - Stop properly penalty: Reward scale coefficient for truncated responses +# +# Inherits from grpo_math_1B.yaml +# +# Usage: +# python examples/run_grpo_math.py --config examples/configs/prorlv2.yaml +# +# Reference papers and blogs: +# ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ +# REINFORCE++: https://arxiv.org/abs/2501.03262 +# The Online IcePop Solution for MoE models: https://hijkzzz.notion.site/online-ice-pop +# DLER (for Stop properly penalty): https://arxiv.org/pdf/2510.15110 + +defaults: "grpo_math_1B.yaml" + +grpo: + # ============================================================================ + # DAPO: Dynamic Sampling + # Filter out prompts where all generations have the same reward (std=0) + # This focuses training on "learnable" examples with mixed outcomes + # ============================================================================ + use_dynamic_sampling: true + dynamic_sampling_max_gen_batches: 10 # Max batches before error + batch_multiplier: 1.5 # Generate more prompts to account for filtering + + + # ============================================================================ + # Advantage Estimator + # Options: "grpo" (default) or "reinforce_plus_plus" + # ============================================================================ + adv_estimator: + name: "reinforce_plus_plus" # Use "grpo" for standard GRPO + # Global normalization of rewards + normalize_rewards: true + use_leave_one_out_baseline: false + # Reinforce++-Baseline specific + minus_baseline: true + + # ============================================================================ + # Reward Shaping + # Applied to rewards before advantage calculation + # Includes DAPO overlong penalty and stop properly penalty + # ============================================================================ + reward_shaping: + enabled: true + # Stop properly penalty: scale factor for truncated responses (0-1) + # 0 = zero reward for truncated (default), 1 = no penalty + stop_properly_penalty_coef: 0.0 # Set to e.g., 0.1 to halve truncated rewards + +# ============================================================================ +# Loss Function Configuration +# ============================================================================ +loss_fn: + # KL regularization + reference_policy_kl_penalty: 0.0001 + reference_policy_kl_type: "k2" + kl_input_clamp_value: 20.0 + kl_output_clamp_value: 10.0 + + # ============================================================================ + # DAPO: Decoupled (Asymmetric) Clipping + # ratio_clip_max > ratio_clip_min allows more exploration + # Standard PPO uses symmetric clipping (both = 0.2) + # ============================================================================ + ratio_clip_min: 0.2 + ratio_clip_max: 0.27 # Slightly larger for exploration + + # Dual-clipping (set to e.g., 3.0 to enable, null to disable) + ratio_clip_c: null + + # ============================================================================ + # DAPO: Token-level Loss + # Compute loss per-token instead of per-sequence + # ============================================================================ + token_level_loss: true + + # ============================================================================ + # Truncated Importance Sampling (TIS / ICE-POP) + # Requires use_importance_sampling_correction: true + # ============================================================================ + use_importance_sampling_correction: true + truncated_importance_sampling_ratio: 5.0 # Upper bound + truncated_importance_sampling_ratio_min: 0.5 # Lower bound (ICE-POP only) + # Type: "tis" (clamp to max) or "icepop" (filter outside [min, max]) + truncated_importance_sampling_type: "icepop" + + # Reinforce++: add KL penalty to reward instead of loss + # Set to false to use external KL loss (reference_policy_kl_penalty) for better stability + use_kl_in_reward: false + +# ============================================================================ +# Output directories +# ============================================================================ +checkpointing: + checkpoint_dir: "results/prorl" + +logger: + log_dir: "logs/prorl" diff --git a/examples/configs/recipes/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml new file mode 100644 index 0000000000..89d364c521 --- /dev/null +++ b/examples/configs/recipes/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml @@ -0,0 +1,29 @@ +defaults: ../../prorlv2.yaml +grpo: + max_num_steps: 450 +checkpointing: + checkpoint_dir: results/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 +policy: + model_name: Qwen/Qwen2.5-Math-1.5B-Instruct + tokenizer: + name: Qwen/Qwen2.5-Math-1.5B-Instruct + dynamic_batching: + enabled: true + sequence_packing: + enabled: false + make_sequence_length_divisible_by: 1 + generation: + max_new_tokens: 512 + vllm_cfg: + max_model_len: 512 +data: + max_input_seq_length: 512 +logger: + log_dir: logs/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 + wandb_enabled: true + tensorboard_enabled: true + wandb: + project: nemo-rl + name: prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1 +cluster: + gpus_per_node: 8 diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index f9612007a4..9c6049d85d 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -22,6 +22,13 @@ grpo: overlong_buffer_length: 512 overlong_buffer_penalty: 1 max_response_length: ${policy.max_total_sequence_length} + # Advantage Estimator Configuration + # Options: "grpo" (default) or "reinforce_plus_plus" + adv_estimator: + name: "grpo" # Use "reinforce_plus_plus" for Reinforce++ estimator + normalize_rewards: true + use_leave_one_out_baseline: false + minus_baseline: true # Reinforce++-baseline specific: subtract per-prompt mean baseline reward_scaling: enabled: false source_min: 0.0 diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index b32cd7df04..8e13681629 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -20,6 +20,13 @@ grpo: overlong_buffer_length: 512 overlong_buffer_penalty: 1 max_response_length: ${policy.max_total_sequence_length} + # Advantage Estimator Configuration + # Options: "grpo" (default) or "reinforce_plus_plus" + adv_estimator: + name: "grpo" # Use "reinforce_plus_plus" for Reinforce++ estimator + normalize_rewards: true + use_leave_one_out_baseline: false + minus_baseline: true # Reinforce++-baseline specific: subtract per-prompt mean baseline reward_scaling: enabled: false source_min: 0.0 diff --git a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml index f3e3dcccc8..45a32cd273 100644 --- a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml +++ b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml @@ -21,6 +21,13 @@ grpo: overlong_buffer_length: 128 overlong_buffer_penalty: 1 max_response_length: ${policy.max_total_sequence_length} + # Advantage Estimator Configuration + # Options: "grpo" (default) or "reinforce_plus_plus" + adv_estimator: + name: "grpo" # Use "reinforce_plus_plus" for Reinforce++ estimator + normalize_rewards: true + use_leave_one_out_baseline: false + minus_baseline: true # Reinforce++-baseline specific: subtract per-prompt mean baseline reward_scaling: enabled: false source_min: 0.0 diff --git a/nemo_rl/algorithms/advantage_estimator.py b/nemo_rl/algorithms/advantage_estimator.py new file mode 100644 index 0000000000..6d14d8637a --- /dev/null +++ b/nemo_rl/algorithms/advantage_estimator.py @@ -0,0 +1,144 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Advantage Estimators for RL algorithms. + +This module provides different advantage estimation strategies: +- GRPOAdvantageEstimator: Standard GRPO advantage with leave-one-out baseline +- ReinforcePlusPlusAdvantageEstimator: Reinforce++ with optional baseline subtraction (minus_baseline) and KL penalty in reward +Reference papers: +- ProRLv2: https://developer.nvidia.com/blog/scaling-llm-reinforcement-learning-with-prolonged-training-using-prorl-v2/ +- Reinforce++: https://arxiv.org/abs/2501.03262 +""" + +import torch + +from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, calculate_kl + + +class GRPOAdvantageEstimator: + """GRPO-style advantage estimator with leave-one-out baseline. + + Note: GRPO computes advantages over all responses for each prompt. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.use_leave_one_out_baseline = estimator_config["use_leave_one_out_baseline"] + self.normalize_rewards = estimator_config["normalize_rewards"] + + def compute_advantage(self, prompt_ids, rewards, mask, **kwargs): + """Compute GRPO advantages. + + Args: + prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. + rewards: Tensor of shape [batch_size] containing reward for each sample. + mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used only for expanding advantages to token-level shape. + **kwargs: Additional arguments (unused). + + Returns: + Advantages tensor of shape [batch_size, seq_len]. + """ + baseline, std = calculate_baseline_and_std_per_prompt( + prompt_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=self.use_leave_one_out_baseline, + ) + advantages = (rewards - baseline).unsqueeze(-1) + + if self.normalize_rewards: + # don't sharpen the ones with no variation + epsilon = 1e-6 + non_zero_std_mask = std > 0 + advantages[non_zero_std_mask] = advantages[non_zero_std_mask] / ( + std.unsqueeze(-1)[non_zero_std_mask] + epsilon + ) + + return advantages.expand(mask.shape) + + +class ReinforcePlusPlusAdvantageEstimator: + """Reinforce++ advantage estimator with optional baseline subtraction and KL penalty in reward. + + Args: + minus_baseline: If True, subtract per-prompt mean baseline from rewards. + use_kl_in_reward: If True, add KL penalty to reward instead of loss. + """ + + def __init__(self, estimator_config: dict, loss_config: dict): + self.minus_baseline = estimator_config["minus_baseline"] + self.use_kl_in_reward = loss_config["use_kl_in_reward"] + self.kl_coef = loss_config["reference_policy_kl_penalty"] + self.kl_type = loss_config["reference_policy_kl_type"] + + def compute_advantage( + self, + prompt_ids, + rewards, + mask, + logprobs_policy=None, + logprobs_reference=None, + **kwargs, + ): + """Compute Reinforce++ advantages with optional KL penalty. + + Args: + prompt_ids: Tensor of shape [batch_size] identifying which prompt each sample belongs to. + rewards: Tensor of shape [batch_size] containing reward for each sample. + mask: Response token mask of shape [batch_size, seq_len], 1 for valid response tokens, 0 for padding. + Used for: (1) expanding advantages to token-level shape, (2) global normalization + that only considers valid tokens. + logprobs_policy: Policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + logprobs_reference: Reference policy log probabilities of shape [batch_size, seq_len], required if use_kl_in_reward. + **kwargs: Additional arguments (unused). + + Returns: + Advantages tensor of shape [batch_size, seq_len], globally normalized across valid tokens. + """ + # minus baseline + if self.minus_baseline: + mean, _ = calculate_baseline_and_std_per_prompt( + prompt_ids, + rewards, + torch.ones_like(rewards), + leave_one_out_baseline=False, + ) + adv = rewards - mean + else: + adv = rewards + + adv = adv.unsqueeze(-1) + adv = adv.expand(mask.shape) + + # add kl penalty to reward (token-level) + if ( + self.use_kl_in_reward + and logprobs_policy is not None + and logprobs_reference is not None + ): + kl = calculate_kl( + logprobs_policy, + logprobs_reference, + kl_type=self.kl_type, + ) + adv = adv - self.kl_coef * kl + + # global normalization across the batch + adv_mean = (adv * mask).sum() / mask.sum() + adv_var = ((adv - adv_mean).pow(2) * mask).sum() / mask.sum() + adv_rstd = adv_var.clamp(min=1e-8).rsqrt() + adv = (adv - adv_mean) * adv_rstd + + return adv diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index b959e2828f..103d4a4f6e 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -27,6 +27,10 @@ from transformers import AutoProcessor from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from nemo_rl.algorithms.advantage_estimator import ( + GRPOAdvantageEstimator, + ReinforcePlusPlusAdvantageEstimator, +) from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.algorithms.loss_functions import ( ClippedPGLossConfig, @@ -114,6 +118,17 @@ class AsyncGRPOConfig(TypedDict): recompute_kv_cache_after_weight_updates: NotRequired[bool] +class AdvEstimatorConfig(TypedDict): + """Configuration for advantage estimator (GRPO or Reinforce++).""" + + name: str # "grpo" or "reinforce_plus_plus" + # GRPO specific + normalize_rewards: NotRequired[bool] + use_leave_one_out_baseline: NotRequired[bool] + # Reinforce++ specific + minus_baseline: NotRequired[bool] + + class GRPOConfig(TypedDict): num_prompts_per_step: int num_generations_per_prompt: int @@ -146,6 +161,8 @@ class GRPOConfig(TypedDict): reward_scaling: RewardScalingConfig # By default advantages are calculated on CPU. Setting this flag to true leverages GPU for their computation. calculate_advantages_on_gpu: NotRequired[bool] + # Advantage estimator configuration (grpo or reinforce_plus_plus) + adv_estimator: NotRequired[AdvEstimatorConfig] class GRPOSaveState(TypedDict): @@ -736,33 +753,6 @@ def initialize_generation_with_policy( # =============================================================================== -def normalize_advantages_with_epsilon( - advantages: torch.Tensor, - std: torch.Tensor, - epsilon: float = 1e-6, -) -> torch.Tensor: - """Normalize advantages by standard deviation, skipping samples with zero std. - - When std is exactly zero (from leave-one-out baseline with identical rewards), - normalization is skipped for those samples to prevent numerical instability. - This makes normalize_rewards compatible with use_leave_one_out_baseline. - - Args: - advantages: Tensor of shape (batch_size, 1) containing advantage values - std: Tensor of shape (batch_size,) containing standard deviation values - epsilon: Small value to avoid division by very small std, defaults to 1e-6 - - Returns: - Normalized advantages tensor of same shape as input advantages - """ - # Only normalize where std > 0 to avoid division by near-zero - non_zero_std_mask = std > 0 - advantages[non_zero_std_mask] = advantages[non_zero_std_mask] / ( - std.unsqueeze(-1)[non_zero_std_mask] + epsilon - ) - return advantages - - def dynamic_sampling( repeated_batch: BatchedDataDict[DatumSpec], std: torch.Tensor, @@ -982,6 +972,73 @@ def _should_log_nemo_gym_responses(master_config: MasterConfig) -> bool: return should_log_nemo_gym_responses +def _create_advantage_estimator(master_config: MasterConfig): + """Create and return an advantage estimator based on configuration. + + Args: + master_config: The master configuration dictionary. + + Returns: + An advantage estimator instance (GRPOAdvantageEstimator or ReinforcePlusPlusAdvantageEstimator). + + Raises: + ValueError: If the advantage estimator name is not recognized. + """ + grpo_config = master_config["grpo"] + loss_config = master_config["loss_fn"] + + # Provide backward-compatible defaults when adv_estimator is not in config. + # Fall back to top-level grpo.normalize_rewards / grpo.use_leave_one_out_baseline + # which older configs still use. + adv_estimator_config = grpo_config.get( + "adv_estimator", + { + "name": "grpo", + "normalize_rewards": grpo_config.get("normalize_rewards", True), + "use_leave_one_out_baseline": grpo_config.get( + "use_leave_one_out_baseline", False + ), + "minus_baseline": True, + }, + ) + + adv_estimator_name = adv_estimator_config["name"] + if adv_estimator_name == "grpo": + adv_estimator = GRPOAdvantageEstimator(adv_estimator_config, loss_config) + print(" ✓ Using GRPO advantage estimator") + elif adv_estimator_name == "reinforce_plus_plus": + adv_estimator = ReinforcePlusPlusAdvantageEstimator( + adv_estimator_config, loss_config + ) + print(" ✓ Using Reinforce++ advantage estimator") + else: + raise ValueError(f"Invalid adv_estimator name: {adv_estimator_name}") + + return adv_estimator + + +def _extract_prompt_only_messages(message_logs: list) -> list: + """Extract only prompt messages (user/system) from message logs. + + This is used to get prompt IDs for advantage estimation, excluding + any assistant responses. + + Args: + message_logs: List of message logs, where each log is a list of messages. + + Returns: + List of message logs containing only user and system messages. + """ + prompt_only_message_logs = [] + for message_log in message_logs: + prompt_only_log = [] + for message in message_log: + if message["role"] == "user" or message["role"] == "system": + prompt_only_log.append(message) + prompt_only_message_logs.append(prompt_only_log) + return prompt_only_message_logs + + def refit_policy_generation( policy: ColocatablePolicyInterface, policy_generation: GenerationInterface, @@ -1172,6 +1229,9 @@ def grpo_train( val_period = master_config["grpo"]["val_period"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + # Initialize advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + # Run validation at the start if configured # TODO: Add validation with kv scales if needed if val_at_start and current_step == 0: @@ -1440,20 +1500,20 @@ def grpo_train( gen_step_metrics = policy_generation.get_step_metrics() advantages = (rewards - baseline).unsqueeze(-1) - if master_config["grpo"]["normalize_rewards"]: - advantages = normalize_advantages_with_epsilon( - advantages=advantages, - std=std, - ) + # Save baseline for logging (before deletion) + baseline_for_log = baseline.clone() - _log_mixed_rewards_and_advantages_information( - logger=logger, - total_steps=total_steps, - metrics=metrics, - baseline=baseline, - advantages=advantages, + # Extract prompt-only messages for advantage estimation + prompt_only_message_logs = _extract_prompt_only_messages( + repeated_batch["message_log"] ) - + prompt_batched_flat, _ = batched_message_log_to_flat_message( + prompt_only_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, + ) + prompt_ids_for_adv = prompt_batched_flat["token_ids"] + del prompt_only_message_logs + del prompt_batched_flat del input_ids del baseline del std @@ -1469,7 +1529,7 @@ def grpo_train( loss_multiplier[truncated] = 0 repeated_batch["loss_multiplier"] = loss_multiplier - # Add loss mask and advantages to each message in LLMMessageLogType + # Add loss mask to each message in LLMMessageLogType for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): if message["role"] == "assistant": @@ -1484,10 +1544,6 @@ def grpo_train( message["generation_logprobs"] = torch.zeros_like( message["token_ids"], dtype=torch.float32 ) - message["advantages"] = advantages[i].expand( - message["token_ids"].shape - ) - del advantages # Convert updated LLMMessageLogType to FlatMessagesType for training flat_messages, input_lengths = batched_message_log_to_flat_message( @@ -1499,11 +1555,11 @@ def grpo_train( ) # Create training data from flattened messages + # Note: advantages will be computed and added after logprobs are available train_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": flat_messages["token_ids"], "input_lengths": input_lengths, - "advantages": flat_messages["advantages"], "generation_logprobs": flat_messages["generation_logprobs"], "token_mask": flat_messages["token_loss_mask"], "sample_mask": repeated_batch["loss_multiplier"], @@ -1551,6 +1607,33 @@ def grpo_train( del logprob_data del extra_multimodal_data + # Compute advantages with adv_estimator using correct mask and logprobs + with timer.time("advantage_calculation"): + print("▶ Computing advantages...", flush=True) + # Get token-level mask: token_mask * sample_mask + token_mask = train_data["token_mask"] + sample_mask = train_data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + + train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) + del prompt_ids_for_adv + + # Log rewards and advantages information + _log_mixed_rewards_and_advantages_information( + logger=logger, + total_steps=total_steps, + metrics=metrics, + baseline=baseline_for_log, + advantages=train_data["advantages"], + ) + del baseline_for_log + memory_tracker.snapshot_start_of_stage("Policy train", dir()) print("▶ Preparing for training...", flush=True) with timer.time("training_prep"): @@ -1618,7 +1701,7 @@ def grpo_train( ) # Get flat advantages and token mask for masked metrics computation - flat_advantages = flat_messages["advantages"] + flat_advantages = train_data["advantages"] flat_token_mask = flat_messages["token_loss_mask"] del flat_messages @@ -1809,6 +1892,7 @@ def grpo_train( total_steps + 1, name="train/token_mult_prob_error_plot_sample", ) + del train_data if master_config["policy"]["generation"].get("vllm_cfg", {}).get( "enable_vllm_metrics_logger", False ) and master_config.get("logger", {}).get("wandb_enabled", False): @@ -1903,7 +1987,7 @@ def grpo_train( # processing rewards del repeated_batch del rewards - del train_data + # train_data already deleted after logging above # logging del metrics if "val_metrics" in dir(): @@ -2160,6 +2244,9 @@ def async_grpo_train( val_at_end = master_config["grpo"]["val_at_end"] colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + # Initialize advantage estimator + adv_estimator = _create_advantage_estimator(master_config) + assert not colocated_inference, ( "Colocated inference is not supported for async GRPO. Please use non-colocated inference." ) @@ -2437,59 +2524,27 @@ def async_grpo_train( print("▶ Processing rewards...") with timer.time("reward_calculation"): - prompt_only_message_logs = [] - for message_log in repeated_batch["message_log"]: - prompt_only_log = [] - for message in message_log: - if message["role"] == "user" or message["role"] == "system": - prompt_only_log.append(message) - prompt_only_message_logs.append(prompt_only_log) - - prompt_batched_flat, prompt_input_lengths = ( - batched_message_log_to_flat_message( - prompt_only_message_logs, - pad_value_dict={"token_ids": tokenizer.pad_token_id}, - ) + # Extract prompt-only messages for advantage estimation + prompt_only_message_logs = _extract_prompt_only_messages( + repeated_batch["message_log"] + ) + prompt_batched_flat, _ = batched_message_log_to_flat_message( + prompt_only_message_logs, + pad_value_dict={"token_ids": tokenizer.pad_token_id}, ) - prompt_only_ids = prompt_batched_flat["token_ids"] + prompt_ids_for_adv = prompt_batched_flat["token_ids"] + del prompt_only_message_logs + del prompt_batched_flat rewards = repeated_batch["total_reward"] - print("▶ Computing advantages...") - - baseline, std = calculate_baseline_and_std_per_prompt( - prompt_only_ids, - rewards, - torch.ones_like(rewards), - leave_one_out_baseline=master_config["grpo"][ - "use_leave_one_out_baseline" - ], - ) - advantages = (rewards - baseline).unsqueeze(-1) - print( f" 📊 Rewards stats: min={rewards.min():.4f}, max={rewards.max():.4f}, mean={rewards.mean():.4f}, std={rewards.std():.4f}" ) - print( - f" 📊 Baseline stats: min={baseline.min():.4f}, max={baseline.max():.4f}, mean={baseline.mean():.4f}" - ) - print( - f" 📊 Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" - ) - - if master_config["grpo"]["normalize_rewards"]: - advantages = normalize_advantages_with_epsilon( - advantages=advantages, - std=std, - ) - - print( - f" 📊 Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" - ) # Prepare training data (same as sync version) with timer.time("data_processing"): - # Add loss mask and advantages to each message + # Add loss mask to each message for i, message_log in enumerate(repeated_batch["message_log"]): for j, message in enumerate(message_log): if message["role"] == "assistant": @@ -2504,9 +2559,6 @@ def async_grpo_train( message["generation_logprobs"] = torch.zeros_like( message["token_ids"], dtype=torch.float32 ) - message["advantages"] = advantages[i].expand( - message["token_ids"].shape - ) # Convert to flat format for training flat_messages, input_lengths = batched_message_log_to_flat_message( @@ -2518,11 +2570,11 @@ def async_grpo_train( ) # Create training data + # Note: advantages will be computed and added after logprobs are available train_data = BatchedDataDict[ClippedPGLossDataDict]( { "input_ids": flat_messages["token_ids"], "input_lengths": input_lengths, - "advantages": flat_messages["advantages"], "generation_logprobs": flat_messages["generation_logprobs"], "token_mask": flat_messages["token_loss_mask"], "sample_mask": repeated_batch["loss_multiplier"], @@ -2548,6 +2600,33 @@ def async_grpo_train( train_data["prev_logprobs"] = fprop_logprobs train_data["reference_policy_logprobs"] = reference_logprobs + # Compute advantages with adv_estimator using correct mask and logprobs + with timer.time("advantage_calculation"): + print("▶ Computing advantages...", flush=True) + # Get token-level mask: token_mask * sample_mask + token_mask = train_data["token_mask"] + sample_mask = train_data["sample_mask"] + mask = token_mask * sample_mask.unsqueeze(-1) + + train_data["advantages"] = adv_estimator.compute_advantage( + prompt_ids=prompt_ids_for_adv, + rewards=rewards, + mask=mask, + logprobs_policy=train_data["prev_logprobs"], + logprobs_reference=train_data.get("reference_policy_logprobs"), + ) + del prompt_ids_for_adv + + # Log advantages stats + # Note: For GRPOAdvantageEstimator with normalize_rewards=True, these are + # already normalized advantages (equivalent to "Normalized advantages stats" + # in older versions). For ReinforcePlusPlusAdvantageEstimator, advantages + # are globally normalized across valid tokens. + advantages = train_data["advantages"] + print( + f" 📊 Advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" + ) + print("▶ Preparing for training...") with timer.time("training_prep"): policy.prepare_for_training() @@ -2635,8 +2714,11 @@ def async_grpo_train( # Resume trajectory collection after validation trajectory_collector.resume.remote() # Get flat advantages and token mask for masked metrics computation - flat_advantages = flat_messages["advantages"] + flat_advantages = train_data["advantages"] flat_token_mask = flat_messages["token_loss_mask"] + # Save content for logging before deleting flat_messages + flat_messages_content = flat_messages.get("content", []) + del flat_messages # Filter advantages using token mask (only valid response tokens) response_advantages = torch.masked_select( @@ -2774,7 +2856,7 @@ def async_grpo_train( checkpointer.finalize_checkpoint(checkpoint_path) policy.offload_after_refit() - log_data = {"content": flat_messages["content"]} + log_data = {"content": flat_messages_content} log_data["rewards"] = rewards.tolist() log_data["generation_logprobs"] = train_data["generation_logprobs"].tolist() log_data["prev_logprobs"] = train_data["prev_logprobs"].tolist() @@ -2782,6 +2864,8 @@ def async_grpo_train( logger.log_batched_dict_as_jsonl( log_data, f"train_data_step{step + 1}.jsonl" ) + del train_data + del flat_messages_content timing_metrics: dict[str, float] = timer.get_timing_metrics( reduction_op="sum" diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 21333d1f8d..c61cb5f0ce 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -45,6 +45,10 @@ class ClippedPGLossConfig(TypedDict): use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool truncated_importance_sampling_ratio: float | None + # Type of truncated importance sampling: "tis" (clamp max) or "icepop" (filter [min, max]) + truncated_importance_sampling_type: NotRequired[str | None] + # Lower bound for ICE-POP filtering (default 0.5) + truncated_importance_sampling_ratio_min: NotRequired[float | None] token_level_loss: bool # If True, apply the off-policy importance-sampling correction at the # sequence level (one weight per generated sample), as in GSPO. @@ -57,6 +61,8 @@ class ClippedPGLossConfig(TypedDict): # NOTE: This should only be used when doing exactly one update per rollout # (i.e., num_prompts_per_step * num_generations_per_prompt == train_global_batch_size) force_on_policy_ratio: NotRequired[bool] + # If True, add KL penalty to reward instead of loss (used by Reinforce++) + use_kl_in_reward: NotRequired[bool] class ClippedPGLossDataDict(TypedDict): @@ -132,6 +138,14 @@ def __init__(self, cfg: ClippedPGLossConfig): self.truncated_importance_sampling_ratio = cfg[ "truncated_importance_sampling_ratio" ] + # Type of truncated importance sampling: "tis" (clamp max) or "icepop" (filter [min, max]) + self.truncated_importance_sampling_type = cfg.get( + "truncated_importance_sampling_type" + ) + # Lower bound for ICE-POP filtering (default 0.5) + self.truncated_importance_sampling_ratio_min = cfg.get( + "truncated_importance_sampling_ratio_min" + ) # Whether to compute importance weights per-sequence instead of per-token. self.sequence_level_importance_ratios = cfg.get( "sequence_level_importance_ratios", @@ -151,6 +165,23 @@ def __init__(self, cfg: ClippedPGLossConfig): assert self.truncated_importance_sampling_ratio > 0, ( "truncated_importance_sampling_ratio should be positive" ) + assert self.truncated_importance_sampling_type in ("tis", "icepop"), ( + f"truncated_importance_sampling_type must be 'tis' or 'icepop', got {self.truncated_importance_sampling_type}" + ) + else: + # Warn user that TIS-related parameters are ignored when truncated_importance_sampling_ratio is not set + ignored_params = [] + if cfg.get("truncated_importance_sampling_type") is not None: + ignored_params.append("truncated_importance_sampling_type") + if cfg.get("truncated_importance_sampling_ratio_min") is not None: + ignored_params.append("truncated_importance_sampling_ratio_min") + if ignored_params: + print( + f"[WARN] truncated_importance_sampling_ratio is not set, so the following " + f"parameters are ignored: {', '.join(ignored_params)}. " + f"Set truncated_importance_sampling_ratio to enable truncated importance sampling.", + flush=True, + ) def __call__( self, @@ -370,12 +401,35 @@ def __call__( actor_importance_weights_expanded = torch.nan_to_num( actor_importance_weights_expanded, nan=0.0, posinf=0.0, neginf=0.0 ) - # TIS see https://fengyao.notion.site/off-policy-rl + # Truncated Importance Sampling (TIS / ICE-POP) + # TIS: Simple clamp to max value + # ICE-POP: Filter out samples with importance weights outside [min, max] if self.truncated_importance_sampling_ratio is not None: - actor_importance_weights_expanded = torch.clamp( - actor_importance_weights_expanded, - max=self.truncated_importance_sampling_ratio, - ) + if self.truncated_importance_sampling_type == "tis": + # TIS: Simple clamp to max value + actor_importance_weights_expanded = torch.clamp( + actor_importance_weights_expanded, + max=self.truncated_importance_sampling_ratio, + ) + elif self.truncated_importance_sampling_type == "icepop": # icepop + # ICE-POP: Filter out samples with importance weights outside [min, max] + actor_importance_weights_expanded = torch.where( + ( + actor_importance_weights_expanded + >= self.truncated_importance_sampling_ratio_min + ) + & ( + actor_importance_weights_expanded + <= self.truncated_importance_sampling_ratio + ), + actor_importance_weights_expanded, + torch.zeros_like(actor_importance_weights_expanded), + ) + else: + raise ValueError( + f"Invalid truncated importance sampling type: {self.truncated_importance_sampling_type}" + ) + actor_importance_weights = actor_importance_weights_expanded del actor_importance_weights_expanded if self.use_importance_sampling_correction: diff --git a/nemo_rl/algorithms/reward_functions.py b/nemo_rl/algorithms/reward_functions.py index b4f2ad4d70..87c826db26 100644 --- a/nemo_rl/algorithms/reward_functions.py +++ b/nemo_rl/algorithms/reward_functions.py @@ -44,6 +44,11 @@ class RewardShapingConfig(TypedDict): # The maximum response length threshold. Responses exceeding this length will be penalized. max_response_length: NotRequired[int] + # Stop properly penalty: scale factor for rewards of truncated responses (0-1). + # When set to 0, truncated responses get zero reward. + # When set to 1, no penalty is applied (default behavior). + stop_properly_penalty_coef: NotRequired[float | None] + def apply_reward_shaping( batch: BatchedDataDict, cfg: RewardShapingConfig @@ -56,11 +61,62 @@ def apply_reward_shaping( if not cfg["enabled"]: return batch + # Apply stop properly penalty if configured + stop_properly_penalty_coef = cfg.get("stop_properly_penalty_coef", None) + if stop_properly_penalty_coef is not None: + assert 0 <= stop_properly_penalty_coef <= 1, ( + f"stop_properly_penalty_coef must be in [0, 1], got {stop_properly_penalty_coef}" + ) + # Warn user that DAPO overlong parameters are ignored when stop_properly_penalty_coef is set + ignored_params = [] + if cfg.get("overlong_buffer_length") is not None: + ignored_params.append("overlong_buffer_length") + if cfg.get("overlong_buffer_penalty") is not None: + ignored_params.append("overlong_buffer_penalty") + if cfg.get("max_response_length") is not None: + ignored_params.append("max_response_length") + if ignored_params: + print( + f"[WARN] stop_properly_penalty_coef is set, so the following DAPO overlong " + f"parameters are ignored: {', '.join(ignored_params)}. " + f"Set stop_properly_penalty_coef=null to use DAPO overlong reward shaping instead.", + flush=True, + ) + truncated = batch.get("truncated") + assert truncated is not None, "truncated field not found in batch" + if isinstance(truncated, list): + truncated = torch.tensor(truncated, dtype=torch.bool, device=rewards.device) + else: + truncated = truncated.to(device=rewards.device) + + num_truncated = truncated.sum().item() + if num_truncated > 0: + original_rewards = rewards.clone() + # For truncated samples, scale the reward by stop_properly_penalty_coef + rewards = torch.where( + truncated, rewards * stop_properly_penalty_coef, rewards + ) + batch["total_reward"] = rewards + print( + f"[INFO] stop properly penalty applied: {num_truncated}/{len(truncated)} samples truncated, " + f"coef={stop_properly_penalty_coef}, " + f"original_reward_mean={original_rewards[truncated].mean().item():.4f}, " + f"shaped_reward_mean={rewards[truncated].mean().item():.4f}", + flush=True, + ) + else: + print( + "[INFO] stop properly penalty: no truncated samples (truncation_rate=0)", + flush=True, + ) + + return batch + # DAPO reward shaping requires overlong_buffer_length, overlong_buffer_penalty, and max_response_length to be set. if ( - cfg["overlong_buffer_length"] is None - or cfg["overlong_buffer_penalty"] is None - or cfg["max_response_length"] is None + cfg.get("overlong_buffer_length") is None + or cfg.get("overlong_buffer_penalty") is None + or cfg.get("max_response_length") is None ): raise ValueError( "Reward function is enabled but only DAPO reward shaping is currently supported. Please ensure overlong_buffer_length, overlong_buffer_penalty, and max_response_length are properly configured." diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index 44dab2557f..231820fa10 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -81,6 +81,9 @@ def generate_responses( generation_lengths = generation_outputs["generation_lengths"] unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] + # Extract truncated info if available (response hit max_tokens without stop token) + response_truncated = generation_outputs.get("truncated") + # Extract generated parts generated_ids = [] for i in range(len(input_lengths)): @@ -115,6 +118,10 @@ def generate_responses( "total_generated_tokens": generation_lengths.sum().item(), } + # Add response_truncated to gen_metrics for use by caller + if response_truncated is not None: + gen_metrics["_response_truncated"] = response_truncated + return batch, generated_ids, gen_metrics @@ -176,6 +183,9 @@ async def generate_responses_async( generation_lengths = generation_outputs["generation_lengths"] unpadded_sequence_lengths = generation_outputs["unpadded_sequence_lengths"] + # Extract truncated info if available (response hit max_tokens without stop token) + response_truncated = generation_outputs.get("truncated") + # Extract generated parts generated_ids = [] for i in range(len(input_lengths)): @@ -220,6 +230,10 @@ async def generate_responses_async( except Exception as e: print(f"Error occurred while extracting gen_leader_worker_idx: {e}") + # Add response_truncated to gen_metrics for use by caller + if response_truncated is not None: + gen_metrics["_response_truncated"] = response_truncated + return batch, generated_ids, gen_metrics @@ -427,6 +441,13 @@ def run_multi_turn_rollout( greedy=greedy, ) + # Record response truncation (response hit max_tokens without stop token) + response_truncated = gen_metrics.pop("_response_truncated", None) + if response_truncated is not None: + for i, global_idx in enumerate(active_indices.tolist()): + if response_truncated[i]: + sample_truncated[global_idx] = True + # Record token usage - assistant for i, global_idx in enumerate(active_indices.tolist()): sample_assistant_token_counts[global_idx] += len(generated_ids[i]) @@ -683,6 +704,11 @@ async def run_sample_multi_turn_rollout( ) current_message_log = updated_message_log + # Check if response was truncated (hit max_tokens without stop token) + response_truncated = gen_metrics.pop("_response_truncated", None) + if response_truncated is not None and response_truncated[0]: + truncated = True + # Update token counts gen_token_count = len(generated_tokens) assistant_token_count += gen_token_count diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 80f4ced95e..037b4880f5 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -172,6 +172,7 @@ class GenerationOutputSpec(TypedDict): - generation_lengths: Tensor containing the actual length of each generated sequence - unpadded_sequence_lengths: Tensor containing the actual length of each input + generated sequence (without padding) - logprobs: Tensor of log probabilities for each generated token (right padded with zeros) + - truncated: Boolean tensor indicating if each sequence was truncated (hit max_tokens limit) - __extra__: Additional model-specific data fields Example of a batch with 2 sequences: @@ -197,6 +198,9 @@ class GenerationOutputSpec(TypedDict): [0.0, 0.0, 0.0, -1.2, -0.8, -2.1, -1.5, 0.0], # First 3 are 0 (input tokens), next 4 are actual logprobs [0.0, 0.0, 0.0, 0.0, 0.0, -0.9, -1.7, 0.0], # First 5 are 0 (input tokens), next 2 are actual logprobs ] + + truncated: + [False, True] # Example 2 was truncated (hit max_tokens limit without EOS) ``` All functions receiving or returning GenerationOutputSpec should ensure @@ -209,6 +213,9 @@ class GenerationOutputSpec(TypedDict): torch.Tensor ) # Length of full valid sequence (input + generated response) logprobs: torch.Tensor + truncated: NotRequired[ + torch.Tensor + ] # Whether each sequence was truncated and hit max_tokens without stop token __extra__: Any diff --git a/nemo_rl/models/generation/vllm/vllm_worker.py b/nemo_rl/models/generation/vllm/vllm_worker.py index 7e30a33aed..55fc58f774 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker.py +++ b/nemo_rl/models/generation/vllm/vllm_worker.py @@ -604,6 +604,7 @@ def generate( "logprobs": torch.zeros((0, 0), dtype=torch.float), "generation_lengths": torch.zeros(0, dtype=torch.long), "unpadded_sequence_lengths": torch.zeros(0, dtype=torch.long), + "truncated": torch.zeros(0, dtype=torch.bool), } ) @@ -636,6 +637,7 @@ def generate( logprobs_list = [] generation_lengths = [] unpadded_sequence_lengths = [] + truncated_list = [] # Track if response was truncated (hit max_tokens) max_length = 0 for output in outputs: max_length = max(max_length, len(output.outputs[0].token_ids)) @@ -682,6 +684,11 @@ def generate( response_length = sequence_length + len(generated_tokens) generation_lengths.append(len(generated_tokens)) unpadded_sequence_lengths.append(response_length) + + # Check if response was truncated (hit max_tokens length limit) + is_truncated = generation.finish_reason == "length" + truncated_list.append(is_truncated) + assert response_length <= self.llm.llm_engine.model_config.max_model_len, ( f"response_length={response_length} > max_model_len={self.llm.llm_engine.model_config.max_model_len}, which should not happen. Please check this behavior in isolation by running `uv run --extra vllm tools/model_diagnostics/1.max_model_len_respected.py {self.llm.llm_engine.model_config.model}` and raise this issue with the vllm team." ) @@ -700,6 +707,7 @@ def generate( "unpadded_sequence_lengths": torch.tensor( unpadded_sequence_lengths, dtype=torch.long ), + "truncated": torch.tensor(truncated_list, dtype=torch.bool), } ) diff --git a/nemo_rl/models/generation/vllm/vllm_worker_async.py b/nemo_rl/models/generation/vllm/vllm_worker_async.py index 0e4ea5cdeb..0fd2b5c063 100644 --- a/nemo_rl/models/generation/vllm/vllm_worker_async.py +++ b/nemo_rl/models/generation/vllm/vllm_worker_async.py @@ -727,12 +727,18 @@ async def process_single_sample(sample_idx): device=input_ids_single_row.device, ) + # Not truncated since no generation was attempted (length constraint) + truncated_tensor = torch.tensor( + [False], dtype=torch.bool, device=input_ids_single_row.device + ) + result_batch = BatchedDataDict[GenerationOutputSpec]( { "output_ids": output_ids_single_item_batched, "logprobs": logprobs_single_item, "generation_lengths": generation_lengths_tensor, "unpadded_sequence_lengths": unpadded_sequence_lengths_tensor, + "truncated": truncated_tensor, } ) @@ -832,12 +838,21 @@ async def process_single_sample(sample_idx): device=original_input_ids_single_row.device, ) + # Check if response was truncated (hit max_tokens length limit) + is_truncated = generation_details.finish_reason == "length" + truncated_tensor = torch.tensor( + [is_truncated], + dtype=torch.bool, + device=original_input_ids_single_row.device, + ) + result_batch = BatchedDataDict[GenerationOutputSpec]( { "output_ids": output_ids_single_item_batched, "logprobs": logprobs_single_item, "generation_lengths": generation_lengths_tensor, "unpadded_sequence_lengths": unpadded_sequence_lengths_tensor, + "truncated": truncated_tensor, } ) diff --git a/nemo_rl/utils/config.py b/nemo_rl/utils/config.py index 156a1b9b1c..78eeeea831 100644 --- a/nemo_rl/utils/config.py +++ b/nemo_rl/utils/config.py @@ -79,7 +79,10 @@ def load_config_with_inheritance( base_config = OmegaConf.create({}) for default in defaults: parent_path = resolve_path(base_dir, str(default)) - parent_config = load_config_with_inheritance(parent_path, base_dir) + # Use parent's directory as base_dir for resolving its own defaults + parent_config = load_config_with_inheritance( + parent_path, parent_path.parent + ) base_config = cast( DictConfig, merge_with_override(base_config, parent_config) ) diff --git a/pyrefly.toml b/pyrefly.toml index 32e67b658a..01284c88da 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -37,6 +37,7 @@ project-includes = [ "examples/custom_parallel/custom_parallel.py", "examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py", "nemo_rl/algorithms/__init__.py", + "nemo_rl/algorithms/advantage_estimator.py", "nemo_rl/algorithms/interfaces.py", "nemo_rl/algorithms/reward_functions.py", "nemo_rl/algorithms/utils.py", diff --git a/research/template_project/configs/grpo_math_1B.yaml b/research/template_project/configs/grpo_math_1B.yaml index c28e86d157..ef968b717e 100644 --- a/research/template_project/configs/grpo_math_1B.yaml +++ b/research/template_project/configs/grpo_math_1B.yaml @@ -14,6 +14,13 @@ grpo: max_val_samples: 256 val_batch_size: 256 seed: 42 + # Advantage Estimator Configuration + # Options: "grpo" (default) or "reinforce_plus_plus" + adv_estimator: + name: "grpo" # Use "reinforce_plus_plus" for Reinforce++ estimator + normalize_rewards: true + use_leave_one_out_baseline: false + minus_baseline: true # Reinforce++-baseline specific: subtract per-prompt mean baseline async_grpo: enabled: false # Set to true to enable async training mode # Max age (in training steps) for trajectories used in training diff --git a/tests/functional/L1_Functional_Tests_GPU.sh b/tests/functional/L1_Functional_Tests_GPU.sh index 52a53e41da..f10ef02f9f 100644 --- a/tests/functional/L1_Functional_Tests_GPU.sh +++ b/tests/functional/L1_Functional_Tests_GPU.sh @@ -26,6 +26,7 @@ time bash ./tests/functional/test_frozen_env.sh time uv run --no-sync bash ./tests/functional/sft.sh time uv run --no-sync bash ./tests/functional/sft_resume_diamond.sh time uv run --no-sync bash ./tests/functional/grpo.sh +time uv run --no-sync bash ./tests/functional/prorlv2.sh time uv run --no-sync bash ./tests/functional/grpo_async.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_async.sh time uv run --no-sync bash ./tests/functional/grpo_automodel_lora_non_colocated.sh diff --git a/tests/functional/prorlv2.sh b/tests/functional/prorlv2.sh new file mode 100755 index 0000000000..c39cff1da7 --- /dev/null +++ b/tests/functional/prorlv2.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) +# Mark the current repo as safe, since wandb fetches metadata about the repo +git config --global --add safe.directory $PROJECT_ROOT + +set -eou pipefail + +EXP_NAME=$(basename $0 .sh) +EXP_DIR=$SCRIPT_DIR/$EXP_NAME +LOG_DIR=$EXP_DIR/logs +JSON_METRICS=$EXP_DIR/metrics.json +RUN_LOG=$EXP_DIR/run.log +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} + +rm -rf $EXP_DIR $LOG_DIR +mkdir -p $EXP_DIR $LOG_DIR + +cd $PROJECT_ROOT +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ + $PROJECT_ROOT/examples/run_grpo.py \ + --config $PROJECT_ROOT/examples/configs/prorlv2.yaml \ + policy.model_name=Qwen/Qwen3-0.6B \ + policy.tokenizer.name=Qwen/Qwen3-0.6B \ + grpo.num_prompts_per_step=2 \ + grpo.num_generations_per_prompt=4 \ + grpo.use_dynamic_sampling=false \ + grpo.batch_multiplier=1 \ + policy.train_global_batch_size=4 \ + policy.train_micro_batch_size=1 \ + cluster.gpus_per_node=2 \ + grpo.max_num_steps=2 \ + logger.tensorboard_enabled=true \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=false \ + logger.monitor_gpus=true \ + checkpointing.enabled=false \ + $@ \ + 2>&1 | tee $RUN_LOG + +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +uv run tests/check_metrics.py $JSON_METRICS \ + 'max(data["train/gen_kl_error"]) < 0.01' diff --git a/tests/test_suites/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh b/tests/test_suites/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh new file mode 100755 index 0000000000..5c3c20434c --- /dev/null +++ b/tests/test_suites/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh @@ -0,0 +1,42 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=450 +MAX_STEPS=450 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=120 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_grpo_math.py \ + --config $CONFIG_PATH \ + grpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=True \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + uv run tests/check_metrics.py $JSON_METRICS \ + 'median(data["train/token_mult_prob_error"]) < 1.1' \ + 'data["train/token_mult_prob_error"]["450"] < 1.1' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 403646c680..35b5cec0f3 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -7,6 +7,9 @@ tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.v3.sh tests/test_suites/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.v3.sh tests/test_suites/llm/grpo-gemma3-1b-it-1n8g-fsdp2tp1.sh +# ProRLv2 convergence test +tests/test_suites/llm/prorlv2-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.sh + # SGLang backend tests/test_suites/llm/grpo-qwen3-0.6b-1n8g-sglang.sh tests/test_suites/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1-sglang.sh diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 4ae313ab95..7cadcc40f8 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -19,12 +19,15 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader +from nemo_rl.algorithms.advantage_estimator import ( + GRPOAdvantageEstimator, + ReinforcePlusPlusAdvantageEstimator, +) from nemo_rl.algorithms.grpo import ( _default_grpo_save_state, async_grpo_train, dynamic_sampling, grpo_train, - normalize_advantages_with_epsilon, validate, ) from nemo_rl.algorithms.loss_functions import ClippedPGLossFn @@ -1293,6 +1296,11 @@ def val_iter(self): "enabled": False, "max_trajectory_age_steps": 1, }, + "adv_estimator": { + "name": "grpo", + "use_leave_one_out_baseline": False, + "normalize_rewards": True, + }, }, "policy": { "train_global_batch_size": 1, @@ -1563,140 +1571,208 @@ def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys): # ============================================================================ -# Tests for normalize_advantages_with_epsilon function +# Tests for GRPOAdvantageEstimator class # ============================================================================ -def test_normalize_advantages_with_epsilon_basic(): - """Test basic functionality of normalize_advantages_with_epsilon.""" - # Test case with normal values - advantages = torch.tensor([[2.0], [4.0], [6.0]]) - std = torch.tensor([1.0, 2.0, 3.0]) - epsilon = 1e-6 +def test_grpo_advantage_estimator_zero_std(): + """Test GRPOAdvantageEstimator when std contains zeros (all rewards same for a prompt). + + This test verifies that: + 1. When std=0 (all rewards identical for a prompt), normalization is skipped and advantage=0 + 2. When std>0, advantages are properly normalized by std + """ + estimator_config = { + "use_leave_one_out_baseline": False, + "normalize_rewards": True, + } + loss_config = {} + estimator = GRPOAdvantageEstimator(estimator_config, loss_config) + + # prompt 0: all same rewards -> std=0; prompt 1: different rewards -> std>0 + prompt_ids = torch.tensor( + [[0], [0], [1], [1]] + ) # Shape (4, 1) for unique prompt matching + rewards = torch.tensor( + [2.0, 2.0, 1.0, 3.0] + ) # prompt 0: std=0; prompt 1: std=sqrt(2) + mask = torch.ones(4, 5) - result = normalize_advantages_with_epsilon(advantages, std, epsilon) + result = estimator.compute_advantage(prompt_ids, rewards, mask) - expected = torch.tensor([[2.0], [2.0], [2.0]]) - assert torch.allclose(result, expected, rtol=1e-5) + # prompt 0: std=0 -> skip normalization, advantage=0 (reward - mean = 0) + # prompt 1: With Bessel correction for 2 samples, std = sqrt(2), normalized = ±1/sqrt(2) ≈ ±0.7071 + expected_prompt_0 = torch.zeros(2, 5) # advantage=0 for all same rewards + sqrt2_inv = 1.0 / (2.0**0.5) + expected_prompt_1 = torch.tensor([-sqrt2_inv, sqrt2_inv]).unsqueeze(-1).expand(2, 5) + assert torch.allclose(result[:2], expected_prompt_0, rtol=1e-5) + assert torch.allclose(result[2:], expected_prompt_1, rtol=1e-4) -def test_normalize_advantages_with_epsilon_zero_std(): - """Test normalize_advantages_with_epsilon when std contains zeros.""" - advantages = torch.tensor([[1.0], [2.0], [3.0]]) - std = torch.tensor([0.0, 1.0, 0.0]) # Zero std for indices 0 and 2 - epsilon = 1e-6 - result = normalize_advantages_with_epsilon(advantages, std, epsilon) +def test_grpo_advantage_estimator_tensor_shapes(): + """Test GRPOAdvantageEstimator with different tensor shapes. - # When std=0 AND advantage!=0, normalization is skipped (advantages unchanged) - # When std>0, normal normalization occurs - expected = torch.tensor( - [[1.0], [2.0], [3.0]] - ) # Samples 0,2 unchanged; sample 1 normalized - assert torch.allclose(result, expected, rtol=1e-5) + This test verifies that the estimator works correctly with: + 1. Small batch size (batch=2, single prompt) + 2. Larger batch size (batch=10, single prompt) + """ + estimator_config = { + "use_leave_one_out_baseline": False, + "normalize_rewards": True, + } + loss_config = {} + estimator = GRPOAdvantageEstimator(estimator_config, loss_config) + # Test with batch size 2 + prompt_ids = torch.tensor([[0], [0]]) + rewards = torch.tensor([1.0, 3.0]) # mean=2, std=sqrt(2) with Bessel + mask = torch.ones(2, 3) -def test_normalize_advantages_with_epsilon_all_zero_std(): - """Test normalize_advantages_with_epsilon when all std values are zero.""" - advantages = torch.tensor([[1.5], [2.5], [3.5]]) - std = torch.tensor([0.0, 0.0, 0.0]) - epsilon = 1e-8 + result = estimator.compute_advantage(prompt_ids, rewards, mask) + assert result.shape == (2, 3) - # Save expected values BEFORE calling function (since it modifies in-place) - expected = advantages.clone() + # Verify normalized values: (reward - mean) / std + # With Bessel correction for 2 samples: std = sqrt(2) + sqrt2_inv = 1.0 / (2.0**0.5) + expected = torch.tensor([[-sqrt2_inv], [sqrt2_inv]]).expand(2, 3) + assert torch.allclose(result, expected, rtol=1e-4) - result = normalize_advantages_with_epsilon(advantages, std, epsilon) + # Test with larger batch (10 samples, single prompt) + prompt_ids = torch.tensor([[0]] * 10) + rewards = torch.arange(10, dtype=torch.float32) # 0, 1, 2, ..., 9 + mask = torch.ones(10, 5) - # When std=0 AND advantage!=0, normalization is skipped (all unchanged) - assert torch.allclose(result, expected, rtol=1e-5) + result = estimator.compute_advantage(prompt_ids, rewards, mask) + assert result.shape == (10, 5) + # After normalization, mean should be ~0 + result_mean = result.mean() + assert torch.abs(result_mean) < 1e-5 -def test_normalize_advantages_with_epsilon_tensor_shapes(): - """Test normalize_advantages_with_epsilon with different tensor shapes.""" - # Test with batch size 1 - advantages = torch.tensor([[5.0]]) - std = torch.tensor([2.0]) - result = normalize_advantages_with_epsilon(advantages, std) - expected = torch.tensor([[2.5]]) - assert torch.allclose(result, expected, rtol=1e-5) - # Test with larger batch - batch_size = 10 - advantages = torch.ones(batch_size, 1) * 3.0 - std = torch.ones(batch_size) * 1.5 - result = normalize_advantages_with_epsilon(advantages, std) - expected = torch.ones(batch_size, 1) * 2.0 - assert torch.allclose(result, expected, rtol=1e-5) +def test_grpo_advantage_estimator_negative_advantages(): + """Test GRPOAdvantageEstimator with rewards that produce negative advantages. + This test verifies that negative advantages are handled correctly. + """ + estimator_config = { + "use_leave_one_out_baseline": False, + "normalize_rewards": True, + } + loss_config = {} + estimator = GRPOAdvantageEstimator(estimator_config, loss_config) -def test_normalize_advantages_with_epsilon_negative_advantages(): - """Test normalize_advantages_with_epsilon with negative advantages.""" - advantages = torch.tensor([[-2.0], [3.0], [-1.5]]) - std = torch.tensor([1.0, 1.5, 0.5]) + # Rewards with values below and above mean + prompt_ids = torch.tensor([[0], [0], [0]]) + rewards = torch.tensor([0.0, 2.0, 4.0]) # mean=2, deviations: -2, 0, +2 + mask = torch.ones(3, 4) - result = normalize_advantages_with_epsilon(advantages, std) + result = estimator.compute_advantage(prompt_ids, rewards, mask) - expected = torch.tensor([[-2.0], [2.0], [-3.0]]) - assert torch.allclose(result, expected, rtol=1e-5) + # Verify ordering: first should be negative, middle ~0, last positive + assert result[0, 0] < 0 # below mean -> negative advantage + assert torch.abs(result[1, 0]) < 1e-5 # at mean -> ~0 advantage + assert result[2, 0] > 0 # above mean -> positive advantage + # Verify symmetry + assert torch.allclose(result[0], -result[2], rtol=1e-5) -def test_normalize_advantages_with_zero_std_from_leave_one_out(): - """Test that zero std (from leave-one-out baseline) is handled gracefully by skipping normalization.""" - # Simulate the leave-one-out case: rewards [1.0, 0.0, 0.0, 0.0] - # Sample 0 has baseline from [0, 0, 0] -> std=0, advantage=1.0 - # Samples 1-3 have baseline from [1, 0, 0] -> std≈0.577, advantage≈-0.333 - advantages = torch.tensor([[1.0], [-0.333], [-0.333], [-0.333]]) - std = torch.tensor([0.0, 0.577, 0.577, 0.577]) - epsilon = 1e-6 - # Compute expected values BEFORE calling function (since it modifies in-place) - expected_sample_0 = advantages[0].clone() - expected_normalized = advantages[1:].clone() / (std[1:].unsqueeze(-1) + epsilon) +def test_grpo_advantage_estimator_zero_std_and_zero_advantage(): + """Test GRPOAdvantageEstimator when all rewards are identical (std=0, advantage=0). - result = normalize_advantages_with_epsilon(advantages, std, epsilon) + This test verifies that when all rewards for a prompt are the same: + 1. The advantages are all zero (since reward - mean = 0) + 2. No division by zero occurs (normalization is skipped when std=0) + """ + estimator_config = { + "use_leave_one_out_baseline": False, + "normalize_rewards": True, + } + loss_config = {} + estimator = GRPOAdvantageEstimator(estimator_config, loss_config) - # Sample 0: std=0 -> advantage unchanged (skip normalization) - assert torch.allclose(result[0], expected_sample_0, rtol=1e-5) + # All rewards identical -> std=0, all advantages=0 + prompt_ids = torch.tensor([[0], [0], [0], [0]]) + rewards = torch.tensor([5.0, 5.0, 5.0, 5.0]) # all same + mask = torch.ones(4, 3) - # Samples 1-3: std>0 -> normalized with epsilon - assert torch.allclose(result[1:], expected_normalized, rtol=1e-5) + result = estimator.compute_advantage(prompt_ids, rewards, mask) + # All advantages should be exactly 0 + expected = torch.zeros(4, 3) + assert torch.allclose(result, expected, rtol=1e-5) -def test_normalize_advantages_with_zero_std_and_zero_advantage(): - """Test that zero std with zero advantage is left unchanged.""" - advantages = torch.tensor([[0.0], [1.0], [0.0]]) - std = torch.tensor([0.0, 0.0, 1.0]) - epsilon = 1e-6 - # Compute expected values BEFORE calling function (since it modifies in-place) - expected_sample_0 = advantages[0].clone() - expected_sample_1 = advantages[1].clone() - expected_sample_2 = advantages[2].clone() / (std[2] + epsilon) +def test_grpo_advantage_estimator_small_nonzero_std(): + """Test GRPOAdvantageEstimator with small but non-zero std values. - result = normalize_advantages_with_epsilon(advantages, std, epsilon) + This test verifies that small but non-zero std values are still normalized + (no arbitrary threshold that would skip normalization). + """ + estimator_config = { + "use_leave_one_out_baseline": False, + "normalize_rewards": True, + } + loss_config = {} + estimator = GRPOAdvantageEstimator(estimator_config, loss_config) - # Sample 0: std=0, advantage=0 -> unchanged (skip normalization) - assert torch.allclose(result[0], expected_sample_0, rtol=1e-5) + # Small reward differences -> small std but non-zero + # Use larger difference to avoid floating point precision issues in std calculation + prompt_ids = torch.tensor([[0], [0]]) + rewards = torch.tensor([1.0, 1.01]) # small but detectable difference + mask = torch.ones(2, 3) - # Sample 1: std=0, advantage!=0 -> unchanged (skip normalization) - assert torch.allclose(result[1], expected_sample_1, rtol=1e-5) + result = estimator.compute_advantage(prompt_ids, rewards, mask) - # Sample 2: std>0 -> normalize with epsilon - assert torch.allclose(result[2], expected_sample_2, rtol=1e-5) + # Even with small std, normalization should still happen + # After normalization, the values should be ±1/sqrt(2) (for 2 samples with Bessel) + sqrt2_inv = 1.0 / (2.0**0.5) + assert torch.allclose(torch.abs(result[0, 0]), torch.tensor(sqrt2_inv), rtol=1e-3) + assert torch.allclose(torch.abs(result[1, 0]), torch.tensor(sqrt2_inv), rtol=1e-3) + # Verify opposite signs + assert result[0, 0] * result[1, 0] < 0 -def test_normalize_advantages_with_small_nonzero_std(): - """Test that small but non-zero std values still get normalized (no threshold).""" - advantages = torch.tensor([[2.0], [3.0], [-1.0]]) - std = torch.tensor([0.001, 0.01, 0.0001]) # All small but non-zero - # Compute expected values BEFORE calling function (since it modifies in-place) - expected = advantages.clone() / (std.unsqueeze(-1) + 1e-6) +# ============================================================================ +# Tests for ReinforcePlusPlusAdvantageEstimator class +# ============================================================================ - result = normalize_advantages_with_epsilon(advantages, std) - # All should be normalized since std > 0 - assert torch.allclose(result, expected, rtol=1e-5) +def test_reinforce_plus_plus_global_normalization(): + """Test that ReinforcePlusPlusAdvantageEstimator applies global normalization. + + This test verifies that: + 1. After global normalization, the mean of advantages is approximately 0 + 2. The advantages are properly scaled by the global std + """ + estimator_config = { + "minus_baseline": True, + } + loss_config = { + "use_kl_in_reward": False, + "reference_policy_kl_penalty": 0.0001, + "reference_policy_kl_type": "k2", + } + estimator = ReinforcePlusPlusAdvantageEstimator(estimator_config, loss_config) + + prompt_ids = torch.tensor( + [[0], [0], [0], [0]] + ) # Shape (4, 1) for unique prompt matching + rewards = torch.tensor([0.0, 1.0, 2.0, 3.0]) # mean=1.5 + mask = torch.ones(4, 5) + + result = estimator.compute_advantage(prompt_ids, rewards, mask) + + # After global normalization, mean should be ~0 + result_mean = (result * mask).sum() / mask.sum() + assert torch.abs(result_mean) < 1e-5 + + # Check the normalized advantages have correct relative ordering + # Lower rewards should have negative advantages, higher should have positive + assert result[0, 0] < result[1, 0] < result[2, 0] < result[3, 0] # ============================================================================ diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 4a09ce5b2a..5be0e69c80 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -1004,6 +1004,7 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( cfg = deepcopy(basic_pg_loss_test_config) cfg["use_importance_sampling_correction"] = True cfg["truncated_importance_sampling_ratio"] = 0.8 + cfg["truncated_importance_sampling_type"] = "tis" if sequence_level_importance_ratios: cfg["sequence_level_importance_ratios"] = True cfg["token_level_loss"] = False diff --git a/tests/unit/algorithms/test_reward_functions.py b/tests/unit/algorithms/test_reward_functions.py index 5d71bb3d8d..2f72b300cb 100755 --- a/tests/unit/algorithms/test_reward_functions.py +++ b/tests/unit/algorithms/test_reward_functions.py @@ -272,3 +272,61 @@ def test_reward_shaping_mismatched_lengths(): match="The number of messages in the batch must match the number of rewards", ): apply_reward_shaping(batch, config) + + +def test_stop_properly_penalty(): + """Test stop_properly_penalty_coef scales rewards for truncated samples.""" + batch = create_mock_batch_with_responses( + num_samples=4, + response_lengths=[10, 20, 30, 40], + initial_rewards=[1.0, 0.8, 0.6, 0.4], + ) + batch["truncated"] = torch.tensor([False, True, False, True]) + + config = RewardShapingConfig(enabled=True, stop_properly_penalty_coef=0.5) + result_batch = apply_reward_shaping(batch, config) + + # Non-truncated unchanged, truncated scaled by 0.5 + expected_rewards = torch.tensor([1.0, 0.4, 0.6, 0.2]) + assert torch.allclose(result_batch["total_reward"], expected_rewards, atol=1e-6) + + +def test_stop_properly_penalty_boundary_coefs(): + """Test boundary values: coef=0 gives zero reward, coef=1 has no effect.""" + # Test coef=0: truncated samples get zero reward + batch = create_mock_batch_with_responses( + num_samples=2, response_lengths=[10, 20], initial_rewards=[1.0, 0.5] + ) + batch["truncated"] = torch.tensor([True, True]) + + config = RewardShapingConfig(enabled=True, stop_properly_penalty_coef=0.0) + result = apply_reward_shaping(batch, config) + assert torch.allclose(result["total_reward"], torch.tensor([0.0, 0.0]), atol=1e-6) + + # Test coef=1: no penalty applied + batch["total_reward"] = torch.tensor([1.0, 0.5]) + config["stop_properly_penalty_coef"] = 1.0 + result = apply_reward_shaping(batch, config) + assert torch.allclose(result["total_reward"], torch.tensor([1.0, 0.5]), atol=1e-6) + + +def test_stop_properly_penalty_error_cases(): + """Test error handling for invalid coef and missing truncated field.""" + batch = create_mock_batch_with_responses( + num_samples=2, response_lengths=[10, 20], initial_rewards=[1.0, 0.5] + ) + + # Missing truncated field + config = RewardShapingConfig(enabled=True, stop_properly_penalty_coef=0.5) + with pytest.raises(AssertionError, match="truncated field not found"): + apply_reward_shaping(batch, config) + + # Invalid coef values + batch["truncated"] = torch.tensor([False, True]) + config["stop_properly_penalty_coef"] = -0.1 + with pytest.raises(AssertionError, match="stop_properly_penalty_coef must be in"): + apply_reward_shaping(batch, config) + + config["stop_properly_penalty_coef"] = 1.5 + with pytest.raises(AssertionError, match="stop_properly_penalty_coef must be in"): + apply_reward_shaping(batch, config) diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index 9c7761e748..60e880fd12 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -44,6 +44,7 @@ "distillation": "examples/configs/distillation_math.yaml", "rm": "examples/configs/rm.yaml", "dapo": "examples/configs/grpo_math_1B.yaml", + "prorlv2": "examples/configs/prorlv2.yaml", } # Configuration keys that are allowed to be added to base configs during testing @@ -216,7 +217,7 @@ def test_all_recipe_yamls_accounted_for_in_test_suites( ) -def test_nightly_compute_stays_below_1240_hours(nightly_test_suite, tracker): +def test_nightly_compute_stays_below_1300_hours(nightly_test_suite, tracker): command = f"DRYRUN=1 HF_HOME=... HF_DATASETS_CACHE=... CONTAINER= ACCOUNT= PARTITION= ./tools/launch {' '.join(nightly_test_suite)}" print(f"Running command: {command}") @@ -248,8 +249,8 @@ def test_nightly_compute_stays_below_1240_hours(nightly_test_suite, tracker): f"Last line of output was not as expected: '{last_line}'" ) total_gpu_hours = float(last_line.split(":")[-1].strip()) - assert total_gpu_hours <= 1240, ( - f"Total GPU hours exceeded 1240: {last_line}. We should revisit the test suites to reduce the total GPU hours." + assert total_gpu_hours <= 1300, ( + f"Total GPU hours exceeded 1300: {last_line}. We should revisit the test suites to reduce the total GPU hours." ) tracker.track("total_nightly_gpu_hours", total_gpu_hours) diff --git a/tools/config_cli.py b/tools/config_cli.py index 49ac296eee..f9dbaf3c93 100755 --- a/tools/config_cli.py +++ b/tools/config_cli.py @@ -133,7 +133,10 @@ def load_config_with_inheritance( base_config = OmegaConf.create({}) for default in defaults: parent_path = resolve_path(base_dir, str(default)) - parent_config = load_config_with_inheritance(parent_path, base_dir) + # Use parent's directory as base_dir for resolving its own defaults + parent_config = load_config_with_inheritance( + parent_path, parent_path.parent + ) base_config = cast( DictConfig, merge_with_override(base_config, parent_config) )