Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ __pycache__/

# C extensions
*.so
*.pyd

# Distribution / packaging
.Python
Expand Down Expand Up @@ -174,3 +175,16 @@ cython_debug/
.pypirc

/wheelhouse
csrc/cp_async_hip.cuh
csrc/fused/fused.hip
csrc/fused/pybind_hip.cpp
csrc/numeric_conversion_hip.cuh
csrc/qattn/qk_int_sv_gfx12_native.hip
csrc/qattn/qk_int_sv_gfx12_native_*.hip
csrc/reduction_utils_hip.cuh
evaluate/
*.pkl
*.vscode/
bench_qk_int8_pv_fp8_cuda_head96.py
*test.py
*unmask.py
4 changes: 2 additions & 2 deletions csrc/cp_async.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum class PrefetchMode {
kPrefetch // Fetch additional data from global memory to L2
};

#if (__CUDACC_VER_MAJOR__ >= 11)
#if !defined(__HIP_PLATFORM_AMD__) && (__CUDACC_VER_MAJOR__ >= 11)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800))
#define CP_ASYNC_ENABLED
#endif
Expand Down Expand Up @@ -138,4 +138,4 @@ __device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, b
#endif
}

} // namespace cp_async
} // namespace cp_async
9 changes: 9 additions & 0 deletions csrc/fused/fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>

#if defined(__HIP_PLATFORM_AMD__)
using nv_bfloat16 = __hip_bfloat16;
using nv_bfloat162 = __hip_bfloat162;
#endif

using torch::stable::Tensor;

enum class QuantType
Expand Down Expand Up @@ -329,7 +334,11 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp
const uint32_t stride_bz_mean, const uint32_t stride_h_mean,
const uint32_t stride_bz_scale, const uint32_t stride_h_scale)
{
#if defined(__HIP_PLATFORM_AMD__)
static_assert(std::is_same<T, half>::value || std::is_same<T, nv_bfloat16>::value, "Only half and bfloat16 are supported");
#else
static_assert(std::is_same<T, half>::value || std::is_same<T, __nv_bfloat16>::value, "Only half and bfloat16 are supported");
#endif

constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16

Expand Down
81 changes: 73 additions & 8 deletions csrc/numeric_conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,28 @@
*/

#pragma once
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
#include <hip/hip_runtime.h>
#else
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda/pipeline>
#endif

#if !defined(__HIP_PLATFORM_AMD__)
#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400)
#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890))
#define FP8_CAST_ENABLED
#endif
#endif
#endif

#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
#define RUNTIME_ASSERT(x) __brkpt()
#else
#include <assert.h>
Expand All @@ -39,13 +48,28 @@
__device__ __forceinline__ void unpack_half2_from_uint32_to_float(float* dest, uint32_t source) {
uint16_t h0 = source & 0xFFFF;
uint16_t h1 = (source >> 16) & 0xFFFF;
#if defined(__HIP_PLATFORM_AMD__)
union {
uint16_t bits;
half value;
} lo{h0}, hi{h1};
dest[0] = __half2float(lo.value);
dest[1] = __half2float(hi.value);
#else
asm("cvt.f32.f16 %0, %1;" : "=f"(dest[0]) : "h"(h0));
asm("cvt.f32.f16 %0, %1;" : "=f"(dest[1]) : "h"(h1));
#endif
}

__device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_float2_to_fp8x2(
make_float2(source0[0], source0[1]), __HIP_SATFINITE, __HIP_E4M3);
const auto hi = __hip_cvt_float2_to_fp8x2(
make_float2(source1[0], source1[1]), __HIP_SATFINITE, __HIP_E4M3);
dest[0] = static_cast<uint32_t>(lo) | (static_cast<uint32_t>(hi) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -62,7 +86,13 @@ __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0

__device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_float2_to_fp8x2(
make_float2(source0[0], source0[1]), __HIP_SATFINITE, __HIP_E5M2);
const auto hi = __hip_cvt_float2_to_fp8x2(
make_float2(source1[0], source1[1]), __HIP_SATFINITE, __HIP_E5M2);
dest[0] = static_cast<uint32_t>(lo) | (static_cast<uint32_t>(hi) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -79,7 +109,13 @@ __device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0

__device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
float s0[2];
float s1[2];
unpack_half2_from_uint32_to_float(s0, source0[0]);
unpack_half2_from_uint32_to_float(s1, source1[0]);
floatx4_to_e4m3x4(dest, s0, s1);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -96,7 +132,13 @@ __device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *sourc

__device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
float s0[2];
float s1[2];
unpack_half2_from_uint32_to_float(s0, source0[0]);
unpack_half2_from_uint32_to_float(s1, source1[0]);
floatx4_to_e5m2x4(dest, s0, s1);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo;\n" \
Expand All @@ -113,7 +155,16 @@ __device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *sourc

__device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] & 0xFFFF), __HIP_E4M3);
const auto hi = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] >> 16), __HIP_E4M3);
dest0[0] = static_cast<uint32_t>(lo.x.x) |
(static_cast<uint32_t>(lo.y.x) << 16);
dest1[0] = static_cast<uint32_t>(hi.x.x) |
(static_cast<uint32_t>(hi.y.x) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo, hi;\n" \
Expand All @@ -128,7 +179,16 @@ __device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest

__device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source)
{
#ifdef FP8_CAST_ENABLED
#if defined(__HIP_PLATFORM_AMD__)
const auto lo = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] & 0xFFFF), __HIP_E5M2);
const auto hi = __hip_cvt_fp8x2_to_halfraw2(
static_cast<__hip_fp8x2_storage_t>(source[0] >> 16), __HIP_E5M2);
dest0[0] = static_cast<uint32_t>(lo.x.x) |
(static_cast<uint32_t>(lo.y.x) << 16);
dest1[0] = static_cast<uint32_t>(hi.x.x) |
(static_cast<uint32_t>(hi.y.x) << 16);
#elif defined(FP8_CAST_ENABLED)
asm volatile( \
"{\n" \
".reg .b16 lo, hi;\n" \
Expand All @@ -143,7 +203,12 @@ __device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest

__device__ __forceinline__ int8_t float_to_int8_rn(float x)
{
#if defined(__HIP_PLATFORM_AMD__)
const float clipped = fminf(127.0f, fmaxf(-128.0f, nearbyintf(x)));
return static_cast<int8_t>(clipped);
#else
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}
#endif
}
129 changes: 129 additions & 0 deletions csrc/qattn/attn_gfx12_native.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Copyright (c) 2024 by SageAttention team.
*
* Licensed under the Apache License, Version 2.0.
*/

