diff --git a/tests/kernels/test_fused_indexer_q_rope_quant.py b/tests/kernels/test_fused_indexer_q_rope_quant.py index 41a4d0ed0905..dd94dafd9585 100644 --- a/tests/kernels/test_fused_indexer_q_rope_quant.py +++ b/tests/kernels/test_fused_indexer_q_rope_quant.py @@ -13,6 +13,9 @@ Expects bit-exact equality on both q_fp8 and weights_out. """ +import contextlib +from unittest import mock + import pytest import torch @@ -20,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( per_token_group_quant_fp8, ) +from vllm.utils.import_utils import has_cutedsl from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import ( fused_indexer_q_rope_quant, ) @@ -125,8 +129,14 @@ def _reference( @pytest.mark.parametrize("num_tokens", [1, 7, 32, 257, 1023]) @pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("use_fp4", [False, True]) +@pytest.mark.parametrize("use_cutedsl", [False, True]) @torch.inference_mode() -def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype, use_fp4): +def test_fused_indexer_q_rope_quant_matches_unfused( + num_tokens, cache_dtype, use_fp4, use_cutedsl +): + if use_cutedsl and not has_cutedsl(): + pytest.skip("cutedsl (cutlass) not installed") + device = "cuda" torch.manual_seed(0) @@ -142,9 +152,26 @@ def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype, use q_quant_ref, weights_ref = _reference( positions, q, cos_sin_cache, weights, softmax_scale, head_scale, use_fp4 ) - q_quant_fused, weights_fused = fused_indexer_q_rope_quant( - positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale, use_fp4 + # use_cutedsl=False: force the triton path even when cutedsl is installed + # by patching the dispatcher's has_cutedsl() binding to return False. + cutedsl_patch = ( + mock.patch( + "vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q.has_cutedsl", + return_value=False, + ) + if not use_cutedsl + else contextlib.nullcontext() ) + with cutedsl_patch: + q_quant_fused, weights_fused = fused_indexer_q_rope_quant( + positions, + q.clone(), + cos_sin_cache, + weights, + softmax_scale, + head_scale, + use_fp4, + ) if use_fp4: q_quant_ref, q_scale_ref = q_quant_ref diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py b/vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py index dddd3f544f8e..6cb53cd07078 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py @@ -117,6 +117,39 @@ def _fp8x4_to_bf16x4(x: Uint32, *, loc=None, ip=None) -> cute.TensorSSA: return cute.TensorSSA(vec, 2, Uint32) +@dsl_user_op +def _fp32x4_to_fp8x4( + a0: Float32, + a1: Float32, + a2: Float32, + a3: Float32, + *, + loc=None, + ip=None, +) -> Uint32: + # Pack four FP32 values into one b32 of four e4m3 bytes, byte order + # {a0, a1, a2, a3} from low to high address. + out = llvm.inline_asm( + T.i32(), + [ + a0.ir_value(loc=loc, ip=ip), + a1.ir_value(loc=loc, ip=ip), + a2.ir_value(loc=loc, ip=ip), + a3.ir_value(loc=loc, ip=ip), + ], + "{\n\t" + ".reg .b16 t0, t1;\n\t" + "cvt.rn.satfinite.e4m3x2.f32 t0, $2, $1;\n\t" + "cvt.rn.satfinite.e4m3x2.f32 t1, $4, $3;\n\t" + "mov.b32 $0, {t0, t1};\n\t" + "}\n", + "=r,f,f,f,f", + has_side_effects=False, + is_align_stack=False, + ) + return Uint32(out) + + @dsl_user_op def _fp32x8_to_fp4x8( vals: cute.Tensor, diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py index ec880f7ab4c4..d9c1e93f2b56 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py @@ -398,24 +398,41 @@ def fused_indexer_q_rope_quant( ), index_weights_out index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn) - _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( - positions, - index_q, - index_q.stride(0), - index_q.stride(1), - index_q_cos_sin_cache, - index_q_cos_sin_cache.stride(0), - index_q_cos_sin_cache.shape[-1] // 2, - index_q_fp8, - index_q_fp8.stride(0), - index_q_fp8.stride(1), - index_q_head_dim, - index_weights, - index_weights.stride(0), - index_weights_softmax_scale, - index_weights_head_scale, - index_weights_out, - index_weights_out.stride(0), - num_warps=1, # TODO: Tune this - ) + if has_cutedsl(): + # lazily import, otherwise some tests fail due to CUDA driver init failure. + from .fused_indexer_q_cutedsl import ( + fused_indexer_q_rope_quant_fp8_cutedsl, + ) + + fused_indexer_q_rope_quant_fp8_cutedsl( + positions, + index_q, + index_q_cos_sin_cache, + index_weights, + index_weights_softmax_scale, + index_weights_head_scale, + index_q_fp8, + index_weights_out, + ) + else: + _fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)]( + positions, + index_q, + index_q.stride(0), + index_q.stride(1), + index_q_cos_sin_cache, + index_q_cos_sin_cache.stride(0), + index_q_cos_sin_cache.shape[-1] // 2, + index_q_fp8, + index_q_fp8.stride(0), + index_q_fp8.stride(1), + index_q_head_dim, + index_weights, + index_weights.stride(0), + index_weights_softmax_scale, + index_weights_head_scale, + index_weights_out, + index_weights_out.stride(0), + num_warps=1, # TODO: Tune this + ) return index_q_fp8, index_weights_out diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py index 4468a95651bf..01ace330855d 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q_cutedsl.py @@ -14,6 +14,7 @@ _bf16x2_max, _bf16x2_to_fp32, _fp32x2_to_bf16x2, + _fp32x4_to_fp8x4, _fp32x8_to_fp4x8, _recast_val, ) @@ -65,8 +66,48 @@ def fused_indexer_q_rope_quant_mxfp4_cutedsl( ) -class IndexerQMxFp4Kernel: - """Eight-thread subwarps process one ``(token, head)`` row.""" +def fused_indexer_q_rope_quant_fp8_cutedsl( + positions: torch.Tensor, + index_q: torch.Tensor, + index_q_cos_sin_cache: torch.Tensor, + index_weights: torch.Tensor, + index_weights_softmax_scale: float, + index_weights_head_scale: float, + index_q_fp8: torch.Tensor, + index_weights_out: torch.Tensor, +) -> None: + num_tokens, num_heads, head_dim = index_q.shape + rope_dim = index_q_cos_sin_cache.shape[-1] + rope_type = _TORCH_TO_CUTE[index_q_cos_sin_cache.dtype] + + for coarsen in (1, 4): + IndexerQFp8Kernel.compile(head_dim, rope_dim, num_heads, rope_type, coarsen) + + coarsen = 1 if num_tokens < 512 else 4 + compiled = IndexerQFp8Kernel.compile( + head_dim, rope_dim, num_heads, rope_type, coarsen + ) + scale = float(index_weights_softmax_scale * index_weights_head_scale) + # The cute kernel treats the FP8 buffer as raw bytes (Uint8). + compiled( + positions, + index_q, + index_q_cos_sin_cache, + index_weights, + index_q_fp8.view(torch.uint8), + index_weights_out, + scale, + ) + + +class IndexerQRopeQuantKernel: + """Shared infrastructure for indexer-Q RoPE+quant fused kernels. + + Subclasses implement ``kernel`` for a particular Q quantization scheme + (MXFP4, FP8 e4m3, …). The base class owns the launch geometry and the + common preamble: thread/token addressing, the BF16 Q load, and the + interleaved-RoPE pass over the trailing ``rope_dim`` lanes. + """ def __init__( self, @@ -94,47 +135,27 @@ def __init__( self.threads_per_token = (self.num_heads // self.coarsen) * self.subwarp_size @cute.jit - def __call__( + def _load_q_and_rope( self, positions: cute.Tensor, q: cute.Tensor, cos_sin_cache: cute.Tensor, - weights: cute.Tensor, - q_fp4: cute.Tensor, - q_scale: cute.Tensor, - weights_out: cute.Tensor, - scale: Float32, - stream: CUstream, - ): - total_threads = q.shape[0] * self.threads_per_token - grid = (cute.ceil_div(total_threads, self.tb_size), 1, 1) - self.kernel( - positions, - q, - cos_sin_cache, - weights, - q_fp4, - q_scale, - weights_out, - scale, - ).launch(grid=grid, block=(self.tb_size, 1, 1), stream=stream) - - @cute.kernel - def kernel( - self, - positions: cute.Tensor, - q: cute.Tensor, - cos_sin_cache: cute.Tensor, - weights: cute.Tensor, - q_fp4: cute.Tensor, - q_scale: cute.Tensor, - weights_out: cute.Tensor, - scale: Float32, ): + """Compute thread indices, load Q (BF16), and apply interleaved RoPE. + + Returns a tuple + (q_bf16x2, tid, global_tid, sublane, token_id, head_tile_id, + head_start, in_bounds, num_token_heads) + where ``q_bf16x2`` is a (coarsen, 8) rmem tile of Uint32 packed + bf16x2 pairs covering the 16 BF16 lanes owned by this thread for + each of ``coarsen`` heads. RoPE is applied in place to the + trailing ``rope_dim`` lanes; the leading nope lanes pass through. + """ block_id, _, _ = cute.arch.block_idx() tid, _, _ = cute.arch.thread_idx() - num_token_heads = q.shape[0] * self.num_heads + num_tokens = q.shape[0] + num_token_heads = num_tokens * self.num_heads global_tid = block_id * self.tb_size + tid global_subwarp_id = global_tid // self.subwarp_size @@ -150,7 +171,7 @@ def kernel( # must_in_bounds is constexpr, True when 1 threadblock fit within 1 token # position. the compiler will remove bounds check when that happens. must_in_bounds = cutlass.const_expr(self.tb_size % self.threads_per_token == 0) - in_bounds = must_in_bounds or (token_id < q.shape[0]) + in_bounds = must_in_bounds or (token_id < num_tokens) cp_op = cute.nvgpu.CopyUniversalOp() @@ -219,9 +240,77 @@ def kernel( # convert back to BF16 to match numerics q_bf16x2[i, j] = _fp32x2_to_bf16x2(rot0, rot1) + return ( + q_bf16x2, + tid, + global_tid, + sublane, + token_id, + head_tile_id, + head_start, + in_bounds, + num_token_heads, + ) + + +class IndexerQMxFp4Kernel(IndexerQRopeQuantKernel): + """Eight-thread subwarps process one ``(token, head)`` row.""" + + @cute.jit + def __call__( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_quant: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + stream: CUstream, + ): + total_threads = q.shape[0] * self.threads_per_token + grid = (cute.ceil_div(total_threads, self.tb_size), 1, 1) + self.kernel( + positions, + q, + cos_sin_cache, + weights, + q_quant, + q_scale, + weights_out, + scale, + ).launch(grid=grid, block=(self.tb_size, 1, 1), stream=stream) + + @cute.kernel + def kernel( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_quant: cute.Tensor, + q_scale: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + ( + q_bf16x2, + tid, + global_tid, + sublane, + token_id, + head_tile_id, + head_start, + in_bounds, + num_token_heads, + ) = self._load_q_and_rope(positions, q, cos_sin_cache) + + cp_op = cute.nvgpu.CopyUniversalOp() + # layout: [coarsen, 8] q_fp4_tile = cute.local_tile( - q_fp4[token_id, None, None], + q_quant[token_id, None, None], tiler=(self.coarsen, 8), coord=(head_tile_id, sublane), ) @@ -337,3 +426,188 @@ def compile( stream, options="--enable-tvm-ffi", ) + + +class IndexerQFp8Kernel(IndexerQRopeQuantKernel): + """Eight-thread subwarps process one ``(token, head)`` row and emit + float8 e4m3fn with a single per-(token, head) scalar scale folded + into the per-token weight (mirrors ``_fused_indexer_q_rope_quant_kernel``). + """ + + def __init__( + self, + head_dim: int = 128, + rope_dim: int = 64, + num_heads: int = 64, + cos_sin_dtype: type[cutlass.Numeric] = Float32, + coarsen: int = 4, + ): + super().__init__(head_dim, rope_dim, num_heads, cos_sin_dtype, coarsen) + # Each subwarp owns `coarsen` heads; we use the first `coarsen` + # threads of the subwarp to write the per-head weights using the + # fp8 scale computed in the matching loop iteration. + assert self.coarsen <= self.subwarp_size, ( + f"FP8 kernel requires coarsen ({self.coarsen}) <= " + f"subwarp_size ({self.subwarp_size}) for the weight-fold step" + ) + + @cute.jit + def __call__( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp8: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + stream: CUstream, + ): + total_threads = q.shape[0] * self.threads_per_token + grid = (cute.ceil_div(total_threads, self.tb_size), 1, 1) + self.kernel( + positions, + q, + cos_sin_cache, + weights, + q_fp8, + weights_out, + scale, + ).launch(grid=grid, block=(self.tb_size, 1, 1), stream=stream) + + @cute.kernel + def kernel( + self, + positions: cute.Tensor, + q: cute.Tensor, + cos_sin_cache: cute.Tensor, + weights: cute.Tensor, + q_fp8: cute.Tensor, + weights_out: cute.Tensor, + scale: Float32, + ): + ( + q_bf16x2, + _tid, + _global_tid, + sublane, + token_id, + head_tile_id, + head_start, + in_bounds, + _num_token_heads, + ) = self._load_q_and_rope(positions, q, cos_sin_cache) + + cp_op = cute.nvgpu.CopyUniversalOp() + + # layout: [coarsen, 16] bytes (one e4m3fn per element). + q_fp8_tile = cute.local_tile( + q_fp8[token_id, None, None], + tiler=(self.coarsen, 16), + coord=(head_tile_id, sublane), + ) + + for i in cutlass.range_constexpr(self.coarsen): + # Reduce amax across the full head_dim: each thread already holds + # the max over its 16 lanes; a width=subwarp_size warp shuffle + # spreads the head-wide max to every lane in the subwarp. + amax_bf16x2 = _bf16x2_abs(q_bf16x2[i, 0]) + for j in cutlass.range_constexpr(1, 8): + amax_bf16x2 = _bf16x2_max(amax_bf16x2, _bf16x2_abs(q_bf16x2[i, j])) + amax_bf16x2 = cute_utils.warp_reduce( + amax_bf16x2, + _bf16x2_max, + width=self.subwarp_size, + ) + amax_pair = _bf16x2_to_fp32(amax_bf16x2) + amax = cute_utils.fmax(amax_pair[0], amax_pair[1]) + + # scale = max(amax, eps) / fp8_max, then rounded UP to the next + # power of two. Adding the mantissa mask before shifting out the + # mantissa bumps the exponent whenever s isn't a pure pow2. + fp32_scale = cute_utils.fmax(amax, Float32(1e-4)) * Float32(1.0 / 448.0) + bits = _recast_val(fp32_scale, Uint32) + scale_exp = cute_utils.shr_u32( + bits + Uint32(0x7FFFFF), Uint32(23) + ) & Uint32(0xFF) + + # rounded scale = 2^(scale_exp - 127); bit pattern is scale_exp << 23 + fp8_scale_bits = scale_exp << Uint32(23) + fp8_scale = _recast_val(fp8_scale_bits, Float32) + # inverse = 2^-(scale_exp - 127); bit pattern is (254 - scale_exp) << 23 + inv_scale_bits = (Uint32(254) - scale_exp) << Uint32(23) + inv_fp8_scale = _recast_val(inv_scale_bits, Float32) + + # Weight fold: weights_out = weights * q_scale * scale_combined. + # All threads in the subwarp share the same fp8_scale after the + # warp_reduce above, so we let thread `sublane == i` write the + # weight for head `head_start + i`. + if in_bounds and sublane == i: + head_id = head_start + i + weights_out[token_id, head_id] = ( + weights[token_id, head_id].to(Float32) * scale * fp8_scale + ) + + if in_bounds: + # 16 BF16 → 16 e4m3 bytes per thread, packed into 4 b32s + # (one cp.async-shaped 128-bit store per row). + packed = cute.make_rmem_tensor((4,), Uint32) + for j in cutlass.range_constexpr(4): + q0, q1 = _bf16x2_to_fp32(q_bf16x2[i, j * 2]) + q2, q3 = _bf16x2_to_fp32(q_bf16x2[i, j * 2 + 1]) + packed[j] = _fp32x4_to_fp8x4( + q0 * inv_fp8_scale, + q1 * inv_fp8_scale, + q2 * inv_fp8_scale, + q3 * inv_fp8_scale, + ) + + dst = q_fp8_tile[i, None] + cp_u32x4 = cute.make_copy_atom(cp_op, Uint32, num_bits_per_copy=128) + cute.copy(cp_u32x4, packed, cute.recast_tensor(dst, Uint32)) + + @cache + @staticmethod + def compile( + head_dim: int = 128, + rope_dim: int = 64, + num_heads: int = 64, + cos_sin_dtype: type[cutlass.Numeric] = Float32, + coarsen: int = 4, + ): + num_tokens = cute.sym_int() + max_pos = cute.sym_int() + + q = make_fake_tensor( + BFloat16, (num_tokens, num_heads, head_dim), divisibility=16 + ) + positions = make_fake_tensor(Int64, (num_tokens,), divisibility=1) + cos_sin_cache = make_fake_tensor( + cos_sin_dtype, + (max_pos, rope_dim), + divisibility=8, + ) + weights = make_fake_tensor(BFloat16, (num_tokens, num_heads), divisibility=8) + q_fp8 = make_fake_tensor( + Uint8, + (num_tokens, num_heads, head_dim), + divisibility=16, + ) + weights_out = make_fake_tensor(Float32, (num_tokens, num_heads), divisibility=4) + + kernel = IndexerQFp8Kernel( + head_dim, rope_dim, num_heads, cos_sin_dtype, coarsen + ) + stream = cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True) + return cute.compile( + kernel, + positions, + q, + cos_sin_cache, + weights, + q_fp8, + weights_out, + Float32(0.0), + stream, + options="--enable-tvm-ffi", + )