diff --git a/docs/design-docs/loss-functions.md b/docs/design-docs/loss-functions.md index b0fb9523e2..37cf868ad8 100644 --- a/docs/design-docs/loss-functions.md +++ b/docs/design-docs/loss-functions.md @@ -17,14 +17,13 @@ $$ which is, in general, not equivalent to the full-batch loss. To fix this, we need each microbatch to have information about how many tokens are in the other microbatches in the global batch. -In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses. +In NeMo RL, this information is passed to the loss function directly. Each loss function is expected to fall into one of two categories, token-level or sequence-level, which is an attribute of the loss function itself (see [loss_functions.py](../../nemo_rl/algorithms/loss/loss_functions.py) for some examples). The policy then uses this information to compute the global normalization factor using the full batch (for token-level losses, this is the total number of tokens in the batch. For sequence-level losses, this is the number of valid sequences in the batch). The normalization factor is then passed to the loss function, which uses it to normalize the microbatch loss. To get the loss for the global batch, the policy simply sums across all microbatch losses. For our simple example above, this would look like: ```{testcode} import torch -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/docs/guides/grpo.md b/docs/guides/grpo.md index 2576b93303..2e7b410de9 100755 --- a/docs/guides/grpo.md +++ b/docs/guides/grpo.md @@ -343,7 +343,7 @@ The function, [grpo_train](../../nemo_rl/algorithms/grpo.py), contains the core RL generations typically produce highly variable sequence lengths, which result in a significant amount of padding if approached naively. We address this with Sequence Packing and Dynamic Batching, which are techniques to reduce the amount of padding required. You can read more about these in the [design doc](../design-docs/sequence-packing-and-dynamic-batching.md). ## Loss -We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss_functions.py) to calculate the loss for GRPO. Formally, +We use the [ClippedPGLossFn](../../nemo_rl/algorithms/loss/loss_functions.py) to calculate the loss for GRPO. Formally, $$ L(\theta) = E_{x \sim \pi_{\theta_{\text{old}}}} \Big[ \min \Big(\frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}A_t, \text{clip} \big( \frac{\pi_\theta(x)}{\pi_{\theta_{\text{old}}}(x)}, 1 - \varepsilon, 1 + \varepsilon \big) A_t \Big) \Big] - \beta D_{\text{KL}} (\pi_\theta \| \pi_\text{ref}) diff --git a/docs/guides/prorlv2.md b/docs/guides/prorlv2.md index 795bb0d08a..98d6142c12 100644 --- a/docs/guides/prorlv2.md +++ b/docs/guides/prorlv2.md @@ -106,7 +106,7 @@ loss_fn: This keeps PPO/GRPO-style clipping behavior but allows a larger expansion region than the contraction region, which can help exploration and reduce early collapse. -- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). +- **Implementation**: `ClippedPGLossFn` documents decoupled clipping in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py). ## Loss: Token-level Policy Gradient @@ -153,7 +153,7 @@ loss_fn: - `"icepop"`: set weights outside \([min, max]\) to zero (filter outliers) - `"seq-mask-tis"`: sequence-level geometric-mean mask + non-truncated token-level IS correction (see below) -- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss_functions.py`](../../nemo_rl/algorithms/loss_functions.py). +- **Implementation**: see `ClippedPGLossFn` init-time checks and logic in [`nemo_rl/algorithms/loss/loss_functions.py`](../../nemo_rl/algorithms/loss/loss_functions.py). ### Seq-mask-tis: Sequence-level Geometric-Mean Mask diff --git a/nemo_rl/algorithms/distillation.py b/nemo_rl/algorithms/distillation.py index 6fa9689d1a..1c7cf86a0b 100644 --- a/nemo_rl/algorithms/distillation.py +++ b/nemo_rl/algorithms/distillation.py @@ -24,7 +24,7 @@ from transformers.tokenization_utils_base import PreTrainedTokenizerBase from nemo_rl.algorithms.grpo import _should_use_async_rollouts, refit_policy_generation -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( DistillationLossConfig, DistillationLossDataDict, DistillationLossFn, diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 32df7bb10b..b91c3e6730 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -23,7 +23,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer -from nemo_rl.algorithms.loss_functions import DPOLossFn +from nemo_rl.algorithms.loss import DPOLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import preference_collate_fn diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 84c9ec2b8b..c060a05a50 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -31,12 +31,12 @@ GRPOAdvantageEstimator, ReinforcePlusPlusAdvantageEstimator, ) -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossConfig, ClippedPGLossDataDict, ClippedPGLossFn, ) +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.reward_functions import ( RewardShapingConfig, apply_reward_shaping, diff --git a/nemo_rl/algorithms/loss/__init__.py b/nemo_rl/algorithms/loss/__init__.py new file mode 100644 index 0000000000..163ce71a24 --- /dev/null +++ b/nemo_rl/algorithms/loss/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_rl.algorithms.loss.loss_functions import ( + ClippedPGLossConfig, + ClippedPGLossDataDict, + ClippedPGLossFn, + DistillationLossConfig, + DistillationLossDataDict, + DistillationLossFn, + DPOLossConfig, + DPOLossDataDict, + DPOLossFn, + NLLLossFn, + PreferenceLossDataDict, + PreferenceLossFn, +) +from nemo_rl.algorithms.loss.utils import prepare_loss_input +from nemo_rl.algorithms.loss.wrapper import ( + SequencePackingLossWrapper, + wrap_loss_fn_with_input_preparation, +) + +__all__ = [ + "ClippedPGLossConfig", + "ClippedPGLossDataDict", + "ClippedPGLossFn", + "DistillationLossConfig", + "DistillationLossDataDict", + "DistillationLossFn", + "DPOLossConfig", + "DPOLossDataDict", + "DPOLossFn", + "NLLLossFn", + "PreferenceLossDataDict", + "PreferenceLossFn", + "prepare_loss_input", + "SequencePackingLossWrapper", + "wrap_loss_fn_with_input_preparation", +] diff --git a/nemo_rl/algorithms/interfaces.py b/nemo_rl/algorithms/loss/interfaces.py similarity index 80% rename from nemo_rl/algorithms/interfaces.py rename to nemo_rl/algorithms/loss/interfaces.py index d7b6bfe67b..f1c0db3e35 100644 --- a/nemo_rl/algorithms/interfaces.py +++ b/nemo_rl/algorithms/loss/interfaces.py @@ -25,6 +25,12 @@ class LossType(enum.Enum): SEQUENCE_LEVEL = "sequence_level" +class LossInputType(enum.Enum): + LOGIT = "logit" + LOGPROB = "logprob" + DISTILLATION = "distillation" + + class LossFunction(Protocol): """Signature for loss functions used in reinforcement learning algorithms. @@ -33,33 +39,33 @@ class LossFunction(Protocol): """ loss_type: LossType + input_type: LossInputType def __call__( self, - next_token_logits: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, + **kwargs: Any, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute loss and metrics from logprobs and other data. Args: - next_token_logits: Logits from the model, typically with shape [batch_size, seq_len, vocab_size]. - For each position (b, i), contains the logit distribution over the entire vocabulary - for predicting the next token (at position i+1). For example, if processing "The cat sat on", - then next_token_logits[b, 3] would contain the logits for predicting the word - that follows "on". data: Dictionary containing all relevant data for loss computation such as rewards, values, actions, advantages, masks, and other algorithm-specific information needed for the particular loss calculation. global_valid_seqs: torch.Tensor - this tensor should contain the number of valid sequences in the microbatch. + This tensor should contain the number of valid sequences in the microbatch. It's used for global normalization for losses/metrics that are computed at the sequence level and needs to be aggregated across all microbatches. global_valid_toks: torch.Tensor This tensor should contain the number of valid tokens in the microbatch. It's used for global normalization for losses/metrics that are computed at the token level and needs to be aggregated across all microbatches. + **kwargs: Loss function input, which varies by input_type: + - For LossInputType.LOGPROB: next_token_logprobs (torch.Tensor) + - For LossInputType.LOGIT: logits (torch.Tensor) + - For LossInputType.DISTILLATION: student_topk_logprobs, teacher_topk_logprobs, H_all (torch.Tensor) Returns: tuple: (loss, metrics) diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py similarity index 67% rename from nemo_rl/algorithms/loss_functions.py rename to nemo_rl/algorithms/loss/loss_functions.py index 1a275146d2..812d2917b6 100755 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -11,24 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math -from typing import Any, NotRequired, Optional, TypedDict, TypeVar + +from typing import Any, NotRequired, TypedDict, TypeVar import torch -import torch.distributed -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType, LossType from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.model_utils import ( - ChunkedDistributedEntropy, - ChunkedDistributedGatherLogprob, - _get_tokens_on_this_cp_rank, - allgather_cp_sharded_tensor, - from_parallel_logits_to_logprobs, - gather_logits_at_global_indices, - get_logprobs_from_vocab_parallel_logits, -) Tensor = TypeVar("Tensor", bound=torch.Tensor) @@ -122,6 +112,8 @@ class ClippedPGLossFn(LossFunction): Due to potential numerical instability, we cast the logits to float32 before computing the loss. """ + input_type = LossInputType.LOGPROB + def __init__(self, cfg: ClippedPGLossConfig): self.ratio_clip_min = cfg["ratio_clip_min"] self.ratio_clip_max = cfg["ratio_clip_max"] @@ -198,15 +190,13 @@ def __init__(self, cfg: ClippedPGLossConfig): def __call__( self, - next_token_logits: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[ClippedPGLossDataDict], global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict]: """Clipped Policy Gradient RL loss function.""" + curr_logprobs = next_token_logprobs token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] advantages = data["advantages"][:, 1:] @@ -214,7 +204,6 @@ def __call__( generation_logprobs = data["generation_logprobs"][:, 1:] if self.reference_policy_kl_penalty != 0: reference_policy_logprobs = data["reference_policy_logprobs"][:, 1:] - seq_index = data.get("seq_index", None) mask = token_mask * sample_mask.unsqueeze(-1) @@ -282,39 +271,6 @@ def __call__( global_normalization_factor=global_valid_toks, ).item() - next_token_logits = next_token_logits.to(torch.float32) - - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - curr_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - curr_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_token_logits_wo_last = next_token_logits[ - :, :-1 - ] # Remove last position's logits - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits_wo_last, dim=-1 - ) - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - curr_logprobs = next_token_logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - # Calculate KL regularization. if self.reference_policy_kl_penalty != 0: if self.use_on_policy_kl_approximation: @@ -599,20 +555,18 @@ def __call__( ) -class NLLLoss(LossFunction): +class NLLLossFn(LossFunction): """Negative Log Likelihood Loss function.""" loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.LOGPROB def __call__( self, - next_token_logits: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[Any], global_valid_seqs: Tensor | None, global_valid_toks: Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, dpo_loss: bool = False, dpo_average_log_probs: bool = False, ) -> tuple[torch.Tensor, dict[str, Any]]: @@ -621,52 +575,19 @@ def __call__( token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] mask = token_mask * sample_mask.unsqueeze(-1) - seq_index = data.get("seq_index", None) - - next_token_logits = next_token_logits.to(torch.float32) - - # Gather the logprobs for the actual next tokens - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) if dpo_loss: ## shape: [batch_size] num_unmasked_tokens = torch.sum(mask, -1) ## multiply by sample_mask to zero out invalid samples - loss = -torch.sum(token_logprobs * mask, dim=-1) + loss = -torch.sum(next_token_logprobs * mask, dim=-1) if dpo_average_log_probs: loss = loss / num_unmasked_tokens.clamp(min=1) else: ## single scalar loss ## scale by the total number of tokens in the batch loss = -masked_mean( - token_logprobs, + next_token_logprobs, mask, global_normalization_factor=global_valid_toks, ) @@ -686,7 +607,7 @@ class PreferenceLossDataDict(TypedDict): sample_mask: torch.Tensor -class PreferenceLoss(LossFunction): +class PreferenceLossFn(LossFunction): """Preference Loss function. Optimizes the model to prefer chosen responses over rejected ones @@ -707,8 +628,8 @@ class PreferenceLoss(LossFunction): - accuracy: Fraction of examples where chosen response has higher reward """ - def __init__(self): - self.loss_type = LossType.SEQUENCE_LEVEL + loss_type = LossType.SEQUENCE_LEVEL + input_type = LossInputType.LOGIT def split_output_tensor(self, tensor: Tensor) -> tuple[Tensor, Tensor]: # tensor is of shape (2*micro_batch_size,) @@ -754,14 +675,14 @@ def _preference_loss( def __call__( self, - rewards: Tensor, + logits: Tensor, data: BatchedDataDict[PreferenceLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, ) -> tuple[torch.Tensor, dict[str, Any]]: sample_mask = data["sample_mask"] - rewards = rewards.squeeze(-1) + rewards = logits.squeeze(-1) ( preference_loss, @@ -799,7 +720,7 @@ class DPOLossDataDict(TypedDict): sample_mask: torch.Tensor -class DPOLossFn(PreferenceLoss): +class DPOLossFn(PreferenceLossFn): """Direct Preference Optimization (DPO) loss function. This loss function implements the DPO algorithm as described in: @@ -855,63 +776,29 @@ class DPOLossFn(PreferenceLoss): - accuracy: Fraction of examples where chosen response has higher reward """ + loss_type = LossType.SEQUENCE_LEVEL + input_type = LossInputType.LOGPROB + def __init__(self, cfg: DPOLossConfig): self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.preference_loss_weight = cfg["preference_loss_weight"] self.sft_loss_weight = cfg["sft_loss_weight"] self.preference_average_log_probs = cfg["preference_average_log_probs"] self.sft_average_log_probs = cfg["sft_average_log_probs"] - self.sft_loss = NLLLoss() - - self.loss_type = LossType.SEQUENCE_LEVEL + self.sft_loss = NLLLossFn() def _dpo_loss( self, - next_token_logits: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[Tensor, Tensor, Tensor, Tensor]: - ## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor + ## TODO(@ashors): there's some duplicate code here with the NLLLossFn function. We should refactor token_mask = data["token_mask"][:, 1:] sample_mask = data["sample_mask"] - seq_index = data.get("seq_index", None) - - next_token_logits = next_token_logits.to(torch.float32) - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - token_logprobs = from_parallel_logits_to_logprobs( - next_token_logits, - data["input_ids"], - vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], - vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], - tp_group=vocab_parallel_group, - inference_only=False, - cp_group=context_parallel_group, - ) - # slice off to the correct length to remove potential CP padding - token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1] - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - token_logprobs = get_logprobs_from_vocab_parallel_logits( - next_token_logits, data["input_ids"], seq_index=seq_index - ) - else: - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax( - next_token_logits, dim=-1 - ) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) ref_logprobs = data["reference_policy_logprobs"][:, :-1] - - diff = (token_logprobs - ref_logprobs) * token_mask + diff = (next_token_logprobs - ref_logprobs) * token_mask rewards = diff.sum(-1) if self.preference_average_log_probs: @@ -921,16 +808,13 @@ def _dpo_loss( rewards, sample_mask, global_valid_seqs, self.reference_policy_kl_penalty ) - # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLoss) + # TODO a cleaner typing fix would be required (probably that DPOLossFn should not inherit from PreferenceLossFn) def __call__( # type: ignore self, - next_token_logits: Tensor, + next_token_logprobs: Tensor, data: BatchedDataDict[DPOLossDataDict], global_valid_seqs: Tensor, global_valid_toks: Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: sft_loss_chosen = torch.tensor(0.0) if self.sft_loss_weight > 0: @@ -938,13 +822,10 @@ def __call__( # type: ignore "global_valid_toks must be provided for SFT loss" ) sft_loss, _ = self.sft_loss( - next_token_logits, + next_token_logprobs, data, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, dpo_loss=True, dpo_average_log_probs=self.sft_average_log_probs, ) @@ -960,14 +841,7 @@ def __call__( # type: ignore accuracy, rewards_chosen_mean, rewards_rejected_mean, - ) = self._dpo_loss( - next_token_logits, - data, - global_valid_seqs, - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, - ) + ) = self._dpo_loss(next_token_logprobs, data, global_valid_seqs) dpo_loss = ( self.sft_loss_weight * sft_loss_chosen @@ -988,103 +862,6 @@ def __call__( # type: ignore } -class SequencePackingLossWrapper: - def __init__( - self, - loss_fn: LossFunction, - cu_seqlens_q: Tensor, - cu_seqlens_q_padded: Optional[Tensor] = None, - ): - self.loss_fn = loss_fn - self.cu_seqlens_q = cu_seqlens_q - self.cu_seqlens_q_padded = cu_seqlens_q_padded - - def __call__( - self, - next_token_logits: Tensor, - data: BatchedDataDict[Any], - global_valid_seqs: Tensor | None, - global_valid_toks: Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - ) -> tuple[Tensor, dict[str, Any]]: - """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" - unpadded_cu_seqlens = self.cu_seqlens_q - unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] - if self.cu_seqlens_q_padded is not None: - padded_cu_seqlens = self.cu_seqlens_q_padded - padded_seq_lengths = ( - self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] - ) - else: - padded_cu_seqlens = unpadded_cu_seqlens - padded_seq_lengths = unpadded_seq_lengths - seq_starts = padded_cu_seqlens[:-1] - seq_ends = padded_cu_seqlens[1:] - - loss_accum = 0 - metrics_accum = {} - for seq_idx in range(len(seq_starts)): - seq_start = seq_starts[seq_idx].item() - seq_end = seq_ends[seq_idx].item() - - # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors - seq_data = data.slice(seq_idx, seq_idx + 1) - unpadded_seq_data = {} - for k, v in seq_data.items(): - if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: - unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] - else: - unpadded_seq_data[k] = v - - # get next_token_logits - cp_size = ( - 1 - if context_parallel_group is None - else torch.distributed.get_world_size(context_parallel_group) - ) - logit_start = seq_start // cp_size - logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size - logit_length = logit_end - logit_start - next_token_logits_slice = next_token_logits.narrow( - 1, logit_start, logit_length - ) - - loss, metrics = self.loss_fn( - next_token_logits_slice, - unpadded_seq_data, - global_valid_seqs, - global_valid_toks, - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, - ) - loss_accum += loss - for k, v in metrics.items(): - if k not in metrics_accum: - if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: - metrics_accum[k] = float("inf") - elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: - metrics_accum[k] = float("-inf") - else: - metrics_accum[k] = 0 - - val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v - - # Skip inf/-inf sentinel values (from sequences with no valid tokens) - if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: - if not math.isinf(val): - metrics_accum[k] = min(metrics_accum[k], val) - elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: - if not math.isinf(val): - metrics_accum[k] = max(metrics_accum[k], val) - else: - metrics_accum[k] += val - - return loss_accum, metrics_accum - - class DistillationLossConfig(TypedDict): kl_type: str mixed_kl_weight: float @@ -1103,12 +880,14 @@ class DistillationLossDataDict(TypedDict): class DistillationLossFn(LossFunction): """Distillation loss function.""" + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.DISTILLATION + def __init__(self, cfg: DistillationLossConfig): self.kl_type = cfg["kl_type"] self.mixed_kl_weight = cfg["mixed_kl_weight"] self.zero_outside_topk = cfg["zero_outside_topk"] self.log_infinitesimal = -100 - self.loss_type = LossType.TOKEN_LEVEL assert self.kl_type in ["forward", "reverse", "mixed"], "Invalid KL type" assert self.mixed_kl_weight >= 0 and self.mixed_kl_weight <= 1, ( @@ -1117,165 +896,14 @@ def __init__(self, cfg: DistillationLossConfig): def __call__( self, - next_token_logits: torch.Tensor, + student_topk_logprobs: torch.Tensor, + teacher_topk_logprobs: torch.Tensor, + H_all: torch.Tensor | None, data: DistillationLossDataDict, global_valid_seqs: torch.Tensor, global_valid_toks: torch.Tensor, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: """Compute distillation loss between teacher and student logits.""" - # Basic shapes - input_ids = data["input_ids"] - batch_size = input_ids.shape[0] - - # CP support: get CP group and size - cp_group = context_parallel_group - cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) - - # Ensure float32 for stability (match other losses) - next_token_logits = next_token_logits.to(torch.float32) - per_token_kl = None - # Preferred truncated-KL path: teacher provides top-k support per position - teacher_topk_logits = data["teacher_topk_logits"] # [B, S, k] - teacher_topk_indices = data["teacher_topk_indices"] # [B, S, k] - - if teacher_topk_indices.shape[-1] <= 0: - raise ValueError( - f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " - "topk=0 is not supported as it would result in empty tensor operations." - ) - - # Determine processing path and setup variables - if vocab_parallel_group is not None: - assert vocab_parallel_rank is not None, ( - "vocab_parallel_rank must be provided when vocab_parallel_group is provided" - ) - V_local = int(next_token_logits.shape[-1]) - vocab_start_index = vocab_parallel_rank * V_local - vocab_end_index = (vocab_parallel_rank + 1) * V_local - parallel_group = vocab_parallel_group - logits_tensor = next_token_logits - elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): - device_mesh = next_token_logits.device_mesh - tp_group = device_mesh.get_group("tp") - tp_rank = tp_group.rank() - local_student_logits = next_token_logits.to_local() - V_local = int(local_student_logits.shape[-1]) - vocab_start_index = tp_rank * V_local - vocab_end_index = (tp_rank + 1) * V_local - parallel_group = tp_group - logits_tensor = local_student_logits - teacher_topk_indices = teacher_topk_indices.to(local_student_logits.device) - # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment - if ( - device_mesh.mesh_dim_names is not None - and "cp" in device_mesh.mesh_dim_names - ): - cp_group = device_mesh.get_group("cp") - cp_size = cp_group.size() - else: - cp_group = None - cp_size = 1 - else: - parallel_group = None - logits_tensor = next_token_logits - - # Process based on zero_outside_topk setting - if self.zero_outside_topk and parallel_group is not None: - # Distributed processing with chunking - indices_local = teacher_topk_indices - pad_len = 0 - if cp_size > 1: - pad_len = logits_tensor.shape[1] * cp_size - indices_local.shape[1] - if pad_len > 0: - indices_local = torch.nn.functional.pad( - indices_local, (0, 0, 0, pad_len), value=0 - ) - cp_rank = torch.distributed.get_rank(cp_group) - indices_local = _get_tokens_on_this_cp_rank( - indices_local, cp_rank, cp_size, seq_dim=1 - ) - - S_local = int(logits_tensor.shape[1]) - chunk_size = max(1, min(S_local, 1024)) - student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore - logits_tensor, - indices_local, - vocab_start_index, - vocab_end_index, - chunk_size, - parallel_group, - False, - ) - - if self.kl_type != "forward": - H_all = ChunkedDistributedEntropy.apply( # type: ignore - logits_tensor, - chunk_size, - parallel_group, - False, - ) - - if cp_size > 1: - student_topk_logprobs = allgather_cp_sharded_tensor( - student_topk_logprobs, cp_group, seq_dim=1 - ) - if self.kl_type != "forward": - H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) - if pad_len > 0: - student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] - if self.kl_type != "forward": - H_all = H_all[:, :-pad_len] - elif self.zero_outside_topk: - # Non-distributed processing - student_logprobs = torch.nn.functional.log_softmax(logits_tensor, dim=-1) - student_topk_logprobs = student_logprobs.gather( - dim=-1, index=teacher_topk_indices.to(student_logprobs.device) - ) - if self.kl_type != "forward": - H_all = (student_logprobs.exp() * student_logprobs).sum(-1) - else: - # Gather logits at global indices - if (parallel_group is not None) or (cp_size > 1): - student_topk_logits = gather_logits_at_global_indices( - logits_tensor, - teacher_topk_indices, - tp_group=parallel_group, - cp_group=cp_group, - vocab_start_index=( - vocab_start_index if parallel_group is not None else 0 - ), - vocab_end_index=( - vocab_end_index - if parallel_group is not None - else int(logits_tensor.shape[-1]) - ), - ) - else: - student_topk_logits = logits_tensor.gather( - dim=-1, index=teacher_topk_indices.to(logits_tensor.device) - ) - student_topk_logprobs = torch.nn.functional.log_softmax( - student_topk_logits, dim=-1 - ) - - # Move teacher tensors to the same device/dtype as student_topk_logits - teacher_topk_logits = teacher_topk_logits.to( - student_topk_logprobs.device, dtype=student_topk_logprobs.dtype - ) - teacher_topk_logprobs = torch.nn.functional.log_softmax( - teacher_topk_logits, dim=-1 - ) - - # Single point of next-token alignment after TP/CP processing - teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] - student_topk_logprobs = student_topk_logprobs[:, :-1, :] - if self.zero_outside_topk and self.kl_type != "forward": - # Align H_all with next-token prediction - H_all = H_all[:, :-1] - student_probs = student_topk_logprobs.exp() # [B, S-1, k] teacher_probs = teacher_topk_logprobs.exp() # [B, S-1, k] @@ -1328,7 +956,7 @@ def __call__( metrics = { "loss": float(kl_loss.item()) if kl_loss.ndim == 0 else kl_loss, - "num_valid_samples": int(batch_size), + "num_valid_samples": data["input_ids"].shape[0], } return kl_loss, metrics diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py new file mode 100644 index 0000000000..359641ae09 --- /dev/null +++ b/nemo_rl/algorithms/loss/utils.py @@ -0,0 +1,89 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossInputType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.model_utils import ( + get_distillation_topk_logprobs_from_logits, + get_next_token_logprobs_from_logits, +) + + +def prepare_loss_input( + logits: torch.Tensor, + data: BatchedDataDict[Any], + loss_fn: LossFunction, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +) -> dict[str, Any]: + """Prepare loss input for a loss function. + + Args: + logits: Logits from the model. + data: Microbatch data. + loss_fn: Loss function. + vocab_parallel_rank: Vocab parallel rank. + vocab_parallel_group: Vocab parallel group. + context_parallel_group: Context parallel group. + + vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. + + Returns: + Loss input. + """ + if loss_fn.input_type == LossInputType.LOGIT: + loss_input = {"logits": logits} + + elif loss_fn.input_type == LossInputType.LOGPROB: + logprobs = get_next_token_logprobs_from_logits( + input_ids=data["input_ids"], + next_token_logits=logits, + seq_index=data.get("seq_index", None), + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) + + loss_input = {"next_token_logprobs": logprobs} + + elif loss_fn.input_type == LossInputType.DISTILLATION: + calculate_entropy = loss_fn.zero_outside_topk and loss_fn.kl_type != "forward" + student_topk_logprobs, teacher_topk_logprobs, H_all = ( + get_distillation_topk_logprobs_from_logits( + student_logits=logits, + teacher_topk_logits=data["teacher_topk_logits"], + teacher_topk_indices=data["teacher_topk_indices"], + zero_outside_topk=loss_fn.zero_outside_topk, + calculate_entropy=calculate_entropy, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) + ) + + loss_input = { + "student_topk_logprobs": student_topk_logprobs, + "teacher_topk_logprobs": teacher_topk_logprobs, + "H_all": H_all, + } + + else: + raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}") + + return loss_input diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py new file mode 100644 index 0000000000..39e8b12814 --- /dev/null +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, Callable, Optional, TypeVar + +import torch +import torch.distributed + +from nemo_rl.algorithms.loss.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict + +Tensor = TypeVar("Tensor", bound=torch.Tensor) + + +class SequencePackingLossWrapper: + def __init__( + self, + loss_fn: LossFunction, + prepare_fn: Callable[Any, Any], + cu_seqlens_q: Tensor, + cu_seqlens_q_padded: Optional[Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + ): + """Wrap a loss function to handle sequence packing. + + Args: + loss_fn: Loss function. + prepare_fn: Prepare function. + cu_seqlens_q: Unpadded cu seqlens q. + cu_seqlens_q_padded: Padded cu seqlens q. + vocab_parallel_rank: Vocab parallel rank. + vocab_parallel_group: Vocab parallel group. + context_parallel_group: Context parallel group. + + vocab_parallel_rank, vocab_parallel_group, context_parallel_group are only used for megatron policy worker. + + Returns: + Sequence packing loss wrapper. + """ + self.loss_fn = loss_fn + self.prepare_fn = prepare_fn + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_q_padded = cu_seqlens_q_padded + self.vocab_parallel_rank = vocab_parallel_rank + self.vocab_parallel_group = vocab_parallel_group + self.context_parallel_group = context_parallel_group + + def __call__( + self, + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + ) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding.""" + unpadded_cu_seqlens = self.cu_seqlens_q + unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1] + if self.cu_seqlens_q_padded is not None: + padded_cu_seqlens = self.cu_seqlens_q_padded + padded_seq_lengths = ( + self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1] + ) + else: + padded_cu_seqlens = unpadded_cu_seqlens + padded_seq_lengths = unpadded_seq_lengths + seq_starts = padded_cu_seqlens[:-1] + seq_ends = padded_cu_seqlens[1:] + + loss_accum = 0 + metrics_accum = {} + for seq_idx in range(len(seq_starts)): + seq_start = seq_starts[seq_idx].item() + seq_end = seq_ends[seq_idx].item() + + # get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors + seq_data = data.slice(seq_idx, seq_idx + 1) + unpadded_seq_data = {} + for k, v in seq_data.items(): + if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1: + unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]] + else: + unpadded_seq_data[k] = v + + # get next_token_logits + cp_size = ( + 1 + if self.context_parallel_group is None + else torch.distributed.get_world_size(self.context_parallel_group) + ) + logit_start = seq_start // cp_size + logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size + logit_length = logit_end - logit_start + next_token_logits_slice = next_token_logits.narrow( + 1, logit_start, logit_length + ) + + # prepare data for loss function + loss_input = self.prepare_fn( + logits=next_token_logits_slice, + data=unpadded_seq_data, + loss_fn=self.loss_fn, + vocab_parallel_rank=self.vocab_parallel_rank, + vocab_parallel_group=self.vocab_parallel_group, + context_parallel_group=self.context_parallel_group, + ) + + # call loss function + loss, metrics = self.loss_fn( + data=unpadded_seq_data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) + + # aggregate loss and metrics + loss_accum += loss + for k, v in metrics.items(): + if k not in metrics_accum: + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + metrics_accum[k] = float("inf") + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + metrics_accum[k] = float("-inf") + else: + metrics_accum[k] = 0 + + val = v.item() if isinstance(v, torch.Tensor) and v.ndim == 0 else v + + # Skip inf/-inf sentinel values (from sequences with no valid tokens) + if k in {"probs_ratio_min", "probs_ratio_clamped_min"}: + if not math.isinf(val): + metrics_accum[k] = min(metrics_accum[k], val) + elif k in {"probs_ratio_max", "probs_ratio_clamped_max"}: + if not math.isinf(val): + metrics_accum[k] = max(metrics_accum[k], val) + else: + metrics_accum[k] += val + + return loss_accum, metrics_accum + + +def wrap_loss_fn_with_input_preparation( + next_token_logits: Tensor, + data: BatchedDataDict[Any], + global_valid_seqs: Tensor | None, + global_valid_toks: Tensor | None, + loss_fn: LossFunction, + prepare_fn: Callable[Any, Any], + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +) -> tuple[Tensor, dict[str, Any]]: + """Wraps a loss function to handle input preparation for megatron policy worker.""" + # prepare loss input + loss_input = prepare_fn( + logits=next_token_logits, + data=data, + loss_fn=loss_fn, + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + ) + + # call loss function + loss, loss_metrics = loss_fn( + data=data, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) + + return loss, loss_metrics diff --git a/nemo_rl/algorithms/rm.py b/nemo_rl/algorithms/rm.py index 2d7d4c936a..8787888777 100644 --- a/nemo_rl/algorithms/rm.py +++ b/nemo_rl/algorithms/rm.py @@ -23,9 +23,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer -from nemo_rl.algorithms.loss_functions import ( - PreferenceLoss, -) +from nemo_rl.algorithms.loss import PreferenceLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import preference_collate_fn @@ -103,7 +101,7 @@ def setup( RayVirtualCluster, StatefulDataLoader, dict[str, StatefulDataLoader], - PreferenceLoss, + PreferenceLossFn, MasterConfig, Logger, TaskDataSpec, @@ -229,7 +227,7 @@ def setup( # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = PreferenceLoss() + loss_fn = PreferenceLossFn() print(" ✓ Model initialized") print("\n" + "=" * 60) diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index dcd7b9d025..a08c76022c 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -21,9 +21,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer, PreTrainedTokenizerBase -from nemo_rl.algorithms.loss_functions import ( - NLLLoss, -) +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn @@ -98,7 +96,7 @@ def setup( RayVirtualCluster, StatefulDataLoader, Optional[StatefulDataLoader], - NLLLoss, + NLLLossFn, Logger, CheckpointManager, SFTSaveState, @@ -210,7 +208,7 @@ def setup( # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = NLLLoss() + loss_fn = NLLLossFn() print(" ✓ Model initialized") print("\n" + "=" * 60) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index fb17ee1661..3eeff238e6 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -825,6 +825,50 @@ def get_logprobs_from_vocab_parallel_logits( ) +def get_next_token_logprobs_from_logits( + input_ids: torch.Tensor, + next_token_logits: torch.Tensor, + seq_index: Optional[torch.Tensor] = None, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Computes next token log probabilities from logits.""" + next_token_logits = next_token_logits.to(torch.float32) + + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + logprobs = from_parallel_logits_to_logprobs( + next_token_logits, + input_ids, + vocab_start_index=vocab_parallel_rank * next_token_logits.shape[-1], + vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1], + tp_group=vocab_parallel_group, + inference_only=False, + cp_group=context_parallel_group, + ) + # slice off to the correct length to remove potential CP padding + logprobs = logprobs[:, : input_ids.shape[1] - 1] + elif isinstance(next_token_logits, torch.distributed.tensor.DTensor): + logprobs = get_logprobs_from_vocab_parallel_logits( + next_token_logits, input_ids, seq_index=seq_index + ) + else: + # Remove last position's logits + next_token_logits_wo_last = next_token_logits[:, :-1] + next_token_logprobs = torch.nn.functional.log_softmax( + next_token_logits_wo_last, dim=-1 + ) + next_tokens = input_ids[:, 1:].cuda() # Skip first token + logprobs = next_token_logprobs.gather( + dim=-1, index=next_tokens.unsqueeze(-1) + ).squeeze(-1) + + return logprobs + + @torch.no_grad() def distributed_vocab_topk( vocab_parallel_logits: torch.Tensor, @@ -982,6 +1026,173 @@ def gather_logits_at_global_indices( return gathered_logits +def get_distillation_topk_logprobs_from_logits( + student_logits: torch.Tensor, + teacher_topk_logits: torch.Tensor, + teacher_topk_indices: torch.Tensor, + zero_outside_topk: bool, + calculate_entropy: bool, + vocab_parallel_rank: Optional[int] = None, + vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, + context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, +): + """Compute top-k log probabilities from logits.""" + if teacher_topk_indices.shape[-1] <= 0: + raise ValueError( + f"topk must be positive, got {teacher_topk_indices.shape[-1]}. " + "topk=0 is not supported as it would result in empty tensor operations." + ) + + # Ensure float32 for stability + student_logits = student_logits.to(torch.float32) + # Move teacher topk indices to the same device as student logits + teacher_topk_indices = teacher_topk_indices.to(student_logits.device) + + # CP support: get CP group and size + cp_group = context_parallel_group + cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group) + + # Process based on the student logits type + if vocab_parallel_group is not None: + assert vocab_parallel_rank is not None, ( + "vocab_parallel_rank must be provided when vocab_parallel_group is provided" + ) + student_logits = student_logits + parallel_group = vocab_parallel_group + + V_local = int(student_logits.shape[-1]) + vocab_start_index = vocab_parallel_rank * V_local + vocab_end_index = (vocab_parallel_rank + 1) * V_local + + elif isinstance(student_logits, torch.distributed.tensor.DTensor): + device_mesh = student_logits.device_mesh + tp_group = device_mesh.get_group("tp") + + student_logits = student_logits.to_local() + parallel_group = tp_group + + tp_rank = tp_group.rank() + V_local = int(student_logits.shape[-1]) + vocab_start_index = tp_rank * V_local + vocab_end_index = (tp_rank + 1) * V_local + + # For DTensor, derive CP group/size from the device mesh to ensure CP-aware alignment + if ( + device_mesh.mesh_dim_names is not None + and "cp" in device_mesh.mesh_dim_names + ): + cp_group = device_mesh.get_group("cp") + cp_size = cp_group.size() + else: + cp_group = None + cp_size = 1 + + else: + student_logits = student_logits + parallel_group = None + + # Process based on the zero_outside_topk setting + H_all = None + if zero_outside_topk: + # Distributed processing + if parallel_group is not None: + indices_local = teacher_topk_indices + pad_len = 0 + + if cp_size > 1: + pad_len = student_logits.shape[1] * cp_size - indices_local.shape[1] + if pad_len > 0: + indices_local = torch.nn.functional.pad( + indices_local, (0, 0, 0, pad_len), value=0 + ) + cp_rank = torch.distributed.get_rank(cp_group) + indices_local = _get_tokens_on_this_cp_rank( + indices_local, cp_rank, cp_size, seq_dim=1 + ) + + seq_len_local = int(student_logits.shape[1]) + chunk_size = max(1, min(seq_len_local, 1024)) + student_topk_logprobs = ChunkedDistributedGatherLogprob.apply( # type: ignore + student_logits, + indices_local, + vocab_start_index, + vocab_end_index, + chunk_size, + parallel_group, + False, + ) + + if calculate_entropy: + H_all = ChunkedDistributedEntropy.apply( # type: ignore + student_logits, + chunk_size, + parallel_group, + False, + ) + + if cp_size > 1: + student_topk_logprobs = allgather_cp_sharded_tensor( + student_topk_logprobs, cp_group, seq_dim=1 + ) + if calculate_entropy: + H_all = allgather_cp_sharded_tensor(H_all, cp_group, seq_dim=1) + if pad_len > 0: + student_topk_logprobs = student_topk_logprobs[:, :-pad_len, :] + if calculate_entropy: + H_all = H_all[:, :-pad_len] + + # Non-distributed processing + else: + student_logprobs = torch.nn.functional.log_softmax(student_logits, dim=-1) + student_topk_logprobs = student_logprobs.gather( + dim=-1, index=teacher_topk_indices + ) + + if calculate_entropy: + H_all = (student_logprobs.exp() * student_logprobs).sum(-1) + + else: + # Distributed processing + if parallel_group is not None or cp_size > 1: + if parallel_group is None: + vocab_start_index = 0 + vocab_end_index = int(student_logits.shape[-1]) + + student_topk_logits = gather_logits_at_global_indices( + student_logits, + teacher_topk_indices, + tp_group=parallel_group, + cp_group=cp_group, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + ) + + # Non-distributed processing + else: + student_topk_logits = student_logits.gather( + dim=-1, index=teacher_topk_indices + ) + + student_topk_logprobs = torch.nn.functional.log_softmax( + student_topk_logits, dim=-1 + ) + + # Move teacher tensors to the same device/dtype as student_topk_logits + teacher_topk_logits = teacher_topk_logits.to( + student_topk_logprobs.device, dtype=student_topk_logprobs.dtype + ) + teacher_topk_logprobs = torch.nn.functional.log_softmax(teacher_topk_logits, dim=-1) + + # Single point of next-token alignment after TP/CP processing + teacher_topk_logprobs = teacher_topk_logprobs[:, :-1, :] + student_topk_logprobs = student_topk_logprobs[:, :-1, :] + + if calculate_entropy: + H_all = H_all[:, :-1] + + return student_topk_logprobs, teacher_topk_logprobs, H_all + + class ChunkedDistributedEntropy(torch.autograd.Function): """Compute H_all = sum_v p_v log p_v across TP with chunking over sequence. diff --git a/nemo_rl/models/automodel/data.py b/nemo_rl/models/automodel/data.py index 3ffbbc4d0a..1004542284 100644 --- a/nemo_rl/models/automodel/data.py +++ b/nemo_rl/models/automodel/data.py @@ -21,7 +21,7 @@ import torch from transformers import AutoTokenizer -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.huggingface.common import ( get_flash_attention_kwargs, diff --git a/nemo_rl/models/automodel/train.py b/nemo_rl/models/automodel/train.py index acbfec711e..32e3386ab6 100644 --- a/nemo_rl/models/automodel/train.py +++ b/nemo_rl/models/automodel/train.py @@ -31,8 +31,8 @@ from torch import nn from torch.distributed.tensor import DTensor, Shard -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper +from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -473,8 +473,8 @@ def __init__( dp_size: Data parallel size enable_seq_packing: Whether sequence packing is enabled """ - self.loss_fn = loss_fn - self.cfg = cfg + self.loss_fn: LossFunction = loss_fn + self.cfg: PolicyConfig = cfg self.device_mesh = device_mesh self.cp_mesh = cp_mesh self.tp_mesh = tp_mesh @@ -515,20 +515,26 @@ def __call__( # Wrap loss function for sequence packing if needed if self.enable_seq_packing: - loss_fn_ = SequencePackingLossWrapper( + loss_fn = SequencePackingLossWrapper( loss_fn=self.loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=processed_inputs.flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=processed_inputs.flash_attn_kwargs.cu_seqlens_q, ) + loss, loss_metrics = loss_fn( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) else: - loss_fn_ = self.loss_fn - - loss, loss_metrics = loss_fn_( - logits, - mb, - global_valid_seqs, - global_valid_toks, - ) + loss_input = prepare_loss_input(logits, mb, self.loss_fn) + loss, loss_metrics = self.loss_fn( + data=mb, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) return loss, loss_metrics diff --git a/nemo_rl/models/megatron/data.py b/nemo_rl/models/megatron/data.py index 7c765f19b5..13daee1352 100644 --- a/nemo_rl/models/megatron/data.py +++ b/nemo_rl/models/megatron/data.py @@ -25,7 +25,7 @@ from megatron.core.utils import StragglerDetector from megatron.training.utils import get_ltor_masks_and_position_ids -from nemo_rl.algorithms.interfaces import LossFunction, LossType +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import _get_tokens_on_this_cp_rank from nemo_rl.models.megatron.common import _round_up_to_multiple diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 8459eada93..5535c9025c 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -29,7 +29,12 @@ from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.core.utils import StragglerDetector -from nemo_rl.algorithms.loss_functions import LossFunction, SequencePackingLossWrapper +from nemo_rl.algorithms.loss import ( + SequencePackingLossWrapper, + prepare_loss_input, + wrap_loss_fn_with_input_preparation, +) +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -302,24 +307,33 @@ def __call__( Returns: Callable: Function that takes output tensor and returns (loss, metrics) tuple """ - loss_fn = self.loss_fn + # wrap loss function with loss input preparation pack_sequences = self.cfg["sequence_packing"]["enabled"] if pack_sequences and packed_seq_params is not None: - # remove padding - loss_fn = SequencePackingLossWrapper( - loss_fn=loss_fn, + loss_fn_wrapped = SequencePackingLossWrapper( + loss_fn=self.loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=packed_seq_params.cu_seqlens_q, cu_seqlens_q_padded=packed_seq_params.cu_seqlens_q_padded, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), + ) + else: + loss_fn_wrapped = partial( + wrap_loss_fn_with_input_preparation, + loss_fn=self.loss_fn, + prepare_fn=prepare_loss_input, + vocab_parallel_rank=get_tensor_model_parallel_rank(), + vocab_parallel_group=get_tensor_model_parallel_group(), + context_parallel_group=get_context_parallel_group(), ) loss_fn_wrapped = partial( - loss_fn, + loss_fn_wrapped, data=data_dict, global_valid_seqs=global_valid_seqs, global_valid_toks=global_valid_toks, - vocab_parallel_rank=get_tensor_model_parallel_rank(), - vocab_parallel_group=get_tensor_model_parallel_group(), - context_parallel_group=get_context_parallel_group(), ) if self.cp_normalize: diff --git a/nemo_rl/models/policy/interfaces.py b/nemo_rl/models/policy/interfaces.py index 464377c57a..f6facfc748 100644 --- a/nemo_rl/models/policy/interfaces.py +++ b/nemo_rl/models/policy/interfaces.py @@ -17,7 +17,7 @@ import ray import torch -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.generation.interfaces import GenerationDatumSpec from nemo_rl.utils.timer import Timer diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index 29f034b065..20864b6d24 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -23,7 +23,7 @@ from ray.util.queue import Queue as RayQueue from transformers import AutoProcessor, PreTrainedTokenizerBase -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import ( BatchedDataDict, DynamicBatchingArgs, diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker.py b/nemo_rl/models/policy/workers/dtensor_policy_worker.py index 6028506f92..661254da23 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker.py @@ -46,8 +46,8 @@ ) from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM -from nemo_rl.algorithms.interfaces import LossFunction, LossType -from nemo_rl.algorithms.loss_functions import SequencePackingLossWrapper +from nemo_rl.algorithms.loss import SequencePackingLossWrapper, prepare_loss_input +from nemo_rl.algorithms.loss.interfaces import LossFunction, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.model_utils import ( allgather_cp_sharded_tensor, @@ -776,20 +776,28 @@ def train( placements=[Shard(sequence_dim), Shard(-1)], ) + # Wrap loss function for sequence packing if needed if self.enable_seq_packing: loss_fn_ = SequencePackingLossWrapper( loss_fn=loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=flash_attn_kwargs.cu_seqlens_q, cu_seqlens_q_padded=flash_attn_kwargs.cu_seqlens_q, ) + loss, loss_metrics = loss_fn_( + logits, + mb, + global_valid_seqs, + global_valid_toks, + ) else: - loss_fn_ = loss_fn - loss, loss_metrics = loss_fn_( - logits, - mb, - global_valid_seqs, - global_valid_toks, - ) + loss_input = prepare_loss_input(logits, mb, loss_fn) + loss, loss_metrics = loss_fn( + data=mb, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, + ) del logits # skip the update for dummy batches diff --git a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py index 4eb730e5a0..f5adf24fac 100644 --- a/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py +++ b/nemo_rl/models/policy/workers/dtensor_policy_worker_v2.py @@ -39,7 +39,7 @@ AutoTokenizer, ) -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( check_sequence_dim, diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 5a6a683765..50fde65eec 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -48,7 +48,7 @@ from megatron.core.rerun_state_machine import get_rerun_state_machine from transformers import PreTrainedTokenizerBase -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding from nemo_rl.models.generation.interfaces import ( diff --git a/pyrefly.toml b/pyrefly.toml index ac3cd167ed..4c4ae33fa2 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -39,7 +39,9 @@ project-includes = [ "examples/custom_parallel/llama_nemotron_super_49b_custom_plan.py", "nemo_rl/algorithms/__init__.py", "nemo_rl/algorithms/advantage_estimator.py", - "nemo_rl/algorithms/interfaces.py", + "nemo_rl/algorithms/loss/__init__.py", + "nemo_rl/algorithms/loss/interfaces.py", + "nemo_rl/algorithms/loss/utils.py", "nemo_rl/algorithms/reward_functions.py", "nemo_rl/algorithms/utils.py", "nemo_rl/data/__init__.py", diff --git a/research/template_project/single_update.py b/research/template_project/single_update.py index 598744ab99..43cb51c66c 100644 --- a/research/template_project/single_update.py +++ b/research/template_project/single_update.py @@ -17,7 +17,7 @@ 1) Sets up a RayVirtualCluster 2) Initializes VllmGeneration 3) Initializes LM Policy - 4) Trains on a tiny synthetic batch (global batch size = 2) with NLLLoss + 4) Trains on a tiny synthetic batch (global batch size = 2) with NLLLossFn 5) Refits the generation engine with the latest policy weights 6) Optionally repeats the train→refit cycle in a short loop @@ -34,7 +34,7 @@ from template_project.data_utils import create_batch_from from nemo_rl.algorithms.grpo import MasterConfig, refit_policy_generation -from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster, init_ray @@ -95,8 +95,8 @@ def main(config: MasterConfig) -> None: state_dict_info = policy.prepare_refit_info() policy_generation.prepare_refit_info(state_dict_info or {}) - # 4) Create tiny numeric batch and train with NLLLoss - print("\n▶ Creating tiny numeric batch and training with NLLLoss...") + # 4) Create tiny numeric batch and train with NLLLossFn + print("\n▶ Creating tiny numeric batch and training with NLLLossFn...") train_sentences = ["a b c d e hello", "a d f world"] * config["policy"][ "train_global_batch_size" ] @@ -116,7 +116,7 @@ def main(config: MasterConfig) -> None: "What is the capital of the Nepal?", ] data = create_batch_from(tokenizer, sentences=train_sentences) - loss_fn = NLLLoss() + loss_fn = NLLLossFn() # Optionally repeat the train→refit cycle num_iters = int(os.environ.get("SINGLE_UPDATE_ITERS", "10")) diff --git a/research/template_project/template_project/data_utils.py b/research/template_project/template_project/data_utils.py index 8f76d58715..0670e917d7 100644 --- a/research/template_project/template_project/data_utils.py +++ b/research/template_project/template_project/data_utils.py @@ -34,7 +34,7 @@ def create_batch_from(tokenizer, sentences: list[str]) -> BatchedDataDict: sample_mask = torch.ones(input_ids.size(0), dtype=torch.float32) # For simple NLL training, use the attention mask as token_mask - # (loss will be applied to positions 1..len-1 via NLLLoss) + # (loss will be applied to positions 1..len-1 via NLLLossFn) token_mask = torch.ones_like(input_ids) return BatchedDataDict( diff --git a/tests/unit/algorithms/sequence_packing_gradient_actor.py b/tests/unit/algorithms/sequence_packing_gradient_actor.py index 20564d77af..a5e750d358 100644 --- a/tests/unit/algorithms/sequence_packing_gradient_actor.py +++ b/tests/unit/algorithms/sequence_packing_gradient_actor.py @@ -23,10 +23,8 @@ import ray import torch -from nemo_rl.algorithms.loss_functions import ( - ClippedPGLossFn, - SequencePackingLossWrapper, -) +from nemo_rl.algorithms.loss import ClippedPGLossFn, SequencePackingLossWrapper +from nemo_rl.algorithms.loss.utils import prepare_loss_input from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -152,11 +150,12 @@ def test_sequence_packing_gradients(self): global_valid_seqs = torch.tensor(batch_size, dtype=torch.float, device="cuda") # Forward pass - baseline_loss, baseline_metrics = base_loss_fn( - baseline_logits, - data_dict, - global_valid_seqs, - global_valid_toks, + loss_input = prepare_loss_input(baseline_logits, data_dict, base_loss_fn) + baseline_loss, _ = base_loss_fn( + data=data_dict, + global_valid_seqs=global_valid_seqs, + global_valid_toks=global_valid_toks, + **loss_input, ) # Backward pass @@ -218,26 +217,26 @@ def make_packed_logits(logits): packed_logits = make_packed_logits(baseline_logits) # Create sequence packing wrapper + tp_group = torch.distributed.new_group(ranks=[rank]) wrapper = SequencePackingLossWrapper( loss_fn=base_loss_fn, + prepare_fn=prepare_loss_input, cu_seqlens_q=cu_seqlens, cu_seqlens_q_padded=cu_seqlens_padded, + vocab_parallel_rank=0, + vocab_parallel_group=tp_group, + context_parallel_group=cp_group, ) # Create data dict for packed sequences packed_data_dict = BatchedDataDict(original_data) - tp_group = torch.distributed.new_group(ranks=[rank]) - # Forward pass - packed_loss, packed_metrics = wrapper( + packed_loss, _ = wrapper( packed_logits, packed_data_dict, global_valid_seqs, global_valid_toks, - vocab_parallel_rank=0, - vocab_parallel_group=tp_group, - context_parallel_group=cp_group, ) # Backward pass diff --git a/tests/unit/algorithms/test_distillation.py b/tests/unit/algorithms/test_distillation.py index a0dfc19d69..1bcbe2e2bd 100644 --- a/tests/unit/algorithms/test_distillation.py +++ b/tests/unit/algorithms/test_distillation.py @@ -25,7 +25,7 @@ distillation_train, validate, ) -from nemo_rl.algorithms.loss_functions import DistillationLossFn +from nemo_rl.algorithms.loss import DistillationLossFn from nemo_rl.data.interfaces import DatumSpec from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py index b2155ac91f..214bd90572 100644 --- a/tests/unit/algorithms/test_dpo.py +++ b/tests/unit/algorithms/test_dpo.py @@ -24,7 +24,7 @@ add_ref_logprobs_to_data, dpo_train, ) -from nemo_rl.algorithms.loss_functions import PreferenceLoss +from nemo_rl.algorithms.loss import PreferenceLossFn from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.named_sharding import NamedSharding @@ -169,7 +169,7 @@ def val_iter(self): tokenizer = MagicMock() tokenizer.pad_token_id = 0 - loss_fn = PreferenceLoss() + loss_fn = PreferenceLossFn() logger = MagicMock() checkpointer = MagicMock() diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index 73a75fe64e..7a0783f132 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -31,7 +31,7 @@ grpo_train, validate, ) -from nemo_rl.algorithms.loss_functions import ClippedPGLossFn +from nemo_rl.algorithms.loss import ClippedPGLossFn from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index fbec4c8504..84ee67eb93 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -17,12 +17,13 @@ import pytest import torch -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossConfig, ClippedPGLossFn, DistillationLossFn, DPOLossFn, - NLLLoss, + NLLLossFn, + prepare_loss_input, ) from nemo_rl.algorithms.utils import calculate_kl, masked_mean from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -65,7 +66,7 @@ def test_nll_loss(): if not torch.cuda.is_available(): pytest.skip("No GPU available") - loss_fn = NLLLoss() + loss_fn = NLLLossFn() vocab_size = 8 data = { @@ -91,13 +92,14 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) + loss_input = prepare_loss_input(next_token_logits, data, loss_fn) loss, metrics_dict = loss_fn( - next_token_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["token_mask"] * data["sample_mask"].unsqueeze(-1) ), + **loss_input, ) torch.testing.assert_close(loss.cpu(), torch.tensor(0.0)) # Check the metrics dictionary contains the expected values @@ -116,16 +118,17 @@ def test_nll_loss(): .unsqueeze(0) .to("cuda") ) + loss_input = prepare_loss_input(next_token_logits, data, loss_fn) loss, metrics_dict = loss_fn( - next_token_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["token_mask"] * data["sample_mask"].unsqueeze(-1) ), + **loss_input, ) ## loss per token is 999, and we have two unmasked tokens - ## NLLLoss averages the loss over unmasked tokens + ## NLLLossFn averages the loss over unmasked tokens torch.testing.assert_close(loss.cpu(), torch.tensor(999.0)) assert metrics_dict["num_unmasked_tokens"] == 2 @@ -151,13 +154,14 @@ def test_dpo_loss(): } ) - loss, metrics_dict = loss_fn( - next_token_logits, - data, + loss_input = prepare_loss_input(next_token_logits, data, loss_fn) + loss, _ = loss_fn( + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] + data["token_mask"] * data["sample_mask"].unsqueeze(-1) ), + **loss_input, ) ## chosen and rejected errors are the same, so difference between them is 0 @@ -173,6 +177,16 @@ def test_dpo_loss(): } ) + loss_input = prepare_loss_input(next_token_logits, data, loss_fn_with_sft) + loss_sft, _ = loss_fn_with_sft( + data=data, + global_valid_seqs=torch.sum(data["sample_mask"]), + global_valid_toks=torch.sum( + data["sample_mask"].unsqueeze(-1) * data["token_mask"] + ), + **loss_input, + ) + expected_sft_loss = ( -( torch.nn.functional.log_softmax(torch.tensor([[0.0] * vocab_size]), dim=-1)[ @@ -184,14 +198,7 @@ def test_dpo_loss(): ) expected_preference_loss = -torch.nn.functional.logsigmoid(torch.tensor(0.0)) assert torch.isclose( - loss_fn_with_sft( - next_token_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - )[0].cpu(), + loss_sft.cpu(), 0.5 * expected_sft_loss + expected_preference_loss, ) @@ -261,20 +268,25 @@ def test_dpo_loss_varying_sequence_lengths(): } ) - # Compute loss - loss, metrics = dpo_loss_fn_no_avg( - next_token_logits, - data, + # Compute no averaging loss + loss_input = prepare_loss_input(next_token_logits, data, dpo_loss_fn_no_avg) + _, metrics = dpo_loss_fn_no_avg( + data=data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), + **loss_input, ) - loss_avg, metrics_avg = dpo_loss_fn_avg( - next_token_logits, - data, + + # Compute averaging loss + loss_input = prepare_loss_input(next_token_logits, data, dpo_loss_fn_avg) + _, metrics_avg = dpo_loss_fn_avg( + data=data, global_valid_seqs=torch.sum(sample_mask), global_valid_toks=torch.sum(sample_mask.unsqueeze(-1) * token_mask), + **loss_input, ) + # Compute expected losses num_unmasked_tokens = token_mask[:, 1:][::2].sum().item() logprobs = torch.nn.functional.log_softmax(next_token_logits[:, 1:], dim=-1) token_logprobs = logprobs.gather( @@ -321,14 +333,15 @@ def test_dpo_sft_matches_nll_loss(): next_token_logits = torch.randn((batch_size * 2, 5, vocab_size)).to("cuda") # Compute NLL loss - nll_loss_fn = NLLLoss() - nll_loss, nll_metrics = nll_loss_fn( - next_token_logits[::2], - sft_data, + nll_loss_fn = NLLLossFn() + loss_input = prepare_loss_input(next_token_logits[::2], sft_data, nll_loss_fn) + nll_loss, _ = nll_loss_fn( + data=sft_data, global_valid_seqs=None, global_valid_toks=torch.sum( sft_data["sample_mask"].unsqueeze(-1) * torch.sum(sft_data["token_mask"]) ), + **loss_input, ) # Compute DPO loss with preference_loss_weight=0 @@ -341,13 +354,14 @@ def test_dpo_sft_matches_nll_loss(): "sft_average_log_probs": False, } ) - dpo_loss, dpo_metrics = dpo_loss_fn( - next_token_logits, - dpo_data, + loss_input = prepare_loss_input(next_token_logits, dpo_data, dpo_loss_fn) + dpo_loss, _ = dpo_loss_fn( + data=dpo_data, global_valid_seqs=torch.sum(dpo_data["sample_mask"]), global_valid_toks=torch.sum( dpo_data["sample_mask"].unsqueeze(-1) * dpo_data["token_mask"] ), + **loss_input, ) # Verify losses match @@ -504,12 +518,13 @@ def test_clipped_pg_loss_ppo_clipping(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -551,14 +566,15 @@ def test_clipped_pg_loss_reinforce_mode(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -596,14 +612,15 @@ def test_clipped_pg_loss_force_on_policy_ratio(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, metrics = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Loss should match the on-policy expectation @@ -706,14 +723,15 @@ def test_clipped_pg_loss_kl_penalty(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -734,6 +752,7 @@ def test_clipped_pg_loss_masking(): ) # Need some realistic-ish logits and logprobs for masking test dummy_logits = torch.randn(batch_size, seq_len, vocab_size, device=device) + # Ensure logprobs used by the loss fn make sense relative to advantages data["prev_logprobs"] = torch.randn_like(data["prev_logprobs"]) * 0.1 data["reference_policy_logprobs"] = ( @@ -745,16 +764,17 @@ def test_clipped_pg_loss_masking(): cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) # --- Test 1: Token Mask --- # Default mask: [[0, 1, 1, 1], [0, 1, 1, 1]] -> 3 tokens per sample loss_default, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Modify token_mask for batch item 0 to mask one more token (pos 1) @@ -765,12 +785,12 @@ def test_clipped_pg_loss_masking(): ) loss_token_masked, _ = loss_fn( - dummy_logits, - data_mod_token, + data=data_mod_token, global_valid_seqs=torch.sum(data_mod_token["sample_mask"]), global_valid_toks=torch.sum( data_mod_token["sample_mask"].unsqueeze(-1) * data_mod_token["token_mask"] ), + **loss_input, ) # Loss should change if a potentially contributing token is masked assert not torch.isclose(loss_default, loss_token_masked, atol=1e-4), ( @@ -784,12 +804,12 @@ def test_clipped_pg_loss_masking(): ) # Ignore item 1 loss_sample_masked, _ = loss_fn( - dummy_logits, - data_mod_sample, + data=data_mod_sample, global_valid_seqs=torch.sum(data_mod_sample["sample_mask"]), global_valid_toks=torch.sum( data_mod_sample["sample_mask"].unsqueeze(-1) * data_mod_sample["token_mask"] ), + **loss_input, ) # Manually create data dict for only batch 0 @@ -805,13 +825,14 @@ def test_clipped_pg_loss_masking(): data_only_b0 = BatchedDataDict(data_only_b0_dict) logits_only_b0 = dummy_logits[0:1] + loss_input = prepare_loss_input(logits_only_b0, data_only_b0, loss_fn) loss_only_b0, _ = loss_fn( - logits_only_b0, - data_only_b0, + data=data_only_b0, global_valid_seqs=torch.sum(data_only_b0["sample_mask"]), global_valid_toks=torch.sum( data_only_b0["sample_mask"].unsqueeze(-1) * data_only_b0["token_mask"] ), + **loss_input, ) torch.testing.assert_close(loss_sample_masked, loss_only_b0) @@ -830,17 +851,18 @@ def test_clipped_pg_loss_zero_mask(): cfg = deepcopy(basic_pg_loss_test_config) cfg["reference_policy_kl_penalty"] = 0.1 loss_fn = ClippedPGLossFn(cfg) # Use original loss fn + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) # Set token mask to all zeros data["token_mask"] = torch.zeros_like(data["token_mask"]) loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Loss should be exactly zero @@ -980,12 +1002,13 @@ def test_clipped_pg_loss_on_policy_kl_importance_sampling(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_total_loss, atol=1e-4, rtol=1e-3) @@ -1112,12 +1135,13 @@ def test_clipped_pg_loss_on_policy_truncated_importance_sampling( dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) @@ -1159,11 +1183,12 @@ def test_clipped_pg_loss_icepop_importance_sampling(): dummy_logits = _create_exact_logits( prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) @@ -1202,21 +1227,22 @@ def test_clipped_pg_loss_seq_mask_tis(): dummy_logits = _create_exact_logits( prev_lp, data["input_ids"], batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss, atol=1e-4, rtol=1e-3) # nan_to_num: inject -inf → loss must stay finite data["generation_logprobs"][0, 2] = float("-inf") actual_loss2, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) assert not torch.isnan(actual_loss2), "Loss is NaN — nan_to_num fix not working" assert not torch.isinf(actual_loss2), "Loss is inf — nan_to_num fix not working" @@ -1333,14 +1359,15 @@ def test_clipped_pg_loss_dual_clip(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1382,11 +1409,13 @@ def test_clipped_pg_loss_entropy(): dummy_logits = _create_exact_logits( curr_lp_masked, data["input_ids"], batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) + _, metrics = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close( @@ -1465,12 +1494,13 @@ def test_clipped_pg_loss_gspo(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1563,14 +1593,15 @@ def test_clipped_pg_loss_gspo_batch_size_2(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(1) * data["token_mask"] ), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_loss) @@ -1664,12 +1695,13 @@ def test_clipped_pg_loss_gspo_importance_sampling_correction(): dummy_logits = _create_exact_logits( curr_lp_masked, input_ids, batch_size, seq_len, vocab_size, device ) + loss_input = prepare_loss_input(dummy_logits, data, loss_fn) actual_loss, _ = loss_fn( - dummy_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum(data["sample_mask"] * data["token_mask"]), + **loss_input, ) torch.testing.assert_close(actual_loss, expected_actor_loss, atol=1e-4, rtol=1e-3) @@ -1681,6 +1713,10 @@ def setup_distillation_test_data(batch_size=2, seq_len=4, vocab_size=8, topk=64) device = "cuda" + # Set seed for reproducibility + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) + # Create input data input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device) input_lengths = torch.tensor([seq_len] * batch_size, device=device) @@ -1708,138 +1744,88 @@ def setup_distillation_test_data(batch_size=2, seq_len=4, vocab_size=8, topk=64) return data, student_logits -def test_distillation_loss_forward_kl(): - """Test forward KL divergence loss calculation.""" +@pytest.mark.parametrize("kl_type", ["forward", "reverse", "mixed"]) +@pytest.mark.parametrize("zero_outside_topk", [True, False]) +def test_distillation_loss_different_settings(kl_type, zero_outside_topk): + """Test different distillation loss settings.""" data, student_logits = setup_distillation_test_data() loss_fn = DistillationLossFn( { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": False, + "kl_type": kl_type, + "mixed_kl_weight": 0.3, + "zero_outside_topk": zero_outside_topk, } ) + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, metrics = loss_fn( - student_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) - # Verify loss is a scalar tensor - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) + # Verify loss + if zero_outside_topk: + if kl_type == "forward": + assert torch.allclose(loss, torch.tensor(-0.9636520743370056)) + elif kl_type == "reverse": + assert torch.allclose(loss, torch.tensor(-490.5150451660156)) + elif kl_type == "mixed": + assert torch.allclose(loss, torch.tensor(-343.6496276855469)) + else: + if kl_type == "forward": + assert torch.allclose(loss, torch.tensor(0.5783048868179321)) + elif kl_type == "reverse": + assert torch.allclose(loss, torch.tensor(0.5811167359352112)) + elif kl_type == "mixed": + assert torch.allclose(loss, torch.tensor(0.5802732110023499)) # Verify metrics dictionary assert isinstance(metrics, dict) assert "loss" in metrics -def test_distillation_loss_reverse_kl(): - """Test reverse KL divergence loss calculation.""" - data, student_logits = setup_distillation_test_data() +@pytest.mark.parametrize("k", [1, 32, 64, 1000000]) +@pytest.mark.parametrize("zero_outside_topk", [True, False]) +def test_distillation_loss_topk_filtering(k, zero_outside_topk): + """Test top-k filtering functionality with various k values.""" + data, student_logits = setup_distillation_test_data(topk=k) loss_fn = DistillationLossFn( { - "kl_type": "reverse", + "kl_type": "forward", "mixed_kl_weight": 0.5, - "zero_outside_topk": False, + "zero_outside_topk": zero_outside_topk, } ) - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is a scalar tensor - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # Verify metrics dictionary - assert isinstance(metrics, dict) - assert "loss" in metrics - - -def test_distillation_loss_mixed_kl(): - """Test mixed KL divergence loss calculation.""" - data, student_logits = setup_distillation_test_data() - - mixed_kl_weight = 0.3 - loss_fn = DistillationLossFn( - { - "kl_type": "mixed", - "mixed_kl_weight": mixed_kl_weight, - "zero_outside_topk": False, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, + loss_input = prepare_loss_input(student_logits, data, loss_fn) + loss, _ = loss_fn( + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) - # Verify loss is a scalar tensor + # Verify loss is calculated correctly with top-k filtering assert loss.dim() == 0 assert not torch.isnan(loss) assert not torch.isinf(loss) - # Verify metrics dictionary - assert isinstance(metrics, dict) - assert "loss" in metrics - + # For k=1, we expect only the top-1 token to be considered + if k == 1: + assert isinstance(loss, torch.Tensor) -def test_distillation_loss_topk_filtering(): - """Test top-k filtering functionality with various k values.""" - # Test with different k values (excluding k=0 which should be invalid) - k_values = [1, 32, 64, 1000000] # Valid k values - - for k in k_values: - data, student_logits = setup_distillation_test_data(topk=k) - - loss_fn = DistillationLossFn( - { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": False, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is calculated correctly with top-k filtering - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # For k=1, we expect only the top-1 token to be considered - if k == 1: - assert isinstance(loss, torch.Tensor) - - # For large k values, we expect normal behavior - if k >= 32: - assert isinstance(loss, torch.Tensor) - assert loss.item() != 0.0 # Should have some meaningful loss + # For large k values, we expect normal behavior + if k >= 32: + assert isinstance(loss, torch.Tensor) + assert loss.item() != 0.0 # Should have some meaningful loss def test_distillation_loss_invalid_k_zero(): @@ -1857,54 +1843,7 @@ def test_distillation_loss_invalid_k_zero(): # This should raise a ValueError for k=0 with pytest.raises(ValueError, match="topk must be positive"): - loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - -def test_distillation_loss_zero_outside_topk(): - """Test zeroing outside top-k functionality with various k values.""" - # Test with different k values for zero_outside_topk (excluding k=0 which should be invalid) - k_values = [1, 32, 64, 1000000] # Valid k values - - for k in k_values: - data, student_logits = setup_distillation_test_data(topk=k) - - loss_fn = DistillationLossFn( - { - "kl_type": "forward", - "mixed_kl_weight": 0.5, - "zero_outside_topk": True, - } - ) - - loss, metrics = loss_fn( - student_logits, - data, - global_valid_seqs=torch.sum(data["sample_mask"]), - global_valid_toks=torch.sum( - data["sample_mask"].unsqueeze(-1) * data["token_mask"] - ), - ) - - # Verify loss is calculated correctly with zeroing - assert loss.dim() == 0 - assert not torch.isnan(loss) - assert not torch.isinf(loss) - - # For k=1, only top-1 token should remain non-zero - if k == 1: - assert isinstance(loss, torch.Tensor) - - # For large k values, most tokens should remain non-zero - if k >= 32: - assert isinstance(loss, torch.Tensor) - assert loss.item() != 0.0 # Should have some meaningful loss + _ = prepare_loss_input(student_logits, data, loss_fn) def test_distillation_loss_gradient_flow(): @@ -1922,13 +1861,14 @@ def test_distillation_loss_gradient_flow(): } ) + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, _ = loss_fn( - student_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Compute gradients @@ -1955,39 +1895,42 @@ def test_distillation_loss_edge_cases(): # Test with all-zero logits zero_logits = torch.zeros_like(student_logits) + loss_input = prepare_loss_input(zero_logits, data, loss_fn) loss, _ = loss_fn( - zero_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) assert not torch.isnan(loss) assert not torch.isinf(loss) # Test with very large logits large_logits = torch.ones_like(student_logits) * 100.0 + loss_input = prepare_loss_input(large_logits, data, loss_fn) loss, _ = loss_fn( - large_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) assert not torch.isnan(loss) assert not torch.isinf(loss) # Test with very small logits small_logits = torch.ones_like(student_logits) * -100.0 + loss_input = prepare_loss_input(small_logits, data, loss_fn) loss, _ = loss_fn( - small_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) assert not torch.isnan(loss) assert not torch.isinf(loss) @@ -2030,13 +1973,14 @@ def test_distillation_loss_fn_call(): } ) + loss_input = prepare_loss_input(student_logits, data, loss_fn) loss, metrics = loss_fn( - student_logits, - data, + data=data, global_valid_seqs=torch.sum(data["sample_mask"]), global_valid_toks=torch.sum( data["sample_mask"].unsqueeze(-1) * data["token_mask"] ), + **loss_input, ) # Verify return types diff --git a/tests/unit/algorithms/test_rm.py b/tests/unit/algorithms/test_rm.py index f053c4246d..b5c0328681 100644 --- a/tests/unit/algorithms/test_rm.py +++ b/tests/unit/algorithms/test_rm.py @@ -18,7 +18,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader -from nemo_rl.algorithms.loss_functions import PreferenceLoss +from nemo_rl.algorithms.loss import PreferenceLossFn from nemo_rl.algorithms.rm import _default_rm_save_state, rm_train @@ -75,7 +75,7 @@ def val_iter(self): tokenizer = MagicMock() tokenizer.pad_token_id = 0 - loss_fn = PreferenceLoss() + loss_fn = PreferenceLossFn() logger = MagicMock() checkpointer = MagicMock() diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index c507f8a987..2e76dda6ab 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -18,7 +18,7 @@ import torch from torchdata.stateful_dataloader import StatefulDataLoader -from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.sft import _default_sft_save_state, sft_train @@ -58,7 +58,7 @@ def val_iter(self): tokenizer = MagicMock() tokenizer.pad_token_id = 0 - loss_fn = NLLLoss() + loss_fn = NLLLossFn() logger = MagicMock() checkpointer = MagicMock() diff --git a/tests/unit/models/automodel/test_automodel_data.py b/tests/unit/models/automodel/test_automodel_data.py index c362e3168d..27bd4e3b99 100644 --- a/tests/unit/models/automodel/test_automodel_data.py +++ b/tests/unit/models/automodel/test_automodel_data.py @@ -17,7 +17,7 @@ import pytest import torch -from nemo_rl.algorithms.interfaces import LossType +from nemo_rl.algorithms.loss.interfaces import LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( ProcessedInputs, diff --git a/tests/unit/models/automodel/test_automodel_train.py b/tests/unit/models/automodel/test_automodel_train.py index a2dfddf9e3..9eb5d72bb9 100644 --- a/tests/unit/models/automodel/test_automodel_train.py +++ b/tests/unit/models/automodel/test_automodel_train.py @@ -24,6 +24,7 @@ except ImportError: pytest.skip("nemo_automodel not available", allow_module_level=True) +from nemo_rl.algorithms.loss.interfaces import LossInputType from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.automodel.data import ( ProcessedInputs, @@ -63,6 +64,7 @@ def mock_model(): def mock_loss_fn(): loss_fn = MagicMock() loss_fn.return_value = (torch.tensor(0.5), {"loss": 0.5}) + loss_fn.input_type = LossInputType.LOGIT return loss_fn @@ -310,10 +312,10 @@ def test_basic_loss_computation( # Verify loss function was called mock_loss_fn.assert_called_once() - call_args = mock_loss_fn.call_args[0] - assert torch.is_tensor(call_args[0]) # logits - assert call_args[2] == global_valid_seqs # global_valid_seqs - assert call_args[3] == global_valid_toks # global_valid_toks + call_kwargs = mock_loss_fn.call_args[1] + assert torch.is_tensor(call_kwargs["logits"]) + assert call_kwargs["global_valid_seqs"] == global_valid_seqs + assert call_kwargs["global_valid_toks"] == global_valid_toks @patch("nemo_rl.models.automodel.train.SequencePackingLossWrapper") def test_loss_with_sequence_packing( @@ -1896,10 +1898,12 @@ def forward(self, input_ids, **kwargs): ) # Create loss function that returns requires_grad tensor - def loss_fn(logits, mb, global_valid_seqs, global_valid_toks): + def loss_fn(logits, data, global_valid_seqs, global_valid_toks): loss = logits.mean() return loss, {"loss": loss.item()} + loss_fn.input_type = LossInputType.LOGIT + # Create loss post-processor loss_post_processor = LossPostProcessor( loss_fn=loss_fn, diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index c27a183b5c..ac5d2484ab 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -23,7 +23,7 @@ import torch from nemo_rl.algorithms.grpo import refit_policy_generation -from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.loss import NLLLossFn from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster @@ -719,7 +719,7 @@ async def run_hf_train_process( 1. Use vLLM for generation 2. Use HF policy for training and logprob computation """ - from tests.unit.test_utils import SimpleNLLLoss + from tests.unit.test_utils import SimpleNLLLossFn try: prompts = [ @@ -848,7 +848,7 @@ async def run_hf_train_process( { "input_ids": train_input_ids, "input_lengths": generation_results["unpadded_sequence_lengths"], - "token_loss_mask": token_loss_mask, + "token_mask": token_loss_mask, "sample_mask": torch.ones(train_input_ids.shape[0]), } ) @@ -858,7 +858,7 @@ async def run_hf_train_process( lm_policy.prepare_for_training() # Just do one training step to verify it works - results = lm_policy.train(train_data, SimpleNLLLoss()) + results = lm_policy.train(train_data, SimpleNLLLossFn()) print(f"Training loss: {results['loss']}") lm_policy.finish_training() @@ -2164,7 +2164,7 @@ def test_vllm_generation_with_megatron_training( megatron_policy.prepare_for_training() # Do one training step to verify it works - results = megatron_policy.train(train_data, NLLLoss()) + results = megatron_policy.train(train_data, NLLLossFn()) print(f"Training loss: {results['loss']}") megatron_policy.finish_training() @@ -2331,7 +2331,7 @@ def test_vllm_generation_with_megatron_training_moe_model( megatron_policy.prepare_for_training() # Do one training step to verify it works - results = megatron_policy.train(train_data, NLLLoss()) + results = megatron_policy.train(train_data, NLLLossFn()) print(f"Training loss: {results['loss']}") megatron_policy.finish_training() diff --git a/tests/unit/models/megatron/test_train.py b/tests/unit/models/megatron/test_train.py index 24dda67eec..b80d1a7986 100644 --- a/tests/unit/models/megatron/test_train.py +++ b/tests/unit/models/megatron/test_train.py @@ -27,6 +27,8 @@ import pytest import torch +from nemo_rl.algorithms.loss.interfaces import LossInputType + class TestModelForward: """Tests for model_forward function.""" @@ -685,6 +687,7 @@ def test_loss_post_processor_no_packing( from nemo_rl.models.megatron.train import LossPostProcessor mock_loss_fn = MagicMock(return_value=(torch.tensor(0.5), {"loss": 0.5})) + mock_loss_fn.input_type = LossInputType.LOGIT cfg = {"sequence_packing": {"enabled": False}} processor = LossPostProcessor(loss_fn=mock_loss_fn, cfg=cfg, cp_normalize=False) @@ -723,6 +726,7 @@ def test_loss_post_processor_with_cp_normalize( from nemo_rl.models.megatron.train import LossPostProcessor mock_loss_fn = MagicMock(return_value=(torch.tensor(1.0), {})) + mock_loss_fn.input_type = LossInputType.LOGIT cfg = {"sequence_packing": {"enabled": False}} processor = LossPostProcessor( diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index a750a78f9a..363c23ae09 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -19,15 +19,15 @@ import torch from transformers import AutoModelForCausalLM -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ClippedPGLossFn, NLLLoss +from nemo_rl.algorithms.loss import ClippedPGLossFn, NLLLossFn +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn def create_test_config( @@ -267,7 +267,7 @@ def _base_setup_impl(request, cluster): if mode == "train": # Create loss function - loss_fn: LossFunction = SimpleLoss() + loss_fn: LossFunction = SimpleLossFn() yield policy, data, loss_fn elif mode == "logprob": token_logprobs = calculate_token_logprobs(model_name, data) @@ -424,7 +424,7 @@ def test_dtensor_single_gpu_training( # Create test batch data = create_test_batch(mode="train") - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() # Test training policy.prepare_for_training() @@ -977,8 +977,8 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( tokenizer=tokenizer, ) - # Test NLLLoss and ClippedPGLossFn with mbs=1 - nll_loss_fn = NLLLoss() + # Test NLLLossFn and ClippedPGLossFn with mbs=1 + nll_loss_fn = NLLLossFn() pg_loss_fn = ClippedPGLossFn( { "ratio_clip_min": 0.2, @@ -1022,7 +1022,7 @@ def test_dtensor_loss_independent_of_microbatch_size_two_gpus( tokenizer=tokenizer, ) - # Test NLLLoss and ClippedPGLossFn with mbs=2 + # Test NLLLossFn and ClippedPGLossFn with mbs=2 policy_mbs2.prepare_for_training() mbs2_nll_results = policy_mbs2.train(data, nll_loss_fn) mbs2_nll_loss = mbs2_nll_results["loss"] @@ -1087,7 +1087,7 @@ def test_dtensor_v1_policy_flops_range_check( ) # Create loss function - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() try: # Prepare for training diff --git a/tests/unit/models/policy/test_dtensor_worker_v2.py b/tests/unit/models/policy/test_dtensor_worker_v2.py index 0a257baa86..648cbbaa1b 100644 --- a/tests/unit/models/policy/test_dtensor_worker_v2.py +++ b/tests/unit/models/policy/test_dtensor_worker_v2.py @@ -27,7 +27,7 @@ from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.policy import AutomodelKwargs, PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn try: from nemo_rl.models.policy.workers.dtensor_policy_worker_v2 import ( @@ -423,7 +423,7 @@ def test_dtensor_v2_mixed_precision_training_and_logprobs( try: # --- Test Training --- train_data = create_test_batch(mode="train") - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() policy.prepare_for_training() results = policy.train(train_data, loss_fn) diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 7d329ab411..3c4625ba23 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -21,20 +21,20 @@ import ray import torch -from nemo_rl.algorithms.interfaces import LossFunction -from nemo_rl.algorithms.loss_functions import ( +from nemo_rl.algorithms.loss import ( ClippedPGLossConfig, ClippedPGLossFn, DPOLossFn, - NLLLoss, + NLLLossFn, ) +from nemo_rl.algorithms.loss.interfaces import LossFunction from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation import configure_generation_config from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.lm_policy import Policy -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn basic_pg_loss_test_config: ClippedPGLossConfig = { "ratio_clip_min": 0.2, @@ -347,7 +347,7 @@ def training_setup(request): ) # Create loss function - loss_fn: LossFunction = SimpleLoss() + loss_fn: LossFunction = SimpleLossFn() yield policy, cluster, data, loss_fn @@ -824,7 +824,7 @@ def test_megatron_loss_independent_of_microbatch_size(tiny_llama_model_path): ) # Test loss functions - nll_loss_fn = NLLLoss() + nll_loss_fn = NLLLossFn() pg_loss_fn = ClippedPGLossFn(basic_pg_loss_test_config) policy1.prepare_for_training() @@ -902,7 +902,7 @@ def test_megatron_grad_norm_invariant_to_number_of_microbatches(tiny_llama_model ) tokenizer = get_tokenizer({"name": tiny_llama_model_path}) - nll_loss_fn = NLLLoss() + nll_loss_fn = NLLLossFn() cluster1 = RayVirtualCluster( name="test-gradnorm-mbs1", @@ -1032,7 +1032,7 @@ def test_megatron_reference_policy_functionality(tiny_llama_model_path): } ) - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() policy.prepare_for_training() # Train for more steps and monitor loss to ensure training is working @@ -1147,7 +1147,7 @@ def test_megatron_checkpoint_save_kill_and_restore( } ) - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() # Train for several steps to modify model state significantly policy1.prepare_for_training() @@ -1842,7 +1842,7 @@ def test_megatron_sft_training(tiny_llama_model_path): ) # Create NLL loss function for SFT - sft_loss_fn = NLLLoss() + sft_loss_fn = NLLLossFn() try: # Prepare for training @@ -2358,8 +2358,8 @@ def test_megatron_gradient_norm_consistency_across_parallelism(tiny_llama_model_ init_reference_model=False, ) - # Use SimpleLoss for consistent comparison - loss_fn = NLLLoss() + # Use SimpleLossFn for consistent comparison + loss_fn = NLLLossFn() try: # Prepare for training @@ -2532,7 +2532,7 @@ def test_megatron_policy_flops_range_check(tiny_llama_model_path): ) # Create loss function - loss_fn = SimpleLoss() + loss_fn = SimpleLossFn() try: # Prepare for training diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 11515ec661..3369761655 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -11,29 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any import torch -from nemo_rl.algorithms.interfaces import LossType +from nemo_rl.algorithms.loss.interfaces import LossInputType, LossType from nemo_rl.distributed.batched_data_dict import BatchedDataDict -class SimpleLoss: +class SimpleLossFn: loss_type = LossType.SEQUENCE_LEVEL + input_type = LossInputType.LOGIT def __call__( self, - next_token_logits: torch.Tensor, + logits: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor | None, global_valid_toks: torch.Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: - # Just return mean of logprobs as the loss for testing - loss = next_token_logits.mean() + # Just return mean of logits as the loss for testing + loss = logits.mean() metrics = { "loss": loss.item(), "test_metric": loss.item() * 0.5, @@ -43,34 +41,21 @@ def __call__( # Create a simple masked NLL loss function -class SimpleNLLLoss: - loss_type = LossType.SEQUENCE_LEVEL +class SimpleNLLLossFn: + loss_type = LossType.TOKEN_LEVEL + input_type = LossInputType.LOGPROB def __call__( self, - next_token_logits: torch.Tensor, + next_token_logprobs: torch.Tensor, data: BatchedDataDict, global_valid_seqs: torch.Tensor | None, global_valid_toks: torch.Tensor | None, - vocab_parallel_rank: Optional[int] = None, - vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None, - context_parallel_group: Optional[torch.distributed.ProcessGroup] = None, ) -> tuple[torch.Tensor, dict[str, Any]]: - # logits shape: [batch_size, seq_len, vocab_size] - # Get the next token logits for each position - next_tokens = data["input_ids"][:, 1:].cuda() # Skip first token - next_token_logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1) - logprobs = next_token_logprobs[:, :-1] # Remove last position's logits - - # Gather the logprobs for the actual next tokens - token_logprobs = logprobs.gather( - dim=-1, index=next_tokens.unsqueeze(-1) - ).squeeze(-1) - # Only compute loss on generated tokens (not input tokens) - # by applying the token_loss_mask (shifted by 1 since we're predicting next tokens) - token_loss_mask = data["token_loss_mask"][:, 1:].cuda() - loss = -torch.sum(token_logprobs * token_loss_mask) + # by applying the token_mask (shifted by 1 since we're predicting next tokens) + mask = data["token_mask"][:, 1:].cuda() + loss = -torch.sum(next_token_logprobs * mask) return loss, { "loss": loss.item(), diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index ad4c9b7728..f94da69c54 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -29,7 +29,7 @@ load_checkpoint, save_checkpoint, ) -from tests.unit.test_utils import SimpleLoss +from tests.unit.test_utils import SimpleLossFn # Define basic test config simple_policy_config = { @@ -310,7 +310,7 @@ def test_convert_dcp_to_hf(policy, num_gpus, request): "sample_mask": torch.ones(input_ids.shape[0]), } ) - policy.train(dummy_fwd_dict, SimpleLoss()) + policy.train(dummy_fwd_dict, SimpleLossFn()) policy_version_is_v2 = request.node.callspec.params["policy"] with TemporaryDirectory() as tmp_dir: