Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,3 +1548,7 @@ def is_sliding(self):
@property
def is_compileable(self) -> bool:
return self.self_attention_cache.is_compileable


# Deprecated alias: SlidingWindowCache was removed in transformers v5. StaticCache is the replacement.
SlidingWindowCache = StaticCache
5 changes: 4 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def __post_init__(self, **kwargs):
def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
cls_has_custom_init = "__init__" in cls.__dict__
cls = dataclass(cls, repr=False)
# kw_only=True ensures fields without defaults in subclasses can follow
# parent fields that have defaults (Python dataclass ordering rule).
# Config fields are always passed as keyword arguments, so this is safe.
cls = dataclass(cls, repr=False, kw_only=True)

if not cls_has_custom_init:
# Wrap all subclasses to accept arbitrary kwargs for BC
Expand Down
31 changes: 31 additions & 0 deletions src/transformers/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,34 @@ def empty_func(*args, **kwargs):
finally:
# Set back the original
PreTrainedModel.tie_weights = original_tie_weights


@contextmanager
def meta_device_safe_creation_ops():
"""
During meta-device model initialisation, ``torch.linspace`` produces meta
tensors that have no data. Custom models loaded from the Hub (remote code)
often call ``.item()`` on these tensors to compute scalar hyperparameters
(e.g. stochastic-depth / drop-path schedules). Native transformers models
already pass ``device="cpu"`` explicitly for such calls (see e.g.
``modeling_swin.py``, ``modeling_pvt_v2.py``), but remote-code models
written before v5 do not.

This context manager patches ``torch.linspace`` to default to
``device="cpu"`` when no explicit device is requested, matching the best
practice already used throughout transformers. Calls that supply an
explicit ``device`` argument (e.g. ``device=self.logits.device``) are left
untouched. ``torch.arange`` is intentionally NOT patched because it is
used in RoPE computations where the device must match model parameters.
"""
original_linspace = torch.linspace

def _safe_linspace(*args, **kwargs):
kwargs.setdefault("device", "cpu")
return original_linspace(*args, **kwargs)

torch.linspace = _safe_linspace
try:
yield
finally:
torch.linspace = original_linspace
28 changes: 19 additions & 9 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,34 +801,44 @@ def validate_rope(self: "PreTrainedConfig"):
)

def _validate_default_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "rope_theta"}
required_keys = {"rope_type"}
optional_keys = {"rope_theta"}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely needs to be done elsewhere then too? I think all need rope theta?

Comment on lines +805 to +806

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, this kinda defeats the point of validation because a RoPE dict with no theta isn't valid for our modules

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but we always default to default_theta if its not there no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation always happens after teh defaults are set, so ideally it shouldn't raise an error. Do we know why the theta was missing?

