From 2d8b9277344ae3cbaf5ff7ceabe488a5015222f8 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 31 Jul 2025 17:07:08 +0200 Subject: [PATCH 1/4] like this? --- src/transformers/modeling_rope_utils.py | 546 ++++++++++++------ .../models/gemma3/configuration_gemma3.py | 27 +- .../models/gemma3/modeling_gemma3.py | 82 +-- .../models/gemma3/modular_gemma3.py | 16 +- .../models/llama/modeling_llama.py | 55 +- .../modernbert/configuration_modernbert.py | 28 +- .../models/modernbert/modeling_modernbert.py | 103 ++-- .../models/modernbert/modular_modernbert.py | 12 +- 8 files changed, 576 insertions(+), 293 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index 59989aa5927c..d4dad0b2780c 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -27,6 +27,18 @@ import torch +def extract_rope_type_from_config(*args, **kwargs): + pass + + +def _get_rope_scaling_dict(config, layer_type: str) -> dict: + """Get the RoPE scaling dictionary for the specified layer.""" + rope_scaling_dict = config.rope_scaling_dict + if layer_type is not None: + rope_scaling_dict = rope_scaling_dict[layer_type] + return rope_scaling_dict + + def dynamic_rope_update(rope_forward): """ Decorator function to update the RoPE parameters in the forward pass, if the model is using a dynamic RoPE @@ -40,7 +52,7 @@ def dynamic_rope_update(rope_forward): The decorated forward pass. """ - def longrope_frequency_update(self, position_ids, device): + def longrope_frequency_update(self, position_ids, device, layer_type=None): """Longrope uses long factor if sequence is larger than original pretraining length, short otherwise.""" seq_len = torch.max(position_ids) + 1 if hasattr(self.config, "original_max_position_embeddings"): @@ -48,18 +60,28 @@ def longrope_frequency_update(self, position_ids, device): else: original_max_position_embeddings = self.config.max_position_embeddings if seq_len > original_max_position_embeddings: - if not hasattr(self, "long_inv_freq"): - self.long_inv_freq, _ = self.rope_init_fn( - self.config, device, seq_len=original_max_position_embeddings + 1 + if not hasattr(self, f"{layer_type}_long_inv_freq"): + rope_type = self.rope_type if hasattr(self, "rope_type") else self.rope_types[layer_type] + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + rope_scaling_dict = _get_rope_scaling_dict(self.config, layer_type=layer_type) + long_inv_freq, _ = rope_init_fn( + self.config, + device, + seq_len=original_max_position_embeddings + 1, + rope_scaling_dict=rope_scaling_dict, ) - self.register_buffer("inv_freq", self.long_inv_freq, persistent=False) + self._update_inv_freq(long_inv_freq, layer_type=None) + setattr(self, f"{layer_type}_long_inv_freq", long_inv_freq) else: # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + if layer_type is not None: + original_inv_freq = getattr(f"{layer_type}_original_inv_freq").to(device) + else: + original_inv_freq = self.original_inv_freq.to(device) + self._update_inv_freq(original_inv_freq, update_original=True, layer_type=layer_type) - def dynamic_frequency_update(self, position_ids, device): + def dynamic_frequency_update(self, position_ids, device, layer_type=None): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: 1 - growing beyond the cached sequence length (allow scaling) @@ -67,24 +89,37 @@ def dynamic_frequency_update(self, position_ids, device): """ seq_len = torch.max(position_ids) + 1 if seq_len > self.max_seq_len_cached: # growth - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len) - self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation + rope_type = self.rope_type if hasattr(self, "rope_type") else self.rope_types[layer_type] + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_type] + rope_scaling_dict = _get_rope_scaling_dict(self.config, layer_type=layer_type) + inv_freq, self.attention_scaling = rope_init_fn( + self.config, + device, + seq_len=seq_len, + rope_scaling_dict=rope_scaling_dict, + ) + # TODO joao: may break with compilation + self._update_inv_freq(inv_freq, layer_type=None) self.max_seq_len_cached = seq_len if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset # This .to() is needed if the model has been moved to a device after being initialized (because # the buffer is automatically moved, but not the original copy) - self.original_inv_freq = self.original_inv_freq.to(device) - self.register_buffer("inv_freq", self.original_inv_freq, persistent=False) + if layer_type is not None: + original_inv_freq = getattr(f"{layer_type}_original_inv_freq").to(device) + else: + original_inv_freq = self.original_inv_freq.to(device) + self._update_inv_freq(original_inv_freq, update_original=True, layer_type=layer_type) self.max_seq_len_cached = self.original_max_seq_len @wraps(rope_forward) - def wrapper(self, x, position_ids): - if "dynamic" in self.rope_type: - dynamic_frequency_update(self, position_ids, device=x.device) - elif self.rope_type == "longrope": - longrope_frequency_update(self, position_ids, device=x.device) - return rope_forward(self, x, position_ids) + def wrapper(self, x, position_ids, layer_type=None): + rope_type = self.rope_type if hasattr(self, "rope_type") else self.rope_types[layer_type] + if "dynamic" in rope_type: + dynamic_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) + elif rope_type == "longrope": + longrope_frequency_update(self, position_ids, device=x.device, layer_type=layer_type) + return rope_forward(self, x, position_ids, layer_type=layer_type) return wrapper @@ -93,6 +128,7 @@ def _compute_default_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, + rope_scaling_dict: Optional[dict] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies according to the original RoPE implementation @@ -103,12 +139,16 @@ def _compute_default_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. + is_global (`bool`, *optional*, defaults to `True`): + Whether to use global or local rope theta from config. For local rope theta, + the config object should have attribute - `local_rope_theta`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - base = config.rope_theta - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + + base = rope_scaling_dict["rope_theta"] + partial_rotary_factor = rope_scaling_dict.get("partial_rotary_factor", 1.0) head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads dim = int(head_dim * partial_rotary_factor) @@ -123,6 +163,7 @@ def _compute_linear_scaling_rope_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, + rope_scaling_dict: Optional[dict] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with linear scaling. Credits to the Reddit user /u/kaiokendev @@ -133,14 +174,17 @@ def _compute_linear_scaling_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. + is_global (`bool`, *optional*, defaults to `True`): + Whether to use global or local rope theta from config. For local rope theta, + the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ - factor = config.rope_scaling["factor"] + factor = rope_scaling_dict["factor"] # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len) + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, rope_scaling_dict) # Then applies linear scaling to the frequencies. # NOTE: originally, scaling was applied to the position_ids. However, we get `embs = inv_freq @ position_ids`, so @@ -153,6 +197,7 @@ def _compute_dynamic_ntk_parameters( config: Optional[PretrainedConfig] = None, device: Optional["torch.device"] = None, seq_len: Optional[int] = None, + rope_scaling_dict: Optional[dict] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla @@ -163,17 +208,20 @@ def _compute_dynamic_ntk_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length, used to update the dynamic RoPE at inference time. + is_global (`bool`, *optional*, defaults to `True`): + Whether to use global or local rope theta from config. For local rope theta, + the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - base = config.rope_theta + base = rope_scaling_dict["rope_theta"] partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) max_position_embeddings = config.max_position_embeddings - factor = config.rope_scaling["factor"] + factor = rope_scaling_dict["factor"] attention_factor = 1.0 # Unused in this type of RoPE @@ -195,7 +243,10 @@ def _compute_dynamic_ntk_parameters( def _compute_yarn_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None + config: PretrainedConfig, + device: "torch.device", + seq_len: Optional[int] = None, + rope_scaling_dict: Optional[dict] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with NTK scaling. Please refer to the @@ -207,25 +258,28 @@ def _compute_yarn_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. + is_global (`bool`, *optional*, defaults to `True`): + Whether to use global or local rope theta from config. For local rope theta, + the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ - - base = config.rope_theta + base = rope_scaling_dict["rope_theta"] partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) - factor = config.rope_scaling["factor"] - attention_factor = config.rope_scaling.get("attention_factor") - mscale = config.rope_scaling.get("mscale") - mscale_all_dim = config.rope_scaling.get("mscale_all_dim") + + factor = rope_scaling_dict["factor"] + attention_factor = rope_scaling_dict.get("attention_factor") + mscale = rope_scaling_dict.get("mscale") + mscale_all_dim = rope_scaling_dict.get("mscale_all_dim") # NOTE: DeekSeek-V3 (and potentially other models) modify `max_position_embeddings` and have a # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two # values to compute the default attention scaling factor, instead of using `factor`. - if "original_max_position_embeddings" in config.rope_scaling: - original_max_position_embeddings = config.rope_scaling["original_max_position_embeddings"] + if "original_max_position_embeddings" in rope_scaling_dict: + original_max_position_embeddings = rope_scaling_dict["original_max_position_embeddings"] factor = config.max_position_embeddings / original_max_position_embeddings else: original_max_position_embeddings = config.max_position_embeddings @@ -244,8 +298,8 @@ def get_mscale(scale, mscale=1): # Optional config options # beta_fast/beta_slow: as suggested in the paper, default to 32/1 (correspondingly) - beta_fast = config.rope_scaling.get("beta_fast") or 32 - beta_slow = config.rope_scaling.get("beta_slow") or 1 + beta_fast = rope_scaling_dict.get("beta_fast") or 32 + beta_slow = rope_scaling_dict.get("beta_slow") or 1 # Compute the inverse frequencies def find_correction_dim(num_rotations, dim, base, max_position_embeddings): @@ -284,7 +338,10 @@ def linear_ramp_factor(min, max, dim): def _compute_longrope_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None + config: PretrainedConfig, + device: "torch.device", + seq_len: Optional[int] = None, + rope_scaling_dict: Optional[dict] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies with LongRoPE scaling. Please refer to the @@ -296,19 +353,23 @@ def _compute_longrope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. + is_global (`bool`, *optional*, defaults to `True`): + Whether to use global or local rope theta from config. For local rope theta, + the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # TODO (joao): use the new `original_max_position_embeddings` from rope_scaling - base = config.rope_theta + base = rope_scaling_dict["rope_theta"] partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) dim = int(head_dim * partial_rotary_factor) - long_factor = config.rope_scaling["long_factor"] - short_factor = config.rope_scaling["short_factor"] - factor = config.rope_scaling.get("factor") - attention_factor = config.rope_scaling.get("attention_factor") + + long_factor = rope_scaling_dict["long_factor"] + short_factor = rope_scaling_dict["short_factor"] + factor = rope_scaling_dict.get("factor") + attention_factor = rope_scaling_dict.get("attention_factor") # NOTE: Phi3 (and potentially other models) modify `max_position_embeddings` and have a # `original_max_position_embeddings` field containing the pretrained value. They use the ratio between these two @@ -338,7 +399,10 @@ def _compute_longrope_parameters( def _compute_llama3_parameters( - config: PretrainedConfig, device: "torch.device", seq_len: Optional[int] = None + config: PretrainedConfig, + device: "torch.device", + seq_len: Optional[int] = None, + rope_scaling_dict: Optional[dict] = None, ) -> tuple["torch.Tensor", float]: """ Computes the inverse frequencies for llama 3.1. @@ -350,17 +414,20 @@ def _compute_llama3_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. + is_global (`bool`, *optional*, defaults to `True`): + Whether to use global or local rope theta from config. For local rope theta, + the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. """ # Gets the default RoPE parameters - inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len) + inv_freq, attention_factor = _compute_default_rope_parameters(config, device, seq_len, rope_scaling_dict) - factor = config.rope_scaling["factor"] # `8` in the original implementation - low_freq_factor = config.rope_scaling["low_freq_factor"] # `1` in the original implementation - high_freq_factor = config.rope_scaling["high_freq_factor"] # `4` in the original implementation - old_context_len = config.rope_scaling["original_max_position_embeddings"] # `8192` in the original implementation + factor = rope_scaling_dict["factor"] # `8` in the original implementation + low_freq_factor = rope_scaling_dict["low_freq_factor"] # `1` in the original implementation + high_freq_factor = rope_scaling_dict["high_freq_factor"] # `4` in the original implementation + old_context_len = rope_scaling_dict["original_max_position_embeddings"] # `8192` in the original implementation low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor @@ -391,6 +458,77 @@ def _compute_llama3_parameters( } +def extract_rope_scaling_dict_from_config(config): + "Helper to extract the rope type from config, while handling BC and local/global keys" + + # The RoPE scaling dict might be serialized differently in older versions, which we need to support. + # Case 1: `config.rope_scaling` is a simple dict with values to configure RoPE type. Deprecated. + # Case 2: `config.rope_scaling` is a dict where keys define different RoPE configurations used by the model. + # For example `rope_scaling={"global": {}, "local": {}}` to alternate between global and local attention layers. + rope_scaling_dict = getattr(config, "rope_scaling", None) + + if getattr(config, "layer_types", None) is not None: + if rope_scaling_dict is not None and ("type" in rope_scaling_dict or "rope_type" in rope_scaling_dict): + # if there is a 'type' field, copy it it to 'rope_type'. + if "type" in rope_scaling_dict: + rope_scaling_dict["rope_type"] = rope_scaling_dict.pop("type") + rope_scaling_dict["rope_theta"] = config.rope_theta + rope_scaling_dict = {"full_attention": rope_scaling_dict} + elif rope_scaling_dict is None: + rope_scaling_dict = {"full_attention": {"rope_type": "default", "rope_theta": config.rope_theta}} + else: + for rope_key in rope_scaling_dict: + if "type" in rope_scaling_dict[rope_key]: + rope_scaling_dict[rope_key]["rope_type"] = rope_scaling_dict[rope_key].pop("type") + else: + if rope_scaling_dict is None: + rope_scaling_dict = {"rope_type": "default", "rope_theta": config.rope_theta} + else: + if "type" in rope_scaling_dict: + rope_scaling_dict["rope_type"] = rope_scaling_dict.pop("type") + if "rope_theta" not in rope_scaling_dict: + rope_scaling_dict["rope_theta"] = config.rope_theta + config.rope_scaling_dict = rope_scaling_dict + return rope_scaling_dict + + +def compute_rope_parameters( + config: Optional[PretrainedConfig] = None, + device: Optional["torch.device"] = None, + rope_config_key: Optional[str] = "global", +) -> tuple["torch.Tensor", float]: + """ + Extracts requested RoPE type from the config (e.g. "dynamic") and computes inverse frequencies. + + Args: + config ([`~transformers.PretrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + rope_config_key (`str`, *optional*, defaults to `"global"`): + RoPE type key + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + rope_scaling_dicts = extract_rope_scaling_dict_from_config(config) + + if hasattr(config, "layer_types"): + rope_inv_freqs = {} + rope_types = {} + for rope_key in rope_scaling_dicts: + rope_scaling_dict = rope_scaling_dicts[rope_key] + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_scaling_dict["rope_type"]] + inv_freq, attention_scaling = rope_init_fn(config, rope_scaling_dict=rope_scaling_dict, device=device) + rope_inv_freqs[rope_key] = (inv_freq, attention_scaling) + rope_types[rope_key] = rope_scaling_dict["rope_type"] + return rope_inv_freqs, rope_types + else: + rope_init_fn = ROPE_INIT_FUNCTIONS[rope_scaling_dicts["rope_type"]] + inv_freq, attention_scaling = rope_init_fn(config, rope_scaling_dict=rope_scaling_dicts, device=device) + return inv_freq, attention_scaling + + def _check_received_keys( rope_type: str, received_keys: set, @@ -421,161 +559,205 @@ def _check_received_keys( def _validate_default_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + rope_scaling_dict = config.rope_scaling + + if getattr(config, "layer_types") is not None: + missing_rope_keys = set(config.layer_types) - set(rope_scaling_dict.keys()) + if missing_rope_keys: + raise KeyError( + f"Missing required keys in `rope_scaling`: {missing_rope_keys}. The `rope_scaling` dict should " + "contain keys for all types in `config.layer_types`" + ) + else: + rope_scaling_dict = {"full_attention": rope_scaling_dict} + + for dictionary in rope_scaling_dict.values(): + # BC: "rope_type" was originally "type" + rope_type = dictionary.get("rope_type", dictionary.get("type", None)) + required_keys = {"rope_type", "rope_theta"} + received_keys = set(dictionary.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) def _validate_linear_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + for key in ["rope_scaling", "local_rope_scaling"]: + rope_scaling = getattr(config, key, None) + if rope_scaling is not None: + rope_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", None) + ) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "rope_theta"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_dynamic_scaling_rope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor"} - # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` - optional_keys = {"original_max_position_embeddings"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + for key in ["rope_scaling", "local_rope_scaling"]: + rope_scaling = getattr(config, key, None) + if rope_scaling is not None: + rope_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", None) + ) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "rope_theta"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") def _validate_yarn_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor"} - optional_keys = { - "attention_factor", - "beta_fast", - "beta_slow", - "original_max_position_embeddings", - "mscale", - "mscale_all_dim", - } - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): - logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" - ) - beta_fast = rope_scaling.get("beta_fast") - if beta_fast is not None and not isinstance(beta_fast, float): - logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") - beta_slow = rope_scaling.get("beta_slow") - if beta_slow is not None and not isinstance(beta_slow, float): - logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") - - if (beta_fast or 32) < (beta_slow or 1): - logger.warning( - f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " - f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" - ) + for key in ["rope_scaling", "local_rope_scaling"]: + rope_scaling = getattr(config, key, None) + if rope_scaling is not None: + rope_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", None) + ) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "factor", "rope_theta"} + optional_keys = { + "attention_factor", + "beta_fast", + "beta_slow", + "original_max_position_embeddings", + "mscale", + "mscale_all_dim", + } + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None and (not isinstance(attention_factor, float) or attention_factor < 0): + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + beta_fast = rope_scaling.get("beta_fast") + if beta_fast is not None and not isinstance(beta_fast, float): + logger.warning(f"`rope_scaling`'s beta_fast field must be a float, got {beta_fast}") + beta_slow = rope_scaling.get("beta_slow") + if beta_slow is not None and not isinstance(beta_slow, float): + logger.warning(f"`rope_scaling`'s beta_slow field must be a float, got {beta_slow}") + + if (beta_fast or 32) < (beta_slow or 1): + logger.warning( + f"`rope_scaling`'s beta_fast field must be greater than beta_slow, got beta_fast={beta_fast} " + f"(defaults to 32 if None) and beta_slow={beta_slow} (defaults to 1 if None)" + ) def _validate_longrope_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "short_factor", "long_factor"} - # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` - optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) - - partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 - head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) - dim = int(head_dim * partial_rotary_factor) + for key in ["rope_scaling", "local_rope_scaling"]: + rope_scaling = getattr(config, key, None) + if rope_scaling is not None: + rope_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", None) + ) # BC: "rope_type" was originally "type" + required_keys = {"rope_type", "short_factor", "long_factor", "rope_theta"} + # TODO (joao): update logic for the inclusion of `original_max_position_embeddings` + optional_keys = {"attention_factor", "factor", "original_max_position_embeddings"} + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, optional_keys, ignore_keys=ignore_keys) + + partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0 + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + dim = int(head_dim * partial_rotary_factor) + + short_factor = rope_scaling.get("short_factor") + if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): + logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") + if not len(short_factor) == dim // 2: + logger.warning( + f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}" + ) - short_factor = rope_scaling.get("short_factor") - if not isinstance(short_factor, list) and all(isinstance(x, (int, float)) for x in short_factor): - logger.warning(f"`rope_scaling`'s short_factor field must be a list of numbers, got {short_factor}") - if not len(short_factor) == dim // 2: - logger.warning(f"`rope_scaling`'s short_factor field must have length {dim // 2}, got {len(short_factor)}") - - long_factor = rope_scaling.get("long_factor") - if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): - logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") - if not len(long_factor) == dim // 2: - logger.warning(f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}") - - # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over - # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is - # unique to longrope (= undesirable) - if hasattr(config, "original_max_position_embeddings"): - logger.warning_once( - "This model has set a `original_max_position_embeddings` field, to be used together with " - "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" - "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " - "as it is compatible with most model architectures." - ) - else: - factor = rope_scaling.get("factor") - if factor is None: - logger.warning("Missing required keys in `rope_scaling`: 'factor'") - elif not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - attention_factor = rope_scaling.get("attention_factor") - if attention_factor is not None: - if not isinstance(attention_factor, float) or attention_factor < 0.0: + long_factor = rope_scaling.get("long_factor") + if not isinstance(long_factor, list) and all(isinstance(x, (int, float)) for x in long_factor): + logger.warning(f"`rope_scaling`'s long_factor field must be a list of numbers, got {long_factor}") + if not len(long_factor) == dim // 2: logger.warning( - f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + f"`rope_scaling`'s long_factor field must have length {dim // 2}, got {len(long_factor)}" ) + # Handle Phi3 divergence: prefer the use of `attention_factor` and/or `factor` over + # `original_max_position_embeddings` to compute internal variables. The latter lives outside `rope_scaling` and is + # unique to longrope (= undesirable) + if hasattr(config, "original_max_position_embeddings"): + logger.warning_once( + "This model has set a `original_max_position_embeddings` field, to be used together with " + "`max_position_embeddings` to determine a scaling factor. Please set the `factor` field of `rope_scaling`" + "with this ratio instead -- we recommend the use of this field over `original_max_position_embeddings`, " + "as it is compatible with most model architectures." + ) + else: + factor = rope_scaling.get("factor") + if factor is None: + logger.warning("Missing required keys in `rope_scaling`: 'factor'") + elif not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + attention_factor = rope_scaling.get("attention_factor") + if attention_factor is not None: + if not isinstance(attention_factor, float) or attention_factor < 0.0: + logger.warning( + f"`rope_scaling`'s attention_factor field must be a float greater than 0, got {attention_factor}" + ) + def _validate_llama3_parameters(config: PretrainedConfig, ignore_keys: Optional[set] = None): - rope_scaling = config.rope_scaling - rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None)) # BC: "rope_type" was originally "type" - required_keys = {"rope_type", "factor", "original_max_position_embeddings", "low_freq_factor", "high_freq_factor"} - received_keys = set(rope_scaling.keys()) - _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) - - factor = rope_scaling["factor"] - if factor is None or not isinstance(factor, float) or factor < 1.0: - logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") - - low_freq_factor = rope_scaling["low_freq_factor"] - high_freq_factor = rope_scaling["high_freq_factor"] - if low_freq_factor is None or not isinstance(low_freq_factor, float): - logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") - if high_freq_factor is None or not isinstance(high_freq_factor, float): - logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") - if high_freq_factor <= low_freq_factor: - logger.warning( - "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" - f"{high_freq_factor} and low_freq_factor={low_freq_factor}" - ) + for key in ["rope_scaling", "local_rope_scaling"]: + rope_scaling = getattr(config, key, None) + if rope_scaling is not None: + rope_type = rope_scaling.get( + "rope_type", rope_scaling.get("type", None) + ) # BC: "rope_type" was originally "type" + required_keys = { + "rope_type", + "factor", + "original_max_position_embeddings", + "low_freq_factor", + "high_freq_factor", + "rope_theta", + } + received_keys = set(rope_scaling.keys()) + _check_received_keys(rope_type, received_keys, required_keys, ignore_keys=ignore_keys) + + factor = rope_scaling["factor"] + if factor is None or not isinstance(factor, float) or factor < 1.0: + logger.warning(f"`rope_scaling`'s factor field must be a float >= 1, got {factor}") + + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + if low_freq_factor is None or not isinstance(low_freq_factor, float): + logger.warning(f"`rope_scaling`'s low_freq_factor field must be a float, got {low_freq_factor}") + if high_freq_factor is None or not isinstance(high_freq_factor, float): + logger.warning(f"`rope_scaling`'s high_freq_factor field must be a float, got {high_freq_factor}") + if high_freq_factor <= low_freq_factor: + logger.warning( + "`rope_scaling`'s high_freq_factor field must be greater than low_freq_factor, got high_freq_factor=" + f"{high_freq_factor} and low_freq_factor={low_freq_factor}" + ) - original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] - if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): - logger.warning( - "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " - f"{original_max_position_embeddings}" - ) - if original_max_position_embeddings >= config.max_position_embeddings: - logger.warning( - "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " - f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" - ) + original_max_position_embeddings = rope_scaling["original_max_position_embeddings"] + if original_max_position_embeddings is None or not isinstance(original_max_position_embeddings, int): + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be an integer, got " + f"{original_max_position_embeddings}" + ) + if original_max_position_embeddings >= config.max_position_embeddings: + logger.warning( + "`rope_scaling`'s original_max_position_embeddings field must be less than max_position_embeddings, got " + f"{original_max_position_embeddings} and max_position_embeddings={config.max_position_embeddings}" + ) # Like `ROPE_INIT_FUNCTIONS`, this validation function mapping can be dynamically updated for custom RoPE types. diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index c0184c1993d3..674f68f22f8b 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -81,8 +81,6 @@ class Gemma3TextConfig(PretrainedConfig): Beginning of stream token id. tie_word_embeddings (`bool`, *optional*, defaults to `True`): Whether to tie weight embeddings - rope_theta (`float`, *optional*, defaults to 1000000.0): - The base period of the RoPE embeddings. attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -134,6 +132,9 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + local_rope_scaling (`Dict`, *optional*): + Dictionary equivalent to `config.rope_scaling` containing the scaling configuration for the RoPE embeddings used + in local attention. rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. @@ -164,6 +165,7 @@ class Gemma3TextConfig(PretrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } + attribute_map = {"local_rope_theta": "rope_local_base_freq"} def __init__( self, @@ -183,7 +185,6 @@ def __init__( eos_token_id=1, bos_token_id=2, tie_word_embeddings=True, - rope_theta=1_000_000.0, attention_bias=False, attention_dropout=0.0, query_pre_attn_scalar=256, @@ -192,6 +193,7 @@ def __init__( final_logit_softcapping=None, attn_logit_softcapping=None, rope_scaling=None, + local_rope_scaling=None, rope_local_base_freq=10_000.0, **kwargs, ): @@ -213,7 +215,6 @@ def __init__( self.initializer_range = initializer_range self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache - self.rope_theta = rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation @@ -223,10 +224,6 @@ def __init__( self.attn_logit_softcapping = attn_logit_softcapping self.layer_types = layer_types - self.rope_local_base_freq = rope_local_base_freq - self.rope_scaling = rope_scaling - rope_config_validation(self) - # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub self._sliding_window_pattern = kwargs.get("sliding_window_pattern", 6) @@ -237,6 +234,20 @@ def __init__( ] layer_type_validation(self.layer_types) + # Validate the correctness of rotary position embeddings parameters + # If the config was saved with a simple rope scaling dict, we need to convert to nested structure + # per RoPE type and raise a warning + rope_theta = getattr(self, "rope_theta", 1_000_000) + local_rope_scaling = local_rope_scaling if local_rope_scaling is not None else {"rope_type": "default"} + sliding_attention_rope = {"rope_theta": rope_local_base_freq, **local_rope_scaling} + full_attention_rope = {"rope_type": "default", "rope_theta": rope_theta} + if rope_scaling is not None: + full_attention_rope.update(**rope_scaling) + + rope_scaling = {"full_attention": full_attention_rope, "sliding_attention": sliding_attention_rope} + self.rope_scaling = {k: v for k, v in rope_scaling.items() if k in self.layer_types} + rope_config_validation(self) + @property def sliding_window_pattern(self): warnings.warn( diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 76eb0b697bda..a37cfadaa3d3 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -19,7 +19,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy from collections.abc import Callable from dataclasses import dataclass from typing import Optional, Union @@ -35,7 +34,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_rope_utils import compute_rope_parameters, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ( @@ -157,38 +156,65 @@ def extra_repr(self): class Gemma3RotaryEmbedding(nn.Module): - def __init__(self, config: Gemma3TextConfig, device=None): + def __init__(self, config: Gemma3TextConfig, device=None, is_global=True): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + rope_inv_freqs, rope_types = compute_rope_parameters(self.config, device) + self.rope_type = rope_types + for layer_type in rope_inv_freqs: + self._update_inv_freq(rope_inv_freqs[layer_type][0], update_original=True, layer_type=layer_type) + setattr(self, f"{layer_type}_attention_scaling", rope_inv_freqs[layer_type][1]) @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + if getattr(self.config, "layer_types", None) is not None: + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.apply_rope(x, position_ids, layer_type=layer_type) + else: + position_embeddings = self.apply_rope(x, position_ids) + return position_embeddings + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def apply_rope(self, x, position_ids, layer_type=None): + inv_freq, attention_scaling = self._get_inv_freq(layer_type=layer_type) + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def _update_inv_freq(self, new_inv_freq, update_original=False, layer_type=None): + if layer_type: + inv_freq_name = f"{layer_type}_inv_freq" + original_freq_name = f"{layer_type}_original_inv_freq" + else: + inv_freq_name = "inv_freq" + original_freq_name = "original_inv_freq" + + self.register_buffer(inv_freq_name, new_inv_freq, persistent=False) + if update_original: + setattr(self, original_freq_name, new_inv_freq) + + def _get_inv_freq(self, layer_type=None): + if layer_type is not None: + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + else: + inv_freq = self.inv_freq + attention_scaling = self.attention_scaling + + return inv_freq, attention_scaling + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -368,8 +394,7 @@ def __init__(self, config: Gemma3TextConfig, layer_idx: int): def forward( self, hidden_states: torch.Tensor, - position_embeddings_global: torch.Tensor, - position_embeddings_local: torch.Tensor, + position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, @@ -382,12 +407,6 @@ def forward( hidden_states = self.input_layernorm(hidden_states) - # apply global RoPE to non-sliding layer only - if self.self_attn.is_sliding: - position_embeddings = position_embeddings_local - else: - position_embeddings = position_embeddings_global - hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, position_embeddings=position_embeddings, @@ -465,13 +484,6 @@ def __init__(self, config: Gemma3TextConfig): self.rotary_emb = Gemma3RotaryEmbedding(config=config) self.gradient_checkpointing = False - # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas - # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE - config = copy.deepcopy(config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default"} - self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) - # Initialize weights and apply final processing self.post_init() @@ -543,8 +555,7 @@ def forward( hidden_states = inputs_embeds # create position embeddings to be shared across the decoder layers - position_embeddings_global = self.rotary_emb(hidden_states, position_ids) - position_embeddings_local = self.rotary_emb_local(hidden_states, position_ids) + position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -556,8 +567,7 @@ def forward( layer_outputs = decoder_layer( hidden_states, - position_embeddings_global=position_embeddings_global, - position_embeddings_local=position_embeddings_local, + position_embeddings=position_embeddings[decoder_layer.attention_type], attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index c2ad52f10809..82eb1a05210a 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import warnings from collections.abc import Callable from typing import Any, Optional, Union @@ -159,6 +158,9 @@ class Gemma3TextConfig(Gemma2Config, PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE + local_rope_scaling (`Dict`, *optional*): + Dictionary equivalent to `config.rope_scaling` containing the scaling configuration for the RoPE embeddings used + in local attention. rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. @@ -174,6 +176,7 @@ class Gemma3TextConfig(Gemma2Config, PretrainedConfig): """ model_type = "gemma3_text" + attribute_map = {"local_rope_theta": "rope_local_base_freq"} def __init__( self, @@ -202,6 +205,7 @@ def __init__( final_logit_softcapping=None, attn_logit_softcapping=None, rope_scaling=None, + local_rope_scaling=None, rope_local_base_freq=10_000.0, **kwargs, ): @@ -235,6 +239,7 @@ def __init__( self.rope_local_base_freq = rope_local_base_freq self.rope_scaling = rope_scaling + self.local_rope_scaling = local_rope_scaling rope_config_validation(self) # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub @@ -386,7 +391,7 @@ def __init__(self, dim: int, eps: float = 1e-6): class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): - def __init__(self, config: Gemma3TextConfig, device=None): + def __init__(self, config: Gemma3TextConfig, device=None, is_global=True): super().__init__(config) @@ -543,12 +548,7 @@ def __init__(self, config: Gemma3TextConfig): config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 ) - # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas - # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE - config = copy.deepcopy(config) - config.rope_theta = config.rope_local_base_freq - config.rope_scaling = {"rope_type": "default"} - self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config, is_global=False) def forward( self, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 11bcb93f8666..b04ae1040afe 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -37,7 +37,7 @@ BaseModelOutputWithPast, CausalLMOutputWithPast, ) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_rope_utils import compute_rope_parameters, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging @@ -72,36 +72,63 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__(self, config: LlamaConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + inv_freq, attention_scaling = compute_rope_parameters(self.config, device) + self.rope_type = config.rope_scaling_dict["rope_type"] self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + self.original_inv_freq = inv_freq + self.attention_scaling = attention_scaling @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + if getattr(self.config, "layer_types", None) is not None: + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.apply_rope(x, position_ids, layer_type=layer_type) + else: + position_embeddings = self.apply_rope(x, position_ids) + return position_embeddings + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def apply_rope(self, x, position_ids, layer_type=None): + inv_freq, attention_scaling = self._get_inv_freq(layer_type=layer_type) + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def _update_inv_freq(self, new_inv_freq, update_original=False, layer_type=None): + if layer_type: + inv_freq_name = f"{layer_type}_inv_freq" + original_freq_name = f"{layer_type}_original_inv_freq" + else: + inv_freq_name = "inv_freq" + original_freq_name = "original_inv_freq" + + self.register_buffer(inv_freq_name, new_inv_freq, persistent=False) + if update_original: + setattr(self, original_freq_name, new_inv_freq) + + def _get_inv_freq(self, layer_type=None): + if layer_type is not None: + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + else: + inv_freq = self.inv_freq + attention_scaling = self.attention_scaling + + return inv_freq, attention_scaling + def rotate_half(x): """Rotates half the hidden dims of the input.""" diff --git a/src/transformers/models/modernbert/configuration_modernbert.py b/src/transformers/models/modernbert/configuration_modernbert.py index 3b0da20ad203..1748f1ff9b3c 100644 --- a/src/transformers/models/modernbert/configuration_modernbert.py +++ b/src/transformers/models/modernbert/configuration_modernbert.py @@ -21,7 +21,8 @@ from typing import Literal -from ...configuration_utils import PretrainedConfig +from ...configuration_utils import PretrainedConfig, layer_type_validation +from ...modeling_rope_utils import rope_config_validation class ModernBertConfig(PretrainedConfig): @@ -69,8 +70,6 @@ class ModernBertConfig(PretrainedConfig): Classification token id. sep_token_id (`int`, *optional*, defaults to 50282): Separation token id. - global_rope_theta (`float`, *optional*, defaults to 160000.0): - The base period of the global RoPE embeddings. attention_bias (`bool`, *optional*, defaults to `False`): Whether to use a bias in the query, key, value and output projection layers during self-attention. attention_dropout (`float`, *optional*, defaults to 0.0): @@ -150,7 +149,6 @@ def __init__( bos_token_id=50281, cls_token_id=50281, sep_token_id=50282, - global_rope_theta=160000.0, attention_bias=False, attention_dropout=0.0, global_attn_every_n_layers=3, @@ -169,6 +167,7 @@ def __init__( sparse_pred_ignore_index=-100, reference_compile=None, repad_logits_with_grad=False, + rope_scaling=None, **kwargs, ): super().__init__( @@ -189,13 +188,11 @@ def __init__( self.initializer_cutoff_factor = initializer_cutoff_factor self.norm_eps = norm_eps self.norm_bias = norm_bias - self.global_rope_theta = global_rope_theta self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.hidden_activation = hidden_activation self.global_attn_every_n_layers = global_attn_every_n_layers self.local_attention = local_attention - self.local_rope_theta = local_rope_theta self.embedding_dropout = embedding_dropout self.mlp_bias = mlp_bias self.mlp_dropout = mlp_dropout @@ -210,6 +207,25 @@ def __init__( self.reference_compile = reference_compile self.repad_logits_with_grad = repad_logits_with_grad + self.layer_types = [ + "sliding_attention" if bool((i + 1) % self.global_attn_every_n_layers) else "full_attention" + for i in range(self.num_hidden_layers) + ] + layer_type_validation(self.layer_types) + + # Validate the correctness of rotary position embeddings parameters + # If the config was saved with a simple rope scaling dict, we need to convert to nested structure + # per RoPE type and raise a warning + rope_theta = getattr(self, "rope_theta", 160000) + sliding_attention_rope = {"rope_type": "default", "rope_theta": local_rope_theta} + full_attention_rope = {"rope_type": "default", "rope_theta": rope_theta} + if rope_scaling is not None: + full_attention_rope.update(**rope_scaling) + + rope_scaling = {"full_attention": full_attention_rope, "sliding_attention": sliding_attention_rope} + self.rope_scaling = {k: v for k, v in rope_scaling.items() if k in self.layer_types} + rope_config_validation(self) + if self.classifier_pooling not in ["cls", "mean"]: raise ValueError( f'Invalid value for `classifier_pooling`, should be either "cls" or "mean", but is {self.classifier_pooling}.' diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py index a76a6fead76d..d090ae73a676 100644 --- a/src/transformers/models/modernbert/modeling_modernbert.py +++ b/src/transformers/models/modernbert/modeling_modernbert.py @@ -38,9 +38,10 @@ SequenceClassifierOutput, TokenClassifierOutput, ) -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_rope_utils import compute_rope_parameters, dynamic_rope_update from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_flash_attn_2_available, logging +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, is_flash_attn_2_available, logging from ...utils.import_utils import is_triton_available from .configuration_modernbert import ModernBertConfig @@ -243,36 +244,63 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class ModernBertRotaryEmbedding(nn.Module): def __init__(self, config: ModernBertConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq + rope_inv_freqs, rope_types = compute_rope_parameters(self.config, device) + self.rope_type = rope_types + for layer_type in rope_inv_freqs: + self._update_inv_freq(rope_inv_freqs[layer_type][0], update_original=True, layer_type=layer_type) + setattr(self, f"{layer_type}_attention_scaling", rope_inv_freqs[layer_type][1]) @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + if getattr(self.config, "layer_types", None) is not None: + position_embeddings = {} + for layer_type in self.config.layer_types: + position_embeddings[layer_type] = self.apply_rope(x, position_ids, layer_type=layer_type) + else: + position_embeddings = self.apply_rope(x, position_ids) + return position_embeddings + + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def apply_rope(self, x, position_ids, layer_type=None): + inv_freq, attention_scaling = self._get_inv_freq(layer_type=layer_type) + inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) position_ids_expanded = position_ids[:, None, :].float() device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" with torch.autocast(device_type=device_type, enabled=False): # Force float32 freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling + cos = emb.cos() * attention_scaling + sin = emb.sin() * attention_scaling return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + def _update_inv_freq(self, new_inv_freq, update_original=False, layer_type=None): + if layer_type: + inv_freq_name = f"{layer_type}_inv_freq" + original_freq_name = f"{layer_type}_original_inv_freq" + else: + inv_freq_name = "inv_freq" + original_freq_name = "original_inv_freq" + + self.register_buffer(inv_freq_name, new_inv_freq, persistent=False) + if update_original: + setattr(self, original_freq_name, new_inv_freq) + + def _get_inv_freq(self, layer_type=None): + if layer_type is not None: + inv_freq = getattr(self, f"{layer_type}_inv_freq") + attention_scaling = getattr(self, f"{layer_type}_attention_scaling") + else: + inv_freq = self.inv_freq + attention_scaling = self.attention_scaling + + return inv_freq, attention_scaling + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -317,11 +345,12 @@ def eager_attention_forward( local_attention: tuple[int, int], bs: int, dim: int, + position_embeddings: Optional[torch.Tensor], output_attentions: Optional[bool] = False, **_kwargs, ) -> Union[tuple[torch.Tensor, torch.Tensor], tuple[torch.Tensor]]: # qkv: [batch_size, seqlen, 3, nheads, headdim] - cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) @@ -358,7 +387,7 @@ def flash_attention_forward( **_kwargs, ) -> tuple[torch.Tensor]: # (total_seqlen, 3, nheads, headdim) - qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) + qkv = module.rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) if convert_dtype: @@ -397,10 +426,11 @@ def sdpa_attention_forward( local_attention: tuple[int, int], bs: int, dim: int, + position_embeddings: Optional[torch.Tensor], **_kwargs, ) -> tuple[torch.Tensor]: # qkv: [batch_size, seqlen, 3, nheads, headdim] - cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + cos, sin = position_embeddings query, key, value = qkv.transpose(3, 1).unbind(dim=2) # query, key, value: [batch_size, heads, seq_len, head_dim] query, key = apply_rotary_pos_emb(query, key, cos, sin) @@ -462,17 +492,8 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): else: self.local_attention = (-1, -1) - max_position_embeddings = config.max_position_embeddings - if self.local_attention != (-1, -1): - rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta - max_position_embeddings = config.local_attention - if config._attn_implementation == "flash_attention_2": - self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( - dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta - ) - else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config) + self.rotary_emb = ModernBertUnpaddedRotaryEmbedding(config) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() @@ -482,6 +503,7 @@ def forward( self, hidden_states: torch.Tensor, output_attentions: Optional[bool] = False, + position_embeddings: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: qkv = self.Wqkv(hidden_states) @@ -495,11 +517,11 @@ def forward( attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation]( self, qkv=qkv, - rotary_emb=self.rotary_emb, local_attention=self.local_attention, bs=bs, dim=self.all_head_size, output_attentions=output_attentions, + position_embeddings=position_embeddings, **kwargs, ) hidden_states = attn_outputs[0] @@ -519,6 +541,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): self.attn = ModernBertAttention(config=config, layer_id=layer_id) self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) + self.attention_type = config.layer_types[layer_id] @torch.compile(dynamic=True) def compiled_mlp(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -533,6 +556,8 @@ def forward( cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None, output_attentions: Optional[bool] = False, + position_embeddings: Optional[torch.Tensor] = None, + **kwargs, ) -> torch.Tensor: attn_outputs = self.attn( self.attn_norm(hidden_states), @@ -542,6 +567,8 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, + position_embeddings=position_embeddings, + **kwargs, ) hidden_states = hidden_states + attn_outputs[0] mlp_output = ( @@ -757,6 +784,7 @@ def __init__(self, config: ModernBertConfig): [ModernBertEncoderLayer(config, layer_id) for layer_id in range(config.num_hidden_layers)] ) self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) + self.rotary_emb = ModernBertRotaryEmbedding(config=config) self.gradient_checkpointing = False self.post_init() @@ -782,6 +810,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.Tensor, ...], BaseModelOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -848,6 +877,7 @@ def forward( ) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) for encoder_layer in self.layers: if output_hidden_states: @@ -861,6 +891,8 @@ def forward( cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, output_attentions=output_attentions, + position_embeddings=position_embeddings[encoder_layer.attention_type], + **kwargs, ) hidden_states = layer_outputs[0] if output_attentions and len(layer_outputs) > 1: @@ -981,7 +1013,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.Tensor], MaskedLMOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1038,6 +1070,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) last_hidden_state = outputs[0] @@ -1113,7 +1146,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1152,6 +1185,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) last_hidden_state = outputs[0] @@ -1236,6 +1270,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1272,6 +1307,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) last_hidden_state = outputs[0] @@ -1326,7 +1362,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: r""" sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -1360,6 +1396,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) last_hidden_state = outputs[0] diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index 254b5d3163f7..b1011b3e7482 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -659,20 +659,20 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) + is_global = True if config.local_rope_theta is None else False + max_position_embeddings = config.local_attention else: self.local_attention = (-1, -1) - - max_position_embeddings = config.max_position_embeddings - if self.local_attention != (-1, -1): - rope_theta = config.global_rope_theta if config.local_rope_theta is None else config.local_rope_theta - max_position_embeddings = config.local_attention + is_global = True + max_position_embeddings = config.max_position_embeddings if config._attn_implementation == "flash_attention_2": + rope_theta = config.global_rope_theta if is_global else config.local_rope_theta self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config) + self.rotary_emb = ModernBertRotaryEmbedding(config=config, is_global=is_global) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() From 9ce72cd6805cfa2a3ffc7f781b7f449e69e94584 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 31 Jul 2025 17:33:14 +0200 Subject: [PATCH 2/4] update --- src/transformers/modeling_rope_utils.py | 55 +++++++------------------ 1 file changed, 15 insertions(+), 40 deletions(-) diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index d4dad0b2780c..a1df6a386579 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -27,10 +27,6 @@ import torch -def extract_rope_type_from_config(*args, **kwargs): - pass - - def _get_rope_scaling_dict(config, layer_type: str) -> dict: """Get the RoPE scaling dictionary for the specified layer.""" rope_scaling_dict = config.rope_scaling_dict @@ -139,9 +135,6 @@ def _compute_default_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - is_global (`bool`, *optional*, defaults to `True`): - Whether to use global or local rope theta from config. For local rope theta, - the config object should have attribute - `local_rope_theta`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). @@ -174,9 +167,6 @@ def _compute_linear_scaling_rope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - is_global (`bool`, *optional*, defaults to `True`): - Whether to use global or local rope theta from config. For local rope theta, - the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). @@ -208,9 +198,6 @@ def _compute_dynamic_ntk_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length, used to update the dynamic RoPE at inference time. - is_global (`bool`, *optional*, defaults to `True`): - Whether to use global or local rope theta from config. For local rope theta, - the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). @@ -258,9 +245,6 @@ def _compute_yarn_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - is_global (`bool`, *optional*, defaults to `True`): - Whether to use global or local rope theta from config. For local rope theta, - the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. @@ -353,9 +337,6 @@ def _compute_longrope_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. - is_global (`bool`, *optional*, defaults to `True`): - Whether to use global or local rope theta from config. For local rope theta, - the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. @@ -414,9 +395,6 @@ def _compute_llama3_parameters( The device to use for initialization of the inverse frequencies. seq_len (`int`, *optional*): The current sequence length. Unused for this type of RoPE. - is_global (`bool`, *optional*, defaults to `True`): - Whether to use global or local rope theta from config. For local rope theta, - the config object should have attributes - `local_rope_theta` and `local_rope_scaling`. Returns: Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the post-processing scaling factor applied to the computed cos/sin. @@ -462,32 +440,29 @@ def extract_rope_scaling_dict_from_config(config): "Helper to extract the rope type from config, while handling BC and local/global keys" # The RoPE scaling dict might be serialized differently in older versions, which we need to support. - # Case 1: `config.rope_scaling` is a simple dict with values to configure RoPE type. Deprecated. + # Case 1: `config.rope_scaling` is a simple dict with values to configure RoPE type. Used when the model + # has a single attention type, i.e. `config.layer_types = None` # Case 2: `config.rope_scaling` is a dict where keys define different RoPE configurations used by the model. - # For example `rope_scaling={"global": {}, "local": {}}` to alternate between global and local attention layers. + # For example `rope_scaling={"full_attention": {}, "sliding_attentionn": {}}` to alternate between global and local attention layers. rope_scaling_dict = getattr(config, "rope_scaling", None) + # Case 1, if `config.layer_types` is defined we need to check that all layer types got their own RoPE config + # We need to handle BC in case users still use simple dict RoPE params and duplicate it for all layer types if getattr(config, "layer_types", None) is not None: - if rope_scaling_dict is not None and ("type" in rope_scaling_dict or "rope_type" in rope_scaling_dict): - # if there is a 'type' field, copy it it to 'rope_type'. - if "type" in rope_scaling_dict: - rope_scaling_dict["rope_type"] = rope_scaling_dict.pop("type") - rope_scaling_dict["rope_theta"] = config.rope_theta - rope_scaling_dict = {"full_attention": rope_scaling_dict} - elif rope_scaling_dict is None: - rope_scaling_dict = {"full_attention": {"rope_type": "default", "rope_theta": config.rope_theta}} - else: - for rope_key in rope_scaling_dict: - if "type" in rope_scaling_dict[rope_key]: - rope_scaling_dict[rope_key]["rope_type"] = rope_scaling_dict[rope_key].pop("type") + if rope_scaling_dict is None: + default_rope_params = {"rope_type": "default", "rope_theta": config.rope_theta} + rope_scaling_dict = dict.fromkeys(config.layer_types, default_rope_params) + elif set(config.layer_types) != set(rope_scaling_dict.keys()): + rope_scaling_dict = dict.fromkeys(config.layer_types, rope_scaling_dict) + + # Case 2, single RoPE per model, just make sure `rope_scaling_dict` has all default attributes defined else: if rope_scaling_dict is None: rope_scaling_dict = {"rope_type": "default", "rope_theta": config.rope_theta} else: - if "type" in rope_scaling_dict: - rope_scaling_dict["rope_type"] = rope_scaling_dict.pop("type") - if "rope_theta" not in rope_scaling_dict: - rope_scaling_dict["rope_theta"] = config.rope_theta + rope_type = rope_scaling_dict.get("rope_type", rope_scaling_dict.get("type", "default")) + rope_theta = config.rope_theta if hasattr(config, "rope_theta") else rope_scaling_dict["rope_theta"] + rope_scaling_dict.update({"rope_type": rope_type, "rope_theta": rope_theta}) config.rope_scaling_dict = rope_scaling_dict return rope_scaling_dict From 62a218fe150cbfbe4cf17ed285d7ee7a11fa81fd Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 31 Jul 2025 17:35:54 +0200 Subject: [PATCH 3/4] revert this --- src/transformers/models/gemma3/configuration_gemma3.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index 674f68f22f8b..a2ebd9367f62 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -132,9 +132,6 @@ class Gemma3TextConfig(PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - local_rope_scaling (`Dict`, *optional*): - Dictionary equivalent to `config.rope_scaling` containing the scaling configuration for the RoPE embeddings used - in local attention. rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. @@ -193,7 +190,6 @@ def __init__( final_logit_softcapping=None, attn_logit_softcapping=None, rope_scaling=None, - local_rope_scaling=None, rope_local_base_freq=10_000.0, **kwargs, ): @@ -235,11 +231,9 @@ def __init__( layer_type_validation(self.layer_types) # Validate the correctness of rotary position embeddings parameters - # If the config was saved with a simple rope scaling dict, we need to convert to nested structure - # per RoPE type and raise a warning + # The config was saved with a simple rope scaling dict, we need to convert to nested structure per RoPE type rope_theta = getattr(self, "rope_theta", 1_000_000) - local_rope_scaling = local_rope_scaling if local_rope_scaling is not None else {"rope_type": "default"} - sliding_attention_rope = {"rope_theta": rope_local_base_freq, **local_rope_scaling} + sliding_attention_rope = {"rope_type": "default", "rope_theta": rope_local_base_freq} full_attention_rope = {"rope_type": "default", "rope_theta": rope_theta} if rope_scaling is not None: full_attention_rope.update(**rope_scaling) From 5f73d3b3eb85878fac7eb2bd734c4cbf15b0e511 Mon Sep 17 00:00:00 2001 From: raushan Date: Thu, 31 Jul 2025 17:38:34 +0200 Subject: [PATCH 4/4] revert as well --- .../models/gemma3/configuration_gemma3.py | 1 - .../models/gemma3/modeling_gemma3.py | 2 +- .../models/gemma3/modular_gemma3.py | 16 +++++++------- .../models/modernbert/modular_modernbert.py | 21 ++++++++++++------- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/transformers/models/gemma3/configuration_gemma3.py b/src/transformers/models/gemma3/configuration_gemma3.py index a2ebd9367f62..ed6f4d3eee24 100644 --- a/src/transformers/models/gemma3/configuration_gemma3.py +++ b/src/transformers/models/gemma3/configuration_gemma3.py @@ -162,7 +162,6 @@ class Gemma3TextConfig(PretrainedConfig): "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), "norm": (["hidden_states"], ["hidden_states"]), } - attribute_map = {"local_rope_theta": "rope_local_base_freq"} def __init__( self, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index a37cfadaa3d3..446b6f50a722 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -156,7 +156,7 @@ def extra_repr(self): class Gemma3RotaryEmbedding(nn.Module): - def __init__(self, config: Gemma3TextConfig, device=None, is_global=True): + def __init__(self, config: Gemma3TextConfig, device=None): super().__init__() self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index 82eb1a05210a..c2ad52f10809 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import copy import warnings from collections.abc import Callable from typing import Any, Optional, Union @@ -158,9 +159,6 @@ class Gemma3TextConfig(Gemma2Config, PretrainedConfig): Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE `high_freq_factor` (`float`, *optional*): Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE - local_rope_scaling (`Dict`, *optional*): - Dictionary equivalent to `config.rope_scaling` containing the scaling configuration for the RoPE embeddings used - in local attention. rope_local_base_freq (float, *optional*, defaults to 10000.0): The base period of the RoPE embeddings for local attention. @@ -176,7 +174,6 @@ class Gemma3TextConfig(Gemma2Config, PretrainedConfig): """ model_type = "gemma3_text" - attribute_map = {"local_rope_theta": "rope_local_base_freq"} def __init__( self, @@ -205,7 +202,6 @@ def __init__( final_logit_softcapping=None, attn_logit_softcapping=None, rope_scaling=None, - local_rope_scaling=None, rope_local_base_freq=10_000.0, **kwargs, ): @@ -239,7 +235,6 @@ def __init__( self.rope_local_base_freq = rope_local_base_freq self.rope_scaling = rope_scaling - self.local_rope_scaling = local_rope_scaling rope_config_validation(self) # BC -> the pattern used to be a simple int, and it's still present in configs on the Hub @@ -391,7 +386,7 @@ def __init__(self, dim: int, eps: float = 1e-6): class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding): - def __init__(self, config: Gemma3TextConfig, device=None, is_global=True): + def __init__(self, config: Gemma3TextConfig, device=None): super().__init__(config) @@ -548,7 +543,12 @@ def __init__(self, config: Gemma3TextConfig): config.vocab_size, config.hidden_size, self.padding_idx, embed_scale=self.config.hidden_size**0.5 ) - self.rotary_emb_local = Gemma3RotaryEmbedding(config=config, is_global=False) + # TODO: raushan fix this after RoPE refactor. For now we hack it by reassigning thetas + # when we want to create a local RoPE layer. Config defaults should hold values for global RoPE + config = copy.deepcopy(config) + config.rope_theta = config.rope_local_base_freq + config.rope_scaling = {"rope_type": "default"} + self.rotary_emb_local = Gemma3RotaryEmbedding(config=config) def forward( self, diff --git a/src/transformers/models/modernbert/modular_modernbert.py b/src/transformers/models/modernbert/modular_modernbert.py index b1011b3e7482..3648e30ac942 100644 --- a/src/transformers/models/modernbert/modular_modernbert.py +++ b/src/transformers/models/modernbert/modular_modernbert.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import math from contextlib import nullcontext from typing import Literal, Optional, Union @@ -659,20 +660,21 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): if layer_id % config.global_attn_every_n_layers != 0: self.local_attention = (config.local_attention // 2, config.local_attention // 2) - is_global = True if config.local_rope_theta is None else False + rope_theta = config.local_rope_theta if config.local_rope_theta is not None else config.global_rope_theta max_position_embeddings = config.local_attention else: self.local_attention = (-1, -1) - is_global = True max_position_embeddings = config.max_position_embeddings + rope_theta = config.global_rope_theta if config._attn_implementation == "flash_attention_2": - rope_theta = config.global_rope_theta if is_global else config.local_rope_theta self.rotary_emb = ModernBertUnpaddedRotaryEmbedding( dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta ) else: - self.rotary_emb = ModernBertRotaryEmbedding(config=config, is_global=is_global) + config_copy = copy.deepcopy(config) + config_copy.rope_theta = rope_theta + self.rotary_emb = ModernBertRotaryEmbedding(config=config_copy) self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() @@ -811,7 +813,9 @@ def init_weight(module: nn.Module, std: float): if module.bias is not None: module.bias.data.zero_() - def set_attention_implementation(self, attn_implementation: Union[dict, str]): + def _check_and_adjust_attn_implementation( + self, attn_implementation: Optional[str], is_init_check: bool = False + ) -> str: """ Checks and dispatches to hhe requested attention implementation. """ @@ -820,16 +824,17 @@ def set_attention_implementation(self, attn_implementation: Union[dict, str]): # ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't # need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check. - requested_attn_implementation = self._check_attn_implementation(attn_implementation) try: attn_implementation = ( "flash_attention_2" - if requested_attn_implementation is None and self._flash_attn_2_can_dispatch() + if attn_implementation is None and self._flash_attn_2_can_dispatch() else attn_implementation ) except (ValueError, ImportError): pass - return super().set_attention_implementation(attn_implementation=attn_implementation) + return super()._check_and_adjust_attn_implementation( + attn_implementation=attn_implementation, is_init_check=is_init_check + ) def _maybe_set_compile(self): if self.config.reference_compile is False: