Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
325 changes: 320 additions & 5 deletions tests/unit/algorithms/test_loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
# limitations under the License.
import pytest
import torch
from nemo_reinforcer.algorithms.loss_functions import NLLLoss
import numpy as np

from nemo_reinforcer.algorithms.loss_functions import NLLLoss, ClippedPGLossFn
from nemo_reinforcer.distributed.batched_data_dict import BatchedDataDict
from nemo_reinforcer.algorithms.utils import (
calculate_kl_penalty_joschu2020,
masked_mean,
)


def test_nll_loss():
Expand Down Expand Up @@ -46,7 +53,7 @@ def test_nll_loss():
.to("cuda")
)
loss, metrics_dict = loss_fn(next_token_logits, data)
torch.testing.assert_allclose(loss.cpu(), torch.tensor(0.0))
torch.testing.assert_close(loss.cpu(), torch.tensor(0.0))
# Check the metrics dictionary contains the expected values
assert metrics_dict["num_unmasked_tokens"] == 2
assert metrics_dict["total_tokens"] == 3
Expand All @@ -66,8 +73,316 @@ def test_nll_loss():
)
loss, metrics_dict = loss_fn(next_token_logits, data)
## loss per token is 999, and we have two unmasked tokens
## with the updated loss function, we now average the loss over unmasked tokens
torch.testing.assert_allclose(loss.cpu(), torch.tensor(999.0))
# Check the metrics dictionary contains the expected values
## NLLLoss averages the loss over unmasked tokens
torch.testing.assert_close(loss.cpu(), torch.tensor(999.0))
assert metrics_dict["num_unmasked_tokens"] == 2
assert metrics_dict["total_tokens"] == 3


def _setup_clipped_pg_test_data(batch_size=1, seq_len=4, vocab_size=8, device="cuda"):
"""Sets up basic mock data structure. Tests should fill values."""
input_ids = torch.randint( # Input IDs only needed if original loss fn used
0, vocab_size, (batch_size, seq_len), dtype=torch.int64, device=device
)
# Default mask: Mask first token [[0, 1, 1, 1]]
token_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
token_mask[:, 0] = 0
# sample_mask needs shape [B]
sample_mask = torch.ones(batch_size, dtype=torch.int64, device=device)

# Simple default values, tests overwrite these
advantages = torch.zeros((batch_size, seq_len), device=device)
prev_logprobs = torch.zeros((batch_size, seq_len), device=device)
reference_policy_logprobs = torch.zeros((batch_size, seq_len), device=device)
generation_logprobs = torch.zeros((batch_size, seq_len), device=device)

data = BatchedDataDict(
{
"input_ids": input_ids, # Include for completeness
"token_mask": token_mask,
"sample_mask": sample_mask,
"advantages": advantages,
"prev_logprobs": prev_logprobs,
"reference_policy_logprobs": reference_policy_logprobs,
"generation_logprobs": generation_logprobs,
}
)
# Return seq_len and vocab_size needed by tests
return data, seq_len, vocab_size


# Helper to create logits that yield specific target log probs after log_softmax
def _create_exact_logits(target_curr_lp_masked, input_ids, seq_len, vocab_size, device):
"""Constructs logits such that log_softmax results in target_curr_lp_masked."""
dummy_logits = torch.full(
(1, seq_len, vocab_size), -100.0, device=device
) # Start very low

# Loss fn uses logits[:, :-1] and gathers based on next_tokens = input_ids[:, 1:]
# We need to set logits for indices i=0..S-2 of the sliced logits tensor.
# These correspond to target logprobs at indices 0..S-2 of target_curr_lp_masked.
num_effective_pos = target_curr_lp_masked.shape[1]
for i in range(num_effective_pos):
logit_idx = i # Index in the sliced logits tensor (dummy_logits[:, 0:S-1, :])
data_idx = i + 1 # Index in the original input_ids to find the target token

target_token_id = input_ids[0, data_idx].item()
# Keep target_lp as a 0-dim tensor for torch ops
target_lp = target_curr_lp_masked[0, i]

# Handle target_lp = 0 case separately
if torch.isclose(target_lp, torch.tensor(0.0, device=device)):
dummy_logits[0, logit_idx, target_token_id] = 100.0 # Large positive logit
elif target_lp < 0:
# Set target token logit to 0
dummy_logits[0, logit_idx, target_token_id] = 0.0
# Set one distractor token logit using the formula
distractor_token_id = (target_token_id + 1) % vocab_size
# Ensure distractor isn't same as target if vocab_size=1 (edge case)
if distractor_token_id == target_token_id:
distractor_token_id = (target_token_id + 2) % vocab_size
distractor_logit = torch.log(torch.exp(-target_lp) - 1.0)
dummy_logits[0, logit_idx, distractor_token_id] = distractor_logit
else: # target_lp > 0 is not supported by this method
raise ValueError(
"Target log probability must be negative or zero for this construction"
)
return dummy_logits


