diff --git a/docs/adding-new-models.md b/docs/adding-new-models.md index 34aaaaf3b0..673cc602bf 100644 --- a/docs/adding-new-models.md +++ b/docs/adding-new-models.md @@ -8,7 +8,7 @@ In on-policy RL, we sample tokens (actions) from the latest version of the polic As an example, we would see errors in naive KL estimation: -$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ +$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$ When summed/integrated, replacing the $x \sim \pi$ with $x \sim \pi_{\text{wrong}}$ leads to an error of: @@ -17,12 +17,12 @@ $$\sum_{x} \left( \pi(x) - \pi_{\text{ref}}(x) \right) \left( \pi_{\text{wrong}} So, to verify correctness, we calculate $$ -\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-sampling-fwk}_i\right\|\right) +\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-inference-fwk}_i\right\|\right) $$ -where samples are drawn as $x \sim \pi_{\text{sampling-framework}}$ +where samples are drawn as $x \sim \pi_{\text{inference-framework}}$ -As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the sampling framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{sampling-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. +As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{inference-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient. ## Understanding Discrepancies Between Backends diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index b84cbf9f0c..58fa6a7c9e 100644 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -16,12 +16,13 @@ If not specified, `config` will default to [examples/configs/grpo.yaml](../../ex ## Now, for the details: -In this guide, we'll walk through we handle +In this guide, we'll walk through how we handle * Data * Model training * Fast generation * Overall Resource Flow +* Loss ### Data @@ -108,3 +109,60 @@ This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/wo We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now. The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop. + +### Loss +We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, + +$$ +L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) +$$ + +where: + +- $\pi_\theta$ is the policy model we are currently optimizing +- $\pi_{\theta_{\text{old}}}$ is the previous policy model (from the beginning of this step) +- $A_t$ is the advantage estimate +- $\varepsilon$ is a clipping hyperparameter +- $\beta$ is the KL penalty coefficient +- $\pi_{\text{ref}}$ is the reference policy + +#### Improvements to the GRPO loss formulation for stability and accuracy + +#### On-Policy KL Approximation + +In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive. + +$$ +D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] +$$ + +Note that the loss function above samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the KL approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights: + +$$ +\begin{align*} +D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= \sum_x \pi_{\theta}(x) \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= \sum_x \pi_{\theta_{\text{old}}}(x) \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +&= E_{x \sim \pi_{\theta_\text{old}}} \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\ +\end{align*} +$$ + +To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO. + + +#### Importance Sampling Correction +The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible for the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies (from numerics, precision differences, bugs, etc.), leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function. + +Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then, + +$$ +\begin{align*} +E_{x \sim \pi_\text{training}} f_\theta(x) &= \sum_x \pi_\text{training}(x) f_\theta(x) \\ +&= \sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\ +&= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) +\end{align*} +$$ + +By multiplying the first term of the loss function by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$. + +To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO. diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index 4cf474df01..72424f779c 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -14,6 +14,9 @@ loss_fn: reference_policy_kl_penalty: 0.01 ratio_eps_min: 0.2 ratio_eps_max: 0.2 + # (default off) loss formulation improvements (docs/guides/grpo.md#loss) + use_on_policy_kl_approximation: false + use_importance_sampling_correction: false checkpointing: enabled: true diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index d674e7deb0..dd9ac45acf 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -31,6 +31,8 @@ class ClippedPGLossConfig(TypedDict): reference_policy_kl_penalty: float ratio_eps_min: float ratio_eps_max: float + use_on_policy_kl_approximation: bool + use_importance_sampling_correction: bool class ClippedPGLossDataDict(TypedDict): @@ -80,6 +82,10 @@ def __init__(self, cfg: ClippedPGLossConfig): self.ratio_eps_max = cfg["ratio_eps_max"] self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False) + self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"] + self.use_importance_sampling_correction = cfg[ + "use_importance_sampling_correction" + ] def __call__( self, @@ -119,9 +125,23 @@ def __call__( # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: - kl = self.reference_policy_kl_penalty * calculate_kl_penalty_joschu2020( - logprobs_policy=curr_logprobs, - logprobs_reference=reference_policy_logprobs, + if self.use_on_policy_kl_approximation: + # See: docs/guides/grpo.md#on-policy-kl-approximation + kl_importance_weights = torch.exp( + curr_logprobs - generation_logprobs + ).detach() + kl_importance_weights = torch.nan_to_num( + kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 + ) + else: + kl_importance_weights = torch.ones_like(curr_logprobs) + kl = ( + kl_importance_weights + * self.reference_policy_kl_penalty + * calculate_kl_penalty_joschu2020( + logprobs_policy=curr_logprobs, + logprobs_reference=reference_policy_logprobs, + ) ) kl = masked_mean(kl, mask) else: @@ -140,7 +160,17 @@ def __call__( loss1 = -advantages * ratios loss2 = -advantages * ratios_clamped - actor_loss = masked_mean(torch.max(loss1, loss2), mask) + if self.use_importance_sampling_correction: + # See: docs/guides/grpo.md#importance-sampling-correction + actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs) + actor_importance_weights = torch.nan_to_num( + actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0 + ) + else: + actor_importance_weights = torch.ones_like(prev_logprobs) + actor_loss = masked_mean( + actor_importance_weights * torch.max(loss1, loss2), mask + ) loss = actor_loss + kl with torch.no_grad(): probs_ratio = masked_mean(ratios.detach(), mask).item() diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 447bd20a54..b6e9fab673 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -165,6 +165,8 @@ def test_clipped_pg_loss_ppo_clipping(): "ratio_eps_max": ratio_eps, "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -184,15 +186,38 @@ def test_clipped_pg_loss_ppo_clipping(): # --- Hand Calculation --- ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + ratios_clamped = torch.clamp( ratios, 1.0 - ratio_eps, 1.0 + ratio_eps ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + expected_loss = torch.mean( max_loss ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + assert torch.allclose( + expected_loss, torch.tensor(-0.6333, device=device), rtol=1e-3 + ) input_ids = data["input_ids"] dummy_logits = _create_exact_logits( @@ -217,6 +242,8 @@ def test_clipped_pg_loss_reinforce_mode(): "reference_policy_kl_penalty": 0.0, "ratio_eps_min": 0.0, # Placeholder, ignored "ratio_eps_max": 0.0, # Placeholder, ignored + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -229,7 +256,14 @@ def test_clipped_pg_loss_reinforce_mode(): # --- Hand Calculation --- expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + assert torch.allclose( + expected_loss_per_token, + torch.tensor([[0.5, -1.0, 3.0]], device=device), + rtol=1e-3, + ) + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + assert torch.allclose(expected_loss, torch.tensor(0.8333, device=device), rtol=1e-3) input_ids = data["input_ids"] dummy_logits = _create_exact_logits( @@ -256,6 +290,8 @@ def test_clipped_pg_loss_kl_penalty(): "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) @@ -273,9 +309,20 @@ def test_clipped_pg_loss_kl_penalty(): # Actor loss is 0. Total loss = kl_beta * mean(kl_term) # kl_term = exp(ref - curr) - (ref - curr) - 1 r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + assert torch.allclose(r, torch.tensor([[-1.0, 0.0, 1.0]], device=device), rtol=1e-3) + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + assert torch.allclose( + kl_term_per_token, torch.tensor([[0.368, 0.0, 0.718]], device=device), rtol=1e-3 + ) + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + assert torch.allclose( + expected_kl_mean, torch.tensor(0.362, device=device), rtol=1e-3 + ) + expected_loss = kl_beta * expected_kl_mean # 0.0362 + assert torch.allclose(expected_loss, torch.tensor(0.0362, device=device), rtol=1e-3) input_ids = data["input_ids"] dummy_logits = _create_exact_logits( @@ -315,6 +362,8 @@ def test_clipped_pg_loss_masking(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -376,6 +425,8 @@ def test_clipped_pg_loss_zero_mask(): "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": False, + "use_importance_sampling_correction": False, } loss_fn = ClippedPGLossFn(cfg) # Use original loss fn @@ -388,6 +439,150 @@ def test_clipped_pg_loss_zero_mask(): torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) +def test_clipped_pg_loss_on_policy_kl_importance_sampling(): + """Tests PPO loss with KL penalty and importance sampling enabled.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_eps = 0.2 + kl_beta = 0.1 + + cfg = { + "ratio_eps_min": ratio_eps, + "ratio_eps_max": ratio_eps, + "reference_policy_kl_penalty": kl_beta, + "disable_ppo_ratio": False, + "use_on_policy_kl_approximation": True, + "use_importance_sampling_correction": True, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + + # For Importance Sampling + gen_lp_masked = torch.tensor([[-0.5, -1.5, -0.8]], device=device) + + # Fill full tensors + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["generation_logprobs"][0, 1:] = gen_lp_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + + # --- Hand Calculation --- + # Actor Loss Calculation + actor_importance_weights = torch.exp( + prev_lp_masked - gen_lp_masked + ) # exp([-1 - (-0.5), -1 - (-1.5), -1 - (-0.8)]) = [0.6065, 1.6487, 0.8187] + assert torch.allclose( + actor_importance_weights, + torch.tensor([[0.6065, 1.6487, 0.8187]], device=device), + rtol=1e-3, + ) + + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # [0.5, 1.0, 1.5] + assert torch.allclose( + ratios, torch.tensor([[0.5, 1.0, 1.5]], device=device), rtol=1e-3 + ) + + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ) # [0.8, 1.0, 1.2] + assert torch.allclose( + ratios_clamped, torch.tensor([[0.8, 1.0, 1.2]], device=device), rtol=1e-3 + ) + + loss1 = -adv_masked * ratios # [-0.5, 1.0, -3.0] + assert torch.allclose( + loss1, torch.tensor([[-0.5, 1.0, -3.0]], device=device), rtol=1e-3 + ) + + loss2 = -adv_masked * ratios_clamped # [-0.8, 1.0, -2.4] + assert torch.allclose( + loss2, torch.tensor([[-0.8, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + max_loss = torch.maximum(loss1, loss2) # [-0.5, 1.0, -2.4] + assert torch.allclose( + max_loss, torch.tensor([[-0.5, 1.0, -2.4]], device=device), rtol=1e-3 + ) + + importance_weighted_max_loss = ( + actor_importance_weights * max_loss + ) # [0.6065*(-0.5), 1.6487*1.0, 0.8187*(-2.4)] = [-0.30325, 1.6487, -1.96488] + assert torch.allclose( + importance_weighted_max_loss, + torch.tensor([[-0.30325, 1.6487, -1.96488]], device=device), + rtol=1e-3, + ) + + expected_actor_loss = torch.mean(importance_weighted_max_loss) # -0.2065 + assert torch.allclose( + expected_actor_loss, torch.tensor(-0.2065, device=device), rtol=1e-3 + ) + + # KL Loss Calculation + kl_importance_weights = torch.exp( + curr_lp_masked - gen_lp_masked + ) # exp([-1.69315 - (-0.5), -1 - (-1.5), -0.59453 - (-0.8)]) = [0.3033, 1.6487, 1.2281] + assert torch.allclose( + kl_importance_weights, + torch.tensor([[0.3033, 1.6487, 1.2281]], device=device), + rtol=1e-3, + ) + + r = ( + ref_lp_masked - curr_lp_masked + ) # [-1.0 - (-1.69315), -1.0 - (-1.0), -1.0 - (-0.59453)] = [0.69315, 0.0, -0.40547] + assert torch.allclose( + r, torch.tensor([[0.69315, 0.0, -0.40547]], device=device), rtol=1e-3 + ) + + kl_term_per_token = ( + torch.exp(r) - r - 1 + ) # [exp(0.69315)-0.69315-1, exp(0)-0-1, exp(-0.40547)-(-0.40547)-1] = [0.3069, 0.0, 0.0721] + assert torch.allclose( + kl_term_per_token, + torch.tensor([[0.3069, 0.0, 0.0721]], device=device), + rtol=1e-3, + ) + # Apply importance weights to KL loss + # kl_term = importance_weights * kl_beta * kl_indiv + importance_weighted_kl_term_per_token = ( + kl_importance_weights * kl_term_per_token + ) # [0.3033*0.3069, 1.6487*0.0, 1.2281*0.0721] = [0.09308, 0.0, 0.08855] + assert torch.allclose( + importance_weighted_kl_term_per_token, + torch.tensor([[0.09308, 0.0, 0.08855]], device=device), + rtol=1e-3, + ) + + expected_kl_mean = torch.mean( + importance_weighted_kl_term_per_token + ) # mean([0.09308, 0.0, 0.08855]) = 0.060543 + expected_kl_loss = kl_beta * expected_kl_mean # 0.1 * 0.060543 = 0.0060543 + + expected_total_loss = ( + expected_actor_loss + expected_kl_loss + ) # -0.2065 + 0.0060543 = -0.2004457 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_total_loss, atol=1e-4, rtol=1e-3) + + def test_masked_mean_all_zeros(): """Test masked_mean function with all zeros mask.""" values = torch.tensor([1.0, 2.0, 3.0, 4.0])