diff --git a/CMakeLists.txt b/CMakeLists.txt index d956e29e3990..954c57a03543 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1312,6 +1312,15 @@ if(VLLM_GPU_LANG STREQUAL "HIP") "csrc/rocm/moe_q_gemm_rdna3.cu") endif() + set(VLLM_ROCM_HAS_GFX950 OFF) + if(VLLM_GPU_ARCHES MATCHES "gfx950") + set(VLLM_ROCM_HAS_GFX950 ON) + list(APPEND VLLM_ROCM_EXT_SRC + "csrc/rocm/sparse_mla_decode.cu") + set_source_files_properties("csrc/rocm/sparse_mla_decode.cu" + PROPERTIES COMPILE_OPTIONS "-Wno-c++11-narrowing") + endif() + define_extension_target( _rocm_C DESTINATION vllm @@ -1325,6 +1334,9 @@ if(VLLM_GPU_LANG STREQUAL "HIP") if(VLLM_ROCM_HAS_GFX1100) target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX1100) endif() + if(VLLM_ROCM_HAS_GFX950) + target_compile_definitions(_rocm_C PRIVATE VLLM_ROCM_GFX950) + endif() endif() # Must run after the last HIP `define_extension_target` so every extension diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 549d50300d6f..fe9c5a7306a5 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -46,3 +46,20 @@ void paged_attention( const std::string& kv_cache_dtype, torch::Tensor& k_scale, torch::Tensor& v_scale, const std::optional& fp8_out_scale, const std::string& mfma_type); + +void sparse_mla_decode_single( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + const std::optional& attn_sink, torch::Tensor output, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale, bool has_extra); + +void sparse_mla_decode_split( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + const std::optional& attn_sink, torch::Tensor output, + torch::Tensor scratch_m, torch::Tensor scratch_l, torch::Tensor scratch_acc, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale, bool has_extra, int64_t split_k); diff --git a/csrc/rocm/sparse_mla_decode.cu b/csrc/rocm/sparse_mla_decode.cu new file mode 100644 index 000000000000..c5f8c18c868a --- /dev/null +++ b/csrc/rocm/sparse_mla_decode.cu @@ -0,0 +1,710 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#include +#include +#include + +#include +#include + +using bf16x8 = __attribute__((__vector_size__(8 * sizeof(__bf16)))) __bf16; +using fx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + +// Most of these are referenced only from the gfx950 device kernels below. +// On non-gfx950 device passes the `#else` empty-stub branch hides those +// references, so we annotate them `[[maybe_unused]]` to silence +// `-Werror=unused-const-variable` for those passes. `BLOCK_H` is also used +// by the host launchers further down, so it is genuinely live on every pass. +[[maybe_unused]] static constexpr int NOPE_DIM = 448; +[[maybe_unused]] static constexpr int TOKEN_BYTES = 576; +[[maybe_unused]] static constexpr int SCALE_BYTES = 8; +[[maybe_unused]] static constexpr int HEAD_DIM = 512; +static constexpr int BLOCK_H = 16; +[[maybe_unused]] static constexpr int BLOCK_K = 32; + +// gfx950 device-only path: the MFMA bf16 and fp8 conversion builtins below +// require gfx950 MFMA support. Following the q_gemm_rdna3.cu precedent, we +// gate the full device code on `__HIP__GFX950__ || !__HIP_DEVICE_COMPILE__` +// so the kernel bodies remain visible to the host compilation pass (for +// linkage of the host launchers) while non-gfx950 device passes get empty +// `__global__` stubs (see the `#else` branch at the end of this block). +#if defined(__HIPCC__) && defined(__gfx950__) + #define __HIP__GFX950__ +#endif + +#if defined(__HIP__GFX950__) || !defined(__HIP_DEVICE_COMPILE__) + +__device__ __forceinline__ fx4 mfma_16x16x32_bf16(bf16x8 a, bf16x8 b, fx4 c) { + return __builtin_amdgcn_mfma_f32_16x16x32_bf16(a, b, c, 0, 0, 0); +} + +__device__ __forceinline__ void gather_and_dequant_k_tile( + int k_start, int k_len, const uint8_t* cache_base, int64_t cache_stride0, + int num_rows, int block_size, const int32_t* idx_base, __bf16* k_lds, + int8_t* kv_lds, int tid) { + const int tok_id = tid >> 3; // 0..31 + const int chunk = tid & 7; // 0..7 + const int col0 = chunk * 64; + + int k_pos = k_start + tok_id; + bool in_range = (k_pos < k_len); + int slot = in_range ? idx_base[k_pos] : 0; + bool valid = in_range && (slot >= 0) && (slot < num_rows); + int safe_slot = valid ? slot : 0; + int bi = safe_slot / block_size; + int pib = safe_slot - bi * block_size; + const uint8_t* block_ptr = cache_base + (int64_t)bi * cache_stride0; + const uint8_t* token_ptr = block_ptr + pib * TOKEN_BYTES; + + __bf16* dst_row = &k_lds[tok_id * HEAD_DIM + col0]; + + if (!valid) { + int4 z; + z.x = z.y = z.z = z.w = 0; + int4* d4 = reinterpret_cast(dst_row); + #pragma unroll + for (int j = 0; j < 8; ++j) d4[j] = z; + } else if (col0 < NOPE_DIM) { + const uint8_t* scale_ptr = + block_ptr + block_size * TOKEN_BYTES + pib * SCALE_BYTES; + uint8_t scl_u = scale_ptr[chunk]; + union { + uint32_t u; + float fv; + } sb; + sb.u = ((uint32_t)scl_u) << 23; + float scl_f = sb.fv; + + const uint32_t* src32 = reinterpret_cast(token_ptr + col0); + #pragma unroll + for (int u32_i = 0; u32_i < 16; ++u32_i) { + uint32_t word = src32[u32_i]; + #pragma unroll + for (int b = 0; b < 4; ++b) { + uint8_t kb = (word >> (b * 8)) & 0xFF; + uint32_t packed = (uint32_t)kb; + float f = __builtin_amdgcn_cvt_f32_fp8(packed, 0) * scl_f; + dst_row[u32_i * 4 + b] = (__bf16)f; + } + } + } else { + const int4* src4 = reinterpret_cast(token_ptr + NOPE_DIM); + int4* d4 = reinterpret_cast(dst_row); + #pragma unroll + for (int j = 0; j < 8; ++j) d4[j] = src4[j]; + } + + if (tid < BLOCK_K) { + int kp = k_start + tid; + int sl = (kp < k_len) ? idx_base[kp] : -1; + kv_lds[tid] = (kp < k_len) && (sl >= 0) && (sl < num_rows) ? 1 : 0; + } +} + +constexpr int N_TILES_PER_WAVE = 8; // 32 N-tiles / 4 waves +__device__ __forceinline__ void process_k_tile( + const __bf16* q_lds, const __bf16* k_lds, const int8_t* kv_lds, + __bf16* p_lds, float* scores_lds, float* m_state, float* l_state, fx4* acc, + float scale, int lane, int m_a, int kg, int n_b, int m_d_base, int n_d, + int wave) { + if (wave == 0) { + fx4 qk[2] = {{0.f, 0.f, 0.f, 0.f}, {0.f, 0.f, 0.f, 0.f}}; + #pragma unroll + for (int c = 0; c < HEAD_DIM / 32; ++c) { + bf16x8 q_reg; + const __bf16* q_src = &q_lds[m_a * HEAD_DIM + c * 32 + kg * 8]; + #pragma unroll + for (int i = 0; i < 8; ++i) q_reg[i] = q_src[i]; + + #pragma unroll + for (int nt = 0; nt < 2; ++nt) { + bf16x8 k_reg; + const __bf16* k_src = + &k_lds[(nt * 16 + n_b) * HEAD_DIM + c * 32 + kg * 8]; + #pragma unroll + for (int i = 0; i < 8; ++i) k_reg[i] = k_src[i]; + qk[nt] = mfma_16x16x32_bf16(q_reg, k_reg, qk[nt]); + } + } + #pragma unroll + for (int nt = 0; nt < 2; ++nt) { + #pragma unroll + for (int i = 0; i < 4; ++i) { + int k_col = nt * 16 + n_d; + float s = qk[nt][i] * scale; + if (!kv_lds[k_col]) s = -3.4028234663852886e38f; + scores_lds[(m_d_base + i) * BLOCK_K + nt * 16 + n_d] = s; + } + } + } + + __syncthreads(); + + fx4 qk_local[2]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + qk_local[0][i] = scores_lds[(m_d_base + i) * BLOCK_K + n_d]; + qk_local[1][i] = scores_lds[(m_d_base + i) * BLOCK_K + 16 + n_d]; + } + + fx4 p[2]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + float row_max = fmaxf(qk_local[0][i], qk_local[1][i]); + row_max = fmaxf(row_max, __shfl_xor(row_max, 1)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 2)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 4)); + row_max = fmaxf(row_max, __shfl_xor(row_max, 8)); + + float m_new = fmaxf(m_state[i], row_max); + float alpha = + __builtin_amdgcn_exp2f((m_state[i] - m_new) * 1.4426950408889634f); + + float e0 = + __builtin_amdgcn_exp2f((qk_local[0][i] - m_new) * 1.4426950408889634f); + float e1 = + __builtin_amdgcn_exp2f((qk_local[1][i] - m_new) * 1.4426950408889634f); + + float row_sum = e0 + e1; + row_sum += __shfl_xor(row_sum, 1); + row_sum += __shfl_xor(row_sum, 2); + row_sum += __shfl_xor(row_sum, 4); + row_sum += __shfl_xor(row_sum, 8); + + float l_new = l_state[i] * alpha + row_sum; + p[0][i] = e0; + p[1][i] = e1; + + #pragma unroll + for (int nt = 0; nt < N_TILES_PER_WAVE; ++nt) acc[nt][i] *= alpha; + + m_state[i] = m_new; + l_state[i] = l_new; + } + + if (wave == 0) { + #pragma unroll + for (int i = 0; i < 4; ++i) { + p_lds[(m_d_base + i) * BLOCK_K + n_d] = (__bf16)p[0][i]; + p_lds[(m_d_base + i) * BLOCK_K + 16 + n_d] = (__bf16)p[1][i]; + } + } + + __syncthreads(); + + bf16x8 p_reg; + const __bf16* p_src = &p_lds[m_a * BLOCK_K + kg * 8]; + #pragma unroll + for (int i = 0; i < 8; ++i) p_reg[i] = p_src[i]; + + #pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + bf16x8 k_reg; + #pragma unroll + for (int i = 0; i < 8; ++i) { + k_reg[i] = k_lds[(kg * 8 + i) * HEAD_DIM + n_tile * 16 + n_b]; + } + acc[nt_local] = mfma_16x16x32_bf16(p_reg, k_reg, acc[nt_local]); + } +} + +__device__ __forceinline__ void load_q(const __bf16* q, int64_t q_stride0, + int64_t q_stride1, int query, int pid_h, + int num_heads, __bf16* q_lds, int tid) { + const int qh = tid >> 4; // 0..15 + const int qc0 = (tid & 15) << 5; // 0,32,...,480 + const int head_global = pid_h * BLOCK_H + qh; + __bf16* dst = &q_lds[qh * HEAD_DIM + qc0]; + if (head_global < num_heads) { + const __bf16* src = q + query * q_stride0 + head_global * q_stride1 + qc0; + const int4* s4 = reinterpret_cast(src); + int4* d4 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < 4; ++i) d4[i] = s4[i]; + } else { + int4 z; + z.x = z.y = z.z = z.w = 0; + int4* d4 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < 4; ++i) d4[i] = z; + } +} + +template +__global__ __launch_bounds__(256, 2) void sparse_mla_decode_kernel( + const __bf16* __restrict__ q, const uint8_t* __restrict__ main_cache, + const int32_t* __restrict__ main_indices, + const int32_t* __restrict__ main_indptr, + const uint8_t* __restrict__ extra_cache, + const int32_t* __restrict__ extra_indices, + const int32_t* __restrict__ extra_indptr, + const float* __restrict__ attn_sink, __bf16* __restrict__ output, + int64_t q_stride0, int64_t q_stride1, int64_t out_stride0, + int64_t out_stride1, int64_t main_cache_stride0, + int64_t extra_cache_stride0, int main_num_rows, int extra_num_rows, + int main_block_size, int extra_block_size, float scale, int num_heads) { + const int query = blockIdx.x; + const int pid_h = blockIdx.y; + const int tid = threadIdx.x; + const int wave = tid >> 6; + const int lane = tid & 63; + + const int m_a = lane & 15; + const int kg = lane >> 4; + const int n_b = lane & 15; + const int m_d_base = (lane >> 4) * 4; + const int n_d = lane & 15; + + __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; + __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; + __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; + __shared__ float scores_lds[BLOCK_H * BLOCK_K]; + __shared__ int8_t kv_lds[BLOCK_K]; + __shared__ char force_1wg_per_cu[48 * 1024]; // pads LDS to ~96 KB + (void)force_1wg_per_cu; + + load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); + + float m_state[4], l_state[4]; + fx4 acc[N_TILES_PER_WAVE]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + m_state[i] = -3.4028234663852886e38f; + l_state[i] = 0.f; + } + #pragma unroll + for (int i = 0; i < N_TILES_PER_WAVE; ++i) { + acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; + } + + __syncthreads(); + + { + int main_start = main_indptr[query]; + int main_end = main_indptr[query + 1]; + int main_len = main_end - main_start; + for (int k_start = 0; k_start < main_len; k_start += BLOCK_K) { + gather_and_dequant_k_tile( + k_start, main_len, main_cache, main_cache_stride0, main_num_rows, + main_block_size, main_indices + main_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + if (HAS_EXTRA) { + int extra_start = extra_indptr[query]; + int extra_end = extra_indptr[query + 1]; + int extra_len = extra_end - extra_start; + for (int k_start = 0; k_start < extra_len; k_start += BLOCK_K) { + gather_and_dequant_k_tile( + k_start, extra_len, extra_cache, extra_cache_stride0, extra_num_rows, + extra_block_size, extra_indices + extra_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + { + #pragma unroll + for (int i = 0; i < 4; ++i) { + int head_local = m_d_base + i; + int head_global = pid_h * BLOCK_H + head_local; + if (head_global >= num_heads) continue; + + float m_final = m_state[i]; + float l_final = l_state[i]; + float alpha_final = 1.f; + if (HAS_ATTN_SINK) { + float sink_val = attn_sink[head_global]; + m_final = fmaxf(m_state[i], sink_val); + alpha_final = __builtin_amdgcn_exp2f((m_state[i] - m_final) * + 1.4426950408889634f); + l_final = + l_state[i] * alpha_final + + __builtin_amdgcn_exp2f((sink_val - m_final) * 1.4426950408889634f); + } + float denom = fmaxf(l_final, 1.0e-30f); + bool live = (l_final > 0.f); + + __bf16* out_row = + output + query * out_stride0 + head_global * out_stride1; + #pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + int col = n_tile * 16 + n_d; + float v = live ? (acc[nt_local][i] * alpha_final) / denom : 0.f; + out_row[col] = (__bf16)v; + } + } + } +} + +template +__global__ __launch_bounds__(256, 2) void sparse_mla_decode_partial_kernel( + const __bf16* __restrict__ q, const uint8_t* __restrict__ main_cache, + const int32_t* __restrict__ main_indices, + const int32_t* __restrict__ main_indptr, + const uint8_t* __restrict__ extra_cache, + const int32_t* __restrict__ extra_indices, + const int32_t* __restrict__ extra_indptr, float* __restrict__ scratch_m, + float* __restrict__ scratch_l, __bf16* __restrict__ scratch_acc, + int64_t q_stride0, int64_t q_stride1, int64_t main_cache_stride0, + int64_t extra_cache_stride0, int main_num_rows, int extra_num_rows, + int main_block_size, int extra_block_size, float scale, int num_heads, + int num_head_blocks) { + const int query = blockIdx.x; + const int pid_hs = blockIdx.y; + const int pid_split = pid_hs / num_head_blocks; + const int pid_h = pid_hs - pid_split * num_head_blocks; + const int tid = threadIdx.x; + const int wave = tid >> 6; + const int lane = tid & 63; + + const int m_a = lane & 15; + const int kg = lane >> 4; + const int n_b = lane & 15; + const int m_d_base = (lane >> 4) * 4; + const int n_d = lane & 15; + + __shared__ __bf16 q_lds[BLOCK_H * HEAD_DIM]; + __shared__ __bf16 k_lds[BLOCK_K * HEAD_DIM]; + __shared__ __bf16 p_lds[BLOCK_H * BLOCK_K]; + __shared__ float scores_lds[BLOCK_H * BLOCK_K]; + __shared__ int8_t kv_lds[BLOCK_K]; + load_q(q, q_stride0, q_stride1, query, pid_h, num_heads, q_lds, tid); + + float m_state[4], l_state[4]; + fx4 acc[N_TILES_PER_WAVE]; + #pragma unroll + for (int i = 0; i < 4; ++i) { + m_state[i] = -3.4028234663852886e38f; + l_state[i] = 0.f; + } + #pragma unroll + for (int i = 0; i < N_TILES_PER_WAVE; ++i) { + acc[i] = (fx4){0.f, 0.f, 0.f, 0.f}; + } + + __syncthreads(); + + { + int main_start = main_indptr[query]; + int main_end = main_indptr[query + 1]; + int main_len = main_end - main_start; + for (int k_start = pid_split * BLOCK_K; k_start < main_len; + k_start += BLOCK_K * SPLIT_K) { + gather_and_dequant_k_tile( + k_start, main_len, main_cache, main_cache_stride0, main_num_rows, + main_block_size, main_indices + main_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + if (HAS_EXTRA) { + int extra_start = extra_indptr[query]; + int extra_end = extra_indptr[query + 1]; + int extra_len = extra_end - extra_start; + for (int k_start = pid_split * BLOCK_K; k_start < extra_len; + k_start += BLOCK_K * SPLIT_K) { + gather_and_dequant_k_tile( + k_start, extra_len, extra_cache, extra_cache_stride0, extra_num_rows, + extra_block_size, extra_indices + extra_start, k_lds, kv_lds, tid); + __syncthreads(); + process_k_tile(q_lds, k_lds, kv_lds, p_lds, scores_lds, m_state, l_state, + acc, scale, lane, m_a, kg, n_b, m_d_base, n_d, wave); + __syncthreads(); + } + } + + const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + pid_split; + + if (wave == 0 && n_d == 0) { + #pragma unroll + for (int i = 0; i < 4; ++i) { + int idx = triple * BLOCK_H + m_d_base + i; + scratch_m[idx] = m_state[i]; + scratch_l[idx] = l_state[i]; + } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < 4; ++i) { + int row = m_d_base + i; + #pragma unroll + for (int nt_local = 0; nt_local < N_TILES_PER_WAVE; ++nt_local) { + int n_tile = wave * N_TILES_PER_WAVE + nt_local; + int col = n_tile * 16 + n_d; + k_lds[row * HEAD_DIM + col] = (__bf16)acc[nt_local][i]; + } + } + __syncthreads(); + + { + int my_row = tid >> 4; + int my_col0 = (tid & 15) << 5; + __bf16* dst = scratch_acc + (int64_t)triple * BLOCK_H * HEAD_DIM + + my_row * HEAD_DIM + my_col0; + const int4* src4 = + reinterpret_cast(&k_lds[my_row * HEAD_DIM + my_col0]); + int4* dst4 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < 4; ++i) dst4[i] = src4[i]; + } +} + +template +__global__ __launch_bounds__(256, 4) void sparse_mla_decode_reduce_kernel( + const float* __restrict__ scratch_m, const float* __restrict__ scratch_l, + const __bf16* __restrict__ scratch_acc, const float* __restrict__ attn_sink, + __bf16* __restrict__ output, int64_t out_stride0, int64_t out_stride1, + int num_heads, int num_head_blocks) { + const int query = blockIdx.x; + const int pid_h = blockIdx.y; + const int tid = threadIdx.x; + + const int my_row = tid >> 4; // 0..15 + const int my_col0 = (tid & 15) << 5; // 0,32,...,480 + const int head_global = pid_h * BLOCK_H + my_row; + + float m_merged = -3.4028234663852886e38f; + float l_merged = 0.f; + float acc_merged[32]; + #pragma unroll + for (int i = 0; i < 32; ++i) acc_merged[i] = 0.f; + + #pragma unroll + for (int s = 0; s < SPLIT_K; ++s) { + const int triple = (query * num_head_blocks + pid_h) * SPLIT_K + s; + float m_s = scratch_m[triple * BLOCK_H + my_row]; + float l_s = scratch_l[triple * BLOCK_H + my_row]; + + float m_new = fmaxf(m_merged, m_s); + float alpha = + __builtin_amdgcn_exp2f((m_merged - m_new) * 1.4426950408889634f); + float beta = __builtin_amdgcn_exp2f((m_s - m_new) * 1.4426950408889634f); + l_merged = l_merged * alpha + l_s * beta; + m_merged = m_new; + + const __bf16* acc_base = scratch_acc + + (int64_t)triple * BLOCK_H * HEAD_DIM + + my_row * HEAD_DIM + my_col0; + const int4* src4 = reinterpret_cast(acc_base); + #pragma unroll + for (int i = 0; i < 4; ++i) { + int4 v = src4[i]; + __bf16 vbf[8]; + *reinterpret_cast(vbf) = v; + #pragma unroll + for (int j = 0; j < 8; ++j) { + float a_s = (float)vbf[j]; + acc_merged[i * 8 + j] = acc_merged[i * 8 + j] * alpha + a_s * beta; + } + } + } + + if (head_global >= num_heads) return; + + float m_final = m_merged; + float l_final = l_merged; + float alpha_final = 1.f; + if (HAS_ATTN_SINK) { + float sink_val = attn_sink[head_global]; + m_final = fmaxf(m_merged, sink_val); + alpha_final = + __builtin_amdgcn_exp2f((m_merged - m_final) * 1.4426950408889634f); + l_final = + l_merged * alpha_final + + __builtin_amdgcn_exp2f((sink_val - m_final) * 1.4426950408889634f); + } + float denom = fmaxf(l_final, 1.0e-30f); + bool live = (l_final > 0.f); + float inv_denom = live ? (alpha_final / denom) : 0.f; + + __bf16* out_row = + output + query * out_stride0 + head_global * out_stride1 + my_col0; + __bf16 out_buf[32]; + #pragma unroll + for (int i = 0; i < 32; ++i) out_buf[i] = (__bf16)(acc_merged[i] * inv_denom); + int4* dst4 = reinterpret_cast(out_row); + const int4* sb4 = reinterpret_cast(out_buf); + #pragma unroll + for (int i = 0; i < 4; ++i) dst4[i] = sb4[i]; +} + +#else // non-gfx950 device pass: empty __global__ stubs for symbol parity. + +template +__global__ void sparse_mla_decode_kernel(const __bf16*, const uint8_t*, + const int32_t*, const int32_t*, + const uint8_t*, const int32_t*, + const int32_t*, const float*, __bf16*, + int64_t, int64_t, int64_t, int64_t, + int64_t, int64_t, int, int, int, int, + float, int) {} + +template +__global__ void sparse_mla_decode_partial_kernel( + const __bf16*, const uint8_t*, const int32_t*, const int32_t*, + const uint8_t*, const int32_t*, const int32_t*, float*, float*, __bf16*, + int64_t, int64_t, int64_t, int64_t, int, int, int, int, float, int, int) {} + +template +__global__ void sparse_mla_decode_reduce_kernel(const float*, const float*, + const __bf16*, const float*, + __bf16*, int64_t, int64_t, int, + int) {} + +#endif // __HIP__GFX950__ || !__HIP_DEVICE_COMPILE__ + +void sparse_mla_decode_single( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + const std::optional& attn_sink, torch::Tensor output, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale_d, bool has_extra) { + const int num_queries = q.size(0); + const int num_heads = q.size(1); + const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; + const float scale_f = (float)scale_d; + const bool has_sink = attn_sink.has_value(); + + dim3 grid(num_queries, num_head_blocks); + dim3 block(256); + + const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); + const uint8_t* mc_ptr = + reinterpret_cast(main_cache.data_ptr()); + const uint8_t* ec_ptr = + reinterpret_cast(extra_cache.data_ptr()); + const int32_t* mi_ptr = main_indices.data_ptr(); + const int32_t* mip_ptr = main_indptr.data_ptr(); + const int32_t* ei_ptr = extra_indices.data_ptr(); + const int32_t* eip_ptr = extra_indptr.data_ptr(); + __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); + const float* sink_ptr = + has_sink ? attn_sink.value().data_ptr() : nullptr; + + auto stream = at::cuda::getCurrentCUDAStream(); + +#define LAUNCH(HAS_S, HAS_E) \ + do { \ + sparse_mla_decode_kernel<<>>( \ + q_ptr, mc_ptr, mi_ptr, mip_ptr, ec_ptr, ei_ptr, eip_ptr, sink_ptr, \ + out_ptr, q.stride(0), q.stride(1), output.stride(0), output.stride(1), \ + main_cache.stride(0), extra_cache.stride(0), main_num_rows, \ + extra_num_rows, main_block_size, extra_block_size, scale_f, \ + num_heads); \ + } while (0) + + if (has_sink && has_extra) + LAUNCH(true, true); + else if (has_sink) + LAUNCH(true, false); + else if (has_extra) + LAUNCH(false, true); + else + LAUNCH(false, false); + +#undef LAUNCH +} + +void sparse_mla_decode_split( + torch::Tensor q, torch::Tensor main_cache, torch::Tensor main_indices, + torch::Tensor main_indptr, torch::Tensor extra_cache, + torch::Tensor extra_indices, torch::Tensor extra_indptr, + const std::optional& attn_sink, torch::Tensor output, + torch::Tensor scratch_m, torch::Tensor scratch_l, torch::Tensor scratch_acc, + int64_t main_block_size, int64_t extra_block_size, int64_t main_num_rows, + int64_t extra_num_rows, double scale_d, bool has_extra, int64_t split_k) { + const int num_queries = q.size(0); + const int num_heads = q.size(1); + const int num_head_blocks = (num_heads + BLOCK_H - 1) / BLOCK_H; + const float scale_f = (float)scale_d; + const bool has_sink = attn_sink.has_value(); + + dim3 grid_p(num_queries, num_head_blocks * (int)split_k); + dim3 grid_r(num_queries, num_head_blocks); + dim3 block_p(256); + dim3 block_r(256); + + const __bf16* q_ptr = reinterpret_cast(q.data_ptr()); + const uint8_t* mc_ptr = + reinterpret_cast(main_cache.data_ptr()); + const uint8_t* ec_ptr = + reinterpret_cast(extra_cache.data_ptr()); + const int32_t* mi_ptr = main_indices.data_ptr(); + const int32_t* mip_ptr = main_indptr.data_ptr(); + const int32_t* ei_ptr = extra_indices.data_ptr(); + const int32_t* eip_ptr = extra_indptr.data_ptr(); + __bf16* out_ptr = reinterpret_cast<__bf16*>(output.data_ptr()); + float* sm_ptr = scratch_m.data_ptr(); + float* sl_ptr = scratch_l.data_ptr(); + __bf16* sa_ptr = reinterpret_cast<__bf16*>(scratch_acc.data_ptr()); + const float* sink_ptr = + has_sink ? attn_sink.value().data_ptr() : nullptr; + + auto stream = at::cuda::getCurrentCUDAStream(); + +#define LAUNCH_P(HAS_E, SK) \ + do { \ + sparse_mla_decode_partial_kernel \ + <<>>( \ + q_ptr, mc_ptr, mi_ptr, mip_ptr, ec_ptr, ei_ptr, eip_ptr, sm_ptr, \ + sl_ptr, sa_ptr, q.stride(0), q.stride(1), main_cache.stride(0), \ + extra_cache.stride(0), main_num_rows, extra_num_rows, \ + main_block_size, extra_block_size, scale_f, num_heads, \ + num_head_blocks); \ + } while (0) + +#define LAUNCH_R(HAS_S, SK) \ + do { \ + sparse_mla_decode_reduce_kernel \ + <<>>( \ + sm_ptr, sl_ptr, sa_ptr, sink_ptr, out_ptr, output.stride(0), \ + output.stride(1), num_heads, num_head_blocks); \ + } while (0) + +#define DISPATCH_SK(SK) \ + do { \ + if (has_extra) \ + LAUNCH_P(true, SK); \ + else \ + LAUNCH_P(false, SK); \ + if (has_sink) \ + LAUNCH_R(true, SK); \ + else \ + LAUNCH_R(false, SK); \ + } while (0) + + switch ((int)split_k) { + case 2: + DISPATCH_SK(2); + break; + case 4: + DISPATCH_SK(4); + break; + case 8: + DISPATCH_SK(8); + break; + case 16: + DISPATCH_SK(16); + break; + default: + TORCH_CHECK(false, "Unsupported SPLIT_K"); + } +#undef DISPATCH_SK +#undef LAUNCH_P +#undef LAUNCH_R +} diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 03de6dcd1576..6c442afd0e7a 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -82,4 +82,26 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } +#ifdef VLLM_ROCM_GFX950 +TORCH_LIBRARY_FRAGMENT(vllm_sparse_mla_hip, m) { + m.def( + "decode_single(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra) -> ()"); + m.def( + "decode_split(Tensor q, Tensor main_cache, Tensor main_indices, " + "Tensor main_indptr, Tensor extra_cache, Tensor extra_indices, " + "Tensor extra_indptr, Tensor? attn_sink, Tensor output, " + "Tensor scratch_m, Tensor scratch_l, Tensor scratch_acc, " + "int main_block_size, int extra_block_size, int main_num_rows, " + "int extra_num_rows, float scale, bool has_extra, int split_k) -> ()"); +} +TORCH_LIBRARY_IMPL(vllm_sparse_mla_hip, CUDA, m) { + m.impl("decode_single", &sparse_mla_decode_single); + m.impl("decode_split", &sparse_mla_decode_split); +} +#endif + REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py b/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py new file mode 100644 index 000000000000..67aaef9c5bbd --- /dev/null +++ b/tests/kernels/attention/test_rocm_sparse_mla_decode_gfx950.py @@ -0,0 +1,476 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Unit tests for the HIP MFMA sparse-MLA decode kernels (gfx950). + +All tests use the DeepSeek-V4-Pro production attention dims (see HF config +https://huggingface.co/deepseek-ai/DeepSeek-V4-Pro/blob/main/config.json) and +are parametrized over the runtime axes the kernel must support: + * batch (``num_queries``): single decode and small MTP-style batches. + * sequence length per query: the SWA window (``sliding_window=128``) and + the topk budget (``index_topk=1024``). + * split-K: both the single-WG path and the split-reduce path are covered + by forcing the internal split-K override. +""" + +import contextlib + +import pytest +import torch + +from vllm.platforms import current_platform + + +def _is_gfx950() -> bool: + if not current_platform.is_rocm(): + return False + try: + from vllm.platforms.rocm import _ON_GFX950 + + return _ON_GFX950 + except ImportError: + return False + + +pytestmark = pytest.mark.skipif( + not _is_gfx950(), reason="Requires ROCm gfx950 hardware" +) + +# Per-token KV head dims for DeepSeek-V4-Pro MLA. The HIP kernel is +# hard-coded for these values (576-byte token payload + 8-byte scales): +# nope_head_dim = head_dim(512) - qk_rope_head_dim(64) = 448 +# rope_head_dim = qk_rope_head_dim = 64 +NOPE_HEAD_DIM = 448 +ROPE_HEAD_DIM = 64 +HEAD_DIM = NOPE_HEAD_DIM + ROPE_HEAD_DIM + +# DeepSeek-V4-Pro production attention params (HF config.json): +# * num_attention_heads = 128 -> 16 heads per rank at TP=8. +# * ROCM_AITER_MLA_SPARSE backend supports kernel_block_size in {1, 64}; +# paged deployments use 64. +# * sliding_window = 128 -> max main/SWA tokens per query. +# * index_topk = 1024 -> max extra/topk tokens per query. +DSV4_NUM_HEADS = 16 +DSV4_BLOCK_SIZE = 64 +DSV4_SWA_WINDOW = 128 +DSV4_TOPK = 1024 + + +# --------------------------------------------------------------------------- +# Helpers (shared with test_rocm_triton_attn_dsv4.py) +# --------------------------------------------------------------------------- + + +def _pack_fp8_ds_mla_cache(kv: torch.Tensor, block_size: int) -> torch.Tensor: + """Pack bf16 KV rows into the fp8_ds_mla uint8 cache layout.""" + assert kv.shape[-1] == HEAD_DIM + num_tokens = kv.shape[0] + num_blocks = (num_tokens + block_size - 1) // block_size + cache = torch.zeros( + (num_blocks, block_size, 584), + dtype=torch.uint8, + device=kv.device, + ) + cache_flat = cache.view(torch.uint8).flatten() + kv_nope_fp8 = ( + kv[:, :NOPE_HEAD_DIM].to(current_platform.fp8_dtype()).view(torch.uint8) + ) + kv_rope_u8 = kv[:, NOPE_HEAD_DIM:].contiguous().view(torch.uint8) + + for slot in range(num_tokens): + block_idx = slot // block_size + pos = slot % block_size + block_base = block_idx * cache.stride(0) + token_base = block_base + pos * 576 + scale_base = block_base + block_size * 576 + pos * 8 + cache_flat[token_base : token_base + NOPE_HEAD_DIM].copy_(kv_nope_fp8[slot]) + cache_flat[ + token_base + NOPE_HEAD_DIM : token_base + NOPE_HEAD_DIM + ROPE_HEAD_DIM * 2 + ].copy_(kv_rope_u8[slot]) + cache_flat[scale_base : scale_base + 7].fill_(127) + return cache + + +def _read_fp8_ds_mla_cache( + cache: torch.Tensor, slot: int, block_size: int +) -> torch.Tensor: + cache_flat = cache.view(torch.uint8).flatten() + block_idx = slot // block_size + pos = slot % block_size + block_base = block_idx * cache.stride(0) + token_base = block_base + pos * 576 + + nope_u8 = cache_flat[token_base : token_base + NOPE_HEAD_DIM] + nope = nope_u8.view(current_platform.fp8_dtype()).to(torch.float32) + rope_u8 = cache_flat[ + token_base + NOPE_HEAD_DIM : token_base + NOPE_HEAD_DIM + ROPE_HEAD_DIM * 2 + ] + rope = rope_u8.view(torch.bfloat16).to(torch.float32) + return torch.cat([nope, rope]) + + +def _ref_sparse_decode_ragged( + q: torch.Tensor, + main_cache: torch.Tensor, + main_rows: list[list[int]], + scale: float, + attn_sink: torch.Tensor | None, + block_size: int, + extra_cache: torch.Tensor | None = None, + extra_rows: list[list[int]] | None = None, +) -> torch.Tensor: + """Pure-Python reference for ragged sparse decode attention.""" + q_f32 = q.float() + out = torch.empty_like(q_f32) + + for query_idx in range(q.shape[0]): + row_kv = [ + _read_fp8_ds_mla_cache(main_cache, int(slot), block_size) + for slot in main_rows[query_idx] + ] + if extra_cache is not None and extra_rows is not None: + row_kv.extend( + _read_fp8_ds_mla_cache(extra_cache, int(slot), block_size) + for slot in extra_rows[query_idx] + ) + + kv = torch.stack(row_kv).to(q.device) + for head_idx in range(q.shape[1]): + scores = torch.mv(kv, q_f32[query_idx, head_idx]) * scale + if attn_sink is not None: + scores_with_sink = torch.cat( + [scores, attn_sink[head_idx].float().reshape(1)] + ) + probs = torch.softmax(scores_with_sink, dim=0)[:-1] + else: + probs = torch.softmax(scores, dim=0) + out[query_idx, head_idx] = torch.sum(probs[:, None] * kv, dim=0) + return out.to(torch.bfloat16) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink, + extra_cache=None, + extra_indices=None, + extra_indptr=None, + max_main_len=None, + max_extra_len=None, +): + from vllm.v1.attention.ops.rocm_aiter_mla_sparse import ( + _decode_sparse_mla_hip, + ) + + if max_main_len is None: + max_main_len = int(main_indices.numel()) + if max_extra_len is None: + max_extra_len = int(extra_indices.numel()) if extra_indices is not None else 0 + + return _decode_sparse_mla_hip( + q=q, + main_cache=main_cache, + main_indices=main_indices, + main_indptr=main_indptr, + scale=scale, + attn_sink=attn_sink, + nope_head_dim=NOPE_HEAD_DIM, + rope_head_dim=ROPE_HEAD_DIM, + extra_cache=extra_cache, + extra_indices=extra_indices, + extra_indptr=extra_indptr, + max_main_len=max_main_len, + max_extra_len=max_extra_len, + ) + + +@contextlib.contextmanager +def _force_split_k(value: int | None): + """Temporarily override the internal split-K picker. + + The kernel module reads ``SPARSE_MLA_HIP_SPLIT_K`` once at import time + into ``mod._SPLIT_K_OVERRIDE``; tests mutate that module attribute + directly so they can exercise both the single-WG (split_k=1) and the + split-reduce paths regardless of the auto-tuned heuristic. + """ + from vllm.v1.attention.ops import rocm_aiter_mla_sparse as mod + + orig = mod._SPLIT_K_OVERRIDE + if value is not None: + mod._SPLIT_K_OVERRIDE = str(int(value)) + try: + yield + finally: + mod._SPLIT_K_OVERRIDE = orig + + +def _make_q(num_queries: int, num_heads: int, device: torch.device) -> torch.Tensor: + return ( + torch.randn( + num_queries, num_heads, HEAD_DIM, dtype=torch.bfloat16, device=device + ) + * 0.125 + ) + + +def _build_contiguous_ragged( + num_queries: int, + tokens_per_query: int, + device: torch.device, + start: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, list[list[int]]]: + """Build (indices, indptr, rows) where each query owns + ``tokens_per_query`` contiguous slots starting at ``start``.""" + rows = [ + list(range(start + qi * tokens_per_query, start + (qi + 1) * tokens_per_query)) + for qi in range(num_queries) + ] + indices = torch.tensor( + [s for r in rows for s in r], dtype=torch.int32, device=device + ) + indptr = torch.tensor( + [qi * tokens_per_query for qi in range(num_queries + 1)], + dtype=torch.int32, + device=device, + ) + return indices, indptr, rows + + +# Parametrize against the realistic DSv4 deployment axes: +# * ``num_queries`` covers single-token decode and small MTP batches. +# * ``split_k`` covers both the single-WG path and the split-reduce path. +# * ``with_sink`` toggles the attention-sink branch. +# All tests pin ``num_heads = DSV4_NUM_HEADS`` (=16, DSv4 at TP=8) and +# ``block_size = DSV4_BLOCK_SIZE`` (=64, the paged-MLA cache block size). + + +@torch.inference_mode() +@pytest.mark.parametrize("split_k", [1, 4]) +@pytest.mark.parametrize("num_queries", [1, 4]) +@pytest.mark.parametrize("with_sink", [False, True]) +def test_hip_decode_main_only(num_queries, with_sink, split_k) -> None: + """SWA-only decode (main cache) at DSv4 dims, with/without sink.""" + device = torch.device("cuda") + torch.manual_seed(42 + num_queries * 10 + int(with_sink) * 3 + split_k) + tokens_per_query = DSV4_SWA_WINDOW + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( + torch.randn( + num_queries * tokens_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, + ) + * 0.125 + ) + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, tokens_per_query, device + ) + attn_sink = ( + torch.randn(DSV4_NUM_HEADS, dtype=torch.float32, device=device) * 0.1 + if with_sink + else None + ) + scale = HEAD_DIM**-0.5 + + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + max_main_len=tokens_per_query, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + main_rows, + scale, + attn_sink=attn_sink, + block_size=DSV4_BLOCK_SIZE, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +@pytest.mark.parametrize("split_k", [1, 4]) +@pytest.mark.parametrize("num_queries", [1, 4]) +@pytest.mark.parametrize("with_sink", [False, True]) +def test_hip_decode_main_extra(num_queries, with_sink, split_k) -> None: + """SWA + topk decode (main + extra caches) at DSv4 dims.""" + device = torch.device("cuda") + torch.manual_seed(7 + num_queries * 11 + int(with_sink) * 5 + split_k) + # ``main`` carries SWA tokens; ``extra`` carries topk tokens. Use a + # modest extra length (rather than the full DSV4_TOPK=1024) to keep + # the Python reference fast while still exercising both code paths. + main_per_query = DSV4_SWA_WINDOW + extra_per_query = DSV4_BLOCK_SIZE * 2 # 128 topk tokens + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( + torch.randn( + num_queries * main_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, + ) + * 0.125 + ) + extra_kv = ( + torch.randn( + num_queries * extra_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, + ) + * 0.125 + ) + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + extra_cache = _pack_fp8_ds_mla_cache(extra_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, main_per_query, device + ) + extra_indices, extra_indptr, extra_rows = _build_contiguous_ragged( + num_queries, extra_per_query, device + ) + attn_sink = ( + torch.randn(DSV4_NUM_HEADS, dtype=torch.float32, device=device) * 0.1 + if with_sink + else None + ) + scale = HEAD_DIM**-0.5 + + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + extra_cache=extra_cache, + extra_indices=extra_indices, + extra_indptr=extra_indptr, + max_main_len=main_per_query, + max_extra_len=extra_per_query, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + main_rows, + scale, + attn_sink=attn_sink, + block_size=DSV4_BLOCK_SIZE, + extra_cache=extra_cache, + extra_rows=extra_rows, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +@pytest.mark.parametrize("split_k", [2, 4, 8]) +@pytest.mark.parametrize("num_queries", [1, 4]) +def test_hip_decode_split_k(num_queries, split_k) -> None: + """Force the split-K reduce path with the DSv4 topk-sized seqlen.""" + device = torch.device("cuda") + torch.manual_seed(2026 + num_queries * 13 + split_k) + tokens_per_query = DSV4_TOPK # full DSv4 topk budget + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( + torch.randn( + num_queries * tokens_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, + ) + * 0.125 + ) + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, tokens_per_query, device + ) + attn_sink = torch.randn(DSV4_NUM_HEADS, dtype=torch.float32, device=device) * 0.1 + scale = HEAD_DIM**-0.5 + + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=attn_sink, + max_main_len=tokens_per_query, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + main_rows, + scale, + attn_sink=attn_sink, + block_size=DSV4_BLOCK_SIZE, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) + + +@torch.inference_mode() +@pytest.mark.parametrize("split_k", [1, 4]) +@pytest.mark.parametrize("num_queries", [1, 4]) +@pytest.mark.parametrize( + "tokens_per_query", + [DSV4_BLOCK_SIZE, DSV4_SWA_WINDOW, DSV4_TOPK], + ids=["one_block", "swa_window", "topk"], +) +def test_hip_decode_seqlen(num_queries, tokens_per_query, split_k) -> None: + """Sweep DSv4 sequence lengths: one block, SWA window, full topk.""" + device = torch.device("cuda") + torch.manual_seed(num_queries * 1009 + tokens_per_query + split_k) + + q = _make_q(num_queries, DSV4_NUM_HEADS, device) + main_kv = ( + torch.randn( + num_queries * tokens_per_query, + HEAD_DIM, + dtype=torch.bfloat16, + device=device, + ) + * 0.125 + ) + main_cache = _pack_fp8_ds_mla_cache(main_kv, DSV4_BLOCK_SIZE) + main_indices, main_indptr, main_rows = _build_contiguous_ragged( + num_queries, tokens_per_query, device + ) + scale = HEAD_DIM**-0.5 + + with _force_split_k(split_k): + actual = _call_hip_decode( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink=None, + max_main_len=tokens_per_query, + ) + expected = _ref_sparse_decode_ragged( + q, + main_cache, + main_rows, + scale, + attn_sink=None, + block_size=DSV4_BLOCK_SIZE, + ) + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 7f6d8794c285..426cb9ef4357 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -471,7 +471,7 @@ def import_kernels(cls) -> None: import contextlib - # Import ROCm-specific extension + # Import ROCm-specific extensions with contextlib.suppress(ImportError): import vllm._rocm_C # noqa: F401 diff --git a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py index 12fd3a174217..8f31d46e46c4 100644 --- a/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/ops/rocm_aiter_mla_sparse.py @@ -2,7 +2,9 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools import importlib +import logging import math +import os from importlib.util import find_spec import torch @@ -1715,6 +1717,138 @@ def rocm_sparse_attn_prefill( output.copy_(output_chunk.to(output.dtype)) +# ============================================================================ +# HIP MFMA kernel implementation for sparse-MLA decode. +# Compiled at build time via CMakeLists.txt into the _rocm_C extension when +# gfx950 is in VLLM_GPU_ARCHES (see csrc/rocm/sparse_mla_decode.cu). +# Ops are registered under the torch.ops.vllm_sparse_mla_hip namespace. +# ============================================================================ + +logger = logging.getLogger(__name__) + + +def _sparse_mla_hip_as_int32_1d(x): + if x.dtype != torch.int32: + x = x.to(torch.int32) + if not x.is_contiguous(): + x = x.contiguous() + return x.view(-1) + + +_NUM_CUS = 256 +_MIN_K_PER_SPLIT = int(os.environ.get("SPARSE_MLA_HIP_MIN_K_PER_SPLIT", "32")) +_SPLIT_K_OVERRIDE = os.environ.get("SPARSE_MLA_HIP_SPLIT_K") + + +def _pick_split_k(num_queries, num_head_blocks, max_total_k): + if _SPLIT_K_OVERRIDE is not None: + return int(_SPLIT_K_OVERRIDE) + base_tiles = max(1, num_queries * num_head_blocks) + cu_target = max(1, _NUM_CUS // base_tiles) + k_limit = max(1, max_total_k // _MIN_K_PER_SPLIT) + target = max(1, min(cu_target, k_limit)) + best = 1 + for b in (1, 2, 4, 8): + if b <= target: + best = b + return best + + +def _decode_sparse_mla_hip( + q, + main_cache, + main_indices, + main_indptr, + scale, + attn_sink, + nope_head_dim, + rope_head_dim, + extra_cache, + extra_indices, + extra_indptr, + max_main_len, + max_extra_len, +): + main_indices = _sparse_mla_hip_as_int32_1d(main_indices) + main_indptr = _sparse_mla_hip_as_int32_1d(main_indptr) + num_queries, num_heads, _ = q.shape + + has_extra = ( + extra_cache is not None + and extra_indices is not None + and extra_indptr is not None + ) + if has_extra: + extra_indices = _sparse_mla_hip_as_int32_1d(extra_indices) + extra_indptr = _sparse_mla_hip_as_int32_1d(extra_indptr) + else: + extra_cache = main_cache + extra_indices = torch.empty(0, device=q.device, dtype=torch.int32) + extra_indptr = torch.zeros(num_queries + 1, device=q.device, dtype=torch.int32) + + out = torch.empty_like(q, dtype=torch.bfloat16) + sink = attn_sink.contiguous() if attn_sink is not None else None + q_in = q.contiguous() if not q.is_contiguous() else q + + BLOCK_H = 16 + num_head_blocks = (num_heads + BLOCK_H - 1) // BLOCK_H + total_max_k = max_main_len + (max_extra_len if has_extra else 0) + split_k = _pick_split_k(num_queries, num_head_blocks, total_max_k) + + if split_k == 1: + torch.ops.vllm_sparse_mla_hip.decode_single( + q_in, + main_cache, + main_indices, + main_indptr, + extra_cache, + extra_indices, + extra_indptr, + sink, + out, + int(main_cache.shape[1]), + int(extra_cache.shape[1]), + int(main_cache.shape[0] * main_cache.shape[1]), + int(extra_cache.shape[0] * extra_cache.shape[1]), + float(scale), + bool(has_extra), + ) + else: + scratch_m = torch.empty( + num_queries * num_head_blocks * split_k * BLOCK_H, + device=q.device, + dtype=torch.float32, + ) + scratch_l = torch.empty_like(scratch_m) + scratch_acc = torch.empty( + num_queries * num_head_blocks * split_k * BLOCK_H * 512, + device=q.device, + dtype=torch.bfloat16, + ) + torch.ops.vllm_sparse_mla_hip.decode_split( + q_in, + main_cache, + main_indices, + main_indptr, + extra_cache, + extra_indices, + extra_indptr, + sink, + out, + scratch_m, + scratch_l, + scratch_acc, + int(main_cache.shape[1]), + int(extra_cache.shape[1]), + int(main_cache.shape[0] * main_cache.shape[1]), + int(extra_cache.shape[0] * extra_cache.shape[1]), + float(scale), + bool(has_extra), + int(split_k), + ) + return out + + def rocm_sparse_attn_decode( q: torch.Tensor, kv_cache: torch.Tensor | None, @@ -1736,7 +1870,7 @@ def rocm_sparse_attn_decode( output: torch.Tensor, ) -> None: assert swa_k_cache.dtype == torch.uint8, ( - "ROCm Triton sparse decode expects uint8 fp8_ds_mla SWA cache, " + "ROCm sparse decode expects uint8 fp8_ds_mla SWA cache, " f"got {swa_k_cache.dtype}" ) _validate_dsv4_sparse_dims( @@ -1746,38 +1880,146 @@ def rocm_sparse_attn_decode( "rocm_sparse_attn_decode", ) - main_indices = swa_indices.reshape(swa_indices.shape[0], -1) + if _ON_GFX950: + _rocm_sparse_attn_decode_hip( + q=q, + kv_cache=kv_cache, + swa_k_cache=swa_k_cache, + swa_only=swa_only, + topk_indices=topk_indices, + topk_lens=topk_lens, + swa_indices=swa_indices, + swa_lens=swa_lens, + swa_ragged_indices=swa_ragged_indices, + swa_ragged_indptr=swa_ragged_indptr, + topk_ragged_indices=topk_ragged_indices, + topk_ragged_indptr=topk_ragged_indptr, + attn_sink=attn_sink, + scale=scale, + nope_head_dim=nope_head_dim, + rope_head_dim=rope_head_dim, + output=output, + ) + else: + main_indices = swa_indices.reshape(swa_indices.shape[0], -1) + + extra_cache = None + extra_indices = None + if not swa_only: + assert kv_cache is not None + assert topk_indices is not None or ( + topk_ragged_indices is not None and topk_ragged_indptr is not None + ) + assert kv_cache.dtype == torch.uint8 + extra_cache = kv_cache + if topk_indices is not None: + extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) - extra_cache = None - extra_indices = None - if not swa_only: - assert kv_cache is not None - assert topk_indices is not None or ( - topk_ragged_indices is not None and topk_ragged_indptr is not None + attn_out = _rocm_sparse_attn_decode_triton( + q=q, + main_cache=swa_k_cache, + main_indices=main_indices, + scale=scale, + attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], + nope_head_dim=nope_head_dim, + rope_head_dim=rope_head_dim, + extra_cache=extra_cache, + extra_indices=extra_indices, + main_lengths=swa_lens, + extra_lengths=topk_lens, + main_ragged_indices=swa_ragged_indices, + main_ragged_indptr=swa_ragged_indptr, + extra_ragged_indices=topk_ragged_indices, + extra_ragged_indptr=topk_ragged_indptr, + ) + output.copy_(attn_out.to(output.dtype)) + + +def _rocm_sparse_attn_decode_hip( + q, + kv_cache, + swa_k_cache, + swa_only, + topk_indices, + topk_lens, + swa_indices, + swa_lens, + swa_ragged_indices, + swa_ragged_indptr, + topk_ragged_indices, + topk_ragged_indptr, + attn_sink, + scale, + nope_head_dim, + rope_head_dim, + output, +): + assert nope_head_dim == 448 + assert rope_head_dim == 64 + + if swa_ragged_indices is None or swa_ragged_indptr is None: + main_indices_dense = swa_indices.reshape(swa_indices.shape[0], -1) + lengths = ( + swa_lens + if swa_lens is not None + else (main_indices_dense >= 0).sum(dim=-1, dtype=torch.int32) ) - assert kv_cache.dtype == torch.uint8, ( - "ROCm Triton sparse decode expects uint8 fp8_ds_mla extra cache, " - f"got {kv_cache.dtype}" + main_ragged_indices, main_ragged_indptr = build_ragged_indices_from_dense( + main_indices_dense, + lengths, + num_rows=swa_k_cache.shape[0] * swa_k_cache.shape[1], ) + else: + main_ragged_indices = swa_ragged_indices + main_ragged_indptr = swa_ragged_indptr + + has_extra = not swa_only + extra_cache = None + extra_ragged_indices = None + extra_ragged_indptr = None + if has_extra: + assert kv_cache is not None + assert kv_cache.dtype == torch.uint8 extra_cache = kv_cache - if topk_indices is not None: - extra_indices = topk_indices.reshape(topk_indices.shape[0], -1) + if topk_ragged_indices is None or topk_ragged_indptr is None: + assert topk_indices is not None + ex_dense = topk_indices.reshape(topk_indices.shape[0], -1) + lengths = ( + topk_lens + if topk_lens is not None + else (ex_dense >= 0).sum(dim=-1, dtype=torch.int32) + ) + extra_ragged_indices, extra_ragged_indptr = build_ragged_indices_from_dense( + ex_dense, + lengths, + num_rows=kv_cache.shape[0] * kv_cache.shape[1], + ) + else: + extra_ragged_indices = topk_ragged_indices + extra_ragged_indptr = topk_ragged_indptr + + if torch.cuda.is_current_stream_capturing(): + max_main_len = swa_indices.shape[-1] + max_extra_len = max_main_len if has_extra else 0 + else: + max_main_len = int(swa_lens.max().item()) if swa_lens is not None else 0 + max_extra_len = 0 + if has_extra and topk_lens is not None: + max_extra_len = int(topk_lens.max().item()) - attn_out = _rocm_sparse_attn_decode_triton( + attn_out = _decode_sparse_mla_hip( q=q, main_cache=swa_k_cache, - main_indices=main_indices, + main_indices=main_ragged_indices, + main_indptr=main_ragged_indptr, scale=scale, attn_sink=None if attn_sink is None else attn_sink[: q.shape[1]], nope_head_dim=nope_head_dim, rope_head_dim=rope_head_dim, extra_cache=extra_cache, - extra_indices=extra_indices, - main_lengths=swa_lens, - extra_lengths=topk_lens, - main_ragged_indices=swa_ragged_indices, - main_ragged_indptr=swa_ragged_indptr, - extra_ragged_indices=topk_ragged_indices, - extra_ragged_indptr=topk_ragged_indptr, + extra_indices=extra_ragged_indices, + extra_indptr=extra_ragged_indptr, + max_main_len=max_main_len, + max_extra_len=max_extra_len, ) output.copy_(attn_out.to(output.dtype))