diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 3e6a9add9ec4..88705a9cde2c 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -145,6 +145,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class MiniMaxM2Attention(nn.Module): def __init__( self, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -199,9 +200,14 @@ def __init__( prefix=f"{prefix}.o_proj", ) + final_rotary_dim = rotary_dim + # https://github.com/vllm-project/vllm/pull/30384 + if hasattr(config, "partial_rotary_factor"): + final_rotary_dim = self.head_dim + self.rotary_emb = get_rope( self.head_dim, - rotary_dim=self.head_dim, + rotary_dim=final_rotary_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, ) @@ -260,6 +266,7 @@ def __init__( self.layer_idx = layer_idx self.self_attn = MiniMaxM2Attention( + config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads,