diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 02c722ba031a..374fc2ee6ddc 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -15,6 +15,7 @@ #ifdef __aarch64__ #include "cpu_attn_neon.hpp" + // NEON requires head_dim to be a multiple of 32 #define NEON_DISPATCH(...) \ case cpu_attention::ISA::NEON: { \ using attn_impl = cpu_attention::AttentionImpl { const int32_t q_heads_per_kv, const int64_t q_num_stride, const int64_t q_head_stride, const float scale) { constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t); - static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); + // static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES; constexpr int64_t head_elem_num_pre_block = AMX_TILE_ROW_BYTES / sizeof(scalar_t); diff --git a/csrc/cpu/cpu_attn_neon.hpp b/csrc/cpu/cpu_attn_neon.hpp index 827f0cfbc718..e9ecd1d32904 100644 --- a/csrc/cpu/cpu_attn_neon.hpp +++ b/csrc/cpu/cpu_attn_neon.hpp @@ -264,7 +264,7 @@ class AttentionImpl { constexpr static ISA ISAType = ISA::NEON; constexpr static bool scale_on_logits = false; // apply scale on q_buffer - static_assert(HeadDim % HeadDimAlignment == 0); + // static_assert(HeadDim % HeadDimAlignment == 0); // the gemm micro kernel is Mx8 static_assert(HeadDimAlignment % 8 == 0); static_assert(BlockSizeAlignment % 8 == 0); diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 394d0c2f6713..abbee244af3d 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -42,7 +42,7 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: @@ -137,7 +137,7 @@ def __init__( if self.window_size is None: self.window_size = -1 self.block_size = vllm_config.cache_config.block_size - self.isa = _get_attn_isa(self.dtype, self.block_size) + self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim) self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) def build( @@ -484,7 +484,11 @@ def _make_sliding_window_bias( return attn_biases -def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str: +def _get_attn_isa( + dtype: torch.dtype, block_size: int, head_size: int | None = None +) -> str: + if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0: + return "vec16" supports_amx = torch._C._cpu._is_amx_tile_supported() if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: return "amx"