Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
5883608
Importance sampling
yfw Apr 14, 2025
34d4d8b
Docs
yfw Apr 14, 2025
927f968
No math* in latex
yfw Apr 14, 2025
d750676
More doc fix
yfw Apr 14, 2025
726356e
Rename config to use_importance_sampling_correction
yfw Apr 14, 2025
ad69440
Add use_online_kl_approximation and assertions in test
yfw Apr 14, 2025
eb800b2
Docs
yfw Apr 14, 2025
a4489b5
Remove tag
yfw Apr 14, 2025
04732fd
Capitalization
yfw Apr 14, 2025
2d47a43
Typo
yfw Apr 14, 2025
d28883c
Merge remote-tracking branch 'origin' into yifu/importance_sampling
yfw Apr 15, 2025
cdfe7d6
on_policy
yfw Apr 15, 2025
060e662
Add use_on_policy_kl_approximation to config
yfw Apr 15, 2025
35052a7
Handle nan importance weights
yfw Apr 15, 2025
5144526
Detach kl importance weights
yfw Apr 15, 2025
495f259
Merge remote-tracking branch 'origin' into yifu/importance_sampling
yfw Apr 17, 2025
b80a3b1
ruff
yfw Apr 17, 2025
1e99327
Didn't commit by accident
yfw Apr 17, 2025
0932ee9
Fix docs
yfw Apr 17, 2025
4656f2e
Missed one
yfw Apr 17, 2025
a15878c
Merge remote-tracking branch 'origin' into yifu/importance_sampling
yfw Apr 17, 2025
b3a784b
Merge branch 'main' into yifu/importance_sampling
yfw Apr 18, 2025
3146bd1
Update docs/guides/grpo.md
yfw Apr 22, 2025
24011aa
Update docs/guides/grpo.md
yfw Apr 22, 2025
9bad18d
Update docs/guides/grpo.md
yfw Apr 22, 2025
d2f682e
Update docs/guides/grpo.md
yfw Apr 22, 2025
049ba18
Update docs/guides/grpo.md
yfw Apr 22, 2025
f468be7
Update examples/configs/grpo_math_1B.yaml
yfw Apr 22, 2025
12d33bc
Update nemo_reinforcer/algorithms/loss_functions.py
yfw Apr 22, 2025
fc0f7bc
Update nemo_reinforcer/algorithms/loss_functions.py
yfw Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/adding-new-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ In on-policy RL, we sample tokens (actions) from the latest version of the polic

As an example, we would see errors in naive KL estimation:

$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$
$$\text{KL} = E_{x \sim \pi}[\pi(x) - \pi_{\text{ref}}(x)]$$

When summed/integrated, replacing the $x \sim \pi$ with $x \sim \pi_{\text{wrong}}$ leads to an error of:

Expand All @@ -17,12 +17,12 @@ $$\sum_{x} \left( \pi(x) - \pi_{\text{ref}}(x) \right) \left( \pi_{\text{wrong}}
So, to verify correctness, we calculate

$$
\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-sampling-fwk}_i\right\|\right)
\frac{1}{n}\sum_{i=1}^{n\text{(tokens)}}\exp\left(\left\|\text{logprobs-train-fwk}_i - \text{logprobs-inference-fwk}_i\right\|\right)
$$

where samples are drawn as $x \sim \pi_{\text{sampling-framework}}$
where samples are drawn as $x \sim \pi_{\text{inference-framework}}$

