diff --git a/csrc/cpu/cpu_attn.cpp b/csrc/cpu/cpu_attn.cpp index 50f17c758c..c92871b076 100644 --- a/csrc/cpu/cpu_attn.cpp +++ b/csrc/cpu/cpu_attn.cpp @@ -24,7 +24,9 @@ switch (HEAD_DIM) { \ CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \ + CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \ CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \ diff --git a/csrc/cpu/cpu_attn_amx.hpp b/csrc/cpu/cpu_attn_amx.hpp index 8da458b991..129f6423dd 100644 --- a/csrc/cpu/cpu_attn_amx.hpp +++ b/csrc/cpu/cpu_attn_amx.hpp @@ -377,7 +377,7 @@ class AttentionImpl { const int32_t q_heads_per_kv, const int64_t q_num_stride, const int64_t q_head_stride, const float scale) { constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t); - static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); + // static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0); constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES; constexpr int64_t head_elem_num_pre_block = AMX_TILE_ROW_BYTES / sizeof(scalar_t); diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index f1254352c0..eb087fc9ee 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -42,7 +42,7 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: @classmethod def get_supported_head_sizes(cls) -> list[int]: - return [32, 64, 96, 128, 160, 192, 224, 256] + return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256] @staticmethod def get_name() -> str: @@ -137,7 +137,7 @@ def __init__( if self.window_size is None: self.window_size = -1 self.block_size = vllm_config.cache_config.block_size - self.isa = _get_attn_isa(self.dtype, self.block_size) + self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim) def build( self, @@ -486,7 +486,9 @@ def _make_sliding_window_bias( return attn_biases -def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str: +def _get_attn_isa(dtype: torch.dtype, block_size: int, head_size: int | None = None) -> str: + if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0: + return "vec16" supports_amx = torch._C._cpu._is_amx_tile_supported() if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: return "amx"