diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index fdfa5e3ed2..3618b9b92d 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -538,6 +538,25 @@ def setup( # =============================================================================== +def normalize_advantages_with_epsilon( + advantages: torch.Tensor, + std: torch.Tensor, + epsilon: float = 1e-6, +) -> torch.Tensor: + """Normalize advantages by standard deviation with epsilon to avoid division by zero. + + Args: + advantages: Tensor of shape (batch_size, 1) containing advantage values + std: Tensor of shape (batch_size,) containing standard deviation values + epsilon: Small value to avoid division by zero, defaults to 1e-6 + + Returns: + Normalized advantages tensor of same shape as input advantages + """ + # Use epsilon to avoid division by zero instead of masking + return advantages / (std.unsqueeze(-1) + epsilon) + + def dynamic_sampling( repeated_batch: BatchedDataDict[DatumSpec], std: torch.Tensor, @@ -1056,10 +1075,9 @@ def grpo_train( advantages = (rewards - baseline).unsqueeze(-1) if master_config["grpo"]["normalize_rewards"]: - # don't sharpen the ones with no variation - zero_std_mask = std > 0 - advantages[zero_std_mask] = ( - advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + advantages = normalize_advantages_with_epsilon( + advantages=advantages, + std=std, ) with timer.time("data_processing"): @@ -1172,12 +1190,31 @@ def grpo_train( val_metrics, total_steps + 1, prefix="validation" ) + # Get flat advantages and token mask for masked metrics computation + flat_advantages = flat_messages["advantages"] + flat_token_mask = flat_messages["token_loss_mask"] + + # Filter advantages using token mask (only valid response tokens) + response_advantages = torch.masked_select( + flat_advantages, flat_token_mask.bool() + ) + metrics = { "loss": train_results["loss"].numpy(), "grad_norm": train_results["grad_norm"].numpy(), "reward": rewards.numpy(), "mean_prompt_length": repeated_batch["length"].numpy(), "total_num_tokens": input_lengths.numpy(), + # Add masked advantages tracking metrics (only for valid response tokens) + "advantages/mean": torch.mean(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/max": torch.max(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/min": torch.min(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, **ds_metrics, } if master_config["grpo"]["use_dynamic_sampling"]: @@ -1929,10 +1966,11 @@ def async_grpo_train( ) if master_config["grpo"]["normalize_rewards"]: - zero_std_mask = std > 0 - advantages[zero_std_mask] = ( - advantages[zero_std_mask] / std.unsqueeze(-1)[zero_std_mask] + advantages = normalize_advantages_with_epsilon( + advantages=advantages, + std=std, ) + print( f" 📊 Normalized advantages stats: min={advantages.min():.4f}, max={advantages.max():.4f}, mean={advantages.mean():.4f}, std={advantages.std():.4f}" ) @@ -2060,12 +2098,31 @@ def async_grpo_train( # Resume trajectory collection after validation trajectory_collector.resume.remote() + # Get flat advantages and token mask for masked metrics computation + flat_advantages = flat_messages["advantages"] + flat_token_mask = flat_messages["token_loss_mask"] + + # Filter advantages using token mask (only valid response tokens) + response_advantages = torch.masked_select( + flat_advantages, flat_token_mask.bool() + ) + metrics = { "loss": train_results["loss"].numpy(), "reward": rewards.numpy(), "grad_norm": train_results["grad_norm"].numpy(), "mean_prompt_length": repeated_batch["length"].numpy(), "total_num_tokens": input_lengths.numpy(), + # Add masked advantages tracking metrics (only for valid response tokens) + "advantages/mean": torch.mean(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/max": torch.max(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, + "advantages/min": torch.min(response_advantages).detach().item() + if response_advantages.numel() > 0 + else 0.0, } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items(): diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index e323bec734..1a28f5f690 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -99,11 +99,12 @@ def calculate_baseline_and_std_per_prompt( baseline = torch.zeros_like(rewards) sq_baseline = torch.zeros_like(rewards) + std = torch.zeros_like(rewards) device_ordinal = rewards.get_device() if device_ordinal == -1: reward_device = torch.device("cpu") else: - reward_device = torch.device(reward_device) + reward_device = torch.device(f"cuda:{device_ordinal}") for i in range(len(unique_prompts)): is_matching_prompt = (prompts == unique_prompts[i]).all(1) @@ -142,8 +143,15 @@ def calculate_baseline_and_std_per_prompt( baseline[prompt_idx] = prompt_baseline sq_baseline[prompt_idx] = prompt_baseline_square + std[prompt_idx] = ( + ( + (prompt_baseline_square - prompt_baseline.square()) + * (num_valid / (num_valid - 1)) + ) + .sqrt() + .nan_to_num(0) + ) - std = (sq_baseline - baseline.square()).sqrt().nan_to_num(0) return baseline, std diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 6d42ad553d..07e94f8c48 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -24,6 +24,7 @@ async_grpo_train, dynamic_sampling, grpo_train, + normalize_advantages_with_epsilon, ) from nemo_rl.algorithms.loss_functions import ClippedPGLossFn from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType @@ -1208,3 +1209,75 @@ def test_grpo_exit_on_timeout(mock_grpo_components, train_func, capsys): assert not (line.startswith("Step ") and "Step 9" in line), ( f"Training continued to next step after timeout: {line}" ) + + +# ============================================================================ +# Tests for normalize_advantages_with_epsilon function +# ============================================================================ + + +def test_normalize_advantages_with_epsilon_basic(): + """Test basic functionality of normalize_advantages_with_epsilon.""" + # Test case with normal values + advantages = torch.tensor([[2.0], [4.0], [6.0]]) + std = torch.tensor([1.0, 2.0, 3.0]) + epsilon = 1e-6 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + expected = torch.tensor([[2.0], [2.0], [2.0]]) + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_zero_std(): + """Test normalize_advantages_with_epsilon when std contains zeros.""" + advantages = torch.tensor([[1.0], [2.0], [3.0]]) + std = torch.tensor([0.0, 1.0, 0.0]) # Zero std for indices 0 and 2 + epsilon = 1e-6 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + # When std=0, result should be advantages / epsilon + expected = torch.tensor([[1.0 / epsilon], [2.0], [3.0 / epsilon]]) + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_all_zero_std(): + """Test normalize_advantages_with_epsilon when all std values are zero.""" + advantages = torch.tensor([[1.5], [2.5], [3.5]]) + std = torch.tensor([0.0, 0.0, 0.0]) + epsilon = 1e-8 + + result = normalize_advantages_with_epsilon(advantages, std, epsilon) + + expected = advantages / epsilon + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_tensor_shapes(): + """Test normalize_advantages_with_epsilon with different tensor shapes.""" + # Test with batch size 1 + advantages = torch.tensor([[5.0]]) + std = torch.tensor([2.0]) + result = normalize_advantages_with_epsilon(advantages, std) + expected = torch.tensor([[2.5]]) + assert torch.allclose(result, expected, rtol=1e-5) + + # Test with larger batch + batch_size = 10 + advantages = torch.ones(batch_size, 1) * 3.0 + std = torch.ones(batch_size) * 1.5 + result = normalize_advantages_with_epsilon(advantages, std) + expected = torch.ones(batch_size, 1) * 2.0 + assert torch.allclose(result, expected, rtol=1e-5) + + +def test_normalize_advantages_with_epsilon_negative_advantages(): + """Test normalize_advantages_with_epsilon with negative advantages.""" + advantages = torch.tensor([[-2.0], [3.0], [-1.5]]) + std = torch.tensor([1.0, 1.5, 0.5]) + + result = normalize_advantages_with_epsilon(advantages, std) + + expected = torch.tensor([[-2.0], [2.0], [-3.0]]) + assert torch.allclose(result, expected, rtol=1e-5) diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py index ce049a19db..edc8d0a812 100755 --- a/tests/unit/algorithms/test_utils.py +++ b/tests/unit/algorithms/test_utils.py @@ -19,6 +19,7 @@ import torch from nemo_rl.algorithms.utils import ( + calculate_baseline_and_std_per_prompt, get_tokenizer, maybe_pad_last_batch, print_performance_metrics, @@ -393,3 +394,202 @@ def test_minimal_inputs_no_counts_no_flops(capsys): out = capsys.readouterr().out assert "Throughputs (per GPU)" in out + + +# ============================================================================ +# Tests for calculate_baseline_and_std_per_prompt function +# ============================================================================ + + +def test_calculate_baseline_and_std_per_prompt_basic(): + """Test basic functionality of calculate_baseline_and_std_per_prompt.""" + # Create rewards for 2 prompts, each with 3 generations + rewards = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + valid_mask = torch.ones(6) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + expected_baseline = torch.tensor([2.5, 2.0, 1.5, 5.5, 5.0, 4.5]) + expected_std = torch.tensor( + [0.707107, 1.414214, 0.707107, 0.707107, 1.414214, 0.707107] + ) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_single_generation_per_prompt(): + """Test calculate_baseline_and_std_per_prompt when num_valid < 2 (single generation per prompt).""" + # Case where each prompt has only 1 generation (num_valid = 1 < 2) + rewards = torch.tensor([2.5, 4.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + ] + ) + valid_mask = torch.ones(2) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + # When num_valid <= 1 (single generation per prompt), baseline equals reward + expected_baseline = torch.tensor([2.5, 4.0]) + expected_std = torch.tensor([0.0, 0.0]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_identical_rewards(): + """Test calculate_baseline_and_std_per_prompt when all rewards for a prompt are identical.""" + # All generations for both prompts have the same reward + rewards = torch.tensor([3.0, 3.0, 3.0, 7.0, 7.0, 7.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + valid_mask = torch.ones(6) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + expected_baseline = torch.tensor([3.0, 3.0, 3.0, 7.0, 7.0, 7.0]) + expected_std = torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_mixed_prompt_sizes(): + """Test calculate_baseline_and_std_per_prompt with different number of generations per prompt.""" + # Prompt 0 has 2 generations, Prompt 1 has 3 generations + rewards = torch.tensor([1.0, 2.0, 4.0, 5.0, 6.0]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + valid_mask = torch.ones(5) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + expected_baseline = torch.tensor([2.0, 1.0, 5.5, 5.0, 4.5]) + expected_std = torch.tensor([0.0, 0.0, 0.707107, 1.414214, 0.707107]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_empty_input(): + """Test calculate_baseline_and_std_per_prompt with empty tensors.""" + rewards = torch.tensor([]) + prompts = torch.empty(0, 3, dtype=torch.long) + valid_mask = torch.tensor([]) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + assert baseline.shape == torch.Size([0]) + assert std.shape == torch.Size([0]) + assert torch.equal(baseline, torch.tensor([])) + assert torch.equal(std, torch.tensor([])) + + +def test_calculate_baseline_and_std_per_prompt_nan_handling(): + """Test calculate_baseline_and_std_per_prompt handles valid_mask correctly with masked samples.""" + # Test that valid_mask properly excludes samples from baseline calculation + # Note: The function doesn't handle actual NaN values; it uses valid_mask to exclude samples + rewards = torch.tensor([1.0, 999.0, 3.0, 4.0, 5.0, 6.0]) # 999.0 should be ignored + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 (invalid sample) + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + # Mark the second sample as invalid + valid_mask = torch.tensor([1.0, 0.0, 1.0, 1.0, 1.0, 1.0]) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + expected_baseline = torch.tensor([3.0, 4.0, 1.0, 5.5, 5.0, 4.5]) + expected_std = torch.tensor([0.0, 0.0, 0.0, 0.707107, 1.414214, 0.707107]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_cuda_compatibility(): + """Test calculate_baseline_and_std_per_prompt works with CUDA tensors if available.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + rewards = torch.tensor([1.0, 2.0, 3.0, 4.0]).cuda() + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ).cuda() + valid_mask = torch.ones(4).cuda() + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + # Verify results are on CUDA and have expected values + assert baseline.device.type == "cuda" + assert std.device.type == "cuda" + + expected_baseline = torch.tensor([2.0, 1.0, 4.0, 3.0]).cuda() + expected_std = torch.tensor([0.0, 0.0, 0.0, 0.0]).cuda() + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + assert torch.allclose(std, expected_std, rtol=1e-5) + + +def test_calculate_baseline_and_std_per_prompt_numerical_precision(): + """Test calculate_baseline_and_std_per_prompt with edge case numerical values.""" + # Use very small and very large values + rewards = torch.tensor([1e-8, 2e-8, 3e-8, 1e8, 2e8, 3e8]) + prompts = torch.tensor( + [ + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [1, 2, 3], # prompt 0 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + [4, 5, 6], # prompt 1 + ] + ) + valid_mask = torch.ones(6) + + baseline, std = calculate_baseline_and_std_per_prompt(prompts, rewards, valid_mask) + + expected_baseline = torch.tensor([2.5e-8, 2e-8, 1.5e-8, 2.5e8, 2e8, 1.5e8]) + + assert torch.allclose(baseline, expected_baseline, rtol=1e-5) + # Std values should be finite and not NaN + assert torch.isfinite(std).all() + assert not torch.isnan(std).any()