Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/model-quirks.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ NeMo-RL uses the vLLM V1 runtime for both synchronous and asynchronous inference

- NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations.
Whether model level support CP only depends on arguments passed to `torch.nn.functional.scaled_dot_product_attention`. Current NeMo-RL passed all ones attention mask to `model.forward`. For Gemma-3, it won't ignore attention mask as result `attn_bias` is not None which is not supported by torch CP. Please see [assertion](https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/tensor/experimental/_attention.py#L262) .
- Context parallel can't be used together with sequence packing. Sequence packing requires `attn_implementation="flash_attention_2"`, this conflict with context parallel requires SDPA impl. Refer to [here](https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/modeling_utils.py#L2317) for more details.


- It's a known issue that context parallel can't be used together with sequence parallel.
Refer to [here](https://github.com/NVIDIA-NeMo/RL/issues/659) for more details.

- It's a known issue that context parallel can't be used together with sequence parallel.
Refer to [here](https://github.com/NVIDIA-NeMo/RL/issues/659) for more details.
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ policy:
dynamic_batching:
enabled: False

sequence_packing:
enabled: False

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_deepscaler-1.5b-24K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ policy:
dynamic_batching:
enabled: False

sequence_packing:
enabled: False

optimizer:
name: "torch.optim.AdamW"
kwargs:
Expand Down
10 changes: 10 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,26 @@ policy:
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

megatron_cfg:
enabled: false

# dynamic_batching improves performance by ensuring logprob and training microbatches
# have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length
# responses are sorted by sequence length and bucketed into microbatches with a total
# amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the
# training and logprob stages respectively.
dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

# makes the training sequence length divisible by the tensor parallel size
Expand Down
9 changes: 7 additions & 2 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,19 @@ policy:
# responses are sorted by sequence length and bucketed into microbatches with a total
# amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the
# training and logprob stages respectively.
#
# We disable it for Megatron as it is incompatible with Pipeline parallelism. Instead, we use sequence packing
dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: False # coming soon
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_ffd"
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

max_grad_norm: 1.0
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_8B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ policy:

cluster:
gpus_per_node: 8
num_nodes: 1
num_nodes: 1
8 changes: 7 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ policy:
dynamic_batching:
enabled: false

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down Expand Up @@ -121,7 +127,7 @@ policy:
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"


data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "squad"
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ policy:
context_parallel_size: 1
custom_parallel_plan: null

sequence_packing:
enabled: False

dynamic_batching:
enabled: false

Expand Down
2 changes: 2 additions & 0 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)


def parse_args():
"""Parse command line arguments."""
Expand Down
95 changes: 95 additions & 0 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __call__(
global_valid_toks: torch.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.Tensor, dict]:
"""Clipped Policy Gradient RL loss function."""
token_mask = data["token_mask"][:, 1:]
Expand Down Expand Up @@ -149,7 +150,10 @@ def __call__(
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
Expand Down Expand Up @@ -312,6 +316,7 @@ def __call__(
global_valid_toks: Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) -> tuple[torch.Tensor, dict[str, Any]]:
Expand All @@ -335,7 +340,10 @@ def __call__(
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
Expand Down Expand Up @@ -466,6 +474,7 @@ def _preference_loss(
global_valid_seqs: Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor
token_mask = data["token_mask"][:, 1:]
Expand All @@ -483,7 +492,10 @@ def _preference_loss(
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
Expand Down Expand Up @@ -548,6 +560,7 @@ def __call__(
global_valid_toks: Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
sft_loss_chosen = torch.tensor(0.0)
if self.sft_loss_weight > 0:
Expand All @@ -561,6 +574,7 @@ def __call__(
global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
dpo_loss=True,
dpo_average_log_probs=self.sft_average_log_probs,
)
Expand All @@ -582,6 +596,7 @@ def __call__(
global_valid_seqs,
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
)

dpo_loss = (
Expand All @@ -601,3 +616,83 @@ def __call__(
"rewards_rejected_mean": rewards_rejected_mean.item(),
"num_valid_samples": num_valid_samples.item(),
}


class SequencePackingLossWrapper:
def __init__(
self,
loss_fn: LossFunction,
cu_seqlens_q: Tensor,
cu_seqlens_q_padded: Optional[Tensor] = None,
):
self.loss_fn = loss_fn
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_q_padded = cu_seqlens_q_padded

def __call__(
self,
next_token_logits: Tensor,
data: BatchedDataDict[Any],
global_valid_seqs: Tensor | None,
global_valid_toks: Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[Tensor, dict[str, Any]]:
"""Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding."""
unpadded_cu_seqlens = self.cu_seqlens_q
unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1]
if self.cu_seqlens_q_padded is not None:
padded_cu_seqlens = self.cu_seqlens_q_padded
padded_seq_lengths = (
self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1]
)
else:
padded_cu_seqlens = unpadded_cu_seqlens
padded_seq_lengths = unpadded_seq_lengths
seq_starts = padded_cu_seqlens[:-1]
seq_ends = padded_cu_seqlens[1:]

loss_accum = 0
metrics_accum = {}
for seq_idx in range(len(seq_starts)):
seq_start = seq_starts[seq_idx].item()
seq_end = seq_ends[seq_idx].item()

# get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors
seq_data = data.slice(seq_idx, seq_idx + 1)
unpadded_seq_data = {}
for k, v in seq_data.items():
if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1:
unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]]
else:
unpadded_seq_data[k] = v

# get next_token_logits
cp_size = (
1
if context_parallel_group is None
else torch.distributed.get_world_size(context_parallel_group)
)
logit_slice_idxs = slice(
seq_start // cp_size,
(seq_start + padded_seq_lengths[seq_idx]) // cp_size,
)
next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :]

loss, metrics = self.loss_fn(
next_token_logits_slice,
unpadded_seq_data,
global_valid_seqs,
global_valid_toks,
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
)
loss_accum += loss
for k, v in metrics.items():
if k not in metrics_accum:
metrics_accum[k] = 0
metrics_accum[k] += v

return loss_accum, metrics_accum
35 changes: 35 additions & 0 deletions nemo_rl/data/packing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_rl.data.packing.algorithms import (
ConcatenativePacker,
FirstFitDecreasingPacker,
FirstFitShufflePacker,
ModifiedFirstFitDecreasingPacker,
PackingAlgorithm,
SequencePacker,
get_packer,
)
from nemo_rl.data.packing.metrics import PackingMetrics

__all__ = [
"PackingAlgorithm",
"SequencePacker",
"ConcatenativePacker",
"FirstFitDecreasingPacker",
"FirstFitShufflePacker",
"ModifiedFirstFitDecreasingPacker",
"get_packer",
"PackingMetrics",
]
Loading
Loading