diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index 9328cad4bf94..e0be49cf39c3 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -8,8 +8,9 @@ steps: - csrc/ - tests/kernels/core - tests/kernels/test_top_k_per_row.py + - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py + - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py - label: Kernels Attention Test %N timeout_in_minutes: 35 diff --git a/benchmarks/kernels/bench_concat_mla_q.py b/benchmarks/kernels/bench_concat_mla_q.py new file mode 100644 index 000000000000..8d940484d6b3 --- /dev/null +++ b/benchmarks/kernels/bench_concat_mla_q.py @@ -0,0 +1,98 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import triton + +# DeepSeek V3 dimensions +NOPE_DIM = 512 +ROPE_DIM = 64 +NUM_HEADS = 128 + +NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + + +def get_configs(): + return NUM_TOKENS + + +def make_inputs(num_tokens, dtype): + """Create inputs matching the real code path. + + Args: + contiguous_nope: If False, simulate the transposed BMM output + (non-contiguous nope with stride pattern from + [N,B,L].transpose(0,1)). + """ + # Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L] + raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda") + ql_nope = raw.transpose(0, 1) + + q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda") + return ql_nope, q_pe + + +# ---- Non-contiguous nope benchmark (real code path) ---- +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=get_configs(), + line_arg="provider", + line_vals=["torch_cat", "concat_mla_q"], + line_names=["torch.cat", "concat_mla_q (v8)"], + styles=[("blue", "--"), ("green", "-")], + ylabel="Latency (us)", + plot_name="concat_mla_q-transposed", + args={}, + ) +) +def bench_transposed(num_tokens, provider): + dtype = torch.bfloat16 + ql_nope, q_pe = make_inputs(num_tokens, dtype) + + q_out = torch.empty( + num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda" + ) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch_cat": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500 + ) + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500 + ) + + return ms * 1000, max_ms * 1000, min_ms * 1000 # us + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat") + parser.add_argument( + "--save-path", type=str, default=None, help="Path to save benchmark results" + ) + args = parser.parse_args() + + print("\n" + "=" * 70) + print("CONCAT MLA Q KERNEL BENCHMARKS") + print("=" * 70) + print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}") + print( + f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = " + f"{(NOPE_DIM + ROPE_DIM) * 2} bytes" + ) + print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}") + print("=" * 70) + + print("\n--- Non-contiguous nope inputs (transposed BMM output) ---") + bench_transposed.run(print_data=True, save_path=args.save_path) + + print("\n" + "=" * 70) + print("Benchmarking complete!") + print("=" * 70) diff --git a/csrc/cache.h b/csrc/cache.h index 0c7823ffe9e2..0188a568edc7 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -74,6 +74,12 @@ void indexer_k_quant_and_cache( int64_t quant_block_size, // quantization block size const std::string& scale_fmt); +// Concatenate query nope and rope for MLA/DSA attention +void concat_mla_q( + torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] + torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] + torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim] + // Extract function to gather quantized K cache void cp_gather_indexer_k_quant_cache( const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3e8ffe15b42d..71050cfcad1a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -8,6 +8,7 @@ #include "cuda_compat.h" #include "dispatch_utils.h" #include "quantization/vectorization_utils.cuh" +#include "concat_mla_q.cuh" #ifdef USE_ROCM #include "quantization/w8a8/fp8/amd/quant_utils.cuh" @@ -1365,3 +1366,43 @@ void cp_gather_indexer_k_quant_cache( CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); } } + +// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA. +// Replaces torch.cat((ql_nope, q_pe), dim=-1). +void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] + torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] + torch::Tensor& q_out // [num_tokens, num_heads, nope_dim + + // rope_dim] +) { + const int num_tokens = ql_nope.size(0); + const int num_heads = ql_nope.size(1); + const int nope_dim = ql_nope.size(2); + const int rope_dim = q_pe.size(2); + + TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ", + nope_dim); + TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim); + TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim); + + TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2"); + TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2"); + TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2"); + + if (num_tokens == 0) return; + + constexpr int warps_per_block = 8; + const int total_warps = num_tokens * num_heads; + const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block; + const int block_size = warps_per_block * 32; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] { + vllm::ConcatMLAQKernel<<>>( + q_out.data_ptr(), ql_nope.data_ptr(), + q_pe.data_ptr(), num_tokens, num_heads, q_out.stride(0), + q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0), + q_pe.stride(1)); + }); +} diff --git a/csrc/concat_mla_q.cuh b/csrc/concat_mla_q.cuh new file mode 100644 index 000000000000..68bcfa011fb3 --- /dev/null +++ b/csrc/concat_mla_q.cuh @@ -0,0 +1,60 @@ +#ifndef CONCAT_MLA_Q_CUH_ +#define CONCAT_MLA_Q_CUH_ + +#include +#include + +#include "cuda_vec_utils.cuh" + +namespace vllm { + +// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and +// q_pe [num_tokens, num_heads, 64] +// into q_out [num_tokens, num_heads, NOPE_DIM+64]. +// Currently instantiated only for NOPE_DIM=512. +// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA) +template +__global__ void ConcatMLAQKernel( + DType* __restrict__ q_out, const DType* __restrict__ ql_nope, + const DType* __restrict__ q_pe, const int num_tokens, const int num_heads, + const int64_t out_stride_0, const int64_t out_stride_1, + const int64_t nope_stride_0, const int64_t nope_stride_1, + const int64_t pe_stride_0, const int64_t pe_stride_1) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; + if (flat_warp_id >= num_tokens * num_heads) return; + + const int token_id = flat_warp_id / num_heads; + const int head_id = flat_warp_id % num_heads; + const int lane_id = threadIdx.x & 31; + + constexpr bool use_256b = VLLM_256B_PTX_ENABLED; + constexpr int nope_vec_loads = + NOPE_DIM * sizeof(DType) / (VecTraits::ARCH_MAX_VEC_SIZE * 32); + + const DType* nope_src = + ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1; + DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1; + +#pragma unroll + for (int i = 0; i < nope_vec_loads; i++) { + const int offset = i * 32 + lane_id; + if constexpr (use_256b) { + st256_cs(reinterpret_cast(nope_dst) + offset, + ld256_cs(reinterpret_cast(nope_src) + offset)); + } else { + st128_cs(reinterpret_cast(nope_dst) + offset, + ld128_cs(reinterpret_cast(nope_src) + offset)); + } + } + + const int* rope_src = reinterpret_cast( + q_pe + token_id * pe_stride_0 + head_id * pe_stride_1); + int* rope_dst = reinterpret_cast(q_out + token_id * out_stride_0 + + head_id * out_stride_1 + NOPE_DIM); + + st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id)); +} + +} // namespace vllm + +#endif // CONCAT_MLA_Q_CUH_ diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 82a19f10a70e..8f997f3ba409 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -196,7 +196,6 @@ __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { return val; #else assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); - return {}; #endif } @@ -211,23 +210,51 @@ __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { #endif } -// 32-bit cache-streaming (.cs) load / store — SM100+ only. +// 32-bit load / store. +__device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); } + +__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; } + +// 32-bit cache-streaming (.cs) load / store. +// Falls back to ld32/st32 on ROCm (no .cs hint). __forceinline__ __device__ int ld32_cs(const int* addr) { -#if VLLM_256B_PTX_ENABLED int val; +#ifndef USE_ROCM asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); - return val; #else - assert(false && "ld32_cs requires SM100+ with CUDA 12.9+"); - return 0; + val = ld32(addr); #endif + return val; } __forceinline__ __device__ void st32_cs(int* addr, int val) { -#if VLLM_256B_PTX_ENABLED +#ifndef USE_ROCM asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); #else - assert(false && "st32_cs requires SM100+ with CUDA 12.9+"); + st32(addr, val); +#endif +} + +// 128-bit cache-streaming (.cs) load / store. +// Falls back to ld128/st128 on ROCm (no .cs hint). +__forceinline__ __device__ int4 ld128_cs(const int4* addr) { + int4 val; +#ifndef USE_ROCM + asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(addr)); +#else + ld128(val, addr); +#endif + return val; +} + +__forceinline__ __device__ void st128_cs(int4* addr, int4 val) { +#ifndef USE_ROCM + asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), + "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +#else + st128(val, addr); #endif } @@ -260,7 +287,7 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, bool pred) { -#if VLLM_256B_PTX_ENABLED +#ifndef USE_ROCM uint32_t r0, r1, r2, r3; asm volatile( @@ -278,7 +305,7 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, val = uint4{r0, r1, r2, r3}; #else - assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+"); + assert(false && "ld128_cg_or_zero is not supported on ROCm"); #endif } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index f7ea8c788dd0..d98e987d92a2 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -802,6 +802,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, &indexer_k_quant_and_cache); + cache_ops.def( + "concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()"); + cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q); + cache_ops.def( "cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! " "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); diff --git a/tests/kernels/test_concat_mla_q.py b/tests/kernels/test_concat_mla_q.py new file mode 100644 index 000000000000..fec5c063c7ca --- /dev/null +++ b/tests/kernels/test_concat_mla_q.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm import _custom_ops as ops + +NUM_TOKENS = [1, 4, 16, 64, 128] +NUM_HEADS = [128] +NOPE_DIM = [512] +ROPE_DIM = [64] +DTYPES = [torch.bfloat16, torch.float16] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("nope_dim", NOPE_DIM) +@pytest.mark.parametrize("rope_dim", ROPE_DIM) +@pytest.mark.parametrize("dtype", DTYPES) +def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, dtype): + """Test with contiguous inputs (standard layout).""" + torch.manual_seed(42) + ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda") + q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda") + + ref = torch.cat((ql_nope, q_pe), dim=-1) + + q_out = torch.empty( + num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda" + ) + ops.concat_mla_q(ql_nope, q_pe, q_out) + + torch.testing.assert_close(q_out, ref, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", [t for t in NUM_TOKENS if t > 1]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("nope_dim", NOPE_DIM) +@pytest.mark.parametrize("rope_dim", ROPE_DIM) +@pytest.mark.parametrize("dtype", DTYPES) +def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, rope_dim, dtype): + """Test with transposed nope input (simulates BMM output after transpose). + + In the real code path, mqa_ql_nope is the result of: + torch.bmm(q_nope, W_UK_T) # [N, B, L] + .transpose(0, 1) # [B, N, L] — non-contiguous! + """ + torch.manual_seed(42) + nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, device="cuda") + ql_nope = nope_raw.transpose(0, 1) # [B, N, L], non-contiguous + assert not ql_nope.is_contiguous() + + q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda") + + ref = torch.cat((ql_nope, q_pe), dim=-1) + + q_out = torch.empty( + num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda" + ) + ops.concat_mla_q(ql_nope, q_pe, q_out) + + torch.testing.assert_close(q_out, ref, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype): + """Test with rope from a split (simulates the actual code path). + + In the real code path, q_pe comes from: + mqa_q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + which creates a non-contiguous view with stride(1) != rope_dim. + """ + torch.manual_seed(42) + nope_dim = 512 + rope_dim = 64 + orig_dim = 128 + 64 # original q before absorption: [B, N, 192] + + # Simulate split from original q tensor + q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, device="cuda") + q_nope_orig, q_pe = q_orig.split([128, 64], dim=-1) + + # q_pe is non-contiguous: stride(1) = 192, not 64 + assert q_pe.stride(1) == orig_dim + assert q_pe.stride(2) == 1 # but innermost is fine + + # Simulate absorbed nope (contiguous, different size) + ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda") + + ref = torch.cat((ql_nope, q_pe), dim=-1) + + q_out = torch.empty( + num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda" + ) + ops.concat_mla_q(ql_nope, q_pe, q_out) + + torch.testing.assert_close(q_out, ref, atol=0, rtol=0) + + +def test_concat_mla_q_zero_tokens(): + """Test with zero tokens (edge case).""" + ql_nope = torch.empty(0, 128, 512, dtype=torch.bfloat16, device="cuda") + q_pe = torch.empty(0, 128, 64, dtype=torch.bfloat16, device="cuda") + q_out = torch.empty(0, 128, 576, dtype=torch.bfloat16, device="cuda") + + ops.concat_mla_q(ql_nope, q_pe, q_out) + + +@pytest.mark.parametrize("num_tokens", [1, 64]) +def test_concat_mla_q_values_preserved(num_tokens): + """Verify exact bit-level preservation (no computation, pure copy). + + Compares raw int16 bits to avoid NaN != NaN issues from IEEE 754. + """ + nope_dim, rope_dim = 512, 64 + + # Use specific bit patterns (stay in int16 for bit-exact comparison) + ql_nope_bits = torch.arange( + num_tokens * 128 * nope_dim, dtype=torch.int16, device="cuda" + ).view(num_tokens, 128, nope_dim) + q_pe_bits = torch.arange( + num_tokens * 128 * rope_dim, dtype=torch.int16, device="cuda" + ).view(num_tokens, 128, rope_dim) + + ql_nope = ql_nope_bits.view(torch.bfloat16) + q_pe = q_pe_bits.view(torch.bfloat16) + + q_out = torch.empty( + num_tokens, 128, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda" + ) + ops.concat_mla_q(ql_nope, q_pe, q_out) + + out_bits = q_out.view(torch.int16) + + assert torch.equal(out_bits[..., :nope_dim], ql_nope_bits) + + assert torch.equal(out_bits[..., nope_dim:], q_pe_bits) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index e03a4c149689..dd2cca9b7443 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2672,6 +2672,21 @@ def cp_gather_and_upconvert_fp8_kv_cache( ) +def concat_mla_q( + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + q_out: torch.Tensor, +) -> None: + """Concatenate query nope and rope for MLA/DSA attention. + + Args: + ql_nope: Query nope component [num_tokens, num_heads, nope_dim] + q_pe: Query rope component [num_tokens, num_heads, rope_dim] + q_out: Output tensor [num_tokens, num_heads, nope_dim + rope_dim] + """ + torch.ops._C_cache_ops.concat_mla_q(ql_nope, q_pe, q_out) + + def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index c0cdc204d2df..7cc50ec84584 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -568,6 +568,9 @@ def __init__( ) self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads) + vllm_config = get_current_vllm_config() + max_tokens = vllm_config.scheduler_config.max_num_batched_tokens + q_concat_shape = (max_tokens, num_heads, head_size) if kv_cache_dtype.startswith("fp8"): assert kv_cache_dtype == "fp8_ds_mla", ( "FlashMLA Sparse Attention backend fp8 only supports " @@ -576,17 +579,21 @@ def __init__( if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization - vllm_config = get_current_vllm_config() assert vllm_config is not None and vllm_config.model_config is not None prefill_workspace_size = get_prefill_workspace_size( vllm_config.model_config.max_model_len ) self.prefill_workspace_shape = (prefill_workspace_size, head_size) - (self.prefill_bf16_workspace,) = ( + self.q_concat_buffer, self.prefill_bf16_workspace = ( current_workspace_manager().get_simultaneous( - (self.prefill_workspace_shape, torch.bfloat16) + (q_concat_shape, torch.bfloat16), + (self.prefill_workspace_shape, torch.bfloat16), ) ) + else: + (self.q_concat_buffer,) = current_workspace_manager().get_simultaneous( + (q_concat_shape, torch.bfloat16), + ) def _forward_bf16_kv( self, @@ -828,7 +835,9 @@ def forward_mqa( # Concatenate q if it's a tuple (ql_nope, q_pe) if isinstance(q, tuple): - q = torch.cat(q, dim=-1) + ql_nope, q_pe = q + q = self.q_concat_buffer[: ql_nope.shape[0]] + ops.concat_mla_q(ql_nope, q_pe, q) num_actual_toks = q.shape[0]