-
-
Notifications
You must be signed in to change notification settings - Fork 15.8k
[CPU] Add head sizes 80 and 112 with vec16 fallback #31968
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,7 +42,7 @@ def get_supported_dtypes(cls) -> list[torch.dtype]: | |
|
|
||
| @classmethod | ||
| def get_supported_head_sizes(cls) -> list[int]: | ||
| return [32, 64, 96, 128, 160, 192, 224, 256] | ||
| return [32, 64, 80, 96, 112, 128, 160, 192, 224, 256] | ||
|
Comment on lines
44
to
+45
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| @staticmethod | ||
| def get_name() -> str: | ||
|
|
@@ -137,7 +137,7 @@ def __init__( | |
| if self.window_size is None: | ||
| self.window_size = -1 | ||
| self.block_size = vllm_config.cache_config.block_size | ||
| self.isa = _get_attn_isa(self.dtype, self.block_size) | ||
| self.isa = _get_attn_isa(self.dtype, self.block_size, self.head_dim) | ||
| self.is_cross_attention = isinstance(kv_cache_spec, CrossAttentionSpec) | ||
|
|
||
| def build( | ||
|
|
@@ -484,7 +484,11 @@ def _make_sliding_window_bias( | |
| return attn_biases | ||
|
|
||
|
|
||
| def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str: | ||
| def _get_attn_isa( | ||
| dtype: torch.dtype, block_size: int, head_size: int | None = None | ||
| ) -> str: | ||
| if head_size is not None and head_size % 32 != 0 and head_size % 16 == 0: | ||
| return "vec16" | ||
|
Comment on lines
+487
to
+491
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The logic to force To maintain a single source of truth and avoid redundancy, this logic should only reside in the C++ backend. Please remove the Also, the call to this function in def _get_attn_isa(dtype: torch.dtype, block_size: int) -> str: |
||
| supports_amx = torch._C._cpu._is_amx_tile_supported() | ||
| if supports_amx and dtype in (torch.bfloat16,) and block_size % 32 == 0: | ||
| return "amx" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok...now we created too much reduandant template instantiations, requires to reorgnize the dispatch procedure.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe a future PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes