From 21ed9870df2360caccbc12dee5604d0e56e65eda Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 30 Apr 2026 11:04:23 +0000 Subject: [PATCH 01/21] initial version Signed-off-by: Thien Tran --- benchmark_fused_indexer_q.py | 95 +++++++++ indexer_q_mxfp4.py | 369 +++++++++++++++++++++++++++++++++++ 2 files changed, 464 insertions(+) create mode 100644 benchmark_fused_indexer_q.py create mode 100644 indexer_q_mxfp4.py diff --git a/benchmark_fused_indexer_q.py b/benchmark_fused_indexer_q.py new file mode 100644 index 000000000000..f9a7a406038f --- /dev/null +++ b/benchmark_fused_indexer_q.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the FP4 fused_indexer_q_rope_quant path.""" + +import argparse + +import torch + +from indexer_q_mxfp4 import fused_indexer_q_rope_quant as dev_impl +from vllm.triton_utils import triton +from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( + fused_indexer_q_rope_quant as prod_impl, +) + +NUM_HEADS = 64 +HEAD_DIM = 128 +ROPE_DIM = 64 +MAX_POS = 100_000 +TOKENS = [1, 8, 32, 128, 256, 512, 1024, 2048, 4096, 8192] +QUANTILES = [0.5, 0.2, 0.8] +PROVIDERS = {"production": prod_impl, "dev": dev_impl} + + +def measure(num_tokens, provider): + torch.set_default_device("cuda") + positions = torch.randint(MAX_POS, (num_tokens,), dtype=torch.int64) + query = torch.randn(num_tokens, NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16) + cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=torch.float32) + weights = torch.randn(num_tokens, NUM_HEADS, dtype=torch.bfloat16) + kernel_args = ( + positions, + query, + cos_sin_cache, + weights, + HEAD_DIM**-0.5, + NUM_HEADS**-0.5, + True, + ) + + selected_impl = PROVIDERS[provider] + if provider != "production": + prod_q, prod_weights = prod_impl(*kernel_args) + dev_q, dev_weights = selected_impl(*kernel_args) + prod_q_packed, prod_q_scale = prod_q + dev_q_packed, dev_q_scale = dev_q + + assert torch.equal(prod_q_packed, dev_q_packed), ( + f"q packed mismatch for num_tokens={num_tokens}" + ) + assert torch.equal(prod_q_scale, dev_q_scale), ( + f"q scale mismatch for num_tokens={num_tokens}" + ) + assert torch.equal(prod_weights, dev_weights), ( + f"weights mismatch for num_tokens={num_tokens}" + ) + + benchmark_fn = lambda: selected_impl(*kernel_args) + median_ms, p20_ms, p80_ms = triton.testing.do_bench( + benchmark_fn, + quantiles=QUANTILES, + ) + + bytes_per_token = 8 # position int64 + bytes_per_token += NUM_HEADS * HEAD_DIM * 2 # q in bf16 + bytes_per_token += ROPE_DIM * 4 # rope fp32 + bytes_per_token += NUM_HEADS * 2 # weights in bf16 + bytes_per_token += NUM_HEADS * HEAD_DIM // 2 # q out fp4 + bytes_per_token += NUM_HEADS * HEAD_DIM // 32 # q_scale uint8 + bytes_per_token += NUM_HEADS * 4 # weights out fp32 + total_bytes = bytes_per_token * num_tokens + + return median_ms, p20_ms, p80_ms, total_bytes + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tokens", type=int, nargs="+", default=TOKENS) + args = parser.parse_args() + + print(f"Device: {torch.cuda.get_device_name(0)}") + print(f"H={NUM_HEADS} D={HEAD_DIM} rope_dim={ROPE_DIM} use_fp4=True\n") + + for num_tokens in args.tokens: + for provider in PROVIDERS: + median_ms, p20_ms, p80_ms, moved_bytes = measure( + num_tokens, + provider, + ) + bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9 + print( + f"[{provider:10s}] T={num_tokens:6d} " + f"{median_ms * 1e3:7.2f} us " + f"BW {bandwidth_gb_s:7.1f} GB/s " + f"(p20={p20_ms * 1e3:.2f} p80={p80_ms * 1e3:.2f} us)" + ) diff --git a/indexer_q_mxfp4.py b/indexer_q_mxfp4.py new file mode 100644 index 000000000000..f371ea1b7b38 --- /dev/null +++ b/indexer_q_mxfp4.py @@ -0,0 +1,369 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import torch +from torch.utils.cpp_extension import load_inline + +CUDA_SRC = r""" +#include +#include +#include + +constexpr int WARP_SIZE = 32; +constexpr int MX_BLOCK_SIZE = 32; + +__device__ inline +void ldg_f32x8(float *data, const void *ptr) { + asm volatile("ld.global.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=f"(data[0]), "=f"(data[1]), "=f"(data[2]), "=f"(data[3]), + "=f"(data[4]), "=f"(data[5]), "=f"(data[6]), "=f"(data[7]) + : "l"(ptr)); +} + +__device__ inline +void ldg_b32x8_fast(int *data, const void *ptr) { + asm volatile("ld.global.relaxed.cta.L1::no_allocate.v8.b32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" + : "=r"(data[0]), "=r"(data[1]), "=r"(data[2]), "=r"(data[3]), + "=r"(data[4]), "=r"(data[5]), "=r"(data[6]), "=r"(data[7]) + : "l"(ptr)); +} + +__device__ inline +void bf16x2_to_fp32x2(float *out, uint32_t data) { + asm volatile("shl.b32 %0, %2, 16;\n" // low 16-bit + "and.b32 %1, %2, 0xFFFF0000;" // high 16-bit + : "=f"(out[0]), "=f"(out[1]) : "r"(data)); +} + +__device__ inline +int fp32x2_to_bf16x2(float a, float b) { + int tmp; + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;" : "=r"(tmp) : "f"(b), "f"(a)); + return tmp; +} + +__device__ inline +int bf16x2_abs(int a) { + int d; + asm volatile("abs.bf16x2 %0, %1;" : "=r"(d) : "r"(a)); + return d; +} + +__device__ inline +int bf16x2_max(int a, int b) { + int d; + asm volatile("max.bf16x2 %0, %1, %2;" : "=r"(d) : "r"(a), "r"(b)); + return d; +} + +__device__ inline +int fp32x8_to_fp4x8(const float *x) { + int out; + asm volatile( + "{\n" + ".reg .b8 x0, x1, x2, x3;\n" + "cvt.rn.satfinite.e2m1x2.f32 x0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 x1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 x2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 x3, %8, %7;\n" + "mov.b32 %0, {x0, x1, x2, x3};\n" + "}" + : "=r"(out) + : "f"(x[0]), "f"(x[1]), "f"(x[2]), "f"(x[3]), + "f"(x[4]), "f"(x[5]), "f"(x[6]), "f"(x[7]) + ); + return out; +} + +template +__global__ +void fused_indexer_q_rope_mxfp4_kernel( + const int64_t *positions_ptr, // [num_tokens] + const nv_bfloat16 *q_ptr, // [num_tokens, num_heads, head_dim] + const float *cos_sin_ptr, // [max_pos, rope_dim] + const nv_bfloat16 *weights, // [num_tokens, num_heads] + char *q_fp4_ptr, // [num_tokens, num_heads, head_dim/2] + uint8_t *q_scale_ptr, // [num_tokens, num_heads, head_dim/32] + float *weights_out, // [num_tokens, num_heads] + float scale, + int num_tokens, + int num_heads, + int q_stride0, int q_stride1, + int cos_sin_stride, + int weights_stride, + int q_fp4_stride0, int q_fp4_stride1, + int q_scale_stride0, int q_scale_stride1, + int weights_out_stride +) { + constexpr int NOPE_DIM = HEAD_DIM - ROPE_DIM; + + // we will use 32B load per thread = 16 BF16 elems. + // hence, we need 8 threads to load single head (128 elems). + // let's call subwarp = 8 threads -> 1 subwarp handles 1 token + constexpr int SUBWARP_SIZE = HEAD_DIM / 16; + static_assert(SUBWARP_SIZE <= WARP_SIZE); + + const int tid = threadIdx.x; + const int bid = blockIdx.x; + + const int global_tid = bid * blockDim.x + tid; + const int global_subwarp_id = global_tid / SUBWARP_SIZE; + const int sublane_id = tid % SUBWARP_SIZE; + + const int token_id = global_subwarp_id / num_heads; + const int head_id = global_subwarp_id % num_heads; + + // load Q + int q[8]; + float q_f32[16]; + const int q_offset = token_id * q_stride0 + head_id * q_stride1 + sublane_id * 16; + ldg_b32x8_fast(q, q_ptr + q_offset); + int64_t pos = positions_ptr[token_id]; + + // apply rope + // NOTE: warp divergence + if (sublane_id * 16 >= NOPE_DIM) { + float cos[8], sin[8]; + const int rope_idx = (sublane_id * 16 - NOPE_DIM) / 2; + ldg_f32x8(cos, cos_sin_ptr + (pos * cos_sin_stride + rope_idx)); + ldg_f32x8(sin, cos_sin_ptr + (pos * cos_sin_stride + ROPE_DIM / 2 + rope_idx)); + + // unpack + for (int i = 0; i < 8; i++) + bf16x2_to_fp32x2(q_f32 + i * 2, q[i]); + + for (int i = 0; i < 8; i++) { + float q0 = q_f32[i * 2 + 0] * cos[i] - q_f32[i * 2 + 1] * sin[i]; + float q1 = q_f32[i * 2 + 0] * sin[i] + q_f32[i * 2 + 1] * cos[i]; + q_f32[i * 2 + 0] = q0; + q_f32[i * 2 + 1] = q1; + } + + // BF16 round-trip to match reference + for (int i = 0; i < 8; i++) + q[i] = fp32x2_to_bf16x2(q_f32[i * 2], q_f32[i * 2 + 1]); + } + + // absmax in BF16 to save instructions + int q_amax = bf16x2_abs(q[0]); + for (int i = 1; i < 8; i++) + q_amax = bf16x2_max(q_amax, bf16x2_abs(q[i])); + + // amax between 2 threads -> 1 warp shuffle call + q_amax = bf16x2_max(q_amax, __shfl_xor_sync(0xFFFF'FFFF, q_amax, 1)); + + // final amax in FP32 + float q_amax_f32[2]; + bf16x2_to_fp32x2(q_amax_f32, q_amax); + float amax = max(q_amax_f32[0], q_amax_f32[1]); + + constexpr float amax_eps = 0x6p-126f; // 6.0f * 2^-126 + constexpr float inv_fp4_max = 1.0f / 6.0f; + float fp4_scale = max(amax, amax_eps) * inv_fp4_max; + + // compute ceil_log2 with bit manipulation + // add a magic number so that exponent increments by 1 + // when mantissa bits > 0 + uint32_t bits = __float_as_uint(fp4_scale); + uint32_t ue8m0 = ((bits + 0x7FFFFFU) >> 23U) & 0xFFU; + + // only 1 out of 2 threads need to store SF (rmb, 2 threads = 32 elems) + if (tid % 2 == 0) { + const int q_scale_offset = token_id * q_scale_stride0 + + head_id * q_scale_stride1 + + sublane_id / 2; + q_scale_ptr[q_scale_offset] = ue8m0; + } + + // unpack + for (int i = 0; i < 8; i++) + bf16x2_to_fp32x2(q_f32 + i * 2, q[i]); + + // let A = ceil(log2(fp4_scale)) be the actual mathematical value + // fp4_scale = 2^A, and ue8m0 = A + 127, where 127 is the exponent bias + // we want 1/fp4_scale = 2^(-A), whose exponent bits = -A + 127 = 254 - ue8m0 + float inv_fp4_scale = __uint_as_float((254U - ue8m0) << 23U); + for (int i = 0; i < 16; i++) + q_f32[i] *= inv_fp4_scale; + + int2 packed_fp4; + packed_fp4.x = fp32x8_to_fp4x8(q_f32); + packed_fp4.y = fp32x8_to_fp4x8(q_f32 + 8); + const int q_fp4_offset = token_id * q_fp4_stride0 + + head_id * q_fp4_stride1 + + sublane_id * 8; + reinterpret_cast(q_fp4_ptr + q_fp4_offset)[0] = packed_fp4; + + // scale weights + if (global_tid < num_tokens * num_heads) { + const int token_id = global_tid / num_heads; + const int head_id = global_tid % num_heads; + float w = __bfloat162float(weights[token_id * weights_stride + head_id]); + weights_out[token_id * weights_out_stride + head_id] = w * scale; + } +} + +#include +#include +#include +#include +#include + +at::Tensor fused_indexer_q_rope_mxfp4( + const at::Tensor& positions, + const at::Tensor& q, + const at::Tensor& cos_sin, + const at::Tensor& weights, + at::Tensor& q_fp4, + at::Tensor& q_scale, + at::Tensor& weights_out, + double scale) { + TORCH_CHECK(positions.is_cuda(), "positions must be CUDA"); + TORCH_CHECK(q.is_cuda(), "q must be CUDA"); + TORCH_CHECK(cos_sin.is_cuda(), "cos_sin must be CUDA"); + TORCH_CHECK(weights.is_cuda(), "weights must be CUDA"); + TORCH_CHECK(q_fp4.is_cuda(), "q_fp4 must be CUDA"); + TORCH_CHECK(q_scale.is_cuda(), "q_scale must be CUDA"); + TORCH_CHECK(weights_out.is_cuda(), "weights_out must be CUDA"); + + TORCH_CHECK(positions.scalar_type() == at::kLong, "positions must be int64"); + TORCH_CHECK(q.scalar_type() == at::kBFloat16, "q must be bfloat16"); + TORCH_CHECK(cos_sin.scalar_type() == at::kFloat, "cos_sin must be float32"); + TORCH_CHECK(weights.scalar_type() == at::kBFloat16, "weights must be bfloat16"); + TORCH_CHECK(q_fp4.scalar_type() == at::kByte, "q_fp4 must be uint8"); + TORCH_CHECK(q_scale.scalar_type() == at::kByte, "q_scale must be uint8"); + TORCH_CHECK(weights_out.scalar_type() == at::kFloat, "weights_out must be float32"); + + TORCH_CHECK(positions.dim() == 1, "positions must be rank 1"); + TORCH_CHECK(q.dim() == 3, "q must have shape [num_tokens, num_heads, 128]"); + TORCH_CHECK(cos_sin.dim() == 2, "cos_sin must have shape [max_pos, 64]"); + TORCH_CHECK(weights.dim() == 2, "weights must have shape [num_tokens, num_heads]"); + TORCH_CHECK(q.size(2) == 128, "q head_dim must be 128"); + TORCH_CHECK(cos_sin.size(1) == 64, "cos_sin rope_dim must be 64"); + + const int num_tokens = static_cast(positions.size(0)); + const int num_heads = static_cast(q.size(1)); + TORCH_CHECK(q.size(0) == num_tokens, "q and positions token counts differ"); + TORCH_CHECK(weights.size(0) == num_tokens, "weights token count differs"); + TORCH_CHECK(weights.size(1) == num_heads, "weights head count differs"); + TORCH_CHECK(q_fp4.sizes() == at::IntArrayRef({num_tokens, num_heads, 64}), + "q_fp4 must have shape [num_tokens, num_heads, 64]"); + TORCH_CHECK(q_scale.sizes() == at::IntArrayRef({num_tokens, num_heads, 4}), + "q_scale must have shape [num_tokens, num_heads, 4]"); + TORCH_CHECK(weights_out.sizes() == weights.sizes(), + "weights_out must have the same shape as weights"); + + // The kernel uses a full-warp shuffle mask. For 8-thread subwarps and a + // 256-thread block, DeepSeek's 64 heads makes the grid exact. + TORCH_CHECK(num_heads % 32 == 0, + "num_heads must be divisible by 32 for this launch wrapper"); + + c10::cuda::CUDAGuard device_guard(q.device()); + constexpr int kHeadDim = 128; + constexpr int kRopeDim = 64; + constexpr int kSubwarpSize = kHeadDim / 16; + constexpr int kBlockSize = 256; + const int total_threads = num_tokens * num_heads * kSubwarpSize; + const int grid = (total_threads + kBlockSize - 1) / kBlockSize; + + fused_indexer_q_rope_mxfp4_kernel + <<>>( + positions.data_ptr(), + reinterpret_cast(q.data_ptr()), + cos_sin.data_ptr(), + reinterpret_cast(weights.data_ptr()), + reinterpret_cast(q_fp4.data_ptr()), + q_scale.data_ptr(), + weights_out.data_ptr(), + static_cast(scale), + num_tokens, + num_heads, + static_cast(q.stride(0)), + static_cast(q.stride(1)), + static_cast(cos_sin.stride(0)), + static_cast(weights.stride(0)), + static_cast(q_fp4.stride(0)), + static_cast(q_fp4.stride(1)), + static_cast(q_scale.stride(0)), + static_cast(q_scale.stride(1)), + static_cast(weights_out.stride(0))); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return q_fp4; +} + +TORCH_LIBRARY(indexer_q_mxfp4, m) { + m.def("fused_indexer_q_rope_mxfp4(" + "Tensor positions, Tensor q, Tensor cos_sin, Tensor weights, " + "Tensor(a!) q_fp4, Tensor(b!) q_scale, Tensor(c!) weights_out, " + "float scale) -> Tensor"); + m.impl("fused_indexer_q_rope_mxfp4", + torch::dispatch(c10::DispatchKey::CUDA, + TORCH_FN(fused_indexer_q_rope_mxfp4))); +} +""" + +HEAD_DIM = 128 +ROPE_DIM = 64 + +load_inline( + "indexer_q_mxfp4", + cpp_sources="", + cuda_sources=CUDA_SRC, + verbose=False, + is_python_module=False, + no_implicit_headers=True, + extra_cuda_cflags=[ + "-O3", + "-gencode=arch=compute_100a,code=sm_100a", + "--expt-relaxed-constexpr", + "--relocatable-device-code=false", + "-lineinfo", + "-Xptxas=-v", + ], + extra_ldflags=["-lcuda"], +) +_fused_indexer_q_rope_mxfp4 = torch.ops.indexer_q_mxfp4.fused_indexer_q_rope_mxfp4 + + +def fused_indexer_q_rope_quant( + positions: torch.Tensor, + index_q: torch.Tensor, + index_q_cos_sin_cache: torch.Tensor, + index_weights: torch.Tensor, + index_weights_softmax_scale: float, + index_weights_head_scale: float, + use_fp4: bool = False, +) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + if not use_fp4: + raise NotImplementedError("indexer_q_mxfp4 only implements use_fp4=True") + assert index_q.ndim == 3 and index_q.shape[-1] == HEAD_DIM + assert index_q_cos_sin_cache.ndim == 2 + assert index_q_cos_sin_cache.shape[-1] == ROPE_DIM + + num_tokens, num_heads, _ = index_q.shape + q_fp4 = torch.empty( + (num_tokens, num_heads, HEAD_DIM // 2), + dtype=torch.uint8, + device=index_q.device, + ) + q_scale = torch.empty( + (num_tokens, num_heads, HEAD_DIM // 32), + dtype=torch.uint8, + device=index_q.device, + ) + weights_out = torch.empty_like(index_weights, dtype=torch.float32) + + scale = float(index_weights_softmax_scale * index_weights_head_scale) + _fused_indexer_q_rope_mxfp4( + positions, + index_q, + index_q_cos_sin_cache, + index_weights, + q_fp4, + q_scale, + weights_out, + scale, + ) + return (q_fp4, q_scale.view(torch.int32).squeeze(-1)), weights_out From 7fb6d4dd8431e876a7f8470376274a1914d8beb1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 30 Apr 2026 11:48:39 +0000 Subject: [PATCH 02/21] update Signed-off-by: Thien Tran --- benchmark_fused_indexer_q.py | 6 +++--- indexer_q_mxfp4.py | 16 +++++++++------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/benchmark_fused_indexer_q.py b/benchmark_fused_indexer_q.py index f9a7a406038f..5719da77c1b3 100644 --- a/benchmark_fused_indexer_q.py +++ b/benchmark_fused_indexer_q.py @@ -18,7 +18,7 @@ MAX_POS = 100_000 TOKENS = [1, 8, 32, 128, 256, 512, 1024, 2048, 4096, 8192] QUANTILES = [0.5, 0.2, 0.8] -PROVIDERS = {"production": prod_impl, "dev": dev_impl} +PROVIDERS = {"prod": prod_impl, "dev": dev_impl} def measure(num_tokens, provider): @@ -38,7 +38,7 @@ def measure(num_tokens, provider): ) selected_impl = PROVIDERS[provider] - if provider != "production": + if provider != "prod": prod_q, prod_weights = prod_impl(*kernel_args) dev_q, dev_weights = selected_impl(*kernel_args) prod_q_packed, prod_q_scale = prod_q @@ -88,7 +88,7 @@ def measure(num_tokens, provider): ) bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9 print( - f"[{provider:10s}] T={num_tokens:6d} " + f"[{provider:4s}] T={num_tokens:6d} " f"{median_ms * 1e3:7.2f} us " f"BW {bandwidth_gb_s:7.1f} GB/s " f"(p20={p20_ms * 1e3:.2f} p80={p80_ms * 1e3:.2f} us)" diff --git a/indexer_q_mxfp4.py b/indexer_q_mxfp4.py index f371ea1b7b38..6180a2ea0961 100644 --- a/indexer_q_mxfp4.py +++ b/indexer_q_mxfp4.py @@ -76,7 +76,8 @@ return out; } -template +template +__block_size__((TB_SIZE, 1, 1)) __global__ void fused_indexer_q_rope_mxfp4_kernel( const int64_t *positions_ptr, // [num_tokens] @@ -150,8 +151,11 @@ for (int i = 1; i < 8; i++) q_amax = bf16x2_max(q_amax, bf16x2_abs(q[i])); - // amax between 2 threads -> 1 warp shuffle call - q_amax = bf16x2_max(q_amax, __shfl_xor_sync(0xFFFF'FFFF, q_amax, 1)); + // each thread holds 16 elems -> 2 threads hold 32 elems + // warp shuffle among 2 threads + constexpr int NUM_THREADS_PER_MX = MX_BLOCK_SIZE / 16; + for (int stride = NUM_THREADS_PER_MX / 2; stride > 0; stride /= 2) + q_amax = bf16x2_max(q_amax, __shfl_xor_sync(0xFFFF'FFFF, q_amax, stride)); // final amax in FP32 float q_amax_f32[2]; @@ -193,7 +197,7 @@ const int q_fp4_offset = token_id * q_fp4_stride0 + head_id * q_fp4_stride1 + sublane_id * 8; - reinterpret_cast(q_fp4_ptr + q_fp4_offset)[0] = packed_fp4; + __stcs(reinterpret_cast(q_fp4_ptr + q_fp4_offset), packed_fp4); // scale weights if (global_tid < num_tokens * num_heads) { @@ -254,8 +258,6 @@ TORCH_CHECK(weights_out.sizes() == weights.sizes(), "weights_out must have the same shape as weights"); - // The kernel uses a full-warp shuffle mask. For 8-thread subwarps and a - // 256-thread block, DeepSeek's 64 heads makes the grid exact. TORCH_CHECK(num_heads % 32 == 0, "num_heads must be divisible by 32 for this launch wrapper"); @@ -267,7 +269,7 @@ const int total_threads = num_tokens * num_heads * kSubwarpSize; const int grid = (total_threads + kBlockSize - 1) / kBlockSize; - fused_indexer_q_rope_mxfp4_kernel + fused_indexer_q_rope_mxfp4_kernel <<>>( positions.data_ptr(), reinterpret_cast(q.data_ptr()), From 293597268bc2ef26903f601163d889e6c9fd9aad Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 30 Apr 2026 16:07:24 +0000 Subject: [PATCH 03/21] add cutedsl version Signed-off-by: Thien Tran --- benchmark_fused_indexer_q.py | 71 ++++-- indexer_q_mxfp4_cutedsl.py | 460 +++++++++++++++++++++++++++++++++++ 2 files changed, 511 insertions(+), 20 deletions(-) create mode 100644 indexer_q_mxfp4_cutedsl.py diff --git a/benchmark_fused_indexer_q.py b/benchmark_fused_indexer_q.py index 5719da77c1b3..36798f40933c 100644 --- a/benchmark_fused_indexer_q.py +++ b/benchmark_fused_indexer_q.py @@ -6,7 +6,8 @@ import torch -from indexer_q_mxfp4 import fused_indexer_q_rope_quant as dev_impl +from indexer_q_mxfp4 import fused_indexer_q_rope_quant as cuda_cpp_impl +from indexer_q_mxfp4_cutedsl import fused_indexer_q_rope_quant as cutedsl_impl from vllm.triton_utils import triton from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( fused_indexer_q_rope_quant as prod_impl, @@ -18,14 +19,18 @@ MAX_POS = 100_000 TOKENS = [1, 8, 32, 128, 256, 512, 1024, 2048, 4096, 8192] QUANTILES = [0.5, 0.2, 0.8] -PROVIDERS = {"prod": prod_impl, "dev": dev_impl} +PROVIDERS = { + "cuda_cpp": cuda_cpp_impl, + "cutedsl": cutedsl_impl, + "prod": prod_impl, +} -def measure(num_tokens, provider): +def measure(num_tokens, provider, check_provider, cache_dtype, skip_check): torch.set_default_device("cuda") positions = torch.randint(MAX_POS, (num_tokens,), dtype=torch.int64) query = torch.randn(num_tokens, NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16) - cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=torch.float32) + cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=cache_dtype) weights = torch.randn(num_tokens, NUM_HEADS, dtype=torch.bfloat16) kernel_args = ( positions, @@ -38,21 +43,27 @@ def measure(num_tokens, provider): ) selected_impl = PROVIDERS[provider] - if provider != "prod": - prod_q, prod_weights = prod_impl(*kernel_args) - dev_q, dev_weights = selected_impl(*kernel_args) - prod_q_packed, prod_q_scale = prod_q - dev_q_packed, dev_q_scale = dev_q - - assert torch.equal(prod_q_packed, dev_q_packed), ( - f"q packed mismatch for num_tokens={num_tokens}" + selected_q, selected_weights = selected_impl(*kernel_args) + torch.accelerator.synchronize() + + if not skip_check and provider != check_provider: + ref_q, ref_weights = PROVIDERS[check_provider](*kernel_args) + ref_q_packed, ref_q_scale = ref_q + selected_q_packed, selected_q_scale = selected_q + + assert torch.equal(ref_q_packed, selected_q_packed), ( + f"q packed mismatch for provider={provider} " + f"num_tokens={num_tokens} cache_dtype={cache_dtype}" ) - assert torch.equal(prod_q_scale, dev_q_scale), ( - f"q scale mismatch for num_tokens={num_tokens}" + assert torch.equal(ref_q_scale, selected_q_scale), ( + f"q scale mismatch for provider={provider} " + f"num_tokens={num_tokens} cache_dtype={cache_dtype}" ) - assert torch.equal(prod_weights, dev_weights), ( - f"weights mismatch for num_tokens={num_tokens}" + assert torch.equal(ref_weights, selected_weights), ( + f"weights mismatch for provider={provider} " + f"num_tokens={num_tokens} cache_dtype={cache_dtype}" ) + torch.accelerator.synchronize() benchmark_fn = lambda: selected_impl(*kernel_args) median_ms, p20_ms, p80_ms = triton.testing.do_bench( @@ -62,7 +73,7 @@ def measure(num_tokens, provider): bytes_per_token = 8 # position int64 bytes_per_token += NUM_HEADS * HEAD_DIM * 2 # q in bf16 - bytes_per_token += ROPE_DIM * 4 # rope fp32 + bytes_per_token += ROPE_DIM * torch.empty((), dtype=cache_dtype).element_size() bytes_per_token += NUM_HEADS * 2 # weights in bf16 bytes_per_token += NUM_HEADS * HEAD_DIM // 2 # q out fp4 bytes_per_token += NUM_HEADS * HEAD_DIM // 32 # q_scale uint8 @@ -75,20 +86,40 @@ def measure(num_tokens, provider): if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--tokens", type=int, nargs="+", default=TOKENS) + parser.add_argument( + "--providers", + choices=PROVIDERS, + nargs="+", + default=PROVIDERS.keys(), + ) + parser.add_argument("--check-provider", choices=PROVIDERS, default="prod") + parser.add_argument("--skip-check", action="store_true") + parser.add_argument( + "--cache-dtype", + choices=["float32", "bfloat16"], + default="float32", + ) args = parser.parse_args() + cache_dtype = getattr(torch, args.cache_dtype) print(f"Device: {torch.cuda.get_device_name(0)}") - print(f"H={NUM_HEADS} D={HEAD_DIM} rope_dim={ROPE_DIM} use_fp4=True\n") + print( + f"H={NUM_HEADS} D={HEAD_DIM} rope_dim={ROPE_DIM} " + f"use_fp4=True cache_dtype={args.cache_dtype}\n" + ) for num_tokens in args.tokens: - for provider in PROVIDERS: + for provider in args.providers: median_ms, p20_ms, p80_ms, moved_bytes = measure( num_tokens, provider, + args.check_provider, + cache_dtype, + args.skip_check, ) bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9 print( - f"[{provider:4s}] T={num_tokens:6d} " + f"[{provider:8s}] T={num_tokens:6d} " f"{median_ms * 1e3:7.2f} us " f"BW {bandwidth_gb_s:7.1f} GB/s " f"(p20={p20_ms * 1e3:.2f} p80={p80_ms * 1e3:.2f} us)" diff --git a/indexer_q_mxfp4_cutedsl.py b/indexer_q_mxfp4_cutedsl.py new file mode 100644 index 000000000000..a06b5395ac9f --- /dev/null +++ b/indexer_q_mxfp4_cutedsl.py @@ -0,0 +1,460 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from functools import cache + +import cutlass +import cutlass.cute as cute +import torch +from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm, vector +from cutlass.cutlass_dsl import T, dsl_user_op +from quack.compile_utils import make_fake_tensor + +from vllm.vllm_flash_attn.cute import utils as cute_utils + +MXFP4_BLOCK_SIZE = 32 + +_TORCH_TO_CUTE = { + torch.bfloat16: BFloat16, + torch.float32: Float32, +} + + +@dsl_user_op +def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + ], + "cvt.rn.bf16x2.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: + out = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Uint32(data).ir_value(loc=loc, ip=ip)], + "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n", + "=f,=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return ( + Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), + Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), + ) + + +@dsl_user_op +def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Uint32(a).ir_value(loc=loc, ip=ip)], + "abs.bf16x2 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Uint32(a).ir_value(loc=loc, ip=ip), + Uint32(b).ir_value(loc=loc, ip=ip), + ], + "max.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _fp32x8_to_fp4x8( + vals: cute.Tensor, + offset: cutlass.Constexpr[int], + *, + loc=None, + ip=None, +) -> Uint32: + # Pack eight scaled FP32 values into four E2M1x2 bytes, returned as one b32. + operands = [Float32(vals[offset + i]).ir_value(loc=loc, ip=ip) for i in range(8)] + return Uint32( + llvm.inline_asm( + T.i32(), + operands, + "{\n\t" + ".reg .b8 x0, x1, x2, x3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" + "mov.b32 $0, {x0, x1, x2, x3};\n\t" + "}\n", + "=r,f,f,f,f,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +# Custom vectorized load to support cache modifiers. For some reason, +# cute.autovec_copy() does not currently emit the requested modifiers. +# tensor and coord is only used to select the base pointer. actual load +# is done using out_dtype +@dsl_user_op +def _ldg_vec( + tensor: cute.Tensor, + coord: cute.Coord, + vec_size: cutlass.Constexpr[int], + modifier: cutlass.Constexpr[str] = "", + out_dtype: cutlass.Constexpr[type[cutlass.Numeric]] = Uint32, + *, + loc=None, + ip=None, +) -> cute.TensorSSA: + if const_expr(out_dtype is Float32): + mlir_ty = T.f32() + ptx_ty = "f32" + constraint = "=f" + elif const_expr(out_dtype is Uint32): + mlir_ty = T.i32() + ptx_ty = "b32" + constraint = "=r" + else: + raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {out_dtype}") + + # compute base pointer + base_ptr = ( + tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) + ).toint() + + # build PTX string + ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" + ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" + ptx_str += f", [${vec_size}];" + out = llvm.inline_asm( + llvm.StructType.get_literal([mlir_ty] * vec_size), + [Int64(base_ptr).ir_value(loc=loc, ip=ip)], + ptx_str, + ",".join([constraint] * vec_size + ["l"]), + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + vec = vector.from_elements( + ir.VectorType.get([vec_size], mlir_ty, loc=loc), + [llvm.extractvalue(mlir_ty, out, [i], loc=loc, ip=ip) for i in range(vec_size)], + loc=loc, + ip=ip, + ) + return cute.TensorSSA(vec, vec_size, out_dtype) + + +@dsl_user_op +def _stg_u32xN( + tensor: cute.Tensor, + coord: cute.Coord, + values: cute.Tensor, + vec_size: cutlass.Constexpr[int], + modifier: cutlass.Constexpr[str] = "", + *, + loc=None, + ip=None, +) -> None: + base_ptr = ( + tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) + ).toint() + value_operands = ", ".join(f"${i + 1}" for i in range(vec_size)) + llvm.inline_asm( + None, + [Int64(base_ptr).ir_value(loc=loc, ip=ip)] + + [Uint32(values[i]).ir_value(loc=loc, ip=ip) for i in range(vec_size)], + f"st.global{modifier}.v{vec_size}.u32 [$0], {{{value_operands}}};", + ",".join(["l"] + ["r"] * vec_size), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +class IndexerQMxFp4Kernel: + """Eight-thread subwarps process one ``(token, head)`` row.""" + + def __init__( + self, + head_dim: int = 128, + rope_dim: int = 64, + num_heads: int = 64, + cos_sin_dtype: type[cutlass.Numeric] = cutlass.Float32, + ): + self.head_dim = head_dim + self.rope_dim = rope_dim + self.nope_dim = head_dim - rope_dim + self.num_heads = num_heads + self.cos_sin_dtype = cos_sin_dtype + + # later we will use 32B load = 16 BF16 elems + # thus, head_dim=128 requires 8 threads to handle. + # let's call subwarp = 8 threads. + self.subwarp_size = head_dim // 16 + self.tb_size = 256 + + @cute.jit + def __call__( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp4: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + num_tokens, num_heads, _ = q.shape + total_threads = num_tokens * num_heads * self.subwarp_size + grid = [cute.ceil_div(total_threads, self.tb_size), 1, 1] + self.kernel( + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + scale, + ).launch(grid=grid, block=[self.tb_size, 1, 1]) + + @cute.kernel + def kernel( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp4: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + block_id, _, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + num_token_heads = q.shape[0] * self.num_heads + global_tid = block_id * self.tb_size + tidx + + global_subwarp_id = global_tid // self.subwarp_size + sublane = tidx % self.subwarp_size + + token_id = global_subwarp_id // self.num_heads + head_id = global_subwarp_id - token_id * self.num_heads + + # each thread loads 16 BF16 elems + elem_base = sublane * 16 + + # q layout: [num_tokens, num_heads, head_dim] + _q_bits = _ldg_vec( + q, (token_id, head_id, elem_base), 8, ".relaxed.cta.L1::no_allocate" + ) + q_bits = cute.make_rmem_tensor(8, Uint32) + q_bits.store(_q_bits) # copy to make it mutable + + # RoPE applies only to the trailing rope_dim values. We keep the rounded + # BF16 result in q_bits so the later amax and quantization see BF16. + # cos_sin_cache layout: [max_pos, rope_dim] + if elem_base >= self.nope_dim: + pos = positions[token_id] + rope_idx = (elem_base - self.nope_dim) // 2 + if const_expr(self.cos_sin_dtype is Float32): + cos_vals = _ldg_vec( + cos_sin_cache, + (pos, rope_idx), + 8, + out_dtype=Float32, + ) + sin_vals = _ldg_vec( + cos_sin_cache, + (pos, self.nope_dim // 2 + rope_idx), + 8, + out_dtype=Float32, + ) + else: + # Each BF16 cache load lane contains two adjacent values. + cos_loaded = _ldg_vec(cos_sin_cache, (pos, rope_idx), 4) + sin_loaded = _ldg_vec( + cos_sin_cache, + (pos, self.rope_dim // 2 + rope_idx), + 4, + ) + cos_vals = cute.make_rmem_tensor(8, Float32) + sin_vals = cute.make_rmem_tensor(8, Float32) + for i in cutlass.range_constexpr(4): + cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( + cos_loaded[i] + ) + sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( + sin_loaded[i] + ) + + for i in cutlass.range_constexpr(8): + q0, q1 = _bf16x2_to_fp32(q_bits[i]) + cos = cos_vals[i] + sin = sin_vals[i] + rot0 = q0 * cos - q1 * sin + rot1 = q0 * sin + q1 * cos + # convert back to BF16 to match numerics + q_bits[i] = _fp32x2_to_bf16x2(rot0, rot1) + + # Each thread holds 16 elems. Two adjacent threads form one 32-elem + # MXFP4 block, so a width-2 shuffle gives the block amax. + local_amax = _bf16x2_abs(q_bits[0]) + for i in cutlass.range_constexpr(1, 8): + local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bits[i])) + amax_bits = cute_utils.warp_reduce( + local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 + ) + amax0, amax1 = _bf16x2_to_fp32(amax_bits) + amax = cute_utils.fmax(amax0, amax1) + + fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * (1.0 / 6.0) + bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) + # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask + # increments the exponent whenever fp4_scale is not exactly a power of 2. + ue8m0 = cute_utils.shr_u32(bits + Uint32(0x7FFFFF), Uint32(23)) & Uint32(0xFF) + + # Only one of the two threads in an MXFP4 block writes the shared scale. + if tidx % 2 == 0: + mx_block = sublane // (MXFP4_BLOCK_SIZE // 16) + q_scale[token_id, head_id, mx_block] = Uint8(ue8m0) + + # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent + # -A + 127 = 254 - ue8m0. + inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) + inv_fp4_scale = Float32(llvm.bitcast(T.f32(), inv_scale_bits.ir_value())) + + vals = cute.make_rmem_tensor(16, Float32) + for i in cutlass.range_constexpr(8): + vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bits[i]) + vals[i * 2] = vals[i * 2] * inv_fp4_scale + vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale + + # pack to FP4 + packed = cute.make_rmem_tensor(2, Uint32) + packed[0] = _fp32x8_to_fp4x8(vals, 0) + packed[1] = _fp32x8_to_fp4x8(vals, 8) + # Each thread writes the eight packed bytes corresponding to its 16 Q values. + _stg_u32xN(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") + + # Weight scaling is independent of the Q subwarp work. The first + # num_tokens * num_heads logical threads cover one weight each. + if global_tid < num_token_heads: + weight_token_id = global_tid // self.num_heads + weight_head_id = global_tid - weight_token_id * self.num_heads + weights_out[weight_token_id, weight_head_id] = ( + weights[weight_token_id, weight_head_id].to(Float32) * scale + ) + + +@cache +def _compile_indexer_q_mxfp4( + head_dim: int, rope_dim: int, num_heads: int, cos_sin_dtype: type[cutlass.Numeric] +): + num_tokens = cute.sym_int() + max_pos = cute.sym_int() + + q = make_fake_tensor(BFloat16, (num_tokens, num_heads, head_dim), divisibility=8) + positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) + cos_div = math.gcd(128 // cos_sin_dtype.width, rope_dim) + cos_sin_cache = make_fake_tensor( + cos_sin_dtype, (max_pos, rope_dim), divisibility=cos_div + ) + weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) + q_fp4 = make_fake_tensor( + Uint8, (num_tokens, num_heads, head_dim // 2), divisibility=16 + ) + q_scale = make_fake_tensor( + Uint8, + (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), + divisibility=4, + ) + weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) + + kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) + return cute.compile( + kernel, + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + Float32(0.0), + options="--enable-tvm-ffi", + ) + + +def fused_indexer_q_rope_quant( + positions: torch.Tensor, + q: torch.Tensor, + cos_sin_cache: torch.Tensor, + weights: torch.Tensor, + softmax_scale: float, + head_scale: float, + use_fp4: bool = False, +) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], torch.Tensor]: + if not use_fp4: + raise NotImplementedError( + "indexer_q_mxfp4_cutedsl only implements use_fp4=True" + ) + + num_tokens, num_heads, head_dim = q.shape + rope_dim = cos_sin_cache.shape[-1] + q_fp4 = q.new_empty((num_tokens, num_heads, head_dim // 2), dtype=torch.uint8) + q_scale = q.new_empty( + (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), dtype=torch.uint8 + ) + weights_out = torch.empty_like(weights, dtype=torch.float32) + + compiled = _compile_indexer_q_mxfp4( + head_dim, + rope_dim, + num_heads, + _TORCH_TO_CUTE[cos_sin_cache.dtype], + ) + scale = float(softmax_scale * head_scale) + compiled( + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + scale, + ) + return (q_fp4, q_scale.view(torch.int32).squeeze(-1)), weights_out From 53bc65a446456551af1c6277b6be755e7eedab7e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 30 Apr 2026 23:24:47 +0000 Subject: [PATCH 04/21] replace triton with cutedsl Signed-off-by: Thien Tran --- benchmark_fused_indexer_q.py | 110 +--- .../ops/deepseek_v4_ops/fused_indexer_q.py | 552 +++++++++++++----- 2 files changed, 447 insertions(+), 215 deletions(-) diff --git a/benchmark_fused_indexer_q.py b/benchmark_fused_indexer_q.py index 36798f40933c..366338049d55 100644 --- a/benchmark_fused_indexer_q.py +++ b/benchmark_fused_indexer_q.py @@ -1,16 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Benchmark the FP4 fused_indexer_q_rope_quant path.""" - -import argparse - import torch -from indexer_q_mxfp4 import fused_indexer_q_rope_quant as cuda_cpp_impl -from indexer_q_mxfp4_cutedsl import fused_indexer_q_rope_quant as cutedsl_impl from vllm.triton_utils import triton from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( - fused_indexer_q_rope_quant as prod_impl, + fused_indexer_q_rope_quant, ) NUM_HEADS = 64 @@ -18,21 +12,15 @@ ROPE_DIM = 64 MAX_POS = 100_000 TOKENS = [1, 8, 32, 128, 256, 512, 1024, 2048, 4096, 8192] -QUANTILES = [0.5, 0.2, 0.8] -PROVIDERS = { - "cuda_cpp": cuda_cpp_impl, - "cutedsl": cutedsl_impl, - "prod": prod_impl, -} +ROPE_DTYPE = torch.float32 -def measure(num_tokens, provider, check_provider, cache_dtype, skip_check): - torch.set_default_device("cuda") +def make_inputs(num_tokens: int): positions = torch.randint(MAX_POS, (num_tokens,), dtype=torch.int64) query = torch.randn(num_tokens, NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16) - cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=cache_dtype) + cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=ROPE_DTYPE) weights = torch.randn(num_tokens, NUM_HEADS, dtype=torch.bfloat16) - kernel_args = ( + return ( positions, query, cos_sin_cache, @@ -42,85 +30,39 @@ def measure(num_tokens, provider, check_provider, cache_dtype, skip_check): True, ) - selected_impl = PROVIDERS[provider] - selected_q, selected_weights = selected_impl(*kernel_args) - torch.accelerator.synchronize() - if not skip_check and provider != check_provider: - ref_q, ref_weights = PROVIDERS[check_provider](*kernel_args) - ref_q_packed, ref_q_scale = ref_q - selected_q_packed, selected_q_scale = selected_q +def benchmark(num_tokens: int): + torch.set_default_device("cuda") - assert torch.equal(ref_q_packed, selected_q_packed), ( - f"q packed mismatch for provider={provider} " - f"num_tokens={num_tokens} cache_dtype={cache_dtype}" - ) - assert torch.equal(ref_q_scale, selected_q_scale), ( - f"q scale mismatch for provider={provider} " - f"num_tokens={num_tokens} cache_dtype={cache_dtype}" - ) - assert torch.equal(ref_weights, selected_weights), ( - f"weights mismatch for provider={provider} " - f"num_tokens={num_tokens} cache_dtype={cache_dtype}" - ) - torch.accelerator.synchronize() + # run multiple times per measurement for more reliable results + # separate sets of inputs to avoid L2 cache + N = 10 + inputs_list = [make_inputs(num_tokens) for _ in range(N)] - benchmark_fn = lambda: selected_impl(*kernel_args) - median_ms, p20_ms, p80_ms = triton.testing.do_bench( - benchmark_fn, - quantiles=QUANTILES, - ) + def f(): + for kernel_args in inputs_list: + fused_indexer_q_rope_quant(*kernel_args) + + median_ms = triton.testing.do_bench(f) / N bytes_per_token = 8 # position int64 bytes_per_token += NUM_HEADS * HEAD_DIM * 2 # q in bf16 - bytes_per_token += ROPE_DIM * torch.empty((), dtype=cache_dtype).element_size() + bytes_per_token += ROPE_DIM * torch.empty((), dtype=ROPE_DTYPE).element_size() bytes_per_token += NUM_HEADS * 2 # weights in bf16 bytes_per_token += NUM_HEADS * HEAD_DIM // 2 # q out fp4 bytes_per_token += NUM_HEADS * HEAD_DIM // 32 # q_scale uint8 bytes_per_token += NUM_HEADS * 4 # weights out fp32 total_bytes = bytes_per_token * num_tokens - return median_ms, p20_ms, p80_ms, total_bytes + return median_ms, total_bytes if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--tokens", type=int, nargs="+", default=TOKENS) - parser.add_argument( - "--providers", - choices=PROVIDERS, - nargs="+", - default=PROVIDERS.keys(), - ) - parser.add_argument("--check-provider", choices=PROVIDERS, default="prod") - parser.add_argument("--skip-check", action="store_true") - parser.add_argument( - "--cache-dtype", - choices=["float32", "bfloat16"], - default="float32", - ) - args = parser.parse_args() - - cache_dtype = getattr(torch, args.cache_dtype) - print(f"Device: {torch.cuda.get_device_name(0)}") - print( - f"H={NUM_HEADS} D={HEAD_DIM} rope_dim={ROPE_DIM} " - f"use_fp4=True cache_dtype={args.cache_dtype}\n" - ) - - for num_tokens in args.tokens: - for provider in args.providers: - median_ms, p20_ms, p80_ms, moved_bytes = measure( - num_tokens, - provider, - args.check_provider, - cache_dtype, - args.skip_check, - ) - bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9 - print( - f"[{provider:8s}] T={num_tokens:6d} " - f"{median_ms * 1e3:7.2f} us " - f"BW {bandwidth_gb_s:7.1f} GB/s " - f"(p20={p20_ms * 1e3:.2f} p80={p80_ms * 1e3:.2f} us)" - ) + for num_tokens in TOKENS: + median_ms, moved_bytes = benchmark(num_tokens) + bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9 + print( + f"T={num_tokens:6d} " + f"{median_ms * 1e3:7.2f} us " + f"BW {bandwidth_gb_s:7.1f} GB/s " + ) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index f94fc013f5c6..c2e981f41363 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -1,12 +1,27 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from functools import cache + +import cutlass +import cutlass.cute as cute import torch +from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm, vector +from cutlass.cutlass_dsl import T, dsl_user_op +from quack.compile_utils import make_fake_tensor from vllm.triton_utils import tl, triton +from vllm.vllm_flash_attn.cute import utils as cute_utils # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. MXFP4_BLOCK_SIZE = 32 +_TORCH_TO_CUTE = { + torch.bfloat16: BFloat16, + torch.float32: Float32, +} + @triton.jit def _get_cos_sin( @@ -169,116 +184,6 @@ def _fused_indexer_q_rope_quant_kernel( ) -@triton.jit -def _fused_indexer_q_rope_mxfp4_kernel( - pos_ptr, - # Index Q RoPE input (fp/bf16) - index_q_ptr, - index_q_stride0, - index_q_stride1, - index_q_cos_sin_ptr, - index_q_cos_sin_stride, - INDEX_Q_HALF_ROT_DIM: tl.constexpr, - # MXFP4 Q outputs - index_q_mxfp4_ptr, # uint8, (T, H, HEAD_DIM // 2) - index_q_mxfp4_stride0, - index_q_mxfp4_stride1, - index_q_scale_ptr, # uint8 ue8m0, (T, H, HEAD_DIM // BLOCK) - index_q_scale_stride0, - index_q_scale_stride1, - INDEX_Q_HEAD_DIM: tl.constexpr, - MXFP4_BLOCK: tl.constexpr, - # Weights (NO per-token q_scale fold for MXFP4; per-block scales stay - # with the Q values in the output scale tensor). - index_weights_ptr, - index_weights_stride, - index_weights_softmax_scale, - index_weights_head_scale, - index_weights_out_ptr, - index_weights_out_stride, -): - INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM - INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM - NUM_NOPE_BLOCKS: tl.constexpr = INDEX_Q_NOPE_DIM // MXFP4_BLOCK - NUM_ROPE_BLOCKS: tl.constexpr = INDEX_Q_ROT_DIM // MXFP4_BLOCK - HALF_BLOCK: tl.constexpr = MXFP4_BLOCK // 2 - tl.static_assert(INDEX_Q_NOPE_DIM >= 0) - tl.static_assert(INDEX_Q_NOPE_DIM % MXFP4_BLOCK == 0) - tl.static_assert(INDEX_Q_ROT_DIM % MXFP4_BLOCK == 0) - tl.static_assert(MXFP4_BLOCK % 2 == 0) - - tok_idx = tl.program_id(0) - head_idx = tl.program_id(1) - - pos = tl.load(pos_ptr + tok_idx) - - q_base = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1 - out_base = ( - index_q_mxfp4_ptr - + tok_idx * index_q_mxfp4_stride0 - + head_idx * index_q_mxfp4_stride1 - ) - scale_base = ( - index_q_scale_ptr - + tok_idx * index_q_scale_stride0 - + head_idx * index_q_scale_stride1 - ) - - half_off = tl.arange(0, HALF_BLOCK) - - # ---- NoPE blocks: direct load, pair as (even-index, odd-index) values ---- - for b in tl.static_range(NUM_NOPE_BLOCKS): - base = b * MXFP4_BLOCK - x_lo = tl.load(q_base + base + half_off * 2).to(tl.float32) - x_hi = tl.load(q_base + base + half_off * 2 + 1).to(tl.float32) - packed, ue8m0 = _quantize_mxfp4_pair(x_lo, x_hi) - tl.store(out_base + base // 2 + half_off, packed) - tl.store(scale_base + b, ue8m0) - - # ---- RoPE blocks: apply GPT-J interleaved RoPE to the block's 16 pairs, - # then quantize. Each block covers HALF_BLOCK (=16) cos/sin pairs. ---- - rot_q_base = q_base + INDEX_Q_NOPE_DIM - for b in tl.static_range(NUM_ROPE_BLOCKS): - pair_off = b * HALF_BLOCK + half_off # indices in [0, HALF_ROT_DIM) - cos_b = tl.load( - index_q_cos_sin_ptr + pos * index_q_cos_sin_stride + pair_off - ).to(tl.float32) - sin_b = tl.load( - index_q_cos_sin_ptr - + pos * index_q_cos_sin_stride - + pair_off - + INDEX_Q_HALF_ROT_DIM - ).to(tl.float32) - x_even = tl.load(rot_q_base + pair_off * 2).to(tl.float32) - x_odd = tl.load(rot_q_base + pair_off * 2 + 1).to(tl.float32) - r_even = x_even * cos_b - x_odd * sin_b - r_odd = x_odd * cos_b + x_even * sin_b - # bf16 roundtrip for parity with the FP8 kernel / reference numerics. - r_even = r_even.to(tl.bfloat16).to(tl.float32) - r_odd = r_odd.to(tl.bfloat16).to(tl.float32) - packed, ue8m0 = _quantize_mxfp4_pair(r_even, r_odd) - rope_byte_off = (INDEX_Q_NOPE_DIM + b * MXFP4_BLOCK) // 2 - tl.store(out_base + rope_byte_off + half_off, packed) - tl.store(scale_base + NUM_NOPE_BLOCKS + b, ue8m0) - - # MXFP4 weight-fold contract: - # index_weights_out = index_weights * softmax_scale * head_scale - # NOTE: q_scale is NOT folded here (contrast with the FP8 kernel above). - # MXFP4 Q emits a separate ue8m0 scale tensor of shape - # (T, H, HEAD_DIM // MXFP4_BLOCK) alongside the packed values, so each - # per-block scale is applied by the downstream MXFP4 logits kernel when - # dequantizing Q — there is no per-token scalar to fold into `weights`. - index_weights = tl.load( - index_weights_ptr + tok_idx * index_weights_stride + head_idx - ).to(tl.float32) - index_weights *= index_weights_softmax_scale - index_weights *= index_weights_head_scale - tl.store( - index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx, - index_weights, - ) - - def fused_indexer_q_rope_quant( positions: torch.Tensor, index_q: torch.Tensor, @@ -332,39 +237,30 @@ def fused_indexer_q_rope_quant( f"size {MXFP4_BLOCK_SIZE}" ) num_scale_blocks = index_q_head_dim // MXFP4_BLOCK_SIZE - index_q_packed = torch.empty( + index_q_packed = index_q.new_empty( (num_tokens, num_index_q_heads, index_q_head_dim // 2), dtype=torch.uint8, - device=index_q.device, ) - index_q_scale = torch.empty( + index_q_scale = index_q.new_empty( (num_tokens, num_index_q_heads, num_scale_blocks), dtype=torch.uint8, - device=index_q.device, ) - _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( + compiled = _compile_indexer_q_mxfp4( + index_q_head_dim, + index_q_cos_sin_cache.shape[-1], + num_index_q_heads, + _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype], + ) + scale = float(index_weights_softmax_scale * index_weights_head_scale) + compiled( positions, index_q, - index_q.stride(0), - index_q.stride(1), index_q_cos_sin_cache, - index_q_cos_sin_cache.stride(0), - index_q_cos_sin_cache.shape[-1] // 2, + index_weights, index_q_packed, - index_q_packed.stride(0), - index_q_packed.stride(1), index_q_scale, - index_q_scale.stride(0), - index_q_scale.stride(1), - index_q_head_dim, - MXFP4_BLOCK_SIZE, - index_weights, - index_weights.stride(0), - index_weights_softmax_scale, - index_weights_head_scale, index_weights_out, - index_weights_out.stride(0), - num_warps=1, # TODO: Tune this + scale, ) # Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0 # bytes per (token, head) reinterpreted as one int32, then squeezed @@ -398,3 +294,397 @@ def fused_indexer_q_rope_quant( num_warps=1, # TODO: Tune this ) return index_q_fp8, index_weights_out + + +@dsl_user_op +def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + ], + "cvt.rn.bf16x2.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: + out = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Uint32(data).ir_value(loc=loc, ip=ip)], + "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n", + "=f,=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return ( + Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), + Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), + ) + + +@dsl_user_op +def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Uint32(a).ir_value(loc=loc, ip=ip)], + "abs.bf16x2 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Uint32(a).ir_value(loc=loc, ip=ip), + Uint32(b).ir_value(loc=loc, ip=ip), + ], + "max.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _fp32x8_to_fp4x8( + vals: cute.Tensor, + offset: cutlass.Constexpr[int], + *, + loc=None, + ip=None, +) -> Uint32: + # Pack eight scaled FP32 values into four E2M1x2 bytes, returned as one b32. + operands = [Float32(vals[offset + i]).ir_value(loc=loc, ip=ip) for i in range(8)] + return Uint32( + llvm.inline_asm( + T.i32(), + operands, + "{\n\t" + ".reg .b8 x0, x1, x2, x3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" + "mov.b32 $0, {x0, x1, x2, x3};\n\t" + "}\n", + "=r,f,f,f,f,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +# Custom vectorized load to support cache modifiers. For some reason, +# cute.autovec_copy() does not currently emit the requested modifiers. +# tensor and coord is only used to select the base pointer. actual load +# is done using out_dtype +@dsl_user_op +def _ldg_vec( + tensor: cute.Tensor, + coord: cute.Coord, + vec_size: cutlass.Constexpr[int], + modifier: cutlass.Constexpr[str] = "", + out_dtype: cutlass.Constexpr[type[cutlass.Numeric]] = Uint32, + *, + loc=None, + ip=None, +) -> cute.TensorSSA: + if const_expr(out_dtype is Float32): + mlir_ty = T.f32() + ptx_ty = "f32" + constraint = "=f" + elif const_expr(out_dtype is Uint32): + mlir_ty = T.i32() + ptx_ty = "b32" + constraint = "=r" + else: + raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {out_dtype}") + + # compute base pointer + base_ptr = ( + tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) + ).toint() + + # build PTX string + ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" + ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" + ptx_str += f", [${vec_size}];" + out = llvm.inline_asm( + llvm.StructType.get_literal([mlir_ty] * vec_size), + [Int64(base_ptr).ir_value(loc=loc, ip=ip)], + ptx_str, + ",".join([constraint] * vec_size + ["l"]), + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + vec = vector.from_elements( + ir.VectorType.get([vec_size], mlir_ty, loc=loc), + [llvm.extractvalue(mlir_ty, out, [i], loc=loc, ip=ip) for i in range(vec_size)], + loc=loc, + ip=ip, + ) + return cute.TensorSSA(vec, vec_size, out_dtype) + + +@dsl_user_op +def _stg_u32xN( + tensor: cute.Tensor, + coord: cute.Coord, + values: cute.Tensor, + vec_size: cutlass.Constexpr[int], + modifier: cutlass.Constexpr[str] = "", + *, + loc=None, + ip=None, +) -> None: + base_ptr = ( + tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) + ).toint() + value_operands = ", ".join(f"${i + 1}" for i in range(vec_size)) + llvm.inline_asm( + None, + [Int64(base_ptr).ir_value(loc=loc, ip=ip)] + + [Uint32(values[i]).ir_value(loc=loc, ip=ip) for i in range(vec_size)], + f"st.global{modifier}.v{vec_size}.u32 [$0], {{{value_operands}}};", + ",".join(["l"] + ["r"] * vec_size), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +class IndexerQMxFp4Kernel: + """Eight-thread subwarps process one ``(token, head)`` row.""" + + def __init__( + self, + head_dim: int = 128, + rope_dim: int = 64, + num_heads: int = 64, + cos_sin_dtype: type[cutlass.Numeric] = cutlass.Float32, + ): + self.head_dim = head_dim + self.rope_dim = rope_dim + self.nope_dim = head_dim - rope_dim + self.num_heads = num_heads + self.cos_sin_dtype = cos_sin_dtype + + # later we will use 32B load = 16 BF16 elems + # thus, head_dim=128 requires 8 threads to handle. + # let's call subwarp = 8 threads. + self.subwarp_size = head_dim // 16 + self.tb_size = 256 + + @cute.jit + def __call__( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp4: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + num_tokens, num_heads, _ = q.shape + total_threads = num_tokens * num_heads * self.subwarp_size + grid = [cute.ceil_div(total_threads, self.tb_size), 1, 1] + self.kernel( + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + scale, + ).launch(grid=grid, block=[self.tb_size, 1, 1]) + + @cute.kernel + def kernel( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp4: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + block_id, _, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + num_token_heads = q.shape[0] * self.num_heads + global_tid = block_id * self.tb_size + tidx + + global_subwarp_id = global_tid // self.subwarp_size + sublane = tidx % self.subwarp_size + + token_id = global_subwarp_id // self.num_heads + head_id = global_subwarp_id - token_id * self.num_heads + + # each thread loads 16 BF16 elems + elem_base = sublane * 16 + + # q layout: [num_tokens, num_heads, head_dim] + _q_bits = _ldg_vec( + q, (token_id, head_id, elem_base), 8, ".relaxed.cta.L1::no_allocate" + ) + q_bits = cute.make_rmem_tensor(8, Uint32) + q_bits.store(_q_bits) # copy to make it mutable + + # RoPE applies only to the trailing rope_dim values. We keep the rounded + # BF16 result in q_bits so the later amax and quantization see BF16. + # cos_sin_cache layout: [max_pos, rope_dim] + if elem_base >= self.nope_dim: + pos = positions[token_id] + rope_idx = (elem_base - self.nope_dim) // 2 + if const_expr(self.cos_sin_dtype is Float32): + cos_vals = _ldg_vec( + cos_sin_cache, + (pos, rope_idx), + 8, + out_dtype=Float32, + ) + sin_vals = _ldg_vec( + cos_sin_cache, + (pos, self.nope_dim // 2 + rope_idx), + 8, + out_dtype=Float32, + ) + else: + # Each BF16 cache load lane contains two adjacent values. + cos_loaded = _ldg_vec(cos_sin_cache, (pos, rope_idx), 4) + sin_loaded = _ldg_vec( + cos_sin_cache, + (pos, self.rope_dim // 2 + rope_idx), + 4, + ) + cos_vals = cute.make_rmem_tensor(8, Float32) + sin_vals = cute.make_rmem_tensor(8, Float32) + for i in cutlass.range_constexpr(4): + cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( + cos_loaded[i] + ) + sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( + sin_loaded[i] + ) + + for i in cutlass.range_constexpr(8): + q0, q1 = _bf16x2_to_fp32(q_bits[i]) + cos = cos_vals[i] + sin = sin_vals[i] + rot0 = q0 * cos - q1 * sin + rot1 = q0 * sin + q1 * cos + # convert back to BF16 to match numerics + q_bits[i] = _fp32x2_to_bf16x2(rot0, rot1) + + # compute amax in packed bf16x2 to save instructions + # Each thread holds 16 elems. Two adjacent threads form one 32-elem + # MXFP4 block, so a width-2 shuffle gives the block amax. + local_amax = _bf16x2_abs(q_bits[0]) + for i in cutlass.range_constexpr(1, 8): + local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bits[i])) + amax_bits = cute_utils.warp_reduce( + local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 + ) + amax0, amax1 = _bf16x2_to_fp32(amax_bits) + amax = cute_utils.fmax(amax0, amax1) + + # compute block scale with bit manipulation + # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask + # increments the exponent whenever fp4_scale is not exactly a power of 2. + fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * (1.0 / 6.0) + bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) + ue8m0 = cute_utils.shr_u32(bits + Uint32(0x7FFFFF), Uint32(23)) & Uint32(0xFF) + + # Only one of the two threads in an MXFP4 block writes the shared scale. + if tidx % 2 == 0: + mx_block = sublane // (MXFP4_BLOCK_SIZE // 16) + q_scale[token_id, head_id, mx_block] = Uint8(ue8m0) + + # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent + # -A + 127 = 254 - ue8m0. + inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) + inv_fp4_scale = Float32(llvm.bitcast(T.f32(), inv_scale_bits.ir_value())) + + vals = cute.make_rmem_tensor(16, Float32) + for i in cutlass.range_constexpr(8): + vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bits[i]) + vals[i * 2] = vals[i * 2] * inv_fp4_scale + vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale + + # pack to FP4 + packed = cute.make_rmem_tensor(2, Uint32) + packed[0] = _fp32x8_to_fp4x8(vals, 0) + packed[1] = _fp32x8_to_fp4x8(vals, 8) + # Each thread writes the eight packed bytes corresponding to its 16 Q values. + _stg_u32xN(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") + + # Weight scaling is independent of the Q subwarp work. The first + # num_tokens * num_heads logical threads cover one weight each. + if global_tid < num_token_heads: + weight_token_id = global_tid // self.num_heads + weight_head_id = global_tid - weight_token_id * self.num_heads + weights_out[weight_token_id, weight_head_id] = ( + weights[weight_token_id, weight_head_id].to(Float32) * scale + ) + + +@cache +def _compile_indexer_q_mxfp4( + head_dim: int, rope_dim: int, num_heads: int, cos_sin_dtype: type[cutlass.Numeric] +): + num_tokens = cute.sym_int() + max_pos = cute.sym_int() + + q = make_fake_tensor(BFloat16, (num_tokens, num_heads, head_dim), divisibility=8) + positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) + cos_sin_cache = make_fake_tensor(cos_sin_dtype, (max_pos, rope_dim), divisibility=8) + weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) + q_fp4 = make_fake_tensor( + Uint8, (num_tokens, num_heads, head_dim // 2), divisibility=16 + ) + q_scale = make_fake_tensor( + Uint8, + (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), + divisibility=4, + ) + weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) + + kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) + return cute.compile( + kernel, + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + Float32(0.0), + options="--enable-tvm-ffi", + ) From 7a014c025891c0201dd60e1eaba680b0a5a578c5 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 30 Apr 2026 23:26:09 +0000 Subject: [PATCH 05/21] remove old files Signed-off-by: Thien Tran --- benchmark_fused_indexer_q.py | 68 ------ indexer_q_mxfp4.py | 371 ---------------------------- indexer_q_mxfp4_cutedsl.py | 460 ----------------------------------- 3 files changed, 899 deletions(-) delete mode 100644 benchmark_fused_indexer_q.py delete mode 100644 indexer_q_mxfp4.py delete mode 100644 indexer_q_mxfp4_cutedsl.py diff --git a/benchmark_fused_indexer_q.py b/benchmark_fused_indexer_q.py deleted file mode 100644 index 366338049d55..000000000000 --- a/benchmark_fused_indexer_q.py +++ /dev/null @@ -1,68 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import torch - -from vllm.triton_utils import triton -from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( - fused_indexer_q_rope_quant, -) - -NUM_HEADS = 64 -HEAD_DIM = 128 -ROPE_DIM = 64 -MAX_POS = 100_000 -TOKENS = [1, 8, 32, 128, 256, 512, 1024, 2048, 4096, 8192] -ROPE_DTYPE = torch.float32 - - -def make_inputs(num_tokens: int): - positions = torch.randint(MAX_POS, (num_tokens,), dtype=torch.int64) - query = torch.randn(num_tokens, NUM_HEADS, HEAD_DIM, dtype=torch.bfloat16) - cos_sin_cache = torch.randn(MAX_POS, ROPE_DIM, dtype=ROPE_DTYPE) - weights = torch.randn(num_tokens, NUM_HEADS, dtype=torch.bfloat16) - return ( - positions, - query, - cos_sin_cache, - weights, - HEAD_DIM**-0.5, - NUM_HEADS**-0.5, - True, - ) - - -def benchmark(num_tokens: int): - torch.set_default_device("cuda") - - # run multiple times per measurement for more reliable results - # separate sets of inputs to avoid L2 cache - N = 10 - inputs_list = [make_inputs(num_tokens) for _ in range(N)] - - def f(): - for kernel_args in inputs_list: - fused_indexer_q_rope_quant(*kernel_args) - - median_ms = triton.testing.do_bench(f) / N - - bytes_per_token = 8 # position int64 - bytes_per_token += NUM_HEADS * HEAD_DIM * 2 # q in bf16 - bytes_per_token += ROPE_DIM * torch.empty((), dtype=ROPE_DTYPE).element_size() - bytes_per_token += NUM_HEADS * 2 # weights in bf16 - bytes_per_token += NUM_HEADS * HEAD_DIM // 2 # q out fp4 - bytes_per_token += NUM_HEADS * HEAD_DIM // 32 # q_scale uint8 - bytes_per_token += NUM_HEADS * 4 # weights out fp32 - total_bytes = bytes_per_token * num_tokens - - return median_ms, total_bytes - - -if __name__ == "__main__": - for num_tokens in TOKENS: - median_ms, moved_bytes = benchmark(num_tokens) - bandwidth_gb_s = moved_bytes / (median_ms * 1e-3) * 1e-9 - print( - f"T={num_tokens:6d} " - f"{median_ms * 1e3:7.2f} us " - f"BW {bandwidth_gb_s:7.1f} GB/s " - ) diff --git a/indexer_q_mxfp4.py b/indexer_q_mxfp4.py deleted file mode 100644 index 6180a2ea0961..000000000000 --- a/indexer_q_mxfp4.py +++ /dev/null @@ -1,371 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -from torch.utils.cpp_extension import load_inline - -CUDA_SRC = r""" -#include -#include -#include - -constexpr int WARP_SIZE = 32; -constexpr int MX_BLOCK_SIZE = 32; - -__device__ inline -void ldg_f32x8(float *data, const void *ptr) { - asm volatile("ld.global.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" - : "=f"(data[0]), "=f"(data[1]), "=f"(data[2]), "=f"(data[3]), - "=f"(data[4]), "=f"(data[5]), "=f"(data[6]), "=f"(data[7]) - : "l"(ptr)); -} - -__device__ inline -void ldg_b32x8_fast(int *data, const void *ptr) { - asm volatile("ld.global.relaxed.cta.L1::no_allocate.v8.b32 " - "{%0, %1, %2, %3, %4, %5, %6, %7}, [%8];" - : "=r"(data[0]), "=r"(data[1]), "=r"(data[2]), "=r"(data[3]), - "=r"(data[4]), "=r"(data[5]), "=r"(data[6]), "=r"(data[7]) - : "l"(ptr)); -} - -__device__ inline -void bf16x2_to_fp32x2(float *out, uint32_t data) { - asm volatile("shl.b32 %0, %2, 16;\n" // low 16-bit - "and.b32 %1, %2, 0xFFFF0000;" // high 16-bit - : "=f"(out[0]), "=f"(out[1]) : "r"(data)); -} - -__device__ inline -int fp32x2_to_bf16x2(float a, float b) { - int tmp; - asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;" : "=r"(tmp) : "f"(b), "f"(a)); - return tmp; -} - -__device__ inline -int bf16x2_abs(int a) { - int d; - asm volatile("abs.bf16x2 %0, %1;" : "=r"(d) : "r"(a)); - return d; -} - -__device__ inline -int bf16x2_max(int a, int b) { - int d; - asm volatile("max.bf16x2 %0, %1, %2;" : "=r"(d) : "r"(a), "r"(b)); - return d; -} - -__device__ inline -int fp32x8_to_fp4x8(const float *x) { - int out; - asm volatile( - "{\n" - ".reg .b8 x0, x1, x2, x3;\n" - "cvt.rn.satfinite.e2m1x2.f32 x0, %2, %1;\n" - "cvt.rn.satfinite.e2m1x2.f32 x1, %4, %3;\n" - "cvt.rn.satfinite.e2m1x2.f32 x2, %6, %5;\n" - "cvt.rn.satfinite.e2m1x2.f32 x3, %8, %7;\n" - "mov.b32 %0, {x0, x1, x2, x3};\n" - "}" - : "=r"(out) - : "f"(x[0]), "f"(x[1]), "f"(x[2]), "f"(x[3]), - "f"(x[4]), "f"(x[5]), "f"(x[6]), "f"(x[7]) - ); - return out; -} - -template -__block_size__((TB_SIZE, 1, 1)) -__global__ -void fused_indexer_q_rope_mxfp4_kernel( - const int64_t *positions_ptr, // [num_tokens] - const nv_bfloat16 *q_ptr, // [num_tokens, num_heads, head_dim] - const float *cos_sin_ptr, // [max_pos, rope_dim] - const nv_bfloat16 *weights, // [num_tokens, num_heads] - char *q_fp4_ptr, // [num_tokens, num_heads, head_dim/2] - uint8_t *q_scale_ptr, // [num_tokens, num_heads, head_dim/32] - float *weights_out, // [num_tokens, num_heads] - float scale, - int num_tokens, - int num_heads, - int q_stride0, int q_stride1, - int cos_sin_stride, - int weights_stride, - int q_fp4_stride0, int q_fp4_stride1, - int q_scale_stride0, int q_scale_stride1, - int weights_out_stride -) { - constexpr int NOPE_DIM = HEAD_DIM - ROPE_DIM; - - // we will use 32B load per thread = 16 BF16 elems. - // hence, we need 8 threads to load single head (128 elems). - // let's call subwarp = 8 threads -> 1 subwarp handles 1 token - constexpr int SUBWARP_SIZE = HEAD_DIM / 16; - static_assert(SUBWARP_SIZE <= WARP_SIZE); - - const int tid = threadIdx.x; - const int bid = blockIdx.x; - - const int global_tid = bid * blockDim.x + tid; - const int global_subwarp_id = global_tid / SUBWARP_SIZE; - const int sublane_id = tid % SUBWARP_SIZE; - - const int token_id = global_subwarp_id / num_heads; - const int head_id = global_subwarp_id % num_heads; - - // load Q - int q[8]; - float q_f32[16]; - const int q_offset = token_id * q_stride0 + head_id * q_stride1 + sublane_id * 16; - ldg_b32x8_fast(q, q_ptr + q_offset); - int64_t pos = positions_ptr[token_id]; - - // apply rope - // NOTE: warp divergence - if (sublane_id * 16 >= NOPE_DIM) { - float cos[8], sin[8]; - const int rope_idx = (sublane_id * 16 - NOPE_DIM) / 2; - ldg_f32x8(cos, cos_sin_ptr + (pos * cos_sin_stride + rope_idx)); - ldg_f32x8(sin, cos_sin_ptr + (pos * cos_sin_stride + ROPE_DIM / 2 + rope_idx)); - - // unpack - for (int i = 0; i < 8; i++) - bf16x2_to_fp32x2(q_f32 + i * 2, q[i]); - - for (int i = 0; i < 8; i++) { - float q0 = q_f32[i * 2 + 0] * cos[i] - q_f32[i * 2 + 1] * sin[i]; - float q1 = q_f32[i * 2 + 0] * sin[i] + q_f32[i * 2 + 1] * cos[i]; - q_f32[i * 2 + 0] = q0; - q_f32[i * 2 + 1] = q1; - } - - // BF16 round-trip to match reference - for (int i = 0; i < 8; i++) - q[i] = fp32x2_to_bf16x2(q_f32[i * 2], q_f32[i * 2 + 1]); - } - - // absmax in BF16 to save instructions - int q_amax = bf16x2_abs(q[0]); - for (int i = 1; i < 8; i++) - q_amax = bf16x2_max(q_amax, bf16x2_abs(q[i])); - - // each thread holds 16 elems -> 2 threads hold 32 elems - // warp shuffle among 2 threads - constexpr int NUM_THREADS_PER_MX = MX_BLOCK_SIZE / 16; - for (int stride = NUM_THREADS_PER_MX / 2; stride > 0; stride /= 2) - q_amax = bf16x2_max(q_amax, __shfl_xor_sync(0xFFFF'FFFF, q_amax, stride)); - - // final amax in FP32 - float q_amax_f32[2]; - bf16x2_to_fp32x2(q_amax_f32, q_amax); - float amax = max(q_amax_f32[0], q_amax_f32[1]); - - constexpr float amax_eps = 0x6p-126f; // 6.0f * 2^-126 - constexpr float inv_fp4_max = 1.0f / 6.0f; - float fp4_scale = max(amax, amax_eps) * inv_fp4_max; - - // compute ceil_log2 with bit manipulation - // add a magic number so that exponent increments by 1 - // when mantissa bits > 0 - uint32_t bits = __float_as_uint(fp4_scale); - uint32_t ue8m0 = ((bits + 0x7FFFFFU) >> 23U) & 0xFFU; - - // only 1 out of 2 threads need to store SF (rmb, 2 threads = 32 elems) - if (tid % 2 == 0) { - const int q_scale_offset = token_id * q_scale_stride0 - + head_id * q_scale_stride1 - + sublane_id / 2; - q_scale_ptr[q_scale_offset] = ue8m0; - } - - // unpack - for (int i = 0; i < 8; i++) - bf16x2_to_fp32x2(q_f32 + i * 2, q[i]); - - // let A = ceil(log2(fp4_scale)) be the actual mathematical value - // fp4_scale = 2^A, and ue8m0 = A + 127, where 127 is the exponent bias - // we want 1/fp4_scale = 2^(-A), whose exponent bits = -A + 127 = 254 - ue8m0 - float inv_fp4_scale = __uint_as_float((254U - ue8m0) << 23U); - for (int i = 0; i < 16; i++) - q_f32[i] *= inv_fp4_scale; - - int2 packed_fp4; - packed_fp4.x = fp32x8_to_fp4x8(q_f32); - packed_fp4.y = fp32x8_to_fp4x8(q_f32 + 8); - const int q_fp4_offset = token_id * q_fp4_stride0 - + head_id * q_fp4_stride1 - + sublane_id * 8; - __stcs(reinterpret_cast(q_fp4_ptr + q_fp4_offset), packed_fp4); - - // scale weights - if (global_tid < num_tokens * num_heads) { - const int token_id = global_tid / num_heads; - const int head_id = global_tid % num_heads; - float w = __bfloat162float(weights[token_id * weights_stride + head_id]); - weights_out[token_id * weights_out_stride + head_id] = w * scale; - } -} - -#include -#include -#include -#include -#include - -at::Tensor fused_indexer_q_rope_mxfp4( - const at::Tensor& positions, - const at::Tensor& q, - const at::Tensor& cos_sin, - const at::Tensor& weights, - at::Tensor& q_fp4, - at::Tensor& q_scale, - at::Tensor& weights_out, - double scale) { - TORCH_CHECK(positions.is_cuda(), "positions must be CUDA"); - TORCH_CHECK(q.is_cuda(), "q must be CUDA"); - TORCH_CHECK(cos_sin.is_cuda(), "cos_sin must be CUDA"); - TORCH_CHECK(weights.is_cuda(), "weights must be CUDA"); - TORCH_CHECK(q_fp4.is_cuda(), "q_fp4 must be CUDA"); - TORCH_CHECK(q_scale.is_cuda(), "q_scale must be CUDA"); - TORCH_CHECK(weights_out.is_cuda(), "weights_out must be CUDA"); - - TORCH_CHECK(positions.scalar_type() == at::kLong, "positions must be int64"); - TORCH_CHECK(q.scalar_type() == at::kBFloat16, "q must be bfloat16"); - TORCH_CHECK(cos_sin.scalar_type() == at::kFloat, "cos_sin must be float32"); - TORCH_CHECK(weights.scalar_type() == at::kBFloat16, "weights must be bfloat16"); - TORCH_CHECK(q_fp4.scalar_type() == at::kByte, "q_fp4 must be uint8"); - TORCH_CHECK(q_scale.scalar_type() == at::kByte, "q_scale must be uint8"); - TORCH_CHECK(weights_out.scalar_type() == at::kFloat, "weights_out must be float32"); - - TORCH_CHECK(positions.dim() == 1, "positions must be rank 1"); - TORCH_CHECK(q.dim() == 3, "q must have shape [num_tokens, num_heads, 128]"); - TORCH_CHECK(cos_sin.dim() == 2, "cos_sin must have shape [max_pos, 64]"); - TORCH_CHECK(weights.dim() == 2, "weights must have shape [num_tokens, num_heads]"); - TORCH_CHECK(q.size(2) == 128, "q head_dim must be 128"); - TORCH_CHECK(cos_sin.size(1) == 64, "cos_sin rope_dim must be 64"); - - const int num_tokens = static_cast(positions.size(0)); - const int num_heads = static_cast(q.size(1)); - TORCH_CHECK(q.size(0) == num_tokens, "q and positions token counts differ"); - TORCH_CHECK(weights.size(0) == num_tokens, "weights token count differs"); - TORCH_CHECK(weights.size(1) == num_heads, "weights head count differs"); - TORCH_CHECK(q_fp4.sizes() == at::IntArrayRef({num_tokens, num_heads, 64}), - "q_fp4 must have shape [num_tokens, num_heads, 64]"); - TORCH_CHECK(q_scale.sizes() == at::IntArrayRef({num_tokens, num_heads, 4}), - "q_scale must have shape [num_tokens, num_heads, 4]"); - TORCH_CHECK(weights_out.sizes() == weights.sizes(), - "weights_out must have the same shape as weights"); - - TORCH_CHECK(num_heads % 32 == 0, - "num_heads must be divisible by 32 for this launch wrapper"); - - c10::cuda::CUDAGuard device_guard(q.device()); - constexpr int kHeadDim = 128; - constexpr int kRopeDim = 64; - constexpr int kSubwarpSize = kHeadDim / 16; - constexpr int kBlockSize = 256; - const int total_threads = num_tokens * num_heads * kSubwarpSize; - const int grid = (total_threads + kBlockSize - 1) / kBlockSize; - - fused_indexer_q_rope_mxfp4_kernel - <<>>( - positions.data_ptr(), - reinterpret_cast(q.data_ptr()), - cos_sin.data_ptr(), - reinterpret_cast(weights.data_ptr()), - reinterpret_cast(q_fp4.data_ptr()), - q_scale.data_ptr(), - weights_out.data_ptr(), - static_cast(scale), - num_tokens, - num_heads, - static_cast(q.stride(0)), - static_cast(q.stride(1)), - static_cast(cos_sin.stride(0)), - static_cast(weights.stride(0)), - static_cast(q_fp4.stride(0)), - static_cast(q_fp4.stride(1)), - static_cast(q_scale.stride(0)), - static_cast(q_scale.stride(1)), - static_cast(weights_out.stride(0))); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - return q_fp4; -} - -TORCH_LIBRARY(indexer_q_mxfp4, m) { - m.def("fused_indexer_q_rope_mxfp4(" - "Tensor positions, Tensor q, Tensor cos_sin, Tensor weights, " - "Tensor(a!) q_fp4, Tensor(b!) q_scale, Tensor(c!) weights_out, " - "float scale) -> Tensor"); - m.impl("fused_indexer_q_rope_mxfp4", - torch::dispatch(c10::DispatchKey::CUDA, - TORCH_FN(fused_indexer_q_rope_mxfp4))); -} -""" - -HEAD_DIM = 128 -ROPE_DIM = 64 - -load_inline( - "indexer_q_mxfp4", - cpp_sources="", - cuda_sources=CUDA_SRC, - verbose=False, - is_python_module=False, - no_implicit_headers=True, - extra_cuda_cflags=[ - "-O3", - "-gencode=arch=compute_100a,code=sm_100a", - "--expt-relaxed-constexpr", - "--relocatable-device-code=false", - "-lineinfo", - "-Xptxas=-v", - ], - extra_ldflags=["-lcuda"], -) -_fused_indexer_q_rope_mxfp4 = torch.ops.indexer_q_mxfp4.fused_indexer_q_rope_mxfp4 - - -def fused_indexer_q_rope_quant( - positions: torch.Tensor, - index_q: torch.Tensor, - index_q_cos_sin_cache: torch.Tensor, - index_weights: torch.Tensor, - index_weights_softmax_scale: float, - index_weights_head_scale: float, - use_fp4: bool = False, -) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], torch.Tensor]: - if not use_fp4: - raise NotImplementedError("indexer_q_mxfp4 only implements use_fp4=True") - assert index_q.ndim == 3 and index_q.shape[-1] == HEAD_DIM - assert index_q_cos_sin_cache.ndim == 2 - assert index_q_cos_sin_cache.shape[-1] == ROPE_DIM - - num_tokens, num_heads, _ = index_q.shape - q_fp4 = torch.empty( - (num_tokens, num_heads, HEAD_DIM // 2), - dtype=torch.uint8, - device=index_q.device, - ) - q_scale = torch.empty( - (num_tokens, num_heads, HEAD_DIM // 32), - dtype=torch.uint8, - device=index_q.device, - ) - weights_out = torch.empty_like(index_weights, dtype=torch.float32) - - scale = float(index_weights_softmax_scale * index_weights_head_scale) - _fused_indexer_q_rope_mxfp4( - positions, - index_q, - index_q_cos_sin_cache, - index_weights, - q_fp4, - q_scale, - weights_out, - scale, - ) - return (q_fp4, q_scale.view(torch.int32).squeeze(-1)), weights_out diff --git a/indexer_q_mxfp4_cutedsl.py b/indexer_q_mxfp4_cutedsl.py deleted file mode 100644 index a06b5395ac9f..000000000000 --- a/indexer_q_mxfp4_cutedsl.py +++ /dev/null @@ -1,460 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import math -from functools import cache - -import cutlass -import cutlass.cute as cute -import torch -from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr -from cutlass._mlir import ir -from cutlass._mlir.dialects import llvm, vector -from cutlass.cutlass_dsl import T, dsl_user_op -from quack.compile_utils import make_fake_tensor - -from vllm.vllm_flash_attn.cute import utils as cute_utils - -MXFP4_BLOCK_SIZE = 32 - -_TORCH_TO_CUTE = { - torch.bfloat16: BFloat16, - torch.float32: Float32, -} - - -@dsl_user_op -def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [ - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - ], - "cvt.rn.bf16x2.f32 $0, $2, $1;", - "=r,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: - out = llvm.inline_asm( - llvm.StructType.get_literal([T.f32(), T.f32()]), - [Uint32(data).ir_value(loc=loc, ip=ip)], - "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n", - "=f,=f,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - return ( - Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), - Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), - ) - - -@dsl_user_op -def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [Uint32(a).ir_value(loc=loc, ip=ip)], - "abs.bf16x2 $0, $1;", - "=r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [ - Uint32(a).ir_value(loc=loc, ip=ip), - Uint32(b).ir_value(loc=loc, ip=ip), - ], - "max.bf16x2 $0, $1, $2;", - "=r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _fp32x8_to_fp4x8( - vals: cute.Tensor, - offset: cutlass.Constexpr[int], - *, - loc=None, - ip=None, -) -> Uint32: - # Pack eight scaled FP32 values into four E2M1x2 bytes, returned as one b32. - operands = [Float32(vals[offset + i]).ir_value(loc=loc, ip=ip) for i in range(8)] - return Uint32( - llvm.inline_asm( - T.i32(), - operands, - "{\n\t" - ".reg .b8 x0, x1, x2, x3;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" - "mov.b32 $0, {x0, x1, x2, x3};\n\t" - "}\n", - "=r,f,f,f,f,f,f,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -# Custom vectorized load to support cache modifiers. For some reason, -# cute.autovec_copy() does not currently emit the requested modifiers. -# tensor and coord is only used to select the base pointer. actual load -# is done using out_dtype -@dsl_user_op -def _ldg_vec( - tensor: cute.Tensor, - coord: cute.Coord, - vec_size: cutlass.Constexpr[int], - modifier: cutlass.Constexpr[str] = "", - out_dtype: cutlass.Constexpr[type[cutlass.Numeric]] = Uint32, - *, - loc=None, - ip=None, -) -> cute.TensorSSA: - if const_expr(out_dtype is Float32): - mlir_ty = T.f32() - ptx_ty = "f32" - constraint = "=f" - elif const_expr(out_dtype is Uint32): - mlir_ty = T.i32() - ptx_ty = "b32" - constraint = "=r" - else: - raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {out_dtype}") - - # compute base pointer - base_ptr = ( - tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) - ).toint() - - # build PTX string - ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" - ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" - ptx_str += f", [${vec_size}];" - out = llvm.inline_asm( - llvm.StructType.get_literal([mlir_ty] * vec_size), - [Int64(base_ptr).ir_value(loc=loc, ip=ip)], - ptx_str, - ",".join([constraint] * vec_size + ["l"]), - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - vec = vector.from_elements( - ir.VectorType.get([vec_size], mlir_ty, loc=loc), - [llvm.extractvalue(mlir_ty, out, [i], loc=loc, ip=ip) for i in range(vec_size)], - loc=loc, - ip=ip, - ) - return cute.TensorSSA(vec, vec_size, out_dtype) - - -@dsl_user_op -def _stg_u32xN( - tensor: cute.Tensor, - coord: cute.Coord, - values: cute.Tensor, - vec_size: cutlass.Constexpr[int], - modifier: cutlass.Constexpr[str] = "", - *, - loc=None, - ip=None, -) -> None: - base_ptr = ( - tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) - ).toint() - value_operands = ", ".join(f"${i + 1}" for i in range(vec_size)) - llvm.inline_asm( - None, - [Int64(base_ptr).ir_value(loc=loc, ip=ip)] - + [Uint32(values[i]).ir_value(loc=loc, ip=ip) for i in range(vec_size)], - f"st.global{modifier}.v{vec_size}.u32 [$0], {{{value_operands}}};", - ",".join(["l"] + ["r"] * vec_size), - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -class IndexerQMxFp4Kernel: - """Eight-thread subwarps process one ``(token, head)`` row.""" - - def __init__( - self, - head_dim: int = 128, - rope_dim: int = 64, - num_heads: int = 64, - cos_sin_dtype: type[cutlass.Numeric] = cutlass.Float32, - ): - self.head_dim = head_dim - self.rope_dim = rope_dim - self.nope_dim = head_dim - rope_dim - self.num_heads = num_heads - self.cos_sin_dtype = cos_sin_dtype - - # later we will use 32B load = 16 BF16 elems - # thus, head_dim=128 requires 8 threads to handle. - # let's call subwarp = 8 threads. - self.subwarp_size = head_dim // 16 - self.tb_size = 256 - - @cute.jit - def __call__( - self, - positions: cute.Tensor, - q: cute.Tensor, - cos_sin_cache: cute.Tensor, - weights: cute.Tensor, - q_fp4: cute.Tensor, - q_scale: cute.Tensor, - weights_out: cute.Tensor, - scale: Float32, - ): - num_tokens, num_heads, _ = q.shape - total_threads = num_tokens * num_heads * self.subwarp_size - grid = [cute.ceil_div(total_threads, self.tb_size), 1, 1] - self.kernel( - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - scale, - ).launch(grid=grid, block=[self.tb_size, 1, 1]) - - @cute.kernel - def kernel( - self, - positions: cute.Tensor, - q: cute.Tensor, - cos_sin_cache: cute.Tensor, - weights: cute.Tensor, - q_fp4: cute.Tensor, - q_scale: cute.Tensor, - weights_out: cute.Tensor, - scale: Float32, - ): - block_id, _, _ = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() - - num_token_heads = q.shape[0] * self.num_heads - global_tid = block_id * self.tb_size + tidx - - global_subwarp_id = global_tid // self.subwarp_size - sublane = tidx % self.subwarp_size - - token_id = global_subwarp_id // self.num_heads - head_id = global_subwarp_id - token_id * self.num_heads - - # each thread loads 16 BF16 elems - elem_base = sublane * 16 - - # q layout: [num_tokens, num_heads, head_dim] - _q_bits = _ldg_vec( - q, (token_id, head_id, elem_base), 8, ".relaxed.cta.L1::no_allocate" - ) - q_bits = cute.make_rmem_tensor(8, Uint32) - q_bits.store(_q_bits) # copy to make it mutable - - # RoPE applies only to the trailing rope_dim values. We keep the rounded - # BF16 result in q_bits so the later amax and quantization see BF16. - # cos_sin_cache layout: [max_pos, rope_dim] - if elem_base >= self.nope_dim: - pos = positions[token_id] - rope_idx = (elem_base - self.nope_dim) // 2 - if const_expr(self.cos_sin_dtype is Float32): - cos_vals = _ldg_vec( - cos_sin_cache, - (pos, rope_idx), - 8, - out_dtype=Float32, - ) - sin_vals = _ldg_vec( - cos_sin_cache, - (pos, self.nope_dim // 2 + rope_idx), - 8, - out_dtype=Float32, - ) - else: - # Each BF16 cache load lane contains two adjacent values. - cos_loaded = _ldg_vec(cos_sin_cache, (pos, rope_idx), 4) - sin_loaded = _ldg_vec( - cos_sin_cache, - (pos, self.rope_dim // 2 + rope_idx), - 4, - ) - cos_vals = cute.make_rmem_tensor(8, Float32) - sin_vals = cute.make_rmem_tensor(8, Float32) - for i in cutlass.range_constexpr(4): - cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( - cos_loaded[i] - ) - sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( - sin_loaded[i] - ) - - for i in cutlass.range_constexpr(8): - q0, q1 = _bf16x2_to_fp32(q_bits[i]) - cos = cos_vals[i] - sin = sin_vals[i] - rot0 = q0 * cos - q1 * sin - rot1 = q0 * sin + q1 * cos - # convert back to BF16 to match numerics - q_bits[i] = _fp32x2_to_bf16x2(rot0, rot1) - - # Each thread holds 16 elems. Two adjacent threads form one 32-elem - # MXFP4 block, so a width-2 shuffle gives the block amax. - local_amax = _bf16x2_abs(q_bits[0]) - for i in cutlass.range_constexpr(1, 8): - local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bits[i])) - amax_bits = cute_utils.warp_reduce( - local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 - ) - amax0, amax1 = _bf16x2_to_fp32(amax_bits) - amax = cute_utils.fmax(amax0, amax1) - - fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * (1.0 / 6.0) - bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) - # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask - # increments the exponent whenever fp4_scale is not exactly a power of 2. - ue8m0 = cute_utils.shr_u32(bits + Uint32(0x7FFFFF), Uint32(23)) & Uint32(0xFF) - - # Only one of the two threads in an MXFP4 block writes the shared scale. - if tidx % 2 == 0: - mx_block = sublane // (MXFP4_BLOCK_SIZE // 16) - q_scale[token_id, head_id, mx_block] = Uint8(ue8m0) - - # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent - # -A + 127 = 254 - ue8m0. - inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) - inv_fp4_scale = Float32(llvm.bitcast(T.f32(), inv_scale_bits.ir_value())) - - vals = cute.make_rmem_tensor(16, Float32) - for i in cutlass.range_constexpr(8): - vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bits[i]) - vals[i * 2] = vals[i * 2] * inv_fp4_scale - vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale - - # pack to FP4 - packed = cute.make_rmem_tensor(2, Uint32) - packed[0] = _fp32x8_to_fp4x8(vals, 0) - packed[1] = _fp32x8_to_fp4x8(vals, 8) - # Each thread writes the eight packed bytes corresponding to its 16 Q values. - _stg_u32xN(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") - - # Weight scaling is independent of the Q subwarp work. The first - # num_tokens * num_heads logical threads cover one weight each. - if global_tid < num_token_heads: - weight_token_id = global_tid // self.num_heads - weight_head_id = global_tid - weight_token_id * self.num_heads - weights_out[weight_token_id, weight_head_id] = ( - weights[weight_token_id, weight_head_id].to(Float32) * scale - ) - - -@cache -def _compile_indexer_q_mxfp4( - head_dim: int, rope_dim: int, num_heads: int, cos_sin_dtype: type[cutlass.Numeric] -): - num_tokens = cute.sym_int() - max_pos = cute.sym_int() - - q = make_fake_tensor(BFloat16, (num_tokens, num_heads, head_dim), divisibility=8) - positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) - cos_div = math.gcd(128 // cos_sin_dtype.width, rope_dim) - cos_sin_cache = make_fake_tensor( - cos_sin_dtype, (max_pos, rope_dim), divisibility=cos_div - ) - weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) - q_fp4 = make_fake_tensor( - Uint8, (num_tokens, num_heads, head_dim // 2), divisibility=16 - ) - q_scale = make_fake_tensor( - Uint8, - (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), - divisibility=4, - ) - weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) - - kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) - return cute.compile( - kernel, - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - Float32(0.0), - options="--enable-tvm-ffi", - ) - - -def fused_indexer_q_rope_quant( - positions: torch.Tensor, - q: torch.Tensor, - cos_sin_cache: torch.Tensor, - weights: torch.Tensor, - softmax_scale: float, - head_scale: float, - use_fp4: bool = False, -) -> tuple[torch.Tensor | tuple[torch.Tensor, torch.Tensor], torch.Tensor]: - if not use_fp4: - raise NotImplementedError( - "indexer_q_mxfp4_cutedsl only implements use_fp4=True" - ) - - num_tokens, num_heads, head_dim = q.shape - rope_dim = cos_sin_cache.shape[-1] - q_fp4 = q.new_empty((num_tokens, num_heads, head_dim // 2), dtype=torch.uint8) - q_scale = q.new_empty( - (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), dtype=torch.uint8 - ) - weights_out = torch.empty_like(weights, dtype=torch.float32) - - compiled = _compile_indexer_q_mxfp4( - head_dim, - rope_dim, - num_heads, - _TORCH_TO_CUTE[cos_sin_cache.dtype], - ) - scale = float(softmax_scale * head_scale) - compiled( - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - scale, - ) - return (q_fp4, q_scale.view(torch.int32).squeeze(-1)), weights_out From f7ce368a403f87c02d48389cbad5f57274415640 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 00:05:54 +0000 Subject: [PATCH 06/21] fix Signed-off-by: Thien Tran --- vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index c2e981f41363..9934dcb52760 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -571,7 +571,7 @@ def kernel( ) sin_vals = _ldg_vec( cos_sin_cache, - (pos, self.nope_dim // 2 + rope_idx), + (pos, self.rope_dim // 2 + rope_idx), 8, out_dtype=Float32, ) From 52885b2c1bf28e3b953ada9f93f4a82d9b1e3192 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 03:33:15 +0000 Subject: [PATCH 07/21] use current cuda stream Signed-off-by: Thien Tran --- vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 9934dcb52760..464338d79a24 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -5,6 +5,7 @@ import cutlass import cutlass.cute as cute import torch +from cuda.bindings.driver import CUstream from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr from cutlass._mlir import ir from cutlass._mlir.dialects import llvm, vector @@ -507,6 +508,7 @@ def __call__( q_scale: cute.Tensor, weights_out: cute.Tensor, scale: Float32, + stream: CUstream, ): num_tokens, num_heads, _ = q.shape total_threads = num_tokens * num_heads * self.subwarp_size @@ -520,7 +522,7 @@ def __call__( q_scale, weights_out, scale, - ).launch(grid=grid, block=[self.tb_size, 1, 1]) + ).launch(grid=grid, block=[self.tb_size, 1, 1], stream=stream) @cute.kernel def kernel( @@ -676,6 +678,7 @@ def _compile_indexer_q_mxfp4( weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) return cute.compile( kernel, positions, @@ -686,5 +689,6 @@ def _compile_indexer_q_mxfp4( q_scale, weights_out, Float32(0.0), + stream, options="--enable-tvm-ffi", ) From 575e1660824dae5d107833e17f2efe768d3e875f Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 08:14:08 +0000 Subject: [PATCH 08/21] add back triton for potential rocm fallback. add import guards Signed-off-by: Thien Tran --- .../ops/deepseek_v4_ops/fused_indexer_q.py | 618 ++++++------------ .../fused_indexer_q_cutedsl.py | 454 +++++++++++++ 2 files changed, 637 insertions(+), 435 deletions(-) create mode 100644 vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 464338d79a24..e36eba64f887 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -1,28 +1,35 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from functools import cache -import cutlass -import cutlass.cute as cute +from importlib.util import find_spec + import torch -from cuda.bindings.driver import CUstream -from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr -from cutlass._mlir import ir -from cutlass._mlir.dialects import llvm, vector -from cutlass.cutlass_dsl import T, dsl_user_op -from quack.compile_utils import make_fake_tensor from vllm.triton_utils import tl, triton -from vllm.vllm_flash_attn.cute import utils as cute_utils + +HAS_CUTEDSL = find_spec("cutlass") is not None + +if HAS_CUTEDSL: + from .fused_indexer_q_cutedsl import fused_indexer_q_rope_quant_mxfp4_cutedsl +else: + + def fused_indexer_q_rope_quant_mxfp4_cutedsl( + positions: torch.Tensor, + index_q: torch.Tensor, + index_q_cos_sin_cache: torch.Tensor, + index_weights: torch.Tensor, + index_weights_softmax_scale: float, + index_weights_head_scale: float, + index_q_packed: torch.Tensor, + index_q_scale: torch.Tensor, + index_weights_out: torch.Tensor, + ) -> None: + pass + # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. MXFP4_BLOCK_SIZE = 32 -_TORCH_TO_CUTE = { - torch.bfloat16: BFloat16, - torch.float32: Float32, -} - @triton.jit def _get_cos_sin( @@ -185,6 +192,116 @@ def _fused_indexer_q_rope_quant_kernel( ) +@triton.jit +def _fused_indexer_q_rope_mxfp4_kernel( + pos_ptr, + # Index Q RoPE input (fp/bf16) + index_q_ptr, + index_q_stride0, + index_q_stride1, + index_q_cos_sin_ptr, + index_q_cos_sin_stride, + INDEX_Q_HALF_ROT_DIM: tl.constexpr, + # MXFP4 Q outputs + index_q_mxfp4_ptr, # uint8, (T, H, HEAD_DIM // 2) + index_q_mxfp4_stride0, + index_q_mxfp4_stride1, + index_q_scale_ptr, # uint8 ue8m0, (T, H, HEAD_DIM // BLOCK) + index_q_scale_stride0, + index_q_scale_stride1, + INDEX_Q_HEAD_DIM: tl.constexpr, + MXFP4_BLOCK: tl.constexpr, + # Weights (NO per-token q_scale fold for MXFP4; per-block scales stay + # with the Q values in the output scale tensor). + index_weights_ptr, + index_weights_stride, + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out_ptr, + index_weights_out_stride, +): + INDEX_Q_ROT_DIM: tl.constexpr = 2 * INDEX_Q_HALF_ROT_DIM + INDEX_Q_NOPE_DIM: tl.constexpr = INDEX_Q_HEAD_DIM - INDEX_Q_ROT_DIM + NUM_NOPE_BLOCKS: tl.constexpr = INDEX_Q_NOPE_DIM // MXFP4_BLOCK + NUM_ROPE_BLOCKS: tl.constexpr = INDEX_Q_ROT_DIM // MXFP4_BLOCK + HALF_BLOCK: tl.constexpr = MXFP4_BLOCK // 2 + tl.static_assert(INDEX_Q_NOPE_DIM >= 0) + tl.static_assert(INDEX_Q_NOPE_DIM % MXFP4_BLOCK == 0) + tl.static_assert(INDEX_Q_ROT_DIM % MXFP4_BLOCK == 0) + tl.static_assert(MXFP4_BLOCK % 2 == 0) + + tok_idx = tl.program_id(0) + head_idx = tl.program_id(1) + + pos = tl.load(pos_ptr + tok_idx) + + q_base = index_q_ptr + tok_idx * index_q_stride0 + head_idx * index_q_stride1 + out_base = ( + index_q_mxfp4_ptr + + tok_idx * index_q_mxfp4_stride0 + + head_idx * index_q_mxfp4_stride1 + ) + scale_base = ( + index_q_scale_ptr + + tok_idx * index_q_scale_stride0 + + head_idx * index_q_scale_stride1 + ) + + half_off = tl.arange(0, HALF_BLOCK) + + # NoPE blocks: direct load, pair as (even-index, odd-index) values. + for b in tl.static_range(NUM_NOPE_BLOCKS): + base = b * MXFP4_BLOCK + x_lo = tl.load(q_base + base + half_off * 2).to(tl.float32) + x_hi = tl.load(q_base + base + half_off * 2 + 1).to(tl.float32) + packed, ue8m0 = _quantize_mxfp4_pair(x_lo, x_hi) + tl.store(out_base + base // 2 + half_off, packed) + tl.store(scale_base + b, ue8m0) + + # RoPE blocks: apply GPT-J interleaved RoPE to the block's 16 pairs, + # then quantize. Each block covers HALF_BLOCK (=16) cos/sin pairs. + rot_q_base = q_base + INDEX_Q_NOPE_DIM + for b in tl.static_range(NUM_ROPE_BLOCKS): + pair_off = b * HALF_BLOCK + half_off # indices in [0, HALF_ROT_DIM) + cos_b = tl.load( + index_q_cos_sin_ptr + pos * index_q_cos_sin_stride + pair_off + ).to(tl.float32) + sin_b = tl.load( + index_q_cos_sin_ptr + + pos * index_q_cos_sin_stride + + pair_off + + INDEX_Q_HALF_ROT_DIM + ).to(tl.float32) + x_even = tl.load(rot_q_base + pair_off * 2).to(tl.float32) + x_odd = tl.load(rot_q_base + pair_off * 2 + 1).to(tl.float32) + r_even = x_even * cos_b - x_odd * sin_b + r_odd = x_odd * cos_b + x_even * sin_b + # bf16 roundtrip for parity with the FP8 kernel / reference numerics. + r_even = r_even.to(tl.bfloat16).to(tl.float32) + r_odd = r_odd.to(tl.bfloat16).to(tl.float32) + packed, ue8m0 = _quantize_mxfp4_pair(r_even, r_odd) + rope_byte_off = (INDEX_Q_NOPE_DIM + b * MXFP4_BLOCK) // 2 + tl.store(out_base + rope_byte_off + half_off, packed) + tl.store(scale_base + NUM_NOPE_BLOCKS + b, ue8m0) + + # MXFP4 weight-fold contract: + # index_weights_out = index_weights * softmax_scale * head_scale + # NOTE: q_scale is NOT folded here (contrast with the FP8 kernel above). + # MXFP4 Q emits a separate ue8m0 scale tensor of shape + # (T, H, HEAD_DIM // MXFP4_BLOCK) alongside the packed values, so each + # per-block scale is applied by the downstream MXFP4 logits kernel when + # dequantizing Q. There is no per-token scalar to fold into `weights`. + index_weights = tl.load( + index_weights_ptr + tok_idx * index_weights_stride + head_idx + ).to(tl.float32) + index_weights *= index_weights_softmax_scale + index_weights *= index_weights_head_scale + tl.store( + index_weights_out_ptr + tok_idx * index_weights_out_stride + head_idx, + index_weights, + ) + + def fused_indexer_q_rope_quant( positions: torch.Tensor, index_q: torch.Tensor, @@ -226,43 +343,68 @@ def fused_indexer_q_rope_quant( assert index_q.ndim == 3 assert index_q_cos_sin_cache.ndim == 2 - num_tokens = positions.shape[0] - num_index_q_heads = index_q.shape[1] index_q_head_dim = index_q.shape[2] - index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) - if use_fp4: assert index_q_head_dim % MXFP4_BLOCK_SIZE == 0, ( f"head_dim={index_q_head_dim} must be a multiple of MXFP4 block " f"size {MXFP4_BLOCK_SIZE}" ) + num_tokens = positions.shape[0] + num_index_q_heads = index_q.shape[1] num_scale_blocks = index_q_head_dim // MXFP4_BLOCK_SIZE - index_q_packed = index_q.new_empty( + index_q_packed = torch.empty( (num_tokens, num_index_q_heads, index_q_head_dim // 2), dtype=torch.uint8, + device=index_q.device, ) - index_q_scale = index_q.new_empty( + index_q_scale = torch.empty( (num_tokens, num_index_q_heads, num_scale_blocks), dtype=torch.uint8, + device=index_q.device, ) - compiled = _compile_indexer_q_mxfp4( - index_q_head_dim, - index_q_cos_sin_cache.shape[-1], - num_index_q_heads, - _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype], - ) - scale = float(index_weights_softmax_scale * index_weights_head_scale) - compiled( - positions, - index_q, - index_q_cos_sin_cache, - index_weights, - index_q_packed, - index_q_scale, - index_weights_out, - scale, - ) + index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) + + if fused_indexer_q_rope_quant_mxfp4_cutedsl is not None: + fused_indexer_q_rope_quant_mxfp4_cutedsl( + positions, + index_q, + index_q_cos_sin_cache, + index_weights, + index_weights_softmax_scale, + index_weights_head_scale, + index_q_packed, + index_q_scale, + index_weights_out, + ) + + else: + # Triton fallback + _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( + positions, + index_q, + index_q.stride(0), + index_q.stride(1), + index_q_cos_sin_cache, + index_q_cos_sin_cache.stride(0), + index_q_cos_sin_cache.shape[-1] // 2, + index_q_packed, + index_q_packed.stride(0), + index_q_packed.stride(1), + index_q_scale, + index_q_scale.stride(0), + index_q_scale.stride(1), + index_q_head_dim, + MXFP4_BLOCK_SIZE, + index_weights, + index_weights.stride(0), + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out, + index_weights_out.stride(0), + num_warps=1, # TODO: Tune this + ) + # Values stay uint8 (2 E2M1 nibbles per byte). Scales are 4 ue8m0 # bytes per (token, head) reinterpreted as one int32, then squeezed # from (T, H, 1) to (T, H) to match DeepGEMM's expected q_sf rank @@ -273,6 +415,9 @@ def fused_indexer_q_rope_quant( index_q_scale.view(torch.int32).squeeze(-1), ), index_weights_out + num_tokens = positions.shape[0] + num_index_q_heads = index_q.shape[1] + index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( positions, @@ -295,400 +440,3 @@ def fused_indexer_q_rope_quant( num_warps=1, # TODO: Tune this ) return index_q_fp8, index_weights_out - - -@dsl_user_op -def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [ - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - ], - "cvt.rn.bf16x2.f32 $0, $2, $1;", - "=r,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: - out = llvm.inline_asm( - llvm.StructType.get_literal([T.f32(), T.f32()]), - [Uint32(data).ir_value(loc=loc, ip=ip)], - "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n", - "=f,=f,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - return ( - Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), - Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), - ) - - -@dsl_user_op -def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [Uint32(a).ir_value(loc=loc, ip=ip)], - "abs.bf16x2 $0, $1;", - "=r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [ - Uint32(a).ir_value(loc=loc, ip=ip), - Uint32(b).ir_value(loc=loc, ip=ip), - ], - "max.bf16x2 $0, $1, $2;", - "=r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -@dsl_user_op -def _fp32x8_to_fp4x8( - vals: cute.Tensor, - offset: cutlass.Constexpr[int], - *, - loc=None, - ip=None, -) -> Uint32: - # Pack eight scaled FP32 values into four E2M1x2 bytes, returned as one b32. - operands = [Float32(vals[offset + i]).ir_value(loc=loc, ip=ip) for i in range(8)] - return Uint32( - llvm.inline_asm( - T.i32(), - operands, - "{\n\t" - ".reg .b8 x0, x1, x2, x3;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" - "mov.b32 $0, {x0, x1, x2, x3};\n\t" - "}\n", - "=r,f,f,f,f,f,f,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - ) - - -# Custom vectorized load to support cache modifiers. For some reason, -# cute.autovec_copy() does not currently emit the requested modifiers. -# tensor and coord is only used to select the base pointer. actual load -# is done using out_dtype -@dsl_user_op -def _ldg_vec( - tensor: cute.Tensor, - coord: cute.Coord, - vec_size: cutlass.Constexpr[int], - modifier: cutlass.Constexpr[str] = "", - out_dtype: cutlass.Constexpr[type[cutlass.Numeric]] = Uint32, - *, - loc=None, - ip=None, -) -> cute.TensorSSA: - if const_expr(out_dtype is Float32): - mlir_ty = T.f32() - ptx_ty = "f32" - constraint = "=f" - elif const_expr(out_dtype is Uint32): - mlir_ty = T.i32() - ptx_ty = "b32" - constraint = "=r" - else: - raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {out_dtype}") - - # compute base pointer - base_ptr = ( - tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) - ).toint() - - # build PTX string - ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" - ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" - ptx_str += f", [${vec_size}];" - out = llvm.inline_asm( - llvm.StructType.get_literal([mlir_ty] * vec_size), - [Int64(base_ptr).ir_value(loc=loc, ip=ip)], - ptx_str, - ",".join([constraint] * vec_size + ["l"]), - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - vec = vector.from_elements( - ir.VectorType.get([vec_size], mlir_ty, loc=loc), - [llvm.extractvalue(mlir_ty, out, [i], loc=loc, ip=ip) for i in range(vec_size)], - loc=loc, - ip=ip, - ) - return cute.TensorSSA(vec, vec_size, out_dtype) - - -@dsl_user_op -def _stg_u32xN( - tensor: cute.Tensor, - coord: cute.Coord, - values: cute.Tensor, - vec_size: cutlass.Constexpr[int], - modifier: cutlass.Constexpr[str] = "", - *, - loc=None, - ip=None, -) -> None: - base_ptr = ( - tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) - ).toint() - value_operands = ", ".join(f"${i + 1}" for i in range(vec_size)) - llvm.inline_asm( - None, - [Int64(base_ptr).ir_value(loc=loc, ip=ip)] - + [Uint32(values[i]).ir_value(loc=loc, ip=ip) for i in range(vec_size)], - f"st.global{modifier}.v{vec_size}.u32 [$0], {{{value_operands}}};", - ",".join(["l"] + ["r"] * vec_size), - has_side_effects=True, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) - - -class IndexerQMxFp4Kernel: - """Eight-thread subwarps process one ``(token, head)`` row.""" - - def __init__( - self, - head_dim: int = 128, - rope_dim: int = 64, - num_heads: int = 64, - cos_sin_dtype: type[cutlass.Numeric] = cutlass.Float32, - ): - self.head_dim = head_dim - self.rope_dim = rope_dim - self.nope_dim = head_dim - rope_dim - self.num_heads = num_heads - self.cos_sin_dtype = cos_sin_dtype - - # later we will use 32B load = 16 BF16 elems - # thus, head_dim=128 requires 8 threads to handle. - # let's call subwarp = 8 threads. - self.subwarp_size = head_dim // 16 - self.tb_size = 256 - - @cute.jit - def __call__( - self, - positions: cute.Tensor, - q: cute.Tensor, - cos_sin_cache: cute.Tensor, - weights: cute.Tensor, - q_fp4: cute.Tensor, - q_scale: cute.Tensor, - weights_out: cute.Tensor, - scale: Float32, - stream: CUstream, - ): - num_tokens, num_heads, _ = q.shape - total_threads = num_tokens * num_heads * self.subwarp_size - grid = [cute.ceil_div(total_threads, self.tb_size), 1, 1] - self.kernel( - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - scale, - ).launch(grid=grid, block=[self.tb_size, 1, 1], stream=stream) - - @cute.kernel - def kernel( - self, - positions: cute.Tensor, - q: cute.Tensor, - cos_sin_cache: cute.Tensor, - weights: cute.Tensor, - q_fp4: cute.Tensor, - q_scale: cute.Tensor, - weights_out: cute.Tensor, - scale: Float32, - ): - block_id, _, _ = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() - - num_token_heads = q.shape[0] * self.num_heads - global_tid = block_id * self.tb_size + tidx - - global_subwarp_id = global_tid // self.subwarp_size - sublane = tidx % self.subwarp_size - - token_id = global_subwarp_id // self.num_heads - head_id = global_subwarp_id - token_id * self.num_heads - - # each thread loads 16 BF16 elems - elem_base = sublane * 16 - - # q layout: [num_tokens, num_heads, head_dim] - _q_bits = _ldg_vec( - q, (token_id, head_id, elem_base), 8, ".relaxed.cta.L1::no_allocate" - ) - q_bits = cute.make_rmem_tensor(8, Uint32) - q_bits.store(_q_bits) # copy to make it mutable - - # RoPE applies only to the trailing rope_dim values. We keep the rounded - # BF16 result in q_bits so the later amax and quantization see BF16. - # cos_sin_cache layout: [max_pos, rope_dim] - if elem_base >= self.nope_dim: - pos = positions[token_id] - rope_idx = (elem_base - self.nope_dim) // 2 - if const_expr(self.cos_sin_dtype is Float32): - cos_vals = _ldg_vec( - cos_sin_cache, - (pos, rope_idx), - 8, - out_dtype=Float32, - ) - sin_vals = _ldg_vec( - cos_sin_cache, - (pos, self.rope_dim // 2 + rope_idx), - 8, - out_dtype=Float32, - ) - else: - # Each BF16 cache load lane contains two adjacent values. - cos_loaded = _ldg_vec(cos_sin_cache, (pos, rope_idx), 4) - sin_loaded = _ldg_vec( - cos_sin_cache, - (pos, self.rope_dim // 2 + rope_idx), - 4, - ) - cos_vals = cute.make_rmem_tensor(8, Float32) - sin_vals = cute.make_rmem_tensor(8, Float32) - for i in cutlass.range_constexpr(4): - cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( - cos_loaded[i] - ) - sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( - sin_loaded[i] - ) - - for i in cutlass.range_constexpr(8): - q0, q1 = _bf16x2_to_fp32(q_bits[i]) - cos = cos_vals[i] - sin = sin_vals[i] - rot0 = q0 * cos - q1 * sin - rot1 = q0 * sin + q1 * cos - # convert back to BF16 to match numerics - q_bits[i] = _fp32x2_to_bf16x2(rot0, rot1) - - # compute amax in packed bf16x2 to save instructions - # Each thread holds 16 elems. Two adjacent threads form one 32-elem - # MXFP4 block, so a width-2 shuffle gives the block amax. - local_amax = _bf16x2_abs(q_bits[0]) - for i in cutlass.range_constexpr(1, 8): - local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bits[i])) - amax_bits = cute_utils.warp_reduce( - local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 - ) - amax0, amax1 = _bf16x2_to_fp32(amax_bits) - amax = cute_utils.fmax(amax0, amax1) - - # compute block scale with bit manipulation - # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask - # increments the exponent whenever fp4_scale is not exactly a power of 2. - fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * (1.0 / 6.0) - bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) - ue8m0 = cute_utils.shr_u32(bits + Uint32(0x7FFFFF), Uint32(23)) & Uint32(0xFF) - - # Only one of the two threads in an MXFP4 block writes the shared scale. - if tidx % 2 == 0: - mx_block = sublane // (MXFP4_BLOCK_SIZE // 16) - q_scale[token_id, head_id, mx_block] = Uint8(ue8m0) - - # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent - # -A + 127 = 254 - ue8m0. - inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) - inv_fp4_scale = Float32(llvm.bitcast(T.f32(), inv_scale_bits.ir_value())) - - vals = cute.make_rmem_tensor(16, Float32) - for i in cutlass.range_constexpr(8): - vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bits[i]) - vals[i * 2] = vals[i * 2] * inv_fp4_scale - vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale - - # pack to FP4 - packed = cute.make_rmem_tensor(2, Uint32) - packed[0] = _fp32x8_to_fp4x8(vals, 0) - packed[1] = _fp32x8_to_fp4x8(vals, 8) - # Each thread writes the eight packed bytes corresponding to its 16 Q values. - _stg_u32xN(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") - - # Weight scaling is independent of the Q subwarp work. The first - # num_tokens * num_heads logical threads cover one weight each. - if global_tid < num_token_heads: - weight_token_id = global_tid // self.num_heads - weight_head_id = global_tid - weight_token_id * self.num_heads - weights_out[weight_token_id, weight_head_id] = ( - weights[weight_token_id, weight_head_id].to(Float32) * scale - ) - - -@cache -def _compile_indexer_q_mxfp4( - head_dim: int, rope_dim: int, num_heads: int, cos_sin_dtype: type[cutlass.Numeric] -): - num_tokens = cute.sym_int() - max_pos = cute.sym_int() - - q = make_fake_tensor(BFloat16, (num_tokens, num_heads, head_dim), divisibility=8) - positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) - cos_sin_cache = make_fake_tensor(cos_sin_dtype, (max_pos, rope_dim), divisibility=8) - weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) - q_fp4 = make_fake_tensor( - Uint8, (num_tokens, num_heads, head_dim // 2), divisibility=16 - ) - q_scale = make_fake_tensor( - Uint8, - (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), - divisibility=4, - ) - weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) - - kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) - stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - return cute.compile( - kernel, - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - Float32(0.0), - stream, - options="--enable-tvm-ffi", - ) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py new file mode 100644 index 000000000000..92ba02d689a7 --- /dev/null +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -0,0 +1,454 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# once we have more CuteDSL kernels in vLLM, we can refactor small helper functions +# to a separate file +from functools import cache + +import cutlass +import cutlass.cute as cute +import torch +from cuda.bindings.driver import CUstream +from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm, vector +from cutlass.cutlass_dsl import T, dsl_user_op +from quack.compile_utils import make_fake_tensor + +from vllm.vllm_flash_attn.cute import utils as cute_utils + +# MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. +MXFP4_BLOCK_SIZE = 32 + +_TORCH_TO_CUTE = { + torch.bfloat16: BFloat16, + torch.float32: Float32, +} + + +def fused_indexer_q_rope_quant_mxfp4_cutedsl( + positions: torch.Tensor, + index_q: torch.Tensor, + index_q_cos_sin_cache: torch.Tensor, + index_weights: torch.Tensor, + index_weights_softmax_scale: float, + index_weights_head_scale: float, + index_q_packed: torch.Tensor, + index_q_scale: torch.Tensor, + index_weights_out: torch.Tensor, +) -> None: + num_index_q_heads = index_q.shape[1] + index_q_head_dim = index_q.shape[2] + compiled = _compile_indexer_q_mxfp4( + index_q_head_dim, + index_q_cos_sin_cache.shape[-1], + num_index_q_heads, + _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype], + ) + scale = float(index_weights_softmax_scale * index_weights_head_scale) + compiled( + positions, + index_q, + index_q_cos_sin_cache, + index_weights, + index_q_packed, + index_q_scale, + index_weights_out, + scale, + ) + + +@dsl_user_op +def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + ], + "cvt.rn.bf16x2.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: + out = llvm.inline_asm( + llvm.StructType.get_literal([T.f32(), T.f32()]), + [Uint32(data).ir_value(loc=loc, ip=ip)], + "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n", + "=f,=f,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + return ( + Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), + Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), + ) + + +@dsl_user_op +def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Uint32(a).ir_value(loc=loc, ip=ip)], + "abs.bf16x2 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + Uint32(a).ir_value(loc=loc, ip=ip), + Uint32(b).ir_value(loc=loc, ip=ip), + ], + "max.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def _fp32x8_to_fp4x8( + vals: cute.Tensor, + offset: cutlass.Constexpr[int], + *, + loc=None, + ip=None, +) -> Uint32: + # Pack eight scaled FP32 values into four E2M1x2 bytes, returned as one b32. + operands = [Float32(vals[offset + i]).ir_value(loc=loc, ip=ip) for i in range(8)] + return Uint32( + llvm.inline_asm( + T.i32(), + operands, + "{\n\t" + ".reg .b8 x0, x1, x2, x3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" + "mov.b32 $0, {x0, x1, x2, x3};\n\t" + "}\n", + "=r,f,f,f,f,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +# Custom vectorized load to support cache modifiers. For some reason, +# cute.autovec_copy() does not currently emit the requested modifiers. +# tensor and coord is only used to select the base pointer. actual load +# is done using out_dtype +@dsl_user_op +def _ldg_vec( + tensor: cute.Tensor, + coord: cute.Coord, + vec_size: cutlass.Constexpr[int], + modifier: cutlass.Constexpr[str] = "", + out_dtype: cutlass.Constexpr[type[cutlass.Numeric]] = Uint32, + *, + loc=None, + ip=None, +) -> cute.TensorSSA: + if const_expr(out_dtype is Float32): + mlir_ty = T.f32() + ptx_ty = "f32" + constraint = "=f" + elif const_expr(out_dtype is Uint32): + mlir_ty = T.i32() + ptx_ty = "b32" + constraint = "=r" + else: + raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {out_dtype}") + + # compute base pointer + base_ptr = ( + tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) + ).toint() + + # build PTX string + ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" + ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" + ptx_str += f", [${vec_size}];" + out = llvm.inline_asm( + llvm.StructType.get_literal([mlir_ty] * vec_size), + [Int64(base_ptr).ir_value(loc=loc, ip=ip)], + ptx_str, + ",".join([constraint] * vec_size + ["l"]), + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + vec = vector.from_elements( + ir.VectorType.get([vec_size], mlir_ty, loc=loc), + [llvm.extractvalue(mlir_ty, out, [i], loc=loc, ip=ip) for i in range(vec_size)], + loc=loc, + ip=ip, + ) + return cute.TensorSSA(vec, vec_size, out_dtype) + + +@dsl_user_op +def _stg_u32xN( + tensor: cute.Tensor, + coord: cute.Coord, + values: cute.Tensor, + vec_size: cutlass.Constexpr[int], + modifier: cutlass.Constexpr[str] = "", + *, + loc=None, + ip=None, +) -> None: + base_ptr = ( + tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) + ).toint() + value_operands = ", ".join(f"${i + 1}" for i in range(vec_size)) + llvm.inline_asm( + None, + [Int64(base_ptr).ir_value(loc=loc, ip=ip)] + + [Uint32(values[i]).ir_value(loc=loc, ip=ip) for i in range(vec_size)], + f"st.global{modifier}.v{vec_size}.u32 [$0], {{{value_operands}}};", + ",".join(["l"] + ["r"] * vec_size), + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +class IndexerQMxFp4Kernel: + """Eight-thread subwarps process one ``(token, head)`` row.""" + + def __init__( + self, + head_dim: int = 128, + rope_dim: int = 64, + num_heads: int = 64, + cos_sin_dtype: type[cutlass.Numeric] = cutlass.Float32, + ): + self.head_dim = head_dim + self.rope_dim = rope_dim + self.nope_dim = head_dim - rope_dim + self.num_heads = num_heads + self.cos_sin_dtype = cos_sin_dtype + + # later we will use 32B load = 16 BF16 elems + # thus, head_dim=128 requires 8 threads to handle. + # let's call subwarp = 8 threads. + self.subwarp_size = head_dim // 16 + self.tb_size = 256 + + @cute.jit + def __call__( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp4: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + stream: CUstream, + ): + num_tokens, num_heads, _ = q.shape + total_threads = num_tokens * num_heads * self.subwarp_size + grid = [cute.ceil_div(total_threads, self.tb_size), 1, 1] + self.kernel( + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + scale, + ).launch(grid=grid, block=[self.tb_size, 1, 1], stream=stream) + + @cute.kernel + def kernel( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp4: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + block_id, _, _ = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + num_token_heads = q.shape[0] * self.num_heads + global_tid = block_id * self.tb_size + tidx + + global_subwarp_id = global_tid // self.subwarp_size + sublane = tidx % self.subwarp_size + + token_id = global_subwarp_id // self.num_heads + head_id = global_subwarp_id - token_id * self.num_heads + + # each thread loads 16 BF16 elems + elem_base = sublane * 16 + + # q layout: [num_tokens, num_heads, head_dim] + _q_bits = _ldg_vec( + q, (token_id, head_id, elem_base), 8, ".relaxed.cta.L1::no_allocate" + ) + q_bits = cute.make_rmem_tensor(8, Uint32) + q_bits.store(_q_bits) # copy to make it mutable + + # RoPE applies only to the trailing rope_dim values. We keep the rounded + # BF16 result in q_bits so the later amax and quantization see BF16. + # cos_sin_cache layout: [max_pos, rope_dim] + if elem_base >= self.nope_dim: + pos = positions[token_id] + rope_idx = (elem_base - self.nope_dim) // 2 + if const_expr(self.cos_sin_dtype is Float32): + cos_vals = _ldg_vec( + cos_sin_cache, + (pos, rope_idx), + 8, + out_dtype=Float32, + ) + sin_vals = _ldg_vec( + cos_sin_cache, + (pos, self.rope_dim // 2 + rope_idx), + 8, + out_dtype=Float32, + ) + else: + # Each BF16 cache load lane contains two adjacent values. + cos_loaded = _ldg_vec(cos_sin_cache, (pos, rope_idx), 4) + sin_loaded = _ldg_vec( + cos_sin_cache, + (pos, self.rope_dim // 2 + rope_idx), + 4, + ) + cos_vals = cute.make_rmem_tensor(8, Float32) + sin_vals = cute.make_rmem_tensor(8, Float32) + for i in cutlass.range_constexpr(4): + cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( + cos_loaded[i] + ) + sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( + sin_loaded[i] + ) + + for i in cutlass.range_constexpr(8): + q0, q1 = _bf16x2_to_fp32(q_bits[i]) + cos = cos_vals[i] + sin = sin_vals[i] + rot0 = q0 * cos - q1 * sin + rot1 = q0 * sin + q1 * cos + # convert back to BF16 to match numerics + q_bits[i] = _fp32x2_to_bf16x2(rot0, rot1) + + # compute amax in packed bf16x2 to save instructions + # Each thread holds 16 elems. Two adjacent threads form one 32-elem + # MXFP4 block, so a width-2 shuffle gives the block amax. + local_amax = _bf16x2_abs(q_bits[0]) + for i in cutlass.range_constexpr(1, 8): + local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bits[i])) + amax_bits = cute_utils.warp_reduce( + local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 + ) + amax0, amax1 = _bf16x2_to_fp32(amax_bits) + amax = cute_utils.fmax(amax0, amax1) + + # compute block scale with bit manipulation + # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask + # increments the exponent whenever fp4_scale is not exactly a power of 2. + fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * (1.0 / 6.0) + bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) + ue8m0 = cute_utils.shr_u32(bits + Uint32(0x7FFFFF), Uint32(23)) & Uint32(0xFF) + + # Only one of the two threads in an MXFP4 block writes the shared scale. + if tidx % 2 == 0: + mx_block = sublane // (MXFP4_BLOCK_SIZE // 16) + q_scale[token_id, head_id, mx_block] = Uint8(ue8m0) + + # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent + # -A + 127 = 254 - ue8m0. + inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) + inv_fp4_scale = Float32(llvm.bitcast(T.f32(), inv_scale_bits.ir_value())) + + vals = cute.make_rmem_tensor(16, Float32) + for i in cutlass.range_constexpr(8): + vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bits[i]) + vals[i * 2] = vals[i * 2] * inv_fp4_scale + vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale + + # pack to FP4 + packed = cute.make_rmem_tensor(2, Uint32) + packed[0] = _fp32x8_to_fp4x8(vals, 0) + packed[1] = _fp32x8_to_fp4x8(vals, 8) + # Each thread writes the eight packed bytes corresponding to its 16 Q values. + _stg_u32xN(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") + + # Weight scaling is independent of the Q subwarp work. The first + # num_tokens * num_heads logical threads cover one weight each. + if global_tid < num_token_heads: + weight_token_id = global_tid // self.num_heads + weight_head_id = global_tid - weight_token_id * self.num_heads + weights_out[weight_token_id, weight_head_id] = ( + weights[weight_token_id, weight_head_id].to(Float32) * scale + ) + + +@cache +def _compile_indexer_q_mxfp4( + head_dim: int, rope_dim: int, num_heads: int, cos_sin_dtype: type[cutlass.Numeric] +): + num_tokens = cute.sym_int() + max_pos = cute.sym_int() + + q = make_fake_tensor(BFloat16, (num_tokens, num_heads, head_dim), divisibility=8) + positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) + cos_sin_cache = make_fake_tensor(cos_sin_dtype, (max_pos, rope_dim), divisibility=8) + weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) + q_fp4 = make_fake_tensor( + Uint8, (num_tokens, num_heads, head_dim // 2), divisibility=16 + ) + q_scale = make_fake_tensor( + Uint8, + (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), + divisibility=4, + ) + weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) + + kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + return cute.compile( + kernel, + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + Float32(0.0), + stream, + options="--enable-tvm-ffi", + ) From 0f70fc66c98d693f6929a4c2624e4a76f7a980e2 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 12:52:40 +0000 Subject: [PATCH 09/21] remove unnecessary diffs Signed-off-by: Thien Tran --- .../ops/deepseek_v4_ops/fused_indexer_q.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index e36eba64f887..567e007d70cd 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -249,7 +249,7 @@ def _fused_indexer_q_rope_mxfp4_kernel( half_off = tl.arange(0, HALF_BLOCK) - # NoPE blocks: direct load, pair as (even-index, odd-index) values. + # ---- NoPE blocks: direct load, pair as (even-index, odd-index) values ---- for b in tl.static_range(NUM_NOPE_BLOCKS): base = b * MXFP4_BLOCK x_lo = tl.load(q_base + base + half_off * 2).to(tl.float32) @@ -258,8 +258,8 @@ def _fused_indexer_q_rope_mxfp4_kernel( tl.store(out_base + base // 2 + half_off, packed) tl.store(scale_base + b, ue8m0) - # RoPE blocks: apply GPT-J interleaved RoPE to the block's 16 pairs, - # then quantize. Each block covers HALF_BLOCK (=16) cos/sin pairs. + # ---- RoPE blocks: apply GPT-J interleaved RoPE to the block's 16 pairs, + # then quantize. Each block covers HALF_BLOCK (=16) cos/sin pairs. ---- rot_q_base = q_base + INDEX_Q_NOPE_DIM for b in tl.static_range(NUM_ROPE_BLOCKS): pair_off = b * HALF_BLOCK + half_off # indices in [0, HALF_ROT_DIM) @@ -290,7 +290,7 @@ def _fused_indexer_q_rope_mxfp4_kernel( # MXFP4 Q emits a separate ue8m0 scale tensor of shape # (T, H, HEAD_DIM // MXFP4_BLOCK) alongside the packed values, so each # per-block scale is applied by the downstream MXFP4 logits kernel when - # dequantizing Q. There is no per-token scalar to fold into `weights`. + # dequantizing Q — there is no per-token scalar to fold into `weights`. index_weights = tl.load( index_weights_ptr + tok_idx * index_weights_stride + head_idx ).to(tl.float32) @@ -343,15 +343,17 @@ def fused_indexer_q_rope_quant( assert index_q.ndim == 3 assert index_q_cos_sin_cache.ndim == 2 + num_tokens = positions.shape[0] + num_index_q_heads = index_q.shape[1] index_q_head_dim = index_q.shape[2] + index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) + if use_fp4: assert index_q_head_dim % MXFP4_BLOCK_SIZE == 0, ( f"head_dim={index_q_head_dim} must be a multiple of MXFP4 block " f"size {MXFP4_BLOCK_SIZE}" ) - num_tokens = positions.shape[0] - num_index_q_heads = index_q.shape[1] num_scale_blocks = index_q_head_dim // MXFP4_BLOCK_SIZE index_q_packed = torch.empty( (num_tokens, num_index_q_heads, index_q_head_dim // 2), @@ -363,8 +365,6 @@ def fused_indexer_q_rope_quant( dtype=torch.uint8, device=index_q.device, ) - index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) - if fused_indexer_q_rope_quant_mxfp4_cutedsl is not None: fused_indexer_q_rope_quant_mxfp4_cutedsl( positions, @@ -377,9 +377,7 @@ def fused_indexer_q_rope_quant( index_q_scale, index_weights_out, ) - else: - # Triton fallback _fused_indexer_q_rope_mxfp4_kernel[(num_tokens, num_index_q_heads)]( positions, index_q, @@ -415,9 +413,6 @@ def fused_indexer_q_rope_quant( index_q_scale.view(torch.int32).squeeze(-1), ), index_weights_out - num_tokens = positions.shape[0] - num_index_q_heads = index_q.shape[1] - index_weights_out = torch.empty_like(index_weights, dtype=torch.float32) index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( positions, From 432cc37a22f9b7657b0ed6a4af20102ba91aafab Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 13:52:47 +0000 Subject: [PATCH 10/21] fix wrong check condition Signed-off-by: Thien Tran --- vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 567e007d70cd..f62bd68977b0 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -365,7 +365,7 @@ def fused_indexer_q_rope_quant( dtype=torch.uint8, device=index_q.device, ) - if fused_indexer_q_rope_quant_mxfp4_cutedsl is not None: + if HAS_CUTEDSL: fused_indexer_q_rope_quant_mxfp4_cutedsl( positions, index_q, From df23d880704d736eb56178513ec79045f121c586 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 14:18:43 +0000 Subject: [PATCH 11/21] lazy import to fix CUDA init failure Signed-off-by: Thien Tran --- .../ops/deepseek_v4_ops/fused_indexer_q.py | 23 ++++--------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index f62bd68977b0..39723d533d3c 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -9,24 +9,6 @@ HAS_CUTEDSL = find_spec("cutlass") is not None -if HAS_CUTEDSL: - from .fused_indexer_q_cutedsl import fused_indexer_q_rope_quant_mxfp4_cutedsl -else: - - def fused_indexer_q_rope_quant_mxfp4_cutedsl( - positions: torch.Tensor, - index_q: torch.Tensor, - index_q_cos_sin_cache: torch.Tensor, - index_weights: torch.Tensor, - index_weights_softmax_scale: float, - index_weights_head_scale: float, - index_q_packed: torch.Tensor, - index_q_scale: torch.Tensor, - index_weights_out: torch.Tensor, - ) -> None: - pass - - # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. MXFP4_BLOCK_SIZE = 32 @@ -366,6 +348,11 @@ def fused_indexer_q_rope_quant( device=index_q.device, ) if HAS_CUTEDSL: + # lazily import, otherwise some tests fail due to CUDA driver init failure. + from .fused_indexer_q_cutedsl import ( + fused_indexer_q_rope_quant_mxfp4_cutedsl, + ) + fused_indexer_q_rope_quant_mxfp4_cutedsl( positions, index_q, From 08cd0b407fc43ac1afaf68bc9fb5da4454b6445c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 May 2026 23:33:12 +0000 Subject: [PATCH 12/21] move has_cutedsl() to import_utils.py Signed-off-by: Thien Tran --- vllm/utils/import_utils.py | 5 +++++ vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py | 7 ++----- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/utils/import_utils.py b/vllm/utils/import_utils.py index 6cf57c6894ab..5822e5840afc 100644 --- a/vllm/utils/import_utils.py +++ b/vllm/utils/import_utils.py @@ -469,3 +469,8 @@ def has_mori() -> bool: def has_fbgemm_gpu() -> bool: """Whether the optional `fbgemm_gpu` package is available.""" return _has_module("fbgemm_gpu") + + +def has_cutedsl() -> bool: + """Whether the optional `cutelass` package is available.""" + return _has_module("cutlass") diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index 39723d533d3c..ec880f7ab4c4 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -1,13 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from importlib.util import find_spec - import torch from vllm.triton_utils import tl, triton - -HAS_CUTEDSL = find_spec("cutlass") is not None +from vllm.utils.import_utils import has_cutedsl # MXFP4: 32 elements per block, packed 2 nibbles per byte, ue8m0 block scale. MXFP4_BLOCK_SIZE = 32 @@ -347,7 +344,7 @@ def fused_indexer_q_rope_quant( dtype=torch.uint8, device=index_q.device, ) - if HAS_CUTEDSL: + if has_cutedsl(): # lazily import, otherwise some tests fail due to CUDA driver init failure. from .fused_indexer_q_cutedsl import ( fused_indexer_q_rope_quant_mxfp4_cutedsl, From a7db9e9cc22fe9309c058d00a0a94053085e16a1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 6 May 2026 03:49:10 +0000 Subject: [PATCH 13/21] remove asm_dialect. clean up Signed-off-by: Thien Tran --- .../fused_indexer_q_cutedsl.py | 293 +++++++++--------- 1 file changed, 154 insertions(+), 139 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 92ba02d689a7..c2e6399b77ff 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -38,7 +38,7 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( ) -> None: num_index_q_heads = index_q.shape[1] index_q_head_dim = index_q.shape[2] - compiled = _compile_indexer_q_mxfp4( + compiled = IndexerQMxFp4Kernel.compile( index_q_head_dim, index_q_cos_sin_cache.shape[-1], num_index_q_heads, @@ -59,32 +59,26 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( @dsl_user_op def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [ - Float32(a).ir_value(loc=loc, ip=ip), - Float32(b).ir_value(loc=loc, ip=ip), - ], - "cvt.rn.bf16x2.f32 $0, $2, $1;", - "=r,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) + out = llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16x2.f32 $0, $2, $1;", + "=r,f,f", + has_side_effects=False, + is_align_stack=False, ) + return Uint32(out) @dsl_user_op def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: out = llvm.inline_asm( llvm.StructType.get_literal([T.f32(), T.f32()]), - [Uint32(data).ir_value(loc=loc, ip=ip)], + [data.ir_value(loc=loc, ip=ip)], "shl.b32 $0, $2, 16;\n\tand.b32 $1, $2, 0xFFFF0000;\n", "=f,=f,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) return ( Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), @@ -94,35 +88,28 @@ def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float3 @dsl_user_op def _bf16x2_abs(a: Uint32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [Uint32(a).ir_value(loc=loc, ip=ip)], - "abs.bf16x2 $0, $1;", - "=r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) + out = llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip)], + "abs.bf16x2 $0, $1;", + "=r,r", + has_side_effects=False, + is_align_stack=False, ) + return Uint32(out) @dsl_user_op def _bf16x2_max(a: Uint32, b: Uint32, *, loc=None, ip=None) -> Uint32: - return Uint32( - llvm.inline_asm( - T.i32(), - [ - Uint32(a).ir_value(loc=loc, ip=ip), - Uint32(b).ir_value(loc=loc, ip=ip), - ], - "max.bf16x2 $0, $1, $2;", - "=r,r,r", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) + out = llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + "max.bf16x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, ) + return Uint32(out) @dsl_user_op @@ -134,25 +121,23 @@ def _fp32x8_to_fp4x8( ip=None, ) -> Uint32: # Pack eight scaled FP32 values into four E2M1x2 bytes, returned as one b32. - operands = [Float32(vals[offset + i]).ir_value(loc=loc, ip=ip) for i in range(8)] - return Uint32( - llvm.inline_asm( - T.i32(), - operands, - "{\n\t" - ".reg .b8 x0, x1, x2, x3;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" - "mov.b32 $0, {x0, x1, x2, x3};\n\t" - "}\n", - "=r,f,f,f,f,f,f,f,f", - has_side_effects=False, - is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, - ) + assert vals.element_type is Float32 + out = llvm.inline_asm( + T.i32(), + [vals[offset + i].ir_value(loc=loc, ip=ip) for i in range(8)], + "{\n\t" + ".reg .b8 x0, x1, x2, x3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x0, $2, $1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x1, $4, $3;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x2, $6, $5;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 x3, $8, $7;\n\t" + "mov.b32 $0, {x0, x1, x2, x3};\n\t" + "}\n", + "=r,f,f,f,f,f,f,f,f", + has_side_effects=False, + is_align_stack=False, ) + return Uint32(out) # Custom vectorized load to support cache modifiers. For some reason, @@ -165,21 +150,21 @@ def _ldg_vec( coord: cute.Coord, vec_size: cutlass.Constexpr[int], modifier: cutlass.Constexpr[str] = "", - out_dtype: cutlass.Constexpr[type[cutlass.Numeric]] = Uint32, + ld_type: cutlass.Constexpr[type[cutlass.Numeric] | None] = None, *, loc=None, ip=None, ) -> cute.TensorSSA: - if const_expr(out_dtype is Float32): - mlir_ty = T.f32() + if ld_type is None: + ld_type = tensor.element_type + if const_expr(ld_type is Float32): ptx_ty = "f32" constraint = "=f" - elif const_expr(out_dtype is Uint32): - mlir_ty = T.i32() - ptx_ty = "b32" + elif const_expr(ld_type is Uint32): + ptx_ty = "u32" constraint = "=r" else: - raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {out_dtype}") + raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {ld_type}") # compute base pointer base_ptr = ( @@ -190,26 +175,29 @@ def _ldg_vec( ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" ptx_str += f", [${vec_size}];" + out = llvm.inline_asm( - llvm.StructType.get_literal([mlir_ty] * vec_size), + llvm.StructType.get_literal([ld_type.mlir_type] * vec_size), [Int64(base_ptr).ir_value(loc=loc, ip=ip)], ptx_str, ",".join([constraint] * vec_size + ["l"]), has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) vec = vector.from_elements( - ir.VectorType.get([vec_size], mlir_ty, loc=loc), - [llvm.extractvalue(mlir_ty, out, [i], loc=loc, ip=ip) for i in range(vec_size)], + ir.VectorType.get([vec_size], ld_type.mlir_type, loc=loc), + [ + llvm.extractvalue(ld_type.mlir_type, out, [i], loc=loc, ip=ip) + for i in range(vec_size) + ], loc=loc, ip=ip, ) - return cute.TensorSSA(vec, vec_size, out_dtype) + return cute.TensorSSA(vec, vec_size, ld_type) @dsl_user_op -def _stg_u32xN( +def _stg_vec( tensor: cute.Tensor, coord: cute.Coord, values: cute.Tensor, @@ -219,19 +207,34 @@ def _stg_u32xN( loc=None, ip=None, ) -> None: + # NOTE: st_type is derived from values tensor + st_type = values.element_type + if const_expr(st_type is Float32): + ptx_ty = "f32" + constraint = "f" + elif const_expr(st_type is Uint32): + ptx_ty = "u32" + constraint = "r" + else: + raise TypeError(f"_stg_vec only supports Uint32 and Float32, got {st_type}") + + # compute base pointer base_ptr = ( tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) ).toint() - value_operands = ", ".join(f"${i + 1}" for i in range(vec_size)) + + # build PTX string + ptx_str = f"st.global{modifier}.v{vec_size}.{ptx_ty} [$0], " + ptx_str += "{" + ", ".join(f"${i + 1}" for i in range(vec_size)) + "};" + llvm.inline_asm( None, [Int64(base_ptr).ir_value(loc=loc, ip=ip)] - + [Uint32(values[i]).ir_value(loc=loc, ip=ip) for i in range(vec_size)], - f"st.global{modifier}.v{vec_size}.u32 [$0], {{{value_operands}}};", - ",".join(["l"] + ["r"] * vec_size), + + [values[i].ir_value(loc=loc, ip=ip) for i in range(vec_size)], + ptx_str, + ",".join(["l"] + [constraint] * vec_size), has_side_effects=True, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT, ) @@ -312,11 +315,15 @@ def kernel( elem_base = sublane * 16 # q layout: [num_tokens, num_heads, head_dim] - _q_bits = _ldg_vec( - q, (token_id, head_id, elem_base), 8, ".relaxed.cta.L1::no_allocate" + _q_bf16x2 = _ldg_vec( + q, + (token_id, head_id, elem_base), + 8, + ".relaxed.cta.L1::no_allocate", + ld_type=Uint32, ) - q_bits = cute.make_rmem_tensor(8, Uint32) - q_bits.store(_q_bits) # copy to make it mutable + q_bf16x2 = cute.make_rmem_tensor(8, Uint32) + q_bf16x2.store(_q_bf16x2) # copy to make it mutable # RoPE applies only to the trailing rope_dim values. We keep the rounded # BF16 result in q_bits so the later amax and quantization see BF16. @@ -325,51 +332,48 @@ def kernel( pos = positions[token_id] rope_idx = (elem_base - self.nope_dim) // 2 if const_expr(self.cos_sin_dtype is Float32): - cos_vals = _ldg_vec( - cos_sin_cache, - (pos, rope_idx), - 8, - out_dtype=Float32, - ) + # fp32x8 loads + cos_vals = _ldg_vec(cos_sin_cache, (pos, rope_idx), 8) sin_vals = _ldg_vec( - cos_sin_cache, - (pos, self.rope_dim // 2 + rope_idx), - 8, - out_dtype=Float32, + cos_sin_cache, (pos, self.rope_dim // 2 + rope_idx), 8 ) else: - # Each BF16 cache load lane contains two adjacent values. - cos_loaded = _ldg_vec(cos_sin_cache, (pos, rope_idx), 4) - sin_loaded = _ldg_vec( + # bf16x2x4 loads + cos_bf16x2 = _ldg_vec( + cos_sin_cache, + (pos, rope_idx), + 4, + ld_type=Uint32, + ) + sin_bf16x2 = _ldg_vec( cos_sin_cache, (pos, self.rope_dim // 2 + rope_idx), 4, + ld_type=Uint32, ) cos_vals = cute.make_rmem_tensor(8, Float32) sin_vals = cute.make_rmem_tensor(8, Float32) for i in cutlass.range_constexpr(4): cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( - cos_loaded[i] + cos_bf16x2[i] ) sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( - sin_loaded[i] + sin_bf16x2[i] ) for i in cutlass.range_constexpr(8): - q0, q1 = _bf16x2_to_fp32(q_bits[i]) - cos = cos_vals[i] - sin = sin_vals[i] - rot0 = q0 * cos - q1 * sin - rot1 = q0 * sin + q1 * cos + q0, q1 = _bf16x2_to_fp32(q_bf16x2[i]) + rot0 = q0 * cos_vals[i] - q1 * sin_vals[i] + rot1 = q0 * sin_vals[i] + q1 * cos_vals[i] # convert back to BF16 to match numerics - q_bits[i] = _fp32x2_to_bf16x2(rot0, rot1) + q_bf16x2[i] = _fp32x2_to_bf16x2(rot0, rot1) # compute amax in packed bf16x2 to save instructions # Each thread holds 16 elems. Two adjacent threads form one 32-elem # MXFP4 block, so a width-2 shuffle gives the block amax. - local_amax = _bf16x2_abs(q_bits[0]) + local_amax = _bf16x2_abs(q_bf16x2[0]) for i in cutlass.range_constexpr(1, 8): - local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bits[i])) + local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bf16x2[i])) amax_bits = cute_utils.warp_reduce( local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 ) @@ -395,7 +399,7 @@ def kernel( vals = cute.make_rmem_tensor(16, Float32) for i in cutlass.range_constexpr(8): - vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bits[i]) + vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bf16x2[i]) vals[i * 2] = vals[i * 2] * inv_fp4_scale vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale @@ -404,7 +408,7 @@ def kernel( packed[0] = _fp32x8_to_fp4x8(vals, 0) packed[1] = _fp32x8_to_fp4x8(vals, 8) # Each thread writes the eight packed bytes corresponding to its 16 Q values. - _stg_u32xN(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") + _stg_vec(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") # Weight scaling is independent of the Q subwarp work. The first # num_tokens * num_heads logical threads cover one weight each. @@ -415,40 +419,51 @@ def kernel( weights[weight_token_id, weight_head_id].to(Float32) * scale ) + @cache + @staticmethod + def compile( + head_dim: int, + rope_dim: int, + num_heads: int, + cos_sin_dtype: type[cutlass.Numeric], + ): + num_tokens = cute.sym_int() + max_pos = cute.sym_int() -@cache -def _compile_indexer_q_mxfp4( - head_dim: int, rope_dim: int, num_heads: int, cos_sin_dtype: type[cutlass.Numeric] -): - num_tokens = cute.sym_int() - max_pos = cute.sym_int() - - q = make_fake_tensor(BFloat16, (num_tokens, num_heads, head_dim), divisibility=8) - positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) - cos_sin_cache = make_fake_tensor(cos_sin_dtype, (max_pos, rope_dim), divisibility=8) - weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) - q_fp4 = make_fake_tensor( - Uint8, (num_tokens, num_heads, head_dim // 2), divisibility=16 - ) - q_scale = make_fake_tensor( - Uint8, - (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), - divisibility=4, - ) - weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) + q = make_fake_tensor( + BFloat16, (num_tokens, num_heads, head_dim), divisibility=8 + ) + positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) + cos_sin_cache = make_fake_tensor( + cos_sin_dtype, + (max_pos, rope_dim), + divisibility=8, + ) + weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) + q_fp4 = make_fake_tensor( + Uint8, + (num_tokens, num_heads, head_dim // 2), + divisibility=16, + ) + q_scale = make_fake_tensor( + Uint8, + (num_tokens, num_heads, head_dim // MXFP4_BLOCK_SIZE), + divisibility=4, + ) + weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) - kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) - stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) - return cute.compile( - kernel, - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - Float32(0.0), - stream, - options="--enable-tvm-ffi", - ) + kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + return cute.compile( + kernel, + positions, + q, + cos_sin_cache, + weights, + q_fp4, + q_scale, + weights_out, + Float32(0.0), + stream, + options="--enable-tvm-ffi", + ) From 90d6713ce9756c3a580b56d2ed8fa1685ef9805d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 6 May 2026 08:45:40 +0000 Subject: [PATCH 14/21] add thread coarsening Signed-off-by: Thien Tran --- .../fused_indexer_q_cutedsl.py | 196 +++++++++++------- 1 file changed, 116 insertions(+), 80 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index c2e6399b77ff..6e66407ad92e 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -36,13 +36,15 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( index_q_scale: torch.Tensor, index_weights_out: torch.Tensor, ) -> None: - num_index_q_heads = index_q.shape[1] - index_q_head_dim = index_q.shape[2] + num_tokens, num_heads, head_dim = index_q.shape + # heuristic + tile_head = 1 if num_tokens < 512 else 4 compiled = IndexerQMxFp4Kernel.compile( - index_q_head_dim, + head_dim, index_q_cos_sin_cache.shape[-1], - num_index_q_heads, + num_heads, _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype], + tile_head, ) scale = float(index_weights_softmax_scale * index_weights_head_scale) compiled( @@ -246,7 +248,8 @@ def __init__( head_dim: int = 128, rope_dim: int = 64, num_heads: int = 64, - cos_sin_dtype: type[cutlass.Numeric] = cutlass.Float32, + cos_sin_dtype: type[cutlass.Numeric] = Float32, + tile_head: int = 4, ): self.head_dim = head_dim self.rope_dim = rope_dim @@ -254,11 +257,16 @@ def __init__( self.num_heads = num_heads self.cos_sin_dtype = cos_sin_dtype + # process multiple heads at the same time to armotize RoPE load costs + assert num_heads % tile_head == 0 + self.tile_head = tile_head + # later we will use 32B load = 16 BF16 elems # thus, head_dim=128 requires 8 threads to handle. # let's call subwarp = 8 threads. self.subwarp_size = head_dim // 16 - self.tb_size = 256 + self.tb_size = 128 + self.threads_per_token = (self.num_heads // self.tile_head) * self.subwarp_size @cute.jit def __call__( @@ -273,9 +281,8 @@ def __call__( scale: Float32, stream: CUstream, ): - num_tokens, num_heads, _ = q.shape - total_threads = num_tokens * num_heads * self.subwarp_size - grid = [cute.ceil_div(total_threads, self.tb_size), 1, 1] + total_threads = q.shape[0] * self.threads_per_token + grid = (cute.ceil_div(total_threads, self.tb_size), 1, 1) self.kernel( positions, q, @@ -285,7 +292,7 @@ def __call__( q_scale, weights_out, scale, - ).launch(grid=grid, block=[self.tb_size, 1, 1], stream=stream) + ).launch(grid=grid, block=(self.tb_size, 1, 1), stream=stream) @cute.kernel def kernel( @@ -300,35 +307,47 @@ def kernel( scale: Float32, ): block_id, _, _ = cute.arch.block_idx() - tidx, _, _ = cute.arch.thread_idx() + tid, _, _ = cute.arch.thread_idx() num_token_heads = q.shape[0] * self.num_heads - global_tid = block_id * self.tb_size + tidx + global_tid = block_id * self.tb_size + tid global_subwarp_id = global_tid // self.subwarp_size - sublane = tidx % self.subwarp_size + sublane = tid % self.subwarp_size + + token_id = global_subwarp_id // (self.num_heads // self.tile_head) + head_start = ( + global_subwarp_id % (self.num_heads // self.tile_head) + ) * self.tile_head - token_id = global_subwarp_id // self.num_heads - head_id = global_subwarp_id - token_id * self.num_heads + # NOTE: token_id may exceed bounds, hence we need to add load/store guards + # we can't do early exit because CuteDSL doesn't support it. and we also need + # all threads in a warp to be active since we utilize warp shuffle later. + # must_in_bounds is constexpr, True when 1 threadblock fit within 1 token + # position. the compiler will remove bounds check when that happens. + must_in_bounds = cutlass.const_expr(self.tb_size % self.threads_per_token == 0) + in_bounds = must_in_bounds or (token_id < q.shape[0]) # each thread loads 16 BF16 elems elem_base = sublane * 16 - # q layout: [num_tokens, num_heads, head_dim] - _q_bf16x2 = _ldg_vec( - q, - (token_id, head_id, elem_base), - 8, - ".relaxed.cta.L1::no_allocate", - ld_type=Uint32, - ) - q_bf16x2 = cute.make_rmem_tensor(8, Uint32) - q_bf16x2.store(_q_bf16x2) # copy to make it mutable + q_bf16x2 = cute.make_rmem_tensor((self.tile_head, 8), Uint32) + if in_bounds: + for i in cutlass.range_constexpr(self.tile_head): + # q layout: [num_tokens, num_heads, head_dim] + _q_bf16x2 = _ldg_vec( + q, + (token_id, head_start + i, elem_base), + 8, + ".relaxed.cta.L1::no_allocate", + ld_type=Uint32, + ) + q_bf16x2[i, None].store(_q_bf16x2) # copy to make it mutable # RoPE applies only to the trailing rope_dim values. We keep the rounded # BF16 result in q_bits so the later amax and quantization see BF16. # cos_sin_cache layout: [max_pos, rope_dim] - if elem_base >= self.nope_dim: + if in_bounds and elem_base >= self.nope_dim: pos = positions[token_id] rope_idx = (elem_base - self.nope_dim) // 2 if const_expr(self.cos_sin_dtype is Float32): @@ -361,60 +380,74 @@ def kernel( sin_bf16x2[i] ) - for i in cutlass.range_constexpr(8): - q0, q1 = _bf16x2_to_fp32(q_bf16x2[i]) - rot0 = q0 * cos_vals[i] - q1 * sin_vals[i] - rot1 = q0 * sin_vals[i] + q1 * cos_vals[i] - # convert back to BF16 to match numerics - q_bf16x2[i] = _fp32x2_to_bf16x2(rot0, rot1) - - # compute amax in packed bf16x2 to save instructions - # Each thread holds 16 elems. Two adjacent threads form one 32-elem - # MXFP4 block, so a width-2 shuffle gives the block amax. - local_amax = _bf16x2_abs(q_bf16x2[0]) - for i in cutlass.range_constexpr(1, 8): - local_amax = _bf16x2_max(local_amax, _bf16x2_abs(q_bf16x2[i])) - amax_bits = cute_utils.warp_reduce( - local_amax, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 - ) - amax0, amax1 = _bf16x2_to_fp32(amax_bits) - amax = cute_utils.fmax(amax0, amax1) - - # compute block scale with bit manipulation - # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask - # increments the exponent whenever fp4_scale is not exactly a power of 2. - fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * (1.0 / 6.0) - bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) - ue8m0 = cute_utils.shr_u32(bits + Uint32(0x7FFFFF), Uint32(23)) & Uint32(0xFF) - - # Only one of the two threads in an MXFP4 block writes the shared scale. - if tidx % 2 == 0: - mx_block = sublane // (MXFP4_BLOCK_SIZE // 16) - q_scale[token_id, head_id, mx_block] = Uint8(ue8m0) - - # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent - # -A + 127 = 254 - ue8m0. - inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) - inv_fp4_scale = Float32(llvm.bitcast(T.f32(), inv_scale_bits.ir_value())) - - vals = cute.make_rmem_tensor(16, Float32) - for i in cutlass.range_constexpr(8): - vals[i * 2], vals[i * 2 + 1] = _bf16x2_to_fp32(q_bf16x2[i]) - vals[i * 2] = vals[i * 2] * inv_fp4_scale - vals[i * 2 + 1] = vals[i * 2 + 1] * inv_fp4_scale - - # pack to FP4 - packed = cute.make_rmem_tensor(2, Uint32) - packed[0] = _fp32x8_to_fp4x8(vals, 0) - packed[1] = _fp32x8_to_fp4x8(vals, 8) - # Each thread writes the eight packed bytes corresponding to its 16 Q values. - _stg_vec(q_fp4, (token_id, head_id, elem_base // 2), packed, 2, ".cs") + for i in cutlass.range_constexpr(self.tile_head): + for j in cutlass.range_constexpr(8): + q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j]) + rot0 = q0 * cos_vals[j] - q1 * sin_vals[j] + rot1 = q0 * sin_vals[j] + q1 * cos_vals[j] + # convert back to BF16 to match numerics + q_bf16x2[i, j] = _fp32x2_to_bf16x2(rot0, rot1) + + for i in cutlass.range_constexpr(self.tile_head): + # compute amax in packed bf16x2 to save instructions + # Each thread holds 16 elems. Two adjacent threads form one 32-elem + # MXFP4 block, so a width-2 shuffle gives the block amax. + amax_bf16x2 = _bf16x2_abs(q_bf16x2[i, 0]) + for j in cutlass.range_constexpr(1, 8): + amax_bf16x2 = _bf16x2_max(amax_bf16x2, _bf16x2_abs(q_bf16x2[i, j])) + amax_bf16x2 = cute_utils.warp_reduce( + amax_bf16x2, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 + ) + amax0, amax1 = _bf16x2_to_fp32(amax_bf16x2) + amax = cute_utils.fmax(amax0, amax1) + + if in_bounds: + # compute block scale with bit manipulation + # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask + # increments the exponent whenever fp4_scale is not exactly a power of 2 + fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * Float32( + 1.0 / 6.0 + ) + bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) + ue8m0 = cute_utils.shr_u32( + bits + Uint32(0x7FFFFF), Uint32(23) + ) & Uint32(0xFF) + + # Only one of the two threads in an MXFP4 block writes the shared scale. + if tid % 2 == 0: + mx_block = sublane // 2 + q_scale[token_id, head_start + i, mx_block] = Uint8(ue8m0) + + # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent + # -A + 127 = 254 - ue8m0. + inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) + inv_fp4_scale = Float32( + llvm.bitcast(T.f32(), inv_scale_bits.ir_value()) + ) + + vals = cute.make_rmem_tensor(16, Float32) + for j in cutlass.range_constexpr(8): + vals[j * 2], vals[j * 2 + 1] = _bf16x2_to_fp32(q_bf16x2[i, j]) + vals[j * 2] = vals[j * 2] * inv_fp4_scale + vals[j * 2 + 1] = vals[j * 2 + 1] * inv_fp4_scale + + # pack to FP4 + packed = cute.make_rmem_tensor(2, Uint32) + packed[0] = _fp32x8_to_fp4x8(vals, 0) + packed[1] = _fp32x8_to_fp4x8(vals, 8) + _stg_vec( + q_fp4, + (token_id, head_start + i, elem_base // 2), + packed, + 2, + modifier=".cs", + ) # Weight scaling is independent of the Q subwarp work. The first # num_tokens * num_heads logical threads cover one weight each. if global_tid < num_token_heads: weight_token_id = global_tid // self.num_heads - weight_head_id = global_tid - weight_token_id * self.num_heads + weight_head_id = global_tid % self.num_heads weights_out[weight_token_id, weight_head_id] = ( weights[weight_token_id, weight_head_id].to(Float32) * scale ) @@ -422,10 +455,11 @@ def kernel( @cache @staticmethod def compile( - head_dim: int, - rope_dim: int, - num_heads: int, - cos_sin_dtype: type[cutlass.Numeric], + head_dim: int = 128, + rope_dim: int = 64, + num_heads: int = 64, + cos_sin_dtype: type[cutlass.Numeric] = Float32, + tile_head: int = 4, ): num_tokens = cute.sym_int() max_pos = cute.sym_int() @@ -452,7 +486,9 @@ def compile( ) weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) - kernel = IndexerQMxFp4Kernel(head_dim, rope_dim, num_heads, cos_sin_dtype) + kernel = IndexerQMxFp4Kernel( + head_dim, rope_dim, num_heads, cos_sin_dtype, tile_head + ) stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) return cute.compile( kernel, From e764a1f25b20be40fe508096310382f2b5bd1603 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 6 May 2026 10:35:49 +0000 Subject: [PATCH 15/21] format Signed-off-by: Thien Tran --- .../fused_indexer_q_cutedsl.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 6e66407ad92e..4a8dc46ace36 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -335,13 +335,9 @@ def kernel( if in_bounds: for i in cutlass.range_constexpr(self.tile_head): # q layout: [num_tokens, num_heads, head_dim] - _q_bf16x2 = _ldg_vec( - q, - (token_id, head_start + i, elem_base), - 8, - ".relaxed.cta.L1::no_allocate", - ld_type=Uint32, - ) + coord = (token_id, head_start + i, elem_base) + cache_mod = ".relaxed.cta.L1::no_allocate" + _q_bf16x2 = _ldg_vec(q, coord, 8, cache_mod, ld_type=Uint32) q_bf16x2[i, None].store(_q_bf16x2) # copy to make it mutable # RoPE applies only to the trailing rope_dim values. We keep the rounded @@ -435,13 +431,8 @@ def kernel( packed = cute.make_rmem_tensor(2, Uint32) packed[0] = _fp32x8_to_fp4x8(vals, 0) packed[1] = _fp32x8_to_fp4x8(vals, 8) - _stg_vec( - q_fp4, - (token_id, head_start + i, elem_base // 2), - packed, - 2, - modifier=".cs", - ) + coord = (token_id, head_start + i, elem_base // 2) + _stg_vec(q_fp4, coord, packed, 2, modifier=".cs") # Weight scaling is independent of the Q subwarp work. The first # num_tokens * num_heads logical threads cover one weight each. From 86f442534aa9e4980319ad6a6a3312155495f29c Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Wed, 6 May 2026 10:58:05 +0000 Subject: [PATCH 16/21] compile all variants at 1st invocation Signed-off-by: Thien Tran --- .../ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 4a8dc46ace36..5bdc92089d94 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -37,14 +37,17 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( index_weights_out: torch.Tensor, ) -> None: num_tokens, num_heads, head_dim = index_q.shape + rope_dim = index_q_cos_sin_cache.shape[-1] + rope_type = _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype] + + # compile all variants at first invocation + for tile_head in (1, 4): + IndexerQMxFp4Kernel.compile(head_dim, rope_dim, num_heads, rope_type, tile_head) + # heuristic tile_head = 1 if num_tokens < 512 else 4 compiled = IndexerQMxFp4Kernel.compile( - head_dim, - index_q_cos_sin_cache.shape[-1], - num_heads, - _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype], - tile_head, + head_dim, rope_dim, num_heads, rope_type, tile_head ) scale = float(index_weights_softmax_scale * index_weights_head_scale) compiled( From a1533c10a270c9bd25db47941df75cc11c63335a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 8 May 2026 06:58:31 +0000 Subject: [PATCH 17/21] eliminate custom loads and stores Signed-off-by: Thien Tran --- .../fused_indexer_q_cutedsl.py | 223 ++++++------------ 1 file changed, 68 insertions(+), 155 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 5bdc92089d94..4371da50119b 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -9,8 +9,7 @@ import torch from cuda.bindings.driver import CUstream from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr -from cutlass._mlir import ir -from cutlass._mlir.dialects import llvm, vector +from cutlass._mlir.dialects import llvm from cutlass.cutlass_dsl import T, dsl_user_op from quack.compile_utils import make_fake_tensor @@ -145,104 +144,6 @@ def _fp32x8_to_fp4x8( return Uint32(out) -# Custom vectorized load to support cache modifiers. For some reason, -# cute.autovec_copy() does not currently emit the requested modifiers. -# tensor and coord is only used to select the base pointer. actual load -# is done using out_dtype -@dsl_user_op -def _ldg_vec( - tensor: cute.Tensor, - coord: cute.Coord, - vec_size: cutlass.Constexpr[int], - modifier: cutlass.Constexpr[str] = "", - ld_type: cutlass.Constexpr[type[cutlass.Numeric] | None] = None, - *, - loc=None, - ip=None, -) -> cute.TensorSSA: - if ld_type is None: - ld_type = tensor.element_type - if const_expr(ld_type is Float32): - ptx_ty = "f32" - constraint = "=f" - elif const_expr(ld_type is Uint32): - ptx_ty = "u32" - constraint = "=r" - else: - raise TypeError(f"_ldg_vec only supports Uint32 and Float32, got {ld_type}") - - # compute base pointer - base_ptr = ( - tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) - ).toint() - - # build PTX string - ptx_str = f"ld.global{modifier}.v{vec_size}.{ptx_ty}" - ptx_str += "{" + ", ".join(f"${i}" for i in range(vec_size)) + "}" - ptx_str += f", [${vec_size}];" - - out = llvm.inline_asm( - llvm.StructType.get_literal([ld_type.mlir_type] * vec_size), - [Int64(base_ptr).ir_value(loc=loc, ip=ip)], - ptx_str, - ",".join([constraint] * vec_size + ["l"]), - has_side_effects=False, - is_align_stack=False, - ) - vec = vector.from_elements( - ir.VectorType.get([vec_size], ld_type.mlir_type, loc=loc), - [ - llvm.extractvalue(ld_type.mlir_type, out, [i], loc=loc, ip=ip) - for i in range(vec_size) - ], - loc=loc, - ip=ip, - ) - return cute.TensorSSA(vec, vec_size, ld_type) - - -@dsl_user_op -def _stg_vec( - tensor: cute.Tensor, - coord: cute.Coord, - values: cute.Tensor, - vec_size: cutlass.Constexpr[int], - modifier: cutlass.Constexpr[str] = "", - *, - loc=None, - ip=None, -) -> None: - # NOTE: st_type is derived from values tensor - st_type = values.element_type - if const_expr(st_type is Float32): - ptx_ty = "f32" - constraint = "f" - elif const_expr(st_type is Uint32): - ptx_ty = "u32" - constraint = "r" - else: - raise TypeError(f"_stg_vec only supports Uint32 and Float32, got {st_type}") - - # compute base pointer - base_ptr = ( - tensor.iterator + cute.crd2idx(coord, tensor.layout, loc=loc, ip=ip) - ).toint() - - # build PTX string - ptx_str = f"st.global{modifier}.v{vec_size}.{ptx_ty} [$0], " - ptx_str += "{" + ", ".join(f"${i + 1}" for i in range(vec_size)) + "};" - - llvm.inline_asm( - None, - [Int64(base_ptr).ir_value(loc=loc, ip=ip)] - + [values[i].ir_value(loc=loc, ip=ip) for i in range(vec_size)], - ptx_str, - ",".join(["l"] + [constraint] * vec_size), - has_side_effects=True, - is_align_stack=False, - ) - - class IndexerQMxFp4Kernel: """Eight-thread subwarps process one ``(token, head)`` row.""" @@ -252,7 +153,7 @@ def __init__( rope_dim: int = 64, num_heads: int = 64, cos_sin_dtype: type[cutlass.Numeric] = Float32, - tile_head: int = 4, + coarsen: int = 4, ): self.head_dim = head_dim self.rope_dim = rope_dim @@ -261,15 +162,15 @@ def __init__( self.cos_sin_dtype = cos_sin_dtype # process multiple heads at the same time to armotize RoPE load costs - assert num_heads % tile_head == 0 - self.tile_head = tile_head + assert num_heads % coarsen == 0 + self.coarsen = coarsen # later we will use 32B load = 16 BF16 elems # thus, head_dim=128 requires 8 threads to handle. # let's call subwarp = 8 threads. self.subwarp_size = head_dim // 16 self.tb_size = 128 - self.threads_per_token = (self.num_heads // self.tile_head) * self.subwarp_size + self.threads_per_token = (self.num_heads // self.coarsen) * self.subwarp_size @cute.jit def __call__( @@ -318,10 +219,9 @@ def kernel( global_subwarp_id = global_tid // self.subwarp_size sublane = tid % self.subwarp_size - token_id = global_subwarp_id // (self.num_heads // self.tile_head) - head_start = ( - global_subwarp_id % (self.num_heads // self.tile_head) - ) * self.tile_head + token_id = global_subwarp_id // (self.num_heads // self.coarsen) + head_tile_id = global_subwarp_id % (self.num_heads // self.coarsen) + head_start = head_tile_id * self.coarsen # NOTE: token_id may exceed bounds, hence we need to add load/store guards # we can't do early exit because CuteDSL doesn't support it. and we also need @@ -331,55 +231,60 @@ def kernel( must_in_bounds = cutlass.const_expr(self.tb_size % self.threads_per_token == 0) in_bounds = must_in_bounds or (token_id < q.shape[0]) - # each thread loads 16 BF16 elems - elem_base = sublane * 16 + cp_op = cute.nvgpu.CopyUniversalOp() + cp_u32x2 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=64) + cp_u32x4 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=128) + cp_u32x8 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=256) + cp_f32x8 = cute.make_copy_atom(cp_op, Float32, num_bits_per_copy=256) + + _layout = cute.make_layout((self.coarsen, 8), stride=(8, 1)) + q_bf16x2 = cute.make_rmem_tensor(_layout, Uint32) - q_bf16x2 = cute.make_rmem_tensor((self.tile_head, 8), Uint32) if in_bounds: - for i in cutlass.range_constexpr(self.tile_head): - # q layout: [num_tokens, num_heads, head_dim] - coord = (token_id, head_start + i, elem_base) - cache_mod = ".relaxed.cta.L1::no_allocate" - _q_bf16x2 = _ldg_vec(q, coord, 8, cache_mod, ld_type=Uint32) - q_bf16x2[i, None].store(_q_bf16x2) # copy to make it mutable + src = cute.local_tile( + q[token_id, None, None], + tiler=(self.coarsen, 16), + coord=(head_tile_id, sublane), + ) + cute.copy(cp_u32x8, cute.recast_tensor(src, Uint32), q_bf16x2) # RoPE applies only to the trailing rope_dim values. We keep the rounded # BF16 result in q_bits so the later amax and quantization see BF16. # cos_sin_cache layout: [max_pos, rope_dim] - if in_bounds and elem_base >= self.nope_dim: + if in_bounds and sublane * 16 >= self.nope_dim: + cos_vals = cute.make_rmem_tensor((8,), Float32) + sin_vals = cute.make_rmem_tensor((8,), Float32) + pos = positions[token_id] - rope_idx = (elem_base - self.nope_dim) // 2 + + # select 8 elems from cos and sin + cos_id = sublane - self.nope_dim // 16 + sin_id = cos_id + self.rope_dim // 16 + cos_src = cute.local_tile( + cos_sin_cache[pos, None], tiler=(8,), coord=(cos_id,) + ) + sin_src = cute.local_tile( + cos_sin_cache[pos, None], tiler=(8,), coord=(sin_id,) + ) + if const_expr(self.cos_sin_dtype is Float32): - # fp32x8 loads - cos_vals = _ldg_vec(cos_sin_cache, (pos, rope_idx), 8) - sin_vals = _ldg_vec( - cos_sin_cache, (pos, self.rope_dim // 2 + rope_idx), 8 - ) + cute.copy(cp_f32x8, cos_src, cos_vals) + cute.copy(cp_f32x8, sin_src, sin_vals) else: - # bf16x2x4 loads - cos_bf16x2 = _ldg_vec( - cos_sin_cache, - (pos, rope_idx), - 4, - ld_type=Uint32, - ) - sin_bf16x2 = _ldg_vec( - cos_sin_cache, - (pos, self.rope_dim // 2 + rope_idx), - 4, - ld_type=Uint32, - ) - cos_vals = cute.make_rmem_tensor(8, Float32) - sin_vals = cute.make_rmem_tensor(8, Float32) + cos_bf16x2 = cute.make_rmem_tensor((4,), Uint32) + sin_bf16x2 = cute.make_rmem_tensor((4,), Uint32) + cute.copy(cp_u32x4, cute.recast_tensor(cos_src, Uint32), cos_bf16x2) + cute.copy(cp_u32x4, cute.recast_tensor(sin_src, Uint32), sin_bf16x2) + for i in cutlass.range_constexpr(4): - cos_vals[i * 2], cos_vals[i * 2 + 1] = _bf16x2_to_fp32( - cos_bf16x2[i] - ) - sin_vals[i * 2], sin_vals[i * 2 + 1] = _bf16x2_to_fp32( - sin_bf16x2[i] - ) - - for i in cutlass.range_constexpr(self.tile_head): + tmp_cos = _bf16x2_to_fp32(cos_bf16x2[i]) + tmp_sin = _bf16x2_to_fp32(sin_bf16x2[i]) + cos_vals[i * 2] = tmp_cos[0] + cos_vals[i * 2 + 1] = tmp_cos[1] + sin_vals[i * 2] = tmp_sin[0] + sin_vals[i * 2 + 1] = tmp_sin[1] + + for i in cutlass.range_constexpr(self.coarsen): for j in cutlass.range_constexpr(8): q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j]) rot0 = q0 * cos_vals[j] - q1 * sin_vals[j] @@ -387,7 +292,14 @@ def kernel( # convert back to BF16 to match numerics q_bf16x2[i, j] = _fp32x2_to_bf16x2(rot0, rot1) - for i in cutlass.range_constexpr(self.tile_head): + # layout: [coarsen, 8] + q_fp4_tile = cute.local_tile( + q_fp4[token_id, None, None], + tiler=(self.coarsen, 8), + coord=(head_tile_id, sublane), + ) + + for i in cutlass.range_constexpr(self.coarsen): # compute amax in packed bf16x2 to save instructions # Each thread holds 16 elems. Two adjacent threads form one 32-elem # MXFP4 block, so a width-2 shuffle gives the block amax. @@ -426,16 +338,17 @@ def kernel( vals = cute.make_rmem_tensor(16, Float32) for j in cutlass.range_constexpr(8): - vals[j * 2], vals[j * 2 + 1] = _bf16x2_to_fp32(q_bf16x2[i, j]) - vals[j * 2] = vals[j * 2] * inv_fp4_scale - vals[j * 2 + 1] = vals[j * 2 + 1] * inv_fp4_scale + tmp = _bf16x2_to_fp32(q_bf16x2[i, j]) + vals[j * 2] = tmp[0] * inv_fp4_scale + vals[j * 2 + 1] = tmp[1] * inv_fp4_scale # pack to FP4 - packed = cute.make_rmem_tensor(2, Uint32) + packed = cute.make_rmem_tensor((2,), Uint32) packed[0] = _fp32x8_to_fp4x8(vals, 0) packed[1] = _fp32x8_to_fp4x8(vals, 8) - coord = (token_id, head_start + i, elem_base // 2) - _stg_vec(q_fp4, coord, packed, 2, modifier=".cs") + + dst = q_fp4_tile[i, None] + cute.copy(cp_u32x2, packed, cute.recast_tensor(dst, Uint32)) # Weight scaling is independent of the Q subwarp work. The first # num_tokens * num_heads logical threads cover one weight each. @@ -459,7 +372,7 @@ def compile( max_pos = cute.sym_int() q = make_fake_tensor( - BFloat16, (num_tokens, num_heads, head_dim), divisibility=8 + BFloat16, (num_tokens, num_heads, head_dim), divisibility=16 ) positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) cos_sin_cache = make_fake_tensor( From c0b5c69318de8184fc4fcbc6fa967bda95944fc1 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 8 May 2026 07:41:30 +0000 Subject: [PATCH 18/21] _bf16x2_to_fp32 returns TensorSSA. cleanup Signed-off-by: Thien Tran --- .../fused_indexer_q_cutedsl.py | 54 +++++++++++-------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 4371da50119b..ffeb26893d38 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -9,8 +9,8 @@ import torch from cuda.bindings.driver import CUstream from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr -from cutlass._mlir.dialects import llvm -from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, vector +from cutlass.cutlass_dsl import T, dsl_user_op, ir from quack.compile_utils import make_fake_tensor from vllm.vllm_flash_attn.cute import utils as cute_utils @@ -75,7 +75,7 @@ def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: @dsl_user_op -def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: +def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> cute.TensorSSA: out = llvm.inline_asm( llvm.StructType.get_literal([T.f32(), T.f32()]), [data.ir_value(loc=loc, ip=ip)], @@ -84,10 +84,16 @@ def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float3 has_side_effects=False, is_align_stack=False, ) - return ( - Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), - Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), + vec = vector.from_elements( + ir.VectorType.get([2], T.f32(), loc=loc), + [ + llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip), + llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip), + ], + loc=loc, + ip=ip, ) + return cute.TensorSSA(vec, 2, Float32) @dsl_user_op @@ -232,10 +238,6 @@ def kernel( in_bounds = must_in_bounds or (token_id < q.shape[0]) cp_op = cute.nvgpu.CopyUniversalOp() - cp_u32x2 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=64) - cp_u32x4 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=128) - cp_u32x8 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=256) - cp_f32x8 = cute.make_copy_atom(cp_op, Float32, num_bits_per_copy=256) _layout = cute.make_layout((self.coarsen, 8), stride=(8, 1)) q_bf16x2 = cute.make_rmem_tensor(_layout, Uint32) @@ -246,6 +248,7 @@ def kernel( tiler=(self.coarsen, 16), coord=(head_tile_id, sublane), ) + cp_u32x8 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=256) cute.copy(cp_u32x8, cute.recast_tensor(src, Uint32), q_bf16x2) # RoPE applies only to the trailing rope_dim values. We keep the rounded @@ -267,6 +270,9 @@ def kernel( cos_sin_cache[pos, None], tiler=(8,), coord=(sin_id,) ) + cp_f32x8 = cute.make_copy_atom(cp_op, Float32, num_bits_per_copy=256) + cp_u32x4 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=128) + if const_expr(self.cos_sin_dtype is Float32): cute.copy(cp_f32x8, cos_src, cos_vals) cute.copy(cp_f32x8, sin_src, sin_vals) @@ -276,19 +282,18 @@ def kernel( cute.copy(cp_u32x4, cute.recast_tensor(cos_src, Uint32), cos_bf16x2) cute.copy(cp_u32x4, cute.recast_tensor(sin_src, Uint32), sin_bf16x2) + cos_vals_view = cute.logical_divide(cos_vals, 2) # (2, 4) + sin_vals_view = cute.logical_divide(sin_vals, 2) + for i in cutlass.range_constexpr(4): - tmp_cos = _bf16x2_to_fp32(cos_bf16x2[i]) - tmp_sin = _bf16x2_to_fp32(sin_bf16x2[i]) - cos_vals[i * 2] = tmp_cos[0] - cos_vals[i * 2 + 1] = tmp_cos[1] - sin_vals[i * 2] = tmp_sin[0] - sin_vals[i * 2 + 1] = tmp_sin[1] + cos_vals_view[None, i].store(_bf16x2_to_fp32(cos_bf16x2[i])) + sin_vals_view[None, i].store(_bf16x2_to_fp32(sin_bf16x2[i])) for i in cutlass.range_constexpr(self.coarsen): for j in cutlass.range_constexpr(8): - q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j]) - rot0 = q0 * cos_vals[j] - q1 * sin_vals[j] - rot1 = q0 * sin_vals[j] + q1 * cos_vals[j] + q_pair = _bf16x2_to_fp32(q_bf16x2[i, j]) + rot0 = q_pair[0] * cos_vals[j] - q_pair[1] * sin_vals[j] + rot1 = q_pair[0] * sin_vals[j] + q_pair[1] * cos_vals[j] # convert back to BF16 to match numerics q_bf16x2[i, j] = _fp32x2_to_bf16x2(rot0, rot1) @@ -307,10 +312,12 @@ def kernel( for j in cutlass.range_constexpr(1, 8): amax_bf16x2 = _bf16x2_max(amax_bf16x2, _bf16x2_abs(q_bf16x2[i, j])) amax_bf16x2 = cute_utils.warp_reduce( - amax_bf16x2, _bf16x2_max, width=MXFP4_BLOCK_SIZE // 16 + amax_bf16x2, + _bf16x2_max, + width=MXFP4_BLOCK_SIZE // 16, ) - amax0, amax1 = _bf16x2_to_fp32(amax_bf16x2) - amax = cute_utils.fmax(amax0, amax1) + amax_pair = _bf16x2_to_fp32(amax_bf16x2) + amax = cute_utils.fmax(amax_pair[0], amax_pair[1]) if in_bounds: # compute block scale with bit manipulation @@ -348,11 +355,12 @@ def kernel( packed[1] = _fp32x8_to_fp4x8(vals, 8) dst = q_fp4_tile[i, None] + cp_u32x2 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=64) cute.copy(cp_u32x2, packed, cute.recast_tensor(dst, Uint32)) # Weight scaling is independent of the Q subwarp work. The first # num_tokens * num_heads logical threads cover one weight each. - if global_tid < num_token_heads: + if global_tid * 8 < num_token_heads: weight_token_id = global_tid // self.num_heads weight_head_id = global_tid % self.num_heads weights_out[weight_token_id, weight_head_id] = ( From bbf1ff4390cd9d2d4931c6ab020feb23e574f4ca Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 8 May 2026 08:20:58 +0000 Subject: [PATCH 19/21] fix bug with recast_ptr Signed-off-by: Thien Tran --- .../test_fused_indexer_q_rope_quant.py | 2 +- .../fused_indexer_q_cutedsl.py | 79 ++++++++++--------- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/tests/kernels/test_fused_indexer_q_rope_quant.py b/tests/kernels/test_fused_indexer_q_rope_quant.py index be2039ce513e..41a4d0ed0905 100644 --- a/tests/kernels/test_fused_indexer_q_rope_quant.py +++ b/tests/kernels/test_fused_indexer_q_rope_quant.py @@ -122,7 +122,7 @@ def _reference( return q_fp8, weights_out -@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257]) +@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257, 1023]) @pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("use_fp4", [False, True]) @torch.inference_mode() diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index ffeb26893d38..7fbbfe1206b2 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -9,8 +9,8 @@ import torch from cuda.bindings.driver import CUstream from cutlass import BFloat16, Float32, Int64, Uint8, Uint32, const_expr -from cutlass._mlir.dialects import llvm, vector -from cutlass.cutlass_dsl import T, dsl_user_op, ir +from cutlass._mlir.dialects import llvm +from cutlass.cutlass_dsl import T, dsl_user_op from quack.compile_utils import make_fake_tensor from vllm.vllm_flash_attn.cute import utils as cute_utils @@ -40,13 +40,13 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( rope_type = _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype] # compile all variants at first invocation - for tile_head in (1, 4): - IndexerQMxFp4Kernel.compile(head_dim, rope_dim, num_heads, rope_type, tile_head) + for coarsen in (1, 4): + IndexerQMxFp4Kernel.compile(head_dim, rope_dim, num_heads, rope_type, coarsen) # heuristic - tile_head = 1 if num_tokens < 512 else 4 + coarsen = 1 if num_tokens < 256 else 4 compiled = IndexerQMxFp4Kernel.compile( - head_dim, rope_dim, num_heads, rope_type, tile_head + head_dim, rope_dim, num_heads, rope_type, coarsen ) scale = float(index_weights_softmax_scale * index_weights_head_scale) compiled( @@ -61,6 +61,11 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( ) +@dsl_user_op +def _recast_val(x, dtype, *, loc=None, ip=None): + return dtype(llvm.bitcast(dtype.mlir_type, x.ir_value(loc=loc, ip=ip))) + + @dsl_user_op def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: out = llvm.inline_asm( @@ -75,7 +80,7 @@ def _fp32x2_to_bf16x2(a: Float32, b: Float32, *, loc=None, ip=None) -> Uint32: @dsl_user_op -def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> cute.TensorSSA: +def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> tuple[Float32, Float32]: out = llvm.inline_asm( llvm.StructType.get_literal([T.f32(), T.f32()]), [data.ir_value(loc=loc, ip=ip)], @@ -84,16 +89,10 @@ def _bf16x2_to_fp32(data: Uint32, *, loc=None, ip=None) -> cute.TensorSSA: has_side_effects=False, is_align_stack=False, ) - vec = vector.from_elements( - ir.VectorType.get([2], T.f32(), loc=loc), - [ - llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip), - llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip), - ], - loc=loc, - ip=ip, + return ( + Float32(llvm.extractvalue(T.f32(), out, [0], loc=loc, ip=ip)), + Float32(llvm.extractvalue(T.f32(), out, [1], loc=loc, ip=ip)), ) - return cute.TensorSSA(vec, 2, Float32) @dsl_user_op @@ -243,13 +242,17 @@ def kernel( q_bf16x2 = cute.make_rmem_tensor(_layout, Uint32) if in_bounds: - src = cute.local_tile( + # we can't do cute.copy on the whole 2D tile directly because + # cute.recast_tensor() is wrong for 2D strided layout :( + q_tile = cute.local_tile( q[token_id, None, None], tiler=(self.coarsen, 16), coord=(head_tile_id, sublane), ) cp_u32x8 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=256) - cute.copy(cp_u32x8, cute.recast_tensor(src, Uint32), q_bf16x2) + for i in cutlass.range_constexpr(self.coarsen): + src = cute.recast_tensor(q_tile[i, None], Uint32) + cute.copy(cp_u32x8, src, q_bf16x2[i, None]) # RoPE applies only to the trailing rope_dim values. We keep the rounded # BF16 result in q_bits so the later amax and quantization see BF16. @@ -282,18 +285,19 @@ def kernel( cute.copy(cp_u32x4, cute.recast_tensor(cos_src, Uint32), cos_bf16x2) cute.copy(cp_u32x4, cute.recast_tensor(sin_src, Uint32), sin_bf16x2) - cos_vals_view = cute.logical_divide(cos_vals, 2) # (2, 4) - sin_vals_view = cute.logical_divide(sin_vals, 2) - for i in cutlass.range_constexpr(4): - cos_vals_view[None, i].store(_bf16x2_to_fp32(cos_bf16x2[i])) - sin_vals_view[None, i].store(_bf16x2_to_fp32(sin_bf16x2[i])) + cos0, cos1 = _bf16x2_to_fp32(cos_bf16x2[i]) + sin0, sin1 = _bf16x2_to_fp32(sin_bf16x2[i]) + cos_vals[i * 2] = cos0 + cos_vals[i * 2 + 1] = cos1 + sin_vals[i * 2] = sin0 + sin_vals[i * 2 + 1] = sin1 for i in cutlass.range_constexpr(self.coarsen): for j in cutlass.range_constexpr(8): - q_pair = _bf16x2_to_fp32(q_bf16x2[i, j]) - rot0 = q_pair[0] * cos_vals[j] - q_pair[1] * sin_vals[j] - rot1 = q_pair[0] * sin_vals[j] + q_pair[1] * cos_vals[j] + q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j]) + rot0 = q0 * cos_vals[j] - q1 * sin_vals[j] + rot1 = q0 * sin_vals[j] + q1 * cos_vals[j] # convert back to BF16 to match numerics q_bf16x2[i, j] = _fp32x2_to_bf16x2(rot0, rot1) @@ -323,10 +327,9 @@ def kernel( # compute block scale with bit manipulation # UE8M0 stores ceil(log2(fp4_scale)) + 127. Adding the mantissa mask # increments the exponent whenever fp4_scale is not exactly a power of 2 - fp4_scale = cute_utils.fmax(amax, float.fromhex("0x6p-126")) * Float32( - 1.0 / 6.0 - ) - bits = Uint32(llvm.bitcast(T.i32(), fp4_scale.ir_value())) + eps = cutlass.const_expr(float.fromhex("0x6p-126")) + fp4_scale = cute_utils.fmax(amax, eps) * Float32(1.0 / 6.0) + bits = _recast_val(fp4_scale, Uint32) ue8m0 = cute_utils.shr_u32( bits + Uint32(0x7FFFFF), Uint32(23) ) & Uint32(0xFF) @@ -339,15 +342,13 @@ def kernel( # If scale = 2^A and ue8m0 = A + 127, then inverse scale has exponent # -A + 127 = 254 - ue8m0. inv_scale_bits = (Uint32(254) - ue8m0) << Uint32(23) - inv_fp4_scale = Float32( - llvm.bitcast(T.f32(), inv_scale_bits.ir_value()) - ) + inv_fp4_scale = _recast_val(inv_scale_bits, Float32) vals = cute.make_rmem_tensor(16, Float32) for j in cutlass.range_constexpr(8): - tmp = _bf16x2_to_fp32(q_bf16x2[i, j]) - vals[j * 2] = tmp[0] * inv_fp4_scale - vals[j * 2 + 1] = tmp[1] * inv_fp4_scale + q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j]) + vals[j * 2] = q0 * inv_fp4_scale + vals[j * 2 + 1] = q1 * inv_fp4_scale # pack to FP4 packed = cute.make_rmem_tensor((2,), Uint32) @@ -360,7 +361,7 @@ def kernel( # Weight scaling is independent of the Q subwarp work. The first # num_tokens * num_heads logical threads cover one weight each. - if global_tid * 8 < num_token_heads: + if global_tid < num_token_heads: weight_token_id = global_tid // self.num_heads weight_head_id = global_tid % self.num_heads weights_out[weight_token_id, weight_head_id] = ( @@ -374,7 +375,7 @@ def compile( rope_dim: int = 64, num_heads: int = 64, cos_sin_dtype: type[cutlass.Numeric] = Float32, - tile_head: int = 4, + coarsen: int = 4, ): num_tokens = cute.sym_int() max_pos = cute.sym_int() @@ -402,7 +403,7 @@ def compile( weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) kernel = IndexerQMxFp4Kernel( - head_dim, rope_dim, num_heads, cos_sin_dtype, tile_head + head_dim, rope_dim, num_heads, cos_sin_dtype, coarsen ) stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) return cute.compile( From da38a13d3ab60a6fe13a5fad27461e6962ea1b0d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 8 May 2026 08:24:27 +0000 Subject: [PATCH 20/21] fix heuristic Signed-off-by: Thien Tran --- .../v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 7fbbfe1206b2..3bd47fd076de 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -44,7 +44,7 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( IndexerQMxFp4Kernel.compile(head_dim, rope_dim, num_heads, rope_type, coarsen) # heuristic - coarsen = 1 if num_tokens < 256 else 4 + coarsen = 1 if num_tokens < 512 else 4 compiled = IndexerQMxFp4Kernel.compile( head_dim, rope_dim, num_heads, rope_type, coarsen ) From 9d469b8c12b132d96559872bea3be462f5e7ded4 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 8 May 2026 09:19:16 +0000 Subject: [PATCH 21/21] fix comment Signed-off-by: Thien Tran --- .../ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 3bd47fd076de..3ba60bb33102 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -242,8 +242,10 @@ def kernel( q_bf16x2 = cute.make_rmem_tensor(_layout, Uint32) if in_bounds: - # we can't do cute.copy on the whole 2D tile directly because - # cute.recast_tensor() is wrong for 2D strided layout :( + # we can't do cute.copy() on the whole 2D tile directly because + # cute.copy() wants the 1st mode to be covered by the copy atom, + # and other modes as for loop. there is no fast way to + # "transpose" the tensor view. q_tile = cute.local_tile( q[token_id, None, None], tiler=(self.coarsen, 16),