diff --git a/nemo_reinforcer/algorithms/loss_functions.py b/nemo_reinforcer/algorithms/loss_functions.py index 158c9824eb..320ccbbe6f 100644 --- a/nemo_reinforcer/algorithms/loss_functions.py +++ b/nemo_reinforcer/algorithms/loss_functions.py @@ -91,7 +91,7 @@ def __call__( mask = token_mask * sample_mask.unsqueeze(-1) lp_error = torch.abs(generation_logprobs - prev_logprobs) # noqa: F841 (precommit ignore for now) - mult_prob_error = ((torch.exp(lp_error) * mask).sum() / mask.sum()).item() + mult_prob_error = masked_mean(torch.exp(lp_error), mask).item() next_token_logits = next_token_logits[:, :-1] # Remove last position's logits next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) @@ -124,13 +124,8 @@ def __call__( loss1 = -advantages * ratios loss2 = -advantages * ratios_clamped - if mask.sum() > 0: - actor_loss = masked_mean(torch.max(loss1, loss2), mask) - loss = actor_loss + kl - else: - # disable this update since there are no valid tokens - loss = loss1.view(-1)[0] * 0 - + actor_loss = masked_mean(torch.max(loss1, loss2), mask) + loss = actor_loss + kl with torch.no_grad(): probs_ratio = masked_mean(ratios.detach(), mask).item() probs_ratio_clamped = masked_mean(ratios_clamped.detach(), mask).item() diff --git a/nemo_reinforcer/algorithms/utils.py b/nemo_reinforcer/algorithms/utils.py index a3c42e2a19..e68ea1d6fa 100644 --- a/nemo_reinforcer/algorithms/utils.py +++ b/nemo_reinforcer/algorithms/utils.py @@ -120,9 +120,7 @@ def wrapper(*args, **kwargs): @surpress_user_warnings def masked_mean(values, mask, dim=None): """Masks values with mask, and computes the mean of the values using the masked values.""" - if dim is None: - return values[mask.bool()].mean() - return as_masked_tensor(values, mask.bool()).mean(dim=dim).to_tensor(torch.nan) + return (values * mask).sum(dim=dim) / (mask.sum(dim=dim) + 1e-8) def set_seed(seed: int): diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index af78baf34d..447bd20a54 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -386,3 +386,25 @@ def test_clipped_pg_loss_zero_mask(): # Loss should be exactly zero torch.testing.assert_close(loss, torch.tensor(0.0, device=device)) + + +def test_masked_mean_all_zeros(): + """Test masked_mean function with all zeros mask.""" + values = torch.tensor([1.0, 2.0, 3.0, 4.0]) + mask = torch.zeros_like(values) + + # All zeros mask should return 0 + result = masked_mean(values, mask) + print(result) + torch.testing.assert_allclose(result, torch.tensor(0.0)) + + # With check_zero_mask=False + mask[0] = 1 + result = masked_mean(values, mask) + torch.testing.assert_allclose(result, torch.tensor(1.0)) + + # Case 2: dim is not None + values = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) + mask = torch.zeros_like(values) + result = masked_mean(values, mask, dim=1) + torch.testing.assert_allclose(result, torch.tensor([0.0, 0.0]))