diff --git a/vllm_metal/metal_kernel_backend/attention_sdpa.py b/vllm_metal/metal_kernel_backend/attention_sdpa.py index 2453911d..c5d80300 100644 --- a/vllm_metal/metal_kernel_backend/attention_sdpa.py +++ b/vllm_metal/metal_kernel_backend/attention_sdpa.py @@ -65,7 +65,7 @@ def sdpa_forward( # split into queries (for attention) + gate (applied after attention). q_proj_out = inner.q_proj(x) gate = None - head_dim = inner.head_dim + head_dim = inner.k_proj.weight.shape[0] // n_kv_heads q_full_head = q_proj_out.shape[-1] // n_heads if q_full_head == 2 * head_dim: # Gated: split into queries + gate