Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
521 changes: 339 additions & 182 deletions src/transformers/modeling_rope_utils.py

Large diffs are not rendered by default.

20 changes: 12 additions & 8 deletions src/transformers/models/gemma3/configuration_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,6 @@ class Gemma3TextConfig(PretrainedConfig):
Beginning of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `True`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 1000000.0):
The base period of the RoPE embeddings.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
Expand Down Expand Up @@ -183,7 +181,6 @@ def __init__(
eos_token_id=1,
bos_token_id=2,
tie_word_embeddings=True,
rope_theta=1_000_000.0,
attention_bias=False,
attention_dropout=0.0,
query_pre_attn_scalar=256,
Expand Down Expand Up @@ -213,7 +210,6 @@ def __init__(
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
Expand All @@ -223,10 +219,6 @@ def __init__(
self.attn_logit_softcapping = attn_logit_softcapping
self.layer_types = layer_types

self.rope_local_base_freq = rope_local_base_freq
self.rope_scaling = rope_scaling
rope_config_validation(self)

# BC -> the pattern used to be a simple int, and it's still present in configs on the Hub
self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6)

Expand All @@ -237,6 +229,18 @@ def __init__(
]
layer_type_validation(self.layer_types)

# Validate the correctness of rotary position embeddings parameters
# The config was saved with a simple rope scaling dict, we need to convert to nested structure per RoPE type
rope_theta = getattr(self, "rope_theta", 1_000_000)
sliding_attention_rope = {"rope_type": "default", "rope_theta": rope_local_base_freq}
full_attention_rope = {"rope_type": "default", "rope_theta": rope_theta}
if rope_scaling is not None:
full_attention_rope.update(**rope_scaling)

rope_scaling = {"full_attention": full_attention_rope, "sliding_attention": sliding_attention_rope}
self.rope_scaling = {k: v for k, v in rope_scaling.items() if k in self.layer_types}
rope_config_validation(self)

@property
def sliding_window_pattern(self):
warnings.warn(
Expand Down
80 changes: 45 additions & 35 deletions src/transformers/models/gemma3/modeling_gemma3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional, Union
Expand All @@ -35,7 +34,7 @@
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_rope_utils import compute_rope_parameters, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import (
Expand Down Expand Up @@ -159,36 +158,63 @@ def extra_repr(self):
class Gemma3RotaryEmbedding(nn.Module):
def __init__(self, config: Gemma3TextConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
rope_inv_freqs, rope_types = compute_rope_parameters(self.config, device)
self.rope_type = rope_types
for layer_type in rope_inv_freqs:
self._update_inv_freq(rope_inv_freqs[layer_type][0], update_original=True, layer_type=layer_type)
setattr(self, f"{layer_type}_attention_scaling", rope_inv_freqs[layer_type][1])

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
if getattr(self.config, "layer_types", None) is not None:
position_embeddings = {}
for layer_type in self.config.layer_types:
position_embeddings[layer_type] = self.apply_rope(x, position_ids, layer_type=layer_type)
else:
position_embeddings = self.apply_rope(x, position_ids)
return position_embeddings

@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def apply_rope(self, x, position_ids, layer_type=None):
inv_freq, attention_scaling = self._get_inv_freq(layer_type=layer_type)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
cos = emb.cos() * attention_scaling
sin = emb.sin() * attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def _update_inv_freq(self, new_inv_freq, update_original=False, layer_type=None):
if layer_type:
inv_freq_name = f"{layer_type}_inv_freq"
original_freq_name = f"{layer_type}_original_inv_freq"
else:
inv_freq_name = "inv_freq"
original_freq_name = "original_inv_freq"

self.register_buffer(inv_freq_name, new_inv_freq, persistent=False)
if update_original:
setattr(self, original_freq_name, new_inv_freq)

def _get_inv_freq(self, layer_type=None):
if layer_type is not None:
inv_freq = getattr(self, f"{layer_type}_inv_freq")
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
else:
inv_freq = self.inv_freq
attention_scaling = self.attention_scaling

return inv_freq, attention_scaling


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down Expand Up @@ -368,8 +394,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int):
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings_global: torch.Tensor,
position_embeddings_local: torch.Tensor,
position_embeddings: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
Expand All @@ -382,12 +407,6 @@ def forward(

hidden_states = self.input_layernorm(hidden_states)

# apply global RoPE to non-sliding layer only
if self.self_attn.is_sliding:
position_embeddings = position_embeddings_local
else:
position_embeddings = position_embeddings_global

hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
position_embeddings=position_embeddings,
Expand Down Expand Up @@ -465,13 +484,6 @@ def __init__(self, config: Gemma3TextConfig):
self.rotary_emb = Gemma3RotaryEmbedding(config=config)
self.gradient_checkpointing = False

# TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas
# when we want to create a local RoPE layer. Config defaults should hold values for global RoPE
config = copy.deepcopy(config)
config.rope_theta = config.rope_local_base_freq
config.rope_scaling = {"rope_type": "default"}
self.rotary_emb_local = Gemma3RotaryEmbedding(config=config)

# Initialize weights and apply final processing
self.post_init()

Expand Down Expand Up @@ -543,8 +555,7 @@ def forward(
hidden_states = inputs_embeds

# create position embeddings to be shared across the decoder layers
position_embeddings_global = self.rotary_emb(hidden_states, position_ids)
position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids)
position_embeddings = self.rotary_emb(hidden_states, position_ids)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand All @@ -556,8 +567,7 @@ def forward(

layer_outputs = decoder_layer(
hidden_states,
position_embeddings_global=position_embeddings_global,
position_embeddings_local=position_embeddings_local,
position_embeddings=position_embeddings[decoder_layer.attention_type],
attention_mask=causal_mask_mapping[decoder_layer.attention_type],
position_ids=position_ids,
past_key_value=past_key_values,
Expand Down
55 changes: 41 additions & 14 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from ...modeling_rope_utils import compute_rope_parameters, dynamic_rope_update
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
Expand Down Expand Up @@ -72,36 +72,63 @@ def extra_repr(self):
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config: LlamaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
inv_freq, attention_scaling = compute_rope_parameters(self.config, device)
self.rope_type = config.rope_scaling_dict["rope_type"]
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
self.original_inv_freq = inv_freq
self.attention_scaling = attention_scaling

@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
if getattr(self.config, "layer_types", None) is not None:
position_embeddings = {}
for layer_type in self.config.layer_types:
position_embeddings[layer_type] = self.apply_rope(x, position_ids, layer_type=layer_type)
else:
position_embeddings = self.apply_rope(x, position_ids)
return position_embeddings

@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def apply_rope(self, x, position_ids, layer_type=None):
inv_freq, attention_scaling = self._get_inv_freq(layer_type=layer_type)
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()

device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
cos = emb.cos() * attention_scaling
sin = emb.sin() * attention_scaling

return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)

def _update_inv_freq(self, new_inv_freq, update_original=False, layer_type=None):
if layer_type:
inv_freq_name = f"{layer_type}_inv_freq"
original_freq_name = f"{layer_type}_original_inv_freq"
else:
inv_freq_name = "inv_freq"
original_freq_name = "original_inv_freq"

self.register_buffer(inv_freq_name, new_inv_freq, persistent=False)
if update_original:
setattr(self, original_freq_name, new_inv_freq)

def _get_inv_freq(self, layer_type=None):
if layer_type is not None:
inv_freq = getattr(self, f"{layer_type}_inv_freq")
attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
else:
inv_freq = self.inv_freq
attention_scaling = self.attention_scaling

return inv_freq, attention_scaling


def rotate_half(x):
"""Rotates half the hidden dims of the input."""
Expand Down
28 changes: 22 additions & 6 deletions src/transformers/models/modernbert/configuration_modernbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

from typing import Literal

from ...configuration_utils import PretrainedConfig
from ...configuration_utils import PretrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation


class ModernBertConfig(PretrainedConfig):
Expand Down Expand Up @@ -69,8 +70,6 @@ class ModernBertConfig(PretrainedConfig):
Classification token id.
sep_token_id (`int`, *optional*, defaults to 50282):
Separation token id.
global_rope_theta (`float`, *optional*, defaults to 160000.0):
The base period of the global RoPE embeddings.
attention_bias (`bool`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
Expand Down Expand Up @@ -150,7 +149,6 @@ def __init__(
bos_token_id=50281,
cls_token_id=50281,
sep_token_id=50282,
global_rope_theta=160000.0,
attention_bias=False,
attention_dropout=0.0,
global_attn_every_n_layers=3,
Expand All @@ -169,6 +167,7 @@ def __init__(
sparse_pred_ignore_index=-100,
reference_compile=None,
repad_logits_with_grad=False,
rope_scaling=None,
**kwargs,
):
super().__init__(
Expand All @@ -189,13 +188,11 @@ def __init__(
self.initializer_cutoff_factor = initializer_cutoff_factor
self.norm_eps = norm_eps
self.norm_bias = norm_bias
self.global_rope_theta = global_rope_theta
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_activation = hidden_activation
self.global_attn_every_n_layers = global_attn_every_n_layers
self.local_attention = local_attention
self.local_rope_theta = local_rope_theta
self.embedding_dropout = embedding_dropout
self.mlp_bias = mlp_bias
self.mlp_dropout = mlp_dropout
Expand All @@ -210,6 +207,25 @@ def __init__(
self.reference_compile = reference_compile
self.repad_logits_with_grad = repad_logits_with_grad

self.layer_types = [
"sliding_attention" if bool((i + 1) % self.global_attn_every_n_layers) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)

# Validate the correctness of rotary position embeddings parameters
# If the config was saved with a simple rope scaling dict, we need to convert to nested structure
# per RoPE type and raise a warning
rope_theta = getattr(self, "rope_theta", 160000)
sliding_attention_rope = {"rope_type": "default", "rope_theta": local_rope_theta}
full_attention_rope = {"rope_type": "default", "rope_theta": rope_theta}
if rope_scaling is not None:
full_attention_rope.update(**rope_scaling)

rope_scaling = {"full_attention": full_attention_rope, "sliding_attention": sliding_attention_rope}
self.rope_scaling = {k: v for k, v in rope_scaling.items() if k in self.layer_types}
rope_config_validation(self)

if self.classifier_pooling not in ["cls", "mean"]:
raise ValueError(
f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.'
Expand Down
Loading
Loading