#pragma once

#include <torch/csrc/stable/tensor.h>

#include <vector>

using torch::stable::Tensor;

Tensor qk_int8_sv_f16_d64_native_attn_gfx12(
Tensor query,
Tensor key,
Tensor value,
Tensor output,
Tensor query_scale,
Tensor key_scale,
int64_t tensor_layout,
int64_t is_causal,
double sm_scale,
int64_t valid_kv_len = 0,
int64_t value_transposed_hnd = -1,
int64_t pv_accum_mode = -1);

Tensor qk_rawq_int8_sv_f8_native_attn_gfx12(
Tensor query,
Tensor key,
Tensor value,
Tensor output,
Tensor key_scale,
int64_t tensor_layout,
int64_t is_causal,
double sm_scale,
int64_t valid_kv_len = 0,
int64_t value_transposed_hnd = -1,
int64_t key_hnd_layout = 0);

Tensor qk_rawq_int8_sv_f16_native_attn_gfx12(
Tensor query,
Tensor key,
Tensor value,
Tensor output,
Tensor key_scale,
int64_t tensor_layout,
int64_t is_causal,
double sm_scale,
int64_t valid_kv_len = 0,
int64_t pv_accum_mode = -1);

Tensor qk_int8_sv_f8_scaled_native_attn_gfx12(
Tensor query,
Tensor key,
Tensor value,
Tensor output,
Tensor query_scale,
Tensor key_scale,
Tensor value_scale,
int64_t tensor_layout,
int64_t is_causal,
double sm_scale,
int64_t valid_kv_len = 0);

Tensor qk_rawq_int8_sv_f8_scaled_native_attn_gfx12(
Tensor query,
Tensor key,
Tensor value,
Tensor output,
Tensor key_scale,
Tensor value_scale,
int64_t tensor_layout,
int64_t is_causal,
double sm_scale,
int64_t valid_kv_len = 0,
int64_t value_transposed_hnd = -1,
int64_t key_hnd_layout = 0);

Tensor sage_fp8_nhd_short_mha_gfx12(
Tensor query,
Tensor key,
Tensor value,
int64_t is_causal,
double sm_scale,
double scale_max);

Tensor qk_int8_sv_f16_d64_prepare_attn_hnd_gfx12(
Tensor query,
Tensor key,
Tensor value,
int64_t is_causal,
int64_t value_is_fp8,
int64_t use_raw_f16_value,
double sm_scale,
int64_t valid_kv_len = 0,
int64_t pv_accum_mode = -1);

std::vector<Tensor> quant_q_nhd_per_warp_gfx12(Tensor query);

Tensor transpose_value_fp8_hnd_gfx12(Tensor value);

Tensor transpose_value_fp8_scaled_hnd_gfx12(Tensor value, Tensor value_scale);

std::vector<Tensor> fp8_value_nhd_short_gfx12(
Tensor value,
double scale_max);

Tensor mean_nhd_gfx12(Tensor input);

Tensor mean_nhd_d64_seq32_gfx12(Tensor input);

Tensor mean_hnd_gfx12(Tensor input);

std::vector<Tensor> prepare_qkv_hnd_smooth_f16_gfx12(
Tensor query,
Tensor key,
Tensor value,
Tensor key_mean);

std::vector<Tensor> mean_and_fp8_value_nhd_short_gfx12(
Tensor key,
Tensor value,
double scale_max);

Tensor transpose_value_f16_hnd_gfx12(Tensor value);

Tensor convert_f16_to_bf16_gfx12(Tensor input);
Loading