diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index dac6be0a7e..14ddf507d1 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -13,8 +13,8 @@ grpo: loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 # (default off) loss formulation improvements (docs/guides/grpo.md#loss) use_on_policy_kl_approximation: false use_importance_sampling_correction: false diff --git a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.yaml b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.yaml index ba6ba255f3..7ea93d7425 100644 --- a/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-4n8g-fsdp2tp1-long.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.yaml index 96e8e023cb..6483c1a2da 100644 --- a/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-llama3.2-1b-instruct-1n8g-fsdp2tp1.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.yaml index 3693ac4677..0ddb85dc2f 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt-long.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.yaml index aed12183a8..5e3240a5ce 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-32b-16n8g-fsdp2tp8sp-actckpt.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.yaml index 27211ddc7e..fa2c0a7c29 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp1.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.yaml index 87e2c592c0..8c8629cb02 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-7b-instruct-4n8g-fsdp2tp4sp.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml index 9f5762f173..0dfb376979 100644 --- a/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml +++ b/examples/configs/recipes/llm/grpo-qwen2.5-math-1.5b-instruct-1n8g-fsdp2tp1.yaml @@ -11,8 +11,8 @@ grpo: val_batch_size: 256 loss_fn: reference_policy_kl_penalty: 0.01 - ratio_eps_min: 0.2 - ratio_eps_max: 0.2 + ratio_clip_min: 0.2 + ratio_clip_max: 0.2 use_on_policy_kl_approximation: false use_importance_sampling_correction: false checkpointing: diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 26441fe616..8d70325596 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -29,8 +29,8 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float - ratio_eps_min: float - ratio_eps_max: float + ratio_clip_min: float + ratio_clip_max: float use_on_policy_kl_approximation: bool use_importance_sampling_correction: bool @@ -55,7 +55,7 @@ class ClippedPGLossFn(LossFunction): - PPO (Clipped) - https://arxiv.org/abs/1707.06347 - GRPO - https://arxiv.org/abs/2402.03300 - - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_eps) - https://arxiv.org/abs/2402.14740 + - REINFORCE/RLOO (set disable_ppo_ratio = True and ignores ratio_clip_min/ratio_clip_max) - https://arxiv.org/abs/2402.14740 Formula: L(θ) = E_t [ min(r_t(θ) * A_t, clip(r_t(θ), 1-ε, 1+ε) * A_t) ] - β * KL(π_θ || π_ref) @@ -63,11 +63,11 @@ class ClippedPGLossFn(LossFunction): where: - r_t(θ) = π_θ(a_t|s_t) / π_θ_old(a_t|s_t) is the probability ratio - A_t is the advantage estimate - - ε is the clip parameter (ratio_eps) + - ε is the clip parameter (ratio_clip_min/ratio_clip_max) - As proposed in the DAPO paper (https://arxiv.org/pdf/2503.14476), we allow setting a distinct minimum and maximum value for the clip parameter (set to the same value for PPO/GRPO/etc.) - - ratio_eps_min: minimum value for the clip parameter - - ratio_eps_max: maximum value for the clip parameter + - ratio_clip_min: minimum value for the clip parameter + - ratio_clip_max: maximum value for the clip parameter - β is the KL penalty coefficient (reference_policy_kl_penalty) - KL(π_θ || π_ref) is the KL divergence between the current policy and reference policy (Schulman Approx.) @@ -78,8 +78,8 @@ class ClippedPGLossFn(LossFunction): """ def __init__(self, cfg: ClippedPGLossConfig): - self.ratio_eps_min = cfg["ratio_eps_min"] - self.ratio_eps_max = cfg["ratio_eps_max"] + self.ratio_clip_min = cfg["ratio_clip_min"] + self.ratio_clip_max = cfg["ratio_clip_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"] @@ -154,7 +154,7 @@ def __call__( if not self.disable_ppo_ratio: ratios = (curr_logprobs - prev_logprobs).exp() ratios_clamped = ratios.clamp( - 1.0 - self.ratio_eps_min, 1.0 + self.ratio_eps_max + 1.0 - self.ratio_clip_min, 1.0 + self.ratio_clip_max ) else: ratios = curr_logprobs diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index d36d8c0b89..5c59da9335 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -373,10 +373,10 @@ def test_clipped_pg_loss_ppo_clipping(): device = "cuda" data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_eps = 0.2 + ratio_clip = 0.2 cfg = { - "ratio_eps_min": ratio_eps, - "ratio_eps_max": ratio_eps, + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, @@ -405,7 +405,7 @@ def test_clipped_pg_loss_ppo_clipping(): ) ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ratios, 1.0 - ratio_clip, 1.0 + ratio_clip ) # [0.8, 1.0, 1.2] assert torch.allclose( ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 @@ -454,8 +454,8 @@ def test_clipped_pg_loss_reinforce_mode(): cfg = { "disable_ppo_ratio": True, "reference_policy_kl_penalty": 0.0, - "ratio_eps_min": 0.0, # Placeholder, ignored - "ratio_eps_max": 0.0, # Placeholder, ignored + "ratio_clip_min": 0.0, # Placeholder, ignored + "ratio_clip_max": 0.0, # Placeholder, ignored "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, } @@ -501,8 +501,8 @@ def test_clipped_pg_loss_kl_penalty(): kl_beta = 0.1 cfg = { "reference_policy_kl_penalty": kl_beta, - "ratio_eps_min": 0.2, - "ratio_eps_max": 0.2, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, "use_importance_sampling_correction": False, @@ -572,8 +572,8 @@ def test_clipped_pg_loss_masking(): data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 cfg = { - "ratio_eps_min": 0.2, - "ratio_eps_max": 0.2, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, @@ -635,8 +635,8 @@ def test_clipped_pg_loss_zero_mask(): dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) cfg = { - "ratio_eps_min": 0.2, - "ratio_eps_max": 0.2, + "ratio_clip_min": 0.2, + "ratio_clip_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": False, @@ -661,12 +661,12 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): device = "cuda" data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) - ratio_eps = 0.2 + ratio_clip = 0.2 kl_beta = 0.1 cfg = { - "ratio_eps_min": ratio_eps, - "ratio_eps_max": ratio_eps, + "ratio_clip_min": ratio_clip, + "ratio_clip_max": ratio_clip, "reference_policy_kl_penalty": kl_beta, "disable_ppo_ratio": False, "use_on_policy_kl_approximation": True, @@ -708,7 +708,7 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): ) ratios_clamped = torch.clamp( - ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ratios, 1.0 - ratio_clip, 1.0 + ratio_clip ) # [0.8, 1.0, 1.2] assert torch.allclose( ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3