From 0810050be53228f779e3d72f3594475e58376007 Mon Sep 17 00:00:00 2001 From: Jeffrey Huynh Date: Thu, 8 Dec 2022 05:32:46 +0000 Subject: [PATCH] Enable bf16 option for XLA devices --- src/transformers/trainer.py | 2 +- src/transformers/training_args.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5c01cbd0427f..e566486fe09d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -565,7 +565,7 @@ def __init__( logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False - if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()): + if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): # deepspeed and SageMaker Model Parallel manage their own half precision if args.half_precision_backend == "cuda_amp": self.use_cuda_amp = True diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c8c0a4588888..01b009cb2adc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1090,9 +1090,9 @@ def __post_init__(self): if self.bf16 or self.bf16_full_eval: - if self.no_cuda and not is_torch_bf16_cpu_available(): + if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): # cpu - raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10") + raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): # gpu raise ValueError( @@ -1140,12 +1140,13 @@ def __post_init__(self): and is_torch_available() and (self.device.type != "cuda") and (get_xla_device_type(self.device) != "GPU") + and (get_xla_device_type(self.device) != "TPU") and (self.device.type != "cpu") and (self.bf16 or self.bf16_full_eval) ): raise ValueError( "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" - " (`--bf16_full_eval`) can only be used on CUDA or CPU devices." + " (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices." ) if self.framework == "pt" and is_torch_available() and self.torchdynamo is not None: