From 37734f4306a7f0f7ceebd5e94eb566d92f163e98 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Thu, 21 May 2026 05:06:54 -0500 Subject: [PATCH 1/9] replacing sparse mla triton kernel with an optimized HIP version Signed-off-by: Hemanth Acharya --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 981 +++++++++++++++++- 1 file changed, 958 insertions(+), 23 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 8bda5bb7a86b..a2cd873b13a0 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -2,11 +2,16 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import importlib +import logging import math +import os +import pathlib +import tempfile from importlib.util import find_spec import torch import torch.nn.functional as F +from torch.utils.cpp_extension import load_inline from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.forward_context import get_forward_context @@ -1650,6 +1655,903 @@ def rocm_sparse_attn_prefill( output.copy_(output_chunk.to(output.dtype)) +# ============================================================================ +# HIP MFMA kernel implementation for sparse-MLA decode. +# ============================================================================ + +_HIP_SPARSE_MLA_DECODE_SRC = r""" +#include +#include +#include + +#include +#include + +using bf16x8 = __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16; +using fx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + +static constexpr int NOPE_DIM = 448; +static constexpr int ROPE_DIM = 64; +static constexpr int TOKEN_BYTES = 576; +static constexpr int SCALE_BYTES = 8; +static constexpr int HEAD_DIM = 512; +static constexpr int BLOCK_H = 16; +static constexpr int BLOCK_K = 32; +static constexpr int N_TILES = HEAD_DIM / 16; // 32 +static constexpr int QK_N_TILES = BLOCK_K / 16; // 2 + +__device__ __forceinline__ fx4 mfma_16x16x32_bf16( + bf16x8 a, bf16x8 b, fx4 c) { + return __builtin_amdgcn_mfma_f32_16x16x32_bf16(a, b, c, 0, 0, 0); +} + +__device__ __forceinline__ void gather_and_dequant_k_tile( + int k_start, int k_len, const uint8_t* cache_base, + int64_t cache_stride0, int num_rows, int block_size, + const int32_t* idx_base, + __bf16* k_lds, int8_t* kv_lds, int tid) +{ + const int tok_id = tid >> 3; // 0..31 + const int chunk = tid & 7; // 0..7 + const int col0 = chunk * 64; + + int k_pos = k_start + tok_id; + bool in_range = (k_pos < k_len); + int slot = in_range ? idx_base[k_pos] : 0; + bool valid = in_range && (slot >= 0) && (slot < num_rows); + int safe_slot = valid ? slot : 0; + int bi = safe_slot / block_size; + int pib = safe_slot - bi * block_size; + const uint8_t* block_ptr = cache_base + + (int64_t)bi * cache_stride0; + const uint8_t* token_ptr = block_ptr + pib * TOKEN_BYTES; + + __bf16* dst_row = &k_lds[tok_id * HEAD_DIM + col0]; + + if (!valid) { + int4 z; z.x = z.y = z.z = z.w = 0; + int4* d4 = reinterpret_cast(dst_row); + #pragma unroll + for (int j = 0; j < 8; ++j) d4[j] = z; + } else if (col0 < NOPE_DIM) { + const uint8_t* scale_ptr = block_ptr + + block_size * TOKEN_BYTES + + pib * SCALE_BYTES; + uint8_t scl_u = scale_ptr[chunk]; + union { uint32_t u; float fv; } sb; + sb.u = ((uint32_t)scl_u) << 23; + float scl_f = sb.fv; + + const uint32_t* src32 = reinterpret_cast( + token_ptr + col0); + #pragma unroll + for (int u32_i = 0; u32_i < 16; ++u32_i) { + uint32_t word = src32[u32_i]; + #pragma unroll + for (int b = 0; b < 4; ++b) { + uint8_t kb = (word >> (b * 8)) & 0xFF; + uint32_t packed = (uint32_t)kb; + float f = __builtin_amdgcn_cvt_f32_fp8(packed, 0) * scl_f; + dst_row[u32_i * 4 + b] = (__bf16)f; + } + } + } else { + const int4* src4 = reinterpret_cast( + token_ptr + NOPE_DIM); + int4* d4 = reinterpret_cast(dst_row); + #pragma unroll + for (int j = 0; j < 8; ++j) d4[j] = src4[j]; + } + + if (tid < BLOCK_K) { + int kp = k_start + tid; + int sl = (kp < k_len) ? idx_base[kp] : -1; + kv_lds[tid] = (kp < k_len) && (sl >= 0) && (sl < num_rows) ? 1 : 0; + } +} + + +constexpr int N_TILES_PER_WAVE = 8; // 32 N-tiles / 4 waves +__device__ __forceinline__ void process_k_tile( + const __bf16* q_lds, const __bf16* k_lds, const int8_t* kv_lds, + __bf16* p_lds, float* scores_lds, + float* m_state, float* l_state, fx4* acc, float scale, + int lane, int m_a, int kg, int n_b, int m_d_base, int n_d, + int wave) +{ + if (wave == 0) { + fx4 qk[2] = {{0.f, 0.f, 0.f, 0.f}, {0.f, 0.f, 0.f, 0.f}}; + #pragma unroll + for (int c = 0; c < HEAD_DIM / 32; ++c) { + bf16x8 q_reg; + const __bf16* q_src = &q_lds[m_a * HEAD_DIM + c * 32 + kg * 8]; + #pragma unroll + for (int i = 0; i < 8; ++i) q_reg[i] = q_src[i]; + + #pragma unroll + for (int nt = 0; nt < 2; ++nt) { + bf16x8 k_reg; + const __bf16* k_src = &k_lds[(nt * 16 + n_b) * HEAD_DIM + + c * 32 + kg * 8]; + #pragma unroll + for (int i = 0; i < 8; ++i) k_reg[i] = k_src[i]; + qk[nt] = mfma_16x16x32_bf16(q_reg, k_reg, qk[nt]); + } + } + #pragma unroll + for (int nt = 0; nt < 2; ++nt) { + #pragma unroll + for (int i = 0; i < 4; ++i) { + int k_col = nt * 16 + n_d; + float s = qk[nt][i] * scale; + if (!kv_lds[k_col]) s = -3.4028234663852886e38f; + scores_lds[(m_d_base + i) * BLOCK_K + nt * 16 + n_d] = s; + } + } + } + + __syncthreads(); + + fx4 qk_local[2]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + qk_local[0][i] = scores_lds[(m_d_base + i) * BLOCK_K + n_d]; + qk_local[1][i] = scores_lds[(m_d_base + i) * BLOCK_K + 16 + n_d]; + } + + fx4 p[2]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + float row_max = fmaxf(qk_local[0][i], qk_local[1][i]); + row_max = fmaxf(row_max, __shfl_xor(row_max, 1)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 2)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 4)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 8)); + + float m_new = fmaxf(m_state[i], row_max); + float alpha = __builtin_amdgcn_exp2f( + (m_state[i] - m_new) * 1.4426950408889634f); + + float e0 = __builtin_amdgcn_exp2f( + (qk_local[0][i] - m_new) * 1.4426950408889634f); + float e1 = __builtin_amdgcn_exp2f( + (qk_local[1][i] - m_new) * 1.4426950408889634f); + + float row_sum = e0 + e1; + row_sum += __shfl_xor(row_sum, 1); + row_sum += __shfl_xor(row_sum, 2); + row_sum += __shfl_xor(row_sum, 4); + row_sum += __shfl_xor(row_sum, 8); + + float l_new = l_state[i] * alpha + row_sum; + p[0][i] = e0; + p[1][i] = e1; + + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WAVE; ++nt) acc[nt][i] *= alpha; + + m_state[i] = m_new; + l_state[i] = l_new; + } + + if (wave == 0) { + #pragma unroll + for (int i = 0; i < 4; ++i) { + p_lds[(m_d_base + i) * BLOCK_K + n_d] = (__bf16)p[0][i]; + p_lds[(m_d_base + i) * BLOCK_K + 16 + n_d] = (__bf16)p[1][i]; + } + } + + __syncthreads(); + + bf16x8 p_reg; + const __bf16* p_src = &p_lds[m_a * BLOCK_K + kg * 8]; + #pragma unroll + for (int i = 0; i < 8; ++i) p_reg[i] = p_src[i]; + + #pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + bf16x8 k_reg; + #pragma unroll + for (int i = 0; i < 8; ++i) { + k_reg[i] = k_lds[(kg * 8 + i) * HEAD_DIM + + n_tile * 16 + n_b]; + } + acc[nt_local] = mfma_16x16x32_bf16(p_reg, k_reg, acc[nt_local]); + } +} + + +__device__ __forceinline__ void load_q( + const __bf16* q, int64_t q_stride0, int64_t q_stride1, + int query, int pid_h, int num_heads, + __bf16* q_lds, int tid) +{ + const int qh = tid >> 4; // 0..15 + const int qc0 = (tid & 15) << 5; // 0,32,...,480 + const int head_global = pid_h * BLOCK_H + qh; + __bf16* dst = &q_lds[qh * HEAD_DIM + qc0]; + if (head_global < num_heads) { + const __bf16* src = q + query * q_stride0 + + head_global * q_stride1 + qc0; + const int4* s4 = reinterpret_cast(src); + int4* d4 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < 4; ++i) d4[i] = s4[i]; + } else { + int4 z; z.x = z.y = z.z = z.w = 0; + int4* d4 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < 4; ++i) d4[i] = z; + } +} + + +template +__global__ __launch_bounds__(256, 2) +void sparse_mla_decode_kernel( + const __bf16* __restrict__ q, + const uint8_t* __restrict__ main_cache, + const int32_t* __restrict__ main_indices, + const int32_t* __restrict__ main_indptr, + const uint8_t* __restrict__ extra_cache, + const int32_t* __restrict__ extra_indices, + const int32_t* __restrict__ extra_indptr, + const float* __restrict__ attn_sink, + __bf16* __restrict__ output, + int64_t q_stride0, int64_t q_stride1, + int64_t out_stride0, int64_t out_stride1, + int64_t main_cache_stride0, int64_t extra_cache_stride0, + int main_num_rows, int extra_num_rows, + int main_block_size, int extra_block_size, + float scale, int num_heads) +{ + const int query = blockIdx.x; + const int pid_h = blockIdx.y; + const int tid = threadIdx.x; + const int wave = tid >> 6; + const int lane = tid & 63; + + const int m_a = lane & 15; + const int kg = lane >> 4; + const int n_b = lane & 15; + const int m_d_base = (lane >> 4) * 4; + const int n_d = lane & 15; + + __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; + __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; + __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; + __shared__ float scores_lds[BLOCK_H * BLOCK_K]; + __shared__ int8_t kv_lds[BLOCK_K]; + __shared__ char force_1wg_per_cu[48 * 1024]; // pads LDS to ~96 KB + (void)force_1wg_per_cu; + + load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); + + float m_state[4], l_state[4]; + fx4 acc[N_TILES_PER_WAVE]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + m_state[i] = -3.4028234663852886e38f; + l_state[i] = 0.f; + } + #pragma unroll + for (int i = 0; i < N_TILES_PER_WAVE; ++i) { + acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; + } + + __syncthreads(); + + { + int main_start = main_indptr[query]; + int main_end = main_indptr[query + 1]; + int main_len = main_end - main_start; + for (int k_start = 0; k_start < main_len; k_start += BLOCK_K) { + gather_and_dequant_k_tile( + k_start, main_len, main_cache, main_cache_stride0, + main_num_rows, main_block_size, + main_indices + main_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, + m_state, l_state, acc, scale, + lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + if (HAS_EXTRA) { + int extra_start = extra_indptr[query]; + int extra_end = extra_indptr[query + 1]; + int extra_len = extra_end - extra_start; + for (int k_start = 0; k_start < extra_len; k_start += BLOCK_K) { + gather_and_dequant_k_tile( + k_start, extra_len, extra_cache, extra_cache_stride0, + extra_num_rows, extra_block_size, + extra_indices + extra_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, + m_state, l_state, acc, scale, + lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + { + #pragma unroll + for (int i = 0; i < 4; ++i) { + int head_local = m_d_base + i; + int head_global = pid_h * BLOCK_H + head_local; + if (head_global >= num_heads) continue; + + float m_final = m_state[i]; + float l_final = l_state[i]; + float alpha_final = 1.f; + if (HAS_ATTN_SINK) { + float sink_val = attn_sink[head_global]; + m_final = fmaxf(m_state[i], sink_val); + alpha_final = __builtin_amdgcn_exp2f( + (m_state[i] - m_final) * 1.4426950408889634f); + l_final = l_state[i] * alpha_final + __builtin_amdgcn_exp2f( + (sink_val - m_final) * 1.4426950408889634f); + } + float denom = fmaxf(l_final, 1.0e-30f); + bool live = (l_final > 0.f); + + __bf16* out_row = output + query * out_stride0 + + head_global * out_stride1; + #pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + int col = n_tile * 16 + n_d; + float v = live ? (acc[nt_local][i] * alpha_final) / denom : 0.f; + out_row[col] = (__bf16)v; + } + } + } +} + + +template +__global__ __launch_bounds__(256, 2) +void sparse_mla_decode_partial_kernel( + const __bf16* __restrict__ q, + const uint8_t* __restrict__ main_cache, + const int32_t* __restrict__ main_indices, + const int32_t* __restrict__ main_indptr, + const uint8_t* __restrict__ extra_cache, + const int32_t* __restrict__ extra_indices, + const int32_t* __restrict__ extra_indptr, + float* __restrict__ scratch_m, + float* __restrict__ scratch_l, + __bf16* __restrict__ scratch_acc, + int64_t q_stride0, int64_t q_stride1, + int64_t main_cache_stride0, int64_t extra_cache_stride0, + int main_num_rows, int extra_num_rows, + int main_block_size, int extra_block_size, + float scale, int num_heads, int num_head_blocks) +{ + const int query = blockIdx.x; + const int pid_hs = blockIdx.y; + const int pid_split = pid_hs / num_head_blocks; + const int pid_h = pid_hs - pid_split * num_head_blocks; + const int tid = threadIdx.x; + const int wave = tid >> 6; + const int lane = tid & 63; + + const int m_a = lane & 15; + const int kg = lane >> 4; + const int n_b = lane & 15; + const int m_d_base = (lane >> 4) * 4; + const int n_d = lane & 15; + + __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; + __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; + __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; + __shared__ float scores_lds[BLOCK_H * BLOCK_K]; + __shared__ int8_t kv_lds[BLOCK_K]; + load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); + + float m_state[4], l_state[4]; + fx4 acc[N_TILES_PER_WAVE]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + m_state[i] = -3.4028234663852886e38f; + l_state[i] = 0.f; + } + #pragma unroll + for (int i = 0; i < N_TILES_PER_WAVE; ++i) { + acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; + } + + __syncthreads(); + + { + int main_start = main_indptr[query]; + int main_end = main_indptr[query + 1]; + int main_len = main_end - main_start; + for (int k_start = pid_split * BLOCK_K; k_start < main_len; + k_start += BLOCK_K * SPLIT_K) { + gather_and_dequant_k_tile( + k_start, main_len, main_cache, main_cache_stride0, + main_num_rows, main_block_size, + main_indices + main_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, + m_state, l_state, acc, scale, + lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + if (HAS_EXTRA) { + int extra_start = extra_indptr[query]; + int extra_end = extra_indptr[query + 1]; + int extra_len = extra_end - extra_start; + for (int k_start = pid_split * BLOCK_K; k_start < extra_len; + k_start += BLOCK_K * SPLIT_K) { + gather_and_dequant_k_tile( + k_start, extra_len, extra_cache, extra_cache_stride0, + extra_num_rows, extra_block_size, + extra_indices + extra_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, + m_state, l_state, acc, scale, + lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + pid_split; + + if (wave == 0 && n_d == 0) { + #pragma unroll + for (int i = 0; i < 4; ++i) { + int idx = triple * BLOCK_H + m_d_base + i; + scratch_m[idx] = m_state[i]; + scratch_l[idx] = l_state[i]; + } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < 4; ++i) { + int row = m_d_base + i; + #pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + int col = n_tile * 16 + n_d; + k_lds[row * HEAD_DIM + col] = (__bf16)acc[nt_local][i]; + } + } + __syncthreads(); + + { + int my_row = tid >> 4; + int my_col0 = (tid & 15) << 5; + __bf16* dst = scratch_acc + (int64_t)triple * BLOCK_H * HEAD_DIM + + my_row * HEAD_DIM + my_col0; + const int4* src4 = reinterpret_cast( + &k_lds[my_row * HEAD_DIM + my_col0]); + int4* dst4 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < 4; ++i) dst4[i] = src4[i]; + } +} + + +template +__global__ __launch_bounds__(256, 4) +void sparse_mla_decode_reduce_kernel( + const float* __restrict__ scratch_m, + const float* __restrict__ scratch_l, + const __bf16* __restrict__ scratch_acc, + const float* __restrict__ attn_sink, + __bf16* __restrict__ output, + int64_t out_stride0, int64_t out_stride1, + int num_heads, int num_head_blocks) +{ + const int query = blockIdx.x; + const int pid_h = blockIdx.y; + const int tid = threadIdx.x; + + const int my_row = tid >> 4; // 0..15 + const int my_col0 = (tid & 15) << 5; // 0,32,...,480 + const int head_global = pid_h * BLOCK_H + my_row; + + float m_merged = -3.4028234663852886e38f; + float l_merged = 0.f; + float acc_merged[32]; + #pragma unroll + for (int i = 0; i < 32; ++i) acc_merged[i] = 0.f; + + #pragma unroll + for (int s = 0; s < SPLIT_K; ++s) { + const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + s; + float m_s = scratch_m[triple * BLOCK_H + my_row]; + float l_s = scratch_l[triple * BLOCK_H + my_row]; + + float m_new = fmaxf(m_merged, m_s); + float alpha = __builtin_amdgcn_exp2f( + (m_merged - m_new) * 1.4426950408889634f); + float beta = __builtin_amdgcn_exp2f( + (m_s - m_new) * 1.4426950408889634f); + l_merged = l_merged * alpha + l_s * beta; + m_merged = m_new; + + const __bf16* acc_base = scratch_acc + + (int64_t)triple * BLOCK_H * HEAD_DIM + + my_row * HEAD_DIM + my_col0; + const int4* src4 = reinterpret_cast(acc_base); + #pragma unroll + for (int i = 0; i < 4; ++i) { + int4 v = src4[i]; + __bf16 vbf[8]; + *reinterpret_cast(vbf) = v; + #pragma unroll + for (int j = 0; j < 8; ++j) { + float a_s = (float)vbf[j]; + acc_merged[i * 8 + j] = acc_merged[i * 8 + j] * alpha + + a_s * beta; + } + } + } + + if (head_global >= num_heads) return; + + float m_final = m_merged; + float l_final = l_merged; + float alpha_final = 1.f; + if (HAS_ATTN_SINK) { + float sink_val = attn_sink[head_global]; + m_final = fmaxf(m_merged, sink_val); + alpha_final = __builtin_amdgcn_exp2f( + (m_merged - m_final) * 1.4426950408889634f); + l_final = l_merged * alpha_final + __builtin_amdgcn_exp2f( + (sink_val - m_final) * 1.4426950408889634f); + } + float denom = fmaxf(l_final, 1.0e-30f); + bool live = (l_final > 0.f); + float inv_denom = live ? (alpha_final / denom) : 0.f; + + __bf16* out_row = output + query * out_stride0 + + head_global * out_stride1 + my_col0; + __bf16 out_buf[32]; + #pragma unroll + for (int i = 0; i < 32; ++i) out_buf[i] = (__bf16)(acc_merged[i] * inv_denom); + int4* dst4 = reinterpret_cast(out_row); + const int4* sb4 = reinterpret_cast(out_buf); + #pragma unroll + for (int i = 0; i < 4; ++i) dst4[i] = sb4[i]; +} + + +void sparse_mla_decode_single( + torch::Tensor q, + torch::Tensor main_cache, + torch::Tensor main_indices, + torch::Tensor main_indptr, + torch::Tensor extra_cache, + torch::Tensor extra_indices, + torch::Tensor extra_indptr, + c10::optional attn_sink, + torch::Tensor output, + int64_t main_block_size, + int64_t extra_block_size, + int64_t main_num_rows, + int64_t extra_num_rows, + double scale_d, + bool has_extra) +{ + const int num_queries = q.size(0); + const int num_heads = q.size(1); + const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; + const float scale_f = (float)scale_d; + const bool has_sink = attn_sink.has_value(); + + dim3 grid(num_queries, num_head_blocks); + dim3 block(256); + + const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); + const uint8_t* mc_ptr = reinterpret_cast(main_cache.data_ptr()); + const uint8_t* ec_ptr = reinterpret_cast(extra_cache.data_ptr()); + const int32_t* mi_ptr = main_indices.data_ptr(); + const int32_t* mip_ptr = main_indptr.data_ptr(); + const int32_t* ei_ptr = extra_indices.data_ptr(); + const int32_t* eip_ptr = extra_indptr.data_ptr(); + __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); + const float* sink_ptr = has_sink + ? attn_sink.value().data_ptr() : nullptr; + + auto stream = at::cuda::getCurrentCUDAStream(); + + #define LAUNCH(HAS_S, HAS_E) do { \ + sparse_mla_decode_kernel<<>>( \ + q_ptr, mc_ptr, mi_ptr, mip_ptr, \ + ec_ptr, ei_ptr, eip_ptr, sink_ptr, out_ptr, \ + q.stride(0), q.stride(1), \ + output.stride(0), output.stride(1), \ + main_cache.stride(0), extra_cache.stride(0), \ + main_num_rows, extra_num_rows, \ + main_block_size, extra_block_size, \ + scale_f, num_heads); \ + } while (0) + + if (has_sink && has_extra) LAUNCH(true, true); + else if (has_sink) LAUNCH(true, false); + else if (has_extra) LAUNCH(false, true); + else LAUNCH(false, false); + + #undef LAUNCH +} + + +void sparse_mla_decode_split( + torch::Tensor q, + torch::Tensor main_cache, + torch::Tensor main_indices, + torch::Tensor main_indptr, + torch::Tensor extra_cache, + torch::Tensor extra_indices, + torch::Tensor extra_indptr, + c10::optional attn_sink, + torch::Tensor output, + torch::Tensor scratch_m, + torch::Tensor scratch_l, + torch::Tensor scratch_acc, + int64_t main_block_size, + int64_t extra_block_size, + int64_t main_num_rows, + int64_t extra_num_rows, + double scale_d, + bool has_extra, + int64_t split_k) +{ + const int num_queries = q.size(0); + const int num_heads = q.size(1); + const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; + const float scale_f = (float)scale_d; + const bool has_sink = attn_sink.has_value(); + + dim3 grid_p(num_queries, num_head_blocks * (int)split_k); + dim3 grid_r(num_queries, num_head_blocks); + dim3 block_p(256); + dim3 block_r(256); + + const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); + const uint8_t* mc_ptr = reinterpret_cast(main_cache.data_ptr()); + const uint8_t* ec_ptr = reinterpret_cast(extra_cache.data_ptr()); + const int32_t* mi_ptr = main_indices.data_ptr(); + const int32_t* mip_ptr = main_indptr.data_ptr(); + const int32_t* ei_ptr = extra_indices.data_ptr(); + const int32_t* eip_ptr = extra_indptr.data_ptr(); + __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); + float* sm_ptr = scratch_m.data_ptr(); + float* sl_ptr = scratch_l.data_ptr(); + __bf16* sa_ptr = reinterpret_cast<__bf16*>(scratch_acc.data_ptr()); + const float* sink_ptr = has_sink + ? attn_sink.value().data_ptr() : nullptr; + + auto stream = at::cuda::getCurrentCUDAStream(); + + #define LAUNCH_P(HAS_E, SK) do { \ + sparse_mla_decode_partial_kernel<<>>( \ + q_ptr, mc_ptr, mi_ptr, mip_ptr, \ + ec_ptr, ei_ptr, eip_ptr, \ + sm_ptr, sl_ptr, sa_ptr, \ + q.stride(0), q.stride(1), \ + main_cache.stride(0), extra_cache.stride(0), \ + main_num_rows, extra_num_rows, \ + main_block_size, extra_block_size, \ + scale_f, num_heads, num_head_blocks); \ + } while (0) + + #define LAUNCH_R(HAS_S, SK) do { \ + sparse_mla_decode_reduce_kernel<<>>( \ + sm_ptr, sl_ptr, sa_ptr, sink_ptr, out_ptr, \ + output.stride(0), output.stride(1), \ + num_heads, num_head_blocks); \ + } while (0) + + #define DISPATCH_SK(SK) do { \ + if (has_extra) LAUNCH_P(true, SK); \ + else LAUNCH_P(false, SK); \ + if (has_sink) LAUNCH_R(true, SK); \ + else LAUNCH_R(false, SK); \ + } while (0) + + switch ((int)split_k) { + case 2: DISPATCH_SK(2); break; + case 4: DISPATCH_SK(4); break; + case 8: DISPATCH_SK(8); break; + case 16: DISPATCH_SK(16); break; + default: TORCH_CHECK(false, "Unsupported SPLIT_K"); + } + #undef DISPATCH_SK + #undef LAUNCH_P + #undef LAUNCH_R +} + + +TORCH_LIBRARY_FRAGMENT(vllm_sparse_mla_hip, m) { + m.def("decode_single(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra) -> ()"); + m.def("decode_split(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "Tensor scratch_m, Tensor scratch_l, Tensor scratch_acc, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra, int split_k) -> ()"); +} +TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { + m.impl("decode_single", &sparse_mla_decode_single); + m.impl("decode_split", &sparse_mla_decode_split); +} +""" + +logger = logging.getLogger(__name__) + +_sparse_mla_hip_module_cache: dict = {} + + +def _build_sparse_mla_hip_ext(): + if "ext" in _sparse_mla_hip_module_cache: + return _sparse_mla_hip_module_cache["ext"] + cache_dir = os.environ.get( + "VLLM_SPARSE_MLA_HIP_CACHE_DIR", + str(pathlib.Path(tempfile.gettempdir()) / "vllm_sparse_mla_hip_cache"), + ) + os.makedirs(cache_dir, exist_ok=True) + os.environ["PYTORCH_ROCM_ARCH"] = "gfx950" + ext = load_inline( + name="vllm_sparse_mla_hip", + cpp_sources=[""], + cuda_sources=[_HIP_SPARSE_MLA_DECODE_SRC], + functions=[], + extra_cflags=["-O3", "-DNDEBUG", "-std=c++17"], + extra_cuda_cflags=[ + "-O3", + "-std=c++17", + "--offload-arch=gfx950", + "-DNDEBUG", + "-Wno-c++11-narrowing", + "-Wno-unused-result", + ], + with_cuda=True, + build_directory=cache_dir, + verbose=False, + ) + _sparse_mla_hip_module_cache["ext"] = ext + return ext + + +def _sparse_mla_hip_as_int32_1d(x): + if x.dtype != torch.int32: + x = x.to(torch.int32) + if not x.is_contiguous(): + x = x.contiguous() + return x.view(-1) + + +_NUM_CUS = 256 +_MIN_K_PER_SPLIT = int(os.environ.get("SPARSE_MLA_HIP_MIN_K_PER_SPLIT", "32")) +_SPLIT_K_OVERRIDE = os.environ.get("SPARSE_MLA_HIP_SPLIT_K") + + +def _pick_split_k(num_queries, num_head_blocks, max_total_k): + if _SPLIT_K_OVERRIDE is not None: + return int(_SPLIT_K_OVERRIDE) + base_tiles = max(1, num_queries * num_head_blocks) + cu_target = max(1, _NUM_CUS // base_tiles) + k_limit = max(1, max_total_k // _MIN_K_PER_SPLIT) + target = max(1, min(cu_target, k_limit)) + best = 1 + for b in (1, 2, 4, 8): + if b <= target: + best = b + return best + + +def _decode_sparse_mla_hip( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink, + nope_head_dim, + rope_head_dim, + extra_cache, + extra_indices, + extra_indptr, + max_main_len, + max_extra_len, +): + main_indices = _sparse_mla_hip_as_int32_1d(main_indices) + main_indptr = _sparse_mla_hip_as_int32_1d(main_indptr) + num_queries, num_heads, _ = q.shape + + has_extra = ( + extra_cache is not None + and extra_indices is not None + and extra_indptr is not None + ) + if has_extra: + extra_indices = _sparse_mla_hip_as_int32_1d(extra_indices) + extra_indptr = _sparse_mla_hip_as_int32_1d(extra_indptr) + else: + extra_cache = main_cache + extra_indices = torch.empty(0, device=q.device, dtype=torch.int32) + extra_indptr = torch.zeros(num_queries + 1, device=q.device, dtype=torch.int32) + + out = torch.empty_like(q, dtype=torch.bfloat16) + sink = attn_sink.contiguous() if attn_sink is not None else None + q_in = q.contiguous() if not q.is_contiguous() else q + + BLOCK_H = 16 + num_head_blocks = (num_heads + BLOCK_H - 1) // BLOCK_H + total_max_k = max_main_len + (max_extra_len if has_extra else 0) + split_k = _pick_split_k(num_queries, num_head_blocks, total_max_k) + + _build_sparse_mla_hip_ext() + + if split_k == 1: + torch.ops.vllm_sparse_mla_hip.decode_single( + q_in, + main_cache, + main_indices, + main_indptr, + extra_cache, + extra_indices, + extra_indptr, + sink, + out, + int(main_cache.shape[1]), + int(extra_cache.shape[1]), + int(main_cache.shape[0] * main_cache.shape[1]), + int(extra_cache.shape[0] * extra_cache.shape[1]), + float(scale), + bool(has_extra), + ) + else: + scratch_m = torch.empty( + num_queries * num_head_blocks * split_k * BLOCK_H, + device=q.device, + dtype=torch.float32, + ) + scratch_l = torch.empty_like(scratch_m) + scratch_acc = torch.empty( + num_queries * num_head_blocks * split_k * BLOCK_H * 512, + device=q.device, + dtype=torch.bfloat16, + ) + torch.ops.vllm_sparse_mla_hip.decode_split( + q_in, + main_cache, + main_indices, + main_indptr, + extra_cache, + extra_indices, + extra_indptr, + sink, + out, + scratch_m, + scratch_l, + scratch_acc, + int(main_cache.shape[1]), + int(extra_cache.shape[1]), + int(main_cache.shape[0] * main_cache.shape[1]), + int(extra_cache.shape[0] * extra_cache.shape[1]), + float(scale), + bool(has_extra), + int(split_k), + ) + return out + + def rocm_sparse_attn_decode( q: torch.Tensor, kv_cache: torch.Tensor | None, @@ -1671,7 +2573,7 @@ def rocm_sparse_attn_decode( output: torch.Tensor, ) -> None: assert swa_k_cache.dtype == torch.uint8, ( - "ROCm Triton sparse decode expects uint8 fp8_ds_mla SWA cache, " + "ROCm sparse decode expects uint8 fp8_ds_mla SWA cache, " f"got {swa_k_cache.dtype}" ) _validate_dsv4_sparse_dims( @@ -1680,39 +2582,72 @@ def rocm_sparse_attn_decode( rope_head_dim, "rocm_sparse_attn_decode", ) + assert nope_head_dim == 448 + assert rope_head_dim == 64 + + if swa_ragged_indices is None or swa_ragged_indptr is None: + main_indices_dense = swa_indices.reshape(swa_indices.shape[0], -1) + lengths = ( + swa_lens + if swa_lens is not None + else ((main_indices_dense >= 0).sum(dim=-1, dtype=torch.int32)) + ) + main_ragged_indices, main_ragged_indptr = build_ragged_indices_from_dense( + main_indices_dense, + lengths, + num_rows=swa_k_cache.shape[0] * swa_k_cache.shape[1], + ) + else: + main_ragged_indices = swa_ragged_indices + main_ragged_indptr = swa_ragged_indptr - main_indices = swa_indices.reshape(swa_indices.shape[0], -1) - + has_extra = not swa_only extra_cache = None - extra_indices = None - if not swa_only: + extra_ragged_indices = None + extra_ragged_indptr = None + if has_extra: assert kv_cache is not None - assert topk_indices is not None or ( - topk_ragged_indices is not None and topk_ragged_indptr is not None - ) - assert kv_cache.dtype == torch.uint8, ( - "ROCm Triton sparse decode expects uint8 fp8_ds_mla extra cache, " - f"got {kv_cache.dtype}" - ) + assert kv_cache.dtype == torch.uint8 extra_cache = kv_cache - if topk_indices is not None: - extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) + if topk_ragged_indices is None or topk_ragged_indptr is None: + assert topk_indices is not None + ex_dense = topk_indices.reshape(topk_indices.shape[0], -1) + lengths = ( + topk_lens + if topk_lens is not None + else ((ex_dense >= 0).sum(dim=-1, dtype=torch.int32)) + ) + extra_ragged_indices, extra_ragged_indptr = build_ragged_indices_from_dense( + ex_dense, + lengths, + num_rows=kv_cache.shape[0] * kv_cache.shape[1], + ) + else: + extra_ragged_indices = topk_ragged_indices + extra_ragged_indptr = topk_ragged_indptr - attn_out = _rocm_sparse_attn_decode_triton( + if torch.cuda.is_current_stream_capturing(): + max_main_len = swa_indices.shape[-1] + max_extra_len = max_main_len if has_extra else 0 + else: + max_main_len = int(swa_lens.max().item()) if swa_lens is not None else 0 + max_extra_len = 0 + if has_extra and topk_lens is not None: + max_extra_len = int(topk_lens.max().item()) + + attn_out = _decode_sparse_mla_hip( q=q, main_cache=swa_k_cache, - main_indices=main_indices, + main_indices=main_ragged_indices, + main_indptr=main_ragged_indptr, scale=scale, attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, extra_cache=extra_cache, - extra_indices=extra_indices, - main_lengths=swa_lens, - extra_lengths=topk_lens, - main_ragged_indices=swa_ragged_indices, - main_ragged_indptr=swa_ragged_indptr, - extra_ragged_indices=topk_ragged_indices, - extra_ragged_indptr=topk_ragged_indptr, + extra_indices=extra_ragged_indices, + extra_indptr=extra_ragged_indptr, + max_main_len=max_main_len, + max_extra_len=max_extra_len, ) output.copy_(attn_out.to(output.dtype)) From 2419a515c14989e9c8475c20e2802e3111c537f6 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Thu, 21 May 2026 05:33:39 -0500 Subject: [PATCH 2/9] add gfx950 gate Signed-off-by: Hemanth Acharya --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 119 +++++++++++++++++- 1 file changed, 117 insertions(+), 2 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index a2cd873b13a0..7481be6aea3b 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -2582,6 +2582,121 @@ def rocm_sparse_attn_decode( rope_head_dim, "rocm_sparse_attn_decode", ) + + if _ON_GFX950: + _rocm_sparse_attn_decode_hip( + q=q, + kv_cache=kv_cache, + swa_k_cache=swa_k_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + swa_ragged_indices=swa_ragged_indices, + swa_ragged_indptr=swa_ragged_indptr, + topk_ragged_indices=topk_ragged_indices, + topk_ragged_indptr=topk_ragged_indptr, + attn_sink=attn_sink, + scale=scale, + nope_head_dim=nope_head_dim, + rope_head_dim=rope_head_dim, + output=output, + ) + else: + _rocm_sparse_attn_decode_triton_path( + q=q, + kv_cache=kv_cache, + swa_k_cache=swa_k_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + swa_ragged_indices=swa_ragged_indices, + swa_ragged_indptr=swa_ragged_indptr, + topk_ragged_indices=topk_ragged_indices, + topk_ragged_indptr=topk_ragged_indptr, + attn_sink=attn_sink, + scale=scale, + nope_head_dim=nope_head_dim, + rope_head_dim=rope_head_dim, + output=output, + ) + + +def _rocm_sparse_attn_decode_triton_path( + q, + kv_cache, + swa_k_cache, + swa_only, + topk_indices, + topk_lens, + swa_indices, + swa_lens, + swa_ragged_indices, + swa_ragged_indptr, + topk_ragged_indices, + topk_ragged_indptr, + attn_sink, + scale, + nope_head_dim, + rope_head_dim, + output, +): + main_indices = swa_indices.reshape(swa_indices.shape[0], -1) + + extra_cache = None + extra_indices = None + if not swa_only: + assert kv_cache is not None + assert topk_indices is not None or ( + topk_ragged_indices is not None and topk_ragged_indptr is not None + ) + assert kv_cache.dtype == torch.uint8 + extra_cache = kv_cache + if topk_indices is not None: + extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) + + attn_out = _rocm_sparse_attn_decode_triton( + q=q, + main_cache=swa_k_cache, + main_indices=main_indices, + scale=scale, + attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], + nope_head_dim=nope_head_dim, + rope_head_dim=rope_head_dim, + extra_cache=extra_cache, + extra_indices=extra_indices, + main_lengths=swa_lens, + extra_lengths=topk_lens, + main_ragged_indices=swa_ragged_indices, + main_ragged_indptr=swa_ragged_indptr, + extra_ragged_indices=topk_ragged_indices, + extra_ragged_indptr=topk_ragged_indptr, + ) + output.copy_(attn_out.to(output.dtype)) + + +def _rocm_sparse_attn_decode_hip( + q, + kv_cache, + swa_k_cache, + swa_only, + topk_indices, + topk_lens, + swa_indices, + swa_lens, + swa_ragged_indices, + swa_ragged_indptr, + topk_ragged_indices, + topk_ragged_indptr, + attn_sink, + scale, + nope_head_dim, + rope_head_dim, + output, +): assert nope_head_dim == 448 assert rope_head_dim == 64 @@ -2590,7 +2705,7 @@ def rocm_sparse_attn_decode( lengths = ( swa_lens if swa_lens is not None - else ((main_indices_dense >= 0).sum(dim=-1, dtype=torch.int32)) + else (main_indices_dense >= 0).sum(dim=-1, dtype=torch.int32) ) main_ragged_indices, main_ragged_indptr = build_ragged_indices_from_dense( main_indices_dense, @@ -2615,7 +2730,7 @@ def rocm_sparse_attn_decode( lengths = ( topk_lens if topk_lens is not None - else ((ex_dense >= 0).sum(dim=-1, dtype=torch.int32)) + else (ex_dense >= 0).sum(dim=-1, dtype=torch.int32) ) extra_ragged_indices, extra_ragged_indptr = build_ragged_indices_from_dense( ex_dense, From 0a5f08a564484b589860df1ef7e4d2a8cef7dfc3 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Mon, 25 May 2026 06:00:45 -0500 Subject: [PATCH 3/9] shifted HIP source to csrc Signed-off-by: Hemanth Acharya --- csrc/rocm/sparse_mla_decode.cu | 691 ++++++++++++++++ .../v1/attention/ops/rocm_aiter_mla_sparse.py | 741 +----------------- 2 files changed, 699 insertions(+), 733 deletions(-) create mode 100644 csrc/rocm/sparse_mla_decode.cu diff --git a/csrc/rocm/sparse_mla_decode.cu b/csrc/rocm/sparse_mla_decode.cu new file mode 100644 index 000000000000..0a4c89732fe0 --- /dev/null +++ b/csrc/rocm/sparse_mla_decode.cu @@ -0,0 +1,691 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#include +#include +#include + +#include +#include + +using bf16x8 = __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16; +using fx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + +static constexpr int NOPE_DIM = 448; +static constexpr int ROPE_DIM = 64; +static constexpr int TOKEN_BYTES = 576; +static constexpr int SCALE_BYTES = 8; +static constexpr int HEAD_DIM = 512; +static constexpr int BLOCK_H = 16; +static constexpr int BLOCK_K = 32; +static constexpr int N_TILES = HEAD_DIM / 16; // 32 +static constexpr int QK_N_TILES = BLOCK_K / 16; // 2 + +__device__ __forceinline__ fx4 mfma_16x16x32_bf16(bf16x8 a, bf16x8 b, fx4 c) { + return __builtin_amdgcn_mfma_f32_16x16x32_bf16(a, b, c, 0, 0, 0); +} + +__device__ __forceinline__ void gather_and_dequant_k_tile( + int k_start, int k_len, const uint8_t* cache_base, int64_t cache_stride0, + int num_rows, int block_size, const int32_t* idx_base, __bf16* k_lds, + int8_t* kv_lds, int tid) { + const int tok_id = tid >> 3; // 0..31 + const int chunk = tid & 7; // 0..7 + const int col0 = chunk * 64; + + int k_pos = k_start + tok_id; + bool in_range = (k_pos < k_len); + int slot = in_range ? idx_base[k_pos] : 0; + bool valid = in_range && (slot >= 0) && (slot < num_rows); + int safe_slot = valid ? slot : 0; + int bi = safe_slot / block_size; + int pib = safe_slot - bi * block_size; + const uint8_t* block_ptr = cache_base + (int64_t)bi * cache_stride0; + const uint8_t* token_ptr = block_ptr + pib * TOKEN_BYTES; + + __bf16* dst_row = &k_lds[tok_id * HEAD_DIM + col0]; + + if (!valid) { + int4 z; + z.x = z.y = z.z = z.w = 0; + int4* d4 = reinterpret_cast(dst_row); +#pragma unroll + for (int j = 0; j < 8; ++j) d4[j] = z; + } else if (col0 < NOPE_DIM) { + const uint8_t* scale_ptr = + block_ptr + block_size * TOKEN_BYTES + pib * SCALE_BYTES; + uint8_t scl_u = scale_ptr[chunk]; + union { + uint32_t u; + float fv; + } sb; + sb.u = ((uint32_t)scl_u) << 23; + float scl_f = sb.fv; + + const uint32_t* src32 = reinterpret_cast(token_ptr + col0); +#pragma unroll + for (int u32_i = 0; u32_i < 16; ++u32_i) { + uint32_t word = src32[u32_i]; +#pragma unroll + for (int b = 0; b < 4; ++b) { + uint8_t kb = (word >> (b * 8)) & 0xFF; + uint32_t packed = (uint32_t)kb; + float f = __builtin_amdgcn_cvt_f32_fp8(packed, 0) * scl_f; + dst_row[u32_i * 4 + b] = (__bf16)f; + } + } + } else { + const int4* src4 = reinterpret_cast(token_ptr + NOPE_DIM); + int4* d4 = reinterpret_cast(dst_row); +#pragma unroll + for (int j = 0; j < 8; ++j) d4[j] = src4[j]; + } + + if (tid < BLOCK_K) { + int kp = k_start + tid; + int sl = (kp < k_len) ? idx_base[kp] : -1; + kv_lds[tid] = (kp < k_len) && (sl >= 0) && (sl < num_rows) ? 1 : 0; + } +} + +constexpr int N_TILES_PER_WAVE = 8; // 32 N-tiles / 4 waves +__device__ __forceinline__ void process_k_tile( + const __bf16* q_lds, const __bf16* k_lds, const int8_t* kv_lds, + __bf16* p_lds, float* scores_lds, float* m_state, float* l_state, fx4* acc, + float scale, int lane, int m_a, int kg, int n_b, int m_d_base, int n_d, + int wave) { + if (wave == 0) { + fx4 qk[2] = {{0.f, 0.f, 0.f, 0.f}, {0.f, 0.f, 0.f, 0.f}}; +#pragma unroll + for (int c = 0; c < HEAD_DIM / 32; ++c) { + bf16x8 q_reg; + const __bf16* q_src = &q_lds[m_a * HEAD_DIM + c * 32 + kg * 8]; +#pragma unroll + for (int i = 0; i < 8; ++i) q_reg[i] = q_src[i]; + +#pragma unroll + for (int nt = 0; nt < 2; ++nt) { + bf16x8 k_reg; + const __bf16* k_src = + &k_lds[(nt * 16 + n_b) * HEAD_DIM + c * 32 + kg * 8]; +#pragma unroll + for (int i = 0; i < 8; ++i) k_reg[i] = k_src[i]; + qk[nt] = mfma_16x16x32_bf16(q_reg, k_reg, qk[nt]); + } + } +#pragma unroll + for (int nt = 0; nt < 2; ++nt) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + int k_col = nt * 16 + n_d; + float s = qk[nt][i] * scale; + if (!kv_lds[k_col]) s = -3.4028234663852886e38f; + scores_lds[(m_d_base + i) * BLOCK_K + nt * 16 + n_d] = s; + } + } + } + + __syncthreads(); + + fx4 qk_local[2]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + qk_local[0][i] = scores_lds[(m_d_base + i) * BLOCK_K + n_d]; + qk_local[1][i] = scores_lds[(m_d_base + i) * BLOCK_K + 16 + n_d]; + } + + fx4 p[2]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + float row_max = fmaxf(qk_local[0][i], qk_local[1][i]); + row_max = fmaxf(row_max, __shfl_xor(row_max, 1)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 2)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 4)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 8)); + + float m_new = fmaxf(m_state[i], row_max); + float alpha = + __builtin_amdgcn_exp2f((m_state[i] - m_new) * 1.4426950408889634f); + + float e0 = + __builtin_amdgcn_exp2f((qk_local[0][i] - m_new) * 1.4426950408889634f); + float e1 = + __builtin_amdgcn_exp2f((qk_local[1][i] - m_new) * 1.4426950408889634f); + + float row_sum = e0 + e1; + row_sum += __shfl_xor(row_sum, 1); + row_sum += __shfl_xor(row_sum, 2); + row_sum += __shfl_xor(row_sum, 4); + row_sum += __shfl_xor(row_sum, 8); + + float l_new = l_state[i] * alpha + row_sum; + p[0][i] = e0; + p[1][i] = e1; + +#pragma unroll + for (int nt = 0; nt < N_TILES_PER_WAVE; ++nt) acc[nt][i] *= alpha; + + m_state[i] = m_new; + l_state[i] = l_new; + } + + if (wave == 0) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + p_lds[(m_d_base + i) * BLOCK_K + n_d] = (__bf16)p[0][i]; + p_lds[(m_d_base + i) * BLOCK_K + 16 + n_d] = (__bf16)p[1][i]; + } + } + + __syncthreads(); + + bf16x8 p_reg; + const __bf16* p_src = &p_lds[m_a * BLOCK_K + kg * 8]; +#pragma unroll + for (int i = 0; i < 8; ++i) p_reg[i] = p_src[i]; + +#pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + bf16x8 k_reg; +#pragma unroll + for (int i = 0; i < 8; ++i) { + k_reg[i] = k_lds[(kg * 8 + i) * HEAD_DIM + n_tile * 16 + n_b]; + } + acc[nt_local] = mfma_16x16x32_bf16(p_reg, k_reg, acc[nt_local]); + } +} + +__device__ __forceinline__ void load_q(const __bf16* q, int64_t q_stride0, + int64_t q_stride1, int query, int pid_h, + int num_heads, __bf16* q_lds, int tid) { + const int qh = tid >> 4; // 0..15 + const int qc0 = (tid & 15) << 5; // 0,32,...,480 + const int head_global = pid_h * BLOCK_H + qh; + __bf16* dst = &q_lds[qh * HEAD_DIM + qc0]; + if (head_global < num_heads) { + const __bf16* src = q + query * q_stride0 + head_global * q_stride1 + qc0; + const int4* s4 = reinterpret_cast(src); + int4* d4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < 4; ++i) d4[i] = s4[i]; + } else { + int4 z; + z.x = z.y = z.z = z.w = 0; + int4* d4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < 4; ++i) d4[i] = z; + } +} + +template +__global__ __launch_bounds__(256, 2) void sparse_mla_decode_kernel( + const __bf16* __restrict__ q, const uint8_t* __restrict__ main_cache, + const int32_t* __restrict__ main_indices, + const int32_t* __restrict__ main_indptr, + const uint8_t* __restrict__ extra_cache, + const int32_t* __restrict__ extra_indices, + const int32_t* __restrict__ extra_indptr, + const float* __restrict__ attn_sink, __bf16* __restrict__ output, + int64_t q_stride0, int64_t q_stride1, int64_t out_stride0, + int64_t out_stride1, int64_t main_cache_stride0, + int64_t extra_cache_stride0, int main_num_rows, int extra_num_rows, + int main_block_size, int extra_block_size, float scale, int num_heads) { + const int query = blockIdx.x; + const int pid_h = blockIdx.y; + const int tid = threadIdx.x; + const int wave = tid >> 6; + const int lane = tid & 63; + + const int m_a = lane & 15; + const int kg = lane >> 4; + const int n_b = lane & 15; + const int m_d_base = (lane >> 4) * 4; + const int n_d = lane & 15; + + __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; + __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; + __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; + __shared__ float scores_lds[BLOCK_H * BLOCK_K]; + __shared__ int8_t kv_lds[BLOCK_K]; + __shared__ char force_1wg_per_cu[48 * 1024]; // pads LDS to ~96 KB + (void)force_1wg_per_cu; + + load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); + + float m_state[4], l_state[4]; + fx4 acc[N_TILES_PER_WAVE]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + m_state[i] = -3.4028234663852886e38f; + l_state[i] = 0.f; + } +#pragma unroll + for (int i = 0; i < N_TILES_PER_WAVE; ++i) { + acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; + } + + __syncthreads(); + + { + int main_start = main_indptr[query]; + int main_end = main_indptr[query + 1]; + int main_len = main_end - main_start; + for (int k_start = 0; k_start < main_len; k_start += BLOCK_K) { + gather_and_dequant_k_tile( + k_start, main_len, main_cache, main_cache_stride0, main_num_rows, + main_block_size, main_indices + main_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + if (HAS_EXTRA) { + int extra_start = extra_indptr[query]; + int extra_end = extra_indptr[query + 1]; + int extra_len = extra_end - extra_start; + for (int k_start = 0; k_start < extra_len; k_start += BLOCK_K) { + gather_and_dequant_k_tile( + k_start, extra_len, extra_cache, extra_cache_stride0, extra_num_rows, + extra_block_size, extra_indices + extra_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + { +#pragma unroll + for (int i = 0; i < 4; ++i) { + int head_local = m_d_base + i; + int head_global = pid_h * BLOCK_H + head_local; + if (head_global >= num_heads) continue; + + float m_final = m_state[i]; + float l_final = l_state[i]; + float alpha_final = 1.f; + if (HAS_ATTN_SINK) { + float sink_val = attn_sink[head_global]; + m_final = fmaxf(m_state[i], sink_val); + alpha_final = __builtin_amdgcn_exp2f((m_state[i] - m_final) * + 1.4426950408889634f); + l_final = + l_state[i] * alpha_final + + __builtin_amdgcn_exp2f((sink_val - m_final) * 1.4426950408889634f); + } + float denom = fmaxf(l_final, 1.0e-30f); + bool live = (l_final > 0.f); + + __bf16* out_row = + output + query * out_stride0 + head_global * out_stride1; +#pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + int col = n_tile * 16 + n_d; + float v = live ? (acc[nt_local][i] * alpha_final) / denom : 0.f; + out_row[col] = (__bf16)v; + } + } + } +} + +template +__global__ __launch_bounds__(256, 2) void sparse_mla_decode_partial_kernel( + const __bf16* __restrict__ q, const uint8_t* __restrict__ main_cache, + const int32_t* __restrict__ main_indices, + const int32_t* __restrict__ main_indptr, + const uint8_t* __restrict__ extra_cache, + const int32_t* __restrict__ extra_indices, + const int32_t* __restrict__ extra_indptr, float* __restrict__ scratch_m, + float* __restrict__ scratch_l, __bf16* __restrict__ scratch_acc, + int64_t q_stride0, int64_t q_stride1, int64_t main_cache_stride0, + int64_t extra_cache_stride0, int main_num_rows, int extra_num_rows, + int main_block_size, int extra_block_size, float scale, int num_heads, + int num_head_blocks) { + const int query = blockIdx.x; + const int pid_hs = blockIdx.y; + const int pid_split = pid_hs / num_head_blocks; + const int pid_h = pid_hs - pid_split * num_head_blocks; + const int tid = threadIdx.x; + const int wave = tid >> 6; + const int lane = tid & 63; + + const int m_a = lane & 15; + const int kg = lane >> 4; + const int n_b = lane & 15; + const int m_d_base = (lane >> 4) * 4; + const int n_d = lane & 15; + + __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; + __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; + __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; + __shared__ float scores_lds[BLOCK_H * BLOCK_K]; + __shared__ int8_t kv_lds[BLOCK_K]; + load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); + + float m_state[4], l_state[4]; + fx4 acc[N_TILES_PER_WAVE]; +#pragma unroll + for (int i = 0; i < 4; ++i) { + m_state[i] = -3.4028234663852886e38f; + l_state[i] = 0.f; + } +#pragma unroll + for (int i = 0; i < N_TILES_PER_WAVE; ++i) { + acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; + } + + __syncthreads(); + + { + int main_start = main_indptr[query]; + int main_end = main_indptr[query + 1]; + int main_len = main_end - main_start; + for (int k_start = pid_split * BLOCK_K; k_start < main_len; + k_start += BLOCK_K * SPLIT_K) { + gather_and_dequant_k_tile( + k_start, main_len, main_cache, main_cache_stride0, main_num_rows, + main_block_size, main_indices + main_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + if (HAS_EXTRA) { + int extra_start = extra_indptr[query]; + int extra_end = extra_indptr[query + 1]; + int extra_len = extra_end - extra_start; + for (int k_start = pid_split * BLOCK_K; k_start < extra_len; + k_start += BLOCK_K * SPLIT_K) { + gather_and_dequant_k_tile( + k_start, extra_len, extra_cache, extra_cache_stride0, extra_num_rows, + extra_block_size, extra_indices + extra_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + pid_split; + + if (wave == 0 && n_d == 0) { +#pragma unroll + for (int i = 0; i < 4; ++i) { + int idx = triple * BLOCK_H + m_d_base + i; + scratch_m[idx] = m_state[i]; + scratch_l[idx] = l_state[i]; + } + } + + __syncthreads(); +#pragma unroll + for (int i = 0; i < 4; ++i) { + int row = m_d_base + i; +#pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + int col = n_tile * 16 + n_d; + k_lds[row * HEAD_DIM + col] = (__bf16)acc[nt_local][i]; + } + } + __syncthreads(); + + { + int my_row = tid >> 4; + int my_col0 = (tid & 15) << 5; + __bf16* dst = scratch_acc + (int64_t)triple * BLOCK_H * HEAD_DIM + + my_row * HEAD_DIM + my_col0; + const int4* src4 = + reinterpret_cast(&k_lds[my_row * HEAD_DIM + my_col0]); + int4* dst4 = reinterpret_cast(dst); +#pragma unroll + for (int i = 0; i < 4; ++i) dst4[i] = src4[i]; + } +} + +template +__global__ __launch_bounds__(256, 4) void sparse_mla_decode_reduce_kernel( + const float* __restrict__ scratch_m, const float* __restrict__ scratch_l, + const __bf16* __restrict__ scratch_acc, const float* __restrict__ attn_sink, + __bf16* __restrict__ output, int64_t out_stride0, int64_t out_stride1, + int num_heads, int num_head_blocks) { + const int query = blockIdx.x; + const int pid_h = blockIdx.y; + const int tid = threadIdx.x; + + const int my_row = tid >> 4; // 0..15 + const int my_col0 = (tid & 15) << 5; // 0,32,...,480 + const int head_global = pid_h * BLOCK_H + my_row; + + float m_merged = -3.4028234663852886e38f; + float l_merged = 0.f; + float acc_merged[32]; +#pragma unroll + for (int i = 0; i < 32; ++i) acc_merged[i] = 0.f; + +#pragma unroll + for (int s = 0; s < SPLIT_K; ++s) { + const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + s; + float m_s = scratch_m[triple * BLOCK_H + my_row]; + float l_s = scratch_l[triple * BLOCK_H + my_row]; + + float m_new = fmaxf(m_merged, m_s); + float alpha = + __builtin_amdgcn_exp2f((m_merged - m_new) * 1.4426950408889634f); + float beta = __builtin_amdgcn_exp2f((m_s - m_new) * 1.4426950408889634f); + l_merged = l_merged * alpha + l_s * beta; + m_merged = m_new; + + const __bf16* acc_base = scratch_acc + + (int64_t)triple * BLOCK_H * HEAD_DIM + + my_row * HEAD_DIM + my_col0; + const int4* src4 = reinterpret_cast(acc_base); +#pragma unroll + for (int i = 0; i < 4; ++i) { + int4 v = src4[i]; + __bf16 vbf[8]; + *reinterpret_cast(vbf) = v; +#pragma unroll + for (int j = 0; j < 8; ++j) { + float a_s = (float)vbf[j]; + acc_merged[i * 8 + j] = acc_merged[i * 8 + j] * alpha + a_s * beta; + } + } + } + + if (head_global >= num_heads) return; + + float m_final = m_merged; + float l_final = l_merged; + float alpha_final = 1.f; + if (HAS_ATTN_SINK) { + float sink_val = attn_sink[head_global]; + m_final = fmaxf(m_merged, sink_val); + alpha_final = + __builtin_amdgcn_exp2f((m_merged - m_final) * 1.4426950408889634f); + l_final = + l_merged * alpha_final + + __builtin_amdgcn_exp2f((sink_val - m_final) * 1.4426950408889634f); + } + float denom = fmaxf(l_final, 1.0e-30f); + bool live = (l_final > 0.f); + float inv_denom = live ? (alpha_final / denom) : 0.f; + + __bf16* out_row = + output + query * out_stride0 + head_global * out_stride1 + my_col0; + __bf16 out_buf[32]; +#pragma unroll + for (int i = 0; i < 32; ++i) out_buf[i] = (__bf16)(acc_merged[i] * inv_denom); + int4* dst4 = reinterpret_cast(out_row); + const int4* sb4 = reinterpret_cast(out_buf); +#pragma unroll + for (int i = 0; i < 4; ++i) dst4[i] = sb4[i]; +} + +void sparse_mla_decode_single( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + c10::optional attn_sink, torch::Tensor output, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale_d, bool has_extra) { + const int num_queries = q.size(0); + const int num_heads = q.size(1); + const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; + const float scale_f = (float)scale_d; + const bool has_sink = attn_sink.has_value(); + + dim3 grid(num_queries, num_head_blocks); + dim3 block(256); + + const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); + const uint8_t* mc_ptr = + reinterpret_cast(main_cache.data_ptr()); + const uint8_t* ec_ptr = + reinterpret_cast(extra_cache.data_ptr()); + const int32_t* mi_ptr = main_indices.data_ptr(); + const int32_t* mip_ptr = main_indptr.data_ptr(); + const int32_t* ei_ptr = extra_indices.data_ptr(); + const int32_t* eip_ptr = extra_indptr.data_ptr(); + __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); + const float* sink_ptr = + has_sink ? attn_sink.value().data_ptr() : nullptr; + + auto stream = at::cuda::getCurrentCUDAStream(); + +#define LAUNCH(HAS_S, HAS_E) \ + do { \ + sparse_mla_decode_kernel<<>>( \ + q_ptr, mc_ptr, mi_ptr, mip_ptr, ec_ptr, ei_ptr, eip_ptr, sink_ptr, \ + out_ptr, q.stride(0), q.stride(1), output.stride(0), output.stride(1), \ + main_cache.stride(0), extra_cache.stride(0), main_num_rows, \ + extra_num_rows, main_block_size, extra_block_size, scale_f, \ + num_heads); \ + } while (0) + + if (has_sink && has_extra) + LAUNCH(true, true); + else if (has_sink) + LAUNCH(true, false); + else if (has_extra) + LAUNCH(false, true); + else + LAUNCH(false, false); + +#undef LAUNCH +} + +void sparse_mla_decode_split( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + c10::optional attn_sink, torch::Tensor output, + torch::Tensor scratch_m, torch::Tensor scratch_l, torch::Tensor scratch_acc, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale_d, bool has_extra, int64_t split_k) { + const int num_queries = q.size(0); + const int num_heads = q.size(1); + const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; + const float scale_f = (float)scale_d; + const bool has_sink = attn_sink.has_value(); + + dim3 grid_p(num_queries, num_head_blocks * (int)split_k); + dim3 grid_r(num_queries, num_head_blocks); + dim3 block_p(256); + dim3 block_r(256); + + const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); + const uint8_t* mc_ptr = + reinterpret_cast(main_cache.data_ptr()); + const uint8_t* ec_ptr = + reinterpret_cast(extra_cache.data_ptr()); + const int32_t* mi_ptr = main_indices.data_ptr(); + const int32_t* mip_ptr = main_indptr.data_ptr(); + const int32_t* ei_ptr = extra_indices.data_ptr(); + const int32_t* eip_ptr = extra_indptr.data_ptr(); + __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); + float* sm_ptr = scratch_m.data_ptr(); + float* sl_ptr = scratch_l.data_ptr(); + __bf16* sa_ptr = reinterpret_cast<__bf16*>(scratch_acc.data_ptr()); + const float* sink_ptr = + has_sink ? attn_sink.value().data_ptr() : nullptr; + + auto stream = at::cuda::getCurrentCUDAStream(); + +#define LAUNCH_P(HAS_E, SK) \ + do { \ + sparse_mla_decode_partial_kernel \ + <<>>( \ + q_ptr, mc_ptr, mi_ptr, mip_ptr, ec_ptr, ei_ptr, eip_ptr, sm_ptr, \ + sl_ptr, sa_ptr, q.stride(0), q.stride(1), main_cache.stride(0), \ + extra_cache.stride(0), main_num_rows, extra_num_rows, \ + main_block_size, extra_block_size, scale_f, num_heads, \ + num_head_blocks); \ + } while (0) + +#define LAUNCH_R(HAS_S, SK) \ + do { \ + sparse_mla_decode_reduce_kernel \ + <<>>( \ + sm_ptr, sl_ptr, sa_ptr, sink_ptr, out_ptr, output.stride(0), \ + output.stride(1), num_heads, num_head_blocks); \ + } while (0) + +#define DISPATCH_SK(SK) \ + do { \ + if (has_extra) \ + LAUNCH_P(true, SK); \ + else \ + LAUNCH_P(false, SK); \ + if (has_sink) \ + LAUNCH_R(true, SK); \ + else \ + LAUNCH_R(false, SK); \ + } while (0) + + switch ((int)split_k) { + case 2: + DISPATCH_SK(2); + break; + case 4: + DISPATCH_SK(4); + break; + case 8: + DISPATCH_SK(8); + break; + case 16: + DISPATCH_SK(16); + break; + default: + TORCH_CHECK(false, "Unsupported SPLIT_K"); + } +#undef DISPATCH_SK +#undef LAUNCH_P +#undef LAUNCH_R +} + +TORCH_LIBRARY_FRAGMENT(vllm_sparse_mla_hip, m) { + m.def( + "decode_single(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra) -> ()"); + m.def( + "decode_split(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "Tensor scratch_m, Tensor scratch_l, Tensor scratch_acc, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra, int split_k) -> ()"); +} +TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { + m.impl("decode_single", &sparse_mla_decode_single); + m.impl("decode_split", &sparse_mla_decode_split); +} diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 7481be6aea3b..b28dd7c45d78 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -1659,738 +1659,13 @@ def rocm_sparse_attn_prefill( # HIP MFMA kernel implementation for sparse-MLA decode. # ============================================================================ -_HIP_SPARSE_MLA_DECODE_SRC = r""" -#include -#include -#include - -#include -#include - -using bf16x8 = __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16; -using fx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; - -static constexpr int NOPE_DIM = 448; -static constexpr int ROPE_DIM = 64; -static constexpr int TOKEN_BYTES = 576; -static constexpr int SCALE_BYTES = 8; -static constexpr int HEAD_DIM = 512; -static constexpr int BLOCK_H = 16; -static constexpr int BLOCK_K = 32; -static constexpr int N_TILES = HEAD_DIM / 16; // 32 -static constexpr int QK_N_TILES = BLOCK_K / 16; // 2 - -__device__ __forceinline__ fx4 mfma_16x16x32_bf16( - bf16x8 a, bf16x8 b, fx4 c) { - return __builtin_amdgcn_mfma_f32_16x16x32_bf16(a, b, c, 0, 0, 0); -} - -__device__ __forceinline__ void gather_and_dequant_k_tile( - int k_start, int k_len, const uint8_t* cache_base, - int64_t cache_stride0, int num_rows, int block_size, - const int32_t* idx_base, - __bf16* k_lds, int8_t* kv_lds, int tid) -{ - const int tok_id = tid >> 3; // 0..31 - const int chunk = tid & 7; // 0..7 - const int col0 = chunk * 64; - - int k_pos = k_start + tok_id; - bool in_range = (k_pos < k_len); - int slot = in_range ? idx_base[k_pos] : 0; - bool valid = in_range && (slot >= 0) && (slot < num_rows); - int safe_slot = valid ? slot : 0; - int bi = safe_slot / block_size; - int pib = safe_slot - bi * block_size; - const uint8_t* block_ptr = cache_base - + (int64_t)bi * cache_stride0; - const uint8_t* token_ptr = block_ptr + pib * TOKEN_BYTES; - - __bf16* dst_row = &k_lds[tok_id * HEAD_DIM + col0]; - - if (!valid) { - int4 z; z.x = z.y = z.z = z.w = 0; - int4* d4 = reinterpret_cast(dst_row); - #pragma unroll - for (int j = 0; j < 8; ++j) d4[j] = z; - } else if (col0 < NOPE_DIM) { - const uint8_t* scale_ptr = block_ptr - + block_size * TOKEN_BYTES - + pib * SCALE_BYTES; - uint8_t scl_u = scale_ptr[chunk]; - union { uint32_t u; float fv; } sb; - sb.u = ((uint32_t)scl_u) << 23; - float scl_f = sb.fv; - - const uint32_t* src32 = reinterpret_cast( - token_ptr + col0); - #pragma unroll - for (int u32_i = 0; u32_i < 16; ++u32_i) { - uint32_t word = src32[u32_i]; - #pragma unroll - for (int b = 0; b < 4; ++b) { - uint8_t kb = (word >> (b * 8)) & 0xFF; - uint32_t packed = (uint32_t)kb; - float f = __builtin_amdgcn_cvt_f32_fp8(packed, 0) * scl_f; - dst_row[u32_i * 4 + b] = (__bf16)f; - } - } - } else { - const int4* src4 = reinterpret_cast( - token_ptr + NOPE_DIM); - int4* d4 = reinterpret_cast(dst_row); - #pragma unroll - for (int j = 0; j < 8; ++j) d4[j] = src4[j]; - } - - if (tid < BLOCK_K) { - int kp = k_start + tid; - int sl = (kp < k_len) ? idx_base[kp] : -1; - kv_lds[tid] = (kp < k_len) && (sl >= 0) && (sl < num_rows) ? 1 : 0; - } -} - - -constexpr int N_TILES_PER_WAVE = 8; // 32 N-tiles / 4 waves -__device__ __forceinline__ void process_k_tile( - const __bf16* q_lds, const __bf16* k_lds, const int8_t* kv_lds, - __bf16* p_lds, float* scores_lds, - float* m_state, float* l_state, fx4* acc, float scale, - int lane, int m_a, int kg, int n_b, int m_d_base, int n_d, - int wave) -{ - if (wave == 0) { - fx4 qk[2] = {{0.f, 0.f, 0.f, 0.f}, {0.f, 0.f, 0.f, 0.f}}; - #pragma unroll - for (int c = 0; c < HEAD_DIM / 32; ++c) { - bf16x8 q_reg; - const __bf16* q_src = &q_lds[m_a * HEAD_DIM + c * 32 + kg * 8]; - #pragma unroll - for (int i = 0; i < 8; ++i) q_reg[i] = q_src[i]; - - #pragma unroll - for (int nt = 0; nt < 2; ++nt) { - bf16x8 k_reg; - const __bf16* k_src = &k_lds[(nt * 16 + n_b) * HEAD_DIM - + c * 32 + kg * 8]; - #pragma unroll - for (int i = 0; i < 8; ++i) k_reg[i] = k_src[i]; - qk[nt] = mfma_16x16x32_bf16(q_reg, k_reg, qk[nt]); - } - } - #pragma unroll - for (int nt = 0; nt < 2; ++nt) { - #pragma unroll - for (int i = 0; i < 4; ++i) { - int k_col = nt * 16 + n_d; - float s = qk[nt][i] * scale; - if (!kv_lds[k_col]) s = -3.4028234663852886e38f; - scores_lds[(m_d_base + i) * BLOCK_K + nt * 16 + n_d] = s; - } - } - } - - __syncthreads(); - - fx4 qk_local[2]; - #pragma unroll - for (int i = 0; i < 4; ++i) { - qk_local[0][i] = scores_lds[(m_d_base + i) * BLOCK_K + n_d]; - qk_local[1][i] = scores_lds[(m_d_base + i) * BLOCK_K + 16 + n_d]; - } - - fx4 p[2]; - #pragma unroll - for (int i = 0; i < 4; ++i) { - float row_max = fmaxf(qk_local[0][i], qk_local[1][i]); - row_max = fmaxf(row_max, __shfl_xor(row_max, 1)); - row_max = fmaxf(row_max, __shfl_xor(row_max, 2)); - row_max = fmaxf(row_max, __shfl_xor(row_max, 4)); - row_max = fmaxf(row_max, __shfl_xor(row_max, 8)); - - float m_new = fmaxf(m_state[i], row_max); - float alpha = __builtin_amdgcn_exp2f( - (m_state[i] - m_new) * 1.4426950408889634f); - - float e0 = __builtin_amdgcn_exp2f( - (qk_local[0][i] - m_new) * 1.4426950408889634f); - float e1 = __builtin_amdgcn_exp2f( - (qk_local[1][i] - m_new) * 1.4426950408889634f); - - float row_sum = e0 + e1; - row_sum += __shfl_xor(row_sum, 1); - row_sum += __shfl_xor(row_sum, 2); - row_sum += __shfl_xor(row_sum, 4); - row_sum += __shfl_xor(row_sum, 8); - - float l_new = l_state[i] * alpha + row_sum; - p[0][i] = e0; - p[1][i] = e1; - - #pragma unroll - for (int nt = 0; nt < N_TILES_PER_WAVE; ++nt) acc[nt][i] *= alpha; - - m_state[i] = m_new; - l_state[i] = l_new; - } - - if (wave == 0) { - #pragma unroll - for (int i = 0; i < 4; ++i) { - p_lds[(m_d_base + i) * BLOCK_K + n_d] = (__bf16)p[0][i]; - p_lds[(m_d_base + i) * BLOCK_K + 16 + n_d] = (__bf16)p[1][i]; - } - } - - __syncthreads(); - - bf16x8 p_reg; - const __bf16* p_src = &p_lds[m_a * BLOCK_K + kg * 8]; - #pragma unroll - for (int i = 0; i < 8; ++i) p_reg[i] = p_src[i]; - - #pragma unroll - for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { - int n_tile = wave * N_TILES_PER_WAVE + nt_local; - bf16x8 k_reg; - #pragma unroll - for (int i = 0; i < 8; ++i) { - k_reg[i] = k_lds[(kg * 8 + i) * HEAD_DIM - + n_tile * 16 + n_b]; - } - acc[nt_local] = mfma_16x16x32_bf16(p_reg, k_reg, acc[nt_local]); - } -} - - -__device__ __forceinline__ void load_q( - const __bf16* q, int64_t q_stride0, int64_t q_stride1, - int query, int pid_h, int num_heads, - __bf16* q_lds, int tid) -{ - const int qh = tid >> 4; // 0..15 - const int qc0 = (tid & 15) << 5; // 0,32,...,480 - const int head_global = pid_h * BLOCK_H + qh; - __bf16* dst = &q_lds[qh * HEAD_DIM + qc0]; - if (head_global < num_heads) { - const __bf16* src = q + query * q_stride0 - + head_global * q_stride1 + qc0; - const int4* s4 = reinterpret_cast(src); - int4* d4 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < 4; ++i) d4[i] = s4[i]; - } else { - int4 z; z.x = z.y = z.z = z.w = 0; - int4* d4 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < 4; ++i) d4[i] = z; - } -} - - -template -__global__ __launch_bounds__(256, 2) -void sparse_mla_decode_kernel( - const __bf16* __restrict__ q, - const uint8_t* __restrict__ main_cache, - const int32_t* __restrict__ main_indices, - const int32_t* __restrict__ main_indptr, - const uint8_t* __restrict__ extra_cache, - const int32_t* __restrict__ extra_indices, - const int32_t* __restrict__ extra_indptr, - const float* __restrict__ attn_sink, - __bf16* __restrict__ output, - int64_t q_stride0, int64_t q_stride1, - int64_t out_stride0, int64_t out_stride1, - int64_t main_cache_stride0, int64_t extra_cache_stride0, - int main_num_rows, int extra_num_rows, - int main_block_size, int extra_block_size, - float scale, int num_heads) -{ - const int query = blockIdx.x; - const int pid_h = blockIdx.y; - const int tid = threadIdx.x; - const int wave = tid >> 6; - const int lane = tid & 63; - - const int m_a = lane & 15; - const int kg = lane >> 4; - const int n_b = lane & 15; - const int m_d_base = (lane >> 4) * 4; - const int n_d = lane & 15; - - __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; - __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; - __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; - __shared__ float scores_lds[BLOCK_H * BLOCK_K]; - __shared__ int8_t kv_lds[BLOCK_K]; - __shared__ char force_1wg_per_cu[48 * 1024]; // pads LDS to ~96 KB - (void)force_1wg_per_cu; - - load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); - - float m_state[4], l_state[4]; - fx4 acc[N_TILES_PER_WAVE]; - #pragma unroll - for (int i = 0; i < 4; ++i) { - m_state[i] = -3.4028234663852886e38f; - l_state[i] = 0.f; - } - #pragma unroll - for (int i = 0; i < N_TILES_PER_WAVE; ++i) { - acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; - } - - __syncthreads(); - - { - int main_start = main_indptr[query]; - int main_end = main_indptr[query + 1]; - int main_len = main_end - main_start; - for (int k_start = 0; k_start < main_len; k_start += BLOCK_K) { - gather_and_dequant_k_tile( - k_start, main_len, main_cache, main_cache_stride0, - main_num_rows, main_block_size, - main_indices + main_start, k_lds, kv_lds, tid); - __syncthreads(); - process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, - m_state, l_state, acc, scale, - lane, m_a, kg, n_b, m_d_base, n_d, wave); - __syncthreads(); - } - } - - if (HAS_EXTRA) { - int extra_start = extra_indptr[query]; - int extra_end = extra_indptr[query + 1]; - int extra_len = extra_end - extra_start; - for (int k_start = 0; k_start < extra_len; k_start += BLOCK_K) { - gather_and_dequant_k_tile( - k_start, extra_len, extra_cache, extra_cache_stride0, - extra_num_rows, extra_block_size, - extra_indices + extra_start, k_lds, kv_lds, tid); - __syncthreads(); - process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, - m_state, l_state, acc, scale, - lane, m_a, kg, n_b, m_d_base, n_d, wave); - __syncthreads(); - } - } - - { - #pragma unroll - for (int i = 0; i < 4; ++i) { - int head_local = m_d_base + i; - int head_global = pid_h * BLOCK_H + head_local; - if (head_global >= num_heads) continue; - - float m_final = m_state[i]; - float l_final = l_state[i]; - float alpha_final = 1.f; - if (HAS_ATTN_SINK) { - float sink_val = attn_sink[head_global]; - m_final = fmaxf(m_state[i], sink_val); - alpha_final = __builtin_amdgcn_exp2f( - (m_state[i] - m_final) * 1.4426950408889634f); - l_final = l_state[i] * alpha_final + __builtin_amdgcn_exp2f( - (sink_val - m_final) * 1.4426950408889634f); - } - float denom = fmaxf(l_final, 1.0e-30f); - bool live = (l_final > 0.f); - - __bf16* out_row = output + query * out_stride0 - + head_global * out_stride1; - #pragma unroll - for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { - int n_tile = wave * N_TILES_PER_WAVE + nt_local; - int col = n_tile * 16 + n_d; - float v = live ? (acc[nt_local][i] * alpha_final) / denom : 0.f; - out_row[col] = (__bf16)v; - } - } - } -} - - -template -__global__ __launch_bounds__(256, 2) -void sparse_mla_decode_partial_kernel( - const __bf16* __restrict__ q, - const uint8_t* __restrict__ main_cache, - const int32_t* __restrict__ main_indices, - const int32_t* __restrict__ main_indptr, - const uint8_t* __restrict__ extra_cache, - const int32_t* __restrict__ extra_indices, - const int32_t* __restrict__ extra_indptr, - float* __restrict__ scratch_m, - float* __restrict__ scratch_l, - __bf16* __restrict__ scratch_acc, - int64_t q_stride0, int64_t q_stride1, - int64_t main_cache_stride0, int64_t extra_cache_stride0, - int main_num_rows, int extra_num_rows, - int main_block_size, int extra_block_size, - float scale, int num_heads, int num_head_blocks) -{ - const int query = blockIdx.x; - const int pid_hs = blockIdx.y; - const int pid_split = pid_hs / num_head_blocks; - const int pid_h = pid_hs - pid_split * num_head_blocks; - const int tid = threadIdx.x; - const int wave = tid >> 6; - const int lane = tid & 63; - - const int m_a = lane & 15; - const int kg = lane >> 4; - const int n_b = lane & 15; - const int m_d_base = (lane >> 4) * 4; - const int n_d = lane & 15; - - __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; - __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; - __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; - __shared__ float scores_lds[BLOCK_H * BLOCK_K]; - __shared__ int8_t kv_lds[BLOCK_K]; - load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); - - float m_state[4], l_state[4]; - fx4 acc[N_TILES_PER_WAVE]; - #pragma unroll - for (int i = 0; i < 4; ++i) { - m_state[i] = -3.4028234663852886e38f; - l_state[i] = 0.f; - } - #pragma unroll - for (int i = 0; i < N_TILES_PER_WAVE; ++i) { - acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; - } - - __syncthreads(); - - { - int main_start = main_indptr[query]; - int main_end = main_indptr[query + 1]; - int main_len = main_end - main_start; - for (int k_start = pid_split * BLOCK_K; k_start < main_len; - k_start += BLOCK_K * SPLIT_K) { - gather_and_dequant_k_tile( - k_start, main_len, main_cache, main_cache_stride0, - main_num_rows, main_block_size, - main_indices + main_start, k_lds, kv_lds, tid); - __syncthreads(); - process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, - m_state, l_state, acc, scale, - lane, m_a, kg, n_b, m_d_base, n_d, wave); - __syncthreads(); - } - } - - if (HAS_EXTRA) { - int extra_start = extra_indptr[query]; - int extra_end = extra_indptr[query + 1]; - int extra_len = extra_end - extra_start; - for (int k_start = pid_split * BLOCK_K; k_start < extra_len; - k_start += BLOCK_K * SPLIT_K) { - gather_and_dequant_k_tile( - k_start, extra_len, extra_cache, extra_cache_stride0, - extra_num_rows, extra_block_size, - extra_indices + extra_start, k_lds, kv_lds, tid); - __syncthreads(); - process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, - m_state, l_state, acc, scale, - lane, m_a, kg, n_b, m_d_base, n_d, wave); - __syncthreads(); - } - } - - const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + pid_split; - - if (wave == 0 && n_d == 0) { - #pragma unroll - for (int i = 0; i < 4; ++i) { - int idx = triple * BLOCK_H + m_d_base + i; - scratch_m[idx] = m_state[i]; - scratch_l[idx] = l_state[i]; - } - } - - __syncthreads(); - #pragma unroll - for (int i = 0; i < 4; ++i) { - int row = m_d_base + i; - #pragma unroll - for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { - int n_tile = wave * N_TILES_PER_WAVE + nt_local; - int col = n_tile * 16 + n_d; - k_lds[row * HEAD_DIM + col] = (__bf16)acc[nt_local][i]; - } - } - __syncthreads(); - - { - int my_row = tid >> 4; - int my_col0 = (tid & 15) << 5; - __bf16* dst = scratch_acc + (int64_t)triple * BLOCK_H * HEAD_DIM - + my_row * HEAD_DIM + my_col0; - const int4* src4 = reinterpret_cast( - &k_lds[my_row * HEAD_DIM + my_col0]); - int4* dst4 = reinterpret_cast(dst); - #pragma unroll - for (int i = 0; i < 4; ++i) dst4[i] = src4[i]; - } -} - - -template -__global__ __launch_bounds__(256, 4) -void sparse_mla_decode_reduce_kernel( - const float* __restrict__ scratch_m, - const float* __restrict__ scratch_l, - const __bf16* __restrict__ scratch_acc, - const float* __restrict__ attn_sink, - __bf16* __restrict__ output, - int64_t out_stride0, int64_t out_stride1, - int num_heads, int num_head_blocks) -{ - const int query = blockIdx.x; - const int pid_h = blockIdx.y; - const int tid = threadIdx.x; - - const int my_row = tid >> 4; // 0..15 - const int my_col0 = (tid & 15) << 5; // 0,32,...,480 - const int head_global = pid_h * BLOCK_H + my_row; - - float m_merged = -3.4028234663852886e38f; - float l_merged = 0.f; - float acc_merged[32]; - #pragma unroll - for (int i = 0; i < 32; ++i) acc_merged[i] = 0.f; - - #pragma unroll - for (int s = 0; s < SPLIT_K; ++s) { - const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + s; - float m_s = scratch_m[triple * BLOCK_H + my_row]; - float l_s = scratch_l[triple * BLOCK_H + my_row]; - - float m_new = fmaxf(m_merged, m_s); - float alpha = __builtin_amdgcn_exp2f( - (m_merged - m_new) * 1.4426950408889634f); - float beta = __builtin_amdgcn_exp2f( - (m_s - m_new) * 1.4426950408889634f); - l_merged = l_merged * alpha + l_s * beta; - m_merged = m_new; - - const __bf16* acc_base = scratch_acc - + (int64_t)triple * BLOCK_H * HEAD_DIM - + my_row * HEAD_DIM + my_col0; - const int4* src4 = reinterpret_cast(acc_base); - #pragma unroll - for (int i = 0; i < 4; ++i) { - int4 v = src4[i]; - __bf16 vbf[8]; - *reinterpret_cast(vbf) = v; - #pragma unroll - for (int j = 0; j < 8; ++j) { - float a_s = (float)vbf[j]; - acc_merged[i * 8 + j] = acc_merged[i * 8 + j] * alpha - + a_s * beta; - } - } - } - - if (head_global >= num_heads) return; - - float m_final = m_merged; - float l_final = l_merged; - float alpha_final = 1.f; - if (HAS_ATTN_SINK) { - float sink_val = attn_sink[head_global]; - m_final = fmaxf(m_merged, sink_val); - alpha_final = __builtin_amdgcn_exp2f( - (m_merged - m_final) * 1.4426950408889634f); - l_final = l_merged * alpha_final + __builtin_amdgcn_exp2f( - (sink_val - m_final) * 1.4426950408889634f); - } - float denom = fmaxf(l_final, 1.0e-30f); - bool live = (l_final > 0.f); - float inv_denom = live ? (alpha_final / denom) : 0.f; - - __bf16* out_row = output + query * out_stride0 - + head_global * out_stride1 + my_col0; - __bf16 out_buf[32]; - #pragma unroll - for (int i = 0; i < 32; ++i) out_buf[i] = (__bf16)(acc_merged[i] * inv_denom); - int4* dst4 = reinterpret_cast(out_row); - const int4* sb4 = reinterpret_cast(out_buf); - #pragma unroll - for (int i = 0; i < 4; ++i) dst4[i] = sb4[i]; -} - - -void sparse_mla_decode_single( - torch::Tensor q, - torch::Tensor main_cache, - torch::Tensor main_indices, - torch::Tensor main_indptr, - torch::Tensor extra_cache, - torch::Tensor extra_indices, - torch::Tensor extra_indptr, - c10::optional attn_sink, - torch::Tensor output, - int64_t main_block_size, - int64_t extra_block_size, - int64_t main_num_rows, - int64_t extra_num_rows, - double scale_d, - bool has_extra) -{ - const int num_queries = q.size(0); - const int num_heads = q.size(1); - const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; - const float scale_f = (float)scale_d; - const bool has_sink = attn_sink.has_value(); - - dim3 grid(num_queries, num_head_blocks); - dim3 block(256); - - const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); - const uint8_t* mc_ptr = reinterpret_cast(main_cache.data_ptr()); - const uint8_t* ec_ptr = reinterpret_cast(extra_cache.data_ptr()); - const int32_t* mi_ptr = main_indices.data_ptr(); - const int32_t* mip_ptr = main_indptr.data_ptr(); - const int32_t* ei_ptr = extra_indices.data_ptr(); - const int32_t* eip_ptr = extra_indptr.data_ptr(); - __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); - const float* sink_ptr = has_sink - ? attn_sink.value().data_ptr() : nullptr; - - auto stream = at::cuda::getCurrentCUDAStream(); - - #define LAUNCH(HAS_S, HAS_E) do { \ - sparse_mla_decode_kernel<<>>( \ - q_ptr, mc_ptr, mi_ptr, mip_ptr, \ - ec_ptr, ei_ptr, eip_ptr, sink_ptr, out_ptr, \ - q.stride(0), q.stride(1), \ - output.stride(0), output.stride(1), \ - main_cache.stride(0), extra_cache.stride(0), \ - main_num_rows, extra_num_rows, \ - main_block_size, extra_block_size, \ - scale_f, num_heads); \ - } while (0) - - if (has_sink && has_extra) LAUNCH(true, true); - else if (has_sink) LAUNCH(true, false); - else if (has_extra) LAUNCH(false, true); - else LAUNCH(false, false); - - #undef LAUNCH -} - - -void sparse_mla_decode_split( - torch::Tensor q, - torch::Tensor main_cache, - torch::Tensor main_indices, - torch::Tensor main_indptr, - torch::Tensor extra_cache, - torch::Tensor extra_indices, - torch::Tensor extra_indptr, - c10::optional attn_sink, - torch::Tensor output, - torch::Tensor scratch_m, - torch::Tensor scratch_l, - torch::Tensor scratch_acc, - int64_t main_block_size, - int64_t extra_block_size, - int64_t main_num_rows, - int64_t extra_num_rows, - double scale_d, - bool has_extra, - int64_t split_k) -{ - const int num_queries = q.size(0); - const int num_heads = q.size(1); - const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; - const float scale_f = (float)scale_d; - const bool has_sink = attn_sink.has_value(); - - dim3 grid_p(num_queries, num_head_blocks * (int)split_k); - dim3 grid_r(num_queries, num_head_blocks); - dim3 block_p(256); - dim3 block_r(256); - - const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); - const uint8_t* mc_ptr = reinterpret_cast(main_cache.data_ptr()); - const uint8_t* ec_ptr = reinterpret_cast(extra_cache.data_ptr()); - const int32_t* mi_ptr = main_indices.data_ptr(); - const int32_t* mip_ptr = main_indptr.data_ptr(); - const int32_t* ei_ptr = extra_indices.data_ptr(); - const int32_t* eip_ptr = extra_indptr.data_ptr(); - __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); - float* sm_ptr = scratch_m.data_ptr(); - float* sl_ptr = scratch_l.data_ptr(); - __bf16* sa_ptr = reinterpret_cast<__bf16*>(scratch_acc.data_ptr()); - const float* sink_ptr = has_sink - ? attn_sink.value().data_ptr() : nullptr; - - auto stream = at::cuda::getCurrentCUDAStream(); - - #define LAUNCH_P(HAS_E, SK) do { \ - sparse_mla_decode_partial_kernel<<>>( \ - q_ptr, mc_ptr, mi_ptr, mip_ptr, \ - ec_ptr, ei_ptr, eip_ptr, \ - sm_ptr, sl_ptr, sa_ptr, \ - q.stride(0), q.stride(1), \ - main_cache.stride(0), extra_cache.stride(0), \ - main_num_rows, extra_num_rows, \ - main_block_size, extra_block_size, \ - scale_f, num_heads, num_head_blocks); \ - } while (0) - - #define LAUNCH_R(HAS_S, SK) do { \ - sparse_mla_decode_reduce_kernel<<>>( \ - sm_ptr, sl_ptr, sa_ptr, sink_ptr, out_ptr, \ - output.stride(0), output.stride(1), \ - num_heads, num_head_blocks); \ - } while (0) - - #define DISPATCH_SK(SK) do { \ - if (has_extra) LAUNCH_P(true, SK); \ - else LAUNCH_P(false, SK); \ - if (has_sink) LAUNCH_R(true, SK); \ - else LAUNCH_R(false, SK); \ - } while (0) - - switch ((int)split_k) { - case 2: DISPATCH_SK(2); break; - case 4: DISPATCH_SK(4); break; - case 8: DISPATCH_SK(8); break; - case 16: DISPATCH_SK(16); break; - default: TORCH_CHECK(false, "Unsupported SPLIT_K"); - } - #undef DISPATCH_SK - #undef LAUNCH_P - #undef LAUNCH_R -} - - -TORCH_LIBRARY_FRAGMENT(vllm_sparse_mla_hip, m) { - m.def("decode_single(Tensor q, Tensor main_cache, Tensor main_indices, " - "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " - "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " - "int main_block_size, int extra_block_size, int main_num_rows, " - "int extra_num_rows, float scale, bool has_extra) -> ()"); - m.def("decode_split(Tensor q, Tensor main_cache, Tensor main_indices, " - "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " - "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " - "Tensor scratch_m, Tensor scratch_l, Tensor scratch_acc, " - "int main_block_size, int extra_block_size, int main_num_rows, " - "int extra_num_rows, float scale, bool has_extra, int split_k) -> ()"); -} -TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { - m.impl("decode_single", &sparse_mla_decode_single); - m.impl("decode_split", &sparse_mla_decode_split); -} -""" +_SPARSE_MLA_DECODE_CU = ( + pathlib.Path(__file__).resolve().parents[4] + / "csrc" + / "rocm" + / "sparse_mla_decode.cu" +) + logger = logging.getLogger(__name__) @@ -2409,7 +1684,7 @@ def _build_sparse_mla_hip_ext(): ext = load_inline( name="vllm_sparse_mla_hip", cpp_sources=[""], - cuda_sources=[_HIP_SPARSE_MLA_DECODE_SRC], + cuda_sources=[_SPARSE_MLA_DECODE_CU.read_text()], functions=[], extra_cflags=["-O3", "-DNDEBUG", "-std=c++17"], extra_cuda_cflags=[ From 6c91f495bf37306aff6a39c525553f110747daf6 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Mon, 25 May 2026 06:06:56 -0500 Subject: [PATCH 4/9] added unit test for sparse mla Signed-off-by: Hemanth Acharya --- .../test_rocm_sparse_mla_decode_gfx950.py | 413 ++++++++++++++++++ 1 file changed, 413 insertions(+) create mode 100644 tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py diff --git a/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py b/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py new file mode 100644 index 000000000000..769119736e37 --- /dev/null +++ b/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py @@ -0,0 +1,413 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the HIP MFMA sparse-MLA decode kernels (gfx950). + +Tests cover: + - single-WG decode (split_k == 1): main-only, main+extra, with/without attn_sink + - split-K decode: forced via SPARSE_MLA_HIP_SPLIT_K env override + - various batch / head / sequence-length combinations +""" + +import os + +import pytest +import torch + +from vllm.platforms import current_platform + + +def _is_gfx950() -> bool: + if not current_platform.is_rocm(): + return False + try: + from vllm.platforms.rocm import _ON_GFX950 + + return _ON_GFX950 + except ImportError: + return False + + +pytestmark = pytest.mark.skipif( + not _is_gfx950(), reason="Requires ROCm gfx950 hardware" +) + +NOPE_HEAD_DIM = 448 +ROPE_HEAD_DIM = 64 +HEAD_DIM = NOPE_HEAD_DIM + ROPE_HEAD_DIM + + +# --------------------------------------------------------------------------- +# Helpers (shared with test_rocm_triton_attn_dsv4.py) +# --------------------------------------------------------------------------- + + +def _pack_fp8_ds_mla_cache(kv: torch.Tensor, block_size: int) -> torch.Tensor: + """Pack bf16 KV rows into the fp8_ds_mla uint8 cache layout.""" + assert kv.shape[-1] == HEAD_DIM + num_tokens = kv.shape[0] + num_blocks = (num_tokens + block_size - 1) // block_size + cache = torch.zeros( + (num_blocks, block_size, 584), + dtype=torch.uint8, + device=kv.device, + ) + cache_flat = cache.view(torch.uint8).flatten() + kv_nope_fp8 = ( + kv[:, :NOPE_HEAD_DIM].to(current_platform.fp8_dtype()).view(torch.uint8) + ) + kv_rope_u8 = kv[:, NOPE_HEAD_DIM:].contiguous().view(torch.uint8) + + for slot in range(num_tokens): + block_idx = slot // block_size + pos = slot % block_size + block_base = block_idx * cache.stride(0) + token_base = block_base + pos * 576 + scale_base = block_base + block_size * 576 + pos * 8 + cache_flat[token_base : token_base + NOPE_HEAD_DIM].copy_(kv_nope_fp8[slot]) + cache_flat[ + token_base + NOPE_HEAD_DIM : token_base + NOPE_HEAD_DIM + ROPE_HEAD_DIM * 2 + ].copy_(kv_rope_u8[slot]) + cache_flat[scale_base : scale_base + 7].fill_(127) + return cache + + +def _read_fp8_ds_mla_cache( + cache: torch.Tensor, slot: int, block_size: int +) -> torch.Tensor: + cache_flat = cache.view(torch.uint8).flatten() + block_idx = slot // block_size + pos = slot % block_size + block_base = block_idx * cache.stride(0) + token_base = block_base + pos * 576 + + nope_u8 = cache_flat[token_base : token_base + NOPE_HEAD_DIM] + nope = nope_u8.view(current_platform.fp8_dtype()).to(torch.float32) + rope_u8 = cache_flat[ + token_base + NOPE_HEAD_DIM : token_base + NOPE_HEAD_DIM + ROPE_HEAD_DIM * 2 + ] + rope = rope_u8.view(torch.bfloat16).to(torch.float32) + return torch.cat([nope, rope]) + + +def _ref_sparse_decode_ragged( + q: torch.Tensor, + main_cache: torch.Tensor, + main_rows: list[list[int]], + scale: float, + attn_sink: torch.Tensor | None, + block_size: int, + extra_cache: torch.Tensor | None = None, + extra_rows: list[list[int]] | None = None, +) -> torch.Tensor: + """Pure-Python reference for ragged sparse decode attention.""" + q_f32 = q.float() + out = torch.empty_like(q_f32) + + for query_idx in range(q.shape[0]): + row_kv = [ + _read_fp8_ds_mla_cache(main_cache, int(slot), block_size) + for slot in main_rows[query_idx] + ] + if extra_cache is not None and extra_rows is not None: + row_kv.extend( + _read_fp8_ds_mla_cache(extra_cache, int(slot), block_size) + for slot in extra_rows[query_idx] + ) + + kv = torch.stack(row_kv).to(q.device) + for head_idx in range(q.shape[1]): + scores = torch.mv(kv, q_f32[query_idx, head_idx]) * scale + if attn_sink is not None: + scores_with_sink = torch.cat( + [scores, attn_sink[head_idx].float().reshape(1)] + ) + probs = torch.softmax(scores_with_sink, dim=0)[:-1] + else: + probs = torch.softmax(scores, dim=0) + out[query_idx, head_idx] = torch.sum(probs[:, None] * kv, dim=0) + return out.to(torch.bfloat16) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink, + extra_cache=None, + extra_indices=None, + extra_indptr=None, + max_main_len=None, + max_extra_len=None, +): + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + _decode_sparse_mla_hip, + ) + + if max_main_len is None: + max_main_len = int(main_indices.numel()) + if max_extra_len is None: + max_extra_len = int(extra_indices.numel()) if extra_indices is not None else 0 + + return _decode_sparse_mla_hip( + q=q, + main_cache=main_cache, + main_indices=main_indices, + main_indptr=main_indptr, + scale=scale, + attn_sink=attn_sink, + nope_head_dim=NOPE_HEAD_DIM, + rope_head_dim=ROPE_HEAD_DIM, + extra_cache=extra_cache, + extra_indices=extra_indices, + extra_indptr=extra_indptr, + max_main_len=max_main_len, + max_extra_len=max_extra_len, + ) + + +@torch.inference_mode() +def test_hip_decode_main_only_no_sink() -> None: + """Single-WG decode with main cache only, no attn_sink.""" + device = torch.device("cuda") + torch.manual_seed(42) + block_size = 4 + num_queries, num_heads = 2, 3 + q = ( + torch.randn( + num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + * 0.125 + ) + main_kv = torch.randn(6, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) + main_indices = torch.tensor([0, 2, 4, 1], dtype=torch.int32, device=device) + main_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device=device) + scale = HEAD_DIM**-0.5 + + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=None, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + [[0, 2], [4, 1]], + scale, + attn_sink=None, + block_size=block_size, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +def test_hip_decode_main_only_with_sink() -> None: + """Single-WG decode with main cache only, with attn_sink.""" + device = torch.device("cuda") + torch.manual_seed(42) + block_size = 4 + num_queries, num_heads = 2, 3 + q = ( + torch.randn( + num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + * 0.125 + ) + main_kv = torch.randn(6, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) + main_indices = torch.tensor([0, 2, 4, 1], dtype=torch.int32, device=device) + main_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device=device) + attn_sink = torch.tensor([-0.1, 0.0, 0.1], dtype=torch.float32, device=device) + scale = HEAD_DIM**-0.5 + + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + [[0, 2], [4, 1]], + scale, + attn_sink=attn_sink, + block_size=block_size, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +def test_hip_decode_main_extra_with_sink() -> None: + """Single-WG decode with main + extra cache and attn_sink.""" + device = torch.device("cuda") + torch.manual_seed(1) + block_size = 4 + num_queries, num_heads = 2, 3 + q = ( + torch.randn( + num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + * 0.125 + ) + main_kv = torch.randn(6, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + extra_kv = torch.randn(5, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) + extra_cache = _pack_fp8_ds_mla_cache(extra_kv, block_size) + main_indices = torch.tensor([0, 2, 4, 1], dtype=torch.int32, device=device) + main_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device=device) + extra_indices = torch.tensor([1, 3, 0], dtype=torch.int32, device=device) + extra_indptr = torch.tensor([0, 1, 3], dtype=torch.int32, device=device) + attn_sink = torch.tensor([-0.1, 0.0, 0.1], dtype=torch.float32, device=device) + scale = HEAD_DIM**-0.5 + + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + extra_cache=extra_cache, + extra_indices=extra_indices, + extra_indptr=extra_indptr, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + [[0, 2], [4, 1]], + scale, + attn_sink=attn_sink, + block_size=block_size, + extra_cache=extra_cache, + extra_rows=[[1], [3, 0]], + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +def test_hip_decode_split_k() -> None: + """Force split-K path (split_k=2) and verify correctness.""" + device = torch.device("cuda") + torch.manual_seed(7) + block_size = 4 + num_queries, num_heads = 1, 16 + num_tokens = 128 + q = ( + torch.randn( + num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + * 0.125 + ) + main_kv = ( + torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + ) + main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) + main_indices = torch.arange(num_tokens, dtype=torch.int32, device=device) + main_indptr = torch.tensor([0, num_tokens], dtype=torch.int32, device=device) + attn_sink = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + scale = HEAD_DIM**-0.5 + + old_val = os.environ.get("SPARSE_MLA_HIP_SPLIT_K") + try: + os.environ["SPARSE_MLA_HIP_SPLIT_K"] = "2" + from vllm.v1.attention.ops import rocm_aiter_mla_sparse as mod + + orig = mod._SPLIT_K_OVERRIDE + mod._SPLIT_K_OVERRIDE = "2" + + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + max_main_len=num_tokens, + ) + finally: + mod._SPLIT_K_OVERRIDE = orig + if old_val is None: + os.environ.pop("SPARSE_MLA_HIP_SPLIT_K", None) + else: + os.environ["SPARSE_MLA_HIP_SPLIT_K"] = old_val + + main_rows = [list(range(num_tokens))] + expected = _ref_sparse_decode_ragged( + q, + main_cache, + main_rows, + scale, + attn_sink=attn_sink, + block_size=block_size, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_queries", [1, 4]) +@pytest.mark.parametrize("num_heads", [3, 16, 32]) +@pytest.mark.parametrize("seq_len", [32, 64, 128]) +def test_hip_decode_shapes(num_queries, num_heads, seq_len) -> None: + """Parametrized test over different batch/head/seqlen combos.""" + device = torch.device("cuda") + torch.manual_seed(num_queries * 1000 + num_heads * 10 + seq_len) + block_size = 4 + q = ( + torch.randn( + num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + * 0.125 + ) + main_kv = ( + torch.randn(seq_len, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + ) + main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) + scale = HEAD_DIM**-0.5 + + tokens_per_query = seq_len // num_queries + main_rows = [] + indices_list = [] + indptr = [0] + for qi in range(num_queries): + start = qi * tokens_per_query + end = start + tokens_per_query + row_slots = list(range(start, end)) + main_rows.append(row_slots) + indices_list.extend(row_slots) + indptr.append(indptr[-1] + len(row_slots)) + + main_indices = torch.tensor(indices_list, dtype=torch.int32, device=device) + main_indptr = torch.tensor(indptr, dtype=torch.int32, device=device) + + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=None, + max_main_len=tokens_per_query, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + main_rows, + scale, + attn_sink=None, + block_size=block_size, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) From 2afbaba48341eb0cb8faa69c0d1ad29790a27f68 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Mon, 25 May 2026 06:35:46 -0500 Subject: [PATCH 5/9] remove _path function Signed-off-by: Hemanth Acharya --- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 94 ++++++------------- 1 file changed, 27 insertions(+), 67 deletions(-) diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index b28dd7c45d78..865cf215f2b7 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -1879,78 +1879,38 @@ def rocm_sparse_attn_decode( output=output, ) else: - _rocm_sparse_attn_decode_triton_path( + main_indices = swa_indices.reshape(swa_indices.shape[0], -1) + + extra_cache = None + extra_indices = None + if not swa_only: + assert kv_cache is not None + assert topk_indices is not None or ( + topk_ragged_indices is not None and topk_ragged_indptr is not None + ) + assert kv_cache.dtype == torch.uint8 + extra_cache = kv_cache + if topk_indices is not None: + extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) + + attn_out = _rocm_sparse_attn_decode_triton( q=q, - kv_cache=kv_cache, - swa_k_cache=swa_k_cache, - swa_only=swa_only, - topk_indices=topk_indices, - topk_lens=topk_lens, - swa_indices=swa_indices, - swa_lens=swa_lens, - swa_ragged_indices=swa_ragged_indices, - swa_ragged_indptr=swa_ragged_indptr, - topk_ragged_indices=topk_ragged_indices, - topk_ragged_indptr=topk_ragged_indptr, - attn_sink=attn_sink, + main_cache=swa_k_cache, + main_indices=main_indices, scale=scale, + attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, - output=output, - ) - - -def _rocm_sparse_attn_decode_triton_path( - q, - kv_cache, - swa_k_cache, - swa_only, - topk_indices, - topk_lens, - swa_indices, - swa_lens, - swa_ragged_indices, - swa_ragged_indptr, - topk_ragged_indices, - topk_ragged_indptr, - attn_sink, - scale, - nope_head_dim, - rope_head_dim, - output, -): - main_indices = swa_indices.reshape(swa_indices.shape[0], -1) - - extra_cache = None - extra_indices = None - if not swa_only: - assert kv_cache is not None - assert topk_indices is not None or ( - topk_ragged_indices is not None and topk_ragged_indptr is not None + extra_cache=extra_cache, + extra_indices=extra_indices, + main_lengths=swa_lens, + extra_lengths=topk_lens, + main_ragged_indices=swa_ragged_indices, + main_ragged_indptr=swa_ragged_indptr, + extra_ragged_indices=topk_ragged_indices, + extra_ragged_indptr=topk_ragged_indptr, ) - assert kv_cache.dtype == torch.uint8 - extra_cache = kv_cache - if topk_indices is not None: - extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) - - attn_out = _rocm_sparse_attn_decode_triton( - q=q, - main_cache=swa_k_cache, - main_indices=main_indices, - scale=scale, - attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], - nope_head_dim=nope_head_dim, - rope_head_dim=rope_head_dim, - extra_cache=extra_cache, - extra_indices=extra_indices, - main_lengths=swa_lens, - extra_lengths=topk_lens, - main_ragged_indices=swa_ragged_indices, - main_ragged_indptr=swa_ragged_indptr, - extra_ragged_indices=topk_ragged_indices, - extra_ragged_indptr=topk_ragged_indptr, - ) - output.copy_(attn_out.to(output.dtype)) + output.copy_(attn_out.to(output.dtype)) def _rocm_sparse_attn_decode_hip( From 8863212bf33b58c8895e380061cc6396a5e4ad9b Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Thu, 28 May 2026 11:14:59 -0500 Subject: [PATCH 6/9] move sparse mla hip kernel compilation to cmake Signed-off-by: Hemanth Acharya --- CMakeLists.txt | 26 ++++++++++ csrc/rocm/sparse_mla_decode.cu | 9 ++-- setup.py | 2 + vllm/platforms/rocm.py | 9 +++- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 48 +------------------ 5 files changed, 43 insertions(+), 51 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f0b1f53af831..9c16f20ee48c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1277,6 +1277,32 @@ if(VLLM_GPU_LANG STREQUAL "HIP") ARCHITECTURES ${VLLM_GPU_ARCHES} USE_SABI 3 WITH_SOABI) + + # + # _rocm_sparse_mla_C extension (gfx950 only) + # + list(FIND VLLM_GPU_ARCHES "gfx950" _gfx950_idx) + if(NOT _gfx950_idx EQUAL -1) + set(VLLM_ROCM_SPARSE_MLA_SRC + "csrc/rocm/sparse_mla_decode.cu") + + set(VLLM_ROCM_SPARSE_MLA_FLAGS ${VLLM_GPU_FLAGS} + "-Wno-c++11-narrowing") + + define_extension_target( + _rocm_sparse_mla_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_SPARSE_MLA_SRC} + COMPILE_FLAGS ${VLLM_ROCM_SPARSE_MLA_FLAGS} + ARCHITECTURES "gfx950" + USE_SABI 3 + WITH_SOABI) + + message(STATUS "Building ROCm sparse MLA decode kernel for gfx950.") + else() + message(STATUS "Not building ROCm sparse MLA decode kernel (gfx950 not in target architectures).") + endif() endif() # Must run after the last HIP `define_extension_target` so every extension diff --git a/csrc/rocm/sparse_mla_decode.cu b/csrc/rocm/sparse_mla_decode.cu index 0a4c89732fe0..a0eeb6bfc288 100644 --- a/csrc/rocm/sparse_mla_decode.cu +++ b/csrc/rocm/sparse_mla_decode.cu @@ -5,21 +5,20 @@ #include #include -#include +#include #include +#include "core/registration.h" + using bf16x8 = __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16; using fx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; static constexpr int NOPE_DIM = 448; -static constexpr int ROPE_DIM = 64; static constexpr int TOKEN_BYTES = 576; static constexpr int SCALE_BYTES = 8; static constexpr int HEAD_DIM = 512; static constexpr int BLOCK_H = 16; static constexpr int BLOCK_K = 32; -static constexpr int N_TILES = HEAD_DIM / 16; // 32 -static constexpr int QK_N_TILES = BLOCK_K / 16; // 2 __device__ __forceinline__ fx4 mfma_16x16x32_bf16(bf16x8 a, bf16x8 b, fx4 c) { return __builtin_amdgcn_mfma_f32_16x16x32_bf16(a, b, c, 0, 0, 0); @@ -689,3 +688,5 @@ TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { m.impl("decode_single", &sparse_mla_decode_single); m.impl("decode_split", &sparse_mla_decode_split); } + +REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index a95ee3451b58..06d7045c7077 100644 --- a/setup.py +++ b/setup.py @@ -729,6 +729,7 @@ def extract_precompiled_and_patch_package( "vllm/spinloop.abi3.so", # ROCm-specific libraries "vllm/_rocm_C.abi3.so", + "vllm/_rocm_sparse_mla_C.abi3.so", } ) if extract_rust_frontend: @@ -1047,6 +1048,7 @@ def _read_requirements(filename: str) -> list[str]: if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) + ext_modules.append(CMakeExtension(name="vllm._rocm_sparse_mla_C", optional=True)) if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 21e6ca18ea25..642067b5317a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -54,6 +54,11 @@ except ImportError as e: logger.warning("Failed to import from vllm._rocm_C with %r", e) +try: + import vllm._rocm_sparse_mla_C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._rocm_sparse_mla_C with %r", e) + # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: list[str] = [] @@ -441,9 +446,11 @@ def import_kernels(cls) -> None: import contextlib - # Import ROCm-specific extension + # Import ROCm-specific extensions with contextlib.suppress(ImportError): import vllm._rocm_C # noqa: F401 + with contextlib.suppress(ImportError): + import vllm._rocm_sparse_mla_C # noqa: F401 @classmethod def get_valid_backends( diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index e469d1694106..5c2da35a07b3 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -5,13 +5,10 @@ import logging import math import os -import pathlib -import tempfile from importlib.util import find_spec import torch import torch.nn.functional as F -from torch.utils.cpp_extension import load_inline from vllm.compilation.breakable_cudagraph import eager_break_during_capture from vllm.forward_context import get_forward_context @@ -1688,51 +1685,12 @@ def rocm_sparse_attn_prefill( # ============================================================================ # HIP MFMA kernel implementation for sparse-MLA decode. +# Compiled at build time via CMakeLists.txt (_rocm_sparse_mla_C extension). +# Ops are registered under torch.ops.vllm_sparse_mla_hip namespace. # ============================================================================ -_SPARSE_MLA_DECODE_CU = ( - pathlib.Path(__file__).resolve().parents[4] - / "csrc" - / "rocm" - / "sparse_mla_decode.cu" -) - - logger = logging.getLogger(__name__) -_sparse_mla_hip_module_cache: dict = {} - - -def _build_sparse_mla_hip_ext(): - if "ext" in _sparse_mla_hip_module_cache: - return _sparse_mla_hip_module_cache["ext"] - cache_dir = os.environ.get( - "VLLM_SPARSE_MLA_HIP_CACHE_DIR", - str(pathlib.Path(tempfile.gettempdir()) / "vllm_sparse_mla_hip_cache"), - ) - os.makedirs(cache_dir, exist_ok=True) - os.environ["PYTORCH_ROCM_ARCH"] = "gfx950" - ext = load_inline( - name="vllm_sparse_mla_hip", - cpp_sources=[""], - cuda_sources=[_SPARSE_MLA_DECODE_CU.read_text()], - functions=[], - extra_cflags=["-O3", "-DNDEBUG", "-std=c++17"], - extra_cuda_cflags=[ - "-O3", - "-std=c++17", - "--offload-arch=gfx950", - "-DNDEBUG", - "-Wno-c++11-narrowing", - "-Wno-unused-result", - ], - with_cuda=True, - build_directory=cache_dir, - verbose=False, - ) - _sparse_mla_hip_module_cache["ext"] = ext - return ext - def _sparse_mla_hip_as_int32_1d(x): if x.dtype != torch.int32: @@ -1802,8 +1760,6 @@ def _decode_sparse_mla_hip( total_max_k = max_main_len + (max_extra_len if has_extra else 0) split_k = _pick_split_k(num_queries, num_head_blocks, total_max_k) - _build_sparse_mla_hip_ext() - if split_k == 1: torch.ops.vllm_sparse_mla_hip.decode_single( q_in, From 6ca69d637462fd9c49f7f370ab5221ecf27c4698 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Fri, 5 Jun 2026 01:36:20 -0500 Subject: [PATCH 7/9] parameterise the sparse mla hip unit test for dsv4 params Signed-off-by: Hemanth Acharya --- .../test_rocm_sparse_mla_decode_gfx950.py | 369 ++++++++++-------- 1 file changed, 216 insertions(+), 153 deletions(-) diff --git a/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py b/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py index 769119736e37..67aaef9c5bbd 100644 --- a/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py +++ b/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py @@ -3,13 +3,17 @@ """ Unit tests for the HIP MFMA sparse-MLA decode kernels (gfx950). -Tests cover: - - single-WG decode (split_k == 1): main-only, main+extra, with/without attn_sink - - split-K decode: forced via SPARSE_MLA_HIP_SPLIT_K env override - - various batch / head / sequence-length combinations +All tests use the DeepSeek-V4-Pro production attention dims (see HF config +https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/config.json) and +are parametrized over the runtime axes the kernel must support: + * batch (``num_queries``): single decode and small MTP-style batches. + * sequence length per query: the SWA window (``sliding_window=128``) and + the topk budget (``index_topk=1024``). + * split-K: both the single-WG path and the split-reduce path are covered + by forcing the internal split-K override. """ -import os +import contextlib import pytest import torch @@ -32,10 +36,25 @@ def _is_gfx950() -> bool: not _is_gfx950(), reason="Requires ROCm gfx950 hardware" ) +# Per-token KV head dims for DeepSeek-V4-Pro MLA. The HIP kernel is +# hard-coded for these values (576-byte token payload + 8-byte scales): +# nope_head_dim = head_dim(512) - qk_rope_head_dim(64) = 448 +# rope_head_dim = qk_rope_head_dim = 64 NOPE_HEAD_DIM = 448 ROPE_HEAD_DIM = 64 HEAD_DIM = NOPE_HEAD_DIM + ROPE_HEAD_DIM +# DeepSeek-V4-Pro production attention params (HF config.json): +# * num_attention_heads = 128 -> 16 heads per rank at TP=8. +# * ROCM_AITER_MLA_SPARSE backend supports kernel_block_size in {1, 64}; +# paged deployments use 64. +# * sliding_window = 128 -> max main/SWA tokens per query. +# * index_topk = 1024 -> max extra/topk tokens per query. +DSV4_NUM_HEADS = 16 +DSV4_BLOCK_SIZE = 64 +DSV4_SWA_WINDOW = 128 +DSV4_TOPK = 1024 + # --------------------------------------------------------------------------- # Helpers (shared with test_rocm_triton_attn_dsv4.py) @@ -173,162 +192,220 @@ def _call_hip_decode( ) -@torch.inference_mode() -def test_hip_decode_main_only_no_sink() -> None: - """Single-WG decode with main cache only, no attn_sink.""" - device = torch.device("cuda") - torch.manual_seed(42) - block_size = 4 - num_queries, num_heads = 2, 3 - q = ( +@contextlib.contextmanager +def _force_split_k(value: int | None): + """Temporarily override the internal split-K picker. + + The kernel module reads ``SPARSE_MLA_HIP_SPLIT_K`` once at import time + into ``mod._SPLIT_K_OVERRIDE``; tests mutate that module attribute + directly so they can exercise both the single-WG (split_k=1) and the + split-reduce paths regardless of the auto-tuned heuristic. + """ + from vllm.v1.attention.ops import rocm_aiter_mla_sparse as mod + + orig = mod._SPLIT_K_OVERRIDE + if value is not None: + mod._SPLIT_K_OVERRIDE = str(int(value)) + try: + yield + finally: + mod._SPLIT_K_OVERRIDE = orig + + +def _make_q(num_queries: int, num_heads: int, device: torch.device) -> torch.Tensor: + return ( torch.randn( num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device ) * 0.125 ) - main_kv = torch.randn(6, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 - main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) - main_indices = torch.tensor([0, 2, 4, 1], dtype=torch.int32, device=device) - main_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device=device) - scale = HEAD_DIM**-0.5 - actual = _call_hip_decode( - q, - main_cache, - main_indices, - main_indptr, - scale, - attn_sink=None, + +def _build_contiguous_ragged( + num_queries: int, + tokens_per_query: int, + device: torch.device, + start: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]: + """Build (indices, indptr, rows) where each query owns + ``tokens_per_query`` contiguous slots starting at ``start``.""" + rows = [ + list(range(start + qi * tokens_per_query, start + (qi + 1) * tokens_per_query)) + for qi in range(num_queries) + ] + indices = torch.tensor( + [s for r in rows for s in r], dtype=torch.int32, device=device ) - expected = _ref_sparse_decode_ragged( - q, - main_cache, - [[0, 2], [4, 1]], - scale, - attn_sink=None, - block_size=block_size, + indptr = torch.tensor( + [qi * tokens_per_query for qi in range(num_queries + 1)], + dtype=torch.int32, + device=device, ) - torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + return indices, indptr, rows + + +# Parametrize against the realistic DSv4 deployment axes: +# * ``num_queries`` covers single-token decode and small MTP batches. +# * ``split_k`` covers both the single-WG path and the split-reduce path. +# * ``with_sink`` toggles the attention-sink branch. +# All tests pin ``num_heads = DSV4_NUM_HEADS`` (=16, DSv4 at TP=8) and +# ``block_size = DSV4_BLOCK_SIZE`` (=64, the paged-MLA cache block size). @torch.inference_mode() -def test_hip_decode_main_only_with_sink() -> None: - """Single-WG decode with main cache only, with attn_sink.""" +@pytest.mark.parametrize("split_k", [1, 4]) +@pytest.mark.parametrize("num_queries", [1, 4]) +@pytest.mark.parametrize("with_sink", [False, True]) +def test_hip_decode_main_only(num_queries, with_sink, split_k) -> None: + """SWA-only decode (main cache) at DSv4 dims, with/without sink.""" device = torch.device("cuda") - torch.manual_seed(42) - block_size = 4 - num_queries, num_heads = 2, 3 - q = ( + torch.manual_seed(42 + num_queries * 10 + int(with_sink) * 3 + split_k) + tokens_per_query = DSV4_SWA_WINDOW + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( torch.randn( - num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + num_queries * tokens_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, ) * 0.125 ) - main_kv = torch.randn(6, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 - main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) - main_indices = torch.tensor([0, 2, 4, 1], dtype=torch.int32, device=device) - main_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device=device) - attn_sink = torch.tensor([-0.1, 0.0, 0.1], dtype=torch.float32, device=device) + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, tokens_per_query, device + ) + attn_sink = ( + torch.randn(DSV4_NUM_HEADS, dtype=torch.float32, device=device) * 0.1 + if with_sink + else None + ) scale = HEAD_DIM**-0.5 - actual = _call_hip_decode( - q, - main_cache, - main_indices, - main_indptr, - scale, - attn_sink=attn_sink, - ) + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + max_main_len=tokens_per_query, + ) expected = _ref_sparse_decode_ragged( q, main_cache, - [[0, 2], [4, 1]], + main_rows, scale, attn_sink=attn_sink, - block_size=block_size, + block_size=DSV4_BLOCK_SIZE, ) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) @torch.inference_mode() -def test_hip_decode_main_extra_with_sink() -> None: - """Single-WG decode with main + extra cache and attn_sink.""" +@pytest.mark.parametrize("split_k", [1, 4]) +@pytest.mark.parametrize("num_queries", [1, 4]) +@pytest.mark.parametrize("with_sink", [False, True]) +def test_hip_decode_main_extra(num_queries, with_sink, split_k) -> None: + """SWA + topk decode (main + extra caches) at DSv4 dims.""" device = torch.device("cuda") - torch.manual_seed(1) - block_size = 4 - num_queries, num_heads = 2, 3 - q = ( + torch.manual_seed(7 + num_queries * 11 + int(with_sink) * 5 + split_k) + # ``main`` carries SWA tokens; ``extra`` carries topk tokens. Use a + # modest extra length (rather than the full DSV4_TOPK=1024) to keep + # the Python reference fast while still exercising both code paths. + main_per_query = DSV4_SWA_WINDOW + extra_per_query = DSV4_BLOCK_SIZE * 2 # 128 topk tokens + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( torch.randn( - num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + num_queries * main_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, + ) + * 0.125 + ) + extra_kv = ( + torch.randn( + num_queries * extra_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, ) * 0.125 ) - main_kv = torch.randn(6, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 - extra_kv = torch.randn(5, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 - main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) - extra_cache = _pack_fp8_ds_mla_cache(extra_kv, block_size) - main_indices = torch.tensor([0, 2, 4, 1], dtype=torch.int32, device=device) - main_indptr = torch.tensor([0, 2, 4], dtype=torch.int32, device=device) - extra_indices = torch.tensor([1, 3, 0], dtype=torch.int32, device=device) - extra_indptr = torch.tensor([0, 1, 3], dtype=torch.int32, device=device) - attn_sink = torch.tensor([-0.1, 0.0, 0.1], dtype=torch.float32, device=device) + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + extra_cache = _pack_fp8_ds_mla_cache(extra_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, main_per_query, device + ) + extra_indices, extra_indptr, extra_rows = _build_contiguous_ragged( + num_queries, extra_per_query, device + ) + attn_sink = ( + torch.randn(DSV4_NUM_HEADS, dtype=torch.float32, device=device) * 0.1 + if with_sink + else None + ) scale = HEAD_DIM**-0.5 - actual = _call_hip_decode( - q, - main_cache, - main_indices, - main_indptr, - scale, - attn_sink=attn_sink, - extra_cache=extra_cache, - extra_indices=extra_indices, - extra_indptr=extra_indptr, - ) + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + extra_cache=extra_cache, + extra_indices=extra_indices, + extra_indptr=extra_indptr, + max_main_len=main_per_query, + max_extra_len=extra_per_query, + ) expected = _ref_sparse_decode_ragged( q, main_cache, - [[0, 2], [4, 1]], + main_rows, scale, attn_sink=attn_sink, - block_size=block_size, + block_size=DSV4_BLOCK_SIZE, extra_cache=extra_cache, - extra_rows=[[1], [3, 0]], + extra_rows=extra_rows, ) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) @torch.inference_mode() -def test_hip_decode_split_k() -> None: - """Force split-K path (split_k=2) and verify correctness.""" +@pytest.mark.parametrize("split_k", [2, 4, 8]) +@pytest.mark.parametrize("num_queries", [1, 4]) +def test_hip_decode_split_k(num_queries, split_k) -> None: + """Force the split-K reduce path with the DSv4 topk-sized seqlen.""" device = torch.device("cuda") - torch.manual_seed(7) - block_size = 4 - num_queries, num_heads = 1, 16 - num_tokens = 128 - q = ( + torch.manual_seed(2026 + num_queries * 13 + split_k) + tokens_per_query = DSV4_TOPK # full DSv4 topk budget + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( torch.randn( - num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + num_queries * tokens_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, ) * 0.125 ) - main_kv = ( - torch.randn(num_tokens, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, tokens_per_query, device ) - main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) - main_indices = torch.arange(num_tokens, dtype=torch.int32, device=device) - main_indptr = torch.tensor([0, num_tokens], dtype=torch.int32, device=device) - attn_sink = torch.randn(num_heads, dtype=torch.float32, device=device) * 0.1 + attn_sink = torch.randn(DSV4_NUM_HEADS, dtype=torch.float32, device=device) * 0.1 scale = HEAD_DIM**-0.5 - old_val = os.environ.get("SPARSE_MLA_HIP_SPLIT_K") - try: - os.environ["SPARSE_MLA_HIP_SPLIT_K"] = "2" - from vllm.v1.attention.ops import rocm_aiter_mla_sparse as mod - - orig = mod._SPLIT_K_OVERRIDE - mod._SPLIT_K_OVERRIDE = "2" - + with _force_split_k(split_k): actual = _call_hip_decode( q, main_cache, @@ -336,78 +413,64 @@ def test_hip_decode_split_k() -> None: main_indptr, scale, attn_sink=attn_sink, - max_main_len=num_tokens, + max_main_len=tokens_per_query, ) - finally: - mod._SPLIT_K_OVERRIDE = orig - if old_val is None: - os.environ.pop("SPARSE_MLA_HIP_SPLIT_K", None) - else: - os.environ["SPARSE_MLA_HIP_SPLIT_K"] = old_val - - main_rows = [list(range(num_tokens))] expected = _ref_sparse_decode_ragged( q, main_cache, main_rows, scale, attn_sink=attn_sink, - block_size=block_size, + block_size=DSV4_BLOCK_SIZE, ) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) @torch.inference_mode() +@pytest.mark.parametrize("split_k", [1, 4]) @pytest.mark.parametrize("num_queries", [1, 4]) -@pytest.mark.parametrize("num_heads", [3, 16, 32]) -@pytest.mark.parametrize("seq_len", [32, 64, 128]) -def test_hip_decode_shapes(num_queries, num_heads, seq_len) -> None: - """Parametrized test over different batch/head/seqlen combos.""" +@pytest.mark.parametrize( + "tokens_per_query", + [DSV4_BLOCK_SIZE, DSV4_SWA_WINDOW, DSV4_TOPK], + ids=["one_block", "swa_window", "topk"], +) +def test_hip_decode_seqlen(num_queries, tokens_per_query, split_k) -> None: + """Sweep DSv4 sequence lengths: one block, SWA window, full topk.""" device = torch.device("cuda") - torch.manual_seed(num_queries * 1000 + num_heads * 10 + seq_len) - block_size = 4 - q = ( + torch.manual_seed(num_queries * 1009 + tokens_per_query + split_k) + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( torch.randn( - num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + num_queries * tokens_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, ) * 0.125 ) - main_kv = ( - torch.randn(seq_len, HEAD_DIM, dtype=torch.bfloat16, device=device) * 0.125 + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, tokens_per_query, device ) - main_cache = _pack_fp8_ds_mla_cache(main_kv, block_size) scale = HEAD_DIM**-0.5 - tokens_per_query = seq_len // num_queries - main_rows = [] - indices_list = [] - indptr = [0] - for qi in range(num_queries): - start = qi * tokens_per_query - end = start + tokens_per_query - row_slots = list(range(start, end)) - main_rows.append(row_slots) - indices_list.extend(row_slots) - indptr.append(indptr[-1] + len(row_slots)) - - main_indices = torch.tensor(indices_list, dtype=torch.int32, device=device) - main_indptr = torch.tensor(indptr, dtype=torch.int32, device=device) - - actual = _call_hip_decode( - q, - main_cache, - main_indices, - main_indptr, - scale, - attn_sink=None, - max_main_len=tokens_per_query, - ) + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=None, + max_main_len=tokens_per_query, + ) expected = _ref_sparse_decode_ragged( q, main_cache, main_rows, scale, attn_sink=None, - block_size=block_size, + block_size=DSV4_BLOCK_SIZE, ) torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) From 3aa14d2c3b4aeaf9554a01d7c8472c4d9e6c1ad5 Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Fri, 5 Jun 2026 03:04:00 -0500 Subject: [PATCH 8/9] remove sparse mla extension Signed-off-by: Hemanth Acharya --- CMakeLists.txt | 33 +---- csrc/rocm/sparse_mla_decode.cu | 126 ++++++++++++------ setup.py | 2 - vllm/platforms/rocm.py | 7 - .../v1/attention/ops/rocm_aiter_mla_sparse.py | 5 +- 5 files changed, 92 insertions(+), 81 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 3000eda8d2dd..e9e25a3b1275 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1306,6 +1306,13 @@ if(VLLM_GPU_LANG STREQUAL "HIP") "csrc/rocm/q_gemm_rdna3_wmma.cu") endif() + if(VLLM_GPU_ARCHES MATCHES "gfx950") + list(APPEND VLLM_ROCM_EXT_SRC + "csrc/rocm/sparse_mla_decode.cu") + set_source_files_properties("csrc/rocm/sparse_mla_decode.cu" + PROPERTIES COMPILE_OPTIONS "-Wno-c++11-narrowing") + endif() + define_extension_target( _rocm_C DESTINATION vllm @@ -1316,32 +1323,6 @@ if(VLLM_GPU_LANG STREQUAL "HIP") USE_SABI 3 WITH_SOABI) - # - # _rocm_sparse_mla_C extension (gfx950 only) - # - list(FIND VLLM_GPU_ARCHES "gfx950" _gfx950_idx) - if(NOT _gfx950_idx EQUAL -1) - set(VLLM_ROCM_SPARSE_MLA_SRC - "csrc/rocm/sparse_mla_decode.cu") - - set(VLLM_ROCM_SPARSE_MLA_FLAGS ${VLLM_GPU_FLAGS} - "-Wno-c++11-narrowing") - - define_extension_target( - _rocm_sparse_mla_C - DESTINATION vllm - LANGUAGE ${VLLM_GPU_LANG} - SOURCES ${VLLM_ROCM_SPARSE_MLA_SRC} - COMPILE_FLAGS ${VLLM_ROCM_SPARSE_MLA_FLAGS} - ARCHITECTURES "gfx950" - USE_SABI 3 - WITH_SOABI) - - message(STATUS "Building ROCm sparse MLA decode kernel for gfx950.") - else() - message(STATUS "Not building ROCm sparse MLA decode kernel (gfx950 not in target architectures).") - endif() - if(VLLM_ROCM_HAS_GFX1100) target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX1100) endif() diff --git a/csrc/rocm/sparse_mla_decode.cu b/csrc/rocm/sparse_mla_decode.cu index a0eeb6bfc288..264dfb788dac 100644 --- a/csrc/rocm/sparse_mla_decode.cu +++ b/csrc/rocm/sparse_mla_decode.cu @@ -8,17 +8,32 @@ #include #include -#include "core/registration.h" - using bf16x8 = __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16; using fx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; -static constexpr int NOPE_DIM = 448; -static constexpr int TOKEN_BYTES = 576; -static constexpr int SCALE_BYTES = 8; -static constexpr int HEAD_DIM = 512; +// Most of these are referenced only from the gfx950 device kernels below. +// On non-gfx950 device passes the `#else` empty-stub branch hides those +// references, so we annotate them `[[maybe_unused]]` to silence +// `-Werror=unused-const-variable` for those passes. `BLOCK_H` is also used +// by the host launchers further down, so it is genuinely live on every pass. +[[maybe_unused]] static constexpr int NOPE_DIM = 448; +[[maybe_unused]] static constexpr int TOKEN_BYTES = 576; +[[maybe_unused]] static constexpr int SCALE_BYTES = 8; +[[maybe_unused]] static constexpr int HEAD_DIM = 512; static constexpr int BLOCK_H = 16; -static constexpr int BLOCK_K = 32; +[[maybe_unused]] static constexpr int BLOCK_K = 32; + +// gfx950 device-only path: the MFMA bf16 and fp8 conversion builtins below +// require gfx950 MFMA support. Following the q_gemm_rdna3.cu precedent, we +// gate the full device code on `__HIP__GFX950__ || !__HIP_DEVICE_COMPILE__` +// so the kernel bodies remain visible to the host compilation pass (for +// linkage of the host launchers) while non-gfx950 device passes get empty +// `__global__` stubs (see the `#else` branch at the end of this block). +#if defined(__HIPCC__) && defined(__gfx950__) + #define __HIP__GFX950__ +#endif + +#if defined(__HIP__GFX950__) || !defined(__HIP_DEVICE_COMPILE__) __device__ __forceinline__ fx4 mfma_16x16x32_bf16(bf16x8 a, bf16x8 b, fx4 c) { return __builtin_amdgcn_mfma_f32_16x16x32_bf16(a, b, c, 0, 0, 0); @@ -48,7 +63,7 @@ __device__ __forceinline__ void gather_and_dequant_k_tile( int4 z; z.x = z.y = z.z = z.w = 0; int4* d4 = reinterpret_cast(dst_row); -#pragma unroll + #pragma unroll for (int j = 0; j < 8; ++j) d4[j] = z; } else if (col0 < NOPE_DIM) { const uint8_t* scale_ptr = @@ -62,10 +77,10 @@ __device__ __forceinline__ void gather_and_dequant_k_tile( float scl_f = sb.fv; const uint32_t* src32 = reinterpret_cast(token_ptr + col0); -#pragma unroll + #pragma unroll for (int u32_i = 0; u32_i < 16; ++u32_i) { uint32_t word = src32[u32_i]; -#pragma unroll + #pragma unroll for (int b = 0; b < 4; ++b) { uint8_t kb = (word >> (b * 8)) & 0xFF; uint32_t packed = (uint32_t)kb; @@ -76,7 +91,7 @@ __device__ __forceinline__ void gather_and_dequant_k_tile( } else { const int4* src4 = reinterpret_cast(token_ptr + NOPE_DIM); int4* d4 = reinterpret_cast(dst_row); -#pragma unroll + #pragma unroll for (int j = 0; j < 8; ++j) d4[j] = src4[j]; } @@ -95,26 +110,26 @@ __device__ __forceinline__ void process_k_tile( int wave) { if (wave == 0) { fx4 qk[2] = {{0.f, 0.f, 0.f, 0.f}, {0.f, 0.f, 0.f, 0.f}}; -#pragma unroll + #pragma unroll for (int c = 0; c < HEAD_DIM / 32; ++c) { bf16x8 q_reg; const __bf16* q_src = &q_lds[m_a * HEAD_DIM + c * 32 + kg * 8]; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; ++i) q_reg[i] = q_src[i]; -#pragma unroll + #pragma unroll for (int nt = 0; nt < 2; ++nt) { bf16x8 k_reg; const __bf16* k_src = &k_lds[(nt * 16 + n_b) * HEAD_DIM + c * 32 + kg * 8]; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; ++i) k_reg[i] = k_src[i]; qk[nt] = mfma_16x16x32_bf16(q_reg, k_reg, qk[nt]); } } -#pragma unroll + #pragma unroll for (int nt = 0; nt < 2; ++nt) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { int k_col = nt * 16 + n_d; float s = qk[nt][i] * scale; @@ -127,14 +142,14 @@ __device__ __forceinline__ void process_k_tile( __syncthreads(); fx4 qk_local[2]; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { qk_local[0][i] = scores_lds[(m_d_base + i) * BLOCK_K + n_d]; qk_local[1][i] = scores_lds[(m_d_base + i) * BLOCK_K + 16 + n_d]; } fx4 p[2]; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { float row_max = fmaxf(qk_local[0][i], qk_local[1][i]); row_max = fmaxf(row_max, __shfl_xor(row_max, 1)); @@ -161,7 +176,7 @@ __device__ __forceinline__ void process_k_tile( p[0][i] = e0; p[1][i] = e1; -#pragma unroll + #pragma unroll for (int nt = 0; nt < N_TILES_PER_WAVE; ++nt) acc[nt][i] *= alpha; m_state[i] = m_new; @@ -169,7 +184,7 @@ __device__ __forceinline__ void process_k_tile( } if (wave == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { p_lds[(m_d_base + i) * BLOCK_K + n_d] = (__bf16)p[0][i]; p_lds[(m_d_base + i) * BLOCK_K + 16 + n_d] = (__bf16)p[1][i]; @@ -180,14 +195,14 @@ __device__ __forceinline__ void process_k_tile( bf16x8 p_reg; const __bf16* p_src = &p_lds[m_a * BLOCK_K + kg * 8]; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; ++i) p_reg[i] = p_src[i]; -#pragma unroll + #pragma unroll for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { int n_tile = wave * N_TILES_PER_WAVE + nt_local; bf16x8 k_reg; -#pragma unroll + #pragma unroll for (int i = 0; i < 8; ++i) { k_reg[i] = k_lds[(kg * 8 + i) * HEAD_DIM + n_tile * 16 + n_b]; } @@ -206,13 +221,13 @@ __device__ __forceinline__ void load_q(const __bf16* q, int64_t q_stride0, const __bf16* src = q + query * q_stride0 + head_global * q_stride1 + qc0; const int4* s4 = reinterpret_cast(src); int4* d4 = reinterpret_cast(dst); -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) d4[i] = s4[i]; } else { int4 z; z.x = z.y = z.z = z.w = 0; int4* d4 = reinterpret_cast(dst); -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) d4[i] = z; } } @@ -254,12 +269,12 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_kernel( float m_state[4], l_state[4]; fx4 acc[N_TILES_PER_WAVE]; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { m_state[i] = -3.4028234663852886e38f; l_state[i] = 0.f; } -#pragma unroll + #pragma unroll for (int i = 0; i < N_TILES_PER_WAVE; ++i) { acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; } @@ -297,7 +312,7 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_kernel( } { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { int head_local = m_d_base + i; int head_global = pid_h * BLOCK_H + head_local; @@ -320,7 +335,7 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_kernel( __bf16* out_row = output + query * out_stride0 + head_global * out_stride1; -#pragma unroll + #pragma unroll for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { int n_tile = wave * N_TILES_PER_WAVE + nt_local; int col = n_tile * 16 + n_d; @@ -367,12 +382,12 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_partial_kernel( float m_state[4], l_state[4]; fx4 acc[N_TILES_PER_WAVE]; -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { m_state[i] = -3.4028234663852886e38f; l_state[i] = 0.f; } -#pragma unroll + #pragma unroll for (int i = 0; i < N_TILES_PER_WAVE; ++i) { acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; } @@ -414,7 +429,7 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_partial_kernel( const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + pid_split; if (wave == 0 && n_d == 0) { -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { int idx = triple * BLOCK_H + m_d_base + i; scratch_m[idx] = m_state[i]; @@ -423,10 +438,10 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_partial_kernel( } __syncthreads(); -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { int row = m_d_base + i; -#pragma unroll + #pragma unroll for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { int n_tile = wave * N_TILES_PER_WAVE + nt_local; int col = n_tile * 16 + n_d; @@ -443,7 +458,7 @@ __global__ __launch_bounds__(256, 2) void sparse_mla_decode_partial_kernel( const int4* src4 = reinterpret_cast(&k_lds[my_row * HEAD_DIM + my_col0]); int4* dst4 = reinterpret_cast(dst); -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) dst4[i] = src4[i]; } } @@ -465,10 +480,10 @@ __global__ __launch_bounds__(256, 4) void sparse_mla_decode_reduce_kernel( float m_merged = -3.4028234663852886e38f; float l_merged = 0.f; float acc_merged[32]; -#pragma unroll + #pragma unroll for (int i = 0; i < 32; ++i) acc_merged[i] = 0.f; -#pragma unroll + #pragma unroll for (int s = 0; s < SPLIT_K; ++s) { const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + s; float m_s = scratch_m[triple * BLOCK_H + my_row]; @@ -485,12 +500,12 @@ __global__ __launch_bounds__(256, 4) void sparse_mla_decode_reduce_kernel( (int64_t)triple * BLOCK_H * HEAD_DIM + my_row * HEAD_DIM + my_col0; const int4* src4 = reinterpret_cast(acc_base); -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) { int4 v = src4[i]; __bf16 vbf[8]; *reinterpret_cast(vbf) = v; -#pragma unroll + #pragma unroll for (int j = 0; j < 8; ++j) { float a_s = (float)vbf[j]; acc_merged[i * 8 + j] = acc_merged[i * 8 + j] * alpha + a_s * beta; @@ -519,14 +534,39 @@ __global__ __launch_bounds__(256, 4) void sparse_mla_decode_reduce_kernel( __bf16* out_row = output + query * out_stride0 + head_global * out_stride1 + my_col0; __bf16 out_buf[32]; -#pragma unroll + #pragma unroll for (int i = 0; i < 32; ++i) out_buf[i] = (__bf16)(acc_merged[i] * inv_denom); int4* dst4 = reinterpret_cast(out_row); const int4* sb4 = reinterpret_cast(out_buf); -#pragma unroll + #pragma unroll for (int i = 0; i < 4; ++i) dst4[i] = sb4[i]; } +#else // non-gfx950 device pass: empty __global__ stubs for symbol parity. + +template +__global__ void sparse_mla_decode_kernel(const __bf16*, const uint8_t*, + const int32_t*, const int32_t*, + const uint8_t*, const int32_t*, + const int32_t*, const float*, __bf16*, + int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int, int, int, int, + float, int) {} + +template +__global__ void sparse_mla_decode_partial_kernel( + const __bf16*, const uint8_t*, const int32_t*, const int32_t*, + const uint8_t*, const int32_t*, const int32_t*, float*, float*, __bf16*, + int64_t, int64_t, int64_t, int64_t, int, int, int, int, float, int, int) {} + +template +__global__ void sparse_mla_decode_reduce_kernel(const float*, const float*, + const __bf16*, const float*, + __bf16*, int64_t, int64_t, int, + int) {} + +#endif // __HIP__GFX950__ || !__HIP_DEVICE_COMPILE__ + void sparse_mla_decode_single( torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, torch::Tensor main_indptr, torch::Tensor extra_cache, @@ -688,5 +728,3 @@ TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { m.impl("decode_single", &sparse_mla_decode_single); m.impl("decode_split", &sparse_mla_decode_split); } - -REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/setup.py b/setup.py index 0c9202422280..07374807bee6 100644 --- a/setup.py +++ b/setup.py @@ -729,7 +729,6 @@ def extract_precompiled_and_patch_package( "vllm/spinloop.abi3.so", # ROCm-specific libraries "vllm/_rocm_C.abi3.so", - "vllm/_rocm_sparse_mla_C.abi3.so", } ) if extract_rust_frontend: @@ -1048,7 +1047,6 @@ def _read_requirements(filename: str) -> list[str]: if _is_hip(): ext_modules.append(CMakeExtension(name="vllm._rocm_C")) - ext_modules.append(CMakeExtension(name="vllm._rocm_sparse_mla_C", optional=True)) if _is_cuda(): ext_modules.append(CMakeExtension(name="vllm.vllm_flash_attn._vllm_fa2_C")) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 932baf81b8c4..a763dc09c117 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -54,11 +54,6 @@ except ImportError as e: logger.warning("Failed to import from vllm._rocm_C with %r", e) -try: - import vllm._rocm_sparse_mla_C # noqa: F401 -except ImportError as e: - logger.warning("Failed to import from vllm._rocm_sparse_mla_C with %r", e) - # Models not supported by ROCm. _ROCM_UNSUPPORTED_MODELS: list[str] = [] @@ -467,8 +462,6 @@ def import_kernels(cls) -> None: # Import ROCm-specific extensions with contextlib.suppress(ImportError): import vllm._rocm_C # noqa: F401 - with contextlib.suppress(ImportError): - import vllm._rocm_sparse_mla_C # noqa: F401 @classmethod def get_valid_backends( diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 5633d52ccd62..b9713d400175 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -1688,8 +1688,9 @@ def rocm_sparse_attn_prefill( # ============================================================================ # HIP MFMA kernel implementation for sparse-MLA decode. -# Compiled at build time via CMakeLists.txt (_rocm_sparse_mla_C extension). -# Ops are registered under torch.ops.vllm_sparse_mla_hip namespace. +# Compiled at build time via CMakeLists.txt into the _rocm_C extension when +# gfx950 is in VLLM_GPU_ARCHES (see csrc/rocm/sparse_mla_decode.cu). +# Ops are registered under the torch.ops.vllm_sparse_mla_hip namespace. # ============================================================================ logger = logging.getLogger(__name__) From 74334cedf615632e3fd06e09067034f4ace2533d Mon Sep 17 00:00:00 2001 From: Hemanth Acharya Date: Mon, 8 Jun 2026 03:45:51 -0500 Subject: [PATCH 9/9] remove sparse mla extension Signed-off-by: Hemanth Acharya --- CMakeLists.txt | 5 +++++ csrc/rocm/ops.h | 17 +++++++++++++++++ csrc/rocm/sparse_mla_decode.cu | 24 ++---------------------- csrc/rocm/torch_bindings.cpp | 22 ++++++++++++++++++++++ 4 files changed, 46 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4259d5cb9b4a..5f809c98ca38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1313,7 +1313,9 @@ if(VLLM_GPU_LANG STREQUAL "HIP") "csrc/rocm/q_gemm_rdna3_wmma.cu") endif() + set(VLLM_ROCM_HAS_GFX950 OFF) if(VLLM_GPU_ARCHES MATCHES "gfx950") + set(VLLM_ROCM_HAS_GFX950 ON) list(APPEND VLLM_ROCM_EXT_SRC "csrc/rocm/sparse_mla_decode.cu") set_source_files_properties("csrc/rocm/sparse_mla_decode.cu" @@ -1333,6 +1335,9 @@ if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_ROCM_HAS_GFX1100) target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX1100) endif() + if(VLLM_ROCM_HAS_GFX950) + target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX950) + endif() endif() # Must run after the last HIP `define_extension_target` so every extension diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 73197d8a5e20..9214a2ddc4a2 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -37,3 +37,20 @@ void paged_attention( const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const std::optional& fp8_out_scale, const std::string& mfma_type); + +void sparse_mla_decode_single( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + const std::optional& attn_sink, torch::Tensor output, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale, bool has_extra); + +void sparse_mla_decode_split( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + const std::optional& attn_sink, torch::Tensor output, + torch::Tensor scratch_m, torch::Tensor scratch_l, torch::Tensor scratch_acc, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale, bool has_extra, int64_t split_k); diff --git a/csrc/rocm/sparse_mla_decode.cu b/csrc/rocm/sparse_mla_decode.cu index 264dfb788dac..c5f8c18c868a 100644 --- a/csrc/rocm/sparse_mla_decode.cu +++ b/csrc/rocm/sparse_mla_decode.cu @@ -571,7 +571,7 @@ void sparse_mla_decode_single( torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, torch::Tensor main_indptr, torch::Tensor extra_cache, torch::Tensor extra_indices, torch::Tensor extra_indptr, - c10::optional attn_sink, torch::Tensor output, + const std::optional& attn_sink, torch::Tensor output, int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, int64_t extra_num_rows, double scale_d, bool has_extra) { const int num_queries = q.size(0); @@ -624,7 +624,7 @@ void sparse_mla_decode_split( torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, torch::Tensor main_indptr, torch::Tensor extra_cache, torch::Tensor extra_indices, torch::Tensor extra_indptr, - c10::optional attn_sink, torch::Tensor output, + const std::optional& attn_sink, torch::Tensor output, torch::Tensor scratch_m, torch::Tensor scratch_l, torch::Tensor scratch_acc, int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, int64_t extra_num_rows, double scale_d, bool has_extra, int64_t split_k) { @@ -708,23 +708,3 @@ void sparse_mla_decode_split( #undef LAUNCH_P #undef LAUNCH_R } - -TORCH_LIBRARY_FRAGMENT(vllm_sparse_mla_hip, m) { - m.def( - "decode_single(Tensor q, Tensor main_cache, Tensor main_indices, " - "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " - "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " - "int main_block_size, int extra_block_size, int main_num_rows, " - "int extra_num_rows, float scale, bool has_extra) -> ()"); - m.def( - "decode_split(Tensor q, Tensor main_cache, Tensor main_indices, " - "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " - "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " - "Tensor scratch_m, Tensor scratch_l, Tensor scratch_acc, " - "int main_block_size, int extra_block_size, int main_num_rows, " - "int extra_num_rows, float scale, bool has_extra, int split_k) -> ()"); -} -TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { - m.impl("decode_single", &sparse_mla_decode_single); - m.impl("decode_split", &sparse_mla_decode_split); -} diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 1e589598c742..240bddf28b12 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -73,4 +73,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } +#ifdef VLLM_ROCM_GFX950 +TORCH_LIBRARY_FRAGMENT(vllm_sparse_mla_hip, m) { + m.def( + "decode_single(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra) -> ()"); + m.def( + "decode_split(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "Tensor scratch_m, Tensor scratch_l, Tensor scratch_acc, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra, int split_k) -> ()"); +} +TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { + m.impl("decode_single", &sparse_mla_decode_single); + m.impl("decode_split", &sparse_mla_decode_split); +} +#endif + REGISTER_EXTENSION(TORCH_EXTENSION_NAME)