Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion nemo_rl/algorithms/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@
PreferenceLossDataDict,
PreferenceLossFn,
)
from nemo_rl.algorithms.loss.utils import prepare_loss_input
from nemo_rl.algorithms.loss.utils import (
prepare_loss_input,
prepare_packed_loss_input,
)
from nemo_rl.algorithms.loss.wrapper import (
SequencePackingFusionLossWrapper,
SequencePackingLossWrapper,
wrap_loss_fn_with_input_preparation,
)
Expand All @@ -46,6 +50,8 @@
"PreferenceLossDataDict",
"PreferenceLossFn",
"prepare_loss_input",
"prepare_packed_loss_input",
"SequencePackingFusionLossWrapper",
"SequencePackingLossWrapper",
"wrap_loss_fn_with_input_preparation",
]
158 changes: 158 additions & 0 deletions nemo_rl/algorithms/loss/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from nemo_rl.algorithms.utils import mask_out_neg_inf_logprobs
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.model_utils import (
_get_tokens_on_this_cp_rank,
from_parallel_logits_to_logprobs_packed_sequences,
get_distillation_topk_logprobs_from_logits,
get_next_token_logprobs_from_logits,
)
Expand Down Expand Up @@ -119,3 +121,159 @@ def prepare_loss_input(
raise ValueError(f"Unknown loss function input type: {loss_fn.input_type}")

return loss_input, data


def _pack_input_ids(
input_ids: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_q_padded: torch.Tensor,
cp_rank: int = 0,
cp_size: int = 1,
roll_shift: int = 0,
) -> torch.Tensor:
"""Pack input_ids from [B, S] to [1, T_packed // CP] using sequence boundaries.

Each sequence is individually padded to its padded length (from
cu_seqlens_q_padded), optionally rolled, and CP-sharded at that padded
length before being placed into the packed output. This matches how
Megatron packs and CP-shards sequences in _pack_sequences_for_megatron.

Args:
input_ids: Unpacked input IDs [B, S].
cu_seqlens_q: Unpadded cumulative sequence lengths [B+1].
cu_seqlens_q_padded: Padded cumulative sequence lengths [B+1].
cp_rank: Context parallelism rank.
cp_size: Context parallelism size.
roll_shift: If non-zero, roll each padded sequence by this amount
before CP-sharding. Use -1 to build shifted targets for
next-token prediction.
"""
batch_size = input_ids.shape[0]
total_packed_len = int(cu_seqlens_q_padded[-1].item()) // cp_size
packed = torch.zeros(
total_packed_len, dtype=input_ids.dtype, device=input_ids.device
)
for i in range(batch_size):
actual_len = int((cu_seqlens_q[i + 1] - cu_seqlens_q[i]).item())
padded_len = int((cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i]).item())
packed_start = int(cu_seqlens_q_padded[i].item())
seq = torch.zeros(padded_len, dtype=input_ids.dtype, device=input_ids.device)
seq[:actual_len] = input_ids[i, :actual_len]
if roll_shift != 0:
seq = seq.roll(shifts=roll_shift, dims=0)
sharded = _get_tokens_on_this_cp_rank(seq, cp_rank, cp_size, seq_dim=0)
packed[packed_start // cp_size : (packed_start + padded_len) // cp_size] = (
sharded
)
return packed.unsqueeze(0)


def prepare_packed_loss_input(
logits: torch.Tensor,
data: BatchedDataDict[Any],
loss_fn: LossFunction,
cu_seqlens_q: torch.Tensor,
cu_seqlens_q_padded: torch.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
sampling_params: Optional[TrainingSamplingParams] = None,
) -> tuple[dict[str, Any], BatchedDataDict[Any]]:
"""Prepare loss input from packed logits in a single fused pass.

Unlike prepare_loss_input which operates on a single (unpacked) sequence,
this function computes log probabilities from packed logits across all
sequences at once using from_parallel_logits_to_logprobs_packed_sequences.

Currently only supports LossInputType.LOGPROB.

Args:
logits: Packed logits from the model [1, T_packed // CP, V // TP].
data: Microbatch data (unpacked, [B, S]).
loss_fn: Loss function (must have input_type == LossInputType.LOGPROB).
cu_seqlens_q: Unpadded cumulative sequence lengths [B+1].
cu_seqlens_q_padded: Padded cumulative sequence lengths [B+1].
vocab_parallel_rank: Vocab parallel rank.
vocab_parallel_group: Vocab parallel group.
context_parallel_group: Context parallel group.
sampling_params: Sampling parameters.

Returns:
tuple(loss_input, maybe_updated_data)
"""
if loss_fn.input_type != LossInputType.LOGPROB:
raise ValueError(
f"prepare_packed_loss_input only supports LossInputType.LOGPROB, "
f"got {loss_fn.input_type}. Use SequencePackingLossWrapper with "
f"prepare_loss_input for other types."
)
assert vocab_parallel_group is not None, (
"prepare_packed_loss_input requires vocab_parallel_group (Megatron TP)."
)
assert vocab_parallel_rank is not None, (
"vocab_parallel_rank must be provided with vocab_parallel_group."
)

