Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/transformers/models/afmoe/modeling_afmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_afmoe import AfmoeConfig


Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/apertus/modeling_apertus.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_apertus import ApertusConfig


Expand Down Expand Up @@ -131,7 +131,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/arcee/modeling_arcee.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_arcee import ArceeConfig


Expand Down Expand Up @@ -138,7 +138,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/aria/modeling_aria.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from ..auto import AutoModel
from .configuration_aria import AriaConfig, AriaTextConfig

Expand Down Expand Up @@ -675,7 +675,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/bamba/modeling_bamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
from .configuration_bamba import BambaConfig

Expand Down Expand Up @@ -250,7 +251,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bitnet/modeling_bitnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_bitnet import BitNetConfig


Expand Down Expand Up @@ -326,7 +326,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/blt/modeling_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from .configuration_blt import (
BltConfig,
BltGlobalTransformerConfig,
Expand Down Expand Up @@ -141,7 +141,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/blt/modular_blt.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, logging
from ...utils.generic import OutputRecorder, check_model_inputs
from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401
from ..llama.modeling_llama import LlamaRotaryEmbedding
from ..mllama.modeling_mllama import (
Expand Down Expand Up @@ -277,7 +277,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
can_return_tuple,
logging,
)
from ...utils.generic import maybe_autocast
from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig


Expand Down Expand Up @@ -122,7 +123,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_cohere import CohereConfig


Expand Down Expand Up @@ -122,7 +122,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere/modular_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ...utils.generic import maybe_autocast
from ..llama.modeling_llama import (
LlamaAttention,
LlamaForCausalLM,
Expand Down Expand Up @@ -75,7 +76,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cohere2/modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_cohere2 import Cohere2Config


Expand Down Expand Up @@ -96,7 +96,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/cohere2/modular_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, logging
from ...utils.generic import maybe_autocast
from ..cohere.modeling_cohere import (
CohereAttention,
CohereDecoderLayer,
Expand Down Expand Up @@ -222,7 +223,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
cos = emb.cos() * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/csm/modeling_csm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
from ...utils.generic import maybe_autocast
from ...utils.import_utils import is_torchdynamo_compiling
from ..auto import AutoModel
from .configuration_csm import CsmConfig, CsmDepthDecoderConfig
Expand Down Expand Up @@ -174,7 +175,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/cwm/modeling_cwm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_cwm import CwmConfig


Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_dbrx import DbrxConfig


Expand Down Expand Up @@ -97,7 +97,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
auto_docstring,
logging,
)
from ...utils.generic import maybe_autocast
from .configuration_decision_transformer import DecisionTransformerConfig


Expand Down Expand Up @@ -141,7 +142,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None):
scale_factor /= float(self.layer_idx + 1)

# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
with torch.autocast(query.device.type, enabled=False):
with maybe_autocast(query.device.type, enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_deepseek_v2 import DeepseekV2Config


Expand Down Expand Up @@ -223,7 +223,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
freqs_cis = freqs_cis * self.attention_scaling
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/deepseek_v2/modular_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ...modeling_rope_utils import RopeParameters, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import logging
from ...utils.generic import maybe_autocast
from ..llama.configuration_llama import LlamaConfig
from ..llama.modeling_llama import (
LlamaDecoderLayer,
Expand Down Expand Up @@ -303,7 +304,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.to(x.device) @ position_ids_expanded).transpose(1, 2)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # Convert to complex representation
freqs_cis = freqs_cis * self.attention_scaling
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from ...utils.generic import check_model_inputs, maybe_autocast
from .configuration_deepseek_v3 import DeepseekV3Config


Expand Down Expand Up @@ -110,7 +110,7 @@ def forward(self, x, position_ids):
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
with maybe_autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
Expand Down
Loading