Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
159 commits
Select commit Hold shift + click to select a range
097de0d
Fixed ~100 type errors
SahilJain314 May 9, 2025
21afaf7
fixed 50 more type issues
SahilJain314 May 9, 2025
f38ab81
lint
SahilJain314 May 9, 2025
e5296bc
Added mypy config, fixed 50 more type errors, and updated old typing.…
SahilJain314 May 11, 2025
a9cee8f
Updated pyproject
SahilJain314 May 11, 2025
89cf9cf
Down to 100 errors
SahilJain314 May 11, 2025
e8160c3
Update pyproject
SahilJain314 May 11, 2025
6a494b9
Down to 50 errors
SahilJain314 May 12, 2025
8a95964
Added testing doc
SahilJain314 May 12, 2025
3b0bfbc
Fixed missing import
SahilJain314 May 12, 2025
588cfda
Fixed tokenizer type
SahilJain314 May 12, 2025
bc07e76
Down to 50 errors
SahilJain314 May 12, 2025
72e4819
fixed 150 strict mypy errors
SahilJain314 May 12, 2025
9bb8d83
lint
SahilJain314 May 12, 2025
7376384
Fixed another 100 strict typing mypy errors
SahilJain314 May 12, 2025
b3f5098
Fixed another 100 strict typing mypy errors (down to 130)
SahilJain314 May 12, 2025
e8b3552
Brought non-strict errors down to 18
SahilJain314 May 13, 2025
7aa1f44
Brought non-strict errors down further
SahilJain314 May 13, 2025
958f538
Fixed pynvml test type
SahilJain314 May 13, 2025
6a14aec
feat: support mcore extra (megatron + tron)
terrykong May 2, 2025
17f5136
typo
terrykong May 17, 2025
b17e163
all good
terrykong May 17, 2025
5b3b96d
pin pre-commit
terrykong May 17, 2025
48f8ffc
undo
terrykong May 17, 2025
8aa6ff8
move submodule stuff into comment until
terrykong May 18, 2025
c5f4305
ok
terrykong May 18, 2025
fdc6024
rmove this round
terrykong May 18, 2025
5c22f20
fix
terrykong May 18, 2025
23c0943
dockerignore too
terrykong May 18, 2025
7d2d8dc
Moved all actors to using NamedSharding-based distribution instead of…
SahilJain314 May 20, 2025
9cd192b
oops forgot named_sharding files
SahilJain314 May 20, 2025
9952110
lint
SahilJain314 May 20, 2025
e599e76
Merge remote-tracking branch 'origin/main' into sahilj/type_fix
SahilJain314 May 20, 2025
c917363
Updated with tot merge
SahilJain314 May 20, 2025
a843293
Updated configure_generation_config
SahilJain314 May 21, 2025
11b2f42
updated uv lock
SahilJain314 May 21, 2025
d72dfcb
original uv lock
SahilJain314 May 21, 2025
18aee25
updated uv lock
SahilJain314 May 21, 2025
a9c285b
original uv
SahilJain314 May 21, 2025
b678022
Merge remote-tracking branch 'origin/main' into sahilj/type_fix
SahilJain314 May 21, 2025
3b241df
updated uv lock
SahilJain314 May 21, 2025
40c6275
Pushed coordinate finding into the NamedSharding
SahilJain314 May 21, 2025
aacb783
Unit test failure
SahilJain314 May 21, 2025
c477048
Added tests and fixed types
SahilJain314 May 21, 2025
c0043c1
Added mypy to ci (just a warning rn)
SahilJain314 May 21, 2025
e3f98ea
Merge remote-tracking branch 'origin/sahilj/type_fix' into sahilj/nam…
SahilJain314 May 21, 2025
76d510a
try setting max_jobs really low
terrykong May 21, 2025
3058683
Merge remote-tracking branch 'origin' into sahilj/named_sharding
SahilJain314 May 21, 2025
c38fcb4
Merge branch 'tk/megatron-extra' into sahilj/megatron_tot
SahilJain314 May 22, 2025
3f9667f
Added Megatron
SahilJain314 May 22, 2025
4b82ffd
Megatron fixes
SahilJain314 May 23, 2025
e36935d
tot bugfixes
SahilJain314 May 27, 2025
36d41ba
Updated git module
SahilJain314 May 27, 2025
9678444
Added kwargs to save checkpoint
SahilJain314 May 27, 2025
e5bcd1e
Fixed checkpointing Megatron
SahilJain314 May 28, 2025
f9be255
Don't bother with rng restore
SahilJain314 May 28, 2025
42800e6
Fixed metric logging
SahilJain314 May 28, 2025
dc7920b
lint
SahilJain314 May 28, 2025
f88bbc3
Updated nemo patch
SahilJain314 May 29, 2025
8910bd4
Updated patch
SahilJain314 May 29, 2025
84bd407
Fixed memory offloading for parameter and grad buffers
SahilJain314 Jun 3, 2025
2271229
fix: Don't call state_dict in loop + dtype fix (#445)
yfw Jun 3, 2025
b08ab1c
Enable dyanmic batching
SahilJain314 Jun 3, 2025
bbf8de7
Fixed pp bug
SahilJain314 Jun 3, 2025
41a542c
lint
SahilJain314 Jun 3, 2025
6b6355e
Merge remote-tracking branch 'origin' into sahilj/megatron_tot
SahilJain314 Jun 3, 2025
8ad9565
Fixed merge artifact
SahilJain314 Jun 4, 2025
f8f55a8
Fixes for tests
SahilJain314 Jun 4, 2025
1688ff2
Fixed dynamic batching and improved memory usage
SahilJain314 Jun 4, 2025
af9f767
default expandable segments on
SahilJain314 Jun 4, 2025
678caf7
Added basic sequence packing
SahilJain314 Jun 10, 2025
ec9b174
Added basic sequence packing
SahilJain314 Jun 10, 2025
72ff438
Fixed PP with sequence packing
SahilJain314 Jun 10, 2025
7142efa
Updated Megatron patch
SahilJain314 Jun 10, 2025
ce21205
Remove custom_fsdp mentions
SahilJain314 Jun 11, 2025
7b3cc97
Bump ray
SahilJain314 Jun 11, 2025
6608e62
Added a 70b config with megatron
SahilJain314 Jun 11, 2025
e0002c4
Merge branch 'sahilj/megatron_packed' of github.com:NVIDIA/NeMo-RL in…
SahilJain314 Jun 11, 2025
6a24f56
Sequence packing
ahmadki Jun 3, 2025
04859b7
checkpoint
ahmadki Jun 4, 2025
7592a35
revert some packing changes
ahmadki Jun 4, 2025
6c57e22
initial fix for different micro batch lengths
ahmadki Jun 6, 2025
051b329
benchmark configs
ahmadki Jun 6, 2025
fbbd77a
minor fixes
ahmadki Jun 8, 2025
88919cf
grpo configs
ahmadki Jun 8, 2025
4fb888d
grpo fixes so code would run
ahmadki Jun 8, 2025
1da878e
loss function fixes, added packing strategy to policy
ahmadki Jun 9, 2025
db26d63
SFT config/API cleanup
ahmadki Jun 9, 2025
5598f50
debug mode on
ahmadki Jun 9, 2025
c1b072f
new packing
ahmadki Jun 9, 2025
06de588
cleanup
ahmadki Jun 9, 2025
29a0cae
Merge branch 'main' into ahmadki/dev/sequence_packing_2
ahmadki Jun 11, 2025
34bca2e
Merge branch 'sahilj/megatron_packed' into ahmadki/dev/sequence_packi…
ahmadki Jun 11, 2025
41972d7
implemented MFFD as a "SequencePacker", moved it to bin packing algor…
ahmadki Jun 15, 2025
ab47ed1
logging cleanup
ahmadki Jun 15, 2025
35e0421
made packing algorithms naming more clear
ahmadki Jun 15, 2025
e2a3375
more code cleanup
ahmadki Jun 18, 2025
099ccce
Merge branch 'main' into ahmadki/dev/sequence_packing_2
ahmadki Jun 18, 2025
1c8cc46
reduce amount of diff with main
ahmadki Jun 18, 2025
76094e0
reduce amount of diff with main 2
ahmadki Jun 18, 2025
fd75847
added back flash-attn dependency
ahmadki Jun 22, 2025
773c6db
cleanup and config alignments
ahmadki Jun 22, 2025
8f64913
Merge branch 'main' into ahmadki/dev/sequence_packing_2
ahmadki Jun 22, 2025
61804e4
config alignments, configs for new implementation
ahmadki Jun 24, 2025
67c2958
generic get_packer
ahmadki Jun 24, 2025
ad31fce
config syntax cleanup
ahmadki Jun 24, 2025
e5811c2
moved dtensor sequence packing functions into hf common
ahmadki Jun 24, 2025
9f35db8
typed flash attention kwargs
ahmadki Jun 24, 2025
8674b96
dropped database based seq packing
ahmadki Jun 24, 2025
948096d
typo
ahmadki Jun 24, 2025
c3c8b66
unified loss_fn for seq packing
ahmadki Jun 24, 2025
ff27c79
config organization
ahmadki Jun 24, 2025
106ef8c
removed debug configs
ahmadki Jun 24, 2025
3b1b444
more config cleanup
ahmadki Jun 24, 2025
dc3e2d6
removed PackedDataset
ahmadki Jun 29, 2025
84be3b8
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jun 29, 2025
fe5e0e1
aligned NeMo git submodule with main
ahmadki Jun 29, 2025
6f9e203
Merge branch 'main' into ahmadki/sequence_packing
SahilJain314 Jun 30, 2025
d81a24d
Merge branch 'main' into ahmadki/sequence_packing
SahilJain314 Jul 1, 2025
b4c56b6
Lint fix
SahilJain314 Jul 1, 2025
222dc8a
Load AutoModelForCausalLM weight in FP32
ahmadki Jul 1, 2025
a9ec7cd
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jul 7, 2025
049e193
Merge branch 'main' into ahmadki/sequence_packing
ahmadki Jul 9, 2025
03208bc
Updating seq packing algo to modified ffd
SahilJain314 Jul 1, 2025
3764f77
Enabling sequence packing by default for megatron
SahilJain314 Jul 1, 2025
bd2a393
Critical sequence packing fixes for Megatron
SahilJain314 Jul 3, 2025
f2db981
Init CP (no pp
SahilJain314 Jul 10, 2025
c0d2898
Fixed CP + PP
SahilJain314 Jul 11, 2025
0543108
Cleanup
SahilJain314 Jul 11, 2025
2b1b4b5
Fixed unit tests
SahilJain314 Jul 11, 2025
24cf74f
lint
SahilJain314 Jul 11, 2025
da84d6f
copyright
SahilJain314 Jul 11, 2025
e14e41b
copyright
SahilJain314 Jul 11, 2025
790d803
Merge branch 'main' into sahilj/cp-rebase
SahilJain314 Jul 11, 2025
3e22ec2
bugfix
SahilJain314 Jul 11, 2025
75e2465
PR fixes
SahilJain314 Jul 11, 2025
3dadbf2
Merge remote-tracking branch 'origin' into sahilj/cp-rebase
SahilJain314 Jul 11, 2025
508fdf4
PR Fixes
SahilJain314 Jul 11, 2025
9c2b013
Update nemo_rl/models/policy/__init__.py
SahilJain314 Jul 11, 2025
8b9ed7e
Update tests/unit/data/packing/test_algorithms.py
SahilJain314 Jul 11, 2025
b386ab3
Lint, also adding Ahmad as Coauthor
SahilJain314 Jul 11, 2025
11e3b7d
Fixed dtensor sequence packing
SahilJain314 Jul 17, 2025
a4416f2
Merge remote-tracking branch 'origin/main' into sahilj/cp-rebase
SahilJain314 Jul 17, 2025
9ee7bf5
Fixed NeMo commit merge
SahilJain314 Jul 17, 2025
1e95143
feat: Enable CP during get_logprobs for dtensor worker. (#678)
joyang-nv Jul 18, 2025
5e497ce
Try unit fix
SahilJain314 Jul 18, 2025
62a6f01
fix: remove unnecessary ray initialization since it's handled at the …
terrykong Jul 19, 2025
4500cde
Unit fix
SahilJain314 Jul 21, 2025
36f67ac
docs: update converter path in README. (#672)
xxman-google Jul 17, 2025
75e2f69
fix: make mcore lr scheduler configuration consistent with dtensor (#…
ashors1 Jul 17, 2025
d431685
fix: fix mcore LR increment (#685)
ashors1 Jul 17, 2025
f9ef28f
fix: upgrade datasets to fix squad download (#692)
ashors1 Jul 18, 2025
cc31642
fix: Megatron config updates to avoid OOM (#687)
ashors1 Jul 18, 2025
df53dbc
fix: fix lr scheduler for config that was missed in #681 (#693)
ashors1 Jul 18, 2025
e93284d
fix: Fix gemma models broken by HF update (#676)
yfw Jul 19, 2025
172dd0a
chore: add CP+SP (sequence parallel) assertion in DTensor worker (#689)
yuki-97 Jul 19, 2025
9083a2e
Lint
SahilJain314 Jul 21, 2025
db08e14
Fixed generation test
SahilJain314 Jul 21, 2025
75bb1b2
revert conftest
SahilJain314 Jul 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ If you have trained a model and saved the checkpoint in the Pytorch DCP format,

```sh
# Example for a GRPO checkpoint at step 170
uv run python examples/convert_dcp_to_hf.py \
uv run python examples/converters/convert_dcp_to_hf.py \
--config results/grpo/step_170/config.yaml \
--dcp-ckpt-path results/grpo/step_170/policy/weights/ \
--hf-ckpt-path results/grpo/hf
Expand Down
7 changes: 6 additions & 1 deletion docs/model-quirks.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ NeMo-RL uses the vLLM V1 runtime for both synchronous and asynchronous inference

### Context Parallel with FSDP2

NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations.
- NeMo-RL implemented this feature based on torch CP [implementation](https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py). And we inherit its limitations.
Whether model level support CP only depends on arguments passed to `torch.nn.functional.scaled_dot_product_attention`. Current NeMo-RL passed all ones attention mask to `model.forward`. For Gemma-3, it won't ignore attention mask as result `attn_bias` is not None which is not supported by torch CP. Please see [assertion](https://github.com/pytorch/pytorch/blob/134179474539648ba7dee1317959529fbd0e7f89/torch/distributed/tensor/experimental/_attention.py#L262) .
- Context parallel can't be used together with sequence packing. Sequence packing requires `attn_implementation="flash_attention_2"`, this conflict with context parallel requires SDPA impl. Refer to [here](https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/modeling_utils.py#L2317) for more details.


- It's a known issue that context parallel can't be used together with sequence parallel.
Refer to [here](https://github.com/NVIDIA-NeMo/RL/issues/659) for more details.

## vLLM Async Rollout Timeout

Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo-deepscaler-1.5b-8K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ policy:
dynamic_batching:
enabled: False

sequence_packing:
enabled: False

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_deepscaler-1.5b-24K.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ policy:
dynamic_batching:
enabled: False

sequence_packing:
enabled: False

optimizer:
name: "torch.optim.AdamW"
kwargs:
Expand Down
10 changes: 10 additions & 0 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,26 @@ policy:
tensor_parallel_size: 1
context_parallel_size: 1
custom_parallel_plan: null

megatron_cfg:
enabled: false

# dynamic_batching improves performance by ensuring logprob and training microbatches
# have a sufficent number of tokens to maximize GPU utilization. Specifically, variable length
# responses are sorted by sequence length and bucketed into microbatches with a total
# amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the
# training and logprob stages respectively.
dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

# makes the training sequence length divisible by the tensor parallel size
Expand Down
11 changes: 8 additions & 3 deletions examples/configs/grpo_math_1B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,19 @@ policy:
# responses are sorted by sequence length and bucketed into microbatches with a total
# amount of tokens is approximately close to 'train_mb_tokens' and 'logprob_mb_tokens' for the
# training and logprob stages respectively.
#
# We disable it for Megatron as it is incompatible with Pipeline parallelism. Instead, we use sequence packing
dynamic_batching:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
sequence_length_round: 64

sequence_packing:
enabled: False # coming soon
enabled: True
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
algorithm: "modified_ffd"
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

max_grad_norm: 1.0
Expand Down Expand Up @@ -116,7 +121,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 5.0e-7

distributed_data_parallel_config:
Expand Down
4 changes: 2 additions & 2 deletions examples/configs/grpo_math_70B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

generation:
Expand All @@ -62,7 +62,7 @@ policy:
stop_strings: null
vllm_cfg:
tensor_parallel_size: 4
gpu_memory_utilization: 0.8
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}

cluster:
Expand Down
6 changes: 3 additions & 3 deletions examples/configs/grpo_math_8B_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

generation:
Expand All @@ -67,9 +67,9 @@ policy:
stop_strings: null
vllm_cfg:
tensor_parallel_size: 1
gpu_memory_utilization: 0.8
gpu_memory_utilization: 0.6
max_model_len: ${policy.max_total_sequence_length}

cluster:
gpus_per_node: 8
num_nodes: 1
num_nodes: 1
11 changes: 6 additions & 5 deletions examples/configs/grpo_math_qwen30ba3b_megatron.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ policy:
enabled: true
empty_unused_memory_level: 1
converter_type: "LlamaForCausalLM"
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 4
tensor_model_parallel_size: 2
pipeline_model_parallel_size: 1
context_parallel_size: 1
expert_tensor_parallel_size: 1
expert_model_parallel_size: 4
expert_model_parallel_size: 8
sequence_parallel: True
pipeline_dtype: ${policy.precision}

Expand All @@ -52,7 +52,7 @@ policy:
weight_decay_incr_style: "constant"
lr_decay_style: "constant"
lr_decay_iters: null
lr_warmup_iters: 50
lr_warmup_iters: 13
lr_warmup_init: 3.0e-8

env_vars:
Expand All @@ -68,7 +68,8 @@ policy:
stop_strings: null
vllm_cfg:
tensor_parallel_size: 4
gpu_memory_utilization: 0.8
gpu_memory_utilization: 0.7
enforce_eager: false
max_model_len: ${policy.max_total_sequence_length}

cluster:
Expand Down
8 changes: 7 additions & 1 deletion examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ policy:
dynamic_batching:
enabled: false

sequence_packing:
enabled: False
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
algorithm: "modified_first_fit_decreasing"
sequence_length_round: 64

# makes the training sequence length divisible by the tensor parallel size
# this is useful for sequence parallel training
make_sequence_length_divisible_by: ${policy.dtensor_cfg.tensor_parallel_size}
Expand Down Expand Up @@ -121,7 +127,7 @@ policy:
average_in_collective: true
data_parallel_sharding_strategy: "optim_grads_params"


data:
max_input_seq_length: ${policy.max_total_sequence_length}
dataset_name: "squad"
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/sft_openmathinstruct2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ policy:
context_parallel_size: 1
custom_parallel_plan: null

sequence_packing:
enabled: False

dynamic_batching:
enabled: false

Expand Down
2 changes: 2 additions & 0 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

OmegaConf.register_new_resolver("mul", lambda a, b: a * b)


def parse_args():
"""Parse command line arguments."""
Expand Down
95 changes: 95 additions & 0 deletions nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __call__(
global_valid_toks: torch.Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.Tensor, dict]:
"""Clipped Policy Gradient RL loss function."""
token_mask = data["token_mask"][:, 1:]
Expand Down Expand Up @@ -149,7 +150,10 @@ def __call__(
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
curr_logprobs = curr_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
curr_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"], seq_index=seq_index
Expand Down Expand Up @@ -312,6 +316,7 @@ def __call__(
global_valid_toks: Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
dpo_loss: bool = False,
dpo_average_log_probs: bool = False,
) -> tuple[torch.Tensor, dict[str, Any]]:
Expand All @@ -335,7 +340,10 @@ def __call__(
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
Expand Down Expand Up @@ -466,6 +474,7 @@ def _preference_loss(
global_valid_seqs: Tensor,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
## TODO(@ashors): there's some duplicate code here with the NLLLoss function. We should refactor
token_mask = data["token_mask"][:, 1:]
Expand All @@ -483,7 +492,10 @@ def _preference_loss(
vocab_end_index=(vocab_parallel_rank + 1) * next_token_logits.shape[-1],
tp_group=vocab_parallel_group,
inference_only=False,
cp_group=context_parallel_group,
)
# slice off to the correct length to remove potential CP padding
token_logprobs = token_logprobs[:, : data["input_ids"].shape[1] - 1]
elif isinstance(next_token_logits, torch.distributed.tensor.DTensor):
token_logprobs = get_logprobs_from_vocab_parallel_logits(
next_token_logits, data["input_ids"]
Expand Down Expand Up @@ -548,6 +560,7 @@ def __call__(
global_valid_toks: Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[torch.Tensor, dict[str, Any]]:
sft_loss_chosen = torch.tensor(0.0)
if self.sft_loss_weight > 0:
Expand All @@ -561,6 +574,7 @@ def __call__(
global_valid_toks=global_valid_toks, ## unused because sft loss returned is at the sample level
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
dpo_loss=True,
dpo_average_log_probs=self.sft_average_log_probs,
)
Expand All @@ -582,6 +596,7 @@ def __call__(
global_valid_seqs,
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
)

dpo_loss = (
Expand All @@ -601,3 +616,83 @@ def __call__(
"rewards_rejected_mean": rewards_rejected_mean.item(),
"num_valid_samples": num_valid_samples.item(),
}


class SequencePackingLossWrapper:
def __init__(
self,
loss_fn: LossFunction,
cu_seqlens_q: Tensor,
cu_seqlens_q_padded: Optional[Tensor] = None,
):
self.loss_fn = loss_fn
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_q_padded = cu_seqlens_q_padded

def __call__(
self,
next_token_logits: Tensor,
data: BatchedDataDict[Any],
global_valid_seqs: Tensor | None,
global_valid_toks: Tensor | None,
vocab_parallel_rank: Optional[int] = None,
vocab_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
context_parallel_group: Optional[torch.distributed.ProcessGroup] = None,
) -> tuple[Tensor, dict[str, Any]]:
"""Wraps a loss function to handle sequence packing by doing one sequence at a time to avoid excessive padding."""
unpadded_cu_seqlens = self.cu_seqlens_q
unpadded_seq_lengths = self.cu_seqlens_q[1:] - self.cu_seqlens_q[:-1]
if self.cu_seqlens_q_padded is not None:
padded_cu_seqlens = self.cu_seqlens_q_padded
padded_seq_lengths = (
self.cu_seqlens_q_padded[1:] - self.cu_seqlens_q_padded[:-1]
)
else:
padded_cu_seqlens = unpadded_cu_seqlens
padded_seq_lengths = unpadded_seq_lengths
seq_starts = padded_cu_seqlens[:-1]
seq_ends = padded_cu_seqlens[1:]

loss_accum = 0
metrics_accum = {}
for seq_idx in range(len(seq_starts)):
seq_start = seq_starts[seq_idx].item()
seq_end = seq_ends[seq_idx].item()

# get sequence and unpad all 'data' tensors. The data dict is a BatchedDataDict of unpacked tensors
seq_data = data.slice(seq_idx, seq_idx + 1)
unpadded_seq_data = {}
for k, v in seq_data.items():
if isinstance(v, torch.Tensor) and v.ndim > 1 and v.shape[1] > 1:
unpadded_seq_data[k] = v[:, : unpadded_seq_lengths[seq_idx]]
else:
unpadded_seq_data[k] = v

# get next_token_logits
cp_size = (
1
if context_parallel_group is None
else torch.distributed.get_world_size(context_parallel_group)
)
logit_slice_idxs = slice(
seq_start // cp_size,
(seq_start + padded_seq_lengths[seq_idx]) // cp_size,
)
next_token_logits_slice = next_token_logits[:, logit_slice_idxs, :]

loss, metrics = self.loss_fn(
next_token_logits_slice,
unpadded_seq_data,
global_valid_seqs,
global_valid_toks,
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
)
loss_accum += loss
for k, v in metrics.items():
if k not in metrics_accum:
metrics_accum[k] = 0
metrics_accum[k] += v

return loss_accum, metrics_accum
Loading
Loading