Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/test_areas/kernels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ steps:
- csrc/
- tests/kernels/core
- tests/kernels/test_top_k_per_row.py
- tests/kernels/test_concat_mla_q.py
commands:
- pytest -v -s kernels/core kernels/test_top_k_per_row.py
- pytest -v -s kernels/core kernels/test_top_k_per_row.py kernels/test_concat_mla_q.py

- label: Kernels Attention Test %N
timeout_in_minutes: 35
Expand Down
98 changes: 98 additions & 0 deletions benchmarks/kernels/bench_concat_mla_q.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import argparse

import torch

from vllm import _custom_ops as ops
from vllm.triton_utils import triton

# DeepSeek V3 dimensions
NOPE_DIM = 512
ROPE_DIM = 64
NUM_HEADS = 128

NUM_TOKENS = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]


def get_configs():
return NUM_TOKENS


def make_inputs(num_tokens, dtype):
"""Create inputs matching the real code path.
Args:
contiguous_nope: If False, simulate the transposed BMM output
(non-contiguous nope with stride pattern from
[N,B,L].transpose(0,1)).
"""
# Simulate: bmm output [N, B, L].transpose(0, 1) -> [B, N, L]
raw = torch.randn(NUM_HEADS, num_tokens, NOPE_DIM, dtype=dtype, device="cuda")
ql_nope = raw.transpose(0, 1)

q_pe = torch.randn(num_tokens, NUM_HEADS, ROPE_DIM, dtype=dtype, device="cuda")
return ql_nope, q_pe


# ---- Non-contiguous nope benchmark (real code path) ----
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["num_tokens"],
x_vals=get_configs(),
line_arg="provider",
line_vals=["torch_cat", "concat_mla_q"],
line_names=["torch.cat", "concat_mla_q (v8)"],
styles=[("blue", "--"), ("green", "-")],
ylabel="Latency (us)",
plot_name="concat_mla_q-transposed",
args={},
)
)
def bench_transposed(num_tokens, provider):
dtype = torch.bfloat16
ql_nope, q_pe = make_inputs(num_tokens, dtype)

q_out = torch.empty(
num_tokens, NUM_HEADS, NOPE_DIM + ROPE_DIM, dtype=dtype, device="cuda"
)

quantiles = [0.5, 0.2, 0.8]

if provider == "torch_cat":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.cat((ql_nope, q_pe), dim=-1), quantiles=quantiles, rep=500
)
else:
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: ops.concat_mla_q(ql_nope, q_pe, q_out), quantiles=quantiles, rep=500
)

return ms * 1000, max_ms * 1000, min_ms * 1000 # us


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark concat_mla_q vs torch.cat")
parser.add_argument(
"--save-path", type=str, default=None, help="Path to save benchmark results"
)
args = parser.parse_args()

print("\n" + "=" * 70)
print("CONCAT MLA Q KERNEL BENCHMARKS")
print("=" * 70)
print(f"Dimensions: nope={NOPE_DIM}, rope={ROPE_DIM}, heads={NUM_HEADS}")
print(
f"Per-head output: {NOPE_DIM + ROPE_DIM} bf16 = "
f"{(NOPE_DIM + ROPE_DIM) * 2} bytes"
)
print(f"num_tokens (decode=batch_size, prefill=chunk_size): {NUM_TOKENS}")
print("=" * 70)

print("\n--- Non-contiguous nope inputs (transposed BMM output) ---")
bench_transposed.run(print_data=True, save_path=args.save_path)

print("\n" + "=" * 70)
print("Benchmarking complete!")
print("=" * 70)
6 changes: 6 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);

// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim]

// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
Expand Down
41 changes: 41 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#include "concat_mla_q.cuh"

#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
Expand Down Expand Up @@ -1365,3 +1366,43 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
}

// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA.
// Replaces torch.cat((ql_nope, q_pe), dim=-1).
void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out // [num_tokens, num_heads, nope_dim +
// rope_dim]
) {
const int num_tokens = ql_nope.size(0);
const int num_heads = ql_nope.size(1);
const int nope_dim = ql_nope.size(2);
const int rope_dim = q_pe.size(2);

TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ",
nope_dim);
TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);

TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");

if (num_tokens == 0) return;

constexpr int warps_per_block = 8;
const int total_warps = num_tokens * num_heads;
const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block;
const int block_size = warps_per_block * 32;

const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0),
q_pe.stride(1));
});
}
60 changes: 60 additions & 0 deletions csrc/concat_mla_q.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#ifndef CONCAT_MLA_Q_CUH_
#define CONCAT_MLA_Q_CUH_

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "cuda_vec_utils.cuh"

