Skip to content
35 changes: 35 additions & 0 deletions src/megatron/bridge/training/comm_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,41 @@ def _get_model_comm_overlap_cfgs(
"CUDA graph with delay_wgrad_compute does not support attention bias for now."
)

# CUDA graph scope-specific validations for delayed wgrad.
cuda_graph_scope = getattr(model_cfg, "cuda_graph_scope", None)
if cuda_graph_scope is None or cuda_graph_scope == "full":
cuda_graph_scope = []
elif isinstance(cuda_graph_scope, (str, CudaGraphScope)):
cuda_graph_scope = [cuda_graph_scope]
attn_scope_enabled = (
CudaGraphScope.attn in cuda_graph_scope
or CudaGraphScope.attn.value in cuda_graph_scope
or f"CudaGraphScope.{CudaGraphScope.attn.value}" in cuda_graph_scope
)
moe_router_scope_enabled = (
CudaGraphScope.moe_router in cuda_graph_scope
or CudaGraphScope.moe_router.value in cuda_graph_scope
or f"CudaGraphScope.{CudaGraphScope.moe_router.value}" in cuda_graph_scope
)
wgrad_in_graph_scope = attn_scope_enabled or (
moe_router_scope_enabled
and getattr(model_cfg, "moe_shared_expert_intermediate_size", None) is not None
and not getattr(model_cfg, "moe_shared_expert_overlap", False)
)
if wgrad_in_graph_scope:
assert is_te_min_version("2.12.0"), (
"CUDA graph with delay_wgrad_compute requires TE version >= 2.12.0."
)
assert model_cfg.gradient_accumulation_fusion, (
"CUDA graph with delay_wgrad_compute requires gradient_accumulation_fusion "
"to be enabled. This is because default gradient accumulation does not use "
"static memory addresses, which breaks CUDA graph requirements."
)
if attn_scope_enabled:
assert not model_cfg.add_bias_linear and not model_cfg.add_qkv_bias, (
"CUDA graph with delay_wgrad_compute does not support attention bias for now."
)

comm_overlap_cfg = self._override_user_cfgs(comm_overlap_cfg)
return comm_overlap_cfg

Expand Down
Loading