Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,13 @@ def validate(self) -> None:
print_rank_0("Gradient accumulation fusion is not supported with Megatron FSDP, setting to False")
self.model.gradient_accumulation_fusion = False

# reuse_grad_buf_for_mxfp8_param_ag is not supported with Megatron FSDP
if self.ddp.reuse_grad_buf_for_mxfp8_param_ag:
print_rank_0("reuse_grad_buf_for_mxfp8_param_ag is not supported with Megatron FSDP, setting to False")
self.ddp.reuse_grad_buf_for_mxfp8_param_ag = False
if self.optimizer.reuse_grad_buf_for_mxfp8_param_ag:
self.optimizer.reuse_grad_buf_for_mxfp8_param_ag = False

# ModelOpt/Quantization checks
if getattr(self.model, "restore_modelopt_state", False):
assert not self.model.gradient_accumulation_fusion, (
Expand Down
36 changes: 36 additions & 0 deletions tests/unit_tests/training/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,6 +1046,42 @@ def test_megatron_fsdp_config(self, monkeypatch):
finally:
restore_get_world_size_safe(og_ws, cfg_mod)

def test_megatron_fsdp_forces_reuse_grad_buf_false(self, monkeypatch):
"""Test that Megatron FSDP forces reuse_grad_buf_for_mxfp8_param_ag=False on ddp and optimizer."""
gpt_model_cfg = create_test_gpt_config()
train_cfg = create_test_training_config(train_iters=500, global_batch_size=16)
sched_cfg = create_test_scheduler_config()
dist_cfg = create_test_distributed_init_config(use_megatron_fsdp=True)
# Create optimizer config with reuse_grad_buf_for_mxfp8_param_ag=True
optimizer_cfg = create_test_optimizer_config(reuse_grad_buf_for_mxfp8_param_ag=True)
# Create ddp config with reuse_grad_buf_for_mxfp8_param_ag=True
# fp8_param_gather=True is required for reuse_grad_buf in DDP config validation
ddp_cfg = create_test_ddp_config(
use_megatron_fsdp=True, reuse_grad_buf_for_mxfp8_param_ag=True, fp8_param_gather=True
)

container, og_ws, cfg_mod = create_test_config_container(
world_size_override=1,
model_config=gpt_model_cfg,
train_config=train_cfg,
scheduler_config=sched_cfg,
dist_config=dist_cfg,
optimizer_config=optimizer_cfg,
ddp_config=ddp_cfg,
)
try:
# Verify the values are True before validation
assert container.ddp.reuse_grad_buf_for_mxfp8_param_ag is True
assert container.optimizer.reuse_grad_buf_for_mxfp8_param_ag is True

container.validate()

# After validation, both should be forced to False due to FSDP
assert container.ddp.reuse_grad_buf_for_mxfp8_param_ag is False
assert container.optimizer.reuse_grad_buf_for_mxfp8_param_ag is False
finally:
restore_get_world_size_safe(og_ws, cfg_mod)

def test_megatron_fsdp_config_with_torch_fsdp2(self, monkeypatch):
"""Test MegatronFSDP config with torch_fsdp2, should raise ValueError."""
gpt_model_cfg = create_test_gpt_config()
Expand Down