diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 409dd88d0c78..566b39e30454 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -32,7 +32,7 @@ from torch.nn import CrossEntropyLoss from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files -from transformers.utils.import_utils import is_sagemaker_mp_enabled +from transformers.utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled from .activations import get_activation from .configuration_utils import PretrainedConfig @@ -68,12 +68,16 @@ is_offline_mode, is_remote_url, is_safetensors_available, + is_torch_tpu_available, logging, replace_return_docstrings, ) from .utils.versions import require_version_core +XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper() +XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper() + if is_accelerate_available(): from accelerate import __version__ as accelerate_version from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights @@ -181,6 +185,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil for t in parameter.parameters(): last_dtype = t.dtype if t.is_floating_point(): + # Adding fix for https://github.com/pytorch/xla/issues/4152 + # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1 + # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf + if is_torch_tpu_available(): + if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES: + return torch.bfloat16 + if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES: + if t.dtype == torch.float: + return torch.bfloat16 + if t.dtype == torch.double: + return torch.float32 return t.dtype if last_dtype is not None: