diff --git a/src/transformers/models/afmoe/modeling_afmoe.py b/src/transformers/models/afmoe/modeling_afmoe.py index 4ee42f8137c2..0d7f330d7918 100644 --- a/src/transformers/models/afmoe/modeling_afmoe.py +++ b/src/transformers/models/afmoe/modeling_afmoe.py @@ -37,7 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_afmoe import AfmoeConfig @@ -97,7 +97,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/apertus/modeling_apertus.py b/src/transformers/models/apertus/modeling_apertus.py index 2425da1ea658..b6d0aa94cfbc 100644 --- a/src/transformers/models/apertus/modeling_apertus.py +++ b/src/transformers/models/apertus/modeling_apertus.py @@ -36,7 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_apertus import ApertusConfig @@ -131,7 +131,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index f9f869c90191..900a530a34b7 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -43,7 +43,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_arcee import ArceeConfig @@ -138,7 +138,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 9eca2c271f4f..a9ef9ea7cb28 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ..auto import AutoModel from .configuration_aria import AriaConfig, AriaTextConfig @@ -675,7 +675,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index ec6e5ad8aec6..9a5790c1fb6a 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -43,6 +43,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import maybe_autocast from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig @@ -250,7 +251,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 05017cdb56b8..76532cd10982 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -36,7 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_bitnet import BitNetConfig @@ -326,7 +326,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/blt/modeling_blt.py b/src/transformers/models/blt/modeling_blt.py index 294c0c2e99f0..b640a1b3db09 100644 --- a/src/transformers/models/blt/modeling_blt.py +++ b/src/transformers/models/blt/modeling_blt.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_blt import ( BltConfig, BltGlobalTransformerConfig, @@ -141,7 +141,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/blt/modular_blt.py b/src/transformers/models/blt/modular_blt.py index 66876d58e404..c8c0812b00c1 100644 --- a/src/transformers/models/blt/modular_blt.py +++ b/src/transformers/models/blt/modular_blt.py @@ -29,7 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401 from ..llama.modeling_llama import LlamaRotaryEmbedding from ..mllama.modeling_mllama import ( @@ -277,7 +277,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 2e816ad63b4d..04abf453c815 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -38,6 +38,7 @@ can_return_tuple, logging, ) +from ...utils.generic import maybe_autocast from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig @@ -122,7 +123,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a3df29f848e1..10c9c845695f 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -45,7 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_cohere import CohereConfig @@ -122,7 +122,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 5147e5638eb2..f3bce4003e6a 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -36,6 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.generic import maybe_autocast from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -75,7 +76,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 9699844cfc6a..e398e363409b 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -36,7 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_cohere2 import Cohere2Config @@ -96,7 +96,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index 070040a2e426..2fa8c10bbc1f 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -30,6 +30,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.generic import maybe_autocast from ..cohere.modeling_cohere import ( CohereAttention, CohereDecoderLayer, @@ -222,7 +223,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat() cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index 73be86a3648b..d27a1cd19671 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -40,6 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import maybe_autocast from ...utils.import_utils import is_torchdynamo_compiling from ..auto import AutoModel from .configuration_csm import CsmConfig, CsmDepthDecoderConfig @@ -174,7 +175,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/cwm/modeling_cwm.py b/src/transformers/models/cwm/modeling_cwm.py index 838143d22cbb..4bfa7842294e 100644 --- a/src/transformers/models/cwm/modeling_cwm.py +++ b/src/transformers/models/cwm/modeling_cwm.py @@ -37,7 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_cwm import CwmConfig @@ -97,7 +97,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index f3aa337828ab..459941267397 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -37,7 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_dbrx import DbrxConfig @@ -97,7 +97,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 11fa8a88deb3..ecce72e40aa8 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -34,6 +34,7 @@ auto_docstring, logging, ) +from ...utils.generic import maybe_autocast from .configuration_decision_transformer import DecisionTransformerConfig @@ -141,7 +142,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None): scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.autocast(query.device.type, enabled=False): + with maybe_autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index f7d1d12ef44e..c8de73dba293 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_deepseek_v2 import DeepseekV2Config @@ -223,7 +223,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation freqs_cis = freqs_cis * self.attention_scaling diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index 510be8576374..d39c8db60036 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -25,6 +25,7 @@ from ...modeling_rope_utils import RopeParameters, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import logging +from ...utils.generic import maybe_autocast from ..llama.configuration_llama import LlamaConfig from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -303,7 +304,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation freqs_cis = freqs_cis * self.attention_scaling diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 97fe6cf3e843..0f6e93abefe2 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -29,7 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_deepseek_v3 import DeepseekV3Config @@ -110,7 +110,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 85d363e0fde6..cf665cc3c473 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import maybe_autocast from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig from .generation_dia import DiaGenerationMixin @@ -184,7 +185,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 5149b670cab5..673925f7d819 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -46,7 +46,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_diffllama import DiffLlamaConfig @@ -125,7 +125,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py index 018b071d6179..2827524127e9 100644 --- a/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modeling_dinov3_vit.py @@ -36,7 +36,7 @@ from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import TransformersKwargs, auto_docstring, can_return_tuple from ...utils.backbone_utils import BackboneMixin -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_dinov3_vit import DINOv3ViTConfig @@ -156,7 +156,7 @@ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso device = pixel_values.device device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 # Although we could precompute static patch_coords from image_size and patch_size in the config, # the model was trained with random_scale, so it can process images of varying sizes. # Therefore, it's better to compute patch_coords dynamically (with lru_cache). diff --git a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py index 0e04f532ea65..1f6848269966 100644 --- a/src/transformers/models/dinov3_vit/modular_dinov3_vit.py +++ b/src/transformers/models/dinov3_vit/modular_dinov3_vit.py @@ -40,7 +40,7 @@ from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging from ...utils.backbone_utils import BackboneMixin -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_dinov3_vit import DINOv3ViTConfig @@ -163,7 +163,7 @@ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tenso device = pixel_values.device device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 # Although we could precompute static patch_coords from image_size and patch_size in the config, # the model was trained with random_scale, so it can process images of varying sizes. # Therefore, it's better to compute patch_coords dynamically (with lru_cache). diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 57f4b9d88902..ccc3911b8c95 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -42,7 +42,7 @@ from ...modeling_utils import AttentionInterface, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_doge import DogeConfig @@ -127,7 +127,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 1ef3a6396fb6..c80090877c0f 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_dots1 import Dots1Config @@ -119,7 +119,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/efficientloftr/modeling_efficientloftr.py b/src/transformers/models/efficientloftr/modeling_efficientloftr.py index a6711138ba49..3e464d7c22fd 100644 --- a/src/transformers/models/efficientloftr/modeling_efficientloftr.py +++ b/src/transformers/models/efficientloftr/modeling_efficientloftr.py @@ -33,7 +33,7 @@ can_return_tuple, torch_int, ) -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_efficientloftr import EfficientLoFTRConfig @@ -147,7 +147,7 @@ def forward( embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1 embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1 device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size) sin = emb.sin() cos = emb.cos() diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index d6654c69b614..3b3d7513e7ee 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -41,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig @@ -1167,7 +1167,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index deecb9e6940f..2a04ca804c49 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -35,7 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_ernie4_5 import Ernie4_5Config @@ -95,7 +95,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/ernie4_5/modular_ernie4_5.py b/src/transformers/models/ernie4_5/modular_ernie4_5.py index 780b07164ec0..51c5d03a8a8b 100644 --- a/src/transformers/models/ernie4_5/modular_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modular_ernie4_5.py @@ -18,6 +18,7 @@ from ...modeling_rope_utils import dynamic_rope_update from ...utils import auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from ..glm.modeling_glm import rotate_half from ..llama.modeling_llama import ( LlamaAttention, @@ -36,7 +37,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 5bd3c7b896ad..11c4ef6b0a72 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -37,7 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig @@ -135,7 +135,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling @@ -371,7 +371,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens else "cpu" ) - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 router_logits = F.linear(hidden_states.float(), self.weight) router_logits = F.softmax(router_logits, dim=1, dtype=torch.float) router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1) diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index 74feb925a7ce..2c27e1ad81de 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -26,7 +26,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from ..ernie4_5.modeling_ernie4_5 import Ernie4_5RotaryEmbedding, apply_rotary_pos_emb, rotate_half # noqa: F401 from ..llama.modeling_llama import LlamaAttention, LlamaRMSNorm from ..mixtral.modeling_mixtral import ( @@ -146,7 +146,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens else "cpu" ) - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 router_logits = F.linear(hidden_states.float(), self.weight) router_logits = F.softmax(router_logits, dim=1, dtype=torch.float) router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1) diff --git a/src/transformers/models/esm/modeling_esmfold.py b/src/transformers/models/esm/modeling_esmfold.py index 2d9f7a20eb1b..dc00161d478c 100644 --- a/src/transformers/models/esm/modeling_esmfold.py +++ b/src/transformers/models/esm/modeling_esmfold.py @@ -32,6 +32,7 @@ auto_docstring, logging, ) +from ...utils.generic import maybe_autocast from .modeling_esm import EsmModel, EsmPreTrainedModel from .openfold_utils import ( OFProtein, @@ -267,7 +268,7 @@ def __init__(self, c_in, eps=1e-5): def forward(self, x): d = x.dtype if d is torch.bfloat16 and not is_deepspeed_initialized(): - with torch.autocast(device_type="cuda", enabled=False): + with maybe_autocast(device_type="cuda", enabled=False): out = nn.functional.layer_norm(x, self.c_in, self.weight.to(dtype=d), self.bias.to(dtype=d), self.eps) else: out = nn.functional.layer_norm(x, self.c_in, self.weight, self.bias, self.eps) @@ -282,7 +283,7 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor: """ d = t.dtype if d is torch.bfloat16 and not is_deepspeed_initialized(): - with torch.autocast(device_type="cuda", enabled=False): + with maybe_autocast(device_type="cuda", enabled=False): s = torch.nn.functional.softmax(t, dim=dim) else: s = torch.nn.functional.softmax(t, dim=dim) @@ -868,7 +869,7 @@ def forward( device_type = a.device.type if a.device.type != "mps" else "cpu" if is_fp16_enabled(device_type): - with torch.autocast(device_type=device_type, enabled=False): + with maybe_autocast(device_type=device_type, enabled=False): x = self._combine_projections(a.float(), b.float()) else: x = self._combine_projections(a, b) @@ -1491,7 +1492,7 @@ def forward( # [*, H, N_res, N_res] device_type = q.device.type if q.device.type != "mps" else "cpu" if is_fp16_enabled(device_type): - with torch.autocast(device_type=device_type, enabled=False): + with maybe_autocast(device_type=device_type, enabled=False): a = torch.matmul( permute_final_dims(q.float(), (1, 0, 2)), # [*, H, N_res, C_hidden] permute_final_dims(k.float(), (1, 2, 0)), # [*, H, C_hidden, N_res] diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 81cca5d3bf9a..597e64e2370d 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -45,7 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_evolla import EvollaConfig, SaProtConfig @@ -1019,7 +1019,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index ff083bd93a18..40e557c5fa6b 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -44,6 +44,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from .configuration_exaone4 import Exaone4Config @@ -124,7 +125,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 3d14f1d40df7..32380fb4df54 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -48,6 +48,7 @@ auto_docstring, logging, ) +from ...utils.generic import maybe_autocast from .configuration_falcon import FalconConfig @@ -160,7 +161,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index ba4f8ccbfd67..a52daaab1352 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -45,6 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.generic import maybe_autocast from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_falcon_h1 import FalconH1Config @@ -279,7 +280,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 8aeb014ff12b..dd57fe199486 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_flex_olmo import FlexOlmoConfig @@ -119,7 +119,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index 8c14f33fb46c..d70e2be87843 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -41,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_gemma import GemmaConfig @@ -137,7 +137,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index fb3cfcc2f5db..c2c8c72f3605 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -42,7 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_gemma2 import Gemma2Config @@ -138,7 +138,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index b991f2d2fc65..3150dd1772b2 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -34,6 +34,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.generic import maybe_autocast from ..gemma.modeling_gemma import ( GemmaAttention, GemmaForCausalLM, @@ -252,7 +253,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index df3277003398..60e24e26a0ca 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -39,7 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ..auto import AutoModel from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig @@ -214,7 +214,7 @@ def forward(self, x, position_ids, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * attention_scaling diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index f17945d90931..2bf81619f1e5 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -33,6 +33,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import maybe_autocast from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -437,7 +438,7 @@ def forward(self, x, position_ids, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * attention_scaling diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index b72b20b61261..8d95a5a367e1 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ..auto import AutoModel from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig @@ -1525,7 +1525,7 @@ def forward(self, x, position_ids, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * attention_scaling diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 52ba32f84267..51072f7a67c6 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_glm import GlmConfig @@ -120,7 +120,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 91cfcd7e4e65..3b1261a1ef27 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -41,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_glm4 import Glm4Config @@ -325,7 +325,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index f21a66f8c819..4b34c5987569 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -39,7 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_glm4_moe import Glm4MoeConfig @@ -101,7 +101,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index 58c0ca50d17e..78c4c9581994 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig @@ -446,7 +446,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index 7f81d03f8ac9..c6e832048cd5 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -36,7 +36,7 @@ from ...processing_utils import Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ...video_utils import VideoInput from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, Glm4RotaryEmbedding, eager_attention_forward from ..qwen2_5_vl.modeling_qwen2_5_vl import ( @@ -511,7 +511,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 5ef301974813..491ce7c3d3f4 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -41,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_glm4v_moe import Glm4vMoeConfig, Glm4vMoeTextConfig, Glm4vMoeVisionConfig @@ -150,7 +150,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 16a0d987c2f9..de9f9530d55c 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -45,6 +45,7 @@ auto_docstring, logging, ) +from ...utils.generic import maybe_autocast from .configuration_gpt2 import GPT2Config @@ -150,7 +151,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None): scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.autocast(query.device.type, enabled=False): + with maybe_autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7490e019e18d..ad4805cb5633 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -28,7 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_gpt_neox import GPTNeoXConfig @@ -107,7 +107,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 811e628cc3a0..078a40c8de1f 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -30,6 +30,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.generic import maybe_autocast from .configuration_gpt_neox_japanese import GPTNeoXJapaneseConfig @@ -116,7 +117,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 544ece409a22..91a8546e0c9a 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -41,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_gpt_oss import GptOssConfig @@ -236,7 +236,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = freqs cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 9432ac28b0fb..7a5227946fe5 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -34,7 +34,7 @@ auto_docstring, logging, ) -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaPreTrainedModel, @@ -185,7 +185,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = freqs cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index ca67268b843b..66f9c38e7832 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -36,7 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_granite import GraniteConfig @@ -376,7 +376,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 6cafb3a32eee..75d34787c5a1 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring -from ...utils.generic import can_return_tuple, check_model_inputs +from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast from .configuration_granitemoe import GraniteMoeConfig @@ -119,7 +119,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 9e45bd420f91..e96944c45f35 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -39,7 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_granitemoehybrid import GraniteMoeHybridConfig @@ -954,7 +954,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 425e36a5ed05..24cbe16f2052 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring -from ...utils.generic import can_return_tuple, check_model_inputs +from ...utils.generic import can_return_tuple, check_model_inputs, maybe_autocast from .configuration_granitemoeshared import GraniteMoeSharedConfig @@ -533,7 +533,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index eb786565defe..e46e2f57fe05 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -41,7 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_helium import HeliumConfig @@ -118,7 +118,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py index ad47eb08d7e4..6079a5aeedc3 100644 --- a/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +++ b/src/transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_hunyuan_v1_dense import HunYuanDenseV1Config @@ -359,7 +359,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index bbfd8cc58c47..458bd0f5723d 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_hunyuan_v1_moe import HunYuanMoEV1Config @@ -452,7 +452,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 0b2288f682c4..288d10748be0 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -38,6 +38,7 @@ logging, torch_float, ) +from ...utils.generic import maybe_autocast from .configuration_imagegpt import ImageGPTConfig @@ -150,7 +151,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None): scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) - with torch.autocast(query.device.type, enabled=False): + with maybe_autocast(query.device.type, enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 275f67592eb3..2a33f09bc6bd 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_jetmoe import JetMoeConfig @@ -122,7 +122,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index e43c548d57b9..7e8c37919957 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -39,6 +39,7 @@ from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.generic import maybe_autocast from ..auto import AutoModel from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig @@ -315,7 +316,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/lasr/modeling_lasr.py b/src/transformers/models/lasr/modeling_lasr.py index c7ef834845fe..df66da296f7c 100644 --- a/src/transformers/models/lasr/modeling_lasr.py +++ b/src/transformers/models/lasr/modeling_lasr.py @@ -35,7 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_lasr import LasrCTCConfig, LasrEncoderConfig @@ -123,7 +123,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index fbd6eb8aa39b..d7d4716500ae 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -34,7 +34,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ...utils.import_utils import is_causal_conv1d_available, is_torchdynamo_compiling from .configuration_lfm2 import Lfm2Config @@ -122,7 +122,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py index 6ad0ab9067ca..f90c3c207c39 100644 --- a/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/modeling_lfm2_moe.py @@ -35,7 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ...utils.import_utils import is_causal_conv1d_available, is_torchdynamo_compiling from .configuration_lfm2_moe import Lfm2MoeConfig @@ -123,7 +123,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c27684870f05..54ea8e0da271 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -42,7 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_llama import LlamaConfig @@ -126,7 +126,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 657b8e003dc2..9d45fa3f6c08 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_llama4 import Llama4Config, Llama4TextConfig @@ -228,7 +228,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2) freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation freqs_cis = freqs_cis * self.attention_scaling diff --git a/src/transformers/models/longcat_flash/modeling_longcat_flash.py b/src/transformers/models/longcat_flash/modeling_longcat_flash.py index e7e541ba70ca..2246fbc71756 100644 --- a/src/transformers/models/longcat_flash/modeling_longcat_flash.py +++ b/src/transformers/models/longcat_flash/modeling_longcat_flash.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_longcat_flash import LongcatFlashConfig @@ -121,7 +121,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 2f7626f73bb5..be06fe36261c 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -32,6 +32,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.generic import maybe_autocast from .configuration_mimi import MimiConfig @@ -559,7 +560,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index c24bd5eef573..a7c535888007 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -45,7 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_minimax import MiniMaxConfig @@ -310,7 +310,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/ministral/modeling_ministral.py b/src/transformers/models/ministral/modeling_ministral.py index eed97cc17d69..7b3aa97ff5ee 100644 --- a/src/transformers/models/ministral/modeling_ministral.py +++ b/src/transformers/models/ministral/modeling_ministral.py @@ -27,7 +27,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_ministral import MinistralConfig @@ -328,7 +328,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/ministral3/modeling_ministral3.py b/src/transformers/models/ministral3/modeling_ministral3.py index dad049e3634d..871d333f3430 100644 --- a/src/transformers/models/ministral3/modeling_ministral3.py +++ b/src/transformers/models/ministral3/modeling_ministral3.py @@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from .configuration_ministral3 import Ministral3Config @@ -333,7 +334,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 0af1c53babe8..bcb8b007400d 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -29,6 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from .configuration_mistral import MistralConfig @@ -323,7 +324,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1b1c65215dc8..e41870213a4a 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -51,7 +51,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder +from ...utils.generic import OutputRecorder, maybe_autocast from .configuration_mixtral import MixtralConfig @@ -208,7 +208,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0dc1770cc139..d84e22c4d2be 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -37,7 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig @@ -781,7 +781,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index 8740d3ed4b5e..993d7e634614 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -45,6 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_flash_attn_2_available, logging +from ...utils.generic import maybe_autocast from ...utils.import_utils import is_triton_available from .configuration_modernbert import ModernBertConfig @@ -316,7 +317,7 @@ def forward(self, x, position_ids, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * attention_scaling diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index ec9cc6fdd58a..066cd0e0539b 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -39,7 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_modernbert_decoder import ModernBertDecoderConfig @@ -168,7 +168,7 @@ def forward(self, x, position_ids, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * attention_scaling diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index ca385796fd89..3e9d585f6ea6 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -46,6 +46,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from .configuration_moonshine import MoonshineConfig @@ -138,7 +139,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 771d3f341327..a825cc53379f 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -34,6 +34,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.generic import maybe_autocast from ..auto.modeling_auto import AutoModel from .configuration_moshi import MoshiConfig, MoshiDepthConfig @@ -327,7 +328,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/nanochat/modeling_nanochat.py b/src/transformers/models/nanochat/modeling_nanochat.py index 65630762a1c3..f7d0d9261a9f 100644 --- a/src/transformers/models/nanochat/modeling_nanochat.py +++ b/src/transformers/models/nanochat/modeling_nanochat.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_nanochat import NanoChatConfig @@ -113,7 +113,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 7b4e94599d96..00a1d2f3ca18 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -45,6 +45,7 @@ ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.generic import maybe_autocast from .configuration_nemotron import NemotronConfig @@ -87,7 +88,7 @@ def forward(self, input: Tensor) -> Tensor: args = _cast_if_autocast_enabled( device_type, input, self.normalized_shape, self.weight + 1, self.bias, self.eps ) - with torch.autocast(device_type=input.device.type, enabled=False): + with maybe_autocast(device_type=input.device.type, enabled=False): return F.layer_norm(*args) @@ -151,7 +152,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2172462ccc28..8df79456c60a 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -42,7 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_olmo import OlmoConfig @@ -132,7 +132,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 4f2796837cb5..5f60d4d57c57 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -29,6 +29,7 @@ from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import logging +from ...utils.generic import maybe_autocast from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -77,7 +78,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 130da226011e..6c0772e14556 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -43,7 +43,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_olmo2 import Olmo2Config @@ -124,7 +124,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/olmo3/modeling_olmo3.py b/src/transformers/models/olmo3/modeling_olmo3.py index 5fca8a74d449..61c9096d7748 100644 --- a/src/transformers/models/olmo3/modeling_olmo3.py +++ b/src/transformers/models/olmo3/modeling_olmo3.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_olmo3 import Olmo3Config @@ -332,7 +332,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index a890f6db5ff8..4e41a75fea5f 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -35,7 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_olmoe import OlmoeConfig @@ -116,7 +116,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 2d16b470838a..292638ee3f14 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -39,6 +39,7 @@ requires_backends, ) from ...utils.backbone_utils import load_backbone +from ...utils.generic import maybe_autocast from .configuration_oneformer import OneFormerConfig @@ -322,7 +323,7 @@ def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class align_corners=False, ).squeeze(1) - with torch.autocast(device_type="cuda", enabled=False): + with maybe_autocast(device_type="cuda", enabled=False): pred_mask = pred_mask.float() target_mask = target_mask.float() diff --git a/src/transformers/models/parakeet/modeling_parakeet.py b/src/transformers/models/parakeet/modeling_parakeet.py index 672f63bb4b34..58909c249bb6 100644 --- a/src/transformers/models/parakeet/modeling_parakeet.py +++ b/src/transformers/models/parakeet/modeling_parakeet.py @@ -35,7 +35,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig @@ -88,7 +88,7 @@ def forward(self, hidden_states: torch.Tensor): if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" else "cpu" ) - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) sin = freqs.sin() cos = freqs.cos() diff --git a/src/transformers/models/parakeet/modular_parakeet.py b/src/transformers/models/parakeet/modular_parakeet.py index 2424a4afadd9..bdafbfc3507f 100644 --- a/src/transformers/models/parakeet/modular_parakeet.py +++ b/src/transformers/models/parakeet/modular_parakeet.py @@ -29,7 +29,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule from ..llama.modeling_llama import LlamaAttention, eager_attention_forward from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig @@ -84,7 +84,7 @@ def forward(self, hidden_states: torch.Tensor): if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps" else "cpu" ) - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) sin = freqs.sin() cos = freqs.cos() diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a7054a2bd989..2eb827778e37 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -46,6 +46,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.generic import maybe_autocast from .configuration_persimmon import PersimmonConfig @@ -118,7 +119,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 23c060f0a11f..f58f3be5347c 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -25,7 +25,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_phi import PhiConfig @@ -90,7 +90,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 56d09ad58c18..aef959f42720 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -44,6 +44,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from .configuration_phi3 import Phi3Config @@ -123,7 +124,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index 27b312a7e389..b37888669ee9 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -47,7 +47,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_int -from ...utils.generic import TransformersKwargs, check_model_inputs +from ...utils.generic import TransformersKwargs, check_model_inputs, maybe_autocast from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig @@ -602,7 +602,7 @@ def forward( # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 - with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + with maybe_autocast(device_type=inputs_embeds.device.type, enabled=False): image_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_img_set_tensor, accumulate=False ) @@ -1116,7 +1116,7 @@ def forward( merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 - with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + with maybe_autocast(device_type=inputs_embeds.device.type, enabled=False): audio_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_audio_embeds, accumulate=False ) @@ -1500,7 +1500,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 2d1fd160e2ee..2f34aa160b25 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -37,7 +37,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging -from ...utils.generic import TransformersKwargs, check_model_inputs +from ...utils.generic import TransformersKwargs, check_model_inputs, maybe_autocast from ..phi3.configuration_phi3 import Phi3Config from ..phi3.modeling_phi3 import ( Phi3DecoderLayer, @@ -844,7 +844,7 @@ def forward( # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 - with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + with maybe_autocast(device_type=inputs_embeds.device.type, enabled=False): image_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_img_set_tensor, accumulate=False ) @@ -1358,7 +1358,7 @@ def forward( merged_audio_embeds = merged_audio_embeds.to(dtype=inputs_embeds.dtype, device=inputs_embeds.device) # Temporarily disable autocast to avoid issue on bf16 tensors # Ref: https://github.com/pytorch/pytorch/issues/132715 - with torch.autocast(device_type=inputs_embeds.device.type, enabled=False): + with maybe_autocast(device_type=inputs_embeds.device.type, enabled=False): audio_embeds = inputs_embeds.index_put( indices=positions_tuple, values=merged_audio_embeds, accumulate=False ) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 752af7a19257..60b0ba979eeb 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_phimoe import PhimoeConfig @@ -113,7 +113,7 @@ def forward(self, x, position_ids=None, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * mscale diff --git a/src/transformers/models/phimoe/modular_phimoe.py b/src/transformers/models/phimoe/modular_phimoe.py index 76693282256a..774190308c87 100644 --- a/src/transformers/models/phimoe/modular_phimoe.py +++ b/src/transformers/models/phimoe/modular_phimoe.py @@ -24,7 +24,7 @@ GenericForSequenceClassification, ) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS -from ...utils.generic import OutputRecorder +from ...utils.generic import OutputRecorder, maybe_autocast from ..llama.modeling_llama import LlamaAttention from ..mixtral.modeling_mixtral import ( MixtralDecoderLayer, @@ -74,7 +74,7 @@ def forward(self, x, position_ids=None, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * mscale diff --git a/src/transformers/models/pixtral/modeling_pixtral.py b/src/transformers/models/pixtral/modeling_pixtral.py index ef861b9ebdc1..64cf923b11cd 100644 --- a/src/transformers/models/pixtral/modeling_pixtral.py +++ b/src/transformers/models/pixtral/modeling_pixtral.py @@ -28,6 +28,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.generic import maybe_autocast from .configuration_pixtral import PixtralVisionConfig @@ -125,7 +126,7 @@ def compute_default_rope_parameters( def forward(self, x, position_ids): freqs = self.inv_freq[position_ids] device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 emb = freqs cos = emb.cos() sin = emb.sin() diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 1d4f6a6baf6b..a2d67f93a9be 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -27,7 +27,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_qwen2 import Qwen2Config @@ -103,7 +103,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 17124a5b4b66..656471617677 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -43,6 +43,7 @@ from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import maybe_autocast from ...utils.hub import cached_file from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from .configuration_qwen2_5_omni import ( @@ -1291,7 +1292,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling @@ -2559,7 +2560,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ab377b745f5e..80a39b96b411 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -43,6 +43,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import maybe_autocast from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig @@ -547,7 +548,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index f5d484dd3ef7..cc67efff6bd9 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -48,7 +48,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_qwen2_moe import Qwen2MoeConfig @@ -129,7 +129,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 335439a1d3a1..f99d6b1c4990 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -44,6 +44,7 @@ can_return_tuple, logging, ) +from ...utils.generic import maybe_autocast from ..qwen2.modeling_qwen2 import ( Qwen2RMSNorm, ) @@ -164,7 +165,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index b0f8084f8da6..bac3690fcef1 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -42,7 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_qwen3 import Qwen3Config @@ -139,7 +139,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 66ceff7ae5e8..5700813e5dac 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -44,7 +44,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_qwen3_moe import Qwen3MoeConfig @@ -440,7 +440,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 8c066f2de40a..f67d83282696 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -44,7 +44,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_linear_attention_available, @@ -233,7 +233,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 213230d660b9..dc9c02824ebb 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -50,7 +50,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs +from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs, maybe_autocast from .configuration_qwen3_omni_moe import ( Qwen3OmniMoeAudioEncoderConfig, Qwen3OmniMoeCode2WavConfig, @@ -1291,7 +1291,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) @@ -2516,7 +2516,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 0938d540742a..41438c1a37b2 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -39,7 +39,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_qwen3_vl import Qwen3VLConfig, Qwen3VLTextConfig, Qwen3VLVisionConfig @@ -337,7 +337,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 20585db81039..8bf9572baca1 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -35,7 +35,7 @@ from ...processing_utils import ProcessingKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from ...video_utils import VideoInput from ..llama.modeling_llama import LlamaRotaryEmbedding from ..qwen2_5_vl.modeling_qwen2_5_vl import ( @@ -389,7 +389,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 07bdadc9b18f..ca59338c1ae6 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_qwen3_vl_moe import Qwen3VLMoeConfig, Qwen3VLMoeTextConfig, Qwen3VLMoeVisionConfig @@ -860,7 +860,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) emb = torch.cat((freqs, freqs), dim=-1) diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 63573208d36b..2da5adcb27f0 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -31,6 +31,7 @@ from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.generic import maybe_autocast from ...utils.import_utils import is_tracing from .configuration_recurrent_gemma import RecurrentGemmaConfig @@ -121,7 +122,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/seed_oss/modeling_seed_oss.py b/src/transformers/models/seed_oss/modeling_seed_oss.py index 9126606a68fb..c3dd76357e2e 100644 --- a/src/transformers/models/seed_oss/modeling_seed_oss.py +++ b/src/transformers/models/seed_oss/modeling_seed_oss.py @@ -40,7 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_seed_oss import SeedOssConfig @@ -350,7 +350,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 05445e812855..2a9107bc67e9 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -42,7 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_smollm3 import SmolLM3Config @@ -102,7 +102,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 5412dd600c82..2328585f34b1 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -45,6 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.generic import maybe_autocast from .configuration_stablelm import StableLmConfig @@ -117,7 +118,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ce504c5947ac..dcd71a74d47c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -48,6 +48,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import maybe_autocast from .configuration_starcoder2 import Starcoder2Config @@ -327,7 +328,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index a7be0f0e465f..ca07431ec3fd 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -45,7 +45,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig @@ -147,7 +147,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/t5gemma2/modeling_t5gemma2.py b/src/transformers/models/t5gemma2/modeling_t5gemma2.py index 09935adf509f..4f5cca0edec5 100644 --- a/src/transformers/models/t5gemma2/modeling_t5gemma2.py +++ b/src/transformers/models/t5gemma2/modeling_t5gemma2.py @@ -46,7 +46,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import OutputRecorder, check_model_inputs +from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast from ..auto import AutoModel from .configuration_t5gemma2 import T5Gemma2Config, T5Gemma2DecoderConfig, T5Gemma2EncoderConfig, T5Gemma2TextConfig @@ -162,7 +162,7 @@ def forward(self, x, position_ids, layer_type=None): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * attention_scaling diff --git a/src/transformers/models/vaultgemma/modeling_vaultgemma.py b/src/transformers/models/vaultgemma/modeling_vaultgemma.py index 4be80f4573d3..95171197fdd8 100644 --- a/src/transformers/models/vaultgemma/modeling_vaultgemma.py +++ b/src/transformers/models/vaultgemma/modeling_vaultgemma.py @@ -38,7 +38,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple -from ...utils.generic import check_model_inputs +from ...utils.generic import check_model_inputs, maybe_autocast from .configuration_vaultgemma import VaultGemmaConfig @@ -336,7 +336,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 6079b846188f..a6bec1ff724b 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -41,6 +41,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging +from ...utils.generic import maybe_autocast from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config @@ -263,7 +264,7 @@ def forward(self, x, position_ids): position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 + with maybe_autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos() * self.attention_scaling diff --git a/src/transformers/utils/generic.py b/src/transformers/utils/generic.py index aef86d382a76..fc3b04549414 100644 --- a/src/transformers/utils/generic.py +++ b/src/transformers/utils/generic.py @@ -21,7 +21,7 @@ import warnings from collections import OrderedDict, UserDict, defaultdict from collections.abc import Callable, Iterable, MutableMapping -from contextlib import AbstractContextManager, ExitStack +from contextlib import AbstractContextManager, ExitStack, nullcontext from dataclasses import dataclass, fields, is_dataclass from enum import Enum from functools import partial, wraps @@ -42,6 +42,7 @@ if is_torch_available(): # required for @can_return_tuple decorator to work with torchdynamo import torch + from torch.types import _dtype from ..model_debugging_utils import model_addition_debugger_context @@ -154,6 +155,28 @@ def is_torch_dtype(x): return isinstance(x, torch.dtype) +def maybe_autocast( + device_type: str, + dtype: Optional["_dtype"] = None, + enabled: bool = True, + cache_enabled: Optional[bool] = None, +): + """ + Context manager that only autocasts if: + + - `autocast` is already enabled in this context + - Or this call to `maybe_autocast` has `enabled=True` + + This prevents `autocast` being added to the graph when it is effectively a no-op. + Which makes graph splitting in `torch.compile` more flexible as it removes the + requirement that partition IDs be monotonically increasing. + """ + if torch.is_autocast_enabled(device_type) or enabled: + return torch.autocast(device_type, dtype=dtype, enabled=enabled, cache_enabled=cache_enabled) + else: + return nullcontext() + + def _is_mlx(x): import mlx.core as mx