feat: Add chunked linear ce loss function from hidden states#2036
feat: Add chunked linear ce loss function from hidden states#2036yuki-97 merged 28 commits intoNVIDIA-NeMo:mainfrom
Conversation
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
📝 WalkthroughWalkthroughThis PR introduces a linear cross-entropy fusion loss mechanism for efficient distributed training. It adds Changes
Sequence Diagram(s)sequenceDiagram
participant Training as Training Loop
participant GPTModel as GPT Model<br/>(Patched Forward)
participant DistLogprobs as ChunkedDistributedHiddenStatesToLogprobs
participant LossFn as NLLLinearCEFusionLoss
participant Wrapper as SequencePackingNLLLinearCEFusionLossWrapper
Training->>GPTModel: forward(input_ids, labels,<br/>return_logprobs_for_linear_ce_fusion=True)
activate GPTModel
GPTModel->>DistLogprobs: from_parallel_hidden_states_to_logprobs<br/>(tensor_parallel_hidden_states, target, ...)
activate DistLogprobs
DistLogprobs->>DistLogprobs: Chunked distributed log-softmax<br/>with gather across TP ranks
DistLogprobs-->>GPTModel: token_logprobs
deactivate DistLogprobs
GPTModel-->>Training: logprobs (instead of loss)
deactivate GPTModel
Training->>Wrapper: forward(logprobs, labels, ...)
activate Wrapper
Wrapper->>Wrapper: Iterate over packed sequences,<br/>unpad per-sequence data
Wrapper->>LossFn: forward(unpadded_logprobs,<br/>unpadded_labels, ...)
activate LossFn
LossFn->>LossFn: Compute token-level NLL loss
LossFn-->>Wrapper: per-batch loss + metrics
deactivate LossFn
Wrapper->>Wrapper: Accumulate losses<br/>across sequences
Wrapper-->>Training: final_loss, metrics
deactivate Wrapper
Training->>Training: Backward pass with accumulated loss
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment Tip You can get early access to new features in CodeRabbit.Enable the |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (3)
nemo_rl/algorithms/sft.py (1)
213-213: Updatesetupreturn typing to include fusion loss.Line 213 can now return
NLLLinearCEFusionLoss, butsetup(...)is still annotated as returningNLLLossonly. Please update the return type to a union for type safety.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/algorithms/sft.py` at line 213, The annotated return type of setup is too narrow: since loss_fn can be either NLLLoss or NLLLinearCEFusionLoss (see the assignment to loss_fn and the megaron/linear fusion flags), update the setup(...) function signature to reflect a union return type (e.g. Union[NLLLoss, NLLLinearCEFusionLoss]) or the common base class if one exists; also add the necessary typing import (from typing import Union) and adjust any downstream type hints or stubs that assumed only NLLLoss.nemo_rl/models/megatron/train.py (1)
313-314: Add an explicit loss-type guard for the fusion wrapper path.Line 313 chooses the wrapper only from config. If a non-fusion-compatible loss is passed while the flag is on, the failure mode will be late and opaque. Consider a clear upfront
TypeErrorguard.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/models/megatron/train.py` around lines 313 - 314, The selection of sequence_packing_loss_wrapper_type uses only the cfg flag and can pick the fusion wrapper even when an incompatible loss object is passed; add an explicit upfront TypeError guard when self.cfg["megatron_cfg"]["use_linear_ce_fusion_loss"] is true that validates the provided loss is a fusion-compatible type (e.g., check isinstance(loss, <appropriate fusion-compatible loss classes>) or the presence of a fusion-compatible attribute/method) before constructing SequencePackingNLLLinearCEFusionLossWrapper, and raise a clear TypeError naming the expected loss types and the actual type if validation fails; update the logic around sequence_packing_loss_wrapper_type/loss_fn to perform this check prior to instantiation.nemo_rl/distributed/model_utils.py (1)
1067-1069: Trim unused parameters fromfrom_parallel_hidden_states_to_logprobsto keep the API truthful.
output_weightandruntime_gather_outputare passed (Lines 1386-1390) but never consumed by the implementation, which makes the interface misleading.Proposed refactor
def from_parallel_hidden_states_to_logprobs( tensor_parallel_hidden_states: torch.Tensor, output_weight_layer: torch.Tensor, - output_weight: torch.Tensor, - runtime_gather_output: bool, target: torch.Tensor, @@ logprobs = from_parallel_hidden_states_to_logprobs( hidden_states, output_weight_layer, - self.shared_embedding_or_output_weight() - if self.share_embeddings_and_output_weights - else self.output_layer.weight, - runtime_gather_output, labels,Also applies to: 1386-1390
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemo_rl/distributed/model_utils.py` around lines 1067 - 1069, The function signature for from_parallel_hidden_states_to_logprobs currently declares unused parameters output_weight and runtime_gather_output; remove these parameters from the function definition and from any calls that pass them (the callers that invoke from_parallel_hidden_states_to_logprobs with output_weight and runtime_gather_output) so the API matches the implementation, then run tests/formatting to ensure no remaining references to those symbols remain; keep the function name from_parallel_hidden_states_to_logprobs as the identifier to locate and update both the definition and all call sites.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@nemo_rl/algorithms/loss_functions.py`:
- Line 1358: Remove the unused local assignments that cause Ruff F841 by
deleting the unused variables seq_index and seq_end where they are assigned;
locate the assignments (e.g., the line setting seq_index = data.get("seq_index",
None) and the block assigning seq_end) in the function(s) in loss_functions.py
and simply remove those assignment statements (or the unused unpacking) so the
variables are not created if they are not referenced elsewhere.
In `@nemo_rl/algorithms/sft.py`:
- Line 213: The loss selection currently uses hidden defaults via .get(...,
False); change it to read the config keys directly and rely on YAML-provided
defaults by using policy_config["megatron_cfg"]["enabled"] and
policy_config["megatron_cfg"]["use_linear_ce_fusion_loss"] (or validate their
presence earlier) when deciding between NLLLinearCEFusionLoss and NLLLoss
(symbols: loss_fn, NLLLinearCEFusionLoss, NLLLoss,
policy_config["megatron_cfg"]). Ensure you remove .get default values here and
either add an explicit config validation step before this line or let a KeyError
surface so missing values are fixed in configuration rather than silently
defaulted in code.
In `@nemo_rl/distributed/model_utils.py`:
- Around line 1065-1087: The function from_parallel_hidden_states_to_logprobs
forwards chunk_size directly to the chunked autograd op
ChunkedDistributedHiddenStatesToLogprobs.apply but the op expects a positive
integer; validate and normalize chunk_size in
from_parallel_hidden_states_to_logprobs before the apply call (e.g., if
chunk_size is None or <=0, set it to a safe positive default such as 1 or the
hidden dimension), then pass the validated integer to
ChunkedDistributedHiddenStatesToLogprobs.apply; apply the same validation to any
other call sites in this module that forward chunk_size to the chunked autograd
op.
- Around line 1065-1076: The function from_parallel_hidden_states_to_logprobs
currently accepts cp_group but never shards/gathers targets or logprobs by CP,
which can misalign tokens when CP is enabled; update this function to either (A)
implement CP-aware handling: shard targets and gathered logprobs across cp_group
(mirror the TP logic that uses tp_group so that tensor_parallel_hidden_states,
target, and final logprobs are correctly reduced/concatenated across the
model-parallel column group), or (B) add a fail-fast check at the top of
from_parallel_hidden_states_to_logprobs that raises a clear error if cp_group is
not None (or indicates unsupported CP config) so callers (e.g., the call site
passing cp_group=self.cp_group) cannot silently produce incorrect results;
reference the function name from_parallel_hidden_states_to_logprobs and the
parameters cp_group, tensor_parallel_hidden_states, target, tp_group,
runtime_gather_output, output_weight_layer, and output_weight when making the
change.
- Line 1204: Remove the unused local variable assignment
`all_grad_input_output_layer = []` (it is declared but never used) to avoid
leaving a dead local that can interfere with the backward path; locate the
occurrence of `all_grad_input_output_layer` in the function in model_utils.py
and delete the assignment line (or if the variable was intended to be used, wire
it into the computation where gradients are collected instead of leaving it
unused).
- Around line 1276-1281: The code embeds a hidden default chunk_size (256) in
patch_gpt_model_forward_for_linear_ce_fusion and in a getattr call; remove these
hard-coded defaults and read the required value from the canonical config
instead. Update patch_gpt_model_forward_for_linear_ce_fusion to not use a
default parameter (make chunk_size required or accept None and immediately load
policy_cfg['linear_ce_fusion_chunk_size']), remove the getattr(..., 256) usage
so it does not fall back silently, and set GPTModel._linear_ce_fusion_chunk_size
from the explicit config value; keep the existing attribute names
(GPTModel._linear_ce_fusion_chunk_size,
GPTModel._original_forward_for_linear_ce_fusion,
GPTModel._linear_ce_fusion_forward_patched) so the patch logic still finds and
sets the model attributes.
In `@nemo_rl/models/megatron/setup.py`:
- Around line 744-749: The code currently falls back to a hardcoded default
(256) for linear_ce_fusion_chunk_size inside the call to
patch_gpt_model_forward_for_linear_ce_fusion; remove that inline default and
instead require the value be supplied from configuration
(policy_cfg["megatron_cfg"]["linear_ce_fusion_chunk_size"]) or explicitly
validate/panic if missing. Update the conditional around
use_linear_ce_fusion_loss to read the chunk size from policy_cfg (no .get
default), pass that value into patch_gpt_model_forward_for_linear_ce_fusion, and
add a clear error/validation message if linear_ce_fusion_chunk_size is absent or
None so YAML remains the single source of truth.
---
Nitpick comments:
In `@nemo_rl/algorithms/sft.py`:
- Line 213: The annotated return type of setup is too narrow: since loss_fn can
be either NLLLoss or NLLLinearCEFusionLoss (see the assignment to loss_fn and
the megaron/linear fusion flags), update the setup(...) function signature to
reflect a union return type (e.g. Union[NLLLoss, NLLLinearCEFusionLoss]) or the
common base class if one exists; also add the necessary typing import (from
typing import Union) and adjust any downstream type hints or stubs that assumed
only NLLLoss.
In `@nemo_rl/distributed/model_utils.py`:
- Around line 1067-1069: The function signature for
from_parallel_hidden_states_to_logprobs currently declares unused parameters
output_weight and runtime_gather_output; remove these parameters from the
function definition and from any calls that pass them (the callers that invoke
from_parallel_hidden_states_to_logprobs with output_weight and
runtime_gather_output) so the API matches the implementation, then run
tests/formatting to ensure no remaining references to those symbols remain; keep
the function name from_parallel_hidden_states_to_logprobs as the identifier to
locate and update both the definition and all call sites.
In `@nemo_rl/models/megatron/train.py`:
- Around line 313-314: The selection of sequence_packing_loss_wrapper_type uses
only the cfg flag and can pick the fusion wrapper even when an incompatible loss
object is passed; add an explicit upfront TypeError guard when
self.cfg["megatron_cfg"]["use_linear_ce_fusion_loss"] is true that validates the
provided loss is a fusion-compatible type (e.g., check isinstance(loss,
<appropriate fusion-compatible loss classes>) or the presence of a
fusion-compatible attribute/method) before constructing
SequencePackingNLLLinearCEFusionLossWrapper, and raise a clear TypeError naming
the expected loss types and the actual type if validation fails; update the
logic around sequence_packing_loss_wrapper_type/loss_fn to perform this check
prior to instantiation.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
nemo_rl/algorithms/loss_functions.pynemo_rl/algorithms/sft.pynemo_rl/distributed/model_utils.pynemo_rl/models/megatron/setup.pynemo_rl/models/megatron/train.pynemo_rl/models/policy/workers/megatron_policy_worker.pytests/unit/distributed/test_model_utils.py
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
|
@yuki-97 , Hi, this is Peng. This PR is to add chunked linear ce fusion loss to avoid the OOM caused by large logit tensor. Would you please help take a look? Thanks! |
yuki-97
left a comment
There was a problem hiding this comment.
hi @pengdurice , thanks for the contribution, great work! I left some comments.
and do you mind helping to add your experiment at the PR description as a nightly test? you can refer to https://github.com/NVIDIA-NeMo/RL/pull/1866/changes.
- add a config under
examples/configs/recipes/llm/ - add a script under
tests/test_suites/llm/ - add the test to
tests/test_suites/nightly.txt
|
@terrykong could you or find someone who's familiar with mcore distributed part to take a review at |
Signed-off-by: pengdurice <pengduhit@gmail.com>
Signed-off-by: pengdurice <pengduhit@gmail.com>
Co-authored-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: pengdurice <pengduhit@gmail.com>
|
@yuki-97 , thank you so much for your review! I have fixed according to your comments, added the nightly test config and sh files. @terrykong , would you please review the model_utils.py when you got a chance? thanks! |
terrykong
left a comment
There was a problem hiding this comment.
@yaoyu-33 @ananthsub can you review?
|
@yaoyu-33 @ananthsub can you help review please? thanks! |
yuki-97
left a comment
There was a problem hiding this comment.
hi @pengdurice , thanks for the update and sorry for waiting.
I just took a review at model_utils.py and left some comments for that file and previous updates. and there seems some conflict with current main, so you'll need to do a rebase.
@terrykong could you take a review as well?
tests/test_suites/llm/sft-qwen2.5-math7b-1n8g-megatron_chunked_linear_ce_loss.sh
Outdated
Show resolved
Hide resolved
|
/ok to test de97d32 |
There was a problem hiding this comment.
hi @pengdurice , looks there's some unit tests failed. can you help to fix?
https://github.com/NVIDIA-NeMo/RL/actions/runs/23230488220/job/67540009384?pr=2036
https://github.com/NVIDIA-NeMo/RL/actions/runs/23230488220/job/67540009385?pr=2036
just a reminder that you may need to run the lint command again after fixing.
and thanks again for your contribution!
Co-authored-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: pengdurice <pengduhit@gmail.com>
Co-authored-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: pengdurice <pengduhit@gmail.com>
Co-authored-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: pengdurice <pengduhit@gmail.com>
…_fusion_loss value Signed-off-by: pengdurice <pengduhit@gmail.com>
Thank you! It looks to me that it is missing config issue. Thank you for the fix;-) I also changed another two places to be on the safe side, LMK if that's not necessary (one in sft.py), thanks! |
Signed-off-by: Yuki Huang <yukih@nvidia.com>
|
good catch for these two! the one in SFT should always be in the config so let's just I directly update it and let's run CI again. |
|
/ok to test 74bdffb |
Signed-off-by: pengdurice <pengduhit@gmail.com>
Head branch was pushed to by a user without write access
|
@yuki-97, some tests failed (due to the wrong sh name in the nightly.txt, that was skipped on my local cause I called that sh file directly to test), would you mind trigger the CICD again? thank you! |
|
/ok to test c87b344 |
What does this PR do ?
chunked linear cross entropy loss to avoid materialization of logit tensors to avoid OOM.
Issues
None
Key changes
Usage
Before your PR is "Ready for review"
Pre checks:
Additional Information
Summary by CodeRabbit
New Features
Tests