input_ids = data["input_ids"]
unpacked_seqlen = input_ids.shape[1]
cp_size = (
1
if context_parallel_group is None
else torch.distributed.get_world_size(context_parallel_group)
)
cp_rank = (
0
if context_parallel_group is None
else torch.distributed.get_rank(context_parallel_group)
)

packed_rolled_targets = _pack_input_ids(
input_ids,
cu_seqlens_q,
cu_seqlens_q_padded,
cp_rank=cp_rank,
cp_size=cp_size,
roll_shift=-1,
)

logprobs = from_parallel_logits_to_logprobs_packed_sequences(
logits.to(torch.float32),
packed_rolled_targets,
cu_seqlens_q_padded,
unpacked_seqlen,
vocab_start_index=vocab_parallel_rank * logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * logits.shape[-1],
group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
sampling_params=sampling_params,
target_is_pre_rolled=True,
)

# Match prepare_loss_input behavior for top-k/top-p filtered training:
# use filtered curr_logprobs for actor loss, but keep unfiltered values for KL.
if need_top_k_or_top_p_filtering(sampling_params):
mask = data["token_mask"] * data["sample_mask"].unsqueeze(-1)
logprobs = mask_out_neg_inf_logprobs(logprobs, mask[:, 1:], "curr_logprobs")

if (
hasattr(loss_fn, "reference_policy_kl_penalty")
and loss_fn.reference_policy_kl_penalty != 0
):
data["curr_logprobs_unfiltered"] = (
from_parallel_logits_to_logprobs_packed_sequences(
logits.to(torch.float32),
packed_rolled_targets,
cu_seqlens_q_padded,
unpacked_seqlen,
vocab_start_index=vocab_parallel_rank * logits.shape[-1],
vocab_end_index=(vocab_parallel_rank + 1) * logits.shape[-1],
group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
sampling_params=None,
target_is_pre_rolled=True,
)
)

