Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b593561
first commit, will clean up and add local tests
pengdurice Feb 26, 2026
52b85d4
code clean up and unit tests added
pengdurice Feb 26, 2026
9ad25fe
more code clean up
pengdurice Feb 26, 2026
322bea8
remove a print line
pengdurice Feb 27, 2026
fbc7a7d
remove a print line
pengdurice Feb 27, 2026
57fd917
rebase on main and then smoothout again
pengdurice Mar 2, 2026
a636341
fix the wiring of linear ce fusion
pengdurice Mar 3, 2026
fc199e5
address comments.
pengdurice Mar 4, 2026
3e10e27
add nightly config and sh files and tested locally
pengdurice Mar 5, 2026
e314b90
Update nemo_rl/models/megatron/setup.py
pengdurice Mar 5, 2026
71e5275
Update nemo_rl/algorithms/sft.py
pengdurice Mar 12, 2026
a18b24d
save local and rebase upstream
pengdurice Mar 12, 2026
de5ee1f
fix according to comments, maybe need some more
pengdurice Mar 13, 2026
f278ada
address more comments
pengdurice Mar 13, 2026
72a7254
Apply suggestions from code review
pengdurice Mar 13, 2026
86fd0b0
address comments and cleaning up
pengdurice Mar 13, 2026
797ee8d
remove local test env setup
pengdurice Mar 13, 2026
0485dee
some more clean up
pengdurice Mar 13, 2026
8296c6d
Update nemo_rl/models/megatron/setup.py
pengdurice Mar 17, 2026
fb26086
address comments
pengdurice Mar 17, 2026
b96b569
Merge branch 'main' into peng-add-linear-ce-fusion-v1
yuki-97 Mar 18, 2026
de97d32
fix ruff-format issues
pengdurice Mar 18, 2026
eabe55d
Update nemo_rl/models/megatron/setup.py
pengdurice Mar 18, 2026
751ccb0
Update nemo_rl/models/policy/__init__.py
pengdurice Mar 18, 2026
a6cee3a
Update nemo_rl/models/megatron/setup.py
pengdurice Mar 18, 2026
c8b0a97
run pre-commit, also change another place for accessing use_linear_ce…
pengdurice Mar 18, 2026
74bdffb
Apply suggestion
yuki-97 Mar 18, 2026
c87b344
add nightly fix
pengdurice Mar 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -320,4 +320,33 @@ uv run examples/run_sft.py \
policy.megatron_cfg.peft.enabled=true
```

For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685).
For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685).

## Optimizations

### Chunked Linear Cross-Entropy Fusion Loss

During standard SFT training the model materializes a full logit tensor of shape `[batch_size, seq_length, vocab_size]`, which can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The **chunked linear cross-entropy fusion loss** avoids this by computing the loss directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, computes per-token log probabilities, and discards the logits before moving to the next chunk.

**Benefits:**

- Extends the maximum trainable sequence length significantly (e.g. from <65K to >100K tokens) by eliminating the large logit tensor from GPU memory.
- Produces numerically equivalent loss values to the standard path.

**How to enable:**

Add the following to your Megatron config in your YAML file:

```yaml
policy:
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput
```

**Notes:**

- This optimization only applies to SFT training with `NLLLoss`. It does not affect other algorithms (GRPO, DPO, etc.).
- Context parallelism is not supported when linear CE fusion is enabled.
- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
defaults: ../../sft.yaml
sft:
max_num_steps: 10
checkpointing:
enabled: false
policy:
model_name: Qwen/Qwen2.5-Math-7B
train_global_batch_size: 64
max_total_sequence_length: 3200
dtensor_cfg:
enabled: false
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 128
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 2
sequence_parallel: true
attention_backend: unfused
freeze_moe_router: true
moe_router_dtype: fp64
moe_router_bias_update_rate: 0.0
moe_permute_fusion: true
optimizer:
lr: 1.0e-06
min_lr: 1.0e-06
adam_beta2: 0.999
adam_eps: 1.0e-08
use_distributed_optimizer: false
use_precision_aware_optimizer: false
scheduler:
lr_warmup_iters: 10
lr_warmup_init: 1.0e-11
lr_decay_iters: 32
make_sequence_length_divisible_by: 8
data:
add_generation_prompt: true
num_workers: 8
train:
dataset_name: OpenMathInstruct-2
output_key: generated_solution
split: train_1M
split_validation_size: 0.05
seed: ${sft.seed}
validation: null
default:
prompt_file: examples/prompts/math.txt
logger:
wandb:
project: nemo-rl
name: sft-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
tensorboard:
log_dir: tb_logs-sft-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
mlflow:
run_name: sft-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
cluster:
gpus_per_node: 8
2 changes: 2 additions & 0 deletions examples/configs/sft.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ policy:
## ignored since enabled=false, but needed for testing purposes
megatron_cfg:
enabled: false
use_linear_ce_fusion_loss: false
linear_ce_fusion_chunk_size: 256
env_vars: {}
empty_unused_memory_level: 1
activation_checkpointing: false
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ class NLLLossFn(LossFunction):
loss_type = LossType.TOKEN_LEVEL
input_type = LossInputType.LOGPROB

def __init__(self, use_linear_ce_fusion: bool = False):
self.use_linear_ce_fusion = use_linear_ce_fusion

def __call__(
self,
next_token_logprobs: Tensor,
Expand Down
25 changes: 16 additions & 9 deletions nemo_rl/algorithms/loss/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,22 @@ def prepare_loss_input(
loss_input = {"logits": logits}

elif loss_fn.input_type == LossInputType.LOGPROB:
logprobs = get_next_token_logprobs_from_logits(
input_ids=data["input_ids"],
next_token_logits=logits,
seq_index=data.get("seq_index", None),
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
sampling_params=sampling_params,
)
# Linear CE fusion patch returns precomputed next-token logprobs (2D tensor).
# Keep normal path unchanged for standard logits (3D tensor).
if hasattr(loss_fn, "use_linear_ce_fusion") and loss_fn.use_linear_ce_fusion:
logprobs = logits
logprobs = logprobs.to(torch.float32)
logprobs = logprobs[:, : data["input_ids"].shape[1] - 1]
else:
logprobs = get_next_token_logprobs_from_logits(
input_ids=data["input_ids"],
next_token_logits=logits,
seq_index=data.get("seq_index", None),
vocab_parallel_rank=vocab_parallel_rank,
vocab_parallel_group=vocab_parallel_group,
context_parallel_group=context_parallel_group,
sampling_params=sampling_params,
)

# handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now
if need_top_k_or_top_p_filtering(sampling_params):
Expand Down
29 changes: 21 additions & 8 deletions nemo_rl/algorithms/loss/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,33 @@ def __call__(
else:
unpadded_seq_data[k] = v

# get next_token_logits
cp_size = (
1
if self.context_parallel_group is None
else torch.distributed.get_world_size(self.context_parallel_group)
)
logit_start = seq_start // cp_size
logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size
logit_length = logit_end - logit_start
next_token_logits_slice = next_token_logits.narrow(
1, logit_start, logit_length
)

# prepare data for loss function
if (
hasattr(self.loss_fn, "use_linear_ce_fusion")
and self.loss_fn.use_linear_ce_fusion
):
# Linear CE fusion returns precomputed token logprobs where shape
# can be shorter by 1 token than padded sequence metadata.
# Use slicing (clamped end) to avoid narrow() OOB on packed tails.
logit_start = seq_start // cp_size
logit_end = min(
(seq_start + padded_seq_lengths[seq_idx]) // cp_size,
next_token_logits.shape[1],
)
logit_slice_idxs = slice(logit_start, logit_end)
next_token_logits_slice = next_token_logits[:, logit_slice_idxs]
else:
logit_start = seq_start // cp_size
logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size
logit_length = logit_end - logit_start
next_token_logits_slice = next_token_logits.narrow(
1, logit_start, logit_length
)
loss_input, unpadded_seq_data = self.prepare_fn(
logits=next_token_logits_slice,
data=unpadded_seq_data,
Expand Down
7 changes: 5 additions & 2 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from nemo_rl.algorithms.loss import NLLLossFn
from nemo_rl.algorithms.loss.loss_functions import NLLLossFn
from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed
from nemo_rl.data import DataConfig
from nemo_rl.data.collate_fn import rl_collate_fn
Expand Down Expand Up @@ -208,7 +208,10 @@ def setup(
# print the node IP and GPU ID of the policy workers for debugging
policy.print_node_ip_and_gpu_id()

loss_fn = NLLLossFn()
loss_fn = NLLLossFn(
use_linear_ce_fusion=policy_config["megatron_cfg"]["enabled"]
and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"]
)
print(" ✓ Model initialized")

print("\n" + "=" * 60)
Expand Down
Loading
Loading