diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index fe874ecc26..af78baf34d 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -13,7 +13,14 @@ # limitations under the License. import pytest import torch -from nemo_reinforcer.algorithms.loss_functions import NLLLoss +import numpy as np + +from nemo_reinforcer.algorithms.loss_functions import NLLLoss, ClippedPGLossFn +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.utils import ( + calculate_kl_penalty_joschu2020, + masked_mean, +) def test_nll_loss(): @@ -46,7 +53,7 @@ def test_nll_loss(): .to("cuda") ) loss, metrics_dict = loss_fn(next_token_logits, data) - torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + torch.testing.assert_close(loss.cpu(), torch.tensor(0.0)) # Check the metrics dictionary contains the expected values assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 @@ -66,8 +73,316 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) ## loss per token is 999, and we have two unmasked tokens - ## with the updated loss function, we now average the loss over unmasked tokens - torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0)) - # Check the metrics dictionary contains the expected values + ## NLLLoss averages the loss over unmasked tokens + torch.testing.assert_close(loss.cpu(), torch.tensor(999.0)) assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 + + +def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): + """Sets up basic mock data structure. Tests should fill values.""" + input_ids = torch.randint( # Input IDs only needed if original loss fn used + 0, vocab_size, (batch_size, seq_len), dtype=torch.int64, device=device + ) + # Default mask: Mask first token [[0, 1, 1, 1]] + token_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) + token_mask[:, 0] = 0 + # sample_mask needs shape [B] + sample_mask = torch.ones(batch_size, dtype=torch.int64, device=device) + + # Simple default values, tests overwrite these + advantages = torch.zeros((batch_size, seq_len), device=device) + prev_logprobs = torch.zeros((batch_size, seq_len), device=device) + reference_policy_logprobs = torch.zeros((batch_size, seq_len), device=device) + generation_logprobs = torch.zeros((batch_size, seq_len), device=device) + + data = BatchedDataDict( + { + "input_ids": input_ids, # Include for completeness + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + "generation_logprobs": generation_logprobs, + } + ) + # Return seq_len and vocab_size needed by tests + return data, seq_len, vocab_size + + +# Helper to create logits that yield specific target log probs after log_softmax +def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): + """Constructs logits such that log_softmax results in target_curr_lp_masked.""" + dummy_logits = torch.full( + (1, seq_len, vocab_size), -100.0, device=device + ) # Start very low + + # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] + # We need to set logits for indices i=0..S-2 of the sliced logits tensor. + # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. + num_effective_pos = target_curr_lp_masked.shape[1] + for i in range(num_effective_pos): + logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) + data_idx = i + 1 # Index in the original input_ids to find the target token + + target_token_id = input_ids[0, data_idx].item() + # Keep target_lp as a 0-dim tensor for torch ops + target_lp = target_curr_lp_masked[0, i] + + # Handle target_lp = 0 case separately + if torch.isclose(target_lp, torch.tensor(0.0, device=device)): + dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + elif target_lp < 0: + # Set target token logit to 0 + dummy_logits[0, logit_idx, target_token_id] = 0.0 + # Set one distractor token logit using the formula + distractor_token_id = (target_token_id + 1) % vocab_size + # Ensure distractor isn't same as target if vocab_size=1 (edge case) + if distractor_token_id == target_token_id: + distractor_token_id = (target_token_id + 2) % vocab_size + distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) + dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit + else: # target_lp > 0 is not supported by this method + raise ValueError( + "Target log probability must be negative or zero for this construction" + ) + return dummy_logits + + +# Simplified PPO Clipping Test using original Loss +def test_clipped_pg_loss_ppo_clipping(): + """Tests PPO clipping calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_eps = 0.2 + cfg = { + "ratio_eps_min": ratio_eps, + "ratio_eps_max": ratio_eps, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand Calculation --- + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ) # [0.8, 1.0, 1.2] + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + expected_loss = torch.mean( + max_loss + ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified REINFORCE Test using original Loss +def test_clipped_pg_loss_reinforce_mode(): + """Tests REINFORCE mode calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = { + "disable_ppo_ratio": True, + "reference_policy_kl_penalty": 0.0, + "ratio_eps_min": 0.0, # Placeholder, ignored + "ratio_eps_max": 0.0, # Placeholder, ignored + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + curr_lp_masked = torch.tensor([[-0.5, -1.0, -1.5]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["_test_curr_logprobs"] = curr_lp_masked + data["prev_logprobs"][0, 1:] = torch.zeros_like(curr_lp_masked) + + # --- Hand Calculation --- + expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified KL Penalty Test using original Loss +def test_clipped_pg_loss_kl_penalty(): + """Tests KL penalty calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + # --- Test Setup --- + kl_beta = 0.1 + cfg = { + "reference_policy_kl_penalty": kl_beta, + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + curr_lp_masked = torch.tensor([[0.0, -1.0, -2.0]], device=device) + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + prev_lp_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["_test_curr_logprobs"] = curr_lp_masked + + # --- Hand Calculation --- + # Actor loss is 0. Total loss = kl_beta * mean(kl_term) + # kl_term = exp(ref - curr) - (ref - curr) - 1 + r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + expected_loss = kl_beta * expected_kl_mean # 0.0362 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Masking tests - Should work with original Loss Fn if needed, but less critical +def test_clipped_pg_loss_masking(): + """Tests the effect of token_mask and sample_mask.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + batch_size = 2 + seq_len = 4 + device = "cuda" + # Use original loss function for masking tests, as it involves interactions + # that the Testable class might obscure slightly. + data, seq_len, vocab_size = _setup_clipped_pg_test_data( + batch_size=batch_size, seq_len=seq_len, device=device + ) + # Need some realistic-ish logits and logprobs for masking test + dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + # Ensure logprobs used by the loss fn make sense relative to advantages + data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 + data["reference_policy_logprobs"] = ( + torch.randn_like(data["reference_policy_logprobs"]) * 0.1 + ) + # Make advantages non-zero + data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 + + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # --- Test 1: Token Mask --- + # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample + loss_default, _ = loss_fn(dummy_logits, data) + + # Modify token_mask for batch item 0 to mask one more token (pos 1) + data_mod_token = data.copy() + data_mod_token["token_mask"] = data["token_mask"].clone() + data_mod_token["token_mask"][0, 1] = ( + 0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1 + ) + + loss_token_masked, _ = loss_fn(dummy_logits, data_mod_token) + # Loss should change if a potentially contributing token is masked + assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), ( + "Token mask did not change loss as expected" + ) + + # --- Test 2: Sample Mask --- + data_mod_sample = data.copy() + data_mod_sample["sample_mask"] = torch.tensor( + [1, 0], dtype=torch.int64, device=device + ) # Ignore item 1 + + loss_sample_masked, _ = loss_fn(dummy_logits, data_mod_sample) + + # Manually create data dict for only batch 0 + data_only_b0_dict = {} + for key, value in data.items(): + if isinstance(value, torch.Tensor): + if key == "sample_mask": + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value + data_only_b0 = BatchedDataDict(data_only_b0_dict) + + logits_only_b0 = dummy_logits[0:1] + loss_only_b0, _ = loss_fn(logits_only_b0, data_only_b0) + + torch.testing.assert_close(loss_sample_masked, loss_only_b0) + + +def test_clipped_pg_loss_zero_mask(): + """Tests the case where the combined mask sum is zero.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + # Need dummy logits + dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) + + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # Set token mask to all zeros + data["token_mask"] = torch.zeros_like(data["token_mask"]) + + loss, _ = loss_fn(dummy_logits, data) + + # Loss should be exactly zero + torch.testing.assert_close(loss, torch.tensor(0.0, device=device))