Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/attention_backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ configuration.
|---------|--------|-----------|-------------|------------|------|--------|-----------|-----|-----------------|--------------|
| `CUTLASS_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 128 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 10.x |
| `FLASHINFER_MLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | Any | ❌ | ❌ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHINFER_MLA_SPARSE` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 32, 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 10.x |
| `FLASHMLA` | fp16, bf16 | `auto`, `bfloat16`, `fp8`, `fp8_e4m3` | 64 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x-10.x |
| `FLASHMLA_SPARSE` | bf16 | `auto`, `bfloat16`, `fp8_ds_mla` | 64 | 576 | ❌ | ✅ | ❌ | ❌ | Decoder | 9.x-10.x |
| `FLASH_ATTN_MLA` | fp16, bf16 | `auto`, `bfloat16` | %16 | Any | ❌ | ❌ | ❌ | ✅ | Decoder | 9.x |
Expand Down
20 changes: 18 additions & 2 deletions tests/v1/attention/test_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ def __init__(
self._k_scale_float = 1.0
self._v_scale_float = 1.0

self._decode_concat_quant_fp8_op = _DecodeConcatQuantFP8(
static=True,
group_shape=GroupShape.PER_TENSOR,
compile_native=True,
)

def forward_impl(
self,
q: torch.Tensor,
Expand All @@ -338,6 +344,7 @@ def forward_impl(
) -> torch.Tensor:
"""Forward for sparse MLA - uses forward_mqa for all tokens."""
kv_cache_dtype = getattr(self.impl, "kv_cache_dtype", "auto")
fp8_attention = kv_cache_dtype.startswith("fp8")

# Write to KV cache
if kv_cache.numel() > 0:
Expand All @@ -350,6 +357,9 @@ def forward_impl(
scale=self._k_scale,
)

if fp8_attention and kv_cache_dtype != "fp8_ds_mla":
kv_cache = kv_cache.view(current_platform.fp8_dtype())

num_tokens = q.shape[0]

# Sparse MLA uses forward_mqa for all tokens
Expand All @@ -367,8 +377,14 @@ def forward_impl(
# Convert from (N, B, L) to (B, N, L)
mqa_ql_nope = mqa_ql_nope.transpose(0, 1)

# Pass as tuple to forward_mqa
mqa_q = (mqa_ql_nope, mqa_q_pe)
if fp8_attention and self.impl.supports_quant_query_input:
assert mqa_ql_nope.shape[0] == mqa_q_pe.shape[0]
assert mqa_ql_nope.shape[1] == mqa_q_pe.shape[1]
mqa_q = self._decode_concat_quant_fp8_op(
mqa_ql_nope, mqa_q_pe, self._q_scale
)
else:
mqa_q = (mqa_ql_nope, mqa_q_pe)

attn_out, _ = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)

Expand Down
12 changes: 11 additions & 1 deletion tests/v1/attention/test_sparse_mla_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def test_sparse_backend_decode_correctness(
if kv_cache_dtype not in backend_cls.supported_kv_cache_dtypes:
pytest.skip(f"{backend_cls.get_name()} does not support {kv_cache_dtype}")

if (
backend_cls == FlashMLASparseBackend
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
pytest.skip(
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)

supported_block_sizes = backend_cls.get_supported_kernel_block_sizes()
if block_size not in supported_block_sizes:
pytest.skip(
Expand Down Expand Up @@ -419,7 +429,7 @@ def test_sparse_backend_decode_correctness(
num_blocks=vllm_config.cache_config.num_gpu_blocks,
common_attn_metadata=common_attn_metadata,
randomize_blocks=False,
kv_cache_dtype=kv_cache_dtype if use_fp8_ds_mla_quantization else "auto",
kv_cache_dtype=kv_cache_dtype,
scale=kv_cache_scale,
)

Expand Down
16 changes: 15 additions & 1 deletion tools/pre_commit/generate_attention_backend_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
# Backends to skip during doc generation
SKIP_BACKENDS = {"CUSTOM", "TORCH_SDPA"}

BACKEND_KV_DTYPE_EXCLUDES: dict[str, set[str]] = {
# fp8 is an alias for fp8_ds_mla for FlashMLA Sparse
"FLASHMLA_SPARSE": {"fp8"},
}


def is_relevant_file(filepath: str) -> bool:
"""Check if a file matches any of the relevant patterns."""
Expand Down Expand Up @@ -546,10 +551,19 @@ def analyze_backend(backend_name: str, class_path: str) -> dict[str, Any] | None
tree, impl_class_name, "can_return_lse_for_decode", False, file_path
)

kv_cache_dtypes = parse_kv_cache_dtypes(class_node)
if backend_name in BACKEND_KV_DTYPE_EXCLUDES:
excluded = BACKEND_KV_DTYPE_EXCLUDES[backend_name]
kv_cache_dtypes = ", ".join(
d
for d in (d.strip() for d in kv_cache_dtypes.split(","))
if d not in excluded
)

return {
"name": backend_name,
"dtypes": parse_supported_dtypes(class_node),
"kv_cache_dtypes": parse_kv_cache_dtypes(class_node),
"kv_cache_dtypes": kv_cache_dtypes,
"block_sizes": parse_block_sizes(class_node),
"head_sizes": parse_head_sizes(class_node),
"attn_types": parse_attention_types(class_node),
Expand Down
35 changes: 30 additions & 5 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,6 @@ def __init__(
calculate_kv_scales = False
self.quant_config = quant_config

# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)

dtype = torch.get_default_dtype()
self.attn_backend = get_attn_backend(
self.head_size,
Expand All @@ -347,6 +342,36 @@ def __init__(
num_heads=self.num_heads,
)

# FlashMLA Sparse Attention fp8 backend uses "fp8_ds_mla" kv-cache format
# Automatically convert fp8 kv-cache format to "fp8_ds_mla"
if (
self.attn_backend.get_name() == "FLASHMLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
and kv_cache_dtype != "fp8_ds_mla"
):
assert cache_config is not None
cache_config.cache_dtype = "fp8_ds_mla"
kv_cache_dtype = "fp8_ds_mla"
logger.info_once(
"Using DeepSeek's fp8_ds_mla KV cache format. To use standard "
"fp8 kv-cache format, please set `--attention-backend "
"FLASHINFER_MLA_SPARSE`"
)