As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the sampling framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{sampling-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient.
As a measure of multiplicative probability error for sampled tokens. Note that this is not exhaustive (the inference framework could lack distribution support and we wouldn't catch it here, as $x \sim \pi_{\text{inference-framework}}$). To get a much stricter guarantee on correctness, you should run this metric twice and average the results, where in the second run, you sample $x \sim \pi_{\text{training-framework}}$. In practice, we use just the former in our tests and find it sufficient.

## Understanding Discrepancies Between Backends

Expand Down
60 changes: 59 additions & 1 deletion docs/guides/grpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ If not specified, `config` will default to [examples/configs/grpo.yaml](../../ex

## Now, for the details:

In this guide, we'll walk through we handle
In this guide, we'll walk through how we handle

* Data
* Model training
* Fast generation
* Overall Resource Flow
* Loss

### Data

Expand Down Expand Up @@ -108,3 +109,60 @@ This Policy object holds a [RayWorkerGroup](../../nemo_reinforcer/distributed/wo
We support vLLM through the [VllmGeneration](../../nemo_reinforcer/models/generation/vllm.py) class right now.

The function [grpo_train](../../nemo_reinforcer/algorithms/grpo.py) contains the core GRPO training loop.

### Loss
We use the [ClippedPGLossFn](../../nemo_reinforcer/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally,

$$
L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref})
$$

where:

- $\pi_\theta$ is the policy model we are currently optimizing
- $\pi_{\theta_{\text{old}}}$ is the previous policy model (from the beginning of this step)
- $A_t$ is the advantage estimate
- $\varepsilon$ is a clipping hyperparameter
- $\beta$ is the KL penalty coefficient
- $\pi_{\text{ref}}$ is the reference policy

#### Improvements to the GRPO loss formulation for stability and accuracy

#### On-Policy KL Approximation

In practice, we calculate the KL divergence using the estimator from Schulman 2020 (http://joschu.net/blog/kl-approx.html), which is unbiased and guaranteed to be positive.

$$
D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) \approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big]
$$

Note that the loss function above samples from $\pi_{\theta_{\text{old}}}$ instead of $\pi_\theta$, meaning that the KL approximation is off-policy if we use samples from $\pi_{\theta_{\text{old}}}$. This is the default formulation used in the [original GRPO paper](https://arxiv.org/abs/2402.03300). In order to use an _on-policy_ KL approximation while sampling from $\pi_{\theta_{\text{old}}}$, we can incorporate importance weights:

$$
\begin{align*}
D_{\text{KL}} (\pi_\theta || \pi_\text{ref}) &\approx E_{x \sim \pi_{\theta}} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
&= \sum_x \pi_{\theta}(x) \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
&= \sum_x \pi_{\theta_{\text{old}}}(x) \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
&= E_{x \sim \pi_{\theta_\text{old}}} \frac{\pi_{\theta}(x)}{\pi_{\theta_{\text{old}}}(x)} \Big[ \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - \log \frac{\pi_\text{ref}(x)}{\pi_\theta(x)} - 1 \Big] \\
\end{align*}
$$

To enable the on-policy KL approximation, set the config `use_on_policy_kl_approximation=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.


#### Importance Sampling Correction
The policy we use to draw samples, $\pi_{\theta_{\text{old}}}$, is used in both the inference framework and the training framework. To account for this distinction, we refer to the inference framework policy as $\pi_{\text{inference}}$ and the training framework policy as $\pi_{\text{training}}$. As noted in [Adding New Models](../adding_new_models.md#understanding-discrepancies-between-backends), it is possible for the token probabilities from $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to have discrepancies (from numerics, precision differences, bugs, etc.), leading to off-policy samples. We can correct for this by introducing importance weights between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ to the first term of the loss function.

Let $f_\theta(x) = \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big)$ represent the first term of loss function. Then,

$$
\begin{align*}
E_{x \sim \pi_\text{training}} f_\theta(x) &= \sum_x \pi_\text{training}(x) f_\theta(x) \\
&= \sum_x \pi_\text{inference}(x) \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x) \\
&= E_{x \sim \pi_\text{inference}} \frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)} f_\theta(x)
\end{align*}
$$

By multiplying the first term of the loss function by the importance weights $\frac{\pi_\text{training}(x)}{\pi_\text{inference}(x)}$, we can correct for the distribution mismatch between $\pi_{\text{training}}$ and $\pi_{\text{inference}}$ while still sampling from $\pi_{\text{inference}}$.

To enable the importance sampling correction, set the config `use_importance_sampling_correction=True` in the `ClippedPGLossConfig`. By default, we set this config to False to align with standard GRPO.
3 changes: 3 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ loss_fn:
reference_policy_kl_penalty: 0.01
ratio_eps_min: 0.2
ratio_eps_max: 0.2
# (default off) loss formulation improvements (docs/guides/grpo.md#loss)
use_on_policy_kl_approximation: false
use_importance_sampling_correction: false

checkpointing:
enabled: true
Expand Down
38 changes: 34 additions & 4 deletions nemo_reinforcer/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ClippedPGLossConfig(TypedDict):
reference_policy_kl_penalty: float
ratio_eps_min: float
ratio_eps_max: float
use_on_policy_kl_approximation: bool
use_importance_sampling_correction: bool


class ClippedPGLossDataDict(TypedDict):
Expand Down Expand Up @@ -80,6 +82,10 @@ def __init__(self, cfg: ClippedPGLossConfig):
self.ratio_eps_max = cfg["ratio_eps_max"]
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.disable_ppo_ratio = cfg.get("disable_ppo_ratio", False)
self.use_on_policy_kl_approximation = cfg["use_on_policy_kl_approximation"]
self.use_importance_sampling_correction = cfg[
"use_importance_sampling_correction"
]

def __call__(
self,
Expand Down Expand Up @@ -119,9 +125,23 @@ def __call__(

# Calculate KL regularization.
if self.reference_policy_kl_penalty != 0:
kl = self.reference_policy_kl_penalty * calculate_kl_penalty_joschu2020(
logprobs_policy=curr_logprobs,
logprobs_reference=reference_policy_logprobs,
if self.use_on_policy_kl_approximation:
# See: docs/guides/grpo.md#on-policy-kl-approximation
kl_importance_weights = torch.exp(
curr_logprobs - generation_logprobs
).detach()
kl_importance_weights = torch.nan_to_num(
kl_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
else:
kl_importance_weights = torch.ones_like(curr_logprobs)
kl = (
kl_importance_weights
* self.reference_policy_kl_penalty
* calculate_kl_penalty_joschu2020(
logprobs_policy=curr_logprobs,
logprobs_reference=reference_policy_logprobs,
)
)
kl = masked_mean(kl, mask)
else:
Expand All @@ -140,7 +160,17 @@ def __call__(
loss1 = -advantages * ratios
loss2 = -advantages * ratios_clamped

actor_loss = masked_mean(torch.max(loss1, loss2), mask)
if self.use_importance_sampling_correction:
# See: docs/guides/grpo.md#importance-sampling-correction
actor_importance_weights = torch.exp(prev_logprobs - generation_logprobs)
actor_importance_weights = torch.nan_to_num(
actor_importance_weights, nan=0.0, posinf=0.0, neginf=0.0
)
else:
actor_importance_weights = torch.ones_like(prev_logprobs)
actor_loss = masked_mean(
actor_importance_weights * torch.max(loss1, loss2), mask
)
loss = actor_loss + kl
with torch.no_grad():
probs_ratio = masked_mean(ratios.detach(), mask).item()
Expand Down
Loading