# Simplified PPO Clipping Test using original Loss
def test_clipped_pg_loss_ppo_clipping():
"""Tests PPO clipping calculations directly."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)

ratio_eps = 0.2
cfg = {
"ratio_eps_min": ratio_eps,
"ratio_eps_max": ratio_eps,
"reference_policy_kl_penalty": 0.0, # Disable KL
"disable_ppo_ratio": False,
}
loss_fn = ClippedPGLossFn(cfg)

adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device)
# Use non-zero prev_lp to allow ratios > 1 with valid curr_lp <= 0
prev_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
# Target Curr logprobs (masked pos 1, 2, 3) - design for clipping
# Target ratios: 0.5 (<0.8), 1.0 (in [0.8, 1.2]), 1.5 (>1.2)
# Curr = log(Ratio) + Prev
curr_lp_masked = torch.tensor(
[[-1.69315, -1.0, -0.59453]], device=device
) # approx log(0.5)-1, log(1)-1, log(1.5)-1

# Fill full tensors (only need first dim for B=1)
data["advantages"][0, 1:] = adv_masked
data["prev_logprobs"][0, 1:] = prev_lp_masked

# --- Hand Calculation ---
ratios = torch.exp(curr_lp_masked - prev_lp_masked) # approx [0.5, 1.0, 1.5]
ratios_clamped = torch.clamp(
ratios, 1.0 - ratio_eps, 1.0 + ratio_eps
) # [0.8, 1.0, 1.2]
loss1 = -adv_masked * ratios # approx -[1*0.5, -1*1.0, 2*1.5] = [-0.5, 1.0, -3.0]
loss2 = -adv_masked * ratios_clamped # -[1*0.8, -1*1.0, 2*1.2] = [-0.8, 1.0, -2.4]
max_loss = torch.maximum(loss1, loss2) # approx [-0.5, 1.0, -2.4]
expected_loss = torch.mean(
max_loss
) # approx (-0.5 + 1.0 - 2.4) / 3 = -1.9 / 3 = -0.6333

input_ids = data["input_ids"]
dummy_logits = _create_exact_logits(
curr_lp_masked, input_ids, seq_len, vocab_size, device
)

actual_loss, _ = loss_fn(dummy_logits, data)
torch.testing.assert_close(actual_loss, expected_loss)


# Simplified REINFORCE Test using original Loss
def test_clipped_pg_loss_reinforce_mode():
"""Tests REINFORCE mode calculations directly."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)

cfg = {
"disable_ppo_ratio": True,
"reference_policy_kl_penalty": 0.0,
"ratio_eps_min": 0.0, # Placeholder, ignored
"ratio_eps_max": 0.0, # Placeholder, ignored
}
loss_fn = ClippedPGLossFn(cfg)

adv_masked = torch.tensor([[1.0, -1.0, 2.0]], device=device)
curr_lp_masked = torch.tensor([[-0.5, -1.0, -1.5]], device=device)

data["advantages"][0, 1:] = adv_masked
data["_test_curr_logprobs"] = curr_lp_masked
data["prev_logprobs"][0, 1:] = torch.zeros_like(curr_lp_masked)

# --- Hand Calculation ---
expected_loss_per_token = -adv_masked * curr_lp_masked # [0.5, -1.0, 3.0]
expected_loss = torch.mean(expected_loss_per_token) # 2.5 / 3 = 0.8333

input_ids = data["input_ids"]
dummy_logits = _create_exact_logits(
curr_lp_masked, input_ids, seq_len, vocab_size, device
)

actual_loss, _ = loss_fn(dummy_logits, data)
torch.testing.assert_close(actual_loss, expected_loss)