if (
self.attn_backend.get_name() == "FLASHINFER_MLA_SPARSE"
and kv_cache_dtype.startswith("fp8")
):
logger.info_once(
"Using standard fp8 KV cache format. To use DeepSeek's fp8_ds_mla "
"KV cache format, please set `--attention-backend FLASHMLA_SPARSE`"
)

# Initialize KV cache quantization attributes
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
_init_kv_cache_quant(self, quant_config, prefix)

if (
cache_config is not None
and cache_config.enable_prefix_caching
Expand Down
7 changes: 0 additions & 7 deletions vllm/model_executor/models/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,13 @@ def verify_and_update_model_config(model_config: "ModelConfig") -> None:
class DeepseekV32ForCausalLM(VerifyAndUpdateConfig):
@classmethod
def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
"""
Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32
"""
hf_config = vllm_config.model_config.hf_config

# Mirror the check in vllm/model_executor/models/deepseek_v2.py
is_v32 = hasattr(hf_config, "index_topk")
assert is_v32

# For DeepSeekV3.2, a custom fp8 format is used when fp8 kv-cache is enabled.
cache_config = vllm_config.cache_config
if cache_config.cache_dtype.startswith("fp8"):
cache_config.cache_dtype = "fp8_ds_mla"
logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2")
if cache_config.cache_dtype == "bfloat16":
cache_config.cache_dtype = "auto"
logger.info("Using bfloat16 kv-cache for DeepSeekV3.2")
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/attention/backends/mla/flashinfer_mla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ class FlashInferMLASparseBackend(AttentionBackend):
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
"auto",
"bfloat16",
"fp8",
"fp8_e4m3",
]

@staticmethod
Expand Down Expand Up @@ -304,6 +306,11 @@ def __init__(
self.bmm1_scale: float | None = None
self.bmm2_scale: float | None = None

# fp8 query quantization is required when using fp8 kv_cache,
# as the TRTLLM-GEN sparse MLA kernel requires matching dtypes
# for query and kv_cache (mixed bf16+fp8 is not supported).
self.supports_quant_query_input = True

def forward_mqa(
self,
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
Expand Down
7 changes: 7 additions & 0 deletions vllm/v1/attention/backends/mla/flashmla_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class FlashMLASparseBackend(AttentionBackend):
"auto",
"bfloat16",
"fp8_ds_mla",
"fp8", # alias for fp8_ds_mla
]

@staticmethod
Expand Down Expand Up @@ -567,6 +568,12 @@ def __init__(
)
self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads)

if kv_cache_dtype.startswith("fp8"):
assert kv_cache_dtype == "fp8_ds_mla", (
"FlashMLA Sparse Attention backend fp8 only supports "
"fp8_ds_mla kv-cache dtype"
)
Comment thread
pavanimajety marked this conversation as resolved.

if kv_cache_dtype == "fp8_ds_mla":
# Reserve workspace during initialization
vllm_config = get_current_vllm_config()
Expand Down