Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,3 +1548,7 @@ def is_sliding(self):
@property
def is_compileable(self) -> bool:
return self.self_attention_cache.is_compileable


# Deprecated alias: SlidingWindowCache was removed in transformers v5. StaticCache is the replacement.
SlidingWindowCache = StaticCache
5 changes: 4 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def __post_init__(self, **kwargs):
def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
cls_has_custom_init = "__init__" in cls.__dict__
cls = dataclass(cls, repr=False)
# kw_only=True ensures fields without defaults in subclasses can follow
# parent fields that have defaults (Python dataclass ordering rule).
# Config fields are always passed as keyword arguments, so this is safe.
cls = dataclass(cls, repr=False, kw_only=True)

if not cls_has_custom_init:
# Wrap all subclasses to accept arbitrary kwargs for BC
Expand Down
31 changes: 31 additions & 0 deletions src/transformers/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,34 @@ def empty_func(*args, **kwargs):
finally:
# Set back the original
PreTrainedModel.tie_weights = original_tie_weights


@contextmanager
def meta_device_safe_creation_ops():
"""
During meta-device model initialisation, ``torch.linspace`` produces meta
tensors that have no data. Custom models loaded from the Hub (remote code)
often call ``.item()`` on these tensors to compute scalar hyperparameters
(e.g. stochastic-depth / drop-path schedules). Native transformers models
already pass ``device="cpu"`` explicitly for such calls (see e.g.
``modeling_swin.py``, ``modeling_pvt_v2.py``), but remote-code models
written before v5 do not.

This context manager patches ``torch.linspace`` to default to
``device="cpu"`` when no explicit device is requested, matching the best
practice already used throughout transformers. Calls that supply an
explicit ``device`` argument (e.g. ``device=self.logits.device``) are left
untouched. ``torch.arange`` is intentionally NOT patched because it is
used in RoPE computations where the device must match model parameters.
"""
original_linspace = torch.linspace

def _safe_linspace(*args, **kwargs):
kwargs.setdefault("device", "cpu")
return original_linspace(*args, **kwargs)

torch.linspace = _safe_linspace
try:
yield
finally:
torch.linspace = original_linspace
68 changes: 65 additions & 3 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import math
import warnings
from collections.abc import Callable
Expand Down Expand Up @@ -626,10 +627,47 @@ def _compute_llama3_parameters(
return inv_freq_llama, attention_factor


def _compute_default_rope_parameters(
config: Optional["PreTrainedConfig"] = None,
device: Optional["torch.device"] = None,
seq_len: int | None = None,
layer_type: str | None = None,
) -> tuple["torch.Tensor", float]:
"""
Computes the inverse frequencies for the default RoPE implementation (no scaling).

Args:
config ([`~transformers.PreTrainedConfig`]):
The model configuration.
device (`torch.device`):
The device to use for initialization of the inverse frequencies.
seq_len (`int`, *optional*):
The current sequence length. Unused for this type of RoPE.
layer_type (`str`, *optional*):
The layer type for per-layer rope configs.

Returns:
Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
"""
config.standardize_rope_params()
rope_parameters_dict = config.rope_parameters[layer_type] if layer_type is not None else config.rope_parameters

base = rope_parameters_dict.get("rope_theta", getattr(config, "rope_theta", config.default_theta))
partial_rotary_factor = rope_parameters_dict.get("partial_rotary_factor", 1.0)
head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
dim = int(head_dim * partial_rotary_factor)
attention_factor = 1.0 # Unused in this type of RoPE

inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim))
return inv_freq, attention_factor


# This maps the "rope_type" string field in rope config to the corresponding function to compute the RoPE parameters
# from the model config. You can append new {'rope_type': callable} pairs to this rope_parameters to enable custom RoPE
# parameterizations, as long as the callable has the same signature.
ROPE_INIT_FUNCTIONS: dict[str, Callable[..., tuple["torch.Tensor", float]]] = {
"default": _compute_default_rope_parameters,

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remote code BC?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zucchini-nlp

It is true that remote code won't have it, but we likely would also need to refactor a lot of models, seems risky; especially for models that do have a different default init so we need to check if some code exists first and then use it as fallback

"linear": _compute_linear_scaling_rope_parameters,
"dynamic": _compute_dynamic_ntk_parameters,
"yarn": _compute_yarn_parameters,
Expand Down Expand Up @@ -771,10 +809,25 @@ def standardize_rope_params(self):

self.rope_parameters = rope_parameters

def validate_rope(self: "PreTrainedConfig"):
def validate_rope(self: "PreTrainedConfig", **kwargs):
"""
Validate the RoPE config arguments, given a `"PreTrainedConfig"` object

Note: the `ignore_keys` keyword argument is accepted for backward compatibility with external libraries
(e.g. vllm) but is deprecated. Set `config.ignore_keys_at_rope_validation` directly instead.
"""
if kwargs:
import warnings

warnings.warn(
"Passing keyword arguments to `validate_rope()` is deprecated. "
"Set `config.ignore_keys_at_rope_validation` directly instead.",
FutureWarning,
stacklevel=2,
)
ignore_keys = kwargs.pop("ignore_keys", None)
if ignore_keys is not None:
self.ignore_keys_at_rope_validation = self.ignore_keys_at_rope_validation | ignore_keys
# Don't validate if no rope_parameters found (`None`) or if it's an empty dict
# Note that validation runs every time a new config is created, even if config is non-RoPE
rope_parameters_dict = getattr(self, "rope_parameters", None)
Expand All @@ -800,11 +853,20 @@ def validate_rope(self: "PreTrainedConfig"):
f"Missing validation function in 'RotaryEmbeddingConfigMixin' for 'rope_type'='{rope_type}'"
)

