diff --git a/docs/guides/sft.md b/docs/guides/sft.md index 368ccd216e..f661c1f146 100644 --- a/docs/guides/sft.md +++ b/docs/guides/sft.md @@ -320,4 +320,33 @@ uv run examples/run_sft.py \ policy.megatron_cfg.peft.enabled=true ``` -For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). \ No newline at end of file +For more details on LoRA, see [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). + +## Optimizations + +### Chunked Linear Cross-Entropy Fusion Loss + +During standard SFT training the model materializes a full logit tensor of shape `[batch_size, seq_length, vocab_size]`, which can cause out-of-memory (OOM) errors for long sequences or large vocabularies. The **chunked linear cross-entropy fusion loss** avoids this by computing the loss directly from the hidden states: it chunks the sequence dimension, projects each chunk to logits on the fly, computes per-token log probabilities, and discards the logits before moving to the next chunk. + +**Benefits:** + +- Extends the maximum trainable sequence length significantly (e.g. from <65K to >100K tokens) by eliminating the large logit tensor from GPU memory. +- Produces numerically equivalent loss values to the standard path. + +**How to enable:** + +Add the following to your Megatron config in your YAML file: + +```yaml +policy: + megatron_cfg: + enabled: true + use_linear_ce_fusion_loss: true + linear_ce_fusion_chunk_size: 256 # tokens per chunk; smaller = less memory, larger = more throughput +``` + +**Notes:** + +- This optimization only applies to SFT training with `NLLLoss`. It does not affect other algorithms (GRPO, DPO, etc.). +- Context parallelism is not supported when linear CE fusion is enabled. +- The `linear_ce_fusion_chunk_size` parameter controls the trade-off between memory savings and compute throughput. The default value of 256 is a good starting point. \ No newline at end of file diff --git a/examples/configs/recipes/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml b/examples/configs/recipes/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml new file mode 100644 index 0000000000..c2b5206810 --- /dev/null +++ b/examples/configs/recipes/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.yaml @@ -0,0 +1,57 @@ +defaults: ../../sft.yaml +sft: + max_num_steps: 10 +checkpointing: + enabled: false +policy: + model_name: Qwen/Qwen2.5-Math-7B + train_global_batch_size: 64 + max_total_sequence_length: 3200 + dtensor_cfg: + enabled: false + megatron_cfg: + enabled: true + use_linear_ce_fusion_loss: true + linear_ce_fusion_chunk_size: 128 + tensor_model_parallel_size: 4 + pipeline_model_parallel_size: 2 + sequence_parallel: true + attention_backend: unfused + freeze_moe_router: true + moe_router_dtype: fp64 + moe_router_bias_update_rate: 0.0 + moe_permute_fusion: true + optimizer: + lr: 1.0e-06 + min_lr: 1.0e-06 + adam_beta2: 0.999 + adam_eps: 1.0e-08 + use_distributed_optimizer: false + use_precision_aware_optimizer: false + scheduler: + lr_warmup_iters: 10 + lr_warmup_init: 1.0e-11 + lr_decay_iters: 32 + make_sequence_length_divisible_by: 8 +data: + add_generation_prompt: true + num_workers: 8 + train: + dataset_name: OpenMathInstruct-2 + output_key: generated_solution + split: train_1M + split_validation_size: 0.05 + seed: ${sft.seed} + validation: null + default: + prompt_file: examples/prompts/math.txt +logger: + wandb: + project: nemo-rl + name: sft-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g + tensorboard: + log_dir: tb_logs-sft-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g + mlflow: + run_name: sft-qwen2.5-math-7b-megatron-chunked-linear-ce-loss-1n8g +cluster: + gpus_per_node: 8 diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 821da4e530..34cb03c4b9 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -91,6 +91,8 @@ policy: ## ignored since enabled=false, but needed for testing purposes megatron_cfg: enabled: false + use_linear_ce_fusion_loss: false + linear_ce_fusion_chunk_size: 256 env_vars: {} empty_unused_memory_level: 1 activation_checkpointing: false diff --git a/nemo_rl/algorithms/loss/loss_functions.py b/nemo_rl/algorithms/loss/loss_functions.py index c72269eee1..ab05586d7c 100755 --- a/nemo_rl/algorithms/loss/loss_functions.py +++ b/nemo_rl/algorithms/loss/loss_functions.py @@ -576,6 +576,9 @@ class NLLLossFn(LossFunction): loss_type = LossType.TOKEN_LEVEL input_type = LossInputType.LOGPROB + def __init__(self, use_linear_ce_fusion: bool = False): + self.use_linear_ce_fusion = use_linear_ce_fusion + def __call__( self, next_token_logprobs: Tensor, diff --git a/nemo_rl/algorithms/loss/utils.py b/nemo_rl/algorithms/loss/utils.py index ad92522db0..15e19037df 100644 --- a/nemo_rl/algorithms/loss/utils.py +++ b/nemo_rl/algorithms/loss/utils.py @@ -62,15 +62,22 @@ def prepare_loss_input( loss_input = {"logits": logits} elif loss_fn.input_type == LossInputType.LOGPROB: - logprobs = get_next_token_logprobs_from_logits( - input_ids=data["input_ids"], - next_token_logits=logits, - seq_index=data.get("seq_index", None), - vocab_parallel_rank=vocab_parallel_rank, - vocab_parallel_group=vocab_parallel_group, - context_parallel_group=context_parallel_group, - sampling_params=sampling_params, - ) + # Linear CE fusion patch returns precomputed next-token logprobs (2D tensor). + # Keep normal path unchanged for standard logits (3D tensor). + if hasattr(loss_fn, "use_linear_ce_fusion") and loss_fn.use_linear_ce_fusion: + logprobs = logits + logprobs = logprobs.to(torch.float32) + logprobs = logprobs[:, : data["input_ids"].shape[1] - 1] + else: + logprobs = get_next_token_logprobs_from_logits( + input_ids=data["input_ids"], + next_token_logits=logits, + seq_index=data.get("seq_index", None), + vocab_parallel_rank=vocab_parallel_rank, + vocab_parallel_group=vocab_parallel_group, + context_parallel_group=context_parallel_group, + sampling_params=sampling_params, + ) # handle top-k/top-p filtering for logprobs, only used for ClippedPGLossFn now if need_top_k_or_top_p_filtering(sampling_params): diff --git a/nemo_rl/algorithms/loss/wrapper.py b/nemo_rl/algorithms/loss/wrapper.py index a28bb18a19..e27095379a 100644 --- a/nemo_rl/algorithms/loss/wrapper.py +++ b/nemo_rl/algorithms/loss/wrapper.py @@ -95,20 +95,33 @@ def __call__( else: unpadded_seq_data[k] = v - # get next_token_logits cp_size = ( 1 if self.context_parallel_group is None else torch.distributed.get_world_size(self.context_parallel_group) ) - logit_start = seq_start // cp_size - logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size - logit_length = logit_end - logit_start - next_token_logits_slice = next_token_logits.narrow( - 1, logit_start, logit_length - ) - # prepare data for loss function + if ( + hasattr(self.loss_fn, "use_linear_ce_fusion") + and self.loss_fn.use_linear_ce_fusion + ): + # Linear CE fusion returns precomputed token logprobs where shape + # can be shorter by 1 token than padded sequence metadata. + # Use slicing (clamped end) to avoid narrow() OOB on packed tails. + logit_start = seq_start // cp_size + logit_end = min( + (seq_start + padded_seq_lengths[seq_idx]) // cp_size, + next_token_logits.shape[1], + ) + logit_slice_idxs = slice(logit_start, logit_end) + next_token_logits_slice = next_token_logits[:, logit_slice_idxs] + else: + logit_start = seq_start // cp_size + logit_end = (seq_start + padded_seq_lengths[seq_idx]) // cp_size + logit_length = logit_end - logit_start + next_token_logits_slice = next_token_logits.narrow( + 1, logit_start, logit_length + ) loss_input, unpadded_seq_data = self.prepare_fn( logits=next_token_logits_slice, data=unpadded_seq_data, diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index a08c76022c..465d41a505 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -21,7 +21,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer, PreTrainedTokenizerBase -from nemo_rl.algorithms.loss import NLLLossFn +from nemo_rl.algorithms.loss.loss_functions import NLLLossFn from nemo_rl.algorithms.utils import maybe_pad_last_batch, set_seed from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn @@ -208,7 +208,10 @@ def setup( # print the node IP and GPU ID of the policy workers for debugging policy.print_node_ip_and_gpu_id() - loss_fn = NLLLossFn() + loss_fn = NLLLossFn( + use_linear_ce_fusion=policy_config["megatron_cfg"]["enabled"] + and policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"] + ) print(" ✓ Model initialized") print("\n" + "=" * 60) diff --git a/nemo_rl/distributed/model_utils.py b/nemo_rl/distributed/model_utils.py index 2575fd891d..3c98cdbca1 100644 --- a/nemo_rl/distributed/model_utils.py +++ b/nemo_rl/distributed/model_utils.py @@ -15,6 +15,12 @@ from typing import Any, Optional import torch +from megatron.core.models.gpt import GPTModel +from megatron.core.parallel_state import ( + get_tensor_model_parallel_group, + get_tensor_model_parallel_rank, +) +from megatron.core.utils import deprecate_inference_params, get_pg_size from torch.distributed.tensor import DTensor, distribute_tensor from nemo_rl.algorithms.logits_sampling_utils import ( @@ -1729,6 +1735,346 @@ def backward( return grad_input, None, None, None +def from_parallel_hidden_states_to_logprobs( + tensor_parallel_hidden_states: torch.Tensor, + output_weight_layer: torch.Tensor, + output_weight: torch.Tensor, + runtime_gather_output: bool, + target: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + cp_group: Optional[torch.distributed.ProcessGroup] = None, + chunk_size: Optional[int] = None, +) -> torch.Tensor: + """Get log probabilities from TP sharded hidden states.""" + target = target.roll(shifts=-1, dims=-1) + assert cp_group is None or torch.distributed.get_world_size(cp_group) == 1, ( + "Context parallelism is not supported for linear CE fusion loss" + ) + logprobs: torch.Tensor = ChunkedDistributedHiddenStatesToLogprobs.apply( # type: ignore + tensor_parallel_hidden_states, + target, + output_weight_layer, + vocab_start_index, + vocab_end_index, + chunk_size, + tp_group, + inference_only, + ).contiguous() + + return logprobs[:, :-1] + + +class ChunkedDistributedHiddenStatesToLogprobs(torch.autograd.Function): + """Compute distributed log-softmax once and gather logprobs at given global indices.""" + + @staticmethod + def forward( + ctx: Any, + tensor_parallel_hidden_states: torch.Tensor, + target: torch.Tensor, + output_weight_layer: torch.Tensor, + vocab_start_index: int, + vocab_end_index: int, + chunk_size: int, + tp_group: torch.distributed.ProcessGroup, + inference_only: bool = False, + ) -> torch.Tensor: + target_mask = (target < vocab_start_index) | (target >= vocab_end_index) + masked_target = target - vocab_start_index + masked_target[target_mask] = 0 + tp_group_size = torch.distributed.get_world_size(tp_group) + if tp_group_size > 1: + original_tensor_parallel_hidden_states = ( + tensor_parallel_hidden_states.clone() + ) + all_hidden_states = [ + torch.zeros_like(tensor_parallel_hidden_states) + for _ in range(tp_group_size) + ] + torch.distributed.all_gather( + all_hidden_states, tensor_parallel_hidden_states, group=tp_group + ) + tensor_parallel_hidden_states = torch.cat(all_hidden_states, dim=0) + else: + original_tensor_parallel_hidden_states = tensor_parallel_hidden_states + seq_size = int(tensor_parallel_hidden_states.shape[0]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_log_probs = [] + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + logits = torch.matmul( + tensor_parallel_hidden_states[chunk_start:chunk_end, :, :], + output_weight_layer.T, + ) + logits = logits.to(dtype=torch.float32).transpose(0, 1).contiguous() + log_probs = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + + log_probs = ( + torch.gather( + log_probs, -1, masked_target[:, chunk_start:chunk_end].unsqueeze(-1) + ) + .squeeze(-1) + .detach() + ) + log_probs[target_mask[:, chunk_start:chunk_end]] = 0.0 + + all_log_probs.append(log_probs) + + log_probs = torch.cat(all_log_probs, dim=1) + torch.distributed.all_reduce( + log_probs, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + if not inference_only: + # only save for backward when we have inference only=False + # save tensor_parallel_hidden_states and the output_layer to the context + ctx.save_for_backward( + original_tensor_parallel_hidden_states.detach(), + target_mask.detach(), + masked_target.detach(), + output_weight_layer.detach(), + ) + ctx.chunk_size = chunk_size + ctx.tp_group = tp_group + + return log_probs + + @staticmethod + def backward( + ctx: Any, *grad_outputs: torch.Tensor + ) -> tuple[torch.Tensor, None, torch.Tensor, None, None, None, None, None]: + grad_output = grad_outputs[0] + # the tensor_parallel_hidden_states is already all gathered in the forward pass + ( + tensor_parallel_hidden_states, + target_mask, + masked_target, + output_weight_layer, + ) = ctx.saved_tensors + tp_group = ctx.tp_group + tp_group_size = torch.distributed.get_world_size(tp_group) + if tp_group_size > 1: + all_hidden_states = [ + torch.zeros_like(tensor_parallel_hidden_states) + for _ in range(tp_group_size) + ] + torch.distributed.all_gather( + all_hidden_states, tensor_parallel_hidden_states, group=tp_group + ) + tensor_parallel_hidden_states = torch.cat(all_hidden_states, dim=0) + chunk_size = ctx.chunk_size + tp_group = ctx.tp_group + # this is the vocab size for this partition when the output_layer is a ColumnParallelLinear + partition_vocab_size = output_weight_layer.size(0) + seq_size = int(tensor_parallel_hidden_states.shape[0]) + num_chunks = (seq_size + chunk_size - 1) // chunk_size + all_grad_input_hidden_states = [] + all_grad_input_output_layer = [] + grad_input_output_layer = torch.zeros_like(output_weight_layer) + for chunk_idx in range(num_chunks): + chunk_start = chunk_idx * chunk_size + chunk_end = min(seq_size, (chunk_idx + 1) * chunk_size) + # recalculate the logits using the output_layer + logits = torch.matmul( + tensor_parallel_hidden_states[chunk_start:chunk_end, :, :], + output_weight_layer.T, + ) + logits = logits.to(dtype=torch.float32).transpose(0, 1).contiguous() + softmax_output = _compute_distributed_log_softmax( + logits, + group=tp_group, + ) + softmax_output = softmax_output.exp().detach() + is_chosen = (~(target_mask[:, chunk_start:chunk_end])).unsqueeze( + -1 + ) * torch.nn.functional.one_hot( + masked_target[:, chunk_start:chunk_end], + num_classes=partition_vocab_size, + ) + grad_input = is_chosen.float().sub_(softmax_output) + used_grad_output = grad_output[:, chunk_start:chunk_end] + grad_input.mul_(used_grad_output.unsqueeze(dim=-1)) + grad_input_hidden_states = torch.matmul( + grad_input, output_weight_layer.to(dtype=torch.float32) + ) # [chunk_start:chunk_end, :, :] + grad_input_output_layer_local = torch.einsum( + "bsd, bsv -> dv", + tensor_parallel_hidden_states[chunk_start:chunk_end, :, :] + .transpose(0, 1) + .contiguous() + .to(dtype=torch.float32), + grad_input.to(dtype=torch.float32), + ) + all_grad_input_hidden_states.append(grad_input_hidden_states) + grad_input_output_layer.add_( + grad_input_output_layer_local.transpose(0, 1).contiguous() + ) + + grad_input_hidden_states = ( + torch.cat(all_grad_input_hidden_states, dim=1).transpose(0, 1).contiguous() + ) + weight_grad = grad_input_output_layer + local_seq_size = seq_size // tp_group_size + + sharded_grad_hidden_states = torch.empty_like( + grad_input_hidden_states[:local_seq_size] + ) + grad_input_hidden_states_list = list( + torch.chunk(grad_input_hidden_states, chunks=tp_group_size, dim=0) + ) + torch.distributed.reduce_scatter( + sharded_grad_hidden_states, + grad_input_hidden_states_list, + op=torch.distributed.ReduceOp.SUM, + group=tp_group, + ) + + return ( + sharded_grad_hidden_states, + None, + weight_grad, + None, + None, + None, + None, + None, + ) + + +def patch_gpt_model_forward_for_linear_ce_fusion(*, chunk_size: int) -> None: + if getattr(GPTModel, "_linear_ce_fusion_forward_patched", False): + GPTModel._linear_ce_fusion_chunk_size = chunk_size + return + GPTModel._original_forward_for_linear_ce_fusion = GPTModel.forward + GPTModel._linear_ce_fusion_chunk_size = chunk_size + GPTModel.forward = _gpt_forward_with_linear_ce_fusion + GPTModel._linear_ce_fusion_forward_patched = True + + +def _gpt_forward_with_linear_ce_fusion( + self: GPTModel, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + decoder_input: torch.Tensor = None, + labels: torch.Tensor = None, + inference_context: Any = None, + packed_seq_params: Any = None, + extra_block_kwargs: Optional[dict] = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[Any] = None, + loss_mask: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + return_logprobs_for_linear_ce_fusion: bool = False, +) -> torch.Tensor: + if not return_logprobs_for_linear_ce_fusion: + return self._original_forward_for_linear_ce_fusion( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + decoder_input=decoder_input, + labels=labels, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + extra_block_kwargs=extra_block_kwargs, + runtime_gather_output=runtime_gather_output, + inference_params=inference_params, + loss_mask=loss_mask, + padding_mask=padding_mask, + ) + """ + original forward function signature: + def forward( + self, + input_ids: Tensor, + position_ids: Tensor, + attention_mask: Tensor, + decoder_input: Tensor = None, + labels: Tensor = None, + inference_context: BaseInferenceContext = None, + packed_seq_params: PackedSeqParams = None, + extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, + *, + inference_params: Optional[BaseInferenceContext] = None, + loss_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + if labels is None: + raise ValueError("labels must be provided when linear CE fusion is enabled") + + inference_context = deprecate_inference_params(inference_context, inference_params) + + preproc_output = self._preprocess( + input_ids=input_ids, + position_ids=position_ids, + decoder_input=decoder_input, + inference_context=inference_context, + packed_seq_params=packed_seq_params, + padding_mask=padding_mask, + ) + ( + decoder_input, + rotary_pos_emb, + rotary_pos_cos, + rotary_pos_sin, + sequence_len_offset, + padding_mask, + ) = preproc_output[:6] + rotary_pos_cos_sin = preproc_output[6] if len(preproc_output) == 7 else None + + hidden_states = self.decoder( + hidden_states=decoder_input, + attention_mask=attention_mask, + inference_context=inference_context, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + rotary_pos_cos_sin=rotary_pos_cos_sin, + packed_seq_params=packed_seq_params, + sequence_len_offset=sequence_len_offset, + padding_mask=padding_mask, + **(extra_block_kwargs or {}), + ) + + # Non post-process pipeline stages do not own the output layer. + if not self.post_process or not hasattr(self, "output_layer"): + return hidden_states + + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_pg_size(get_tensor_model_parallel_group()) + # calculate the logprobs for the last token and then return the logprobs + vocab_start_index = tp_rank * (self.vocab_size // tp_size) + vocab_end_index = min((tp_rank + 1) * (self.vocab_size // tp_size), self.vocab_size) + output_weight_layer = self.output_layer.weight + logprobs = from_parallel_hidden_states_to_logprobs( + hidden_states, # .transpose(0, 1).contiguous(), + output_weight_layer, + self.shared_embedding_or_output_weight() + if self.share_embeddings_and_output_weights + else self.output_layer.weight, + runtime_gather_output, + labels, + vocab_start_index=vocab_start_index, + vocab_end_index=vocab_end_index, + inference_only=inference_context is not None and not self.training, + tp_group=get_tensor_model_parallel_group(), + cp_group=self.cp_group, + chunk_size=self._linear_ce_fusion_chunk_size, + ) + return logprobs + + def all_to_all_vp2sq( vocab_parallel_logits: torch.Tensor, tp_group: torch.distributed.ProcessGroup, diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 2e8c2c3bfe..6a58d930d5 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -53,10 +53,13 @@ from megatron.core import parallel_state from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer import MegatronModule +from megatron.core.transformer.enums import AttnBackend from megatron.core.transformer.module import Float16Module from megatron.core.transformer.transformer_config import TransformerConfig from transformers import PreTrainedTokenizerBase +from nemo_rl.distributed.model_utils import patch_gpt_model_forward_for_linear_ce_fusion + try: from megatron.core.distributed import ( TorchFullyShardedDataParallel as torch_FSDP, # noqa: F401 unused-import @@ -375,6 +378,9 @@ def _apply_parallelism_config(model_cfg: Any, config: PolicyConfig) -> None: assert config["sequence_packing"]["enabled"], ( "Sequence Packing must be enabled to use Context Parallelism with MCore" ) + assert not config["megatron_cfg"].get("use_linear_ce_fusion_loss", False), ( + "Context Parallelism is not supported with linear CE fusion loss, please set use_linear_ce_fusion_loss to false" + ) def _apply_moe_config(model_cfg: Any, config: PolicyConfig) -> None: @@ -462,6 +468,18 @@ def _apply_performance_config(model_cfg: Any, config: PolicyConfig) -> None: # Fusion settings model_cfg.apply_rope_fusion = config["megatron_cfg"]["apply_rope_fusion"] model_cfg.bias_activation_fusion = config["megatron_cfg"]["bias_activation_fusion"] + # Optional explicit attention backend override for environments where + # TE auto backend probing is unstable. + attention_backend = config["megatron_cfg"].get("attention_backend") + if attention_backend is not None: + if isinstance(attention_backend, str): + model_cfg.attention_backend = AttnBackend[attention_backend] + elif isinstance(attention_backend, int): + model_cfg.attention_backend = AttnBackend(attention_backend) + else: + raise ValueError( + f"Unsupported {type(attention_backend)=}, expected str or int" + ) # FP8 configuration fp8_cfg = config["megatron_cfg"].get("fp8_cfg", None) @@ -752,6 +770,10 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]: # Model, optimizer, and learning rate. pg_collection = ProcessGroupCollection.use_mpu_process_groups() setattr(megatron_cfg.model, "_pg_collection", pg_collection) + if policy_cfg["megatron_cfg"].get("use_linear_ce_fusion_loss", False): + patch_gpt_model_forward_for_linear_ce_fusion( + chunk_size=policy_cfg["megatron_cfg"]["linear_ce_fusion_chunk_size"] + ) model = get_model( megatron_cfg.model, megatron_cfg.ddp, diff --git a/nemo_rl/models/megatron/train.py b/nemo_rl/models/megatron/train.py index 883aa44ad7..248c85f3ff 100644 --- a/nemo_rl/models/megatron/train.py +++ b/nemo_rl/models/megatron/train.py @@ -69,19 +69,20 @@ def model_forward( packed_seq_params: Optional[PackedSeqParams] = None, defer_fp32_logits: Optional[bool] = False, straggler_timer: Optional[StragglerDetector] = None, + use_linear_ce_fusion_loss: bool = False, ) -> torch.Tensor: """Perform a single forward pass through the model. Args: model: The model to run forward pass on data_dict: Dictionary containing batch data - cfg: Policy configuration dictionary input_ids_cp_sharded: Context-parallel sharded input token IDs position_ids: Position IDs for tokens attention_mask: Attention mask for the sequence packed_seq_params: Parameters for packed sequences (optional) defer_fp32_logits: Whether to skip the conversion of logits to fp32 straggler_timer: Straggler detector for profiling the forward pass + use_linear_ce_fusion_loss: Whether to use linear CE fusion loss Returns: torch.Tensor: Output tensor from the model (logits) @@ -98,6 +99,11 @@ def model_forward( additional_kwargs["packed_seq_params"] = packed_seq_params if defer_fp32_logits: additional_kwargs["fp32_output"] = False + if use_linear_ce_fusion_loss: + additional_kwargs["labels"] = input_ids_cp_sharded + # Only pass this kwarg when linear CE fusion is enabled. Older Megatron-LM + # GPTModel.forward signatures do not accept it. + additional_kwargs["return_logprobs_for_linear_ce_fusion"] = True with straggler_timer() if straggler_timer is not None else nullcontext(): output_tensor = model( @@ -137,6 +143,7 @@ def forward_with_post_processing_fn( global_valid_toks: Optional[torch.Tensor] = None, sampling_params: Optional[TrainingSamplingParams] = None, straggler_timer: Optional[StragglerDetector] = None, + use_linear_ce_fusion_loss: bool = False, ) -> Tuple[torch.Tensor, Callable]: """Perform forward pass with pre-processed microbatch and return output tensor and post-processing function. @@ -180,6 +187,7 @@ def forward_with_post_processing_fn( packed_seq_params=packed_seq_params, defer_fp32_logits=defer_fp32_logits, straggler_timer=straggler_timer, + use_linear_ce_fusion_loss=use_linear_ce_fusion_loss, ) # Apply temperature scaling only for sampling-oriented post-processors. @@ -233,6 +241,7 @@ def megatron_forward_backward( global_valid_toks: Optional[torch.Tensor] = None, sampling_params: Optional[TrainingSamplingParams] = None, straggler_timer: Optional[StragglerDetector] = None, + use_linear_ce_fusion_loss: bool = False, ) -> Any: """Execute forward and backward passes using Megatron's utilities. @@ -265,6 +274,7 @@ def megatron_forward_backward( global_valid_toks=global_valid_toks, sampling_params=sampling_params, straggler_timer=straggler_timer, + use_linear_ce_fusion_loss=use_linear_ce_fusion_loss, ) forward_backward_func = get_forward_backward_func() return forward_backward_func( diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 4099543bde..3636e5ac64 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -215,6 +215,15 @@ class MegatronConfig(TypedDict): optimizer: MegatronOptimizerConfig scheduler: MegatronSchedulerConfig distributed_data_parallel_config: MegatronDDPConfig + # When True, uses chunked linear cross-entropy fusion loss to compute loss + # directly from hidden states, avoiding materialization of the full + # [batch, seq_len, vocab_size] logit tensor. This significantly reduces peak + # GPU memory, extending the maximum trainable sequence length (e.g. from <65K + # to >100K tokens). Only applicable to SFT with NLLLoss. + use_linear_ce_fusion_loss: NotRequired[bool] + # Number of tokens per chunk when computing the fused linear CE loss. + # Smaller values reduce peak memory further but may decrease throughput. + linear_ce_fusion_chunk_size: NotRequired[int] class TokenizerConfig(TypedDict): diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 3a77d406e4..fdb141fcf8 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -340,6 +340,9 @@ def train( global_valid_toks=global_valid_toks, sampling_params=self.sampling_params, straggler_timer=self.mcore_state.straggler_timer, + use_linear_ce_fusion_loss=self.cfg["megatron_cfg"].get( + "use_linear_ce_fusion_loss", False + ), ) # Empty unused memory. diff --git a/tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh b/tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh new file mode 100755 index 0000000000..261a0e851c --- /dev/null +++ b/tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh @@ -0,0 +1,44 @@ +#!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) +source $SCRIPT_DIR/common.env + +# ===== BEGIN CONFIG ===== +NUM_NODES=1 +STEPS_PER_RUN=10 +MAX_STEPS=10 +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up +NUM_MINUTES=25 +# ===== END CONFIG ===== + +exit_if_max_steps_reached + +# Run the experiment +cd $PROJECT_ROOT +uv run examples/run_sft.py \ + --config $CONFIG_PATH \ + sft.max_num_steps=$MAX_STEPS \ + logger.log_dir=$LOG_DIR \ + logger.wandb_enabled=True \ + logger.wandb.project=nemo-rl \ + logger.wandb.name=$EXP_NAME \ + logger.monitor_gpus=True \ + logger.tensorboard_enabled=True \ + checkpointing.enabled=true \ + checkpointing.checkpoint_dir=$CKPT_DIR \ + ~policy.tokenizer.chat_template \ + $@ \ + 2>&1 | tee $RUN_LOG + +# Convert tensorboard logs to json +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS + +# Only run metrics if the target step is reached +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then + # Smoke checks: run completed and loss is finite/reasonable. + uv run tests/check_metrics.py $JSON_METRICS \ + 'data["train/loss"]["10"] > 0.0' \ + 'data["train/loss"]["10"] < 20.0' + + # Clean up checkpoint directory after successful run to save space. + rm -rf "$CKPT_DIR" +fi diff --git a/tests/test_suites/nightly.txt b/tests/test_suites/nightly.txt index 072945331e..61e474f2c3 100644 --- a/tests/test_suites/nightly.txt +++ b/tests/test_suites/nightly.txt @@ -106,6 +106,8 @@ tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron.sh tests/test_suites/llm/sft-llama3.1-8b-1n8g-megatron-seqpack.sh # validate TP/DP tests/test_suites/llm/sft-qwen2.5-math7b-2n8g-megatron.sh +# chunked linear CE loss +tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh # Nemotron Super 49B SFT tests # Issue with details: https://github.com/NVIDIA-NeMo/RL/issues/1571 diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index ffd2be0445..eff97b215c 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -138,6 +138,8 @@ def create_megatron_test_config( "moe_token_dispatcher_type": "alltoall", "moe_shared_expert_overlap": False, "defer_fp32_logits": defer_fp32_logits, + "use_linear_ce_fusion_loss": False, + "linear_ce_fusion_chunk_size": 256, "train_iters": 100, # Required for Megatron training "optimizer": { "optimizer": "adam", @@ -1879,6 +1881,96 @@ def test_megatron_sft_training(tiny_llama_model_path): cluster.shutdown() +@pytest.mark.timeout(300) +def test_megatron_sft_linear_ce_fusion_agreement(tiny_qwen2_model_path): + """Test that linear CE fusion loss produces the same results as the standard path.""" + num_gpus = 2 + batch_size = 8 + seq_len = 64 + vocab_size = 151936 + + torch.manual_seed(42) + input_ids = torch.randint(0, vocab_size, (batch_size, seq_len)) + attention_mask = torch.ones(batch_size, seq_len) + input_lengths = attention_mask.sum(dim=1).to(torch.int32) + token_mask = torch.triu(torch.ones(batch_size, seq_len), diagonal=1) + sample_mask = torch.ones(batch_size) + labels = torch.randint(0, vocab_size, (batch_size, seq_len)) + + data = BatchedDataDict( + { + "input_ids": input_ids, + "input_lengths": input_lengths, + "attention_mask": attention_mask, + "token_mask": token_mask, + "sample_mask": sample_mask, + "labels": labels, + } + ) + + # --- Standard SFT (no linear CE fusion) --- + cluster_std = RayVirtualCluster( + name="test-sft-std", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + config_std = create_megatron_test_config(tiny_qwen2_model_path) + tokenizer = get_tokenizer(config_std["tokenizer"]) + policy_std = Policy( + cluster=cluster_std, + config=config_std, + tokenizer=tokenizer, + init_reference_model=False, + ) + sft_loss_std = NLLLossFn() + + try: + policy_std.prepare_for_training() + results_std = policy_std.train(data, sft_loss_std) + loss_std = results_std["loss"] + finally: + policy_std.shutdown() + cluster_std.shutdown() + + # --- SFT with linear CE fusion --- + cluster_fuse = RayVirtualCluster( + name="test-sft-fuse", + bundle_ct_per_node_list=[num_gpus], + use_gpus=True, + num_gpus_per_node=num_gpus, + max_colocated_worker_groups=1, + ) + config_fuse = create_megatron_test_config(tiny_qwen2_model_path) + config_fuse["megatron_cfg"]["use_linear_ce_fusion_loss"] = True + config_fuse["megatron_cfg"]["linear_ce_fusion_chunk_size"] = 256 + policy_fuse = Policy( + cluster=cluster_fuse, + config=config_fuse, + tokenizer=tokenizer, + init_reference_model=False, + ) + sft_loss_fuse = NLLLossFn(use_linear_ce_fusion=True) + + try: + policy_fuse.prepare_for_training() + results_fuse = policy_fuse.train(data, sft_loss_fuse) + loss_fuse = results_fuse["loss"] + finally: + policy_fuse.shutdown() + cluster_fuse.shutdown() + + # Verify both produce valid losses + assert not torch.isnan(loss_std).any(), "Standard loss should not be NaN" + assert not torch.isnan(loss_fuse).any(), "Fusion loss should not be NaN" + assert not torch.isinf(loss_std).any(), "Standard loss should not be Inf" + assert not torch.isinf(loss_fuse).any(), "Fusion loss should not be Inf" + + # Verify losses are numerically close + torch.testing.assert_close(loss_std, loss_fuse, rtol=1e-2, atol=1e-2) + + @pytest.mark.hf_gated @pytest.mark.timeout(300) def test_megatron_context_parallel_logprob_agreement(tiny_llama_model_path):