From 2d4e1c3423e6e6aad384ae3b3b4cde3e1b7d021c Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 10 Apr 2025 13:20:31 -0700 Subject: [PATCH 1/2] Add loss tests (prep for incoming vocab parallel Signed-off-by: Sahil Jain --- tests/unit/algorithms/test_loss_functions.py | 287 ++++++++++++++++++- 1 file changed, 282 insertions(+), 5 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index fe874ecc26..564d9f699c 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -13,7 +13,11 @@ # limitations under the License. import pytest import torch -from nemo_reinforcer.algorithms.loss_functions import NLLLoss +import numpy as np + +from nemo_reinforcer.algorithms.loss_functions import NLLLoss, ClippedPGLossFn +from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict +from nemo_reinforcer.algorithms.utils import calculate_kl_penalty_joschu2020, masked_mean def test_nll_loss(): @@ -46,7 +50,7 @@ def test_nll_loss(): .to("cuda") ) loss, metrics_dict = loss_fn(next_token_logits, data) - torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0)) + torch.testing.assert_close(loss.cpu(), torch.tensor(0.0)) # Check the metrics dictionary contains the expected values assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 @@ -66,8 +70,281 @@ def test_nll_loss(): ) loss, metrics_dict = loss_fn(next_token_logits, data) ## loss per token is 999, and we have two unmasked tokens - ## with the updated loss function, we now average the loss over unmasked tokens - torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0)) - # Check the metrics dictionary contains the expected values + ## NLLLoss averages the loss over unmasked tokens + torch.testing.assert_close(loss.cpu(), torch.tensor(999.0)) assert metrics_dict["num_unmasked_tokens"] == 2 assert metrics_dict["total_tokens"] == 3 + + +def _setup_clipped_pg_test_data( + batch_size=1, seq_len=4, vocab_size=8, device="cuda" +): + """Sets up basic mock data structure. Tests should fill values.""" + input_ids = torch.randint( # Input IDs only needed if original loss fn used + 0, vocab_size, (batch_size, seq_len), dtype=torch.int64, device=device + ) + # Default mask: Mask first token [[0, 1, 1, 1]] + token_mask = torch.ones( + (batch_size, seq_len), dtype=torch.int64, device=device + ) + token_mask[:, 0] = 0 + # sample_mask needs shape [B] + sample_mask = torch.ones(batch_size, dtype=torch.int64, device=device) + + # Simple default values, tests overwrite these + advantages = torch.zeros((batch_size, seq_len), device=device) + prev_logprobs = torch.zeros((batch_size, seq_len), device=device) + reference_policy_logprobs = torch.zeros((batch_size, seq_len), device=device) + generation_logprobs = torch.zeros((batch_size, seq_len), device=device) + + data = BatchedDataDict( + { + "input_ids": input_ids, # Include for completeness + "token_mask": token_mask, + "sample_mask": sample_mask, + "advantages": advantages, + "prev_logprobs": prev_logprobs, + "reference_policy_logprobs": reference_policy_logprobs, + "generation_logprobs": generation_logprobs, + } + ) + # Return seq_len and vocab_size needed by tests + return data, seq_len, vocab_size + + +# Helper to create logits that yield specific target log probs after log_softmax +def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): + """Constructs logits such that log_softmax results in target_curr_lp_masked.""" + dummy_logits = torch.full((1, seq_len, vocab_size), -100.0, device=device) # Start very low + + # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] + # We need to set logits for indices i=0..S-2 of the sliced logits tensor. + # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. + num_effective_pos = target_curr_lp_masked.shape[1] + for i in range(num_effective_pos): + logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) + data_idx = i + 1 # Index in the original input_ids to find the target token + + target_token_id = input_ids[0, data_idx].item() + # Keep target_lp as a 0-dim tensor for torch ops + target_lp = target_curr_lp_masked[0, i] + + # Handle target_lp = 0 case separately + if torch.isclose(target_lp, torch.tensor(0.0, device=device)): + dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + elif target_lp < 0: + # Set target token logit to 0 + dummy_logits[0, logit_idx, target_token_id] = 0.0 + # Set one distractor token logit using the formula + distractor_token_id = (target_token_id + 1) % vocab_size + # Ensure distractor isn't same as target if vocab_size=1 (edge case) + if distractor_token_id == target_token_id: + distractor_token_id = (target_token_id + 2) % vocab_size + distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) + dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit + else: # target_lp > 0 is not supported by this method + raise ValueError("Target log probability must be negative or zero for this construction") + return dummy_logits + + +# Simplified PPO Clipping Test using original Loss +def test_clipped_pg_loss_ppo_clipping(): + """Tests PPO clipping calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + ratio_eps = 0.2 + cfg = { + "ratio_eps_min": ratio_eps, + "ratio_eps_max": ratio_eps, + "reference_policy_kl_penalty": 0.0, # Disable KL + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + # Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0 + prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping + # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) + # Curr = log(Ratio) + Prev + curr_lp_masked = torch.tensor([[-1.69315, -1.0, -0.59453]], device=device) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + + # Fill full tensors (only need first dim for B=1) + data["advantages"][0, 1:] = adv_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + + # --- Hand Calculation --- + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp(ratios, 1.0 - ratio_eps, 1.0 + ratio_eps) # [0.8, 1.0, 1.2] + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + expected_loss = torch.mean(max_loss) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits(curr_lp_masked, input_ids, seq_len, vocab_size, device) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified REINFORCE Test using original Loss +def test_clipped_pg_loss_reinforce_mode(): + """Tests REINFORCE mode calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + cfg = { + "disable_ppo_ratio": True, + "reference_policy_kl_penalty": 0.0, + "ratio_eps_min": 0.0, # Placeholder, ignored + "ratio_eps_max": 0.0, # Placeholder, ignored + } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device) + curr_lp_masked = torch.tensor([[-0.5, -1.0, -1.5]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["_test_curr_logprobs"] = curr_lp_masked + data["prev_logprobs"][0, 1:] = torch.zeros_like(curr_lp_masked) + + # --- Hand Calculation --- + expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits(curr_lp_masked, input_ids, seq_len, vocab_size, device) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Simplified KL Penalty Test using original Loss +def test_clipped_pg_loss_kl_penalty(): + """Tests KL penalty calculations directly.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + + # --- Test Setup --- + kl_beta = 0.1 + cfg = { "reference_policy_kl_penalty": kl_beta, "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False } + loss_fn = ClippedPGLossFn(cfg) + + adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + curr_lp_masked = torch.tensor([[0.0, -1.0, -2.0]], device=device) + ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device) + prev_lp_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) + + data["advantages"][0, 1:] = adv_masked + data["reference_policy_logprobs"][0, 1:] = ref_lp_masked + data["prev_logprobs"][0, 1:] = prev_lp_masked + data["_test_curr_logprobs"] = curr_lp_masked + + # --- Hand Calculation --- + # Actor loss is 0. Total loss = kl_beta * mean(kl_term) + # kl_term = exp(ref - curr) - (ref - curr) - 1 + r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + expected_loss = kl_beta * expected_kl_mean # 0.0362 + + input_ids = data["input_ids"] + dummy_logits = _create_exact_logits(curr_lp_masked, input_ids, seq_len, vocab_size, device) + + actual_loss, _ = loss_fn(dummy_logits, data) + torch.testing.assert_close(actual_loss, expected_loss) + + +# Masking tests - Should work with original Loss Fn if needed, but less critical +def test_clipped_pg_loss_masking(): + """Tests the effect of token_mask and sample_mask.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + batch_size = 2 + seq_len = 4 + device = "cuda" + # Use original loss function for masking tests, as it involves interactions + # that the Testable class might obscure slightly. + data, seq_len, vocab_size = _setup_clipped_pg_test_data( batch_size=batch_size, seq_len=seq_len, device=device) + # Need some realistic-ish logits and logprobs for masking test + dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + # Ensure logprobs used by the loss fn make sense relative to advantages + data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 + data["reference_policy_logprobs"] = torch.randn_like(data["reference_policy_logprobs"]) * 0.1 + # Make advantages non-zero + data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 + + cfg = { "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False,} + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # --- Test 1: Token Mask --- + # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample + loss_default, _ = loss_fn(dummy_logits, data) + + # Modify token_mask for batch item 0 to mask one more token (pos 1) + data_mod_token = data.copy() + data_mod_token["token_mask"] = data["token_mask"].clone() + data_mod_token["token_mask"][0, 1] = 0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1 + + loss_token_masked, _ = loss_fn(dummy_logits, data_mod_token) + # Loss should change if a potentially contributing token is masked + assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), \ + "Token mask did not change loss as expected" + + # --- Test 2: Sample Mask --- + data_mod_sample = data.copy() + data_mod_sample["sample_mask"] = torch.tensor([1, 0], dtype=torch.int64, device=device) # Ignore item 1 + + loss_sample_masked, _ = loss_fn(dummy_logits, data_mod_sample) + + # Manually create data dict for only batch 0 + data_only_b0_dict = {} + for key, value in data.items(): + if isinstance(value, torch.Tensor): + if key == "sample_mask": + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value[0:1] + else: + data_only_b0_dict[key] = value + data_only_b0 = BatchedDataDict(data_only_b0_dict) + + logits_only_b0 = dummy_logits[0:1] + loss_only_b0, _ = loss_fn(logits_only_b0, data_only_b0) + + torch.testing.assert_close(loss_sample_masked, loss_only_b0) + + +def test_clipped_pg_loss_zero_mask(): + """Tests the case where the combined mask sum is zero.""" + if not torch.cuda.is_available(): + pytest.skip("No GPU available") + + device = "cuda" + data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) + # Need dummy logits + dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) + + cfg = { "ratio_eps_min": 0.2,"ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + + # Set token mask to all zeros + data["token_mask"] = torch.zeros_like(data["token_mask"]) + + loss, _ = loss_fn(dummy_logits, data) + + # Loss should be exactly zero + torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) + From be8aa5ecdf53b2cb7a2879d6ac7c49f0c4c0cbf4 Mon Sep 17 00:00:00 2001 From: Sahil Jain Date: Thu, 10 Apr 2025 13:23:38 -0700 Subject: [PATCH 2/2] formatting Signed-off-by: Sahil Jain --- tests/unit/algorithms/test_loss_functions.py | 136 ++++++++++++------- 1 file changed, 87 insertions(+), 49 deletions(-) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index 564d9f699c..af78baf34d 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -17,7 +17,10 @@ from nemo_reinforcer.algorithms.loss_functions import NLLLoss, ClippedPGLossFn from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict -from nemo_reinforcer.algorithms.utils import calculate_kl_penalty_joschu2020, masked_mean +from nemo_reinforcer.algorithms.utils import ( + calculate_kl_penalty_joschu2020, + masked_mean, +) def test_nll_loss(): @@ -76,17 +79,13 @@ def test_nll_loss(): assert metrics_dict["total_tokens"] == 3 -def _setup_clipped_pg_test_data( - batch_size=1, seq_len=4, vocab_size=8, device="cuda" -): +def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"): """Sets up basic mock data structure. Tests should fill values.""" - input_ids = torch.randint( # Input IDs only needed if original loss fn used + input_ids = torch.randint( # Input IDs only needed if original loss fn used 0, vocab_size, (batch_size, seq_len), dtype=torch.int64, device=device ) # Default mask: Mask first token [[0, 1, 1, 1]] - token_mask = torch.ones( - (batch_size, seq_len), dtype=torch.int64, device=device - ) + token_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) token_mask[:, 0] = 0 # sample_mask needs shape [B] sample_mask = torch.ones(batch_size, dtype=torch.int64, device=device) @@ -99,7 +98,7 @@ def _setup_clipped_pg_test_data( data = BatchedDataDict( { - "input_ids": input_ids, # Include for completeness + "input_ids": input_ids, # Include for completeness "token_mask": token_mask, "sample_mask": sample_mask, "advantages": advantages, @@ -115,15 +114,17 @@ def _setup_clipped_pg_test_data( # Helper to create logits that yield specific target log probs after log_softmax def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device): """Constructs logits such that log_softmax results in target_curr_lp_masked.""" - dummy_logits = torch.full((1, seq_len, vocab_size), -100.0, device=device) # Start very low - + dummy_logits = torch.full( + (1, seq_len, vocab_size), -100.0, device=device + ) # Start very low + # Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:] # We need to set logits for indices i=0..S-2 of the sliced logits tensor. # These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked. num_effective_pos = target_curr_lp_masked.shape[1] for i in range(num_effective_pos): - logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) - data_idx = i + 1 # Index in the original input_ids to find the target token + logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :]) + data_idx = i + 1 # Index in the original input_ids to find the target token target_token_id = input_ids[0, data_idx].item() # Keep target_lp as a 0-dim tensor for torch ops @@ -131,7 +132,7 @@ def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, # Handle target_lp = 0 case separately if torch.isclose(target_lp, torch.tensor(0.0, device=device)): - dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit + dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit elif target_lp < 0: # Set target token logit to 0 dummy_logits[0, logit_idx, target_token_id] = 0.0 @@ -142,8 +143,10 @@ def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, distractor_token_id = (target_token_id + 2) % vocab_size distractor_logit = torch.log(torch.exp(-target_lp) - 1.0) dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit - else: # target_lp > 0 is not supported by this method - raise ValueError("Target log probability must be negative or zero for this construction") + else: # target_lp > 0 is not supported by this method + raise ValueError( + "Target log probability must be negative or zero for this construction" + ) return dummy_logits @@ -160,7 +163,7 @@ def test_clipped_pg_loss_ppo_clipping(): cfg = { "ratio_eps_min": ratio_eps, "ratio_eps_max": ratio_eps, - "reference_policy_kl_penalty": 0.0, # Disable KL + "reference_policy_kl_penalty": 0.0, # Disable KL "disable_ppo_ratio": False, } loss_fn = ClippedPGLossFn(cfg) @@ -171,22 +174,30 @@ def test_clipped_pg_loss_ppo_clipping(): # Target Curr logprobs (masked pos 1, 2, 3) - design for clipping # Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2) # Curr = log(Ratio) + Prev - curr_lp_masked = torch.tensor([[-1.69315, -1.0, -0.59453]], device=device) # approx log(0.5)-1, log(1)-1, log(1.5)-1 + curr_lp_masked = torch.tensor( + [[-1.69315, -1.0, -0.59453]], device=device + ) # approx log(0.5)-1, log(1)-1, log(1.5)-1 # Fill full tensors (only need first dim for B=1) data["advantages"][0, 1:] = adv_masked data["prev_logprobs"][0, 1:] = prev_lp_masked # --- Hand Calculation --- - ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] - ratios_clamped = torch.clamp(ratios, 1.0 - ratio_eps, 1.0 + ratio_eps) # [0.8, 1.0, 1.2] - loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] - loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] - max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] - expected_loss = torch.mean(max_loss) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 + ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5] + ratios_clamped = torch.clamp( + ratios, 1.0 - ratio_eps, 1.0 + ratio_eps + ) # [0.8, 1.0, 1.2] + loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0] + loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4] + max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4] + expected_loss = torch.mean( + max_loss + ) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333 input_ids = data["input_ids"] - dummy_logits = _create_exact_logits(curr_lp_masked, input_ids, seq_len, vocab_size, device) + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) actual_loss, _ = loss_fn(dummy_logits, data) torch.testing.assert_close(actual_loss, expected_loss) @@ -202,10 +213,10 @@ def test_clipped_pg_loss_reinforce_mode(): data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device) cfg = { - "disable_ppo_ratio": True, + "disable_ppo_ratio": True, "reference_policy_kl_penalty": 0.0, - "ratio_eps_min": 0.0, # Placeholder, ignored - "ratio_eps_max": 0.0, # Placeholder, ignored + "ratio_eps_min": 0.0, # Placeholder, ignored + "ratio_eps_max": 0.0, # Placeholder, ignored } loss_fn = ClippedPGLossFn(cfg) @@ -217,11 +228,13 @@ def test_clipped_pg_loss_reinforce_mode(): data["prev_logprobs"][0, 1:] = torch.zeros_like(curr_lp_masked) # --- Hand Calculation --- - expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] - expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 + expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0] + expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333 input_ids = data["input_ids"] - dummy_logits = _create_exact_logits(curr_lp_masked, input_ids, seq_len, vocab_size, device) + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) actual_loss, _ = loss_fn(dummy_logits, data) torch.testing.assert_close(actual_loss, expected_loss) @@ -238,7 +251,12 @@ def test_clipped_pg_loss_kl_penalty(): # --- Test Setup --- kl_beta = 0.1 - cfg = { "reference_policy_kl_penalty": kl_beta, "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "disable_ppo_ratio": False } + cfg = { + "reference_policy_kl_penalty": kl_beta, + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "disable_ppo_ratio": False, + } loss_fn = ClippedPGLossFn(cfg) adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device) @@ -254,13 +272,15 @@ def test_clipped_pg_loss_kl_penalty(): # --- Hand Calculation --- # Actor loss is 0. Total loss = kl_beta * mean(kl_term) # kl_term = exp(ref - curr) - (ref - curr) - 1 - r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] - kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] - expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 - expected_loss = kl_beta * expected_kl_mean # 0.0362 + r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0] + kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718] + expected_kl_mean = torch.mean(kl_term_per_token) # 0.362 + expected_loss = kl_beta * expected_kl_mean # 0.0362 input_ids = data["input_ids"] - dummy_logits = _create_exact_logits(curr_lp_masked, input_ids, seq_len, vocab_size, device) + dummy_logits = _create_exact_logits( + curr_lp_masked, input_ids, seq_len, vocab_size, device + ) actual_loss, _ = loss_fn(dummy_logits, data) torch.testing.assert_close(actual_loss, expected_loss) @@ -277,17 +297,26 @@ def test_clipped_pg_loss_masking(): device = "cuda" # Use original loss function for masking tests, as it involves interactions # that the Testable class might obscure slightly. - data, seq_len, vocab_size = _setup_clipped_pg_test_data( batch_size=batch_size, seq_len=seq_len, device=device) + data, seq_len, vocab_size = _setup_clipped_pg_test_data( + batch_size=batch_size, seq_len=seq_len, device=device + ) # Need some realistic-ish logits and logprobs for masking test dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) # Ensure logprobs used by the loss fn make sense relative to advantages data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 - data["reference_policy_logprobs"] = torch.randn_like(data["reference_policy_logprobs"]) * 0.1 + data["reference_policy_logprobs"] = ( + torch.randn_like(data["reference_policy_logprobs"]) * 0.1 + ) # Make advantages non-zero data["advantages"] = torch.randn_like(data["advantages"]) + 1.0 - cfg = { "ratio_eps_min": 0.2, "ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False,} - loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn # --- Test 1: Token Mask --- # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample @@ -296,16 +325,21 @@ def test_clipped_pg_loss_masking(): # Modify token_mask for batch item 0 to mask one more token (pos 1) data_mod_token = data.copy() data_mod_token["token_mask"] = data["token_mask"].clone() - data_mod_token["token_mask"][0, 1] = 0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1 + data_mod_token["token_mask"][0, 1] = ( + 0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1 + ) loss_token_masked, _ = loss_fn(dummy_logits, data_mod_token) # Loss should change if a potentially contributing token is masked - assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), \ - "Token mask did not change loss as expected" + assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), ( + "Token mask did not change loss as expected" + ) # --- Test 2: Sample Mask --- data_mod_sample = data.copy() - data_mod_sample["sample_mask"] = torch.tensor([1, 0], dtype=torch.int64, device=device) # Ignore item 1 + data_mod_sample["sample_mask"] = torch.tensor( + [1, 0], dtype=torch.int64, device=device + ) # Ignore item 1 loss_sample_masked, _ = loss_fn(dummy_logits, data_mod_sample) @@ -314,7 +348,7 @@ def test_clipped_pg_loss_masking(): for key, value in data.items(): if isinstance(value, torch.Tensor): if key == "sample_mask": - data_only_b0_dict[key] = value[0:1] + data_only_b0_dict[key] = value[0:1] else: data_only_b0_dict[key] = value[0:1] else: @@ -337,8 +371,13 @@ def test_clipped_pg_loss_zero_mask(): # Need dummy logits dummy_logits = torch.randn(1, seq_len, vocab_size, device=device) - cfg = { "ratio_eps_min": 0.2,"ratio_eps_max": 0.2, "reference_policy_kl_penalty": 0.1, "disable_ppo_ratio": False, } - loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + cfg = { + "ratio_eps_min": 0.2, + "ratio_eps_max": 0.2, + "reference_policy_kl_penalty": 0.1, + "disable_ppo_ratio": False, + } + loss_fn = ClippedPGLossFn(cfg) # Use original loss fn # Set token mask to all zeros data["token_mask"] = torch.zeros_like(data["token_mask"]) @@ -347,4 +386,3 @@ def test_clipped_pg_loss_zero_mask(): # Loss should be exactly zero torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) -