From f3cbbc495782b19c929625c8b65498d29f1a707b Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Mon, 27 Oct 2025 16:11:14 -0700 Subject: [PATCH 1/5] Refactor normalize_advantages function and add tests. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/algorithms/grpo.py | 41 ++- nemo_rl/algorithms/utils.py | 9 + tests/unit/algorithms/test_grpo.py | 437 ++++++++++++++++++++++++++++ tests/unit/algorithms/test_utils.py | 200 +++++++++++++ 4 files changed, 680 insertions(+), 7 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index fdfa5e3ed2..818db4bb2c 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -538,6 +538,25 @@ def setup( # =============================================================================== +def normalize_advantages_with_epsilon( + advantages: torch.Tensor, + std: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Normalize advantages by standard deviation with epsilon to avoid division by zero. + + Args: + advantages: Tensor of shape (batch_size, 1) containing advantage values + std: Tensor of shape (batch_size,) containing standard deviation values + epsilon: Small value to avoid division by zero, defaults to 1e-6 + + Returns: + Normalized advantages tensor of same shape as input advantages + """ + # Use epsilon to avoid division by zero instead of masking + return advantages / (std.unsqueeze(-1) + epsilon) + + def dynamic_sampling( repeated_batch: BatchedDataDict[DatumSpec], std: torch.Tensor, @@ -1056,10 +1075,9 @@ def grpo_train( advantages = (rewards - baseline).unsqueeze(-1) if master_config["grpo"]["normalize_rewards"]: - # don't sharpen the ones with no variation - zero_std_mask = std > 0 - advantages[zero_std_mask] = ( - advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + advantages = normalize_advantages_with_epsilon( + advantages=advantages, + std=std, ) with timer.time("data_processing"): @@ -1178,6 +1196,10 @@ def grpo_train( "reward": rewards.numpy(), "mean_prompt_length": repeated_batch["length"].numpy(), "total_num_tokens": input_lengths.numpy(), + # Add advantages tracking metrics + "advantages/mean": torch.mean(advantages).detach().item(), + "advantages/max": torch.max(advantages).detach().item(), + "advantages/min": torch.min(advantages).detach().item(), **ds_metrics, } if master_config["grpo"]["use_dynamic_sampling"]: @@ -1929,10 +1951,11 @@ def async_grpo_train( ) if master_config["grpo"]["normalize_rewards"]: - zero_std_mask = std > 0 - advantages[zero_std_mask] = ( - advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + advantages = normalize_advantages_with_epsilon( + advantages=advantages, + std=std, ) + print( f" 📊 Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" ) @@ -2066,6 +2089,10 @@ def async_grpo_train( "grad_norm": train_results["grad_norm"].numpy(), "mean_prompt_length": repeated_batch["length"].numpy(), "total_num_tokens": input_lengths.numpy(), + # Add advantages tracking metrics + "advantages/mean": torch.mean(advantages).detach().item(), + "advantages/max": torch.max(advantages).detach().item(), + "advantages/min": torch.min(advantages).detach().item(), } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items(): diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index e323bec734..306d7ba307 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -99,6 +99,7 @@ def calculate_baseline_and_std_per_prompt( baseline = torch.zeros_like(rewards) sq_baseline = torch.zeros_like(rewards) + std = torch.zeros_like(rewards) device_ordinal = rewards.get_device() if device_ordinal == -1: reward_device = torch.device("cpu") @@ -142,6 +143,14 @@ def calculate_baseline_and_std_per_prompt( baseline[prompt_idx] = prompt_baseline sq_baseline[prompt_idx] = prompt_baseline_square + std[prompt_idx] = ( + ( + (prompt_baseline_square - prompt_baseline.square()) + * (num_valid / (num_valid - 1)) + ) + .sqrt() + .nan_to_num(0) + ) std = (sq_baseline - baseline.square()).sqrt().nan_to_num(0) return baseline, std diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 6d42ad553d..6fab6d9276 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from unittest.mock import MagicMock, patch import pytest @@ -24,6 +25,7 @@ async_grpo_train, dynamic_sampling, grpo_train, + normalize_advantages_with_epsilon, ) from nemo_rl.algorithms.loss_functions import ClippedPGLossFn from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType @@ -1208,3 +1210,438 @@ def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys): assert not (line.startswith("Step ") and "Step 9" in line), ( f"Training continued to next step after timeout: {line}" ) + + +# ============================================================================ +# Tests for normalize_advantages_with_epsilon function +# ============================================================================ + + +def test_normalize_advantages_with_epsilon_basic(): + """Test basic functionality of normalize_advantages_with_epsilon.""" + # Test case with normal values + advantages = torch.tensor([[2.0], [4.0], [6.0]]) + std = torch.tensor([1.0, 2.0, 3.0]) + epsilon = 1e-6 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + expected = torch.tensor([[2.0], [2.0], [2.0]]) + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_zero_std(): + """Test normalize_advantages_with_epsilon when std contains zeros.""" + advantages = torch.tensor([[1.0], [2.0], [3.0]]) + std = torch.tensor([0.0, 1.0, 0.0]) # Zero std for indices 0 and 2 + epsilon = 1e-6 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + # When std=0, result should be advantages / epsilon + expected = torch.tensor([[1.0 / epsilon], [2.0], [3.0 / epsilon]]) + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_all_zero_std(): + """Test normalize_advantages_with_epsilon when all std values are zero.""" + advantages = torch.tensor([[1.5], [2.5], [3.5]]) + std = torch.tensor([0.0, 0.0, 0.0]) + epsilon = 1e-8 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + expected = advantages / epsilon + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_different_epsilons(): + """Test normalize_advantages_with_epsilon with different epsilon values.""" + advantages = torch.tensor([[1.0], [2.0]]) + std = torch.tensor([0.0, 0.5]) + + # Test with different epsilon values + epsilons = [1e-6, 1e-8, 1e-4] + + for eps in epsilons: + result = normalize_advantages_with_epsilon(advantages, std, eps) + expected = torch.tensor([[1.0 / eps], [2.0 / 0.5]]) + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_tensor_shapes(): + """Test normalize_advantages_with_epsilon with different tensor shapes.""" + # Test with batch size 1 + advantages = torch.tensor([[5.0]]) + std = torch.tensor([2.0]) + result = normalize_advantages_with_epsilon(advantages, std) + expected = torch.tensor([[2.5]]) + assert torch.allclose(result, expected, rtol=1e-5) + + # Test with larger batch + batch_size = 10 + advantages = torch.ones(batch_size, 1) * 3.0 + std = torch.ones(batch_size) * 1.5 + result = normalize_advantages_with_epsilon(advantages, std) + expected = torch.ones(batch_size, 1) * 2.0 + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_negative_advantages(): + """Test normalize_advantages_with_epsilon with negative advantages.""" + advantages = torch.tensor([[-2.0], [3.0], [-1.5]]) + std = torch.tensor([1.0, 1.5, 0.5]) + + result = normalize_advantages_with_epsilon(advantages, std) + + expected = torch.tensor([[-2.0], [2.0], [-3.0]]) + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_very_small_std(): + """Test normalize_advantages_with_epsilon with very small std values.""" + advantages = torch.tensor([[1.0], [2.0]]) + std = torch.tensor([1e-10, 2e-9]) # Very small but non-zero std + epsilon = 1e-6 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + # Since std values are much smaller than epsilon, result should be close to advantages/epsilon + expected = torch.tensor([[1.0 / epsilon], [2.0 / epsilon]]) + assert torch.allclose(result, expected, rtol=1e-3) + + +def test_normalize_advantages_with_epsilon_gradient_flow(): + """Test that normalize_advantages_with_epsilon preserves gradients.""" + advantages = torch.tensor([[2.0], [4.0]], requires_grad=True) + std = torch.tensor([1.0, 2.0]) + + result = normalize_advantages_with_epsilon(advantages, std) + loss = result.sum() + loss.backward() + + # Check that gradients are computed correctly + assert advantages.grad is not None + expected_grad = torch.tensor([[1.0], [0.5]]) + assert torch.allclose(advantages.grad, expected_grad, rtol=1e-5) + + +# ============================================================================ +# Tests for advantages metrics tracking in GRPO +# ============================================================================ + + +def test_grpo_advantages_metrics_tracking(mock_grpo_components): + """Test that advantages metrics (mean, max, min) are correctly tracked during GRPO training.""" + from unittest.mock import patch + + # Set up config to enable reward normalization (where advantages metrics are logged) + mock_grpo_components["master_config"]["grpo"]["normalize_rewards"] = True + mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = ( + 2 # Just run 2 steps + ) + + grpo_save_state = _default_grpo_save_state() + + # Mock batch with specific rewards to get predictable advantages + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test", + "token_ids": torch.tensor([1, 2, 3]), + } + ] + ] + * 4, + "task_name": ["math"] * 4, + "extra_env_info": [{}] * 4, + "loss_multiplier": torch.tensor([1.0, 1.0, 1.0, 1.0]), + "idx": torch.tensor([0, 1, 2, 3]), + "length": torch.tensor([3, 3, 3, 3]), + "total_reward": torch.tensor( + [1.0, 3.0, 2.0, 4.0] + ), # Known rewards for predictable advantages + } + ) + + mock_rollout_metrics = { + "mean_gen_tokens_per_sample": 10.0, + "max_gen_tokens": 20, + "min_gen_tokens": 5, + } + + # Capture the logged metrics + logged_metrics = [] + original_log_metrics = mock_grpo_components["logger"].log_metrics + + def capture_log_metrics(metrics, step): + logged_metrics.append((dict(metrics), step)) + return original_log_metrics(metrics, step) + + mock_grpo_components["logger"].log_metrics.side_effect = capture_log_metrics + + with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: + mock_rollout.return_value = (mock_batch, mock_rollout_metrics) + + # Run GRPO training + grpo_train( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + + # Verify that advantages metrics were logged + assert len(logged_metrics) >= 2, "Should have logged metrics for at least 2 steps" + + # Check each logged step for advantages metrics + for metrics_dict, step in logged_metrics: + if "advantages/mean" in metrics_dict: + # Verify that all three advantages metrics are present + assert "advantages/mean" in metrics_dict + assert "advantages/max" in metrics_dict + assert "advantages/min" in metrics_dict + + # Verify that the values are reasonable (should be numeric) + assert isinstance(metrics_dict["advantages/mean"], (int, float)) + assert isinstance(metrics_dict["advantages/max"], (int, float)) + assert isinstance(metrics_dict["advantages/min"], (int, float)) + + # Verify logical relationships between min, mean, max + assert metrics_dict["advantages/min"] <= metrics_dict["advantages/mean"] + assert metrics_dict["advantages/mean"] <= metrics_dict["advantages/max"] + + +# ============================================================================ +# Integration tests for GRPO with new advantage normalization logic +# ============================================================================ + + +def test_grpo_integration_with_new_normalization(mock_grpo_components): + """Integration test for GRPO training with the new advantage normalization logic.""" + from unittest.mock import patch + + # Enable reward normalization to test the new normalization logic + mock_grpo_components["master_config"]["grpo"]["normalize_rewards"] = True + mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 3 + + grpo_save_state = _default_grpo_save_state() + + # Create test batch with mixed std scenarios (including zero std cases) + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test", + "token_ids": torch.tensor([1, 2, 3]), + } + ] + ] + * 6, + "task_name": ["math"] * 6, + "extra_env_info": [{}] * 6, + "loss_multiplier": torch.tensor([1.0] * 6), + "idx": torch.tensor([0, 1, 2, 3, 4, 5]), + "length": torch.tensor([3] * 6), + # Mix of rewards: some with variation, some identical (zero std) + "total_reward": torch.tensor([1.0, 1.0, 1.0, 2.0, 4.0, 6.0]), + } + ) + + mock_rollout_metrics = { + "mean_gen_tokens_per_sample": 10.0, + "max_gen_tokens": 20, + "min_gen_tokens": 5, + } + + # Track calls to normalize_advantages_with_epsilon + original_normalize = normalize_advantages_with_epsilon + normalize_calls = [] + + def track_normalize_calls(advantages, std, epsilon=1e-6): + normalize_calls.append((advantages.clone(), std.clone(), epsilon)) + return original_normalize(advantages, std, epsilon) + + # Capture logged metrics to verify advantages tracking + logged_metrics = [] + original_log_metrics = mock_grpo_components["logger"].log_metrics + + def capture_log_metrics(metrics, step): + logged_metrics.append((dict(metrics), step)) + return original_log_metrics(metrics, step) + + mock_grpo_components["logger"].log_metrics.side_effect = capture_log_metrics + + with ( + patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout, + patch( + "nemo_rl.algorithms.grpo.normalize_advantages_with_epsilon", + side_effect=track_normalize_calls, + ) as mock_normalize, + ): + mock_rollout.return_value = (mock_batch, mock_rollout_metrics) + + # Run GRPO training + grpo_train( + mock_grpo_components["policy"], + None, # policy_generation + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + + # Verify that normalize_advantages_with_epsilon was called + assert len(normalize_calls) >= 3, ( + "normalize_advantages_with_epsilon should be called for each training step" + ) + assert mock_normalize.call_count >= 3, ( + "Should have called normalization function at least 3 times" + ) + + # Verify the function was called with expected parameters + for advantages, std, epsilon in normalize_calls: + # Verify tensor shapes are correct + assert advantages.dim() == 2, "Advantages should be 2D tensor" + assert advantages.shape[1] == 1, "Advantages should have shape (batch_size, 1)" + assert std.dim() == 1, "Std should be 1D tensor" + assert advantages.shape[0] == std.shape[0], ( + "Advantages and std should have same batch size" + ) + + # Verify epsilon is the expected default + assert epsilon == 1e-6, "Should use default epsilon value" + + # Verify that std values are non-negative (after baseline calculation) + assert (std >= 0).all(), "All std values should be non-negative" + + # Verify advantages metrics are logged + advantages_metrics_found = False + for metrics_dict, step in logged_metrics: + if "advantages/mean" in metrics_dict: + advantages_metrics_found = True + # Verify all three metrics are present + assert "advantages/mean" in metrics_dict + assert "advantages/max" in metrics_dict + assert "advantages/min" in metrics_dict + + # Verify metrics are finite numbers + assert math.isfinite(metrics_dict["advantages/mean"]) + assert math.isfinite(metrics_dict["advantages/max"]) + assert math.isfinite(metrics_dict["advantages/min"]) + + assert advantages_metrics_found, ( + "Should have logged advantages metrics during training" + ) + + # Verify training completed successfully (no exceptions raised) + assert mock_grpo_components["policy"].train.call_count == 3, ( + "Should have completed 3 training steps" + ) + + +def test_grpo_normalization_with_all_zero_std(mock_grpo_components): + """Test GRPO training when all prompts have zero standard deviation.""" + from unittest.mock import patch + + # Enable reward normalization + mock_grpo_components["master_config"]["grpo"]["normalize_rewards"] = True + mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 1 + + grpo_save_state = _default_grpo_save_state() + + # Create batch where all rewards are identical (zero std everywhere) + mock_batch = BatchedDataDict[DatumSpec]( + { + "message_log": [ + [ + { + "role": "user", + "content": "test", + "token_ids": torch.tensor([1, 2, 3]), + } + ] + ] + * 4, + "task_name": ["math"] * 4, + "extra_env_info": [{}] * 4, + "loss_multiplier": torch.tensor([1.0] * 4), + "idx": torch.tensor([0, 1, 2, 3]), + "length": torch.tensor([3] * 4), + "total_reward": torch.tensor([2.5, 2.5, 2.5, 2.5]), # All identical rewards + } + ) + + mock_rollout_metrics = { + "mean_gen_tokens_per_sample": 10.0, + "max_gen_tokens": 20, + "min_gen_tokens": 5, + } + + # Track normalization calls to verify epsilon is used properly + normalize_calls = [] + original_normalize = normalize_advantages_with_epsilon + + def track_normalize_calls(advantages, std, epsilon=1e-6): + normalize_calls.append((advantages.clone(), std.clone(), epsilon)) + return original_normalize(advantages, std, epsilon) + + with ( + patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout, + patch( + "nemo_rl.algorithms.grpo.normalize_advantages_with_epsilon", + side_effect=track_normalize_calls, + ), + ): + mock_rollout.return_value = (mock_batch, mock_rollout_metrics) + + # Run GRPO training - should complete without errors even with all zero std + grpo_train( + mock_grpo_components["policy"], + None, + mock_grpo_components["train_dataloader"], + mock_grpo_components["val_dataloader"], + mock_grpo_components["tokenizer"], + mock_grpo_components["loss_fn"], + mock_grpo_components["task_to_env"], + mock_grpo_components["val_task_to_env"], + mock_grpo_components["logger"], + mock_grpo_components["checkpointer"], + grpo_save_state, + mock_grpo_components["master_config"], + ) + + # Verify normalization was called and handled zero std correctly + assert len(normalize_calls) >= 1, "Should have called normalization at least once" + + for advantages, std, epsilon in normalize_calls: + # When all rewards are identical, std should be 0 + if (std == 0).all(): + # Verify that normalized advantages = advantages / epsilon + expected_normalized = advantages / epsilon + actual_normalized = normalize_advantages_with_epsilon( + advantages, std, epsilon + ) + assert torch.allclose(actual_normalized, expected_normalized, rtol=1e-5) + + # Verify training completed successfully + assert mock_grpo_components["policy"].train.call_count == 1 diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py index ce049a19db..45220acf29 100755 --- a/tests/unit/algorithms/test_utils.py +++ b/tests/unit/algorithms/test_utils.py @@ -19,6 +19,7 @@ import torch from nemo_rl.algorithms.utils import ( + calculate_baseline_and_std_per_prompt, get_tokenizer, maybe_pad_last_batch, print_performance_metrics, @@ -393,3 +394,202 @@ def test_minimal_inputs_no_counts_no_flops(capsys): out = capsys.readouterr().out assert "Throughputs (per GPU)" in out + + +# ============================================================================ +# Tests for calculate_baseline_and_std_per_prompt function +# ============================================================================ + + +def test_calculate_baseline_and_std_per_prompt_basic(): + """Test basic functionality of calculate_baseline_and_std_per_prompt.""" + # Create rewards for 2 prompts, each with 3 generations + rewards = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # For prompt 0: rewards [1, 2, 3] -> mean = 2.0, std = 1.0 + # For prompt 1: rewards [4, 5, 6] -> mean = 5.0, std = 1.0 + expected_baseline = torch.tensor([2.0, 2.0, 2.0, 5.0, 5.0, 5.0]) + expected_std = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_single_generation_per_prompt(): + """Test calculate_baseline_and_std_per_prompt when num_valid < 2 (single generation per prompt).""" + # Case where each prompt has only 1 generation (num_valid = 1 < 2) + rewards = torch.tensor([2.5, 4.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + ] + ) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # When num_valid < 2, std should be 0 due to nan_to_num(0) + expected_baseline = torch.tensor([2.5, 4.0]) + expected_std = torch.tensor([0.0, 0.0]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_identical_rewards(): + """Test calculate_baseline_and_std_per_prompt when all rewards for a prompt are identical.""" + # All generations for both prompts have the same reward + rewards = torch.tensor([3.0, 3.0, 3.0, 7.0, 7.0, 7.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # When all rewards are identical, std should be 0 + expected_baseline = torch.tensor([3.0, 3.0, 3.0, 7.0, 7.0, 7.0]) + expected_std = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_mixed_prompt_sizes(): + """Test calculate_baseline_and_std_per_prompt with different number of generations per prompt.""" + # Prompt 0 has 2 generations, Prompt 1 has 3 generations + rewards = torch.tensor([1.0, 2.0, 4.0, 5.0, 6.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # For prompt 0: rewards [1, 2] -> mean = 1.5, std = 0.5 + # For prompt 1: rewards [4, 5, 6] -> mean = 5.0, std = 1.0 + expected_baseline = torch.tensor([1.5, 1.5, 5.0, 5.0, 5.0]) + expected_std = torch.tensor([0.5, 0.5, 1.0, 1.0, 1.0]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_empty_input(): + """Test calculate_baseline_and_std_per_prompt with empty tensors.""" + rewards = torch.tensor([]) + prompts = torch.empty(0, 3, dtype=torch.long) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + assert baseline.shape == torch.Size([0]) + assert std.shape == torch.Size([0]) + assert torch.equal(baseline, torch.tensor([])) + assert torch.equal(std, torch.tensor([])) + + +def test_calculate_baseline_and_std_per_prompt_nan_handling(): + """Test calculate_baseline_and_std_per_prompt handles NaN values correctly.""" + # Include some NaN rewards + rewards = torch.tensor([1.0, float("nan"), 3.0, 4.0, 5.0, 6.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 (NaN reward) + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # The function should handle NaN values gracefully + # For prompt 0: only valid rewards [1, 3] -> mean = 2.0, std = 1.0 + # For prompt 1: rewards [4, 5, 6] -> mean = 5.0, std = 1.0 + + # Check that NaN positions get filled appropriately + assert not torch.isnan(baseline).any(), "Baseline should not contain NaN values" + assert not torch.isnan(std).any(), ( + "Std should not contain NaN values due to nan_to_num(0)" + ) + + +def test_calculate_baseline_and_std_per_prompt_cuda_compatibility(): + """Test calculate_baseline_and_std_per_prompt works with CUDA tensors if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + rewards = torch.tensor([1.0, 2.0, 3.0, 4.0]).cuda() + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ).cuda() + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # Verify results are on CUDA and have expected values + assert baseline.device.type == "cuda" + assert std.device.type == "cuda" + expected_baseline = torch.tensor([1.5, 1.5, 4.0, 4.0]).cuda() + expected_std = torch.tensor( + [0.5, 0.5, 0.0, 0.0] + ).cuda() # std=0 for single sample per prompt + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_numerical_precision(): + """Test calculate_baseline_and_std_per_prompt with edge case numerical values.""" + # Use very small and very large values + rewards = torch.tensor([1e-8, 2e-8, 3e-8, 1e8, 2e8, 3e8]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + + baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + + # For prompt 0: very small values [1e-8, 2e-8, 3e-8] -> mean = 2e-8 + # For prompt 1: very large values [1e8, 2e8, 3e8] -> mean = 2e8 + expected_baseline = torch.tensor([2e-8, 2e-8, 2e-8, 2e8, 2e8, 2e8]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + # Std values should be finite and not NaN + assert torch.isfinite(std).all() + assert not torch.isnan(std).any() From eb961c8f74dffcb87d8a9690677071be98b223cd Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 28 Oct 2025 15:57:52 -0700 Subject: [PATCH 2/5] Update tests, rebase from main. Signed-off-by: Felipe Vieira Frujeri --- tests/unit/algorithms/test_grpo.py | 364 ----------------------------- 1 file changed, 364 deletions(-) diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 6fab6d9276..07e94f8c48 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from unittest.mock import MagicMock, patch import pytest @@ -1255,20 +1254,6 @@ def test_normalize_advantages_with_epsilon_all_zero_std(): assert torch.allclose(result, expected, rtol=1e-5) -def test_normalize_advantages_with_epsilon_different_epsilons(): - """Test normalize_advantages_with_epsilon with different epsilon values.""" - advantages = torch.tensor([[1.0], [2.0]]) - std = torch.tensor([0.0, 0.5]) - - # Test with different epsilon values - epsilons = [1e-6, 1e-8, 1e-4] - - for eps in epsilons: - result = normalize_advantages_with_epsilon(advantages, std, eps) - expected = torch.tensor([[1.0 / eps], [2.0 / 0.5]]) - assert torch.allclose(result, expected, rtol=1e-5) - - def test_normalize_advantages_with_epsilon_tensor_shapes(): """Test normalize_advantages_with_epsilon with different tensor shapes.""" # Test with batch size 1 @@ -1296,352 +1281,3 @@ def test_normalize_advantages_with_epsilon_negative_advantages(): expected = torch.tensor([[-2.0], [2.0], [-3.0]]) assert torch.allclose(result, expected, rtol=1e-5) - - -def test_normalize_advantages_with_epsilon_very_small_std(): - """Test normalize_advantages_with_epsilon with very small std values.""" - advantages = torch.tensor([[1.0], [2.0]]) - std = torch.tensor([1e-10, 2e-9]) # Very small but non-zero std - epsilon = 1e-6 - - result = normalize_advantages_with_epsilon(advantages, std, epsilon) - - # Since std values are much smaller than epsilon, result should be close to advantages/epsilon - expected = torch.tensor([[1.0 / epsilon], [2.0 / epsilon]]) - assert torch.allclose(result, expected, rtol=1e-3) - - -def test_normalize_advantages_with_epsilon_gradient_flow(): - """Test that normalize_advantages_with_epsilon preserves gradients.""" - advantages = torch.tensor([[2.0], [4.0]], requires_grad=True) - std = torch.tensor([1.0, 2.0]) - - result = normalize_advantages_with_epsilon(advantages, std) - loss = result.sum() - loss.backward() - - # Check that gradients are computed correctly - assert advantages.grad is not None - expected_grad = torch.tensor([[1.0], [0.5]]) - assert torch.allclose(advantages.grad, expected_grad, rtol=1e-5) - - -# ============================================================================ -# Tests for advantages metrics tracking in GRPO -# ============================================================================ - - -def test_grpo_advantages_metrics_tracking(mock_grpo_components): - """Test that advantages metrics (mean, max, min) are correctly tracked during GRPO training.""" - from unittest.mock import patch - - # Set up config to enable reward normalization (where advantages metrics are logged) - mock_grpo_components["master_config"]["grpo"]["normalize_rewards"] = True - mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = ( - 2 # Just run 2 steps - ) - - grpo_save_state = _default_grpo_save_state() - - # Mock batch with specific rewards to get predictable advantages - mock_batch = BatchedDataDict[DatumSpec]( - { - "message_log": [ - [ - { - "role": "user", - "content": "test", - "token_ids": torch.tensor([1, 2, 3]), - } - ] - ] - * 4, - "task_name": ["math"] * 4, - "extra_env_info": [{}] * 4, - "loss_multiplier": torch.tensor([1.0, 1.0, 1.0, 1.0]), - "idx": torch.tensor([0, 1, 2, 3]), - "length": torch.tensor([3, 3, 3, 3]), - "total_reward": torch.tensor( - [1.0, 3.0, 2.0, 4.0] - ), # Known rewards for predictable advantages - } - ) - - mock_rollout_metrics = { - "mean_gen_tokens_per_sample": 10.0, - "max_gen_tokens": 20, - "min_gen_tokens": 5, - } - - # Capture the logged metrics - logged_metrics = [] - original_log_metrics = mock_grpo_components["logger"].log_metrics - - def capture_log_metrics(metrics, step): - logged_metrics.append((dict(metrics), step)) - return original_log_metrics(metrics, step) - - mock_grpo_components["logger"].log_metrics.side_effect = capture_log_metrics - - with patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout: - mock_rollout.return_value = (mock_batch, mock_rollout_metrics) - - # Run GRPO training - grpo_train( - mock_grpo_components["policy"], - None, # policy_generation - mock_grpo_components["train_dataloader"], - mock_grpo_components["val_dataloader"], - mock_grpo_components["tokenizer"], - mock_grpo_components["loss_fn"], - mock_grpo_components["task_to_env"], - mock_grpo_components["val_task_to_env"], - mock_grpo_components["logger"], - mock_grpo_components["checkpointer"], - grpo_save_state, - mock_grpo_components["master_config"], - ) - - # Verify that advantages metrics were logged - assert len(logged_metrics) >= 2, "Should have logged metrics for at least 2 steps" - - # Check each logged step for advantages metrics - for metrics_dict, step in logged_metrics: - if "advantages/mean" in metrics_dict: - # Verify that all three advantages metrics are present - assert "advantages/mean" in metrics_dict - assert "advantages/max" in metrics_dict - assert "advantages/min" in metrics_dict - - # Verify that the values are reasonable (should be numeric) - assert isinstance(metrics_dict["advantages/mean"], (int, float)) - assert isinstance(metrics_dict["advantages/max"], (int, float)) - assert isinstance(metrics_dict["advantages/min"], (int, float)) - - # Verify logical relationships between min, mean, max - assert metrics_dict["advantages/min"] <= metrics_dict["advantages/mean"] - assert metrics_dict["advantages/mean"] <= metrics_dict["advantages/max"] - - -# ============================================================================ -# Integration tests for GRPO with new advantage normalization logic -# ============================================================================ - - -def test_grpo_integration_with_new_normalization(mock_grpo_components): - """Integration test for GRPO training with the new advantage normalization logic.""" - from unittest.mock import patch - - # Enable reward normalization to test the new normalization logic - mock_grpo_components["master_config"]["grpo"]["normalize_rewards"] = True - mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 3 - - grpo_save_state = _default_grpo_save_state() - - # Create test batch with mixed std scenarios (including zero std cases) - mock_batch = BatchedDataDict[DatumSpec]( - { - "message_log": [ - [ - { - "role": "user", - "content": "test", - "token_ids": torch.tensor([1, 2, 3]), - } - ] - ] - * 6, - "task_name": ["math"] * 6, - "extra_env_info": [{}] * 6, - "loss_multiplier": torch.tensor([1.0] * 6), - "idx": torch.tensor([0, 1, 2, 3, 4, 5]), - "length": torch.tensor([3] * 6), - # Mix of rewards: some with variation, some identical (zero std) - "total_reward": torch.tensor([1.0, 1.0, 1.0, 2.0, 4.0, 6.0]), - } - ) - - mock_rollout_metrics = { - "mean_gen_tokens_per_sample": 10.0, - "max_gen_tokens": 20, - "min_gen_tokens": 5, - } - - # Track calls to normalize_advantages_with_epsilon - original_normalize = normalize_advantages_with_epsilon - normalize_calls = [] - - def track_normalize_calls(advantages, std, epsilon=1e-6): - normalize_calls.append((advantages.clone(), std.clone(), epsilon)) - return original_normalize(advantages, std, epsilon) - - # Capture logged metrics to verify advantages tracking - logged_metrics = [] - original_log_metrics = mock_grpo_components["logger"].log_metrics - - def capture_log_metrics(metrics, step): - logged_metrics.append((dict(metrics), step)) - return original_log_metrics(metrics, step) - - mock_grpo_components["logger"].log_metrics.side_effect = capture_log_metrics - - with ( - patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout, - patch( - "nemo_rl.algorithms.grpo.normalize_advantages_with_epsilon", - side_effect=track_normalize_calls, - ) as mock_normalize, - ): - mock_rollout.return_value = (mock_batch, mock_rollout_metrics) - - # Run GRPO training - grpo_train( - mock_grpo_components["policy"], - None, # policy_generation - mock_grpo_components["train_dataloader"], - mock_grpo_components["val_dataloader"], - mock_grpo_components["tokenizer"], - mock_grpo_components["loss_fn"], - mock_grpo_components["task_to_env"], - mock_grpo_components["val_task_to_env"], - mock_grpo_components["logger"], - mock_grpo_components["checkpointer"], - grpo_save_state, - mock_grpo_components["master_config"], - ) - - # Verify that normalize_advantages_with_epsilon was called - assert len(normalize_calls) >= 3, ( - "normalize_advantages_with_epsilon should be called for each training step" - ) - assert mock_normalize.call_count >= 3, ( - "Should have called normalization function at least 3 times" - ) - - # Verify the function was called with expected parameters - for advantages, std, epsilon in normalize_calls: - # Verify tensor shapes are correct - assert advantages.dim() == 2, "Advantages should be 2D tensor" - assert advantages.shape[1] == 1, "Advantages should have shape (batch_size, 1)" - assert std.dim() == 1, "Std should be 1D tensor" - assert advantages.shape[0] == std.shape[0], ( - "Advantages and std should have same batch size" - ) - - # Verify epsilon is the expected default - assert epsilon == 1e-6, "Should use default epsilon value" - - # Verify that std values are non-negative (after baseline calculation) - assert (std >= 0).all(), "All std values should be non-negative" - - # Verify advantages metrics are logged - advantages_metrics_found = False - for metrics_dict, step in logged_metrics: - if "advantages/mean" in metrics_dict: - advantages_metrics_found = True - # Verify all three metrics are present - assert "advantages/mean" in metrics_dict - assert "advantages/max" in metrics_dict - assert "advantages/min" in metrics_dict - - # Verify metrics are finite numbers - assert math.isfinite(metrics_dict["advantages/mean"]) - assert math.isfinite(metrics_dict["advantages/max"]) - assert math.isfinite(metrics_dict["advantages/min"]) - - assert advantages_metrics_found, ( - "Should have logged advantages metrics during training" - ) - - # Verify training completed successfully (no exceptions raised) - assert mock_grpo_components["policy"].train.call_count == 3, ( - "Should have completed 3 training steps" - ) - - -def test_grpo_normalization_with_all_zero_std(mock_grpo_components): - """Test GRPO training when all prompts have zero standard deviation.""" - from unittest.mock import patch - - # Enable reward normalization - mock_grpo_components["master_config"]["grpo"]["normalize_rewards"] = True - mock_grpo_components["master_config"]["grpo"]["max_num_steps"] = 1 - - grpo_save_state = _default_grpo_save_state() - - # Create batch where all rewards are identical (zero std everywhere) - mock_batch = BatchedDataDict[DatumSpec]( - { - "message_log": [ - [ - { - "role": "user", - "content": "test", - "token_ids": torch.tensor([1, 2, 3]), - } - ] - ] - * 4, - "task_name": ["math"] * 4, - "extra_env_info": [{}] * 4, - "loss_multiplier": torch.tensor([1.0] * 4), - "idx": torch.tensor([0, 1, 2, 3]), - "length": torch.tensor([3] * 4), - "total_reward": torch.tensor([2.5, 2.5, 2.5, 2.5]), # All identical rewards - } - ) - - mock_rollout_metrics = { - "mean_gen_tokens_per_sample": 10.0, - "max_gen_tokens": 20, - "min_gen_tokens": 5, - } - - # Track normalization calls to verify epsilon is used properly - normalize_calls = [] - original_normalize = normalize_advantages_with_epsilon - - def track_normalize_calls(advantages, std, epsilon=1e-6): - normalize_calls.append((advantages.clone(), std.clone(), epsilon)) - return original_normalize(advantages, std, epsilon) - - with ( - patch("nemo_rl.algorithms.grpo.run_multi_turn_rollout") as mock_rollout, - patch( - "nemo_rl.algorithms.grpo.normalize_advantages_with_epsilon", - side_effect=track_normalize_calls, - ), - ): - mock_rollout.return_value = (mock_batch, mock_rollout_metrics) - - # Run GRPO training - should complete without errors even with all zero std - grpo_train( - mock_grpo_components["policy"], - None, - mock_grpo_components["train_dataloader"], - mock_grpo_components["val_dataloader"], - mock_grpo_components["tokenizer"], - mock_grpo_components["loss_fn"], - mock_grpo_components["task_to_env"], - mock_grpo_components["val_task_to_env"], - mock_grpo_components["logger"], - mock_grpo_components["checkpointer"], - grpo_save_state, - mock_grpo_components["master_config"], - ) - - # Verify normalization was called and handled zero std correctly - assert len(normalize_calls) >= 1, "Should have called normalization at least once" - - for advantages, std, epsilon in normalize_calls: - # When all rewards are identical, std should be 0 - if (std == 0).all(): - # Verify that normalized advantages = advantages / epsilon - expected_normalized = advantages / epsilon - actual_normalized = normalize_advantages_with_epsilon( - advantages, std, epsilon - ) - assert torch.allclose(actual_normalized, expected_normalized, rtol=1e-5) - - # Verify training completed successfully - assert mock_grpo_components["policy"].train.call_count == 1 From e6c0070df5fe295c96e624426684e4054165b8d2 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Fri, 31 Oct 2025 22:34:23 -0700 Subject: [PATCH 3/5] Remove lingering std. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/algorithms/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 306d7ba307..662b948d7f 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -152,7 +152,6 @@ def calculate_baseline_and_std_per_prompt( .nan_to_num(0) ) - std = (sq_baseline - baseline.square()).sqrt().nan_to_num(0) return baseline, std From 4783ef0eec58dea8780f0466c9bff9a688b993ca Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Thu, 6 Nov 2025 23:39:57 -0800 Subject: [PATCH 4/5] Fix unit tests. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/algorithms/utils.py | 2 +- tests/unit/algorithms/test_utils.py | 74 ++++++++++++++--------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 662b948d7f..1a28f5f690 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -104,7 +104,7 @@ def calculate_baseline_and_std_per_prompt( if device_ordinal == -1: reward_device = torch.device("cpu") else: - reward_device = torch.device(reward_device) + reward_device = torch.device(f"cuda:{device_ordinal}") for i in range(len(unique_prompts)): is_matching_prompt = (prompts == unique_prompts[i]).all(1) diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py index 45220acf29..edc8d0a812 100755 --- a/tests/unit/algorithms/test_utils.py +++ b/tests/unit/algorithms/test_utils.py @@ -415,13 +415,14 @@ def test_calculate_baseline_and_std_per_prompt_basic(): [4, 5, 6], # prompt 1 ] ) + valid_mask = torch.ones(6) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) - # For prompt 0: rewards [1, 2, 3] -> mean = 2.0, std = 1.0 - # For prompt 1: rewards [4, 5, 6] -> mean = 5.0, std = 1.0 - expected_baseline = torch.tensor([2.0, 2.0, 2.0, 5.0, 5.0, 5.0]) - expected_std = torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0, 1.0]) + expected_baseline = torch.tensor([2.5, 2.0, 1.5, 5.5, 5.0, 4.5]) + expected_std = torch.tensor( + [0.707107, 1.414214, 0.707107, 0.707107, 1.414214, 0.707107] + ) assert torch.allclose(baseline, expected_baseline, rtol=1e-5) assert torch.allclose(std, expected_std, rtol=1e-5) @@ -437,10 +438,11 @@ def test_calculate_baseline_and_std_per_prompt_single_generation_per_prompt(): [4, 5, 6], # prompt 1 ] ) + valid_mask = torch.ones(2) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) - # When num_valid < 2, std should be 0 due to nan_to_num(0) + # When num_valid <= 1 (single generation per prompt), baseline equals reward expected_baseline = torch.tensor([2.5, 4.0]) expected_std = torch.tensor([0.0, 0.0]) @@ -462,10 +464,10 @@ def test_calculate_baseline_and_std_per_prompt_identical_rewards(): [4, 5, 6], # prompt 1 ] ) + valid_mask = torch.ones(6) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) - # When all rewards are identical, std should be 0 expected_baseline = torch.tensor([3.0, 3.0, 3.0, 7.0, 7.0, 7.0]) expected_std = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) @@ -486,13 +488,12 @@ def test_calculate_baseline_and_std_per_prompt_mixed_prompt_sizes(): [4, 5, 6], # prompt 1 ] ) + valid_mask = torch.ones(5) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) - # For prompt 0: rewards [1, 2] -> mean = 1.5, std = 0.5 - # For prompt 1: rewards [4, 5, 6] -> mean = 5.0, std = 1.0 - expected_baseline = torch.tensor([1.5, 1.5, 5.0, 5.0, 5.0]) - expected_std = torch.tensor([0.5, 0.5, 1.0, 1.0, 1.0]) + expected_baseline = torch.tensor([2.0, 1.0, 5.5, 5.0, 4.5]) + expected_std = torch.tensor([0.0, 0.0, 0.707107, 1.414214, 0.707107]) assert torch.allclose(baseline, expected_baseline, rtol=1e-5) assert torch.allclose(std, expected_std, rtol=1e-5) @@ -502,8 +503,9 @@ def test_calculate_baseline_and_std_per_prompt_empty_input(): """Test calculate_baseline_and_std_per_prompt with empty tensors.""" rewards = torch.tensor([]) prompts = torch.empty(0, 3, dtype=torch.long) + valid_mask = torch.tensor([]) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) assert baseline.shape == torch.Size([0]) assert std.shape == torch.Size([0]) @@ -512,31 +514,30 @@ def test_calculate_baseline_and_std_per_prompt_empty_input(): def test_calculate_baseline_and_std_per_prompt_nan_handling(): - """Test calculate_baseline_and_std_per_prompt handles NaN values correctly.""" - # Include some NaN rewards - rewards = torch.tensor([1.0, float("nan"), 3.0, 4.0, 5.0, 6.0]) + """Test calculate_baseline_and_std_per_prompt handles valid_mask correctly with masked samples.""" + # Test that valid_mask properly excludes samples from baseline calculation + # Note: The function doesn't handle actual NaN values; it uses valid_mask to exclude samples + rewards = torch.tensor([1.0, 999.0, 3.0, 4.0, 5.0, 6.0]) # 999.0 should be ignored prompts = torch.tensor( [ [1, 2, 3], # prompt 0 - [1, 2, 3], # prompt 0 (NaN reward) + [1, 2, 3], # prompt 0 (invalid sample) [1, 2, 3], # prompt 0 [4, 5, 6], # prompt 1 [4, 5, 6], # prompt 1 [4, 5, 6], # prompt 1 ] ) + # Mark the second sample as invalid + valid_mask = torch.tensor([1.0, 0.0, 1.0, 1.0, 1.0, 1.0]) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) - # The function should handle NaN values gracefully - # For prompt 0: only valid rewards [1, 3] -> mean = 2.0, std = 1.0 - # For prompt 1: rewards [4, 5, 6] -> mean = 5.0, std = 1.0 + expected_baseline = torch.tensor([3.0, 4.0, 1.0, 5.5, 5.0, 4.5]) + expected_std = torch.tensor([0.0, 0.0, 0.0, 0.707107, 1.414214, 0.707107]) - # Check that NaN positions get filled appropriately - assert not torch.isnan(baseline).any(), "Baseline should not contain NaN values" - assert not torch.isnan(std).any(), ( - "Std should not contain NaN values due to nan_to_num(0)" - ) + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) def test_calculate_baseline_and_std_per_prompt_cuda_compatibility(): @@ -553,16 +554,16 @@ def test_calculate_baseline_and_std_per_prompt_cuda_compatibility(): [4, 5, 6], # prompt 1 ] ).cuda() + valid_mask = torch.ones(4).cuda() - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) # Verify results are on CUDA and have expected values assert baseline.device.type == "cuda" assert std.device.type == "cuda" - expected_baseline = torch.tensor([1.5, 1.5, 4.0, 4.0]).cuda() - expected_std = torch.tensor( - [0.5, 0.5, 0.0, 0.0] - ).cuda() # std=0 for single sample per prompt + + expected_baseline = torch.tensor([2.0, 1.0, 4.0, 3.0]).cuda() + expected_std = torch.tensor([0.0, 0.0, 0.0, 0.0]).cuda() assert torch.allclose(baseline, expected_baseline, rtol=1e-5) assert torch.allclose(std, expected_std, rtol=1e-5) @@ -582,12 +583,11 @@ def test_calculate_baseline_and_std_per_prompt_numerical_precision(): [4, 5, 6], # prompt 1 ] ) + valid_mask = torch.ones(6) - baseline, std = calculate_baseline_and_std_per_prompt(rewards, prompts) + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) - # For prompt 0: very small values [1e-8, 2e-8, 3e-8] -> mean = 2e-8 - # For prompt 1: very large values [1e8, 2e8, 3e8] -> mean = 2e8 - expected_baseline = torch.tensor([2e-8, 2e-8, 2e-8, 2e8, 2e8, 2e8]) + expected_baseline = torch.tensor([2.5e-8, 2e-8, 1.5e-8, 2.5e8, 2e8, 1.5e8]) assert torch.allclose(baseline, expected_baseline, rtol=1e-5) # Std values should be finite and not NaN From 2884e3069618fe99765b95daa3a11d3c6183b9b4 Mon Sep 17 00:00:00 2001 From: Felipe Vieira Frujeri Date: Tue, 11 Nov 2025 08:55:18 -0800 Subject: [PATCH 5/5] Compute advantages only for masked response tokens. Signed-off-by: Felipe Vieira Frujeri --- nemo_rl/algorithms/grpo.py | 46 +++++++++++++++++++++++++++++++------- 1 file changed, 38 insertions(+), 8 deletions(-) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 818db4bb2c..3618b9b92d 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -1190,16 +1190,31 @@ def grpo_train( val_metrics, total_steps + 1, prefix="validation" ) + # Get flat advantages and token mask for masked metrics computation + flat_advantages = flat_messages["advantages"] + flat_token_mask = flat_messages["token_loss_mask"] + + # Filter advantages using token mask (only valid response tokens) + response_advantages = torch.masked_select( + flat_advantages, flat_token_mask.bool() + ) + metrics = { "loss": train_results["loss"].numpy(), "grad_norm": train_results["grad_norm"].numpy(), "reward": rewards.numpy(), "mean_prompt_length": repeated_batch["length"].numpy(), "total_num_tokens": input_lengths.numpy(), - # Add advantages tracking metrics - "advantages/mean": torch.mean(advantages).detach().item(), - "advantages/max": torch.max(advantages).detach().item(), - "advantages/min": torch.min(advantages).detach().item(), + # Add masked advantages tracking metrics (only for valid response tokens) + "advantages/mean": torch.mean(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/max": torch.max(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/min": torch.min(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, **ds_metrics, } if master_config["grpo"]["use_dynamic_sampling"]: @@ -2083,16 +2098,31 @@ def async_grpo_train( # Resume trajectory collection after validation trajectory_collector.resume.remote() + # Get flat advantages and token mask for masked metrics computation + flat_advantages = flat_messages["advantages"] + flat_token_mask = flat_messages["token_loss_mask"] + + # Filter advantages using token mask (only valid response tokens) + response_advantages = torch.masked_select( + flat_advantages, flat_token_mask.bool() + ) + metrics = { "loss": train_results["loss"].numpy(), "reward": rewards.numpy(), "grad_norm": train_results["grad_norm"].numpy(), "mean_prompt_length": repeated_batch["length"].numpy(), "total_num_tokens": input_lengths.numpy(), - # Add advantages tracking metrics - "advantages/mean": torch.mean(advantages).detach().item(), - "advantages/max": torch.max(advantages).detach().item(), - "advantages/min": torch.min(advantages).detach().item(), + # Add masked advantages tracking metrics (only for valid response tokens) + "advantages/mean": torch.mean(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/max": torch.max(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/min": torch.min(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items():