Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
126 commits
Select commit Hold shift + click to select a range
097de0d
Fixed ~100 type errors
SahilJain314 May 9, 2025
21afaf7
fixed 50 more type issues
SahilJain314 May 9, 2025
f38ab81
lint
SahilJain314 May 9, 2025
e5296bc
Added mypy config, fixed 50 more type errors, and updated old typing.…
SahilJain314 May 11, 2025
a9cee8f
Updated pyproject
SahilJain314 May 11, 2025
89cf9cf
Down to 100 errors
SahilJain314 May 11, 2025
e8160c3
Update pyproject
SahilJain314 May 11, 2025
6a494b9
Down to 50 errors
SahilJain314 May 12, 2025
8a95964
Added testing doc
SahilJain314 May 12, 2025
3b0bfbc
Fixed missing import
SahilJain314 May 12, 2025
588cfda
Fixed tokenizer type
SahilJain314 May 12, 2025
bc07e76
Down to 50 errors
SahilJain314 May 12, 2025
72e4819
fixed 150 strict mypy errors
SahilJain314 May 12, 2025
9bb8d83
lint
SahilJain314 May 12, 2025
7376384
Fixed another 100 strict typing mypy errors
SahilJain314 May 12, 2025
b3f5098
Fixed another 100 strict typing mypy errors (down to 130)
SahilJain314 May 12, 2025
e8b3552
Brought non-strict errors down to 18
SahilJain314 May 13, 2025
7aa1f44
Brought non-strict errors down further
SahilJain314 May 13, 2025
958f538
Fixed pynvml test type
SahilJain314 May 13, 2025
6a14aec
feat: support mcore extra (megatron + tron)
terrykong May 2, 2025
17f5136
typo
terrykong May 17, 2025
b17e163
all good
terrykong May 17, 2025
5b3b96d
pin pre-commit
terrykong May 17, 2025
48f8ffc
undo
terrykong May 17, 2025
8aa6ff8
move submodule stuff into comment until
terrykong May 18, 2025
c5f4305
ok
terrykong May 18, 2025
fdc6024
rmove this round
terrykong May 18, 2025
5c22f20
fix
terrykong May 18, 2025
23c0943
dockerignore too
terrykong May 18, 2025
7d2d8dc
Moved all actors to using NamedSharding-based distribution instead of…
SahilJain314 May 20, 2025
9cd192b
oops forgot named_sharding files
SahilJain314 May 20, 2025
9952110
lint
SahilJain314 May 20, 2025
e599e76
Merge remote-tracking branch 'origin/main' into sahilj/type_fix
SahilJain314 May 20, 2025
c917363
Updated with tot merge
SahilJain314 May 20, 2025
a843293
Updated configure_generation_config
SahilJain314 May 21, 2025
11b2f42
updated uv lock
SahilJain314 May 21, 2025
d72dfcb
original uv lock
SahilJain314 May 21, 2025
18aee25
updated uv lock
SahilJain314 May 21, 2025
a9c285b
original uv
SahilJain314 May 21, 2025
b678022
Merge remote-tracking branch 'origin/main' into sahilj/type_fix
SahilJain314 May 21, 2025
3b241df
updated uv lock
SahilJain314 May 21, 2025
40c6275
Pushed coordinate finding into the NamedSharding
SahilJain314 May 21, 2025
aacb783
Unit test failure
SahilJain314 May 21, 2025
c477048
Added tests and fixed types
SahilJain314 May 21, 2025
c0043c1
Added mypy to ci (just a warning rn)
SahilJain314 May 21, 2025
e3f98ea
Merge remote-tracking branch 'origin/sahilj/type_fix' into sahilj/nam…
SahilJain314 May 21, 2025
76d510a
try setting max_jobs really low
terrykong May 21, 2025
3058683
Merge remote-tracking branch 'origin' into sahilj/named_sharding
SahilJain314 May 21, 2025
c38fcb4
Merge branch 'tk/megatron-extra' into sahilj/megatron_tot
SahilJain314 May 22, 2025
3f9667f
Added Megatron
SahilJain314 May 22, 2025
4b82ffd
Megatron fixes
SahilJain314 May 23, 2025
e36935d
tot bugfixes
SahilJain314 May 27, 2025
36d41ba
Updated git module
SahilJain314 May 27, 2025
9678444
Added kwargs to save checkpoint
SahilJain314 May 27, 2025
e5bcd1e
Fixed checkpointing Megatron
SahilJain314 May 28, 2025
f9be255
Don't bother with rng restore
SahilJain314 May 28, 2025
42800e6
Fixed metric logging
SahilJain314 May 28, 2025
dc7920b
lint
SahilJain314 May 28, 2025
f88bbc3
Updated nemo patch
SahilJain314 May 29, 2025
8910bd4
Updated patch
SahilJain314 May 29, 2025
84bd407
Fixed memory offloading for parameter and grad buffers
SahilJain314 Jun 3, 2025
ac6405c
fix: Don't call state_dict in loop + dtype fix (#445)
yfw Jun 3, 2025
ea2d51f
Enable dyanmic batching
SahilJain314 Jun 3, 2025
4fd43fd
Fixed pp bug
SahilJain314 Jun 3, 2025
ff60018
lint
SahilJain314 Jun 3, 2025
265c365
Merge remote-tracking branch 'origin' into sahilj/megatron_tot
SahilJain314 Jun 3, 2025
732228c
Fixed merge artifact
SahilJain314 Jun 4, 2025
acf8b15
Fixes for tests
SahilJain314 Jun 4, 2025
292a447
Fixed dynamic batching and improved memory usage
SahilJain314 Jun 4, 2025
6243456
default expandable segments on
SahilJain314 Jun 4, 2025
8c8544b
Added basic sequence packing
SahilJain314 Jun 10, 2025
1fa69a6
Added basic sequence packing
SahilJain314 Jun 10, 2025
fca4b6d
Fixed PP with sequence packing
SahilJain314 Jun 10, 2025
1ffe9f7
Updated Megatron patch
SahilJain314 Jun 10, 2025
1e51324
Remove custom_fsdp mentions
SahilJain314 Jun 11, 2025
e03a36c
Bump ray
SahilJain314 Jun 11, 2025
d04b0ba
Added a 70b config with megatron
SahilJain314 Jun 11, 2025
44d40af
Merge branch 'sahilj/megatron_packed' of github.com:NVIDIA/NeMo-RL in…
SahilJain314 Jun 11, 2025
6a24f56
Sequence packing
ahmadki Jun 3, 2025
04859b7
checkpoint
ahmadki Jun 4, 2025
7592a35
revert some packing changes
ahmadki Jun 4, 2025
6c57e22
initial fix for different micro batch lengths
ahmadki Jun 6, 2025
051b329
benchmark configs
ahmadki Jun 6, 2025
fbbd77a
minor fixes
ahmadki Jun 8, 2025
88919cf
grpo configs
ahmadki Jun 8, 2025
4fb888d
grpo fixes so code would run
ahmadki Jun 8, 2025
1da878e
loss function fixes, added packing strategy to policy
ahmadki Jun 9, 2025
db26d63
SFT config/API cleanup
ahmadki Jun 9, 2025
5598f50
debug mode on
ahmadki Jun 9, 2025
c1b072f
new packing
ahmadki Jun 9, 2025
06de588
cleanup
ahmadki Jun 9, 2025
29a0cae
Merge branch 'main' into ahmadki/dev/sequence_packing_2
ahmadki Jun 11, 2025
4a3fa32
Merge branch 'sahilj/megatron_packed' into ahmadki/dev/sequence_packi…
ahmadki Jun 11, 2025
9dfddc0
implemented MFFD as a "SequencePacker", moved it to bin packing algor…
ahmadki Jun 15, 2025
ffb3b98
logging cleanup
ahmadki Jun 15, 2025
395155c
made packing algorithms naming more clear
ahmadki Jun 15, 2025
9b5c2fa
more code cleanup
ahmadki Jun 18, 2025
c314338
Merge branch 'main' into ahmadki/dev/sequence_packing_2
ahmadki Jun 18, 2025
bd9e081
reduce amount of diff with main
ahmadki Jun 18, 2025
0faea67
reduce amount of diff with main 2
ahmadki Jun 18, 2025
b4f9297
added back flash-attn dependency
ahmadki Jun 22, 2025
394e3fd
cleanup and config alignments
ahmadki Jun 22, 2025
c1cab59
Merge branch 'main' into ahmadki/dev/sequence_packing_2
ahmadki Jun 22, 2025
48fbbad
config alignments, configs for new implementation
ahmadki Jun 24, 2025
227f6b0
generic get_packer
ahmadki Jun 24, 2025
93d05f1
config syntax cleanup
ahmadki Jun 24, 2025
b858ef2
moved dtensor sequence packing functions into hf common
ahmadki Jun 24, 2025
67861dd
typed flash attention kwargs
ahmadki Jun 24, 2025
c25667f
dropped database based seq packing
ahmadki Jun 24, 2025
79848fc
typo
ahmadki Jun 24, 2025
7d50e22
unified loss_fn for seq packing
ahmadki Jun 24, 2025
163da64
config organization
ahmadki Jun 24, 2025
91d09e9
removed debug configs
ahmadki Jun 24, 2025
bda4219
more config cleanup
ahmadki Jun 24, 2025
e7e4038
removed PackedDataset
ahmadki Jun 29, 2025
66bd9c7
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jun 29, 2025
be32c8e
aligned NeMo git submodule with main
ahmadki Jun 29, 2025
d63e440
Merge branch 'main' into ahmadki/sequence_packing
SahilJain314 Jun 30, 2025
28410f5
Merge branch 'main' into ahmadki/sequence_packing
SahilJain314 Jul 1, 2025
ae8f12f
Lint fix
SahilJain314 Jul 1, 2025
e078be1
Load AutoModelForCausalLM weight in FP32
ahmadki Jul 1, 2025
0d2f2c0
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jul 7, 2025
344275c
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jul 9, 2025
43a401e
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jul 14, 2025
e1c22a1
fix cp_size reference after merge with main
ahmadki Jul 14, 2025
29d34ba
added missing megatron_cfg to grpo_math config
ahmadki Jul 14, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ policy:
dynamic_batching:
enabled: False

