Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,16 @@
/test
__pycache__
*.so
*.pyd
csrc/cp_async_hip.cuh
csrc/fused/fused.hip
csrc/fused/pybind_hip.cpp
csrc/numeric_conversion_hip.cuh
csrc/qattn/qk_int_sv_gfx12_native.hip
csrc/reduction_utils_hip.cuh
evaluate/
*.pkl
*.vscode/
bench_qk_int8_pv_fp8_cuda_head96.py
*test.py
*unmask.py
*unmask.py
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ attn_output = sageattn(q, k, v, tensor_layout="HND", is_causal=False)
+ `sageattn_qk_int8_pv_fp16_cuda`: INT8 quantization for $QK^\top$ and FP16 for $PV$ using CUDA backend.
+ `sageattn_qk_int8_pv_fp8_cuda`: INT8 quantization for $QK^\top$ and FP8 for $PV$ using CUDA backend. (Note that setting `pv_accum_dtype=fp32+fp16` corresponds to SageAttention2++.)
+ `sageattn_qk_int8_pv_fp8_cuda_sm90`: INT8 quantization for $QK^\top$ and FP8 for $PV$ using CUDA backend, specifically optimized for Hopper GPUs.
+ `sageattn_qk_int8_pv_gfx12_native`: INT8 quantization for $QK^\top$ and FP8/FP16 for $PV$ using ROCm native kernels on RDNA4/gfx12 GPUs.
+ `sageattn_varlen`: INT8 quantization for $QK^\top$ and FP16 for $PV$ using Triton backend. Support for varying sequence lengths within the same batch.

For optimal speed and accuracy performance on custom devices and models, we strongly recommend referring to the [this file](./sageattention/core.py) for detailed guidance.
Expand Down
4 changes: 2 additions & 2 deletions csrc/cp_async.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class PrefetchMode {
kPrefetch // Fetch additional data from global memory to L2
};

#if (__CUDACC_VER_MAJOR__ >= 11)
#if !defined(__HIP_PLATFORM_AMD__) && (__CUDACC_VER_MAJOR__ >= 11)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
#define CP_ASYNC_ENABLED
#endif
Expand Down Expand Up @@ -138,4 +138,4 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, b
#endif
}

} // namespace cp_async
} // namespace cp_async
12 changes: 11 additions & 1 deletion csrc/fused/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#if defined(__HIP_PLATFORM_AMD__)
using nv_bfloat16 = __hip_bfloat16;
using nv_bfloat162 = __hip_bfloat162;
#define __float2bfloat16_rn __float2bfloat16
#endif

enum class QuantType
{
kInt8,
Expand Down Expand Up @@ -320,7 +326,11 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp
const uint32_t stride_bz_mean, const uint32_t stride_h_mean,
const uint32_t stride_bz_scale, const uint32_t stride_h_scale)
{
#if defined(__HIP_PLATFORM_AMD__)
static_assert(std::is_same<T, half>::value || std::is_same<T, nv_bfloat16>::value, "Only half and bfloat16 are supported");
#else
static_assert(std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value, "Only half and bfloat16 are supported");
#endif

constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16

Expand Down Expand Up @@ -1080,4 +1090,4 @@ void mean_scale_fuse_quant_cuda(
scale.stride(0), scale.stride(1)
);
});
}
}
81 changes: 73 additions & 8 deletions csrc/numeric_conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,28 @@
*/

#pragma once
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/pipeline>
#endif

#if !defined(__HIP_PLATFORM_AMD__)
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890))
#define FP8_CAST_ENABLED
#endif
#endif
#endif

#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
#define RUNTIME_ASSERT(x) __brkpt()
#else
#include <assert.h>
Expand All @@ -39,13 +48,28 @@
__device__ __forceinline__ void unpack_half2_from_uint32_to_float(float* dest, uint32_t source) {
uint16_t h0 = source & 0xFFFF;
uint16_t h1 = (source >> 16) & 0xFFFF;
#if defined(__HIP_PLATFORM_AMD__)
union {
uint16_t bits;
half value;
} lo{h0}, hi{h1};
dest[0] = __half2float(lo.value);
dest[1] = __half2float(hi.value);
#else
asm("cvt.f32.f16 %0, %1;" : "=f"(dest[0]) : "h"(h0));
asm("cvt.f32.f16 %0, %1;" : "=f"(dest[1]) : "h"(h1));
#endif
}

