Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torch.nn import CrossEntropyLoss

from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from transformers.utils.import_utils import is_sagemaker_mp_enabled
from transformers.utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled

from .activations import get_activation
from .configuration_utils import PretrainedConfig
Expand Down Expand Up @@ -68,12 +68,16 @@
is_offline_mode,
is_remote_url,
is_safetensors_available,
is_torch_tpu_available,
logging,
replace_return_docstrings,
)
from .utils.versions import require_version_core


XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()

if is_accelerate_available():
from accelerate import __version__ as accelerate_version
from accelerate import dispatch_model, infer_auto_device_map, init_empty_weights
Expand Down Expand Up @@ -181,6 +185,17 @@ def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtil
for t in parameter.parameters():
last_dtype = t.dtype
if t.is_floating_point():
# Adding fix for https://github.com/pytorch/xla/issues/4152
# Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
# and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
if is_torch_tpu_available():
if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES:
return torch.bfloat16
if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES:
if t.dtype == torch.float:
return torch.bfloat16
if t.dtype == torch.double:
return torch.float32
return t.dtype

if last_dtype is not None:
Expand Down