sequence_packing:
enabled: False

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_deepscaler-1.5b-24K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ policy:
dynamic_batching:
enabled: False

sequence_packing:
enabled: False

optimizer:
name: "torch.optim.AdamW"
kwargs:
Expand Down
10 changes: 10 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,16 @@ policy:
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "concatenative"
sequence_length_round: 64

megatron_cfg:
enabled: false

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down
4 changes: 0 additions & 4 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,8 @@ policy:
# training and logprob stages respectively.
dynamic_batching:
enabled: False

sequence_packing:
enabled: False # coming soon
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_ffd"
sequence_length_round: 64

max_grad_norm: 1.0
Expand Down
2 changes: 1 addition & 1 deletion examples/configs/grpo_math_8B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,4 @@ policy:

cluster:
gpus_per_node: 8
num_nodes: 1
num_nodes: 1
8 changes: 7 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ policy:
dynamic_batching:
enabled: false

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
algorithm: "concatenative"
sequence_length_round: 64

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down Expand Up @@ -121,7 +127,7 @@ policy:
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"


data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "squad"
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ policy:
context_parallel_size: 1
custom_parallel_plan: null

sequence_packing:
enabled: False

dynamic_batching:
enabled: false

Expand Down
2 changes: 2 additions & 0 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)


def parse_args():
"""Parse command line arguments."""
Expand Down
73 changes: 73 additions & 0 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,3 +601,76 @@ def __call__(
"rewards_rejected_mean": rewards_rejected_mean.item(),
"num_valid_samples": num_valid_samples.item(),
}


