diff --git a/docs/guides/dpo.md b/docs/guides/dpo.md index f5451dc760..61a80dc2ab 100644 --- a/docs/guides/dpo.md +++ b/docs/guides/dpo.md @@ -195,6 +195,36 @@ The DPO implementation in NeMo RL supports several key parameters that can be ad These parameters can be adjusted in the config file or via command-line overrides to optimize training for your specific use case. +## Optimizations + +### Chunked Linear Cross-Entropy Fusion Loss + +During standard DPO training the model materializes a full logit tensor of shape `[batch_size, seq_length, vocab_size]` for both the policy forward-backward pass and the reference model logprob computation. This can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The **chunked linear cross-entropy fusion loss** avoids this by computing log probabilities directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, gathers per-token log probabilities, and discards the logits before moving to the next chunk. + +**Benefits:** + +- Extends the maximum trainable sequence length significantly by eliminating the large logit tensor from GPU memory. +- Applies to both the training forward-backward pass and the reference model logprob computation. +- Produces numerically equivalent loss values to the standard path. + +**How to enable:** + +Add the following to your Megatron config in your YAML file: + +```yaml +policy: + megatron_cfg: + enabled: true + use_linear_ce_fusion_loss: true + linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput +``` + +**Notes:** + +- Context parallelism is not supported when linear CE fusion is enabled. +- Sequence packing is not supported with DPO regardless of this setting (see [#719](https://github.com/NVIDIA-NeMo/RL/issues/719)). +- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point. + ## Evaluate the Trained Model Upon completion of the training process, you can refer to our [evaluation guide](eval.md) to assess model capabilities. diff --git a/docs/guides/sft.md b/docs/guides/sft.md index f661c1f146..1ce19416bd 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -347,6 +347,6 @@ policy: **Notes:** -- This optimization only applies to SFT training with `NLLLoss`. It does not affect other algorithms (GRPO, DPO, etc.). +- This optimization applies to SFT training with `NLLLoss` and DPO training. See the [DPO guide](dpo.md#chunked-linear-cross-entropy-fusion-loss) for DPO-specific details. - Context parallelism is not supported when linear CE fusion is enabled. - The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point. \ No newline at end of file diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index e48293070a..82f79d0595 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -110,6 +110,8 @@ policy: ## ignored since enabled=false, but needed for testing purposes megatron_cfg: enabled: false + use_linear_ce_fusion_loss: false + linear_ce_fusion_chunk_size: 256 empty_unused_memory_level: 1 activation_checkpointing: false tensor_model_parallel_size: 2 diff --git a/examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml b/examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml new file mode 100644 index 0000000000..071f90d077 --- /dev/null +++ b/examples/configs/recipes/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml @@ -0,0 +1,47 @@ +defaults: ../../dpo.yaml +dpo: + max_num_steps: 10 +checkpointing: + enabled: false +policy: + model_name: Qwen/Qwen2.5-Math-7B + tokenizer: + name: ${policy.model_name} + train_global_batch_size: 32 + train_micro_batch_size: 1 + max_total_sequence_length: 6000 + dtensor_cfg: + enabled: false + megatron_cfg: + enabled: true + use_linear_ce_fusion_loss: true + linear_ce_fusion_chunk_size: 128 + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 2 + attention_backend: unfused + freeze_moe_router: true + moe_router_bias_update_rate: 0.0 + moe_permute_fusion: true + optimizer: + lr: 1.0e-06 + min_lr: 1.0e-06 + adam_beta2: 0.999 + use_distributed_optimizer: false + use_precision_aware_optimizer: false + scheduler: + lr_warmup_iters: 10 + lr_warmup_init: 1.0e-11 + lr_decay_iters: 32 + make_sequence_length_divisible_by: 8 +data: + num_workers: 8 +logger: + wandb: + project: nemo-rl + name: dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g + tensorboard: + log_dir: tb_logs-dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g + mlflow: + run_name: dpo-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g +cluster: + gpus_per_node: 8 diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index cfb4909bba..27c638933f 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -252,7 +252,11 @@ def setup( # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = DPOLossFn(master_config["dpo"]) + loss_fn = DPOLossFn( + master_config["dpo"], + use_linear_ce_fusion=policy_config["megatron_cfg"]["enabled"] + and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"], + ) print(" ✓ Model initialized") print("\n" + "=" * 60) diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index ab05586d7c..49860b541c 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -797,13 +797,14 @@ class DPOLossFn(PreferenceLossFn): loss_type = LossType.SEQUENCE_LEVEL input_type = LossInputType.LOGPROB - def __init__(self, cfg: DPOLossConfig): + def __init__(self, cfg: DPOLossConfig, use_linear_ce_fusion: bool = False): self.reference_policy_kl_penalty = cfg["reference_policy_kl_penalty"] self.preference_loss_weight = cfg["preference_loss_weight"] self.sft_loss_weight = cfg["sft_loss_weight"] self.preference_average_log_probs = cfg["preference_average_log_probs"] self.sft_average_log_probs = cfg["sft_average_log_probs"] - self.sft_loss = NLLLossFn() + self.use_linear_ce_fusion = use_linear_ce_fusion + self.sft_loss = NLLLossFn(use_linear_ce_fusion=use_linear_ce_fusion) def _dpo_loss( self, diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 248c85f3ff..c9b4d81aca 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -400,9 +400,11 @@ def __init__( self, cfg: PolicyConfig, sampling_params: Optional[TrainingSamplingParams] = None, + use_linear_ce_fusion: bool = False, ): self.cfg = cfg self.sampling_params = sampling_params + self.use_linear_ce_fusion = use_linear_ce_fusion def __call__( self, @@ -427,10 +429,13 @@ def __call__( original_seq_length = unpacked_input_ids.shape[1] def processor_fn_inner(output_tensor): - tp_grp = get_tensor_model_parallel_group() - tp_rank = get_tensor_model_parallel_rank() - logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) - if self.cfg["sequence_packing"]["enabled"]: + if self.use_linear_ce_fusion: + token_logprobs = output_tensor.to(torch.float32) + token_logprobs = token_logprobs[:, : original_seq_length - 1] + elif self.cfg["sequence_packing"]["enabled"]: + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) token_logprobs = from_parallel_logits_to_logprobs_packed_sequences( output_tensor, target=input_ids, @@ -445,6 +450,9 @@ def processor_fn_inner(output_tensor): sampling_params=self.sampling_params, ) else: + tp_grp = get_tensor_model_parallel_group() + tp_rank = get_tensor_model_parallel_rank() + logprob_chunk_size = self.cfg.get("logprob_chunk_size", None) token_logprobs = from_parallel_logits_to_logprobs( output_tensor, target=unpacked_input_ids, diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index fdb141fcf8..0da8470d68 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -490,9 +490,13 @@ def get_logprobs( straggler_timer=self.mcore_state.straggler_timer, ) + use_linear_ce_fusion = self.cfg["megatron_cfg"].get( + "use_linear_ce_fusion_loss", False + ) logprobs_post_processor = LogprobsPostProcessor( cfg=self.cfg, sampling_params=self.sampling_params, + use_linear_ce_fusion=use_linear_ce_fusion, ) list_of_logprobs = megatron_forward_backward( @@ -506,6 +510,7 @@ def get_logprobs( defer_fp32_logits=self.defer_fp32_logits, sampling_params=self.sampling_params, straggler_timer=self.mcore_state.straggler_timer, + use_linear_ce_fusion_loss=use_linear_ce_fusion, ) if is_pipeline_last_stage(ignore_virtual=True): diff --git a/tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh b/tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh new file mode 100755 index 0000000000..dbab6cb2f7 --- /dev/null +++ b/tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh @@ -0,0 +1,43 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=10 +MAX_STEPS=10 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=25 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_dpo.py \ + --config $CONFIG_PATH \ + dpo.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=true \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + # Smoke checks: run completed and loss is finite/reasonable. + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["10"] > 0.0' \ + 'data["train/loss"]["10"] < 20.0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 61e474f2c3..74f149716c 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -108,6 +108,7 @@ tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.sh tests/test_suites/llm/sft-qwen2.5-math7b-2n8g-megatron.sh # chunked linear CE loss tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh +tests/test_suites/llm/dpo-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh # Nemotron Super 49B SFT tests # Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571 diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index eff97b215c..595d22a4cd 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -1971,6 +1971,108 @@ def test_megatron_sft_linear_ce_fusion_agreement(tiny_qwen2_model_path): torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2) +@pytest.mark.timeout(600) +def test_megatron_dpo_linear_ce_fusion_agreement(tiny_qwen2_model_path): + """Test that linear CE fusion loss produces the same results as the standard path for DPO.""" + import time + + num_gpus = 2 + batch_size = 4 + seq_len = 64 + vocab_size = 151936 + + torch.manual_seed(42) + input_ids = torch.randint(0, vocab_size, (batch_size * 2, seq_len)) + attention_mask = torch.ones(batch_size * 2, seq_len) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + token_mask = torch.triu(torch.ones(batch_size * 2, seq_len), diagonal=1) + sample_mask = torch.ones(batch_size * 2) + reference_policy_logprobs = torch.randn(batch_size * 2, seq_len) + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "token_mask": token_mask, + "sample_mask": sample_mask, + "reference_policy_logprobs": reference_policy_logprobs, + } + ) + + dpo_cfg = { + "reference_policy_kl_penalty": 0.1, + "preference_loss_weight": 1.0, + "sft_loss_weight": 0.5, + "preference_average_log_probs": False, + "sft_average_log_probs": False, + } + + # --- Standard DPO (no linear CE fusion) --- + cluster_std = RayVirtualCluster( + name="test-dpo-std", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + config_std = create_megatron_test_config(tiny_qwen2_model_path) + tokenizer = get_tokenizer(config_std["tokenizer"]) + policy_std = Policy( + cluster=cluster_std, + config=config_std, + tokenizer=tokenizer, + init_reference_model=False, + ) + dpo_loss_std = DPOLossFn(dpo_cfg) + + try: + policy_std.prepare_for_training() + results_std = policy_std.train(data, dpo_loss_std) + loss_std = results_std["loss"] + finally: + policy_std.shutdown() + cluster_std.shutdown() + + time.sleep(10) + + # --- DPO with linear CE fusion --- + cluster_fuse = RayVirtualCluster( + name="test-dpo-fuse", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + config_fuse = create_megatron_test_config(tiny_qwen2_model_path) + config_fuse["megatron_cfg"]["use_linear_ce_fusion_loss"] = True + config_fuse["megatron_cfg"]["linear_ce_fusion_chunk_size"] = 256 + policy_fuse = Policy( + cluster=cluster_fuse, + config=config_fuse, + tokenizer=tokenizer, + init_reference_model=False, + ) + dpo_loss_fuse = DPOLossFn(dpo_cfg, use_linear_ce_fusion=True) + + try: + policy_fuse.prepare_for_training() + results_fuse = policy_fuse.train(data, dpo_loss_fuse) + loss_fuse = results_fuse["loss"] + finally: + policy_fuse.shutdown() + cluster_fuse.shutdown() + + # Verify both produce valid losses + assert not torch.isnan(loss_std).any(), "Standard DPO loss should not be NaN" + assert not torch.isnan(loss_fuse).any(), "Fusion DPO loss should not be NaN" + assert not torch.isinf(loss_std).any(), "Standard DPO loss should not be Inf" + assert not torch.isinf(loss_fuse).any(), "Fusion DPO loss should not be Inf" + + # Verify losses are numerically close + torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2) + + @pytest.mark.hf_gated @pytest.mark.timeout(300) def test_megatron_context_parallel_logprob_agreement(tiny_llama_model_path):