received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta"}
required_keys = {"rope_type", "factor"}
optional_keys = {"rope_theta"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

factor = rope_parameters["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta"}
required_keys = {"rope_type", "factor"}
optional_keys = {"rope_theta"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

factor = rope_parameters["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"}
required_keys = {"rope_type", "factor", "original_max_position_embeddings"}
optional_keys = {
"rope_theta",
"attention_factor",
"beta_fast",
"beta_slow",
Expand Down Expand Up @@ -878,8 +888,8 @@ def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set
)

def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"}
optional_keys = {"attention_factor", "factor"}
required_keys = {"rope_type", "short_factor", "long_factor", "original_max_position_embeddings"}
optional_keys = {"rope_theta", "attention_factor", "factor"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3618,7 +3618,10 @@ def get_init_context(
elif is_quantized:
init_contexts.extend([torch.device("meta"), set_quantized_state()])
else:
init_contexts.append(torch.device("meta"))
# meta_device_safe_creation_ops patches torch.linspace to default to CPU
# so that custom models calling .item() during __init__ (e.g. drop-path
# schedules) don't crash on meta tensors.
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])

return init_contexts

Expand Down Expand Up @@ -4612,7 +4615,7 @@ def mark_tied_weights_as_initialized(self, loading_info):
later as they will be tied (overwritten) anyway.
This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so
running inits on them is very costly."""
for tied_param in self.all_tied_weights_keys.keys():
for tied_param in getattr(self, "all_tied_weights_keys", {}).keys():

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix remote code

param = self.get_parameter(tied_param)
param._is_hf_initialized = True

Expand Down
15 changes: 6 additions & 9 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,12 @@ def forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape
input_shape = hidden_states.shape[:-1]

queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
hidden_shape = (*input_shape, -1, self.head_dim)
queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
Expand All @@ -283,7 +280,7 @@ def forward(
dropout=0.0 if not self.training else self.dropout,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,16 @@ def forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape
input_shape = hidden_states.shape[:-1]

hidden_shape = (*input_shape, -1, self.head_dim)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(hidden_shape).transpose(1, 2)
keys = keys.view(hidden_shape).transpose(1, 2)
values = values.view(hidden_shape).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
Expand All @@ -480,7 +481,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights

Expand Down
21 changes: 10 additions & 11 deletions src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,15 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None

# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. audioflamingo3 is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()
query_states = (self.q_proj(hidden_states) * self.scaling).view(hidden_shape).transpose(1, 2).contiguous()

# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
Expand All @@ -160,10 +157,12 @@ def forward(
key_states = past_key_values.layers[self.layer_idx].keys
value_states = past_key_values.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
# Use the query's batch dimension for kv view so that a different-batch
# encoder output (e.g. in tests) gets absorbed into the sequence axis,
# preserving backward-compatible behaviour.
kv_shape = (input_shape[0], -1, self.num_heads, self.head_dim)
Comment thread
ArthurZucker marked this conversation as resolved.
key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)

Expand All @@ -183,7 +182,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,12 @@ def forward(
is_cross_attention = key_value_states is not None

# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
input_shape = hidden_states.shape[:-1]

q_input_shape = (bsz, tgt_len, -1, self.head_dim)
kv_input_shape = (bsz, src_len, -1, self.head_dim)
hidden_shape = (*input_shape, -1, self.head_dim)

# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)

is_updated = False
if past_key_values is not None:
Expand All @@ -228,8 +226,9 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
key_states = key_states.view(*kv_input_shape).transpose(1, 2)
value_states = value_states.view(*kv_input_shape).transpose(1, 2)
kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
key_states = key_states.view(kv_shape).transpose(1, 2)
value_states = value_states.view(kv_shape).transpose(1, 2)

if past_key_values is not None:
key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
Expand All @@ -252,7 +251,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
42 changes: 10 additions & 32 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,11 @@ def forward(
interpolate_pos_encoding: bool = False,
resolution: tuple[int] | None = None,
) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_length, _ = hidden_states.shape
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
Expand Down Expand Up @@ -320,22 +309,11 @@ def forward(
f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
"be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
)
batch_size, seq_length, _ = hidden_states.shape
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)

attn_bias = None
if self.has_relative_position_bias:
Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,23 +240,22 @@ def forward(
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = encoder_hidden_states.shape[1]
input_shape = hidden_states.shape[:-1]

q_input_shape = (bsz, tgt_len, -1, self.attention_head_size)
kv_input_shape = (bsz, src_len, -1, self.attention_head_size)
hidden_shape = (*input_shape, -1, self.attention_head_size)

# get query proj
query_layer = self.query(hidden_states).view(*q_input_shape).transpose(1, 2)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)

is_updated = past_key_values.is_updated.get(self.layer_idx) if past_key_values is not None else False
if past_key_values is not None and is_updated:
# reuse k,v, cross_attentions
key_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].keys
value_layer = past_key_values.cross_attention_cache.layers[self.layer_idx].values
else:
key_layer = self.key(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
kv_shape = (*encoder_hidden_states.shape[:-1], -1, self.attention_head_size)
key_layer = self.key(encoder_hidden_states).view(kv_shape).transpose(1, 2)
value_layer = self.value(encoder_hidden_states).view(kv_shape).transpose(1, 2)

if past_key_values is not None:
# save all states to the cache
Expand All @@ -280,7 +279,7 @@ def forward(
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_output, attn_weights


Expand Down
Loading
Loading