-
-
Notifications
You must be signed in to change notification settings - Fork 14.8k
Reenable features for ROCm attention backends #36185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
|
|
||
| from vllm._aiter_ops import rocm_aiter_ops | ||
| from vllm.config import VllmConfig | ||
| from vllm.config.cache import CacheDType | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.attention.mla_attention import ( | ||
| get_mla_dims, | ||
|
|
@@ -21,6 +22,7 @@ | |
| AttentionMetadata, | ||
| AttentionMetadataBuilder, | ||
| CommonAttentionMetadata, | ||
| MultipleOf, | ||
| SparseMLAAttentionImpl, | ||
| ) | ||
| from vllm.v1.attention.backends.mla.flashmla_sparse import ( | ||
|
|
@@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton( | |
|
|
||
| class ROCMAiterMLASparseBackend(AttentionBackend): | ||
| accept_output_buffer: bool = True | ||
| supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16] | ||
| supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16] | ||
| supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ | ||
| "auto", | ||
| "bfloat16", | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: | ||
| return [1] | ||
|
|
||
| @staticmethod | ||
| def get_name() -> str: | ||
|
|
@@ -105,10 +115,6 @@ def get_kv_cache_shape( | |
| ) -> tuple[int, ...]: | ||
| return (num_blocks, block_size, head_size) | ||
|
|
||
| @classmethod | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. already default in |
||
| def get_supported_head_sizes(cls) -> list[int]: | ||
| return [576] | ||
|
|
||
| @classmethod | ||
| def is_mla(cls) -> bool: | ||
| return True | ||
|
|
@@ -117,11 +123,6 @@ def is_mla(cls) -> bool: | |
| def is_sparse(cls) -> bool: | ||
| return True | ||
|
|
||
| @classmethod | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant |
||
| def supports_block_size(cls, block_size: int | None) -> bool: | ||
| # The only supported block_size is 1 | ||
| return block_size is None or block_size == 1 | ||
|
|
||
|
|
||
| @dataclass | ||
| class ROCMAiterMLASparseMetadata(AttentionMetadata): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,11 +45,6 @@ def get_impl_cls() -> type["TritonMLAImpl"]: | |
| def supports_compute_capability(cls, capability: DeviceCapability) -> bool: | ||
| return True | ||
|
|
||
| @classmethod | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant with |
||
| def supports_block_size(cls, block_size: int | None) -> bool: | ||
| # The only unsupported block_size is 1 | ||
| return block_size is None or block_size != 1 | ||
|
|
||
|
|
||
| class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]): | ||
| can_return_lse_for_decode: bool = True | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
|
|
||
| from vllm._aiter_ops import rocm_aiter_ops | ||
| from vllm.config import VllmConfig | ||
| from vllm.config.cache import CacheDType | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.quantization.utils.quant_utils import ( | ||
| QuantKey, | ||
|
|
@@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend): | |
| torch.bfloat16, | ||
| torch.float32, | ||
| ] | ||
| supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [ | ||
| "auto", | ||
| "bfloat16", | ||
| "fp8", | ||
| "fp8_e4m3", | ||
| "fp8_e5m2", | ||
| ] | ||
|
|
||
| @staticmethod | ||
| def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: | ||
|
|
@@ -185,15 +193,12 @@ def get_supported_head_sizes(cls) -> list[int]: | |
| return [32, 64, 80, 96, 128, 160, 192, 224, 256] | ||
|
|
||
| @classmethod | ||
| def validate_head_size(cls, head_size: int) -> None: | ||
| if not cls.supports_head_size(head_size): | ||
| attn_type = cls.__name__.removesuffix("Backend") | ||
| raise ValueError( | ||
| f"Head size {head_size} is not supported by {attn_type}. " | ||
| f"Supported head sizes are: {cls.get_supported_head_sizes()}. " | ||
| "Set --attention-backend=FLEX_ATTENTION to use " | ||
| "FlexAttention backend which supports all head sizes." | ||
| ) | ||
| def supports_mm_prefix(cls) -> bool: | ||
| return True | ||
|
|
||
| @classmethod | ||
| def supports_sink(cls) -> bool: | ||
| return True | ||
|
|
||
| forward_includes_kv_cache_update: bool = False | ||
|
|
||
|
|
@@ -275,8 +280,6 @@ def __init__( | |
|
|
||
| self.num_queries_per_kv = self.num_heads // self.num_kv_heads | ||
|
|
||
| RocmAttentionBackend.validate_head_size(head_size) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. redundant |
||
|
|
||
| self.fp8_dtype = current_platform.fp8_dtype() | ||
|
|
||
| self.sinks = sinks | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename for clarity