__device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_float2_to_fp8x2(
make_float2(source0[0], source0[1]), __HIP_SATFINITE, __HIP_E4M3);
const auto hi = __hip_cvt_float2_to_fp8x2(
make_float2(source1[0], source1[1]), __HIP_SATFINITE, __HIP_E4M3);
dest[0] = static_cast<uint32_t>(lo) | (static_cast<uint32_t>(hi) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -62,7 +86,13 @@ __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0

__device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_float2_to_fp8x2(
make_float2(source0[0], source0[1]), __HIP_SATFINITE, __HIP_E5M2);
const auto hi = __hip_cvt_float2_to_fp8x2(
make_float2(source1[0], source1[1]), __HIP_SATFINITE, __HIP_E5M2);
dest[0] = static_cast<uint32_t>(lo) | (static_cast<uint32_t>(hi) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -79,7 +109,13 @@ __device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0

__device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
float s0[2];
float s1[2];
unpack_half2_from_uint32_to_float(s0, source0[0]);
unpack_half2_from_uint32_to_float(s1, source1[0]);
floatx4_to_e4m3x4(dest, s0, s1);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -96,7 +132,13 @@ __device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *sourc

__device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
float s0[2];
float s1[2];
unpack_half2_from_uint32_to_float(s0, source0[0]);
unpack_half2_from_uint32_to_float(s1, source1[0]);
floatx4_to_e5m2x4(dest, s0, s1);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -113,7 +155,16 @@ __device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *sourc

__device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] & 0xFFFF), __HIP_E4M3);
const auto hi = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] >> 16), __HIP_E4M3);
dest0[0] = static_cast<uint32_t>(lo.x.x) |
(static_cast<uint32_t>(lo.y.x) << 16);
dest1[0] = static_cast<uint32_t>(hi.x.x) |
(static_cast<uint32_t>(hi.y.x) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo, hi;\n" \
Expand All @@ -128,7 +179,16 @@ __device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest

__device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] & 0xFFFF), __HIP_E5M2);
const auto hi = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] >> 16), __HIP_E5M2);
dest0[0] = static_cast<uint32_t>(lo.x.x) |
(static_cast<uint32_t>(lo.y.x) << 16);
dest1[0] = static_cast<uint32_t>(hi.x.x) |
(static_cast<uint32_t>(hi.y.x) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo, hi;\n" \
Expand All @@ -143,7 +203,12 @@ __device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest

__device__ __forceinline__ int8_t float_to_int8_rn(float x)
{
#if defined(__HIP_PLATFORM_AMD__)
const float clipped = fminf(127.0f, fmaxf(-128.0f, nearbyintf(x)));
return static_cast<int8_t>(clipped);
#else
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}
#endif
}
86 changes: 86 additions & 0 deletions csrc/qattn/attn_gfx12_native.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (c) 2024 by SageAttention team.
*
* Licensed under the Apache License, Version 2.0.
*/

#pragma once

#include <vector>

#include <torch/extension.h>

torch::Tensor qk_int8_sv_f16_d64_native_attn_gfx12(
torch::Tensor query,
torch::Tensor key,
torch::Tensor value,
torch::Tensor output,
torch::Tensor query_scale,
torch::Tensor key_scale,
int tensor_layout,
int is_causal,
float sm_scale,
int64_t valid_kv_len = 0,
int value_transposed_hnd = -1,
int pv_accum_mode = -1);

torch::Tensor qk_rawq_int8_sv_f8_native_attn_gfx12(
torch::Tensor query,
torch::Tensor key,
torch::Tensor value,
torch::Tensor output,
torch::Tensor key_scale,
int tensor_layout,
int is_causal,
float sm_scale,
int64_t valid_kv_len = 0,
int value_transposed_hnd = -1);

torch::Tensor qk_int8_sv_f8_scaled_native_attn_gfx12(
torch::Tensor query,
torch::Tensor key,
torch::Tensor value,
torch::Tensor output,
torch::Tensor query_scale,
torch::Tensor key_scale,
torch::Tensor value_scale,
int tensor_layout,
int is_causal,
float sm_scale,
int64_t valid_kv_len = 0);

