Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/en/internal/rope_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ from transformers import LlamaConfig
config = LlamaConfig()
config.rope_parameters = {
"rope_type": "default", # type of RoPE to use
"rope_theta": 10000.0 # base frequency parameter
# rope_theta is optional — omitting it uses the model’s default_theta (typically 10000.0)
}

# If we want to apply a scaled RoPE type, we need to pass extra parameters
config.rope_parameters = {
"rope_type": "linear",
"rope_theta": 10000.0,
"rope_theta": 10000.0, # can be omitted to fall back to default_theta
"factor": 8.0 # scale factor for context extension
}
```
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,3 +1548,7 @@ def is_sliding(self):
@property
def is_compileable(self) -> bool:
return self.self_attention_cache.is_compileable


# Deprecated alias: SlidingWindowCache was removed in transformers v5. StaticCache is the replacement.
SlidingWindowCache = StaticCache
5 changes: 4 additions & 1 deletion src/transformers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,10 @@ def __post_init__(self, **kwargs):
def __init_subclass__(cls, *args, **kwargs):
super().__init_subclass__(*args, **kwargs)
cls_has_custom_init = "__init__" in cls.__dict__
cls = dataclass(cls, repr=False)
# kw_only=True ensures fields without defaults in subclasses can follow
# parent fields that have defaults (Python dataclass ordering rule).
# Config fields are always passed as keyword arguments, so this is safe.
cls = dataclass(cls, repr=False, kw_only=True)

if not cls_has_custom_init:
# Wrap all subclasses to accept arbitrary kwargs for BC
Expand Down
31 changes: 31 additions & 0 deletions src/transformers/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,3 +300,34 @@ def empty_func(*args, **kwargs):
finally:
# Set back the original
PreTrainedModel.tie_weights = original_tie_weights


@contextmanager
def meta_device_safe_creation_ops():
"""
During meta-device model initialisation, ``torch.linspace`` produces meta
tensors that have no data. Custom models loaded from the Hub (remote code)
often call ``.item()`` on these tensors to compute scalar hyperparameters
(e.g. stochastic-depth / drop-path schedules). Native transformers models
already pass ``device="cpu"`` explicitly for such calls (see e.g.
``modeling_swin.py``, ``modeling_pvt_v2.py``), but remote-code models
written before v5 do not.

This context manager patches ``torch.linspace`` to default to
``device="cpu"`` when no explicit device is requested, matching the best
practice already used throughout transformers. Calls that supply an
explicit ``device`` argument (e.g. ``device=self.logits.device``) are left
untouched. ``torch.arange`` is intentionally NOT patched because it is
used in RoPE computations where the device must match model parameters.
"""
original_linspace = torch.linspace

def _safe_linspace(*args, **kwargs):
kwargs.setdefault("device", "cpu")
return original_linspace(*args, **kwargs)

torch.linspace = _safe_linspace
try:
yield
finally:
torch.linspace = original_linspace
47 changes: 29 additions & 18 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _compute_linear_scaling_rope_parameters(
The model configuration. This function assumes that the config will provide at least the following
properties:

* rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
* rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
* hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
* num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.

Expand Down Expand Up @@ -199,7 +199,7 @@ def _compute_proportional_rope_parameters(
The model configuration. This function assumes that the config will provide at least the following
properties:

* rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
* rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
* hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
* num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.

Expand Down Expand Up @@ -268,7 +268,7 @@ def _compute_dynamic_ntk_parameters(
The model configuration. This function assumes that the config will provide at least the following
properties:

* rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
* rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
* hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
* num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
* max_position_embeddings (`int`): The default sequence length used to update the dynamic RoPE at
Expand Down Expand Up @@ -339,7 +339,7 @@ def _compute_yarn_parameters(
The model configuration. This function assumes that the config will provide at least the following
properties:

* rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
* rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
* hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
* num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
* max_position_embeddings (`int`): The maximum length of the positional embeddings.
Expand Down Expand Up @@ -474,7 +474,7 @@ def _compute_longrope_parameters(
The model configuration. This function assumes that the config will provide at least the following
properties:

* rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
* rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
* hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
* num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
* max_position_embeddings (`int`): The maximum length of the positional embeddings.
Expand Down Expand Up @@ -561,7 +561,7 @@ def _compute_llama3_parameters(
The model configuration. This function assumes that the config will provide at least the following
properties:

* rope_theta (`float`): The base wavelength from which the inverse frequencies will be derived.
* rope_theta (`float`, *optional*): The base wavelength from which the inverse frequencies will be derived. Defaults to `config.default_theta` if omitted.
* hidden_size (`int`): The numerator when deriving a head_dim, if not provided directly.
* num_attention_heads (`int`): The denominator when deriving a head_dim, if not provided directly.
* rope_parameters (`dict[str, float | int]`): The standard RoPE scaling parameters, from which the following
Expand Down Expand Up @@ -642,8 +642,9 @@ def _compute_llama3_parameters(
class RopeParameters(TypedDict):
"""
Args:
rope_theta (`float`):
The base period of the RoPE embeddings.
rope_theta (`float`, *optional*, defaults to `RotaryEmbeddingConfigMixin.default_theta`):
The base period of the RoPE embeddings. Optional in serialized configs — if omitted,
the model's `default_theta` (typically 10000.0) is used.
rope_type (`str`, *optional*, defaults to "default"):
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
'llama3'], with 'default' being the original RoPE implementation.
Expand Down Expand Up @@ -680,7 +681,7 @@ class RopeParameters(TypedDict):
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
"""

rope_theta: float
rope_theta: float | None
rope_type: str | None
partial_rotary_factor: float | None
factor: float | None
Expand Down Expand Up @@ -801,34 +802,44 @@ def validate_rope(self: "PreTrainedConfig"):
)

def _validate_default_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "rope_theta"}
required_keys = {"rope_type"}
optional_keys = {"rope_theta"}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likely needs to be done elsewhere then too? I think all need rope theta?

Comment on lines +805 to +806
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, this kinda defeats the point of validation because a RoPE dict with no theta isn't valid for our modules

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah but we always default to default_theta if its not there no?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation always happens after teh defaults are set, so ideally it shouldn't raise an error. Do we know why the theta was missing?

received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

def _validate_linear_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta"}
required_keys = {"rope_type", "factor"}
optional_keys = {"rope_theta"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

factor = rope_parameters["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

def _validate_dynamic_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta"}
required_keys = {"rope_type", "factor"}
optional_keys = {"rope_theta"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys)
self._check_received_keys(
rope_type, received_keys, required_keys, optional_keys=optional_keys, ignore_keys=ignore_keys
)

factor = rope_parameters["factor"]
if factor is None or not isinstance(factor, float) or factor < 1.0:
logger.warning(f"`rope_parameters`'s factor field must be a float >= 1, got {factor}")

def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "factor", "rope_theta", "original_max_position_embeddings"}
required_keys = {"rope_type", "factor", "original_max_position_embeddings"}
optional_keys = {
"rope_theta",
"attention_factor",
"beta_fast",
"beta_slow",
Expand Down Expand Up @@ -878,8 +889,8 @@ def _validate_yarn_rope_parameters(self, rope_parameters: dict, ignore_keys: set
)

def _validate_longrope_rope_parameters(self, rope_parameters: dict, ignore_keys: set | None = None):
required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta", "original_max_position_embeddings"}
optional_keys = {"attention_factor", "factor"}
required_keys = {"rope_type", "short_factor", "long_factor", "original_max_position_embeddings"}
optional_keys = {"rope_theta", "attention_factor", "factor"}
received_keys = set(rope_parameters.keys())
rope_type = rope_parameters["rope_type"]
self._check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys)
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3618,7 +3618,10 @@ def get_init_context(
elif is_quantized:
init_contexts.extend([torch.device("meta"), set_quantized_state()])
else:
init_contexts.append(torch.device("meta"))
# meta_device_safe_creation_ops patches torch.linspace to default to CPU
# so that custom models calling .item() during __init__ (e.g. drop-path
# schedules) don't crash on meta tensors.
init_contexts.extend([torch.device("meta"), init.meta_device_safe_creation_ops()])

return init_contexts

Expand Down Expand Up @@ -4612,7 +4615,7 @@ def mark_tied_weights_as_initialized(self, loading_info):
later as they will be tied (overwritten) anyway.
This is very important as most embeddings are tied, and they are huge params (vocabularies are often 256k), so
running inits on them is very costly."""
for tied_param in self.all_tied_weights_keys.keys():
for tied_param in getattr(self, "all_tied_weights_keys", {}).keys():
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix remote code

param = self.get_parameter(tied_param)
param._is_hf_initialized = True

Expand Down
15 changes: 6 additions & 9 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,12 @@ def forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape
input_shape = hidden_states.shape[:-1]

queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
hidden_shape = (*input_shape, -1, self.head_dim)
queries = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
keys = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
values = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
Expand All @@ -283,7 +280,7 @@ def forward(
dropout=0.0 if not self.training else self.dropout,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,15 +455,16 @@ def forward(
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""Input shape: Batch x Time x Channel"""

batch_size, seq_length, embed_dim = hidden_states.shape
input_shape = hidden_states.shape[:-1]

hidden_shape = (*input_shape, -1, self.head_dim)
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)

queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
queries = queries.view(hidden_shape).transpose(1, 2)
keys = keys.view(hidden_shape).transpose(1, 2)
values = values.view(hidden_shape).transpose(1, 2)

attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
self.config._attn_implementation, eager_attention_forward
Expand All @@ -480,7 +481,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights

Expand Down
21 changes: 10 additions & 11 deletions src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,15 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None

# determine input shapes
bsz, tgt_len = hidden_states.shape[:-1]
q_input_shape = (bsz, tgt_len, -1, self.head_dim)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)

# Scaling is susceptible to floating point arithmetics' inprecisions
# which can lead to different results (this is dependent from model
# to model, e.g. audioflamingo3 is one such case). We therefore keep the
# original order of scaling to follow the original implementation
# and enforce no scaling (1.0) in the attention call below.
query_states = self.q_proj(hidden_states) * self.scaling
query_states = query_states.view(*q_input_shape)
query_states = query_states.transpose(1, 2).contiguous()
query_states = (self.q_proj(hidden_states) * self.scaling).view(hidden_shape).transpose(1, 2).contiguous()

# Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
Expand All @@ -160,10 +157,12 @@ def forward(
key_states = past_key_values.layers[self.layer_idx].keys
value_states = past_key_values.layers[self.layer_idx].values
else:
key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim)
key_states = key_states.transpose(1, 2).contiguous()
value_states = value_states.transpose(1, 2).contiguous()
# Use the query's batch dimension for kv view so that a different-batch
# encoder output (e.g. in tests) gets absorbed into the sequence axis,
# preserving backward-compatible behaviour.
kv_shape = (input_shape[0], -1, self.num_heads, self.head_dim)
Comment thread
ArthurZucker marked this conversation as resolved.
key_states = self.k_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
value_states = self.v_proj(current_states).view(kv_shape).transpose(1, 2).contiguous()
if past_key_values is not None:
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)

Expand All @@ -183,7 +182,7 @@ def forward(
**kwargs,
)

attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.out_proj(attn_output)

return attn_output, attn_weights
Expand Down
Loading
Loading