[training] fix: normalize cuda_graph_scope type before membership checks#2578
[training] fix: normalize cuda_graph_scope type before membership checks#2578
Conversation
…and delay_wgrad checks Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
…rad cuda-graph checks Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
…omm_overlap cuda_graph_scope can be a string (e.g., "full") when TransformerConfig's __post_init__ normalization hasn't processed it. The `in` operator on a string requires a string left operand, but CudaGraphScope enum values are not strings, causing TypeError in delay_wgrad validation. Signed-off-by: Yu Yao <yaoyu.094@gmail.com> Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com> Made-with: Cursor
|
/ok to test 662e98c |
📝 WalkthroughWalkthroughThe PR adds CUDA graph scope validations for delayed weight gradient computation, introduces MTP (Multi-Token Prediction) depth awareness to FLOP calculations, and implements dynamic hybrid layout pattern parsing. Changes include new configuration requirements when CUDA graph scopes are enabled and parameter propagation of Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit_tests/training/test_comm_overlap.py (1)
615-715: Please add regression cases for raw string and single-enumcuda_graph_scopeforms.These tests only pass
[CudaGraphScope.attn], which would not catch the original type-normalization failure mode ("attn"orCudaGraphScope.attndirectly).✅ Suggested test tightening
+ `@pytest.mark.parametrize`("cuda_scope", ["attn", CudaGraphScope.attn, [CudaGraphScope.attn]]) def test_delay_wgrad_cuda_graph_attn_requires_grad_accum_fusion(self): @@ - cuda_graph_scope=[CudaGraphScope.attn], + cuda_graph_scope=cuda_scope,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit_tests/training/test_comm_overlap.py` around lines 615 - 715, Add regression tests to cover the two alternative forms of cuda_graph_scope that previously broke normalization: one where cuda_graph_scope is provided as the raw string "attn" and one where it is provided as the single enum value CudaGraphScope.attn (not wrapped in a list). In the same test group that exercises CommOverlapConfig._get_model_comm_overlap_cfgs (using CommOverlapConfig, create_gpt_config and DistributedDataParallelConfig), duplicate the existing passing and failing cases but set model_cfg.cuda_graph_scope to "attn" and to CudaGraphScope.attn respectively, and assert the same outcomes (raising AssertionError for invalid configs and returning delay_wgrad_compute True for the valid config) so the normalization handling is exercised.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit_tests/training/test_comm_overlap.py`:
- Around line 615-715: Add regression tests to cover the two alternative forms
of cuda_graph_scope that previously broke normalization: one where
cuda_graph_scope is provided as the raw string "attn" and one where it is
provided as the single enum value CudaGraphScope.attn (not wrapped in a list).
In the same test group that exercises
CommOverlapConfig._get_model_comm_overlap_cfgs (using CommOverlapConfig,
create_gpt_config and DistributedDataParallelConfig), duplicate the existing
passing and failing cases but set model_cfg.cuda_graph_scope to "attn" and to
CudaGraphScope.attn respectively, and assert the same outcomes (raising
AssertionError for invalid configs and returning delay_wgrad_compute True for
the valid config) so the normalization handling is exercised.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (4)
src/megatron/bridge/training/comm_overlap.pysrc/megatron/bridge/training/utils/flop_utils.pytests/unit_tests/training/test_comm_overlap.pytests/unit_tests/training/utils/test_flop_utils.py
|
/ok to test 6f87c74 |
Signed-off-by: yaoyu-33 <yaoyu.094@gmail.com>
|
/ok to test 39a4d2c |
3 similar comments
|
/ok to test 39a4d2c |
|
/ok to test 39a4d2c |
|
/ok to test 39a4d2c |
|
/ok to test a31341d |
Summary
TypeError: 'in <string>' requires string as left operand, not CudaGraphScopeincomm_overlap.pydelay_wgrad validationcuda_graph_scopeto a list before performinginmembership checks, handling cases where it's a string ("full"), a singleCudaGraphScopeenum, orNoneTest plan
tests/unit_tests/training/test_comm_overlap.py::TestMegatronCommOverlapConfig::test_delay_wgrad_config_validationshould passtests/unit_tests/training/test_comm_overlap.py::TestMegatronCommOverlapConfig::test_delay_wgrad_config_validation_with_overlap_grad_reduceshould passtest_comm_overlap.pytests should continue to passMade with Cursor
Summary by CodeRabbit
New Features
Tests