Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions csrc/cpu/cpu_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
switch (HEAD_DIM) { \
CPU_ATTN_DISPATCH_CASE(32, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(64, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(80, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(96, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(112, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(128, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(160, __VA_ARGS__) \
CPU_ATTN_DISPATCH_CASE(192, __VA_ARGS__) \
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/cpu_attn_amx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
const int32_t q_heads_per_kv, const int64_t q_num_stride,
const int64_t q_head_stride, const float scale) {
constexpr int64_t bytes_per_head = head_dim * sizeof(scalar_t);
static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
//static_assert(bytes_per_head % AMX_TILE_ROW_BYTES == 0);
constexpr int64_t head_size_block_num = bytes_per_head / AMX_TILE_ROW_BYTES;
constexpr int64_t head_elem_num_pre_block =
AMX_TILE_ROW_BYTES / sizeof(scalar_t);
Expand Down
8 changes: 5 additions & 3 deletions vllm/v1/attention/backends/cpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_supported_dtypes(cls) -> list[torch.dtype]:

@classmethod
def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256]

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -137,7 +137,7 @@ def __init__(
if self.window_size is None:
self.window_size = -1
self.block_size = vllm_config.cache_config.block_size
self.isa = _get_attn_isa(self.dtype, self.block_size)
self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim)

def build(
self,
Expand Down Expand Up @@ -486,7 +486,9 @@ def _make_sliding_window_bias(
return attn_biases


def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str:
def _get_attn_isa(dtype: torch.dtype, block_size: int, head_size: int | None = None) -> str:
if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0:
return "vec16"
supports_amx = torch._C._cpu._is_amx_tile_supported()
if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0:
return "amx"
Expand Down