# Simplified KL Penalty Test using original Loss
def test_clipped_pg_loss_kl_penalty():
"""Tests KL penalty calculations directly."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)

# --- Test Setup ---
kl_beta = 0.1
cfg = {
"reference_policy_kl_penalty": kl_beta,
"ratio_eps_min": 0.2,
"ratio_eps_max": 0.2,
"disable_ppo_ratio": False,
}
loss_fn = ClippedPGLossFn(cfg)

adv_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device)
curr_lp_masked = torch.tensor([[0.0, -1.0, -2.0]], device=device)
ref_lp_masked = torch.tensor([[-1.0, -1.0, -1.0]], device=device)
prev_lp_masked = torch.tensor([[0.0, 0.0, 0.0]], device=device)

data["advantages"][0, 1:] = adv_masked
data["reference_policy_logprobs"][0, 1:] = ref_lp_masked
data["prev_logprobs"][0, 1:] = prev_lp_masked
data["_test_curr_logprobs"] = curr_lp_masked

# --- Hand Calculation ---
# Actor loss is 0. Total loss = kl_beta * mean(kl_term)
# kl_term = exp(ref - curr) - (ref - curr) - 1
r = ref_lp_masked - curr_lp_masked # [-1.0, 0.0, 1.0]
kl_term_per_token = torch.exp(r) - r - 1 # [0.368, 0.0, 0.718]
expected_kl_mean = torch.mean(kl_term_per_token) # 0.362
expected_loss = kl_beta * expected_kl_mean # 0.0362

input_ids = data["input_ids"]
dummy_logits = _create_exact_logits(
curr_lp_masked, input_ids, seq_len, vocab_size, device
)

actual_loss, _ = loss_fn(dummy_logits, data)
torch.testing.assert_close(actual_loss, expected_loss)


# Masking tests - Should work with original Loss Fn if needed, but less critical
def test_clipped_pg_loss_masking():
"""Tests the effect of token_mask and sample_mask."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

batch_size = 2
seq_len = 4
device = "cuda"
# Use original loss function for masking tests, as it involves interactions
# that the Testable class might obscure slightly.
data, seq_len, vocab_size = _setup_clipped_pg_test_data(
batch_size=batch_size, seq_len=seq_len, device=device
)
# Need some realistic-ish logits and logprobs for masking test
dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device)
# Ensure logprobs used by the loss fn make sense relative to advantages
data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1
data["reference_policy_logprobs"] = (
torch.randn_like(data["reference_policy_logprobs"]) * 0.1
)
# Make advantages non-zero
data["advantages"] = torch.randn_like(data["advantages"]) + 1.0

cfg = {
"ratio_eps_min": 0.2,
"ratio_eps_max": 0.2,
"reference_policy_kl_penalty": 0.1,
"disable_ppo_ratio": False,
}
loss_fn = ClippedPGLossFn(cfg) # Use original loss fn

# --- Test 1: Token Mask ---
# Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample
loss_default, _ = loss_fn(dummy_logits, data)

# Modify token_mask for batch item 0 to mask one more token (pos 1)
data_mod_token = data.copy()
data_mod_token["token_mask"] = data["token_mask"].clone()
data_mod_token["token_mask"][0, 1] = (
0 # New mask: [[0, 0, 1, 1], [0, 1, 1, 1]] -> 2 tokens sample 0, 3 tokens sample 1
)

loss_token_masked, _ = loss_fn(dummy_logits, data_mod_token)
# Loss should change if a potentially contributing token is masked
assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), (
"Token mask did not change loss as expected"
)

# --- Test 2: Sample Mask ---
data_mod_sample = data.copy()
data_mod_sample["sample_mask"] = torch.tensor(
[1, 0], dtype=torch.int64, device=device
) # Ignore item 1

loss_sample_masked, _ = loss_fn(dummy_logits, data_mod_sample)

# Manually create data dict for only batch 0
data_only_b0_dict = {}
for key, value in data.items():
if isinstance(value, torch.Tensor):
if key == "sample_mask":
data_only_b0_dict[key] = value[0:1]
else:
data_only_b0_dict[key] = value[0:1]
else:
data_only_b0_dict[key] = value
data_only_b0 = BatchedDataDict(data_only_b0_dict)

logits_only_b0 = dummy_logits[0:1]
loss_only_b0, _ = loss_fn(logits_only_b0, data_only_b0)

torch.testing.assert_close(loss_sample_masked, loss_only_b0)


def test_clipped_pg_loss_zero_mask():
"""Tests the case where the combined mask sum is zero."""
if not torch.cuda.is_available():
pytest.skip("No GPU available")

device = "cuda"
data, seq_len, vocab_size = _setup_clipped_pg_test_data(device=device)
# Need dummy logits
dummy_logits = torch.randn(1, seq_len, vocab_size, device=device)

cfg = {
"ratio_eps_min": 0.2,
"ratio_eps_max": 0.2,
"reference_policy_kl_penalty": 0.1,
"disable_ppo_ratio": False,
}
loss_fn = ClippedPGLossFn(cfg) # Use original loss fn

# Set token mask to all zeros
data["token_mask"] = torch.zeros_like(data["token_mask"])

loss, _ = loss_fn(dummy_logits, data)

# Loss should be exactly zero
torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
Loading