Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import math
import warnings
from functools import wraps
from typing import TYPE_CHECKING, Optional, TypedDict

Expand Down Expand Up @@ -656,7 +657,7 @@ def standardize_rope_params(self):
# Move `rope_theta` and `partial_rotary_factor` to the params dict, if not there yet
rope_theta = getattr(self, "rope_theta", None)
partial_rotary_factor = getattr(self, "partial_rotary_factor", None)
rope_parameters = self.rope_parameters
rope_parameters = self.rope_parameters or {}

# Case 1: RoPE param keys do not intersect with possible `layer_types` -> one global dict
if getattr(self, "layer_types", None) is None or not set(rope_parameters.keys()).issubset(self.layer_types):
Expand Down Expand Up @@ -913,3 +914,20 @@ def _check_received_keys(
unused_keys = received_keys - required_keys
if unused_keys:
logger.warning(f"Unrecognized keys in `rope_parameters` for 'rope_type'='{rope_type}': {unused_keys}")


def rope_config_validation(config: RotaryEmbeddingConfigMixin, ignore_keys: Optional[set] = None):
"""
This is a deprecated function.
It has been kept for backward compatibility with custom code models.
"""
warnings.warn(
"`rope_config_validation` is deprecated and has been removed. "
"Its functionality has been moved to RotaryEmbeddingConfigMixin.validate_rope method. "
"PreTrainedConfig inherits this class, so please call self.validate_rope() instead. "
"Also, make sure to use the new rope_parameters syntax. "
"You can call self.standardize_rope_params() in the meantime.",
FutureWarning,
)
config.standardize_rope_params()
config.validate_rope(ignore_keys=ignore_keys)