Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ Priority is **1 = highest** (tried first).
| `FLASH_ATTN` | FA4* | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ✅ | All | ≥10.0 |
| `FLASH_ATTN_DIFFKV` | | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ✅ | Decoder | Any |
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | Any | Any | | | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | | | ❌ | All | N/A |
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | | | ❌ | All | N/A |
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32, 544 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | | | ❌ | All | N/A |
| `TREE_ATTN` | | fp16, bf16 | `auto` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |

Expand Down Expand Up @@ -210,7 +210,7 @@ configuration.
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | bf16 | `auto` | Any | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 1 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 1 | Any | ❌ | ✅ | ❌ | ❌ | Decoder | N/A |
| `ROCM_AITER_TRITON_MLA` | fp16, bf16 | `auto` | Any | Any | ❌ | ❌ | ❌ | ❌ | Decoder | N/A |
| `TRITON_MLA` | fp16, bf16 | `auto`, `bfloat16` | Any | Any | ❌ | ❌ | ❌ | ✅ | Decoder | Any |
2 changes: 1 addition & 1 deletion vllm/v1/attention/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def validate_configuration(
else:
invalid_reasons.append("non-MLA not supported")
if has_sink and not cls.supports_sink():
invalid_reasons.append("sink setting not supported")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename for clarity

invalid_reasons.append("attention sinks not supported")
if use_sparse != cls.is_sparse():
if use_sparse:
invalid_reasons.append("sparse not supported")
Expand Down
10 changes: 10 additions & 0 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
Expand All @@ -21,6 +22,15 @@


class AiterMLABackend(MLACommonBackend):
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]
Expand Down
21 changes: 11 additions & 10 deletions vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims,
Expand All @@ -21,6 +22,7 @@
AttentionMetadata,
AttentionMetadataBuilder,
CommonAttentionMetadata,
MultipleOf,
SparseMLAAttentionImpl,
)
from vllm.v1.attention.backends.mla.flashmla_sparse import (
Expand Down Expand Up @@ -77,7 +79,15 @@ def fetch_id_to_ragged_triton(

class ROCMAiterMLASparseBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [1]

@staticmethod
def get_name() -> str:
Expand Down Expand Up @@ -105,10 +115,6 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
return (num_blocks, block_size, head_size)

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

already default in MLACommonBackend

def get_supported_head_sizes(cls) -> list[int]:
return [576]

@classmethod
def is_mla(cls) -> bool:
return True
Expand All @@ -117,11 +123,6 @@ def is_mla(cls) -> bool:
def is_sparse(cls) -> bool:
return True

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant

def supports_block_size(cls, block_size: int | None) -> bool:
# The only supported block_size is 1
return block_size is None or block_size == 1


@dataclass
class ROCMAiterMLASparseMetadata(AttentionMetadata):
Expand Down
5 changes: 0 additions & 5 deletions vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,6 @@ def get_impl_cls() -> type["TritonMLAImpl"]:
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
return True

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with get_supported_kernel_block_sizes

def supports_block_size(cls, block_size: int | None) -> bool:
# The only unsupported block_size is 1
return block_size is None or block_size != 1


class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True
Expand Down
8 changes: 8 additions & 0 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention
from vllm.platforms import current_platform
Expand Down Expand Up @@ -732,6 +733,13 @@ def use_cascade_attention(self, *args, **kwargs) -> bool:
class AiterFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
Expand Down
18 changes: 17 additions & 1 deletion vllm/v1/attention/backends/rocm_aiter_unified_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
QuantKey,
kFp8StaticTensorSym,
)
from vllm.v1.attention.backend import AttentionLayer, AttentionType
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import (
RocmAttentionBackend,
Expand All @@ -25,6 +25,22 @@
class RocmAiterUnifiedAttentionBackend(RocmAttentionBackend):
accept_output_buffer: bool = True

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
return [MultipleOf(16)]

@classmethod
def supports_head_size(cls, head_size: int) -> bool:
return head_size >= 32

@classmethod
def supports_mm_prefix(cls) -> bool:
return True

@classmethod
def supports_sink(cls) -> bool:
return True

forward_includes_kv_cache_update: bool = False

@staticmethod
Expand Down
25 changes: 14 additions & 11 deletions vllm/v1/attention/backends/rocm_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey,
Expand Down Expand Up @@ -163,6 +164,13 @@ class RocmAttentionBackend(AttentionBackend):
torch.bfloat16,
torch.float32,
]
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
"fp8_e5m2",
]

@staticmethod
def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
Expand All @@ -185,15 +193,12 @@ def get_supported_head_sizes(cls) -> list[int]:
return [32, 64, 80, 96, 128, 160, 192, 224, 256]

@classmethod
def validate_head_size(cls, head_size: int) -> None:
if not cls.supports_head_size(head_size):
attn_type = cls.__name__.removesuffix("Backend")
raise ValueError(
f"Head size {head_size} is not supported by {attn_type}. "
f"Supported head sizes are: {cls.get_supported_head_sizes()}. "
"Set --attention-backend=FLEX_ATTENTION to use "
"FlexAttention backend which supports all head sizes."
)
def supports_mm_prefix(cls) -> bool:
return True

@classmethod
def supports_sink(cls) -> bool:
return True

forward_includes_kv_cache_update: bool = False

Expand Down Expand Up @@ -275,8 +280,6 @@ def __init__(

self.num_queries_per_kv = self.num_heads // self.num_kv_heads

RocmAttentionBackend.validate_head_size(head_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant


self.fp8_dtype = current_platform.fp8_dtype()

self.sinks = sinks
Expand Down