Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions src/transformers/modeling_rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,20 +654,26 @@ def standardize_rope_params(self):
Helper to standardize the config's rope params field by ensuring the params are defined for each
later type. For old model the fn will duplicate a single rope param in each layer type (backward compatibility)
"""
# Move `rope_theta` and `partial_rotary_factor` to the params dict, if not there yet
# Move `rope_theta` and `partial_rotary_factor` to the `rope_parameters`, if not there yet
rope_theta = getattr(self, "rope_theta", None)
partial_rotary_factor = getattr(self, "partial_rotary_factor", None)
rope_parameters = getattr(self, "rope_parameters", None) or {}
layer_types = getattr(self, "layer_types", None)

# Case 0: no RoPE params defined
if not (rope_parameters or rope_theta):
# partial_rotary_factor without rope_theta is invalid, so we don't check for it here
logger.warning("`standardize_rope_params` was called but no RoPE parameters were found.")
return
# Case 1: RoPE param keys do not intersect with possible `layer_types` -> one global dict
if getattr(self, "layer_types", None) is None or not set(rope_parameters.keys()).issubset(self.layer_types):
elif layer_types is None or rope_parameters == {} or not set(rope_parameters.keys()).issubset(layer_types):
rope_parameters.setdefault("rope_type", rope_parameters.get("type", "default"))
rope_parameters.setdefault("rope_theta", rope_theta)
if partial_rotary_factor is not None:
rope_parameters["partial_rotary_factor"] = partial_rotary_factor
# Case 2: different RoPE for each layer -> several params as nested dict
else:
for layer_type in self.layer_types:
for layer_type in layer_types:
rope_parameters[layer_type].setdefault("rope_type", rope_parameters[layer_type].get("type", "default"))
rope_parameters[layer_type].setdefault("rope_theta", rope_theta)
if partial_rotary_factor is not None:
Expand Down