diff --git a/src/transformers/deepspeed.py b/src/transformers/deepspeed.py index 46cfb9730760..6feabdaa8095 100644 --- a/src/transformers/deepspeed.py +++ b/src/transformers/deepspeed.py @@ -250,7 +250,7 @@ def trainer_config_process(self, args): self.fill_match("bf16.enabled", (args.bf16 or args.bf16_full_eval), "bf16|bf16_full_eval") # deepspeed's default mode is fp16 unless there is a config that says differently - if self.is_true("bfoat16.enabled"): + if self.is_true("bf16.enabled"): self._dtype = torch.bfloat16 elif self.is_false("fp16.enabled"): self._dtype = torch.float32