Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def __init__(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
rope_parameters=config.rope_parameters,
rope_parameters=getattr(config, "rope_parameters", None),
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/gpt_j.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __init__(
self.head_size,
rotary_dim=config.rotary_dim,
max_position=max_position_embeddings,
rope_parameters=config.rope_parameters,
rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=False,
)
self.attn = Attention(
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/grok1.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_parameters=config.rope_parameters,
rope_parameters=getattr(config, "rope_parameters", None),
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _init_rotary_emb(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=config.rope_parameters,
rope_parameters=getattr(config, "rope_parameters", None),
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
Expand Down
62 changes: 33 additions & 29 deletions vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,51 +452,55 @@ def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> No

def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
# Retrieve rope_parameters differently based on Transformers version
# Patch rope_parameters differently based on Transformers version
if Version(version("transformers")) >= Version("5.0.0.dev0"):
from transformers.modeling_rope_utils import RopeParameters

rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
config, "rope_parameters", None
from transformers.modeling_rope_utils import (
rope_config_validation,
standardize_rope_params,
)
elif hasattr(config, "rope_parameters"):
# We are in Transformers v4 and rope_parameters
# has already been patched for this config
return

# When Transformers v5 is installed, legacy rope_theta may be present
# when using custom code models written for Transformers v4
if (rope_theta := getattr(config, "rope_theta", None)) is not None:
standardize_rope_params(config, rope_theta=rope_theta)
rope_config_validation(config)
# Delete rope_theta to avoid confusion in downstream code
del config.rope_theta
else:
# Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
rope_theta: float | None = getattr(config, "rope_theta", None)
rope_scaling: dict | None = getattr(config, "rope_scaling", None)
rope_parameters = rope_scaling
# Move rope_theta into rope_parameters
if rope_theta is not None:
rope_parameters = rope_parameters or {"rope_type": "default"}
rope_parameters["rope_theta"] = rope_theta
# Add original_max_position_embeddings if present
if rope_parameters and (
ompe := getattr(config, "original_max_position_embeddings", None)
):
rope_parameters["original_max_position_embeddings"] = ompe
# Write back to config
config.rope_parameters = rope_parameters
# When Transformers v4 is installed, legacy rope_scaling may be present
if (rope_scaling := getattr(config, "rope_scaling", None)) is not None:
config.rope_parameters = rope_scaling
# When Transformers v4 is installed, legacy rope_theta may be present
if (rope_theta := getattr(config, "rope_theta", None)) is not None:
if not hasattr(config, "rope_parameters"):
config.rope_parameters = {"rope_type": "default"}
config.rope_parameters["rope_theta"] = rope_theta

# No RoPE parameters to patch
if rope_parameters is None:
if not hasattr(config, "rope_parameters"):
return

# Add original_max_position_embeddings if present
if ompe := getattr(config, "original_max_position_embeddings", None):
config.rope_parameters["original_max_position_embeddings"] = ompe

# Handle nested rope_parameters in interleaved sliding attention models
if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
for rope_parameters_layer_type in rope_parameters.values():
if set(config.rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
for rope_parameters_layer_type in config.rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type)
else:
patch_rope_parameters_dict(rope_parameters)
patch_rope_parameters_dict(config.rope_parameters)


def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
if "rope_type" in rope_parameters and "type" in rope_parameters:
rope_type = rope_parameters["rope_type"]
rope_type_legacy = rope_parameters["type"]
if rope_type != rope_type_legacy:
if (rope_type_legacy == "su" and rope_type == "longrope") or (
rope_type_legacy == "mrope" and rope_type == "default"
):
pass # No action needed
elif rope_type != rope_type_legacy:
raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern "
f"field) and 'type={rope_type_legacy}' (legacy field). "
Expand Down
Loading