diff --git a/src/accelerate/accelerator.py b/src/accelerate/accelerator.py index 152efd6ccd0..e3be2af35a4 100644 --- a/src/accelerate/accelerator.py +++ b/src/accelerate/accelerator.py @@ -382,7 +382,6 @@ def __init__( self.scaler = torch.cuda.amp.GradScaler(**kwargs) elif self.state.mixed_precision == "bf16" and self.distributed_type not in ( DistributedType.DEEPSPEED, - DistributedType.FSDP, DistributedType.MEGATRON_LM, ): if self.device.type == "cpu":