# Override __signature__ so that @strict dataclass validation (huggingface_hub) sees only `self`.
# The method still accepts **kwargs for backward compatibility with external callers (e.g. vllm).
validate_rope.__signature__ = inspect.Signature(
[inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)]
)

def _validate_default_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "rope_theta"}
required_keys = {"rope_type"}
optional_keys = {"rope_theta"}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely needs to be done elsewhere then too? I think all need rope theta?

Comment on lines +805 to +806

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, this kinda defeats the point of validation because a RoPE dict with no theta isn't valid for our modules

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but we always default to default_theta if its not there no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation always happens after teh defaults are set, so ideally it shouldn't raise an error. Do we know why the theta was missing?

received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta"}
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3618,7 +3618,10 @@ def get_init_context(
elif is_quantized:
init_contexts.extend([torch.device("meta"), set_quantized_state()])
else:
init_contexts.append(torch.device("meta"))
# meta_device_safe_creation_ops patches torch.linspace to default to CPU
# so that custom models calling .item() during __init__ (e.g. drop-path
# schedules) don't crash on meta tensors.
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])

return init_contexts

Expand Down Expand Up @@ -4612,7 +4615,7 @@ def mark_tied_weights_as_initialized(self, loading_info):
later as they will be tied (overwritten) anyway.
This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so
running inits on them is very costly."""
for tied_param in self.all_tied_weights_keys.keys():
for tied_param in getattr(self, "all_tied_weights_keys", {}).keys():

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix remote code

param = self.get_parameter(tied_param)
param._is_hf_initialized = True

Expand Down
15 changes: 6 additions & 9 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,12 @@ def forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape
input_shape = hidden_states.shape[:-1]

queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
hidden_shape = (*input_shape, -1, self.head_dim)
queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
Expand All @@ -283,7 +280,7 @@ def forward(
dropout=0.0 if not self.training else self.dropout,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,16 @@ def forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape
input_shape = hidden_states.shape[:-1]

hidden_shape = (*input_shape, -1, self.head_dim)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(hidden_shape).transpose(1, 2)
keys = keys.view(hidden_shape).transpose(1, 2)
values = values.view(hidden_shape).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
Expand All @@ -480,7 +481,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights

Expand Down
21 changes: 10 additions & 11 deletions src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,15 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None

# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. audioflamingo3 is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()
query_states = (self.q_proj(hidden_states) * self.scaling).view(hidden_shape).transpose(1, 2).contiguous()

# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
Expand All @@ -160,10 +157,12 @@ def forward(
key_states = past_key_values.layers[self.layer_idx].keys
value_states = past_key_values.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
# Use the query's batch dimension for kv view so that a different-batch
# encoder output (e.g. in tests) gets absorbed into the sequence axis,
# preserving backward-compatible behaviour.
kv_shape = (input_shape[0], -1, self.num_heads, self.head_dim)
Comment thread
ArthurZucker marked this conversation as resolved.
key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)

Expand All @@ -183,7 +182,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,12 @@ def forward(
is_cross_attention = key_value_states is not None

# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
input_shape = hidden_states.shape[:-1]

q_input_shape = (bsz, tgt_len, -1, self.head_dim)
kv_input_shape = (bsz, src_len, -1, self.head_dim)
hidden_shape = (*input_shape, -1, self.head_dim)

# get query proj
query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)

is_updated = False
if past_key_values is not None:
Expand All @@ -228,8 +226,9 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
key_states = key_states.view(*kv_input_shape).transpose(1, 2)
value_states = value_states.view(*kv_input_shape).transpose(1, 2)
kv_shape = (*current_states.shape[:-1], -1, self.head_dim)
key_states = key_states.view(kv_shape).transpose(1, 2)
value_states = value_states.view(kv_shape).transpose(1, 2)

if past_key_values is not None:
key_states, value_states = curr_past_key_values.update(key_states, value_states, self.layer_idx)
Expand All @@ -252,7 +251,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
42 changes: 10 additions & 32 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,22 +254,11 @@ def forward(
interpolate_pos_encoding: bool = False,
resolution: tuple[int] | None = None,
) -> tuple[torch.Tensor] | tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_length, _ = hidden_states.shape
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)

# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
Expand Down Expand Up @@ -320,22 +309,11 @@ def forward(
f"{self.__class__.__name__} does not support `output_attentions=True`. The returned attention weights will "
"be `None`. If you want to get attention weights, please set `attn_implementation='eager'` when loading the model."
)
batch_size, seq_length, _ = hidden_states.shape
query_layer = (
self.query(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
key_layer = (
self.key(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
value_layer = (
self.value(hidden_states)
.view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
.transpose(1, 2)
)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
query_layer = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(hidden_shape).transpose(1, 2)

attn_bias = None
if self.has_relative_position_bias:
Expand Down
Loading
Loading