Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions docs/guides/dpo.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,36 @@ The DPO implementation in NeMo RL supports several key parameters that can be ad

These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case.

## Optimizations

### Chunked Linear Cross-Entropy Fusion Loss

During standard DPO training the model materializes a full logit tensor of shape `[batch_size, seq_length, vocab_size]` for both the policy forward-backward pass and the reference model logprob computation. This can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The **chunked linear cross-entropy fusion loss** avoids this by computing log probabilities directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, gathers per-token log probabilities, and discards the logits before moving to the next chunk.

**Benefits:**

- Extends the maximum trainable sequence length significantly by eliminating the large logit tensor from GPU memory.
- Applies to both the training forward-backward pass and the reference model logprob computation.
- Produces numerically equivalent loss values to the standard path.

**How to enable:**

Add the following to your Megatron config in your YAML file:

```yaml
policy:
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput
```

**Notes:**

- Context parallelism is not supported when linear CE fusion is enabled.
- Sequence packing is not supported with DPO regardless of this setting (see [#719](https://github.com/NVIDIA-NeMo/RL/issues/719)).
- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point.

## Evaluate the Trained Model

Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities.
2 changes: 1 addition & 1 deletion docs/guides/sft.md
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,6 @@ policy:

**Notes:**

- This optimization only applies to SFT training with `NLLLoss`. It does not affect other algorithms (GRPO, DPO, etc.).
- This optimization applies to SFT training with `NLLLoss` and DPO training. See the [DPO guide](dpo.md#chunked-linear-cross-entropy-fusion-loss) for DPO-specific details.
- Context parallelism is not supported when linear CE fusion is enabled.
- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point.
2 changes: 2 additions & 0 deletions examples/configs/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ policy:
## ignored since enabled=false, but needed for testing purposes
megatron_cfg:
enabled: false
use_linear_ce_fusion_loss: false
linear_ce_fusion_chunk_size: 256
empty_unused_memory_level: 1
activation_checkpointing: false
tensor_model_parallel_size: 2
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
defaults: ../../dpo.yaml
dpo:
max_num_steps: 10
checkpointing:
enabled: false
policy:
model_name: Qwen/Qwen2.5-Math-7B
tokenizer:
name: ${policy.model_name}
train_global_batch_size: 32
train_micro_batch_size: 1
max_total_sequence_length: 6000
dtensor_cfg:
enabled: false
megatron_cfg:
enabled: true
use_linear_ce_fusion_loss: true
linear_ce_fusion_chunk_size: 128
tensor_model_parallel_size: 4
pipeline_model_parallel_size: 2
attention_backend: unfused
freeze_moe_router: true
moe_router_bias_update_rate: 0.0
moe_permute_fusion: true
optimizer:
lr: 1.0e-06
min_lr: 1.0e-06
adam_beta2: 0.999
use_distributed_optimizer: false
use_precision_aware_optimizer: false
scheduler:
lr_warmup_iters: 10
lr_warmup_init: 1.0e-11
lr_decay_iters: 32
make_sequence_length_divisible_by: 8
data:
num_workers: 8
logger:
wandb:
project: nemo-rl
name: dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
tensorboard:
log_dir: tb_logs-dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
mlflow:
run_name: dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g
cluster:
gpus_per_node: 8
6 changes: 5 additions & 1 deletion nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,11 @@ def setup(
# print the node IP and GPU ID of the policy workers for debugging
policy.print_node_ip_and_gpu_id()

loss_fn = DPOLossFn(master_config["dpo"])
loss_fn = DPOLossFn(
master_config["dpo"],
use_linear_ce_fusion=policy_config["megatron_cfg"]["enabled"]
and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"],
)
print(" ✓ Model initialized")

print("\n" + "=" * 60)
Expand Down
5 changes: 3 additions & 2 deletions nemo_rl/algorithms/loss/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,13 +797,14 @@ class DPOLossFn(PreferenceLossFn):
loss_type = LossType.SEQUENCE_LEVEL
input_type = LossInputType.LOGPROB

def __init__(self, cfg: DPOLossConfig):
def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False):
self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"]
self.preference_loss_weight = cfg["preference_loss_weight"]
self.sft_loss_weight = cfg["sft_loss_weight"]
self.preference_average_log_probs = cfg["preference_average_log_probs"]
self.sft_average_log_probs = cfg["sft_average_log_probs"]
self.sft_loss = NLLLossFn()
self.use_linear_ce_fusion = use_linear_ce_fusion
self.sft_loss = NLLLossFn(use_linear_ce_fusion=use_linear_ce_fusion)

def _dpo_loss(
self,
Expand Down
16 changes: 12 additions & 4 deletions nemo_rl/models/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,11 @@ def __init__(
self,
cfg: PolicyConfig,
sampling_params: Optional[TrainingSamplingParams] = None,
use_linear_ce_fusion: bool = False,
):
self.cfg = cfg
self.sampling_params = sampling_params
self.use_linear_ce_fusion = use_linear_ce_fusion

def __call__(
self,
Expand All @@ -427,10 +429,13 @@ def __call__(
original_seq_length = unpacked_input_ids.shape[1]

def processor_fn_inner(output_tensor):
tp_grp = get_tensor_model_parallel_group()
tp_rank = get_tensor_model_parallel_rank()
logprob_chunk_size = self.cfg.get("logprob_chunk_size", None)
if self.cfg["sequence_packing"]["enabled"]:
if self.use_linear_ce_fusion:
token_logprobs = output_tensor.to(torch.float32)
token_logprobs = token_logprobs[:, : original_seq_length - 1]
elif self.cfg["sequence_packing"]["enabled"]:
tp_grp = get_tensor_model_parallel_group()
tp_rank = get_tensor_model_parallel_rank()
logprob_chunk_size = self.cfg.get("logprob_chunk_size", None)
token_logprobs = from_parallel_logits_to_logprobs_packed_sequences(
output_tensor,
target=input_ids,
Expand All @@ -445,6 +450,9 @@ def processor_fn_inner(output_tensor):
sampling_params=self.sampling_params,
)
else:
tp_grp = get_tensor_model_parallel_group()
tp_rank = get_tensor_model_parallel_rank()
logprob_chunk_size = self.cfg.get("logprob_chunk_size", None)
token_logprobs = from_parallel_logits_to_logprobs(
output_tensor,
target=unpacked_input_ids,
Expand Down
5 changes: 5 additions & 0 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,9 +490,13 @@ def get_logprobs(
straggler_timer=self.mcore_state.straggler_timer,
)

use_linear_ce_fusion = self.cfg["megatron_cfg"].get(
"use_linear_ce_fusion_loss", False
)
logprobs_post_processor = LogprobsPostProcessor(
cfg=self.cfg,
sampling_params=self.sampling_params,
use_linear_ce_fusion=use_linear_ce_fusion,
)

list_of_logprobs = megatron_forward_backward(
Expand All @@ -506,6 +510,7 @@ def get_logprobs(
defer_fp32_logits=self.defer_fp32_logits,
sampling_params=self.sampling_params,
straggler_timer=self.mcore_state.straggler_timer,
use_linear_ce_fusion_loss=use_linear_ce_fusion,
)

if is_pipeline_last_stage(ignore_virtual=True):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
source $SCRIPT_DIR/common.env

# ===== BEGIN CONFIG =====
NUM_NODES=1
STEPS_PER_RUN=10
MAX_STEPS=10
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
NUM_MINUTES=25
# ===== END CONFIG =====

exit_if_max_steps_reached

# Run the experiment
cd $PROJECT_ROOT
uv run examples/run_dpo.py \
--config $CONFIG_PATH \
dpo.max_num_steps=$MAX_STEPS \
logger.log_dir=$LOG_DIR \
logger.wandb_enabled=True \
logger.wandb.project=nemo-rl \
logger.wandb.name=$EXP_NAME \
logger.monitor_gpus=True \
logger.tensorboard_enabled=True \
checkpointing.enabled=true \
checkpointing.checkpoint_dir=$CKPT_DIR \
$@ \
2>&1 | tee $RUN_LOG

# Convert tensorboard logs to json
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS

# Only run metrics if the target step is reached
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
# Smoke checks: run completed and loss is finite/reasonable.
uv run tests/check_metrics.py $JSON_METRICS \
'data["train/loss"]["10"] > 0.0' \
'data["train/loss"]["10"] < 20.0'

# Clean up checkpoint directory after successful run to save space.
rm -rf "$CKPT_DIR"
fi
1 change: 1 addition & 0 deletions tests/test_suites/nightly.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.sh
tests/test_suites/llm/sft-qwen2.5-math7b-2n8g-megatron.sh
# chunked linear CE loss
tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh
tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh

# Nemotron Super 49B SFT tests
# Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571
Expand Down
102 changes: 102 additions & 0 deletions tests/unit/models/policy/test_megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,108 @@ def test_megatron_sft_linear_ce_fusion_agreement(tiny_qwen2_model_path):
torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2)


@pytest.mark.timeout(600)
def test_megatron_dpo_linear_ce_fusion_agreement(tiny_qwen2_model_path):
"""Test that linear CE fusion loss produces the same results as the standard path for DPO."""
import time

num_gpus = 2
batch_size = 4
seq_len = 64
vocab_size = 151936

torch.manual_seed(42)
input_ids = torch.randint(0, vocab_size, (batch_size * 2, seq_len))
attention_mask = torch.ones(batch_size * 2, seq_len)
input_lengths = attention_mask.sum(dim=1).to(torch.int32)
token_mask = torch.triu(torch.ones(batch_size * 2, seq_len), diagonal=1)
sample_mask = torch.ones(batch_size * 2)
reference_policy_logprobs = torch.randn(batch_size * 2, seq_len)

data = BatchedDataDict(
{
"input_ids": input_ids,
"input_lengths": input_lengths,
"attention_mask": attention_mask,
"token_mask": token_mask,
"sample_mask": sample_mask,
"reference_policy_logprobs": reference_policy_logprobs,
}
)

dpo_cfg = {
"reference_policy_kl_penalty": 0.1,
"preference_loss_weight": 1.0,
"sft_loss_weight": 0.5,
"preference_average_log_probs": False,
"sft_average_log_probs": False,
}

# --- Standard DPO (no linear CE fusion) ---
cluster_std = RayVirtualCluster(
name="test-dpo-std",
bundle_ct_per_node_list=[num_gpus],
use_gpus=True,
num_gpus_per_node=num_gpus,
max_colocated_worker_groups=1,
)
config_std = create_megatron_test_config(tiny_qwen2_model_path)
tokenizer = get_tokenizer(config_std["tokenizer"])
policy_std = Policy(
cluster=cluster_std,
config=config_std,
tokenizer=tokenizer,
init_reference_model=False,
)
dpo_loss_std = DPOLossFn(dpo_cfg)

try:
policy_std.prepare_for_training()
results_std = policy_std.train(data, dpo_loss_std)
loss_std = results_std["loss"]
finally:
policy_std.shutdown()
cluster_std.shutdown()

time.sleep(10)

# --- DPO with linear CE fusion ---
cluster_fuse = RayVirtualCluster(
name="test-dpo-fuse",
bundle_ct_per_node_list=[num_gpus],
use_gpus=True,
num_gpus_per_node=num_gpus,
max_colocated_worker_groups=1,
)
config_fuse = create_megatron_test_config(tiny_qwen2_model_path)
config_fuse["megatron_cfg"]["use_linear_ce_fusion_loss"] = True
config_fuse["megatron_cfg"]["linear_ce_fusion_chunk_size"] = 256
policy_fuse = Policy(
cluster=cluster_fuse,
config=config_fuse,
tokenizer=tokenizer,
init_reference_model=False,
)
dpo_loss_fuse = DPOLossFn(dpo_cfg, use_linear_ce_fusion=True)

try:
policy_fuse.prepare_for_training()
results_fuse = policy_fuse.train(data, dpo_loss_fuse)
loss_fuse = results_fuse["loss"]
finally:
policy_fuse.shutdown()
cluster_fuse.shutdown()

# Verify both produce valid losses
assert not torch.isnan(loss_std).any(), "Standard DPO loss should not be NaN"
assert not torch.isnan(loss_fuse).any(), "Fusion DPO loss should not be NaN"
assert not torch.isinf(loss_std).any(), "Standard DPO loss should not be Inf"
assert not torch.isinf(loss_fuse).any(), "Fusion DPO loss should not be Inf"

# Verify losses are numerically close
torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2)


@pytest.mark.hf_gated
@pytest.mark.timeout(300)
def test_megatron_context_parallel_logprob_agreement(tiny_llama_model_path):
Expand Down
Loading