namespace vllm {

// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and
// q_pe [num_tokens, num_heads, 64]
// into q_out [num_tokens, num_heads, NOPE_DIM+64].
// Currently instantiated only for NOPE_DIM=512.
// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA)
template <typename DType, int NOPE_DIM>
__global__ void ConcatMLAQKernel(
DType* __restrict__ q_out, const DType* __restrict__ ql_nope,
const DType* __restrict__ q_pe, const int num_tokens, const int num_heads,
const int64_t out_stride_0, const int64_t out_stride_1,
const int64_t nope_stride_0, const int64_t nope_stride_1,
const int64_t pe_stride_0, const int64_t pe_stride_1) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= num_tokens * num_heads) return;

const int token_id = flat_warp_id / num_heads;
const int head_id = flat_warp_id % num_heads;
const int lane_id = threadIdx.x & 31;

constexpr bool use_256b = VLLM_256B_PTX_ENABLED;
constexpr int nope_vec_loads =
NOPE_DIM * sizeof(DType) / (VecTraits<use_256b>::ARCH_MAX_VEC_SIZE * 32);

const DType* nope_src =
ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1;
DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1;

#pragma unroll
for (int i = 0; i < nope_vec_loads; i++) {
const int offset = i * 32 + lane_id;
if constexpr (use_256b) {
st256_cs(reinterpret_cast<u32x8_t*>(nope_dst) + offset,
ld256_cs(reinterpret_cast<const u32x8_t*>(nope_src) + offset));
} else {
st128_cs(reinterpret_cast<int4*>(nope_dst) + offset,
ld128_cs(reinterpret_cast<const int4*>(nope_src) + offset));
}
}

const int* rope_src = reinterpret_cast<const int*>(
q_pe + token_id * pe_stride_0 + head_id * pe_stride_1);
int* rope_dst = reinterpret_cast<int*>(q_out + token_id * out_stride_0 +
head_id * out_stride_1 + NOPE_DIM);

st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id));
}

} // namespace vllm

#endif // CONCAT_MLA_Q_CUH_
47 changes: 37 additions & 10 deletions csrc/cuda_vec_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ __forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
return val;
#else
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
return {};
#endif
}

Expand All @@ -211,23 +210,51 @@ __forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
#endif
}

// 32-bit cache-streaming (.cs) load / store — SM100+ only.
// 32-bit load / store.
__device__ __forceinline__ int ld32(const int* addr) { return __ldg(addr); }

__device__ __forceinline__ void st32(int* addr, int val) { *addr = val; }

// 32-bit cache-streaming (.cs) load / store.
// Falls back to ld32/st32 on ROCm (no .cs hint).
__forceinline__ __device__ int ld32_cs(const int* addr) {
#if VLLM_256B_PTX_ENABLED
int val;
#ifndef USE_ROCM
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
return val;
#else
assert(false && "ld32_cs requires SM100+ with CUDA 12.9+");
return 0;
val = ld32(addr);
#endif
return val;
}

__forceinline__ __device__ void st32_cs(int* addr, int val) {
#if VLLM_256B_PTX_ENABLED
#ifndef USE_ROCM
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
#else
assert(false && "st32_cs requires SM100+ with CUDA 12.9+");
st32(addr, val);
#endif
}

// 128-bit cache-streaming (.cs) load / store.
// Falls back to ld128/st128 on ROCm (no .cs hint).
__forceinline__ __device__ int4 ld128_cs(const int4* addr) {
int4 val;
#ifndef USE_ROCM
asm volatile("ld.global.cs.v4.u32 {%0,%1,%2,%3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(addr));
#else
ld128(val, addr);
#endif
return val;
}

__forceinline__ __device__ void st128_cs(int4* addr, int4 val) {
#ifndef USE_ROCM
asm volatile("st.global.cs.v4.u32 [%0], {%1,%2,%3,%4};" ::"l"(addr),
"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w));
#else
st128(val, addr);
#endif
}

Expand Down Expand Up @@ -260,7 +287,7 @@ __device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr,

__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
bool pred) {
#if VLLM_256B_PTX_ENABLED
#ifndef USE_ROCM
uint32_t r0, r1, r2, r3;

asm volatile(
Expand All @@ -278,7 +305,7 @@ __device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,

val = uint4{r0, r1, r2, r3};
#else
assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+");
assert(false && "ld128_cg_or_zero is not supported on ROCm");
#endif
}

Expand Down
4 changes: 4 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,10 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA,
&indexer_k_quant_and_cache);

cache_ops.def(
"concat_mla_q(Tensor ql_nope, Tensor q_pe, Tensor! q_out) -> ()");
cache_ops.impl("concat_mla_q", torch::kCUDA, &concat_mla_q);

cache_ops.def(
"cp_gather_indexer_k_quant_cache(Tensor kv_cache, Tensor! dst_k, Tensor! "
"dst_scale, Tensor block_table, Tensor cu_seq_lens) -> ()");
Expand Down
Loading