-
-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[Attention] Blackwell FP8 MLA support with CUTLASS_MLA backend #23289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
LucasWilkinson
merged 1 commit into
vllm-project:main
from
MatthewBonanni:feature/fp8_mla_cutlass_blackwell
Sep 3, 2025
+186
−107
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,96 +1,180 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| import math | ||
| import random | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from torch import Tensor | ||
|
|
||
| import vllm._custom_ops as ops | ||
| from vllm.platforms import current_platform | ||
|
|
||
| if not current_platform.has_device_capability(100): | ||
| pytest.skip( | ||
| reason="Cutlass MLA Requires compute capability of 10 or above.", | ||
| allow_module_level=True) | ||
|
|
||
|
|
||
| def ref_mla( | ||
| out: Tensor, # (bs, num_heads, v_head_dim) | ||
| query: Tensor, # (bs, num_heads, head_dim) | ||
| kv_cache: Tensor, # (num_blocks, block_size, head_dim) | ||
| scale: float, | ||
| block_tables: Tensor, # (bs, max_num_blocks) | ||
| seq_lens: Tensor, # (bs,) | ||
| ): | ||
| bs, num_heads, v_head_dim = out.shape | ||
| head_dim = query.shape[2] | ||
|
|
||
| for i in range(bs): | ||
| # gather and flatten KV-cache | ||
| kv = kv_cache[ | ||
| block_tables[i]] # (max_num_blocks, block_size, head_dim) | ||
| kv = kv.view(1, -1, | ||
| head_dim)[:, :seq_lens[i]] # (1, seq_len, head_dim) | ||
| v = kv[:, :, :v_head_dim] | ||
|
|
||
| q = query[i].view(num_heads, 1, head_dim) | ||
| o = F.scaled_dot_product_attention(q, | ||
| kv, | ||
| v, | ||
| scale=scale, | ||
| enable_gqa=True) | ||
| out[i] = o.view(num_heads, v_head_dim) | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) | ||
| @pytest.mark.parametrize("mean_seq_len", [128, 1024, 4096]) | ||
| @pytest.mark.parametrize("bs", [1, 2, 4]) | ||
| from vllm.triton_utils import triton | ||
|
|
||
|
|
||
| def cal_diff(x: torch.Tensor, | ||
| y: torch.Tensor, | ||
| name: str, | ||
| use_fp8: bool = False) -> None: | ||
| x, y = x.double(), y.double() | ||
| cos_diff = 1 - 2 * (x * y).sum().item() / max( | ||
| (x * x + y * y).sum().item(), 1e-12) | ||
| if (use_fp8): | ||
| assert cos_diff < 1e-4 | ||
| else: | ||
| assert cos_diff < 1e-5 | ||
|
|
||
|
|
||
| CUTLASS_MLA_UNSUPPORTED_REASON = \ | ||
| "Cutlass MLA Requires compute capability of 10 or above." \ | ||
| if not current_platform.is_device_capability(100) \ | ||
| else "Cutlass MLA is supported" | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not current_platform.has_device_capability(100), | ||
| reason=CUTLASS_MLA_UNSUPPORTED_REASON) | ||
| @pytest.mark.parametrize("b", [128]) | ||
| @pytest.mark.parametrize("s_q", [1]) | ||
| @pytest.mark.parametrize("mean_sk", [4096, 8192, 16384]) | ||
| @pytest.mark.parametrize("h_q", [16, 32, 64, 128]) | ||
| @pytest.mark.parametrize("h_kv", [1]) | ||
| @pytest.mark.parametrize("d", [576]) | ||
| @pytest.mark.parametrize("dv", [512]) | ||
| @pytest.mark.parametrize("block_size", [64]) | ||
| @pytest.mark.parametrize("causal", [True]) | ||
| @pytest.mark.parametrize("varlen", [False, True]) | ||
| @pytest.mark.parametrize("block_size", [16, 64, 128]) | ||
| def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, | ||
| varlen: bool, block_size: int): | ||
| torch.set_default_dtype(dtype) | ||
| torch.set_default_device('cuda') | ||
| @pytest.mark.parametrize("torch_dtype", [torch.bfloat16, torch.float8_e4m3fn]) | ||
| @torch.inference_mode() | ||
| def test_cutlass_mla_decode(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, | ||
| causal, varlen, torch_dtype): | ||
| device = torch.device("cuda:0") | ||
| if torch_dtype == torch.float8_e4m3fn: | ||
| init_dtype = torch.bfloat16 | ||
| else: | ||
| init_dtype = torch_dtype | ||
| torch.set_default_dtype(init_dtype) | ||
| torch.set_default_device(device) | ||
| torch.cuda.set_device(device) | ||
| torch.manual_seed(42) | ||
| random.seed(42) | ||
|
|
||
| d = 576 | ||
| h_q = 128 | ||
| dv = 512 | ||
| print(f"{b=}, {s_q=}, {mean_sk=}, {h_q=}, {h_kv=}, " | ||
| f"{d=}, {dv=}, {causal=}, {varlen=}, {torch_dtype=}") | ||
|
|
||
| q_nope_dim = 128 | ||
| q_pe_dim = 64 | ||
| scale = (q_nope_dim + q_pe_dim)**(-0.5) | ||
| use_fp8 = torch_dtype == torch.float8_e4m3fn | ||
| scale = math.sqrt(d)**(-1) | ||
| cache_seqlens = torch.full((b, ), mean_sk, dtype=torch.int32) | ||
| if varlen: | ||
| seq_lens = torch.empty(bs).normal_(mean_seq_len, mean_seq_len / 2) | ||
| seq_lens = seq_lens.clip(2).to(torch.int32) | ||
| for i in range(b): | ||
| cache_seqlens[i] = max(random.normalvariate(mean_sk, mean_sk / 2), | ||
| s_q) | ||
| total_seqlens = cache_seqlens.sum().item() | ||
| max_seqlen = cache_seqlens.max().item() | ||
| max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 | ||
|
|
||
| q = torch.randn(b, s_q, h_q, d) | ||
| block_table = torch.arange(b * max_seqlen_pad // block_size, | ||
| dtype=torch.int32).view( | ||
| b, max_seqlen_pad // block_size) | ||
| blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) | ||
| blocked_v = blocked_k[..., :dv] | ||
|
|
||
| init_dtype = q.dtype | ||
| if use_fp8: | ||
| fp8_dtype = torch.float8_e4m3fn | ||
| descale_q = torch.ones((1), dtype=torch.float32) | ||
| descale_k = torch.ones((1), dtype=torch.float32) | ||
|
|
||
| q = q.to(fp8_dtype) | ||
| blocked_k = blocked_k.to(fp8_dtype) | ||
| blocked_v = blocked_v.to(fp8_dtype) | ||
| else: | ||
| seq_lens = torch.full((bs, ), mean_seq_len, dtype=torch.int32) | ||
| max_seq_len = seq_lens.max().item() | ||
| block_num = (max_seq_len + block_size - 1) // block_size | ||
|
|
||
| # Pad block_num so that small blocks can be packed into full 128-sized | ||
| # CUTLASS tiles. One 128-wide tile can hold (128 // block_size) small | ||
| # blocks. | ||
| pack_factor = 128 // block_size | ||
| block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor | ||
|
|
||
| # Amplify input values to ensure test coverage of edge cases where CUTLASS | ||
| # kernel errors occur with split_k settings. | ||
| q = torch.randn(bs, h_q, d) * 100 | ||
| block_table = torch.randint(0, | ||
| bs * block_num, (bs, block_num), | ||
| dtype=torch.int32) | ||
|
|
||
| kv_cache = torch.randn(block_table.numel(), block_size, d) | ||
|
|
||
| out_ref = q.new_zeros(bs, h_q, dv) | ||
| ref_mla(out_ref, q, kv_cache, scale, block_table, seq_lens) | ||
| out_ans = torch.zeros_like(out_ref) | ||
| q_nope = q[:, :, :dv].clone() | ||
| q_pe = q[:, :, dv:].clone() | ||
| ops.cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache, seq_lens, | ||
| block_table, scale) | ||
|
|
||
| torch.testing.assert_close(out_ans, out_ref, atol=1e-2, rtol=1e-2) | ||
| descale_q = None | ||
| descale_k = None | ||
|
|
||
| def cutlass_mla(): | ||
| MAX_HEADS = 128 | ||
|
|
||
| q_reshaped = q.squeeze(1) | ||
| q_nope = q_reshaped[:, :, :dv].clone() | ||
| q_pe = q_reshaped[:, :, dv:].clone() | ||
|
|
||
| if h_q < MAX_HEADS: | ||
| q_nope_padded = q_nope.new_empty((b, MAX_HEADS, dv)) | ||
| q_nope_padded[:, :h_q] = q_nope | ||
| q_nope = q_nope_padded | ||
|
|
||
| q_pe_padded = q_pe.new_empty((b, MAX_HEADS, d - dv)) | ||
| q_pe_padded[:, :h_q] = q_pe | ||
| q_pe = q_pe_padded | ||
|
|
||
| kv_cache_flat = blocked_k.squeeze(2) | ||
| device_properties = torch.cuda.get_device_properties( | ||
| torch.device("cuda:0")) | ||
| sm_count = device_properties.multi_processor_count | ||
| workspace_size = ops.sm100_cutlass_mla_get_workspace_size( | ||
| max_seqlen * block_size, b, sm_count, num_kv_splits=1) | ||
| workspace = torch.empty(workspace_size, | ||
| device="cuda", | ||
| dtype=torch.uint8) | ||
|
|
||
| out_ans = torch.empty(b, MAX_HEADS, dv, dtype=init_dtype) | ||
|
|
||
| ops.sm100_cutlass_mla_decode(out_ans, q_nope, q_pe, kv_cache_flat, | ||
| cache_seqlens, block_table, workspace, | ||
| scale, 1) | ||
| return out_ans[:, :h_q].contiguous() | ||
|
|
||
| def scaled_dot_product_attention(query, key, value, is_causal=False): | ||
| query = query.float() | ||
| key = key.float() | ||
| value = value.float() | ||
| key = key.repeat_interleave(h_q // h_kv, dim=0) | ||
| value = value.repeat_interleave(h_q // h_kv, dim=0) | ||
| attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) | ||
| if is_causal: | ||
| s_q = query.shape[-2] | ||
| s_k = key.shape[-2] | ||
| attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) | ||
| temp_mask = torch.ones(s_q, s_k, | ||
| dtype=torch.bool).tril(diagonal=s_k - s_q) | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias.to(query.dtype) | ||
| attn_weight += attn_bias | ||
| lse = attn_weight.logsumexp(dim=-1) | ||
| attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) | ||
| return attn_weight @ value, lse | ||
|
|
||
| def ref_mla(): | ||
| q_ = (q.to(torch.float) * descale_q).to(init_dtype) if use_fp8 else q | ||
| blocked_k_ = (blocked_k.to(torch.float) * | ||
| descale_k).to(init_dtype) if use_fp8 else blocked_k | ||
| blocked_v_ = (blocked_v.to(torch.float) * | ||
| descale_k).to(init_dtype) if use_fp8 else blocked_v | ||
| out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) | ||
| lse = torch.empty(b, h_q, s_q, dtype=torch.float32) | ||
| for i in range(b): | ||
| begin = i * max_seqlen_pad | ||
| end = begin + cache_seqlens[i] | ||
| out_i, lse_i = scaled_dot_product_attention( | ||
| q_[i].transpose(0, 1), | ||
| blocked_k_.view(-1, h_kv, d)[begin:end].transpose(0, 1), | ||
| blocked_v_.view(-1, h_kv, dv)[begin:end].transpose(0, 1), | ||
| is_causal=causal, | ||
| ) | ||
| out[i] = out_i.transpose(0, 1) | ||
| lse[i] = lse_i | ||
| return out, lse | ||
|
|
||
| out_cutlass = cutlass_mla() | ||
| out_torch, lse_torch = ref_mla() | ||
| # Extract the single token (s_q=1) slice to match cutlass output shape | ||
| out_torch_slice = out_torch[:, 0, :, :] # [b, h_q, dv] | ||
| cal_diff(out_cutlass, out_torch_slice, "out", use_fp8) | ||
|
|
||
| t = triton.testing.do_bench(cutlass_mla) | ||
| FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 | ||
| bytes = (total_seqlens * h_kv * d + | ||
| b * s_q * h_q * d) * (torch.finfo(torch_dtype).bits // 8) + ( | ||
| b * s_q * h_q * dv) * (torch.finfo(init_dtype).bits // 8) | ||
| print(f"{t:.3f} ms, {FLOPS / 10 ** 9 / t:.0f} TFLOPS,", | ||
| f"{bytes / 10 ** 6 / t:.0f} GB/s") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.