Skip to content

[training] fix: normalize cuda_graph_scope type before membership checks#2578

Merged
yaoyu-33 merged 8 commits intomainfrom
yuya/fix-cuda-graph-scope-type-error
Mar 31, 2026
Merged

[training] fix: normalize cuda_graph_scope type before membership checks#2578
yaoyu-33 merged 8 commits intomainfrom
yuya/fix-cuda-graph-scope-type-error

Conversation

@yaoyu-33
Copy link
Copy Markdown
Contributor

@yaoyu-33 yaoyu-33 commented Feb 27, 2026

Summary

  • Fix TypeError: 'in <string>' requires string as left operand, not CudaGraphScope in comm_overlap.py delay_wgrad validation
  • Normalize cuda_graph_scope to a list before performing in membership checks, handling cases where it's a string ("full"), a single CudaGraphScope enum, or None

Test plan

  • tests/unit_tests/training/test_comm_overlap.py::TestMegatronCommOverlapConfig::test_delay_wgrad_config_validation should pass
  • tests/unit_tests/training/test_comm_overlap.py::TestMegatronCommOverlapConfig::test_delay_wgrad_config_validation_with_overlap_grad_reduce should pass
  • All other test_comm_overlap.py tests should continue to pass

Made with Cursor

Summary by CodeRabbit

  • New Features

    • Added CUDA graph scope validations for delayed weight gradient computation to ensure training stability and compatibility requirements
    • Enhanced FLOP calculations with Multi-Token Prediction depth support for improved performance metrics in hybrid models
  • Tests

    • Added comprehensive unit tests for CUDA graph validation and hybrid pattern parsing scenarios

yaoyu-33 and others added 5 commits February 11, 2026 09:11
…and delay_wgrad checks

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
…rad cuda-graph checks

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
…omm_overlap

cuda_graph_scope can be a string (e.g., "full") when TransformerConfig's
__post_init__ normalization hasn't processed it. The `in` operator on a
string requires a string left operand, but CudaGraphScope enum values
are not strings, causing TypeError in delay_wgrad validation.

Signed-off-by: Yu Yao <yaoyu.094@gmail.com>
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
Made-with: Cursor
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 662e98c

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Feb 27, 2026

📝 Walkthrough

Walkthrough

The PR adds CUDA graph scope validations for delayed weight gradient computation, introduces MTP (Multi-Token Prediction) depth awareness to FLOP calculations, and implements dynamic hybrid layout pattern parsing. Changes include new configuration requirements when CUDA graph scopes are enabled and parameter propagation of mtp_num_layers through FLOP calculation functions.

Changes

Cohort / File(s) Summary
CUDA Graph Scope Validation
src/megatron/bridge/training/comm_overlap.py
Adds validation logic for delay_wgrad with CUDA graph scope enabled: enforces TE version ≥ 2.12.0, requires gradient_accumulation_fusion when wgrad is in graph scope, and rejects attention bias when attention scope is in graph scope.
FLOP Calculation Enhancement
src/megatron/bridge/training/utils/flop_utils.py
Introduces dynamic hybrid pattern parsing via importlib with fallback, adds mtp_num_layers parameter to hybrid_flops(), and integrates MTP depth into logits and overall FLOP calculations with backward-compatible error handling.
CUDA Graph Validation Tests
tests/unit_tests/training/test_comm_overlap.py
Adds three unit tests validating CUDA graph scope with delay_wgrad: ensures gradient_accumulation_fusion requirement, rejects attention bias configuration, and verifies validation success with supported settings.
MTP FLOP Scaling Tests
tests/unit_tests/training/utils/test_flop_utils.py
Adds test class validating MTP depth inference from hybrid patterns and verifies logits FLOP scaling matches explicit mtp_num_layers specification.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

cherry-pick, r0.3.0

Suggested reviewers

  • ananthsub
  • gautham-kollu
🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Test Results For Major Changes ⚠️ Warning PR contains major changes (MTP feature, function signature change, CUDA graph validations) but lacks documented test results despite mentioning a test plan. Add actual test results showing tests passed and document no regression in FLOP calculations and model convergence.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately reflects the main fix: normalizing cuda_graph_scope type before membership checks to resolve a TypeError in the delay_wgrad validation logic.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yuya/fix-cuda-graph-scope-type-error

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit_tests/training/test_comm_overlap.py (1)

615-715: Please add regression cases for raw string and single-enum cuda_graph_scope forms.

These tests only pass [CudaGraphScope.attn], which would not catch the original type-normalization failure mode ("attn" or CudaGraphScope.attn directly).

✅ Suggested test tightening
+    `@pytest.mark.parametrize`("cuda_scope", ["attn", CudaGraphScope.attn, [CudaGraphScope.attn]])
     def test_delay_wgrad_cuda_graph_attn_requires_grad_accum_fusion(self):
@@
-            cuda_graph_scope=[CudaGraphScope.attn],
+            cuda_graph_scope=cuda_scope,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit_tests/training/test_comm_overlap.py` around lines 615 - 715, Add
regression tests to cover the two alternative forms of cuda_graph_scope that
previously broke normalization: one where cuda_graph_scope is provided as the
raw string "attn" and one where it is provided as the single enum value
CudaGraphScope.attn (not wrapped in a list). In the same test group that
exercises CommOverlapConfig._get_model_comm_overlap_cfgs (using
CommOverlapConfig, create_gpt_config and DistributedDataParallelConfig),
duplicate the existing passing and failing cases but set
model_cfg.cuda_graph_scope to "attn" and to CudaGraphScope.attn respectively,
and assert the same outcomes (raising AssertionError for invalid configs and
returning delay_wgrad_compute True for the valid config) so the normalization
handling is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit_tests/training/test_comm_overlap.py`:
- Around line 615-715: Add regression tests to cover the two alternative forms
of cuda_graph_scope that previously broke normalization: one where
cuda_graph_scope is provided as the raw string "attn" and one where it is
provided as the single enum value CudaGraphScope.attn (not wrapped in a list).
In the same test group that exercises
CommOverlapConfig._get_model_comm_overlap_cfgs (using CommOverlapConfig,
create_gpt_config and DistributedDataParallelConfig), duplicate the existing
passing and failing cases but set model_cfg.cuda_graph_scope to "attn" and to
CudaGraphScope.attn respectively, and assert the same outcomes (raising
AssertionError for invalid configs and returning delay_wgrad_compute True for
the valid config) so the normalization handling is exercised.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e62dac3 and 662e98c.

📒 Files selected for processing (4)
  • src/megatron/bridge/training/comm_overlap.py
  • src/megatron/bridge/training/utils/flop_utils.py
  • tests/unit_tests/training/test_comm_overlap.py
  • tests/unit_tests/training/utils/test_flop_utils.py

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 6f87c74

Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 39a4d2c

3 similar comments
@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 39a4d2c

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 39a4d2c

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test 39a4d2c

@yaoyu-33
Copy link
Copy Markdown
Contributor Author

/ok to test a31341d

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants