Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions vllm/model_executor/models/plamo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Plamo3Config(PretrainedConfig): # type: ignore
# if `sliding_window` is list
interleaved_sliding_window: list[int | None]
sliding_window_pattern: int
rope_theta: int
rope_parameters: dict[str, Any]
rope_local_theta: int
# MLP
intermediate_size: int
Expand Down Expand Up @@ -153,13 +153,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> No
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)

layer_idx = extract_layer_index(prefix)
full_attn = config.interleaved_sliding_window[layer_idx] is None
layer_type = config.layer_types[layer_idx]
is_sliding = layer_type == "sliding_attention"

self.rope_theta = config.rope_theta if full_attn else config.rope_local_theta
self.rope_scaling = (
config.rope_scaling if hasattr(config, "rope_scaling") else None
)
# Initialize the rotary embedding.
if layer_type in config.rope_parameters:
# Transformers v5 rope config.
rope_parameters = config.rope_parameters[layer_type]
else:
# Transformers v4 rope config.
# Global attention. Use the values in config.json.
rope_parameters = config.rope_parameters
# Local attention. Override the values in config.json.
if is_sliding:
rope_parameters = dict(
rope_type="default", rope_theta=config.rope_local_theta
)
max_position = config.max_position_embeddings
if hasattr(vllm_config.model_config, "max_model_len") and isinstance(
vllm_config.model_config.max_model_len, int
Expand All @@ -170,8 +181,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs) -> No
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
rope_scaling=self.rope_scaling,
rope_parameters=rope_parameters,
)
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
set_weight_attrs(
Expand Down