From e983155256652376d5a3baaad6bf08f82fe638fb Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Thu, 19 Feb 2026 20:14:24 +0000 Subject: [PATCH 01/11] add custom concat kernel Signed-off-by: LopezCastroRoberto --- csrc/cache.h | 5 ++ csrc/cache_kernels.cu | 65 ++++++++++++++ csrc/concat_mla_q.cuh | 85 +++++++++++++++++++ csrc/torch_bindings.cpp | 4 + vllm/_custom_ops.py | 13 +++ .../attention/backends/mla/flashmla_sparse.py | 6 +- 6 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 csrc/concat_mla_q.cuh diff --git a/csrc/cache.h b/csrc/cache.h index 0c7823ffe9e2..e2c0df3ebcdd 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -74,6 +74,11 @@ void indexer_k_quant_and_cache( int64_t quant_block_size, // quantization block size const std::string& scale_fmt); +// Concatenate query nope and rope for MLA/DSA attention +void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] + torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] + torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim] + // Extract function to gather quantized K cache void cp_gather_indexer_k_quant_cache( const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3e8ffe15b42d..b633f1257cee 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -8,6 +8,7 @@ #include "cuda_compat.h" #include "dispatch_utils.h" #include "quantization/vectorization_utils.cuh" +#include "concat_mla_q.cuh" #ifdef USE_ROCM #include "quantization/w8a8/fp8/amd/quant_utils.cuh" @@ -1365,3 +1366,67 @@ void cp_gather_indexer_k_quant_cache( CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); } } + + +// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA. +// Replaces torch.cat((ql_nope, q_pe), dim=-1). +void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] + torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] + torch::Tensor& q_out // [num_tokens, num_heads, nope_dim + + // rope_dim] +) { + const int num_tokens = ql_nope.size(0); + const int num_heads = ql_nope.size(1); + const int nope_dim = ql_nope.size(2); + const int rope_dim = q_pe.size(2); + + TORCH_CHECK(nope_dim % 512 == 0, + "nope_dim must be a multiple of 512, got ", nope_dim); + TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim); + TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim); + + TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2"); + TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2"); + TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2"); + + if (num_tokens == 0) return; + + const int nope_v8_loads = nope_dim / 512; + + constexpr int warps_per_block = 32; + const int total_warps = num_tokens * num_heads; + const int grid_size = + (total_warps + warps_per_block - 1) / warps_per_block; + const int block_size = warps_per_block * 32; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + VLLM_DISPATCH_FLOATING_TYPES( + ql_nope.scalar_type(), "concat_mla_q", [&] { + auto* out_ptr = q_out.data_ptr(); + auto* nope_ptr = ql_nope.data_ptr(); + auto* pe_ptr = q_pe.data_ptr(); + auto out_s0 = q_out.stride(0); + auto out_s1 = q_out.stride(1); + auto nope_s0 = ql_nope.stride(0); + auto nope_s1 = ql_nope.stride(1); + auto pe_s0 = q_pe.stride(0); + auto pe_s1 = q_pe.stride(1); + + if (nope_v8_loads == 1) { + vllm::ConcatMLAQKernel + <<>>( + out_ptr, nope_ptr, pe_ptr, num_tokens, num_heads, nope_dim, + out_s0, out_s1, nope_s0, nope_s1, pe_s0, pe_s1); + } else if (nope_v8_loads == 2) { + vllm::ConcatMLAQKernel + <<>>( + out_ptr, nope_ptr, pe_ptr, num_tokens, num_heads, nope_dim, + out_s0, out_s1, nope_s0, nope_s1, pe_s0, pe_s1); + } else { + TORCH_CHECK(false, "Unsupported nope_dim: ", nope_dim, + " (nope_v8_loads=", nope_v8_loads, ")"); + } + }); +} diff --git a/csrc/concat_mla_q.cuh b/csrc/concat_mla_q.cuh new file mode 100644 index 000000000000..9b171cb6f856 --- /dev/null +++ b/csrc/concat_mla_q.cuh @@ -0,0 +1,85 @@ + #ifndef CONCAT_MLA_Q_CUH_ + #define CONCAT_MLA_Q_CUH_ + + #include + #include + #include + + namespace vllm { + + struct __align__(32) vec8 { + unsigned int d[8]; + }; + + __forceinline__ __device__ vec8 ld_cs_v8(const vec8* addr) { + vec8 val; + asm volatile( + "ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(addr)); + return val; + } + + __forceinline__ __device__ void st_cs_v8(vec8* addr, vec8 val) { + asm volatile( + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" + ::"l"(addr), + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), + "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); + } + + __forceinline__ __device__ int ld_cs_v1(const int* addr) { + int val; + asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); + return val; + } + + __forceinline__ __device__ void st_cs_v1(int* addr, int val) { + asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); + } + + template + __global__ void ConcatMLAQKernel( + DType* __restrict__ q_out, + const DType* __restrict__ ql_nope, + const DType* __restrict__ q_pe, + const int num_tokens, + const int num_heads, + const int nope_dim, + const int64_t out_stride_0, + const int64_t out_stride_1, + const int64_t nope_stride_0, + const int64_t nope_stride_1, + const int64_t pe_stride_0, + const int64_t pe_stride_1) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; + if (flat_warp_id >= num_tokens * num_heads) return; + + const int token_id = flat_warp_id / num_heads; + const int head_id = flat_warp_id % num_heads; + const int lane_id = threadIdx.x & 31; + + const vec8* nope_src = reinterpret_cast( + ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1); + vec8* nope_dst = reinterpret_cast( + q_out + token_id * out_stride_0 + head_id * out_stride_1); + + #pragma unroll + for (int i = 0; i < NOPE_V8_LOADS; i++) { + const int offset = i * 32 + lane_id; + st_cs_v8(nope_dst + offset, ld_cs_v8(nope_src + offset)); + } + + const int* rope_src = reinterpret_cast( + q_pe + token_id * pe_stride_0 + head_id * pe_stride_1); + int* rope_dst = reinterpret_cast( + q_out + token_id * out_stride_0 + head_id * out_stride_1 + nope_dim); + + st_cs_v1(rope_dst + lane_id, ld_cs_v1(rope_src + lane_id)); + } + + } // namespace vllm + + #endif // CONCAT_MLA_Q_CUH_ + \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 97c9eb7428c9..15a9e067207b 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -779,6 +779,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, &indexer_k_quant_and_cache); + cache_ops.def( + "concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()"); + cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q); + cache_ops.def( "cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! " "dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()"); diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 9268eea50924..808a9e28df23 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2567,6 +2567,19 @@ def cp_gather_and_upconvert_fp8_kv_cache( src_cache, dst, block_table, seq_lens, workspace_starts, batch_size ) +def concat_mla_q( + ql_nope: torch.Tensor, + q_pe: torch.Tensor, + q_out: torch.Tensor, +) -> None: + """Concatenate query nope and rope for MLA/DSA attention. + + Args: + ql_nope: Query nope component [num_tokens, num_heads, nope_dim] + q_pe: Query rope component [num_tokens, num_heads, rope_dim] + q_out: Output tensor [num_tokens, num_heads, nope_dim + rope_dim] + """ + torch.ops._C_cache_ops.concat_mla_q(ql_nope, q_pe, q_out) def indexer_k_quant_and_cache( k: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 799c77d73ad2..df4616fa1f8c 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -821,7 +821,11 @@ def forward_mqa( # Concatenate q if it's a tuple (ql_nope, q_pe) if isinstance(q, tuple): - q = torch.cat(q, dim=-1) + ql_nope, q_pe = q + q = torch.empty(ql_nope.shape[0], ql_nope.shape[1], + ql_nope.shape[2] + q_pe.shape[2], + dtype=ql_nope.dtype, device=ql_nope.device) + ops.concat_mla_q(ql_nope, q_pe, q) num_actual_toks = q.shape[0] From 0ce15f9da75ee33fe1364f51a47a26c712118a00 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Thu, 19 Feb 2026 23:37:49 +0000 Subject: [PATCH 02/11] add tests Signed-off-by: LopezCastroRoberto --- tests/kernels/test_concat_mla_q.py | 141 ++++++++++++++++++ .../attention/backends/mla/flashmla_sparse.py | 15 +- 2 files changed, 151 insertions(+), 5 deletions(-) create mode 100644 tests/kernels/test_concat_mla_q.py diff --git a/tests/kernels/test_concat_mla_q.py b/tests/kernels/test_concat_mla_q.py new file mode 100644 index 000000000000..4863250b1969 --- /dev/null +++ b/tests/kernels/test_concat_mla_q.py @@ -0,0 +1,141 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from vllm import _custom_ops as ops + +NUM_TOKENS = [1, 4, 16, 64, 128] +NUM_HEADS = [128] +NOPE_DIM = [512] +ROPE_DIM = [64] +DTYPES = [torch.bfloat16, torch.float16] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("nope_dim", NOPE_DIM) +@pytest.mark.parametrize("rope_dim", ROPE_DIM) +@pytest.mark.parametrize("dtype", DTYPES) +def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, + dtype): + """Test with contiguous inputs (standard layout).""" + torch.manual_seed(42) + ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, + device="cuda") + q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, + device="cuda") + + ref = torch.cat((ql_nope, q_pe), dim=-1) + + q_out = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, + dtype=dtype, device="cuda") + ops.concat_mla_q(ql_nope, q_pe, q_out) + + torch.testing.assert_close(q_out, ref, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", [t for t in NUM_TOKENS if t > 1]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("nope_dim", NOPE_DIM) +@pytest.mark.parametrize("rope_dim", ROPE_DIM) +@pytest.mark.parametrize("dtype", DTYPES) +def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, + rope_dim, dtype): + """Test with transposed nope input (simulates BMM output after transpose). + + In the real code path, mqa_ql_nope is the result of: + torch.bmm(q_nope, W_UK_T) # [N, B, L] + .transpose(0, 1) # [B, N, L] — non-contiguous! + """ + torch.manual_seed(42) + nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, + device="cuda") + ql_nope = nope_raw.transpose(0, 1) # [B, N, L], non-contiguous + assert not ql_nope.is_contiguous() + + q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, + device="cuda") + + ref = torch.cat((ql_nope, q_pe), dim=-1) + + q_out = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, + dtype=dtype, device="cuda") + ops.concat_mla_q(ql_nope, q_pe, q_out) + + torch.testing.assert_close(q_out, ref, atol=0, rtol=0) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype): + """Test with rope from a split (simulates the actual code path). + + In the real code path, q_pe comes from: + mqa_q.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + which creates a non-contiguous view with stride(1) != rope_dim. + """ + torch.manual_seed(42) + nope_dim = 512 + rope_dim = 64 + orig_dim = 128 + 64 # original q before absorption: [B, N, 192] + + # Simulate split from original q tensor + q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, + device="cuda") + q_nope_orig, q_pe = q_orig.split([128, 64], dim=-1) + + # q_pe is non-contiguous: stride(1) = 192, not 64 + assert q_pe.stride(1) == orig_dim + assert q_pe.stride(2) == 1 # but innermost is fine + + # Simulate absorbed nope (contiguous, different size) + ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, + device="cuda") + + ref = torch.cat((ql_nope, q_pe), dim=-1) + + q_out = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, + dtype=dtype, device="cuda") + ops.concat_mla_q(ql_nope, q_pe, q_out) + + torch.testing.assert_close(q_out, ref, atol=0, rtol=0) + + +def test_concat_mla_q_zero_tokens(): + """Test with zero tokens (edge case).""" + ql_nope = torch.empty(0, 128, 512, dtype=torch.bfloat16, device="cuda") + q_pe = torch.empty(0, 128, 64, dtype=torch.bfloat16, device="cuda") + q_out = torch.empty(0, 128, 576, dtype=torch.bfloat16, device="cuda") + + ops.concat_mla_q(ql_nope, q_pe, q_out) + + +@pytest.mark.parametrize("num_tokens", [1, 64]) +def test_concat_mla_q_values_preserved(num_tokens): + """Verify exact bit-level preservation (no computation, pure copy). + + Compares raw int16 bits to avoid NaN != NaN issues from IEEE 754. + """ + nope_dim, rope_dim = 512, 64 + + # Use specific bit patterns (stay in int16 for bit-exact comparison) + ql_nope_bits = torch.arange(num_tokens * 128 * nope_dim, dtype=torch.int16, + device="cuda").view(num_tokens, 128, nope_dim) + q_pe_bits = torch.arange(num_tokens * 128 * rope_dim, dtype=torch.int16, + device="cuda").view(num_tokens, 128, rope_dim) + + ql_nope = ql_nope_bits.view(torch.bfloat16) + q_pe = q_pe_bits.view(torch.bfloat16) + + q_out = torch.empty(num_tokens, 128, nope_dim + rope_dim, + dtype=torch.bfloat16, device="cuda") + ops.concat_mla_q(ql_nope, q_pe, q_out) + + out_bits = q_out.view(torch.int16) + + assert torch.equal(out_bits[..., :nope_dim], ql_nope_bits) + + assert torch.equal(out_bits[..., nope_dim:], q_pe_bits) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index df4616fa1f8c..a30630b9fc07 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -567,9 +567,16 @@ def __init__( ) self.fp8_decode_padded_heads = self._compute_fp8_decode_padded_heads(num_heads) + vllm_config = get_current_vllm_config() + max_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.q_concat_buffer = torch.empty( + max_tokens, num_heads, head_size, + dtype=torch.bfloat16, + device=current_platform.device_type, + ) + if kv_cache_dtype == "fp8_ds_mla": - # Reserve workspace during initialization - vllm_config = get_current_vllm_config() + # Reserve workspace during initialization assert vllm_config is not None and vllm_config.model_config is not None prefill_workspace_size = get_prefill_workspace_size( vllm_config.model_config.max_model_len @@ -822,9 +829,7 @@ def forward_mqa( # Concatenate q if it's a tuple (ql_nope, q_pe) if isinstance(q, tuple): ql_nope, q_pe = q - q = torch.empty(ql_nope.shape[0], ql_nope.shape[1], - ql_nope.shape[2] + q_pe.shape[2], - dtype=ql_nope.dtype, device=ql_nope.device) + q = self.q_concat_buffer[:ql_nope.shape[0]] ops.concat_mla_q(ql_nope, q_pe, q) num_actual_toks = q.shape[0] From 09150f912ebd3b57aee0ed65684afd1a17267c22 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Mon, 23 Feb 2026 21:06:09 +0000 Subject: [PATCH 03/11] add helper vec instructions Signed-off-by: LopezCastroRoberto --- csrc/cache_kernels.cu | 38 ++--- csrc/concat_mla_q.cuh | 145 ++++++++---------- csrc/cuda_vec_utils.cuh | 327 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 398 insertions(+), 112 deletions(-) create mode 100644 csrc/cuda_vec_utils.cuh diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index b633f1257cee..d7d4716b3e4e 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1367,7 +1367,6 @@ void cp_gather_indexer_k_quant_cache( } } - // Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA. // Replaces torch.cat((ql_nope, q_pe), dim=-1). void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] @@ -1391,42 +1390,25 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] if (num_tokens == 0) return; - const int nope_v8_loads = nope_dim / 512; - constexpr int warps_per_block = 32; const int total_warps = num_tokens * num_heads; const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block; - const int block_size = warps_per_block * 32; + const int block_size = warps_per_block * 32; const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( ql_nope.scalar_type(), "concat_mla_q", [&] { - auto* out_ptr = q_out.data_ptr(); - auto* nope_ptr = ql_nope.data_ptr(); - auto* pe_ptr = q_pe.data_ptr(); - auto out_s0 = q_out.stride(0); - auto out_s1 = q_out.stride(1); - auto nope_s0 = ql_nope.stride(0); - auto nope_s1 = ql_nope.stride(1); - auto pe_s0 = q_pe.stride(0); - auto pe_s1 = q_pe.stride(1); - - if (nope_v8_loads == 1) { - vllm::ConcatMLAQKernel - <<>>( - out_ptr, nope_ptr, pe_ptr, num_tokens, num_heads, nope_dim, - out_s0, out_s1, nope_s0, nope_s1, pe_s0, pe_s1); - } else if (nope_v8_loads == 2) { - vllm::ConcatMLAQKernel - <<>>( - out_ptr, nope_ptr, pe_ptr, num_tokens, num_heads, nope_dim, - out_s0, out_s1, nope_s0, nope_s1, pe_s0, pe_s1); - } else { - TORCH_CHECK(false, "Unsupported nope_dim: ", nope_dim, - " (nope_v8_loads=", nope_v8_loads, ")"); - } + vllm::ConcatMLAQKernel + <<>>( + q_out.data_ptr(), + ql_nope.data_ptr(), + q_pe.data_ptr(), + num_tokens, num_heads, + q_out.stride(0), q_out.stride(1), + ql_nope.stride(0), ql_nope.stride(1), + q_pe.stride(0), q_pe.stride(1)); }); } diff --git a/csrc/concat_mla_q.cuh b/csrc/concat_mla_q.cuh index 9b171cb6f856..dc26662bfe62 100644 --- a/csrc/concat_mla_q.cuh +++ b/csrc/concat_mla_q.cuh @@ -1,85 +1,62 @@ - #ifndef CONCAT_MLA_Q_CUH_ - #define CONCAT_MLA_Q_CUH_ - - #include - #include - #include - - namespace vllm { - - struct __align__(32) vec8 { - unsigned int d[8]; - }; +#ifndef CONCAT_MLA_Q_CUH_ +#define CONCAT_MLA_Q_CUH_ - __forceinline__ __device__ vec8 ld_cs_v8(const vec8* addr) { - vec8 val; - asm volatile( - "ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" - : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), - "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) - : "l"(addr)); - return val; - } - - __forceinline__ __device__ void st_cs_v8(vec8* addr, vec8 val) { - asm volatile( - "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" - ::"l"(addr), - "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), - "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); - } - - __forceinline__ __device__ int ld_cs_v1(const int* addr) { - int val; - asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); - return val; - } - - __forceinline__ __device__ void st_cs_v1(int* addr, int val) { - asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); - } - - template - __global__ void ConcatMLAQKernel( - DType* __restrict__ q_out, - const DType* __restrict__ ql_nope, - const DType* __restrict__ q_pe, - const int num_tokens, - const int num_heads, - const int nope_dim, - const int64_t out_stride_0, - const int64_t out_stride_1, - const int64_t nope_stride_0, - const int64_t nope_stride_1, - const int64_t pe_stride_0, - const int64_t pe_stride_1) { - const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; - if (flat_warp_id >= num_tokens * num_heads) return; - - const int token_id = flat_warp_id / num_heads; - const int head_id = flat_warp_id % num_heads; - const int lane_id = threadIdx.x & 31; - - const vec8* nope_src = reinterpret_cast( - ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1); - vec8* nope_dst = reinterpret_cast( - q_out + token_id * out_stride_0 + head_id * out_stride_1); - - #pragma unroll - for (int i = 0; i < NOPE_V8_LOADS; i++) { - const int offset = i * 32 + lane_id; - st_cs_v8(nope_dst + offset, ld_cs_v8(nope_src + offset)); - } - - const int* rope_src = reinterpret_cast( - q_pe + token_id * pe_stride_0 + head_id * pe_stride_1); - int* rope_dst = reinterpret_cast( - q_out + token_id * out_stride_0 + head_id * out_stride_1 + nope_dim); - - st_cs_v1(rope_dst + lane_id, ld_cs_v1(rope_src + lane_id)); - } - - } // namespace vllm - - #endif // CONCAT_MLA_Q_CUH_ - \ No newline at end of file +#include +#include + +#include "cuda_vec_utils.cuh" + +namespace vllm { + +template +__global__ void ConcatMLAQKernel( + DType* __restrict__ q_out, + const DType* __restrict__ ql_nope, + const DType* __restrict__ q_pe, + const int num_tokens, + const int num_heads, + const int64_t out_stride_0, + const int64_t out_stride_1, + const int64_t nope_stride_0, + const int64_t nope_stride_1, + const int64_t pe_stride_0, + const int64_t pe_stride_1) { + const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; + if (flat_warp_id >= num_tokens * num_heads) return; + + const int token_id = flat_warp_id / num_heads; + const int head_id = flat_warp_id % num_heads; + const int lane_id = threadIdx.x & 31; + + constexpr bool use_256b = VLLM_256B_PTX_ENABLED; + constexpr int nope_vec_loads = NOPE_DIM * sizeof(DType) / + (VecTraits::ARCH_MAX_VEC_SIZE * 32); + + const DType* nope_src = + ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1; + DType* nope_dst = + q_out + token_id * out_stride_0 + head_id * out_stride_1; + +#pragma unroll + for (int i = 0; i < nope_vec_loads; i++) { + const int offset = i * 32 + lane_id; + if constexpr (use_256b) { + st256_cs(reinterpret_cast(nope_dst) + offset, + ld256_cs(reinterpret_cast(nope_src) + offset)); + } else { + st128_cs(reinterpret_cast(nope_dst) + offset, + ld128_cs(reinterpret_cast(nope_src) + offset)); + } + } + + const int* rope_src = reinterpret_cast( + q_pe + token_id * pe_stride_0 + head_id * pe_stride_1); + int* rope_dst = reinterpret_cast( + q_out + token_id * out_stride_0 + head_id * out_stride_1 + NOPE_DIM); + + st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id)); +} + +} // namespace vllm + +#endif // CONCAT_MLA_Q_CUH_ diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh new file mode 100644 index 000000000000..0fd41efb1359 --- /dev/null +++ b/csrc/cuda_vec_utils.cuh @@ -0,0 +1,327 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#pragma once + +#include +#include +#include +#include +#include + +// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which +// together enable 256-bit (v8.u32) PTX load/store instructions. +// Use for PTX instruction selection with architecture fallback paths. +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ + defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + #define VLLM_256B_PTX_ENABLED 1 +#else + #define VLLM_256B_PTX_ENABLED 0 +#endif + +#ifndef USE_ROCM + +namespace vllm { + +// ============================================================ +// Types and traits +// ============================================================ + +// 256-bit (32-byte) aligned vector type: 8 x uint32_t +struct alignas(32) u32x8_t { + uint32_t d[8]; +}; + +// VecTraits — select between 128-bit (int4) and 256-bit +// (u32x8_t) vector types at compile time. +template +struct VecTraits; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 32; + using vec_t = u32x8_t; +}; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 16; + using vec_t = int4; +}; + +// TypeConverter — map between CUDA scalar and packed types +// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc. +template +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +template <> +struct TypeConverter { + using Type = float2; +}; + +template <> +struct TypeConverter { + using Type = float; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter { + using Type = __nv_bfloat162; +}; + +// CUDATypeConverter — map PyTorch scalar types to CUDA scalar +// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16 +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// PackedVec — typed vector container for packed element access. +// Derives alignment and element count from VecTraits. +// Type is the CUDA scalar type (e.g. half, __nv_bfloat16). +template +struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { + static constexpr int NUM_ELTS = VecTraits::ARCH_MAX_VEC_SIZE / + sizeof(typename TypeConverter::Type); + typename TypeConverter::Type elts[NUM_ELTS]; +}; + +// ============================================================ +// Load / store primitives +// ============================================================ + +// 256-bit load / store with architecture fallback. +// SM100+ : PTX v8 instructions (.nc / default hint) +// Older : two uint4 loads via __ldg +__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { + #if VLLM_256B_PTX_ENABLED + asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(ptr)); + #else + const uint4* src = reinterpret_cast(ptr); + uint4* dst = reinterpret_cast(val.d); + dst[0] = __ldg(&src[0]); + dst[1] = __ldg(&src[1]); + #endif +} + +__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { + #if VLLM_256B_PTX_ENABLED + asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" + : + : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), + "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), + "r"(val.d[7]) + : "memory"); + #else + uint4* dst = reinterpret_cast(ptr); + const uint4* src = reinterpret_cast(val.d); + dst[0] = src[0]; + dst[1] = src[1]; + #endif +} + +// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). +// Non-template overloads above are preferred for u32x8_t. +template +__device__ __forceinline__ void ld256(T& val, const T* ptr) { + static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type"); + ld256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st256(T& val, T* ptr) { + static_assert(sizeof(T) == 32, "st256 requires a 32-byte type"); + st256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +// 128-bit load / store via __ldg (read-only cache hint). +template +__device__ __forceinline__ void ld128(T& val, const T* ptr) { + static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type"); + *reinterpret_cast(&val) = __ldg(reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st128(T& val, T* ptr) { + static_assert(sizeof(T) == 16, "st128 requires a 16-byte type"); + *reinterpret_cast(ptr) = *reinterpret_cast(&val); +} + +// 256-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { + u32x8_t val; + asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(addr)); + return val; +} + +__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { + asm volatile( + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), + "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); +} + +// 32-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ int ld32_cs(const int* addr) { + int val; + asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); + return val; +} + +__forceinline__ __device__ void st32_cs(int* addr, int val) { + asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); +} + +// 128-bit cache-streaming (.cs) load / store. +__forceinline__ __device__ int4 ld128_cs(const int4* addr) { + int4 val; + asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(addr)); + return val; +} + +__forceinline__ __device__ void st128_cs(int4* addr, int4 val) { + asm volatile( + "st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), + "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +} + +// Predicated 256-bit / 128-bit cache-global (.cg) loads. +// Returns zero if pred is false. SM100+ only. +__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, + bool pred) { + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %8, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " mov.u32 %4, 0;\n" + " mov.u32 %5, 0;\n" + " mov.u32 %6, 0;\n" + " mov.u32 %7, 0;\n" + " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" + "}\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "r"((int)pred), "l"(ptr)); +} + +__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, + bool pred) { + uint32_t r0, r1, r2, r3; + + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %4, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "r"((int)pred), "l"(ptr)); + + val = uint4{r0, r1, r2, r3}; +} + +// ============================================================ +// Alignment helpers +// ============================================================ + +__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 15) == 0; +} + +__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 31) == 0; +} + +// ============================================================ +// Packed type conversion and arithmetic +// ============================================================ + +template +__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { + if constexpr (std::is_same_v) { + return __bfloat1622float2(val); + } else if constexpr (std::is_same_v) { + return __half22float2(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { + if constexpr (std::is_same_v) { + return __float22bfloat162_rn(val); + } else if constexpr (std::is_same_v) { + return __float22half2_rn(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t packed_mul(const packed_t& x, + const packed_t& y) { + if constexpr (std::is_same_v || + std::is_same_v) { + return __hmul2(x, y); + } else if constexpr (std::is_same_v) { + return make_float2(x.x * y.x, x.y * y.y); + } +} + +} // namespace vllm + +#endif // !USE_ROCM From 5a8fe25036a1c2fb16243b07726b7e8a34587f36 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 25 Feb 2026 10:46:13 +0000 Subject: [PATCH 04/11] adding benchmark script Signed-off-by: LopezCastroRoberto --- benchmarks/kernels/bench_concat_mla_q.py | 99 ++++++++++++++++++++++++ csrc/concat_mla_q.cuh | 5 ++ csrc/cuda_vec_utils.cuh | 10 --- 3 files changed, 104 insertions(+), 10 deletions(-) create mode 100644 benchmarks/kernels/bench_concat_mla_q.py diff --git a/benchmarks/kernels/bench_concat_mla_q.py b/benchmarks/kernels/bench_concat_mla_q.py new file mode 100644 index 000000000000..5149791b0ea5 --- /dev/null +++ b/benchmarks/kernels/bench_concat_mla_q.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import argparse +import itertools + +import torch + +from vllm import _custom_ops as ops +from vllm.triton_utils import triton + +# DeepSeek V3 dimensions +NOPE_DIM = 512 +ROPE_DIM = 64 +NUM_HEADS = 128 + +NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] + + +def get_configs(): + return NUM_TOKENS + + +def make_inputs(num_tokens, dtype): + """Create inputs matching the real code path. + + Args: + contiguous_nope: If False, simulate the transposed BMM output + (non-contiguous nope with stride pattern from + [N,B,L].transpose(0,1)). + """ + # Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L] + raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, + device="cuda") + ql_nope = raw.transpose(0, 1) + + q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, + device="cuda") + return ql_nope, q_pe + + +# ---- Non-contiguous nope benchmark (real code path) ---- +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens"], + x_vals=get_configs(), + line_arg="provider", + line_vals=["torch_cat", "concat_mla_q"], + line_names=["torch.cat", "concat_mla_q (v8)"], + styles=[("blue", "--"), ("green", "-")], + ylabel="Latency (us)", + plot_name="concat_mla_q-transposed", + args={}, + ) +) +def bench_transposed(num_tokens, provider): + dtype = torch.bfloat16 + ql_nope, q_pe = make_inputs(num_tokens, dtype) + + q_out = torch.empty(num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, + dtype=dtype, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch_cat": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.cat((ql_nope, q_pe), dim=-1), + quantiles=quantiles, rep=500) + else: + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), + quantiles=quantiles, rep=500) + + return ms * 1000, max_ms * 1000, min_ms * 1000 # us + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Benchmark concat_mla_q vs torch.cat" + ) + parser.add_argument("--save-path", type=str, default=None, + help="Path to save benchmark results") + args = parser.parse_args() + + print("\n" + "=" * 70) + print("CONCAT MLA Q KERNEL BENCHMARKS") + print("=" * 70) + print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}") + print(f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = " + f"{(NOPE_DIM + ROPE_DIM) * 2} bytes") + print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}") + print("=" * 70) + + print("\n--- Non-contiguous nope inputs (transposed BMM output) ---") + bench_transposed.run(print_data=True, save_path=args.save_path) + + print("\n" + "=" * 70) + print("Benchmarking complete!") + print("=" * 70) diff --git a/csrc/concat_mla_q.cuh b/csrc/concat_mla_q.cuh index dc26662bfe62..bd3d36917ed8 100644 --- a/csrc/concat_mla_q.cuh +++ b/csrc/concat_mla_q.cuh @@ -8,6 +8,11 @@ namespace vllm { +// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and +// q_pe [num_tokens, num_heads, 64] +// into q_out [num_tokens, num_heads, NOPE_DIM+64]. +// Currently instantiated only for NOPE_DIM=512. +// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA) template __global__ void ConcatMLAQKernel( DType* __restrict__ q_out, diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 0fd41efb1359..4625ed8c8350 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -136,11 +136,6 @@ __device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) : "l"(ptr)); - #else - const uint4* src = reinterpret_cast(ptr); - uint4* dst = reinterpret_cast(val.d); - dst[0] = __ldg(&src[0]); - dst[1] = __ldg(&src[1]); #endif } @@ -152,11 +147,6 @@ __device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]) : "memory"); - #else - uint4* dst = reinterpret_cast(ptr); - const uint4* src = reinterpret_cast(val.d); - dst[0] = src[0]; - dst[1] = src[1]; #endif } From 9b5c159b2b52e676f0f302e294d34fb1e990a2e5 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 25 Feb 2026 10:48:55 +0000 Subject: [PATCH 05/11] adding benchmark script Signed-off-by: LopezCastroRoberto --- benchmarks/kernels/bench_concat_mla_q.py | 33 ++++++------ csrc/cache.h | 7 +-- csrc/cache_kernels.cu | 32 +++++------ csrc/concat_mla_q.cuh | 31 +++++------ csrc/cuda_vec_utils.cuh | 5 +- csrc/torch_bindings.cpp | 2 +- tests/kernels/test_concat_mla_q.py | 54 +++++++++---------- vllm/_custom_ops.py | 2 + .../attention/backends/mla/flashmla_sparse.py | 8 +-- 9 files changed, 81 insertions(+), 93 deletions(-) diff --git a/benchmarks/kernels/bench_concat_mla_q.py b/benchmarks/kernels/bench_concat_mla_q.py index 5149791b0ea5..8d940484d6b3 100644 --- a/benchmarks/kernels/bench_concat_mla_q.py +++ b/benchmarks/kernels/bench_concat_mla_q.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse -import itertools import torch @@ -30,12 +29,10 @@ def make_inputs(num_tokens, dtype): [N,B,L].transpose(0,1)). """ # Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L] - raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, - device="cuda") + raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda") ql_nope = raw.transpose(0, 1) - q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, - device="cuda") + q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda") return ql_nope, q_pe @@ -57,37 +54,39 @@ def bench_transposed(num_tokens, provider): dtype = torch.bfloat16 ql_nope, q_pe = make_inputs(num_tokens, dtype) - q_out = torch.empty(num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, - dtype=dtype, device="cuda") + q_out = torch.empty( + num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda" + ) quantiles = [0.5, 0.2, 0.8] if provider == "torch_cat": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: torch.cat((ql_nope, q_pe), dim=-1), - quantiles=quantiles, rep=500) + lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500 + ) else: ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( - lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), - quantiles=quantiles, rep=500) + lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500 + ) return ms * 1000, max_ms * 1000, min_ms * 1000 # us if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Benchmark concat_mla_q vs torch.cat" + parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat") + parser.add_argument( + "--save-path", type=str, default=None, help="Path to save benchmark results" ) - parser.add_argument("--save-path", type=str, default=None, - help="Path to save benchmark results") args = parser.parse_args() print("\n" + "=" * 70) print("CONCAT MLA Q KERNEL BENCHMARKS") print("=" * 70) print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}") - print(f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = " - f"{(NOPE_DIM + ROPE_DIM) * 2} bytes") + print( + f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = " + f"{(NOPE_DIM + ROPE_DIM) * 2} bytes" + ) print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}") print("=" * 70) diff --git a/csrc/cache.h b/csrc/cache.h index e2c0df3ebcdd..0188a568edc7 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -75,9 +75,10 @@ void indexer_k_quant_and_cache( const std::string& scale_fmt); // Concatenate query nope and rope for MLA/DSA attention -void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] - torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] - torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim] +void concat_mla_q( + torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] + torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] + torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim] // Extract function to gather quantized K cache void cp_gather_indexer_k_quant_cache( diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index d7d4716b3e4e..acfc1b5d7676 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1370,17 +1370,17 @@ void cp_gather_indexer_k_quant_cache( // Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA. // Replaces torch.cat((ql_nope, q_pe), dim=-1). void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] - torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] - torch::Tensor& q_out // [num_tokens, num_heads, nope_dim + - // rope_dim] + torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim] + torch::Tensor& q_out // [num_tokens, num_heads, nope_dim + + // rope_dim] ) { const int num_tokens = ql_nope.size(0); const int num_heads = ql_nope.size(1); const int nope_dim = ql_nope.size(2); const int rope_dim = q_pe.size(2); - TORCH_CHECK(nope_dim % 512 == 0, - "nope_dim must be a multiple of 512, got ", nope_dim); + TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ", + nope_dim); TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim); TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim); @@ -1392,23 +1392,17 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] constexpr int warps_per_block = 32; const int total_warps = num_tokens * num_heads; - const int grid_size = - (total_warps + warps_per_block - 1) / warps_per_block; + const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block; const int block_size = warps_per_block * 32; const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_FLOATING_TYPES( - ql_nope.scalar_type(), "concat_mla_q", [&] { - vllm::ConcatMLAQKernel - <<>>( - q_out.data_ptr(), - ql_nope.data_ptr(), - q_pe.data_ptr(), - num_tokens, num_heads, - q_out.stride(0), q_out.stride(1), - ql_nope.stride(0), ql_nope.stride(1), - q_pe.stride(0), q_pe.stride(1)); - }); + VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] { + vllm::ConcatMLAQKernel<<>>( + q_out.data_ptr(), ql_nope.data_ptr(), + q_pe.data_ptr(), num_tokens, num_heads, q_out.stride(0), + q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0), + q_pe.stride(1)); + }); } diff --git a/csrc/concat_mla_q.cuh b/csrc/concat_mla_q.cuh index bd3d36917ed8..68bcfa011fb3 100644 --- a/csrc/concat_mla_q.cuh +++ b/csrc/concat_mla_q.cuh @@ -9,23 +9,17 @@ namespace vllm { // Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and -// q_pe [num_tokens, num_heads, 64] +// q_pe [num_tokens, num_heads, 64] // into q_out [num_tokens, num_heads, NOPE_DIM+64]. -// Currently instantiated only for NOPE_DIM=512. +// Currently instantiated only for NOPE_DIM=512. // Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA) template __global__ void ConcatMLAQKernel( - DType* __restrict__ q_out, - const DType* __restrict__ ql_nope, - const DType* __restrict__ q_pe, - const int num_tokens, - const int num_heads, - const int64_t out_stride_0, - const int64_t out_stride_1, - const int64_t nope_stride_0, - const int64_t nope_stride_1, - const int64_t pe_stride_0, - const int64_t pe_stride_1) { + DType* __restrict__ q_out, const DType* __restrict__ ql_nope, + const DType* __restrict__ q_pe, const int num_tokens, const int num_heads, + const int64_t out_stride_0, const int64_t out_stride_1, + const int64_t nope_stride_0, const int64_t nope_stride_1, + const int64_t pe_stride_0, const int64_t pe_stride_1) { const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5; if (flat_warp_id >= num_tokens * num_heads) return; @@ -34,13 +28,12 @@ __global__ void ConcatMLAQKernel( const int lane_id = threadIdx.x & 31; constexpr bool use_256b = VLLM_256B_PTX_ENABLED; - constexpr int nope_vec_loads = NOPE_DIM * sizeof(DType) / - (VecTraits::ARCH_MAX_VEC_SIZE * 32); + constexpr int nope_vec_loads = + NOPE_DIM * sizeof(DType) / (VecTraits::ARCH_MAX_VEC_SIZE * 32); const DType* nope_src = ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1; - DType* nope_dst = - q_out + token_id * out_stride_0 + head_id * out_stride_1; + DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1; #pragma unroll for (int i = 0; i < nope_vec_loads; i++) { @@ -56,8 +49,8 @@ __global__ void ConcatMLAQKernel( const int* rope_src = reinterpret_cast( q_pe + token_id * pe_stride_0 + head_id * pe_stride_1); - int* rope_dst = reinterpret_cast( - q_out + token_id * out_stride_0 + head_id * out_stride_1 + NOPE_DIM); + int* rope_dst = reinterpret_cast(q_out + token_id * out_stride_0 + + head_id * out_stride_1 + NOPE_DIM); st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id)); } diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 4625ed8c8350..d19e5997be6d 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -215,9 +215,8 @@ __forceinline__ __device__ int4 ld128_cs(const int4* addr) { } __forceinline__ __device__ void st128_cs(int4* addr, int4 val) { - asm volatile( - "st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), - "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); + asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), + "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); } // Predicated 256-bit / 128-bit cache-global (.cg) loads. diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 15a9e067207b..f21f70658a1f 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -780,7 +780,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { &indexer_k_quant_and_cache); cache_ops.def( - "concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()"); + "concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()"); cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q); cache_ops.def( diff --git a/tests/kernels/test_concat_mla_q.py b/tests/kernels/test_concat_mla_q.py index 4863250b1969..fec5c063c7ca 100644 --- a/tests/kernels/test_concat_mla_q.py +++ b/tests/kernels/test_concat_mla_q.py @@ -18,19 +18,17 @@ @pytest.mark.parametrize("nope_dim", NOPE_DIM) @pytest.mark.parametrize("rope_dim", ROPE_DIM) @pytest.mark.parametrize("dtype", DTYPES) -def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, - dtype): +def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, dtype): """Test with contiguous inputs (standard layout).""" torch.manual_seed(42) - ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, - device="cuda") - q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, - device="cuda") + ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda") + q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda") ref = torch.cat((ql_nope, q_pe), dim=-1) - q_out = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, - dtype=dtype, device="cuda") + q_out = torch.empty( + num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda" + ) ops.concat_mla_q(ql_nope, q_pe, q_out) torch.testing.assert_close(q_out, ref, atol=0, rtol=0) @@ -41,8 +39,7 @@ def test_concat_mla_q_contiguous(num_tokens, num_heads, nope_dim, rope_dim, @pytest.mark.parametrize("nope_dim", NOPE_DIM) @pytest.mark.parametrize("rope_dim", ROPE_DIM) @pytest.mark.parametrize("dtype", DTYPES) -def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, - rope_dim, dtype): +def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, rope_dim, dtype): """Test with transposed nope input (simulates BMM output after transpose). In the real code path, mqa_ql_nope is the result of: @@ -50,18 +47,17 @@ def test_concat_mla_q_transposed_nope(num_tokens, num_heads, nope_dim, .transpose(0, 1) # [B, N, L] — non-contiguous! """ torch.manual_seed(42) - nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, - device="cuda") + nope_raw = torch.randn(num_heads, num_tokens, nope_dim, dtype=dtype, device="cuda") ql_nope = nope_raw.transpose(0, 1) # [B, N, L], non-contiguous assert not ql_nope.is_contiguous() - q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, - device="cuda") + q_pe = torch.randn(num_tokens, num_heads, rope_dim, dtype=dtype, device="cuda") ref = torch.cat((ql_nope, q_pe), dim=-1) - q_out = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, - dtype=dtype, device="cuda") + q_out = torch.empty( + num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda" + ) ops.concat_mla_q(ql_nope, q_pe, q_out) torch.testing.assert_close(q_out, ref, atol=0, rtol=0) @@ -83,8 +79,7 @@ def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype): orig_dim = 128 + 64 # original q before absorption: [B, N, 192] # Simulate split from original q tensor - q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, - device="cuda") + q_orig = torch.randn(num_tokens, num_heads, orig_dim, dtype=dtype, device="cuda") q_nope_orig, q_pe = q_orig.split([128, 64], dim=-1) # q_pe is non-contiguous: stride(1) = 192, not 64 @@ -92,13 +87,13 @@ def test_concat_mla_q_split_rope(num_tokens, num_heads, dtype): assert q_pe.stride(2) == 1 # but innermost is fine # Simulate absorbed nope (contiguous, different size) - ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, - device="cuda") + ql_nope = torch.randn(num_tokens, num_heads, nope_dim, dtype=dtype, device="cuda") ref = torch.cat((ql_nope, q_pe), dim=-1) - q_out = torch.empty(num_tokens, num_heads, nope_dim + rope_dim, - dtype=dtype, device="cuda") + q_out = torch.empty( + num_tokens, num_heads, nope_dim + rope_dim, dtype=dtype, device="cuda" + ) ops.concat_mla_q(ql_nope, q_pe, q_out) torch.testing.assert_close(q_out, ref, atol=0, rtol=0) @@ -122,16 +117,19 @@ def test_concat_mla_q_values_preserved(num_tokens): nope_dim, rope_dim = 512, 64 # Use specific bit patterns (stay in int16 for bit-exact comparison) - ql_nope_bits = torch.arange(num_tokens * 128 * nope_dim, dtype=torch.int16, - device="cuda").view(num_tokens, 128, nope_dim) - q_pe_bits = torch.arange(num_tokens * 128 * rope_dim, dtype=torch.int16, - device="cuda").view(num_tokens, 128, rope_dim) + ql_nope_bits = torch.arange( + num_tokens * 128 * nope_dim, dtype=torch.int16, device="cuda" + ).view(num_tokens, 128, nope_dim) + q_pe_bits = torch.arange( + num_tokens * 128 * rope_dim, dtype=torch.int16, device="cuda" + ).view(num_tokens, 128, rope_dim) ql_nope = ql_nope_bits.view(torch.bfloat16) q_pe = q_pe_bits.view(torch.bfloat16) - q_out = torch.empty(num_tokens, 128, nope_dim + rope_dim, - dtype=torch.bfloat16, device="cuda") + q_out = torch.empty( + num_tokens, 128, nope_dim + rope_dim, dtype=torch.bfloat16, device="cuda" + ) ops.concat_mla_q(ql_nope, q_pe, q_out) out_bits = q_out.view(torch.int16) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 808a9e28df23..0437d2fd5fa6 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2567,6 +2567,7 @@ def cp_gather_and_upconvert_fp8_kv_cache( src_cache, dst, block_table, seq_lens, workspace_starts, batch_size ) + def concat_mla_q( ql_nope: torch.Tensor, q_pe: torch.Tensor, @@ -2581,6 +2582,7 @@ def concat_mla_q( """ torch.ops._C_cache_ops.concat_mla_q(ql_nope, q_pe, q_out) + def indexer_k_quant_and_cache( k: torch.Tensor, kv_cache: torch.Tensor, diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index a30630b9fc07..e8be7faa6ea9 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -570,13 +570,15 @@ def __init__( vllm_config = get_current_vllm_config() max_tokens = vllm_config.scheduler_config.max_num_batched_tokens self.q_concat_buffer = torch.empty( - max_tokens, num_heads, head_size, + max_tokens, + num_heads, + head_size, dtype=torch.bfloat16, device=current_platform.device_type, ) if kv_cache_dtype == "fp8_ds_mla": - # Reserve workspace during initialization + # Reserve workspace during initialization assert vllm_config is not None and vllm_config.model_config is not None prefill_workspace_size = get_prefill_workspace_size( vllm_config.model_config.max_model_len @@ -829,7 +831,7 @@ def forward_mqa( # Concatenate q if it's a tuple (ql_nope, q_pe) if isinstance(q, tuple): ql_nope, q_pe = q - q = self.q_concat_buffer[:ql_nope.shape[0]] + q = self.q_concat_buffer[: ql_nope.shape[0]] ops.concat_mla_q(ql_nope, q_pe, q) num_actual_toks = q.shape[0] From 71d00a2c1e589b781040981d88d794ae42e7fe82 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 25 Feb 2026 12:21:58 +0000 Subject: [PATCH 06/11] tune threadblock size Signed-off-by: LopezCastroRoberto --- csrc/cache_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index acfc1b5d7676..71050cfcad1a 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1390,7 +1390,7 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim] if (num_tokens == 0) return; - constexpr int warps_per_block = 32; + constexpr int warps_per_block = 8; const int total_warps = num_tokens * num_heads; const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block; const int block_size = warps_per_block * 32; From 8c0e7908e1756dfe683342821a3f57b53eb7b566 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 4 Mar 2026 15:13:01 +0000 Subject: [PATCH 07/11] overlap with MoE workspace Signed-off-by: LopezCastroRoberto --- .../attention/backends/mla/flashmla_sparse.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 2e5b0ad0264b..2ef42d34c162 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -569,13 +569,7 @@ def __init__( vllm_config = get_current_vllm_config() max_tokens = vllm_config.scheduler_config.max_num_batched_tokens - self.q_concat_buffer = torch.empty( - max_tokens, - num_heads, - head_size, - dtype=torch.bfloat16, - device=current_platform.device_type, - ) + q_concat_shape = (max_tokens, num_heads, head_size) if kv_cache_dtype == "fp8_ds_mla": # Reserve workspace during initialization @@ -584,11 +578,16 @@ def __init__( vllm_config.model_config.max_model_len ) self.prefill_workspace_shape = (prefill_workspace_size, head_size) - (self.prefill_bf16_workspace,) = ( + self.q_concat_buffer, self.prefill_bf16_workspace = ( current_workspace_manager().get_simultaneous( - (self.prefill_workspace_shape, torch.bfloat16) + (q_concat_shape, torch.bfloat16), + (self.prefill_workspace_shape, torch.bfloat16), ) ) + else: + (self.q_concat_buffer,) = current_workspace_manager().get_simultaneous( + (q_concat_shape, torch.bfloat16), + ) def _forward_bf16_kv( self, From 36baa4c38ee79ccdb84e9cdf29b3c14a2235bffa Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 4 Mar 2026 16:17:26 +0000 Subject: [PATCH 08/11] add missing instructions to vec_utils Signed-off-by: LopezCastroRoberto --- csrc/cuda_vec_utils.cuh | 35 +++++++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 82a19f10a70e..aae834290caf 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -213,21 +213,43 @@ __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { // 32-bit cache-streaming (.cs) load / store — SM100+ only. __forceinline__ __device__ int ld32_cs(const int* addr) { -#if VLLM_256B_PTX_ENABLED +#ifndef USE_ROCM int val; asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); return val; #else - assert(false && "ld32_cs requires SM100+ with CUDA 12.9+"); + assert(false && "ld32_cs is not supported on ROCm"); return 0; #endif } __forceinline__ __device__ void st32_cs(int* addr, int val) { -#if VLLM_256B_PTX_ENABLED +#ifndef USE_ROCM asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); #else - assert(false && "st32_cs requires SM100+ with CUDA 12.9+"); + assert(false && "st32_cs is not supported on ROCm"); +#endif +} + +__forceinline__ __device__ int4 ld128_cs(const int4* addr) { +#ifndef USE_ROCM + int4 val; + asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) + : "l"(addr)); + return val; +#else + assert(false && "ld128_cs is not supported on ROCm"); + return {}; +#endif +} + +__forceinline__ __device__ void st128_cs(int4* addr, int4 val) { +#ifndef USE_ROCM + asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), + "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); +#else + assert(false && "st128_cs is not supported on ROCm"); #endif } @@ -260,7 +282,7 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, bool pred) { -#if VLLM_256B_PTX_ENABLED +#ifndef USE_ROCM uint32_t r0, r1, r2, r3; asm volatile( @@ -278,7 +300,8 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, val = uint4{r0, r1, r2, r3}; #else - assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+"); + assert(false && "ld128_cg_or_zero is not supported on ROCm"); + return {}; #endif } From 75a5e3710bbdc4e751d10a7ebace6b31c548c8d6 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Wed, 4 Mar 2026 16:38:47 +0000 Subject: [PATCH 09/11] remove return amd path Signed-off-by: LopezCastroRoberto --- csrc/cuda_vec_utils.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index aae834290caf..6d4adc8a4fff 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -196,7 +196,6 @@ __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { return val; #else assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); - return {}; #endif } @@ -240,7 +239,6 @@ __forceinline__ __device__ int4 ld128_cs(const int4* addr) { return val; #else assert(false && "ld128_cs is not supported on ROCm"); - return {}; #endif } @@ -301,7 +299,6 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, val = uint4{r0, r1, r2, r3}; #else assert(false && "ld128_cg_or_zero is not supported on ROCm"); - return {}; #endif } From 6f8bb714fdd49df6ddf5f8ccdcad1ea093f50fa1 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Mon, 9 Mar 2026 10:07:51 +0000 Subject: [PATCH 10/11] add AMD tests to CI Signed-off-by: LopezCastroRoberto --- .buildkite/test-amd.yaml | 6 ++++-- .buildkite/test_areas/kernels.yaml | 3 ++- csrc/cuda_vec_utils.cuh | 27 +++++++++++++++++---------- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ad11f3764f54..361a5e279102 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -687,8 +687,9 @@ steps: - csrc/ - tests/kernels/core - tests/kernels/test_top_k_per_row.py + - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py + - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 @@ -2342,8 +2343,9 @@ steps: - csrc/ - tests/kernels/core - tests/kernels/test_top_k_per_row.py + - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py + - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index 9328cad4bf94..e0be49cf39c3 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -8,8 +8,9 @@ steps: - csrc/ - tests/kernels/core - tests/kernels/test_top_k_per_row.py + - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py + - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py - label: Kernels Attention Test %N timeout_in_minutes: 35 diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh index 6d4adc8a4fff..8f997f3ba409 100644 --- a/csrc/cuda_vec_utils.cuh +++ b/csrc/cuda_vec_utils.cuh @@ -210,36 +210,43 @@ __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { #endif } -// 32-bit cache-streaming (.cs) load / store — SM100+ only. +// 32-bit load / store. +__device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); } + +__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; } + +// 32-bit cache-streaming (.cs) load / store. +// Falls back to ld32/st32 on ROCm (no .cs hint). __forceinline__ __device__ int ld32_cs(const int* addr) { -#ifndef USE_ROCM int val; +#ifndef USE_ROCM asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); - return val; #else - assert(false && "ld32_cs is not supported on ROCm"); - return 0; + val = ld32(addr); #endif + return val; } __forceinline__ __device__ void st32_cs(int* addr, int val) { #ifndef USE_ROCM asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); #else - assert(false && "st32_cs is not supported on ROCm"); + st32(addr, val); #endif } +// 128-bit cache-streaming (.cs) load / store. +// Falls back to ld128/st128 on ROCm (no .cs hint). __forceinline__ __device__ int4 ld128_cs(const int4* addr) { -#ifndef USE_ROCM int4 val; +#ifndef USE_ROCM asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];" : "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w) : "l"(addr)); - return val; #else - assert(false && "ld128_cs is not supported on ROCm"); + ld128(val, addr); #endif + return val; } __forceinline__ __device__ void st128_cs(int4* addr, int4 val) { @@ -247,7 +254,7 @@ __forceinline__ __device__ void st128_cs(int4* addr, int4 val) { asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr), "r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w)); #else - assert(false && "st128_cs is not supported on ROCm"); + st128(val, addr); #endif } From 43dffba1abd23a3ce8e8005e58a4427f26ee6194 Mon Sep 17 00:00:00 2001 From: LopezCastroRoberto Date: Mon, 9 Mar 2026 11:46:30 +0000 Subject: [PATCH 11/11] remove amd ci Signed-off-by: LopezCastroRoberto --- .buildkite/test-amd.yaml | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 361a5e279102..ad11f3764f54 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -687,9 +687,8 @@ steps: - csrc/ - tests/kernels/core - tests/kernels/test_top_k_per_row.py - - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py + - pytest -v -s kernels/core kernels/test_top_k_per_row.py - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35 @@ -2343,9 +2342,8 @@ steps: - csrc/ - tests/kernels/core - tests/kernels/test_top_k_per_row.py - - tests/kernels/test_concat_mla_q.py commands: - - pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py + - pytest -v -s kernels/core kernels/test_top_k_per_row.py - label: Kernels Attention Test %N # 23min timeout_in_minutes: 35