diff --git a/src/transformers/modeling_rope_utils.py b/src/transformers/modeling_rope_utils.py index de27e5f8bd20..542c81a4eaed 100644 --- a/src/transformers/modeling_rope_utils.py +++ b/src/transformers/modeling_rope_utils.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import warnings from functools import wraps from typing import TYPE_CHECKING, Optional, TypedDict @@ -656,7 +657,7 @@ def standardize_rope_params(self): # Move `rope_theta` and `partial_rotary_factor` to the params dict, if not there yet rope_theta = getattr(self, "rope_theta", None) partial_rotary_factor = getattr(self, "partial_rotary_factor", None) - rope_parameters = self.rope_parameters + rope_parameters = self.rope_parameters or {} # Case 1: RoPE param keys do not intersect with possible `layer_types` -> one global dict if getattr(self, "layer_types", None) is None or not set(rope_parameters.keys()).issubset(self.layer_types): @@ -913,3 +914,20 @@ def _check_received_keys( unused_keys = received_keys - required_keys if unused_keys: logger.warning(f"Unrecognized keys in `rope_parameters` for 'rope_type'='{rope_type}': {unused_keys}") + + +def rope_config_validation(config: RotaryEmbeddingConfigMixin, ignore_keys: Optional[set] = None): + """ + This is a deprecated function. + It has been kept for backward compatibility with custom code models. + """ + warnings.warn( + "`rope_config_validation` is deprecated and has been removed. " + "Its functionality has been moved to RotaryEmbeddingConfigMixin.validate_rope method. " + "PreTrainedConfig inherits this class, so please call self.validate_rope() instead. " + "Also, make sure to use the new rope_parameters syntax. " + "You can call self.standardize_rope_params() in the meantime.", + FutureWarning, + ) + config.standardize_rope_params() + config.validate_rope(ignore_keys=ignore_keys)