return {"next_token_logprobs": logprobs}, data
62 changes: 62 additions & 0 deletions nemo_rl/algorithms/loss/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,68 @@ def __call__(
return loss_accum, metrics_accum


class SequencePackingFusionLossWrapper:
"""Fused sequence packing loss wrapper that processes all sequences in one forward pass.

Unlike SequencePackingLossWrapper which iterates over sequences one at a time,
this wrapper calls prepare_fn once on the packed logits to compute log
probabilities in a single shot, then calls the loss function once with the
pre-computed result.

This avoids per-sequence kernel launches and TP/CP communication overhead while
producing numerically identical results.

The prepare_fn should be prepare_packed_loss_input (from nemo_rl.algorithms.loss.utils),
which currently only supports LossInputType.LOGPROB.
"""

def __init__(
self,
loss_fn: LossFunction,
prepare_fn: Callable[..., Any],
cu_seqlens_q: Tensor,
cu_seqlens_q_padded: Optional[Tensor] = None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
):
self.loss_fn = loss_fn
self.prepare_fn = prepare_fn
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_q_padded = (
cu_seqlens_q_padded if cu_seqlens_q_padded is not None else cu_seqlens_q
)
self.vocab_parallel_rank = vocab_parallel_rank
self.vocab_parallel_group = vocab_parallel_group
self.context_parallel_group = context_parallel_group

def __call__(
self,
next_token_logits: Tensor,
data: BatchedDataDict[Any],
global_valid_seqs: Tensor | None,
global_valid_toks: Tensor | None,
) -> tuple[Tensor, dict[str, Any]]:
"""Compute loss for all packed sequences in one forward pass."""
loss_input, prepared_data = self.prepare_fn(
logits=next_token_logits,
data=data,
loss_fn=self.loss_fn,
cu_seqlens_q=self.cu_seqlens_q,
cu_seqlens_q_padded=self.cu_seqlens_q_padded,
vocab_parallel_rank=self.vocab_parallel_rank,
vocab_parallel_group=self.vocab_parallel_group,
context_parallel_group=self.context_parallel_group,
)

return self.loss_fn(
data=prepared_data,
global_valid_seqs=global_valid_seqs,
global_valid_toks=global_valid_toks,
**loss_input,
)


def wrap_loss_fn_with_input_preparation(
next_token_logits: Tensor,
data: BatchedDataDict[Any],
Expand Down
63 changes: 35 additions & 28 deletions nemo_rl/distributed/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -942,61 +942,68 @@ def from_parallel_logits_to_logprobs_packed_sequences(
cp_group: Optional[torch.distributed.ProcessGroup] = None,
chunk_size: Optional[int] = None,
sampling_params: Optional[TrainingSamplingParams] = None,
target_is_pre_rolled: bool = False,
) -> torch.Tensor:
"""Get log probabilities from TP sharded vocab logits for packed sequences.

Args:
vocab_parallel_logits (torch.Tensor): Packed logits tensor with shape [1, T // CP, vocab_size//TP]
where T is the total number of tokens across all packed sequences.
target (torch.Tensor): Packed target token indices with shape [1, T].
NOTE: Must be the unmodified targets as this function will shift them internally.
cu_seqlens (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1].
cu_seqlens[i] indicates the start position of sequence i in the packed format.
target (torch.Tensor): Packed target token indices.
If target_is_pre_rolled=False: shape [1, T] — unmodified targets, rolled internally.
If target_is_pre_rolled=True: shape [1, T // CP] — pre-rolled and pre-CP-sharded.
cu_seqlens_padded (torch.Tensor): Cumulative sequence lengths tensor with shape [batch_size + 1].
cu_seqlens_padded[i] indicates the start position of sequence i in the packed format
(full, not CP-adjusted).
unpacked_seqlen (int): The length of the unpacked sequence tensor.
vocab_start_index (int): Starting vocabulary index for this worker's partition.
vocab_end_index (int): Ending vocabulary index for this worker's partition.
group (torch.distributed.ProcessGroup): Process group for distributed communication.
inference_only (bool, optional): If True, tensors won't be saved for backward pass. Defaults to False.
cp_group (torch.distributed.ProcessGroup, optional): Context parallelism process group. Defaults to None.
chunk_size (int, optional): Sequence dimension chunk size for computing the log probabilities.
sampling_params (TrainingSamplingParams, optional): Sampling parameters for Top-k/Top-p filtering.
target_is_pre_rolled (bool): If True, target is already shifted and CP-sharded to match
vocab_parallel_logits shape, skipping the internal per-sequence roll+CP-shard loop.

Returns:
torch.Tensor: Unpacked log probabilities tensor with shape [batch_size, unpacked_seqlen-1].
The total length is reduced by batch_size due to target shifting (one token per sequence).
"""
# Remove batch dimension to work with [T, vocab_size] and [T]
vocab_parallel_logits = vocab_parallel_logits.squeeze(0)
target = target.squeeze(0)

batch_size = cu_seqlens_padded.shape[0] - 1
cp_size = 1 if cp_group is None else torch.distributed.get_world_size(cp_group)
cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group)

# Roll each sequence individually
rolled_targets = torch.zeros(
target.shape[0] // cp_size, dtype=target.dtype, device=target.device
)
for i in range(batch_size):
start_idx = cu_seqlens_padded[i].item()
end_idx = cu_seqlens_padded[i + 1].item()
if not target_is_pre_rolled:
# Roll each sequence individually and CP-shard the targets
# Remove batch dimension to work with [T, vocab_size] and [T]
vocab_parallel_logits = vocab_parallel_logits.squeeze(0)
target = target.squeeze(0)
cp_rank = 0 if cp_group is None else torch.distributed.get_rank(cp_group)

# Get the sequence targets and roll by -1
seq_targets = target[start_idx:end_idx]
rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0)
rolled_targets[start_idx // cp_size : end_idx // cp_size] = (
_get_tokens_on_this_cp_rank(rolled_seq_targets, cp_rank, cp_size, seq_dim=0)
rolled_targets = torch.zeros(
target.shape[0] // cp_size, dtype=target.dtype, device=target.device
)
for i in range(batch_size):
start_idx = cu_seqlens_padded[i].item()
end_idx = cu_seqlens_padded[i + 1].item()

seq_targets = target[start_idx:end_idx]
rolled_seq_targets = seq_targets.roll(shifts=-1, dims=0)
rolled_targets[start_idx // cp_size : end_idx // cp_size] = (
_get_tokens_on_this_cp_rank(
rolled_seq_targets, cp_rank, cp_size, seq_dim=0
)
)

# Add batch dimension back for DistributedLogprob
rolled_targets = rolled_targets.unsqueeze(0)
vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0)
target = rolled_targets.unsqueeze(0)
vocab_parallel_logits = vocab_parallel_logits.unsqueeze(0)

# Apply distributed log probability computation
if need_top_k_or_top_p_filtering(sampling_params):
if chunk_size is not None:
probs: torch.Tensor = ChunkedDistributedLogprobWithSampling.apply( # type: ignore
vocab_parallel_logits,
rolled_targets,
target,
group,
sampling_params.top_k,
sampling_params.top_p,
Expand All @@ -1006,7 +1013,7 @@ def from_parallel_logits_to_logprobs_packed_sequences(
else:
probs: torch.Tensor = DistributedLogprobWithSampling.apply( # type: ignore
vocab_parallel_logits,
rolled_targets,
target,
group,
sampling_params.top_k,
sampling_params.top_p,
Expand All @@ -1016,7 +1023,7 @@ def from_parallel_logits_to_logprobs_packed_sequences(
if chunk_size is not None:
probs: torch.Tensor = ChunkedDistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
rolled_targets,
target,
vocab_start_index,
vocab_end_index,
chunk_size,
Expand All @@ -1026,7 +1033,7 @@ def from_parallel_logits_to_logprobs_packed_sequences(
else:
probs: torch.Tensor = DistributedLogprob.apply( # type: ignore
vocab_parallel_logits,
rolled_targets,
target,
vocab_start_index,
vocab_end_index,
group,
Expand Down
Loading
Loading