diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index 739669c6b1..408ec20b5d 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -1482,6 +1482,13 @@ def validate(self) -> None: print_rank_0("Gradient accumulation fusion is not supported with Megatron FSDP, setting to False") self.model.gradient_accumulation_fusion = False + # reuse_grad_buf_for_mxfp8_param_ag is not supported with Megatron FSDP + if self.ddp.reuse_grad_buf_for_mxfp8_param_ag: + print_rank_0("reuse_grad_buf_for_mxfp8_param_ag is not supported with Megatron FSDP, setting to False") + self.ddp.reuse_grad_buf_for_mxfp8_param_ag = False + if self.optimizer.reuse_grad_buf_for_mxfp8_param_ag: + self.optimizer.reuse_grad_buf_for_mxfp8_param_ag = False + # ModelOpt/Quantization checks if getattr(self.model, "restore_modelopt_state", False): assert not self.model.gradient_accumulation_fusion, ( diff --git a/tests/unit_tests/training/test_config.py b/tests/unit_tests/training/test_config.py index 0a6464c2cc..3bce9d61f7 100644 --- a/tests/unit_tests/training/test_config.py +++ b/tests/unit_tests/training/test_config.py @@ -1046,6 +1046,42 @@ def test_megatron_fsdp_config(self, monkeypatch): finally: restore_get_world_size_safe(og_ws, cfg_mod) + def test_megatron_fsdp_forces_reuse_grad_buf_false(self, monkeypatch): + """Test that Megatron FSDP forces reuse_grad_buf_for_mxfp8_param_ag=False on ddp and optimizer.""" + gpt_model_cfg = create_test_gpt_config() + train_cfg = create_test_training_config(train_iters=500, global_batch_size=16) + sched_cfg = create_test_scheduler_config() + dist_cfg = create_test_distributed_init_config(use_megatron_fsdp=True) + # Create optimizer config with reuse_grad_buf_for_mxfp8_param_ag=True + optimizer_cfg = create_test_optimizer_config(reuse_grad_buf_for_mxfp8_param_ag=True) + # Create ddp config with reuse_grad_buf_for_mxfp8_param_ag=True + # fp8_param_gather=True is required for reuse_grad_buf in DDP config validation + ddp_cfg = create_test_ddp_config( + use_megatron_fsdp=True, reuse_grad_buf_for_mxfp8_param_ag=True, fp8_param_gather=True + ) + + container, og_ws, cfg_mod = create_test_config_container( + world_size_override=1, + model_config=gpt_model_cfg, + train_config=train_cfg, + scheduler_config=sched_cfg, + dist_config=dist_cfg, + optimizer_config=optimizer_cfg, + ddp_config=ddp_cfg, + ) + try: + # Verify the values are True before validation + assert container.ddp.reuse_grad_buf_for_mxfp8_param_ag is True + assert container.optimizer.reuse_grad_buf_for_mxfp8_param_ag is True + + container.validate() + + # After validation, both should be forced to False due to FSDP + assert container.ddp.reuse_grad_buf_for_mxfp8_param_ag is False + assert container.optimizer.reuse_grad_buf_for_mxfp8_param_ag is False + finally: + restore_get_world_size_safe(og_ws, cfg_mod) + def test_megatron_fsdp_config_with_torch_fsdp2(self, monkeypatch): """Test MegatronFSDP config with torch_fsdp2, should raise ValueError.""" gpt_model_cfg = create_test_gpt_config()