class SequencePackingLossWrapper:
def __init__(
self,
loss_fn: LossFunction,
cu_seqlens_q: Tensor,
cu_seqlens_q_padded: Optional[Tensor] = None,
):
self.loss_fn = loss_fn
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_q_padded = cu_seqlens_q_padded

def __call__(
self,
next_token_logits: Tensor,
data: BatchedDataDict[Any],
global_valid_seqs: Tensor | None,
global_valid_toks: Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[Tensor, dict[str, Any]]:
"""Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid padding."""
unpadded_cu_seqlens = self.cu_seqlens_q
unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1]
if self.cu_seqlens_q_padded is not None:
padded_cu_seqlens = self.cu_seqlens_q_padded
padded_seq_lengths = (
self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1]
)
else:
padded_cu_seqlens = unpadded_cu_seqlens
padded_seq_lengths = unpadded_seq_lengths
seq_starts = padded_cu_seqlens[:-1]
seq_ends = padded_cu_seqlens[1:]

loss_accum = 0
metrics_accum = {}
for seq_idx in range(len(seq_starts)):
seq_start = seq_starts[seq_idx].item()
seq_end = seq_ends[seq_idx].item()

# get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors
seq_data = data.slice(seq_idx, seq_idx + 1)
unpadded_seq_data = {}
for k, v in seq_data.items():
# print(f"k: {k}, v: {v.shape}")
if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1:
unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]]
else:
unpadded_seq_data[k] = v

# get next_token_logits
next_token_logits_slice = next_token_logits[
:, seq_start : seq_start + unpadded_seq_lengths[seq_idx], :
]
# print(f"seq_start: {seq_start}, seq_end: {seq_end}, next_token_logits: {next_token_logits_slice.shape}")

loss, metrics = self.loss_fn(
next_token_logits_slice,
unpadded_seq_data,
global_valid_seqs,
global_valid_toks,
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
)
loss_accum += loss
for k, v in metrics.items():
if k not in metrics_accum:
metrics_accum[k] = 0
metrics_accum[k] += v

return loss_accum, metrics_accum
35 changes: 35 additions & 0 deletions nemo_rl/data/packing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_rl.data.packing.algorithms import (
ConcatenativePacker,
FirstFitDecreasingPacker,
FirstFitShufflePacker,
ModifiedFirstFitDecreasingPacker,
PackingAlgorithm,
SequencePacker,
get_packer,
)
from nemo_rl.data.packing.metrics import PackingMetrics

__all__ = [
"PackingAlgorithm",
"SequencePacker",
"ConcatenativePacker",
"FirstFitDecreasingPacker",
"FirstFitShufflePacker",
"ModifiedFirstFitDecreasingPacker",
"get_packer",
"PackingMetrics",
]
Loading
Loading