From 97c352230f56b43e453a955702723ae6e08ce3d6 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Fri, 15 May 2026 03:04:36 +0900 Subject: [PATCH 1/7] Port gfx12 native attention to ABI3 fork --- csrc/cp_async.cuh | 4 +- csrc/fused/fused.cu | 9 + csrc/numeric_conversion.cuh | 81 +- csrc/qattn/attn_gfx12_native.h | 81 + csrc/qattn/pybind_gfx12_native.cpp | 101 + csrc/qattn/qk_int_sv_gfx12_native.cu | 7446 +++++++++++++++++++++++++ csrc/reduction_utils.cuh | 10 +- sageattention/__init__.py | 3 +- sageattention/core.py | 573 +- sageattention/gfx12_native_compile.py | 129 + setup.py | 483 +- 11 files changed, 8697 insertions(+), 223 deletions(-) create mode 100644 csrc/qattn/attn_gfx12_native.h create mode 100644 csrc/qattn/pybind_gfx12_native.cpp create mode 100644 csrc/qattn/qk_int_sv_gfx12_native.cu create mode 100644 sageattention/gfx12_native_compile.py diff --git a/csrc/cp_async.cuh b/csrc/cp_async.cuh index c877c16e..c19568db 100644 --- a/csrc/cp_async.cuh +++ b/csrc/cp_async.cuh @@ -33,7 +33,7 @@ enum class PrefetchMode { kPrefetch // Fetch additional data from global memory to L2 }; -#if (__CUDACC_VER_MAJOR__ >= 11) +#if !defined(__HIP_PLATFORM_AMD__) && (__CUDACC_VER_MAJOR__ >= 11) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) #define CP_ASYNC_ENABLED #endif @@ -138,4 +138,4 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, b #endif } -} // namespace cp_async \ No newline at end of file +} // namespace cp_async diff --git a/csrc/fused/fused.cu b/csrc/fused/fused.cu index 70fe688a..4fd78b17 100644 --- a/csrc/fused/fused.cu +++ b/csrc/fused/fused.cu @@ -32,6 +32,11 @@ #include #include +#if defined(__HIP_PLATFORM_AMD__) +using nv_bfloat16 = __hip_bfloat16; +using nv_bfloat162 = __hip_bfloat162; +#endif + using torch::stable::Tensor; enum class QuantType @@ -329,7 +334,11 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp const uint32_t stride_bz_mean, const uint32_t stride_h_mean, const uint32_t stride_bz_scale, const uint32_t stride_h_scale) { +#if defined(__HIP_PLATFORM_AMD__) + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +#else static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +#endif constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 diff --git a/csrc/numeric_conversion.cuh b/csrc/numeric_conversion.cuh index 88f80f50..965f452b 100644 --- a/csrc/numeric_conversion.cuh +++ b/csrc/numeric_conversion.cuh @@ -17,19 +17,28 @@ */ #pragma once +#if defined(__HIP_PLATFORM_AMD__) +#include +#include +#include +#include +#else #include #include #include #include #include +#endif +#if !defined(__HIP_PLATFORM_AMD__) #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) #define FP8_CAST_ENABLED #endif #endif +#endif -#if defined(__CUDA_ARCH__) +#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__) #define RUNTIME_ASSERT(x) __brkpt() #else #include @@ -39,13 +48,28 @@ __device__ __forceinline__ void unpack_half2_from_uint32_to_float(float* dest, uint32_t source) { uint16_t h0 = source & 0xFFFF; uint16_t h1 = (source >> 16) & 0xFFFF; +#if defined(__HIP_PLATFORM_AMD__) + union { + uint16_t bits; + half value; + } lo{h0}, hi{h1}; + dest[0] = __half2float(lo.value); + dest[1] = __half2float(hi.value); +#else asm("cvt.f32.f16 %0, %1;" : "=f"(dest[0]) : "h"(h0)); asm("cvt.f32.f16 %0, %1;" : "=f"(dest[1]) : "h"(h1)); +#endif } __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1) { -#ifdef FP8_CAST_ENABLED +#if defined(__HIP_PLATFORM_AMD__) + const auto lo = __hip_cvt_float2_to_fp8x2( + make_float2(source0[0], source0[1]), __HIP_SATFINITE, __HIP_E4M3); + const auto hi = __hip_cvt_float2_to_fp8x2( + make_float2(source1[0], source1[1]), __HIP_SATFINITE, __HIP_E4M3); + dest[0] = static_cast(lo) | (static_cast(hi) << 16); +#elif defined(FP8_CAST_ENABLED) asm volatile( \ "{\n" \ ".reg .b16 lo;\n" \ @@ -62,7 +86,13 @@ __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0 __device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1) { -#ifdef FP8_CAST_ENABLED +#if defined(__HIP_PLATFORM_AMD__) + const auto lo = __hip_cvt_float2_to_fp8x2( + make_float2(source0[0], source0[1]), __HIP_SATFINITE, __HIP_E5M2); + const auto hi = __hip_cvt_float2_to_fp8x2( + make_float2(source1[0], source1[1]), __HIP_SATFINITE, __HIP_E5M2); + dest[0] = static_cast(lo) | (static_cast(hi) << 16); +#elif defined(FP8_CAST_ENABLED) asm volatile( \ "{\n" \ ".reg .b16 lo;\n" \ @@ -79,7 +109,13 @@ __device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0 __device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) { -#ifdef FP8_CAST_ENABLED +#if defined(__HIP_PLATFORM_AMD__) + float s0[2]; + float s1[2]; + unpack_half2_from_uint32_to_float(s0, source0[0]); + unpack_half2_from_uint32_to_float(s1, source1[0]); + floatx4_to_e4m3x4(dest, s0, s1); +#elif defined(FP8_CAST_ENABLED) asm volatile( \ "{\n" \ ".reg .b16 lo;\n" \ @@ -96,7 +132,13 @@ __device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *sourc __device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) { -#ifdef FP8_CAST_ENABLED +#if defined(__HIP_PLATFORM_AMD__) + float s0[2]; + float s1[2]; + unpack_half2_from_uint32_to_float(s0, source0[0]); + unpack_half2_from_uint32_to_float(s1, source1[0]); + floatx4_to_e5m2x4(dest, s0, s1); +#elif defined(FP8_CAST_ENABLED) asm volatile( \ "{\n" \ ".reg .b16 lo;\n" \ @@ -113,7 +155,16 @@ __device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *sourc __device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) { -#ifdef FP8_CAST_ENABLED +#if defined(__HIP_PLATFORM_AMD__) + const auto lo = __hip_cvt_fp8x2_to_halfraw2( + static_cast<__hip_fp8x2_storage_t>(source[0] & 0xFFFF), __HIP_E4M3); + const auto hi = __hip_cvt_fp8x2_to_halfraw2( + static_cast<__hip_fp8x2_storage_t>(source[0] >> 16), __HIP_E4M3); + dest0[0] = static_cast(lo.x.x) | + (static_cast(lo.y.x) << 16); + dest1[0] = static_cast(hi.x.x) | + (static_cast(hi.y.x) << 16); +#elif defined(FP8_CAST_ENABLED) asm volatile( \ "{\n" \ ".reg .b16 lo, hi;\n" \ @@ -128,7 +179,16 @@ __device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest __device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) { -#ifdef FP8_CAST_ENABLED +#if defined(__HIP_PLATFORM_AMD__) + const auto lo = __hip_cvt_fp8x2_to_halfraw2( + static_cast<__hip_fp8x2_storage_t>(source[0] & 0xFFFF), __HIP_E5M2); + const auto hi = __hip_cvt_fp8x2_to_halfraw2( + static_cast<__hip_fp8x2_storage_t>(source[0] >> 16), __HIP_E5M2); + dest0[0] = static_cast(lo.x.x) | + (static_cast(lo.y.x) << 16); + dest1[0] = static_cast(hi.x.x) | + (static_cast(hi.y.x) << 16); +#elif defined(FP8_CAST_ENABLED) asm volatile( \ "{\n" \ ".reg .b16 lo, hi;\n" \ @@ -143,7 +203,12 @@ __device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest __device__ __forceinline__ int8_t float_to_int8_rn(float x) { +#if defined(__HIP_PLATFORM_AMD__) + const float clipped = fminf(127.0f, fmaxf(-128.0f, nearbyintf(x))); + return static_cast(clipped); +#else uint32_t dst; asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); return reinterpret_cast(dst); -} \ No newline at end of file +#endif +} diff --git a/csrc/qattn/attn_gfx12_native.h b/csrc/qattn/attn_gfx12_native.h new file mode 100644 index 00000000..2296c0c4 --- /dev/null +++ b/csrc/qattn/attn_gfx12_native.h @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0. + */ + +#pragma once + +#include + +#include + +using torch::stable::Tensor; + +Tensor qk_int8_sv_f16_d64_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len = 0); + +Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len = 0); + +Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len = 0); + +Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len = 0); + +Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( + Tensor query, + Tensor key, + Tensor value, + int64_t is_causal, + int64_t value_is_fp8, + int64_t use_raw_f16_value, + double sm_scale, + int64_t valid_kv_len = 0); + +std::vector quant_q_nhd_per_warp_gfx12(Tensor query); + +Tensor transpose_value_fp8_hnd_gfx12(Tensor value); + +Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale); + +Tensor transpose_value_f16_hnd_gfx12(Tensor value); + +Tensor convert_f16_to_bf16_gfx12(Tensor input); diff --git a/csrc/qattn/pybind_gfx12_native.cpp b/csrc/qattn/pybind_gfx12_native.cpp new file mode 100644 index 00000000..18dfc20e --- /dev/null +++ b/csrc/qattn/pybind_gfx12_native.cpp @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0. + */ + +#include +#include + +#include "attn_gfx12_native.h" + +PyMODINIT_FUNC PyInit__qattn_gfx12_native(void) +{ + static struct PyModuleDef module_def = { + PyModuleDef_HEAD_INIT, + "_qattn_gfx12_native", + NULL, + -1, + NULL, + }; + return PyModule_Create(&module_def); +} + +STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { + m.def("qk_int8_sv_f16_d64_native_attn(" + "Tensor query, " + "Tensor key, " + "Tensor value, " + "Tensor(a!) output, " + "Tensor query_scale, " + "Tensor key_scale, " + "int tensor_layout, " + "int is_causal, " + "float sm_scale, " + "int valid_kv_len" + ") -> Tensor"); + m.def("qk_rawq_int8_sv_f8_native_attn(" + "Tensor query, " + "Tensor key, " + "Tensor value, " + "Tensor(a!) output, " + "Tensor key_scale, " + "int tensor_layout, " + "int is_causal, " + "float sm_scale, " + "int valid_kv_len" + ") -> Tensor"); + m.def("qk_int8_sv_f8_scaled_native_attn(" + "Tensor query, " + "Tensor key, " + "Tensor value, " + "Tensor(a!) output, " + "Tensor query_scale, " + "Tensor key_scale, " + "Tensor value_scale, " + "int tensor_layout, " + "int is_causal, " + "float sm_scale, " + "int valid_kv_len" + ") -> Tensor"); + m.def("qk_rawq_int8_sv_f8_scaled_native_attn(" + "Tensor query, " + "Tensor key, " + "Tensor value, " + "Tensor(a!) output, " + "Tensor key_scale, " + "Tensor value_scale, " + "int tensor_layout, " + "int is_causal, " + "float sm_scale, " + "int valid_kv_len" + ") -> Tensor"); + m.def("qk_int8_sv_f16_d64_prepare_attn_hnd(" + "Tensor query, " + "Tensor key, " + "Tensor value, " + "int is_causal, " + "int value_is_fp8, " + "int use_raw_f16_value, " + "float sm_scale, " + "int valid_kv_len" + ") -> Tensor"); + m.def("quant_q_nhd_per_warp(Tensor query) -> Tensor[]"); + m.def("transpose_value_fp8_hnd(Tensor value) -> Tensor"); + m.def("transpose_value_fp8_scaled_hnd(Tensor value, Tensor value_scale) -> Tensor"); + m.def("transpose_value_f16_hnd(Tensor value) -> Tensor"); + m.def("convert_f16_to_bf16(Tensor input) -> Tensor"); +} + +STABLE_TORCH_LIBRARY_IMPL(sageattention_qattn_gfx12_native, CUDA, m) { + m.impl("qk_int8_sv_f16_d64_native_attn", TORCH_BOX(qk_int8_sv_f16_d64_native_attn_gfx12)); + m.impl("qk_rawq_int8_sv_f8_native_attn", TORCH_BOX(qk_rawq_int8_sv_f8_native_attn_gfx12)); + m.impl("qk_int8_sv_f8_scaled_native_attn", TORCH_BOX(qk_int8_sv_f8_scaled_native_attn_gfx12)); + m.impl("qk_rawq_int8_sv_f8_scaled_native_attn", TORCH_BOX(qk_rawq_int8_sv_f8_scaled_native_attn_gfx12)); + m.impl("qk_int8_sv_f16_d64_prepare_attn_hnd", TORCH_BOX(qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12)); + m.impl("quant_q_nhd_per_warp", TORCH_BOX(quant_q_nhd_per_warp_gfx12)); + m.impl("transpose_value_fp8_hnd", TORCH_BOX(transpose_value_fp8_hnd_gfx12)); + m.impl("transpose_value_fp8_scaled_hnd", TORCH_BOX(transpose_value_fp8_scaled_hnd_gfx12)); + m.impl("transpose_value_f16_hnd", TORCH_BOX(transpose_value_f16_hnd_gfx12)); + m.impl("convert_f16_to_bf16", TORCH_BOX(convert_f16_to_bf16_gfx12)); +} diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu new file mode 100644 index 00000000..5011e264 --- /dev/null +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -0,0 +1,7446 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * Copyright (c) 2026 Advanced Micro Devices, Inc. + * + * Licensed under the Apache License, Version 2.0. + */ + +#include +#include + +#include "../torch_version.h" +#if TORCH_FEATURE_VERSION >= TORCH_VERSION_2_10_0 +#include +#endif + +#include +#include + +#if defined(__HIP_PLATFORM_AMD__) +#include +#include +#include +#include +#include +#include +#else +#error "qk_int_sv_gfx12_native.cu is only intended for ROCm/HIP." +#endif + +#include "../reduction_utils.cuh" + +#include +#include +#include +#include +#include +#include + +using torch::stable::Tensor; +using ScalarType = torch::headeronly::ScalarType; + +namespace { + +constexpr int kNHD = 0; +constexpr int kHND = 1; +constexpr float kLog2e = 1.4426950408889634f; +constexpr float kFp8SoftmaxOffset = 8.807f; +constexpr float kF16SoftmaxOffset = 0.0f; + +Tensor new_empty_like(const Tensor& like, std::initializer_list sizes, ScalarType dtype) { + return torch::stable::new_empty(like, std::vector(sizes), std::make_optional(dtype)); +} + +Tensor new_empty_like( + const Tensor& like, + torch::headeronly::IntHeaderOnlyArrayRef sizes, + ScalarType dtype) { + return torch::stable::new_empty(like, sizes, std::make_optional(dtype)); +} + +std::vector contiguous_strides(const std::vector& sizes) { + std::vector strides(sizes.size()); + int64_t stride = 1; + for (int64_t idx = static_cast(sizes.size()) - 1; idx >= 0; --idx) { + strides[idx] = stride; + stride *= sizes[idx]; + } + return strides; +} + +Tensor from_blob_like( + void* data, + std::initializer_list sizes, + const Tensor& like, + ScalarType dtype) { + std::vector shape(sizes); + std::vector strides = contiguous_strides(shape); + return torch::stable::from_blob(data, shape, strides, like.device(), dtype); +} + +#define SAGEATTN_NATIVE_HAS_GFX12_WMMA 1 +#ifndef SAGEATTN_GFX12_NATIVE_WAVES_PER_EU_MAX +#define SAGEATTN_GFX12_NATIVE_WAVES_PER_EU_MAX 1 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_D64_2Q_WAVES_PER_EU_MAX +#define SAGEATTN_GFX12_NATIVE_D64_2Q_WAVES_PER_EU_MAX 2 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_D64_2Q_CAUSAL_WAVES_PER_EU_MAX +#define SAGEATTN_GFX12_NATIVE_D64_2Q_CAUSAL_WAVES_PER_EU_MAX 2 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_D64_2Q_CAUSAL_WAVES_PER_EU_MIN +#define SAGEATTN_GFX12_NATIVE_D64_2Q_CAUSAL_WAVES_PER_EU_MIN 1 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_D16_2Q_WAVES_PER_EU_MAX +#define SAGEATTN_GFX12_NATIVE_D16_2Q_WAVES_PER_EU_MAX 2 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_D128_2Q_WAVES_PER_EU_MAX +#define SAGEATTN_GFX12_NATIVE_D128_2Q_WAVES_PER_EU_MAX 1 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_F16_TV_PAD +#define SAGEATTN_GFX12_NATIVE_F16_TV_PAD 16 +#endif +#ifndef SAGEATTN_GFX12_NATIVE_F16_2Q_MIN_BLOCKS +#define SAGEATTN_GFX12_NATIVE_F16_2Q_MIN_BLOCKS 1 +#endif +#define SAGEATTN_NATIVE_WAVES_PER_EU \ + __attribute__((amdgpu_waves_per_eu(1, SAGEATTN_GFX12_NATIVE_WAVES_PER_EU_MAX))) +#define SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HD_, CAUSAL_) \ + __attribute__((amdgpu_waves_per_eu( \ + ((HD_) == 64 && (CAUSAL_) ? SAGEATTN_GFX12_NATIVE_D64_2Q_CAUSAL_WAVES_PER_EU_MIN : 1), \ + ((HD_) == 16 ? SAGEATTN_GFX12_NATIVE_D16_2Q_WAVES_PER_EU_MAX : \ + ((HD_) == 64 ? ((CAUSAL_) ? SAGEATTN_GFX12_NATIVE_D64_2Q_CAUSAL_WAVES_PER_EU_MAX \ + : SAGEATTN_GFX12_NATIVE_D64_2Q_WAVES_PER_EU_MAX) \ + : SAGEATTN_GFX12_NATIVE_D128_2Q_WAVES_PER_EU_MAX))))) +#define SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BR_) \ + __launch_bounds__(BR_, SAGEATTN_GFX12_NATIVE_F16_2Q_MIN_BLOCKS) +#define SAGEATTN_F16_SCHED_BARRIER(MASK_) ((void)0) + +using half8_vec = _Float16 __attribute__((ext_vector_type(8))); +using float8_vec = float __attribute__((ext_vector_type(8))); +using i16x8_vec = int16_t __attribute__((ext_vector_type(8))); +using i32x2_vec = int32_t __attribute__((ext_vector_type(2))); +using u32x4_vec = uint32_t __attribute__((ext_vector_type(4))); +using i32x8_vec = int32_t __attribute__((ext_vector_type(8))); + +void hip_kernel_launch_check() { + const hipError_t err = hipGetLastError(); + STD_TORCH_CHECK(err == hipSuccess, "HIP kernel launch failed: ", hipGetErrorString(err)); +} + +__device__ __forceinline__ float value_to_float(const __half value) { + return __half2float(value); +} + +__device__ __forceinline__ float value_to_float(const __hip_bfloat16 value) { + return __bfloat162float(value); +} + +__device__ __forceinline__ int8_t float_to_int8_rn_gfx12(const float x) { + int32_t rounded; + asm volatile("v_cvt_i32_f32 %[dst], %[src]" + : [dst] "=v"(rounded) + : [src] "v"(x)); + rounded = rounded > 127 ? 127 : rounded; + rounded = rounded < -128 ? -128 : rounded; + return static_cast(rounded); +} + +__device__ __forceinline__ int8_t float_to_int8_nearby_gfx12(const float x) { + const float clipped = fminf(127.0f, fmaxf(-128.0f, nearbyintf(x))); + return static_cast(clipped); +} + +template +__global__ void quant_qk_int8_hnd_kernel( + const T* __restrict__ query, + const T* __restrict__ key, + int8_t* __restrict__ query_out, + int8_t* __restrict__ key_out, + float* __restrict__ query_scale, + float* __restrict__ key_scale, + const int64_t batch, + const int64_t q_heads, + const int64_t kv_heads, + const int64_t q_len, + const int64_t kv_len, + const int q_groups, + const int k_groups) { + constexpr int Threads = 256; + __shared__ float shared_amax; + + const int group = blockIdx.x; + const int head = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const bool is_q = group < q_groups; + const int local_group = is_q ? group : group - q_groups; + const int rows_per_group = is_q ? 32 : 64; + const int64_t seq_len = is_q ? q_len : kv_len; + const int64_t base_row = static_cast(local_group) * rows_per_group; + const int active_heads = is_q ? static_cast(q_heads) : static_cast(kv_heads); + if (b >= batch || head >= active_heads || base_row >= seq_len) { + return; + } + + const T* in = is_q ? query : key; + int8_t* out = is_q ? query_out : key_out; + float* scale_out = is_q ? query_scale : key_scale; + const int64_t heads = is_q ? q_heads : kv_heads; + const int scale_groups = is_q ? q_groups : k_groups; + constexpr int PackElems = 8; + static_assert((HeadDim % PackElems) == 0, "native quantization packs eight elements"); + const int packs = (rows_per_group * HeadDim) / PackElems; + + float local_amax = 0.0000001f; + for (int pack = tid; pack < packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (seq < seq_len) { + const int64_t off = ((static_cast(b) * heads + head) * seq_len + seq) * HeadDim + d; + const uint4 raw = *reinterpret_cast(in + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax = fmaxf(local_amax, fabsf(value_to_float(values[i]))); + } + } + } + const float block_amax = vllm::blockReduceMax(local_amax); + if (tid == 0) { + shared_amax = block_amax; + scale_out[(static_cast(b) * active_heads + head) * scale_groups + local_group] = + shared_amax / 127.0f; + } + __syncthreads(); + const float inv_scale = 127.0f / shared_amax; + + for (int pack = tid; pack < packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (seq < seq_len) { + const int64_t off = ((static_cast(b) * heads + head) * seq_len + seq) * HeadDim + d; + const uint4 raw = *reinterpret_cast(in + off); + const T* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale); + *reinterpret_cast(out + off) = out0; + *reinterpret_cast(out + off + 4) = out1; + } + } +} + +template +__global__ void quant_q_nhd_per_warp_kernel( + const T* __restrict__ query, + int8_t* __restrict__ query_out, + float* __restrict__ query_scale, + const int64_t batch, + const int64_t q_len, + const int64_t q_heads, + const int q_scale_groups) { + constexpr int Threads = 256; + constexpr int PackElems = 8; + constexpr int QRows = 32; + constexpr int GroupsPerBlock = 2; + static_assert((HeadDim % PackElems) == 0, "native Q quantization packs eight elements"); + + __shared__ float shared_amax[GroupsPerBlock]; + __shared__ float shared_pair_amax[GroupsPerBlock][32]; + + const int group_base = static_cast(blockIdx.x) * GroupsPerBlock; + const int head = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const int lane = tid & 31; + const int wid = tid >> 5; + if (b >= batch || head >= q_heads || + static_cast(group_base) * QRows >= q_len) { + return; + } + + constexpr int Packs = (QRows * HeadDim) / PackElems; + float local_amax0 = 0.0000001f; + float local_amax1 = 0.0000001f; + const bool has_group1 = (group_base + 1) < q_scale_groups; + const int64_t base_row0 = static_cast(group_base) * QRows; + const int64_t base_row1 = base_row0 + QRows; + + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq0 = base_row0 + row; + if (seq0 < q_len) { + const int64_t off = + ((static_cast(b) * q_len + seq0) * q_heads + head) * HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax0 = fmaxf(local_amax0, fabsf(value_to_float(values[i]))); + } + } + const int64_t seq1 = base_row1 + row; + if (has_group1 && seq1 < q_len) { + const int64_t off = + ((static_cast(b) * q_len + seq1) * q_heads + head) * HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax1 = fmaxf(local_amax1, fabsf(value_to_float(values[i]))); + } + } + } + + local_amax0 = vllm::warpReduceMax(local_amax0); + local_amax1 = vllm::warpReduceMax(local_amax1); + if (lane == 0) { + shared_pair_amax[0][wid] = local_amax0; + shared_pair_amax[1][wid] = local_amax1; + } + __syncthreads(); + + const bool warp_lane_active = tid < (blockDim.x / 32); + local_amax0 = warp_lane_active ? shared_pair_amax[0][lane] : -1e20f; + local_amax1 = warp_lane_active ? shared_pair_amax[1][lane] : -1e20f; + local_amax0 = vllm::warpReduceMax(local_amax0); + local_amax1 = vllm::warpReduceMax(local_amax1); + if (tid == 0) { + shared_amax[0] = local_amax0; + query_scale[(static_cast(b) * q_heads + head) * q_scale_groups + + group_base] = local_amax0 / 127.0f; + if (has_group1) { + shared_amax[1] = local_amax1; + query_scale[(static_cast(b) * q_heads + head) * q_scale_groups + + group_base + 1] = local_amax1 / 127.0f; + } + } + __syncthreads(); + + const float inv_scale0 = 127.0f / shared_amax[0]; + const float inv_scale1 = has_group1 ? (127.0f / shared_amax[1]) : 0.0f; + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq0 = base_row0 + row; + if (seq0 < q_len) { + const int64_t off = + ((static_cast(b) * q_len + seq0) * q_heads + head) * HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_nearby_gfx12(value_to_float(values[0]) * inv_scale0); + out0.y = float_to_int8_nearby_gfx12(value_to_float(values[1]) * inv_scale0); + out0.z = float_to_int8_nearby_gfx12(value_to_float(values[2]) * inv_scale0); + out0.w = float_to_int8_nearby_gfx12(value_to_float(values[3]) * inv_scale0); + out1.x = float_to_int8_nearby_gfx12(value_to_float(values[4]) * inv_scale0); + out1.y = float_to_int8_nearby_gfx12(value_to_float(values[5]) * inv_scale0); + out1.z = float_to_int8_nearby_gfx12(value_to_float(values[6]) * inv_scale0); + out1.w = float_to_int8_nearby_gfx12(value_to_float(values[7]) * inv_scale0); + *reinterpret_cast(query_out + off) = out0; + *reinterpret_cast(query_out + off + 4) = out1; + } + const int64_t seq1 = base_row1 + row; + if (has_group1 && seq1 < q_len) { + const int64_t off = + ((static_cast(b) * q_len + seq1) * q_heads + head) * HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_nearby_gfx12(value_to_float(values[0]) * inv_scale1); + out0.y = float_to_int8_nearby_gfx12(value_to_float(values[1]) * inv_scale1); + out0.z = float_to_int8_nearby_gfx12(value_to_float(values[2]) * inv_scale1); + out0.w = float_to_int8_nearby_gfx12(value_to_float(values[3]) * inv_scale1); + out1.x = float_to_int8_nearby_gfx12(value_to_float(values[4]) * inv_scale1); + out1.y = float_to_int8_nearby_gfx12(value_to_float(values[5]) * inv_scale1); + out1.z = float_to_int8_nearby_gfx12(value_to_float(values[6]) * inv_scale1); + out1.w = float_to_int8_nearby_gfx12(value_to_float(values[7]) * inv_scale1); + *reinterpret_cast(query_out + off) = out0; + *reinterpret_cast(query_out + off + 4) = out1; + } + } +} + +template +__global__ void transpose_value_hnd_kernel( + const T* __restrict__ value, + OutT* __restrict__ output, + const int64_t total_heads, + const int64_t seq_len, + const int64_t head_dim) { + constexpr int TileS = 128; + constexpr int TileD = 16; + __shared__ OutT tile[TileS][TileD]; + + const int tid = threadIdx.x; + const int64_t bh = blockIdx.z; + + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int load_s = linear / TileD; + const int load_d = linear - load_s * TileD; + const int64_t s = static_cast(blockIdx.x) * TileS + load_s; + const int64_t d = static_cast(blockIdx.y) * TileD + load_d; + if (bh < total_heads && s < seq_len && d < head_dim) { + const float v = value_to_float(value[(bh * seq_len + s) * head_dim + d]); + if constexpr (ToFp8) { + tile[load_s][load_d] = + __hip_cvt_float_to_fp8(v, __HIP_SATFINITE, __HIP_E4M3); + } else { + tile[load_s][load_d] = __float2half_rn(v); + } + } + } + __syncthreads(); + + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int store_d_local = linear / TileS; + const int store_s_local = linear - store_d_local * TileS; + const int64_t store_s = static_cast(blockIdx.x) * TileS + store_s_local; + const int64_t store_d = static_cast(blockIdx.y) * TileD + store_d_local; + if (bh < total_heads && store_s < seq_len && store_d < head_dim) { + output[(bh * head_dim + store_d) * seq_len + store_s] = + tile[store_s_local][store_d_local]; + } + } +} + +template +Tensor transpose_value_hnd_gfx12(Tensor value) { + STD_TORCH_CHECK(value.is_cuda(), "gfx12 value transpose expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(value.dim() == 4, "gfx12 value transpose expects [B, H, S, D]"); + STD_TORCH_CHECK(value.is_contiguous(), "gfx12 value transpose expects contiguous HND input"); + STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half || value.scalar_type() == ScalarType::BFloat16, + "gfx12 value transpose supports fp16/bf16 input"); + + const int64_t batch = value.size(0); + const int64_t heads = value.size(1); + const int64_t seq_len = value.size(2); + const int64_t head_dim = value.size(3); + const ScalarType out_dtype = ToFp8 ? ScalarType::Byte : ScalarType::Half; + Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, out_dtype); + + dim3 block(256); + dim3 grid((seq_len + 127) / 128, (head_dim + 15) / 16, batch * heads); + if (value.scalar_type() == ScalarType::Half) { + transpose_value_hnd_kernel<__half, OutT, ToFp8><<>>( + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + batch * heads, seq_len, head_dim); + } else { + transpose_value_hnd_kernel<__hip_bfloat16, OutT, ToFp8><<>>( + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + batch * heads, seq_len, head_dim); + } + hip_kernel_launch_check(); + return output; +} + +template +__global__ void transpose_value_fp8_scaled_hnd_kernel( + const T* __restrict__ value, + const float* __restrict__ value_scale, + uint8_t* __restrict__ output, + const int64_t total_heads, + const int64_t seq_len, + const int64_t head_dim) { + constexpr int TileS = 128; + constexpr int TileD = 16; + __shared__ uint8_t tile[TileS][TileD]; + + const int tid = threadIdx.x; + const int64_t bh = blockIdx.z; + + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int load_s = linear / TileD; + const int load_d = linear - load_s * TileD; + const int64_t s = static_cast(blockIdx.x) * TileS + load_s; + const int64_t d = static_cast(blockIdx.y) * TileD + load_d; + if (bh < total_heads && s < seq_len && d < head_dim) { + const float scale = value_scale[bh * head_dim + d]; + const float v = scale == 0.0f ? 0.0f : + value_to_float(value[(bh * seq_len + s) * head_dim + d]) / scale; + tile[load_s][load_d] = + __hip_cvt_float_to_fp8(v, __HIP_SATFINITE, __HIP_E4M3); + } + } + __syncthreads(); + + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int store_d_local = linear / TileS; + const int store_s_local = linear - store_d_local * TileS; + const int64_t store_s = static_cast(blockIdx.x) * TileS + store_s_local; + const int64_t store_d = static_cast(blockIdx.y) * TileD + store_d_local; + if (bh < total_heads && store_s < seq_len && store_d < head_dim) { + output[(bh * head_dim + store_d) * seq_len + store_s] = + tile[store_s_local][store_d_local]; + } + } +} + +__device__ __forceinline__ int64_t qkv_offset( + const int tensor_layout, + const int64_t b, + const int64_t h, + const int64_t n, + const int64_t d, + const int64_t stride_b, + const int64_t stride_n, + const int64_t stride_h) { + return tensor_layout == kNHD + ? b * stride_b + n * stride_n + h * stride_h + d + : b * stride_b + h * stride_h + n * stride_n + d; +} + +template +__device__ __forceinline__ int64_t qkv_offset_dispatch( + const int tensor_layout, + const int64_t b, + const int64_t h, + const int64_t n, + const int64_t d, + const int64_t stride_b, + const int64_t stride_n, + const int64_t stride_h) { + if constexpr (HndContiguous) { + return b * stride_b + h * stride_h + n * HeadDim + d; + } else { + return qkv_offset(tensor_layout, b, h, n, d, stride_b, stride_n, stride_h); + } +} + +__device__ __forceinline__ int q_scale_col_per_warp(const int64_t q_idx) { + return static_cast((q_idx / 128) * 4 + ((q_idx & 127) / 32)); +} + +__device__ __forceinline__ int k_scale_col_per_warp(const int64_t k_idx) { + return static_cast(k_idx / 64); +} + +template +__device__ __forceinline__ int64_t q_block_base_for_launch( + const int64_t block_x, + const int64_t qo_len) { + if constexpr (IsCausal) { + const int64_t q_blocks = (qo_len + BlockRows - 1) / BlockRows; + return (q_blocks - 1 - block_x) * BlockRows; + } else { + return block_x * BlockRows; + } +} + +__device__ __forceinline__ float fast_exp2(float x) { + return __builtin_amdgcn_exp2f(x); +} + +template +__device__ __forceinline__ void apply_tqk_causal_mask( + float8_vec& scores, + const int q_start, + const int kb_base, + const int col_tile, + const int lane) { + if constexpr (IsCausal) { + const int tile_end = kb_base + col_tile * 16 + 15; + if (tile_end <= q_start) { + return; + } + const int q_idx = q_start + (lane & 15); + const int k_base = kb_base + col_tile * 16 + ((lane >> 4) << 3); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = (k_base + elem) > q_idx ? -FLT_MAX * 0.5f : scores[elem]; + } + } +} + +template +__device__ __forceinline__ int active_causal_col_tiles( + const int64_t q_start, + const int64_t kb_base) { + if constexpr (!ApplyCausalMask) { + return ColTiles; + } else { + constexpr int BK = 16; + constexpr int RM = 16; + const int64_t q_end = q_start + RM; + if (q_end <= kb_base) { + return 0; + } + const int64_t cols = (q_end - kb_base + BK - 1) / BK; + return static_cast(cols < ColTiles ? cols : ColTiles); + } +} + +__device__ __forceinline__ int wmma_f16_k_for_lane_elem( + const int lane, + const int elem) { + const int reg = elem >> 1; + const int half = elem & 1; + return ((reg >> 1) << 3) + (((lane >> 4) & 1) << 2) + ((reg & 1) << 1) + half; +} + +template +__device__ __forceinline__ void apply_tqk_kv_tail_mask( + float8_vec& scores, + const int64_t valid_kv_len, + const int64_t kb_base, + const int col_tile, + const int lane) { + const int64_t col_base = kb_base + static_cast(col_tile) * 16; + if (col_base + 15 < valid_kv_len) { + return; + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int k_local = PvOrdered ? + wmma_f16_k_for_lane_elem(lane, elem) : + (((lane >> 4) << 3) + elem); + scores[elem] = (col_base + k_local) >= valid_kv_len ? + -FLT_MAX * 0.5f : scores[elem]; + } +} + +__device__ __forceinline__ int gfx12_tr_b128_source_row_for_lane(const int lane) { + const int quad_group = lane >> 2; + const int quad_pos = lane & 3; + return quad_pos + ((quad_group & 1) << 3) + ((lane >> 4) << 2); +} + +__device__ __forceinline__ int gfx12_tr_b128_source_col_for_lane(const int lane) { + return ((lane >> 3) & 1) << 3; +} + +__device__ __forceinline__ half8_vec gfx12_global_load_tr_b128_f16(const __half* ptr) { +#if defined(__gfx1200__) || defined(__gfx1201__) + const i16x8_vec bits = __builtin_amdgcn_global_load_tr_b128_v8i16( + reinterpret_cast(const_cast<__half*>(ptr))); + return *reinterpret_cast(&bits); +#else + half8_vec regs; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + regs[elem] = static_cast<_Float16>(ptr[elem]); + } + return regs; +#endif +} + +__device__ __forceinline__ int pv_k_order_for_acc_row(const int row) { + return (row & 3) | ((row & 4) << 1) | ((row & 8) >> 1); +} + +template +__device__ __forceinline__ void apply_tqk_causal_mask_pv_order( + float8_vec& scores, + const int q_start, + const int kb_base, + const int col_tile, + const int lane) { + if constexpr (IsCausal) { + const int tile_end = kb_base + col_tile * 16 + 15; + if (tile_end <= q_start) { + return; + } + const int q_idx = q_start + (lane & 15); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int k_idx = kb_base + col_tile * 16 + wmma_f16_k_for_lane_elem(lane, elem); + scores[elem] = k_idx > q_idx ? -FLT_MAX * 0.5f : scores[elem]; + } + } +} + +template +__device__ __forceinline__ i32x2_vec pack_quant_q_i8_wmma_b_regs( + const QueryT* __restrict__ q, + const int tensor_layout, + const int lane, + const int64_t b, + const int64_t h, + const int64_t q_start, + const int64_t qo_len, + const int d_base, + const int64_t q_stride_b, + const int64_t q_stride_n, + const int64_t q_stride_h, + const float inv_q_scale) { + i32x2_vec regs; + const int row = lane & 15; + const int k_base = 8 * (lane >> 4); + const int64_t q_idx = q_start + row; + if (q_idx >= qo_len) { +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + regs[gpr] = 0; + } + return regs; + } + + const int d = d_base + k_base; + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, h, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint4 raw = *reinterpret_cast(q + q_off); + const QueryT* values = reinterpret_cast(&raw); +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + uint32_t packed = 0; +#pragma unroll + for (int byte = 0; byte < 4; ++byte) { + packed |= static_cast(static_cast( + float_to_int8_nearby_gfx12(value_to_float(values[4 * gpr + byte]) * inv_q_scale))) + << (8 * byte); + } + regs[gpr] = static_cast(packed); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec pack_q_i8_wmma_b_regs( + const int8_t* __restrict__ q, + const int tensor_layout, + const int lane, + const int64_t b, + const int64_t h, + const int64_t q_start, + const int64_t qo_len, + const int d_base, + const int64_t q_stride_b, + const int64_t q_stride_n, + const int64_t q_stride_h) { + i32x2_vec regs; + const int row = lane & 15; + const int k_base = 8 * (lane >> 4); + const int64_t q_idx = q_start + row; + if (q_idx >= qo_len) { +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + regs[gpr] = 0; + } + return regs; + } + + const int d = d_base + k_base; + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, h, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint2 raw = *reinterpret_cast(q + q_off); + regs[0] = static_cast(raw.x); + regs[1] = static_cast(raw.y); + return regs; +} + +template +__device__ __forceinline__ i32x2_vec pack_k_i8_wmma_b_regs_from_shared( + const int8_t* __restrict__ k_tile, + const int lane, + const int col_tile, + const int d_base) { + i32x2_vec regs; + const int col = lane & 15; + const int k_base = 8 * (lane >> 4); + const int row = col_tile * 16 + col; +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + const int d = d_base + k_base + 4 * gpr; + regs[gpr] = *reinterpret_cast( + k_tile + row * SharedHeadStride + d); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec pack_k_i8_wmma_b_regs_from_shared_pv_order( + const int8_t* __restrict__ k_tile, + const int lane, + const int col_tile, + const int d_base) { + i32x2_vec regs; + const int col = lane & 15; + const int k_base = 8 * (lane >> 4); + const int row = col_tile * 16 + pv_k_order_for_acc_row(col); +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + const int d = d_base + k_base + 4 * gpr; + regs[gpr] = *reinterpret_cast( + k_tile + row * SharedHeadStride + d); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec pack_k_i8_wmma_b_regs_from_lane_major_global( + const int8_t* __restrict__ k, + const int64_t k_head_base, + const int64_t k_group_stride, + const int64_t kb_base, + const int col_tile, + const int d_tile, + const int lane) { + const int64_t group = kb_base >> 6; + const int group_col_tile = ((static_cast(kb_base) & 63) >> 4) + col_tile; + const int64_t off = k_head_base + group * k_group_stride + + (((static_cast(group_col_tile) * DTiles + d_tile) * 32 + lane) * 8); + const uint2 raw = *reinterpret_cast(k + off); + i32x2_vec regs; + regs[0] = static_cast(raw.x); + regs[1] = static_cast(raw.y); + return regs; +} + +template +__device__ __forceinline__ i32x2_vec pack_k_i8_wmma_b_regs_from_lane_major_shared( + const uint2* __restrict__ k_lane_tile, + const int col_tile, + const int d_tile, + const int lane) { + const uint2 raw = k_lane_tile[(col_tile * DTiles + d_tile) * 32 + lane]; + i32x2_vec regs; + regs[0] = static_cast(raw.x); + regs[1] = static_cast(raw.y); + return regs; +} + +__device__ __forceinline__ int wmma_fp8_k_for_lane_byte( + const int lane, + const int gpr, + const int byte) { + return 8 * (lane >> 4) + 4 * gpr + byte; +} + +__device__ __forceinline__ int32_t pack_f32x4_to_ocp_fp8( + const float x0, + const float x1, + const float x2, + const float x3) { + float a0 = x0; + float a1 = x1; + float a2 = x2; + float a3 = x3; + uint32_t packed; + asm volatile("v_cvt_pk_fp8_f32 %[dst], %[a0], %[a1]\n" + "v_cvt_pk_fp8_f32 %[dst], %[a2], %[a3], op_sel:[0, 0, 1]\n" + : [dst] "=v"(packed), [a0] "+v"(a0), [a1] "+v"(a1), + [a2] "+v"(a2), [a3] "+v"(a3)); + return static_cast(packed); +} + +__device__ __forceinline__ uint16_t half_to_u16_bits(const _Float16 value) { + return __builtin_bit_cast(uint16_t, value); +} + +__device__ __forceinline__ _Float16 u16_bits_to_half(const uint16_t value) { + return __builtin_bit_cast(_Float16, value); +} + +__device__ __forceinline__ i32x2_vec make_p_fp8_regs_from_tqk_prob_regs( + const float8_vec prob_values, + const int lane) { + (void)lane; + i32x2_vec regs; + regs[0] = pack_f32x4_to_ocp_fp8( + prob_values[0], prob_values[1], prob_values[2], prob_values[3]); + regs[1] = pack_f32x4_to_ocp_fp8( + prob_values[4], prob_values[5], prob_values[6], prob_values[7]); + return regs; +} + +__device__ __forceinline__ half8_vec make_p_regs_from_tqk_prob_regs( + const half8_vec prob_values, + const int lane) { + half8_vec regs; + const bool lane_upper = lane >= 16; + u32x4_vec local_values; + u32x4_vec peer_values; +#pragma unroll + for (int i = 0; i < 4; ++i) { + const uint32_t lo = half_to_u16_bits(prob_values[2 * i]); + const uint32_t hi = half_to_u16_bits(prob_values[2 * i + 1]); + local_values[i] = lo | (hi << 16); + peer_values[i] = __shfl_xor(local_values[i], 16, 32); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int p_col = wmma_f16_k_for_lane_elem(lane, elem); + const int source_elem = p_col & 7; + const bool source_upper = p_col >= 8; + const uint32_t packed = source_upper == lane_upper + ? local_values[source_elem >> 1] + : peer_values[source_elem >> 1]; + const uint16_t bits = static_cast(packed >> (16 * (source_elem & 1))); + regs[elem] = u16_bits_to_half(bits); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec make_v_fp8_regs_from_shared( + const uint8_t* __restrict__ v_tile, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + i32x2_vec regs; + const int d = d_tile * BK + (lane & 15); +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + uint32_t packed = 0; +#pragma unroll + for (int byte = 0; byte < 4; ++byte) { + const int k_local = wmma_fp8_k_for_lane_byte(lane, gpr, byte); + const int row = col_tile * BK + k_local; + packed |= static_cast(v_tile[row * SharedHeadStride + d]) << (8 * byte); + } + regs[gpr] = static_cast(packed); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec make_v_fp8_regs_from_transposed_shared( + const uint32_t* __restrict__ v_tile, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + i32x2_vec regs; + const int d = d_tile * BK + (lane & 15); +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + const int k_local = wmma_fp8_k_for_lane_byte(lane, gpr, 0); + const int n_group = (col_tile * BK + k_local) >> 2; + regs[gpr] = static_cast(v_tile[d * SharedValueStride32 + n_group]); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec make_v_fp8_regs_from_transposed_shared( + const uint8_t* __restrict__ v_tile, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + i32x2_vec regs; + const int d = d_tile * BK + (lane & 15); +#pragma unroll + for (int gpr = 0; gpr < 2; ++gpr) { + uint32_t packed = 0; +#pragma unroll + for (int byte = 0; byte < 4; ++byte) { + const int k_local = wmma_fp8_k_for_lane_byte(lane, gpr, byte); + const int n = col_tile * BK + k_local; + packed |= static_cast(v_tile[d * SharedValueStride + n]) << (8 * byte); + } + regs[gpr] = static_cast(packed); + } + return regs; +} + +template +__device__ __forceinline__ i32x2_vec make_v_fp8_regs_from_lane_major_global( + const uint8_t* __restrict__ v, + const int64_t v_head_base, + const int64_t kb_base, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int ColTiles64 = 4; + const int64_t group = kb_base >> 6; + const int group_col_tile = ((static_cast(kb_base) & 63) >> 4) + col_tile; + const int64_t off = v_head_base + + group * (ColTiles64 * DTiles * 32 * 8) + + (((static_cast(group_col_tile) * DTiles + d_tile) * 32 + lane) * 8); + const uint2 raw = *reinterpret_cast(v + off); + i32x2_vec regs; + regs[0] = static_cast(raw.x); + regs[1] = static_cast(raw.y); + return regs; +} + +template < + int DTiles, + int SharedHeadStride, + typename FragK, + typename FragQ, + typename FragScoreT> +__device__ __forceinline__ float8_vec compute_tqk_score_regs( + const int8_t* __restrict__ k_tile, + const FragQ (&q_frag)[DTiles], + const int col_tile, + const float score_scale) { + constexpr int BK = 16; + FragScoreT score_acc; + rocwmma::fill_fragment(score_acc, 0); +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + FragK k_frag; + rocwmma::load_matrix_sync( + k_frag, + &k_tile[(col_tile * BK) * SharedHeadStride + dt * BK], + static_cast(SharedHeadStride)); + rocwmma::mma_sync(score_acc, k_frag, q_frag[dt], score_acc); + } + + const auto score_rm = rocwmma::apply_data_layout(score_acc); + float8_vec scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = static_cast(score_rm[elem]) * score_scale; + } + return scores; +} + +template +__device__ __forceinline__ float8_vec compute_tqk_score_regs_raw_kq( + const int8_t* __restrict__ k_tile, + const i32x2_vec (&q_regs)[DTiles], + const int lane, + const int col_tile, + const float score_scale) { + constexpr int BK = 16; + i32x8_vec score_acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc[elem] = 0; + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const i32x2_vec k_regs = + pack_k_i8_wmma_b_regs_from_shared( + k_tile, lane, col_tile, dt * BK); + score_acc = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[dt], score_acc, true); + } + + float8_vec scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = static_cast(score_acc[elem]) * score_scale; + } + return scores; +} + +template +__device__ __forceinline__ float8_vec compute_tqk_score_regs_raw_kq_one( + const int8_t* __restrict__ k_tile, + const i32x2_vec (&q_regs)[DTiles], + const int lane, + const int col_tile, + const float score_scale) { + constexpr int BK = 16; + i32x8_vec score_acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc[elem] = 0; + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + i32x2_vec k_regs; + if constexpr (PvOrderedQK) { + k_regs = pack_k_i8_wmma_b_regs_from_shared_pv_order( + k_tile, lane, col_tile, dt * BK); + } else { + k_regs = pack_k_i8_wmma_b_regs_from_shared( + k_tile, lane, col_tile, dt * BK); + } + score_acc = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[dt], score_acc, true); + } + + float8_vec scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = static_cast(score_acc[elem]) * score_scale; + } + return scores; +} + +template +__device__ __forceinline__ void compute_tqk_score_regs_raw_kq_2( + const int8_t* __restrict__ k_tile, + const i32x2_vec (&q_regs)[2][DTiles], + const int lane, + const int col_tile, + const float score_scale0, + const float score_scale1, + const bool do0, + const bool do1, + float8_vec& scores0, + float8_vec& scores1) { + constexpr int BK = 16; + i32x8_vec score_acc0; + i32x8_vec score_acc1; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc0[elem] = 0; + score_acc1[elem] = 0; + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + i32x2_vec k_regs; + if constexpr (PvOrderedQK) { + k_regs = pack_k_i8_wmma_b_regs_from_shared_pv_order( + k_tile, lane, col_tile, dt * BK); + } else { + k_regs = pack_k_i8_wmma_b_regs_from_shared( + k_tile, lane, col_tile, dt * BK); + } + if (do0) { + score_acc0 = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[0][dt], score_acc0, true); + } + if (do1) { + score_acc1 = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[1][dt], score_acc1, true); + } + } + +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[elem] = do0 ? static_cast(score_acc0[elem]) * score_scale0 : -FLT_MAX * 0.5f; + scores1[elem] = do1 ? static_cast(score_acc1[elem]) * score_scale1 : -FLT_MAX * 0.5f; + } +} + +template +__device__ __forceinline__ void compute_tqk_score_regs_raw_kq_2_lane_key( + const int8_t* __restrict__ k, + const int64_t k_head_base, + const int64_t k_group_stride, + const int64_t kb_base, + const i32x2_vec (&q_regs)[2][DTiles], + const int lane, + const int col_tile, + const float score_scale0, + const float score_scale1, + const bool do0, + const bool do1, + float8_vec& scores0, + float8_vec& scores1) { + i32x8_vec score_acc0; + i32x8_vec score_acc1; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc0[elem] = 0; + score_acc1[elem] = 0; + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const i32x2_vec k_regs = + pack_k_i8_wmma_b_regs_from_lane_major_global( + k, k_head_base, k_group_stride, kb_base, col_tile, dt, lane); + if (do0) { + score_acc0 = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[0][dt], score_acc0, true); + } + if (do1) { + score_acc1 = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[1][dt], score_acc1, true); + } + } + +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[elem] = do0 ? static_cast(score_acc0[elem]) * score_scale0 : -FLT_MAX * 0.5f; + scores1[elem] = do1 ? static_cast(score_acc1[elem]) * score_scale1 : -FLT_MAX * 0.5f; + } +} + +template +__device__ __forceinline__ void compute_tqk_score_regs_raw_kq_2_lane_shared_key( + const uint2* __restrict__ k_lane_tile, + const i32x2_vec (&q_regs)[2][DTiles], + const int lane, + const int col_tile, + const float score_scale0, + const float score_scale1, + const bool do0, + const bool do1, + float8_vec& scores0, + float8_vec& scores1) { + i32x8_vec score_acc0; + i32x8_vec score_acc1; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc0[elem] = 0; + score_acc1[elem] = 0; + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const i32x2_vec k_regs = + pack_k_i8_wmma_b_regs_from_lane_major_shared( + k_lane_tile, col_tile, dt, lane); + if (do0) { + score_acc0 = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[0][dt], score_acc0, true); + } + if (do1) { + score_acc1 = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[1][dt], score_acc1, true); + } + } + +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[elem] = do0 ? static_cast(score_acc0[elem]) * score_scale0 : -FLT_MAX * 0.5f; + scores1[elem] = do1 ? static_cast(score_acc1[elem]) * score_scale1 : -FLT_MAX * 0.5f; + } +} + +template +__device__ __forceinline__ half8_vec make_v_regs_from_shared( + const __half* __restrict__ v_tile, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + half8_vec regs; + const int d = d_tile * BK + (lane & 15); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int k_local = wmma_f16_k_for_lane_elem(lane, elem); + const int row = col_tile * BK + k_local; + regs[elem] = static_cast<_Float16>(v_tile[row * SharedHeadStride + d]); + } + return regs; +} + +template +__device__ __forceinline__ half8_vec make_v_regs_from_transposed_shared( + const __half* __restrict__ v_tile, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + half8_vec regs; + const int d = d_tile * BK + (lane & 15); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int k_local = wmma_f16_k_for_lane_elem(lane, elem); + const int n = col_tile * BK + k_local; + regs[elem] = static_cast<_Float16>(v_tile[d * SharedValueStride + n]); + } + return regs; +} + +__device__ __forceinline__ half8_vec make_v_regs_from_transposed_global( + const __half* __restrict__ v, + const int64_t v_head_base, + const int64_t v_stride_n, + const int64_t kb_base, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + half8_vec regs; + const int d = d_tile * BK + (lane & 15); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int k_local = wmma_f16_k_for_lane_elem(lane, elem); + const int64_t n = kb_base + col_tile * BK + k_local; + regs[elem] = static_cast<_Float16>(v[v_head_base + static_cast(d) * v_stride_n + n]); + } + return regs; +} + +__device__ __forceinline__ half8_vec make_v_regs_from_hnd_global( + const __half* __restrict__ v, + const int64_t v_head_base, + const int64_t v_stride_n, + const int64_t kb_base, + const int col_tile, + const int d_tile, + const int lane) { + constexpr int BK = 16; + const int source_row = gfx12_tr_b128_source_row_for_lane(lane); + const int source_col = gfx12_tr_b128_source_col_for_lane(lane); + const int64_t n = kb_base + col_tile * BK + source_row; + const int64_t d = d_tile * BK + source_col; + return gfx12_global_load_tr_b128_f16(v + v_head_base + n * v_stride_n + d); +} + +template +__device__ __forceinline__ half8_vec make_v_regs_from_lane_major_shared( + const uint4* __restrict__ v_lane_tile, + const int col_tile, + const int d_tile, + const int lane) { + const uint4 packed = v_lane_tile[(col_tile * DTiles + d_tile) * 32 + lane]; + return *reinterpret_cast(&packed); +} + +template +__device__ __forceinline__ half8_vec make_v_regs_from_lane_major_global( + const __half* __restrict__ v, + const int64_t v_head_base, + const int64_t v_group_stride, + const int64_t kb_base, + const int col_tile, + const int d_tile, + const int lane) { + const int64_t group = kb_base >> 6; + const int group_col_tile = ((static_cast(kb_base) & 63) >> 4) + col_tile; + const int64_t off = v_head_base + group * v_group_stride + + (((static_cast(group_col_tile) * DTiles + d_tile) * 32 + lane) * 8); + const uint4 packed = *reinterpret_cast(v + off); + return *reinterpret_cast(&packed); +} + +__device__ __forceinline__ void store_half(__half* output, const int64_t offset, const float value) { + output[offset] = __float2half_rn(value); +} + +__device__ __forceinline__ void store_output_value( + __half* output, + const int64_t offset, + const float value) { + output[offset] = __float2half_rn(value); +} + +__device__ __forceinline__ void store_output_value( + __hip_bfloat16* output, + const int64_t offset, + const float value) { + output[offset] = __float2bfloat16(value); +} + +template < + int BlockCols, + int BlockRows, + bool HndContiguous = false, + bool ValueTransposed = false, + int ValuePad = SAGEATTN_GFX12_NATIVE_F16_TV_PAD, + bool IsCausal = false, + bool TransposeValueOnLoad = false, + bool F16PvAccum = false, + bool PvOrderedQK = false, + typename QueryT = int8_t, + bool QuantizeQuery = false, + bool SplitCausalPrefix = false> +SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, 1) void qk_int8_sv_f16_d64_native_kernel( + const QueryT* __restrict__ q, + const int8_t* __restrict__ k, + const __half* __restrict__ v, + __half* __restrict__ output, + const float* __restrict__ q_scale, + const float* __restrict__ k_scale, + const int64_t batch_size, + const int64_t qo_len, + const int64_t kv_len, + const int64_t num_qo_heads, + const int64_t num_kv_heads, + const int64_t q_stride_b, + const int64_t q_stride_n, + const int64_t q_stride_h, + const int64_t k_stride_b, + const int64_t k_stride_n, + const int64_t k_stride_h, + const int64_t v_stride_b, + const int64_t v_stride_n, + const int64_t v_stride_h, + const int64_t o_stride_b, + const int64_t o_stride_n, + const int64_t o_stride_h, + const int64_t qs_stride_b, + const int64_t qs_stride_h, + const int64_t ks_stride_b, + const int64_t ks_stride_h, + const int tensor_layout, + const float sm_scale) { + constexpr int HeadDim = 64; + constexpr int BR = BlockRows; + constexpr int RM = 16; + constexpr int BK = 16; + constexpr int BC = BlockCols; + constexpr int Threads = (BlockRows / RM) * 32; + constexpr int DTiles = HeadDim / BK; + constexpr int ColTiles = BC / BK; + constexpr int SharedHeadStride = HeadDim + 16; + constexpr bool UseTransposedValueLayout = ValueTransposed || TransposeValueOnLoad; + constexpr bool UseTrLoadLaneMajorValue = false; + constexpr int SharedValueRows = UseTransposedValueLayout ? HeadDim : BC; + constexpr int SharedValueStride = UseTransposedValueLayout ? (BC + ValuePad) : SharedHeadStride; + constexpr int LaneMajorValueElems = ColTiles * DTiles * 32; + constexpr int PackedRows = 4; + static_assert(BlockCols == 64 || BlockCols == 128, + "native gfx12 D64 kernel supports BC64/BC128."); + static_assert(BlockRows == 64 || BlockRows == 128, + "native gfx12 D64 kernel supports BR64/BR128."); + static_assert(!UseTransposedValueLayout || HndContiguous, + "transposed fp16 value path requires contiguous HND tensors."); + static_assert(!IsCausal || ((BlockRows == 64 || BlockRows == 128) && BlockCols == 64), + "native gfx12 D64 single-q causal path supports BR64/BR128/BC64."); + static_assert(!QuantizeQuery || HndContiguous, + "direct fp16 Q quantization requires contiguous HND tensors."); + + __shared__ int8_t k_tile[BC][SharedHeadStride]; + __shared__ __half v_tile[UseTrLoadLaneMajorValue ? 1 : SharedValueRows] + [UseTrLoadLaneMajorValue ? 1 : SharedValueStride]; + __shared__ uint4 v_lane_tile[UseTrLoadLaneMajorValue ? LaneMajorValueElems : 1]; + + const int tid = threadIdx.x; + const int lane = tid & 31; + const int wave = tid >> 5; + const int row_base = (lane >> 4) << 3; + const int col = lane & 15; + const int64_t q_base = + q_block_base_for_launch(static_cast(blockIdx.x), qo_len); + const int64_t hq = blockIdx.y; + const int64_t b = blockIdx.z; + if (b >= batch_size || hq >= num_qo_heads || q_base >= qo_len) { + return; + } + + const int64_t hkv = hq / (num_qo_heads / num_kv_heads); + const int64_t q_start = q_base + static_cast(wave) * RM; + float qs = 1.0f; + + using FragK = rocwmma::fragment; + using FragQ = rocwmma::fragment; + using FragScoreT = rocwmma::fragment; + constexpr bool UseRawPreparedQ = !QuantizeQuery && HndContiguous; + + i32x2_vec q_regs[DTiles]; + FragQ q_frag[DTiles]; + if constexpr (QuantizeQuery) { + constexpr int QPackElems = 8; + constexpr int QPacks = (RM * HeadDim) / QPackElems; + float local_q_amax = 0.0000001f; + for (int pack = lane; pack < QPacks; pack += 32) { + const int elem_base = pack * QPackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t q_idx = q_start + row; + if (q_idx < qo_len) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint4 raw = *reinterpret_cast(q + q_off); + const QueryT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < QPackElems; ++i) { + local_q_amax = fmaxf(local_q_amax, fabsf(value_to_float(values[i]))); + } + } + } + local_q_amax = vllm::warpReduceMax(local_q_amax); + const float q_amax = __shfl(local_q_amax, 0, 32); + const float inv_q_scale = 127.0f / q_amax; + qs = (q_amax / 127.0f) * sm_scale * kLog2e; +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + q_regs[dt] = pack_quant_q_i8_wmma_b_regs( + q, tensor_layout, lane, b, hq, q_start, qo_len, dt * BK, + q_stride_b, q_stride_n, q_stride_h, inv_q_scale); + } + } else { + const int q_scale_idx = q_scale_col_per_warp(q_start); + qs = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; + if constexpr (UseRawPreparedQ) { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + q_regs[dt] = pack_q_i8_wmma_b_regs( + reinterpret_cast(q), tensor_layout, lane, b, hq, q_start, + qo_len, dt * BK, q_stride_b, q_stride_n, q_stride_h); + } + } else { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_start, dt * BK, q_stride_b, q_stride_n, q_stride_h); + rocwmma::load_matrix_sync(q_frag[dt], q + q_off, static_cast(q_stride_n)); + } + } + } + + using PvAccumVec = std::conditional_t; + PvAccumVec out_frag[DTiles]; +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[dt][elem] = 0.0f; + } + } + float m = -FLT_MAX * 0.5f; + float l = 0.0f; + + auto process_kv_tile = [&](const int64_t kb_base, auto causal_mask_tag) { + constexpr bool ApplyCausalMask = decltype(causal_mask_tag)::value; + constexpr int KVecBytes = 16; + constexpr int KVecsPerRow = HeadDim / KVecBytes; + for (int vec = tid; vec < BC * KVecsPerRow; vec += Threads) { + const int n = vec / KVecsPerRow; + const int d = (vec - n * KVecsPerRow) * KVecBytes; + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + *reinterpret_cast(&k_tile[n][d]) = + *reinterpret_cast(k + k_off); + } + + if constexpr (UseTrLoadLaneMajorValue) { + for (int idx = tid; idx < LaneMajorValueElems; idx += Threads) { + const int lane_local = idx & 31; + const int d_tile = (idx >> 5) % DTiles; + const int col_tile = idx / (DTiles * 32); + const half8_vec regs = make_v_regs_from_hnd_global( + v, b * v_stride_b + hkv * v_stride_h, v_stride_n, + kb_base, col_tile, d_tile, lane_local); + v_lane_tile[idx] = *reinterpret_cast(®s); + } + } else if constexpr (ValueTransposed) { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerD = BC / VElemsPerVec; + for (int vec = tid; vec < HeadDim * VVecsPerD; vec += Threads) { + const int d = vec / VVecsPerD; + const int n = (vec - d * VVecsPerD) * VElemsPerVec; + const int64_t v_off = b * v_stride_b + hkv * v_stride_h + + static_cast(d) * v_stride_n + kb_base + n; + *reinterpret_cast(&v_tile[d][n]) = + *reinterpret_cast(v + v_off); + } + } else if constexpr (TransposeValueOnLoad) { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerRow = HeadDim / VElemsPerVec; + for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { + const int n = vec / VVecsPerRow; + const int d = (vec - n * VVecsPerRow) * VElemsPerVec; + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); + const uint4 packed = *reinterpret_cast(v + v_off); + const __half* vals = reinterpret_cast(&packed); +#pragma unroll + for (int elem = 0; elem < VElemsPerVec; ++elem) { + v_tile[d + elem][n] = vals[elem]; + } + } + } else { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerRow = HeadDim / VElemsPerVec; + for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { + const int n = vec / VVecsPerRow; + const int d = (vec - n * VVecsPerRow) * VElemsPerVec; + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); + *reinterpret_cast(&v_tile[n][d]) = + *reinterpret_cast(v + v_off); + } + } + __syncthreads(); + + if constexpr (IsCausal) { + float8_vec score_cache[ColTiles]; + half8_vec prob_cache[ColTiles]; + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + const float score_scale = qs * + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + float8_vec scores; + const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; + const bool fully_future = + ApplyCausalMask && (k_col_start > q_start + RM - 1); + if constexpr (ApplyCausalMask) { + if (fully_future) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else if constexpr (QuantizeQuery || UseRawPreparedQ) { + scores = compute_tqk_score_regs_raw_kq_one( + &k_tile[0][0], q_regs, lane, col_tile, score_scale); + } else { + scores = compute_tqk_score_regs( + &k_tile[0][0], q_frag, col_tile, score_scale); + } + } else if constexpr (QuantizeQuery || UseRawPreparedQ) { + scores = compute_tqk_score_regs_raw_kq_one( + &k_tile[0][0], q_regs, lane, col_tile, score_scale); + } else { + scores = compute_tqk_score_regs( + &k_tile[0][0], q_frag, col_tile, score_scale); + } + if constexpr (ApplyCausalMask) { + if (!fully_future && k_col_start + BK > q_start) { + if constexpr (PvOrderedQK) { + apply_tqk_causal_mask_pv_order( + scores, static_cast(q_start), static_cast(kb_base), col_tile, lane); + } else { + apply_tqk_causal_mask( + scores, static_cast(q_start), static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache[col_tile] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max = fmaxf(local_max, scores[elem]); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m = new_m; + l *= alpha; + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int row = row_base + elem; + out_frag[dt][elem] *= __shfl(alpha, row, 32); + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + half8_vec prob_values; + const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; + const bool fully_future = + ApplyCausalMask && (k_col_start > q_start + RM - 1); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float prob = fully_future ? 0.0f : + fast_exp2(score_cache[col_tile][elem] - m + kF16SoftmaxOffset); + local_sum += prob; + prob_values[elem] = static_cast<_Float16>(prob); + } + if constexpr (PvOrderedQK) { + prob_cache[col_tile] = prob_values; + } else { + prob_cache[col_tile] = make_p_regs_from_tqk_prob_regs(prob_values, lane); + } + } + l += local_sum + __shfl_xor(local_sum, 16, 32); +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; + const bool fully_future = + ApplyCausalMask && (k_col_start > q_start + RM - 1); + if constexpr (ApplyCausalMask) { + if (fully_future) { + continue; + } + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + half8_vec v_regs; + if constexpr (UseTrLoadLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + PvAccumVec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[dt][elem]; + } + PvAccumVec pv_acc; + if constexpr (F16PvAccum) { + pv_acc = + __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + prob_cache[col_tile], v_regs, acc); + } else { + pv_acc = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + prob_cache[col_tile], v_regs, acc); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[dt][elem] = pv_acc[elem]; + } + } + } + } else { + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + const float score_scale = qs * + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; + float8_vec scores; + if constexpr (QuantizeQuery || UseRawPreparedQ) { + scores = compute_tqk_score_regs_raw_kq( + &k_tile[0][0], q_regs, lane, col_tile, score_scale); + } else { + scores = compute_tqk_score_regs( + &k_tile[0][0], q_frag, col_tile, score_scale); + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + for (int elem = 0; elem < 8; ++elem) { + local_max = fmaxf(local_max, scores[elem]); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m = new_m; + l *= alpha; + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int row = row_base + elem; + out_frag[dt][elem] *= __shfl(alpha, row, 32); + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + const float score_scale = qs * + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; + float8_vec scores; + if constexpr (QuantizeQuery || UseRawPreparedQ) { + scores = compute_tqk_score_regs_raw_kq( + &k_tile[0][0], q_regs, lane, col_tile, score_scale); + } else { + scores = compute_tqk_score_regs( + &k_tile[0][0], q_frag, col_tile, score_scale); + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + half8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float prob = fast_exp2(scores[elem] - m + kF16SoftmaxOffset); + local_sum += prob; + prob_values[elem] = static_cast<_Float16>(prob); + } + +#if SAGEATTN_NATIVE_HAS_GFX12_WMMA + const half8_vec p_regs = make_p_regs_from_tqk_prob_regs(prob_values, lane); +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + half8_vec v_regs; + if constexpr (UseTrLoadLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + PvAccumVec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[dt][elem]; + } + PvAccumVec pv_acc; + if constexpr (F16PvAccum) { + pv_acc = + __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs, v_regs, acc); + } else { + pv_acc = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs, v_regs, acc); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[dt][elem] = pv_acc[elem]; + } + } +#else + static_assert(SAGEATTN_NATIVE_HAS_GFX12_WMMA, "native gfx12 kernel requires gfx12 WMMA builtins"); +#endif + } + l += local_sum + __shfl_xor(local_sum, 16, 32); + } + __syncthreads(); + }; + + const int64_t kv_limit = IsCausal && (q_base + BR) < kv_len ? q_base + BR : kv_len; + if constexpr (IsCausal && SplitCausalPrefix) { + const int64_t prefix_limit = q_base < kv_limit ? q_base : kv_limit; + for (int64_t kb_base = 0; kb_base < prefix_limit; kb_base += BC) { + process_kv_tile(kb_base, std::false_type{}); + } + for (int64_t kb_base = prefix_limit; kb_base < kv_limit; kb_base += BC) { + process_kv_tile(kb_base, std::true_type{}); + } + } else { + for (int64_t kb_base = 0; kb_base < kv_limit; kb_base += BC) { + process_kv_tile(kb_base, std::integral_constant{}); + } + } + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const int d = dt * BK + col; +#pragma unroll + for (int pair = 0; pair < PackedRows; ++pair) { + const int elem = pair * 2; + const int64_t q_idx0 = q_start + row_base + elem; + const int64_t q_idx1 = q_idx0 + 1; + const float l_sum0 = __shfl(l, row_base + elem, 32); + const float l_sum1 = __shfl(l, row_base + elem + 1, 32); + const float value0 = l_sum0 == 0.0f ? 0.0f : out_frag[dt][elem] / l_sum0; + const float value1 = l_sum1 == 0.0f ? 0.0f : out_frag[dt][elem + 1] / l_sum1; + store_half(output, qkv_offset_dispatch( + tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h), value0); + store_half(output, qkv_offset_dispatch( + tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h), value1); + } + } +} + +template +SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HeadDim, IsCausal) __global__ +SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2q_kernel( + const QueryT* __restrict__ q, + const KeyT* __restrict__ k, + const __half* __restrict__ v, + __half* __restrict__ output, + const float* __restrict__ q_scale, + const float* __restrict__ k_scale, + const int64_t batch_size, + const int64_t qo_len, + const int64_t kv_len, + const int64_t num_qo_heads, + const int64_t num_kv_heads, + const int64_t q_stride_b, + const int64_t q_stride_n, + const int64_t q_stride_h, + const int64_t k_stride_b, + const int64_t k_stride_n, + const int64_t k_stride_h, + const int64_t v_stride_b, + const int64_t v_stride_n, + const int64_t v_stride_h, + const int64_t o_stride_b, + const int64_t o_stride_n, + const int64_t o_stride_h, + const int64_t qs_stride_b, + const int64_t qs_stride_h, + const int64_t ks_stride_b, + const int64_t ks_stride_h, + const int tensor_layout, + const float sm_scale) { + static_assert(HeadDim == 16 || HeadDim == 64, + "native gfx12 fp16 2q kernel supports D16/D64."); + constexpr int BR = BlockRows; + constexpr int RM = 16; + constexpr int RowsPerWave = 32; + constexpr int QGroups = 2; + constexpr int BK = 16; + constexpr int BC = BlockCols; + constexpr int Threads = BlockRows; + constexpr int DTiles = HeadDim / BK; + constexpr int ColTiles = BC / BK; + constexpr int SharedHeadStride = HeadDim + 16; + constexpr int SharedQKHeadStride = SharedHeadStride; + constexpr bool UseTransposedValueLayout = ValueTransposed || TransposeValueOnLoad; + constexpr bool UseLaneMajorKey = + LaneMajorKey && !QuantizeKey && HndContiguous && BlockCols == 64 && PvOrderedQK; + constexpr bool UseLaneMajorValue = + LaneMajorValue && HndContiguous && + (ValueTransposed || TransposeValueOnLoad); + constexpr bool StageValueInShared = !UseLaneMajorValue; + constexpr bool UsesTileSharedMemory = true; + constexpr bool UsesKeySharedMemory = true; + constexpr bool UsesValueSharedMemory = StageValueInShared || UseLaneMajorValue; + constexpr int SharedValueRows = + StageValueInShared ? (UseTransposedValueLayout ? HeadDim : BC) : 1; + constexpr int SharedValueStride = + StageValueInShared ? (UseTransposedValueLayout ? (BC + ValuePad) : SharedHeadStride) : 1; + constexpr int PackedRows = 4; + static_assert(BlockCols == 32 || BlockCols == 64 || BlockCols == 128, + "native gfx12 fp16 2q kernel supports BC32/BC64/BC128."); + static_assert(BlockRows == 32 || BlockRows == 64 || BlockRows == 128 || + BlockRows == 256 || BlockRows == 512 || BlockRows == 1024, + "native gfx12 fp16 2q kernel supports BR32/BR64/BR128/BR256/BR512/BR1024."); + static_assert(!UseTransposedValueLayout || HndContiguous, + "transposed fp16 value layout requires contiguous HND tensors."); + static_assert(!F16PvAccum || BlockCols <= 64, + "fp16 PV accumulation currently supports the BC64 2q path."); + static_assert(!QuantizeQuery || HndContiguous, + "direct fp16 Q quantization currently requires contiguous HND tensors."); + static_assert(!QuantizeKey || (HndContiguous && BlockCols == 64), + "direct fp16 K quantization currently requires contiguous HND BC64 tensors."); + static_assert(!LaneMajorValue || + (HndContiguous && (ValueTransposed || TransposeValueOnLoad)), + "lane-major fp16 value staging requires contiguous HND values."); + static_assert(!LaneMajorKey || + (!QuantizeKey && HndContiguous && BlockCols == 64 && PvOrderedQK), + "lane-major fp16 key staging requires prepared HND BC64 PvOrderedQK."); + static_assert(!StreamColTiles || (BlockCols <= 128 && (QuantizeQuery || HndContiguous)), + "streaming col-tile softmax is specialized for raw/quantized QK."); + __shared__ int8_t k_tile[UseLaneMajorKey ? 1 : BC] + [SharedQKHeadStride]; + __shared__ uint2 k_lane_tile[UseLaneMajorKey ? (ColTiles * DTiles * 32) : 1]; + __shared__ __half v_tile[SharedValueRows][SharedValueStride]; + __shared__ uint4 v_lane_tile[UseLaneMajorValue ? (ColTiles * DTiles * 32) : 1]; + __shared__ float raw_k_amax_shared; + + const int tid = threadIdx.x; + const int lane = tid & 31; + const int wave = tid >> 5; + const int row_base = (lane >> 4) << 3; + const int col = lane & 15; + const int64_t q_block_count = (qo_len + BR - 1) / BR; + const int64_t hb_count = num_qo_heads * batch_size; + for (;;) { + int64_t q_block = static_cast(blockIdx.x); + int64_t hq = blockIdx.y; + int64_t b = blockIdx.z; + if constexpr (FlatCausalSchedule) { + static_assert(IsCausal, "flat q scheduling is causal-only"); + const int64_t hb = static_cast(blockIdx.x) % hb_count; + q_block = static_cast(blockIdx.x) / hb_count; + hq = hb % num_qo_heads; + b = hb / num_qo_heads; + } + const int64_t q_base = q_block_base_for_launch(q_block, qo_len); + if (b >= batch_size || hq >= num_qo_heads || q_base >= qo_len) { + return; + } + + const int64_t hkv = hq / (num_qo_heads / num_kv_heads); + const int64_t q_head_base = b * q_stride_b + hq * q_stride_h; + const int64_t k_head_base = b * k_stride_b + hkv * k_stride_h; + const int64_t v_head_base = b * v_stride_b + hkv * v_stride_h; + const int64_t o_head_base = b * o_stride_b + hq * o_stride_h; + int64_t q_start[QGroups]; + float qs[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + q_start[qg] = q_base + static_cast(wave) * RowsPerWave + qg * RM; + } + + using FragK = rocwmma::fragment; + using FragQ = rocwmma::fragment; + using FragScoreT = rocwmma::fragment; + constexpr bool UseRawPreparedQ = !QuantizeQuery && HndContiguous; + + i32x2_vec q_regs[QGroups][DTiles]; + if constexpr (QuantizeQuery) { + constexpr int QPackElems = 8; + constexpr int QPacksPerWave = (RowsPerWave * HeadDim) / QPackElems; + const int local_q_row_base = wave * RowsPerWave; + float local_q_amax = 0.0000001f; + for (int pack = lane; pack < QPacksPerWave; pack += 32) { + const int elem_base = pack * QPackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t q_idx = q_base + local_q_row_base + row; + if (q_idx < qo_len) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint4 raw = *reinterpret_cast(q + q_off); + const QueryT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < QPackElems; ++i) { + local_q_amax = fmaxf(local_q_amax, fabsf(value_to_float(values[i]))); + } + } + } + local_q_amax = vllm::warpReduceMax(local_q_amax); + const float q_scale_local = __shfl(local_q_amax, 0, 32) / 127.0f; + const float inv_q_scale = 127.0f / __shfl(local_q_amax, 0, 32); +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + qs[qg] = q_scale_local * sm_scale * kLog2e; + const int64_t qg_start = q_start[qg]; +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + q_regs[qg][dt] = pack_quant_q_i8_wmma_b_regs( + q, tensor_layout, lane, b, hq, qg_start, qo_len, dt * BK, + q_stride_b, q_stride_n, q_stride_h, inv_q_scale); + } + } + } else { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const int q_scale_idx = q_scale_col_per_warp(q_start[qg]); + qs[qg] = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; + } + if constexpr (UseRawPreparedQ) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + q_regs[qg][dt] = pack_q_i8_wmma_b_regs( + reinterpret_cast(q), tensor_layout, lane, b, hq, q_start[qg], + qo_len, dt * BK, q_stride_b, q_stride_n, q_stride_h); + } + } + } + } + + FragQ q_frag[QGroups][DTiles]; + if constexpr (!QuantizeQuery && !UseRawPreparedQ) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + int64_t q_off; + if constexpr (HndContiguous) { + q_off = q_head_base + q_start[qg] * HeadDim + dt * BK; + } else { + q_off = qkv_offset( + tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); + } + rocwmma::load_matrix_sync(q_frag[qg][dt], q + q_off, static_cast(q_stride_n)); + } + } + } + + using PvAccumVec = std::conditional_t; + PvAccumVec out_frag[QGroups][DTiles]; + float m[QGroups]; + float l[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + m[qg] = -FLT_MAX * 0.5f; + l[qg] = 0.0f; +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] = 0.0f; + } + } + } + + const int64_t kv_limit = IsCausal && (q_base + BR) < kv_len ? q_base + BR : kv_len; + auto process_kv_tile = [&](const int64_t kb_base, auto apply_causal_mask_tag) { + constexpr int KVecBytes = 16; + constexpr int KBytesPerRow = HeadDim; + constexpr int KVecsPerRow = KBytesPerRow / KVecBytes; + float k_scale_tile = 1.0f; + if constexpr (QuantizeKey) { + constexpr int PackElems = 8; + constexpr int Packs = (BC * HeadDim) / PackElems; + float local_k_amax = 0.0000001f; + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int n = elem_base / HeadDim; + const int d = elem_base - n * HeadDim; + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + const uint4 raw = *reinterpret_cast(k + k_off); + const KeyT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_k_amax = fmaxf(local_k_amax, fabsf(value_to_float(values[i]))); + } + } + const float block_k_amax = vllm::blockReduceMax(local_k_amax); + if (tid == 0) { + raw_k_amax_shared = block_k_amax; + } + __syncthreads(); + const float raw_k_amax = raw_k_amax_shared; + k_scale_tile = raw_k_amax / 127.0f; + const float inv_k_scale = 127.0f / raw_k_amax; + + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int n = elem_base / HeadDim; + const int d = elem_base - n * HeadDim; + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + const uint4 raw = *reinterpret_cast(k + k_off); + const KeyT* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_k_scale); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_k_scale); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_k_scale); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_k_scale); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_k_scale); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_k_scale); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_k_scale); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_k_scale); + *reinterpret_cast(&k_tile[n][d]) = out0; + *reinterpret_cast(&k_tile[n][d + 4]) = out1; + } + } else if constexpr (UseLaneMajorKey) { + constexpr int LaneMajorElems = ColTiles * DTiles * 32; + for (int idx = tid; idx < LaneMajorElems; idx += Threads) { + const int lane_local = idx & 31; + const int d_tile = (idx >> 5) % DTiles; + const int col_tile = idx / (DTiles * 32); + const int col_local = lane_local & 15; + const int row = col_tile * BK + pv_k_order_for_acc_row(col_local); + const int d = d_tile * BK + 8 * (lane_local >> 4); + const int64_t k_off = k_head_base + (kb_base + row) * HeadDim + d; + k_lane_tile[idx] = *reinterpret_cast(k + k_off); + } + } else { + for (int vec = tid; vec < BC * KVecsPerRow; vec += Threads) { + const int n = vec / KVecsPerRow; + const int d = (vec - n * KVecsPerRow) * KVecBytes; + int64_t k_off; + if constexpr (HndContiguous) { + k_off = k_head_base + (kb_base + n) * HeadDim + d; + } else { + k_off = qkv_offset( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + } + *reinterpret_cast(&k_tile[n][d]) = + *reinterpret_cast(reinterpret_cast(k) + k_off); + } + } + + float prepared_k_scale_tile = k_scale_tile; + if constexpr (!QuantizeKey && BC <= 64) { + const int k_scale_idx = k_scale_col_per_warp(kb_base); + prepared_k_scale_tile = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + + auto stage_value_tile = [&]() { + if constexpr (UseLaneMajorValue) { + if constexpr (ValueTransposed) { + constexpr int LaneMajorElems = ColTiles * DTiles * 32; + for (int idx = tid; idx < LaneMajorElems; idx += Threads) { + const int lane_local = idx & 31; + const int d_tile = (idx >> 5) % DTiles; + const int col_tile = idx / (DTiles * 32); + const int d = d_tile * BK + (lane_local & 15); + const int high_half = (lane_local >> 4) & 1; + const int n0 = col_tile * BK + high_half * 4; + const int n1 = col_tile * BK + 8 + high_half * 4; + const int64_t base = v_head_base + static_cast(d) * v_stride_n + kb_base; + const uint2 raw0 = *reinterpret_cast(v + base + n0); + const uint2 raw1 = *reinterpret_cast(v + base + n1); + uint4 packed; + packed.x = raw0.x; + packed.y = raw0.y; + packed.z = raw1.x; + packed.w = raw1.y; + v_lane_tile[idx] = packed; + } + } else { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerRow = HeadDim / VElemsPerVec; + __half* lane_values = reinterpret_cast<__half*>(v_lane_tile); + for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { + const int n = vec / VVecsPerRow; + const int d_base = (vec - n * VVecsPerRow) * VElemsPerVec; + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d_base, v_stride_b, v_stride_n, v_stride_h); + const uint4 packed = *reinterpret_cast(v + v_off); + const __half* vals = reinterpret_cast(&packed); + const int col_tile = n / BK; + const int k_local = n - col_tile * BK; + const int dst_elem = ((k_local & 8) >> 1) | (k_local & 3); + const int dst_lane_hi = ((k_local >> 2) & 1) << 4; +#pragma unroll + for (int elem = 0; elem < VElemsPerVec; ++elem) { + const int d = d_base + elem; + const int d_tile = d >> 4; + const int dst_lane = (d & 15) | dst_lane_hi; + const int slot = ((col_tile * DTiles + d_tile) * 32 + dst_lane) * 8 + dst_elem; + lane_values[slot] = vals[elem]; + } + } + } + } else if constexpr (StageValueInShared && ValueTransposed) { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerD = BC / VElemsPerVec; + for (int vec = tid; vec < HeadDim * VVecsPerD; vec += Threads) { + const int d = vec / VVecsPerD; + const int n = (vec - d * VVecsPerD) * VElemsPerVec; + const int64_t v_off = v_head_base + static_cast(d) * v_stride_n + kb_base + n; + *reinterpret_cast(&v_tile[d][n]) = + *reinterpret_cast(v + v_off); + } + } else if constexpr (StageValueInShared && TransposeValueOnLoad) { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerRow = HeadDim / VElemsPerVec; + for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { + const int n = vec / VVecsPerRow; + const int d = (vec - n * VVecsPerRow) * VElemsPerVec; + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); + const uint4 packed = *reinterpret_cast(v + v_off); + const __half* vals = reinterpret_cast(&packed); +#pragma unroll + for (int elem = 0; elem < VElemsPerVec; ++elem) { + v_tile[d + elem][n] = vals[elem]; + } + } + } else if constexpr (StageValueInShared) { + constexpr int VElemsPerVec = 8; + constexpr int VVecsPerRow = HeadDim / VElemsPerVec; + for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { + const int n = vec / VVecsPerRow; + const int d = (vec - n * VVecsPerRow) * VElemsPerVec; + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); + *reinterpret_cast(&v_tile[n][d]) = + *reinterpret_cast(v + v_off); + } + } + }; + stage_value_tile(); + if constexpr (UsesKeySharedMemory || UsesValueSharedMemory) { + __syncthreads(); + } + SAGEATTN_F16_SCHED_BARRIER(0); + + auto compute_loaded_tile = [&](auto causal_mask_tag) { + constexpr bool ApplyCausalMask = decltype(causal_mask_tag)::value; + if constexpr (BlockCols <= 128) { + if constexpr ((QuantizeQuery || UseRawPreparedQ) && StreamColTiles) { + constexpr int StreamGroupCols = ColTiles >= 2 ? 2 : 1; +#pragma unroll + for (int group_base = 0; group_base < ColTiles; group_base += StreamGroupCols) { + float8_vec scores0[StreamGroupCols]; + float8_vec scores1[StreamGroupCols]; + bool fully_future[QGroups][StreamGroupCols]; + bool any_work = false; +#pragma unroll + for (int gc = 0; gc < StreamGroupCols; ++gc) { + const int col_tile = group_base + gc; + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future0 = + ApplyCausalMask && k_col_start >= q_start[0] + RM; + const bool fully_future1 = + ApplyCausalMask && k_col_start >= q_start[1] + RM; + fully_future[0][gc] = fully_future0; + fully_future[1][gc] = fully_future1; + any_work = any_work || !(fully_future0 && fully_future1); + if (fully_future0 && fully_future1) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[gc][elem] = -FLT_MAX * 0.5f; + scores1[gc][elem] = -FLT_MAX * 0.5f; + } + continue; + } + + float k_scale_local = prepared_k_scale_tile; + if constexpr (!QuantizeKey && BC > 64) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + if constexpr (UseLaneMajorKey) { + compute_tqk_score_regs_raw_kq_2_lane_shared_key( + k_lane_tile, q_regs, lane, col_tile, + qs[0] * k_scale_local, qs[1] * k_scale_local, + !fully_future0, !fully_future1, + scores0[gc], scores1[gc]); + } else { + compute_tqk_score_regs_raw_kq_2( + &k_tile[0][0], q_regs, lane, col_tile, + qs[0] * k_scale_local, qs[1] * k_scale_local, + !fully_future0, !fully_future1, + scores0[gc], scores1[gc]); + } + if constexpr (ApplyCausalMask) { + if (fully_future0) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[gc][elem] = -FLT_MAX * 0.5f; + } + } else if (k_col_start + BK > q_start[0]) { + if constexpr (PvOrderedQK) { + apply_tqk_causal_mask_pv_order( + scores0[gc], static_cast(q_start[0]), + static_cast(kb_base), col_tile, lane); + } else { + apply_tqk_causal_mask( + scores0[gc], static_cast(q_start[0]), + static_cast(kb_base), col_tile, lane); + } + } + if (fully_future1) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores1[gc][elem] = -FLT_MAX * 0.5f; + } + } else if (k_col_start + BK > q_start[1]) { + if constexpr (PvOrderedQK) { + apply_tqk_causal_mask_pv_order( + scores1[gc], static_cast(q_start[1]), + static_cast(kb_base), col_tile, lane); + } else { + apply_tqk_causal_mask( + scores1[gc], static_cast(q_start[1]), + static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask( + scores0[gc], kv_len, kb_base, col_tile, lane); + apply_tqk_kv_tail_mask( + scores1[gc], kv_len, kb_base, col_tile, lane); + } + } + if (!any_work) { + continue; + } + + half8_vec p_regs[QGroups][StreamGroupCols]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int gc = 0; gc < StreamGroupCols; ++gc) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; + local_max = fmaxf(local_max, score); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float8_vec alpha_rows; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] *= alpha_rows[elem]; + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int gc = 0; gc < StreamGroupCols; ++gc) { + half8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; + float prob = 0.0f; + if (!fully_future[qg][gc]) { + prob = fast_exp2(score - m[qg] + kF16SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = static_cast<_Float16>(prob); + } + if constexpr (PvOrderedQK) { + p_regs[qg][gc] = prob_values; + } else { + p_regs[qg][gc] = make_p_regs_from_tqk_prob_regs(prob_values, lane); + } + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + +#pragma unroll + for (int gc = 0; gc < StreamGroupCols; ++gc) { + if (fully_future[0][gc] && fully_future[1][gc]) { + continue; + } + const int col_tile = group_base + gc; +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + half8_vec v_regs; + if constexpr (UseLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if (fully_future[qg][gc]) { + continue; + } + PvAccumVec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[qg][dt][elem]; + } + PvAccumVec pv_acc; + if constexpr (F16PvAccum) { + pv_acc = + __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs[qg][gc], v_regs, acc); + } else { + pv_acc = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs[qg][gc], v_regs, acc); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] = pv_acc[elem]; + } + } + } + } + } + return; + } + half8_vec prob_cache[QGroups][ColTiles]; + if constexpr (QuantizeQuery || UseRawPreparedQ) { + float8_vec score_cache[QGroups][ColTiles]; + float local_max[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + local_max[qg] = -FLT_MAX * 0.5f; + } +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future0 = + ApplyCausalMask && k_col_start >= q_start[0] + RM; + const bool fully_future1 = + ApplyCausalMask && k_col_start >= q_start[1] + RM; + float8_vec scores0; + float8_vec scores1; + if (fully_future0 && fully_future1) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[elem] = -FLT_MAX * 0.5f; + scores1[elem] = -FLT_MAX * 0.5f; + } + } else { + float k_scale_local = prepared_k_scale_tile; + if constexpr (!QuantizeKey && BC > 64) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + if constexpr (UseLaneMajorKey) { + compute_tqk_score_regs_raw_kq_2_lane_shared_key( + k_lane_tile, q_regs, lane, col_tile, + qs[0] * k_scale_local, qs[1] * k_scale_local, + !fully_future0, !fully_future1, scores0, scores1); + } else { + compute_tqk_score_regs_raw_kq_2( + &k_tile[0][0], q_regs, lane, col_tile, + qs[0] * k_scale_local, qs[1] * k_scale_local, + !fully_future0, !fully_future1, scores0, scores1); + } + if constexpr (ApplyCausalMask) { + if (fully_future0) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores0[elem] = -FLT_MAX * 0.5f; + } + } else if (k_col_start + BK > q_start[0]) { + if constexpr (PvOrderedQK) { + apply_tqk_causal_mask_pv_order( + scores0, static_cast(q_start[0]), static_cast(kb_base), + col_tile, lane); + } else { + apply_tqk_causal_mask( + scores0, static_cast(q_start[0]), static_cast(kb_base), + col_tile, lane); + } + } + if (fully_future1) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores1[elem] = -FLT_MAX * 0.5f; + } + } else if (k_col_start + BK > q_start[1]) { + if constexpr (PvOrderedQK) { + apply_tqk_causal_mask_pv_order( + scores1, static_cast(q_start[1]), static_cast(kb_base), + col_tile, lane); + } else { + apply_tqk_causal_mask( + scores1, static_cast(q_start[1]), static_cast(kb_base), + col_tile, lane); + } + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask( + scores0, kv_len, kb_base, col_tile, lane); + apply_tqk_kv_tail_mask( + scores1, kv_len, kb_base, col_tile, lane); + } + score_cache[0][col_tile] = scores0; + score_cache[1][col_tile] = scores1; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max[0] = fmaxf(local_max[0], scores0[elem]); + local_max[1] = fmaxf(local_max[1], scores1[elem]); + } + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const float tile_max = fmaxf(local_max[qg], __shfl_xor(local_max[qg], 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float8_vec alpha_rows; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] *= alpha_rows[elem]; + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + half8_vec prob_values; + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + const float8_vec scores = score_cache[qg][col_tile]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + float prob = 0.0f; + if (!fully_future) { + prob = fast_exp2(scores[elem] - m[qg] + kF16SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = static_cast<_Float16>(prob); + } + if constexpr (PvOrderedQK) { + prob_cache[qg][col_tile] = prob_values; + } else { + prob_cache[qg][col_tile] = make_p_regs_from_tqk_prob_regs(prob_values, lane); + } + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + } else { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float8_vec score_cache[ColTiles]; + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + float8_vec scores; + if (fully_future) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else { + float k_scale_local = prepared_k_scale_tile; + if constexpr (!QuantizeKey && BC > 64) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + const float score_scale = qs[qg] * k_scale_local; + scores = + compute_tqk_score_regs( + &k_tile[0][0], q_frag[qg], col_tile, score_scale); + if constexpr (ApplyCausalMask) { + const bool needs_causal_mask = k_col_start + BK > q_start[qg]; + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), static_cast(kb_base), + col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache[col_tile] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max = fmaxf(local_max, scores[elem]); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float8_vec alpha_rows; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] *= alpha_rows[elem]; + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + half8_vec prob_values; + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + const float8_vec scores = score_cache[col_tile]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + float prob = 0.0f; + if (!fully_future) { + prob = fast_exp2(scores[elem] - m[qg] + kF16SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = static_cast<_Float16>(prob); + } + prob_cache[qg][col_tile] = make_p_regs_from_tqk_prob_regs(prob_values, lane); + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + } + + SAGEATTN_F16_SCHED_BARRIER(0); +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + if constexpr (ApplyCausalMask) { + if (kb_base + col_tile * BK >= q_start[QGroups - 1] + RM) { + continue; + } + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + half8_vec v_regs; + if constexpr (UseLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + const bool fully_future0 = + ApplyCausalMask && kb_base + col_tile * BK >= q_start[0] + RM; + const half8_vec p_regs0 = prob_cache[0][col_tile]; + if (!fully_future0) { + PvAccumVec acc0; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc0[elem] = out_frag[0][dt][elem]; + } + PvAccumVec pv_acc0; + if constexpr (F16PvAccum) { + pv_acc0 = + __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs0, v_regs, acc0); + } else { + pv_acc0 = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs0, v_regs, acc0); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[0][dt][elem] = pv_acc0[elem]; + } + } + const bool fully_future1 = + ApplyCausalMask && kb_base + col_tile * BK >= q_start[1] + RM; + const half8_vec p_regs1 = prob_cache[1][col_tile]; + if (!fully_future1) { + PvAccumVec acc1; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc1[elem] = out_frag[1][dt][elem]; + } + PvAccumVec pv_acc1; + if constexpr (F16PvAccum) { + pv_acc1 = + __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs1, v_regs, acc1); + } else { + pv_acc1 = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs1, v_regs, acc1); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[1][dt][elem] = pv_acc1[elem]; + } + } + } + } + } else { + const bool needs_causal_mask = ApplyCausalMask && (kb_base + BC > q_base); +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + float k_scale_local = prepared_k_scale_tile; + if constexpr (!QuantizeKey && BC > 64) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + const float score_scale = qs[qg] * k_scale_local; + float8_vec scores; + if constexpr (QuantizeQuery || UseRawPreparedQ) { + scores = compute_tqk_score_regs_raw_kq( + &k_tile[0][0], q_regs[qg], lane, col_tile, score_scale); + } else { + scores = + compute_tqk_score_regs( + &k_tile[0][0], q_frag[qg], col_tile, score_scale); + } + if constexpr (ApplyCausalMask) { + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); + } + } + if (kb_base + static_cast(col_tile) * BK + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max = fmaxf(local_max, scores[elem]); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float8_vec alpha_rows; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] *= alpha_rows[elem]; + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + float k_scale_local = prepared_k_scale_tile; + if constexpr (!QuantizeKey && BC > 64) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + const float score_scale = qs[qg] * k_scale_local; + float8_vec scores; + if constexpr (QuantizeQuery || UseRawPreparedQ) { + scores = compute_tqk_score_regs_raw_kq( + &k_tile[0][0], q_regs[qg], lane, col_tile, score_scale); + } else { + scores = + compute_tqk_score_regs( + &k_tile[0][0], q_frag[qg], col_tile, score_scale); + } + if constexpr (ApplyCausalMask) { + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); + } + } + if (kb_base + static_cast(col_tile) * BK + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + half8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float prob = fast_exp2(scores[elem] - m[qg] + kF16SoftmaxOffset); + local_sum += prob; + prob_values[elem] = static_cast<_Float16>(prob); + } + + const half8_vec p_regs = make_p_regs_from_tqk_prob_regs(prob_values, lane); +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + half8_vec v_regs; + if constexpr (UseLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + PvAccumVec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[qg][dt][elem]; + } + PvAccumVec pv_acc; + if constexpr (F16PvAccum) { + pv_acc = + __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs, v_regs, acc); + } else { + pv_acc = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs, v_regs, acc); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] = pv_acc[elem]; + } + } + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + } + }; + compute_loaded_tile(apply_causal_mask_tag); + if constexpr (UsesTileSharedMemory) { + __syncthreads(); + } + }; + + if constexpr (IsCausal) { + const int64_t prefix_limit = q_base < kv_limit ? q_base : kv_limit; +#pragma unroll 2 + for (int64_t kb_base = 0; kb_base < prefix_limit; kb_base += BC) { + process_kv_tile(kb_base, std::false_type{}); + } +#pragma unroll 2 + for (int64_t kb_base = prefix_limit; kb_base < kv_limit; kb_base += BC) { + process_kv_tile(kb_base, std::true_type{}); + } + } else { +#pragma unroll 2 + for (int64_t kb_base = 0; kb_base < kv_limit; kb_base += BC) { + process_kv_tile(kb_base, std::false_type{}); + } + } + +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float8_vec inv_l_rows; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float l_sum = __shfl(l[qg], row_base + elem, 32); + inv_l_rows[elem] = l_sum == 0.0f ? 0.0f : 1.0f / l_sum; + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const int d = dt * BK + col; +#pragma unroll + for (int pair = 0; pair < PackedRows; ++pair) { + const int elem = pair * 2; + const int64_t q_idx0 = q_start[qg] + row_base + elem; + const int64_t q_idx1 = q_idx0 + 1; + const float value0 = static_cast(out_frag[qg][dt][elem]) * inv_l_rows[elem]; + const float value1 = static_cast(out_frag[qg][dt][elem + 1]) * inv_l_rows[elem + 1]; + int64_t o_off0; + int64_t o_off1; + if constexpr (HndContiguous) { + o_off0 = o_head_base + q_idx0 * HeadDim + d; + o_off1 = o_head_base + q_idx1 * HeadDim + d; + } else { + o_off0 = qkv_offset( + tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h); + o_off1 = qkv_offset( + tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h); + } + store_half(output, o_off0, value0); + store_half(output, o_off1, value1); + } + } + } + return; + } +} + +template +SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HeadDim, IsCausal) +__global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_f8_native_2q_kernel( + const QueryT* __restrict__ q, + const KeyT* __restrict__ k, + const ValueT* __restrict__ v, + OutputT* __restrict__ output, + const float* __restrict__ q_scale, + const float* __restrict__ k_scale, + const float* __restrict__ v_scale, + const int64_t batch_size, + const int64_t qo_len, + const int64_t kv_len, + const int64_t num_qo_heads, + const int64_t num_kv_heads, + const int64_t q_stride_b, + const int64_t q_stride_n, + const int64_t q_stride_h, + const int64_t k_stride_b, + const int64_t k_stride_n, + const int64_t k_stride_h, + const int64_t v_stride_b, + const int64_t v_stride_n, + const int64_t v_stride_h, + const int64_t o_stride_b, + const int64_t o_stride_n, + const int64_t o_stride_h, + const int64_t qs_stride_b, + const int64_t qs_stride_h, + const int64_t ks_stride_b, + const int64_t ks_stride_h, + const int tensor_layout, + const float sm_scale) { + static_assert(HeadDim == 16 || HeadDim == 64 || HeadDim == 128, + "native gfx12 fp8 2q kernel supports D16/D64/D128."); + static_assert(BlockCols == 16 || BlockCols == 32 || BlockCols == 64 || + BlockCols == 128, + "native gfx12 fp8 2q kernel supports BC16/BC32/BC64/BC128."); + static_assert(BlockRows == 32 || BlockRows == 64 || BlockRows == 128 || + BlockRows == 256 || BlockRows == 512, + "native fp8 2q supports 32, 64, 128, 256, or 512 query rows per CTA."); + static_assert(QGroupsParam == 1 || QGroupsParam == 2, + "native fp8 kernel supports one or two query groups per wave."); + static_assert(!LowPressureQGroups || + (HeadDim == 128 && BlockCols == 64 && BlockRows == 128 && + IsCausal && ValueTransposed && QGroupsParam == 2), + "low-pressure fp8 path is specialized for D128 BC64 BR128 causal."); + constexpr int BR = BlockRows; + constexpr int RM = 16; + constexpr int QGroups = QGroupsParam; + constexpr int RowsPerWave = RM * QGroups; + constexpr int BK = 16; + constexpr int BC = BlockCols; + constexpr int Threads = BlockRows * (2 / QGroups); + constexpr int DTiles = HeadDim / BK; + constexpr int ColTiles = BC / BK; + constexpr int SharedHeadStride = HeadDim + 16; + constexpr bool PackedTransposedValue = ValueTransposed && HeadDim == 64; + constexpr int SharedValueRows = ValueTransposed && !PackedTransposedValue ? HeadDim : (ValueTransposed ? 1 : BC); + constexpr int SharedValueStride = ValueTransposed && !PackedTransposedValue ? (BC + 16) : (ValueTransposed ? 1 : SharedHeadStride); + constexpr int SharedValueRowsT = PackedTransposedValue ? HeadDim : 1; + constexpr int SharedValueStrideT = PackedTransposedValue ? (BC / 4 + 1) : 1; + constexpr int PackedRows = 4; + constexpr bool UseLowPressureQGroups = + LowPressureQGroups && HeadDim == 128 && BlockCols == 64 && + BlockRows == 128 && IsCausal && ValueTransposed && + !QuantizeQuery && !QuantizeKeyValue && QGroups == 2; + constexpr bool UsePrepackedLaneMajorKV = + PrepackedLaneMajorKV && HeadDim == 64 && + BlockCols == 64 && HndContiguous && + ValueTransposed && !QuantizeKeyValue; + constexpr bool UsePrepackedLaneMajorKey = + PrepackedLaneMajorKeyOnly && (HeadDim == 64 || HeadDim == 128) && + BlockCols == 64 && HndContiguous && + ValueTransposed && !QuantizeKeyValue; + constexpr bool UsePrepackedLaneMajorK = + UsePrepackedLaneMajorKV || UsePrepackedLaneMajorKey; + constexpr bool UsePrepackedLaneMajorValue = + UsePrepackedLaneMajorKV || + (PrepackedLaneMajorValueOnly && (HeadDim == 64 || HeadDim == 128) && + BlockCols == 64 && HndContiguous && + ValueTransposed && !QuantizeKeyValue); + constexpr bool UsesTileSharedMemory = + !UsePrepackedLaneMajorK || !UsePrepackedLaneMajorValue; + constexpr bool PreloadQFragments = + !UseLowPressureQGroups && !UsePrepackedLaneMajorK && (HeadDim == 64 || + (HeadDim == 128 && IsCausal && BlockRows == 128 && + (BlockCols == 64 || BlockCols == 128) && + ValueTransposed)); + constexpr bool UseStreamedFp8Pv = + StreamColsOverride >= 0 && + (BlockCols == 64 || (BlockCols == 128 && HeadDim == 128)) && + ValueTransposed && + (QuantizeQuery || (HeadDim == 128 && BlockCols == 128 && + PreloadQFragments)) && + (HeadDim == 128 || + (HeadDim == 64 && IsCausal && + (UsePrepackedLaneMajorK || UsePrepackedLaneMajorValue))); + static_assert(!QuantizeKeyValue || + ((HeadDim == 64 || HeadDim == 128) && + BlockCols == 64 && HndContiguous && !ValueTransposed), + "raw K/V fp8 staging currently supports contiguous HND D64/D128 BC64 tensors."); + static_assert(!UsePrepackedLaneMajorKV || + (HeadDim == 64 && HndContiguous && ValueTransposed && + !QuantizeKeyValue), + "lane-major prepared fp8 K/V requires prepared transposed D64 HND tensors."); + static_assert(!UsePrepackedLaneMajorKey || + ((HeadDim == 64 || HeadDim == 128) && + HndContiguous && ValueTransposed && + !QuantizeKeyValue), + "lane-major prepared fp8 K requires prepared transposed D64/D128 HND tensors."); + static_assert(!UsePrepackedLaneMajorValue || + (HndContiguous && ValueTransposed && !QuantizeKeyValue), + "lane-major prepared fp8 V requires prepared transposed HND tensors."); + static_assert(ValueTiles == 1 || ValueTiles == 4 || ValueTiles == 8, + "native fp8 2q stores one D16, D64, or D128 value slice per launch."); + static_assert(ValueTileBase + ValueTiles <= DTiles, "invalid fp8 value tile slice."); + + __shared__ int8_t k_tile[UsePrepackedLaneMajorK ? 1 : BC][SharedHeadStride]; + __shared__ uint8_t v_tile[UsePrepackedLaneMajorValue ? 1 : SharedValueRows] + [UsePrepackedLaneMajorValue ? 1 : SharedValueStride]; + __shared__ uint32_t v_tile_t[UsePrepackedLaneMajorValue ? 1 : SharedValueRowsT] + [UsePrepackedLaneMajorValue ? 1 : SharedValueStrideT]; + __shared__ float raw_k_amax_shared; + + const int tid = threadIdx.x; + const int lane = tid & 31; + const int wave = tid >> 5; + const int row_base = (lane >> 4) << 3; + const int col = lane & 15; + const int64_t q_base = + q_block_base_for_launch(static_cast(blockIdx.x), qo_len); + const int64_t hq = blockIdx.y; + const int64_t b = blockIdx.z; + if (b >= batch_size || hq >= num_qo_heads || q_base >= qo_len) { + return; + } + + const int64_t hkv = hq / (num_qo_heads / num_kv_heads); + const int64_t k_head_base = b * k_stride_b + hkv * k_stride_h; + const int64_t v_head_base = b * v_stride_b + hkv * v_stride_h; + int64_t q_start[QGroups]; + float qs[QGroups]; + const int64_t wave_q_start = q_base + static_cast(wave) * RowsPerWave; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + q_start[qg] = wave_q_start + qg * RM; + } + + i32x2_vec q_regs[QGroups][DTiles]; + if constexpr (QuantizeQuery) { + constexpr int QPackElems = 8; + constexpr int QPacksPerWave = (RowsPerWave * HeadDim) / QPackElems; + const int local_q_row_base = wave * RowsPerWave; + float local_q_amax = 0.0000001f; + for (int pack = lane; pack < QPacksPerWave; pack += 32) { + const int elem_base = pack * QPackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t q_idx = q_base + local_q_row_base + row; + if (q_idx < qo_len) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint4 raw = *reinterpret_cast(q + q_off); + const QueryT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < QPackElems; ++i) { + local_q_amax = fmaxf(local_q_amax, fabsf(value_to_float(values[i]))); + } + } + } + local_q_amax = vllm::warpReduceMax(local_q_amax); + float q_amax_for_scale = __shfl(local_q_amax, 0, 32); + if constexpr (QGroups == 1) { + __shared__ float q_amax_shared[Threads / 32]; + if (lane == 0) { + q_amax_shared[wave] = q_amax_for_scale; + } + __syncthreads(); + const int pair_wave = wave & ~1; + q_amax_for_scale = + fmaxf(q_amax_shared[pair_wave], q_amax_shared[pair_wave + 1]); + __syncthreads(); + } + const float q_scale_local = q_amax_for_scale / 127.0f; + const float inv_q_scale = 127.0f / q_amax_for_scale; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + qs[qg] = q_scale_local * sm_scale * kLog2e; + const int64_t qg_start = q_start[qg]; +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + q_regs[qg][dt] = pack_quant_q_i8_wmma_b_regs( + q, tensor_layout, lane, b, hq, qg_start, qo_len, dt * BK, + q_stride_b, q_stride_n, q_stride_h, inv_q_scale); + } + } + } else { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const int q_scale_idx = q_scale_col_per_warp(q_start[qg]); + qs[qg] = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; + } + if constexpr (UsePrepackedLaneMajorK) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + q_regs[qg][dt] = pack_q_i8_wmma_b_regs( + q, tensor_layout, lane, b, hq, q_start[qg], qo_len, dt * BK, + q_stride_b, q_stride_n, q_stride_h); + } + } + } + } + + using FragK = rocwmma::fragment; + using FragQ = rocwmma::fragment; + using FragScoreT = rocwmma::fragment; + + FragQ q_frag[QGroups][DTiles]; + if constexpr (PreloadQFragments && !QuantizeQuery) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); + rocwmma::load_matrix_sync(q_frag[qg][dt], q + q_off, static_cast(q_stride_n)); + } + } + } + + float out_frag[QGroups][ValueTiles][8]; + float m[QGroups]; + float l[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + m[qg] = -FLT_MAX * 0.5f; + l[qg] = 0.0f; +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] = 0.0f; + } + } + } + + const int64_t kv_limit = IsCausal && (q_base + BR) < kv_len ? q_base + BR : kv_len; + auto process_kv_tile = [&](const int64_t kb_base, auto causal_mask_tag) { + constexpr bool ApplyCausalMask = decltype(causal_mask_tag)::value; + constexpr int VecBytes = 16; + constexpr int VecsPerRow = HeadDim / VecBytes; + constexpr bool UseActiveCausalColSkip = false; + if constexpr (UsePrepackedLaneMajorK) { + if constexpr (ApplyCausalMask) { + if (kb_base >= q_start[QGroups - 1] + RM) { + return; + } + } + } + float k_scale_tile = 1.0f; + if constexpr (QuantizeKeyValue) { + constexpr int PackElems = 8; + constexpr int Packs = (BC * HeadDim) / PackElems; + float local_k_amax = 0.0000001f; + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int n = elem_base / HeadDim; + const int d = elem_base - n * HeadDim; + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + const uint4 raw = *reinterpret_cast(k + k_off); + const KeyT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_k_amax = fmaxf(local_k_amax, fabsf(value_to_float(values[i]))); + } + } + const float block_k_amax = vllm::blockReduceMax(local_k_amax); + if (tid == 0) { + raw_k_amax_shared = block_k_amax; + } + __syncthreads(); + const float raw_k_amax = raw_k_amax_shared; + k_scale_tile = raw_k_amax / 127.0f; + const float inv_k_scale = 127.0f / raw_k_amax; + + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int n = elem_base / HeadDim; + const int d = elem_base - n * HeadDim; + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); + const uint4 raw_k = *reinterpret_cast(k + k_off); + const uint4 raw_v = *reinterpret_cast(v + v_off); + const KeyT* k_values = reinterpret_cast(&raw_k); + const ValueT* v_values = reinterpret_cast(&raw_v); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(k_values[0]) * inv_k_scale); + out0.y = float_to_int8_rn_gfx12(value_to_float(k_values[1]) * inv_k_scale); + out0.z = float_to_int8_rn_gfx12(value_to_float(k_values[2]) * inv_k_scale); + out0.w = float_to_int8_rn_gfx12(value_to_float(k_values[3]) * inv_k_scale); + out1.x = float_to_int8_rn_gfx12(value_to_float(k_values[4]) * inv_k_scale); + out1.y = float_to_int8_rn_gfx12(value_to_float(k_values[5]) * inv_k_scale); + out1.z = float_to_int8_rn_gfx12(value_to_float(k_values[6]) * inv_k_scale); + out1.w = float_to_int8_rn_gfx12(value_to_float(k_values[7]) * inv_k_scale); + *reinterpret_cast(&k_tile[n][d]) = out0; + *reinterpret_cast(&k_tile[n][d + 4]) = out1; + + const uint32_t v_pack0 = static_cast(pack_f32x4_to_ocp_fp8( + value_to_float(v_values[0]), value_to_float(v_values[1]), + value_to_float(v_values[2]), value_to_float(v_values[3]))); + const uint32_t v_pack1 = static_cast(pack_f32x4_to_ocp_fp8( + value_to_float(v_values[4]), value_to_float(v_values[5]), + value_to_float(v_values[6]), value_to_float(v_values[7]))); + *reinterpret_cast(&v_tile[n][d]) = v_pack0; + *reinterpret_cast(&v_tile[n][d + 4]) = v_pack1; + } + } else if constexpr (!UsePrepackedLaneMajorK) { + for (int vec = tid; vec < BC * VecsPerRow; vec += Threads) { + const int n = vec / VecsPerRow; + const int d = (vec - n * VecsPerRow) * VecBytes; + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); + *reinterpret_cast(&k_tile[n][d]) = + *reinterpret_cast(k + k_off); + if constexpr (!ValueTransposed) { + const int64_t v_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); + *reinterpret_cast(&v_tile[n][d]) = + *reinterpret_cast(v + v_off); + } + } + } + if constexpr (ValueTransposed && !QuantizeKeyValue && !UsePrepackedLaneMajorValue) { + constexpr int VVecBytes = 16; + constexpr int VVecsPerD = BC / VVecBytes; + for (int vec = tid; vec < HeadDim * VVecsPerD; vec += Threads) { + const int d = vec / VVecsPerD; + const int n = (vec - d * VVecsPerD) * VVecBytes; + const int64_t v_off = b * v_stride_b + hkv * v_stride_h + + static_cast(d) * v_stride_n + kb_base + n; + const uint4 packed = *reinterpret_cast(v + v_off); + if constexpr (PackedTransposedValue) { + const int group = n >> 2; + v_tile_t[d][group + 0] = packed.x; + v_tile_t[d][group + 1] = packed.y; + v_tile_t[d][group + 2] = packed.z; + v_tile_t[d][group + 3] = packed.w; + } else { + *reinterpret_cast(&v_tile[d][n]) = packed; + } + } + } + if constexpr (UsesTileSharedMemory) { + __syncthreads(); + } + + if constexpr (!PreloadQFragments && !QuantizeQuery && !UsePrepackedLaneMajorK) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); + rocwmma::load_matrix_sync(q_frag[qg][dt], q + q_off, static_cast(q_stride_n)); + } + } + } + + int active_cols0 = ColTiles; + int active_cols1 = ColTiles; + if constexpr (UseActiveCausalColSkip) { + active_cols0 = + active_causal_col_tiles(q_start[0], kb_base); + active_cols1 = + active_causal_col_tiles(q_start[1], kb_base); + } + const int active_cols_any = active_cols0 > active_cols1 ? active_cols0 : active_cols1; + + if constexpr (BlockCols <= 64 || UseStreamedFp8Pv) { + i32x2_vec prob_cache[QGroups][ColTiles]; + if constexpr (QuantizeQuery || UsePrepackedLaneMajorKV || UseStreamedFp8Pv) { + if constexpr (UseStreamedFp8Pv) { + constexpr int StreamCols = StreamColsOverride > 0 ? StreamColsOverride : 2; + static_assert(StreamCols == 1 || StreamCols == 2 || StreamCols == 4, + "fp8 streaming supports one, two, or four col tiles per group."); +#pragma unroll + for (int stream_col = 0; stream_col < ColTiles; stream_col += StreamCols) { + float8_vec score_cache_stream[QGroups][StreamCols]; + i32x2_vec prob_cache_stream[QGroups][StreamCols]; + float local_max_stream[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + local_max_stream[qg] = -FLT_MAX * 0.5f; + } + float prepared_k_scale_tile = k_scale_tile; + if constexpr (!QuantizeKeyValue && BC <= 64) { + const int k_scale_idx = k_scale_col_per_warp(kb_base); + prepared_k_scale_tile = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } +#pragma unroll + for (int sc = 0; sc < StreamCols; ++sc) { + const int col_tile = stream_col + sc; + if constexpr (UseActiveCausalColSkip) { + if (col_tile >= active_cols_any) { + continue; + } + } + const int64_t k_col_start = kb_base + col_tile * BK; + const bool skip_all = + ApplyCausalMask && k_col_start >= q_start[QGroups - 1] + RM; + if (skip_all) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float8_vec scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + score_cache_stream[qg][sc] = scores; + } + continue; + } + bool fully_future[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if constexpr (UseActiveCausalColSkip) { + const int active_cols = qg == 0 ? active_cols0 : active_cols1; + fully_future[qg] = col_tile >= active_cols; + } else { + fully_future[qg] = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + } + } + float k_scale_local = k_scale_tile; + if constexpr (!QuantizeKeyValue && BC <= 64) { + k_scale_local = prepared_k_scale_tile; + } else if constexpr (!QuantizeKeyValue) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + if constexpr (QuantizeQuery || UsePrepackedLaneMajorK) { + i32x8_vec score_acc[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc[qg][elem] = 0; + } + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + i32x2_vec k_regs; + if constexpr (UsePrepackedLaneMajorK) { + k_regs = pack_k_i8_wmma_b_regs_from_lane_major_global( + k, k_head_base, k_stride_n * 64, kb_base, col_tile, dt, lane); + } else { + k_regs = pack_k_i8_wmma_b_regs_from_shared( + &k_tile[0][0], lane, col_tile, dt * BK); + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if (!fully_future[qg]) { + score_acc[qg] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[qg][dt], score_acc[qg], true); + } + } + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const float score_scale = qs[qg] * k_scale_local; + float8_vec scores; + if (fully_future[qg]) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = static_cast(score_acc[qg][elem]) * score_scale; + } + if constexpr (ApplyCausalMask) { + const bool needs_causal_mask = k_col_start + BK > q_start[qg]; + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), + static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache_stream[qg][sc] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max_stream[qg] = fmaxf(local_max_stream[qg], scores[elem]); + } + } + } else { + FragScoreT score_acc[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + rocwmma::fill_fragment(score_acc[qg], 0); + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + FragK k_frag; + rocwmma::load_matrix_sync( + k_frag, + &k_tile[0][0] + (col_tile * BK) * SharedHeadStride + dt * BK, + static_cast(SharedHeadStride)); +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if (!fully_future[qg]) { + rocwmma::mma_sync(score_acc[qg], k_frag, q_frag[qg][dt], score_acc[qg]); + } + } + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const float score_scale = qs[qg] * k_scale_local; + float8_vec scores; + if (fully_future[qg]) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else { + const auto score_rm = + rocwmma::apply_data_layout(score_acc[qg]); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = static_cast(score_rm[elem]) * score_scale; + } + if constexpr (ApplyCausalMask) { + const bool needs_causal_mask = k_col_start + BK > q_start[qg]; + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), + static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache_stream[qg][sc] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max_stream[qg] = fmaxf(local_max_stream[qg], scores[elem]); + } + } + } + } + +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const float tile_max = + fmaxf(local_max_stream[qg], __shfl_xor(local_max_stream[qg], 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const bool has_previous_sum = l[qg] != 0.0f; + const float alpha = has_previous_sum ? fast_exp2(old_m - new_m) : 0.0f; + m[qg] = new_m; + l[qg] *= alpha; + + if (has_previous_sum) { + float alpha_rows[8]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] *= alpha_rows[elem]; + } + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int sc = 0; sc < StreamCols; ++sc) { + const int col_tile = stream_col + sc; + if constexpr (UseActiveCausalColSkip) { + const int active_cols = qg == 0 ? active_cols0 : active_cols1; + if (col_tile >= active_cols) { + continue; + } + } + const int64_t k_col_start = kb_base + col_tile * BK; + bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + if constexpr (UseActiveCausalColSkip) { + fully_future = false; + } + const float8_vec scores = score_cache_stream[qg][sc]; + float8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + float prob = 0.0f; + if (!fully_future) { + prob = fast_exp2(scores[elem] - m[qg] + kFp8SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = prob; + } + prob_cache_stream[qg][sc] = make_p_fp8_regs_from_tqk_prob_regs( + prob_values, lane); + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + +#pragma unroll + for (int sc = 0; sc < StreamCols; ++sc) { + const int col_tile = stream_col + sc; + if constexpr (UseActiveCausalColSkip) { + if (col_tile >= active_cols_any) { + continue; + } + } + const int64_t k_col_start = kb_base + col_tile * BK; + bool fully_future[QGroups]; + bool skip_all = true; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if constexpr (UseActiveCausalColSkip) { + const int active_cols = qg == 0 ? active_cols0 : active_cols1; + fully_future[qg] = col_tile >= active_cols; + } else { + fully_future[qg] = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + } + skip_all = skip_all && fully_future[qg]; + } + if (skip_all) { + continue; + } +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { + const int dt = ValueTileBase + vdt; + i32x2_vec v_regs; + if constexpr (UsePrepackedLaneMajorValue) { + v_regs = make_v_fp8_regs_from_lane_major_global( + v, v_head_base, kb_base, col_tile, dt, lane); + } else if constexpr (PackedTransposedValue) { + v_regs = make_v_fp8_regs_from_transposed_shared( + &v_tile_t[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_fp8_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if (fully_future[qg]) { + continue; + } + float8_vec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[qg][vdt][elem]; + } + acc = __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + prob_cache_stream[qg][sc], v_regs, acc); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] = acc[elem]; + } + } + } + } + } + } else { + float8_vec score_cache[QGroups][ColTiles]; + float local_max[QGroups]; + float prepared_k_scale_tile = k_scale_tile; + if constexpr (!QuantizeKeyValue && BC <= 64) { + const int k_scale_idx = k_scale_col_per_warp(kb_base); + prepared_k_scale_tile = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + local_max[qg] = -FLT_MAX * 0.5f; + } +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + if constexpr (UseActiveCausalColSkip) { + if (col_tile >= active_cols_any) { + continue; + } + } + const int64_t k_col_start = kb_base + col_tile * BK; + const bool skip_all = + ApplyCausalMask && k_col_start >= q_start[QGroups - 1] + RM; + if (skip_all) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float8_vec scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + score_cache[qg][col_tile] = scores; + } + continue; + } + bool q0_fully_future = + ApplyCausalMask && k_col_start >= q_start[0] + RM; + bool q1_fully_future = + ApplyCausalMask && k_col_start >= q_start[1] + RM; + if constexpr (UseActiveCausalColSkip) { + q0_fully_future = col_tile >= active_cols0; + q1_fully_future = col_tile >= active_cols1; + } + float k_scale_local = k_scale_tile; + if constexpr (!QuantizeKeyValue && BC <= 64) { + k_scale_local = prepared_k_scale_tile; + } else if constexpr (!QuantizeKeyValue) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + i32x8_vec score_acc[QGroups]; +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + score_acc[qg][elem] = 0; + } + } +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + i32x2_vec k_regs; + if constexpr (UsePrepackedLaneMajorK) { + k_regs = pack_k_i8_wmma_b_regs_from_lane_major_global( + k, k_head_base, k_stride_n * 64, kb_base, col_tile, dt, lane); + } else { + k_regs = pack_k_i8_wmma_b_regs_from_shared( + &k_tile[0][0], lane, col_tile, dt * BK); + } + if (!q0_fully_future) { + score_acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[0][dt], score_acc[0], true); + } + if (!q1_fully_future) { + score_acc[1] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + true, k_regs, true, q_regs[1][dt], score_acc[1], true); + } + } +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const float score_scale = qs[qg] * k_scale_local; + float8_vec scores; + const bool fully_future = + qg == 0 ? q0_fully_future : q1_fully_future; + if constexpr (UseActiveCausalColSkip) { + if (fully_future) { + continue; + } + } + if (fully_future) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = static_cast(score_acc[qg][elem]) * score_scale; + } + if constexpr (ApplyCausalMask) { + const bool needs_causal_mask = k_col_start + BK > q_start[qg]; + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache[qg][col_tile] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max[qg] = fmaxf(local_max[qg], scores[elem]); + } + } + } + +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + const float tile_max = fmaxf(local_max[qg], __shfl_xor(local_max[qg], 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const bool has_previous_sum = l[qg] != 0.0f; + const float alpha = has_previous_sum ? fast_exp2(old_m - new_m) : 0.0f; + m[qg] = new_m; + l[qg] *= alpha; + + if (has_previous_sum) { + float alpha_rows[8]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] *= alpha_rows[elem]; + } + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + if constexpr (UseActiveCausalColSkip) { + const int active_cols = qg == 0 ? active_cols0 : active_cols1; + if (col_tile >= active_cols) { + continue; + } + } + const int64_t k_col_start = kb_base + col_tile * BK; + bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + if constexpr (UseActiveCausalColSkip) { + fully_future = false; + } + const float8_vec scores = score_cache[qg][col_tile]; + float8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + float prob = 0.0f; + if (!fully_future) { + prob = fast_exp2(scores[elem] - m[qg] + kFp8SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = prob; + } + prob_cache[qg][col_tile] = make_p_fp8_regs_from_tqk_prob_regs(prob_values, lane); + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + } + } else { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float8_vec score_cache[ColTiles]; + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + float8_vec scores; + if (fully_future) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else { + float k_scale_local = k_scale_tile; + if constexpr (!QuantizeKeyValue) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + const float score_scale = qs[qg] * k_scale_local; + if constexpr (QuantizeQuery) { + scores = compute_tqk_score_regs_raw_kq( + &k_tile[0][0], q_regs[qg], lane, col_tile, score_scale); + } else { + scores = + compute_tqk_score_regs( + &k_tile[0][0], q_frag[qg], col_tile, score_scale); + } + if constexpr (ApplyCausalMask) { + const bool needs_causal_mask = k_col_start + BK > q_start[qg]; + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache[col_tile] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max = fmaxf(local_max, scores[elem]); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float alpha_rows[8]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] *= alpha_rows[elem]; + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + const float8_vec scores = score_cache[col_tile]; + float8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + float prob = 0.0f; + if (!fully_future) { + prob = fast_exp2(scores[elem] - m[qg] + kFp8SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = prob; + } + prob_cache[qg][col_tile] = make_p_fp8_regs_from_tqk_prob_regs(prob_values, lane); + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + } + + if constexpr (!UseStreamedFp8Pv) { +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + if constexpr (UseActiveCausalColSkip) { + if (col_tile >= active_cols_any) { + continue; + } + } + const int64_t k_col_start = kb_base + col_tile * BK; + bool q0_fully_future = + ApplyCausalMask && k_col_start >= q_start[0] + RM; + bool q1_fully_future = + ApplyCausalMask && k_col_start >= q_start[1] + RM; + if constexpr (UseActiveCausalColSkip) { + q0_fully_future = col_tile >= active_cols0; + q1_fully_future = col_tile >= active_cols1; + } + if (q0_fully_future && q1_fully_future) { + continue; + } +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { + const int dt = ValueTileBase + vdt; + i32x2_vec v_regs; + if constexpr (UsePrepackedLaneMajorValue) { + v_regs = make_v_fp8_regs_from_lane_major_global( + v, v_head_base, kb_base, col_tile, dt, lane); + } else if constexpr (ValueTransposed) { + if constexpr (PackedTransposedValue) { + v_regs = make_v_fp8_regs_from_transposed_shared( + &v_tile_t[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_fp8_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } + } else { + v_regs = make_v_fp8_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + float8_vec acc0; + float8_vec acc1; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc0[elem] = out_frag[0][vdt][elem]; + acc1[elem] = out_frag[1][vdt][elem]; + } + if (!q0_fully_future) { + const i32x2_vec p_regs0 = prob_cache[0][col_tile]; + acc0 = __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + p_regs0, v_regs, acc0); + } + if (!q1_fully_future) { + const i32x2_vec p_regs1 = prob_cache[1][col_tile]; + acc1 = __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + p_regs1, v_regs, acc1); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[0][vdt][elem] = acc0[elem]; + out_frag[1][vdt][elem] = acc1[elem]; + } + } + } + } + } else { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float8_vec score_cache[ColTiles]; + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + float8_vec scores; + if (fully_future) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + scores[elem] = -FLT_MAX * 0.5f; + } + } else { + float k_scale_local = k_scale_tile; + if constexpr (!QuantizeKeyValue) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + const float score_scale = qs[qg] * k_scale_local; + if constexpr (QuantizeQuery) { + scores = compute_tqk_score_regs_raw_kq( + &k_tile[0][0], q_regs[qg], lane, col_tile, score_scale); + } else { + scores = + compute_tqk_score_regs( + &k_tile[0][0], q_frag[qg], col_tile, score_scale); + } + if constexpr (ApplyCausalMask) { + const bool needs_causal_mask = k_col_start + BK > q_start[qg]; + if (needs_causal_mask) { + apply_tqk_causal_mask( + scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); + } + } + } + if (k_col_start + BK > kv_len) { + apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); + } + score_cache[col_tile] = scores; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + local_max = fmaxf(local_max, scores[elem]); + } + } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float alpha_rows[8]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] *= alpha_rows[elem]; + } + } + + float local_sum = 0.0f; +#pragma unroll + for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { + const int64_t k_col_start = kb_base + col_tile * BK; + const bool fully_future = + ApplyCausalMask && k_col_start >= q_start[qg] + RM; + if (fully_future) { + continue; + } + const float8_vec scores = score_cache[col_tile]; + float8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float prob = fast_exp2(scores[elem] - m[qg] + kFp8SoftmaxOffset); + local_sum += prob; + prob_values[elem] = prob; + } + + const i32x2_vec p_regs = make_p_fp8_regs_from_tqk_prob_regs(prob_values, lane); +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { + const int dt = ValueTileBase + vdt; + i32x2_vec v_regs; + if constexpr (UsePrepackedLaneMajorValue) { + v_regs = make_v_fp8_regs_from_lane_major_global( + v, v_head_base, kb_base, col_tile, dt, lane); + } else if constexpr (ValueTransposed) { + if constexpr (PackedTransposedValue) { + v_regs = make_v_fp8_regs_from_transposed_shared( + &v_tile_t[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_fp8_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } + } else { + v_regs = make_v_fp8_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + float8_vec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[qg][vdt][elem]; + } + const float8_vec pv_acc = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12(p_regs, v_regs, acc); +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][vdt][elem] = pv_acc[elem]; + } + } + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + } + if constexpr (UsesTileSharedMemory) { + __syncthreads(); + } + }; + + if constexpr (IsCausal) { + const int64_t diag_base = (q_base / BC) * BC; + const int64_t prefix_limit = diag_base < kv_limit ? diag_base : kv_limit; +#pragma unroll 2 + for (int64_t kb_base = 0; kb_base < prefix_limit; kb_base += BC) { + process_kv_tile(kb_base, std::false_type{}); + } +#pragma unroll 2 + for (int64_t kb_base = prefix_limit; kb_base < kv_limit; kb_base += BC) { + process_kv_tile(kb_base, std::true_type{}); + } + } else { +#pragma unroll 2 + for (int64_t kb_base = 0; kb_base < kv_limit; kb_base += BC) { + process_kv_tile(kb_base, std::false_type{}); + } + } + +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + float l_rows[8]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + l_rows[elem] = __shfl(l[qg], row_base + elem, 32); + } +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { + const int d = (ValueTileBase + vdt) * BK + col; + const float value_scale = v_scale == nullptr ? + 1.0f : v_scale[(b * num_kv_heads + hkv) * HeadDim + d]; +#pragma unroll + for (int pair = 0; pair < PackedRows; ++pair) { + const int elem = pair * 2; + const int64_t q_idx0 = q_start[qg] + row_base + elem; + const int64_t q_idx1 = q_idx0 + 1; + const float l_sum0 = l_rows[elem]; + const float l_sum1 = l_rows[elem + 1]; + const float value0 = l_sum0 == 0.0f ? + 0.0f : (out_frag[qg][vdt][elem] / l_sum0) * value_scale; + const float value1 = l_sum1 == 0.0f ? + 0.0f : (out_frag[qg][vdt][elem + 1] / l_sum1) * value_scale; + store_output_value(output, qkv_offset_dispatch( + tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h), value0); + store_output_value(output, qkv_offset_dispatch( + tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h), value1); + } + } + } +} + +template +__global__ void prepare_qkv_hnd_kernel( + const T* __restrict__ query, + const T* __restrict__ key, + const T* __restrict__ value, + int8_t* __restrict__ query_out, + int8_t* __restrict__ key_out, + float* __restrict__ query_scale, + float* __restrict__ key_scale, + OutT* __restrict__ value_out, + const int64_t batch, + const int64_t q_heads, + const int64_t kv_heads, + const int64_t q_len, + const int64_t kv_len, + const int q_groups, + const int k_groups, + const bool fuse_self_qkv) { + constexpr int PackElems = 8; + constexpr int KRows = 64; + constexpr int QRows = 32; + constexpr int ValueStride = HeadDim + 16; + constexpr int LaneColTiles = KRows / 16; + constexpr int LaneDTiles = HeadDim / 16; + static_assert((HeadDim % PackElems) == 0, "native preparation packs eight elements"); + static_assert(!PrepackF16VLane || (!ToFp8 && TransposeValue && HeadDim == 64), + "fp16 lane-major V prepack is specialized for transposed D64 fp16 values"); + static_assert(!PrepackF16KLane || (!ToFp8 && HeadDim == 64), + "fp16 lane-major K prepack is specialized for D64 fp16/bf16 keys"); + static_assert(!PrepackFp8Lane || (ToFp8 && TransposeValue && HeadDim == 64), + "fp8 lane-major K/V prepack is specialized for transposed D64 fp8 values"); + static_assert(StaticQLen == 0 || (StaticQLen % (2 * QRows)) == 0, + "static QKV preparation Q length must cover full two-group Q tasks."); + static_assert(StaticKvLen == 0 || (StaticKvLen % KRows) == 0, + "static QKV preparation KV length must cover full K groups."); + + __shared__ float shared_amax[2]; + __shared__ float shared_pair_amax[2][32]; + __shared__ OutT value_tile[PrepackFp8Lane ? 1 : KRows][PrepackFp8Lane ? 1 : ValueStride]; + __shared__ int8_t key_tile[PrepackF16KLane ? KRows : 1][PrepackF16KLane ? ValueStride : 1]; + + const int task = blockIdx.x; + const int head = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + constexpr bool StaticFullQ = FullGroupsNoTail && StaticQLen != 0; + constexpr bool StaticFullKV = FullGroupsNoTail && StaticKvLen != 0; + constexpr int StaticQGroups = StaticQLen == 0 ? 0 : (StaticQLen / QRows); + constexpr int StaticQTaskGroups = StaticQLen == 0 ? 0 : ((StaticQGroups + 1) / 2); + constexpr int StaticKGroups = StaticKvLen == 0 ? 0 : (StaticKvLen / KRows); + const int effective_q_groups = StaticQLen == 0 ? q_groups : StaticQGroups; + const int effective_q_task_groups = + StaticQLen == 0 ? ((q_groups + 1) / 2) : StaticQTaskGroups; + const int effective_k_groups = StaticKvLen == 0 ? k_groups : StaticKGroups; + const int64_t effective_q_len = StaticQLen == 0 ? q_len : StaticQLen; + const int64_t effective_kv_len = StaticKvLen == 0 ? kv_len : StaticKvLen; + + if constexpr (PrepareQuery) { + if (task < effective_q_task_groups) { + const int local_group_base = task * 2; + const int64_t base_row = static_cast(local_group_base) * QRows; + if (b >= batch || head >= q_heads || + (!StaticFullQ && base_row >= effective_q_len)) { + return; + } + + constexpr int packs = (QRows * HeadDim) / PackElems; + const bool has_q_group1 = StaticFullQ || ((local_group_base + 1) < effective_q_groups); + float local_amax0 = 0.0000001f; + float local_amax1 = 0.0000001f; + for (int pack = tid; pack < packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq0 = base_row + row; + if (StaticFullQ || seq0 < effective_q_len) { + const int64_t off = + ((static_cast(b) * q_heads + head) * effective_q_len + seq0) * + HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax0 = fmaxf(local_amax0, fabsf(value_to_float(values[i]))); + } + } + if (has_q_group1) { + const int64_t seq1 = base_row + QRows + row; + if (StaticFullQ || seq1 < effective_q_len) { + const int64_t off = + ((static_cast(b) * q_heads + head) * effective_q_len + seq1) * + HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax1 = fmaxf(local_amax1, fabsf(value_to_float(values[i]))); + } + } + } + } + const int lane = tid & 31; + const int wid = tid >> 5; + local_amax0 = vllm::warpReduceMax(local_amax0); + local_amax1 = vllm::warpReduceMax(local_amax1); + if (lane == 0) { + shared_pair_amax[0][wid] = local_amax0; + shared_pair_amax[1][wid] = local_amax1; + } + __syncthreads(); + const bool warp_lane_active = tid < (blockDim.x / 32); + local_amax0 = warp_lane_active ? shared_pair_amax[0][lane] : -1e20f; + local_amax1 = warp_lane_active ? shared_pair_amax[1][lane] : -1e20f; + local_amax0 = vllm::warpReduceMax(local_amax0); + local_amax1 = vllm::warpReduceMax(local_amax1); + if (tid == 0) { + shared_amax[0] = local_amax0; + query_scale[(static_cast(b) * q_heads + head) * effective_q_groups + + local_group_base] = local_amax0 / 127.0f; + if (has_q_group1) { + shared_amax[1] = local_amax1; + query_scale[(static_cast(b) * q_heads + head) * effective_q_groups + + local_group_base + 1] = local_amax1 / 127.0f; + } + } + __syncthreads(); + const float inv_scale0 = 127.0f / shared_amax[0]; + const float inv_scale1 = has_q_group1 ? (127.0f / shared_amax[1]) : 0.0f; + + for (int pack = tid; pack < packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq0 = base_row + row; + if (StaticFullQ || seq0 < effective_q_len) { + const int64_t off = + ((static_cast(b) * q_heads + head) * effective_q_len + seq0) * + HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale0); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale0); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale0); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale0); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale0); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale0); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale0); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale0); + *reinterpret_cast(query_out + off) = out0; + *reinterpret_cast(query_out + off + 4) = out1; + } + if (has_q_group1) { + const int64_t seq1 = base_row + QRows + row; + if (StaticFullQ || seq1 < effective_q_len) { + const int64_t off = + ((static_cast(b) * q_heads + head) * effective_q_len + seq1) * + HeadDim + d; + const uint4 raw = *reinterpret_cast(query + off); + const T* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale1); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale1); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale1); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale1); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale1); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale1); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale1); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale1); + *reinterpret_cast(query_out + off) = out0; + *reinterpret_cast(query_out + off + 4) = out1; + } + } + } + if (!fuse_self_qkv) { + return; + } + } + } + + const int local_group = + PrepareQuery ? (fuse_self_qkv ? task : task - effective_q_task_groups) : task; + const int64_t base_row = static_cast(local_group) * KRows; + if (b >= batch || head >= kv_heads || local_group >= effective_k_groups || + (!StaticFullKV && base_row >= effective_kv_len)) { + return; + } + + constexpr int kv_packs = (KRows * HeadDim) / PackElems; + float local_amax = 0.0000001f; + for (int pack = tid; pack < kv_packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (StaticFullKV || seq < effective_kv_len) { + const int64_t off = + ((static_cast(b) * kv_heads + head) * effective_kv_len + seq) * + HeadDim + d; + const uint4 raw = *reinterpret_cast(key + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax = fmaxf(local_amax, fabsf(value_to_float(values[i]))); + } + } + } + const float block_amax = vllm::blockReduceMax(local_amax); + if (tid == 0) { + shared_amax[0] = block_amax; + key_scale[(static_cast(b) * kv_heads + head) * effective_k_groups + + local_group] = shared_amax[0] / 127.0f; + } + __syncthreads(); + const float inv_scale = 127.0f / shared_amax[0]; + + for (int pack = tid; pack < kv_packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (StaticFullKV || seq < effective_kv_len) { + const int64_t off = + ((static_cast(b) * kv_heads + head) * effective_kv_len + seq) * + HeadDim + d; + const uint4 raw_k = *reinterpret_cast(key + off); + const uint4 raw_v = *reinterpret_cast(value + off); + const T* k_values = reinterpret_cast(&raw_k); + const T* v_values = reinterpret_cast(&raw_v); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(k_values[0]) * inv_scale); + out0.y = float_to_int8_rn_gfx12(value_to_float(k_values[1]) * inv_scale); + out0.z = float_to_int8_rn_gfx12(value_to_float(k_values[2]) * inv_scale); + out0.w = float_to_int8_rn_gfx12(value_to_float(k_values[3]) * inv_scale); + out1.x = float_to_int8_rn_gfx12(value_to_float(k_values[4]) * inv_scale); + out1.y = float_to_int8_rn_gfx12(value_to_float(k_values[5]) * inv_scale); + out1.z = float_to_int8_rn_gfx12(value_to_float(k_values[6]) * inv_scale); + out1.w = float_to_int8_rn_gfx12(value_to_float(k_values[7]) * inv_scale); + if constexpr (PrepackFp8Lane) { + const int row_in_group = row & 63; + const int col_tile = row_in_group >> 4; + const int col = row_in_group & 15; + const int d_tile = d >> 4; + const int lane_out = col | (((d & 8) != 0) ? 16 : 0); + const int64_t lane_off = + ((((static_cast(b) * kv_heads + head) * k_groups + local_group) * + LaneColTiles + col_tile) * LaneDTiles + d_tile) * 32 * 8 + + static_cast(lane_out) * 8; + uint2 packed; + packed.x = *reinterpret_cast(&out0); + packed.y = *reinterpret_cast(&out1); + *reinterpret_cast(key_out + lane_off) = packed; + } else if constexpr (PrepackF16KLane) { + *reinterpret_cast(&key_tile[row][d]) = out0; + *reinterpret_cast(&key_tile[row][d + 4]) = out1; + } else { + *reinterpret_cast(key_out + off) = out0; + *reinterpret_cast(key_out + off + 4) = out1; + } + if constexpr (ToFp8) { + const uint32_t v_pack0 = static_cast(pack_f32x4_to_ocp_fp8( + value_to_float(v_values[0]), value_to_float(v_values[1]), + value_to_float(v_values[2]), value_to_float(v_values[3]))); + const uint32_t v_pack1 = static_cast(pack_f32x4_to_ocp_fp8( + value_to_float(v_values[4]), value_to_float(v_values[5]), + value_to_float(v_values[6]), value_to_float(v_values[7]))); + if constexpr (PrepackFp8Lane) { + const int row_in_group = row & 63; + const int col_tile = row_in_group >> 4; + const int row_local = row_in_group & 15; + const int lane_hi = row_local >> 3; + const int gpr = (row_local & 7) >> 2; + const int byte = row_local & 3; + uint8_t* value_bytes = reinterpret_cast(value_out); +#pragma unroll + for (int elem = 0; elem < PackElems; ++elem) { + const int d_elem = d + elem; + const int d_tile = d_elem >> 4; + const int lane_local = (d_elem & 15) | (lane_hi << 4); + const int64_t byte_off = + (((((static_cast(b) * kv_heads + head) * k_groups + + local_group) * LaneColTiles + col_tile) * LaneDTiles + d_tile) * 32 + + lane_local) * 8 + + gpr * 4 + byte; + const uint32_t packed = elem < 4 ? v_pack0 : v_pack1; + value_bytes[byte_off] = + static_cast((packed >> (8 * (elem & 3))) & 0xff); + } + } else if constexpr (TransposeValue) { + *reinterpret_cast(&value_tile[row][d]) = v_pack0; + *reinterpret_cast(&value_tile[row][d + 4]) = v_pack1; + } else { + *reinterpret_cast(value_out + off) = v_pack0; + *reinterpret_cast(value_out + off + 4) = v_pack1; + } + } else { +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + if constexpr (std::is_same::value && std::is_same::value) { + value_tile[row][d + i] = v_values[i]; + } else { + const float v = value_to_float(v_values[i]); + value_tile[row][d + i] = __float2half_rn(v); + } + } + } + } + } + if constexpr ((!ToFp8 || TransposeValue) && !PrepackFp8Lane) { + __syncthreads(); + + if constexpr (PrepackF16KLane) { + constexpr int ColTiles64 = KRows / 16; + constexpr int DTiles = HeadDim / 16; + constexpr int LaneMajorElems = ColTiles64 * DTiles * 32; + for (int idx = tid; idx < LaneMajorElems; idx += Threads) { + const int lane_local = idx & 31; + const int d_tile = (idx >> 5) % DTiles; + const int col_tile = idx / (DTiles * 32); + const int col = lane_local & 15; + const int k_base = 8 * (lane_local >> 4); + const int row = col_tile * 16 + pv_k_order_for_acc_row(col); + const int d = d_tile * 16 + k_base; + const int64_t out_off = + ((((static_cast(b) * kv_heads + head) * k_groups + local_group) * + ColTiles64 * DTiles * 32) + idx) * 8; + uint2 packed; + packed.x = *reinterpret_cast(&key_tile[row][d]); + packed.y = *reinterpret_cast(&key_tile[row][d + 4]); + *reinterpret_cast(key_out + out_off) = packed; + } + } + + if constexpr (PrepackF16VLane) { + constexpr int ColTiles64 = KRows / 16; + constexpr int DTiles = HeadDim / 16; + constexpr int LaneMajorElems = ColTiles64 * DTiles * 32; + for (int idx = tid; idx < LaneMajorElems; idx += Threads) { + const int lane_local = idx & 31; + const int d_tile = (idx >> 5) % DTiles; + const int col_tile = idx / (DTiles * 32); + const int d = d_tile * 16 + (lane_local & 15); + const int high_half = (lane_local >> 4) & 1; + const int n0 = col_tile * 16 + high_half * 4; + const int n1 = col_tile * 16 + 8 + high_half * 4; + uint4 packed; + __half* vals = reinterpret_cast<__half*>(&packed); +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int row = n0 + i; + vals[i] = (StaticFullKV || (base_row + row) < effective_kv_len) ? + value_tile[row][d] : __float2half_rn(0.0f); + } +#pragma unroll + for (int i = 0; i < 4; ++i) { + const int row = n1 + i; + vals[4 + i] = + (StaticFullKV || (base_row + row) < effective_kv_len) ? + value_tile[row][d] : __float2half_rn(0.0f); + } + const int64_t out_off = + ((((static_cast(b) * kv_heads + head) * k_groups + local_group) * + ColTiles64 * DTiles * 32) + idx) * 8; + *reinterpret_cast(value_out + out_off) = packed; + } + } else if constexpr (ToFp8) { + constexpr int StoreRows = 16; + constexpr int RowGroups = KRows / StoreRows; + for (int linear = tid; linear < HeadDim * RowGroups; linear += Threads) { + const int d = linear / RowGroups; + const int row = (linear - d * RowGroups) * StoreRows; + const int64_t seq = base_row + row; + if (StaticFullKV || seq + StoreRows - 1 < effective_kv_len) { + uint4 packed; + packed.x = static_cast(value_tile[row + 0][d]) | + (static_cast(value_tile[row + 1][d]) << 8) | + (static_cast(value_tile[row + 2][d]) << 16) | + (static_cast(value_tile[row + 3][d]) << 24); + packed.y = static_cast(value_tile[row + 4][d]) | + (static_cast(value_tile[row + 5][d]) << 8) | + (static_cast(value_tile[row + 6][d]) << 16) | + (static_cast(value_tile[row + 7][d]) << 24); + packed.z = static_cast(value_tile[row + 8][d]) | + (static_cast(value_tile[row + 9][d]) << 8) | + (static_cast(value_tile[row + 10][d]) << 16) | + (static_cast(value_tile[row + 11][d]) << 24); + packed.w = static_cast(value_tile[row + 12][d]) | + (static_cast(value_tile[row + 13][d]) << 8) | + (static_cast(value_tile[row + 14][d]) << 16) | + (static_cast(value_tile[row + 15][d]) << 24); + const int64_t out_off = + ((static_cast(b) * kv_heads + head) * HeadDim + d) * + effective_kv_len + seq; + *reinterpret_cast(value_out + out_off) = packed; + } else { +#pragma unroll + for (int i = 0; i < StoreRows; ++i) { + const int64_t tail_seq = seq + i; + if (tail_seq < effective_kv_len) { + const int64_t out_off = + ((static_cast(b) * kv_heads + head) * HeadDim + d) * + effective_kv_len + + tail_seq; + value_out[out_off] = value_tile[row + i][d]; + } + } + } + } + } else { + constexpr int StoreRows = 8; + constexpr int RowGroups = KRows / StoreRows; + for (int linear = tid; linear < HeadDim * RowGroups; linear += Threads) { + const int d = linear / RowGroups; + const int row = (linear - d * RowGroups) * StoreRows; + const int64_t seq = base_row + row; + const int64_t out_off = + ((static_cast(b) * kv_heads + head) * HeadDim + d) * + effective_kv_len + seq; + if (StaticFullKV || seq + StoreRows - 1 < effective_kv_len) { + uint4 packed; + __half* vals = reinterpret_cast<__half*>(&packed); +#pragma unroll + for (int i = 0; i < StoreRows; ++i) { + vals[i] = value_tile[row + i][d]; + } + *reinterpret_cast(value_out + out_off) = packed; + } else { +#pragma unroll + for (int i = 0; i < StoreRows; ++i) { + const int64_t tail_seq = seq + i; + if (tail_seq < effective_kv_len) { + value_out[out_off + i] = value_tile[row + i][d]; + } + } + } + } + } + } +} + +template +__global__ void prepare_kv_hnd_fp8_kernel( + const T* __restrict__ key, + const T* __restrict__ value, + int8_t* __restrict__ key_out, + float* __restrict__ key_scale, + uint8_t* __restrict__ value_out, + const int64_t batch, + const int64_t kv_heads, + const int64_t kv_len, + const int k_groups) { + constexpr int PackElems = 8; + constexpr int GroupRows = 64; + constexpr int KRows = GroupRows * GroupsPerBlock; + constexpr int Threads = PrepThreads; + constexpr int ValueStride = HeadDim + 16; + constexpr int StoreRows = 16; + constexpr int RowGroups = KRows / StoreRows; + constexpr int ValueRowGroups4 = KRows / 4; + constexpr int ValueStride32 = ValueRowGroups4 + 4; + constexpr int LaneColTiles = GroupRows / 16; + constexpr int LaneDTiles = HeadDim / 16; + constexpr int Packs = (KRows * HeadDim) / PackElems; + constexpr bool LaneMajorK = LaneMajorKV || LaneMajorKOnly; + constexpr bool LaneMajorV = LaneMajorKV || LaneMajorVOnly; + static_assert(HeadDim == 64 || HeadDim == 128, + "fp8 KV preparation supports D64/D128."); + static_assert(GroupsPerBlock == 1 || GroupsPerBlock == 2 || GroupsPerBlock == 4, + "fp8 KV preparation supports 1, 2, or 4 scale groups per CTA."); + static_assert(PrepThreads == 128 || PrepThreads == 256 || PrepThreads == 512, + "fp8 KV preparation supports 128, 256, or 512 threads."); + static_assert(!CacheKeyInShared || GroupsPerBlock == 1, + "cached-key fp8 KV preparation is specialized for one scale group per CTA."); + static_assert(!LaneMajorKOnly || ((HeadDim == 64 || HeadDim == 128) && !LaneMajorKV), + "K-only lane-major fp8 preparation is specialized for D64/D128."); + static_assert(!LaneMajorVOnly || + ((HeadDim == 64 || HeadDim == 128) && + !LaneMajorKV && !TransposedValueStaging), + "V-only lane-major fp8 preparation is specialized for D64/D128."); + static_assert(!FullGroupsNoTail || GroupsPerBlock == 1, + "no-tail fp8 KV preparation is specialized for one scale group per CTA."); + static_assert(StaticKvLen == 0 || (StaticKvLen % GroupRows) == 0, + "static fp8 KV preparation length must be a whole scale group."); + + __shared__ float shared_amax[GroupsPerBlock]; + __shared__ float shared_pair_amax[GroupsPerBlock][32]; + __shared__ uint4 key_cache[CacheKeyInShared ? Packs : 1]; + __shared__ uint8_t value_tile[LaneMajorV ? 1 : (TransposedValueStaging ? 1 : KRows)] + [LaneMajorV ? 1 : ValueStride]; + __shared__ uint32_t value_tile_t[LaneMajorV ? 1 : (TransposedValueStaging ? HeadDim : 1)] + [LaneMajorV ? 1 : (TransposedValueStaging ? ValueStride32 : 1)]; + + const int block_group = blockIdx.x; + const int head = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const int lane = tid & 31; + const int wid = tid >> 5; + const int group_base = block_group * GroupsPerBlock; + constexpr int StaticKGroups = StaticKvLen == 0 ? 0 : (StaticKvLen / GroupRows); + const int effective_k_groups = StaticKvLen == 0 ? k_groups : StaticKGroups; + const int64_t effective_kv_len = StaticKvLen == 0 ? kv_len : StaticKvLen; + const int64_t head_seq_base = + (static_cast(b) * kv_heads + head) * effective_kv_len; + const int64_t transposed_value_head_base = + (static_cast(b) * kv_heads + head) * HeadDim * effective_kv_len; + const int64_t scale_head_base = + (static_cast(b) * kv_heads + head) * effective_k_groups; + const int64_t base_row = static_cast(group_base) * GroupRows; + if (b >= batch || head >= kv_heads || group_base >= effective_k_groups || + base_row >= effective_kv_len) { + return; + } + + float local_amax[GroupsPerBlock]; +#pragma unroll + for (int group = 0; group < GroupsPerBlock; ++group) { + local_amax[group] = 0.0000001f; + } + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (FullGroupsNoTail || seq < effective_kv_len) { + const int64_t off = (head_seq_base + seq) * HeadDim + d; + const uint4 raw = *reinterpret_cast(key + off); + if constexpr (CacheKeyInShared) { + key_cache[pack] = raw; + } + const T* values = reinterpret_cast(&raw); + float pack_amax = 0.0000001f; +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + pack_amax = fmaxf(pack_amax, fabsf(value_to_float(values[i]))); + } + const int group = row >> 6; + local_amax[group] = fmaxf(local_amax[group], pack_amax); + } + } + +#pragma unroll + for (int group = 0; group < GroupsPerBlock; ++group) { + local_amax[group] = vllm::warpReduceMax(local_amax[group]); + if (lane == 0) { + shared_pair_amax[group][wid] = local_amax[group]; + } + } + __syncthreads(); + const bool warp_lane_active = tid < (blockDim.x / 32); +#pragma unroll + for (int group = 0; group < GroupsPerBlock; ++group) { + float group_amax = warp_lane_active ? shared_pair_amax[group][lane] : -1e20f; + group_amax = vllm::warpReduceMax(group_amax); + if (tid == 0) { + shared_amax[group] = group_amax; + if (FullGroupsNoTail || group_base + group < effective_k_groups) { + key_scale[scale_head_base + group_base + group] = group_amax / 127.0f; + } + } + } + __syncthreads(); + float inv_scales[GroupsPerBlock]; +#pragma unroll + for (int group = 0; group < GroupsPerBlock; ++group) { + inv_scales[group] = + (FullGroupsNoTail || (group_base + group) < effective_k_groups) ? + (127.0f / shared_amax[group]) : 0.0f; + } + + for (int pack = tid; pack < Packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (FullGroupsNoTail || seq < effective_kv_len) { + const int64_t off = (head_seq_base + seq) * HeadDim + d; + const uint4 raw_k = CacheKeyInShared ? key_cache[pack] : + *reinterpret_cast(key + off); + const T* k_values = reinterpret_cast(&raw_k); + const float inv_scale = inv_scales[row >> 6]; + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(k_values[0]) * inv_scale); + out0.y = float_to_int8_rn_gfx12(value_to_float(k_values[1]) * inv_scale); + out0.z = float_to_int8_rn_gfx12(value_to_float(k_values[2]) * inv_scale); + out0.w = float_to_int8_rn_gfx12(value_to_float(k_values[3]) * inv_scale); + out1.x = float_to_int8_rn_gfx12(value_to_float(k_values[4]) * inv_scale); + out1.y = float_to_int8_rn_gfx12(value_to_float(k_values[5]) * inv_scale); + out1.z = float_to_int8_rn_gfx12(value_to_float(k_values[6]) * inv_scale); + out1.w = float_to_int8_rn_gfx12(value_to_float(k_values[7]) * inv_scale); + if constexpr (LaneMajorK) { + const int group = row >> 6; + const int row_in_group = row & 63; + const int col_tile = row_in_group >> 4; + const int col = row_in_group & 15; + const int d_tile = d >> 4; + const int lane_out = col | (((d & 8) != 0) ? 16 : 0); + const int64_t lane_off = + ((((static_cast(b) * kv_heads + head) * effective_k_groups + + (group_base + group)) * LaneColTiles + col_tile) * LaneDTiles + d_tile) * + 32 * 8 + + static_cast(lane_out) * 8; + uint2 packed; + packed.x = *reinterpret_cast(&out0); + packed.y = *reinterpret_cast(&out1); + *reinterpret_cast(key_out + lane_off) = packed; + } else { + *reinterpret_cast(key_out + off) = out0; + *reinterpret_cast(key_out + off + 4) = out1; + } + + if constexpr (!TransposedValueStaging && !LaneMajorV) { + const uint4 raw_v = *reinterpret_cast(value + off); + const T* v_values = reinterpret_cast(&raw_v); + const uint32_t v_pack0 = static_cast(pack_f32x4_to_ocp_fp8( + value_to_float(v_values[0]), value_to_float(v_values[1]), + value_to_float(v_values[2]), value_to_float(v_values[3]))); + const uint32_t v_pack1 = static_cast(pack_f32x4_to_ocp_fp8( + value_to_float(v_values[4]), value_to_float(v_values[5]), + value_to_float(v_values[6]), value_to_float(v_values[7]))); + *reinterpret_cast(&value_tile[row][d]) = v_pack0; + *reinterpret_cast(&value_tile[row][d + 4]) = v_pack1; + } + } + } + + if constexpr (LaneMajorV) { + constexpr int LaneMajorValueRegs = + GroupsPerBlock * LaneColTiles * LaneDTiles * 32; + for (int idx = tid; idx < LaneMajorValueRegs; idx += Threads) { + const int lane_out = idx & 31; + const int tile = idx >> 5; + const int d_tile = tile % LaneDTiles; + const int col_tile = (tile / LaneDTiles) % LaneColTiles; + const int group = tile / (LaneDTiles * LaneColTiles); + if constexpr (!FullGroupsNoTail) { + if ((group_base + group) >= effective_k_groups) { + continue; + } + } + const int row_base = + group * GroupRows + col_tile * 16 + ((lane_out >> 4) << 3); + const int d = d_tile * 16 + (lane_out & 15); + float values[8]; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int64_t seq = base_row + row_base + elem; + if (FullGroupsNoTail || seq < effective_kv_len) { + const int64_t off = (head_seq_base + seq) * HeadDim + d; + values[elem] = value_to_float(value[off]); + } else { + values[elem] = 0.0f; + } + } + const uint32_t packed0 = static_cast( + pack_f32x4_to_ocp_fp8(values[0], values[1], values[2], values[3])); + const uint32_t packed1 = static_cast( + pack_f32x4_to_ocp_fp8(values[4], values[5], values[6], values[7])); + const int64_t out_off = + (((((static_cast(b) * kv_heads + head) * effective_k_groups + + (group_base + group)) * LaneColTiles + col_tile) * LaneDTiles + d_tile) * + 32 + + lane_out) * + 8; + uint2 packed; + packed.x = packed0; + packed.y = packed1; + *reinterpret_cast(value_out + out_off) = packed; + } + } else if constexpr (TransposedValueStaging) { + for (int linear = tid; linear < HeadDim * ValueRowGroups4; linear += Threads) { + const int d = linear / ValueRowGroups4; + const int row4 = linear - d * ValueRowGroups4; + const int row = row4 * 4; + const int64_t seq = base_row + row; + float v0 = 0.0f; + float v1 = 0.0f; + float v2 = 0.0f; + float v3 = 0.0f; + if (FullGroupsNoTail || seq < effective_kv_len) { + const int64_t value_base = (head_seq_base + seq) * HeadDim + d; + v0 = value_to_float(value[value_base + 0 * HeadDim]); + if (FullGroupsNoTail || seq + 1 < effective_kv_len) { + v1 = value_to_float(value[value_base + 1 * HeadDim]); + } + if (FullGroupsNoTail || seq + 2 < effective_kv_len) { + v2 = value_to_float(value[value_base + 2 * HeadDim]); + } + if (FullGroupsNoTail || seq + 3 < effective_kv_len) { + v3 = value_to_float(value[value_base + 3 * HeadDim]); + } + } + value_tile_t[d][row4] = static_cast(pack_f32x4_to_ocp_fp8(v0, v1, v2, v3)); + } + __syncthreads(); + + for (int linear = tid; linear < HeadDim * RowGroups; linear += Threads) { + const int d = linear / RowGroups; + const int row = (linear - d * RowGroups) * StoreRows; + const int64_t seq = base_row + row; + const int row4 = row >> 2; + if constexpr (FullGroupsNoTail) { + uint4 packed; + packed.x = value_tile_t[d][row4 + 0]; + packed.y = value_tile_t[d][row4 + 1]; + packed.z = value_tile_t[d][row4 + 2]; + packed.w = value_tile_t[d][row4 + 3]; + const int64_t out_off = transposed_value_head_base + + static_cast(d) * effective_kv_len + seq; + *reinterpret_cast(value_out + out_off) = packed; + } else if (seq + StoreRows - 1 < effective_kv_len) { + uint4 packed; + packed.x = value_tile_t[d][row4 + 0]; + packed.y = value_tile_t[d][row4 + 1]; + packed.z = value_tile_t[d][row4 + 2]; + packed.w = value_tile_t[d][row4 + 3]; + const int64_t out_off = transposed_value_head_base + + static_cast(d) * effective_kv_len + seq; + *reinterpret_cast(value_out + out_off) = packed; + } else { +#pragma unroll + for (int i = 0; i < StoreRows; ++i) { + const int64_t tail_seq = seq + i; + if (tail_seq < effective_kv_len) { + const uint32_t packed = value_tile_t[d][row4 + (i >> 2)]; + const int64_t out_off = transposed_value_head_base + + static_cast(d) * effective_kv_len + tail_seq; + value_out[out_off] = static_cast((packed >> (8 * (i & 3))) & 0xffu); + } + } + } + } + } else { + __syncthreads(); + + for (int linear = tid; linear < HeadDim * RowGroups; linear += Threads) { + const int d = linear / RowGroups; + const int row = (linear - d * RowGroups) * StoreRows; + const int64_t seq = base_row + row; + if constexpr (FullGroupsNoTail) { + uint4 packed; + packed.x = static_cast(value_tile[row + 0][d]) | + (static_cast(value_tile[row + 1][d]) << 8) | + (static_cast(value_tile[row + 2][d]) << 16) | + (static_cast(value_tile[row + 3][d]) << 24); + packed.y = static_cast(value_tile[row + 4][d]) | + (static_cast(value_tile[row + 5][d]) << 8) | + (static_cast(value_tile[row + 6][d]) << 16) | + (static_cast(value_tile[row + 7][d]) << 24); + packed.z = static_cast(value_tile[row + 8][d]) | + (static_cast(value_tile[row + 9][d]) << 8) | + (static_cast(value_tile[row + 10][d]) << 16) | + (static_cast(value_tile[row + 11][d]) << 24); + packed.w = static_cast(value_tile[row + 12][d]) | + (static_cast(value_tile[row + 13][d]) << 8) | + (static_cast(value_tile[row + 14][d]) << 16) | + (static_cast(value_tile[row + 15][d]) << 24); + const int64_t out_off = transposed_value_head_base + + static_cast(d) * effective_kv_len + seq; + *reinterpret_cast(value_out + out_off) = packed; + } else if (seq + StoreRows - 1 < effective_kv_len) { + uint4 packed; + packed.x = static_cast(value_tile[row + 0][d]) | + (static_cast(value_tile[row + 1][d]) << 8) | + (static_cast(value_tile[row + 2][d]) << 16) | + (static_cast(value_tile[row + 3][d]) << 24); + packed.y = static_cast(value_tile[row + 4][d]) | + (static_cast(value_tile[row + 5][d]) << 8) | + (static_cast(value_tile[row + 6][d]) << 16) | + (static_cast(value_tile[row + 7][d]) << 24); + packed.z = static_cast(value_tile[row + 8][d]) | + (static_cast(value_tile[row + 9][d]) << 8) | + (static_cast(value_tile[row + 10][d]) << 16) | + (static_cast(value_tile[row + 11][d]) << 24); + packed.w = static_cast(value_tile[row + 12][d]) | + (static_cast(value_tile[row + 13][d]) << 8) | + (static_cast(value_tile[row + 14][d]) << 16) | + (static_cast(value_tile[row + 15][d]) << 24); + const int64_t out_off = transposed_value_head_base + + static_cast(d) * effective_kv_len + seq; + *reinterpret_cast(value_out + out_off) = packed; + } else { +#pragma unroll + for (int i = 0; i < StoreRows; ++i) { + const int64_t tail_seq = seq + i; + if (tail_seq < effective_kv_len) { + const int64_t out_off = transposed_value_head_base + + static_cast(d) * effective_kv_len + tail_seq; + value_out[out_off] = value_tile[row + i][d]; + } + } + } + } + } +} + +} // namespace + +Tensor transpose_value_fp8_hnd_gfx12(Tensor value) { + return transpose_value_hnd_gfx12(value); +} + +Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale) { + STD_TORCH_CHECK(value.is_cuda() && value_scale.is_cuda(), + "gfx12 scaled value transpose expects CUDA/HIP tensors"); + STD_TORCH_CHECK(value.dim() == 4, "gfx12 scaled value transpose expects [B, H, S, D]"); + STD_TORCH_CHECK(value.is_contiguous(), "gfx12 scaled value transpose expects contiguous HND input"); + STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half || value.scalar_type() == ScalarType::BFloat16, + "gfx12 scaled value transpose supports fp16/bf16 input"); + STD_TORCH_CHECK(value_scale.scalar_type() == ScalarType::Float, + "gfx12 value scale must be fp32"); + STD_TORCH_CHECK(value_scale.dim() == 3 && value_scale.is_contiguous(), + "gfx12 value scale expects contiguous [B, H, D]"); + STD_TORCH_CHECK(value_scale.size(0) == value.size(0) && + value_scale.size(1) == value.size(1) && + value_scale.size(2) == value.size(3), + "gfx12 value scale shape must match [B, H, D]"); + + const int64_t batch = value.size(0); + const int64_t heads = value.size(1); + const int64_t seq_len = value.size(2); + const int64_t head_dim = value.size(3); + Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); + + dim3 block(256); + dim3 grid((seq_len + 127) / 128, (head_dim + 15) / 16, batch * heads); + if (value.scalar_type() == ScalarType::Half) { + transpose_value_fp8_scaled_hnd_kernel<__half><<>>( + reinterpret_cast(value.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + reinterpret_cast(output.data_ptr()), + batch * heads, seq_len, head_dim); + } else { + transpose_value_fp8_scaled_hnd_kernel<__hip_bfloat16><<>>( + reinterpret_cast(value.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + reinterpret_cast(output.data_ptr()), + batch * heads, seq_len, head_dim); + } + hip_kernel_launch_check(); + return output; +} + +Tensor transpose_value_f16_hnd_gfx12(Tensor value) { + return transpose_value_hnd_gfx12<__half, false>(value); +} + +template +std::vector prepare_qkv_hnd_gfx12( + Tensor query, + Tensor key, + Tensor value) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda(), + "gfx12 QKV preparation expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "gfx12 QKV preparation expects [B, H, S, D]"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous(), + "gfx12 QKV preparation expects contiguous HND tensors"); + STD_TORCH_CHECK(query.scalar_type() == key.scalar_type() && query.scalar_type() == value.scalar_type(), + "gfx12 QKV preparation expects matching input dtypes"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || query.scalar_type() == ScalarType::BFloat16, + "gfx12 QKV preparation supports fp16/bf16 input"); + STD_TORCH_CHECK(query.size(0) == key.size(0) && query.size(0) == value.size(0), + "Q/K/V batch size mismatch"); + STD_TORCH_CHECK(query.size(3) == key.size(3) && query.size(3) == value.size(3), + "Q/K/V head_dim mismatch"); + STD_TORCH_CHECK(key.size(1) == value.size(1) && key.size(2) == value.size(2), + "K/V shape mismatch"); + + const int64_t batch = query.size(0); + const int64_t q_heads = query.size(1); + const int64_t q_len = query.size(2); + const int64_t kv_heads = key.size(1); + const int64_t kv_len = key.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "gfx12 QKV preparation supports head_dim 16, 64, or 128"); + STD_TORCH_CHECK((q_len % 64) == 0 && (kv_len % 64) == 0, + "gfx12 QKV preparation requires sequence lengths divisible by 64"); + + const int q_groups = static_cast((q_len + 31) / 32); + const int q_task_groups = (q_groups + 1) / 2; + const int k_groups = static_cast((kv_len + 63) / 64); + const bool fuse_self_qkv = + q_heads == kv_heads && q_len == kv_len && q_task_groups == k_groups; + Tensor query_out = new_empty_like(query, query.sizes(), ScalarType::Char); + Tensor key_out = new_empty_like(key, key.sizes(), ScalarType::Char); + Tensor query_scale = + new_empty_like(query, {batch, q_heads, q_groups}, ScalarType::Float); + Tensor key_scale = + new_empty_like(key, {batch, kv_heads, k_groups}, ScalarType::Float); + const ScalarType value_dtype = ToFp8 ? ScalarType::Byte : ScalarType::Half; + Tensor value_out = new_empty_like(value, {batch, kv_heads, head_dim, kv_len}, value_dtype); + + constexpr int D64PrepThreads = 256; + const dim3 block(head_dim == 64 ? D64PrepThreads : 256); + const dim3 grid(fuse_self_qkv ? k_groups : (q_task_groups + k_groups), + std::max(q_heads, kv_heads), + batch); + const bool use_qkv_static_1024 = + ToFp8 && head_dim == 128 && q_len == 1024 && kv_len == 1024 && + fuse_self_qkv; + if (query.scalar_type() == ScalarType::Half) { + if (head_dim == 16) { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else if (head_dim == 64) { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else { + if (use_qkv_static_1024) { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, + true, true, false, false, false, + true, 1024, 1024><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + true); + } else { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } + } + } else { + if (head_dim == 16) { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else if (head_dim == 64) { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else { + if (use_qkv_static_1024) { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, + true, true, false, false, false, + true, 1024, 1024><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + true); + } else { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } + } + } + hip_kernel_launch_check(); + return {query_out, query_scale, key_out, key_scale, value_out}; +} + +template +std::vector prepare_qkv_hnd_packed_gfx12( + Tensor query, + Tensor key, + Tensor value) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda(), + "packed gfx12 QKV preparation expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "packed gfx12 QKV preparation expects [B, H, S, D]"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous(), + "packed gfx12 QKV preparation expects contiguous HND tensors"); + STD_TORCH_CHECK(query.scalar_type() == key.scalar_type() && query.scalar_type() == value.scalar_type(), + "packed gfx12 QKV preparation expects matching input dtypes"); + + const int64_t batch = query.size(0); + const int64_t q_heads = query.size(1); + const int64_t q_len = query.size(2); + const int64_t kv_heads = key.size(1); + const int64_t kv_len = key.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "packed gfx12 QKV preparation supports head_dim 16, 64, or 128"); + STD_TORCH_CHECK(!PrepackF16KLane || (!ToFp8 && head_dim == 64), + "fp16 lane-major K prepack requires D64 fp16/bf16 keys"); + STD_TORCH_CHECK(!PrepackFp8Lane || (ToFp8 && TransposeValue && head_dim == 64), + "fp8 lane-major K/V prepack requires transposed D64 fp8 values"); + STD_TORCH_CHECK(!PrepackFp8Lane || ((key.size(2) % 64) == 0), + "fp8 lane-major K/V prepack requires kv_len divisible by 64"); + STD_TORCH_CHECK(key.size(0) == batch && value.size(0) == batch, + "Q/K/V batch size mismatch"); + STD_TORCH_CHECK(key.size(2) == kv_len && value.size(2) == kv_len && value.size(1) == kv_heads, + "K/V shape mismatch"); + STD_TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim, + "Q/K/V head_dim mismatch"); + + const int q_groups = static_cast((q_len + 31) / 32); + const int q_task_groups = (q_groups + 1) / 2; + const int k_groups = static_cast((kv_len + 63) / 64); + const bool fuse_self_qkv = + q_heads == kv_heads && q_len == kv_len && q_task_groups == k_groups; + + const int64_t q_numel = query.numel(); + const int64_t k_numel = key.numel(); + const int64_t v_numel = batch * kv_heads * head_dim * kv_len; + const int64_t q_scale_numel = batch * q_heads * q_groups; + const int64_t k_scale_numel = batch * kv_heads * k_groups; + const bool pack_scales_with_bytes = ToFp8; + const int64_t value_byte_offset = q_numel + k_numel; + const int64_t scale_byte_offset = value_byte_offset + (ToFp8 ? v_numel : 0); + Tensor byte_workspace = + new_empty_like(query, {scale_byte_offset + + (pack_scales_with_bytes ? (q_scale_numel + k_scale_numel) * 4 : 0)}, ScalarType::Byte); + uint8_t* byte_ptr = reinterpret_cast(byte_workspace.data_ptr()); + int8_t* query_ptr = reinterpret_cast(byte_ptr); + int8_t* key_ptr = reinterpret_cast(byte_ptr + q_numel); + Tensor value_out; + OutT* value_ptr = nullptr; + if constexpr (ToFp8) { + value_ptr = reinterpret_cast(byte_ptr + value_byte_offset); + if constexpr (TransposeValue) { + value_out = from_blob_like(value_ptr, {batch, kv_heads, head_dim, kv_len}, value, ScalarType::Byte); + } else { + value_out = from_blob_like(value_ptr, {batch, kv_heads, kv_len, head_dim}, value, ScalarType::Byte); + } + } else { + value_out = new_empty_like(value, {batch, kv_heads, head_dim, kv_len}, ScalarType::Half); + value_ptr = reinterpret_cast(value_out.data_ptr()); + } + + Tensor scale_workspace; + float* scale_ptr = nullptr; + if (pack_scales_with_bytes) { + scale_ptr = reinterpret_cast(byte_ptr + scale_byte_offset); + } else { + scale_workspace = + new_empty_like(query, {q_scale_numel + k_scale_numel}, ScalarType::Float); + scale_ptr = reinterpret_cast(scale_workspace.data_ptr()); + } + + Tensor query_out = from_blob_like(query_ptr, {batch, q_heads, q_len, head_dim}, query, ScalarType::Char); + Tensor key_out = from_blob_like(key_ptr, {batch, kv_heads, kv_len, head_dim}, key, ScalarType::Char); + Tensor query_scale = from_blob_like(scale_ptr, {batch, q_heads, q_groups}, query, ScalarType::Float); + Tensor key_scale = from_blob_like(scale_ptr + q_scale_numel, {batch, kv_heads, k_groups}, key, ScalarType::Float); + + constexpr int D64PrepThreads = 256; + const dim3 block(head_dim == 64 ? D64PrepThreads : 256); + const dim3 grid(fuse_self_qkv ? k_groups : (q_task_groups + k_groups), + std::max(q_heads, kv_heads), + batch); + if (query.scalar_type() == ScalarType::Half) { + if (head_dim == 16) { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256, TransposeValue, true><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, + value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else if (head_dim == 64) { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, true, false, PrepackF16KLane, PrepackFp8Lane><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, + value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, TransposeValue, true, false, false, PrepackFp8Lane><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, + value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } + } else { + if (head_dim == 16) { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256, TransposeValue, true><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, + value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else if (head_dim == 64) { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, true, false, PrepackF16KLane, PrepackFp8Lane><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, + value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, TransposeValue, true, false, false, PrepackFp8Lane><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, + value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } + } + hip_kernel_launch_check(); + if (pack_scales_with_bytes) { + return {query_out, query_scale, key_out, key_scale, value_out, byte_workspace}; + } + return {query_out, query_scale, key_out, key_scale, value_out, byte_workspace, scale_workspace}; +} + +template +std::vector prepare_kv_hnd_packed_gfx12( + Tensor query, + Tensor key, + Tensor value) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda(), + "packed gfx12 KV preparation expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "packed gfx12 KV preparation expects [B, H, S, D]"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous(), + "packed gfx12 KV preparation expects contiguous HND tensors"); + STD_TORCH_CHECK(query.scalar_type() == key.scalar_type() && query.scalar_type() == value.scalar_type(), + "packed gfx12 KV preparation expects matching input dtypes"); + + const int64_t batch = query.size(0); + const int64_t q_heads = query.size(1); + const int64_t q_len = query.size(2); + const int64_t kv_heads = key.size(1); + const int64_t kv_len = key.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "packed gfx12 KV preparation supports head_dim 16, 64, or 128"); + STD_TORCH_CHECK(!PrepackF16VLane || (!ToFp8 && TransposeValue && head_dim == 64), + "fp16 lane-major V prepack requires transposed D64 fp16 values"); + STD_TORCH_CHECK(!PrepackF16KLane || (!ToFp8 && head_dim == 64), + "fp16 lane-major K prepack requires D64 fp16/bf16 keys"); + STD_TORCH_CHECK(!PrepackFp8Lane || (ToFp8 && TransposeValue && head_dim == 64), + "fp8 lane-major K/V prepack requires transposed D64 fp8 values"); + STD_TORCH_CHECK(!PrepackFp8Lane || ((key.size(2) % 64) == 0), + "fp8 lane-major K/V prepack requires kv_len divisible by 64"); + STD_TORCH_CHECK(!PrepackFp8VLane || + (ToFp8 && TransposeValue && (head_dim == 64 || head_dim == 128)), + "fp8 lane-major V prepack requires transposed D64/D128 fp8 values"); + STD_TORCH_CHECK(!PrepackFp8VLane || ((key.size(2) % 64) == 0), + "fp8 lane-major V prepack requires kv_len divisible by 64"); + STD_TORCH_CHECK(!PrepackFp8KLane || + (ToFp8 && TransposeValue && (head_dim == 64 || head_dim == 128)), + "fp8 lane-major K prepack requires transposed D64/D128 fp8 values"); + STD_TORCH_CHECK(!PrepackFp8KLane || ((key.size(2) % 64) == 0), + "fp8 lane-major K prepack requires kv_len divisible by 64"); + STD_TORCH_CHECK(key.size(0) == batch && value.size(0) == batch, + "Q/K/V batch size mismatch"); + STD_TORCH_CHECK(key.size(2) == kv_len && value.size(2) == kv_len && value.size(1) == kv_heads, + "K/V shape mismatch"); + STD_TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim, + "Q/K/V head_dim mismatch"); + + const int q_groups = static_cast((q_len + 31) / 32); + const int k_groups = static_cast((kv_len + 63) / 64); + const int64_t k_numel = key.numel(); + const int64_t v_numel = batch * kv_heads * head_dim * kv_len; + const int64_t k_scale_numel = batch * kv_heads * k_groups; + const int64_t value_byte_offset = k_numel; + const int64_t scale_byte_offset = value_byte_offset + (ToFp8 ? v_numel : 0); + Tensor byte_workspace = + new_empty_like(query, {scale_byte_offset + (ToFp8 ? k_scale_numel * 4 : 0)}, ScalarType::Byte); + uint8_t* byte_ptr = reinterpret_cast(byte_workspace.data_ptr()); + int8_t* key_ptr = reinterpret_cast(byte_ptr); + OutT* value_ptr = nullptr; + Tensor value_out; + if constexpr (ToFp8) { + value_ptr = reinterpret_cast(byte_ptr + value_byte_offset); + if constexpr (TransposeValue) { + value_out = from_blob_like(value_ptr, {batch, kv_heads, head_dim, kv_len}, value, ScalarType::Byte); + } else { + value_out = from_blob_like(value_ptr, {batch, kv_heads, kv_len, head_dim}, value, ScalarType::Byte); + } + } else { + if constexpr (PrepackF16VLane) { + value_out = new_empty_like(value, {batch, kv_heads, k_groups, 4, 4, 32, 8}, ScalarType::Half); + } else { + value_out = new_empty_like(value, {batch, kv_heads, head_dim, kv_len}, ScalarType::Half); + } + value_ptr = reinterpret_cast(value_out.data_ptr()); + } + Tensor scale_workspace; + float* key_scale_ptr = nullptr; + if constexpr (ToFp8) { + key_scale_ptr = reinterpret_cast(byte_ptr + scale_byte_offset); + } else { + scale_workspace = new_empty_like(query, {k_scale_numel}, ScalarType::Float); + key_scale_ptr = reinterpret_cast(scale_workspace.data_ptr()); + } + Tensor key_out = from_blob_like(key_ptr, {batch, kv_heads, kv_len, head_dim}, key, ScalarType::Char); + Tensor key_scale = from_blob_like(key_scale_ptr, {batch, kv_heads, k_groups}, key, ScalarType::Float); + + constexpr int D64PrepThreads = 256; + const dim3 block(head_dim == 64 ? D64PrepThreads : 256); + const dim3 grid(k_groups, kv_heads, batch); + const bool use_kv1 = q_len <= 4096; + const bool use_kv_static_1024 = + ToFp8 && TransposeValue && use_kv1 && + q_len == 1024 && kv_len == 1024 && + !PrepackFp8VLane && !PrepackFp8KLane; + if (query.scalar_type() == ScalarType::Half) { + if (head_dim == 16) { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256, TransposeValue, false><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + false); + } else if (head_dim == 64) { + if constexpr (ToFp8 && TransposeValue) { + if (use_kv1) { + if (use_kv_static_1024) { + prepare_kv_hnd_fp8_kernel<__half, 1, false, PrepackFp8Lane, 64, false, false, false, true, 256, 1024><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } else { + prepare_kv_hnd_fp8_kernel<__half, 1, false, PrepackFp8Lane><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); + prepare_kv_hnd_fp8_kernel<__half, 2, false, PrepackFp8Lane><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, false, PrepackF16VLane, PrepackF16KLane><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + false); + } + } else { + if constexpr (ToFp8 && TransposeValue) { + if (use_kv1) { + if (use_kv_static_1024) { + prepare_kv_hnd_fp8_kernel<__half, 1, false, false, 128, false, false, false, true, 256, 1024><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } else { + prepare_kv_hnd_fp8_kernel<__half, 1, false, false, 128><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); + prepare_kv_hnd_fp8_kernel<__half, 2, false, false, 128><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, TransposeValue, false><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + false); + } + } + } else { + if (head_dim == 16) { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256, TransposeValue, false><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + false); + } else if (head_dim == 64) { + if constexpr (ToFp8 && TransposeValue) { + if (use_kv1) { + if (use_kv_static_1024) { + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, PrepackFp8Lane, 64, false, false, false, true, 256, 1024><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } else { + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, PrepackFp8Lane><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 2, false, PrepackFp8Lane><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, false, PrepackF16VLane, PrepackF16KLane><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + false); + } + } else { + if constexpr (ToFp8 && TransposeValue) { + if (use_kv1) { + if (use_kv_static_1024) { + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, false, 128, false, false, false, true, 256, 1024><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } else { + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, false, 128><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 2, false, false, 128><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), + batch, kv_heads, kv_len, k_groups); + } + } else { + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, TransposeValue, false><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + false); + } + } + } + hip_kernel_launch_check(); + if constexpr (ToFp8) { + return {key_out, key_scale, value_out, byte_workspace}; + } + return {key_out, key_scale, value_out, byte_workspace, scale_workspace}; +} + +template +__global__ void prepare_k_hnd_kernel( + const T* __restrict__ key, + int8_t* __restrict__ key_out, + float* __restrict__ key_scale, + const int64_t batch, + const int64_t kv_heads, + const int64_t kv_len, + const int k_groups) { + constexpr int PackElems = 8; + constexpr int KRows = 64; + const int local_group = blockIdx.x; + const int head = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const int64_t base_row = static_cast(local_group) * KRows; + if (b >= batch || head >= kv_heads || local_group >= k_groups || base_row >= kv_len) { + return; + } + + __shared__ float shared_amax; + constexpr int kv_packs = (KRows * HeadDim) / PackElems; + float local_amax = 0.0000001f; + for (int pack = tid; pack < kv_packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (seq < kv_len) { + const int64_t off = + ((static_cast(b) * kv_heads + head) * kv_len + seq) * HeadDim + d; + const uint4 raw = *reinterpret_cast(key + off); + const T* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + local_amax = fmaxf(local_amax, fabsf(value_to_float(values[i]))); + } + } + } + + const float block_amax = vllm::blockReduceMax(local_amax); + if (tid == 0) { + shared_amax = block_amax; + key_scale[(static_cast(b) * kv_heads + head) * k_groups + local_group] = + block_amax / 127.0f; + } + __syncthreads(); + const float inv_scale = 127.0f / shared_amax; + + for (int pack = tid; pack < kv_packs; pack += Threads) { + const int elem_base = pack * PackElems; + const int row = elem_base / HeadDim; + const int d = elem_base - row * HeadDim; + const int64_t seq = base_row + row; + if (seq < kv_len) { + const int64_t off = + ((static_cast(b) * kv_heads + head) * kv_len + seq) * HeadDim + d; + const uint4 raw = *reinterpret_cast(key + off); + const T* values = reinterpret_cast(&raw); + char4 out0; + char4 out1; + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale); + *reinterpret_cast(key_out + off) = out0; + *reinterpret_cast(key_out + off + 4) = out1; + } + } +} + +std::vector prepare_k_hnd_packed_gfx12(Tensor key) { + STD_TORCH_CHECK(key.is_cuda(), "packed gfx12 K preparation expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(key.dim() == 4, "packed gfx12 K preparation expects [B, H, S, D]"); + STD_TORCH_CHECK(key.is_contiguous(), "packed gfx12 K preparation expects contiguous HND tensors"); + STD_TORCH_CHECK(key.scalar_type() == ScalarType::Half || key.scalar_type() == ScalarType::BFloat16, + "packed gfx12 K preparation supports fp16/bf16 input"); + const int64_t batch = key.size(0); + const int64_t kv_heads = key.size(1); + const int64_t kv_len = key.size(2); + const int64_t head_dim = key.size(3); + STD_TORCH_CHECK(head_dim == 64 || head_dim == 128, + "packed gfx12 K preparation supports head_dim 64 or 128"); + STD_TORCH_CHECK((kv_len % 64) == 0, + "packed gfx12 K preparation requires sequence length divisible by 64"); + + const int k_groups = static_cast((kv_len + 63) / 64); + const int64_t k_numel = key.numel(); + Tensor byte_workspace = new_empty_like(key, {k_numel}, ScalarType::Byte); + Tensor scale_workspace = + new_empty_like(key, {batch * kv_heads * k_groups}, ScalarType::Float); + int8_t* key_ptr = reinterpret_cast(reinterpret_cast(byte_workspace.data_ptr())); + Tensor key_out = from_blob_like(key_ptr, {batch, kv_heads, kv_len, head_dim}, key, ScalarType::Char); + Tensor key_scale = from_blob_like(reinterpret_cast(scale_workspace.data_ptr()), {batch, kv_heads, k_groups}, key, ScalarType::Float); + + constexpr int Threads = 256; + const dim3 block(Threads); + const dim3 grid(k_groups, kv_heads, batch); + if (key.scalar_type() == ScalarType::Half) { + if (head_dim == 64) { + prepare_k_hnd_kernel<__half, 64, Threads><<>>( + reinterpret_cast(key.data_ptr()), + key_ptr, reinterpret_cast(scale_workspace.data_ptr()), + batch, kv_heads, kv_len, k_groups); + } else { + prepare_k_hnd_kernel<__half, 128, Threads><<>>( + reinterpret_cast(key.data_ptr()), + key_ptr, reinterpret_cast(scale_workspace.data_ptr()), + batch, kv_heads, kv_len, k_groups); + } + } else { + if (head_dim == 64) { + prepare_k_hnd_kernel<__hip_bfloat16, 64, Threads><<>>( + reinterpret_cast(key.data_ptr()), + key_ptr, reinterpret_cast(scale_workspace.data_ptr()), + batch, kv_heads, kv_len, k_groups); + } else { + prepare_k_hnd_kernel<__hip_bfloat16, 128, Threads><<>>( + reinterpret_cast(key.data_ptr()), + key_ptr, reinterpret_cast(scale_workspace.data_ptr()), + batch, kv_heads, kv_len, k_groups); + } + } + hip_kernel_launch_check(); + return {key_out, key_scale, byte_workspace, scale_workspace}; +} + +static int select_fp8_d64_block_rows_gfx12( + const int64_t q_len, + const bool is_causal, + const bool value_transposed_hnd) { + if (is_causal) { + if (q_len <= 64) { + return 64; + } + return 128; + } + if (q_len <= 64) { + return 64; + } + if ((q_len % 256) == 0 && (q_len >= 2048 || value_transposed_hnd)) { + return 256; + } + return 128; +} + +__global__ void convert_f16_to_bf16_kernel( + const __half* __restrict__ input, + __hip_bfloat16* __restrict__ output, + const int64_t numel) { + const int64_t idx = (static_cast(blockIdx.x) * blockDim.x + threadIdx.x) * 2; + if (idx + 1 < numel) { + const __half2 h2 = *reinterpret_cast(input + idx); + const __hip_bfloat162 b2 = __float22bfloat162_rn(__half22float2(h2)); + *reinterpret_cast<__hip_bfloat162*>(output + idx) = b2; + } else if (idx < numel) { + output[idx] = __float2bfloat16(__half2float(input[idx])); + } +} + +Tensor convert_f16_to_bf16_gfx12(Tensor input) { + STD_TORCH_CHECK(input.is_cuda(), "gfx12 f16 to bf16 conversion expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(input.scalar_type() == ScalarType::Half, + "gfx12 f16 to bf16 conversion expects fp16 input"); + STD_TORCH_CHECK(input.is_contiguous(), + "gfx12 f16 to bf16 conversion expects contiguous input"); + Tensor output = new_empty_like(input, input.sizes(), ScalarType::BFloat16); + const int64_t numel = input.numel(); + const int threads = 256; + const dim3 block(threads); + const dim3 grid((numel + threads * 2 - 1) / (threads * 2)); + convert_f16_to_bf16_kernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(output.data_ptr()), + numel); + hip_kernel_launch_check(); + return output; +} + +std::vector quant_q_nhd_per_warp_gfx12(Tensor query) { + STD_TORCH_CHECK(query.is_cuda(), "gfx12 NHD Q quantization expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(query.dim() == 4, "gfx12 NHD Q quantization expects [B, S, H, D]"); + STD_TORCH_CHECK(query.is_contiguous(), "gfx12 NHD Q quantization expects contiguous NHD input"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || query.scalar_type() == ScalarType::BFloat16, + "gfx12 NHD Q quantization supports fp16/bf16 input"); + const int64_t batch = query.size(0); + const int64_t q_len = query.size(1); + const int64_t q_heads = query.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "gfx12 NHD Q quantization supports head_dim 16, 64, or 128"); + const int q_scale_groups = static_cast(((q_len + 127) / 128) * 4); + Tensor query_out = new_empty_like(query, query.sizes(), ScalarType::Char); + Tensor query_scale = new_empty_like(query, {batch, q_heads, q_scale_groups}, ScalarType::Float); + + const dim3 block(256); + const dim3 grid((q_scale_groups + 1) / 2, q_heads, batch); + if (query.scalar_type() == ScalarType::Half) { + if (head_dim == 16) { + quant_q_nhd_per_warp_kernel<__half, 16><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), + batch, q_len, q_heads, q_scale_groups); + } else if (head_dim == 64) { + quant_q_nhd_per_warp_kernel<__half, 64><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), + batch, q_len, q_heads, q_scale_groups); + } else { + quant_q_nhd_per_warp_kernel<__half, 128><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), + batch, q_len, q_heads, q_scale_groups); + } + } else { + if (head_dim == 16) { + quant_q_nhd_per_warp_kernel<__hip_bfloat16, 16><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), + batch, q_len, q_heads, q_scale_groups); + } else if (head_dim == 64) { + quant_q_nhd_per_warp_kernel<__hip_bfloat16, 64><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), + batch, q_len, q_heads, q_scale_groups); + } else { + quant_q_nhd_per_warp_kernel<__hip_bfloat16, 128><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(query_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), + batch, q_len, q_heads, q_scale_groups); + } + } + hip_kernel_launch_check(); + return {query_out, query_scale}; +} + +std::vector quant_qk_int8_hnd_gfx12(Tensor query, Tensor key) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda(), "gfx12 Q/K quantization expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4, "gfx12 Q/K quantization expects [B, H, S, D]"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous(), + "gfx12 Q/K quantization expects contiguous HND tensors"); + STD_TORCH_CHECK(query.scalar_type() == key.scalar_type(), + "gfx12 Q/K quantization expects matching Q/K dtypes"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || query.scalar_type() == ScalarType::BFloat16, + "gfx12 Q/K quantization supports fp16/bf16 input"); + STD_TORCH_CHECK(query.size(0) == key.size(0), "Q/K batch size mismatch"); + STD_TORCH_CHECK(query.size(3) == key.size(3), "Q/K head_dim mismatch"); + const int64_t batch = query.size(0); + const int64_t q_heads = query.size(1); + const int64_t q_len = query.size(2); + const int64_t kv_heads = key.size(1); + const int64_t kv_len = key.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "gfx12 native Q/K quantization supports head_dim 16, 64, or 128"); + STD_TORCH_CHECK((q_len % 64) == 0 && (kv_len % 64) == 0, + "gfx12 native Q/K quantization requires sequence lengths divisible by 64"); + + const int q_groups = static_cast((q_len + 31) / 32); + const int k_groups = static_cast((kv_len + 63) / 64); + Tensor query_out = new_empty_like(query, query.sizes(), ScalarType::Char); + Tensor key_out = new_empty_like(key, key.sizes(), ScalarType::Char); + Tensor query_scale = new_empty_like(query, {batch, q_heads, q_groups}, ScalarType::Float); + Tensor key_scale = new_empty_like(key, {batch, kv_heads, k_groups}, ScalarType::Float); + + const dim3 block(256); + const dim3 grid(q_groups + k_groups, std::max(q_heads, kv_heads), batch); + if (query.scalar_type() == ScalarType::Half) { + if (head_dim == 16) { + quant_qk_int8_hnd_kernel<__half, 16><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); + } else if (head_dim == 64) { + quant_qk_int8_hnd_kernel<__half, 64><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); + } else { + quant_qk_int8_hnd_kernel<__half, 128><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); + } + } else { + if (head_dim == 16) { + quant_qk_int8_hnd_kernel<__hip_bfloat16, 16><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); + } else if (head_dim == 64) { + quant_qk_int8_hnd_kernel<__hip_bfloat16, 64><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); + } else { + quant_qk_int8_hnd_kernel<__hip_bfloat16, 128><<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); + } + } + hip_kernel_launch_check(); + return {query_out, query_scale, key_out, key_scale}; +} + +static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + int tensor_layout, + int is_causal, + float sm_scale, + int64_t valid_kv_len, + const float* value_scale_ptr = nullptr); + +Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( + Tensor query, + Tensor key, + Tensor value, + int64_t is_causal, + int64_t value_is_fp8, + int64_t use_raw_f16_value, + double sm_scale, + int64_t valid_kv_len) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda(), + "native gfx12 prepare+attention expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "native gfx12 prepare+attention expects HND [B, H, S, D] tensors"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous(), + "native gfx12 prepare+attention expects contiguous HND tensors"); + STD_TORCH_CHECK(query.scalar_type() == key.scalar_type() && query.scalar_type() == value.scalar_type(), + "native gfx12 prepare+attention expects matching input dtypes"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || query.scalar_type() == ScalarType::BFloat16, + "native gfx12 prepare+attention supports fp16/bf16 input"); + + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(value_is_fp8 || head_dim == 16 || head_dim == 64, + "native gfx12 fp16 value prepare+attention supports head_dim 16 or 64"); + STD_TORCH_CHECK(value_is_fp8 || !use_raw_f16_value || query.scalar_type() == ScalarType::Half, + "raw fp16 value path requires fp16 input"); + const int64_t batch = query.size(0); + const int64_t q_heads = query.size(1); + const int64_t q_len = query.size(2); + const int64_t kv_heads = key.size(1); + const int64_t padded_kv_len = key.size(2); + const int64_t kv_len = valid_kv_len > 0 ? valid_kv_len : padded_kv_len; + STD_TORCH_CHECK(kv_len > 0 && kv_len <= padded_kv_len, + "valid_kv_len must be in (0, padded_kv_len]"); + STD_TORCH_CHECK(key.size(0) == batch && value.size(0) == batch, + "Q/K/V batch size mismatch"); + STD_TORCH_CHECK(key.size(3) == head_dim && value.size(3) == head_dim, + "Q/K/V head_dim mismatch"); + STD_TORCH_CHECK(value.size(2) == padded_kv_len && value.size(1) == kv_heads, + "K/V shape mismatch"); + STD_TORCH_CHECK((q_heads % kv_heads) == 0, "q_heads must be divisible by kv_heads"); + STD_TORCH_CHECK((q_len % 64) == 0 && (padded_kv_len % 64) == 0, + "native gfx12 prepare+attention requires sequence lengths divisible by 64"); + STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, + "native gfx12 causal prepare+attention requires q_len == kv_len"); + + const auto output_dtype = + (value_is_fp8 && query.scalar_type() == ScalarType::BFloat16) ? ScalarType::BFloat16 : ScalarType::Half; + Tensor output; + if (!value_is_fp8) { + output = new_empty_like(query, query.sizes(), output_dtype); + } + const bool prefer_prepared_f16_causal = + !value_is_fp8 && head_dim == 64 && is_causal && + query.scalar_type() == ScalarType::Half && q_len >= 4096; + const bool auto_f16_fused_q = + !value_is_fp8 && (head_dim == 16 || head_dim == 64) && + query.scalar_type() == ScalarType::Half && + (is_causal || q_len >= 2048 || (head_dim == 64 && q_len >= 1024)) && + q_len <= 8192 && + !prefer_prepared_f16_causal; + const bool auto_f16_raw_qk = + !value_is_fp8 && is_causal && head_dim == 16 && + query.scalar_type() == ScalarType::Half && q_len <= 2048; + if (!value_is_fp8 && is_causal && (head_dim == 16 || head_dim == 64) && + query.scalar_type() == ScalarType::Half && + auto_f16_raw_qk) { + int block_rows = q_len <= 64 ? 64 : 128; + STD_TORCH_CHECK((q_len % block_rows) == 0, + "native raw-QK fp16 path requires q_len to be a multiple of block rows"); + const dim3 block(block_rows); + const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); + const bool use_f16_pv_accum = + auto_f16_raw_qk && q_len <= 1024; + constexpr bool use_f16_tvload = false; +#define SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL(BR_, TVLOAD_, PAD_, F16ACC_) \ + if (head_dim == 16) { \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, false, PAD_, true, TVLOAD_, F16ACC_, __half, true, __half, true, false, false, false, false, 16><<>>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, nullptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + key.stride(0), key.stride(2), key.stride(1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, 0, 0, \ + kHND, sm_scale); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, false, PAD_, true, TVLOAD_, F16ACC_, __half, true, __half, true><<>>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, nullptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + key.stride(0), key.stride(2), key.stride(1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, 0, 0, \ + kHND, sm_scale); \ + } +#define SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL(BR_) \ + if (use_f16_tvload) { \ + if (use_f16_pv_accum) { \ + SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL(BR_, true, 4, true); \ + } else { \ + SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL(BR_, true, 4, false); \ + } \ + } else if (use_f16_pv_accum) { \ + SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL(BR_, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD, true); \ + } else { \ + SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL(BR_, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD, false); \ + } + if (block_rows == 64) { + SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL(64); + } else if (block_rows == 256) { + SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL(256); + } else if (block_rows == 512) { + SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL(512); + } else if (block_rows == 1024) { + SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL(1024); + } else { + SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL(128); + } +#undef SAGEATTN_DISPATCH_RAW_QK_F16_CAUSAL +#undef SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL + hip_kernel_launch_check(); + return output; + } + if (value_is_fp8) { + int block_rows = head_dim == 64 ? + select_fp8_d64_block_rows_gfx12(q_len, is_causal, true) : + (q_len <= 64 ? 64 : 128); + int block_cols = 64; + if (head_dim == 64 && !is_causal && q_len == 1024) { + block_rows = 128; + } + if (head_dim == 16 && is_causal && q_len <= 1024) { + block_rows = 64; + } + const bool transpose_fp8_value = true; + const bool use_fused_q = + (head_dim == 16 || head_dim == 64 || head_dim == 128) && transpose_fp8_value && + query.scalar_type() == ScalarType::Half && output_dtype == ScalarType::Half && + ((head_dim == 16 && block_cols == 64 && is_causal && q_len == 1024) || + (head_dim == 64 && block_cols == 64 && is_causal && + q_len >= 1024) || + (head_dim == 128 && block_cols <= 64 && is_causal && q_len >= 1024) || + (head_dim == 128 && block_cols <= 64 && !is_causal && q_len <= 512)); + const bool use_fp8_kvlane = + use_fused_q && head_dim == 64 && block_cols == 64 && transpose_fp8_value && + true; + const bool auto_fp8_streamcols4 = + use_fused_q && head_dim == 64 && is_causal && block_cols == 64; + const bool use_fp8_streamcols4 = + use_fused_q && block_cols == 64 && auto_fp8_streamcols4; + const dim3 block(block_rows); + const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); + std::vector prepared; + int8_t* fused_key_ptr = nullptr; + uint8_t* fused_value_ptr = nullptr; + float* fused_k_scale_ptr = nullptr; + int64_t fused_k_stride_b = 0; + int64_t fused_k_stride_n = 0; + int64_t fused_k_stride_h = 0; + int64_t fused_v_stride_b = 0; + int64_t fused_v_stride_n = 0; + int64_t fused_v_stride_h = 0; + int64_t fused_ks_stride_b = 0; + int64_t fused_ks_stride_h = 0; + if (head_dim == 16 && !use_fused_q && transpose_fp8_value) { + prepared = prepare_qkv_hnd_gfx12(query, key, value); + } else if (use_fused_q) { + if (use_fp8_kvlane) { + prepared = prepare_kv_hnd_packed_gfx12( + query, key, value); + } else { + prepared = prepare_kv_hnd_packed_gfx12(query, key, value); + } + } else if (head_dim == 128 && is_causal && q_len == 1024 && transpose_fp8_value) { + prepared = prepare_qkv_hnd_gfx12(query, key, value); + } else if (transpose_fp8_value) { + prepared = prepare_qkv_hnd_packed_gfx12(query, key, value); + } else { + prepared = prepare_qkv_hnd_packed_gfx12(query, key, value); + } + output = new_empty_like(query, query.sizes(), output_dtype); + if (use_fused_q) { + fused_key_ptr = reinterpret_cast(prepared[0].data_ptr()); + fused_value_ptr = reinterpret_cast(prepared[2].data_ptr()); + fused_k_scale_ptr = reinterpret_cast(prepared[1].data_ptr()); + fused_k_stride_b = prepared[0].stride(0); + fused_k_stride_n = prepared[0].stride(2); + fused_k_stride_h = prepared[0].stride(1); + fused_v_stride_b = prepared[2].stride(0); + fused_v_stride_n = prepared[2].stride(2); + fused_v_stride_h = prepared[2].stride(1); + fused_ks_stride_b = prepared[1].stride(0); + fused_ks_stride_h = prepared[1].stride(1); + } +#define SAGEATTN_LAUNCH_PREPARED_FP8_EX(BC_, HD_, BR_, VT_, CAUSAL_, OUT_T_, LOWP_) \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(prepared[0].data_ptr()), reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast(prepared[4].data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(prepared[1].data_ptr()), reinterpret_cast(prepared[3].data_ptr()), \ + nullptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + prepared[4].stride(0), prepared[4].stride(2), prepared[4].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + prepared[1].stride(0), prepared[1].stride(1), \ + prepared[3].stride(0), prepared[3].stride(1), \ + kHND, sm_scale) +#define SAGEATTN_LAUNCH_PREPARED_FP8(BC_, HD_, BR_, VT_, CAUSAL_, OUT_T_) \ + SAGEATTN_LAUNCH_PREPARED_FP8_EX(BC_, HD_, BR_, VT_, CAUSAL_, OUT_T_, false) + +#define SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SLICE(BC_, HD_, BR_, CAUSAL_, KVLANE_, SC_, KLANE_, VLANE_, VBASE_, VTILES_) \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), fused_key_ptr, \ + fused_value_ptr, \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, fused_k_scale_ptr, \ + nullptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + fused_k_stride_b, fused_k_stride_n, fused_k_stride_h, \ + fused_v_stride_b, fused_v_stride_n, fused_v_stride_h, \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + fused_ks_stride_b, fused_ks_stride_h, \ + kHND, sm_scale) +#define SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SC(BC_, HD_, BR_, CAUSAL_, KVLANE_, SC_) \ + SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SLICE(BC_, HD_, BR_, CAUSAL_, KVLANE_, SC_, false, false, 0, ((HD_) / 16)) +#define SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL(BC_, HD_, BR_, CAUSAL_, KVLANE_) \ + if constexpr ((BC_) == 64) { \ + if (use_fp8_streamcols4) { \ + SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SC(BC_, HD_, BR_, CAUSAL_, KVLANE_, 4); \ + } else { \ + SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SC(BC_, HD_, BR_, CAUSAL_, KVLANE_, 0); \ + } \ + } else { \ + SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SC(BC_, HD_, BR_, CAUSAL_, KVLANE_, 0); \ + } +#define SAGEATTN_LAUNCH_FUSED_Q_FP8(BC_, HD_, BR_, CAUSAL_) \ + if (use_fp8_kvlane) { \ + SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL(BC_, HD_, BR_, CAUSAL_, true); \ + } else { \ + SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL(BC_, HD_, BR_, CAUSAL_, false); \ + } + +#define SAGEATTN_DISPATCH_PREPARED_FP8_VT_BC(BC_, VT_, OUT_T_) \ + if (head_dim == 16) { \ + if (block_rows == 64) { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 16, 64, VT_, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 16, 64, VT_, false, OUT_T_); } \ + } else if (block_rows == 256) { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 16, 256, VT_, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 16, 256, VT_, false, OUT_T_); } \ + } else { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 16, 128, VT_, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 16, 128, VT_, false, OUT_T_); } \ + } \ + } else if (head_dim == 128) { \ + if (block_rows == 64) { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 128, 64, true, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 128, 64, true, false, OUT_T_); } \ + } else { \ + if (is_causal) { \ + if constexpr ((BC_) == 64) { \ + if (q_len == 1024) { \ + SAGEATTN_LAUNCH_PREPARED_FP8_EX(BC_, 128, 128, true, true, OUT_T_, true); \ + } else { \ + SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 128, 128, true, true, OUT_T_); \ + } \ + } else { \ + SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 128, 128, true, true, OUT_T_); \ + } \ + } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 128, 128, true, false, OUT_T_); } \ + } \ + } else if (block_rows == 256) { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 64, 256, VT_, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 64, 256, VT_, false, OUT_T_); } \ + } else if (block_rows == 64) { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 64, 64, VT_, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 64, 64, VT_, false, OUT_T_); } \ + } else { \ + if (is_causal) { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 64, 128, VT_, true, OUT_T_); } \ + else { SAGEATTN_LAUNCH_PREPARED_FP8(BC_, 64, 128, VT_, false, OUT_T_); } \ + } + +#define SAGEATTN_DISPATCH_PREPARED_FP8_TV(OUT_T_) \ + SAGEATTN_DISPATCH_PREPARED_FP8_VT_BC(64, true, OUT_T_) + + if (use_fused_q) { +#define SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD_CAUSAL(BC_, HD_, CAUSAL_) \ + if (block_rows == 64) { SAGEATTN_LAUNCH_FUSED_Q_FP8(BC_, HD_, 64, CAUSAL_); } \ + else if (block_rows == 256) { SAGEATTN_LAUNCH_FUSED_Q_FP8(BC_, HD_, 256, CAUSAL_); } \ + else { SAGEATTN_LAUNCH_FUSED_Q_FP8(BC_, HD_, 128, CAUSAL_); } +#define SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD(BC_, HD_) \ + if (is_causal) { SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD_CAUSAL(BC_, HD_, true); } \ + else { SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD_CAUSAL(BC_, HD_, false); } + if (head_dim == 16) { + SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD(64, 16); + } else if (head_dim == 128) { + SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD(64, 128); + } else { + SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD(64, 64); + } +#undef SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD +#undef SAGEATTN_DISPATCH_FUSED_Q_FP8_BC_HD_CAUSAL + } else if (output.scalar_type() == ScalarType::BFloat16) { + SAGEATTN_DISPATCH_PREPARED_FP8_TV(__hip_bfloat16); + } else { + SAGEATTN_DISPATCH_PREPARED_FP8_TV(__half); + } +#undef SAGEATTN_DISPATCH_PREPARED_FP8_TV +#undef SAGEATTN_DISPATCH_PREPARED_FP8_VT_BC +#undef SAGEATTN_LAUNCH_FUSED_Q_FP8 +#undef SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL +#undef SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SC +#undef SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SLICE +#undef SAGEATTN_LAUNCH_PREPARED_FP8 +#undef SAGEATTN_LAUNCH_PREPARED_FP8_EX + } else if (auto_f16_fused_q) { + const bool auto_f16_1q_short = + head_dim == 64 && is_causal && q_len == 1024; + const bool use_f16_fused_q_1q_tv = + head_dim == 64 && auto_f16_1q_short; + const bool use_f16_fused_q_1q = + head_dim == 64 && auto_f16_1q_short; + const bool use_f16_raw_value = + head_dim == 64 && + (use_raw_f16_value || + (use_f16_fused_q_1q && !use_f16_fused_q_1q_tv)); + int block_rows = q_len <= 64 ? 64 : 128; + if ((q_len % 256) == 0) { + if ((!is_causal && q_len >= 1024) || q_len <= 512 || + q_len >= 8192) { + block_rows = 256; + } + } + if (head_dim == 64 && is_causal && !use_f16_fused_q_1q && + q_len >= 2048 && (q_len % 256) == 0) { + block_rows = 256; + } + int block_cols = 64; + if (use_f16_fused_q_1q && block_rows != 128) { + block_rows = 64; + } + const dim3 block(use_f16_fused_q_1q ? (block_rows / 16) * 32 : block_rows); + const bool use_f16_flat_q_schedule = + head_dim == 64 && is_causal && !use_f16_fused_q_1q && + q_len >= 2048; + const int64_t q_blocks = (q_len + block_rows - 1) / block_rows; + const dim3 grid(q_blocks, q_heads, batch); + const dim3 grid_f16_flat(q_blocks * q_heads * batch); + constexpr bool use_f16_pv_accum = true; + const bool use_f16_pv_ordered_qk = + use_f16_pv_accum && !use_f16_raw_value && + q_len >= 1024; + const bool auto_f16_lane_qk = + head_dim == 64 && is_causal && (q_len == 2048 || q_len == 4096); + const bool use_f16_vlane = + (((q_len >= 1024 && q_len <= 2048) || q_len >= 8192) || + auto_f16_lane_qk); + const bool use_f16_streamk = + head_dim == 64 && is_causal && q_len == 2048 && block_rows == 256; + const bool use_f16_klane = + !use_f16_raw_value && block_cols == 64 && + use_f16_pv_ordered_qk && + auto_f16_lane_qk; + std::vector prepared = use_f16_raw_value ? + prepare_k_hnd_packed_gfx12(key) : + prepare_kv_hnd_packed_gfx12<__half, false>(query, key, value); +#define SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_) \ + if (head_dim == 16) { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale); \ + } else { \ + if (use_f16_flat_q_schedule) { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale); \ + } \ + } +#define SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(BR_, PAD_, F16ACC_, PVORDER_, VLANE_, KLANE_) \ + if (head_dim == 16) { \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, PAD_, false, false, F16ACC_, __half, true, int8_t, false, PVORDER_, false, false, false, 16><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, PAD_, false, false, F16ACC_, __half, true, int8_t, false, PVORDER_, VLANE_, false, KLANE_><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale); \ + } +#define SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL(BC_, BR_, PAD_, F16ACC_, VLANE_, STREAM_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale) +#define SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_RAWV_CAUSAL(BR_, F16ACC_, SPLIT_) \ + qk_int8_sv_f16_d64_native_kernel<64, BR_, true, false, 4, true, true, F16ACC_, true, __half, true, SPLIT_><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale) +#define SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_TV_CAUSAL(BR_, F16ACC_) \ + qk_int8_sv_f16_d64_native_kernel<64, BR_, true, true, 4, true, false, F16ACC_, true, __half, true><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ + reinterpret_cast(prepared[2].data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(prepared[1].data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(2), query.stride(1), \ + prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ + prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ + output.stride(0), output.stride(2), output.stride(1), \ + 0, 0, \ + prepared[1].stride(0), prepared[1].stride(1), \ + kHND, sm_scale) +#define SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(BR_, PAD_) \ + if (use_f16_raw_value) { \ + if (use_f16_vlane && use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL(64, BR_, PAD_, true, true, true); } \ + else if (use_f16_vlane) { SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL(64, BR_, PAD_, true, true, false); } \ + else if (use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL(64, BR_, PAD_, true, false, true); } \ + else { SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL(64, BR_, PAD_, true, false, false); } \ + } else if (use_f16_pv_ordered_qk) { \ + if (use_f16_klane && use_f16_vlane && use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, true, true, true); } \ + else if (use_f16_klane && use_f16_vlane) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, true, false, true); } \ + else if (use_f16_klane && use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, false, true, true); } \ + else if (use_f16_klane) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, false, false, true); } \ + else if (use_f16_vlane && use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, true, true, false); } \ + else if (use_f16_vlane) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, true, false, false); } \ + else if (use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, false, true, false); } \ + else { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, true, false, false, false); } \ + } else { \ + if (use_f16_vlane && use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, false, true, true, false); } \ + else if (use_f16_vlane) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, false, true, false, false); } \ + else if (use_f16_streamk) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, false, false, true, false); } \ + else { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(64, BR_, PAD_, true, false, false, false, false); } \ + } + if (!is_causal) { + STD_TORCH_CHECK(block_cols == 64, + "native fp16 fused-Q non-causal path currently supports BC64"); + STD_TORCH_CHECK(!use_f16_raw_value, + "native fp16 fused-Q non-causal path requires transposed prepared values"); + if (use_f16_pv_ordered_qk) { + if (block_rows == 64) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(64, 4, true, true, false, false); } + else if (block_rows == 256) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(256, 4, true, true, false, false); } + else { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(128, 4, true, true, false, false); } + } else { + if (block_rows == 64) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(64, 4, true, false, false, false); } + else if (block_rows == 256) { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(256, 4, true, false, false, false); } + else { SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(128, 4, true, false, false, false); } + } + } else if (use_f16_fused_q_1q) { + if (use_f16_fused_q_1q_tv && block_rows == 128) { + SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_TV_CAUSAL(128, true); + } else if (use_f16_fused_q_1q_tv) { + SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_TV_CAUSAL(64, true); + } else if (block_rows == 128) { + SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_RAWV_CAUSAL(128, true, false); + } else { + SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_RAWV_CAUSAL(64, true, false); + } + } else if (block_rows == 64) { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(64, 4); + } else if (block_rows == 256) { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(256, 4); + } else if (block_rows == 512) { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(512, 4); + } else if (block_rows == 1024) { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(1024, 4); + } else if (q_len >= 8192) { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(128, 8); + } else if (q_len >= 1024) { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(128, 4); + } else { + SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL(128, 16); + } +#undef SAGEATTN_DISPATCH_F16_FUSED_Q_TV_CAUSAL +#undef SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_TV_CAUSAL +#undef SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_RAWV_CAUSAL +#undef SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL +#undef SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL +#undef SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL + hip_kernel_launch_check(); + } else if (use_raw_f16_value) { + std::vector prepared = quant_qk_int8_hnd_gfx12(query, key); + qk_int8_sv_f16_d64_native_attn_gfx12_impl( + prepared[0], prepared[2], value, output, prepared[1], prepared[3], + kHND, is_causal, sm_scale, kv_len); + } else { + const bool use_f16_separate_prepared = + is_causal && head_dim == 64 && q_len == 4096 && + query.scalar_type() == ScalarType::Half; + std::vector prepared = + use_f16_separate_prepared ? + prepare_qkv_hnd_gfx12<__half, false>(query, key, value) : + prepare_qkv_hnd_packed_gfx12<__half, false>(query, key, value); + qk_int8_sv_f16_d64_native_attn_gfx12_impl( + prepared[0], prepared[2], prepared[4], output, prepared[1], prepared[3], + kHND, is_causal, sm_scale, kv_len); + } + return output; +} + +static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + int tensor_layout, + int is_causal, + float sm_scale, + int64_t valid_kv_len, + const float* value_scale_ptr) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), + "native gfx12 tensors must be CUDA/HIP tensors"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Char, "query must be int8"); + STD_TORCH_CHECK(key.scalar_type() == ScalarType::Char, "key must be int8"); + const bool value_is_fp8 = value.scalar_type() == ScalarType::Byte; + STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half || value_is_fp8, + "value must be fp16 or raw OCP e4m3 fp8 bytes"); + STD_TORCH_CHECK(value_scale_ptr == nullptr || value_is_fp8, + "value_scale is only valid for the fp8 value path"); + const bool output_is_bf16 = output.scalar_type() == ScalarType::BFloat16; + STD_TORCH_CHECK(output.scalar_type() == ScalarType::Half || (value_is_fp8 && output_is_bf16), + "output must be fp16, or bf16 for the fp8 value path"); + STD_TORCH_CHECK(query_scale.scalar_type() == ScalarType::Float, "query_scale must be fp32"); + STD_TORCH_CHECK(key_scale.scalar_type() == ScalarType::Float, "key_scale must be fp32"); + STD_TORCH_CHECK(tensor_layout == kHND || tensor_layout == kNHD, "invalid tensor_layout"); + const int64_t head_dim = query.size(-1); + const bool value_maybe_transposed_hnd = + tensor_layout == kHND && value.dim() == 4 && value.size(2) == head_dim; + STD_TORCH_CHECK(key.size(-1) == head_dim && + (value.size(-1) == head_dim || value_maybe_transposed_hnd), + "query, key, and value must have matching head_dim"); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || (value_is_fp8 && head_dim == 128), + "native gfx12 path supports D16/D64, plus D128 for the fp8 2q path"); + + const int64_t batch = query.size(0); + const int64_t q_heads = tensor_layout == kNHD ? query.size(2) : query.size(1); + const int64_t q_len = tensor_layout == kNHD ? query.size(1) : query.size(2); + const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); + const int64_t padded_kv_len = tensor_layout == kNHD ? key.size(1) : key.size(2); + const int64_t kv_len = valid_kv_len > 0 ? valid_kv_len : padded_kv_len; + STD_TORCH_CHECK(kv_len > 0 && kv_len <= padded_kv_len, + "valid_kv_len must be in (0, padded_kv_len]"); + const bool value_transposed_hnd = value_maybe_transposed_hnd && value.size(3) >= padded_kv_len; + STD_TORCH_CHECK(!value_maybe_transposed_hnd || value_transposed_hnd, + "transposed HND value must have shape [B, H, D, padded_kv_len]"); + STD_TORCH_CHECK(!value_transposed_hnd || value.is_contiguous(), + "transposed HND value must be contiguous"); + STD_TORCH_CHECK((q_len % 64) == 0 && (padded_kv_len % 64) == 0, + "native gfx12 path requires q_len and kv_len multiples of 64"); + STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, + "native gfx12 causal path currently requires q_len == kv_len"); + STD_TORCH_CHECK((q_heads % kv_heads) == 0, "q_heads must be divisible by kv_heads"); + STD_TORCH_CHECK(query_scale.stride(-1) == 1 && key_scale.stride(-1) == 1, + "scale tensors must have contiguous scale columns"); + const bool hnd_contiguous = tensor_layout == kHND && + query.is_contiguous() && key.is_contiguous() && + value.is_contiguous() && output.is_contiguous(); + + int block_cols = 64; + bool use_2q = !value_is_fp8; + bool use_f16_causal_1q = false; + bool use_fp8_2q = value_is_fp8; + if (!value_is_fp8 && !value_transposed_hnd && + head_dim == 64 && q_len <= 1024) { + use_2q = false; + } + if (is_causal) { + if (value_is_fp8) { + use_2q = false; + use_f16_causal_1q = false; + use_fp8_2q = true; + } else { + use_fp8_2q = false; + use_f16_causal_1q = false; + use_2q = true; + } + } + STD_TORCH_CHECK(!(value_transposed_hnd && !value_is_fp8 && !use_2q && !use_f16_causal_1q), + "transposed fp16 value path currently requires tqk1/tqk2/auto mode"); + int block_rows = q_len <= 64 ? 64 : 128; + if (use_2q) { + block_rows = q_len <= 64 ? 64 : 128; + } else if (use_fp8_2q && block_rows != 32 && block_rows != 64 && + block_rows != 256 && block_rows != 512) { + block_rows = q_len <= 64 ? 64 : 128; + } + if (is_causal && head_dim == 16 && (use_2q || use_fp8_2q)) { + block_rows = (use_fp8_2q && q_len <= 1024) ? 64 : 128; + } + if (use_f16_causal_1q) { + block_cols = 64; + block_rows = 64; + } + if (!is_causal && use_fp8_2q && head_dim == 64 && + (q_len >= 2048 || value_transposed_hnd) && (q_len % 256) == 0) { + block_rows = select_fp8_d64_block_rows_gfx12(q_len, is_causal, value_transposed_hnd); + } + if (is_causal && use_fp8_2q && head_dim == 64 && + (q_len % 256) == 0) { + block_rows = select_fp8_d64_block_rows_gfx12(q_len, is_causal, value_transposed_hnd); + } + if (is_causal && use_fp8_2q && head_dim == 128 && + value_transposed_hnd && q_len >= 2048 && q_len <= 4096) { + block_cols = 32; + } + if (!is_causal && use_2q && value_transposed_hnd && + q_len >= 2048 && (q_len % 256) == 0) { + block_rows = 256; + } + if (is_causal && use_2q && head_dim == 64 && value_transposed_hnd && + q_len >= 4096 && (q_len % 256) == 0) { + block_rows = 256; + } + STD_TORCH_CHECK(!(use_fp8_2q && block_rows == 64 && block_cols == 128), + "native fp8 2q BR64 is currently specialized for BC32/BC64"); + STD_TORCH_CHECK(!(use_fp8_2q && block_rows == 256 && block_cols != 64), + "native fp8 2q BR256 is currently specialized for BC64"); + STD_TORCH_CHECK(!(use_2q && value_transposed_hnd && block_cols != 64), + "native fp16 transposed value 2q path currently supports BC64"); + STD_TORCH_CHECK(!(use_2q && block_rows != 32 && block_rows != 64 && + block_rows != 128 && block_rows != 256 && block_rows != 512 && + block_rows != 1024), + "native fp16 2q path currently supports BR32/BR64/BR128/BR256/BR512/BR1024"); + STD_TORCH_CHECK(!(use_2q && !value_transposed_hnd && block_rows != 64 && block_rows != 128), + "native fp16 non-transposed 2q path currently supports BR64/BR128"); + STD_TORCH_CHECK(!(use_f16_causal_1q && + ((block_rows != 64 && block_rows != 128) || block_cols != 64)), + "native fp16 single-q causal path currently supports BR64/BR128/BC64"); + STD_TORCH_CHECK((q_len % block_rows) == 0, + "native gfx12 path requires q_len to be a multiple of the selected block rows"); + + const bool use_f16_flat_q_schedule = + head_dim == 64 && !value_is_fp8 && is_causal && use_2q && + value_transposed_hnd && q_len >= 2048; + const int64_t q_blocks = (q_len + block_rows - 1) / block_rows; + dim3 block((use_2q || use_fp8_2q) ? block_rows : (block_rows == 128 ? 256 : 128)); + dim3 grid(q_blocks, q_heads, batch); + dim3 grid_f16_flat(q_blocks * q_heads * batch); + const bool use_f16_tvload = + !value_is_fp8 && is_causal && hnd_contiguous && !value_transposed_hnd && + q_len >= 1024; + const bool use_f16_pv_accum = + !value_is_fp8 && is_causal && value_transposed_hnd && block_cols == 64 && + q_len >= 1024; + const bool use_f16_pv_ordered_qk = + use_f16_pv_accum && q_len >= 4096; + const bool use_f16_vlane = + head_dim == 64 && !value_is_fp8 && is_causal && value_transposed_hnd && + q_len == 4096 && block_rows == 256; + const bool use_f16_streamk = + head_dim == 64 && !value_is_fp8 && is_causal && value_transposed_hnd && + q_len == 4096 && block_rows == 256; +#define SAGEATTN_LAUNCH_FP8_2Q_OUT(BC_, HD_, HND_, BR_, OUT_T_) \ + if (is_causal) { \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + value_scale_ptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } else { \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + value_scale_ptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } +#define SAGEATTN_LAUNCH_FP8_2Q(BC_, HD_, HND_, BR_) \ + if (output_is_bf16) { \ + SAGEATTN_LAUNCH_FP8_2Q_OUT(BC_, HD_, HND_, BR_, __hip_bfloat16); \ + } else { \ + SAGEATTN_LAUNCH_FP8_2Q_OUT(BC_, HD_, HND_, BR_, __half); \ + } +#define SAGEATTN_LAUNCH_FP8_2Q_TV_OUT(BC_, HD_, BR_, OUT_T_) \ + if (is_causal) { \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + value_scale_ptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } else { \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + value_scale_ptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } +#define SAGEATTN_LAUNCH_FP8_2Q_TV(BC_, HD_, BR_) \ + if (output_is_bf16) { \ + SAGEATTN_LAUNCH_FP8_2Q_TV_OUT(BC_, HD_, BR_, __hip_bfloat16); \ + } else { \ + SAGEATTN_LAUNCH_FP8_2Q_TV_OUT(BC_, HD_, BR_, __half); \ + } +#define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) +#define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_) \ + if (use_f16_flat_q_schedule) { \ + SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, grid_f16_flat, true); \ + } else { \ + SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, grid, false); \ + } +#define SAGEATTN_LAUNCH_F16_D16_2Q_TV(BC_, BR_, PAD_, CAUSAL_, F16ACC_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) +#define SAGEATTN_LAUNCH_F16_2Q_TV(BC_, BR_, PAD_) \ + if (is_causal) { \ + if (use_f16_pv_accum) { \ + if (use_f16_pv_ordered_qk) { \ + if (use_f16_vlane && (BC_) == 64 && use_f16_streamk) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, true, true, true, false); } \ + else if (use_f16_vlane && (BC_) == 64) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, true, true, false, false); } \ + else if (use_f16_streamk) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, true, false, true, false); } \ + else { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, true, false, false, false); } \ + } else { \ + if (use_f16_vlane && (BC_) == 64 && use_f16_streamk) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, false, true, true, false); } \ + else if (use_f16_vlane && (BC_) == 64) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, false, true, false, false); } \ + else if (use_f16_streamk) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, false, false, true, false); } \ + else { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, true, false, false, false, false); } \ + } \ + } else { \ + if (use_f16_vlane && (BC_) == 64 && use_f16_streamk) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, false, false, true, true, false); } \ + else if (use_f16_vlane && (BC_) == 64) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, false, false, true, false, false); } \ + else if (use_f16_streamk) { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, false, false, false, true, false); } \ + else { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, false, false, false, false, false); } \ + } \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } +#define SAGEATTN_LAUNCH_F16_2Q(BC_, HND_, BR_) \ + if (is_causal) { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } +#define SAGEATTN_LAUNCH_F16_2Q_TVLOAD_CAUSAL(BC_, BR_, PAD_, F16ACC_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) +#define SAGEATTN_LAUNCH_F16_2Q_TVLOAD(BC_, BR_, PAD_) \ + if (is_causal) { \ + SAGEATTN_LAUNCH_F16_2Q_TVLOAD_CAUSAL(BC_, BR_, PAD_, false); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } +#define SAGEATTN_LAUNCH_F16_1Q(BC_, BR_) \ + qk_int8_sv_f16_d64_native_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) +#define SAGEATTN_LAUNCH_F16_1Q_CAUSAL(BR_, TRANSPOSED_, TVLOAD_, PAD_, F16ACC_) \ + qk_int8_sv_f16_d64_native_kernel<64, BR_, true, TRANSPOSED_, PAD_, true, TVLOAD_, F16ACC_><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), (TRANSPOSED_ ? value.stride(2) : value.stride(tensor_layout == kNHD ? 1 : 2)), \ + (TRANSPOSED_ ? value.stride(1) : value.stride(tensor_layout == kNHD ? 2 : 1)), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) + if (use_f16_causal_1q) { + STD_TORCH_CHECK(hnd_contiguous, "fp16 single-q causal path requires contiguous HND tensors"); + const bool use_f16_1q_pv_accum = use_f16_pv_accum; +#define SAGEATTN_DISPATCH_F16_1Q_CAUSAL(BR_, TRANSPOSED_, TVLOAD_, PAD_) \ + if (use_f16_1q_pv_accum) { \ + SAGEATTN_LAUNCH_F16_1Q_CAUSAL(BR_, TRANSPOSED_, TVLOAD_, PAD_, true); \ + } else { \ + SAGEATTN_LAUNCH_F16_1Q_CAUSAL(BR_, TRANSPOSED_, TVLOAD_, PAD_, false); \ + } + if (value_transposed_hnd) { + if (q_len >= 4096) { + if (block_rows == 128) { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(128, true, false, 8); } + else { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(64, true, false, 8); } + } else if (q_len >= 1024) { + if (block_rows == 128) { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(128, true, false, 4); } + else { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(64, true, false, 4); } + } else { + if (block_rows == 128) { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(128, true, false, 16); } + else { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(64, true, false, 16); } + } + } else if (use_f16_tvload) { + if (block_rows == 128) { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(128, false, true, 4); } + else { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(64, false, true, 4); } + } else { + if (block_rows == 128) { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(128, false, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD); } + else { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(64, false, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD); } + } +#undef SAGEATTN_DISPATCH_F16_1Q_CAUSAL + } else if (use_fp8_2q && value_transposed_hnd) { + STD_TORCH_CHECK(hnd_contiguous, "transposed fp8 value path requires contiguous HND Q/K/O"); + STD_TORCH_CHECK(block_cols == 32 || block_cols == 64, + "transposed fp8 value path currently supports BC32/BC64"); + STD_TORCH_CHECK(!(block_rows == 256 && block_cols != 64), + "transposed fp8 value BR256 path currently supports BC64"); + if (head_dim == 16) { + if (block_cols == 32) { + if (block_rows == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 16, 32); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 16, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 16, 128); + } + } else if (block_rows == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 16, 32); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 16, 64); + } else if (block_rows == 256) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 16, 256); + } else if (block_rows == 512) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 16, 512); + } else { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 16, 128); + } + } else if (block_rows == 512 && block_cols == 32 && head_dim == 128) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 128, 512); + } else if (block_rows == 512 && block_cols == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 64, 512); + } else if (block_rows == 512 && head_dim == 128) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 128, 512); + } else if (block_rows == 512) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 64, 512); + } else if (block_rows == 256 && head_dim == 128) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 128, 256); + } else if (block_rows == 256) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 64, 256); + } else if (block_cols == 32 && head_dim == 128) { + if (block_rows == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 128, 32); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 128, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 128, 128); + } + } else if (block_cols == 32) { + if (block_rows == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 64, 32); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 64, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q_TV(32, 64, 128); + } + } else if (head_dim == 128) { + if (block_rows == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 128, 32); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 128, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 128, 128); + } + } else { + if (block_rows == 32) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 64, 32); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 64, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q_TV(64, 64, 128); + } + } + } else if (use_fp8_2q && block_rows == 64 && block_cols == 32 && head_dim == 128) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(32, 128, true, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q(32, 128, false, 64); + } + } else if (use_fp8_2q && block_rows == 64 && block_cols == 32) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(32, 64, true, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q(32, 64, false, 64); + } + } else if (use_fp8_2q && block_rows == 64 && head_dim == 128) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(64, 128, true, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q(64, 128, false, 64); + } + } else if (use_fp8_2q && block_rows == 64) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(64, 64, true, 64); + } else { + SAGEATTN_LAUNCH_FP8_2Q(64, 64, false, 64); + } + } else if (use_fp8_2q && block_rows == 256 && head_dim == 128) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(64, 128, true, 256); + } else { + SAGEATTN_LAUNCH_FP8_2Q(64, 128, false, 256); + } + } else if (use_fp8_2q && block_rows == 256) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(64, 64, true, 256); + } else { + SAGEATTN_LAUNCH_FP8_2Q(64, 64, false, 256); + } + } else if (use_fp8_2q && block_cols == 32 && head_dim == 128) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(32, 128, true, 128); + } else { + SAGEATTN_LAUNCH_FP8_2Q(32, 128, false, 128); + } + } else if (use_fp8_2q && block_cols == 32) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(32, 64, true, 128); + } else { + SAGEATTN_LAUNCH_FP8_2Q(32, 64, false, 128); + } + } else if (use_fp8_2q && block_cols == 128 && head_dim == 128) { + qk_int8_sv_f8_native_2q_kernel<128, 128, 0, 8><<>>( + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast<__half*>(output.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + value_scale_ptr, + batch, q_len, kv_len, q_heads, kv_heads, + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), + query_scale.stride(0), query_scale.stride(1), + key_scale.stride(0), key_scale.stride(1), + tensor_layout, sm_scale); + } else if (use_fp8_2q && block_cols == 128) { + qk_int8_sv_f8_native_2q_kernel<128, 64, 0, 4><<>>( + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast<__half*>(output.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + value_scale_ptr, + batch, q_len, kv_len, q_heads, kv_heads, + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), + query_scale.stride(0), query_scale.stride(1), + key_scale.stride(0), key_scale.stride(1), + tensor_layout, sm_scale); + } else if (use_fp8_2q && head_dim == 128) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(64, 128, true, 128); + } else { + SAGEATTN_LAUNCH_FP8_2Q(64, 128, false, 128); + } + } else if (use_fp8_2q) { + if (hnd_contiguous) { + SAGEATTN_LAUNCH_FP8_2Q(64, 64, true, 128); + } else { + SAGEATTN_LAUNCH_FP8_2Q(64, 64, false, 128); + } + } else if (use_2q && value_transposed_hnd) { + STD_TORCH_CHECK(hnd_contiguous, "transposed fp16 value path requires contiguous HND Q/K/O"); + if (head_dim == 16) { + if (is_causal) { + if (block_rows == 32) { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 32, 4, true, true); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 64, 4, true, true); + } else if (block_rows == 256) { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 256, 4, true, true); + } else { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 128, 4, true, true); + } + } else if (block_rows == 32) { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 32, 4, false, true); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 64, 4, false, true); + } else if (block_rows == 256) { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 256, 4, false, true); + } else { + SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 128, 4, false, true); + } + } else if (block_rows == 32) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 32, 4); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 64, 4); + } else if (block_rows == 256) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 256, 4); + } else if (block_rows == 512) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 512, 4); + } else if (block_rows == 1024) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 1024, 4); + } else if (q_len >= 8192) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 128, 8); + } else if (q_len >= 1024) { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 128, 4); + } else { + SAGEATTN_LAUNCH_F16_2Q_TV(64, 128, 16); + } + } else if (use_2q && block_cols == 128) { + if (hnd_contiguous) { + if (use_f16_tvload) { + if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_2Q_TVLOAD(128, 64, 16); + } else { + SAGEATTN_LAUNCH_F16_2Q_TVLOAD(128, 128, 16); + } + } else { + if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_2Q(128, true, 64); + } else { + SAGEATTN_LAUNCH_F16_2Q(128, true, 128); + } + } + } else { + if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_2Q(128, false, 64); + } else { + SAGEATTN_LAUNCH_F16_2Q(128, false, 128); + } + } + } else if (use_2q) { + if (hnd_contiguous) { + if (use_f16_tvload) { + if (block_rows == 64) { + if (is_causal && use_f16_pv_accum) { + SAGEATTN_LAUNCH_F16_2Q_TVLOAD_CAUSAL(64, 64, 4, true); + } else { + SAGEATTN_LAUNCH_F16_2Q_TVLOAD(64, 64, 4); + } + } else { + if (is_causal && use_f16_pv_accum) { + SAGEATTN_LAUNCH_F16_2Q_TVLOAD_CAUSAL(64, 128, 4, true); + } else { + SAGEATTN_LAUNCH_F16_2Q_TVLOAD(64, 128, 4); + } + } + } else { + if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_2Q(64, true, 64); + } else { + SAGEATTN_LAUNCH_F16_2Q(64, true, 128); + } + } + } else { + if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_2Q(64, false, 64); + } else { + SAGEATTN_LAUNCH_F16_2Q(64, false, 128); + } + } + } else if (block_cols == 128 && block_rows == 128) { + SAGEATTN_LAUNCH_F16_1Q(128, 128); + } else if (block_cols == 128) { + SAGEATTN_LAUNCH_F16_1Q(128, 64); + } else if (block_rows == 128) { + SAGEATTN_LAUNCH_F16_1Q(64, 128); + } else { + SAGEATTN_LAUNCH_F16_1Q(64, 64); + } + hip_kernel_launch_check(); + return new_empty_like(query, {0}, ScalarType::Float); +} + +static const float* checked_value_scale_ptr_gfx12( + Tensor value_scale, + int64_t batch, + int64_t kv_heads, + int64_t head_dim) { + STD_TORCH_CHECK(value_scale.is_cuda(), "value_scale must be a CUDA/HIP tensor"); + STD_TORCH_CHECK(value_scale.scalar_type() == ScalarType::Float, + "value_scale must be fp32"); + STD_TORCH_CHECK(value_scale.dim() == 3 && value_scale.is_contiguous(), + "value_scale must be contiguous [B, H_kv, D]"); + STD_TORCH_CHECK(value_scale.size(0) == batch && + value_scale.size(1) == kv_heads && + value_scale.size(2) == head_dim, + "value_scale shape must match [B, H_kv, D]"); + return reinterpret_cast(value_scale.data_ptr()); +} + +static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + const float* value_scale_ptr, + int tensor_layout, + int is_causal, + float sm_scale, + int64_t valid_kv_len) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), + "raw-Q gfx12 tensors must be CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4 && output.dim() == 4, + "raw-Q gfx12 attention expects 4D Q/K/V/O tensors"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || + query.scalar_type() == ScalarType::BFloat16, + "raw-Q gfx12 attention supports fp16/bf16 query"); + STD_TORCH_CHECK(key.scalar_type() == ScalarType::Char, + "raw-Q gfx12 attention expects pre-quantized int8 key"); + STD_TORCH_CHECK(value.scalar_type() == ScalarType::Byte, + "raw-Q gfx12 attention expects raw OCP e4m3 fp8 value bytes"); + STD_TORCH_CHECK(output.scalar_type() == ScalarType::Half || + output.scalar_type() == ScalarType::BFloat16, + "raw-Q gfx12 attention output must be fp16 or bf16"); + STD_TORCH_CHECK(key_scale.scalar_type() == ScalarType::Float, + "raw-Q gfx12 attention key_scale must be fp32"); + STD_TORCH_CHECK(tensor_layout == kHND || tensor_layout == kNHD, "invalid tensor_layout"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && + value.is_contiguous() && output.is_contiguous(), + "raw-Q gfx12 attention expects contiguous tensors"); + + const int64_t head_dim = query.size(-1); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "raw-Q gfx12 fp8 path supports head_dim 16, 64, or 128"); + const int64_t batch = query.size(0); + const int64_t q_heads = tensor_layout == kNHD ? query.size(2) : query.size(1); + const int64_t q_len = tensor_layout == kNHD ? query.size(1) : query.size(2); + const int64_t out_q_len = tensor_layout == kNHD ? output.size(1) : output.size(2); + const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); + const int64_t padded_kv_len = tensor_layout == kNHD ? key.size(1) : key.size(2); + const int64_t kv_len = valid_kv_len > 0 ? valid_kv_len : padded_kv_len; + STD_TORCH_CHECK(kv_len > 0 && kv_len <= padded_kv_len, + "valid_kv_len must be in (0, padded_kv_len]"); + STD_TORCH_CHECK(key.size(0) == batch && value.size(0) == batch && + output.size(0) == batch, + "raw-Q gfx12 batch size mismatch"); + const bool value_transposed_hnd = + value.size(1) == kv_heads && value.size(2) == head_dim && + value.size(3) >= padded_kv_len; + STD_TORCH_CHECK(key.size(-1) == head_dim && output.size(-1) == head_dim && + (value_transposed_hnd || value.size(-1) == head_dim), + "raw-Q gfx12 Q/K/V/O head_dim mismatch"); + STD_TORCH_CHECK((tensor_layout == kNHD && + ((value_transposed_hnd && output.size(1) >= q_len && + key.size(2) == kv_heads && output.size(2) == q_heads) || + (!value_transposed_hnd && value.size(1) == padded_kv_len && + output.size(1) >= q_len && key.size(2) == kv_heads && + value.size(2) == kv_heads && output.size(2) == q_heads))) || + (tensor_layout == kHND && + ((value_transposed_hnd && output.size(2) >= q_len && + key.size(1) == kv_heads && output.size(1) == q_heads) || + (!value_transposed_hnd && value.size(2) == padded_kv_len && + output.size(2) >= q_len && key.size(1) == kv_heads && + value.size(1) == kv_heads && output.size(1) == q_heads))), + "raw-Q gfx12 Q/K/V/O shape mismatch"); + STD_TORCH_CHECK((q_heads % kv_heads) == 0, "q_heads must be divisible by kv_heads"); + STD_TORCH_CHECK((padded_kv_len % 64) == 0, + "raw-Q gfx12 attention requires padded kv_len divisible by 64"); + STD_TORCH_CHECK(!is_causal || (q_len % 64) == 0, + "raw-Q gfx12 causal attention requires q_len divisible by 64"); + STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, + "raw-Q gfx12 causal attention requires q_len == padded_kv_len"); + STD_TORCH_CHECK(key_scale.stride(-1) == 1, "key_scale must have contiguous scale columns"); + + int block_rows = head_dim == 64 ? + select_fp8_d64_block_rows_gfx12(q_len, is_causal, false) : + (q_len <= 64 ? 64 : 128); + if (head_dim == 64 && !is_causal && q_len == 1024) { + block_rows = 128; + } + if (head_dim == 16 && is_causal && q_len <= 1024) { + block_rows = 64; + } + const int64_t q_blocks = (q_len + block_rows - 1) / block_rows; + STD_TORCH_CHECK(out_q_len >= q_blocks * block_rows, + "raw-Q gfx12 attention output must cover the padded query tail"); + + constexpr int block_cols = 64; + const bool hnd_contiguous = tensor_layout == kHND; + const dim3 block(block_rows); + const dim3 grid(q_blocks, q_heads, batch); + +#define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_) \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), \ + reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + nullptr, reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), (VT_ ? value.stride(2) : value.stride(tensor_layout == kNHD ? 1 : 2)), (VT_ ? value.stride(1) : value.stride(tensor_layout == kNHD ? 2 : 1)), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + 0, 0, \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) +#define SAGEATTN_DISPATCH_RAWQ_FP8_OUT(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_) \ + if (output.scalar_type() == ScalarType::BFloat16) { \ + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_, __hip_bfloat16); \ + } else { \ + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_, __half); \ + } +#define SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, BR_, VT_, CAUSAL_) \ + if (query.scalar_type() == ScalarType::BFloat16) { \ + SAGEATTN_DISPATCH_RAWQ_FP8_OUT(HD_, HND_, BR_, VT_, CAUSAL_, __hip_bfloat16); \ + } else { \ + SAGEATTN_DISPATCH_RAWQ_FP8_OUT(HD_, HND_, BR_, VT_, CAUSAL_, __half); \ + } +#define SAGEATTN_DISPATCH_RAWQ_FP8_BR(HD_, HND_, VT_, CAUSAL_) \ + if (block_rows == 64) { \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 64, VT_, CAUSAL_); \ + } else if (block_rows == 256) { \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 256, VT_, CAUSAL_); \ + } else if (block_rows == 512) { \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 512, VT_, CAUSAL_); \ + } else { \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 128, VT_, CAUSAL_); \ + } +#define SAGEATTN_DISPATCH_RAWQ_FP8_HD(HND_, VT_, CAUSAL_) \ + if (head_dim == 16) { \ + SAGEATTN_DISPATCH_RAWQ_FP8_BR(16, HND_, VT_, CAUSAL_); \ + } else if (head_dim == 64) { \ + SAGEATTN_DISPATCH_RAWQ_FP8_BR(64, HND_, VT_, CAUSAL_); \ + } else { \ + SAGEATTN_DISPATCH_RAWQ_FP8_BR(128, HND_, VT_, CAUSAL_); \ + } + + if (hnd_contiguous) { + if (is_causal) { + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, true, true); } + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, false, true); } + } else { + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, true, false); } + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, false, false); } + } + } else { + if (is_causal) { + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, true, true); } + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, false, true); } + } else { + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, true, false); } + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, false, false); } + } + } + +#undef SAGEATTN_DISPATCH_RAWQ_FP8_HD +#undef SAGEATTN_DISPATCH_RAWQ_FP8_BR +#undef SAGEATTN_DISPATCH_RAWQ_FP8_QUERY +#undef SAGEATTN_DISPATCH_RAWQ_FP8_OUT +#undef SAGEATTN_LAUNCH_RAWQ_FP8_TYPED + hip_kernel_launch_check(); + return output; +} + +Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len) { + return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( + query, key, value, output, key_scale, nullptr, + static_cast(tensor_layout), static_cast(is_causal), + static_cast(sm_scale), valid_kv_len); +} + +Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len) { + const int64_t head_dim = query.size(-1); + const int64_t batch = query.size(0); + const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); + const float* value_scale_ptr = + checked_value_scale_ptr_gfx12(value_scale, batch, kv_heads, head_dim); + return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( + query, key, value, output, key_scale, value_scale_ptr, + static_cast(tensor_layout), static_cast(is_causal), + static_cast(sm_scale), valid_kv_len); +} + +Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + Tensor value_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len) { + const int64_t head_dim = query.size(-1); + const int64_t batch = query.size(0); + const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); + const float* value_scale_ptr = + checked_value_scale_ptr_gfx12(value_scale, batch, kv_heads, head_dim); + return qk_int8_sv_f16_d64_native_attn_gfx12_impl( + query, key, value, output, query_scale, key_scale, + static_cast(tensor_layout), static_cast(is_causal), + static_cast(sm_scale), valid_kv_len, value_scale_ptr); +} + +Tensor qk_int8_sv_f16_d64_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len) { + return qk_int8_sv_f16_d64_native_attn_gfx12_impl( + query, key, value, output, query_scale, key_scale, + static_cast(tensor_layout), static_cast(is_causal), + static_cast(sm_scale), valid_kv_len); +} diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index f6f74c45..2016f272 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -19,7 +19,11 @@ */ #pragma once +#if defined(__HIP_PLATFORM_AMD__) +#define FINAL_MASK 0xffffffffffffffffull +#else #define FINAL_MASK 0xffffffff +#endif namespace vllm { @@ -27,7 +31,7 @@ template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); + val += __shfl_xor_sync(FINAL_MASK, val, mask, 32); return val; } @@ -121,7 +125,7 @@ __inline__ __device__ T warpReduceMax(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); return val; } /* Calculate the maximum of all elements in a block */ @@ -170,7 +174,7 @@ __inline__ __device__ T warpReduceMin(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val = min(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + val = min(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); return val; } /* Calculate the minimum of all elements in a block */ diff --git a/sageattention/__init__.py b/sageattention/__init__.py index 73b0256d..9f63cf41 100644 --- a/sageattention/__init__.py +++ b/sageattention/__init__.py @@ -2,4 +2,5 @@ from .core import sageattn_qk_int8_pv_fp16_triton from .core import sageattn_qk_int8_pv_fp16_cuda from .core import sageattn_qk_int8_pv_fp8_cuda -from .core import sageattn_qk_int8_pv_fp8_cuda_sm90 \ No newline at end of file +from .core import sageattn_qk_int8_pv_fp8_cuda_sm90 +from .core import sageattn_qk_int8_pv_gfx12_native diff --git a/sageattention/core.py b/sageattention/core.py index 21a955fa..42ca0f9c 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -16,6 +16,8 @@ import torch import torch.nn.functional as F +import subprocess +import re from .triton.quant_per_block import per_block_int8 as per_block_int8_triton from .triton.quant_per_block_varlen import per_block_int8 as per_block_int8_varlen_triton @@ -44,31 +46,471 @@ except: SM90_ENABLED = False +try: + from .gfx12_native_compile import _qattn_gfx12_native + _qattn_gfx12_prepare_attn_hnd = _qattn_gfx12_native.qk_int8_sv_f16_d64_prepare_attn_hnd + GFX12_NATIVE_ENABLED = True +except Exception: + _qattn_gfx12_native = None + _qattn_gfx12_prepare_attn_hnd = None + GFX12_NATIVE_ENABLED = False + from .quant import per_block_int8 as per_block_int8_cuda from .quant import per_warp_int8 as per_warp_int8_cuda from .quant import sub_mean from .quant import per_channel_fp8 +from .quant import _fused as _quant_fused from typing import Any, List, Literal, Optional, Tuple, Union import warnings def get_cuda_version(): - version = torch.version.cuda - major, minor = version.split('.') - return int(major), int(minor) + try: + output = subprocess.check_output(['nvcc', '--version']).decode() + match = re.search(r'release (\d+)\.(\d+)', output) + if match: + major, minor = int(match.group(1)), int(match.group(2)) + return major, minor + except Exception as e: + print("Failed to get CUDA version:", e) + return None, None def get_cuda_arch_versions(): cuda_archs = [] - for i in range(torch.cuda.device_count()): - major, minor = torch.cuda.get_device_capability(i) - cuda_archs.append(f"sm{major}{minor}") + if torch.version.hip is not None: + for i in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(i) + arch = getattr(props, "gcnArchName", "") + cuda_archs.append(arch.split(":", 1)[0] if arch else "") + else: + for i in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(i) + cuda_archs.append(f"sm{major}{minor}") return cuda_archs -# Currently get_cuda_arch_versions cannot be traced by torch.compile -_cuda_archs = get_cuda_arch_versions() +def _get_gfx12_native_extension(): + global _qattn_gfx12_native, _qattn_gfx12_prepare_attn_hnd, GFX12_NATIVE_ENABLED + if _qattn_gfx12_native is None: + from .gfx12_native_compile import _qattn_gfx12_native as ops + _qattn_gfx12_native = ops + _qattn_gfx12_prepare_attn_hnd = _qattn_gfx12_native.qk_int8_sv_f16_d64_prepare_attn_hnd + GFX12_NATIVE_ENABLED = True + return _qattn_gfx12_native + + +def _get_gfx12_prepare_attn_hnd(): + _get_gfx12_native_extension() + return _qattn_gfx12_prepare_attn_hnd + + +def _round_up_to_multiple(value: int, multiple: int) -> int: + return ((value + multiple - 1) // multiple) * multiple + + +def _pad_gfx12_hnd_sequence( + q_hnd: torch.Tensor, + k_hnd: torch.Tensor, + v_hnd: torch.Tensor, + q_len: int, + kv_len: int, + is_causal: bool = False, + k_pad_value: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_padded_len = _round_up_to_multiple(q_len, 128) + kv_padded_len = q_padded_len if is_causal else _round_up_to_multiple(kv_len, 64) + q_pad_len = q_padded_len - q_len + kv_pad_len = kv_padded_len - kv_len + if q_pad_len > 0: + q_hnd = F.pad(q_hnd, (0, 0, 0, q_pad_len)) + if kv_pad_len > 0: + if k_pad_value is None: + k_hnd = F.pad(k_hnd, (0, 0, 0, kv_pad_len)) + else: + k_hnd = torch.cat([k_hnd, k_pad_value.expand(-1, -1, kv_pad_len, -1)], dim=2) + v_hnd = F.pad(v_hnd, (0, 0, 0, kv_pad_len)) + return q_hnd, k_hnd, v_hnd + + +def _pad_gfx12_nhd_sequence( + q_nhd: torch.Tensor, + k_nhd: torch.Tensor, + v_nhd: torch.Tensor, + q_len: int, + kv_len: int, + is_causal: bool = False, + k_pad_value: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_padded_len = _round_up_to_multiple(q_len, 128) + kv_padded_len = q_padded_len if is_causal else _round_up_to_multiple(kv_len, 64) + q_pad_len = q_padded_len - q_len + kv_pad_len = kv_padded_len - kv_len + if q_pad_len > 0: + q_nhd = F.pad(q_nhd, (0, 0, 0, 0, 0, q_pad_len)) + if kv_pad_len > 0: + if k_pad_value is None: + k_nhd = F.pad(k_nhd, (0, 0, 0, 0, 0, kv_pad_len)) + else: + k_nhd = torch.cat([k_nhd, k_pad_value.expand(-1, kv_pad_len, -1, -1)], dim=1) + v_nhd = F.pad(v_nhd, (0, 0, 0, 0, 0, kv_pad_len)) + return q_nhd, k_nhd, v_nhd + + +_GFX12_FP8_VALUE_SCALE_MAX = 2.25 + + +def _gfx12_fp8_value_scale_hnd(v_hnd: torch.Tensor) -> torch.Tensor: + return v_hnd.abs().amax(dim=2).to(torch.float32).div(_GFX12_FP8_VALUE_SCALE_MAX).contiguous() + + +def sageattn_qk_int8_pv_gfx12_native( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str = "HND", + is_causal: bool = False, + sm_scale: Optional[float] = None, + value_dtype: str = "fp8", + smooth_k: bool = True, + return_lse: bool = False, + **kwargs: Any, +) -> torch.Tensor: + """ + ROCm gfx12 native SageAttention path. + + Supports fixed-length attention. The default smooth-K path follows the + CUDA quantization flow; NHD inputs use native NHD quantization to avoid an + extra layout conversion when possible. + + Current gfx12 constraints: + - q, k, and v must be fp16 or bf16. + - value_dtype="fp8" supports head_dim 16, 64, or 128. + - value_dtype="fp16" supports head_dim 16 or 64. + - Causal masking requires q_len == kv_len. + - smooth_k is enabled by default to match the CUDA and Triton paths. + - return_lse is not implemented yet. + """ + + gfx12_native = _get_gfx12_native_extension() + gfx12_prepare_attn_hnd = _qattn_gfx12_prepare_attn_hnd + value_dtype_normalized = value_dtype.lower() + if value_dtype_normalized == "auto": + value_dtype_normalized = "fp8" + if ( + tensor_layout == "HND" + and value_dtype_normalized == "fp16" + and not smooth_k + and not return_lse + and q.dim() == 4 + and k.dim() == 4 + and v.dim() == 4 + and q.is_cuda + and q.is_contiguous() + and k.is_contiguous() + and v.is_contiguous() + and q.dtype == torch.float16 + and q.dtype == k.dtype == v.dtype + and q.device == k.device == v.device + and q.size(-1) in (16, 64) + and q.size(2) % 64 == 0 + and k.size(2) % 64 == 0 + ): + torch.cuda.set_device(v.device) + use_raw_f16_value = is_causal and q.size(-1) == 64 and q.size(2) <= 512 + return gfx12_prepare_attn_hnd( + q, + k, + v, + int(is_causal), + 0, + int(use_raw_f16_value), + float(sm_scale if sm_scale is not None else q.size(-1) ** -0.5), + ) + + assert q.is_cuda, "Input tensors must be on cuda/HIP." + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + assert q.dtype in [torch.float16, torch.bfloat16], "gfx12 native path supports fp16/bf16 inputs." + assert tensor_layout in ["HND", "NHD"], "tensor_layout must be either 'HND' or 'NHD'." + if return_lse: + raise ValueError("gfx12 native path does not return LSE yet.") + torch.cuda.set_device(v.device) + + input_dtype = q.dtype + value_dtype = value_dtype_normalized + if value_dtype not in {"fp16", "fp8"}: + raise ValueError("gfx12 native value_dtype must be 'auto', 'fp16', or 'fp8'.") + if sm_scale is None and q.dim() == 4: + sm_scale = q.size(-1) ** -0.5 + + if ( + tensor_layout == "HND" + and not smooth_k + and q.dim() == 4 + and k.dim() == 4 + and v.dim() == 4 + and q.is_contiguous() + and k.is_contiguous() + and v.is_contiguous() + and q.size(-1) in (16, 64, 128) + and value_dtype == "fp16" + and q.size(-1) in (16, 64) + and q.size(2) % 64 == 0 + and k.size(2) % 64 == 0 + ): + use_raw_f16_value = ( + value_dtype == "fp16" + and input_dtype == torch.float16 + and is_causal + and q.size(-1) == 64 + and q.size(2) <= 512 + ) + out = gfx12_prepare_attn_hnd( + q, + k, + v, + int(is_causal), + int(value_dtype == "fp8"), + int(use_raw_f16_value), + float(sm_scale), + ) + if input_dtype == torch.bfloat16: + return out if out.dtype == torch.bfloat16 else gfx12_native.convert_f16_to_bf16(out) + return out + + if tensor_layout == "NHD" and smooth_k: + q_nhd = q.contiguous() + k_nhd = k.contiguous() + v_nhd = v.contiguous() + + _, qo_len, h_qo, head_dim_og = q_nhd.shape + _, kv_len, h_kv, _ = k_nhd.shape + if h_qo % h_kv != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads.") + if is_causal and qo_len != kv_len: + raise ValueError("gfx12 causal path currently requires q_len == kv_len.") + + head_dim = head_dim_og + if head_dim < 64: + pad = 64 - head_dim + q_nhd = F.pad(q_nhd, (0, pad)) + k_nhd = F.pad(k_nhd, (0, pad)) + v_nhd = F.pad(v_nhd, (0, pad)) + head_dim = 64 + elif value_dtype == "fp8" and 64 < head_dim < 128: + pad = 128 - head_dim + q_nhd = F.pad(q_nhd, (0, pad)) + k_nhd = F.pad(k_nhd, (0, pad)) + v_nhd = F.pad(v_nhd, (0, pad)) + head_dim = 128 + + if value_dtype == "fp16" and head_dim not in (16, 64): + raise ValueError("gfx12 fp16 value path currently supports head_dim 16 or 64.") + if value_dtype == "fp8" and head_dim not in (16, 64, 128): + raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") + + k_mean = k_nhd.mean(dim=1, keepdim=True) + use_rawq_tail = value_dtype == "fp8" and not is_causal and head_dim == 128 + if use_rawq_tail: + q_attn = q_nhd + q_out_len = ((qo_len + 127) // 128) * 128 + kv_pad_len = ((kv_len + 63) // 64) * 64 - kv_len + if kv_pad_len > 0: + k_nhd = torch.cat([k_nhd, k_mean.expand(-1, kv_pad_len, -1, -1)], dim=1) + v_nhd = F.pad(v_nhd, (0, 0, 0, 0, 0, kv_pad_len)) + else: + q_nhd, k_nhd, v_nhd = _pad_gfx12_nhd_sequence( + q_nhd, k_nhd, v_nhd, qo_len, kv_len, bool(is_causal), k_mean + ) + q_attn = q_nhd + q_out_len = q_nhd.size(1) + k_int8 = torch.empty_like(k_nhd, dtype=torch.int8) + k_scale = torch.empty( + (k_nhd.size(0), k_nhd.size(2), (k_nhd.size(1) + 63) // 64), + device=k_nhd.device, + dtype=torch.float32, + ) + _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( + k_nhd, k_mean.squeeze(1), k_int8, k_scale, 64, 0 + ) + value_scale = None + if value_dtype == "fp8": + v_hnd_for_value = v_nhd.transpose(1, 2).contiguous() + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd_for_value) + value_native = gfx12_native.transpose_value_fp8_scaled_hnd( + v_hnd_for_value, value_scale + ) + else: + value_native = v_nhd if input_dtype == torch.float16 else v_nhd.to(torch.float16) + out = torch.empty( + (q_nhd.size(0), q_out_len, q_nhd.size(2), q_nhd.size(3)), + device=q_nhd.device, + dtype=torch.float16, + ) + if value_dtype == "fp8": + gfx12_native.qk_rawq_int8_sv_f8_scaled_native_attn( + q_attn, + k_int8, + value_native, + out, + k_scale, + value_scale, + 0, + int(is_causal), + float(sm_scale), + kv_len, + ) + else: + q_int8, q_scale = gfx12_native.quant_q_nhd_per_warp(q_attn) + gfx12_native.qk_int8_sv_f16_d64_native_attn( + q_int8, + k_int8, + value_native, + out, + q_scale, + k_scale, + 0, + int(is_causal), + float(sm_scale), + kv_len, + ) + out = out[:, :qo_len, :, :head_dim_og] + if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: + out = gfx12_native.convert_f16_to_bf16(out.contiguous() if not out.is_contiguous() else out) + elif input_dtype != torch.float16: + out = out.to(input_dtype) + return out + + if tensor_layout == "NHD": + q_hnd = q.transpose(1, 2).contiguous() + k_hnd = k.transpose(1, 2).contiguous() + v_hnd = v.transpose(1, 2).contiguous() + else: + q_hnd = q.contiguous() + k_hnd = k.contiguous() + v_hnd = v.contiguous() + + _, h_qo, qo_len, head_dim_og = q_hnd.shape + _, h_kv, kv_len, _ = k_hnd.shape + if h_qo % h_kv != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads.") + if is_causal and qo_len != kv_len: + raise ValueError("gfx12 causal path currently requires q_len == kv_len.") + + head_dim = head_dim_og + if head_dim < 64 and (smooth_k or head_dim != 16): + pad = 64 - head_dim + q_hnd = F.pad(q_hnd, (0, pad)) + k_hnd = F.pad(k_hnd, (0, pad)) + v_hnd = F.pad(v_hnd, (0, pad)) + head_dim = 64 + elif value_dtype == "fp8" and 64 < head_dim < 128: + pad = 128 - head_dim + q_hnd = F.pad(q_hnd, (0, pad)) + k_hnd = F.pad(k_hnd, (0, pad)) + v_hnd = F.pad(v_hnd, (0, pad)) + head_dim = 128 + + if value_dtype == "fp16" and head_dim not in (16, 64): + raise ValueError("gfx12 fp16 value path currently supports head_dim 16 or 64.") + if value_dtype == "fp8" and head_dim not in (16, 64, 128): + raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") + + k_mean = k_hnd.mean(dim=2, keepdim=True) if smooth_k else None + q_hnd, k_hnd, v_hnd = _pad_gfx12_hnd_sequence( + q_hnd, k_hnd, v_hnd, qo_len, kv_len, bool(is_causal), k_mean) + padded_qo_len = q_hnd.size(2) + + use_raw_f16_value = ( + value_dtype == "fp16" + and input_dtype == torch.float16 + and is_causal + and head_dim == 64 + and padded_qo_len <= 512 + ) + + if not smooth_k: + if value_dtype == "fp8": + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q_hnd, k_hnd, None, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" + ) + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd) + value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) + out = torch.empty_like(q_hnd, dtype=torch.float16) + gfx12_native.qk_int8_sv_f8_scaled_native_attn( + q_int8, k_int8, value_native, out, q_scale, k_scale, value_scale, + 1, int(is_causal), float(sm_scale), kv_len + ) + else: + out = gfx12_prepare_attn_hnd( + q_hnd, + k_hnd, + v_hnd, + int(is_causal), + 0, + int(use_raw_f16_value), + float(sm_scale), + kv_len, + ) + else: + use_rawq_hnd_fp8 = ( + value_dtype == "fp8" + and head_dim in (64, 128) + and ( + not is_causal + or head_dim == 64 + or padded_qo_len <= 1024 + or padded_qo_len >= 8192 + ) + ) + if use_rawq_hnd_fp8: + k_int8 = torch.empty_like(k_hnd, dtype=torch.int8) + k_scale = torch.empty( + (k_hnd.size(0), k_hnd.size(1), (k_hnd.size(2) + 63) // 64), + device=k_hnd.device, + dtype=torch.float32, + ) + _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( + k_hnd, k_mean.squeeze(2), k_int8, k_scale, 64, 1 + ) + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd) + value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) + out = torch.empty_like( + q_hnd, + dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float16, + ) + gfx12_native.qk_rawq_int8_sv_f8_scaled_native_attn( + q_hnd, k_int8, value_native, out, k_scale, value_scale, + 1, int(is_causal), float(sm_scale), kv_len + ) + else: + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q_hnd, k_hnd, k_mean, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" + ) + out = torch.empty_like(q_hnd, dtype=torch.float16) + if value_dtype == "fp8": + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd) + value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) + gfx12_native.qk_int8_sv_f8_scaled_native_attn( + q_int8, k_int8, value_native, out, q_scale, k_scale, value_scale, + 1, int(is_causal), float(sm_scale), kv_len + ) + else: + value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) + gfx12_native.qk_int8_sv_f16_d64_native_attn( + q_int8, k_int8, value_native, out, q_scale, k_scale, + 1, int(is_causal), float(sm_scale), kv_len + ) + out = out[..., :qo_len, :head_dim_og] + if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: + out = gfx12_native.convert_f16_to_bf16(out.contiguous() if not out.is_contiguous() else out) + elif input_dtype != torch.float16: + out = out.to(input_dtype) + if tensor_layout == "NHD": + out = out.transpose(1, 2).contiguous() + return out def sageattn( @@ -134,9 +576,13 @@ def sageattn( - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. """ - - arch = _cuda_archs[q.device.index] - if arch == "sm75": + + arch = get_cuda_arch_versions()[q.device.index] + if arch.startswith("gfx12"): + return sageattn_qk_int8_pv_gfx12_native( + q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, + sm_scale=sm_scale, return_lse=return_lse, **kwargs) + elif arch == "sm75": return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) elif arch in {"sm80", "sm86", "sm87"}: return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") @@ -162,14 +608,14 @@ def sageattn( def sageattn_qk_int8_pv_fp16_triton( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, tensor_layout: str = "HND", quantization_backend: str = "triton", - is_causal: bool =False, + is_causal: bool =False, attn_mask: Optional[torch.Tensor] = None, - sm_scale: Optional[float] = None, + sm_scale: Optional[float] = None, smooth_k: bool = True, return_lse: bool = False, **kwargs: Any, @@ -237,7 +683,7 @@ def sageattn_qk_int8_pv_fp16_triton( Note ---- - - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``. - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. @@ -253,6 +699,14 @@ def sageattn_qk_int8_pv_fp16_triton( assert attn_mask.dtype == torch.bool or attn_mask.dtype == q.dtype, "attn_mask must be of dtype bool or the same dtype as q." assert attn_mask.device == q.device, "All tensors must be on the same device." + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + head_dim_og = q.size(-1) if head_dim_og < 64: @@ -328,15 +782,15 @@ def sageattn_qk_int8_pv_fp16_triton( def sageattn_varlen( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, is_causal: bool = False, - sm_scale: Optional[float] = None, + sm_scale: Optional[float] = None, smooth_k: bool = True, **kwargs: Any, ) -> torch.Tensor: @@ -354,23 +808,23 @@ def sageattn_varlen( The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``. cu_seqlens_q : torch.Tensor - The cumulative sequence lengths for the query sequences in the batch, used to index into `q`. + The cumulative sequence lengths for the query sequences in the batch, used to index into `q`. Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. cu_seqlens_k : torch.Tensor - The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`. + The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`. Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index. max_seqlen_q : int The maximum sequence length for the query tensor in the batch. - + max_seqlen_k : int The maximum sequence length for the key and value tensors in the batch. is_causal : bool Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence. Default: False. - + sm_scale : Optional[float] The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. @@ -391,13 +845,21 @@ def sageattn_varlen( - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. """ - + dtype = q.dtype assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + head_dim_og = q.size(-1) if head_dim_og < 64: @@ -437,8 +899,8 @@ def sageattn_varlen( def sageattn_qk_int8_pv_fp16_cuda( - q: torch.Tensor, - k: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -495,7 +957,7 @@ def sageattn_qk_int8_pv_fp16_cuda( smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. - + smooth_v : bool Whether to smooth the value tensor by subtracting the mean along the sequence dimension. smooth_v will be ignored if pv_accum_dtype is "fp32" or "fp16+fp32". @@ -519,7 +981,7 @@ def sageattn_qk_int8_pv_fp16_cuda( Note ---- - - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. @@ -533,6 +995,14 @@ def sageattn_qk_int8_pv_fp16_cuda( assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 @@ -614,8 +1084,8 @@ def sageattn_qk_int8_pv_fp16_cuda( def sageattn_qk_int8_pv_fp8_cuda( - q: torch.Tensor, - k: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -667,11 +1137,11 @@ def sageattn_qk_int8_pv_fp8_cuda( - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. Default: "fp32+fp32". - + smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. - + smooth_v : bool Whether to smooth the value tensor by subtracting the mean along the sequence dimension. smooth_v will be ignored if pv_accum_dtype is "fp32+fp32". @@ -695,7 +1165,7 @@ def sageattn_qk_int8_pv_fp8_cuda( Note ---- - - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. @@ -709,6 +1179,19 @@ def sageattn_qk_int8_pv_fp8_cuda( assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + # cuda_major_version, cuda_minor_version = get_cuda_version() + # if(cuda_major_version, cuda_minor_version) < (12, 8) and pv_accum_dtype == 'fp32+fp16': + # warnings.warn("cuda version < 12.8, change pv_accum_dtype to 'fp32+fp32'") + # pv_accum_dtype = 'fp32+fp32' + + # FIXME(DefTruth): make sage attention work compatible with distributed + # env, for example, xDiT which launch by torchrun. Without this workaround, + # sage attention will run into illegal memory access error after first + # inference step in distributed env for multi gpus inference. This small + # workaround also make sage attention work compatible with torch.compile + # through non-fullgraph compile mode. + torch.cuda.set_device(v.device) + _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 @@ -734,7 +1217,7 @@ def sageattn_qk_int8_pv_fp8_cuda( sm_scale = head_dim_og**-0.5 seq_dim = 1 if _tensor_layout == 0 else 2 - nh_dim = 2 if _tensor_layout == 0 else 1 + nh_dim = 2 if _tensor_layout == 0 else 1 if smooth_k: km = k.mean(dim=seq_dim, keepdim=True) @@ -794,8 +1277,8 @@ def sageattn_qk_int8_pv_fp8_cuda( def sageattn_qk_int8_pv_fp8_cuda_sm90( - q: torch.Tensor, - k: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, @@ -846,7 +1329,7 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( - "fp32": PV accumulation is done in fully in FP32. However, due to the hardware issue, there are only 22 valid bits in the FP32 accumulator. - "fp32+fp32": PV accumulation is done in FP32 (actually FP22), but added to a FP32 buffer every few iterations. This offers a balance between speed and accuracy. Default: "fp32+fp32". - + smooth_k : bool Whether to smooth the key tensor by subtracting the mean along the sequence dimension. Default: True. @@ -869,7 +1352,7 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( Note ---- - - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` - All tensors must be on the same cuda device. - `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances. @@ -883,6 +1366,8 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( assert q.device == k.device == v.device, "All tensors must be on the same device." assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + torch.cuda.set_device(v.device) + _tensor_layout = 0 if tensor_layout == "NHD" else 1 _is_caual = 1 if is_causal else 0 _qk_quant_gran = 3 if qk_quant_gran == "per_thread" else 2 diff --git a/sageattention/gfx12_native_compile.py b/sageattention/gfx12_native_compile.py new file mode 100644 index 00000000..40d44f72 --- /dev/null +++ b/sageattention/gfx12_native_compile.py @@ -0,0 +1,129 @@ +import torch + +from . import _qattn_gfx12_native +_qattn_gfx12_native = torch.ops.sageattention_qattn_gfx12_native + + +def _empty_lse(query: torch.Tensor) -> torch.Tensor: + return torch.empty((0,), dtype=torch.float32, device=query.device) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::qk_int8_sv_f16_d64_native_attn") +def qk_int8_sv_f16_d64_native_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + sm_scale: float, + valid_kv_len: int, +) -> torch.Tensor: + return _empty_lse(query) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::qk_rawq_int8_sv_f8_native_attn") +def qk_rawq_int8_sv_f8_native_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + sm_scale: float, + valid_kv_len: int, +) -> torch.Tensor: + return output + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::qk_int8_sv_f8_scaled_native_attn") +def qk_int8_sv_f8_scaled_native_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + query_scale: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + sm_scale: float, + valid_kv_len: int, +) -> torch.Tensor: + return _empty_lse(query) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::qk_rawq_int8_sv_f8_scaled_native_attn") +def qk_rawq_int8_sv_f8_scaled_native_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + key_scale: torch.Tensor, + value_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + sm_scale: float, + valid_kv_len: int, +) -> torch.Tensor: + return output + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::qk_int8_sv_f16_d64_prepare_attn_hnd") +def qk_int8_sv_f16_d64_prepare_attn_hnd_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: int, + value_is_fp8: int, + use_raw_f16_value: int, + sm_scale: float, + valid_kv_len: int, +) -> torch.Tensor: + dtype = torch.bfloat16 if value_is_fp8 and query.dtype == torch.bfloat16 else torch.float16 + return torch.empty(query.shape, dtype=dtype, device=query.device) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::quant_q_nhd_per_warp") +def quant_q_nhd_per_warp_fake_impl(query: torch.Tensor) -> list[torch.Tensor]: + batch, q_len, q_heads, _ = query.shape + q_scale_groups = ((q_len + 127) // 128) * 4 + return [ + torch.empty_like(query, dtype=torch.int8), + torch.empty((batch, q_heads, q_scale_groups), dtype=torch.float32, device=query.device), + ] + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::transpose_value_fp8_hnd") +def transpose_value_fp8_hnd_fake_impl(value: torch.Tensor) -> torch.Tensor: + return torch.empty( + (value.size(0), value.size(1), value.size(3), value.size(2)), + dtype=torch.uint8, + device=value.device, + ) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::transpose_value_fp8_scaled_hnd") +def transpose_value_fp8_scaled_hnd_fake_impl(value: torch.Tensor, value_scale: torch.Tensor) -> torch.Tensor: + return torch.empty( + (value.size(0), value.size(1), value.size(3), value.size(2)), + dtype=torch.uint8, + device=value.device, + ) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::transpose_value_f16_hnd") +def transpose_value_f16_hnd_fake_impl(value: torch.Tensor) -> torch.Tensor: + return torch.empty( + (value.size(0), value.size(1), value.size(3), value.size(2)), + dtype=torch.float16, + device=value.device, + ) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::convert_f16_to_bf16") +def convert_f16_to_bf16_fake_impl(input: torch.Tensor) -> torch.Tensor: + return torch.empty(input.shape, dtype=torch.bfloat16, device=input.device) diff --git a/setup.py b/setup.py index 1d08f50c..8a587d63 100644 --- a/setup.py +++ b/setup.py @@ -32,199 +32,352 @@ ext_modules = [] cmdclass = {} -if not SKIP_CUDA_BUILD: - import torch - from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME - # Compiler flags. +def append_env_flags(flags, env_name): + extra = os.getenv(env_name, "").strip() + if extra: + flags += extra.split() + + +def unique_paths(paths): + out = [] + seen = set() + for path in paths: + if path and path not in seen: + out.append(path) + seen.add(path) + return out + + +def rocm_sdk_path(which): + try: + return subprocess.check_output(["rocm-sdk", "path", f"--{which}"], text=True).strip() + except Exception: + return None + + +def configure_rocm(default_rocm_home): + sdk_root = rocm_sdk_path("root") + sdk_bin = rocm_sdk_path("bin") + rocm_home = sdk_root or default_rocm_home or os.getenv("ROCM_HOME") + if not rocm_home: + raise RuntimeError("Cannot find ROCm. Activate a ROCm-enabled PyTorch environment.") + + os.environ["ROCM_HOME"] = rocm_home if os.name == "nt": - # TODO: Detect MSVC rather than OS - CXX_FLAGS = ["/O2", "/openmp", "/std:c++17", "/permissive-", "-DENABLE_BF16"] - else: - CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] - CXX_FLAGS += ["-DPy_LIMITED_API=0x03090000", "-DTORCH_STABLE_ONLY"] - - NVCC_FLAGS_COMMON = [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "--use_fast_math", - f"--threads={os.cpu_count()}", - # "-Xptxas=-v", - "-diag-suppress=174", - "-diag-suppress=177", - "-diag-suppress=221", - "-DPy_LIMITED_API=0x03090000", - "-DTORCH_STABLE_ONLY", + os.environ.setdefault("CC", "clang-cl") + os.environ.setdefault("CXX", "clang-cl") + os.environ.setdefault("DISTUTILS_USE_SDK", "1") + + path_parts = [ + os.path.join(rocm_home, "lib", "llvm", "bin"), + os.path.join(rocm_home, "bin"), + sdk_bin, ] - if os.name == "nt": - # https://github.com/pytorch/pytorch/issues/148317 - NVCC_FLAGS_COMMON += [ - "-Xcompiler=/Zc:preprocessor", - "-D_WIN32=1", - "-DUSE_CUDA=1", - ] - - # Append flags from env if provided - cxx_append = os.getenv("CXX_APPEND_FLAGS", "").strip() - if cxx_append: - CXX_FLAGS += cxx_append.split() - nvcc_append = os.getenv("NVCC_APPEND_FLAGS", "").strip() - if nvcc_append: - NVCC_FLAGS_COMMON += nvcc_append.split() - - if os.name != "nt": - ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 - CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] - NVCC_FLAGS_COMMON += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + os.environ["PATH"] = os.pathsep.join(unique_paths(path_parts) + [os.environ.get("PATH", "")]) + return rocm_home + + +def rocm_arches(torch): + arch_env = os.getenv("GPU_ARCHS") or os.getenv("PYTORCH_ROCM_ARCH") + if arch_env: + archs = [] + for arch in arch_env.replace(";", " ").replace(",", " ").split(): + arch = arch.strip() + if arch: + archs.append(arch.split(":", 1)[0]) + return archs + + archs = [] + if torch.cuda.is_available(): + for device_idx in range(torch.cuda.device_count()): + props = torch.cuda.get_device_properties(device_idx) + arch = getattr(props, "gcnArchName", "") + if arch: + archs.append(arch.split(":", 1)[0]) + return archs - if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") - - def get_nvcc_cuda_version(cuda_dir: str) -> Version: - """Get the CUDA version from nvcc. - - Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py - """ - nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], - universal_newlines=True) - output = nvcc_output.split() - release_idx = output.index("release") + 1 - nvcc_cuda_version = parse(output[release_idx].split(",")[0]) - return nvcc_cuda_version - - # Determine target compute capabilities - compute_capabilities = set() - if os.getenv("TORCH_CUDA_ARCH_LIST"): - # Prefer TORCH_CUDA_ARCH_LIST if explicitly specified (works without GPUs) - # TORCH_CUDA_ARCH_LIST is separated by space or semicolon - for x in os.getenv("TORCH_CUDA_ARCH_LIST").replace(";", " ").split(): - compute_capabilities.add(x) - else: - # If not provided, try to detect from local GPUs - device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") - continue - compute_capabilities.add(f"{major}.{minor}") - - nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) - - if not compute_capabilities: - raise RuntimeError( - "No target compute capabilities. Set TORCH_CUDA_ARCH_LIST or build on a machine with GPUs.") - else: - print(f"Target compute capabilities: {compute_capabilities}") - - def has_capability(target): - return any(cc.startswith(target) for cc in compute_capabilities) - - # Validate the NVCC CUDA version. - if nvcc_cuda_version < Version("12.4") and has_capability("8.9"): - raise RuntimeError( - "CUDA 12.4 or higher is required for compute capability 8.9.") - if nvcc_cuda_version < Version("12.3") and has_capability("9.0"): - raise RuntimeError( - "CUDA 12.3 or higher is required for compute capability 9.0.") - if nvcc_cuda_version < Version("12.8") and has_capability("12.0"): - raise RuntimeError( - "CUDA 12.8 or higher is required for compute capability 12.0.") - - # Add target compute capabilities to NVCC flags. - def get_nvcc_flags(allowed_capabilities): - NVCC_FLAGS = [] - for capability in compute_capabilities: - if capability not in allowed_capabilities: - continue - - # capability: "8.0+PTX" -> num: "80" - num = capability.split("+")[0].replace(".", "") - if num in {"90", "100", "120", "121"}: - # need to use sm90a instead of sm90 to use wgmma ptx instruction. - # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions. - num += "a" - - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - NVCC_FLAGS += NVCC_FLAGS_COMMON - return NVCC_FLAGS - - if has_capability(("8.0", "8.6", "8.7")): - ext_modules.append( - CUDAExtension( - name="sageattention._qattn_sm80", - sources=[ - "csrc/qattn/pybind_sm80.cpp", - "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", - ], - extra_compile_args={ - "cxx": CXX_FLAGS, - # Build binary for sm80 if sm86/87 is detected. No need to build binary for sm86/87 - "nvcc": get_nvcc_flags(["8.0"]), - }, - py_limited_api=True, + +if not SKIP_CUDA_BUILD: + import torch + import torch.utils.cpp_extension as cpp_extension + from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME + + LIMITED_API_FLAGS = ["-DPy_LIMITED_API=0x03090000", "-DTORCH_STABLE_ONLY"] + ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 + + if torch.version.hip is not None: + rocm_home = configure_rocm(ROCM_HOME) + cpp_extension.ROCM_HOME = rocm_home + amd_arches = rocm_arches(torch) or ["gfx1201"] + os.environ.setdefault("PYTORCH_ROCM_ARCH", ";".join(amd_arches)) + print(f"Target AMD GPU architectures: {amd_arches}") + + if os.name == "nt": + CXX_FLAGS = [ + "/O2", + "/std:c++17", + "/permissive-", + "/DENABLE_BF16", + f"/D_GLIBCXX_USE_CXX11_ABI={ABI}", + ] + else: + CXX_FLAGS = [ + "-g", + "-O3", + "-fopenmp", + "-lgomp", + "-std=c++17", + "-DENABLE_BF16", + f"-D_GLIBCXX_USE_CXX11_ABI={ABI}", + ] + CXX_FLAGS += LIMITED_API_FLAGS + + HIP_FLAGS = [ + "-O3", + "-std=c++17", + "-ffast-math", + "-fgpu-flush-denormals-to-zero", + "-fno-offload-uniform-block", + "-D__HIP_PLATFORM_AMD__=1", + "-U__HIP_NO_HALF_OPERATORS__", + "-U__HIP_NO_HALF_CONVERSIONS__", + f"-D_GLIBCXX_USE_CXX11_ABI={ABI}", + "-mllvm", + "--lsr-drop-solution=1", + "-mllvm", + "-enable-post-misched=1", + "-mllvm", + "-amdgpu-early-inline-all=true", + "-mllvm", + "-amdgpu-function-calls=false", + ] + LIMITED_API_FLAGS + for arch in amd_arches: + HIP_FLAGS.append(f"--offload-arch={arch}") + + append_env_flags(CXX_FLAGS, "CXX_APPEND_FLAGS") + append_env_flags(HIP_FLAGS, "NVCC_APPEND_FLAGS") + append_env_flags(HIP_FLAGS, "HIPCC_APPEND_FLAGS") + + include_dirs = unique_paths([os.path.join(rocm_home, "include")]) + + if any(arch.startswith("gfx12") for arch in amd_arches): + ext_modules.append( + CUDAExtension( + name="sageattention._qattn_gfx12_native", + sources=[ + "csrc/qattn/pybind_gfx12_native.cpp", + "csrc/qattn/qk_int_sv_gfx12_native.cu", + ], + include_dirs=include_dirs, + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": HIP_FLAGS, + }, + py_limited_api=True, + ) + ) + else: + warnings.warn( + "ROCm build detected, but no gfx12 architecture was selected; " + "skipping the gfx12 native attention extension." ) - ) - if has_capability(("8.9", "10.0", "12.0", "12.1")): ext_modules.append( CUDAExtension( - name="sageattention._qattn_sm89", + name="sageattention._fused", sources=[ - "csrc/qattn/pybind_sm89.cpp", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", - "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu", + "csrc/fused/pybind.cpp", + "csrc/fused/fused.cu", ], + include_dirs=include_dirs, extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": get_nvcc_flags(["8.9", "10.0", "12.0", "12.1"]), + "nvcc": HIP_FLAGS, }, py_limited_api=True, ) ) + else: + # Compiler flags. + if os.name == "nt": + # TODO: Detect MSVC rather than OS + CXX_FLAGS = ["/O2", "/openmp", "/std:c++17", "/permissive-", "-DENABLE_BF16"] + else: + CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] + CXX_FLAGS += LIMITED_API_FLAGS + + NVCC_FLAGS_COMMON = [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "--use_fast_math", + f"--threads={os.cpu_count()}", + # "-Xptxas=-v", + "-diag-suppress=174", + "-diag-suppress=177", + "-diag-suppress=221", + ] + LIMITED_API_FLAGS + if os.name == "nt": + # https://github.com/pytorch/pytorch/issues/148317 + NVCC_FLAGS_COMMON += [ + "-D_WIN32=1", + "-DUSE_CUDA=1", + ] + + append_env_flags(CXX_FLAGS, "CXX_APPEND_FLAGS") + append_env_flags(NVCC_FLAGS_COMMON, "NVCC_APPEND_FLAGS") + + CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + NVCC_FLAGS_COMMON += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") + + def get_nvcc_cuda_version(cuda_dir: str) -> Version: + """Get the CUDA version from nvcc. + + Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py + """ + nvcc_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + # Determine target compute capabilities + compute_capabilities = set() + if os.getenv("TORCH_CUDA_ARCH_LIST"): + # Prefer TORCH_CUDA_ARCH_LIST if explicitly specified (works without GPUs) + # TORCH_CUDA_ARCH_LIST is separated by space or semicolon + for x in os.getenv("TORCH_CUDA_ARCH_LIST").replace(";", " ").split(): + compute_capabilities.add(x) + else: + # If not provided, try to detect from local GPUs + device_count = torch.cuda.device_count() if torch.cuda.is_available() else 0 + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + warnings.warn(f"skipping GPU {i} with compute capability {major}.{minor}") + continue + compute_capabilities.add(f"{major}.{minor}") + + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + + if not compute_capabilities: + raise RuntimeError( + "No target compute capabilities. Set TORCH_CUDA_ARCH_LIST or build on a machine with GPUs.") + else: + print(f"Target compute capabilities: {compute_capabilities}") + + def has_capability(target): + return any(cc.startswith(target) for cc in compute_capabilities) + + # Validate the NVCC CUDA version. + if nvcc_cuda_version < Version("12.4") and has_capability("8.9"): + raise RuntimeError( + "CUDA 12.4 or higher is required for compute capability 8.9.") + if nvcc_cuda_version < Version("12.3") and has_capability("9.0"): + raise RuntimeError( + "CUDA 12.3 or higher is required for compute capability 9.0.") + if nvcc_cuda_version < Version("12.8") and has_capability("12.0"): + raise RuntimeError( + "CUDA 12.8 or higher is required for compute capability 12.0.") + + # Add target compute capabilities to NVCC flags. + def get_nvcc_flags(allowed_capabilities): + NVCC_FLAGS = [] + for capability in compute_capabilities: + if capability not in allowed_capabilities: + continue + + # capability: "8.0+PTX" -> num: "80" + num = capability.split("+")[0].replace(".", "") + if num in {"90", "100", "120", "121"}: + # need to use sm90a instead of sm90 to use wgmma ptx instruction. + # need to use sm120a to use mxfp8/mxfp4/nvfp4 instructions. + num += "a" + + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + NVCC_FLAGS += NVCC_FLAGS_COMMON + return NVCC_FLAGS + + if has_capability(("8.0", "8.6", "8.7")): + ext_modules.append( + CUDAExtension( + name="sageattention._qattn_sm80", + sources=[ + "csrc/qattn/pybind_sm80.cpp", + "csrc/qattn/qk_int_sv_f16_cuda_sm80.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + # Build binary for sm80 if sm86/87 is detected. No need to build binary for sm86/87 + "nvcc": get_nvcc_flags(["8.0"]), + }, + py_limited_api=True, + ) + ) + + if has_capability(("8.9", "10.0", "12.0", "12.1")): + ext_modules.append( + CUDAExtension( + name="sageattention._qattn_sm89", + sources=[ + "csrc/qattn/pybind_sm89.cpp", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn_inst_buf.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_attn_inst_buf.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_attn.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf.cu", + "csrc/qattn/sm89_qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": get_nvcc_flags(["8.9", "10.0", "12.0", "12.1"]), + }, + py_limited_api=True, + ) + ) + + if has_capability("9.0"): + ext_modules.append( + CUDAExtension( + name="sageattention._qattn_sm90", + sources=[ + "csrc/qattn/pybind_sm90.cpp", + "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", + ], + libraries=["cuda"], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": get_nvcc_flags(["9.0"]), + }, + py_limited_api=True, + ) + ) - if has_capability("9.0"): ext_modules.append( CUDAExtension( - name="sageattention._qattn_sm90", + name="sageattention._fused", sources=[ - "csrc/qattn/pybind_sm90.cpp", - "csrc/qattn/qk_int_sv_f8_cuda_sm90.cu", + "csrc/fused/pybind.cpp", + "csrc/fused/fused.cu", ], - libraries=["cuda"], extra_compile_args={ "cxx": CXX_FLAGS, - "nvcc": get_nvcc_flags(["9.0"]), + "nvcc": get_nvcc_flags(["8.0", "8.9", "9.0", "10.0", "12.0", "12.1"]), }, py_limited_api=True, ) ) - ext_modules.append( - CUDAExtension( - name="sageattention._fused", - sources=[ - "csrc/fused/pybind.cpp", - "csrc/fused/fused.cu", - ], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": get_nvcc_flags(["8.0", "8.9", "9.0", "10.0", "12.0", "12.1"]), - }, - py_limited_api=True, - ) - ) - # Resolve parallelism from env parallel = None if 'EXT_PARALLEL' in os.environ: From 155ef96df54e5037654431c20800f210fa26d642 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Sat, 16 May 2026 01:33:06 +0900 Subject: [PATCH 2/7] Port latest gfx12 parity fixes to ABI3 fork --- csrc/qattn/attn_gfx12_native.h | 9 +- csrc/qattn/pybind_gfx12_native.cpp | 13 +- csrc/qattn/qk_int_sv_gfx12_native.cu | 120 ++++++++++++++---- sageattention/core.py | 169 ++++++++++++++++++++------ sageattention/gfx12_native_compile.py | 13 +- 5 files changed, 248 insertions(+), 76 deletions(-) diff --git a/csrc/qattn/attn_gfx12_native.h b/csrc/qattn/attn_gfx12_native.h index 2296c0c4..ce8cf2be 100644 --- a/csrc/qattn/attn_gfx12_native.h +++ b/csrc/qattn/attn_gfx12_native.h @@ -22,7 +22,8 @@ Tensor qk_int8_sv_f16_d64_native_attn_gfx12( int64_t tensor_layout, int64_t is_causal, double sm_scale, - int64_t valid_kv_len = 0); + int64_t valid_kv_len = 0, + int64_t value_transposed_hnd = -1); Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( Tensor query, @@ -33,7 +34,8 @@ Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( int64_t tensor_layout, int64_t is_causal, double sm_scale, - int64_t valid_kv_len = 0); + int64_t valid_kv_len = 0, + int64_t value_transposed_hnd = -1); Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( Tensor query, @@ -58,7 +60,8 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( int64_t tensor_layout, int64_t is_causal, double sm_scale, - int64_t valid_kv_len = 0); + int64_t valid_kv_len = 0, + int64_t value_transposed_hnd = -1); Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( Tensor query, diff --git a/csrc/qattn/pybind_gfx12_native.cpp b/csrc/qattn/pybind_gfx12_native.cpp index 18dfc20e..f0931137 100644 --- a/csrc/qattn/pybind_gfx12_native.cpp +++ b/csrc/qattn/pybind_gfx12_native.cpp @@ -32,7 +32,8 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { "int tensor_layout, " "int is_causal, " "float sm_scale, " - "int valid_kv_len" + "int valid_kv_len=0, " + "int value_transposed_hnd=-1" ") -> Tensor"); m.def("qk_rawq_int8_sv_f8_native_attn(" "Tensor query, " @@ -43,7 +44,8 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { "int tensor_layout, " "int is_causal, " "float sm_scale, " - "int valid_kv_len" + "int valid_kv_len=0, " + "int value_transposed_hnd=-1" ") -> Tensor"); m.def("qk_int8_sv_f8_scaled_native_attn(" "Tensor query, " @@ -56,7 +58,7 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { "int tensor_layout, " "int is_causal, " "float sm_scale, " - "int valid_kv_len" + "int valid_kv_len=0" ") -> Tensor"); m.def("qk_rawq_int8_sv_f8_scaled_native_attn(" "Tensor query, " @@ -68,7 +70,8 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { "int tensor_layout, " "int is_causal, " "float sm_scale, " - "int valid_kv_len" + "int valid_kv_len=0, " + "int value_transposed_hnd=-1" ") -> Tensor"); m.def("qk_int8_sv_f16_d64_prepare_attn_hnd(" "Tensor query, " @@ -78,7 +81,7 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { "int value_is_fp8, " "int use_raw_f16_value, " "float sm_scale, " - "int valid_kv_len" + "int valid_kv_len=0" ") -> Tensor"); m.def("quant_q_nhd_per_warp(Tensor query) -> Tensor[]"); m.def("transpose_value_fp8_hnd(Tensor value) -> Tensor"); diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu index 5011e264..ea1cc15d 100644 --- a/csrc/qattn/qk_int_sv_gfx12_native.cu +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -1871,8 +1871,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int64_t ks_stride_h, const int tensor_layout, const float sm_scale) { - static_assert(HeadDim == 16 || HeadDim == 64, - "native gfx12 fp16 2q kernel supports D16/D64."); + static_assert(HeadDim == 16 || HeadDim == 64 || HeadDim == 128, + "native gfx12 fp16 2q kernel supports D16/D64/D128."); constexpr int BR = BlockRows; constexpr int RM = 16; constexpr int RowsPerWave = 32; @@ -5915,7 +5915,8 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( int is_causal, float sm_scale, int64_t valid_kv_len, - const float* value_scale_ptr = nullptr); + const float* value_scale_ptr = nullptr, + int value_transposed_hnd_hint = -1); Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( Tensor query, @@ -5938,8 +5939,8 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( "native gfx12 prepare+attention supports fp16/bf16 input"); const int64_t head_dim = query.size(3); - STD_TORCH_CHECK(value_is_fp8 || head_dim == 16 || head_dim == 64, - "native gfx12 fp16 value prepare+attention supports head_dim 16 or 64"); + STD_TORCH_CHECK(value_is_fp8 || head_dim == 16 || head_dim == 64 || head_dim == 128, + "native gfx12 fp16 value prepare+attention supports head_dim 16, 64, or 128"); STD_TORCH_CHECK(value_is_fp8 || !use_raw_f16_value || query.scalar_type() == ScalarType::Half, "raw fp16 value path requires fp16 input"); const int64_t batch = query.size(0); @@ -6486,7 +6487,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( std::vector prepared = quant_qk_int8_hnd_gfx12(query, key); qk_int8_sv_f16_d64_native_attn_gfx12_impl( prepared[0], prepared[2], value, output, prepared[1], prepared[3], - kHND, is_causal, sm_scale, kv_len); + kHND, is_causal, sm_scale, kv_len, nullptr, 0); } else { const bool use_f16_separate_prepared = is_causal && head_dim == 64 && q_len == 4096 && @@ -6497,7 +6498,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( prepare_qkv_hnd_packed_gfx12<__half, false>(query, key, value); qk_int8_sv_f16_d64_native_attn_gfx12_impl( prepared[0], prepared[2], prepared[4], output, prepared[1], prepared[3], - kHND, is_causal, sm_scale, kv_len); + kHND, is_causal, sm_scale, kv_len, nullptr, 1); } return output; } @@ -6513,7 +6514,8 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( int is_causal, float sm_scale, int64_t valid_kv_len, - const float* value_scale_ptr) { + const float* value_scale_ptr, + int value_transposed_hnd_hint) { STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), "native gfx12 tensors must be CUDA/HIP tensors"); STD_TORCH_CHECK(query.scalar_type() == ScalarType::Char, "query must be int8"); @@ -6531,12 +6533,16 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( STD_TORCH_CHECK(tensor_layout == kHND || tensor_layout == kNHD, "invalid tensor_layout"); const int64_t head_dim = query.size(-1); const bool value_maybe_transposed_hnd = - tensor_layout == kHND && value.dim() == 4 && value.size(2) == head_dim; + tensor_layout == kHND && value.dim() == 4 && value.size(2) == head_dim && + (value_is_fp8 || value_transposed_hnd_hint > 0 || + (value_transposed_hnd_hint < 0 && (value_is_fp8 || head_dim != 128))); STD_TORCH_CHECK(key.size(-1) == head_dim && (value.size(-1) == head_dim || value_maybe_transposed_hnd), "query, key, and value must have matching head_dim"); - STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || (value_is_fp8 && head_dim == 128), - "native gfx12 path supports D16/D64, plus D128 for the fp8 2q path"); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "native gfx12 path supports D16/D64/D128"); + STD_TORCH_CHECK(value_is_fp8 || head_dim != 128 || value_maybe_transposed_hnd, + "native gfx12 fp16 D128 path requires transposed HND values"); const int64_t batch = query.size(0); const int64_t q_heads = tensor_layout == kNHD ? query.size(2) : query.size(1); @@ -6732,8 +6738,8 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } else { \ SAGEATTN_LAUNCH_FP8_2Q_TV_OUT(BC_, HD_, BR_, __half); \ } -#define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ +#define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(HD_, BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -6746,6 +6752,8 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ tensor_layout, sm_scale) +#define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) \ + SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(64, BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) #define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_) \ if (use_f16_flat_q_schedule) { \ SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, grid_f16_flat, true); \ @@ -6766,6 +6774,24 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ tensor_layout, sm_scale) +#define SAGEATTN_LAUNCH_F16_D128_2Q_TV(BC_, BR_, PAD_) \ + if (is_causal) { \ + SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(128, BC_, BR_, PAD_, false, false, false, false, false, grid, false); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale); \ + } #define SAGEATTN_LAUNCH_F16_2Q_TV(BC_, BR_, PAD_) \ if (is_causal) { \ if (use_f16_pv_accum) { \ @@ -7083,7 +7109,25 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } } else if (use_2q && value_transposed_hnd) { STD_TORCH_CHECK(hnd_contiguous, "transposed fp16 value path requires contiguous HND Q/K/O"); - if (head_dim == 16) { + if (head_dim == 128) { + if (block_rows == 32) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 32, 4); + } else if (block_rows == 64) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 64, 4); + } else if (block_rows == 256) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 256, 4); + } else if (block_rows == 512) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 512, 4); + } else if (block_rows == 1024) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 1024, 4); + } else if (q_len >= 8192) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 128, 8); + } else if (q_len >= 1024) { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 128, 4); + } else { + SAGEATTN_LAUNCH_F16_D128_2Q_TV(64, 128, 16); + } + } else if (head_dim == 16) { if (is_causal) { if (block_rows == 32) { SAGEATTN_LAUNCH_F16_D16_2Q_TV(64, 32, 4, true, true); @@ -7212,7 +7256,8 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( int tensor_layout, int is_causal, float sm_scale, - int64_t valid_kv_len) { + int64_t valid_kv_len, + int value_transposed_hnd_hint) { STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), "raw-Q gfx12 tensors must be CUDA/HIP tensors"); STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4 && output.dim() == 4, @@ -7249,11 +7294,32 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( STD_TORCH_CHECK(key.size(0) == batch && value.size(0) == batch && output.size(0) == batch, "raw-Q gfx12 batch size mismatch"); - const bool value_transposed_hnd = + STD_TORCH_CHECK(value_transposed_hnd_hint >= -1 && value_transposed_hnd_hint <= 1, + "value_transposed_hnd must be -1, 0, or 1"); + const bool value_shape_transposed_hnd = value.size(1) == kv_heads && value.size(2) == head_dim && value.size(3) >= padded_kv_len; + const bool value_shape_normal = + (tensor_layout == kNHD && + value.size(1) == padded_kv_len && value.size(2) == kv_heads && + value.size(3) == head_dim) || + (tensor_layout == kHND && + value.size(1) == kv_heads && value.size(2) == padded_kv_len && + value.size(3) == head_dim); + const bool value_layout_ambiguous = + value_shape_transposed_hnd && value_shape_normal; + STD_TORCH_CHECK(value_transposed_hnd_hint <= 0 || value_shape_transposed_hnd, + "value_transposed_hnd=1 requires value shape [B, H, D, padded_kv_len]"); + STD_TORCH_CHECK(value_transposed_hnd_hint != 0 || value_shape_normal, + "value_transposed_hnd=0 requires normal value layout"); + STD_TORCH_CHECK(value_transposed_hnd_hint >= 0 || !value_layout_ambiguous, + "raw-Q gfx12 value layout is ambiguous; pass value_transposed_hnd=0 " + "for normal layout or 1 for transposed HND [B, H, D, padded_kv_len]"); + const bool value_transposed_hnd = + value_transposed_hnd_hint > 0 || + (value_transposed_hnd_hint < 0 && value_shape_transposed_hnd); STD_TORCH_CHECK(key.size(-1) == head_dim && output.size(-1) == head_dim && - (value_transposed_hnd || value.size(-1) == head_dim), + (value_transposed_hnd || value_shape_normal), "raw-Q gfx12 Q/K/V/O head_dim mismatch"); STD_TORCH_CHECK((tensor_layout == kNHD && ((value_transposed_hnd && output.size(1) >= q_len && @@ -7376,11 +7442,13 @@ Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( int64_t tensor_layout, int64_t is_causal, double sm_scale, - int64_t valid_kv_len) { + int64_t valid_kv_len, + int64_t value_transposed_hnd) { return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( query, key, value, output, key_scale, nullptr, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len); + static_cast(sm_scale), valid_kv_len, + static_cast(value_transposed_hnd)); } Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( @@ -7393,7 +7461,8 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( int64_t tensor_layout, int64_t is_causal, double sm_scale, - int64_t valid_kv_len) { + int64_t valid_kv_len, + int64_t value_transposed_hnd) { const int64_t head_dim = query.size(-1); const int64_t batch = query.size(0); const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); @@ -7402,7 +7471,8 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( query, key, value, output, key_scale, value_scale_ptr, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len); + static_cast(sm_scale), valid_kv_len, + static_cast(value_transposed_hnd)); } Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( @@ -7425,7 +7495,7 @@ Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( return qk_int8_sv_f16_d64_native_attn_gfx12_impl( query, key, value, output, query_scale, key_scale, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, value_scale_ptr); + static_cast(sm_scale), valid_kv_len, value_scale_ptr, 1); } Tensor qk_int8_sv_f16_d64_native_attn_gfx12( @@ -7438,9 +7508,11 @@ Tensor qk_int8_sv_f16_d64_native_attn_gfx12( int64_t tensor_layout, int64_t is_causal, double sm_scale, - int64_t valid_kv_len) { + int64_t valid_kv_len, + int64_t value_transposed_hnd) { return qk_int8_sv_f16_d64_native_attn_gfx12_impl( query, key, value, output, query_scale, key_scale, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len); + static_cast(sm_scale), valid_kv_len, nullptr, + static_cast(value_transposed_hnd)); } diff --git a/sageattention/core.py b/sageattention/core.py index 42ca0f9c..0b323139 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -158,11 +158,57 @@ def _pad_gfx12_nhd_sequence( return q_nhd, k_nhd, v_nhd -_GFX12_FP8_VALUE_SCALE_MAX = 2.25 +_GFX12_FP8_VALUE_SCALE_MAX_FP32_FP16 = 2.25 -def _gfx12_fp8_value_scale_hnd(v_hnd: torch.Tensor) -> torch.Tensor: - return v_hnd.abs().amax(dim=2).to(torch.float32).div(_GFX12_FP8_VALUE_SCALE_MAX).contiguous() +def _gfx12_fp8_value_scale_hnd(v_hnd: torch.Tensor, scale_max: float) -> torch.Tensor: + return v_hnd.abs().amax(dim=2).to(torch.float32).div(scale_max).contiguous() + + +def _attention_lse_reference( + q: torch.Tensor, + k: torch.Tensor, + tensor_layout: str, + is_causal: bool, + sm_scale: float, + block_q: int = 128, + max_score_elems: int = 8 * 1024 * 1024, +) -> torch.Tensor: + if tensor_layout == "NHD": + q_hnd = q.transpose(1, 2) + k_hnd = k.transpose(1, 2) + else: + q_hnd = q + k_hnd = k + + bsz, num_q_heads, q_len, _ = q_hnd.shape + _, num_kv_heads, kv_len, _ = k_hnd.shape + if num_q_heads % num_kv_heads != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads.") + + heads_per_kv = num_q_heads // num_kv_heads + block_q = max(1, min(block_q, max_score_elems // max(1, bsz * heads_per_kv * kv_len))) + lse = torch.empty((bsz, num_q_heads, q_len), device=q.device, dtype=torch.float32) + q_float = q_hnd.to(torch.float32) + k_float = k_hnd.to(torch.float32) + + for hkv in range(num_kv_heads): + h_start = hkv * heads_per_kv + h_stop = h_start + heads_per_kv + k_head = k_float[:, hkv] + for q_start in range(0, q_len, block_q): + q_stop = min(q_start + block_q, q_len) + scores = torch.einsum( + "bhsd,btd->bhst", + q_float[:, h_start:h_stop, q_start:q_stop], + k_head, + ).mul_(sm_scale) + if is_causal: + q_idx = torch.arange(q_start, q_stop, device=q.device)[:, None] + k_idx = torch.arange(kv_len, device=q.device)[None, :] + scores.masked_fill_(k_idx > q_idx, float("-inf")) + lse[:, h_start:h_stop, q_start:q_stop] = torch.logsumexp(scores, dim=-1) + return lse def sageattn_qk_int8_pv_gfx12_native( @@ -171,9 +217,12 @@ def sageattn_qk_int8_pv_gfx12_native( v: torch.Tensor, tensor_layout: str = "HND", is_causal: bool = False, + qk_quant_gran: str = "per_warp", sm_scale: Optional[float] = None, + pv_accum_dtype: Optional[str] = None, value_dtype: str = "fp8", smooth_k: bool = True, + smooth_v: bool = False, return_lse: bool = False, **kwargs: Any, ) -> torch.Tensor: @@ -187,17 +236,51 @@ def sageattn_qk_int8_pv_gfx12_native( Current gfx12 constraints: - q, k, and v must be fp16 or bf16. - value_dtype="fp8" supports head_dim 16, 64, or 128. - - value_dtype="fp16" supports head_dim 16 or 64. + - value_dtype="fp16" supports head_dim 16, 64, or 128. - Causal masking requires q_len == kv_len. - smooth_k is enabled by default to match the CUDA and Triton paths. - - return_lse is not implemented yet. + - return_lse uses an exact PyTorch logsumexp side computation and does + not affect the default return_lse=False fast path. """ - gfx12_native = _get_gfx12_native_extension() - gfx12_prepare_attn_hnd = _qattn_gfx12_prepare_attn_hnd + if kwargs: + unsupported = ", ".join(sorted(key for key, value in kwargs.items() if value is not None)) + if unsupported: + raise TypeError(f"Unsupported non-None gfx12 SageAttention arguments: {unsupported}") + if qk_quant_gran not in {"per_warp", "per_thread"}: + raise ValueError("qk_quant_gran must be either 'per_warp' or 'per_thread'.") + if qk_quant_gran != "per_warp": + raise NotImplementedError("gfx12 native currently supports qk_quant_gran='per_warp'.") + if smooth_v: + raise NotImplementedError("gfx12 native does not support smooth_v yet.") value_dtype_normalized = value_dtype.lower() if value_dtype_normalized == "auto": value_dtype_normalized = "fp8" + if pv_accum_dtype is None: + pv_accum_dtype = "fp32+fp16" if value_dtype_normalized == "fp8" else "fp32" + if value_dtype_normalized == "fp8" and pv_accum_dtype not in {"fp32+fp16", "fp32", "fp32+fp32"}: + raise ValueError("gfx12 fp8 value path supports pv_accum_dtype 'fp32+fp16', 'fp32', or 'fp32+fp32'.") + if value_dtype_normalized == "fp8" and pv_accum_dtype != "fp32+fp16": + raise NotImplementedError("gfx12 fp8 value path currently supports pv_accum_dtype='fp32+fp16'.") + if value_dtype_normalized == "fp16" and pv_accum_dtype not in {"fp32", "fp16", "fp16+fp32"}: + raise ValueError("gfx12 fp16 value path supports pv_accum_dtype 'fp32', 'fp16', or 'fp16+fp32'.") + if value_dtype_normalized == "fp16" and pv_accum_dtype != "fp32": + raise NotImplementedError("gfx12 fp16 value path currently supports pv_accum_dtype='fp32'.") + fp8_value_scale_max = _GFX12_FP8_VALUE_SCALE_MAX_FP32_FP16 + gfx12_native = _get_gfx12_native_extension() + gfx12_prepare_attn_hnd = _qattn_gfx12_prepare_attn_hnd + + lse_q = q + lse_k = k + lse_sm_scale = float(sm_scale if sm_scale is not None else q.size(-1) ** -0.5) + + def _with_lse(out: torch.Tensor): + if not return_lse: + return out + return out, _attention_lse_reference( + lse_q, lse_k, tensor_layout, bool(is_causal), lse_sm_scale + ) + if ( tensor_layout == "HND" and value_dtype_normalized == "fp16" @@ -234,8 +317,6 @@ def sageattn_qk_int8_pv_gfx12_native( assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." assert q.dtype in [torch.float16, torch.bfloat16], "gfx12 native path supports fp16/bf16 inputs." assert tensor_layout in ["HND", "NHD"], "tensor_layout must be either 'HND' or 'NHD'." - if return_lse: - raise ValueError("gfx12 native path does not return LSE yet.") torch.cuda.set_device(v.device) input_dtype = q.dtype @@ -277,10 +358,12 @@ def sageattn_qk_int8_pv_gfx12_native( float(sm_scale), ) if input_dtype == torch.bfloat16: - return out if out.dtype == torch.bfloat16 else gfx12_native.convert_f16_to_bf16(out) - return out + out = out if out.dtype == torch.bfloat16 else gfx12_native.convert_f16_to_bf16(out) + return _with_lse(out) - if tensor_layout == "NHD" and smooth_k: + if tensor_layout == "NHD" and smooth_k and not ( + value_dtype == "fp16" and q.size(-1) > 64 + ): q_nhd = q.contiguous() k_nhd = k.contiguous() v_nhd = v.contiguous() @@ -299,15 +382,15 @@ def sageattn_qk_int8_pv_gfx12_native( k_nhd = F.pad(k_nhd, (0, pad)) v_nhd = F.pad(v_nhd, (0, pad)) head_dim = 64 - elif value_dtype == "fp8" and 64 < head_dim < 128: + elif 64 < head_dim < 128: pad = 128 - head_dim q_nhd = F.pad(q_nhd, (0, pad)) k_nhd = F.pad(k_nhd, (0, pad)) v_nhd = F.pad(v_nhd, (0, pad)) head_dim = 128 - if value_dtype == "fp16" and head_dim not in (16, 64): - raise ValueError("gfx12 fp16 value path currently supports head_dim 16 or 64.") + if value_dtype == "fp16" and head_dim not in (16, 64, 128): + raise ValueError("gfx12 fp16 value path currently supports head_dim 16, 64, or 128.") if value_dtype == "fp8" and head_dim not in (16, 64, 128): raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") @@ -338,7 +421,7 @@ def sageattn_qk_int8_pv_gfx12_native( value_scale = None if value_dtype == "fp8": v_hnd_for_value = v_nhd.transpose(1, 2).contiguous() - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd_for_value) + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd_for_value, fp8_value_scale_max) value_native = gfx12_native.transpose_value_fp8_scaled_hnd( v_hnd_for_value, value_scale ) @@ -361,6 +444,7 @@ def sageattn_qk_int8_pv_gfx12_native( int(is_causal), float(sm_scale), kv_len, + 1, ) else: q_int8, q_scale = gfx12_native.quant_q_nhd_per_warp(q_attn) @@ -375,13 +459,14 @@ def sageattn_qk_int8_pv_gfx12_native( int(is_causal), float(sm_scale), kv_len, + 0, ) out = out[:, :qo_len, :, :head_dim_og] if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: out = gfx12_native.convert_f16_to_bf16(out.contiguous() if not out.is_contiguous() else out) elif input_dtype != torch.float16: out = out.to(input_dtype) - return out + return _with_lse(out) if tensor_layout == "NHD": q_hnd = q.transpose(1, 2).contiguous() @@ -406,15 +491,15 @@ def sageattn_qk_int8_pv_gfx12_native( k_hnd = F.pad(k_hnd, (0, pad)) v_hnd = F.pad(v_hnd, (0, pad)) head_dim = 64 - elif value_dtype == "fp8" and 64 < head_dim < 128: + elif 64 < head_dim < 128: pad = 128 - head_dim q_hnd = F.pad(q_hnd, (0, pad)) k_hnd = F.pad(k_hnd, (0, pad)) v_hnd = F.pad(v_hnd, (0, pad)) head_dim = 128 - if value_dtype == "fp16" and head_dim not in (16, 64): - raise ValueError("gfx12 fp16 value path currently supports head_dim 16 or 64.") + if value_dtype == "fp16" and head_dim not in (16, 64, 128): + raise ValueError("gfx12 fp16 value path currently supports head_dim 16, 64, or 128.") if value_dtype == "fp8" and head_dim not in (16, 64, 128): raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") @@ -436,7 +521,7 @@ def sageattn_qk_int8_pv_gfx12_native( q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( q_hnd, k_hnd, None, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" ) - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd) + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd, fp8_value_scale_max) value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) out = torch.empty_like(q_hnd, dtype=torch.float16) gfx12_native.qk_int8_sv_f8_scaled_native_attn( @@ -475,7 +560,7 @@ def sageattn_qk_int8_pv_gfx12_native( _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( k_hnd, k_mean.squeeze(2), k_int8, k_scale, 64, 1 ) - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd) + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd, fp8_value_scale_max) value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) out = torch.empty_like( q_hnd, @@ -483,26 +568,32 @@ def sageattn_qk_int8_pv_gfx12_native( ) gfx12_native.qk_rawq_int8_sv_f8_scaled_native_attn( q_hnd, k_int8, value_native, out, k_scale, value_scale, + 1, int(is_causal), float(sm_scale), kv_len, 1 + ) + out = out[..., :qo_len, :head_dim_og] + if input_dtype != torch.float16 and out.dtype != input_dtype: + out = out.to(input_dtype) + if tensor_layout == "NHD": + out = out.transpose(1, 2).contiguous() + return _with_lse(out) + + q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( + q_hnd, k_hnd, k_mean, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" + ) + out = torch.empty_like(q_hnd, dtype=torch.float16) + if value_dtype == "fp8": + value_scale = _gfx12_fp8_value_scale_hnd(v_hnd, fp8_value_scale_max) + value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) + gfx12_native.qk_int8_sv_f8_scaled_native_attn( + q_int8, k_int8, value_native, out, q_scale, k_scale, value_scale, 1, int(is_causal), float(sm_scale), kv_len ) else: - q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( - q_hnd, k_hnd, k_mean, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" + value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) + gfx12_native.qk_int8_sv_f16_d64_native_attn( + q_int8, k_int8, value_native, out, q_scale, k_scale, + 1, int(is_causal), float(sm_scale), kv_len, 1 ) - out = torch.empty_like(q_hnd, dtype=torch.float16) - if value_dtype == "fp8": - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd) - value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) - gfx12_native.qk_int8_sv_f8_scaled_native_attn( - q_int8, k_int8, value_native, out, q_scale, k_scale, value_scale, - 1, int(is_causal), float(sm_scale), kv_len - ) - else: - value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) - gfx12_native.qk_int8_sv_f16_d64_native_attn( - q_int8, k_int8, value_native, out, q_scale, k_scale, - 1, int(is_causal), float(sm_scale), kv_len - ) out = out[..., :qo_len, :head_dim_og] if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: out = gfx12_native.convert_f16_to_bf16(out.contiguous() if not out.is_contiguous() else out) @@ -510,7 +601,7 @@ def sageattn_qk_int8_pv_gfx12_native( out = out.to(input_dtype) if tensor_layout == "NHD": out = out.transpose(1, 2).contiguous() - return out + return _with_lse(out) def sageattn( diff --git a/sageattention/gfx12_native_compile.py b/sageattention/gfx12_native_compile.py index 40d44f72..abe01c77 100644 --- a/sageattention/gfx12_native_compile.py +++ b/sageattention/gfx12_native_compile.py @@ -19,7 +19,8 @@ def qk_int8_sv_f16_d64_native_attn_fake_impl( tensor_layout: int, is_causal: int, sm_scale: float, - valid_kv_len: int, + valid_kv_len: int = 0, + value_transposed_hnd: int = -1, ) -> torch.Tensor: return _empty_lse(query) @@ -34,7 +35,8 @@ def qk_rawq_int8_sv_f8_native_attn_fake_impl( tensor_layout: int, is_causal: int, sm_scale: float, - valid_kv_len: int, + valid_kv_len: int = 0, + value_transposed_hnd: int = -1, ) -> torch.Tensor: return output @@ -51,7 +53,7 @@ def qk_int8_sv_f8_scaled_native_attn_fake_impl( tensor_layout: int, is_causal: int, sm_scale: float, - valid_kv_len: int, + valid_kv_len: int = 0, ) -> torch.Tensor: return _empty_lse(query) @@ -67,7 +69,8 @@ def qk_rawq_int8_sv_f8_scaled_native_attn_fake_impl( tensor_layout: int, is_causal: int, sm_scale: float, - valid_kv_len: int, + valid_kv_len: int = 0, + value_transposed_hnd: int = -1, ) -> torch.Tensor: return output @@ -81,7 +84,7 @@ def qk_int8_sv_f16_d64_prepare_attn_hnd_fake_impl( value_is_fp8: int, use_raw_f16_value: int, sm_scale: float, - valid_kv_len: int, + valid_kv_len: int = 0, ) -> torch.Tensor: dtype = torch.bfloat16 if value_is_fp8 and query.dtype == torch.bfloat16 else torch.float16 return torch.empty(query.shape, dtype=dtype, device=query.device) From 83ff386ba9d4dbf46e1849e4103ed048e7be07c7 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Sun, 17 May 2026 18:18:00 +0900 Subject: [PATCH 3/7] Port gfx12 cleanup to ABI3 branch --- .gitignore | 14 + csrc/qattn/attn_gfx12_native.h | 35 +- csrc/qattn/pybind_gfx12_native.cpp | 80 +- csrc/qattn/qk_int_sv_gfx12_native.cu | 1057 ++++++++++++++--- csrc/qattn/qk_int_sv_gfx12_native_attn_f16.cu | 12 + csrc/qattn/qk_int_sv_gfx12_native_attn_fp8.cu | 12 + csrc/qattn/qk_int_sv_gfx12_native_aux.cu | 12 + csrc/qattn/qk_int_sv_gfx12_native_prepare.cu | 12 + csrc/qattn/qk_int_sv_gfx12_native_rawq_fp8.cu | 12 + sageattention/core.py | 403 ++++--- sageattention/gfx12_native_compile.py | 52 + setup.py | 10 +- 12 files changed, 1362 insertions(+), 349 deletions(-) create mode 100644 csrc/qattn/qk_int_sv_gfx12_native_attn_f16.cu create mode 100644 csrc/qattn/qk_int_sv_gfx12_native_attn_fp8.cu create mode 100644 csrc/qattn/qk_int_sv_gfx12_native_aux.cu create mode 100644 csrc/qattn/qk_int_sv_gfx12_native_prepare.cu create mode 100644 csrc/qattn/qk_int_sv_gfx12_native_rawq_fp8.cu diff --git a/.gitignore b/.gitignore index 31bf217a..b037236e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ __pycache__/ # C extensions *.so +*.pyd # Distribution / packaging .Python @@ -174,3 +175,16 @@ cython_debug/ .pypirc /wheelhouse +csrc/cp_async_hip.cuh +csrc/fused/fused.hip +csrc/fused/pybind_hip.cpp +csrc/numeric_conversion_hip.cuh +csrc/qattn/qk_int_sv_gfx12_native.hip +csrc/qattn/qk_int_sv_gfx12_native_*.hip +csrc/reduction_utils_hip.cuh +evaluate/ +*.pkl +*.vscode/ +bench_qk_int8_pv_fp8_cuda_head96.py +*test.py +*unmask.py diff --git a/csrc/qattn/attn_gfx12_native.h b/csrc/qattn/attn_gfx12_native.h index ce8cf2be..44045b1e 100644 --- a/csrc/qattn/attn_gfx12_native.h +++ b/csrc/qattn/attn_gfx12_native.h @@ -23,7 +23,8 @@ Tensor qk_int8_sv_f16_d64_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len = 0, - int64_t value_transposed_hnd = -1); + int64_t value_transposed_hnd = -1, + int64_t pv_accum_mode = -1); Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( Tensor query, @@ -35,7 +36,20 @@ Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len = 0, - int64_t value_transposed_hnd = -1); + int64_t value_transposed_hnd = -1, + int64_t key_hnd_layout = 0); + +Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len = 0, + int64_t pv_accum_mode = -1); Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( Tensor query, @@ -61,7 +75,8 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len = 0, - int64_t value_transposed_hnd = -1); + int64_t value_transposed_hnd = -1, + int64_t key_hnd_layout = 0); Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( Tensor query, @@ -71,7 +86,8 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( int64_t value_is_fp8, int64_t use_raw_f16_value, double sm_scale, - int64_t valid_kv_len = 0); + int64_t valid_kv_len = 0, + int64_t pv_accum_mode = -1); std::vector quant_q_nhd_per_warp_gfx12(Tensor query); @@ -79,6 +95,17 @@ Tensor transpose_value_fp8_hnd_gfx12(Tensor value); Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale); +std::vector fp8_value_nhd_short_gfx12( + Tensor value, + double scale_max); + +Tensor mean_nhd_gfx12(Tensor input); + +std::vector mean_and_fp8_value_nhd_short_gfx12( + Tensor key, + Tensor value, + double scale_max); + Tensor transpose_value_f16_hnd_gfx12(Tensor value); Tensor convert_f16_to_bf16_gfx12(Tensor input); diff --git a/csrc/qattn/pybind_gfx12_native.cpp b/csrc/qattn/pybind_gfx12_native.cpp index f0931137..5aa1970f 100644 --- a/csrc/qattn/pybind_gfx12_native.cpp +++ b/csrc/qattn/pybind_gfx12_native.cpp @@ -23,69 +23,45 @@ PyMODINIT_FUNC PyInit__qattn_gfx12_native(void) STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { m.def("qk_int8_sv_f16_d64_native_attn(" - "Tensor query, " - "Tensor key, " - "Tensor value, " - "Tensor(a!) output, " - "Tensor query_scale, " - "Tensor key_scale, " - "int tensor_layout, " - "int is_causal, " - "float sm_scale, " - "int valid_kv_len=0, " - "int value_transposed_hnd=-1" + "Tensor query, Tensor key, Tensor value, Tensor(a!) output, " + "Tensor query_scale, Tensor key_scale, int tensor_layout, " + "int is_causal, float sm_scale, int valid_kv_len=0, " + "int value_transposed_hnd=-1, int pv_accum_mode=-1" ") -> Tensor"); m.def("qk_rawq_int8_sv_f8_native_attn(" - "Tensor query, " - "Tensor key, " - "Tensor value, " - "Tensor(a!) output, " - "Tensor key_scale, " - "int tensor_layout, " - "int is_causal, " - "float sm_scale, " - "int valid_kv_len=0, " - "int value_transposed_hnd=-1" + "Tensor query, Tensor key, Tensor value, Tensor(a!) output, " + "Tensor key_scale, int tensor_layout, int is_causal, " + "float sm_scale, int valid_kv_len=0, int value_transposed_hnd=-1, " + "int key_hnd_layout=0" + ") -> Tensor"); + m.def("qk_rawq_int8_sv_f16_native_attn(" + "Tensor query, Tensor key, Tensor value, Tensor(a!) output, " + "Tensor key_scale, int tensor_layout, int is_causal, " + "float sm_scale, int valid_kv_len=0, int pv_accum_mode=-1" ") -> Tensor"); m.def("qk_int8_sv_f8_scaled_native_attn(" - "Tensor query, " - "Tensor key, " - "Tensor value, " - "Tensor(a!) output, " - "Tensor query_scale, " - "Tensor key_scale, " - "Tensor value_scale, " - "int tensor_layout, " - "int is_causal, " - "float sm_scale, " + "Tensor query, Tensor key, Tensor value, Tensor(a!) output, " + "Tensor query_scale, Tensor key_scale, Tensor value_scale, " + "int tensor_layout, int is_causal, float sm_scale, " "int valid_kv_len=0" ") -> Tensor"); m.def("qk_rawq_int8_sv_f8_scaled_native_attn(" - "Tensor query, " - "Tensor key, " - "Tensor value, " - "Tensor(a!) output, " - "Tensor key_scale, " - "Tensor value_scale, " - "int tensor_layout, " - "int is_causal, " - "float sm_scale, " - "int valid_kv_len=0, " - "int value_transposed_hnd=-1" + "Tensor query, Tensor key, Tensor value, Tensor(a!) output, " + "Tensor key_scale, Tensor value_scale, int tensor_layout, " + "int is_causal, float sm_scale, int valid_kv_len=0, " + "int value_transposed_hnd=-1, int key_hnd_layout=0" ") -> Tensor"); m.def("qk_int8_sv_f16_d64_prepare_attn_hnd(" - "Tensor query, " - "Tensor key, " - "Tensor value, " - "int is_causal, " - "int value_is_fp8, " - "int use_raw_f16_value, " - "float sm_scale, " - "int valid_kv_len=0" + "Tensor query, Tensor key, Tensor value, int is_causal, " + "int value_is_fp8, int use_raw_f16_value, float sm_scale, " + "int valid_kv_len=0, int pv_accum_mode=-1" ") -> Tensor"); m.def("quant_q_nhd_per_warp(Tensor query) -> Tensor[]"); m.def("transpose_value_fp8_hnd(Tensor value) -> Tensor"); m.def("transpose_value_fp8_scaled_hnd(Tensor value, Tensor value_scale) -> Tensor"); + m.def("fp8_value_nhd_short(Tensor value, float scale_max) -> Tensor[]"); + m.def("mean_nhd(Tensor input) -> Tensor"); + m.def("mean_and_fp8_value_nhd_short(Tensor key, Tensor value, float scale_max) -> Tensor[]"); m.def("transpose_value_f16_hnd(Tensor value) -> Tensor"); m.def("convert_f16_to_bf16(Tensor input) -> Tensor"); } @@ -93,12 +69,16 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { STABLE_TORCH_LIBRARY_IMPL(sageattention_qattn_gfx12_native, CUDA, m) { m.impl("qk_int8_sv_f16_d64_native_attn", TORCH_BOX(qk_int8_sv_f16_d64_native_attn_gfx12)); m.impl("qk_rawq_int8_sv_f8_native_attn", TORCH_BOX(qk_rawq_int8_sv_f8_native_attn_gfx12)); + m.impl("qk_rawq_int8_sv_f16_native_attn", TORCH_BOX(qk_rawq_int8_sv_f16_native_attn_gfx12)); m.impl("qk_int8_sv_f8_scaled_native_attn", TORCH_BOX(qk_int8_sv_f8_scaled_native_attn_gfx12)); m.impl("qk_rawq_int8_sv_f8_scaled_native_attn", TORCH_BOX(qk_rawq_int8_sv_f8_scaled_native_attn_gfx12)); m.impl("qk_int8_sv_f16_d64_prepare_attn_hnd", TORCH_BOX(qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12)); m.impl("quant_q_nhd_per_warp", TORCH_BOX(quant_q_nhd_per_warp_gfx12)); m.impl("transpose_value_fp8_hnd", TORCH_BOX(transpose_value_fp8_hnd_gfx12)); m.impl("transpose_value_fp8_scaled_hnd", TORCH_BOX(transpose_value_fp8_scaled_hnd_gfx12)); + m.impl("fp8_value_nhd_short", TORCH_BOX(fp8_value_nhd_short_gfx12)); + m.impl("mean_nhd", TORCH_BOX(mean_nhd_gfx12)); + m.impl("mean_and_fp8_value_nhd_short", TORCH_BOX(mean_and_fp8_value_nhd_short_gfx12)); m.impl("transpose_value_f16_hnd", TORCH_BOX(transpose_value_f16_hnd_gfx12)); m.impl("convert_f16_to_bf16", TORCH_BOX(convert_f16_to_bf16_gfx12)); } diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu index ea1cc15d..7bb709a0 100644 --- a/csrc/qattn/qk_int_sv_gfx12_native.cu +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -39,6 +39,18 @@ using torch::stable::Tensor; using ScalarType = torch::headeronly::ScalarType; +#if !defined(SAGEATTN_GFX12_BUILD_AUX) && \ + !defined(SAGEATTN_GFX12_BUILD_PREPARE) && \ + !defined(SAGEATTN_GFX12_BUILD_ATTN_F16) && \ + !defined(SAGEATTN_GFX12_BUILD_ATTN_FP8) && \ + !defined(SAGEATTN_GFX12_BUILD_RAWQ_FP8) +#define SAGEATTN_GFX12_BUILD_AUX 1 +#define SAGEATTN_GFX12_BUILD_PREPARE 1 +#define SAGEATTN_GFX12_BUILD_ATTN_F16 1 +#define SAGEATTN_GFX12_BUILD_ATTN_FP8 1 +#define SAGEATTN_GFX12_BUILD_RAWQ_FP8 1 +#endif + namespace { constexpr int kNHD = 0; @@ -136,6 +148,14 @@ __device__ __forceinline__ float value_to_float(const __hip_bfloat16 value) { return __bfloat162float(value); } +__device__ __forceinline__ __half value_from_float_half(const float value) { + return __float2half_rn(value); +} + +__device__ __forceinline__ __hip_bfloat16 value_from_float_bfloat16(const float value) { + return __float2bfloat16(value); +} + __device__ __forceinline__ int8_t float_to_int8_rn_gfx12(const float x) { int32_t rounded; asm volatile("v_cvt_i32_f32 %[dst], %[src]" @@ -495,6 +515,250 @@ __global__ void transpose_value_fp8_scaled_hnd_kernel( } } +template +__global__ void fp8_value_nhd_short_kernel( + const T* __restrict__ value, + uint8_t* __restrict__ output, + float* __restrict__ value_scale, + const int64_t seq_len, + const int64_t heads, + const int64_t head_dim, + const float scale_max) { + constexpr int TileS = 128; + constexpr int TileD = 16; + __shared__ float partial_amax[256]; + __shared__ float scale_tile[TileD]; + __shared__ uint8_t tile[TileS][TileD]; + + const int tid = threadIdx.x; + const int d_local = tid & (TileD - 1); + const int s_lane = tid >> 4; + const int64_t d_base = static_cast(blockIdx.x) * TileD; + const int64_t h = blockIdx.y; + const int64_t b = blockIdx.z; + const int64_t d = d_base + d_local; + + float local_amax = 0.0f; + if (d < head_dim) { + for (int64_t s = s_lane; s < seq_len; s += 16) { + const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + d; + local_amax = fmaxf(local_amax, fabsf(value_to_float(value[offset]))); + } + } + partial_amax[tid] = local_amax; + __syncthreads(); + + if (tid < TileD) { + float amax = 0.0f; + for (int i = 0; i < 16; ++i) { + amax = fmaxf(amax, partial_amax[i * TileD + tid]); + } + const float scale = amax / scale_max; + scale_tile[tid] = scale; + const int64_t scale_d = d_base + tid; + if (scale_d < head_dim) { + value_scale[(b * heads + h) * head_dim + scale_d] = scale; + } + } + __syncthreads(); + + for (int64_t s_base = 0; s_base < seq_len; s_base += TileS) { + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int load_s = linear / TileD; + const int load_d = linear - load_s * TileD; + const int64_t s = s_base + load_s; + const int64_t value_d = d_base + load_d; + uint8_t packed = 0; + if (s < seq_len && value_d < head_dim) { + const float scale = scale_tile[load_d]; + const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + value_d; + const float v = scale == 0.0f ? 0.0f : value_to_float(value[offset]) / scale; + packed = __hip_cvt_float_to_fp8(v, __HIP_SATFINITE, __HIP_E4M3); + } + tile[load_s][load_d] = packed; + } + __syncthreads(); + + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int store_d_local = linear / TileS; + const int store_s_local = linear - store_d_local * TileS; + const int64_t s = s_base + store_s_local; + const int64_t value_d = d_base + store_d_local; + if (s < seq_len && value_d < head_dim) { + output[((b * heads + h) * head_dim + value_d) * seq_len + s] = + tile[store_s_local][store_d_local]; + } + } + __syncthreads(); + } +} + +template +__global__ void mean_nhd_kernel( + const T* __restrict__ input, + T* __restrict__ mean, + const int64_t seq_len, + const int64_t heads, + const int64_t head_dim) { + constexpr int TileD = 16; + __shared__ float partial_sum[256]; + + const int tid = threadIdx.x; + const int d_local = tid & (TileD - 1); + const int s_lane = tid >> 4; + const int64_t d_base = static_cast(blockIdx.x) * TileD; + const int64_t h = blockIdx.y; + const int64_t b = blockIdx.z; + const int64_t d = d_base + d_local; + + float local_sum = 0.0f; + if (d < head_dim) { + for (int64_t s = s_lane; s < seq_len; s += 16) { + const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + d; + local_sum += value_to_float(input[offset]); + } + } + partial_sum[tid] = local_sum; + __syncthreads(); + + if (tid < TileD) { + float sum = 0.0f; + for (int i = 0; i < 16; ++i) { + sum += partial_sum[i * TileD + tid]; + } + const int64_t mean_d = d_base + tid; + if (mean_d < head_dim) { + const float value = sum / static_cast(seq_len); + if constexpr (std::is_same::value) { + mean[(b * heads + h) * head_dim + mean_d] = value_from_float_half(value); + } else { + mean[(b * heads + h) * head_dim + mean_d] = value_from_float_bfloat16(value); + } + } + } +} + +__device__ __forceinline__ int32_t pack_f32x4_to_ocp_fp8( + const float x0, + const float x1, + const float x2, + const float x3); + +template +__global__ void mean_and_fp8_value_nhd_short_kernel( + const T* __restrict__ key, + const T* __restrict__ value, + T* __restrict__ key_mean, + uint8_t* __restrict__ output, + float* __restrict__ value_scale, + const int64_t seq_len, + const int64_t heads, + const int64_t head_dim, + const float scale_max) { + constexpr int TileS = 128; + constexpr int TileD = 32; + __shared__ float partial_sum[TileD * SeqLanes]; + __shared__ float partial_amax[TileD * SeqLanes]; + __shared__ float scale_tile[TileD]; + __shared__ uint8_t tile[TileS][TileD]; + + const int tid = threadIdx.x; + const int d_local = tid & (TileD - 1); + const int s_lane = tid / TileD; + const int64_t d_base = static_cast(blockIdx.x) * TileD; + const int64_t h = blockIdx.y; + const int64_t b = blockIdx.z; + const int64_t d = d_base + d_local; + + float local_sum = 0.0f; + float local_amax = 0.0f; + if (d < head_dim) { + for (int64_t s = s_lane; s < seq_len; s += SeqLanes) { + const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + d; + local_sum += value_to_float(key[offset]); + local_amax = fmaxf(local_amax, fabsf(value_to_float(value[offset]))); + } + } + partial_sum[tid] = local_sum; + partial_amax[tid] = local_amax; + __syncthreads(); + + if (tid < TileD) { + float sum = 0.0f; + float amax = 0.0f; + for (int i = 0; i < SeqLanes; ++i) { + const int partial_idx = i * TileD + tid; + sum += partial_sum[partial_idx]; + amax = fmaxf(amax, partial_amax[partial_idx]); + } + const int64_t value_d = d_base + tid; + if (value_d < head_dim) { + const float mean = sum / static_cast(seq_len); + const int64_t mean_offset = (b * heads + h) * head_dim + value_d; + if constexpr (std::is_same::value) { + key_mean[mean_offset] = value_from_float_half(mean); + } else { + key_mean[mean_offset] = value_from_float_bfloat16(mean); + } + const float scale = amax / scale_max; + scale_tile[tid] = scale; + value_scale[mean_offset] = scale; + } + } + __syncthreads(); + + for (int64_t s_base = 0; s_base < seq_len; s_base += TileS) { + constexpr int PackElems = 4; + constexpr int PacksPerRow = TileD / PackElems; + for (int pack = tid; pack < TileS * PacksPerRow; pack += blockDim.x) { + const int load_s = pack / PacksPerRow; + const int load_d = (pack - load_s * PacksPerRow) * PackElems; + const int64_t s = s_base + load_s; + const int64_t value_d = d_base + load_d; + if (s < seq_len && value_d + PackElems - 1 < head_dim) { + const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + value_d; + const float scale0 = scale_tile[load_d + 0]; + const float scale1 = scale_tile[load_d + 1]; + const float scale2 = scale_tile[load_d + 2]; + const float scale3 = scale_tile[load_d + 3]; + const float v0 = scale0 == 0.0f ? 0.0f : value_to_float(value[offset + 0]) / scale0; + const float v1 = scale1 == 0.0f ? 0.0f : value_to_float(value[offset + 1]) / scale1; + const float v2 = scale2 == 0.0f ? 0.0f : value_to_float(value[offset + 2]) / scale2; + const float v3 = scale3 == 0.0f ? 0.0f : value_to_float(value[offset + 3]) / scale3; + const uint32_t packed = static_cast(pack_f32x4_to_ocp_fp8(v0, v1, v2, v3)); + *reinterpret_cast(&tile[load_s][load_d]) = packed; + } else { +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + const int elem_d = load_d + i; + uint8_t packed = 0; + if (s < seq_len && d_base + elem_d < head_dim) { + const float scale = scale_tile[elem_d]; + const int64_t offset = + ((b * seq_len + s) * heads + h) * head_dim + d_base + elem_d; + const float v = scale == 0.0f ? 0.0f : value_to_float(value[offset]) / scale; + packed = __hip_cvt_float_to_fp8(v, __HIP_SATFINITE, __HIP_E4M3); + } + tile[load_s][elem_d] = packed; + } + } + } + __syncthreads(); + + for (int linear = tid; linear < TileS * TileD; linear += blockDim.x) { + const int store_d_local = linear / TileS; + const int store_s_local = linear - store_d_local * TileS; + const int64_t s = s_base + store_s_local; + const int64_t value_d = d_base + store_d_local; + if (s < seq_len && value_d < head_dim) { + output[((b * heads + h) * head_dim + value_d) * seq_len + s] = + tile[store_s_local][store_d_local]; + } + } + __syncthreads(); + } +} + __device__ __forceinline__ int64_t qkv_offset( const int tensor_layout, const int64_t b, @@ -509,7 +773,7 @@ __device__ __forceinline__ int64_t qkv_offset( : b * stride_b + h * stride_h + n * stride_n + d; } -template +template __device__ __forceinline__ int64_t qkv_offset_dispatch( const int tensor_layout, const int64_t b, @@ -521,6 +785,8 @@ __device__ __forceinline__ int64_t qkv_offset_dispatch( const int64_t stride_h) { if constexpr (HndContiguous) { return b * stride_b + h * stride_h + n * HeadDim + d; + } else if constexpr (StaticNhd) { + return b * stride_b + n * stride_n + h * stride_h + d; } else { return qkv_offset(tensor_layout, b, h, n, d, stride_b, stride_n, stride_h); } @@ -668,7 +934,11 @@ __device__ __forceinline__ void apply_tqk_causal_mask_pv_order( } } -template +template __device__ __forceinline__ i32x2_vec pack_quant_q_i8_wmma_b_regs( const QueryT* __restrict__ q, const int tensor_layout, @@ -686,7 +956,7 @@ __device__ __forceinline__ i32x2_vec pack_quant_q_i8_wmma_b_regs( const int row = lane & 15; const int k_base = 8 * (lane >> 4); const int64_t q_idx = q_start + row; - if (q_idx >= qo_len) { + if constexpr (!NoQueryTail) if (q_idx >= qo_len) { #pragma unroll for (int gpr = 0; gpr < 2; ++gpr) { regs[gpr] = 0; @@ -695,7 +965,7 @@ __device__ __forceinline__ i32x2_vec pack_quant_q_i8_wmma_b_regs( } const int d = d_base + k_base; - const int64_t q_off = qkv_offset_dispatch( + const int64_t q_off = qkv_offset_dispatch( tensor_layout, b, h, q_idx, d, q_stride_b, q_stride_n, q_stride_h); const uint4 raw = *reinterpret_cast(q + q_off); const QueryT* values = reinterpret_cast(&raw); @@ -713,7 +983,10 @@ __device__ __forceinline__ i32x2_vec pack_quant_q_i8_wmma_b_regs( return regs; } -template +template __device__ __forceinline__ i32x2_vec pack_q_i8_wmma_b_regs( const int8_t* __restrict__ q, const int tensor_layout, @@ -730,7 +1003,7 @@ __device__ __forceinline__ i32x2_vec pack_q_i8_wmma_b_regs( const int row = lane & 15; const int k_base = 8 * (lane >> 4); const int64_t q_idx = q_start + row; - if (q_idx >= qo_len) { + if constexpr (!NoQueryTail) if (q_idx >= qo_len) { #pragma unroll for (int gpr = 0; gpr < 2; ++gpr) { regs[gpr] = 0; @@ -739,7 +1012,7 @@ __device__ __forceinline__ i32x2_vec pack_q_i8_wmma_b_regs( } const int d = d_base + k_base; - const int64_t q_off = qkv_offset_dispatch( + const int64_t q_off = qkv_offset_dispatch( tensor_layout, b, h, q_idx, d, q_stride_b, q_stride_n, q_stride_h); const uint2 raw = *reinterpret_cast(q + q_off); regs[0] = static_cast(raw.x); @@ -1908,8 +2181,6 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 "transposed fp16 value layout requires contiguous HND tensors."); static_assert(!F16PvAccum || BlockCols <= 64, "fp16 PV accumulation currently supports the BC64 2q path."); - static_assert(!QuantizeQuery || HndContiguous, - "direct fp16 Q quantization currently requires contiguous HND tensors."); static_assert(!QuantizeKey || (HndContiguous && BlockCols == 64), "direct fp16 K quantization currently requires contiguous HND BC64 tensors."); static_assert(!LaneMajorValue || @@ -2912,8 +3183,12 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 o_off1 = qkv_offset( tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h); } - store_half(output, o_off0, value0); - store_half(output, o_off1, value1); + if (q_idx0 < qo_len) { + store_half(output, o_off0, value0); + } + if (q_idx1 < qo_len) { + store_half(output, o_off1, value1); + } } } } @@ -2940,7 +3215,13 @@ template + bool LowPressureQGroups = false, + bool PerThreadQK = false, + bool KeyHndContiguous = HndContiguous, + bool StaticNhdLayout = false, + bool NoKvTail = false, + bool SameQKHeads = false, + bool NoQueryTail = false> SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HeadDim, IsCausal) __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_f8_native_2q_kernel( const QueryT* __restrict__ q, @@ -3040,8 +3321,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ (UsePrepackedLaneMajorK || UsePrepackedLaneMajorValue))); static_assert(!QuantizeKeyValue || ((HeadDim == 64 || HeadDim == 128) && - BlockCols == 64 && HndContiguous && !ValueTransposed), - "raw K/V fp8 staging currently supports contiguous HND D64/D128 BC64 tensors."); + BlockCols == 64 && !ValueTransposed), + "raw K/V fp8 staging currently supports D64/D128 BC64 tensors."); static_assert(!UsePrepackedLaneMajorKV || (HeadDim == 64 && HndContiguous && ValueTransposed && !QuantizeKeyValue), @@ -3057,6 +3338,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ static_assert(ValueTiles == 1 || ValueTiles == 4 || ValueTiles == 8, "native fp8 2q stores one D16, D64, or D128 value slice per launch."); static_assert(ValueTileBase + ValueTiles <= DTiles, "invalid fp8 value tile slice."); + static_assert(!NoQueryTail || (StaticNhdLayout && !IsCausal), + "full-query fp8 path requires a static non-causal dispatch."); __shared__ int8_t k_tile[UsePrepackedLaneMajorK ? 1 : BC][SharedHeadStride]; __shared__ uint8_t v_tile[UsePrepackedLaneMajorValue ? 1 : SharedValueRows] @@ -3078,7 +3361,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ return; } - const int64_t hkv = hq / (num_qo_heads / num_kv_heads); + const int64_t hkv = SameQKHeads ? hq : hq / (num_qo_heads / num_kv_heads); const int64_t k_head_base = b * k_stride_b + hkv * k_stride_h; const int64_t v_head_base = b * v_stride_b + hkv * v_stride_h; int64_t q_start[QGroups]; @@ -3100,8 +3383,17 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ const int row = elem_base / HeadDim; const int d = elem_base - row * HeadDim; const int64_t q_idx = q_base + local_q_row_base + row; - if (q_idx < qo_len) { - const int64_t q_off = qkv_offset_dispatch( + if constexpr (NoQueryTail) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint4 raw = *reinterpret_cast(q + q_off); + const QueryT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < QPackElems; ++i) { + local_q_amax = fmaxf(local_q_amax, fabsf(value_to_float(values[i]))); + } + } else if (q_idx < qo_len) { + const int64_t q_off = qkv_offset_dispatch( tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); const uint4 raw = *reinterpret_cast(q + q_off); const QueryT* values = reinterpret_cast(&raw); @@ -3132,7 +3424,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ const int64_t qg_start = q_start[qg]; #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - q_regs[qg][dt] = pack_quant_q_i8_wmma_b_regs( + q_regs[qg][dt] = + pack_quant_q_i8_wmma_b_regs( q, tensor_layout, lane, b, hq, qg_start, qo_len, dt * BK, q_stride_b, q_stride_n, q_stride_h, inv_q_scale); } @@ -3149,7 +3442,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int qg = 0; qg < QGroups; ++qg) { #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - q_regs[qg][dt] = pack_q_i8_wmma_b_regs( + q_regs[qg][dt] = + pack_q_i8_wmma_b_regs( q, tensor_layout, lane, b, hq, q_start[qg], qo_len, dt * BK, q_stride_b, q_stride_n, q_stride_h); } @@ -3167,7 +3461,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int qg = 0; qg < QGroups; ++qg) { #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - const int64_t q_off = qkv_offset_dispatch( + const int64_t q_off = qkv_offset_dispatch( tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); rocwmma::load_matrix_sync(q_frag[qg][dt], q + q_off, static_cast(q_stride_n)); } @@ -3212,7 +3506,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ const int elem_base = pack * PackElems; const int n = elem_base / HeadDim; const int d = elem_base - n * HeadDim; - const int64_t k_off = qkv_offset_dispatch( + const int64_t k_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); const uint4 raw = *reinterpret_cast(k + k_off); const KeyT* values = reinterpret_cast(&raw); @@ -3234,9 +3528,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ const int elem_base = pack * PackElems; const int n = elem_base / HeadDim; const int d = elem_base - n * HeadDim; - const int64_t k_off = qkv_offset_dispatch( + const int64_t k_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); - const int64_t v_off = qkv_offset_dispatch( + const int64_t v_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); const uint4 raw_k = *reinterpret_cast(k + k_off); const uint4 raw_v = *reinterpret_cast(v + v_off); @@ -3268,12 +3562,12 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int vec = tid; vec < BC * VecsPerRow; vec += Threads) { const int n = vec / VecsPerRow; const int d = (vec - n * VecsPerRow) * VecBytes; - const int64_t k_off = qkv_offset_dispatch( + const int64_t k_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); *reinterpret_cast(&k_tile[n][d]) = *reinterpret_cast(k + k_off); if constexpr (!ValueTransposed) { - const int64_t v_off = qkv_offset_dispatch( + const int64_t v_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); *reinterpret_cast(&v_tile[n][d]) = *reinterpret_cast(v + v_off); @@ -3309,7 +3603,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int qg = 0; qg < QGroups; ++qg) { #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - const int64_t q_off = qkv_offset_dispatch( + const int64_t q_off = qkv_offset_dispatch( tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); rocwmma::load_matrix_sync(q_frag[qg][dt], q + q_off, static_cast(q_stride_n)); } @@ -3384,7 +3678,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } float k_scale_local = k_scale_tile; if constexpr (!QuantizeKeyValue && BC <= 64) { - k_scale_local = prepared_k_scale_tile; + if constexpr (!PerThreadQK) { + k_scale_local = prepared_k_scale_tile; + } } else if constexpr (!QuantizeKeyValue) { const int k_scale_idx = k_scale_col_per_warp(k_col_start); k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; @@ -3439,7 +3735,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } score_cache_stream[qg][sc] = scores; @@ -3493,7 +3789,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } score_cache_stream[qg][sc] = scores; @@ -3729,7 +4025,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } score_cache[qg][col_tile] = scores; @@ -3836,7 +4132,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } score_cache[col_tile] = scores; @@ -3992,7 +4288,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } score_cache[col_tile] = scores; @@ -4099,6 +4395,14 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } + float value_scale_tile[ValueTiles]; +#pragma unroll + for (int vdt = 0; vdt < ValueTiles; ++vdt) { + const int d = (ValueTileBase + vdt) * BK + col; + value_scale_tile[vdt] = v_scale == nullptr ? + 1.0f : v_scale[(b * num_kv_heads + hkv) * HeadDim + d]; + } + #pragma unroll for (int qg = 0; qg < QGroups; ++qg) { float l_rows[8]; @@ -4109,8 +4413,7 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ #pragma unroll for (int vdt = 0; vdt < ValueTiles; ++vdt) { const int d = (ValueTileBase + vdt) * BK + col; - const float value_scale = v_scale == nullptr ? - 1.0f : v_scale[(b * num_kv_heads + hkv) * HeadDim + d]; + const float value_scale = value_scale_tile[vdt]; #pragma unroll for (int pair = 0; pair < PackedRows; ++pair) { const int elem = pair * 2; @@ -4122,9 +4425,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ 0.0f : (out_frag[qg][vdt][elem] / l_sum0) * value_scale; const float value1 = l_sum1 == 0.0f ? 0.0f : (out_frag[qg][vdt][elem + 1] / l_sum1) * value_scale; - store_output_value(output, qkv_offset_dispatch( + store_output_value(output, qkv_offset_dispatch( tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h), value0); - store_output_value(output, qkv_offset_dispatch( + store_output_value(output, qkv_offset_dispatch( tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h), value1); } } @@ -4955,6 +5258,27 @@ __global__ void prepare_kv_hnd_fp8_kernel( } // namespace +static int select_fp8_d64_block_rows_gfx12( + const int64_t q_len, + const bool is_causal, + const bool value_transposed_hnd) { + if (is_causal) { + if (q_len <= 64) { + return 64; + } + return 128; + } + if (q_len <= 64) { + return 64; + } + if ((q_len % 256) == 0 && (q_len >= 2048 || value_transposed_hnd)) { + return 256; + } + return 128; +} + +#if SAGEATTN_GFX12_BUILD_AUX + Tensor transpose_value_fp8_hnd_gfx12(Tensor value) { return transpose_value_hnd_gfx12(value); } @@ -5000,10 +5324,161 @@ Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale) { return output; } +std::vector fp8_value_nhd_short_gfx12( + Tensor value, + double scale_max) { + STD_TORCH_CHECK(value.is_cuda(), "gfx12 short NHD value prep expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(value.dim() == 4, "gfx12 short NHD value prep expects [B, S, H, D]"); + STD_TORCH_CHECK(value.is_contiguous(), "gfx12 short NHD value prep expects contiguous NHD input"); + STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half || value.scalar_type() == ScalarType::BFloat16, + "gfx12 short NHD value prep supports fp16/bf16 input"); + + const int64_t batch = value.size(0); + const int64_t seq_len = value.size(1); + const int64_t heads = value.size(2); + const int64_t head_dim = value.size(3); + STD_TORCH_CHECK(head_dim == 64 || head_dim == 128, + "gfx12 short NHD fp8 value prep currently supports head_dim 64 or 128"); + STD_TORCH_CHECK(seq_len == 512 || seq_len == 1024, + "gfx12 short NHD fp8 value prep currently supports sequence length 512 or 1024"); + + Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); + Tensor value_scale = new_empty_like(value, {batch, heads, head_dim}, ScalarType::Float); + + dim3 block(256); + dim3 grid((head_dim + 15) / 16, heads, batch); + if (value.scalar_type() == ScalarType::Half) { + fp8_value_nhd_short_kernel<__half><<>>( + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + seq_len, heads, head_dim, static_cast(scale_max)); + } else { + fp8_value_nhd_short_kernel<__hip_bfloat16><<>>( + reinterpret_cast(value.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + seq_len, heads, head_dim, static_cast(scale_max)); + } + hip_kernel_launch_check(); + return {output, value_scale}; +} + +Tensor mean_nhd_gfx12(Tensor input) { + STD_TORCH_CHECK(input.is_cuda(), "gfx12 NHD mean expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(input.dim() == 4, "gfx12 NHD mean expects [B, S, H, D]"); + STD_TORCH_CHECK(input.is_contiguous(), "gfx12 NHD mean expects contiguous NHD input"); + STD_TORCH_CHECK(input.scalar_type() == ScalarType::Half || input.scalar_type() == ScalarType::BFloat16, + "gfx12 NHD mean supports fp16/bf16 input"); + + const int64_t batch = input.size(0); + const int64_t seq_len = input.size(1); + const int64_t heads = input.size(2); + const int64_t head_dim = input.size(3); + Tensor mean = new_empty_like(input, {batch, heads, head_dim}, input.scalar_type()); + + dim3 block(256); + dim3 grid((head_dim + 15) / 16, heads, batch); + if (input.scalar_type() == ScalarType::Half) { + mean_nhd_kernel<__half><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__half*>(mean.data_ptr()), + seq_len, heads, head_dim); + } else { + mean_nhd_kernel<__hip_bfloat16><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(mean.data_ptr()), + seq_len, heads, head_dim); + } + hip_kernel_launch_check(); + return mean; +} + +std::vector mean_and_fp8_value_nhd_short_gfx12( + Tensor key, + Tensor value, + double scale_max) { + STD_TORCH_CHECK(key.is_cuda() && value.is_cuda(), + "gfx12 short NHD mean/value prep expects CUDA/HIP tensors"); + STD_TORCH_CHECK(key.dim() == 4 && value.dim() == 4, + "gfx12 short NHD mean/value prep expects [B, S, H, D]"); + STD_TORCH_CHECK(key.is_contiguous() && value.is_contiguous(), + "gfx12 short NHD mean/value prep expects contiguous NHD tensors"); + STD_TORCH_CHECK(key.scalar_type() == value.scalar_type(), + "gfx12 short NHD mean/value prep expects matching key/value dtypes"); + STD_TORCH_CHECK(key.scalar_type() == ScalarType::Half || key.scalar_type() == ScalarType::BFloat16, + "gfx12 short NHD mean/value prep supports fp16/bf16 input"); + STD_TORCH_CHECK(key.size(0) == value.size(0) && + key.size(1) == value.size(1) && + key.size(2) == value.size(2) && + key.size(3) == value.size(3), + "gfx12 short NHD mean/value prep expects matching key/value shapes"); + + const int64_t batch = value.size(0); + const int64_t seq_len = value.size(1); + const int64_t heads = value.size(2); + const int64_t head_dim = value.size(3); + STD_TORCH_CHECK(head_dim == 64 || head_dim == 128, + "gfx12 short NHD mean/value prep currently supports head_dim 64 or 128"); + STD_TORCH_CHECK(seq_len == 512 || seq_len == 1024, + "gfx12 short NHD mean/value prep currently supports sequence length 512 or 1024"); + + Tensor key_mean = new_empty_like(key, {batch, heads, head_dim}, key.scalar_type()); + Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); + Tensor value_scale = new_empty_like(value, {batch, heads, head_dim}, ScalarType::Float); + + const int seq_lanes = head_dim == 64 ? 32 : 16; + dim3 block(32 * seq_lanes); + dim3 grid((head_dim + 31) / 32, heads, batch); + if (value.scalar_type() == ScalarType::Half) { + if (head_dim == 64) { + mean_and_fp8_value_nhd_short_kernel<__half, 32><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast<__half*>(key_mean.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + seq_len, heads, head_dim, static_cast(scale_max)); + } else { + mean_and_fp8_value_nhd_short_kernel<__half, 16><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast<__half*>(key_mean.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + seq_len, heads, head_dim, static_cast(scale_max)); + } + } else { + if (head_dim == 64) { + mean_and_fp8_value_nhd_short_kernel<__hip_bfloat16, 32><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(key_mean.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + seq_len, heads, head_dim, static_cast(scale_max)); + } else { + mean_and_fp8_value_nhd_short_kernel<__hip_bfloat16, 16><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(key_mean.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + seq_len, heads, head_dim, static_cast(scale_max)); + } + } + hip_kernel_launch_check(); + return {key_mean, output, value_scale}; +} + Tensor transpose_value_f16_hnd_gfx12(Tensor value) { return transpose_value_hnd_gfx12<__half, false>(value); } +#endif // SAGEATTN_GFX12_BUILD_AUX + +#if SAGEATTN_GFX12_BUILD_PREPARE + template std::vector prepare_qkv_hnd_gfx12( Tensor query, @@ -5707,24 +6182,9 @@ std::vector prepare_k_hnd_packed_gfx12(Tensor key) { return {key_out, key_scale, byte_workspace, scale_workspace}; } -static int select_fp8_d64_block_rows_gfx12( - const int64_t q_len, - const bool is_causal, - const bool value_transposed_hnd) { - if (is_causal) { - if (q_len <= 64) { - return 64; - } - return 128; - } - if (q_len <= 64) { - return 64; - } - if ((q_len % 256) == 0 && (q_len >= 2048 || value_transposed_hnd)) { - return 256; - } - return 128; -} +#endif // SAGEATTN_GFX12_BUILD_PREPARE + +#if SAGEATTN_GFX12_BUILD_AUX __global__ void convert_f16_to_bf16_kernel( const __half* __restrict__ input, @@ -5822,6 +6282,10 @@ std::vector quant_q_nhd_per_warp_gfx12(Tensor query) { return {query_out, query_scale}; } +#endif // SAGEATTN_GFX12_BUILD_AUX + +#if SAGEATTN_GFX12_BUILD_PREPARE + std::vector quant_qk_int8_hnd_gfx12(Tensor query, Tensor key) { STD_TORCH_CHECK(query.is_cuda() && key.is_cuda(), "gfx12 Q/K quantization expects CUDA/HIP tensors"); STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4, "gfx12 Q/K quantization expects [B, H, S, D]"); @@ -5904,7 +6368,7 @@ std::vector quant_qk_int8_hnd_gfx12(Tensor query, Tensor key) { return {query_out, query_scale, key_out, key_scale}; } -static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( +static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( Tensor query, Tensor key, Tensor value, @@ -5915,8 +6379,9 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( int is_causal, float sm_scale, int64_t valid_kv_len, - const float* value_scale_ptr = nullptr, - int value_transposed_hnd_hint = -1); + Tensor value_scale, + int value_transposed_hnd_hint, + int pv_accum_mode); Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( Tensor query, @@ -5926,7 +6391,8 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( int64_t value_is_fp8, int64_t use_raw_f16_value, double sm_scale, - int64_t valid_kv_len) { + int64_t valid_kv_len, + int64_t pv_accum_mode) { STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda(), "native gfx12 prepare+attention expects CUDA/HIP tensors"); STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, @@ -6485,9 +6951,10 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( hip_kernel_launch_check(); } else if (use_raw_f16_value) { std::vector prepared = quant_qk_int8_hnd_gfx12(query, key); - qk_int8_sv_f16_d64_native_attn_gfx12_impl( + qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( prepared[0], prepared[2], value, output, prepared[1], prepared[3], - kHND, is_causal, sm_scale, kv_len, nullptr, 0); + kHND, is_causal, sm_scale, kv_len, Tensor(), 0, + static_cast(pv_accum_mode)); } else { const bool use_f16_separate_prepared = is_causal && head_dim == 64 && q_len == 4096 && @@ -6496,13 +6963,19 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( use_f16_separate_prepared ? prepare_qkv_hnd_gfx12<__half, false>(query, key, value) : prepare_qkv_hnd_packed_gfx12<__half, false>(query, key, value); - qk_int8_sv_f16_d64_native_attn_gfx12_impl( + qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( prepared[0], prepared[2], prepared[4], output, prepared[1], prepared[3], - kHND, is_causal, sm_scale, kv_len, nullptr, 1); + kHND, is_causal, sm_scale, kv_len, Tensor(), 1, + static_cast(pv_accum_mode)); } return output; } +#endif // SAGEATTN_GFX12_BUILD_PREPARE + +#if SAGEATTN_GFX12_BUILD_ATTN_F16 || SAGEATTN_GFX12_BUILD_ATTN_FP8 + +template static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( Tensor query, Tensor key, @@ -6514,17 +6987,27 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( int is_causal, float sm_scale, int64_t valid_kv_len, - const float* value_scale_ptr, - int value_transposed_hnd_hint) { + Tensor value_scale, + int value_transposed_hnd_hint, + int pv_accum_mode) { STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), "native gfx12 tensors must be CUDA/HIP tensors"); STD_TORCH_CHECK(query.scalar_type() == ScalarType::Char, "query must be int8"); STD_TORCH_CHECK(key.scalar_type() == ScalarType::Char, "key must be int8"); const bool value_is_fp8 = value.scalar_type() == ScalarType::Byte; +#if SAGEATTN_GFX12_BUILD_ATTN_F16 && !SAGEATTN_GFX12_BUILD_ATTN_FP8 + STD_TORCH_CHECK(!value_is_fp8, "native gfx12 fp16 attention TU expects fp16 values"); +#endif +#if SAGEATTN_GFX12_BUILD_ATTN_FP8 && !SAGEATTN_GFX12_BUILD_ATTN_F16 + STD_TORCH_CHECK(value_is_fp8, "native gfx12 fp8 attention TU expects fp8 values"); +#endif STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half || value_is_fp8, "value must be fp16 or raw OCP e4m3 fp8 bytes"); - STD_TORCH_CHECK(value_scale_ptr == nullptr || value_is_fp8, + const bool has_value_scale = value_scale.defined() && value_scale.numel() > 0; + STD_TORCH_CHECK(!has_value_scale || value_is_fp8, "value_scale is only valid for the fp8 value path"); + const float* value_scale_ptr = + has_value_scale ? reinterpret_cast(value_scale.data_ptr()) : nullptr; const bool output_is_bf16 = output.scalar_type() == ScalarType::BFloat16; STD_TORCH_CHECK(output.scalar_type() == ScalarType::Half || (value_is_fp8 && output_is_bf16), "output must be fp16, or bf16 for the fp8 value path"); @@ -6662,6 +7145,79 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( const bool use_f16_streamk = head_dim == 64 && !value_is_fp8 && is_causal && value_transposed_hnd && q_len == 4096 && block_rows == 256; + if constexpr (PerThreadQK) { + STD_TORCH_CHECK(value_transposed_hnd, + "gfx12 per-thread QK path expects transposed HND values"); +#define SAGEATTN_LAUNCH_PERTHREAD_FP8_OUT(HD_, BR_, CAUSAL_, OUT_T_) \ + qk_int8_sv_f8_native_2q_kernel<64, HD_, 0, ((HD_) / 16), true, BR_, true, CAUSAL_, OUT_T_, int8_t, false, int8_t, uint8_t, false, false, 0, false, false, 2, false, true><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale, true) +#define SAGEATTN_LAUNCH_PERTHREAD_FP8(HD_, BR_, CAUSAL_) \ + if (output_is_bf16) { \ + SAGEATTN_LAUNCH_PERTHREAD_FP8_OUT(HD_, BR_, CAUSAL_, __hip_bfloat16); \ + } else { \ + SAGEATTN_LAUNCH_PERTHREAD_FP8_OUT(HD_, BR_, CAUSAL_, __half); \ + } +#define SAGEATTN_LAUNCH_PERTHREAD_F16(HD_, BR_, CAUSAL_) \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, 4, CAUSAL_, false, false, int8_t, false, int8_t, false, false, false, false, false, HD_, false, true><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale, true) +#if SAGEATTN_GFX12_BUILD_ATTN_F16 && SAGEATTN_GFX12_BUILD_ATTN_FP8 +#define SAGEATTN_DISPATCH_PERTHREAD_HEADS(BR_, CAUSAL_) \ + if (value_is_fp8) { \ + if (head_dim == 16) { SAGEATTN_LAUNCH_PERTHREAD_FP8(16, BR_, CAUSAL_); } \ + else if (head_dim == 64) { SAGEATTN_LAUNCH_PERTHREAD_FP8(64, BR_, CAUSAL_); } \ + else { SAGEATTN_LAUNCH_PERTHREAD_FP8(128, BR_, CAUSAL_); } \ + } else { \ + if (head_dim == 16) { SAGEATTN_LAUNCH_PERTHREAD_F16(16, BR_, CAUSAL_); } \ + else if (head_dim == 64) { SAGEATTN_LAUNCH_PERTHREAD_F16(64, BR_, CAUSAL_); } \ + else { SAGEATTN_LAUNCH_PERTHREAD_F16(128, BR_, CAUSAL_); } \ + } +#elif SAGEATTN_GFX12_BUILD_ATTN_FP8 +#define SAGEATTN_DISPATCH_PERTHREAD_HEADS(BR_, CAUSAL_) \ + if (head_dim == 16) { SAGEATTN_LAUNCH_PERTHREAD_FP8(16, BR_, CAUSAL_); } \ + else if (head_dim == 64) { SAGEATTN_LAUNCH_PERTHREAD_FP8(64, BR_, CAUSAL_); } \ + else { SAGEATTN_LAUNCH_PERTHREAD_FP8(128, BR_, CAUSAL_); } +#else +#define SAGEATTN_DISPATCH_PERTHREAD_HEADS(BR_, CAUSAL_) \ + if (head_dim == 16) { SAGEATTN_LAUNCH_PERTHREAD_F16(16, BR_, CAUSAL_); } \ + else if (head_dim == 64) { SAGEATTN_LAUNCH_PERTHREAD_F16(64, BR_, CAUSAL_); } \ + else { SAGEATTN_LAUNCH_PERTHREAD_F16(128, BR_, CAUSAL_); } +#endif + if (block_rows == 64) { + if (is_causal) { SAGEATTN_DISPATCH_PERTHREAD_HEADS(64, true); } + else { SAGEATTN_DISPATCH_PERTHREAD_HEADS(64, false); } + } else { + if (is_causal) { SAGEATTN_DISPATCH_PERTHREAD_HEADS(128, true); } + else { SAGEATTN_DISPATCH_PERTHREAD_HEADS(128, false); } + } +#undef SAGEATTN_DISPATCH_PERTHREAD_HEADS +#undef SAGEATTN_LAUNCH_PERTHREAD_F16 +#undef SAGEATTN_LAUNCH_PERTHREAD_FP8 +#undef SAGEATTN_LAUNCH_PERTHREAD_FP8_OUT + hip_kernel_launch_check(); + return output; + } #define SAGEATTN_LAUNCH_FP8_2Q_OUT(BC_, HD_, HND_, BR_, OUT_T_) \ if (is_causal) { \ qk_int8_sv_f8_native_2q_kernel<<>>( \ @@ -6918,6 +7474,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ tensor_layout, sm_scale) +#if SAGEATTN_GFX12_BUILD_ATTN_F16 if (use_f16_causal_1q) { STD_TORCH_CHECK(hnd_contiguous, "fp16 single-q causal path requires contiguous HND tensors"); const bool use_f16_1q_pv_accum = use_f16_pv_accum; @@ -6946,7 +7503,14 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( else { SAGEATTN_DISPATCH_F16_1Q_CAUSAL(64, false, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD); } } #undef SAGEATTN_DISPATCH_F16_1Q_CAUSAL - } else if (use_fp8_2q && value_transposed_hnd) { + } +#endif // SAGEATTN_GFX12_BUILD_ATTN_F16 +#if SAGEATTN_GFX12_BUILD_ATTN_FP8 +#if SAGEATTN_GFX12_BUILD_ATTN_F16 + else if (use_fp8_2q && value_transposed_hnd) { +#else + if (use_fp8_2q && value_transposed_hnd) { +#endif STD_TORCH_CHECK(hnd_contiguous, "transposed fp8 value path requires contiguous HND Q/K/O"); STD_TORCH_CHECK(block_cols == 32 || block_cols == 64, "transposed fp8 value path currently supports BC32/BC64"); @@ -7107,7 +7671,15 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } else { SAGEATTN_LAUNCH_FP8_2Q(64, 64, false, 128); } - } else if (use_2q && value_transposed_hnd) { + } +#if SAGEATTN_GFX12_BUILD_ATTN_FP8 && !SAGEATTN_GFX12_BUILD_ATTN_F16 + else { + STD_TORCH_CHECK(false, "native gfx12 fp8 attention dispatch could not select a kernel"); + } +#endif +#endif // SAGEATTN_GFX12_BUILD_ATTN_FP8 +#if SAGEATTN_GFX12_BUILD_ATTN_F16 + else if (use_2q && value_transposed_hnd) { STD_TORCH_CHECK(hnd_contiguous, "transposed fp16 value path requires contiguous HND Q/K/O"); if (head_dim == 128) { if (block_rows == 32) { @@ -7225,39 +7797,52 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } else { SAGEATTN_LAUNCH_F16_1Q(64, 64); } +#endif // SAGEATTN_GFX12_BUILD_ATTN_F16 hip_kernel_launch_check(); return new_empty_like(query, {0}, ScalarType::Float); } -static const float* checked_value_scale_ptr_gfx12( +#if SAGEATTN_GFX12_BUILD_ATTN_F16 + +Tensor qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor query_scale, + Tensor key_scale, + int tensor_layout, + int is_causal, + float sm_scale, + int64_t valid_kv_len, Tensor value_scale, - int64_t batch, - int64_t kv_heads, - int64_t head_dim) { - STD_TORCH_CHECK(value_scale.is_cuda(), "value_scale must be a CUDA/HIP tensor"); - STD_TORCH_CHECK(value_scale.scalar_type() == ScalarType::Float, - "value_scale must be fp32"); - STD_TORCH_CHECK(value_scale.dim() == 3 && value_scale.is_contiguous(), - "value_scale must be contiguous [B, H_kv, D]"); - STD_TORCH_CHECK(value_scale.size(0) == batch && - value_scale.size(1) == kv_heads && - value_scale.size(2) == head_dim, - "value_scale shape must match [B, H_kv, D]"); - return reinterpret_cast(value_scale.data_ptr()); + int value_transposed_hnd_hint, + int pv_accum_mode) { + return qk_int8_sv_f16_d64_native_attn_gfx12_impl( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, sm_scale, valid_kv_len, value_scale, + value_transposed_hnd_hint, pv_accum_mode); } +#endif // SAGEATTN_GFX12_BUILD_ATTN_F16 + +#endif // SAGEATTN_GFX12_BUILD_ATTN_F16 || SAGEATTN_GFX12_BUILD_ATTN_FP8 + +#if SAGEATTN_GFX12_BUILD_RAWQ_FP8 + static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( Tensor query, Tensor key, Tensor value, Tensor output, Tensor key_scale, - const float* value_scale_ptr, + Tensor value_scale, int tensor_layout, int is_causal, float sm_scale, int64_t valid_kv_len, - int value_transposed_hnd_hint) { + int value_transposed_hnd_hint, + int key_hnd_layout) { STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), "raw-Q gfx12 tensors must be CUDA/HIP tensors"); STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4 && output.dim() == 4, @@ -7275,6 +7860,10 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( STD_TORCH_CHECK(key_scale.scalar_type() == ScalarType::Float, "raw-Q gfx12 attention key_scale must be fp32"); STD_TORCH_CHECK(tensor_layout == kHND || tensor_layout == kNHD, "invalid tensor_layout"); + STD_TORCH_CHECK(key_hnd_layout == 0 || key_hnd_layout == 1, + "key_hnd_layout must be 0 or 1"); + STD_TORCH_CHECK(tensor_layout == kNHD || key_hnd_layout == 0, + "key_hnd_layout is only needed for NHD query/output tensors"); STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous() && output.is_contiguous(), "raw-Q gfx12 attention expects contiguous tensors"); @@ -7286,8 +7875,11 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( const int64_t q_heads = tensor_layout == kNHD ? query.size(2) : query.size(1); const int64_t q_len = tensor_layout == kNHD ? query.size(1) : query.size(2); const int64_t out_q_len = tensor_layout == kNHD ? output.size(1) : output.size(2); - const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); - const int64_t padded_kv_len = tensor_layout == kNHD ? key.size(1) : key.size(2); + const bool key_hnd_contiguous = tensor_layout == kHND || key_hnd_layout != 0; + const int64_t kv_heads = key_hnd_contiguous ? key.size(1) : + (tensor_layout == kNHD ? key.size(2) : key.size(1)); + const int64_t padded_kv_len = key_hnd_contiguous ? key.size(2) : + (tensor_layout == kNHD ? key.size(1) : key.size(2)); const int64_t kv_len = valid_kv_len > 0 ? valid_kv_len : padded_kv_len; STD_TORCH_CHECK(kv_len > 0 && kv_len <= padded_kv_len, "valid_kv_len must be in (0, padded_kv_len]"); @@ -7318,21 +7910,29 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( const bool value_transposed_hnd = value_transposed_hnd_hint > 0 || (value_transposed_hnd_hint < 0 && value_shape_transposed_hnd); - STD_TORCH_CHECK(key.size(-1) == head_dim && output.size(-1) == head_dim && + STD_TORCH_CHECK(key.size(0) == batch && key.size(-1) == head_dim && + output.size(-1) == head_dim && (value_transposed_hnd || value_shape_normal), "raw-Q gfx12 Q/K/V/O head_dim mismatch"); + const bool key_shape_matches = + key_hnd_contiguous + ? (key.size(1) == kv_heads && key.size(2) == padded_kv_len) + : (tensor_layout == kNHD + ? (key.size(1) == padded_kv_len && key.size(2) == kv_heads) + : (key.size(1) == kv_heads && key.size(2) == padded_kv_len)); + STD_TORCH_CHECK(key_shape_matches, "raw-Q gfx12 key shape mismatch"); STD_TORCH_CHECK((tensor_layout == kNHD && ((value_transposed_hnd && output.size(1) >= q_len && - key.size(2) == kv_heads && output.size(2) == q_heads) || + output.size(2) == q_heads) || (!value_transposed_hnd && value.size(1) == padded_kv_len && - output.size(1) >= q_len && key.size(2) == kv_heads && - value.size(2) == kv_heads && output.size(2) == q_heads))) || + output.size(1) >= q_len && value.size(2) == kv_heads && + output.size(2) == q_heads))) || (tensor_layout == kHND && ((value_transposed_hnd && output.size(2) >= q_len && - key.size(1) == kv_heads && output.size(1) == q_heads) || + output.size(1) == q_heads) || (!value_transposed_hnd && value.size(2) == padded_kv_len && - output.size(2) >= q_len && key.size(1) == kv_heads && - value.size(1) == kv_heads && output.size(1) == q_heads))), + output.size(2) >= q_len && value.size(1) == kv_heads && + output.size(1) == q_heads))), "raw-Q gfx12 Q/K/V/O shape mismatch"); STD_TORCH_CHECK((q_heads % kv_heads) == 0, "q_heads must be divisible by kv_heads"); STD_TORCH_CHECK((padded_kv_len % 64) == 0, @@ -7342,12 +7942,28 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, "raw-Q gfx12 causal attention requires q_len == padded_kv_len"); STD_TORCH_CHECK(key_scale.stride(-1) == 1, "key_scale must have contiguous scale columns"); + const bool has_value_scale = value_scale.defined() && value_scale.numel() > 0; + if (has_value_scale) { + STD_TORCH_CHECK(value_scale.is_cuda(), "value_scale must be a CUDA/HIP tensor"); + STD_TORCH_CHECK(value_scale.scalar_type() == ScalarType::Float, + "value_scale must be fp32"); + STD_TORCH_CHECK(value_scale.dim() == 3 && value_scale.is_contiguous(), + "value_scale must be contiguous [B, H_kv, D]"); + STD_TORCH_CHECK(value_scale.size(0) == batch && + value_scale.size(1) == kv_heads && + value_scale.size(2) == head_dim, + "value_scale shape must match [B, H_kv, D]"); + } + const float* value_scale_ptr = + has_value_scale ? reinterpret_cast(value_scale.data_ptr()) : nullptr; int block_rows = head_dim == 64 ? - select_fp8_d64_block_rows_gfx12(q_len, is_causal, false) : + select_fp8_d64_block_rows_gfx12(q_len, is_causal, value_transposed_hnd) : (q_len <= 64 ? 64 : 128); - if (head_dim == 64 && !is_causal && q_len == 1024) { - block_rows = 128; + if (head_dim == 64 && !is_causal && value_transposed_hnd) { + if (q_len == 1024) { + block_rows = 256; + } } if (head_dim == 16 && is_causal && q_len <= 1024) { block_rows = 64; @@ -7356,79 +7972,109 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( STD_TORCH_CHECK(out_q_len >= q_blocks * block_rows, "raw-Q gfx12 attention output must cover the padded query tail"); - constexpr int block_cols = 64; + const bool use_bc32 = + !is_causal && value_transposed_hnd && tensor_layout == kNHD && + !key_hnd_contiguous && head_dim == 128 && q_len == 1024; const bool hnd_contiguous = tensor_layout == kHND; const dim3 block(block_rows); const dim3 grid(q_blocks, q_heads, batch); -#define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_) \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ +#define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_) \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), \ reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ nullptr, reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ - key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(key_hnd_contiguous ? 2 : (tensor_layout == kNHD ? 1 : 2)), key.stride(key_hnd_contiguous ? 1 : (tensor_layout == kNHD ? 2 : 1)), \ value.stride(0), (VT_ ? value.stride(2) : value.stride(tensor_layout == kNHD ? 1 : 2)), (VT_ ? value.stride(1) : value.stride(tensor_layout == kNHD ? 2 : 1)), \ output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ 0, 0, \ key_scale.stride(0), key_scale.stride(1), \ tensor_layout, sm_scale) -#define SAGEATTN_DISPATCH_RAWQ_FP8_OUT(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_) \ +#define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_) \ + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_, false, false, false, false) +#define SAGEATTN_DISPATCH_RAWQ_FP8_OUT(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, QUERY_AT_T_) \ if (output.scalar_type() == ScalarType::BFloat16) { \ - SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_, __hip_bfloat16); \ + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, __hip_bfloat16, QUERY_AT_T_, at::BFloat16); \ } else { \ - SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(HD_, HND_, BR_, VT_, CAUSAL_, QUERY_T_, __half); \ + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, __half, QUERY_AT_T_, at::Half); \ } -#define SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, BR_, VT_, CAUSAL_) \ +#define SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_) \ if (query.scalar_type() == ScalarType::BFloat16) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_OUT(HD_, HND_, BR_, VT_, CAUSAL_, __hip_bfloat16); \ + SAGEATTN_DISPATCH_RAWQ_FP8_OUT(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, __hip_bfloat16, at::BFloat16); \ } else { \ - SAGEATTN_DISPATCH_RAWQ_FP8_OUT(HD_, HND_, BR_, VT_, CAUSAL_, __half); \ + SAGEATTN_DISPATCH_RAWQ_FP8_OUT(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, __half, at::Half); \ } -#define SAGEATTN_DISPATCH_RAWQ_FP8_BR(HD_, HND_, VT_, CAUSAL_) \ +#define SAGEATTN_DISPATCH_RAWQ_FP8_BR(BC_, HD_, HND_, KEY_HND_, VT_, CAUSAL_) \ if (block_rows == 64) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 64, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 64, VT_, CAUSAL_); \ } else if (block_rows == 256) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 256, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 256, VT_, CAUSAL_); \ } else if (block_rows == 512) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 512, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 512, VT_, CAUSAL_); \ } else { \ - SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(HD_, HND_, 128, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 128, VT_, CAUSAL_); \ } -#define SAGEATTN_DISPATCH_RAWQ_FP8_HD(HND_, VT_, CAUSAL_) \ +#define SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, HND_, KEY_HND_, VT_, CAUSAL_) \ if (head_dim == 16) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_BR(16, HND_, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_BR(BC_, 16, HND_, KEY_HND_, VT_, CAUSAL_); \ } else if (head_dim == 64) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_BR(64, HND_, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_BR(BC_, 64, HND_, KEY_HND_, VT_, CAUSAL_); \ } else { \ - SAGEATTN_DISPATCH_RAWQ_FP8_BR(128, HND_, VT_, CAUSAL_); \ + SAGEATTN_DISPATCH_RAWQ_FP8_BR(BC_, 128, HND_, KEY_HND_, VT_, CAUSAL_); \ + } +#define SAGEATTN_DISPATCH_RAWQ_FP8_LAYOUT(BC_) \ + if (hnd_contiguous) { \ + if (is_causal) { \ + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, true, true, true, true); } \ + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, true, true, false, true); } \ + } else { \ + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, true, true, true, false); } \ + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, true, true, false, false); } \ + } \ + } else if (key_hnd_contiguous) { \ + if (is_causal) { \ + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, true, true, true); } \ + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, true, false, true); } \ + } else { \ + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, true, true, false); } \ + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, true, false, false); } \ + } \ + } else { \ + if (is_causal) { \ + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, false, true, true); } \ + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, false, false, true); } \ + } else { \ + if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, false, true, false); } \ + else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(BC_, false, false, false, false); } \ + } \ } - if (hnd_contiguous) { - if (is_causal) { - if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, true, true); } - else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, false, true); } - } else { - if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, true, false); } - else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(true, false, false); } - } + const bool use_static_short_nhd = + !is_causal && value_transposed_hnd && tensor_layout == kNHD && + !key_hnd_contiguous && q_heads == kv_heads && q_len == kv_len && + q_len == 512 && head_dim == 128 && + query.scalar_type() == ScalarType::Half && output.scalar_type() == ScalarType::Half; + + if (use_static_short_nhd) { + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(64, 128, false, false, 128, true, false, + __half, __half, at::Half, at::Half, + true, true, true, true); + } else if (use_bc32) { + SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(32, 128, false, false, 128, true, false); } else { - if (is_causal) { - if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, true, true); } - else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, false, true); } - } else { - if (value_transposed_hnd) { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, true, false); } - else { SAGEATTN_DISPATCH_RAWQ_FP8_HD(false, false, false); } - } + SAGEATTN_DISPATCH_RAWQ_FP8_LAYOUT(64); } +#undef SAGEATTN_DISPATCH_RAWQ_FP8_LAYOUT #undef SAGEATTN_DISPATCH_RAWQ_FP8_HD #undef SAGEATTN_DISPATCH_RAWQ_FP8_BR #undef SAGEATTN_DISPATCH_RAWQ_FP8_QUERY #undef SAGEATTN_DISPATCH_RAWQ_FP8_OUT #undef SAGEATTN_LAUNCH_RAWQ_FP8_TYPED +#undef SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX hip_kernel_launch_check(); return output; } @@ -7443,12 +8089,14 @@ Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len, - int64_t value_transposed_hnd) { + int64_t value_transposed_hnd, + int64_t key_hnd_layout) { return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( - query, key, value, output, key_scale, nullptr, + query, key, value, output, key_scale, Tensor(), static_cast(tensor_layout), static_cast(is_causal), static_cast(sm_scale), valid_kv_len, - static_cast(value_transposed_hnd)); + static_cast(value_transposed_hnd), + static_cast(key_hnd_layout)); } Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( @@ -7462,19 +8110,20 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len, - int64_t value_transposed_hnd) { - const int64_t head_dim = query.size(-1); - const int64_t batch = query.size(0); - const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); - const float* value_scale_ptr = - checked_value_scale_ptr_gfx12(value_scale, batch, kv_heads, head_dim); + int64_t value_transposed_hnd, + int64_t key_hnd_layout) { return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( - query, key, value, output, key_scale, value_scale_ptr, + query, key, value, output, key_scale, value_scale, static_cast(tensor_layout), static_cast(is_causal), static_cast(sm_scale), valid_kv_len, - static_cast(value_transposed_hnd)); + static_cast(value_transposed_hnd), + static_cast(key_hnd_layout)); } +#endif // SAGEATTN_GFX12_BUILD_RAWQ_FP8 + +#if SAGEATTN_GFX12_BUILD_ATTN_FP8 + Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( Tensor query, Tensor key, @@ -7487,17 +8136,16 @@ Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len) { - const int64_t head_dim = query.size(-1); - const int64_t batch = query.size(0); - const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); - const float* value_scale_ptr = - checked_value_scale_ptr_gfx12(value_scale, batch, kv_heads, head_dim); return qk_int8_sv_f16_d64_native_attn_gfx12_impl( query, key, value, output, query_scale, key_scale, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, value_scale_ptr, 1); + static_cast(sm_scale), valid_kv_len, value_scale, 1, -1); } +#endif // SAGEATTN_GFX12_BUILD_ATTN_FP8 + +#if SAGEATTN_GFX12_BUILD_ATTN_F16 + Tensor qk_int8_sv_f16_d64_native_attn_gfx12( Tensor query, Tensor key, @@ -7509,10 +8157,121 @@ Tensor qk_int8_sv_f16_d64_native_attn_gfx12( int64_t is_causal, double sm_scale, int64_t valid_kv_len, - int64_t value_transposed_hnd) { + int64_t value_transposed_hnd, + int64_t pv_accum_mode) { return qk_int8_sv_f16_d64_native_attn_gfx12_impl( query, key, value, output, query_scale, key_scale, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, nullptr, - static_cast(value_transposed_hnd)); + static_cast(sm_scale), valid_kv_len, Tensor(), + static_cast(value_transposed_hnd), static_cast(pv_accum_mode)); +} + +Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor output, + Tensor key_scale, + int64_t tensor_layout, + int64_t is_causal, + double sm_scale, + int64_t valid_kv_len, + int64_t pv_accum_mode) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && output.is_cuda(), + "raw-Q fp16 gfx12 tensors must be CUDA/HIP tensors"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || + query.scalar_type() == ScalarType::BFloat16, + "raw-Q fp16 gfx12 query must be fp16 or bf16"); + STD_TORCH_CHECK(key.scalar_type() == ScalarType::Char, "raw-Q fp16 gfx12 key must be int8"); + STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half, "raw-Q fp16 gfx12 value must be fp16"); + STD_TORCH_CHECK(output.scalar_type() == ScalarType::Half, "raw-Q fp16 gfx12 output must be fp16"); + STD_TORCH_CHECK(key_scale.scalar_type() == ScalarType::Float, + "raw-Q fp16 gfx12 key_scale must be fp32"); + STD_TORCH_CHECK(tensor_layout == kHND || tensor_layout == kNHD, "invalid tensor_layout"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4 && output.dim() == 4, + "raw-Q fp16 gfx12 attention expects 4D tensors"); + const int64_t head_dim = query.size(-1); + STD_TORCH_CHECK(head_dim == 16 || head_dim == 64 || head_dim == 128, + "raw-Q fp16 gfx12 supports D16/D64/D128"); + STD_TORCH_CHECK(key.size(-1) == head_dim && value.size(-1) == head_dim && + output.size(-1) == head_dim, + "raw-Q fp16 gfx12 tensors must have matching head_dim"); + const int64_t batch = query.size(0); + const int64_t q_heads = tensor_layout == kNHD ? query.size(2) : query.size(1); + const int64_t q_len = tensor_layout == kNHD ? query.size(1) : query.size(2); + const int64_t kv_heads = tensor_layout == kNHD ? key.size(2) : key.size(1); + const int64_t padded_kv_len = tensor_layout == kNHD ? key.size(1) : key.size(2); + const int64_t kv_len = valid_kv_len > 0 ? valid_kv_len : padded_kv_len; + STD_TORCH_CHECK(kv_len > 0 && kv_len <= padded_kv_len, + "valid_kv_len must be in (0, padded_kv_len]"); + STD_TORCH_CHECK((padded_kv_len % 64) == 0, + "raw-Q fp16 gfx12 requires kv_len to be a multiple of 64"); + STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, + "raw-Q fp16 gfx12 causal path requires q_len == kv_len"); + STD_TORCH_CHECK((q_heads % kv_heads) == 0, "q_heads must be divisible by kv_heads"); + STD_TORCH_CHECK(key_scale.dim() == 3 && key_scale.stride(-1) == 1, + "raw-Q fp16 gfx12 key_scale must be [B, H_kv, ceil(K/64)]"); + STD_TORCH_CHECK(key_scale.size(0) == batch && key_scale.size(1) == kv_heads && + key_scale.size(2) == (padded_kv_len + 63) / 64, + "raw-Q fp16 gfx12 key_scale shape mismatch"); + STD_TORCH_CHECK(pv_accum_mode >= -1 && pv_accum_mode <= 1, + "invalid gfx12 fp16 PV accumulation mode"); + + const bool hnd_contiguous = tensor_layout == kHND && + query.is_contiguous() && key.is_contiguous() && + value.is_contiguous() && output.is_contiguous(); + const int block_rows = q_len <= 64 ? 64 : 128; + const dim3 block(block_rows); + const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); +#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_) \ + qk_int8_sv_f16_d64_native_2q_kernel<64, HND_, BR_, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD, CAUSAL_, false, F16ACC_, QUERY_T_, true, int8_t, false, false, false, false, false, HD_><<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + nullptr, reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + 0, 0, key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale) +#define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(HD_, HND_, QUERY_T_) \ + if (is_causal) { \ + if (pv_accum_mode == 1) { \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, true, QUERY_T_, true); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, true, QUERY_T_, true); } \ + } else { \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, true, QUERY_T_, false); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, true, QUERY_T_, false); } \ + } \ + } else if (pv_accum_mode == 1) { \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, false, QUERY_T_, true); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, false, QUERY_T_, true); } \ + } else { \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, false, QUERY_T_, false); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, false, QUERY_T_, false); } \ + } +#define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE(QUERY_T_) \ + if (hnd_contiguous) { \ + if (head_dim == 16) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(16, true, QUERY_T_); } \ + else if (head_dim == 64) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(64, true, QUERY_T_); } \ + else { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(128, true, QUERY_T_); } \ + } else { \ + if (head_dim == 16) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(16, false, QUERY_T_); } \ + else if (head_dim == 64) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(64, false, QUERY_T_); } \ + else { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(128, false, QUERY_T_); } \ + } + if (query.scalar_type() == ScalarType::Half) { + SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE(__half); + } else { + SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE(__hip_bfloat16); + } +#undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE +#undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND +#undef SAGEATTN_LAUNCH_RAWQ_F16_VALUE + hip_kernel_launch_check(); + return new_empty_like(query, {0}, ScalarType::Float); } + +#endif // SAGEATTN_GFX12_BUILD_ATTN_F16 diff --git a/csrc/qattn/qk_int_sv_gfx12_native_attn_f16.cu b/csrc/qattn/qk_int_sv_gfx12_native_attn_f16.cu new file mode 100644 index 00000000..f385a76b --- /dev/null +++ b/csrc/qattn/qk_int_sv_gfx12_native_attn_f16.cu @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. + * + * Licensed under the Apache License, Version 2.0. + */ + +#define SAGEATTN_GFX12_BUILD_AUX 0 +#define SAGEATTN_GFX12_BUILD_PREPARE 0 +#define SAGEATTN_GFX12_BUILD_ATTN_F16 1 +#define SAGEATTN_GFX12_BUILD_ATTN_FP8 0 +#define SAGEATTN_GFX12_BUILD_RAWQ_FP8 0 +#include "qk_int_sv_gfx12_native.cu" diff --git a/csrc/qattn/qk_int_sv_gfx12_native_attn_fp8.cu b/csrc/qattn/qk_int_sv_gfx12_native_attn_fp8.cu new file mode 100644 index 00000000..6347cc1d --- /dev/null +++ b/csrc/qattn/qk_int_sv_gfx12_native_attn_fp8.cu @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. + * + * Licensed under the Apache License, Version 2.0. + */ + +#define SAGEATTN_GFX12_BUILD_AUX 0 +#define SAGEATTN_GFX12_BUILD_PREPARE 0 +#define SAGEATTN_GFX12_BUILD_ATTN_F16 0 +#define SAGEATTN_GFX12_BUILD_ATTN_FP8 1 +#define SAGEATTN_GFX12_BUILD_RAWQ_FP8 0 +#include "qk_int_sv_gfx12_native.cu" diff --git a/csrc/qattn/qk_int_sv_gfx12_native_aux.cu b/csrc/qattn/qk_int_sv_gfx12_native_aux.cu new file mode 100644 index 00000000..19e3d006 --- /dev/null +++ b/csrc/qattn/qk_int_sv_gfx12_native_aux.cu @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. + * + * Licensed under the Apache License, Version 2.0. + */ + +#define SAGEATTN_GFX12_BUILD_AUX 1 +#define SAGEATTN_GFX12_BUILD_PREPARE 0 +#define SAGEATTN_GFX12_BUILD_ATTN_F16 0 +#define SAGEATTN_GFX12_BUILD_ATTN_FP8 0 +#define SAGEATTN_GFX12_BUILD_RAWQ_FP8 0 +#include "qk_int_sv_gfx12_native.cu" diff --git a/csrc/qattn/qk_int_sv_gfx12_native_prepare.cu b/csrc/qattn/qk_int_sv_gfx12_native_prepare.cu new file mode 100644 index 00000000..a6bce427 --- /dev/null +++ b/csrc/qattn/qk_int_sv_gfx12_native_prepare.cu @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. + * + * Licensed under the Apache License, Version 2.0. + */ + +#define SAGEATTN_GFX12_BUILD_AUX 0 +#define SAGEATTN_GFX12_BUILD_PREPARE 1 +#define SAGEATTN_GFX12_BUILD_ATTN_F16 0 +#define SAGEATTN_GFX12_BUILD_ATTN_FP8 0 +#define SAGEATTN_GFX12_BUILD_RAWQ_FP8 0 +#include "qk_int_sv_gfx12_native.cu" diff --git a/csrc/qattn/qk_int_sv_gfx12_native_rawq_fp8.cu b/csrc/qattn/qk_int_sv_gfx12_native_rawq_fp8.cu new file mode 100644 index 00000000..5aa04fb5 --- /dev/null +++ b/csrc/qattn/qk_int_sv_gfx12_native_rawq_fp8.cu @@ -0,0 +1,12 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. + * + * Licensed under the Apache License, Version 2.0. + */ + +#define SAGEATTN_GFX12_BUILD_AUX 0 +#define SAGEATTN_GFX12_BUILD_PREPARE 0 +#define SAGEATTN_GFX12_BUILD_ATTN_F16 0 +#define SAGEATTN_GFX12_BUILD_ATTN_FP8 0 +#define SAGEATTN_GFX12_BUILD_RAWQ_FP8 1 +#include "qk_int_sv_gfx12_native.cu" diff --git a/sageattention/core.py b/sageattention/core.py index 0b323139..06b3f60e 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -165,6 +165,82 @@ def _gfx12_fp8_value_scale_hnd(v_hnd: torch.Tensor, scale_max: float) -> torch.T return v_hnd.abs().amax(dim=2).to(torch.float32).div(scale_max).contiguous() +def _gfx12_fp8_value_native( + gfx12_native: Any, + value: torch.Tensor, + scale_max: float, + tensor_layout: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + value_hnd = value if tensor_layout == "HND" else value.transpose(1, 2).contiguous() + value_scale = _gfx12_fp8_value_scale_hnd(value_hnd, scale_max) + value_native = gfx12_native.transpose_value_fp8_scaled_hnd(value_hnd, value_scale) + return value_native, value_scale + + +def _gfx12_normalize_v2_options( + value_dtype: str, + pv_accum_dtype: Optional[str], + smooth_v: bool, +) -> Tuple[str, str, bool, float]: + value_dtype_normalized = value_dtype.lower() + if value_dtype_normalized == "auto": + value_dtype_normalized = "fp8" + if value_dtype_normalized not in {"fp16", "fp8"}: + raise ValueError("gfx12 native value_dtype must be 'auto', 'fp16', or 'fp8'.") + if pv_accum_dtype is None: + pv_accum_dtype = "fp32+fp16" if value_dtype_normalized == "fp8" else "fp32" + if value_dtype_normalized == "fp8": + if pv_accum_dtype not in {"fp32+fp16", "fp32", "fp32+fp32"}: + raise ValueError("gfx12 fp8 value path supports pv_accum_dtype 'fp32+fp16', 'fp32', or 'fp32+fp32'.") + if smooth_v and pv_accum_dtype in {"fp32+fp16", "fp32+fp32"}: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + return value_dtype_normalized, pv_accum_dtype, smooth_v, ( + _GFX12_FP8_VALUE_SCALE_MAX_FP32_FP16 if pv_accum_dtype == "fp32+fp16" else 448.0 + ) + if pv_accum_dtype not in {"fp32", "fp16", "fp16+fp32"}: + raise ValueError("gfx12 fp16 value path supports pv_accum_dtype 'fp32', 'fp16', or 'fp16+fp32'.") + if smooth_v and pv_accum_dtype in {"fp32", "fp16+fp32"}: + warnings.warn(f"pv_accum_dtype is {pv_accum_dtype}, smooth_v will be ignored.") + smooth_v = False + return value_dtype_normalized, pv_accum_dtype, smooth_v, _GFX12_FP8_VALUE_SCALE_MAX_FP32_FP16 + + +def _gfx12_pv_accum_mode(value_dtype: str, pv_accum_dtype: str) -> int: + if value_dtype != "fp16": + return -1 + return 1 if pv_accum_dtype == "fp16" else 0 + + +def _gfx12_apply_smooth_v( + v: torch.Tensor, + tensor_layout: str, + q_heads: int, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + seq_dim = 1 if tensor_layout == "NHD" else 2 + head_dim = 2 if tensor_layout == "NHD" else 1 + vm = v.mean(dim=seq_dim) + centered = (v - vm.unsqueeze(seq_dim)).to(torch.float16) + kv_heads = v.size(head_dim) + if q_heads % kv_heads != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads.") + if q_heads != kv_heads: + vm = torch.repeat_interleave(vm, q_heads // kv_heads, dim=1) + return centered, vm + + +def _gfx12_add_smooth_v_mean( + out: torch.Tensor, + vm: Optional[torch.Tensor], + tensor_layout: str, +) -> torch.Tensor: + if vm is None: + return out + if tensor_layout == "NHD": + return out + vm.unsqueeze(1).to(out.dtype) + return out + vm.unsqueeze(2).to(out.dtype) + + def _attention_lse_reference( q: torch.Tensor, k: torch.Tensor, @@ -243,86 +319,44 @@ def sageattn_qk_int8_pv_gfx12_native( not affect the default return_lse=False fast path. """ - if kwargs: - unsupported = ", ".join(sorted(key for key, value in kwargs.items() if value is not None)) - if unsupported: - raise TypeError(f"Unsupported non-None gfx12 SageAttention arguments: {unsupported}") if qk_quant_gran not in {"per_warp", "per_thread"}: raise ValueError("qk_quant_gran must be either 'per_warp' or 'per_thread'.") - if qk_quant_gran != "per_warp": - raise NotImplementedError("gfx12 native currently supports qk_quant_gran='per_warp'.") - if smooth_v: - raise NotImplementedError("gfx12 native does not support smooth_v yet.") - value_dtype_normalized = value_dtype.lower() - if value_dtype_normalized == "auto": - value_dtype_normalized = "fp8" - if pv_accum_dtype is None: - pv_accum_dtype = "fp32+fp16" if value_dtype_normalized == "fp8" else "fp32" - if value_dtype_normalized == "fp8" and pv_accum_dtype not in {"fp32+fp16", "fp32", "fp32+fp32"}: - raise ValueError("gfx12 fp8 value path supports pv_accum_dtype 'fp32+fp16', 'fp32', or 'fp32+fp32'.") - if value_dtype_normalized == "fp8" and pv_accum_dtype != "fp32+fp16": - raise NotImplementedError("gfx12 fp8 value path currently supports pv_accum_dtype='fp32+fp16'.") - if value_dtype_normalized == "fp16" and pv_accum_dtype not in {"fp32", "fp16", "fp16+fp32"}: - raise ValueError("gfx12 fp16 value path supports pv_accum_dtype 'fp32', 'fp16', or 'fp16+fp32'.") - if value_dtype_normalized == "fp16" and pv_accum_dtype != "fp32": - raise NotImplementedError("gfx12 fp16 value path currently supports pv_accum_dtype='fp32'.") - fp8_value_scale_max = _GFX12_FP8_VALUE_SCALE_MAX_FP32_FP16 + value_dtype_normalized, pv_accum_dtype, smooth_v, fp8_value_scale_max = ( + _gfx12_normalize_v2_options(value_dtype, pv_accum_dtype, smooth_v) + ) + pv_accum_mode = _gfx12_pv_accum_mode(value_dtype_normalized, pv_accum_dtype) gfx12_native = _get_gfx12_native_extension() gfx12_prepare_attn_hnd = _qattn_gfx12_prepare_attn_hnd + assert q.is_cuda, "Input tensors must be on cuda/HIP." + assert q.device == k.device == v.device, "All tensors must be on the same device." + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." + assert q.dtype in [torch.float16, torch.bfloat16], "gfx12 native path supports fp16/bf16 inputs." + assert tensor_layout in ["HND", "NHD"], "tensor_layout must be either 'HND' or 'NHD'." + input_dtype = q.dtype + + if smooth_v: + q_heads = q.size(2) if tensor_layout == "NHD" else q.size(1) + v, smooth_v_mean = _gfx12_apply_smooth_v(v, tensor_layout, q_heads) + else: + smooth_v_mean = None + lse_q = q lse_k = k lse_sm_scale = float(sm_scale if sm_scale is not None else q.size(-1) ** -0.5) def _with_lse(out: torch.Tensor): + out = _gfx12_add_smooth_v_mean(out, smooth_v_mean, tensor_layout) if not return_lse: return out return out, _attention_lse_reference( lse_q, lse_k, tensor_layout, bool(is_causal), lse_sm_scale ) - if ( - tensor_layout == "HND" - and value_dtype_normalized == "fp16" - and not smooth_k - and not return_lse - and q.dim() == 4 - and k.dim() == 4 - and v.dim() == 4 - and q.is_cuda - and q.is_contiguous() - and k.is_contiguous() - and v.is_contiguous() - and q.dtype == torch.float16 - and q.dtype == k.dtype == v.dtype - and q.device == k.device == v.device - and q.size(-1) in (16, 64) - and q.size(2) % 64 == 0 - and k.size(2) % 64 == 0 - ): - torch.cuda.set_device(v.device) - use_raw_f16_value = is_causal and q.size(-1) == 64 and q.size(2) <= 512 - return gfx12_prepare_attn_hnd( - q, - k, - v, - int(is_causal), - 0, - int(use_raw_f16_value), - float(sm_scale if sm_scale is not None else q.size(-1) ** -0.5), - ) - - assert q.is_cuda, "Input tensors must be on cuda/HIP." - assert q.device == k.device == v.device, "All tensors must be on the same device." - assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype." - assert q.dtype in [torch.float16, torch.bfloat16], "gfx12 native path supports fp16/bf16 inputs." - assert tensor_layout in ["HND", "NHD"], "tensor_layout must be either 'HND' or 'NHD'." torch.cuda.set_device(v.device) - input_dtype = q.dtype + assert v.dtype in [torch.float16, torch.bfloat16], "gfx12 native path supports fp16/bf16 value inputs." value_dtype = value_dtype_normalized - if value_dtype not in {"fp16", "fp8"}: - raise ValueError("gfx12 native value_dtype must be 'auto', 'fp16', or 'fp8'.") if sm_scale is None and q.dim() == 4: sm_scale = q.size(-1) ** -0.5 @@ -332,6 +366,7 @@ def _with_lse(out: torch.Tensor): and q.dim() == 4 and k.dim() == 4 and v.dim() == 4 + and q.dtype == k.dtype == v.dtype and q.is_contiguous() and k.is_contiguous() and v.is_contiguous() @@ -356,12 +391,14 @@ def _with_lse(out: torch.Tensor): int(value_dtype == "fp8"), int(use_raw_f16_value), float(sm_scale), + 0, + pv_accum_mode, ) if input_dtype == torch.bfloat16: out = out if out.dtype == torch.bfloat16 else gfx12_native.convert_f16_to_bf16(out) return _with_lse(out) - if tensor_layout == "NHD" and smooth_k and not ( + if tensor_layout == "NHD" and smooth_k and qk_quant_gran == "per_warp" and not ( value_dtype == "fp16" and q.size(-1) > 64 ): q_nhd = q.contiguous() @@ -394,11 +431,45 @@ def _with_lse(out: torch.Tensor): if value_dtype == "fp8" and head_dim not in (16, 64, 128): raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") - k_mean = k_nhd.mean(dim=1, keepdim=True) + use_short_nhd_fp8_prep = ( + value_dtype == "fp8" + and not is_causal + and input_dtype == torch.float16 + and qo_len == kv_len + and kv_len in (512, 1024) + and head_dim in (64, 128) + ) + value_native = None + value_scale = None + if use_short_nhd_fp8_prep: + k_mean_flat, value_native, value_scale = ( + gfx12_native.mean_and_fp8_value_nhd_short( + k_nhd, v_nhd, float(fp8_value_scale_max) + ) + ) + k_mean = k_mean_flat.unsqueeze(1) + else: + k_mean = k_nhd.mean(dim=1, keepdim=True) + k_mean_flat = k_mean.squeeze(1) use_rawq_tail = value_dtype == "fp8" and not is_causal and head_dim == 128 - if use_rawq_tail: + use_mixed_key_hnd = value_dtype == "fp8" and ( + ( + is_causal + and ( + (head_dim == 64 and qo_len >= 8192) + or (head_dim == 128 and qo_len >= 4096) + ) + ) + ) + use_rawq_f16_value = ( + value_dtype == "fp16" + and not is_causal + and head_dim == 64 + and qk_quant_gran == "per_warp" + ) + if use_rawq_tail or use_rawq_f16_value: q_attn = q_nhd - q_out_len = ((qo_len + 127) // 128) * 128 + q_out_len = ((qo_len + 127) // 128) * 128 if use_rawq_tail else qo_len kv_pad_len = ((kv_len + 63) // 64) * 64 - kv_len if kv_pad_len > 0: k_nhd = torch.cat([k_nhd, k_mean.expand(-1, kv_pad_len, -1, -1)], dim=1) @@ -409,22 +480,33 @@ def _with_lse(out: torch.Tensor): ) q_attn = q_nhd q_out_len = q_nhd.size(1) - k_int8 = torch.empty_like(k_nhd, dtype=torch.int8) - k_scale = torch.empty( - (k_nhd.size(0), k_nhd.size(2), (k_nhd.size(1) + 63) // 64), - device=k_nhd.device, - dtype=torch.float32, - ) - _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( - k_nhd, k_mean.squeeze(1), k_int8, k_scale, 64, 0 - ) - value_scale = None - if value_dtype == "fp8": - v_hnd_for_value = v_nhd.transpose(1, 2).contiguous() - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd_for_value, fp8_value_scale_max) - value_native = gfx12_native.transpose_value_fp8_scaled_hnd( - v_hnd_for_value, value_scale + if use_mixed_key_hnd: + k_attn = k_nhd.transpose(1, 2).contiguous() + k_mean_attn = k_mean.transpose(1, 2).contiguous() + k_int8 = torch.empty_like(k_attn, dtype=torch.int8) + k_scale = torch.empty( + (k_attn.size(0), k_attn.size(1), (k_attn.size(2) + 63) // 64), + device=k_attn.device, + dtype=torch.float32, + ) + _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( + k_attn, k_mean_attn.squeeze(2), k_int8, k_scale, 64, 1 + ) + else: + k_int8 = torch.empty_like(k_nhd, dtype=torch.int8) + k_scale = torch.empty( + (k_nhd.size(0), k_nhd.size(2), (k_nhd.size(1) + 63) // 64), + device=k_nhd.device, + dtype=torch.float32, + ) + _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( + k_nhd, k_mean_flat, k_int8, k_scale, 64, 0 ) + if value_dtype == "fp8": + if value_native is None: + value_native, value_scale = _gfx12_fp8_value_native( + gfx12_native, v_nhd, fp8_value_scale_max, "NHD" + ) else: value_native = v_nhd if input_dtype == torch.float16 else v_nhd.to(torch.float16) out = torch.empty( @@ -445,22 +527,38 @@ def _with_lse(out: torch.Tensor): float(sm_scale), kv_len, 1, + int(use_mixed_key_hnd), ) else: - q_int8, q_scale = gfx12_native.quant_q_nhd_per_warp(q_attn) - gfx12_native.qk_int8_sv_f16_d64_native_attn( - q_int8, - k_int8, - value_native, - out, - q_scale, - k_scale, - 0, - int(is_causal), - float(sm_scale), - kv_len, - 0, - ) + if head_dim == 64 and qk_quant_gran == "per_warp": + gfx12_native.qk_rawq_int8_sv_f16_native_attn( + q_attn, + k_int8, + value_native, + out, + k_scale, + 0, + int(is_causal), + float(sm_scale), + kv_len, + pv_accum_mode, + ) + else: + q_int8, q_scale = gfx12_native.quant_q_nhd_per_warp(q_attn) + gfx12_native.qk_int8_sv_f16_d64_native_attn( + q_int8, + k_int8, + value_native, + out, + q_scale, + k_scale, + 0, + int(is_causal), + float(sm_scale), + kv_len, + 0, + pv_accum_mode, + ) out = out[:, :qo_len, :, :head_dim_og] if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: out = gfx12_native.convert_f16_to_bf16(out.contiguous() if not out.is_contiguous() else out) @@ -485,7 +583,9 @@ def _with_lse(out: torch.Tensor): raise ValueError("gfx12 causal path currently requires q_len == kv_len.") head_dim = head_dim_og - if head_dim < 64 and (smooth_k or head_dim != 16): + if head_dim < 64 and ( + smooth_k or head_dim != 16 or value_dtype == "fp8" or q_hnd.dtype != v_hnd.dtype + ): pad = 64 - head_dim q_hnd = F.pad(q_hnd, (0, pad)) k_hnd = F.pad(k_hnd, (0, pad)) @@ -516,29 +616,50 @@ def _with_lse(out: torch.Tensor): and padded_qo_len <= 512 ) + def _quant_qk_hnd(q_src: torch.Tensor, k_src: torch.Tensor, km_src: Optional[torch.Tensor]): + if qk_quant_gran == "per_thread": + return per_thread_int8_triton( + q_src, k_src, km_src, BLKQ=128, + WARPQ=(16 if (head_dim == 128 and pv_accum_dtype == "fp16+fp32") else 32), + BLKK=64, WARPK=64, tensor_layout="HND" + ) + return per_warp_int8_cuda( + q_src, k_src, km_src, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" + ) + if not smooth_k: if value_dtype == "fp8": - q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( - q_hnd, k_hnd, None, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" + q_int8, q_scale, k_int8, k_scale = _quant_qk_hnd(q_hnd, k_hnd, None) + value_native, value_scale = _gfx12_fp8_value_native( + gfx12_native, v_hnd, fp8_value_scale_max, "HND" ) - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd, fp8_value_scale_max) - value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) out = torch.empty_like(q_hnd, dtype=torch.float16) gfx12_native.qk_int8_sv_f8_scaled_native_attn( q_int8, k_int8, value_native, out, q_scale, k_scale, value_scale, 1, int(is_causal), float(sm_scale), kv_len ) else: - out = gfx12_prepare_attn_hnd( - q_hnd, - k_hnd, - v_hnd, - int(is_causal), - 0, - int(use_raw_f16_value), - float(sm_scale), - kv_len, - ) + if qk_quant_gran == "per_warp" and q_hnd.dtype == k_hnd.dtype == v_hnd.dtype: + out = gfx12_prepare_attn_hnd( + q_hnd, + k_hnd, + v_hnd, + int(is_causal), + 0, + int(use_raw_f16_value), + float(sm_scale), + kv_len, + pv_accum_mode, + ) + else: + q_int8, q_scale, k_int8, k_scale = _quant_qk_hnd(q_hnd, k_hnd, None) + value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) + out = torch.empty_like(q_hnd, dtype=torch.float16) + gfx12_native.qk_int8_sv_f16_d64_native_attn( + q_int8, k_int8, value_native, out, q_scale, k_scale, + 1, int(is_causal), float(sm_scale), kv_len, 1, + pv_accum_mode + ) else: use_rawq_hnd_fp8 = ( value_dtype == "fp8" @@ -550,7 +671,7 @@ def _with_lse(out: torch.Tensor): or padded_qo_len >= 8192 ) ) - if use_rawq_hnd_fp8: + if use_rawq_hnd_fp8 and qk_quant_gran == "per_warp": k_int8 = torch.empty_like(k_hnd, dtype=torch.int8) k_scale = torch.empty( (k_hnd.size(0), k_hnd.size(1), (k_hnd.size(2) + 63) // 64), @@ -560,8 +681,9 @@ def _with_lse(out: torch.Tensor): _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( k_hnd, k_mean.squeeze(2), k_int8, k_scale, 64, 1 ) - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd, fp8_value_scale_max) - value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) + value_native, value_scale = _gfx12_fp8_value_native( + gfx12_native, v_hnd, fp8_value_scale_max, "HND" + ) out = torch.empty_like( q_hnd, dtype=torch.bfloat16 if input_dtype == torch.bfloat16 else torch.float16, @@ -577,13 +699,12 @@ def _with_lse(out: torch.Tensor): out = out.transpose(1, 2).contiguous() return _with_lse(out) - q_int8, q_scale, k_int8, k_scale = per_warp_int8_cuda( - q_hnd, k_hnd, k_mean, BLKQ=128, WARPQ=32, BLKK=64, tensor_layout="HND" - ) + q_int8, q_scale, k_int8, k_scale = _quant_qk_hnd(q_hnd, k_hnd, k_mean) out = torch.empty_like(q_hnd, dtype=torch.float16) if value_dtype == "fp8": - value_scale = _gfx12_fp8_value_scale_hnd(v_hnd, fp8_value_scale_max) - value_native = gfx12_native.transpose_value_fp8_scaled_hnd(v_hnd, value_scale) + value_native, value_scale = _gfx12_fp8_value_native( + gfx12_native, v_hnd, fp8_value_scale_max, "HND" + ) gfx12_native.qk_int8_sv_f8_scaled_native_attn( q_int8, k_int8, value_native, out, q_scale, k_scale, value_scale, 1, int(is_causal), float(sm_scale), kv_len @@ -592,7 +713,8 @@ def _with_lse(out: torch.Tensor): value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) gfx12_native.qk_int8_sv_f16_d64_native_attn( q_int8, k_int8, value_native, out, q_scale, k_scale, - 1, int(is_causal), float(sm_scale), kv_len, 1 + 1, int(is_causal), float(sm_scale), kv_len, 1, + pv_accum_mode ) out = out[..., :qo_len, :head_dim_og] if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: @@ -673,27 +795,18 @@ def sageattn( return sageattn_qk_int8_pv_gfx12_native( q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, **kwargs) - elif arch == "sm75": - return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) - elif arch in {"sm80", "sm86", "sm87"}: + if arch == "sm80": return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") + elif arch == "sm86": + return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) elif arch == "sm89": - if get_cuda_version() < (12, 8): - pv_accum_dtype = "fp32+fp32" - else: - # SageAttention2++ - pv_accum_dtype = "fp32+fp16" - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype=pv_accum_dtype) + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") elif arch == "sm90": return sageattn_qk_int8_pv_fp8_cuda_sm90(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp32") - elif arch in {"sm100", "sm120", "sm121"}: - if get_cuda_version() < (12, 8): - # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. - pv_accum_dtype = "fp32" - else: - # SageAttention2++ - pv_accum_dtype = "fp32+fp16" - return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype=pv_accum_dtype) + elif arch == "sm120": + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm120 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm120. + elif arch == "sm121": + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, qk_quant_gran="per_warp", sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32+fp16") # sm121 has accurate fp32 accumulator for fp8 mma and triton kernel is currently not usable on sm121. else: raise ValueError(f"Unsupported CUDA architecture: {arch}") @@ -1152,17 +1265,17 @@ def sageattn_qk_int8_pv_fp16_cuda( if pv_accum_dtype == 'fp32': v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f32_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16": if smooth_v: smoothed_v, vm = sub_mean(v, tensor_layout=tensor_layout) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_fuse_v_mean_attn(q_int8, k_int8, smoothed_v, o, q_scale, k_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp16+fp32": v = v.to(torch.float16) - lse = _qattn_sm80.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm80_compile.qk_int8_sv_f16_accum_f16_attn_inst_buf(q_int8, k_int8, v, o, q_scale, k_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: raise ValueError(f"Unsupported pv_accum_dtype: {pv_accum_dtype}") @@ -1351,13 +1464,13 @@ def sageattn_qk_int8_pv_fp8_cuda( if pv_accum_dtype == "fp32": if smooth_v: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp16": - lse = _qattn_sm89.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm89_compile.qk_int8_sv_f8_accum_f16_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] @@ -1525,9 +1638,9 @@ def sageattn_qk_int8_pv_fp8_cuda_sm90( if pv_accum_dtype == "fp32": raise NotImplementedError("Please use pv_accum_dtype='fp32+fp32' for sm90.") - lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": - lse = _qattn_sm90.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + lse = sm90_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn_inst_buf(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) o = o[..., :head_dim_og] diff --git a/sageattention/gfx12_native_compile.py b/sageattention/gfx12_native_compile.py index abe01c77..3c717d3f 100644 --- a/sageattention/gfx12_native_compile.py +++ b/sageattention/gfx12_native_compile.py @@ -21,6 +21,7 @@ def qk_int8_sv_f16_d64_native_attn_fake_impl( sm_scale: float, valid_kv_len: int = 0, value_transposed_hnd: int = -1, + pv_accum_mode: int = -1, ) -> torch.Tensor: return _empty_lse(query) @@ -37,10 +38,27 @@ def qk_rawq_int8_sv_f8_native_attn_fake_impl( sm_scale: float, valid_kv_len: int = 0, value_transposed_hnd: int = -1, + key_hnd_layout: int = 0, ) -> torch.Tensor: return output +@torch.library.register_fake("sageattention_qattn_gfx12_native::qk_rawq_int8_sv_f16_native_attn") +def qk_rawq_int8_sv_f16_native_attn_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + output: torch.Tensor, + key_scale: torch.Tensor, + tensor_layout: int, + is_causal: int, + sm_scale: float, + valid_kv_len: int = 0, + pv_accum_mode: int = -1, +) -> torch.Tensor: + return _empty_lse(query) + + @torch.library.register_fake("sageattention_qattn_gfx12_native::qk_int8_sv_f8_scaled_native_attn") def qk_int8_sv_f8_scaled_native_attn_fake_impl( query: torch.Tensor, @@ -71,6 +89,7 @@ def qk_rawq_int8_sv_f8_scaled_native_attn_fake_impl( sm_scale: float, valid_kv_len: int = 0, value_transposed_hnd: int = -1, + key_hnd_layout: int = 0, ) -> torch.Tensor: return output @@ -85,6 +104,7 @@ def qk_int8_sv_f16_d64_prepare_attn_hnd_fake_impl( use_raw_f16_value: int, sm_scale: float, valid_kv_len: int = 0, + pv_accum_mode: int = -1, ) -> torch.Tensor: dtype = torch.bfloat16 if value_is_fp8 and query.dtype == torch.bfloat16 else torch.float16 return torch.empty(query.shape, dtype=dtype, device=query.device) @@ -118,6 +138,38 @@ def transpose_value_fp8_scaled_hnd_fake_impl(value: torch.Tensor, value_scale: t ) +@torch.library.register_fake("sageattention_qattn_gfx12_native::fp8_value_nhd_short") +def fp8_value_nhd_short_fake_impl(value: torch.Tensor, scale_max: float) -> list[torch.Tensor]: + batch, seq_len, heads, head_dim = value.shape + return [ + torch.empty((batch, heads, head_dim, seq_len), dtype=torch.uint8, device=value.device), + torch.empty((batch, heads, head_dim), dtype=torch.float32, device=value.device), + ] + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::mean_nhd") +def mean_nhd_fake_impl(input: torch.Tensor) -> torch.Tensor: + return torch.empty( + (input.size(0), input.size(2), input.size(3)), + dtype=input.dtype, + device=input.device, + ) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::mean_and_fp8_value_nhd_short") +def mean_and_fp8_value_nhd_short_fake_impl( + key: torch.Tensor, + value: torch.Tensor, + scale_max: float, +) -> list[torch.Tensor]: + batch, seq_len, heads, head_dim = value.shape + return [ + torch.empty((batch, heads, head_dim), dtype=key.dtype, device=key.device), + torch.empty((batch, heads, head_dim, seq_len), dtype=torch.uint8, device=value.device), + torch.empty((batch, heads, head_dim), dtype=torch.float32, device=value.device), + ] + + @torch.library.register_fake("sageattention_qattn_gfx12_native::transpose_value_f16_hnd") def transpose_value_f16_hnd_fake_impl(value: torch.Tensor) -> torch.Tensor: return torch.empty( diff --git a/setup.py b/setup.py index 8a587d63..417e3392 100644 --- a/setup.py +++ b/setup.py @@ -154,6 +154,10 @@ def rocm_arches(torch): ] + LIMITED_API_FLAGS for arch in amd_arches: HIP_FLAGS.append(f"--offload-arch={arch}") + HIP_FLAGS.append(f"--rocm-path={rocm_home}") + rocm_device_lib_path = os.path.join(rocm_home, "lib", "llvm", "amdgcn", "bitcode") + if os.path.isdir(rocm_device_lib_path): + HIP_FLAGS.append(f"--rocm-device-lib-path={rocm_device_lib_path}") append_env_flags(CXX_FLAGS, "CXX_APPEND_FLAGS") append_env_flags(HIP_FLAGS, "NVCC_APPEND_FLAGS") @@ -167,7 +171,11 @@ def rocm_arches(torch): name="sageattention._qattn_gfx12_native", sources=[ "csrc/qattn/pybind_gfx12_native.cpp", - "csrc/qattn/qk_int_sv_gfx12_native.cu", + "csrc/qattn/qk_int_sv_gfx12_native_aux.cu", + "csrc/qattn/qk_int_sv_gfx12_native_prepare.cu", + "csrc/qattn/qk_int_sv_gfx12_native_attn_f16.cu", + "csrc/qattn/qk_int_sv_gfx12_native_attn_fp8.cu", + "csrc/qattn/qk_int_sv_gfx12_native_rawq_fp8.cu", ], include_dirs=include_dirs, extra_compile_args={ From 4312b799fad3fae929f25096fd0a466cbf65181d Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 19 May 2026 00:28:33 +0900 Subject: [PATCH 4/7] Improve gfx12 fp16 stream routes --- csrc/qattn/qk_int_sv_gfx12_native.cu | 416 ++++++++++++++++++++------- 1 file changed, 319 insertions(+), 97 deletions(-) diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu index 7bb709a0..60bac298 100644 --- a/csrc/qattn/qk_int_sv_gfx12_native.cu +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -2112,7 +2112,14 @@ template + bool FlatCausalSchedule = false, + bool PerThreadQK = false, + bool StaticNhdLayout = false, + bool NoKvTail = false, + bool SameQKHeads = false, + bool NoQueryTail = false, + bool PrefetchStreamVRegs = false, + bool DirectStreamProbs = false> SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HeadDim, IsCausal) __global__ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2q_kernel( const QueryT* __restrict__ q, @@ -2329,7 +2336,12 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } } - const int64_t kv_limit = IsCausal && (q_base + BR) < kv_len ? q_base + BR : kv_len; + constexpr bool ExactStaticCausalBlock = + IsCausal && StaticNhdLayout && NoKvTail && NoQueryTail && + (BR % BC == 0); + const int64_t kv_limit = + ExactStaticCausalBlock ? (q_base + BR) : + (IsCausal && (q_base + BR) < kv_len ? q_base + BR : kv_len); auto process_kv_tile = [&](const int64_t kb_base, auto apply_causal_mask_tag) { constexpr int KVecBytes = 16; constexpr int KBytesPerRow = HeadDim; @@ -2603,102 +2615,242 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 continue; } - half8_vec p_regs[QGroups][StreamGroupCols]; + if constexpr (DirectStreamProbs) { + float local_sums[QGroups]; #pragma unroll - for (int qg = 0; qg < QGroups; ++qg) { - float local_max = -FLT_MAX * 0.5f; + for (int qg = 0; qg < QGroups; ++qg) { + float local_max = -FLT_MAX * 0.5f; #pragma unroll - for (int gc = 0; gc < StreamGroupCols; ++gc) { + for (int gc = 0; gc < StreamGroupCols; ++gc) { #pragma unroll - for (int elem = 0; elem < 8; ++elem) { - const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; - local_max = fmaxf(local_max, score); + for (int elem = 0; elem < 8; ++elem) { + const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; + local_max = fmaxf(local_max, score); + } } - } - const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); - const float old_m = m[qg]; - const float new_m = fmaxf(old_m, tile_max); - const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); - m[qg] = new_m; - l[qg] *= alpha; + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + local_sums[qg] = 0.0f; - float8_vec alpha_rows; + float8_vec alpha_rows; #pragma unroll - for (int elem = 0; elem < 8; ++elem) { - alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); - } + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } #pragma unroll - for (int dt = 0; dt < DTiles; ++dt) { + for (int dt = 0; dt < DTiles; ++dt) { #pragma unroll - for (int elem = 0; elem < 8; ++elem) { - out_frag[qg][dt][elem] *= alpha_rows[elem]; + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] *= alpha_rows[elem]; + } } } - float local_sum = 0.0f; #pragma unroll for (int gc = 0; gc < StreamGroupCols; ++gc) { - half8_vec prob_values; + if (fully_future[0][gc] && fully_future[1][gc]) { + continue; + } + half8_vec p_regs_current[QGroups]; #pragma unroll - for (int elem = 0; elem < 8; ++elem) { - const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; - float prob = 0.0f; - if (!fully_future[qg][gc]) { - prob = fast_exp2(score - m[qg] + kF16SoftmaxOffset); - local_sum += prob; + for (int qg = 0; qg < QGroups; ++qg) { + half8_vec prob_values; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; + float prob = 0.0f; + if (!fully_future[qg][gc]) { + prob = fast_exp2(score - m[qg] + kF16SoftmaxOffset); + local_sums[qg] += prob; + } + prob_values[elem] = static_cast<_Float16>(prob); + } + if constexpr (PvOrderedQK) { + p_regs_current[qg] = prob_values; + } else { + p_regs_current[qg] = make_p_regs_from_tqk_prob_regs(prob_values, lane); } - prob_values[elem] = static_cast<_Float16>(prob); } - if constexpr (PvOrderedQK) { - p_regs[qg][gc] = prob_values; + const int col_tile = group_base + gc; + auto load_stream_v_regs = [&](const int dt) { + half8_vec v_regs; + if constexpr (UseLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + return v_regs; + }; + auto apply_stream_pv = [&](const int dt, const half8_vec v_regs) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if (fully_future[qg][gc]) { + continue; + } + PvAccumVec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[qg][dt][elem]; + } + if constexpr (F16PvAccum) { + acc = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs_current[qg], v_regs, acc); + } else { + acc = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs_current[qg], v_regs, acc); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] = acc[elem]; + } + } + }; + if constexpr (PrefetchStreamVRegs) { + half8_vec v_regs = load_stream_v_regs(0); +#pragma unroll + for (int dt = 0; dt < DTiles - 1; ++dt) { + const half8_vec next_v_regs = load_stream_v_regs(dt + 1); + apply_stream_pv(dt, v_regs); + v_regs = next_v_regs; + } + apply_stream_pv(DTiles - 1, v_regs); } else { - p_regs[qg][gc] = make_p_regs_from_tqk_prob_regs(prob_values, lane); +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const half8_vec v_regs = load_stream_v_regs(dt); + apply_stream_pv(dt, v_regs); + } } } - l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); - } - #pragma unroll - for (int gc = 0; gc < StreamGroupCols; ++gc) { - if (fully_future[0][gc] && fully_future[1][gc]) { - continue; + for (int qg = 0; qg < QGroups; ++qg) { + l[qg] += local_sums[qg] + __shfl_xor(local_sums[qg], 16, 32); } - const int col_tile = group_base + gc; + } else { + half8_vec p_regs[QGroups][StreamGroupCols]; #pragma unroll - for (int dt = 0; dt < DTiles; ++dt) { - half8_vec v_regs; - if constexpr (UseLaneMajorValue) { - v_regs = make_v_regs_from_lane_major_shared( - v_lane_tile, col_tile, dt, lane); - } else if constexpr (UseTransposedValueLayout) { - v_regs = make_v_regs_from_transposed_shared( - &v_tile[0][0], col_tile, dt, lane); - } else { - v_regs = make_v_regs_from_shared( - &v_tile[0][0], col_tile, dt, lane); + for (int qg = 0; qg < QGroups; ++qg) { + float local_max = -FLT_MAX * 0.5f; +#pragma unroll + for (int gc = 0; gc < StreamGroupCols; ++gc) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; + local_max = fmaxf(local_max, score); + } } + const float tile_max = fmaxf(local_max, __shfl_xor(local_max, 16, 32)); + const float old_m = m[qg]; + const float new_m = fmaxf(old_m, tile_max); + const float alpha = l[qg] == 0.0f ? 0.0f : fast_exp2(old_m - new_m); + m[qg] = new_m; + l[qg] *= alpha; + + float8_vec alpha_rows; #pragma unroll - for (int qg = 0; qg < QGroups; ++qg) { - if (fully_future[qg][gc]) { - continue; + for (int elem = 0; elem < 8; ++elem) { + alpha_rows[elem] = __shfl(alpha, row_base + elem, 32); + } + +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] *= alpha_rows[elem]; } - PvAccumVec acc; + } + + float local_sum = 0.0f; +#pragma unroll + for (int gc = 0; gc < StreamGroupCols; ++gc) { + half8_vec prob_values; #pragma unroll for (int elem = 0; elem < 8; ++elem) { - acc[elem] = out_frag[qg][dt][elem]; + const float score = qg == 0 ? scores0[gc][elem] : scores1[gc][elem]; + float prob = 0.0f; + if (!fully_future[qg][gc]) { + prob = fast_exp2(score - m[qg] + kF16SoftmaxOffset); + local_sum += prob; + } + prob_values[elem] = static_cast<_Float16>(prob); } - PvAccumVec pv_acc; - if constexpr (F16PvAccum) { - pv_acc = - __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs[qg][gc], v_regs, acc); + if constexpr (PvOrderedQK) { + p_regs[qg][gc] = prob_values; } else { - pv_acc = - __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs[qg][gc], v_regs, acc); + p_regs[qg][gc] = make_p_regs_from_tqk_prob_regs(prob_values, lane); } + } + l[qg] += local_sum + __shfl_xor(local_sum, 16, 32); + } + #pragma unroll - for (int elem = 0; elem < 8; ++elem) { - out_frag[qg][dt][elem] = pv_acc[elem]; + for (int gc = 0; gc < StreamGroupCols; ++gc) { + if (fully_future[0][gc] && fully_future[1][gc]) { + continue; + } + const int col_tile = group_base + gc; + auto load_stream_v_regs = [&](const int dt) { + half8_vec v_regs; + if constexpr (UseLaneMajorValue) { + v_regs = make_v_regs_from_lane_major_shared( + v_lane_tile, col_tile, dt, lane); + } else if constexpr (UseTransposedValueLayout) { + v_regs = make_v_regs_from_transposed_shared( + &v_tile[0][0], col_tile, dt, lane); + } else { + v_regs = make_v_regs_from_shared( + &v_tile[0][0], col_tile, dt, lane); + } + return v_regs; + }; + auto apply_stream_pv = [&](const int dt, const half8_vec v_regs) { +#pragma unroll + for (int qg = 0; qg < QGroups; ++qg) { + if (fully_future[qg][gc]) { + continue; + } + PvAccumVec acc; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + acc[elem] = out_frag[qg][dt][elem]; + } + if constexpr (F16PvAccum) { + acc = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs[qg][gc], v_regs, acc); + } else { + acc = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs[qg][gc], v_regs, acc); + } +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + out_frag[qg][dt][elem] = acc[elem]; + } + } + }; + if constexpr (PrefetchStreamVRegs) { + half8_vec v_regs = load_stream_v_regs(0); +#pragma unroll + for (int dt = 0; dt < DTiles - 1; ++dt) { + const half8_vec next_v_regs = load_stream_v_regs(dt + 1); + apply_stream_pv(dt, v_regs); + v_regs = next_v_regs; + } + apply_stream_pv(DTiles - 1, v_regs); + } else { +#pragma unroll + for (int dt = 0; dt < DTiles; ++dt) { + const half8_vec v_regs = load_stream_v_regs(dt); + apply_stream_pv(dt, v_regs); } } } @@ -2963,17 +3115,16 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int elem = 0; elem < 8; ++elem) { acc0[elem] = out_frag[0][dt][elem]; } - PvAccumVec pv_acc0; if constexpr (F16PvAccum) { - pv_acc0 = - __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs0, v_regs, acc0); + acc0 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs0, v_regs, acc0); } else { - pv_acc0 = - __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs0, v_regs, acc0); + acc0 = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs0, v_regs, acc0); } #pragma unroll for (int elem = 0; elem < 8; ++elem) { - out_frag[0][dt][elem] = pv_acc0[elem]; + out_frag[0][dt][elem] = acc0[elem]; } } const bool fully_future1 = @@ -2985,17 +3136,16 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int elem = 0; elem < 8; ++elem) { acc1[elem] = out_frag[1][dt][elem]; } - PvAccumVec pv_acc1; if constexpr (F16PvAccum) { - pv_acc1 = - __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs1, v_regs, acc1); + acc1 = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs1, v_regs, acc1); } else { - pv_acc1 = - __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs1, v_regs, acc1); + acc1 = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs1, v_regs, acc1); } #pragma unroll for (int elem = 0; elem < 8; ++elem) { - out_frag[1][dt][elem] = pv_acc1[elem]; + out_frag[1][dt][elem] = acc1[elem]; } } } @@ -3113,17 +3263,16 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int elem = 0; elem < 8; ++elem) { acc[elem] = out_frag[qg][dt][elem]; } - PvAccumVec pv_acc; if constexpr (F16PvAccum) { - pv_acc = - __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(p_regs, v_regs, acc); + acc = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs, v_regs, acc); } else { - pv_acc = - __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(p_regs, v_regs, acc); + acc = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs, v_regs, acc); } #pragma unroll for (int elem = 0; elem < 8; ++elem) { - out_frag[qg][dt][elem] = pv_acc[elem]; + out_frag[qg][dt][elem] = acc[elem]; } } } @@ -3138,7 +3287,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 }; if constexpr (IsCausal) { - const int64_t prefix_limit = q_base < kv_limit ? q_base : kv_limit; + const int64_t prefix_limit = ExactStaticCausalBlock ? + q_base : (((q_base / BC) * BC) < kv_limit ? ((q_base / BC) * BC) : kv_limit); #pragma unroll 2 for (int64_t kb_base = 0; kb_base < prefix_limit; kb_base += BC) { process_kv_tile(kb_base, std::false_type{}); @@ -8223,8 +8373,33 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( const int block_rows = q_len <= 64 ? 64 : 128; const dim3 block(block_rows); const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); -#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_) \ - qk_int8_sv_f16_d64_native_2q_kernel<64, HND_, BR_, false, SAGEATTN_GFX12_NATIVE_F16_TV_PAD, CAUSAL_, false, F16ACC_, QUERY_T_, true, int8_t, false, false, false, false, false, HD_><<>>( \ + const bool use_d128_short_stream = + is_causal && head_dim == 128 && block_rows == 128 && q_len <= 1024; + const bool use_direct_stream_probs = + use_d128_short_stream && q_len == 1024 && pv_accum_mode != 1; + const bool use_d128_long_stream = + is_causal && head_dim == 128 && block_rows == 128 && + q_len >= 2048 && pv_accum_mode != 1; + const bool use_d64_noncausal_stream_direct = + !is_causal && head_dim == 64 && block_rows == 128 && + q_len >= 1024 && pv_accum_mode != 1; + const bool use_f16_d64_static_long = + head_dim == 64 && (q_len == 2048 || q_len == 4096 || q_len == 8192); + const bool use_f16_d128_static_long = + head_dim == 128 && (q_len == 2048 || q_len == 4096 || q_len == 8192); + const bool use_static_nhd_no_tail = + query.scalar_type() == ScalarType::Half && + tensor_layout == kNHD && q_heads == kv_heads && + block_rows == 128 && + ((!is_causal && (q_len == 512 || q_len == 1024)) || + (is_causal && (q_len == 512 || q_len == 1024)) || + use_f16_d64_static_long || + use_f16_d128_static_long) && + q_len == padded_kv_len && kv_len == padded_kv_len && + (head_dim == 64 || head_dim == 128); + +#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PVORDER_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -8235,25 +8410,67 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ 0, 0, key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale) + tensor_layout, static_cast(sm_scale)) +#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_) \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE(64, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, false, false, false, false, false, false, false) +#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_) \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, false, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, true, true, true, true, true, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_) #define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(HD_, HND_, QUERY_T_) \ if (is_causal) { \ if (pv_accum_mode == 1) { \ - if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, true, QUERY_T_, true); } \ - else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, true, QUERY_T_, true); } \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 64, true, QUERY_T_, true, false); } \ + else if ((HD_) == 128 && use_d128_short_stream) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 128, true, QUERY_T_, true, true); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 128, true, QUERY_T_, true, false); } \ } else { \ - if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, true, QUERY_T_, false); } \ - else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, true, QUERY_T_, false); } \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 64, true, QUERY_T_, false, false); } \ + else if ((HD_) == 128 && use_d128_short_stream) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 128, true, QUERY_T_, false, true); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 128, true, QUERY_T_, false, false); } \ } \ } else if (pv_accum_mode == 1) { \ - if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, false, QUERY_T_, true); } \ - else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, false, QUERY_T_, true); } \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 64, false, QUERY_T_, true, false); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 128, false, QUERY_T_, true, false); } \ } else { \ - if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 64, false, QUERY_T_, false); } \ - else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE(HD_, HND_, 128, false, QUERY_T_, false); } \ + if (block_rows == 64) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 64, false, QUERY_T_, false, false); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, 128, false, QUERY_T_, false, false); } \ + } +#define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_BR_DTYPE(BC_, HD_, BR_, QUERY_T_) \ + if (is_causal && pv_accum_mode == 1) { \ + if ((HD_) == 128 && use_d128_short_stream && (BR_) == 128) { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, true, true, true, false); \ + } else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, true, false, false, false); } \ + } else if (is_causal) { \ + if ((HD_) == 128 && use_d128_short_stream && (BR_) == 128) { \ + if (use_direct_stream_probs) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, true); } \ + else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, false); } \ + } else if ((HD_) == 128 && (BR_) == 128 && use_d128_long_stream) { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, false); \ + } else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, false, false, false); } \ + } else if (pv_accum_mode == 1) { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, true, false, false, false); \ + } else { \ + if ((HD_) == 64 && use_d64_noncausal_stream_direct) { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, false, true, true, true); \ + } else { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, false, false, false, false); \ + } \ + } +#define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_DTYPE(BC_, HD_, QUERY_T_) \ + SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_BR_DTYPE(BC_, HD_, 128, QUERY_T_) +#define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_DTYPE(HD_, QUERY_T_) \ + if constexpr ((HD_) == 128) { \ + if (q_len == 512) { \ + SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_DTYPE(32, HD_, QUERY_T_); \ + } else { \ + SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_DTYPE(64, HD_, QUERY_T_); \ + } \ + } else { \ + SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_DTYPE(64, HD_, QUERY_T_); \ } #define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE(QUERY_T_) \ - if (hnd_contiguous) { \ + if (use_static_nhd_no_tail) { \ + if (head_dim == 64) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_DTYPE(64, QUERY_T_); } \ + else { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_DTYPE(128, QUERY_T_); } \ + } else if (hnd_contiguous) { \ if (head_dim == 16) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(16, true, QUERY_T_); } \ else if (head_dim == 64) { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(64, true, QUERY_T_); } \ else { SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(128, true, QUERY_T_); } \ @@ -8268,7 +8485,12 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE(__hip_bfloat16); } #undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_DTYPE +#undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_DTYPE +#undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_DTYPE +#undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_BR_DTYPE #undef SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND +#undef SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD +#undef SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT #undef SAGEATTN_LAUNCH_RAWQ_F16_VALUE hip_kernel_launch_check(); return new_empty_like(query, {0}, ScalarType::Float); From 5cb040da1c4b4f7d1011e9f219fc0dd8b7f7b9fb Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 19 May 2026 16:21:50 +0900 Subject: [PATCH 5/7] gfx12: optimize short fp16 native paths --- csrc/qattn/qk_int_sv_gfx12_native.cu | 58 ++++++++++++++++++++-------- sageattention/core.py | 3 +- 2 files changed, 43 insertions(+), 18 deletions(-) diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu index 60bac298..e9a35750 100644 --- a/csrc/qattn/qk_int_sv_gfx12_native.cu +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -2119,7 +2119,8 @@ template + bool DirectStreamProbs = false, + bool DirectPvOutFragAccum = false> SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HeadDim, IsCausal) __global__ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2q_kernel( const QueryT* __restrict__ q, @@ -2697,6 +2698,15 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 if (fully_future[qg][gc]) { continue; } + if constexpr (DirectPvOutFragAccum) { + if constexpr (F16PvAccum) { + out_frag[qg][dt] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs_current[qg], v_regs, out_frag[qg][dt]); + } else { + out_frag[qg][dt] = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs_current[qg], v_regs, out_frag[qg][dt]); + } + } else { PvAccumVec acc; #pragma unroll for (int elem = 0; elem < 8; ++elem) { @@ -2713,6 +2723,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int elem = 0; elem < 8; ++elem) { out_frag[qg][dt][elem] = acc[elem]; } + } } }; if constexpr (PrefetchStreamVRegs) { @@ -2819,6 +2830,15 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 if (fully_future[qg][gc]) { continue; } + if constexpr (DirectPvOutFragAccum) { + if constexpr (F16PvAccum) { + out_frag[qg][dt] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12( + p_regs[qg][gc], v_regs, out_frag[qg][dt]); + } else { + out_frag[qg][dt] = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + p_regs[qg][gc], v_regs, out_frag[qg][dt]); + } + } else { PvAccumVec acc; #pragma unroll for (int elem = 0; elem < 8; ++elem) { @@ -2835,6 +2855,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int elem = 0; elem < 8; ++elem) { out_frag[qg][dt][elem] = acc[elem]; } + } } }; if constexpr (PrefetchStreamVRegs) { @@ -7949,7 +7970,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } #endif // SAGEATTN_GFX12_BUILD_ATTN_F16 hip_kernel_launch_check(); - return new_empty_like(query, {0}, ScalarType::Float); + return output; } #if SAGEATTN_GFX12_BUILD_ATTN_F16 @@ -8398,8 +8419,8 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( q_len == padded_kv_len && kv_len == padded_kv_len && (head_dim == 64 || head_dim == 128); -#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PVORDER_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ +#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PVORDER_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_, DIRECT_PV_OUTFRAG_) \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -8412,9 +8433,9 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( 0, 0, key_scale.stride(0), key_scale.stride(1), \ tensor_layout, static_cast(sm_scale)) #define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_) \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE(64, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, false, false, false, false, false, false, false) -#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_) \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, false, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, true, true, true, true, true, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_) + SAGEATTN_LAUNCH_RAWQ_F16_VALUE(64, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, false, false, false, false, false, false, false, false) +#define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_, DIRECT_PV_OUTFRAG_) \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, false, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, true, true, true, true, true, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_, DIRECT_PV_OUTFRAG_) #define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_FOR_HND(HD_, HND_, QUERY_T_) \ if (is_causal) { \ if (pv_accum_mode == 1) { \ @@ -8436,22 +8457,25 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( #define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_BR_DTYPE(BC_, HD_, BR_, QUERY_T_) \ if (is_causal && pv_accum_mode == 1) { \ if ((HD_) == 128 && use_d128_short_stream && (BR_) == 128) { \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, true, true, true, false); \ - } else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, true, false, false, false); } \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, true, true, true, false, false); \ + } else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, true, false, false, false, false); } \ } else if (is_causal) { \ if ((HD_) == 128 && use_d128_short_stream && (BR_) == 128) { \ - if (use_direct_stream_probs) { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, true); } \ - else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, false); } \ + if (use_direct_stream_probs) { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, true, true); \ + } else { \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, false, true); \ + } \ } else if ((HD_) == 128 && (BR_) == 128 && use_d128_long_stream) { \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, false); \ - } else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, false, false, false); } \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, true, true, false, false); \ + } else { SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, true, QUERY_T_, false, false, false, false, false); } \ } else if (pv_accum_mode == 1) { \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, true, false, false, false); \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, true, false, false, false, false); \ } else { \ if ((HD_) == 64 && use_d64_noncausal_stream_direct) { \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, false, true, true, true); \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, false, true, true, true, false); \ } else { \ - SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, false, false, false, false); \ + SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, false, QUERY_T_, false, false, false, false, false); \ } \ } #define SAGEATTN_DISPATCH_RAWQ_F16_VALUE_STATIC_NHD_FOR_BC_DTYPE(BC_, HD_, QUERY_T_) \ @@ -8493,7 +8517,7 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( #undef SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT #undef SAGEATTN_LAUNCH_RAWQ_F16_VALUE hip_kernel_launch_check(); - return new_empty_like(query, {0}, ScalarType::Float); + return output; } #endif // SAGEATTN_GFX12_BUILD_ATTN_F16 diff --git a/sageattention/core.py b/sageattention/core.py index 06b3f60e..a4b5fc58 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -559,7 +559,8 @@ def _with_lse(out: torch.Tensor): 0, pv_accum_mode, ) - out = out[:, :qo_len, :, :head_dim_og] + if q_out_len != qo_len or head_dim != head_dim_og: + out = out[:, :qo_len, :, :head_dim_og] if input_dtype == torch.bfloat16 and out.dtype != torch.bfloat16: out = gfx12_native.convert_f16_to_bf16(out.contiguous() if not out.is_contiguous() else out) elif input_dtype != torch.float16: From df17c2f3f4ed8102cd9bb97226d0b0348d866417 Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 19 May 2026 18:14:35 +0900 Subject: [PATCH 6/7] Port latest gfx12 native paths to ABI3 --- csrc/qattn/attn_gfx12_native.h | 16 + csrc/qattn/pybind_gfx12_native.cpp | 11 + csrc/qattn/qk_int_sv_gfx12_native.cu | 1543 +++++++++++++++++++------ sageattention/core.py | 164 ++- sageattention/gfx12_native_compile.py | 39 + 5 files changed, 1393 insertions(+), 380 deletions(-) diff --git a/csrc/qattn/attn_gfx12_native.h b/csrc/qattn/attn_gfx12_native.h index 44045b1e..f4805e2d 100644 --- a/csrc/qattn/attn_gfx12_native.h +++ b/csrc/qattn/attn_gfx12_native.h @@ -78,6 +78,14 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( int64_t value_transposed_hnd = -1, int64_t key_hnd_layout = 0); +Tensor sage_fp8_nhd_short_mha_gfx12( + Tensor query, + Tensor key, + Tensor value, + int64_t is_causal, + double sm_scale, + double scale_max); + Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( Tensor query, Tensor key, @@ -101,6 +109,14 @@ std::vector fp8_value_nhd_short_gfx12( Tensor mean_nhd_gfx12(Tensor input); +Tensor mean_hnd_gfx12(Tensor input); + +std::vector prepare_qkv_hnd_smooth_f16_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor key_mean); + std::vector mean_and_fp8_value_nhd_short_gfx12( Tensor key, Tensor value, diff --git a/csrc/qattn/pybind_gfx12_native.cpp b/csrc/qattn/pybind_gfx12_native.cpp index 5aa1970f..4bd44219 100644 --- a/csrc/qattn/pybind_gfx12_native.cpp +++ b/csrc/qattn/pybind_gfx12_native.cpp @@ -51,6 +51,10 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { "int is_causal, float sm_scale, int valid_kv_len=0, " "int value_transposed_hnd=-1, int key_hnd_layout=0" ") -> Tensor"); + m.def("sage_fp8_nhd_short_mha(" + "Tensor query, Tensor key, Tensor value, int is_causal, " + "float sm_scale, float scale_max" + ") -> Tensor"); m.def("qk_int8_sv_f16_d64_prepare_attn_hnd(" "Tensor query, Tensor key, Tensor value, int is_causal, " "int value_is_fp8, int use_raw_f16_value, float sm_scale, " @@ -61,6 +65,10 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { m.def("transpose_value_fp8_scaled_hnd(Tensor value, Tensor value_scale) -> Tensor"); m.def("fp8_value_nhd_short(Tensor value, float scale_max) -> Tensor[]"); m.def("mean_nhd(Tensor input) -> Tensor"); + m.def("mean_hnd(Tensor input) -> Tensor"); + m.def("prepare_qkv_hnd_smooth_f16(" + "Tensor query, Tensor key, Tensor value, Tensor key_mean" + ") -> Tensor[]"); m.def("mean_and_fp8_value_nhd_short(Tensor key, Tensor value, float scale_max) -> Tensor[]"); m.def("transpose_value_f16_hnd(Tensor value) -> Tensor"); m.def("convert_f16_to_bf16(Tensor input) -> Tensor"); @@ -72,12 +80,15 @@ STABLE_TORCH_LIBRARY_IMPL(sageattention_qattn_gfx12_native, CUDA, m) { m.impl("qk_rawq_int8_sv_f16_native_attn", TORCH_BOX(qk_rawq_int8_sv_f16_native_attn_gfx12)); m.impl("qk_int8_sv_f8_scaled_native_attn", TORCH_BOX(qk_int8_sv_f8_scaled_native_attn_gfx12)); m.impl("qk_rawq_int8_sv_f8_scaled_native_attn", TORCH_BOX(qk_rawq_int8_sv_f8_scaled_native_attn_gfx12)); + m.impl("sage_fp8_nhd_short_mha", TORCH_BOX(sage_fp8_nhd_short_mha_gfx12)); m.impl("qk_int8_sv_f16_d64_prepare_attn_hnd", TORCH_BOX(qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12)); m.impl("quant_q_nhd_per_warp", TORCH_BOX(quant_q_nhd_per_warp_gfx12)); m.impl("transpose_value_fp8_hnd", TORCH_BOX(transpose_value_fp8_hnd_gfx12)); m.impl("transpose_value_fp8_scaled_hnd", TORCH_BOX(transpose_value_fp8_scaled_hnd_gfx12)); m.impl("fp8_value_nhd_short", TORCH_BOX(fp8_value_nhd_short_gfx12)); m.impl("mean_nhd", TORCH_BOX(mean_nhd_gfx12)); + m.impl("mean_hnd", TORCH_BOX(mean_hnd_gfx12)); + m.impl("prepare_qkv_hnd_smooth_f16", TORCH_BOX(prepare_qkv_hnd_smooth_f16_gfx12)); m.impl("mean_and_fp8_value_nhd_short", TORCH_BOX(mean_and_fp8_value_nhd_short_gfx12)); m.impl("transpose_value_f16_hnd", TORCH_BOX(transpose_value_f16_hnd_gfx12)); m.impl("convert_f16_to_bf16", TORCH_BOX(convert_f16_to_bf16_gfx12)); diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu index e9a35750..0201a8a1 100644 --- a/csrc/qattn/qk_int_sv_gfx12_native.cu +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -90,6 +90,25 @@ Tensor from_blob_like( return torch::stable::from_blob(data, shape, strides, like.device(), dtype); } +bool same_sizes(const Tensor& a, const Tensor& b) { + if (a.dim() != b.dim()) { + return false; + } + for (int64_t i = 0; i < a.dim(); ++i) { + if (a.size(i) != b.size(i)) { + return false; + } + } + return true; +} + +hipStream_t current_hip_stream(const Tensor& tensor) { + int32_t device_index = tensor.get_device_index(); + void* stream = nullptr; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_current_cuda_stream(device_index, &stream)); + return reinterpret_cast(stream); +} + #define SAGEATTN_NATIVE_HAS_GFX12_WMMA 1 #ifndef SAGEATTN_GFX12_NATIVE_WAVES_PER_EU_MAX #define SAGEATTN_GFX12_NATIVE_WAVES_PER_EU_MAX 1 @@ -458,13 +477,14 @@ Tensor transpose_value_hnd_gfx12(Tensor value) { dim3 block(256); dim3 grid((seq_len + 127) / 128, (head_dim + 15) / 16, batch * heads); - if (value.scalar_type() == ScalarType::Half) { - transpose_value_hnd_kernel<__half, OutT, ToFp8><<>>( + const hipStream_t stream = current_hip_stream(value); + if (value.scalar_type() == ScalarType::Half) { + transpose_value_hnd_kernel<__half, OutT, ToFp8><<>>( reinterpret_cast(value.data_ptr()), reinterpret_cast(output.data_ptr()), batch * heads, seq_len, head_dim); } else { - transpose_value_hnd_kernel<__hip_bfloat16, OutT, ToFp8><<>>( + transpose_value_hnd_kernel<__hip_bfloat16, OutT, ToFp8><<>>( reinterpret_cast(value.data_ptr()), reinterpret_cast(output.data_ptr()), batch * heads, seq_len, head_dim); @@ -638,6 +658,96 @@ __global__ void mean_nhd_kernel( } } +template +__global__ void mean_nhd_short_kernel( + const T* __restrict__ input, + T* __restrict__ mean, + const int64_t seq_len, + const int64_t heads, + const int64_t head_dim) { + __shared__ float partial_sum[TileD * SeqLanes]; + + const int tid = threadIdx.x; + const int d_local = tid & (TileD - 1); + const int s_lane = tid / TileD; + const int64_t d_base = static_cast(blockIdx.x) * TileD; + const int64_t h = blockIdx.y; + const int64_t b = blockIdx.z; + const int64_t d = d_base + d_local; + + float local_sum = 0.0f; + if (d < head_dim) { + for (int64_t s = s_lane; s < seq_len; s += SeqLanes) { + const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + d; + local_sum += value_to_float(input[offset]); + } + } + partial_sum[tid] = local_sum; + __syncthreads(); + + if (tid < TileD) { + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < SeqLanes; ++i) { + sum += partial_sum[i * TileD + tid]; + } + const int64_t mean_d = d_base + tid; + if (mean_d < head_dim) { + const float value = sum / static_cast(seq_len); + if constexpr (std::is_same::value) { + mean[(b * heads + h) * head_dim + mean_d] = value_from_float_half(value); + } else { + mean[(b * heads + h) * head_dim + mean_d] = value_from_float_bfloat16(value); + } + } + } +} + +template +__global__ void mean_hnd_kernel( + const T* __restrict__ input, + T* __restrict__ mean, + const int64_t seq_len, + const int64_t heads, + const int64_t head_dim) { + constexpr int TileD = 16; + __shared__ float partial_sum[256]; + + const int tid = threadIdx.x; + const int d_local = tid & (TileD - 1); + const int s_lane = tid >> 4; + const int64_t d_base = static_cast(blockIdx.x) * TileD; + const int64_t h = blockIdx.y; + const int64_t b = blockIdx.z; + const int64_t d = d_base + d_local; + + float local_sum = 0.0f; + if (d < head_dim) { + for (int64_t s = s_lane; s < seq_len; s += 16) { + const int64_t offset = ((b * heads + h) * seq_len + s) * head_dim + d; + local_sum += value_to_float(input[offset]); + } + } + partial_sum[tid] = local_sum; + __syncthreads(); + + if (tid < TileD) { + float sum = 0.0f; + for (int i = 0; i < 16; ++i) { + sum += partial_sum[i * TileD + tid]; + } + const int64_t mean_d = d_base + tid; + if (mean_d < head_dim) { + const float value = sum / static_cast(seq_len); + if constexpr (std::is_same::value) { + mean[(b * heads + h) * head_dim + mean_d] = value_from_float_half(value); + } else { + mean[(b * heads + h) * head_dim + mean_d] = value_from_float_bfloat16(value); + } + } + } +} + __device__ __forceinline__ int32_t pack_f32x4_to_ocp_fp8( const float x0, const float x1, @@ -701,7 +811,7 @@ __global__ void mean_and_fp8_value_nhd_short_kernel( key_mean[mean_offset] = value_from_float_bfloat16(mean); } const float scale = amax / scale_max; - scale_tile[tid] = scale; + scale_tile[tid] = scale == 0.0f ? 0.0f : 1.0f / scale; value_scale[mean_offset] = scale; } } @@ -721,10 +831,10 @@ __global__ void mean_and_fp8_value_nhd_short_kernel( const float scale1 = scale_tile[load_d + 1]; const float scale2 = scale_tile[load_d + 2]; const float scale3 = scale_tile[load_d + 3]; - const float v0 = scale0 == 0.0f ? 0.0f : value_to_float(value[offset + 0]) / scale0; - const float v1 = scale1 == 0.0f ? 0.0f : value_to_float(value[offset + 1]) / scale1; - const float v2 = scale2 == 0.0f ? 0.0f : value_to_float(value[offset + 2]) / scale2; - const float v3 = scale3 == 0.0f ? 0.0f : value_to_float(value[offset + 3]) / scale3; + const float v0 = value_to_float(value[offset + 0]) * scale0; + const float v1 = value_to_float(value[offset + 1]) * scale1; + const float v2 = value_to_float(value[offset + 2]) * scale2; + const float v3 = value_to_float(value[offset + 3]) * scale3; const uint32_t packed = static_cast(pack_f32x4_to_ocp_fp8(v0, v1, v2, v3)); *reinterpret_cast(&tile[load_s][load_d]) = packed; } else { @@ -736,7 +846,7 @@ __global__ void mean_and_fp8_value_nhd_short_kernel( const float scale = scale_tile[elem_d]; const int64_t offset = ((b * seq_len + s) * heads + h) * head_dim + d_base + elem_d; - const float v = scale == 0.0f ? 0.0f : value_to_float(value[offset]) / scale; + const float v = value_to_float(value[offset]) * scale; packed = __hip_cvt_float_to_fp8(v, __HIP_SATFINITE, __HIP_E4M3); } tile[load_s][elem_d] = packed; @@ -759,6 +869,97 @@ __global__ void mean_and_fp8_value_nhd_short_kernel( } } +template +__global__ void quant_k_nhd_fuse_sub_mean_short_kernel( + const T* __restrict__ key, + const T* __restrict__ mean, + int8_t* __restrict__ output, + float* __restrict__ scale, + const int64_t seq_len, + const int64_t heads) { + static_assert(HeadDim == 64 || HeadDim == 128, + "short NHD smooth-K quant supports D64/D128"); + static_assert(NumPackPerThread == 1 || NumPackPerThread == 2, + "short NHD smooth-K quant supports pack1/pack2"); + constexpr int BlockSize = 64; + constexpr int PackElems = 8; + constexpr int ThreadsPerToken = HeadDim / PackElems; + constexpr int IterStride = BlockSize / NumPackPerThread; + + T x_val[NumPackPerThread][PackElems]; + T mean_val[PackElems]; + float x_float[NumPackPerThread][PackElems]; + float mean_float[PackElems]; + + const int k_block = blockIdx.x; + const int h = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const int local_token = tid / ThreadsPerToken; + const int d = (tid % ThreadsPerToken) * PackElems; + const int64_t token_base = static_cast(k_block) * BlockSize + local_token; + const int64_t mean_off = (static_cast(b) * heads + h) * HeadDim + d; + + *reinterpret_cast(&mean_val[0]) = + *reinterpret_cast(mean + mean_off); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + mean_float[i] = value_to_float(mean_val[i]); + } + + float local_amax = 0.0000001f; +#pragma unroll + for (int pack = 0; pack < NumPackPerThread; ++pack) { + const int64_t s = token_base + static_cast(pack) * IterStride; + if (s < seq_len) { + const int64_t off = ((static_cast(b) * seq_len + s) * heads + h) * HeadDim + d; + *reinterpret_cast(&x_val[pack][0]) = + *reinterpret_cast(key + off); +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + const float centered = value_to_float(x_val[pack][i]) - mean_float[i]; + x_float[pack][i] = centered; + local_amax = fmaxf(local_amax, fabsf(centered)); + } + } else { +#pragma unroll + for (int i = 0; i < PackElems; ++i) { + x_float[pack][i] = 0.0f; + } + } + } + + __shared__ float shared_amax; + const float block_amax = vllm::blockReduceMax(local_amax); + if (tid == 0) { + shared_amax = block_amax; + scale[(static_cast(b) * heads + h) * ((seq_len + 63) / 64) + k_block] = + block_amax / 127.0f; + } + __syncthreads(); + const float inv_scale = 127.0f / shared_amax; + +#pragma unroll + for (int pack = 0; pack < NumPackPerThread; ++pack) { + const int64_t s = token_base + static_cast(pack) * IterStride; + if (s < seq_len) { + const int64_t off = ((static_cast(b) * seq_len + s) * heads + h) * HeadDim + d; + char4 out0; + char4 out1; + out0.x = float_to_int8_nearby_gfx12(x_float[pack][0] * inv_scale); + out0.y = float_to_int8_nearby_gfx12(x_float[pack][1] * inv_scale); + out0.z = float_to_int8_nearby_gfx12(x_float[pack][2] * inv_scale); + out0.w = float_to_int8_nearby_gfx12(x_float[pack][3] * inv_scale); + out1.x = float_to_int8_nearby_gfx12(x_float[pack][4] * inv_scale); + out1.y = float_to_int8_nearby_gfx12(x_float[pack][5] * inv_scale); + out1.z = float_to_int8_nearby_gfx12(x_float[pack][6] * inv_scale); + out1.w = float_to_int8_nearby_gfx12(x_float[pack][7] * inv_scale); + *reinterpret_cast(output + off) = out0; + *reinterpret_cast(output + off + 4) = out1; + } + } +} + __device__ __forceinline__ int64_t qkv_offset( const int tensor_layout, const int64_t b, @@ -800,6 +1001,105 @@ __device__ __forceinline__ int k_scale_col_per_warp(const int64_t k_idx) { return static_cast(k_idx / 64); } +__device__ __forceinline__ int wmma_f16_k_for_lane_elem( + const int lane, + const int elem); + +__device__ __forceinline__ int ceil_div_i64_to_int( + const int64_t value, + const int64_t divisor) { + return static_cast((value + divisor - 1) / divisor); +} + +__device__ __forceinline__ int q_scale_col_per_thread( + const int64_t q_idx, + const int64_t qo_len, + const int64_t q_scale_groups) { + const int q_blocks = ceil_div_i64_to_int(qo_len, 128); + const int groups_per_128 = q_blocks > 0 ? + static_cast(q_scale_groups / q_blocks) : 32; + const int warp_q = groups_per_128 >= 64 ? 16 : 32; + return static_cast((q_idx / warp_q) * 8 + (q_idx & 7)); +} + +__device__ __forceinline__ int k_scale_col_per_thread( + const int64_t k_idx, + const int64_t kv_len, + const int64_t k_scale_groups) { + const int k_blocks64 = ceil_div_i64_to_int(kv_len, 64); + const int groups_per_64 = k_blocks64 > 0 ? + static_cast(k_scale_groups / k_blocks64) : 4; + const int warp_k = groups_per_64 <= 2 ? 128 : 64; + return static_cast((k_idx / warp_k) * 4 + ((k_idx & 7) >> 1)); +} + +template +__device__ __forceinline__ float qk_score_scale_scalar( + const float* __restrict__ q_scale, + const float* __restrict__ k_scale, + const int64_t b, + const int64_t hq, + const int64_t hkv, + const int64_t q_start, + const int64_t kb_base, + const int col_tile, + const int64_t qo_len, + const int64_t kv_len, + const int64_t qs_stride_b, + const int64_t qs_stride_h, + const int64_t ks_stride_b, + const int64_t ks_stride_h, + const float sm_scale) { + if constexpr (PerThreadQK) { + return 1.0f; + } else { + const int q_scale_idx = q_scale_col_per_warp(q_start); + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * 16); + return q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx] * + sm_scale * kLog2e; + } +} + +template +__device__ __forceinline__ void apply_per_thread_qk_score_scale( + float8_vec& scores, + const float* __restrict__ q_scale, + const float* __restrict__ k_scale, + const int64_t b, + const int64_t hq, + const int64_t hkv, + const int64_t q_start, + const int64_t kb_base, + const int col_tile, + const int lane, + const int64_t qo_len, + const int64_t kv_len, + const int64_t qs_stride_b, + const int64_t qs_stride_h, + const int64_t ks_stride_b, + const int64_t ks_stride_h, + const float sm_scale) { + if constexpr (PerThreadQK) { + const int64_t q_idx = q_start + (lane & 15); + const int q_scale_idx = q_scale_col_per_thread(q_idx, qo_len, qs_stride_h); + const float q_scale_local = + q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; +#pragma unroll + for (int elem = 0; elem < 8; ++elem) { + const int k_local = PvOrdered ? + wmma_f16_k_for_lane_elem(lane, elem) : + (((lane >> 4) << 3) + elem); + const int64_t k_idx = kb_base + col_tile * 16 + k_local; + const int k_scale_idx = k_scale_col_per_thread(k_idx, kv_len, ks_stride_h); + scores[elem] *= + q_scale_local * + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } + } +} + template __device__ __forceinline__ int64_t q_block_base_for_launch( const int64_t block_x, @@ -1607,7 +1907,8 @@ template < bool PvOrderedQK = false, typename QueryT = int8_t, bool QuantizeQuery = false, - bool SplitCausalPrefix = false> + bool SplitCausalPrefix = false, + bool PerThreadQK = false> SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, 1) void qk_int8_sv_f16_d64_native_kernel( const QueryT* __restrict__ q, const int8_t* __restrict__ k, @@ -1637,7 +1938,8 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, const int64_t ks_stride_b, const int64_t ks_stride_h, const int tensor_layout, - const float sm_scale) { + const float sm_scale, + const bool per_thread_qk = false) { constexpr int HeadDim = 64; constexpr int BR = BlockRows; constexpr int RM = 16; @@ -1724,9 +2026,13 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, q_stride_b, q_stride_n, q_stride_h, inv_q_scale); } } else { - const int q_scale_idx = q_scale_col_per_warp(q_start); - qs = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * - sm_scale * kLog2e; + if constexpr (PerThreadQK) { + qs = 1.0f; + } else { + const int q_scale_idx = q_scale_col_per_warp(q_start); + qs = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; + } if constexpr (UseRawPreparedQ) { #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { @@ -1825,9 +2131,9 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, float local_max = -FLT_MAX * 0.5f; #pragma unroll for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { - const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); - const float score_scale = qs * - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + const float score_scale = qk_score_scale_scalar(q_scale, k_scale, b, hq, hkv, q_start, kb_base, col_tile, + qo_len, kv_len, qs_stride_b, qs_stride_h, ks_stride_b, ks_stride_h, + sm_scale); float8_vec scores; const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; const bool fully_future = @@ -1852,6 +2158,9 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, scores = compute_tqk_score_regs( &k_tile[0][0], q_frag, col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start, kb_base, col_tile, lane, + qo_len, kv_len, qs_stride_b, qs_stride_h, ks_stride_b, ks_stride_h, + sm_scale); if constexpr (ApplyCausalMask) { if (!fully_future && k_col_start + BK > q_start) { if constexpr (PvOrderedQK) { @@ -1957,9 +2266,9 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, float local_max = -FLT_MAX * 0.5f; #pragma unroll for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { - const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); - const float score_scale = qs * - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + const float score_scale = qk_score_scale_scalar(q_scale, k_scale, b, hq, hkv, q_start, kb_base, col_tile, + qo_len, kv_len, qs_stride_b, qs_stride_h, ks_stride_b, ks_stride_h, + sm_scale); const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; float8_vec scores; if constexpr (QuantizeQuery || UseRawPreparedQ) { @@ -1969,6 +2278,9 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, scores = compute_tqk_score_regs( &k_tile[0][0], q_frag, col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start, kb_base, col_tile, lane, + qo_len, kv_len, qs_stride_b, qs_stride_h, ks_stride_b, ks_stride_h, + sm_scale); if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } @@ -1995,9 +2307,9 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, float local_sum = 0.0f; #pragma unroll for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { - const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); - const float score_scale = qs * - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + const float score_scale = qk_score_scale_scalar(q_scale, k_scale, b, hq, hkv, q_start, kb_base, col_tile, + qo_len, kv_len, qs_stride_b, qs_stride_h, ks_stride_b, ks_stride_h, + sm_scale); const int64_t k_col_start = kb_base + static_cast(col_tile) * BK; float8_vec scores; if constexpr (QuantizeQuery || UseRawPreparedQ) { @@ -2007,6 +2319,9 @@ SAGEATTN_NATIVE_WAVES_PER_EU __global__ __launch_bounds__((BlockRows / 16) * 32, scores = compute_tqk_score_regs( &k_tile[0][0], q_frag, col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start, kb_base, col_tile, lane, + qo_len, kv_len, qs_stride_b, qs_stride_h, ks_stride_b, ks_stride_h, + sm_scale); if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } @@ -2151,7 +2466,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int64_t ks_stride_b, const int64_t ks_stride_h, const int tensor_layout, - const float sm_scale) { + const float sm_scale, + const bool per_thread_qk = false) { static_assert(HeadDim == 16 || HeadDim == 64 || HeadDim == 128, "native gfx12 fp16 2q kernel supports D16/D64/D128."); constexpr int BR = BlockRows; @@ -2199,6 +2515,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 "lane-major fp16 key staging requires prepared HND BC64 PvOrderedQK."); static_assert(!StreamColTiles || (BlockCols <= 128 && (QuantizeQuery || HndContiguous)), "streaming col-tile softmax is specialized for raw/quantized QK."); + static_assert(!NoQueryTail || StaticNhdLayout, + "full-query fp16 path requires a static dispatch."); __shared__ int8_t k_tile[UseLaneMajorKey ? 1 : BC] [SharedQKHeadStride]; __shared__ uint2 k_lane_tile[UseLaneMajorKey ? (ColTiles * DTiles * 32) : 1]; @@ -2211,7 +2529,6 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int wave = tid >> 5; const int row_base = (lane >> 4) << 3; const int col = lane & 15; - const int64_t q_block_count = (qo_len + BR - 1) / BR; const int64_t hb_count = num_qo_heads * batch_size; for (;;) { int64_t q_block = static_cast(blockIdx.x); @@ -2229,11 +2546,9 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 return; } - const int64_t hkv = hq / (num_qo_heads / num_kv_heads); - const int64_t q_head_base = b * q_stride_b + hq * q_stride_h; + const int64_t hkv = SameQKHeads ? hq : hq / (num_qo_heads / num_kv_heads); const int64_t k_head_base = b * k_stride_b + hkv * k_stride_h; const int64_t v_head_base = b * v_stride_b + hkv * v_stride_h; - const int64_t o_head_base = b * o_stride_b + hq * o_stride_h; int64_t q_start[QGroups]; float qs[QGroups]; #pragma unroll @@ -2257,8 +2572,17 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int row = elem_base / HeadDim; const int d = elem_base - row * HeadDim; const int64_t q_idx = q_base + local_q_row_base + row; - if (q_idx < qo_len) { - const int64_t q_off = qkv_offset_dispatch( + if constexpr (NoQueryTail) { + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); + const uint4 raw = *reinterpret_cast(q + q_off); + const QueryT* values = reinterpret_cast(&raw); +#pragma unroll + for (int i = 0; i < QPackElems; ++i) { + local_q_amax = fmaxf(local_q_amax, fabsf(value_to_float(values[i]))); + } + } else if (q_idx < qo_len) { + const int64_t q_off = qkv_offset_dispatch( tensor_layout, b, hq, q_idx, d, q_stride_b, q_stride_n, q_stride_h); const uint4 raw = *reinterpret_cast(q + q_off); const QueryT* values = reinterpret_cast(&raw); @@ -2277,7 +2601,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int64_t qg_start = q_start[qg]; #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - q_regs[qg][dt] = pack_quant_q_i8_wmma_b_regs( + q_regs[qg][dt] = + pack_quant_q_i8_wmma_b_regs( q, tensor_layout, lane, b, hq, qg_start, qo_len, dt * BK, q_stride_b, q_stride_n, q_stride_h, inv_q_scale); } @@ -2285,16 +2610,21 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } else { #pragma unroll for (int qg = 0; qg < QGroups; ++qg) { - const int q_scale_idx = q_scale_col_per_warp(q_start[qg]); - qs[qg] = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * - sm_scale * kLog2e; + if constexpr (PerThreadQK) { + qs[qg] = 1.0f; + } else { + const int q_scale_idx = q_scale_col_per_warp(q_start[qg]); + qs[qg] = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; + } } if constexpr (UseRawPreparedQ) { #pragma unroll for (int qg = 0; qg < QGroups; ++qg) { #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - q_regs[qg][dt] = pack_q_i8_wmma_b_regs( + q_regs[qg][dt] = + pack_q_i8_wmma_b_regs( reinterpret_cast(q), tensor_layout, lane, b, hq, q_start[qg], qo_len, dt * BK, q_stride_b, q_stride_n, q_stride_h); } @@ -2308,13 +2638,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int qg = 0; qg < QGroups; ++qg) { #pragma unroll for (int dt = 0; dt < DTiles; ++dt) { - int64_t q_off; - if constexpr (HndContiguous) { - q_off = q_head_base + q_start[qg] * HeadDim + dt * BK; - } else { - q_off = qkv_offset( - tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); - } + const int64_t q_off = qkv_offset_dispatch( + tensor_layout, b, hq, q_start[qg], dt * BK, q_stride_b, q_stride_n, q_stride_h); rocwmma::load_matrix_sync(q_frag[qg][dt], q + q_off, static_cast(q_stride_n)); } } @@ -2356,7 +2681,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int elem_base = pack * PackElems; const int n = elem_base / HeadDim; const int d = elem_base - n * HeadDim; - const int64_t k_off = qkv_offset_dispatch( + const int64_t k_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); const uint4 raw = *reinterpret_cast(k + k_off); const KeyT* values = reinterpret_cast(&raw); @@ -2378,7 +2703,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int elem_base = pack * PackElems; const int n = elem_base / HeadDim; const int d = elem_base - n * HeadDim; - const int64_t k_off = qkv_offset_dispatch( + const int64_t k_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); const uint4 raw = *reinterpret_cast(k + k_off); const KeyT* values = reinterpret_cast(&raw); @@ -2411,13 +2736,8 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int vec = tid; vec < BC * KVecsPerRow; vec += Threads) { const int n = vec / KVecsPerRow; const int d = (vec - n * KVecsPerRow) * KVecBytes; - int64_t k_off; - if constexpr (HndContiguous) { - k_off = k_head_base + (kb_base + n) * HeadDim + d; - } else { - k_off = qkv_offset( - tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); - } + const int64_t k_off = qkv_offset_dispatch( + tensor_layout, b, hkv, kb_base + n, d, k_stride_b, k_stride_n, k_stride_h); *reinterpret_cast(&k_tile[n][d]) = *reinterpret_cast(reinterpret_cast(k) + k_off); } @@ -2425,9 +2745,11 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 float prepared_k_scale_tile = k_scale_tile; if constexpr (!QuantizeKey && BC <= 64) { - const int k_scale_idx = k_scale_col_per_warp(kb_base); - prepared_k_scale_tile = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(kb_base); + prepared_k_scale_tile = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } auto stage_value_tile = [&]() { @@ -2459,7 +2781,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { const int n = vec / VVecsPerRow; const int d_base = (vec - n * VVecsPerRow) * VElemsPerVec; - const int64_t v_off = qkv_offset_dispatch( + const int64_t v_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d_base, v_stride_b, v_stride_n, v_stride_h); const uint4 packed = *reinterpret_cast(v + v_off); const __half* vals = reinterpret_cast(&packed); @@ -2493,7 +2815,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { const int n = vec / VVecsPerRow; const int d = (vec - n * VVecsPerRow) * VElemsPerVec; - const int64_t v_off = qkv_offset_dispatch( + const int64_t v_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); const uint4 packed = *reinterpret_cast(v + v_off); const __half* vals = reinterpret_cast(&packed); @@ -2508,7 +2830,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int vec = tid; vec < BC * VVecsPerRow; vec += Threads) { const int n = vec / VVecsPerRow; const int d = (vec - n * VVecsPerRow) * VElemsPerVec; - const int64_t v_off = qkv_offset_dispatch( + const int64_t v_off = qkv_offset_dispatch( tensor_layout, b, hkv, kb_base + n, d, v_stride_b, v_stride_n, v_stride_h); *reinterpret_cast(&v_tile[n][d]) = *reinterpret_cast(v + v_off); @@ -2554,9 +2876,11 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 float k_scale_local = prepared_k_scale_tile; if constexpr (!QuantizeKey && BC > 64) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } if constexpr (UseLaneMajorKey) { compute_tqk_score_regs_raw_kq_2_lane_shared_key( @@ -2571,6 +2895,12 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 !fully_future0, !fully_future1, scores0[gc], scores1[gc]); } + apply_per_thread_qk_score_scale(scores0[gc], q_scale, k_scale, b, hq, hkv, q_start[0], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); + apply_per_thread_qk_score_scale(scores1[gc], q_scale, k_scale, b, hq, hkv, q_start[1], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { if (fully_future0) { #pragma unroll @@ -2605,7 +2935,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask( scores0[gc], kv_len, kb_base, col_tile, lane); apply_tqk_kv_tail_mask( @@ -2905,9 +3235,11 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } else { float k_scale_local = prepared_k_scale_tile; if constexpr (!QuantizeKey && BC > 64) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } if constexpr (UseLaneMajorKey) { compute_tqk_score_regs_raw_kq_2_lane_shared_key( @@ -2920,6 +3252,12 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 qs[0] * k_scale_local, qs[1] * k_scale_local, !fully_future0, !fully_future1, scores0, scores1); } + apply_per_thread_qk_score_scale(scores0, q_scale, k_scale, b, hq, hkv, q_start[0], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); + apply_per_thread_qk_score_scale(scores1, q_scale, k_scale, b, hq, hkv, q_start[1], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { if (fully_future0) { #pragma unroll @@ -2955,7 +3293,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask( scores0, kv_len, kb_base, col_tile, lane); apply_tqk_kv_tail_mask( @@ -3036,14 +3374,19 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } else { float k_scale_local = prepared_k_scale_tile; if constexpr (!QuantizeKey && BC > 64) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } const float score_scale = qs[qg] * k_scale_local; scores = compute_tqk_score_regs( &k_tile[0][0], q_frag[qg], col_tile, score_scale); + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { const bool needs_causal_mask = k_col_start + BK > q_start[qg]; if (needs_causal_mask) { @@ -3053,7 +3396,7 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 } } } - if (k_col_start + BK > kv_len) { + if constexpr (!NoKvTail) if (k_col_start + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } score_cache[col_tile] = scores; @@ -3180,9 +3523,11 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { float k_scale_local = prepared_k_scale_tile; if constexpr (!QuantizeKey && BC > 64) { - const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); - k_scale_local = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } const float score_scale = qs[qg] * k_scale_local; float8_vec scores; @@ -3194,13 +3539,16 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 compute_tqk_score_regs( &k_tile[0][0], q_frag[qg], col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { if (needs_causal_mask) { apply_tqk_causal_mask( scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); } } - if (kb_base + static_cast(col_tile) * BK + BK > kv_len) { + if constexpr (!NoKvTail) if (kb_base + static_cast(col_tile) * BK + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } #pragma unroll @@ -3234,9 +3582,11 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 for (int col_tile = 0; col_tile < ColTiles; ++col_tile) { float k_scale_local = prepared_k_scale_tile; if constexpr (!QuantizeKey && BC > 64) { - const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); - k_scale_local = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(kb_base + col_tile * BK); + k_scale_local = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } const float score_scale = qs[qg] * k_scale_local; float8_vec scores; @@ -3248,13 +3598,16 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 compute_tqk_score_regs( &k_tile[0][0], q_frag[qg], col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { if (needs_causal_mask) { apply_tqk_causal_mask( scores, static_cast(q_start[qg]), static_cast(kb_base), col_tile, lane); } } - if (kb_base + static_cast(col_tile) * BK + BK > kv_len) { + if constexpr (!NoKvTail) if (kb_base + static_cast(col_tile) * BK + BK > kv_len) { apply_tqk_kv_tail_mask(scores, kv_len, kb_base, col_tile, lane); } half8_vec prob_values; @@ -3343,22 +3696,20 @@ SAGEATTN_NATIVE_F16_2Q_LAUNCH_BOUNDS(BlockRows) void qk_int8_sv_f16_d64_native_2 const int64_t q_idx1 = q_idx0 + 1; const float value0 = static_cast(out_frag[qg][dt][elem]) * inv_l_rows[elem]; const float value1 = static_cast(out_frag[qg][dt][elem + 1]) * inv_l_rows[elem + 1]; - int64_t o_off0; - int64_t o_off1; - if constexpr (HndContiguous) { - o_off0 = o_head_base + q_idx0 * HeadDim + d; - o_off1 = o_head_base + q_idx1 * HeadDim + d; - } else { - o_off0 = qkv_offset( - tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h); - o_off1 = qkv_offset( - tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h); - } - if (q_idx0 < qo_len) { + const int64_t o_off0 = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h); + const int64_t o_off1 = qkv_offset_dispatch( + tensor_layout, b, hq, q_idx1, d, o_stride_b, o_stride_n, o_stride_h); + if constexpr (NoQueryTail) { store_half(output, o_off0, value0); - } - if (q_idx1 < qo_len) { store_half(output, o_off1, value1); + } else { + if (q_idx0 < qo_len) { + store_half(output, o_off0, value0); + } + if (q_idx1 < qo_len) { + store_half(output, o_off1, value1); + } } } } @@ -3392,7 +3743,8 @@ template + bool NoQueryTail = false, + bool InvLRowsEpilogue = false> SAGEATTN_NATIVE_2Q_WAVES_PER_EU(HeadDim, IsCausal) __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_f8_native_2q_kernel( const QueryT* __restrict__ q, @@ -3424,7 +3776,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ const int64_t ks_stride_b, const int64_t ks_stride_h, const int tensor_layout, - const float sm_scale) { + const float sm_scale, + const bool per_thread_qk = false) { static_assert(HeadDim == 16 || HeadDim == 64 || HeadDim == 128, "native gfx12 fp8 2q kernel supports D16/D64/D128."); static_assert(BlockCols == 16 || BlockCols == 32 || BlockCols == 64 || @@ -3490,6 +3843,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ (HeadDim == 128 || (HeadDim == 64 && IsCausal && (UsePrepackedLaneMajorK || UsePrepackedLaneMajorValue))); + static_assert(!PrepackedLaneMajorKV || HndContiguous, + "prepacked transposed fp8 K/V requires contiguous HND tensors."); static_assert(!QuantizeKeyValue || ((HeadDim == 64 || HeadDim == 128) && BlockCols == 64 && !ValueTransposed), @@ -3509,8 +3864,8 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ static_assert(ValueTiles == 1 || ValueTiles == 4 || ValueTiles == 8, "native fp8 2q stores one D16, D64, or D128 value slice per launch."); static_assert(ValueTileBase + ValueTiles <= DTiles, "invalid fp8 value tile slice."); - static_assert(!NoQueryTail || (StaticNhdLayout && !IsCausal), - "full-query fp8 path requires a static non-causal dispatch."); + static_assert(!NoQueryTail || StaticNhdLayout, + "full-query fp8 path requires a static dispatch."); __shared__ int8_t k_tile[UsePrepackedLaneMajorK ? 1 : BC][SharedHeadStride]; __shared__ uint8_t v_tile[UsePrepackedLaneMajorValue ? 1 : SharedValueRows] @@ -3604,9 +3959,13 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } else { #pragma unroll for (int qg = 0; qg < QGroups; ++qg) { - const int q_scale_idx = q_scale_col_per_warp(q_start[qg]); - qs[qg] = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * - sm_scale * kLog2e; + if constexpr (PerThreadQK) { + qs[qg] = 1.0f; + } else { + const int q_scale_idx = q_scale_col_per_warp(q_start[qg]); + qs[qg] = q_scale[b * qs_stride_b + hq * qs_stride_h + q_scale_idx] * + sm_scale * kLog2e; + } } if constexpr (UsePrepackedLaneMajorK) { #pragma unroll @@ -3809,9 +4168,11 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } float prepared_k_scale_tile = k_scale_tile; if constexpr (!QuantizeKeyValue && BC <= 64) { - const int k_scale_idx = k_scale_col_per_warp(kb_base); - prepared_k_scale_tile = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(kb_base); + prepared_k_scale_tile = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } #pragma unroll for (int sc = 0; sc < StreamCols; ++sc) { @@ -3848,13 +4209,15 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } } float k_scale_local = k_scale_tile; - if constexpr (!QuantizeKeyValue && BC <= 64) { + if constexpr (!QuantizeKeyValue && BC <= 64) { if constexpr (!PerThreadQK) { k_scale_local = prepared_k_scale_tile; } } else if constexpr (!QuantizeKeyValue) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } if constexpr (QuantizeQuery || UsePrepackedLaneMajorK) { i32x8_vec score_acc[QGroups]; @@ -3897,6 +4260,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int elem = 0; elem < 8; ++elem) { scores[elem] = static_cast(score_acc[qg][elem]) * score_scale; } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { const bool needs_causal_mask = k_col_start + BK > q_start[qg]; if (needs_causal_mask) { @@ -3951,6 +4317,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int elem = 0; elem < 8; ++elem) { scores[elem] = static_cast(score_rm[elem]) * score_scale; } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { const bool needs_causal_mask = k_col_start + BK > q_start[qg]; if (needs_causal_mask) { @@ -4095,9 +4464,11 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ float local_max[QGroups]; float prepared_k_scale_tile = k_scale_tile; if constexpr (!QuantizeKeyValue && BC <= 64) { - const int k_scale_idx = k_scale_col_per_warp(kb_base); - prepared_k_scale_tile = - k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(kb_base); + prepared_k_scale_tile = + k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } #pragma unroll for (int qg = 0; qg < QGroups; ++qg) { @@ -4135,10 +4506,14 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } float k_scale_local = k_scale_tile; if constexpr (!QuantizeKeyValue && BC <= 64) { - k_scale_local = prepared_k_scale_tile; + if constexpr (!PerThreadQK) { + k_scale_local = prepared_k_scale_tile; + } } else if constexpr (!QuantizeKeyValue) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } i32x8_vec score_acc[QGroups]; #pragma unroll @@ -4188,6 +4563,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ for (int elem = 0; elem < 8; ++elem) { scores[elem] = static_cast(score_acc[qg][elem]) * score_scale; } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { const bool needs_causal_mask = k_col_start + BK > q_start[qg]; if (needs_causal_mask) { @@ -4283,8 +4661,10 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } else { float k_scale_local = k_scale_tile; if constexpr (!QuantizeKeyValue) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } const float score_scale = qs[qg] * k_scale_local; if constexpr (QuantizeQuery) { @@ -4295,6 +4675,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ compute_tqk_score_regs( &k_tile[0][0], q_frag[qg], col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { const bool needs_causal_mask = k_col_start + BK > q_start[qg]; if (needs_causal_mask) { @@ -4439,8 +4822,10 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ } else { float k_scale_local = k_scale_tile; if constexpr (!QuantizeKeyValue) { - const int k_scale_idx = k_scale_col_per_warp(k_col_start); - k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + if constexpr (!PerThreadQK) { + const int k_scale_idx = k_scale_col_per_warp(k_col_start); + k_scale_local = k_scale[b * ks_stride_b + hkv * ks_stride_h + k_scale_idx]; + } } const float score_scale = qs[qg] * k_scale_local; if constexpr (QuantizeQuery) { @@ -4451,6 +4836,9 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ compute_tqk_score_regs( &k_tile[0][0], q_frag[qg], col_tile, score_scale); } + apply_per_thread_qk_score_scale(scores, q_scale, k_scale, b, hq, hkv, q_start[qg], kb_base, + col_tile, lane, qo_len, kv_len, qs_stride_b, qs_stride_h, + ks_stride_b, ks_stride_h, sm_scale); if constexpr (ApplyCausalMask) { const bool needs_causal_mask = k_col_start + BK > q_start[qg]; if (needs_causal_mask) { @@ -4579,7 +4967,12 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ float l_rows[8]; #pragma unroll for (int elem = 0; elem < 8; ++elem) { - l_rows[elem] = __shfl(l[qg], row_base + elem, 32); + const float l_sum = __shfl(l[qg], row_base + elem, 32); + if constexpr (InvLRowsEpilogue) { + l_rows[elem] = l_sum == 0.0f ? 0.0f : 1.0f / l_sum; + } else { + l_rows[elem] = l_sum; + } } #pragma unroll for (int vdt = 0; vdt < ValueTiles; ++vdt) { @@ -4590,12 +4983,19 @@ __global__ __launch_bounds__(BlockRows * (2 / QGroupsParam), 1) void qk_int8_sv_ const int elem = pair * 2; const int64_t q_idx0 = q_start[qg] + row_base + elem; const int64_t q_idx1 = q_idx0 + 1; - const float l_sum0 = l_rows[elem]; - const float l_sum1 = l_rows[elem + 1]; - const float value0 = l_sum0 == 0.0f ? - 0.0f : (out_frag[qg][vdt][elem] / l_sum0) * value_scale; - const float value1 = l_sum1 == 0.0f ? - 0.0f : (out_frag[qg][vdt][elem + 1] / l_sum1) * value_scale; + float value0; + float value1; + if constexpr (InvLRowsEpilogue) { + value0 = out_frag[qg][vdt][elem] * l_rows[elem] * value_scale; + value1 = out_frag[qg][vdt][elem + 1] * l_rows[elem + 1] * value_scale; + } else { + const float l_sum0 = l_rows[elem]; + const float l_sum1 = l_rows[elem + 1]; + value0 = l_sum0 == 0.0f ? + 0.0f : (out_frag[qg][vdt][elem] / l_sum0) * value_scale; + value1 = l_sum1 == 0.0f ? + 0.0f : (out_frag[qg][vdt][elem + 1] / l_sum1) * value_scale; + } store_output_value(output, qkv_offset_dispatch( tensor_layout, b, hq, q_idx0, d, o_stride_b, o_stride_n, o_stride_h), value0); store_output_value(output, qkv_offset_dispatch( @@ -4617,11 +5017,13 @@ template + int StaticKvLen = 0, + bool SubtractKeyMean = false> __global__ void prepare_qkv_hnd_kernel( const T* __restrict__ query, const T* __restrict__ key, const T* __restrict__ value, + const T* __restrict__ key_mean, int8_t* __restrict__ query_out, int8_t* __restrict__ key_out, float* __restrict__ query_scale, @@ -4652,6 +5054,8 @@ __global__ void prepare_qkv_hnd_kernel( "static QKV preparation Q length must cover full two-group Q tasks."); static_assert(StaticKvLen == 0 || (StaticKvLen % KRows) == 0, "static QKV preparation KV length must cover full K groups."); + static_assert(!SubtractKeyMean || !PrepackF16KLane, + "smooth-K preparation does not use lane-major K prepack."); __shared__ float shared_amax[2]; __shared__ float shared_pair_amax[2][32]; @@ -4759,14 +5163,25 @@ __global__ void prepare_qkv_hnd_kernel( const T* values = reinterpret_cast(&raw); char4 out0; char4 out1; - out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale0); - out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale0); - out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale0); - out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale0); - out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale0); - out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale0); - out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale0); - out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale0); + if constexpr (SubtractKeyMean) { + out0.x = float_to_int8_nearby_gfx12(value_to_float(values[0]) * inv_scale0); + out0.y = float_to_int8_nearby_gfx12(value_to_float(values[1]) * inv_scale0); + out0.z = float_to_int8_nearby_gfx12(value_to_float(values[2]) * inv_scale0); + out0.w = float_to_int8_nearby_gfx12(value_to_float(values[3]) * inv_scale0); + out1.x = float_to_int8_nearby_gfx12(value_to_float(values[4]) * inv_scale0); + out1.y = float_to_int8_nearby_gfx12(value_to_float(values[5]) * inv_scale0); + out1.z = float_to_int8_nearby_gfx12(value_to_float(values[6]) * inv_scale0); + out1.w = float_to_int8_nearby_gfx12(value_to_float(values[7]) * inv_scale0); + } else { + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale0); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale0); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale0); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale0); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale0); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale0); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale0); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale0); + } *reinterpret_cast(query_out + off) = out0; *reinterpret_cast(query_out + off + 4) = out1; } @@ -4780,14 +5195,25 @@ __global__ void prepare_qkv_hnd_kernel( const T* values = reinterpret_cast(&raw); char4 out0; char4 out1; - out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale1); - out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale1); - out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale1); - out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale1); - out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale1); - out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale1); - out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale1); - out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale1); + if constexpr (SubtractKeyMean) { + out0.x = float_to_int8_nearby_gfx12(value_to_float(values[0]) * inv_scale1); + out0.y = float_to_int8_nearby_gfx12(value_to_float(values[1]) * inv_scale1); + out0.z = float_to_int8_nearby_gfx12(value_to_float(values[2]) * inv_scale1); + out0.w = float_to_int8_nearby_gfx12(value_to_float(values[3]) * inv_scale1); + out1.x = float_to_int8_nearby_gfx12(value_to_float(values[4]) * inv_scale1); + out1.y = float_to_int8_nearby_gfx12(value_to_float(values[5]) * inv_scale1); + out1.z = float_to_int8_nearby_gfx12(value_to_float(values[6]) * inv_scale1); + out1.w = float_to_int8_nearby_gfx12(value_to_float(values[7]) * inv_scale1); + } else { + out0.x = float_to_int8_rn_gfx12(value_to_float(values[0]) * inv_scale1); + out0.y = float_to_int8_rn_gfx12(value_to_float(values[1]) * inv_scale1); + out0.z = float_to_int8_rn_gfx12(value_to_float(values[2]) * inv_scale1); + out0.w = float_to_int8_rn_gfx12(value_to_float(values[3]) * inv_scale1); + out1.x = float_to_int8_rn_gfx12(value_to_float(values[4]) * inv_scale1); + out1.y = float_to_int8_rn_gfx12(value_to_float(values[5]) * inv_scale1); + out1.z = float_to_int8_rn_gfx12(value_to_float(values[6]) * inv_scale1); + out1.w = float_to_int8_rn_gfx12(value_to_float(values[7]) * inv_scale1); + } *reinterpret_cast(query_out + off) = out0; *reinterpret_cast(query_out + off + 4) = out1; } @@ -4820,9 +5246,17 @@ __global__ void prepare_qkv_hnd_kernel( HeadDim + d; const uint4 raw = *reinterpret_cast(key + off); const T* values = reinterpret_cast(&raw); + const T* mean_values = nullptr; + if constexpr (SubtractKeyMean) { + mean_values = key_mean + (static_cast(b) * kv_heads + head) * HeadDim + d; + } #pragma unroll for (int i = 0; i < PackElems; ++i) { - local_amax = fmaxf(local_amax, fabsf(value_to_float(values[i]))); + float value = value_to_float(values[i]); + if constexpr (SubtractKeyMean) { + value -= value_to_float(mean_values[i]); + } + local_amax = fmaxf(local_amax, fabsf(value)); } } } @@ -4848,16 +5282,49 @@ __global__ void prepare_qkv_hnd_kernel( const uint4 raw_v = *reinterpret_cast(value + off); const T* k_values = reinterpret_cast(&raw_k); const T* v_values = reinterpret_cast(&raw_v); + const T* mean_values = nullptr; + if constexpr (SubtractKeyMean) { + mean_values = key_mean + (static_cast(b) * kv_heads + head) * HeadDim + d; + } char4 out0; char4 out1; - out0.x = float_to_int8_rn_gfx12(value_to_float(k_values[0]) * inv_scale); - out0.y = float_to_int8_rn_gfx12(value_to_float(k_values[1]) * inv_scale); - out0.z = float_to_int8_rn_gfx12(value_to_float(k_values[2]) * inv_scale); - out0.w = float_to_int8_rn_gfx12(value_to_float(k_values[3]) * inv_scale); - out1.x = float_to_int8_rn_gfx12(value_to_float(k_values[4]) * inv_scale); - out1.y = float_to_int8_rn_gfx12(value_to_float(k_values[5]) * inv_scale); - out1.z = float_to_int8_rn_gfx12(value_to_float(k_values[6]) * inv_scale); - out1.w = float_to_int8_rn_gfx12(value_to_float(k_values[7]) * inv_scale); + float k0 = value_to_float(k_values[0]); + float k1 = value_to_float(k_values[1]); + float k2 = value_to_float(k_values[2]); + float k3 = value_to_float(k_values[3]); + float k4 = value_to_float(k_values[4]); + float k5 = value_to_float(k_values[5]); + float k6 = value_to_float(k_values[6]); + float k7 = value_to_float(k_values[7]); + if constexpr (SubtractKeyMean) { + k0 -= value_to_float(mean_values[0]); + k1 -= value_to_float(mean_values[1]); + k2 -= value_to_float(mean_values[2]); + k3 -= value_to_float(mean_values[3]); + k4 -= value_to_float(mean_values[4]); + k5 -= value_to_float(mean_values[5]); + k6 -= value_to_float(mean_values[6]); + k7 -= value_to_float(mean_values[7]); + } + if constexpr (SubtractKeyMean) { + out0.x = float_to_int8_nearby_gfx12(k0 * inv_scale); + out0.y = float_to_int8_nearby_gfx12(k1 * inv_scale); + out0.z = float_to_int8_nearby_gfx12(k2 * inv_scale); + out0.w = float_to_int8_nearby_gfx12(k3 * inv_scale); + out1.x = float_to_int8_nearby_gfx12(k4 * inv_scale); + out1.y = float_to_int8_nearby_gfx12(k5 * inv_scale); + out1.z = float_to_int8_nearby_gfx12(k6 * inv_scale); + out1.w = float_to_int8_nearby_gfx12(k7 * inv_scale); + } else { + out0.x = float_to_int8_rn_gfx12(k0 * inv_scale); + out0.y = float_to_int8_rn_gfx12(k1 * inv_scale); + out0.z = float_to_int8_rn_gfx12(k2 * inv_scale); + out0.w = float_to_int8_rn_gfx12(k3 * inv_scale); + out1.x = float_to_int8_rn_gfx12(k4 * inv_scale); + out1.y = float_to_int8_rn_gfx12(k5 * inv_scale); + out1.z = float_to_int8_rn_gfx12(k6 * inv_scale); + out1.w = float_to_int8_rn_gfx12(k7 * inv_scale); + } if constexpr (PrepackFp8Lane) { const int row_in_group = row & 63; const int col_tile = row_in_group >> 4; @@ -5454,7 +5921,9 @@ Tensor transpose_value_fp8_hnd_gfx12(Tensor value) { return transpose_value_hnd_gfx12(value); } -Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale) { +Tensor transpose_value_fp8_scaled_hnd_gfx12( + Tensor value, + Tensor value_scale) { STD_TORCH_CHECK(value.is_cuda() && value_scale.is_cuda(), "gfx12 scaled value transpose expects CUDA/HIP tensors"); STD_TORCH_CHECK(value.dim() == 4, "gfx12 scaled value transpose expects [B, H, S, D]"); @@ -5474,20 +5943,22 @@ Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale) { const int64_t heads = value.size(1); const int64_t seq_len = value.size(2); const int64_t head_dim = value.size(3); - Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); + Tensor output = + new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); dim3 block(256); dim3 grid((seq_len + 127) / 128, (head_dim + 15) / 16, batch * heads); + const hipStream_t stream = current_hip_stream(value); if (value.scalar_type() == ScalarType::Half) { - transpose_value_fp8_scaled_hnd_kernel<__half><<>>( + transpose_value_fp8_scaled_hnd_kernel<__half><<>>( reinterpret_cast(value.data_ptr()), - reinterpret_cast(value_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), reinterpret_cast(output.data_ptr()), batch * heads, seq_len, head_dim); } else { - transpose_value_fp8_scaled_hnd_kernel<__hip_bfloat16><<>>( + transpose_value_fp8_scaled_hnd_kernel<__hip_bfloat16><<>>( reinterpret_cast(value.data_ptr()), - reinterpret_cast(value_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), reinterpret_cast(output.data_ptr()), batch * heads, seq_len, head_dim); } @@ -5513,19 +5984,22 @@ std::vector fp8_value_nhd_short_gfx12( STD_TORCH_CHECK(seq_len == 512 || seq_len == 1024, "gfx12 short NHD fp8 value prep currently supports sequence length 512 or 1024"); - Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); - Tensor value_scale = new_empty_like(value, {batch, heads, head_dim}, ScalarType::Float); + Tensor output = + new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); + Tensor value_scale = + new_empty_like(value, {batch, heads, head_dim}, ScalarType::Float); dim3 block(256); dim3 grid((head_dim + 15) / 16, heads, batch); + const hipStream_t stream = current_hip_stream(value); if (value.scalar_type() == ScalarType::Half) { - fp8_value_nhd_short_kernel<__half><<>>( + fp8_value_nhd_short_kernel<__half><<>>( reinterpret_cast(value.data_ptr()), reinterpret_cast(output.data_ptr()), reinterpret_cast(value_scale.data_ptr()), seq_len, heads, head_dim, static_cast(scale_max)); } else { - fp8_value_nhd_short_kernel<__hip_bfloat16><<>>( + fp8_value_nhd_short_kernel<__hip_bfloat16><<>>( reinterpret_cast(value.data_ptr()), reinterpret_cast(output.data_ptr()), reinterpret_cast(value_scale.data_ptr()), @@ -5548,15 +6022,65 @@ Tensor mean_nhd_gfx12(Tensor input) { const int64_t head_dim = input.size(3); Tensor mean = new_empty_like(input, {batch, heads, head_dim}, input.scalar_type()); + const bool use_short_mean = + (head_dim == 64 || head_dim == 128) && (seq_len == 512 || seq_len == 1024); + const hipStream_t stream = current_hip_stream(input); + if (use_short_mean) { + dim3 block(1024); + dim3 grid((head_dim + 31) / 32, heads, batch); + if (input.scalar_type() == ScalarType::Half) { + mean_nhd_short_kernel<__half, 32, 32><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__half*>(mean.data_ptr()), + seq_len, heads, head_dim); + } else { + mean_nhd_short_kernel<__hip_bfloat16, 32, 32><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(mean.data_ptr()), + seq_len, heads, head_dim); + } + } else { + dim3 block(256); + dim3 grid((head_dim + 15) / 16, heads, batch); + if (input.scalar_type() == ScalarType::Half) { + mean_nhd_kernel<__half><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__half*>(mean.data_ptr()), + seq_len, heads, head_dim); + } else { + mean_nhd_kernel<__hip_bfloat16><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__hip_bfloat16*>(mean.data_ptr()), + seq_len, heads, head_dim); + } + } + hip_kernel_launch_check(); + return mean; +} + +Tensor mean_hnd_gfx12(Tensor input) { + STD_TORCH_CHECK(input.is_cuda(), "gfx12 HND mean expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(input.dim() == 4, "gfx12 HND mean expects [B, H, S, D]"); + STD_TORCH_CHECK(input.is_contiguous(), "gfx12 HND mean expects contiguous HND input"); + STD_TORCH_CHECK(input.scalar_type() == ScalarType::Half || input.scalar_type() == ScalarType::BFloat16, + "gfx12 HND mean supports fp16/bf16 input"); + + const int64_t batch = input.size(0); + const int64_t heads = input.size(1); + const int64_t seq_len = input.size(2); + const int64_t head_dim = input.size(3); + Tensor mean = new_empty_like(input, {batch, heads, head_dim}, input.scalar_type()); + dim3 block(256); dim3 grid((head_dim + 15) / 16, heads, batch); + const hipStream_t stream = current_hip_stream(input); if (input.scalar_type() == ScalarType::Half) { - mean_nhd_kernel<__half><<>>( + mean_hnd_kernel<__half><<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast<__half*>(mean.data_ptr()), seq_len, heads, head_dim); } else { - mean_nhd_kernel<__hip_bfloat16><<>>( + mean_hnd_kernel<__hip_bfloat16><<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast<__hip_bfloat16*>(mean.data_ptr()), seq_len, heads, head_dim); @@ -5579,10 +6103,7 @@ std::vector mean_and_fp8_value_nhd_short_gfx12( "gfx12 short NHD mean/value prep expects matching key/value dtypes"); STD_TORCH_CHECK(key.scalar_type() == ScalarType::Half || key.scalar_type() == ScalarType::BFloat16, "gfx12 short NHD mean/value prep supports fp16/bf16 input"); - STD_TORCH_CHECK(key.size(0) == value.size(0) && - key.size(1) == value.size(1) && - key.size(2) == value.size(2) && - key.size(3) == value.size(3), + STD_TORCH_CHECK(same_sizes(key, value), "gfx12 short NHD mean/value prep expects matching key/value shapes"); const int64_t batch = value.size(0); @@ -5591,53 +6112,44 @@ std::vector mean_and_fp8_value_nhd_short_gfx12( const int64_t head_dim = value.size(3); STD_TORCH_CHECK(head_dim == 64 || head_dim == 128, "gfx12 short NHD mean/value prep currently supports head_dim 64 or 128"); - STD_TORCH_CHECK(seq_len == 512 || seq_len == 1024, - "gfx12 short NHD mean/value prep currently supports sequence length 512 or 1024"); + STD_TORCH_CHECK(seq_len == 512 || seq_len == 1024 || seq_len == 2048 || + seq_len == 4096 || seq_len == 8192, + "gfx12 NHD mean/value prep currently supports sequence length 512/1024/2048/4096/8192"); Tensor key_mean = new_empty_like(key, {batch, heads, head_dim}, key.scalar_type()); - Tensor output = new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); - Tensor value_scale = new_empty_like(value, {batch, heads, head_dim}, ScalarType::Float); + Tensor output = + new_empty_like(value, {batch, heads, head_dim, seq_len}, ScalarType::Byte); + Tensor value_scale = + new_empty_like(value, {batch, heads, head_dim}, ScalarType::Float); const int seq_lanes = head_dim == 64 ? 32 : 16; dim3 block(32 * seq_lanes); dim3 grid((head_dim + 31) / 32, heads, batch); + const hipStream_t stream = current_hip_stream(key); + +#define SAGEATTN_LAUNCH_MEAN_FP8_VALUE_SHORT(T_, LANES_) \ + mean_and_fp8_value_nhd_short_kernel<<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_mean.data_ptr()), \ + reinterpret_cast(output.data_ptr()), \ + reinterpret_cast(value_scale.data_ptr()), \ + seq_len, heads, head_dim, static_cast(scale_max)) + if (value.scalar_type() == ScalarType::Half) { if (head_dim == 64) { - mean_and_fp8_value_nhd_short_kernel<__half, 32><<>>( - reinterpret_cast(key.data_ptr()), - reinterpret_cast(value.data_ptr()), - reinterpret_cast<__half*>(key_mean.data_ptr()), - reinterpret_cast(output.data_ptr()), - reinterpret_cast(value_scale.data_ptr()), - seq_len, heads, head_dim, static_cast(scale_max)); + SAGEATTN_LAUNCH_MEAN_FP8_VALUE_SHORT(__half, 32); } else { - mean_and_fp8_value_nhd_short_kernel<__half, 16><<>>( - reinterpret_cast(key.data_ptr()), - reinterpret_cast(value.data_ptr()), - reinterpret_cast<__half*>(key_mean.data_ptr()), - reinterpret_cast(output.data_ptr()), - reinterpret_cast(value_scale.data_ptr()), - seq_len, heads, head_dim, static_cast(scale_max)); + SAGEATTN_LAUNCH_MEAN_FP8_VALUE_SHORT(__half, 16); } } else { if (head_dim == 64) { - mean_and_fp8_value_nhd_short_kernel<__hip_bfloat16, 32><<>>( - reinterpret_cast(key.data_ptr()), - reinterpret_cast(value.data_ptr()), - reinterpret_cast<__hip_bfloat16*>(key_mean.data_ptr()), - reinterpret_cast(output.data_ptr()), - reinterpret_cast(value_scale.data_ptr()), - seq_len, heads, head_dim, static_cast(scale_max)); + SAGEATTN_LAUNCH_MEAN_FP8_VALUE_SHORT(__hip_bfloat16, 32); } else { - mean_and_fp8_value_nhd_short_kernel<__hip_bfloat16, 16><<>>( - reinterpret_cast(key.data_ptr()), - reinterpret_cast(value.data_ptr()), - reinterpret_cast<__hip_bfloat16*>(key_mean.data_ptr()), - reinterpret_cast(output.data_ptr()), - reinterpret_cast(value_scale.data_ptr()), - seq_len, heads, head_dim, static_cast(scale_max)); + SAGEATTN_LAUNCH_MEAN_FP8_VALUE_SHORT(__hip_bfloat16, 16); } } +#undef SAGEATTN_LAUNCH_MEAN_FP8_VALUE_SHORT hip_kernel_launch_check(); return {key_mean, output, value_scale}; } @@ -5705,22 +6217,25 @@ std::vector prepare_qkv_hnd_gfx12( const bool use_qkv_static_1024 = ToFp8 && head_dim == 128 && q_len == 1024 && kv_len == 1024 && fuse_self_qkv; - if (query.scalar_type() == ScalarType::Half) { + const hipStream_t stream = current_hip_stream(query); + if (query.scalar_type() == ScalarType::Half) { if (head_dim == 16) { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, fuse_self_qkv); } else if (head_dim == 64) { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), @@ -5730,20 +6245,22 @@ std::vector prepare_qkv_hnd_gfx12( if (use_qkv_static_1024) { prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, true, true, false, false, false, - true, 1024, 1024><<>>( + true, 1024, 1024><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, true); } else { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), @@ -5753,20 +6270,22 @@ std::vector prepare_qkv_hnd_gfx12( } } else { if (head_dim == 16) { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, fuse_self_qkv); } else if (head_dim == 64) { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), @@ -5776,20 +6295,22 @@ std::vector prepare_qkv_hnd_gfx12( if (use_qkv_static_1024) { prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, true, true, false, false, false, - true, 1024, 1024><<>>( + true, 1024, 1024><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, true); } else { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), reinterpret_cast(value_out.data_ptr()), @@ -5802,6 +6323,128 @@ std::vector prepare_qkv_hnd_gfx12( return {query_out, query_scale, key_out, key_scale, value_out}; } +std::vector prepare_qkv_hnd_smooth_f16_gfx12( + Tensor query, + Tensor key, + Tensor value, + Tensor key_mean) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda() && key_mean.is_cuda(), + "smooth gfx12 QKV preparation expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "smooth gfx12 QKV preparation expects [B, H, S, D]"); + STD_TORCH_CHECK(key_mean.dim() == 3, + "smooth gfx12 QKV preparation expects key_mean [B, H, D]"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous() && + key_mean.is_contiguous(), + "smooth gfx12 QKV preparation expects contiguous HND tensors"); + STD_TORCH_CHECK(query.scalar_type() == key.scalar_type() && + query.scalar_type() == value.scalar_type() && + query.scalar_type() == key_mean.scalar_type(), + "smooth gfx12 QKV preparation expects matching input dtypes"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half || query.scalar_type() == ScalarType::BFloat16, + "smooth gfx12 QKV preparation supports fp16/bf16 input"); + STD_TORCH_CHECK(query.size(0) == key.size(0) && query.size(0) == value.size(0) && + query.size(0) == key_mean.size(0), + "Q/K/V batch size mismatch"); + STD_TORCH_CHECK(query.size(3) == key.size(3) && query.size(3) == value.size(3) && + query.size(3) == key_mean.size(2), + "Q/K/V head_dim mismatch"); + STD_TORCH_CHECK(key.size(1) == value.size(1) && key.size(2) == value.size(2) && + key.size(1) == key_mean.size(1), + "K/V shape mismatch"); + + const int64_t batch = query.size(0); + const int64_t q_heads = query.size(1); + const int64_t q_len = query.size(2); + const int64_t kv_heads = key.size(1); + const int64_t kv_len = key.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK(head_dim == 64 || head_dim == 128, + "smooth gfx12 QKV preparation supports head_dim 64 or 128"); + STD_TORCH_CHECK((q_len % 64) == 0 && (kv_len % 64) == 0, + "smooth gfx12 QKV preparation requires sequence lengths divisible by 64"); + + const int q_groups = static_cast((q_len + 31) / 32); + const int q_task_groups = (q_groups + 1) / 2; + const int k_groups = static_cast((kv_len + 63) / 64); + const bool fuse_self_qkv = + q_heads == kv_heads && q_len == kv_len && q_task_groups == k_groups; + Tensor query_out = new_empty_like(query, query.sizes(), ScalarType::Char); + Tensor key_out = new_empty_like(key, key.sizes(), ScalarType::Char); + Tensor query_scale = + new_empty_like(query, {batch, q_heads, q_groups}, ScalarType::Float); + Tensor key_scale = + new_empty_like(key, {batch, kv_heads, k_groups}, ScalarType::Float); + Tensor value_out = + new_empty_like(value, {batch, kv_heads, head_dim, kv_len}, ScalarType::Half); + + constexpr int D64PrepThreads = 256; + const dim3 block(head_dim == 64 ? D64PrepThreads : 256); + const dim3 grid(fuse_self_qkv ? k_groups : (q_task_groups + k_groups), + std::max(q_heads, kv_heads), + batch); + const hipStream_t stream = current_hip_stream(query); + if (query.scalar_type() == ScalarType::Half) { + if (head_dim == 64) { + prepare_qkv_hnd_kernel<__half, __half, false, 64, D64PrepThreads, + true, true, false, false, false, false, 0, 0, true> + <<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(key_mean.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast<__half*>(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else { + prepare_qkv_hnd_kernel<__half, __half, false, 128, 256, + true, true, false, false, false, false, 0, 0, true> + <<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(key_mean.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast<__half*>(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } + } else { + if (head_dim == 64) { + prepare_qkv_hnd_kernel<__hip_bfloat16, __half, false, 64, D64PrepThreads, + true, true, false, false, false, false, 0, 0, true> + <<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(key_mean.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast<__half*>(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } else { + prepare_qkv_hnd_kernel<__hip_bfloat16, __half, false, 128, 256, + true, true, false, false, false, false, 0, 0, true> + <<>>( + reinterpret_cast(query.data_ptr()), + reinterpret_cast(key.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(key_mean.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast<__half*>(value_out.data_ptr()), + batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, + fuse_self_qkv); + } + } + hip_kernel_launch_check(); + return {query_out, query_scale, key_out, key_scale, value_out}; +} + template prepare_qkv_hnd_packed_gfx12( const dim3 grid(fuse_self_qkv ? k_groups : (q_task_groups + k_groups), std::max(q_heads, kv_heads), batch); - if (query.scalar_type() == ScalarType::Half) { + const hipStream_t stream = current_hip_stream(query); + if (query.scalar_type() == ScalarType::Half) { if (head_dim == 16) { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256, TransposeValue, true><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256, TransposeValue, true><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, fuse_self_qkv); } else if (head_dim == 64) { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, true, false, PrepackF16KLane, PrepackFp8Lane><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, true, false, PrepackF16KLane, PrepackFp8Lane><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, fuse_self_qkv); } else { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, TransposeValue, true, false, false, PrepackFp8Lane><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, TransposeValue, true, false, false, PrepackFp8Lane><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, @@ -5926,28 +6573,31 @@ std::vector prepare_qkv_hnd_packed_gfx12( } } else { if (head_dim == 16) { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256, TransposeValue, true><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256, TransposeValue, true><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, fuse_self_qkv); } else if (head_dim == 64) { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, true, false, PrepackF16KLane, PrepackFp8Lane><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, true, false, PrepackF16KLane, PrepackFp8Lane><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, fuse_self_qkv); } else { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, TransposeValue, true, false, false, PrepackFp8Lane><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, TransposeValue, true, false, false, PrepackFp8Lane><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), + nullptr, query_ptr, key_ptr, scale_ptr, scale_ptr + q_scale_numel, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, @@ -6057,31 +6707,32 @@ std::vector prepare_kv_hnd_packed_gfx12( constexpr int D64PrepThreads = 256; const dim3 block(head_dim == 64 ? D64PrepThreads : 256); const dim3 grid(k_groups, kv_heads, batch); - const bool use_kv1 = q_len <= 4096; + const hipStream_t stream = current_hip_stream(query); + const bool use_kv1 = q_len <= 4096; const bool use_kv_static_1024 = ToFp8 && TransposeValue && use_kv1 && q_len == 1024 && kv_len == 1024 && !PrepackFp8VLane && !PrepackFp8KLane; if (query.scalar_type() == ScalarType::Half) { if (head_dim == 16) { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256, TransposeValue, false><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 16, 256, TransposeValue, false><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), - nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + nullptr, nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, false); } else if (head_dim == 64) { if constexpr (ToFp8 && TransposeValue) { if (use_kv1) { if (use_kv_static_1024) { - prepare_kv_hnd_fp8_kernel<__half, 1, false, PrepackFp8Lane, 64, false, false, false, true, 256, 1024><<>>( + prepare_kv_hnd_fp8_kernel<__half, 1, false, PrepackFp8Lane, 64, false, false, false, true, 256, 1024><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } else { - prepare_kv_hnd_fp8_kernel<__half, 1, false, PrepackFp8Lane><<>>( + prepare_kv_hnd_fp8_kernel<__half, 1, false, PrepackFp8Lane><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), @@ -6089,18 +6740,18 @@ std::vector prepare_kv_hnd_packed_gfx12( } } else { const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); - prepare_kv_hnd_fp8_kernel<__half, 2, false, PrepackFp8Lane><<>>( + prepare_kv_hnd_fp8_kernel<__half, 2, false, PrepackFp8Lane><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } } else { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, false, PrepackF16VLane, PrepackF16KLane><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, false, PrepackF16VLane, PrepackF16KLane><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), - nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + nullptr, nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, false); } @@ -6108,13 +6759,13 @@ std::vector prepare_kv_hnd_packed_gfx12( if constexpr (ToFp8 && TransposeValue) { if (use_kv1) { if (use_kv_static_1024) { - prepare_kv_hnd_fp8_kernel<__half, 1, false, false, 128, false, false, false, true, 256, 1024><<>>( + prepare_kv_hnd_fp8_kernel<__half, 1, false, false, 128, false, false, false, true, 256, 1024><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } else { - prepare_kv_hnd_fp8_kernel<__half, 1, false, false, 128><<>>( + prepare_kv_hnd_fp8_kernel<__half, 1, false, false, 128><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), @@ -6122,42 +6773,42 @@ std::vector prepare_kv_hnd_packed_gfx12( } } else { const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); - prepare_kv_hnd_fp8_kernel<__half, 2, false, false, 128><<>>( + prepare_kv_hnd_fp8_kernel<__half, 2, false, false, 128><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } } else { - prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, TransposeValue, false><<>>( + prepare_qkv_hnd_kernel<__half, OutT, ToFp8, 128, 256, TransposeValue, false><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), - nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + nullptr, nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, false); } } } else { if (head_dim == 16) { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256, TransposeValue, false><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 16, 256, TransposeValue, false><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), - nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + nullptr, nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, false); } else if (head_dim == 64) { if constexpr (ToFp8 && TransposeValue) { if (use_kv1) { if (use_kv_static_1024) { - prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, PrepackFp8Lane, 64, false, false, false, true, 256, 1024><<>>( + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, PrepackFp8Lane, 64, false, false, false, true, 256, 1024><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } else { - prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, PrepackFp8Lane><<>>( + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, PrepackFp8Lane><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), @@ -6165,18 +6816,18 @@ std::vector prepare_kv_hnd_packed_gfx12( } } else { const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); - prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 2, false, PrepackFp8Lane><<>>( + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 2, false, PrepackFp8Lane><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } } else { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, false, PrepackF16VLane, PrepackF16KLane><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 64, D64PrepThreads, TransposeValue, false, PrepackF16VLane, PrepackF16KLane><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), - nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + nullptr, nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, false); } @@ -6184,13 +6835,13 @@ std::vector prepare_kv_hnd_packed_gfx12( if constexpr (ToFp8 && TransposeValue) { if (use_kv1) { if (use_kv_static_1024) { - prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, false, 128, false, false, false, true, 256, 1024><<>>( + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, false, 128, false, false, false, true, 256, 1024><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } else { - prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, false, 128><<>>( + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 1, false, false, 128><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), @@ -6198,18 +6849,18 @@ std::vector prepare_kv_hnd_packed_gfx12( } } else { const dim3 grid_kv((k_groups + 1) / 2, kv_heads, batch); - prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 2, false, false, 128><<>>( + prepare_kv_hnd_fp8_kernel<__hip_bfloat16, 2, false, false, 128><<>>( reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), key_ptr, key_scale_ptr, reinterpret_cast(value_ptr), batch, kv_heads, kv_len, k_groups); } } else { - prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, TransposeValue, false><<>>( + prepare_qkv_hnd_kernel<__hip_bfloat16, OutT, ToFp8, 128, 256, TransposeValue, false><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), - nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, + nullptr, nullptr, key_ptr, nullptr, key_scale_ptr, value_ptr, batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups, false); } @@ -6317,33 +6968,34 @@ std::vector prepare_k_hnd_packed_gfx12(Tensor key) { Tensor byte_workspace = new_empty_like(key, {k_numel}, ScalarType::Byte); Tensor scale_workspace = new_empty_like(key, {batch * kv_heads * k_groups}, ScalarType::Float); - int8_t* key_ptr = reinterpret_cast(reinterpret_cast(byte_workspace.data_ptr())); + int8_t* key_ptr = reinterpret_cast(byte_workspace.data_ptr()); Tensor key_out = from_blob_like(key_ptr, {batch, kv_heads, kv_len, head_dim}, key, ScalarType::Char); Tensor key_scale = from_blob_like(reinterpret_cast(scale_workspace.data_ptr()), {batch, kv_heads, k_groups}, key, ScalarType::Float); constexpr int Threads = 256; const dim3 block(Threads); const dim3 grid(k_groups, kv_heads, batch); - if (key.scalar_type() == ScalarType::Half) { + const hipStream_t stream = current_hip_stream(key); + if (key.scalar_type() == ScalarType::Half) { if (head_dim == 64) { - prepare_k_hnd_kernel<__half, 64, Threads><<>>( + prepare_k_hnd_kernel<__half, 64, Threads><<>>( reinterpret_cast(key.data_ptr()), key_ptr, reinterpret_cast(scale_workspace.data_ptr()), batch, kv_heads, kv_len, k_groups); } else { - prepare_k_hnd_kernel<__half, 128, Threads><<>>( + prepare_k_hnd_kernel<__half, 128, Threads><<>>( reinterpret_cast(key.data_ptr()), key_ptr, reinterpret_cast(scale_workspace.data_ptr()), batch, kv_heads, kv_len, k_groups); } } else { if (head_dim == 64) { - prepare_k_hnd_kernel<__hip_bfloat16, 64, Threads><<>>( + prepare_k_hnd_kernel<__hip_bfloat16, 64, Threads><<>>( reinterpret_cast(key.data_ptr()), key_ptr, reinterpret_cast(scale_workspace.data_ptr()), batch, kv_heads, kv_len, k_groups); } else { - prepare_k_hnd_kernel<__hip_bfloat16, 128, Threads><<>>( + prepare_k_hnd_kernel<__hip_bfloat16, 128, Threads><<>>( reinterpret_cast(key.data_ptr()), key_ptr, reinterpret_cast(scale_workspace.data_ptr()), batch, kv_heads, kv_len, k_groups); @@ -6382,7 +7034,8 @@ Tensor convert_f16_to_bf16_gfx12(Tensor input) { const int threads = 256; const dim3 block(threads); const dim3 grid((numel + threads * 2 - 1) / (threads * 2)); - convert_f16_to_bf16_kernel<<>>( + const hipStream_t stream = current_hip_stream(input); + convert_f16_to_bf16_kernel<<>>( reinterpret_cast(input.data_ptr()), reinterpret_cast<__hip_bfloat16*>(output.data_ptr()), numel); @@ -6404,48 +7057,44 @@ std::vector quant_q_nhd_per_warp_gfx12(Tensor query) { "gfx12 NHD Q quantization supports head_dim 16, 64, or 128"); const int q_scale_groups = static_cast(((q_len + 127) / 128) * 4); Tensor query_out = new_empty_like(query, query.sizes(), ScalarType::Char); - Tensor query_scale = new_empty_like(query, {batch, q_heads, q_scale_groups}, ScalarType::Float); + Tensor query_scale = + new_empty_like(query, {batch, q_heads, q_scale_groups}, ScalarType::Float); const dim3 block(256); const dim3 grid((q_scale_groups + 1) / 2, q_heads, batch); + const hipStream_t stream = current_hip_stream(query); if (query.scalar_type() == ScalarType::Half) { if (head_dim == 16) { - quant_q_nhd_per_warp_kernel<__half, 16><<>>( + quant_q_nhd_per_warp_kernel<__half, 16><<>>( reinterpret_cast(query.data_ptr()), - reinterpret_cast(query_out.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), batch, q_len, q_heads, q_scale_groups); } else if (head_dim == 64) { - quant_q_nhd_per_warp_kernel<__half, 64><<>>( + quant_q_nhd_per_warp_kernel<__half, 64><<>>( reinterpret_cast(query.data_ptr()), - reinterpret_cast(query_out.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), batch, q_len, q_heads, q_scale_groups); } else { - quant_q_nhd_per_warp_kernel<__half, 128><<>>( + quant_q_nhd_per_warp_kernel<__half, 128><<>>( reinterpret_cast(query.data_ptr()), - reinterpret_cast(query_out.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), batch, q_len, q_heads, q_scale_groups); } } else { if (head_dim == 16) { - quant_q_nhd_per_warp_kernel<__hip_bfloat16, 16><<>>( + quant_q_nhd_per_warp_kernel<__hip_bfloat16, 16><<>>( reinterpret_cast(query.data_ptr()), - reinterpret_cast(query_out.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), batch, q_len, q_heads, q_scale_groups); } else if (head_dim == 64) { - quant_q_nhd_per_warp_kernel<__hip_bfloat16, 64><<>>( + quant_q_nhd_per_warp_kernel<__hip_bfloat16, 64><<>>( reinterpret_cast(query.data_ptr()), - reinterpret_cast(query_out.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), batch, q_len, q_heads, q_scale_groups); } else { - quant_q_nhd_per_warp_kernel<__hip_bfloat16, 128><<>>( + quant_q_nhd_per_warp_kernel<__hip_bfloat16, 128><<>>( reinterpret_cast(query.data_ptr()), - reinterpret_cast(query_out.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(query_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), batch, q_len, q_heads, q_scale_groups); } } @@ -6488,23 +7137,24 @@ std::vector quant_qk_int8_hnd_gfx12(Tensor query, Tensor key) { const dim3 block(256); const dim3 grid(q_groups + k_groups, std::max(q_heads, kv_heads), batch); - if (query.scalar_type() == ScalarType::Half) { + const hipStream_t stream = current_hip_stream(query); + if (query.scalar_type() == ScalarType::Half) { if (head_dim == 16) { - quant_qk_int8_hnd_kernel<__half, 16><<>>( + quant_qk_int8_hnd_kernel<__half, 16><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); } else if (head_dim == 64) { - quant_qk_int8_hnd_kernel<__half, 64><<>>( + quant_qk_int8_hnd_kernel<__half, 64><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); } else { - quant_qk_int8_hnd_kernel<__half, 128><<>>( + quant_qk_int8_hnd_kernel<__half, 128><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), @@ -6513,21 +7163,21 @@ std::vector quant_qk_int8_hnd_gfx12(Tensor query, Tensor key) { } } else { if (head_dim == 16) { - quant_qk_int8_hnd_kernel<__hip_bfloat16, 16><<>>( + quant_qk_int8_hnd_kernel<__hip_bfloat16, 16><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); } else if (head_dim == 64) { - quant_qk_int8_hnd_kernel<__hip_bfloat16, 64><<>>( + quant_qk_int8_hnd_kernel<__hip_bfloat16, 64><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), batch, q_heads, kv_heads, q_len, kv_len, q_groups, k_groups); } else { - quant_qk_int8_hnd_kernel<__hip_bfloat16, 128><<>>( + quant_qk_int8_hnd_kernel<__hip_bfloat16, 128><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(query_out.data_ptr()), reinterpret_cast(key_out.data_ptr()), @@ -6539,7 +7189,7 @@ std::vector quant_qk_int8_hnd_gfx12(Tensor query, Tensor key) { return {query_out, query_scale, key_out, key_scale}; } -static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( +Tensor qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( Tensor query, Tensor key, Tensor value, @@ -6599,6 +7249,8 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( "native gfx12 prepare+attention requires sequence lengths divisible by 64"); STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, "native gfx12 causal prepare+attention requires q_len == kv_len"); + STD_TORCH_CHECK(pv_accum_mode >= -1 && pv_accum_mode <= 1, + "invalid gfx12 fp16 PV accumulation mode"); const auto output_dtype = (value_is_fp8 && query.scalar_type() == ScalarType::BFloat16) ? ScalarType::BFloat16 : ScalarType::Half; @@ -6606,17 +7258,18 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( if (!value_is_fp8) { output = new_empty_like(query, query.sizes(), output_dtype); } + const bool force_fp32_pv_accum = !value_is_fp8 && pv_accum_mode == 0; const bool prefer_prepared_f16_causal = - !value_is_fp8 && head_dim == 64 && is_causal && + !force_fp32_pv_accum && !value_is_fp8 && head_dim == 64 && is_causal && query.scalar_type() == ScalarType::Half && q_len >= 4096; const bool auto_f16_fused_q = - !value_is_fp8 && (head_dim == 16 || head_dim == 64) && + !force_fp32_pv_accum && !value_is_fp8 && (head_dim == 16 || head_dim == 64) && query.scalar_type() == ScalarType::Half && (is_causal || q_len >= 2048 || (head_dim == 64 && q_len >= 1024)) && q_len <= 8192 && !prefer_prepared_f16_causal; const bool auto_f16_raw_qk = - !value_is_fp8 && is_causal && head_dim == 16 && + !force_fp32_pv_accum && !value_is_fp8 && is_causal && head_dim == 16 && query.scalar_type() == ScalarType::Half && q_len <= 2048; if (!value_is_fp8 && is_causal && (head_dim == 16 || head_dim == 64) && query.scalar_type() == ScalarType::Half && @@ -6626,12 +7279,13 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( "native raw-QK fp16 path requires q_len to be a multiple of block rows"); const dim3 block(block_rows); const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); - const bool use_f16_pv_accum = + const hipStream_t stream = current_hip_stream(query); + const bool use_f16_pv_accum = auto_f16_raw_qk && q_len <= 1024; constexpr bool use_f16_tvload = false; #define SAGEATTN_LAUNCH_RAW_QK_F16_CAUSAL(BR_, TVLOAD_, PAD_, F16ACC_) \ if (head_dim == 16) { \ - qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, false, PAD_, true, TVLOAD_, F16ACC_, __half, true, __half, true, false, false, false, false, 16><<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, false, PAD_, true, TVLOAD_, F16ACC_, __half, true, __half, true, false, false, false, false, 16><<>>( \ reinterpret_cast(query.data_ptr()), \ reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ @@ -6645,7 +7299,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( 0, 0, 0, 0, \ kHND, sm_scale); \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, false, PAD_, true, TVLOAD_, F16ACC_, __half, true, __half, true><<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, false, PAD_, true, TVLOAD_, F16ACC_, __half, true, __half, true><<>>( \ reinterpret_cast(query.data_ptr()), \ reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ @@ -6716,7 +7370,8 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( use_fused_q && block_cols == 64 && auto_fp8_streamcols4; const dim3 block(block_rows); const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); - std::vector prepared; + const hipStream_t stream = current_hip_stream(query); + std::vector prepared; int8_t* fused_key_ptr = nullptr; uint8_t* fused_value_ptr = nullptr; float* fused_k_scale_ptr = nullptr; @@ -6759,12 +7414,11 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( fused_ks_stride_h = prepared[1].stride(1); } #define SAGEATTN_LAUNCH_PREPARED_FP8_EX(BC_, HD_, BR_, VT_, CAUSAL_, OUT_T_, LOWP_) \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(prepared[0].data_ptr()), reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast(prepared[4].data_ptr()), \ reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(prepared[1].data_ptr()), reinterpret_cast(prepared[3].data_ptr()), \ - nullptr, \ + reinterpret_cast(prepared[1].data_ptr()), reinterpret_cast(prepared[3].data_ptr()), nullptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ prepared[0].stride(0), prepared[0].stride(2), prepared[0].stride(1), \ prepared[2].stride(0), prepared[2].stride(2), prepared[2].stride(1), \ @@ -6777,12 +7431,11 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( SAGEATTN_LAUNCH_PREPARED_FP8_EX(BC_, HD_, BR_, VT_, CAUSAL_, OUT_T_, false) #define SAGEATTN_LAUNCH_FUSED_Q_FP8_IMPL_SLICE(BC_, HD_, BR_, CAUSAL_, KVLANE_, SC_, KLANE_, VLANE_, VBASE_, VTILES_) \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), fused_key_ptr, \ fused_value_ptr, \ reinterpret_cast<__half*>(output.data_ptr()), \ - nullptr, fused_k_scale_ptr, \ - nullptr, \ + nullptr, fused_k_scale_ptr, nullptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ query.stride(0), query.stride(2), query.stride(1), \ fused_k_stride_b, fused_k_stride_n, fused_k_stride_h, \ @@ -6917,7 +7570,8 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( const int64_t q_blocks = (q_len + block_rows - 1) / block_rows; const dim3 grid(q_blocks, q_heads, batch); const dim3 grid_f16_flat(q_blocks * q_heads * batch); - constexpr bool use_f16_pv_accum = true; + const hipStream_t stream = current_hip_stream(query); + constexpr bool use_f16_pv_accum = true; const bool use_f16_pv_ordered_qk = use_f16_pv_accum && !use_f16_raw_value && q_len >= 1024; @@ -6937,7 +7591,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( prepare_kv_hnd_packed_gfx12<__half, false>(query, key, value); #define SAGEATTN_LAUNCH_F16_FUSED_Q_TV_CAUSAL(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_) \ if (head_dim == 16) { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -6952,7 +7606,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( kHND, sm_scale); \ } else { \ if (use_f16_flat_q_schedule) { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -6966,7 +7620,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( prepared[1].stride(0), prepared[1].stride(1), \ kHND, sm_scale); \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -6983,7 +7637,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( } #define SAGEATTN_LAUNCH_F16_FUSED_Q_TV_NONCAUSAL(BR_, PAD_, F16ACC_, PVORDER_, VLANE_, KLANE_) \ if (head_dim == 16) { \ - qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, PAD_, false, false, F16ACC_, __half, true, int8_t, false, PVORDER_, false, false, false, 16><<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, PAD_, false, false, F16ACC_, __half, true, int8_t, false, PVORDER_, false, false, false, 16><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -6997,7 +7651,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( prepared[1].stride(0), prepared[1].stride(1), \ kHND, sm_scale); \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, PAD_, false, false, F16ACC_, __half, true, int8_t, false, PVORDER_, VLANE_, false, KLANE_><<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, PAD_, false, false, F16ACC_, __half, true, int8_t, false, PVORDER_, VLANE_, false, KLANE_><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7012,7 +7666,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( kHND, sm_scale); \ } #define SAGEATTN_LAUNCH_F16_FUSED_Q_RAWV_CAUSAL(BC_, BR_, PAD_, F16ACC_, VLANE_, STREAM_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7026,7 +7680,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( prepared[1].stride(0), prepared[1].stride(1), \ kHND, sm_scale) #define SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_RAWV_CAUSAL(BR_, F16ACC_, SPLIT_) \ - qk_int8_sv_f16_d64_native_kernel<64, BR_, true, false, 4, true, true, F16ACC_, true, __half, true, SPLIT_><<>>( \ + qk_int8_sv_f16_d64_native_kernel<64, BR_, true, false, 4, true, true, F16ACC_, true, __half, true, SPLIT_><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7040,7 +7694,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( prepared[1].stride(0), prepared[1].stride(1), \ kHND, sm_scale) #define SAGEATTN_LAUNCH_F16_FUSED_Q_1Q_TV_CAUSAL(BR_, F16ACC_) \ - qk_int8_sv_f16_d64_native_kernel<64, BR_, true, true, 4, true, false, F16ACC_, true, __half, true><<>>( \ + qk_int8_sv_f16_d64_native_kernel<64, BR_, true, true, 4, true, false, F16ACC_, true, __half, true><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(prepared[0].data_ptr()), \ reinterpret_cast(prepared[2].data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7125,7 +7779,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( prepared[0], prepared[2], value, output, prepared[1], prepared[3], kHND, is_causal, sm_scale, kv_len, Tensor(), 0, - static_cast(pv_accum_mode)); + pv_accum_mode); } else { const bool use_f16_separate_prepared = is_causal && head_dim == 64 && q_len == 4096 && @@ -7137,7 +7791,7 @@ Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12( qk_int8_sv_f16_d64_native_attn_gfx12_dispatch( prepared[0], prepared[2], prepared[4], output, prepared[1], prepared[3], kHND, is_causal, sm_scale, kv_len, Tensor(), 1, - static_cast(pv_accum_mode)); + pv_accum_mode); } return output; } @@ -7174,11 +7828,6 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( #endif STD_TORCH_CHECK(value.scalar_type() == ScalarType::Half || value_is_fp8, "value must be fp16 or raw OCP e4m3 fp8 bytes"); - const bool has_value_scale = value_scale.defined() && value_scale.numel() > 0; - STD_TORCH_CHECK(!has_value_scale || value_is_fp8, - "value_scale is only valid for the fp8 value path"); - const float* value_scale_ptr = - has_value_scale ? reinterpret_cast(value_scale.data_ptr()) : nullptr; const bool output_is_bf16 = output.scalar_type() == ScalarType::BFloat16; STD_TORCH_CHECK(output.scalar_type() == ScalarType::Half || (value_is_fp8 && output_is_bf16), "output must be fp16, or bf16 for the fp8 value path"); @@ -7216,8 +7865,48 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( STD_TORCH_CHECK(!is_causal || q_len == padded_kv_len, "native gfx12 causal path currently requires q_len == kv_len"); STD_TORCH_CHECK((q_heads % kv_heads) == 0, "q_heads must be divisible by kv_heads"); + STD_TORCH_CHECK(pv_accum_mode >= -1 && pv_accum_mode <= 1, + "invalid gfx12 fp16 PV accumulation mode"); STD_TORCH_CHECK(query_scale.stride(-1) == 1 && key_scale.stride(-1) == 1, "scale tensors must have contiguous scale columns"); + const int64_t per_warp_q_groups = ((q_len + 127) / 128) * 4; + const int64_t per_thread_q_groups_warp32 = ((q_len + 127) / 128) * 32; + const int64_t per_thread_q_groups_warp16 = ((q_len + 127) / 128) * 64; + const int64_t per_warp_k_groups = (padded_kv_len + 63) / 64; + const int64_t per_thread_k_groups = ((padded_kv_len + 63) / 64) * 4; + const bool use_per_thread_qk = + query_scale.size(2) == per_thread_q_groups_warp32 || + query_scale.size(2) == per_thread_q_groups_warp16 || + key_scale.size(2) == per_thread_k_groups; + STD_TORCH_CHECK((query_scale.size(2) == per_warp_q_groups && + key_scale.size(2) == per_warp_k_groups) || + ((query_scale.size(2) == per_thread_q_groups_warp32 || + query_scale.size(2) == per_thread_q_groups_warp16) && + key_scale.size(2) == per_thread_k_groups), + "gfx12 query/key scale shapes must both be per-warp or both be per-thread"); + if constexpr (!PerThreadQK) { + if (use_per_thread_qk) { + return qk_int8_sv_f16_d64_native_attn_gfx12_impl( + query, key, value, output, query_scale, key_scale, tensor_layout, + is_causal, sm_scale, valid_kv_len, value_scale, + value_transposed_hnd_hint, pv_accum_mode); + } + } + const bool has_value_scale = value_scale.defined() && value_scale.numel() > 0; + STD_TORCH_CHECK(!has_value_scale || value_is_fp8, + "value_scale is only valid for the fp8 value path"); + if (has_value_scale) { + STD_TORCH_CHECK(value_scale.is_cuda(), "value_scale must be a CUDA/HIP tensor"); + STD_TORCH_CHECK(value_scale.scalar_type() == ScalarType::Float, + "value_scale must be fp32"); + STD_TORCH_CHECK(value_scale.dim() == 3 && value_scale.is_contiguous(), + "value_scale must be contiguous [B, H_kv, D]"); + STD_TORCH_CHECK(value_scale.size(0) == batch && + value_scale.size(1) == kv_heads && + value_scale.size(2) == head_dim, + "value_scale shape must match [B, H_kv, D]"); + } + const float* value_scale_ptr = has_value_scale ? reinterpret_cast(value_scale.data_ptr()) : nullptr; const bool hnd_contiguous = tensor_layout == kHND && query.is_contiguous() && key.is_contiguous() && value.is_contiguous() && output.is_contiguous(); @@ -7277,6 +7966,13 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( q_len >= 4096 && (q_len % 256) == 0) { block_rows = 256; } + if constexpr (PerThreadQK) { + block_cols = 64; + block_rows = q_len <= 64 ? 64 : 128; + use_2q = !value_is_fp8; + use_fp8_2q = value_is_fp8; + use_f16_causal_1q = false; + } STD_TORCH_CHECK(!(use_fp8_2q && block_rows == 64 && block_cols == 128), "native fp8 2q BR64 is currently specialized for BC32/BC64"); STD_TORCH_CHECK(!(use_fp8_2q && block_rows == 256 && block_cols != 64), @@ -7302,12 +7998,14 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( dim3 block((use_2q || use_fp8_2q) ? block_rows : (block_rows == 128 ? 256 : 128)); dim3 grid(q_blocks, q_heads, batch); dim3 grid_f16_flat(q_blocks * q_heads * batch); - const bool use_f16_tvload = + const hipStream_t stream = current_hip_stream(query); + const bool use_f16_tvload = !value_is_fp8 && is_causal && hnd_contiguous && !value_transposed_hnd && q_len >= 1024; const bool use_f16_pv_accum = - !value_is_fp8 && is_causal && value_transposed_hnd && block_cols == 64 && - q_len >= 1024; + !value_is_fp8 && pv_accum_mode != 0 && + (pv_accum_mode == 1 || + (is_causal && value_transposed_hnd && block_cols == 64 && q_len >= 1024)); const bool use_f16_pv_ordered_qk = use_f16_pv_accum && q_len >= 4096; const bool use_f16_vlane = @@ -7316,11 +8014,14 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( const bool use_f16_streamk = head_dim == 64 && !value_is_fp8 && is_causal && value_transposed_hnd && q_len == 4096 && block_rows == 256; + const bool use_f16_d128_short_stream = + head_dim == 128 && !value_is_fp8 && is_causal && block_cols == 64 && + block_rows == 128 && q_len <= 1024; if constexpr (PerThreadQK) { STD_TORCH_CHECK(value_transposed_hnd, "gfx12 per-thread QK path expects transposed HND values"); #define SAGEATTN_LAUNCH_PERTHREAD_FP8_OUT(HD_, BR_, CAUSAL_, OUT_T_) \ - qk_int8_sv_f8_native_2q_kernel<64, HD_, 0, ((HD_) / 16), true, BR_, true, CAUSAL_, OUT_T_, int8_t, false, int8_t, uint8_t, false, false, 0, false, false, 2, false, true><<>>( \ + qk_int8_sv_f8_native_2q_kernel<64, HD_, 0, ((HD_) / 16), true, BR_, true, CAUSAL_, OUT_T_, int8_t, false, int8_t, uint8_t, false, false, 0, false, false, 2, false, true><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ @@ -7340,7 +8041,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( SAGEATTN_LAUNCH_PERTHREAD_FP8_OUT(HD_, BR_, CAUSAL_, __half); \ } #define SAGEATTN_LAUNCH_PERTHREAD_F16(HD_, BR_, CAUSAL_) \ - qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, 4, CAUSAL_, false, false, int8_t, false, int8_t, false, false, false, false, false, HD_, false, true><<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<64, true, BR_, true, 4, CAUSAL_, false, false, int8_t, false, int8_t, false, false, false, false, false, HD_, false, true><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7391,12 +8092,11 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } #define SAGEATTN_LAUNCH_FP8_2Q_OUT(BC_, HD_, HND_, BR_, OUT_T_) \ if (is_causal) { \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ - value_scale_ptr, \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ @@ -7404,14 +8104,13 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } else { \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ - value_scale_ptr, \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ @@ -7419,7 +8118,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } #define SAGEATTN_LAUNCH_FP8_2Q(BC_, HD_, HND_, BR_) \ if (output_is_bf16) { \ @@ -7429,12 +8128,11 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( } #define SAGEATTN_LAUNCH_FP8_2Q_TV_OUT(BC_, HD_, BR_, OUT_T_) \ if (is_causal) { \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ - value_scale_ptr, \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ @@ -7442,14 +8140,13 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } else { \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ - reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ - value_scale_ptr, \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, \ batch, q_len, kv_len, q_heads, kv_heads, \ query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ @@ -7457,7 +8154,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } #define SAGEATTN_LAUNCH_FP8_2Q_TV(BC_, HD_, BR_) \ if (output_is_bf16) { \ @@ -7466,7 +8163,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( SAGEATTN_LAUNCH_FP8_2Q_TV_OUT(BC_, HD_, BR_, __half); \ } #define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(HD_, BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7478,7 +8175,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale) + tensor_layout, sm_scale, use_per_thread_qk) #define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) \ SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(64, BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, GRID_, FLAT_) #define SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_) \ @@ -7488,7 +8185,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID(BC_, BR_, PAD_, F16ACC_, PVORDER_, VLANE_, STREAM_, KLANE_, grid, false); \ } #define SAGEATTN_LAUNCH_F16_D16_2Q_TV(BC_, BR_, PAD_, CAUSAL_, F16ACC_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7500,12 +8197,31 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale) + tensor_layout, sm_scale, use_per_thread_qk) #define SAGEATTN_LAUNCH_F16_D128_2Q_TV(BC_, BR_, PAD_) \ if (is_causal) { \ - SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(128, BC_, BR_, PAD_, false, false, false, false, false, grid, false); \ + if (use_f16_d128_short_stream) { \ + SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(128, BC_, BR_, PAD_, false, false, false, true, false, grid, false); \ + } else { \ + SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL_GRID_HD(128, BC_, BR_, PAD_, false, false, false, false, false, grid, false); \ + } \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + if (use_f16_d128_short_stream) { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast<__half*>(output.data_ptr()), \ + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), \ + batch, q_len, kv_len, q_heads, kv_heads, \ + query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), \ + key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), \ + value.stride(0), value.stride(2), value.stride(1), \ + output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ + query_scale.stride(0), query_scale.stride(1), \ + key_scale.stride(0), key_scale.stride(1), \ + tensor_layout, sm_scale, use_per_thread_qk); \ + } else { \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7517,7 +8233,8 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ + } \ } #define SAGEATTN_LAUNCH_F16_2Q_TV(BC_, BR_, PAD_) \ if (is_causal) { \ @@ -7540,7 +8257,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( else { SAGEATTN_LAUNCH_F16_2Q_TV_CAUSAL(BC_, BR_, PAD_, false, false, false, false, false); } \ } \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7552,11 +8269,11 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } #define SAGEATTN_LAUNCH_F16_2Q(BC_, HND_, BR_) \ if (is_causal) { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7568,9 +8285,9 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7582,10 +8299,10 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } #define SAGEATTN_LAUNCH_F16_2Q_TVLOAD_CAUSAL(BC_, BR_, PAD_, F16ACC_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7597,12 +8314,12 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale) + tensor_layout, sm_scale, use_per_thread_qk) #define SAGEATTN_LAUNCH_F16_2Q_TVLOAD(BC_, BR_, PAD_) \ if (is_causal) { \ SAGEATTN_LAUNCH_F16_2Q_TVLOAD_CAUSAL(BC_, BR_, PAD_, false); \ } else { \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7614,10 +8331,10 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale); \ + tensor_layout, sm_scale, use_per_thread_qk); \ } #define SAGEATTN_LAUNCH_F16_1Q(BC_, BR_) \ - qk_int8_sv_f16_d64_native_kernel<<>>( \ + qk_int8_sv_f16_d64_native_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7629,9 +8346,9 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale) + tensor_layout, sm_scale, use_per_thread_qk) #define SAGEATTN_LAUNCH_F16_1Q_CAUSAL(BR_, TRANSPOSED_, TVLOAD_, PAD_, F16ACC_) \ - qk_int8_sv_f16_d64_native_kernel<64, BR_, true, TRANSPOSED_, PAD_, true, TVLOAD_, F16ACC_><<>>( \ + qk_int8_sv_f16_d64_native_kernel<64, BR_, true, TRANSPOSED_, PAD_, true, TVLOAD_, F16ACC_><<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -7644,7 +8361,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ query_scale.stride(0), query_scale.stride(1), \ key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, sm_scale) + tensor_layout, sm_scale, use_per_thread_qk) #if SAGEATTN_GFX12_BUILD_ATTN_F16 if (use_f16_causal_1q) { STD_TORCH_CHECK(hnd_contiguous, "fp16 single-q causal path requires contiguous HND tensors"); @@ -7801,12 +8518,11 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( SAGEATTN_LAUNCH_FP8_2Q(32, 64, false, 128); } } else if (use_fp8_2q && block_cols == 128 && head_dim == 128) { - qk_int8_sv_f8_native_2q_kernel<128, 128, 0, 8><<>>( + qk_int8_sv_f8_native_2q_kernel<128, 128, 0, 8><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), reinterpret_cast<__half*>(output.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), - value_scale_ptr, + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, batch, q_len, kv_len, q_heads, kv_heads, query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), @@ -7814,14 +8530,13 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), query_scale.stride(0), query_scale.stride(1), key_scale.stride(0), key_scale.stride(1), - tensor_layout, sm_scale); + tensor_layout, sm_scale, use_per_thread_qk); } else if (use_fp8_2q && block_cols == 128) { - qk_int8_sv_f8_native_2q_kernel<128, 64, 0, 4><<>>( + qk_int8_sv_f8_native_2q_kernel<128, 64, 0, 4><<>>( reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), reinterpret_cast<__half*>(output.data_ptr()), - reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), - value_scale_ptr, + reinterpret_cast(query_scale.data_ptr()), reinterpret_cast(key_scale.data_ptr()), value_scale_ptr, batch, q_len, kv_len, q_heads, kv_heads, query.stride(0), query.stride(tensor_layout == kNHD ? 1 : 2), query.stride(tensor_layout == kNHD ? 2 : 1), key.stride(0), key.stride(tensor_layout == kNHD ? 1 : 2), key.stride(tensor_layout == kNHD ? 2 : 1), @@ -7829,7 +8544,7 @@ static Tensor qk_int8_sv_f16_d64_native_attn_gfx12_impl( output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), query_scale.stride(0), query_scale.stride(1), key_scale.stride(0), key_scale.stride(1), - tensor_layout, sm_scale); + tensor_layout, sm_scale, use_per_thread_qk); } else if (use_fp8_2q && head_dim == 128) { if (hnd_contiguous) { SAGEATTN_LAUNCH_FP8_2Q(64, 128, true, 128); @@ -8125,8 +8840,7 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( value_scale.size(2) == head_dim, "value_scale shape must match [B, H_kv, D]"); } - const float* value_scale_ptr = - has_value_scale ? reinterpret_cast(value_scale.data_ptr()) : nullptr; + const float* value_scale_ptr = has_value_scale ? reinterpret_cast(value_scale.data_ptr()) : nullptr; int block_rows = head_dim == 64 ? select_fp8_d64_block_rows_gfx12(q_len, is_causal, value_transposed_hnd) : @@ -8149,9 +8863,10 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( const bool hnd_contiguous = tensor_layout == kHND; const dim3 block(block_rows); const dim3 grid(q_blocks, q_heads, batch); + const hipStream_t stream = current_hip_stream(query); -#define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_) \ - qk_int8_sv_f8_native_2q_kernel<<>>( \ +#define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_, INVL_) \ + qk_int8_sv_f8_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), \ reinterpret_cast(key.data_ptr()), reinterpret_cast(value.data_ptr()), \ reinterpret_cast(output.data_ptr()), \ @@ -8165,7 +8880,7 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( key_scale.stride(0), key_scale.stride(1), \ tensor_layout, sm_scale) #define SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_) \ - SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_, false, false, false, false) + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, OUT_T_, QUERY_AT_T_, OUT_AT_T_, false, false, false, false, false) #define SAGEATTN_DISPATCH_RAWQ_FP8_OUT(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, QUERY_AT_T_) \ if (output.scalar_type() == ScalarType::BFloat16) { \ SAGEATTN_LAUNCH_RAWQ_FP8_TYPED(BC_, HD_, HND_, KEY_HND_, BR_, VT_, CAUSAL_, QUERY_T_, __hip_bfloat16, QUERY_AT_T_, at::BFloat16); \ @@ -8183,8 +8898,6 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 64, VT_, CAUSAL_); \ } else if (block_rows == 256) { \ SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 256, VT_, CAUSAL_); \ - } else if (block_rows == 512) { \ - SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 512, VT_, CAUSAL_); \ } else { \ SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(BC_, HD_, HND_, KEY_HND_, 128, VT_, CAUSAL_); \ } @@ -8228,11 +8941,32 @@ static Tensor qk_rawq_int8_sv_f8_native_attn_gfx12_impl( !key_hnd_contiguous && q_heads == kv_heads && q_len == kv_len && q_len == 512 && head_dim == 128 && query.scalar_type() == ScalarType::Half && output.scalar_type() == ScalarType::Half; + const bool use_static_causal_nhd = + is_causal && value_transposed_hnd && tensor_layout == kNHD && + !key_hnd_contiguous && q_heads == kv_heads && q_len == kv_len && + block_rows == 128 && + (q_len % block_rows) == 0 && + (head_dim == 128 || (head_dim == 64 && q_len >= 1024)) && + query.scalar_type() == ScalarType::Half && output.scalar_type() == ScalarType::Half; + const bool use_bc32_causal_short_nhd = + use_static_causal_nhd && head_dim == 128 && q_len <= 1024; if (use_static_short_nhd) { SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(64, 128, false, false, 128, true, false, __half, __half, at::Half, at::Half, - true, true, true, true); + true, true, true, true, false); + } else if (use_static_causal_nhd && head_dim == 64) { + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(64, 64, false, false, 128, true, true, + __half, __half, at::Half, at::Half, + true, true, true, true, false); + } else if (use_bc32_causal_short_nhd) { + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(32, 128, false, false, 128, true, true, + __half, __half, at::Half, at::Half, + true, true, true, true, true); + } else if (use_static_causal_nhd) { + SAGEATTN_LAUNCH_RAWQ_FP8_TYPED_EX(64, 128, false, false, 128, true, true, + __half, __half, at::Half, at::Half, + true, true, true, true, false); } else if (use_bc32) { SAGEATTN_DISPATCH_RAWQ_FP8_QUERY(32, 128, false, false, 128, true, false); } else { @@ -8265,8 +8999,7 @@ Tensor qk_rawq_int8_sv_f8_native_attn_gfx12( return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( query, key, value, output, key_scale, Tensor(), static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, - static_cast(value_transposed_hnd), + static_cast(sm_scale), valid_kv_len, static_cast(value_transposed_hnd), static_cast(key_hnd_layout)); } @@ -8286,11 +9019,82 @@ Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( return qk_rawq_int8_sv_f8_native_attn_gfx12_impl( query, key, value, output, key_scale, value_scale, static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, - static_cast(value_transposed_hnd), + static_cast(sm_scale), valid_kv_len, static_cast(value_transposed_hnd), static_cast(key_hnd_layout)); } +std::vector mean_and_fp8_value_nhd_short_gfx12( + Tensor key, + Tensor value, + double scale_max); + +Tensor sage_fp8_nhd_short_mha_gfx12( + Tensor query, + Tensor key, + Tensor value, + int64_t is_causal, + double sm_scale, + double scale_max) { + STD_TORCH_CHECK(query.is_cuda() && key.is_cuda() && value.is_cuda(), + "gfx12 short fp8 wrapper expects CUDA/HIP tensors"); + STD_TORCH_CHECK(query.dim() == 4 && key.dim() == 4 && value.dim() == 4, + "gfx12 short fp8 wrapper expects [B, S, H, D]"); + STD_TORCH_CHECK(query.is_contiguous() && key.is_contiguous() && value.is_contiguous(), + "gfx12 short fp8 wrapper expects contiguous NHD tensors"); + STD_TORCH_CHECK(query.scalar_type() == ScalarType::Half && + key.scalar_type() == ScalarType::Half && + value.scalar_type() == ScalarType::Half, + "gfx12 short fp8 wrapper currently supports fp16 inputs"); + STD_TORCH_CHECK(same_sizes(query, key) && same_sizes(query, value), + "gfx12 short fp8 wrapper expects matching Q/K/V shapes"); + const int64_t batch = query.size(0); + const int64_t seq_len = query.size(1); + const int64_t heads = query.size(2); + const int64_t head_dim = query.size(3); + STD_TORCH_CHECK((seq_len == 512 || seq_len == 1024 || seq_len == 2048 || + seq_len == 4096 || seq_len == 8192) && + (head_dim == 64 || head_dim == 128), + "gfx12 fp8 wrapper supports S512/S1024/S2048/S4096/S8192 and D64/D128"); + + std::vector prep = + mean_and_fp8_value_nhd_short_gfx12(key, value, scale_max); + Tensor key_mean = prep[0]; + Tensor value_native = prep[1]; + Tensor value_scale = prep[2]; + Tensor key_int8 = new_empty_like(key, key.sizes(), ScalarType::Char); + Tensor key_scale = + new_empty_like(key, {batch, heads, (seq_len + 63) / 64}, ScalarType::Float); + + const dim3 grid((seq_len + 63) / 64, heads, batch); + const hipStream_t stream = current_hip_stream(key); + if (head_dim == 64) { + constexpr int NumPack = 1; + dim3 block(64 * (64 / 8) / NumPack); + quant_k_nhd_fuse_sub_mean_short_kernel<__half, 64, NumPack><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(key_mean.data_ptr()), + reinterpret_cast(key_int8.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + seq_len, heads); + } else { + constexpr int NumPack = 2; + dim3 block(64 * (128 / 8) / NumPack); + quant_k_nhd_fuse_sub_mean_short_kernel<__half, 128, NumPack><<>>( + reinterpret_cast(key.data_ptr()), + reinterpret_cast(key_mean.data_ptr()), + reinterpret_cast(key_int8.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + seq_len, heads); + } + hip_kernel_launch_check(); + + Tensor output = torch::stable::empty_like(query); + qk_rawq_int8_sv_f8_scaled_native_attn_gfx12( + query, key_int8, value_native, output, key_scale, value_scale, + kNHD, is_causal, sm_scale, seq_len, 1, 0); + return output; +} + #endif // SAGEATTN_GFX12_BUILD_RAWQ_FP8 #if SAGEATTN_GFX12_BUILD_ATTN_FP8 @@ -8308,9 +9112,8 @@ Tensor qk_int8_sv_f8_scaled_native_attn_gfx12( double sm_scale, int64_t valid_kv_len) { return qk_int8_sv_f16_d64_native_attn_gfx12_impl( - query, key, value, output, query_scale, key_scale, - static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, value_scale, 1, -1); + query, key, value, output, query_scale, key_scale, static_cast(tensor_layout), + static_cast(is_causal), static_cast(sm_scale), valid_kv_len, value_scale, 1, -1); } #endif // SAGEATTN_GFX12_BUILD_ATTN_FP8 @@ -8331,9 +9134,8 @@ Tensor qk_int8_sv_f16_d64_native_attn_gfx12( int64_t value_transposed_hnd, int64_t pv_accum_mode) { return qk_int8_sv_f16_d64_native_attn_gfx12_impl( - query, key, value, output, query_scale, key_scale, - static_cast(tensor_layout), static_cast(is_causal), - static_cast(sm_scale), valid_kv_len, Tensor(), + query, key, value, output, query_scale, key_scale, static_cast(tensor_layout), + static_cast(is_causal), static_cast(sm_scale), valid_kv_len, Tensor(), static_cast(value_transposed_hnd), static_cast(pv_accum_mode)); } @@ -8391,9 +9193,10 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( const bool hnd_contiguous = tensor_layout == kHND && query.is_contiguous() && key.is_contiguous() && value.is_contiguous() && output.is_contiguous(); - const int block_rows = q_len <= 64 ? 64 : 128; + int block_rows = q_len <= 64 ? 64 : 128; const dim3 block(block_rows); const dim3 grid((q_len + block_rows - 1) / block_rows, q_heads, batch); + const hipStream_t stream = current_hip_stream(query); const bool use_d128_short_stream = is_causal && head_dim == 128 && block_rows == 128 && q_len <= 1024; const bool use_direct_stream_probs = @@ -8420,7 +9223,7 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( (head_dim == 64 || head_dim == 128); #define SAGEATTN_LAUNCH_RAWQ_F16_VALUE(BC_, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PVORDER_, STATIC_NHD_, NO_TAIL_, SAME_HEADS_, NO_Q_TAIL_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_, DIRECT_PV_OUTFRAG_) \ - qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ + qk_int8_sv_f16_d64_native_2q_kernel<<>>( \ reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), \ reinterpret_cast(value.data_ptr()), \ reinterpret_cast<__half*>(output.data_ptr()), \ @@ -8431,7 +9234,7 @@ Tensor qk_rawq_int8_sv_f16_native_attn_gfx12( value.stride(0), value.stride(tensor_layout == kNHD ? 1 : 2), value.stride(tensor_layout == kNHD ? 2 : 1), \ output.stride(0), output.stride(tensor_layout == kNHD ? 1 : 2), output.stride(tensor_layout == kNHD ? 2 : 1), \ 0, 0, key_scale.stride(0), key_scale.stride(1), \ - tensor_layout, static_cast(sm_scale)) + tensor_layout, sm_scale, false) #define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_DEFAULT(HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_) \ SAGEATTN_LAUNCH_RAWQ_F16_VALUE(64, HD_, HND_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, false, false, false, false, false, false, false, false) #define SAGEATTN_LAUNCH_RAWQ_F16_VALUE_STATIC_NHD(BC_, HD_, BR_, CAUSAL_, QUERY_T_, F16ACC_, STREAM_, PREFETCH_STREAM_V_, DIRECT_STREAM_PROBS_, DIRECT_PV_OUTFRAG_) \ diff --git a/sageattention/core.py b/sageattention/core.py index a4b5fc58..cc47babd 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -101,6 +101,39 @@ def _get_gfx12_native_extension(): return _qattn_gfx12_native +def _try_gfx12_fp8_nhd_short_mha( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + is_causal: bool, + sm_scale: float, + fp8_value_scale_max: float, +) -> Optional[torch.Tensor]: + if not ( + q.is_cuda + and k.is_cuda + and v.is_cuda + and q.device == k.device == v.device + and q.dtype == k.dtype == v.dtype == torch.float16 + and q.is_contiguous() + and k.is_contiguous() + and v.is_contiguous() + and q.dim() == 4 + and k.dim() == 4 + and v.dim() == 4 + and q.shape == k.shape == v.shape + and q.size(1) in (512, 1024, 2048, 4096, 8192) + and q.size(3) in (64, 128) + ): + return None + + torch.cuda.set_device(q.device) + gfx12_native = _get_gfx12_native_extension() + return gfx12_native.sage_fp8_nhd_short_mha( + q, k, v, int(is_causal), float(sm_scale), float(fp8_value_scale_max) + ) + + def _get_gfx12_prepare_attn_hnd(): _get_gfx12_native_extension() return _qattn_gfx12_prepare_attn_hnd @@ -398,9 +431,7 @@ def _with_lse(out: torch.Tensor): out = out if out.dtype == torch.bfloat16 else gfx12_native.convert_f16_to_bf16(out) return _with_lse(out) - if tensor_layout == "NHD" and smooth_k and qk_quant_gran == "per_warp" and not ( - value_dtype == "fp16" and q.size(-1) > 64 - ): + if tensor_layout == "NHD" and smooth_k and qk_quant_gran == "per_warp": q_nhd = q.contiguous() k_nhd = k.contiguous() v_nhd = v.contiguous() @@ -431,14 +462,26 @@ def _with_lse(out: torch.Tensor): if value_dtype == "fp8" and head_dim not in (16, 64, 128): raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") + use_gfx12_fp8_nhd_mha_wrapper = ( + value_dtype == "fp8" + and input_dtype == torch.float16 + and qo_len == kv_len + and kv_len in (512, 1024, 2048, 4096, 8192) + and head_dim in (64, 128) + ) use_short_nhd_fp8_prep = ( value_dtype == "fp8" - and not is_causal and input_dtype == torch.float16 and qo_len == kv_len and kv_len in (512, 1024) and head_dim in (64, 128) ) + if use_gfx12_fp8_nhd_mha_wrapper and head_dim_og in (64, 128) and h_qo == h_kv: + out = _try_gfx12_fp8_nhd_short_mha( + q_nhd, k_nhd, v_nhd, is_causal, float(sm_scale), fp8_value_scale_max + ) + if out is not None: + return _with_lse(out) value_native = None value_scale = None if use_short_nhd_fp8_prep: @@ -448,6 +491,9 @@ def _with_lse(out: torch.Tensor): ) ) k_mean = k_mean_flat.unsqueeze(1) + elif value_dtype == "fp16" and head_dim in (64, 128): + k_mean_flat = gfx12_native.mean_nhd(k_nhd) + k_mean = k_mean_flat.unsqueeze(1) else: k_mean = k_nhd.mean(dim=1, keepdim=True) k_mean_flat = k_mean.squeeze(1) @@ -463,9 +509,13 @@ def _with_lse(out: torch.Tensor): ) use_rawq_f16_value = ( value_dtype == "fp16" - and not is_causal - and head_dim == 64 + and input_dtype == torch.float16 + and head_dim in (64, 128) and qk_quant_gran == "per_warp" + and ( + not is_causal + or (qo_len == kv_len and qo_len % 64 == 0 and kv_len % 64 == 0) + ) ) if use_rawq_tail or use_rawq_f16_value: q_attn = q_nhd @@ -530,7 +580,7 @@ def _with_lse(out: torch.Tensor): int(use_mixed_key_hnd), ) else: - if head_dim == 64 and qk_quant_gran == "per_warp": + if use_rawq_f16_value: gfx12_native.qk_rawq_int8_sv_f16_native_attn( q_attn, k_int8, @@ -604,7 +654,12 @@ def _with_lse(out: torch.Tensor): if value_dtype == "fp8" and head_dim not in (16, 64, 128): raise ValueError("gfx12 fp8 value path currently supports head_dim 16, 64, or 128.") - k_mean = k_hnd.mean(dim=2, keepdim=True) if smooth_k else None + k_mean = None + if smooth_k: + if value_dtype == "fp16" and qk_quant_gran == "per_warp" and head_dim in (64, 128): + k_mean = gfx12_native.mean_hnd(k_hnd).unsqueeze(2) + else: + k_mean = k_hnd.mean(dim=2, keepdim=True) q_hnd, k_hnd, v_hnd = _pad_gfx12_hnd_sequence( q_hnd, k_hnd, v_hnd, qo_len, kv_len, bool(is_causal), k_mean) padded_qo_len = q_hnd.size(2) @@ -700,7 +755,56 @@ def _quant_qk_hnd(q_src: torch.Tensor, k_src: torch.Tensor, km_src: Optional[tor out = out.transpose(1, 2).contiguous() return _with_lse(out) - q_int8, q_scale, k_int8, k_scale = _quant_qk_hnd(q_hnd, k_hnd, k_mean) + use_rawq_hnd_f16 = ( + value_dtype == "fp16" + and input_dtype == torch.float16 + and qk_quant_gran == "per_warp" + and head_dim in (64, 128) + and qo_len == kv_len + and is_causal + and qo_len == 512 + and q_hnd.dtype == k_hnd.dtype == v_hnd.dtype + ) + if use_rawq_hnd_f16: + k_int8 = torch.empty_like(k_hnd, dtype=torch.int8) + k_scale = torch.empty( + (k_hnd.size(0), k_hnd.size(1), (k_hnd.size(2) + 63) // 64), + device=k_hnd.device, + dtype=torch.float32, + ) + _quant_fused.quant_per_block_int8_fuse_sub_mean_cuda( + k_hnd, k_mean.squeeze(2).contiguous(), k_int8, k_scale, 64, 1 + ) + out = torch.empty_like(q_hnd, dtype=torch.float16) + gfx12_native.qk_rawq_int8_sv_f16_native_attn( + q_hnd, k_int8, v_hnd, out, k_scale, + 1, int(is_causal), float(sm_scale), kv_len, pv_accum_mode + ) + out = out[..., :qo_len, :head_dim_og] + if input_dtype != torch.float16 and out.dtype != input_dtype: + out = out.to(input_dtype) + if tensor_layout == "NHD": + out = out.transpose(1, 2).contiguous() + return _with_lse(out) + + use_smooth_hnd_f16_prep = ( + value_dtype == "fp16" + and qk_quant_gran == "per_warp" + and head_dim in (64, 128) + and not is_causal + and qo_len == kv_len + and qo_len in (512, 1024) + and q_hnd.dtype == k_hnd.dtype == v_hnd.dtype + ) + value_native = None + if use_smooth_hnd_f16_prep: + q_int8, q_scale, k_int8, k_scale, value_native = ( + gfx12_native.prepare_qkv_hnd_smooth_f16( + q_hnd, k_hnd, v_hnd, k_mean.squeeze(2).contiguous() + ) + ) + else: + q_int8, q_scale, k_int8, k_scale = _quant_qk_hnd(q_hnd, k_hnd, k_mean) out = torch.empty_like(q_hnd, dtype=torch.float16) if value_dtype == "fp8": value_native, value_scale = _gfx12_fp8_value_native( @@ -711,7 +815,8 @@ def _quant_qk_hnd(q_src: torch.Tensor, k_src: torch.Tensor, km_src: Optional[tor 1, int(is_causal), float(sm_scale), kv_len ) else: - value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) + if value_native is None: + value_native = gfx12_native.transpose_value_f16_hnd(v_hnd) gfx12_native.qk_int8_sv_f16_d64_native_attn( q_int8, k_int8, value_native, out, q_scale, k_scale, 1, int(is_causal), float(sm_scale), kv_len, 1, @@ -793,6 +898,45 @@ def sageattn( arch = get_cuda_arch_versions()[q.device.index] if arch.startswith("gfx12"): + fast_path_keys = {"value_dtype", "smooth_k", "qk_quant_gran", "pv_accum_dtype", "smooth_v"} + value_dtype = kwargs.get("value_dtype", "auto") + value_dtype = value_dtype.lower() if isinstance(value_dtype, str) else value_dtype + gfx12_fast_common = ( + not return_lse + and tensor_layout == "NHD" + and set(kwargs).issubset(fast_path_keys) + and kwargs.get("smooth_k", True) + and kwargs.get("qk_quant_gran", "per_warp") == "per_warp" + and not kwargs.get("smooth_v", False) + and q.is_cuda + and k.is_cuda + and v.is_cuda + and q.device == k.device == v.device + and q.dtype == k.dtype == v.dtype == torch.float16 + and q.is_contiguous() + and k.is_contiguous() + and v.is_contiguous() + and q.dim() == 4 + and k.dim() == 4 + and v.dim() == 4 + and q.size(0) == k.size(0) == v.size(0) + and q.size(1) == k.size(1) == v.size(1) + and q.size(2) == k.size(2) == v.size(2) + and q.size(3) == k.size(3) == v.size(3) + and q.size(1) in (512, 1024, 2048, 4096, 8192) + and q.size(3) in (64, 128) + ) + if ( + gfx12_fast_common + and value_dtype in {"auto", "fp8"} + and kwargs.get("pv_accum_dtype", None) in {None, "fp32+fp16"} + ): + fast_sm_scale = float(sm_scale if sm_scale is not None else q.size(-1) ** -0.5) + out = _try_gfx12_fp8_nhd_short_mha( + q, k, v, is_causal, fast_sm_scale, _GFX12_FP8_VALUE_SCALE_MAX_FP32_FP16 + ) + if out is not None: + return out return sageattn_qk_int8_pv_gfx12_native( q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, **kwargs) diff --git a/sageattention/gfx12_native_compile.py b/sageattention/gfx12_native_compile.py index 3c717d3f..e3661c35 100644 --- a/sageattention/gfx12_native_compile.py +++ b/sageattention/gfx12_native_compile.py @@ -94,6 +94,18 @@ def qk_rawq_int8_sv_f8_scaled_native_attn_fake_impl( return output +@torch.library.register_fake("sageattention_qattn_gfx12_native::sage_fp8_nhd_short_mha") +def sage_fp8_nhd_short_mha_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: int, + sm_scale: float, + scale_max: float, +) -> torch.Tensor: + return torch.empty_like(query) + + @torch.library.register_fake("sageattention_qattn_gfx12_native::qk_int8_sv_f16_d64_prepare_attn_hnd") def qk_int8_sv_f16_d64_prepare_attn_hnd_fake_impl( query: torch.Tensor, @@ -156,6 +168,33 @@ def mean_nhd_fake_impl(input: torch.Tensor) -> torch.Tensor: ) +@torch.library.register_fake("sageattention_qattn_gfx12_native::mean_hnd") +def mean_hnd_fake_impl(input: torch.Tensor) -> torch.Tensor: + return torch.empty( + (input.size(0), input.size(1), input.size(3)), + dtype=input.dtype, + device=input.device, + ) + + +@torch.library.register_fake("sageattention_qattn_gfx12_native::prepare_qkv_hnd_smooth_f16") +def prepare_qkv_hnd_smooth_f16_fake_impl( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_mean: torch.Tensor, +) -> list[torch.Tensor]: + batch, q_heads, q_len, _ = query.shape + _, kv_heads, kv_len, head_dim = key.shape + return [ + torch.empty_like(query, dtype=torch.int8), + torch.empty((batch, q_heads, (q_len + 127) // 128 * 4), dtype=torch.float32, device=query.device), + torch.empty_like(key, dtype=torch.int8), + torch.empty((batch, kv_heads, (kv_len + 63) // 64), dtype=torch.float32, device=key.device), + torch.empty((batch, kv_heads, head_dim, kv_len), dtype=torch.float16, device=value.device), + ] + + @torch.library.register_fake("sageattention_qattn_gfx12_native::mean_and_fp8_value_nhd_short") def mean_and_fp8_value_nhd_short_fake_impl( key: torch.Tensor, From a6f959cb56025cfbce25b219505dffbcf020e88c Mon Sep 17 00:00:00 2001 From: Aaryaman Vasishta Date: Tue, 19 May 2026 23:01:17 +0900 Subject: [PATCH 7/7] gfx12: specialize D64 long causal NHD mean --- csrc/qattn/attn_gfx12_native.h | 2 ++ csrc/qattn/pybind_gfx12_native.cpp | 2 ++ csrc/qattn/qk_int_sv_gfx12_native.cu | 28 ++++++++++++++++++++++++++++ sageattention/core.py | 12 +++++++++++- 4 files changed, 43 insertions(+), 1 deletion(-) diff --git a/csrc/qattn/attn_gfx12_native.h b/csrc/qattn/attn_gfx12_native.h index f4805e2d..631e62ad 100644 --- a/csrc/qattn/attn_gfx12_native.h +++ b/csrc/qattn/attn_gfx12_native.h @@ -109,6 +109,8 @@ std::vector fp8_value_nhd_short_gfx12( Tensor mean_nhd_gfx12(Tensor input); +Tensor mean_nhd_d64_seq32_gfx12(Tensor input); + Tensor mean_hnd_gfx12(Tensor input); std::vector prepare_qkv_hnd_smooth_f16_gfx12( diff --git a/csrc/qattn/pybind_gfx12_native.cpp b/csrc/qattn/pybind_gfx12_native.cpp index 4bd44219..33762359 100644 --- a/csrc/qattn/pybind_gfx12_native.cpp +++ b/csrc/qattn/pybind_gfx12_native.cpp @@ -65,6 +65,7 @@ STABLE_TORCH_LIBRARY(sageattention_qattn_gfx12_native, m) { m.def("transpose_value_fp8_scaled_hnd(Tensor value, Tensor value_scale) -> Tensor"); m.def("fp8_value_nhd_short(Tensor value, float scale_max) -> Tensor[]"); m.def("mean_nhd(Tensor input) -> Tensor"); + m.def("mean_nhd_d64_seq32(Tensor input) -> Tensor"); m.def("mean_hnd(Tensor input) -> Tensor"); m.def("prepare_qkv_hnd_smooth_f16(" "Tensor query, Tensor key, Tensor value, Tensor key_mean" @@ -87,6 +88,7 @@ STABLE_TORCH_LIBRARY_IMPL(sageattention_qattn_gfx12_native, CUDA, m) { m.impl("transpose_value_fp8_scaled_hnd", TORCH_BOX(transpose_value_fp8_scaled_hnd_gfx12)); m.impl("fp8_value_nhd_short", TORCH_BOX(fp8_value_nhd_short_gfx12)); m.impl("mean_nhd", TORCH_BOX(mean_nhd_gfx12)); + m.impl("mean_nhd_d64_seq32", TORCH_BOX(mean_nhd_d64_seq32_gfx12)); m.impl("mean_hnd", TORCH_BOX(mean_hnd_gfx12)); m.impl("prepare_qkv_hnd_smooth_f16", TORCH_BOX(prepare_qkv_hnd_smooth_f16_gfx12)); m.impl("mean_and_fp8_value_nhd_short", TORCH_BOX(mean_and_fp8_value_nhd_short_gfx12)); diff --git a/csrc/qattn/qk_int_sv_gfx12_native.cu b/csrc/qattn/qk_int_sv_gfx12_native.cu index 0201a8a1..e4b1143a 100644 --- a/csrc/qattn/qk_int_sv_gfx12_native.cu +++ b/csrc/qattn/qk_int_sv_gfx12_native.cu @@ -6058,6 +6058,34 @@ Tensor mean_nhd_gfx12(Tensor input) { return mean; } +Tensor mean_nhd_d64_seq32_gfx12(Tensor input) { + STD_TORCH_CHECK(input.is_cuda(), "gfx12 D64 seq32 NHD mean expects a CUDA/HIP tensor"); + STD_TORCH_CHECK(input.dim() == 4, "gfx12 D64 seq32 NHD mean expects [B, S, H, D]"); + STD_TORCH_CHECK(input.is_contiguous(), "gfx12 D64 seq32 NHD mean expects contiguous NHD input"); + STD_TORCH_CHECK(input.scalar_type() == ScalarType::Half, + "gfx12 D64 seq32 NHD mean supports fp16 input"); + + const int64_t batch = input.size(0); + const int64_t seq_len = input.size(1); + const int64_t heads = input.size(2); + const int64_t head_dim = input.size(3); + STD_TORCH_CHECK(head_dim == 64, + "gfx12 D64 seq32 NHD mean expects head_dim 64"); + STD_TORCH_CHECK(seq_len == 2048 || seq_len == 4096 || seq_len == 8192, + "gfx12 D64 seq32 NHD mean expects sequence length 2048/4096/8192"); + + Tensor mean = new_empty_like(input, {batch, heads, head_dim}, input.scalar_type()); + dim3 block(1024); + dim3 grid((head_dim + 31) / 32, heads, batch); + const hipStream_t stream = current_hip_stream(input); + mean_nhd_short_kernel<__half, 32, 32><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast<__half*>(mean.data_ptr()), + seq_len, heads, head_dim); + hip_kernel_launch_check(); + return mean; +} + Tensor mean_hnd_gfx12(Tensor input) { STD_TORCH_CHECK(input.is_cuda(), "gfx12 HND mean expects a CUDA/HIP tensor"); STD_TORCH_CHECK(input.dim() == 4, "gfx12 HND mean expects [B, H, S, D]"); diff --git a/sageattention/core.py b/sageattention/core.py index cc47babd..8adc5b1c 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -492,7 +492,17 @@ def _with_lse(out: torch.Tensor): ) k_mean = k_mean_flat.unsqueeze(1) elif value_dtype == "fp16" and head_dim in (64, 128): - k_mean_flat = gfx12_native.mean_nhd(k_nhd) + use_d64_causal_seq32_mean = ( + input_dtype == torch.float16 + and is_causal + and head_dim == 64 + and qo_len == kv_len + and kv_len in (2048, 4096, 8192) + ) + if use_d64_causal_seq32_mean: + k_mean_flat = gfx12_native.mean_nhd_d64_seq32(k_nhd) + else: + k_mean_flat = gfx12_native.mean_nhd(k_nhd) k_mean = k_mean_flat.unsqueeze(1) else: k_mean = k_nhd.mean(dim=1, keepdim=True)