torch::Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12(
torch::Tensor query,
torch::Tensor key,
torch::Tensor value,
torch::Tensor output,
torch::Tensor key_scale,
torch::Tensor value_scale,
int tensor_layout,
int is_causal,
float sm_scale,
int64_t valid_kv_len = 0,
int value_transposed_hnd = -1);

torch::Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12(
torch::Tensor query,
torch::Tensor key,
torch::Tensor value,
int is_causal,
int value_is_fp8,
int use_raw_f16_value,
float sm_scale,
int64_t valid_kv_len = 0,
int pv_accum_mode = -1);

std::vector<torch::Tensor> quant_q_nhd_per_warp_gfx12(torch::Tensor query);

torch::Tensor transpose_value_fp8_hnd_gfx12(torch::Tensor value);

torch::Tensor transpose_value_fp8_scaled_hnd_gfx12(
torch::Tensor value,
torch::Tensor value_scale);

torch::Tensor transpose_value_f16_hnd_gfx12(torch::Tensor value);

torch::Tensor convert_f16_to_bf16_gfx12(torch::Tensor input);
45 changes: 45 additions & 0 deletions csrc/qattn/pybind_gfx12_native.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Copyright (c) 2024 by SageAttention team.
*
* Licensed under the Apache License, Version 2.0.
*/

#include <pybind11/pybind11.h>
#include <torch/extension.h>

#include "attn_gfx12_native.h"

namespace py = pybind11;

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("qk_int8_sv_f16_d64_native_attn", &qk_int8_sv_f16_d64_native_attn_gfx12,
py::arg("query"), py::arg("key"), py::arg("value"), py::arg("output"),
py::arg("query_scale"), py::arg("key_scale"), py::arg("tensor_layout"),
py::arg("is_causal"), py::arg("sm_scale"), py::arg("valid_kv_len") = 0,
py::arg("value_transposed_hnd") = -1, py::arg("pv_accum_mode") = -1);
m.def("qk_rawq_int8_sv_f8_native_attn", &qk_rawq_int8_sv_f8_native_attn_gfx12,
py::arg("query"), py::arg("key"), py::arg("value"), py::arg("output"),
py::arg("key_scale"), py::arg("tensor_layout"),
py::arg("is_causal"), py::arg("sm_scale"), py::arg("valid_kv_len") = 0,
py::arg("value_transposed_hnd") = -1);
m.def("qk_int8_sv_f8_scaled_native_attn", &qk_int8_sv_f8_scaled_native_attn_gfx12,
py::arg("query"), py::arg("key"), py::arg("value"), py::arg("output"),
py::arg("query_scale"), py::arg("key_scale"), py::arg("value_scale"),
py::arg("tensor_layout"), py::arg("is_causal"), py::arg("sm_scale"),
py::arg("valid_kv_len") = 0);
m.def("qk_rawq_int8_sv_f8_scaled_native_attn", &qk_rawq_int8_sv_f8_scaled_native_attn_gfx12,
py::arg("query"), py::arg("key"), py::arg("value"), py::arg("output"),
py::arg("key_scale"), py::arg("value_scale"), py::arg("tensor_layout"),
py::arg("is_causal"), py::arg("sm_scale"), py::arg("valid_kv_len") = 0,
py::arg("value_transposed_hnd") = -1);
m.def("qk_int8_sv_f16_d64_prepare_attn_hnd", &qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12,
py::arg("query"), py::arg("key"), py::arg("value"), py::arg("is_causal"),
py::arg("value_is_fp8"), py::arg("use_raw_f16_value"), py::arg("sm_scale"),
py::arg("valid_kv_len") = 0, py::arg("pv_accum_mode") = -1);
m.def("quant_q_nhd_per_warp", &quant_q_nhd_per_warp_gfx12);
m.def("transpose_value_fp8_hnd", &transpose_value_fp8_hnd_gfx12);
m.def("transpose_value_fp8_scaled_hnd", &transpose_value_fp8_scaled_hnd_gfx12);
m.def("transpose_value_f16_hnd", &transpose_value_f16_hnd_gfx12);
m.def("convert_f16_to_bf16", &convert_f16_to_bf16_gfx12);
}
Loading