Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 64 additions & 7 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,25 @@ def setup(
# ===============================================================================


def normalize_advantages_with_epsilon(
advantages: torch.Tensor,
std: torch.Tensor,
epsilon: float = 1e-6,
) -> torch.Tensor:
"""Normalize advantages by standard deviation with epsilon to avoid division by zero.

Args:
advantages: Tensor of shape (batch_size, 1) containing advantage values
std: Tensor of shape (batch_size,) containing standard deviation values
epsilon: Small value to avoid division by zero, defaults to 1e-6

Returns:
Normalized advantages tensor of same shape as input advantages
"""
# Use epsilon to avoid division by zero instead of masking
return advantages / (std.unsqueeze(-1) + epsilon)


def dynamic_sampling(
repeated_batch: BatchedDataDict[DatumSpec],
std: torch.Tensor,
Expand Down Expand Up @@ -1056,10 +1075,9 @@ def grpo_train(
advantages = (rewards - baseline).unsqueeze(-1)

if master_config["grpo"]["normalize_rewards"]:
# don't sharpen the ones with no variation
zero_std_mask = std > 0
advantages[zero_std_mask] = (
advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask]
advantages = normalize_advantages_with_epsilon(
advantages=advantages,
std=std,
)

with timer.time("data_processing"):
Expand Down Expand Up @@ -1172,12 +1190,31 @@ def grpo_train(
val_metrics, total_steps + 1, prefix="validation"
)

# Get flat advantages and token mask for masked metrics computation
flat_advantages = flat_messages["advantages"]
flat_token_mask = flat_messages["token_loss_mask"]

# Filter advantages using token mask (only valid response tokens)
response_advantages = torch.masked_select(
flat_advantages, flat_token_mask.bool()
)

metrics = {
"loss": train_results["loss"].numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"reward": rewards.numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
# Add masked advantages tracking metrics (only for valid response tokens)
"advantages/mean": torch.mean(response_advantages).detach().item()
if response_advantages.numel() > 0
else 0.0,
"advantages/max": torch.max(response_advantages).detach().item()
if response_advantages.numel() > 0
else 0.0,
"advantages/min": torch.min(response_advantages).detach().item()
if response_advantages.numel() > 0
else 0.0,
**ds_metrics,
}
if master_config["grpo"]["use_dynamic_sampling"]:
Expand Down Expand Up @@ -1929,10 +1966,11 @@ def async_grpo_train(
)

if master_config["grpo"]["normalize_rewards"]:
zero_std_mask = std > 0
advantages[zero_std_mask] = (
advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask]
advantages = normalize_advantages_with_epsilon(
advantages=advantages,
std=std,
)

print(
f" 📊 Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}"
)
Expand Down Expand Up @@ -2060,12 +2098,31 @@ def async_grpo_train(

# Resume trajectory collection after validation
trajectory_collector.resume.remote()
# Get flat advantages and token mask for masked metrics computation
flat_advantages = flat_messages["advantages"]
flat_token_mask = flat_messages["token_loss_mask"]

# Filter advantages using token mask (only valid response tokens)
response_advantages = torch.masked_select(
flat_advantages, flat_token_mask.bool()
)

metrics = {
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
# Add masked advantages tracking metrics (only for valid response tokens)
"advantages/mean": torch.mean(response_advantages).detach().item()
if response_advantages.numel() > 0
else 0.0,
"advantages/max": torch.max(response_advantages).detach().item()
if response_advantages.numel() > 0
else 0.0,
"advantages/min": torch.min(response_advantages).detach().item()
if response_advantages.numel() > 0
else 0.0,
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
Expand Down
12 changes: 10 additions & 2 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,12 @@ def calculate_baseline_and_std_per_prompt(

baseline = torch.zeros_like(rewards)
sq_baseline = torch.zeros_like(rewards)
std = torch.zeros_like(rewards)
device_ordinal = rewards.get_device()
if device_ordinal == -1:
reward_device = torch.device("cpu")
else:
reward_device = torch.device(reward_device)
reward_device = torch.device(f"cuda:{device_ordinal}")

for i in range(len(unique_prompts)):
is_matching_prompt = (prompts == unique_prompts[i]).all(1)
Expand Down Expand Up @@ -142,8 +143,15 @@ def calculate_baseline_and_std_per_prompt(

baseline[prompt_idx] = prompt_baseline
sq_baseline[prompt_idx] = prompt_baseline_square
std[prompt_idx] = (
(
(prompt_baseline_square - prompt_baseline.square())
* (num_valid / (num_valid - 1))
)
.sqrt()
.nan_to_num(0)
)

std = (sq_baseline - baseline.square()).sqrt().nan_to_num(0)
return baseline, std


Expand Down
73 changes: 73 additions & 0 deletions tests/unit/algorithms/test_grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
async_grpo_train,
dynamic_sampling,
grpo_train,
normalize_advantages_with_epsilon,
)
from nemo_rl.algorithms.loss_functions import ClippedPGLossFn
from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
Expand Down Expand Up @@ -1208,3 +1209,75 @@ def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys):
assert not (line.startswith("Step ") and "Step 9" in line), (
f"Training continued to next step after timeout: {line}"
)


# ============================================================================
# Tests for normalize_advantages_with_epsilon function
# ============================================================================


def test_normalize_advantages_with_epsilon_basic():
"""Test basic functionality of normalize_advantages_with_epsilon."""
# Test case with normal values
advantages = torch.tensor([[2.0], [4.0], [6.0]])
std = torch.tensor([1.0, 2.0, 3.0])
epsilon = 1e-6

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

expected = torch.tensor([[2.0], [2.0], [2.0]])
assert torch.allclose(result, expected, rtol=1e-5)


def test_normalize_advantages_with_epsilon_zero_std():
"""Test normalize_advantages_with_epsilon when std contains zeros."""
advantages = torch.tensor([[1.0], [2.0], [3.0]])
std = torch.tensor([0.0, 1.0, 0.0]) # Zero std for indices 0 and 2
epsilon = 1e-6

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

# When std=0, result should be advantages / epsilon
expected = torch.tensor([[1.0 / epsilon], [2.0], [3.0 / epsilon]])
assert torch.allclose(result, expected, rtol=1e-5)


def test_normalize_advantages_with_epsilon_all_zero_std():
"""Test normalize_advantages_with_epsilon when all std values are zero."""
advantages = torch.tensor([[1.5], [2.5], [3.5]])
std = torch.tensor([0.0, 0.0, 0.0])
epsilon = 1e-8

result = normalize_advantages_with_epsilon(advantages, std, epsilon)

expected = advantages / epsilon
assert torch.allclose(result, expected, rtol=1e-5)


def test_normalize_advantages_with_epsilon_tensor_shapes():
"""Test normalize_advantages_with_epsilon with different tensor shapes."""
# Test with batch size 1
advantages = torch.tensor([[5.0]])
std = torch.tensor([2.0])
result = normalize_advantages_with_epsilon(advantages, std)
expected = torch.tensor([[2.5]])
assert torch.allclose(result, expected, rtol=1e-5)

# Test with larger batch
batch_size = 10
advantages = torch.ones(batch_size, 1) * 3.0
std = torch.ones(batch_size) * 1.5
result = normalize_advantages_with_epsilon(advantages, std)
expected = torch.ones(batch_size, 1) * 2.0
assert torch.allclose(result, expected, rtol=1e-5)


def test_normalize_advantages_with_epsilon_negative_advantages():
"""Test normalize_advantages_with_epsilon with negative advantages."""
advantages = torch.tensor([[-2.0], [3.0], [-1.5]])
std = torch.tensor([1.0, 1.5, 0.5])

result = normalize_advantages_with_epsilon(advantages, std)

expected = torch.tensor([[-2.0], [2.0], [-3.0]])
assert torch.allclose(result, expected, rtol=1e-5)
Loading
Loading