diff --git a/csrc/cp_async_hip.cuh b/csrc/cp_async_hip.cuh new file mode 100644 index 00000000..f16b3de9 --- /dev/null +++ b/csrc/cp_async_hip.cuh @@ -0,0 +1,142 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2024 by SageAttention team. + * + * This file is based on code from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/cp_async.cuh + * Copyright (c) 2023 by FlashInfer team. + * Small modifications made by SageAttention team, 2024 (e.g., renamed namespace). + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace cp_async { + +enum class SharedMemFillMode { + kFillZero, // Fill zero to shared memory when predicate is false + kNoFill // Do not fill zero to shared memory when predicate is false +}; + +enum class PrefetchMode { + kNoPrefetch, // Do not fetch additional data from global memory to L2 + kPrefetch // Fetch additional data from global memory to L2 +}; + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define CP_ASYNC_ENABLED +#endif +#endif + +/*! + * \brief Wrapper of PTX cp.async.commit_group instruction, commit all prior uncommitted + * cp.async instructions to a group + */ +__device__ __forceinline__ void commit_group() { +#ifdef CP_ASYNC_ENABLED + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.wait_group instruction + * \tparam n Wait till most recent n groups are committed + */ +template +__device__ __forceinline__ void wait_group() { +#ifdef CP_ASYNC_ENABLED + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from + * global memory to shared memory + * \tparam prefetch_mode Whether to fetch additional data from global memory to L2 + * \tparam T Data type + * \param smem_ptr Pointer to shared memory + * \param gmem_ptr Pointer to global memory + */ +template +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { +#ifdef CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(16)); + } +#else + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); +#endif +} + +/*! + * \brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously copy data from + * global memory to shared memory with predicate. + * \tparam prefetch_mode Whether to fetch additional data from global memory to L2 + * \tparam fill_mode Whether to fill zero to shared memory when predicate is false + * \tparam T Data type + * \param smem_ptr Pointer to shared memory + * \param gmem_ptr Pointer to global memory + * \param predicate Predicate value + * \note fill zero is slower than not fill zero + */ +template +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { +#ifdef CP_ASYNC_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + int src_in_bytes = predicate ? 16 : 0; + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } else { + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" ::"r"(smem_int_ptr), + "l"(gmem_ptr), "n"(16), "r"(src_in_bytes)); + } + } else { + if constexpr (prefetch_mode == PrefetchMode::kPrefetch) { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } else { + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)predicate), + "r"(smem_int_ptr), "l"(gmem_ptr), "n"(16)); + } + } +#else + if (predicate) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); + } else { + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + *((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0); + } + } +#endif +} + +} // namespace cp_async \ No newline at end of file diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 6d798e68..06fd8e55 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -72,6 +72,20 @@ throw std::invalid_argument(err_msg.str()); \ } +#if defined(USE_ROCM) +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + if (pytorch_dtype == at::ScalarType::Half) { \ + using c_type = half; \ + __VA_ARGS__ \ + } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ + using c_type = hip_bfloat16; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + } +#else #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ if (pytorch_dtype == at::ScalarType::Half) { \ using c_type = half; \ @@ -84,6 +98,7 @@ oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ TORCH_CHECK(false, oss.str()); \ } +#endif #define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ if (block_size == 64) { \ diff --git a/csrc/fused/fused.cu b/csrc/fused/fused.cu index fb8b9f15..ff227e01 100644 --- a/csrc/fused/fused.cu +++ b/csrc/fused/fused.cu @@ -20,17 +20,32 @@ #include "../dispatch_utils.h" #include "../utils.cuh" #include "../reduction_utils.cuh" + +#if !defined(USE_ROCM) #include "../numeric_conversion.cuh" #include "../cp_async.cuh" #include #include +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + enum class QuantType { kInt8, kInt4, }; +#if !defined(USE_ROCM) template __device__ __forceinline__ float convert_to_float(T val) { @@ -61,6 +76,94 @@ __device__ __forceinline__ T convert_from_float(float val) } } +#else +__device__ __forceinline__ float u32_as_f32(uint32_t u) { + union { uint32_t u; float f; } v{u}; return v.f; +} + +__device__ __forceinline__ uint16_t bf16_bits(__hip_bfloat16 x) { + return *reinterpret_cast(&x); +} + +// ========== to-float ========== +template +__device__ __forceinline__ float convert_to_float(T val); + +// __half → float +template <> +__device__ __forceinline__ float convert_to_float<__half>(__half v) { + return __half2float(v); +} + +// __hip_bfloat16 → float +template <> +__device__ __forceinline__ float convert_to_float<__hip_bfloat16>(__hip_bfloat16 v) { + uint16_t hi = bf16_bits(v); + return u32_as_f32(uint32_t(hi) << 16); +} + +// hip_bfloat16 → float +template <> +__device__ __forceinline__ float convert_to_float(hip_bfloat16 v) { + return convert_to_float(*reinterpret_cast(&v)); +} + +template +__device__ __forceinline__ T convert_from_float(float val) { + static_assert(std::is_same::value || + std::is_same::value, + "Only __half and __hip_bfloat16 are supported (ROCm)."); + + if constexpr (std::is_same::value) { + // f32 -> fp16 (round-to-nearest) + return __float2half_rn(val); + } else { + // f32 -> bf16 (round-to-nearest) + return __float2bfloat16(val); + } +} + +namespace detail { + + struct vec16_t { float x, y, z, w; }; + + template + __device__ __forceinline__ void predicated_g2s_16B(T* smem_dst, const T* gmem_src, bool pred) { + if (pred) { + *reinterpret_cast(smem_dst) = *reinterpret_cast(gmem_src); + } else if constexpr (PadZero) { + *reinterpret_cast(smem_dst) = vec16_t{0.f, 0.f, 0.f, 0.f}; + } + } + + __device__ __forceinline__ void store_8fp8(const uint32_t* __restrict__ fp8x4, + int8_t* __restrict__ out) { + *reinterpret_cast(out) = *reinterpret_cast(fp8x4); + } + + __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t* dest, float* s0, float* s1) { + + #ifdef __ROCM_ARCH_GFX942 + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + #else + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3); + #endif + + + *dest = (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | ((uint32_t)b3 << 24); + } + +} // namespace detail + +#endif + +#if !defined(USE_ROCM) template __global__ void QuantInt8Kernel(T *__restrict__ input, T *__restrict__ mean, int8_t *__restrict__ output, float *__restrict__ scale, float sm_scale, const uint32_t num_tokens, const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, @@ -258,6 +361,7 @@ __global__ void SubMeanKernel(T *__restrict__ input, T *__restrict__ mean, half } } } +#endif template __global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, @@ -265,7 +369,11 @@ __global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) { - static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #if !defined(USE_ROCM) +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #else +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #endif constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 uint32_t num_threads_per_token = head_dim / pack_size; @@ -284,6 +392,7 @@ __global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ __shared__ T shared_load[CTA_SIZE][head_dim]; __shared__ T shared_store[head_dim][CTA_SIZE]; +#if !defined(USE_ROCM) // 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 // permute on the seq dimension for fp8 mma uint32_t smem_load_row_base = ((thread_id / num_threads_per_token) / 16) * 16; @@ -294,6 +403,14 @@ __global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ cp_async::pred_load_128b(shared_load[smem_load_row] + thread_id % num_threads_per_token * pack_size, input_ptr_base, thread_base_token < num_tokens); cp_async::commit_group(); cp_async::wait_group<0>(); +#else + uint32_t smem_load_row = thread_id / num_threads_per_token; + + detail::predicated_g2s_16B( + &shared_load[smem_load_row][ (thread_id % num_threads_per_token) * pack_size ], + input_ptr_base, + thread_base_token < num_tokens); +#endif __syncthreads(); uint32_t smem_row_base = thread_id % CTA_SIZE; @@ -309,7 +426,14 @@ __global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ __syncthreads(); +#if !defined(USE_ROCM) *(float4*)(output_ptr_base) = *(float4*)(&shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size]); +#else + *reinterpret_cast(output_ptr_base) = + *reinterpret_cast( + &shared_store[ thread_id / num_threads_per_cta ] + [ (thread_id % num_threads_per_cta) * pack_size ]); +#endif } @@ -320,7 +444,12 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp const uint32_t stride_bz_mean, const uint32_t stride_h_mean, const uint32_t stride_bz_scale, const uint32_t stride_h_scale) { - static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + +// #if !defined(USE_ROCM) +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #else +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #endif constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 @@ -418,14 +547,21 @@ __global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ outp x_val_float[j] *= recp_scale; } } - +#if !defined(USE_ROCM) floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); *(uint2*)(output_ptr_base + i * gmem_stride) = *(uint2*)(&x_val_fp8[0]); +#else + detail::floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + detail::floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + detail::store_8fp8(&x_val_fp8[0], output_ptr_base + i * gmem_stride); +#endif } } +#if !defined(USE_ROCM) void quant_per_block_int8_cuda( torch::Tensor input, torch::Tensor output, @@ -847,6 +983,8 @@ void sub_mean_cuda( }); } +#endif + void transpose_pad_permute_cuda( torch::Tensor input, torch::Tensor output, @@ -999,6 +1137,7 @@ void scale_fuse_quant_cuda( }); } +#if !defined(USE_ROCM) void mean_scale_fuse_quant_cuda( torch::Tensor input, torch::Tensor output, @@ -1080,4 +1219,5 @@ void mean_scale_fuse_quant_cuda( scale.stride(0), scale.stride(1) ); }); -} \ No newline at end of file +} +#endif \ No newline at end of file diff --git a/csrc/fused/fused.h b/csrc/fused/fused.h index 268b2835..25277ff8 100644 --- a/csrc/fused/fused.h +++ b/csrc/fused/fused.h @@ -16,6 +16,7 @@ #include +#if !defined(USE_ROCM) void quant_per_block_int8_cuda( torch::Tensor input, torch::Tensor output, @@ -53,6 +54,8 @@ void sub_mean_cuda( torch::Tensor output, int tensor_layout); +#endif + void transpose_pad_permute_cuda( torch::Tensor input, torch::Tensor output, @@ -66,6 +69,7 @@ void scale_fuse_quant_cuda( float scale_max, int tensor_layout); +#if !defined(USE_ROCM) void mean_scale_fuse_quant_cuda( torch::Tensor input, torch::Tensor output, @@ -73,4 +77,5 @@ void mean_scale_fuse_quant_cuda( torch::Tensor scale, int num_tokens, float scale_max, - int tensor_layout); \ No newline at end of file + int tensor_layout); +#endif \ No newline at end of file diff --git a/csrc/fused/fused.hip b/csrc/fused/fused.hip new file mode 100644 index 00000000..1f62427e --- /dev/null +++ b/csrc/fused/fused.hip @@ -0,0 +1,1224 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "../dispatch_utils.h" +#include "../utils.cuh" +#include "../reduction_utils_hip.cuh" + +#if !defined(USE_ROCM) +#include "../numeric_conversion_hip.cuh" +#include "../cp_async_hip.cuh" +#include +#include + +#else +#include +#include +#include +#include +#include +#include +#include +#include +#include +#endif + +enum class QuantType +{ + kInt8, + kInt4, +}; + +#if !defined(USE_ROCM) +template +__device__ __forceinline__ float convert_to_float(T val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __half2float(val); + } + else if constexpr (std::is_same::value) + { + return __bfloat162float(val); + } +} + +template +__device__ __forceinline__ T convert_from_float(float val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __float2half_rn(val); + } + else if constexpr (std::is_same::value) + { + return __float2bfloat16_rn(val); + } +} + +#else +__device__ __forceinline__ float u32_as_f32(uint32_t u) { + union { uint32_t u; float f; } v{u}; return v.f; +} + +__device__ __forceinline__ uint16_t bf16_bits(__hip_bfloat16 x) { + return *reinterpret_cast(&x); +} + +// ========== to-float ========== +template +__device__ __forceinline__ float convert_to_float(T val); + +// __half → float +template <> +__device__ __forceinline__ float convert_to_float<__half>(__half v) { + return __half2float(v); +} + +// __hip_bfloat16 → float +template <> +__device__ __forceinline__ float convert_to_float<__hip_bfloat16>(__hip_bfloat16 v) { + uint16_t hi = bf16_bits(v); + return u32_as_f32(uint32_t(hi) << 16); +} + +// hip_bfloat16 → float +template <> +__device__ __forceinline__ float convert_to_float(hip_bfloat16 v) { + return convert_to_float(*reinterpret_cast(&v)); +} + +template +__device__ __forceinline__ T convert_from_float(float val) { + static_assert(std::is_same::value || + std::is_same::value, + "Only __half and __hip_bfloat16 are supported (ROCm)."); + + if constexpr (std::is_same::value) { + // f32 -> fp16 (round-to-nearest) + return __float2half_rn(val); + } else { + // f32 -> bf16 (round-to-nearest) + return __float2bfloat16(val); + } +} + +namespace detail { + + struct vec16_t { float x, y, z, w; }; + + template + __device__ __forceinline__ void predicated_g2s_16B(T* smem_dst, const T* gmem_src, bool pred) { + if (pred) { + *reinterpret_cast(smem_dst) = *reinterpret_cast(gmem_src); + } else if constexpr (PadZero) { + *reinterpret_cast(smem_dst) = vec16_t{0.f, 0.f, 0.f, 0.f}; + } + } + + __device__ __forceinline__ void store_8fp8(const uint32_t* __restrict__ fp8x4, + int8_t* __restrict__ out) { + *reinterpret_cast(out) = *reinterpret_cast(fp8x4); + } + + __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t* dest, float* s0, float* s1) { + + #ifdef __ROCM_ARCH_GFX942 + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + #else + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3); + #endif + + + *dest = (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | ((uint32_t)b3 << 24); + } + +} // namespace detail + +#endif + +#if !defined(USE_ROCM) +template +__global__ void QuantInt8Kernel(T *__restrict__ input, T *__restrict__ mean, int8_t *__restrict__ output, float *__restrict__ scale, float sm_scale, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_output, const uint32_t stride_seq_output, const uint32_t stride_h_output, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + static_assert(num_pack_per_thread > 0, "The number of pack per thread must be greater than 0"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + constexpr uint32_t num_threads_per_token = head_dim / pack_size; + + static_assert(num_threads_per_token <= 32, "The number of threads per token must be less than or equal to warp size"); + + T x_val[num_pack_per_thread][8]; + T mean_val[8]; + float x_val_float[num_pack_per_thread][8]; + float mean_val_float[8]; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * BLOCK_SIZE + thread_id / num_threads_per_token; + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T *mean_ptr_base = mean + batch_id * stride_bz_mean + head_id * stride_h_mean + thread_id % num_threads_per_token * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + thread_base_token * stride_seq_output + thread_id % num_threads_per_token * pack_size; + float *scale_ptr_base = scale + batch_id * stride_bz_scale + head_id * stride_h_scale + bx; + + if constexpr (sub_mean) + { + *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + mean_val_float[j] = convert_to_float(mean_val[j]); + } + } + + constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread; + + // load the data + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] = convert_to_float(x_val[i][j]); + } + + if constexpr (sub_mean) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] -= mean_val_float[j]; + } + } + + if constexpr (has_sm_scale) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] *= sm_scale; + } + } + } + else + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[i][j] = 0.0f; + } + } + } + + float amax_val = 0.0000001f; // prevent from dividing by zero + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + amax_val = fmaxf(amax_val, fabsf(x_val_float[i][j])); + } + } + + __shared__ float s_amax; + const float block_amax_val = vllm::blockReduceMax(amax_val); + if (thread_id == 0) + { + s_amax = block_amax_val; + scale_ptr_base[0] = s_amax / 127.0f; + } + + __syncthreads(); + + float tmp_scale = 127.0f / s_amax; + + char4 o_val[num_pack_per_thread][2]; + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { +#pragma unroll + for (uint32_t j = 0; j < 2; j += 1) + { + o_val[i][j] = make_char4( + float_to_int8_rn(x_val_float[i][j * 4 + 0] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 1] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 2] * tmp_scale), + float_to_int8_rn(x_val_float[i][j * 4 + 3] * tmp_scale) + ); + } + } + + // int8 result +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + + if (thread_base_token + i * iter_stride < num_tokens) + { + *reinterpret_cast(output_ptr_base + i * iter_stride * stride_seq_output) = *reinterpret_cast(&o_val[i][0]); + } + } +} + +template +__global__ void SubMeanKernel(T *__restrict__ input, T *__restrict__ mean, half *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_output, const uint32_t stride_seq_output, const uint32_t stride_h_output) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + static_assert(num_pack_per_thread > 0, "The number of pack per thread must be greater than 0"); + + using T2 = typename std::conditional::value, half2, nv_bfloat162>::type; + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + constexpr uint32_t num_threads_per_token = head_dim / pack_size; + + static_assert(num_threads_per_token <= 32, "The number of threads per token must be less than or equal to warp size"); + + T2 x_val[num_pack_per_thread][4]; + T2 mean_val[4]; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * BLOCK_SIZE + thread_id / num_threads_per_token; + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T *mean_ptr_base = mean + batch_id * stride_bz_mean + head_id * stride_h_mean + thread_id % num_threads_per_token * pack_size; + half *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + thread_base_token * stride_seq_output + thread_id % num_threads_per_token * pack_size; + + *(float4*)(&mean_val[0]) = *(float4*)(mean_ptr_base); + + constexpr uint32_t iter_stride = BLOCK_SIZE / num_pack_per_thread; + + // load the data + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + *(float4*)(&x_val[i][0]) = *(float4*)(input_ptr_base + i * iter_stride * stride_seq_input); +#pragma unroll + for (uint32_t j = 0; j < 4; j++) + { + x_val[i][j] = __hsub2(x_val[i][j], mean_val[j]); + + if constexpr (std::is_same::value) + { + ((half2*)x_val[i])[j] = __float22half2_rn(__bfloat1622float2(x_val[i][j])); + } + } + } + } + +#pragma unroll + for (uint32_t i = 0; i < num_pack_per_thread; i++) + { + if (thread_base_token + i * iter_stride < num_tokens) + { + *reinterpret_cast(output_ptr_base + i * iter_stride * stride_seq_output) = *reinterpret_cast(&x_val[i][0]); + } + } +} +#endif + +template +__global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) +{ + +// #if !defined(USE_ROCM) +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #else +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #endif + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + uint32_t num_threads_per_token = head_dim / pack_size; + uint32_t num_threads_per_cta = CTA_SIZE / pack_size; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * CTA_SIZE + thread_id / num_threads_per_token; + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T* output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + bx * CTA_SIZE + thread_id % num_threads_per_cta * pack_size + thread_id / num_threads_per_cta * stride_d_output; + + __shared__ T shared_load[CTA_SIZE][head_dim]; + __shared__ T shared_store[head_dim][CTA_SIZE]; + +#if !defined(USE_ROCM) + // 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 + // permute on the seq dimension for fp8 mma + uint32_t smem_load_row_base = ((thread_id / num_threads_per_token) / 16) * 16; + uint32_t smem_load_row_mod = (thread_id / num_threads_per_token) % 16; + uint32_t smem_load_row = smem_load_row_base + (smem_load_row_mod / 8) * 2 + ((smem_load_row_mod / 2) % 4) * 4 + (smem_load_row_mod % 2); + + constexpr cp_async::SharedMemFillMode fill_mode = pad_zero ? cp_async::SharedMemFillMode::kFillZero : cp_async::SharedMemFillMode::kNoFill; + cp_async::pred_load_128b(shared_load[smem_load_row] + thread_id % num_threads_per_token * pack_size, input_ptr_base, thread_base_token < num_tokens); + cp_async::commit_group(); + cp_async::wait_group<0>(); +#else + uint32_t smem_load_row = thread_id / num_threads_per_token; + + detail::predicated_g2s_16B( + &shared_load[smem_load_row][ (thread_id % num_threads_per_token) * pack_size ], + input_ptr_base, + thread_base_token < num_tokens); +#endif + __syncthreads(); + + uint32_t smem_row_base = thread_id % CTA_SIZE; + uint32_t smem_col_base = thread_id / CTA_SIZE; + uint32_t smem_col_stride = head_dim / 8; + + // TODO: use ldmatrix to do permutation +#pragma unroll + for (uint32_t i = 0; i < 8; i++) + { + shared_store[smem_col_base + i * smem_col_stride][smem_row_base] = shared_load[smem_row_base][smem_col_base + i * smem_col_stride]; + } + + __syncthreads(); + +#if !defined(USE_ROCM) + *(float4*)(output_ptr_base) = *(float4*)(&shared_store[thread_id / num_threads_per_cta][thread_id % num_threads_per_cta * pack_size]); +#else + *reinterpret_cast(output_ptr_base) = + *reinterpret_cast( + &shared_store[ thread_id / num_threads_per_cta ] + [ (thread_id % num_threads_per_cta) * pack_size ]); +#endif +} + + +template +__global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ output, float *__restrict__ mean, float *__restrict__ scale, const float scale_max, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_d_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + +// #if !defined(USE_ROCM) +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #else +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); +// #endif + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + + uint32_t head_id = blockIdx.x; + uint32_t batch_id = blockIdx.y; + uint32_t d_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t num_threads = blockDim.x; + uint32_t gmem_stride = num_threads * pack_size; + // pad the number of tokens to 16 to deal with fp8 permute in previous kernel + uint32_t fp8_padded_num_tokens = (num_tokens + 15) / 16 * 16; + uint32_t num_iters = fp8_padded_num_tokens / gmem_stride + ((fp8_padded_num_tokens % gmem_stride) > thread_id * pack_size); + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + d_id * stride_d_input + thread_id * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + d_id * stride_d_output + thread_id * pack_size; + + T x_val[8]; + float x_val_float[8]; + uint32_t x_val_fp8[2]; + + float max_val = - 1000000.0f; + float min_val = 1000000.0f; + float sum_val = 0.0f; + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + float x_temp = convert_to_float(x_val[j]); + max_val = fmaxf(max_val, x_temp); + min_val = fminf(min_val, x_temp); + + if constexpr (sub_mean) + { + sum_val += x_temp; + } + } + } + + // reduce + __shared__ float s_amax_val; + __shared__ float s_mean_val; + + float block_max_val = vllm::blockReduceMax(max_val); + float block_min_val = vllm::blockReduceMin(min_val); + float block_sum_val; + + if constexpr (sub_mean) + { + block_sum_val = vllm::blockReduceSum(sum_val); + } + + if (thread_id == 0) + { + s_mean_val = block_sum_val / fp8_padded_num_tokens; + + if constexpr (sub_mean) + { + s_amax_val = fmaxf(fabsf(block_max_val - s_mean_val), fabsf(block_min_val - s_mean_val)); + mean[batch_id * stride_bz_mean + head_id * stride_h_mean + d_id] = s_mean_val; + } + else + { + s_amax_val = fmaxf(fabsf(block_max_val), fabsf(block_min_val)); + } + + scale[batch_id * stride_bz_scale + head_id * stride_h_scale + d_id] = s_amax_val / scale_max; + } + + __syncthreads(); + + float mean_val = s_mean_val; + float recp_scale = scale_max / s_amax_val; + + // recalculate num_iters to cover all fp8 output tokens to prevent nan in random initialization + uint32_t padded_num_tokens = (num_tokens + pad_size - 1) / pad_size * pad_size; + num_iters = padded_num_tokens / gmem_stride + ((padded_num_tokens % gmem_stride) > thread_id * pack_size); + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[j] = convert_to_float(x_val[j]); + if constexpr (sub_mean) + { + x_val_float[j] = (x_val_float[j] - mean_val) * recp_scale; + } + else + { + x_val_float[j] *= recp_scale; + } + } +#if !defined(USE_ROCM) + floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + *(uint2*)(output_ptr_base + i * gmem_stride) = *(uint2*)(&x_val_fp8[0]); +#else + detail::floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + detail::floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + detail::store_8fp8(&x_val_fp8[0], output_ptr_base + i * gmem_stride); +#endif + } +} + +#if !defined(USE_ROCM) +void quant_per_block_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + float sm_scale, + int block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + hipLaunchKernelGGL(( QuantInt8Kernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + nullptr, + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + sm_scale, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); +} + +void quant_per_block_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + hipLaunchKernelGGL(( QuantInt8Kernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + nullptr, + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + 0.0f, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); +} + +void quant_per_block_int8_fuse_sub_mean_cuda( + torch::Tensor input, + torch::Tensor mean, + torch::Tensor output, + torch::Tensor scale, + int block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(mean); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + auto mean_dtype = mean.scalar_type(); + + TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + hipLaunchKernelGGL(( QuantInt8Kernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(mean.data_ptr()), + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + 0.0f, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + mean.stride(0), mean.stride(1), + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); +} + +// use block size 128 and warp_block size 32 +void quant_per_warp_int8_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int block_size, + int warp_block_size, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, { + DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, (num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE)); + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE * (BLOCK_SIZE / WARP_BLOCK_SIZE), num_heads, batch_size); + + constexpr int num_pack_per_thread = (WARP_BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(WARP_BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + hipLaunchKernelGGL(( QuantInt8Kernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + nullptr, + output.data_ptr(), + reinterpret_cast(scale.data_ptr()), + 0.0, + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + 0, 0, + stride_bz_output, stride_seq_output, stride_h_output, + scale.stride(0), scale.stride(1) + ); + }); + }); + }); + }); +} + +void sub_mean_cuda( + torch::Tensor input, + torch::Tensor mean, + torch::Tensor output, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(mean); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(output, 4); + + CHECK_DTYPE(output, torch::kHalf); + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_seq_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_seq_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_seq_output = output.stride(2); + stride_h_output = output.stride(1); + } + + auto input_dtype = input.scalar_type(); + auto mean_dtype = mean.scalar_type(); + + TORCH_CHECK(input_dtype == mean_dtype, "Input and mean must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + + constexpr int BLOCK_SIZE = (HEAD_DIM == 128) ? 64 : 128; + + dim3 grid((num_tokens + BLOCK_SIZE - 1) / BLOCK_SIZE, num_heads, batch_size); + + constexpr int num_pack_per_thread = (BLOCK_SIZE * (HEAD_DIM / 8) + 1023) / 1024; + + dim3 block(BLOCK_SIZE * (HEAD_DIM / 8) / num_pack_per_thread); + + hipLaunchKernelGGL(( SubMeanKernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(mean.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + mean.stride(0), mean.stride(1), + stride_bz_output, stride_seq_output, stride_h_output + ); + }); + }); +} + +#endif + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + + constexpr int CTA_SIZE = 64; + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, padded_num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens); + } + + auto input_dtype = input.scalar_type(); + auto output_dtype = output.scalar_type(); + + TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size); + + static_assert(CTA_SIZE * HEAD_DIM <= 8192); + + dim3 block(CTA_SIZE * (HEAD_DIM / 8)); + + hipLaunchKernelGGL(( TransposePadPermuteKernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output + ); + }); + }); +} + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + hipLaunchKernelGGL(( MeanScaleKernel<64, false, c_type>), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + nullptr, + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + 0, 0, + scale.stride(0), scale.stride(1) + ); + }); +} + +#if !defined(USE_ROCM) +void mean_scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor mean, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(mean); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(mean, torch::kFloat); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(mean); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(mean, 3); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(mean, batch_size, num_heads, head_dim); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + hipLaunchKernelGGL(( MeanScaleKernel<64, true, c_type>), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + reinterpret_cast(mean.data_ptr()), + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + mean.stride(0), mean.stride(1), + scale.stride(0), scale.stride(1) + ); + }); +} +#endif \ No newline at end of file diff --git a/csrc/fused/pybind.cpp b/csrc/fused/pybind.cpp index bffdb060..de652000 100644 --- a/csrc/fused/pybind.cpp +++ b/csrc/fused/pybind.cpp @@ -20,14 +20,19 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#if !defined(USE_ROCM) m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); m.def("quant_per_block_int8_fuse_sub_mean_cuda", py::overload_cast(&quant_per_block_int8_fuse_sub_mean_cuda), "quant_per_block_int8_fuse_sub_mean_cuda"); m.def("quant_per_warp_int8_cuda", py::overload_cast(&quant_per_warp_int8_cuda), "quant_per_warp_int8_cuda"); m.def("sub_mean_cuda", py::overload_cast(&sub_mean_cuda), "sub_mean_cuda"); +#endif m.def("transpose_pad_permute_cuda", py::overload_cast(&transpose_pad_permute_cuda), "transpose_pad_permute_cuda"); m.def("scale_fuse_quant_cuda", py::overload_cast(&scale_fuse_quant_cuda), "scale_fuse_quant_cuda"); + +#if !defined(USE_ROCM) m.def("mean_scale_fuse_quant_cuda", py::overload_cast(&mean_scale_fuse_quant_cuda), "mean_scale_fuse_quant_cuda"); +#endif } \ No newline at end of file diff --git a/csrc/fused/pybind_hip.cpp b/csrc/fused/pybind_hip.cpp new file mode 100644 index 00000000..e721f30a --- /dev/null +++ b/csrc/fused/pybind_hip.cpp @@ -0,0 +1,39 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "fused.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ +#if !defined(USE_ROCM) + m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); + m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); + m.def("quant_per_block_int8_fuse_sub_mean_cuda", py::overload_cast(&quant_per_block_int8_fuse_sub_mean_cuda), "quant_per_block_int8_fuse_sub_mean_cuda"); + m.def("quant_per_warp_int8_cuda", py::overload_cast(&quant_per_warp_int8_cuda), "quant_per_warp_int8_cuda"); + + m.def("sub_mean_cuda", py::overload_cast(&sub_mean_cuda), "sub_mean_cuda"); +#endif + + m.def("transpose_pad_permute_cuda", py::overload_cast(&transpose_pad_permute_cuda), "transpose_pad_permute_cuda"); + m.def("scale_fuse_quant_cuda", py::overload_cast(&scale_fuse_quant_cuda), "scale_fuse_quant_cuda"); + +#if !defined(USE_ROCM) + m.def("mean_scale_fuse_quant_cuda", py::overload_cast(&mean_scale_fuse_quant_cuda), "mean_scale_fuse_quant_cuda"); +#endif +} \ No newline at end of file diff --git a/csrc/fused/rocm/dispatch_utils.h b/csrc/fused/rocm/dispatch_utils.h new file mode 100755 index 00000000..3bf08f82 --- /dev/null +++ b/csrc/fused/rocm/dispatch_utils.h @@ -0,0 +1,112 @@ +// /* +// * Copyright (c) 2024 by SageAttention team. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +// #pragma once +// #include +// #include +// #include +// #include + +// #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ +// if (head_dim == 64) { \ +// constexpr int HEAD_DIM = 64; \ +// __VA_ARGS__ \ +// } else if (head_dim == 128) { \ +// constexpr int HEAD_DIM = 128; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported head dim: " << int(head_dim); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ +// if (is_causal == 1) { \ +// constexpr bool IS_CAUSAL = true; \ +// __VA_ARGS__ \ +// } else if (is_causal == 0) { \ +// constexpr bool IS_CAUSAL = false; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported causal mode: " << int(is_causal); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \ +// if (qk_quant_gran == 2) { \ +// constexpr int QK_QUANT_GRAN = 2; \ +// __VA_ARGS__ \ +// } else if (qk_quant_gran == 3) { \ +// constexpr int QK_QUANT_GRAN = 3; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \ +// if (return_lse == 1) { \ +// constexpr bool RETURN_LSE = true; \ +// __VA_ARGS__ \ +// } else if (return_lse == 0) { \ +// constexpr bool RETURN_LSE = false; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported causal mode: " << int(return_lse); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ +// if (pytorch_dtype == at::ScalarType::Half) { \ +// using c_type = half; \ +// __VA_ARGS__ \ +// } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ +// using c_type = hip_bfloat16; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream oss; \ +// oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ +// TORCH_CHECK(false, oss.str()); \ +// } + +// #define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ +// if (block_size == 64) { \ +// constexpr int BLOCK_SIZE = 64; \ +// __VA_ARGS__ \ +// } else if (block_size == 128) { \ +// constexpr int BLOCK_SIZE = 128; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported block_size " << int(block_size); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \ +// if (warp_block_size == 16) { \ +// constexpr int WARP_BLOCK_SIZE = 16; \ +// __VA_ARGS__ \ +// } else if (warp_block_size == 32) { \ +// constexpr int WARP_BLOCK_SIZE = 32; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported warp_block_size " << int(warp_block_size); \ +// throw std::invalid_argument(err_msg.str()); \ +// } diff --git a/csrc/fused/rocm/fused.cu b/csrc/fused/rocm/fused.cu new file mode 100755 index 00000000..cb028b40 --- /dev/null +++ b/csrc/fused/rocm/fused.cu @@ -0,0 +1,451 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include +#include + +#include "../../dispatch_utils.h" +#include "../../utils.cuh" +#include "../../reduction_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +enum class QuantType +{ + kInt8, + kInt4, +}; + +__device__ __forceinline__ float u32_as_f32(uint32_t u) { + union { uint32_t u; float f; } v{u}; return v.f; +} + +__device__ __forceinline__ uint16_t bf16_bits(__hip_bfloat16 x) { + return *reinterpret_cast(&x); +} + +// ========== to-float ========== +template +__device__ __forceinline__ float convert_to_float(T val); + +// __half → float +template <> +__device__ __forceinline__ float convert_to_float<__half>(__half v) { + return __half2float(v); +} + +// __hip_bfloat16 → float +template <> +__device__ __forceinline__ float convert_to_float<__hip_bfloat16>(__hip_bfloat16 v) { + uint16_t hi = bf16_bits(v); + return u32_as_f32(uint32_t(hi) << 16); +} + +// hip_bfloat16 → float +template <> +__device__ __forceinline__ float convert_to_float(hip_bfloat16 v) { + return convert_to_float(*reinterpret_cast(&v)); +} + +template +__device__ __forceinline__ T convert_from_float(float val) { + static_assert(std::is_same::value || + std::is_same::value, + "Only __half and __hip_bfloat16 are supported (ROCm)."); + + if constexpr (std::is_same::value) { + // f32 -> fp16 (round-to-nearest) + return __float2half_rn(val); + } else { + // f32 -> bf16 (round-to-nearest) + return __float2bfloat16(val); + } +} + +namespace detail { + + struct vec16_t { float x, y, z, w; }; + + template + __device__ __forceinline__ void predicated_g2s_16B(T* smem_dst, const T* gmem_src, bool pred) { + if (pred) { + *reinterpret_cast(smem_dst) = *reinterpret_cast(gmem_src); + } else if constexpr (PadZero) { + *reinterpret_cast(smem_dst) = vec16_t{0.f, 0.f, 0.f, 0.f}; + } + } + + __device__ __forceinline__ void store_8fp8(const uint32_t* __restrict__ fp8x4, + int8_t* __restrict__ out) { + *reinterpret_cast(out) = *reinterpret_cast(fp8x4); + } + + __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t* dest, float* s0, float* s1) { + + #ifdef __ROCM_ARCH_GFX942 + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + #else + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3); + #endif + + + *dest = (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | ((uint32_t)b3 << 24); + } + +} // namespace detail + +template +__global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ output, float *__restrict__ mean, float *__restrict__ scale, const float scale_max, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_d_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + // static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + + uint32_t head_id = blockIdx.x; + uint32_t batch_id = blockIdx.y; + uint32_t d_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t num_threads = blockDim.x; + uint32_t gmem_stride = num_threads * pack_size; + // pad the number of tokens to 16 to deal with fp8 permute in previous kernel + uint32_t fp8_padded_num_tokens = (num_tokens + 15) / 16 * 16; + uint32_t num_iters = fp8_padded_num_tokens / gmem_stride + ((fp8_padded_num_tokens % gmem_stride) > thread_id * pack_size); + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + d_id * stride_d_input + thread_id * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + d_id * stride_d_output + thread_id * pack_size; + + T x_val[8]; + float x_val_float[8]; + uint32_t x_val_fp8[2]; + + float max_val = - 1000000.0f; + float min_val = 1000000.0f; + float sum_val = 0.0f; + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + float x_temp = convert_to_float(x_val[j]); + max_val = fmaxf(max_val, x_temp); + min_val = fminf(min_val, x_temp); + + if constexpr (sub_mean) + { + sum_val += x_temp; + } + } + } + + // reduce + __shared__ float s_amax_val; + __shared__ float s_mean_val; + + float block_max_val = vllm::blockReduceMax(max_val); + float block_min_val = vllm::blockReduceMin(min_val); + float block_sum_val; + + if constexpr (sub_mean) + { + block_sum_val = vllm::blockReduceSum(sum_val); + } + + if (thread_id == 0) + { + s_mean_val = block_sum_val / fp8_padded_num_tokens; + + if constexpr (sub_mean) + { + s_amax_val = fmaxf(fabsf(block_max_val - s_mean_val), fabsf(block_min_val - s_mean_val)); + mean[batch_id * stride_bz_mean + head_id * stride_h_mean + d_id] = s_mean_val; + } + else + { + s_amax_val = fmaxf(fabsf(block_max_val), fabsf(block_min_val)); + } + + scale[batch_id * stride_bz_scale + head_id * stride_h_scale + d_id] = s_amax_val / scale_max; + } + + __syncthreads(); + + float mean_val = s_mean_val; + float recp_scale = scale_max / s_amax_val; + + // recalculate num_iters to cover all fp8 output tokens to prevent nan in random initialization + uint32_t padded_num_tokens = (num_tokens + pad_size - 1) / pad_size * pad_size; + num_iters = padded_num_tokens / gmem_stride + ((padded_num_tokens % gmem_stride) > thread_id * pack_size); + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[j] = convert_to_float(x_val[j]); + if constexpr (sub_mean) + { + x_val_float[j] = (x_val_float[j] - mean_val) * recp_scale; + } + else + { + x_val_float[j] *= recp_scale; + } + } + + detail::floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + detail::floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + detail::store_8fp8(&x_val_fp8[0], output_ptr_base + i * gmem_stride); + } +} + + +template +__global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) +{ + +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + uint32_t num_threads_per_token = head_dim / pack_size; + uint32_t num_threads_per_cta = CTA_SIZE / pack_size; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * CTA_SIZE + thread_id / num_threads_per_token; + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T* output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + + bx * CTA_SIZE + thread_id % num_threads_per_cta * pack_size + thread_id / num_threads_per_cta * stride_d_output; + + __shared__ T shared_load[CTA_SIZE][head_dim]; + __shared__ T shared_store[head_dim][CTA_SIZE]; + + // 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 + // permute on the seq dimension for fp8 mma + // uint32_t smem_load_row_base = ((thread_id / num_threads_per_token) / 16) * 16; + // uint32_t smem_load_row_mod = (thread_id / num_threads_per_token) % 16; + // uint32_t smem_load_row = smem_load_row_base + (smem_load_row_mod / 8) * 2 + ((smem_load_row_mod / 2) % 4) * 4 + (smem_load_row_mod % 2); + uint32_t smem_load_row = thread_id / num_threads_per_token; + + detail::predicated_g2s_16B( + &shared_load[smem_load_row][ (thread_id % num_threads_per_token) * pack_size ], + input_ptr_base, + thread_base_token < num_tokens); + __syncthreads(); + + uint32_t smem_row_base = thread_id % CTA_SIZE; + uint32_t smem_col_base = thread_id / CTA_SIZE; + uint32_t smem_col_stride = head_dim / 8; + + // TODO: use ldmatrix to do permutation +#pragma unroll + for (uint32_t i = 0; i < 8; i++) + { + shared_store[smem_col_base + i * smem_col_stride][smem_row_base] = shared_load[smem_row_base][smem_col_base + i * smem_col_stride]; + } + + __syncthreads(); + + *reinterpret_cast(output_ptr_base) = + *reinterpret_cast( + &shared_store[ thread_id / num_threads_per_cta ] + [ (thread_id % num_threads_per_cta) * pack_size ]); +} + + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + MeanScaleKernel<64, false, c_type><<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + nullptr, + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + 0, 0, + scale.stride(0), scale.stride(1) + ); + }); +} + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + + constexpr int CTA_SIZE = 64; + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, padded_num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens); + } + + auto input_dtype = input.scalar_type(); + auto output_dtype = output.scalar_type(); + + TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size); + + static_assert(CTA_SIZE * HEAD_DIM <= 8192); + + dim3 block(CTA_SIZE * (HEAD_DIM / 8)); + + TransposePadPermuteKernel<<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output + ); + }); + }); +} diff --git a/csrc/fused/rocm/fused.h b/csrc/fused/rocm/fused.h new file mode 100755 index 00000000..bdfd5b06 --- /dev/null +++ b/csrc/fused/rocm/fused.h @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +// void quant_per_block_int8_cuda( +// torch::Tensor input, +// torch::Tensor output, +// torch::Tensor scale, +// float sm_scale, +// int block_size, +// int tensor_layout); + +// void quant_per_block_int8_cuda( +// torch::Tensor input, +// torch::Tensor output, +// torch::Tensor scale, +// int block_size, +// int tensor_layout); + +// void quant_per_block_int8_fuse_sub_mean_cuda( +// torch::Tensor input, +// torch::Tensor mean, +// torch::Tensor output, +// torch::Tensor scale, +// int block_size, +// int tensor_layout); + +// void quant_per_warp_int8_cuda( +// torch::Tensor input, +// torch::Tensor output, +// torch::Tensor scale, +// int block_size, +// int warp_block_size, +// int tensor_layout); + +// void sub_mean_cuda( +// torch::Tensor input, +// torch::Tensor mean, +// torch::Tensor output, +// int tensor_layout); + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int tensor_layout); + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout); + +// void mean_scale_fuse_quant_cuda( +// torch::Tensor input, +// torch::Tensor output, +// torch::Tensor mean, +// torch::Tensor scale, +// int num_tokens, +// float scale_max, +// int tensor_layout); \ No newline at end of file diff --git a/csrc/fused/rocm/fused.hip b/csrc/fused/rocm/fused.hip new file mode 100644 index 00000000..4a57fd15 --- /dev/null +++ b/csrc/fused/rocm/fused.hip @@ -0,0 +1,482 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +#include +#include + +#include "../../dispatch_utils.h" +#include "../../utils.cuh" +#include "../../reduction_utils_hip.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +enum class QuantType +{ + kInt8, + kInt4, +}; + +// __device__ __forceinline__ float u32_as_f32(uint32_t u) { +// union { uint32_t u; float f; } v{u}; return v.f; +// } + +// __device__ __forceinline__ uint16_t bf16_bits(__hip_bfloat16 x) { +// return *reinterpret_cast(&x); +// } + +// // ========== to-float ========== +// template +// __device__ __forceinline__ float convert_to_float(T val); + +// // __half → float +// template <> +// __device__ __forceinline__ float convert_to_float<__half>(__half v) { +// return __half2float(v); +// } + +// // __hip_bfloat16 → float +// template <> +// __device__ __forceinline__ float convert_to_float<__hip_bfloat16>(__hip_bfloat16 v) { +// uint16_t hi = bf16_bits(v); +// return u32_as_f32(uint32_t(hi) << 16); +// } + +// // hip_bfloat16 → float +// template <> +// __device__ __forceinline__ float convert_to_float(hip_bfloat16 v) { +// return convert_to_float(*reinterpret_cast(&v)); +// } + +// template +// __device__ __forceinline__ T convert_from_float(float val) { +// static_assert(std::is_same::value || +// std::is_same::value, +// "Only __half and __hip_bfloat16 are supported ."); + +// if constexpr (std::is_same::value) { +// // f32 -> fp16 (round-to-nearest) +// return __float2half_rn(val); +// } else { +// // f32 -> bf16 (round-to-nearest) +// return __float2bfloat16(val); +// } +// } + +template +__device__ __forceinline__ float convert_to_float(T val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __half2float(val); + } + else if constexpr (std::is_same::value) + { + return __bfloat162float(val); + } +} + +template +__device__ __forceinline__ T convert_from_float(float val) +{ + static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + if constexpr (std::is_same::value) + { + return __float2half_rn(val); + } + else if constexpr (std::is_same::value) + { + return __float2bfloat16_rn(val); + } +} + +namespace detail { + + struct vec16_t { float x, y, z, w; }; + + template + __device__ __forceinline__ void predicated_g2s_16B(T* smem_dst, const T* gmem_src, bool pred) { + if (pred) { + *reinterpret_cast(smem_dst) = *reinterpret_cast(gmem_src); + } else if constexpr (PadZero) { + *reinterpret_cast(smem_dst) = vec16_t{0.f, 0.f, 0.f, 0.f}; + } + } + + __device__ __forceinline__ void store_8fp8(const uint32_t* __restrict__ fp8x4, + int8_t* __restrict__ out) { + *reinterpret_cast(out) = *reinterpret_cast(fp8x4); + } + + __device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t* dest, float* s0, float* s1) { + + #ifdef __ROCM_ARCH_GFX942 + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3_FNUZ); + #else + uint8_t b0 = __hip_cvt_float_to_fp8(s0[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b1 = __hip_cvt_float_to_fp8(s0[1], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b2 = __hip_cvt_float_to_fp8(s1[0], __HIP_SATFINITE, __HIP_E4M3); + uint8_t b3 = __hip_cvt_float_to_fp8(s1[1], __HIP_SATFINITE, __HIP_E4M3); + #endif + + + *dest = (uint32_t)b0 | ((uint32_t)b1 << 8) | ((uint32_t)b2 << 16) | ((uint32_t)b3 << 24); + } + +} // namespace detail + +template +__global__ void MeanScaleKernel(T *__restrict__ input, int8_t *__restrict__ output, float *__restrict__ mean, float *__restrict__ scale, const float scale_max, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_d_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output, + const uint32_t stride_bz_mean, const uint32_t stride_h_mean, + const uint32_t stride_bz_scale, const uint32_t stride_h_scale) +{ + // static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + + uint32_t head_id = blockIdx.x; + uint32_t batch_id = blockIdx.y; + uint32_t d_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t num_threads = blockDim.x; + uint32_t gmem_stride = num_threads * pack_size; + // pad the number of tokens to 16 to deal with fp8 permute in previous kernel + uint32_t fp8_padded_num_tokens = (num_tokens + 15) / 16 * 16; + uint32_t num_iters = fp8_padded_num_tokens / gmem_stride + ((fp8_padded_num_tokens % gmem_stride) > thread_id * pack_size); + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + d_id * stride_d_input + thread_id * pack_size; + int8_t *output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + d_id * stride_d_output + thread_id * pack_size; + + T x_val[8]; + float x_val_float[8]; + uint32_t x_val_fp8[2]; + + float max_val = - 1000000.0f; + float min_val = 1000000.0f; + float sum_val = 0.0f; + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + float x_temp = convert_to_float(x_val[j]); + max_val = fmaxf(max_val, x_temp); + min_val = fminf(min_val, x_temp); + + if constexpr (sub_mean) + { + sum_val += x_temp; + } + } + } + + // reduce + __shared__ float s_amax_val; + __shared__ float s_mean_val; + + float block_max_val = vllm::blockReduceMax(max_val); + float block_min_val = vllm::blockReduceMin(min_val); + float block_sum_val; + + if constexpr (sub_mean) + { + block_sum_val = vllm::blockReduceSum(sum_val); + } + + if (thread_id == 0) + { + s_mean_val = block_sum_val / fp8_padded_num_tokens; + + if constexpr (sub_mean) + { + s_amax_val = fmaxf(fabsf(block_max_val - s_mean_val), fabsf(block_min_val - s_mean_val)); + mean[batch_id * stride_bz_mean + head_id * stride_h_mean + d_id] = s_mean_val; + } + else + { + s_amax_val = fmaxf(fabsf(block_max_val), fabsf(block_min_val)); + } + + scale[batch_id * stride_bz_scale + head_id * stride_h_scale + d_id] = s_amax_val / scale_max; + } + + __syncthreads(); + + float mean_val = s_mean_val; + float recp_scale = scale_max / s_amax_val; + + // recalculate num_iters to cover all fp8 output tokens to prevent nan in random initialization + uint32_t padded_num_tokens = (num_tokens + pad_size - 1) / pad_size * pad_size; + num_iters = padded_num_tokens / gmem_stride + ((padded_num_tokens % gmem_stride) > thread_id * pack_size); + + for (int i = 0; i < num_iters; i++) + { + *(float4*)(&x_val[0]) = *(float4*)(input_ptr_base + i * gmem_stride); +#pragma unroll + for (uint32_t j = 0; j < 8; j++) + { + x_val_float[j] = convert_to_float(x_val[j]); + if constexpr (sub_mean) + { + x_val_float[j] = (x_val_float[j] - mean_val) * recp_scale; + } + else + { + x_val_float[j] *= recp_scale; + } + } + + detail::floatx4_to_e4m3x4(x_val_fp8, x_val_float, x_val_float + 2); + detail::floatx4_to_e4m3x4(x_val_fp8 + 1, x_val_float + 4, x_val_float + 6); + + detail::store_8fp8(&x_val_fp8[0], output_ptr_base + i * gmem_stride); + } +} + + +template +__global__ void TransposePadPermuteKernel(T *__restrict__ input, T *__restrict__ output, const uint32_t num_tokens, + const uint32_t stride_bz_input, const uint32_t stride_seq_input, const uint32_t stride_h_input, + const uint32_t stride_bz_output, const uint32_t stride_d_output, const uint32_t stride_h_output) +{ + +// static_assert(std::is_same::value || std::is_same::value, "Only half and bfloat16 are supported"); + + constexpr uint32_t pack_size = 8; // float4 contains 8 half or 8 bfloat16 + uint32_t num_threads_per_token = head_dim / pack_size; + uint32_t num_threads_per_cta = CTA_SIZE / pack_size; + + uint32_t bx = blockIdx.x; + uint32_t head_id = blockIdx.y; + uint32_t batch_id = blockIdx.z; + uint32_t thread_id = threadIdx.x; + + uint32_t thread_base_token = bx * CTA_SIZE + thread_id / num_threads_per_token; + + T *input_ptr_base = input + batch_id * stride_bz_input + head_id * stride_h_input + + thread_base_token * stride_seq_input + thread_id % num_threads_per_token * pack_size; + T* output_ptr_base = output + batch_id * stride_bz_output + head_id * stride_h_output + + bx * CTA_SIZE + thread_id % num_threads_per_cta * pack_size + thread_id / num_threads_per_cta * stride_d_output; + + __shared__ T shared_load[CTA_SIZE][head_dim]; + __shared__ T shared_store[head_dim][CTA_SIZE]; + + // 0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15 + // permute on the seq dimension for fp8 mma + // uint32_t smem_load_row_base = ((thread_id / num_threads_per_token) / 16) * 16; + // uint32_t smem_load_row_mod = (thread_id / num_threads_per_token) % 16; + // uint32_t smem_load_row = smem_load_row_base + (smem_load_row_mod / 8) * 2 + ((smem_load_row_mod / 2) % 4) * 4 + (smem_load_row_mod % 2); + uint32_t smem_load_row = thread_id / num_threads_per_token; + + detail::predicated_g2s_16B( + &shared_load[smem_load_row][ (thread_id % num_threads_per_token) * pack_size ], + input_ptr_base, + thread_base_token < num_tokens); + __syncthreads(); + + uint32_t smem_row_base = thread_id % CTA_SIZE; + uint32_t smem_col_base = thread_id / CTA_SIZE; + uint32_t smem_col_stride = head_dim / 8; + + // TODO: use ldmatrix to do permutation +#pragma unroll + for (uint32_t i = 0; i < 8; i++) + { + shared_store[smem_col_base + i * smem_col_stride][smem_row_base] = shared_load[smem_row_base][smem_col_base + i * smem_col_stride]; + } + + __syncthreads(); + + *reinterpret_cast(output_ptr_base) = + *reinterpret_cast( + &shared_store[ thread_id / num_threads_per_cta ] + [ (thread_id % num_threads_per_cta) * pack_size ]); +} + + +void scale_fuse_quant_cuda( + torch::Tensor input, + torch::Tensor output, + torch::Tensor scale, + int num_tokens, + float scale_max, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + CHECK_CUDA(scale); + + // CHECK_DTYPE(output, torch::kInt8); + CHECK_DTYPE(scale, torch::kFloat); + + CHECK_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + CHECK_CONTIGUOUS(scale); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(scale, 3); + + const int batch_size = input.size(0); + const int num_tokens_padded = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_heads, head_dim; + int stride_d_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_heads = input.size(2); + head_dim = input.size(1); + stride_d_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + } + else + { + num_heads = input.size(1); + head_dim = input.size(2); + stride_d_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + } + + CHECK_SHAPE(output, input.size(0), input.size(1), input.size(2), input.size(3)); + CHECK_SHAPE(scale, batch_size, num_heads, head_dim); + + constexpr int CTA_SIZE = 256; + + dim3 grid(num_heads, batch_size, head_dim); + dim3 block(CTA_SIZE); + + auto input_dtype = input.scalar_type(); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + hipLaunchKernelGGL(( MeanScaleKernel<64, false, c_type>), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + nullptr, + reinterpret_cast(scale.data_ptr()), + scale_max, + num_tokens, + stride_bz_input, stride_d_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output, + 0, 0, + scale.stride(0), scale.stride(1) + ); + }); +} + +void transpose_pad_permute_cuda( + torch::Tensor input, + torch::Tensor output, + int tensor_layout) +{ + CHECK_CUDA(input); + CHECK_CUDA(output); + + CHECK_LASTDIM_CONTIGUOUS(input); + CHECK_CONTIGUOUS(output); + + CHECK_DIMS(input, 4); + CHECK_DIMS(output, 4); + + constexpr int CTA_SIZE = 64; + + const int batch_size = input.size(0); + const int head_dim = input.size(3); + + int stride_bz_input = input.stride(0); + int stride_bz_output = output.stride(0); + + int num_tokens, padded_num_tokens, num_heads; + int stride_seq_input, stride_h_input, stride_d_output, stride_h_output; + + if (tensor_layout == 0) + { + num_tokens = input.size(1); + num_heads = input.size(2); + stride_seq_input = input.stride(1); + stride_h_input = input.stride(2); + stride_d_output = output.stride(1); + stride_h_output = output.stride(2); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + + CHECK_SHAPE(output, batch_size, head_dim, num_heads, padded_num_tokens); + } + else + { + num_tokens = input.size(2); + num_heads = input.size(1); + stride_seq_input = input.stride(2); + stride_h_input = input.stride(1); + stride_d_output = output.stride(2); + stride_h_output = output.stride(1); + + padded_num_tokens = (num_tokens + CTA_SIZE - 1) / CTA_SIZE * CTA_SIZE; + CHECK_SHAPE(output, batch_size, num_heads, head_dim, padded_num_tokens); + } + + auto input_dtype = input.scalar_type(); + auto output_dtype = output.scalar_type(); + + TORCH_CHECK(input_dtype == output_dtype, "Input and output must have the same data type"); + + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(input_dtype, c_type, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + dim3 grid(padded_num_tokens / CTA_SIZE, num_heads, batch_size); + + static_assert(CTA_SIZE * HEAD_DIM <= 8192); + + dim3 block(CTA_SIZE * (HEAD_DIM / 8)); + + hipLaunchKernelGGL(( TransposePadPermuteKernel), dim3(grid), dim3(block), 0, 0, + reinterpret_cast(input.data_ptr()), + reinterpret_cast(output.data_ptr()), + num_tokens, + stride_bz_input, stride_seq_input, stride_h_input, + stride_bz_output, stride_d_output, stride_h_output + ); + }); + }); +} diff --git a/csrc/fused/rocm/pybind_rocm.cpp b/csrc/fused/rocm/pybind_rocm.cpp new file mode 100755 index 00000000..d29b8b79 --- /dev/null +++ b/csrc/fused/rocm/pybind_rocm.cpp @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "fused.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ +// m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); +// m.def("quant_per_block_int8_cuda", py::overload_cast(&quant_per_block_int8_cuda), "quant_per_block_int8_cuda"); +// m.def("quant_per_block_int8_fuse_sub_mean_cuda", py::overload_cast(&quant_per_block_int8_fuse_sub_mean_cuda), "quant_per_block_int8_fuse_sub_mean_cuda"); +// m.def("quant_per_warp_int8_cuda", py::overload_cast(&quant_per_warp_int8_cuda), "quant_per_warp_int8_cuda"); + +// m.def("sub_mean_cuda", py::overload_cast(&sub_mean_cuda), "sub_mean_cuda"); + + m.def("transpose_pad_permute_cuda", py::overload_cast(&transpose_pad_permute_cuda), "transpose_pad_permute_cuda"); + m.def("scale_fuse_quant_cuda", py::overload_cast(&scale_fuse_quant_cuda), "scale_fuse_quant_cuda"); +// m.def("mean_scale_fuse_quant_cuda", py::overload_cast(&mean_scale_fuse_quant_cuda), "mean_scale_fuse_quant_cuda"); +} \ No newline at end of file diff --git a/csrc/mma_hip.cuh b/csrc/mma_hip.cuh new file mode 100644 index 00000000..8653369a --- /dev/null +++ b/csrc/mma_hip.cuh @@ -0,0 +1,723 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Adapted from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/mma.cuh + * Copyright (c) 2023 by FlashInfer team. + * + * Modifications copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include + +namespace mma{ + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define MMA_F16F16F32_M16N8K16_ENABLED +#define MMA_F16F16F16_M16N8K16_ENABLED +#define MMA_S8S8S32_M16N8K32_ENABLED +#define MMA_S4S4S32_M16N8K64_ENABLED +#endif +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) +#define MMA_F16F16F32_M16N8K8_ENABLED +#define MMA_F16F16F16_M16N8K8_ENABLED +#define LDMATRIX_M8N8X2_ENABLED +#define LDMATRIX_M8N8X4_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define MMA_F8F8F32_M16N8K16_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120800) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define MMA_F8F8F16_M16N8K16_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define RUNTIME_ASSERT(x) __brkpt() +#else +#include +#define RUNTIME_ASSERT(x) assert(0 && x) +#endif + +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, +}; + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x2 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x2(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X2_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared memory + * to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data from + * shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) { +#ifdef LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + // ! only support half dtype now + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + // ! only support half dtype now + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k16 instruction for row major and column major f16 matrix + * multiplication, accumulated in f16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F16F16F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major int8 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k32_row_col_s8s8s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S8S8S32_M16N8K32_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major int8 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_s8s8s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S8S8S32_M16N8K32_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major int4 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k64_row_col_s4s4s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S4S4S32_M16N8K64_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k64 instruction for row major and column major int4 matrix + * multiplication, accumulated in int32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k64_row_col_s4s4s32(int32_t* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_S4S4S32_M16N8K64_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), + "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C[4]), "r"(C[5]), + "r"(C[6]), "r"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[0]), "=r"(C[1]), "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0), + "r"(0), "r"(0)); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(C[4]), "=r"(C[5]), "=r"(C[6]), "=r"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0), + "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n8k32 instruction for row major and column major fp8 e4m3 matrix + * multiplication, accumulated in fp32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n8k32_row_col_f8f8f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major fp8 matrix + * multiplication, accumulated in fp16. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f16(uint32_t* C_uint32, uint32_t* A, + uint32_t* B) { + //uint32_t* C_uint32 = reinterpret_cast(C); +#ifdef MMA_F8F8F16_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[0]), "=r"(C_uint32[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C_uint32[0]), "r"(C_uint32[1])); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[2]), "=r"(C_uint32[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(C_uint32[2]), "r"(C_uint32[3])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[0]), "=r"(C_uint32[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(0), "r"(0)); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C_uint32[2]), "=r"(C_uint32[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "r"(0), "r"(0)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + + + +/*! + * \brief Wrapper of the mma m16n16k32 instruction for row major and column major fp8 matrix + * multiplication, accumulated in fp32. + * \tparam mma_mode The mode of mma instruction, either kInit or kInplaceUpdate + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uint32_t* A, + uint32_t* B) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + if constexpr (mma_mode == MMAMode::kInplaceUpdate) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else if constexpr (mma_mode == MMAMode::kInit) + { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + } +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +__device__ __forceinline__ void rowsum_f16f16f32(float* d, uint32_t* s) { +#ifdef MMA_F16F16F32_M16N8K16_ENABLED + asm volatile( + "{\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s[0]), "r"(s[1]), "r"(s[2]), "r"(s[3]), "r"(1006648320), // 1006648320 packs two 1.0f in half precision + "r"(1006648320), "f"(d[0]), "f"(d[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +__device__ __forceinline__ void rowsum_f8f8f32(float* d, uint32_t* s) { +#ifdef MMA_F8F8F32_M16N8K16_ENABLED + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s[0]), "r"(s[1]), "r"(s[2]), "r"(s[3]), "r"(943208504), "r"(943208504), // 943208504 packs four 1.0f in e4m3 + "f"(d[0]), "f"(d[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); +#endif +} + +} // namespace mma diff --git a/csrc/numeric_conversion_hip.cuh b/csrc/numeric_conversion_hip.cuh new file mode 100644 index 00000000..7e10a38b --- /dev/null +++ b/csrc/numeric_conversion_hip.cuh @@ -0,0 +1,150 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Copyright (c) 2024 by SageAttention team. + * + * Inspired by CUTLASS, https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/numeric_conversion.h + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include +#include +#include + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define FP8_CAST_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define RUNTIME_ASSERT(x) __brkpt() +#else +#include +#define RUNTIME_ASSERT(x) assert(0 && x) +#endif + +__device__ __forceinline__ void unpack_half2_from_uint32_to_float(float* dest, uint32_t source) { + uint16_t h0 = source & 0xFFFF; + uint16_t h1 = (source >> 16) & 0xFFFF; + asm("cvt.f32.f16 %0, %1;" : "=f"(dest[0]) : "h"(h0)); + asm("cvt.f32.f16 %0, %1;" : "=f"(dest[1]) : "h"(h1)); +} + +__device__ __forceinline__ void floatx4_to_e4m3x4(uint32_t *dest, float *source0, float *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "f"(source0[0]), "f"(source0[1]), "f"(source1[0]), "f"(source1[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void floatx4_to_e5m2x4(uint32_t *dest, float *source0, float *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 lo, %2, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f32 hi, %4, %3;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "f"(source0[0]), "f"(source1[1]), "f"(source1[0]), "f"(source1[1])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void halfx4_to_e4m3x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e4m3x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void halfx4_to_e5m2x4(uint32_t *dest, uint32_t *source0, uint32_t *source1) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo;\n" \ + ".reg .b16 hi;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 lo, %1;\n" \ + "cvt.rn.satfinite.e5m2x2.f16x2 hi, %2;\n" \ + "mov.b32 %0, {lo, hi};\n" \ + "}" \ + : "=r"(dest[0]) : "r"(source0[0]), "r"(source1[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void e4m3x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e4m3x2 %0, lo;\n" \ + "cvt.rn.f16x2.e4m3x2 %1, hi;\n" \ + "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ void e5m2x4_to_halfx4(uint32_t *dest0, uint32_t *dest1, uint32_t *source) +{ +#ifdef FP8_CAST_ENABLED + asm volatile( \ + "{\n" \ + ".reg .b16 lo, hi;\n" \ + "mov.b32 {lo, hi}, %2;\n" \ + "cvt.rn.f16x2.e5m2x2 %0, lo;\n" \ + "cvt.rn.f16x2.e5m2x2 %1, hi;\n" \ + "}\n" : "=r"(dest0[0]), "=r"(dest1[0]) : "r"(source[0])); +#else + RUNTIME_ASSERT("Unsupported CUDA architecture for FP8 CAST instruction"); +#endif +} + +__device__ __forceinline__ int8_t float_to_int8_rn(float x) +{ + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +} \ No newline at end of file diff --git a/csrc/permuted_smem_hip.cuh b/csrc/permuted_smem_hip.cuh new file mode 100644 index 00000000..b6e5d29e --- /dev/null +++ b/csrc/permuted_smem_hip.cuh @@ -0,0 +1,197 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Adapted from Flashinfer, https://github.com/flashinfer-ai/flashinfer/blob/v0.1.5/include/flashinfer/permuted_smem.cuh + * Copyright (c) 2023 by FlashInfer team. + * + * Modifications copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +#include + +#include "cp_async_hip.cuh" +#include "mma_hip.cuh" + +enum class SwizzleMode { + k32B, // for k32B mode, a line of shared memory must have 32B (16 half value) + k64B, // for k64B mode, a line of shared memory must have 64B (32 half value) + k128B, // 128B already spans all banks in shared memory. a line of shared memory can have multiple 128B. +}; + +// Use 128bit as the granularity to fetch/store data per thread to maximize memory bandwidth +using b128_t = uint4; + +/*! + * \brief A stateless shared memory wrapper that uses templates to avoid runtime conditionals. It makes sure + * that access to consecutive rows idx in the same column idx will make full use of the shared memory bank through + * permutation in the granularity of 128bit. + * + * This struct treats all offsets to be the number of `b128_t` elements. It is designed to be stateless, + * meaning it does not maintain any information about the current pointer position. The offset returnd by + * the struct can be used to access the shared memory through the provided interface. + * + * The struct guarantees that the read to permuted offset (i, j) will be the value stored in permuted offset (i, j). + * We assume that shared memory operation operates on at least two consecutive 128-bit values in a row within a warp. + * Under this assumption, we do not permute for k32B mode. + */ +template +struct smem_t { + // The base pointer. + b128_t* base; + // How many b128_t value a row contains + // uint32_t stride; + + __device__ __forceinline__ smem_t() : base(nullptr) {} + template + __device__ __forceinline__ smem_t(T* base) : base((b128_t*)base) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(stride % 8 == 0, "Stride must be multiple of 8 for 128B swizzle mode"); + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(stride == 4, "Stride must be 4 for 64B swizzle mode"); + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + static_assert(stride == 2, "Stride must be 2 for 32B swizzle mode"); + } else { + static_assert(swizzle_mode != swizzle_mode, "Unsupported swizzle mode"); + } + } + + /*! + * \brief Set the base pointer. + */ + template + __device__ __forceinline__ void set_base(T* new_base) { + base = (b128_t*)new_base; + } + + /*! + * \brief Compute the element offset given coordinates in a permuted shared memory. + * \param i The row index. + * \param j The column index. + */ + static __device__ __forceinline__ uint32_t get_permuted_offset(const uint32_t &i, const uint32_t &j) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + return i * stride + (j ^ (i % 8)); + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + return i * stride + (j ^ ((i / 2) % 4)); + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return i * stride + j; + } + } + + /*! + * \tparam step_size The step size to advance the offset in the permuted shared memory. + * \param offset The current offset. + */ + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(const uint32_t &offset) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size % 8 == 0, + "Unsupported step size"); + return offset + step_size; + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4, "Unsupported step size"); + return offset + step_size; + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + static_assert(step_size == 2, "Unsupported step size"); + return offset + step_size; + } + } + + // ! use with care + template + static __device__ __forceinline__ uint32_t advance_offset_by_column(const uint32_t &offset, const uint32_t &step_idx) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, + "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ (0x2 + (0x4 * (step_idx % 2 == 1)))) + (step_idx % 4 == 3) * 8; + } else if constexpr (step_size == 4) { + return (offset ^ 0x4) + (step_idx % 2 == 1) * 8; + } else { + // step_size % 8 == 0 + return offset + step_size; + } + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 2 || step_size == 4, "Unsupported step size"); + if constexpr (step_size == 2) { + return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; + } else { + return offset + step_size; + } + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return offset + step_size; + } + } + + template + static __device__ __forceinline__ uint32_t advance_offset_by_row(const uint32_t &offset) { + if constexpr (swizzle_mode == SwizzleMode::k128B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x4) + step_size * stride; + } else { + // step_size % 8 == 0 + return offset + step_size * stride; + } + } else if constexpr (swizzle_mode == SwizzleMode::k64B) { + static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); + if constexpr (step_size == 4) { + return (offset ^ 0x2) + step_size * stride; + } else { + // step_size % 8 == 0 + return offset + step_size * stride; + } + } else if constexpr (swizzle_mode == SwizzleMode::k32B) { + return offset + step_size * stride; + } + } + + __device__ __forceinline__ void ldmatrix_m8n8x2(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x2(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4(R, smem_ptr); + } + + __device__ __forceinline__ void ldmatrix_m8n8x4_trans(const uint32_t &offset, uint32_t* R) const { + b128_t* smem_ptr = base + offset; + mma::ldmatrix_m8n8x4_trans(R, smem_ptr); + } + + template + __device__ __forceinline__ void load_128b_async(const uint32_t &offset, const T* gptr, bool predicate) const { + b128_t* smem_ptr = base + offset; + cp_async::pred_load_128b( + smem_ptr, reinterpret_cast(gptr), predicate); + } + + template + __device__ __forceinline__ void load_128b_async(const uint32_t &offset, const T* gptr) const { + b128_t* smem_ptr = base + offset; + cp_async::load_128b(smem_ptr, reinterpret_cast(gptr)); + } + + template + __device__ __forceinline__ void store_128b(const uint32_t &offset, T* gptr) const { + *reinterpret_cast(gptr) = *(base + offset); + } +}; \ No newline at end of file diff --git a/csrc/qattn/attn_utils.cuh b/csrc/qattn/attn_utils.cuh index 6f8b87c8..831c32e5 100644 --- a/csrc/qattn/attn_utils.cuh +++ b/csrc/qattn/attn_utils.cuh @@ -16,6 +16,34 @@ #pragma once #include "../utils.cuh" + +#if defined(USE_ROCM) +enum class MaskMode { + kNone = 0, + kCausal = 1, +}; + +enum class DataType { + kHalf, + kInt8, + kInt4, + kE4M3, + kE5M2, +}; + +enum class QuantGranularity { + kPerTensor = 0, + kPerBlock = 1, + kPerWarp = 2, + kPerThread = 3, +}; + +enum class ComputeUnit { + kTensorCore, + kCudaCore, +}; +#else + #include #include #include @@ -989,4 +1017,6 @@ __device__ __forceinline__ void compute_fp8_sv_inst_buf_fp16_accu(const smem_t +#include +#include + +#include "../cp_async_hip.cuh" +#include "../mma_hip.cuh" +#include "../permuted_smem_hip.cuh" +#include "../numeric_conversion_hip.cuh" + +#define WARP_SIZE 32 + +#define S_FP8_OFFSET 8.807f +#define S_FP8_OFFSET_EXP 6680.8477f +#define S_FP8_OFFSET_EXP_INV 0.0022326917f + +#define div_ceil(M, N) (((M) + (N)-1) / (N)) + +enum class MaskMode { + kNone = 0, + kCausal = 1, +}; + +enum class DataType { + kHalf, + kInt8, + kInt4, + kE4M3, + kE5M2, +}; + +enum class QuantGranularity { + kPerTensor = 0, + kPerBlock = 1, + kPerWarp = 2, + kPerThread = 3, +}; + +enum class ComputeUnit { + kTensorCore, + kCudaCore, +}; + +__device__ __forceinline__ uint32_t get_warp_id() +{ + return threadIdx.y; +} + +__device__ __forceinline__ uint32_t get_lane_id() +{ + return threadIdx.x; +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_q() +{ + return get_warp_id() / num_warps_k; +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_k() +{ + return get_warp_id() % num_warps_k; +} + +template +__device__ __forceinline__ void load_global_to_share(T **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + static_assert(std::is_same::value || std::is_same::value); + + constexpr uint32_t pack_size = std::is_same::value ? 8 : 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr); + *lane_ptr += (global_to_shared_line_lanes * pack_size); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size)); + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +// with predicate +template +__device__ __forceinline__ void load_global_to_share(T **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem, uint32_t base_idx, uint32_t max_len) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + static_assert(std::is_same::value || std::is_same::value); + + constexpr uint32_t pack_size = std::is_same::value ? 8 : 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr, base_idx < max_len); + *lane_ptr += (global_to_shared_line_lanes * pack_size); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size)); + base_idx += global_to_shared_copy_lines_per_warp_per_iter; + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +template +__device__ __forceinline__ void load_fp8_V_global_to_share(int8_t **lane_ptr, uint32_t &smem_offset, + const uint32_t &gmem_stride, + const smem_t &smem) +{ + static_assert(global_to_shared_copy_lines_per_warp_per_iter * global_to_shared_line_lanes == WARP_SIZE); + constexpr uint32_t pack_size_fp8 = 16; + +#pragma unroll + for (uint32_t i = 0; i < smem_iters_col; i++) + { +#pragma unroll + for (uint32_t j = 0; j < smem_iters_row; j++) + { + smem.load_128b_async(smem_offset, *lane_ptr); + *lane_ptr += (global_to_shared_line_lanes * pack_size_fp8); + smem_offset = smem.advance_offset_by_column(smem_offset); + } + + smem_offset = smem.advance_offset_by_row(smem_offset - (smem_iters_row * global_to_shared_line_lanes)); + *lane_ptr += ((global_to_shared_copy_lines_per_warp_per_iter * gmem_stride) - (smem_iters_row * global_to_shared_line_lanes * pack_size_fp8)); + } + smem_offset -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter * stride); + // for QK: *lane_ptr += (CTA - smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; + *lane_ptr += CTA; // ! prevent underflow + *lane_ptr -= (smem_iters_col * global_to_shared_copy_lines_per_warp_per_iter) * gmem_stride; +} + +template +__device__ __forceinline__ void compute_int_qk(const smem_t &smem_Q, const smem_t &smem_K, int32_t RS[][num_tiles_k][8], uint32_t &offset_Q, uint32_t &offset_K) +{ + static_assert(DTypeQK == DataType::kInt8 || DTypeQK == DataType::kInt4); + + uint32_t RQ[num_tiles_q][4]; + uint32_t RK[4]; + + // the first iteration, mma mode is kInit +#pragma unroll + for (uint32_t iter = 0; iter < 1; iter++) + { + // load RQ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(offset_Q, RQ[fq]); + offset_Q = smem_Q.advance_offset_by_row<16>(offset_Q); + } + // ! using permutation invariance + offset_Q = smem_Q.advance_offset_by_column<2>(offset_Q - (num_tiles_q * 16 * stride), iter); + +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == DataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == DataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } + offset_K = smem_K.advance_offset_by_column<2>(offset_K - (num_tiles_k * 16 * stride), iter); + } + + // following iteration, mma mode is kInplace +#pragma unroll + for (uint32_t iter = 1; iter < num_tiles_qk_inner; iter++) + { + // load RQ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + smem_Q.ldmatrix_m8n8x4(offset_Q, RQ[fq]); + offset_Q = smem_Q.advance_offset_by_row<16>(offset_Q); + } + offset_Q = smem_Q.advance_offset_by_column<2>(offset_Q - (num_tiles_q * 16 * stride), iter); + +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == DataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == DataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } + offset_K = smem_K.advance_offset_by_column<2>(offset_K - (num_tiles_k * 16 * stride), iter); + } + + offset_Q -= (2 * num_tiles_qk_inner); + offset_K -= (2 * num_tiles_qk_inner); +} + +// for case when num_tiles_qk_inner = 1 +template +__device__ __forceinline__ void compute_int_qk(const smem_t &smem_K, int32_t RS[][num_tiles_k][8], uint32_t RQ[][4], uint32_t offset_K) +{ + static_assert(DTypeQK == DataType::kInt8 || DTypeQK == DataType::kInt4); + static_assert(num_tiles_qk_inner == 1); + + uint32_t RK[4]; + + // mma mode is kInit +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + // load RK + smem_K.ldmatrix_m8n8x4(offset_K, RK); + offset_K = smem_K.advance_offset_by_row<16>(offset_K); + + // mma +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (DTypeQK == DataType::kInt8) + { + mma::mma_sync_m16n16k32_row_col_s8s8s32(RS[fq][fk], RQ[fq], RK); + } + else if constexpr (DTypeQK == DataType::kInt4) + { + mma::mma_sync_m16n16k64_row_col_s4s4s32(RS[fq][fk], RQ[fq], RK); + } + } + } +} + +template +__device__ __forceinline__ void apply_causal_mask(const uint32_t &Q_idx_lane_base, const uint32_t &K_idx_lane_base, DTypeQKAccum RS[][num_tiles_k][8]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t q_idx = Q_idx_lane_base + fq * 16 + 8 * ((k % 4) / 2); + const uint32_t kv_idx = K_idx_lane_base + fk * 16 + 8 * (k / 4) + k % 2; + const bool out_of_boundary = (kv_idx > q_idx); + + if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? -5000000.0f : RS[fq][fk][k]); + } + else if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? __float2half_rn(-50000.0f) : RS[fq][fk][k]); + } + } + } + } +} + +template +__device__ __forceinline__ void apply_out_of_bound_mask(const uint32_t &K_idx_lane_base, DTypeQKAccum RS[][num_tiles_k][8], const uint32_t &kv_len) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + const uint32_t kv_idx = K_idx_lane_base + fk * 16 + 8 * (k / 4) + k % 2; + const bool out_of_boundary = (kv_idx >= kv_len); + + if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? -5000000.0f : RS[fq][fk][k]); + } + else if constexpr (std::is_same::value) + { + RS[fq][fk][k] = (out_of_boundary ? __float2half_rn(-50000.0f) : RS[fq][fk][k]); + } + } + } + } +} + +// for DTypeQKAccum float +template +__device__ __forceinline__ void update_mdo(float RS[][num_tiles_k][8], DTypeSVAccum RO[][num_tiles_v][8], float m[][2], float d[][2], const float &sm_scale) +{ + static_assert(std::is_same::value || (!use_half_o_scale)); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + // assign the smallest value possible + float m_prev = m[fq][k]; + float m_temp = -5000000.0f; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + float m_local = max(max(RS[fq][fk][k * 2 + 0], RS[fq][fk][k * 2 + 1]), + max(RS[fq][fk][k * 2 + 4], RS[fq][fk][k * 2 + 5])); + m_temp = max(m_temp, m_local); + } + + if constexpr (!fuse_scale) + { + if constexpr (exp_offset) + { + m_temp = fmaf(m_temp, sm_scale, -S_FP8_OFFSET); + } + else + { + m_temp *= sm_scale; + } + } + else if constexpr (exp_offset) + { + m_temp += (-S_FP8_OFFSET); + } + + // exchange element with the 4 threads in the row + m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x1)); // 0 exchange with 1, 2 exchange with 3 + m_temp = max(m_temp, __shfl_xor_sync(0xffffffff, m_temp, 0x2)); // 0 exchange with 2, 1 exchange with 3 + + m[fq][k] = max(m[fq][k], m_temp); + + float o_scale = math::ptx_exp2(m_prev - m[fq][k]); + + // update denominator + d[fq][k] *= o_scale; + + half2 o_scale2; + if constexpr (use_half_o_scale) + { + o_scale2 = __floats2half2_rn(o_scale, o_scale); + } + + // update RO +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + if constexpr (std::is_same::value) + { + RO[fq][fv][k * 2 + 0] *= o_scale; + RO[fq][fv][k * 2 + 1] *= o_scale; + RO[fq][fv][k * 2 + 4] *= o_scale; + RO[fq][fv][k * 2 + 5] *= o_scale; + } + else if constexpr (std::is_same::value) + { + if constexpr (use_half_o_scale) + { + ((half2*)RO[fq][fv])[k] = __hmul2(((half2*)RO[fq][fv])[k], o_scale2); + ((half2*)RO[fq][fv])[k + 2] = __hmul2(((half2*)RO[fq][fv])[k + 2], o_scale2); + } + else + { + RO[fq][fv][k * 2 + 0] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 0]) * o_scale); + RO[fq][fv][k * 2 + 1] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 1]) * o_scale); + RO[fq][fv][k * 2 + 4] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 4]) * o_scale); + RO[fq][fv][k * 2 + 5] = __float2half_rn(__half2float(RO[fq][fv][k * 2 + 5]) * o_scale); + } + } + } + + // raise RS to exponent + float negative_m = -m[fq][k]; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + if constexpr (fuse_scale) + { + RS[fq][fk][k * 2 + 0] = math::ptx_exp2(RS[fq][fk][k * 2 + 0] + negative_m); + RS[fq][fk][k * 2 + 1] = math::ptx_exp2(RS[fq][fk][k * 2 + 1] + negative_m); + RS[fq][fk][k * 2 + 4] = math::ptx_exp2(RS[fq][fk][k * 2 + 4] + negative_m); + RS[fq][fk][k * 2 + 5] = math::ptx_exp2(RS[fq][fk][k * 2 + 5] + negative_m); + } + else + { + RS[fq][fk][k * 2 + 0] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 0], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 1] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 1], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 4] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 4], sm_scale, negative_m)); + RS[fq][fk][k * 2 + 5] = math::ptx_exp2(fmaf(RS[fq][fk][k * 2 + 5], sm_scale, negative_m)); + } + } + } + } +} + +template +__device__ __forceinline__ void RS_32_to_16(T RS[][num_tiles_k][8], uint32_t RS_16[][num_tiles_k][4]) +{ + static_assert(sizeof(T) == 4); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + ((half2*)RS_16[fq][fk])[0] = __float22half2_rn(((float2*)RS[fq][fk])[0]); + ((half2*)RS_16[fq][fk])[1] = __float22half2_rn(((float2*)RS[fq][fk])[1]); + ((half2*)RS_16[fq][fk])[2] = __float22half2_rn(((float2*)RS[fq][fk])[2]); + ((half2*)RS_16[fq][fk])[3] = __float22half2_rn(((float2*)RS[fq][fk])[3]); + } + } +} + +template +__device__ __forceinline__ void RS_32_to_8(float RS[][num_tiles_k][8], uint32_t RS_8[][num_tiles_k / 2][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + floatx4_to_e4m3x4(RS_8[fq][fk], RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 4); + floatx4_to_e4m3x4(RS_8[fq][fk] + 1, RS[fq][fk * 2 + 0] + 2, RS[fq][fk * 2 + 0] + 6); + floatx4_to_e4m3x4(RS_8[fq][fk] + 2, RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 4); + floatx4_to_e4m3x4(RS_8[fq][fk] + 3, RS[fq][fk * 2 + 1] + 2, RS[fq][fk * 2 + 1] + 6); + } + } +} + +template +__device__ __forceinline__ void RS_16_to_8(uint32_t RS[][num_tiles_k][4], uint32_t RS_8[][num_tiles_k / 2][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + halfx4_to_e4m3x4(RS_8[fq][fk], RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 2); + halfx4_to_e4m3x4(RS_8[fq][fk] + 1, RS[fq][fk * 2 + 0] + 1, RS[fq][fk * 2 + 0] + 3); + halfx4_to_e4m3x4(RS_8[fq][fk] + 2, RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 2); + halfx4_to_e4m3x4(RS_8[fq][fk] + 3, RS[fq][fk * 2 + 1] + 1, RS[fq][fk * 2 + 1] + 3); + } + } +} + +template +__device__ __forceinline__ void RS_8_to_16(uint32_t RS_8[][num_tiles_k / 2][4], uint32_t RS[][num_tiles_k][4]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + e4m3x4_to_halfx4(RS[fq][fk * 2 + 0], RS[fq][fk * 2 + 0] + 2, RS_8[fq][fk]); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 0] + 1, RS[fq][fk * 2 + 0] + 3, RS_8[fq][fk] + 1); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 1], RS[fq][fk * 2 + 1] + 2, RS_8[fq][fk] + 2); + e4m3x4_to_halfx4(RS[fq][fk * 2 + 1] + 1, RS[fq][fk * 2 + 1] + 3, RS_8[fq][fk] + 3); + } + } +} + +template +__device__ __forceinline__ void accumulate_d(T RS[][num_tiles_k][(compute_unit == ComputeUnit::kTensorCore)? 4 : 8], float d[][2]) +{ + // for compute unit cuda core, RS is float + // for compute unit tensor core, RS is packed half + static_assert((std::is_same::value && compute_unit == ComputeUnit::kCudaCore) || + (std::is_same::value && compute_unit == ComputeUnit::kTensorCore)); + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { + if constexpr (compute_unit == ComputeUnit::kTensorCore) + { + // full accumulate with tensor core + mma::rowsum_f16f16f32(d[fq], (uint32_t*)(RS[fq][fk])); + } + else if constexpr (compute_unit == ComputeUnit::kCudaCore) + { + // partial accumulate with cuda core + d[fq][0] += RS[fq][fk][0] + RS[fq][fk][1] + RS[fq][fk][4] + RS[fq][fk][5]; + d[fq][1] += RS[fq][fk][2] + RS[fq][fk][3] + RS[fq][fk][6] + RS[fq][fk][7]; + } + } + } +} + +template +__device__ __forceinline__ void accumulate_d_f8(uint32_t RS[][num_tiles_k / 2][4], float d[][2]) +{ +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + mma::rowsum_f8f8f32(d[fq], RS[fq][fk]); + } + } +} + +template +__device__ __forceinline__ void compute_fp16_sv(const smem_t &smem_V, uint32_t RS_f16[][num_tiles_k][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_warp_idx_k() * (num_tiles_k * 16) + get_lane_id() % 16; + uint32_t smem_V_col_base = get_lane_id() / 16; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fk * 16, smem_V_col_base + fv * 2); + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32(RO[fq][fv], RS_f16[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO[fq][fv], RS_f16[fq][fk], RV); + } + } + } + } +} + +template +__device__ __forceinline__ void compute_fp16_sv_permuted(const smem_t &smem_V, T RS_f16[][num_tiles_k][RS_width], DTypeSVAccum RO[][num_tiles_v][8], float d[][2], uint32_t &offset_V) +{ + static_assert(sizeof(T) == 4); + + // ! be sure you know what you are doing +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32(RO[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + else if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + + // make offset_V their original value + offset_V -= (16 * num_tiles_k * stride); +} + +template +__device__ __forceinline__ void compute_fp16_sv_permuted_inst_buf(const smem_t &smem_V, T RS_f16[][num_tiles_k][RS_width], DTypeSVAccum RO[][num_tiles_v][8], float d[][2], uint32_t &offset_V) +{ + static_assert(sizeof(T) == 4); + static_assert(std::is_same::value); + + uint32_t RO_inst_buf[num_tiles_q][num_tiles_v][4]; + + // ! be sure you know what you are doing +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO_inst_buf[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k; fk++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + smem_V.ldmatrix_m8n8x4_trans(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)RO_inst_buf[fq][fv], (uint32_t*)(RS_f16[fq][fk]), RV); + } + } + + offset_V = smem_V.advance_offset_by_column<2>(offset_V, fv); + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V - (2 * num_tiles_v)); + } + + // accumulate into RO +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + RO[fq][fv][0] += __half2float(((half2*)RO_inst_buf[fq][fv])[0].x); + RO[fq][fv][1] += __half2float(((half2*)RO_inst_buf[fq][fv])[0].y); + RO[fq][fv][2] += __half2float(((half2*)RO_inst_buf[fq][fv])[1].x); + RO[fq][fv][3] += __half2float(((half2*)RO_inst_buf[fq][fv])[1].y); + RO[fq][fv][4] += __half2float(((half2*)RO_inst_buf[fq][fv])[2].x); + RO[fq][fv][5] += __half2float(((half2*)RO_inst_buf[fq][fv])[2].y); + RO[fq][fv][6] += __half2float(((half2*)RO_inst_buf[fq][fv])[3].x); + RO[fq][fv][7] += __half2float(((half2*)RO_inst_buf[fq][fv])[3].y); + } + } + + // make offset_V their original value + offset_V -= (16 * num_tiles_k * stride); +} + +template +__device__ __forceinline__ void normalize_d(DTypeSVAccum RO[][num_tiles_v][8], DTypeQKAccum m[][2], float d[][2]) +{ + if constexpr (compute_unit == ComputeUnit::kCudaCore) + { + // accumulate_d performs partial accumulation with cuda core + // aggregate d +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + d[fq][k] += __shfl_xor_sync(0xffffffff, d[fq][k], 0x1); // sum 0 and 1, 2 and 3 + d[fq][k] += __shfl_xor_sync(0xffffffff, d[fq][k], 0x2); // sum 0 and 2, 1 and 3 + } + } + } + + // divide O by d + float d_rcp[num_tiles_q][2]; +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t k = 0; k < 2; k++) + { + // TODO: check m to prevent nan + d_rcp[fq][k] = math::ptx_rcp(d[fq][k]); + } + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { +#pragma unroll + for (uint32_t k = 0; k < 8; k++) + { + if constexpr (std::is_same::value) + { + RO[fq][fv][k] *= d_rcp[fq][(k % 4) / 2]; + } + else if constexpr (std::is_same::value) + { + RO[fq][fv][k] = __float2half_rn(__half2float(RO[fq][fv][k]) * d_rcp[fq][(k % 4) / 2]); + } + } + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; +#pragma unroll + for (uint32_t fk = 0; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv_inst_buf(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; + + float RO_inst_buf[num_tiles_q][num_tiles_v][8]; + +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + RO[fq][fv][0] += RO_inst_buf[fq][fv][0]; + RO[fq][fv][1] += RO_inst_buf[fq][fv][1]; + RO[fq][fv][2] += RO_inst_buf[fq][fv][2]; + RO[fq][fv][3] += RO_inst_buf[fq][fv][3]; + RO[fq][fv][4] += RO_inst_buf[fq][fv][4]; + RO[fq][fv][5] += RO_inst_buf[fq][fv][5]; + RO[fq][fv][6] += RO_inst_buf[fq][fv][6]; + RO[fq][fv][7] += RO_inst_buf[fq][fv][7]; + } + } +} + +template +__device__ __forceinline__ void compute_fp8_sv_inst_buf_fp16_accu(const smem_t &smem_V, uint32_t RS_f8[][num_tiles_k / 2][4], DTypeSVAccum RO[][num_tiles_v][8], float d[][2]) +{ + uint32_t smem_V_row_base = get_lane_id() % 8 + (get_lane_id() / 16) * 8; + // uint32_t smem_V_col_base = get_warp_idx_k() * ((16 * num_tiles_k) / 16) + (get_lane_id() / 8) % 2; + uint32_t smem_V_col_base = (get_lane_id() / 8) % 2; + + uint32_t RO_int32[num_tiles_q][num_tiles_v][4]; + +#pragma unroll + for (uint32_t fk = 0; fk < 1; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + //mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + mma::mma_sync_m16n16k32_row_col_f8f8f16(RO_int32[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + +#pragma unroll + for (uint32_t fk = 1; fk < num_tiles_k / 2; fk++) + { + uint32_t offset_V = smem_V.get_permuted_offset(smem_V_row_base, smem_V_col_base + fk * 2); +#pragma unroll + for (uint32_t fv = 0; fv < num_tiles_v; fv++) + { + // load RV + uint32_t RV[4]; + // uint32_t offset_V = (smem_V).get_permuted_offset(smem_V_row_base + fv * 16, smem_V_col_base + fk * 2); + smem_V.ldmatrix_m8n8x4(offset_V, RV); +#pragma unroll + for (uint32_t fq = 0; fq < num_tiles_q; fq++) + { + if constexpr (std::is_same::value) + { + //mma::mma_sync_m16n16k32_row_col_f8f8f32(RO_inst_buf[fq][fv], RS_f8[fq][fk], RV); + mma::mma_sync_m16n16k32_row_col_f8f8f16(RO_int32[fq][fv], RS_f8[fq][fk], RV); + } + else if constexpr (std::is_same::value) + { + // ! Not Implemented + } + } + offset_V = smem_V.advance_offset_by_row<16>(offset_V); + } + } + float RO_tmp_float[2]; +#pragma unroll + for(int i = 0; i < num_tiles_q; i++){ +#pragma unroll + for(int j = 0; j < num_tiles_v; j++){ + #pragma unroll + for(int k = 0; k < 4; k++){ + unpack_half2_from_uint32_to_float(RO_tmp_float, RO_int32[i][j][k]); + RO[i][j][k * 2 + 0] += RO_tmp_float[0]; + RO[i][j][k * 2 + 1] += RO_tmp_float[1]; + } + } + } + +// #pragma unroll +// for (uint32_t fq = 0; fq < num_tiles_q; fq++) +// { +// #pragma unroll +// for (uint32_t fv = 0; fv < num_tiles_v; fv++) +// { +// RO[fq][fv][0] += RO_inst_buf[fq][fv][0]; +// RO[fq][fv][1] += RO_inst_buf[fq][fv][1]; +// RO[fq][fv][2] += RO_inst_buf[fq][fv][2]; +// RO[fq][fv][3] += RO_inst_buf[fq][fv][3]; +// RO[fq][fv][4] += RO_inst_buf[fq][fv][4]; +// RO[fq][fv][5] += RO_inst_buf[fq][fv][5]; +// RO[fq][fv][6] += RO_inst_buf[fq][fv][6]; +// RO[fq][fv][7] += RO_inst_buf[fq][fv][7]; +// } +// } +} + +#endif \ No newline at end of file diff --git a/csrc/qattn/rocm/args.hpp b/csrc/qattn/rocm/args.hpp new file mode 100644 index 00000000..586bf1e7 --- /dev/null +++ b/csrc/qattn/rocm/args.hpp @@ -0,0 +1,34 @@ +#pragma once + +#include + +namespace gfx9Params +{ + enum kernelParams : uint32_t + { + ROCWMMA_M = 16u, + ROCWMMA_N = 16u, + ROCWMMA_K = 32u, + BLOCKS_X = 1u, + BLOCKS_Y = 4u, + TBLOCK_X = 256u, + TBLOCK_Y = 1u, + WARP_SIZE = rocwmma::Constants::AMDGCN_WAVE_SIZE_64 + }; +} + +namespace gfx11Params +{ + enum kernelParams : uint32_t + { + ROCWMMA_M = 16u, + ROCWMMA_N = 16u, + ROCWMMA_K = 32u, + BLOCKS_X = 1u, + BLOCKS_Y = 4u, + TBLOCK_X = 128u, + TBLOCK_Y = 1u, + WARP_SIZE = rocwmma::Constants::AMDGCN_WAVE_SIZE_32 + }; +} + diff --git a/csrc/qattn/rocm/attn_rocm.h b/csrc/qattn/rocm/attn_rocm.h new file mode 100755 index 00000000..8c0b62d0 --- /dev/null +++ b/csrc/qattn/rocm/attn_rocm.h @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +torch::Tensor launch_sgattn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + // torch::Tensor value_mean, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse); \ No newline at end of file diff --git a/csrc/qattn/rocm/attn_utils.h b/csrc/qattn/rocm/attn_utils.h new file mode 100755 index 00000000..888d19de --- /dev/null +++ b/csrc/qattn/rocm/attn_utils.h @@ -0,0 +1,81 @@ +// /* +// * Copyright (c) 2024 by SageAttention team. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +// #pragma once +// #include "../../utils.cuh" +// // #include +// // #include +// // #include + +// // #include "../cp_async.cuh" +// // #include "../mma.cuh" +// // #include "../permuted_smem.cuh" +// // #include "../numeric_conversion.cuh" + +// // #define WARP_SIZE_CUDA 32 + +// // #define S_FP8_OFFSET 8.807f +// // #define S_FP8_OFFSET_EXP 6680.8477f +// // #define S_FP8_OFFSET_EXP_INV 0.0022326917f + +// // #define div_ceil(M, N) (((M) + (N)-1) / (N)) + +// enum class MaskMode { +// kNone = 0, +// kCausal = 1, +// }; + +// enum class DataType { +// kHalf, +// kInt8, +// kInt4, +// kE4M3, +// kE5M2, +// }; + +// enum class QuantGranularity { +// kPerTensor = 0, +// kPerBlock = 1, +// kPerWarp = 2, +// kPerThread = 3, +// }; + +// enum class ComputeUnit { +// kTensorCore, +// kCudaCore, +// }; + +// // __device__ __forceinline__ uint32_t get_warp_id() +// // { +// // return threadIdx.y; +// // } + +// // __device__ __forceinline__ uint32_t get_lane_id() +// // { +// // return threadIdx.x; +// // } + +// // template +// // __device__ __forceinline__ uint32_t get_warp_idx_q() +// // { +// // return get_warp_id() / num_warps_k; +// // } + +// // template +// // __device__ __forceinline__ uint32_t get_warp_idx_k() +// // { +// // return get_warp_id() % num_warps_k; +// // } diff --git a/csrc/qattn/rocm/dispatch_utils.h b/csrc/qattn/rocm/dispatch_utils.h new file mode 100755 index 00000000..3bf08f82 --- /dev/null +++ b/csrc/qattn/rocm/dispatch_utils.h @@ -0,0 +1,112 @@ +// /* +// * Copyright (c) 2024 by SageAttention team. +// * +// * Licensed under the Apache License, Version 2.0 (the "License"); +// * you may not use this file except in compliance with the License. +// * You may obtain a copy of the License at +// * +// * http://www.apache.org/licenses/LICENSE-2.0 +// * +// * Unless required by applicable law or agreed to in writing, software +// * distributed under the License is distributed on an "AS IS" BASIS, +// * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// * See the License for the specific language governing permissions and +// * limitations under the License. +// */ + +// #pragma once +// #include +// #include +// #include +// #include + +// #define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ +// if (head_dim == 64) { \ +// constexpr int HEAD_DIM = 64; \ +// __VA_ARGS__ \ +// } else if (head_dim == 128) { \ +// constexpr int HEAD_DIM = 128; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported head dim: " << int(head_dim); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_CAUSAL(is_causal, IS_CAUSAL, ...) \ +// if (is_causal == 1) { \ +// constexpr bool IS_CAUSAL = true; \ +// __VA_ARGS__ \ +// } else if (is_causal == 0) { \ +// constexpr bool IS_CAUSAL = false; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported causal mode: " << int(is_causal); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, ...) \ +// if (qk_quant_gran == 2) { \ +// constexpr int QK_QUANT_GRAN = 2; \ +// __VA_ARGS__ \ +// } else if (qk_quant_gran == 3) { \ +// constexpr int QK_QUANT_GRAN = 3; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported qk_quant_gran: " << int(qk_quant_gran); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, ...) \ +// if (return_lse == 1) { \ +// constexpr bool RETURN_LSE = true; \ +// __VA_ARGS__ \ +// } else if (return_lse == 0) { \ +// constexpr bool RETURN_LSE = false; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported causal mode: " << int(return_lse); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ +// if (pytorch_dtype == at::ScalarType::Half) { \ +// using c_type = half; \ +// __VA_ARGS__ \ +// } else if (pytorch_dtype == at::ScalarType::BFloat16) { \ +// using c_type = hip_bfloat16; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream oss; \ +// oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ +// TORCH_CHECK(false, oss.str()); \ +// } + +// #define DISPATCH_BLOCK_SIZE(block_size, BLOCK_SIZE, ...) \ +// if (block_size == 64) { \ +// constexpr int BLOCK_SIZE = 64; \ +// __VA_ARGS__ \ +// } else if (block_size == 128) { \ +// constexpr int BLOCK_SIZE = 128; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported block_size " << int(block_size); \ +// throw std::invalid_argument(err_msg.str()); \ +// } + +// #define DISPATCH_WARP_BLOCK_SIZE(warp_block_size, WARP_BLOCK_SIZE, ...) \ +// if (warp_block_size == 16) { \ +// constexpr int WARP_BLOCK_SIZE = 16; \ +// __VA_ARGS__ \ +// } else if (warp_block_size == 32) { \ +// constexpr int WARP_BLOCK_SIZE = 32; \ +// __VA_ARGS__ \ +// } else { \ +// std::ostringstream err_msg; \ +// err_msg << "Unsupported warp_block_size " << int(warp_block_size); \ +// throw std::invalid_argument(err_msg.str()); \ +// } diff --git a/csrc/qattn/rocm/launch_sgattn.cu b/csrc/qattn/rocm/launch_sgattn.cu new file mode 100755 index 00000000..0062fbd2 --- /dev/null +++ b/csrc/qattn/rocm/launch_sgattn.cu @@ -0,0 +1,237 @@ +#include "../../dispatch_utils.h" +#include "sgattn.hip" +#include +#include + +#include +#if defined(USE_ROCM) + #include + #define TORCH_STREAM at::hip::getCurrentHIPStream().stream() + #define DEVICE_GUARD at::hip::HIPGuard device_guard(query.device()); +#else + #include + #define TORCH_STREAM at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream() + #define DEVICE_GUARD at::hip::HIPGuardMasqueradingAsCUDA device_guard(query.device()); +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "args.hpp" + + +#ifdef __ROCM_ARCH_GFX942 + using fp8_type = __hip_fp8_e4m3_fnuz; + namespace Params = gfx9Params; +#else + using fp8_type = __hip_fp8_e4m3; + namespace Params = gfx11Params; +#endif + +torch::Tensor launch_sgattn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + // torch::Tensor value_mean, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + auto pad_seq128 = [](const torch::Tensor& x) { + TORCH_CHECK(x.dim() == 4, "expect [B,H,Seq,D]"); + const int64_t seq = x.size(2); + const int64_t pad = (128 - (seq % 128)) % 128; + if (pad == 0) return x.contiguous(); + return at::constant_pad_nd(x, {0,0, 0,pad}, /*value=*/0).contiguous(); + }; + + auto Q_pad = pad_seq128(query); // [B,Hq,M_pad,D] + auto K_pad = pad_seq128(key); // [B,Hk,N_pad,D] + auto O_pad = at::zeros(Q_pad.sizes(), output.options().dtype(torch::kFloat)).contiguous(); + + int M = query.size(2); + int N = key.size(2); + int M_pad = (query.size(2) + 127) / 128 * 128; + int N_pad = (key.size(2) + 127) / 128 * 128; + + auto V_pad = value.contiguous(); // [B,H,D,N_padded] + + const int batch_size = Q_pad.size(0); + const int head_dim = Q_pad.size(3); + + int stride_bz_q = Q_pad.stride(0); + int stride_bz_k = K_pad.stride(0); + int stride_bz_v = V_pad.stride(0); + int stride_bz_o = O_pad.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + + qo_len = Q_pad.size(2); + kv_len = K_pad.size(2); + num_qo_heads = Q_pad.size(1); + num_kv_heads = K_pad.size(1); + + stride_seq_q = Q_pad.stride(2); + stride_h_q = Q_pad.stride(1); + stride_seq_k = K_pad.stride(2); + stride_h_k = K_pad.stride(1); + stride_h_v = V_pad.stride(1); + stride_d_v = V_pad.stride(2); + stride_seq_o = O_pad.stride(2); + stride_h_o = O_pad.stride(1); + + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + // here need add lse + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + // assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + uint32_t ROCWMMA_M = Params::ROCWMMA_M; + uint32_t ROCWMMA_N = Params::ROCWMMA_N; + uint32_t ROCWMMA_K = Params::ROCWMMA_K; + uint32_t BLOCKS_X = Params::BLOCKS_X; + uint32_t BLOCKS_Y = Params::BLOCKS_Y; + uint32_t TBLOCK_X = Params::TBLOCK_X; + uint32_t TBLOCK_Y = Params::TBLOCK_Y; + uint32_t WARP_SIZE = Params::WARP_SIZE; + + const size_t Mx = (size_t)ROCWMMA_M * BLOCKS_X * TBLOCK_X / WARP_SIZE; + const size_t My = (size_t)ROCWMMA_N * BLOCKS_Y * TBLOCK_Y; + size_t smem_max = std::max(2u * sizeof(int8_t) * (Mx + My) * ROCWMMA_K, 1u * sizeof(float) * Mx * My); + + hipFuncSetAttribute( + (const void*)qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), + static_cast(QK_QUANT_GRAN), + 1, float, false, DTypeOut, + ComputeUnit::kCudaCore, mask_mode, + RETURN_LSE, false, false, false>, + hipFuncAttributeMaxDynamicSharedMemorySize, + (int)smem_max); + + dim3 grid(div_ceil(M_pad, Mx), num_qo_heads, batch_size); + dim3 block(TBLOCK_X, TBLOCK_Y); + + hipLaunchKernelGGL((qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), + static_cast(QK_QUANT_GRAN), + HEAD_DIM / 64, float, false, DTypeOut, + ComputeUnit::kCudaCore, mask_mode, + RETURN_LSE, false, false, false>), + dim3(grid), dim3(block), smem_max, 0, + head_dim, head_dim, N_pad, head_dim, + Q_pad.data_ptr(), + K_pad.data_ptr(), + reinterpret_cast(V_pad.data_ptr()), + reinterpret_cast(O_pad.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + nullptr, + M_pad, N_pad, N, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + + C10_HIP_CHECK(hipGetLastError()); + auto O_valid = O_pad.narrow(/*dim=*/2, 0, M); + output.copy_(O_valid); + }); + }); + }); + }); + }); + return lse; +} \ No newline at end of file diff --git a/csrc/qattn/rocm/launch_sgattn.hip b/csrc/qattn/rocm/launch_sgattn.hip new file mode 100644 index 00000000..26611b5c --- /dev/null +++ b/csrc/qattn/rocm/launch_sgattn.hip @@ -0,0 +1,238 @@ +// !!! This is a file automatically generated by hipify!!! +#include "../../dispatch_utils.h" +#include "sgattn.hip" +#include +#include + +#include +#if defined(USE_ROCM) + #include + #define TORCH_STREAM at::hip::getCurrentHIPStream().stream() + #define DEVICE_GUARD at::hip::HIPGuard device_guard(query.device()); +#else + #include + #define TORCH_STREAM at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream() + #define DEVICE_GUARD at::hip::HIPGuardMasqueradingAsCUDA device_guard(query.device()); +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "args.hpp" + + +#ifdef __ROCM_ARCH_GFX942 + using fp8_type = __hip_fp8_e4m3_fnuz; + namespace Params = gfx9Params; +#else + using fp8_type = __hip_fp8_e4m3; + namespace Params = gfx11Params; +#endif + +torch::Tensor launch_sgattn(torch::Tensor query, + torch::Tensor key, + torch::Tensor value, + torch::Tensor output, + torch::Tensor query_scale, + torch::Tensor key_scale, + torch::Tensor value_scale, + // torch::Tensor value_mean, + int tensor_layout, + int is_causal, + int qk_quant_gran, + float sm_scale, + int return_lse) +{ + CHECK_CUDA(query); + CHECK_CUDA(key); + CHECK_CUDA(value); + CHECK_CUDA(output); + CHECK_CUDA(query_scale); + CHECK_CUDA(key_scale); + CHECK_CUDA(value_scale); + + CHECK_LASTDIM_CONTIGUOUS(query); + CHECK_LASTDIM_CONTIGUOUS(key); + CHECK_CONTIGUOUS(value); // ensure value is contiguous to prevent troubles in the kernel + CHECK_LASTDIM_CONTIGUOUS(output); + CHECK_CONTIGUOUS(query_scale); + CHECK_CONTIGUOUS(key_scale); + CHECK_CONTIGUOUS(value_scale); + + CHECK_DTYPE(query, torch::kInt8); + CHECK_DTYPE(key, torch::kInt8); + // TODO: how to check fp8 data type? + CHECK_DTYPE(query_scale, torch::kFloat32); + CHECK_DTYPE(key_scale, torch::kFloat32); + CHECK_DTYPE(value_scale, torch::kFloat32); + + CHECK_DIMS(query, 4); + CHECK_DIMS(key, 4); + CHECK_DIMS(value, 4); + CHECK_DIMS(output, 4); + CHECK_DIMS(query_scale, 3); + CHECK_DIMS(key_scale, 3); + CHECK_DIMS(value_scale, 3); + + auto pad_seq128 = [](const torch::Tensor& x) { + TORCH_CHECK(x.dim() == 4, "expect [B,H,Seq,D]"); + const int64_t seq = x.size(2); + const int64_t pad = (128 - (seq % 128)) % 128; + if (pad == 0) return x.contiguous(); + return at::constant_pad_nd(x, {0,0, 0,pad}, /*value=*/0).contiguous(); + }; + + auto Q_pad = pad_seq128(query); // [B,Hq,M_pad,D] + auto K_pad = pad_seq128(key); // [B,Hk,N_pad,D] + auto O_pad = at::zeros(Q_pad.sizes(), output.options().dtype(torch::kFloat)).contiguous(); + + int M = query.size(2); + int N = key.size(2); + int M_pad = (query.size(2) + 127) / 128 * 128; + int N_pad = (key.size(2) + 127) / 128 * 128; + + auto V_pad = value.contiguous(); // [B,H,D,N_padded] + + const int batch_size = Q_pad.size(0); + const int head_dim = Q_pad.size(3); + + int stride_bz_q = Q_pad.stride(0); + int stride_bz_k = K_pad.stride(0); + int stride_bz_v = V_pad.stride(0); + int stride_bz_o = O_pad.stride(0); + + int qo_len, kv_len, num_qo_heads, num_kv_heads; + int stride_seq_q, stride_h_q, stride_seq_k, stride_h_k, stride_h_v, stride_d_v, stride_seq_o, stride_h_o; + + + qo_len = Q_pad.size(2); + kv_len = K_pad.size(2); + num_qo_heads = Q_pad.size(1); + num_kv_heads = K_pad.size(1); + + stride_seq_q = Q_pad.stride(2); + stride_h_q = Q_pad.stride(1); + stride_seq_k = K_pad.stride(2); + stride_h_k = K_pad.stride(1); + stride_h_v = V_pad.stride(1); + stride_d_v = V_pad.stride(2); + stride_seq_o = O_pad.stride(2); + stride_h_o = O_pad.stride(1); + + + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads (" << num_qo_heads << ") must be divisible by num_kv_heads (" << num_kv_heads << ")"; + throw std::invalid_argument(err_msg.str()); + } + + // here need add lse + torch::Tensor lse = torch::empty({0}); + if (return_lse) + { + lse = torch::empty({batch_size, num_qo_heads, qo_len}, query.options().dtype(torch::kFloat32)); + } + + const int num_kv_groups = num_qo_heads / num_kv_heads; + + auto output_dtype = output.scalar_type(); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + DISPATCH_CAUSAL(is_causal, IS_CAUSAL, { + DISPATCH_QK_QUANT_GRAN(qk_quant_gran, QK_QUANT_GRAN, { + DISPATCH_RETURN_LSE(return_lse, RETURN_LSE, { + DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(output_dtype, DTypeOut, { + constexpr int CTA_Q = 128; + constexpr int CTA_K = 64; + constexpr int WARP_Q = 32; + constexpr int WARP_K = 64; + + assert(value.size(0) == batch_size); + // assert(value.size(3) >= div_ceil(kv_len, CTA_K) * CTA_K); + + constexpr MaskMode mask_mode = IS_CAUSAL ? MaskMode::kCausal : MaskMode::kNone; + + if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q)); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K)); + } + else if constexpr (QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread)) + { + CHECK_SHAPE(query_scale, batch_size, num_qo_heads, div_ceil(qo_len, CTA_Q) * (CTA_Q / WARP_Q) * 8); + CHECK_SHAPE(key_scale, batch_size, num_kv_heads, div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K) * 4); + } + else + { + static_assert(QK_QUANT_GRAN == static_cast(QuantGranularity::kPerWarp) || QK_QUANT_GRAN == static_cast(QuantGranularity::kPerThread), "Unsupported quantization granularity"); + } + + uint32_t ROCWMMA_M = Params::ROCWMMA_M; + uint32_t ROCWMMA_N = Params::ROCWMMA_N; + uint32_t ROCWMMA_K = Params::ROCWMMA_K; + uint32_t BLOCKS_X = Params::BLOCKS_X; + uint32_t BLOCKS_Y = Params::BLOCKS_Y; + uint32_t TBLOCK_X = Params::TBLOCK_X; + uint32_t TBLOCK_Y = Params::TBLOCK_Y; + uint32_t WARP_SIZE = Params::WARP_SIZE; + + const size_t Mx = (size_t)ROCWMMA_M * BLOCKS_X * TBLOCK_X / WARP_SIZE; + const size_t My = (size_t)ROCWMMA_N * BLOCKS_Y * TBLOCK_Y; + size_t smem_max = ::max(2u * sizeof(int8_t) * (Mx + My) * ROCWMMA_K, 1u * sizeof(float) * Mx * My); + + hipFuncSetAttribute( + (const void*)qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), + static_cast(QK_QUANT_GRAN), + 1, float, false, DTypeOut, + ComputeUnit::kCudaCore, mask_mode, + RETURN_LSE, false, false, false>, + hipFuncAttributeMaxDynamicSharedMemorySize, + (int)smem_max); + + dim3 grid(div_ceil(M_pad, Mx), num_qo_heads, batch_size); + dim3 block(TBLOCK_X, TBLOCK_Y); + + hipLaunchKernelGGL((qk_int_sv_f8_attn_kernel(QK_QUANT_GRAN), + static_cast(QK_QUANT_GRAN), + HEAD_DIM / 64, float, false, DTypeOut, + ComputeUnit::kCudaCore, mask_mode, + RETURN_LSE, false, false, false>), + dim3(grid), dim3(block), smem_max, 0, + head_dim, head_dim, N_pad, head_dim, + Q_pad.data_ptr(), + K_pad.data_ptr(), + reinterpret_cast(V_pad.data_ptr()), + reinterpret_cast(O_pad.data_ptr()), + (RETURN_LSE) ? reinterpret_cast(lse.data_ptr()) : nullptr, + reinterpret_cast(query_scale.data_ptr()), + reinterpret_cast(key_scale.data_ptr()), + reinterpret_cast(value_scale.data_ptr()), + nullptr, + M_pad, N_pad, N, + num_kv_groups, + stride_bz_q, stride_seq_q, stride_h_q, + stride_bz_k, stride_seq_k, stride_h_k, + stride_bz_v, stride_h_v, stride_d_v, + stride_bz_o, stride_seq_o, stride_h_o, + sm_scale); + + C10_HIP_CHECK(hipGetLastError()); + auto O_valid = O_pad.narrow(/*dim=*/2, 0, M); + output.copy_(O_valid); + }); + }); + }); + }); + }); + return lse; +} \ No newline at end of file diff --git a/csrc/qattn/rocm/pybind_rocm.cpp b/csrc/qattn/rocm/pybind_rocm.cpp new file mode 100755 index 00000000..b9fe53bd --- /dev/null +++ b/csrc/qattn/rocm/pybind_rocm.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include "attn_rocm.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("qk_int8_sv_f8_accum_f32_attn", &launch_sgattn, "QK int8 sv f8 accum f32 attn"); +} diff --git a/csrc/qattn/rocm/sgattn.cu b/csrc/qattn/rocm/sgattn.cu new file mode 100755 index 00000000..6ad846cb --- /dev/null +++ b/csrc/qattn/rocm/sgattn.cu @@ -0,0 +1,901 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../attn_utils.cuh" +#include "args.hpp" + +using namespace rocwmma; + +#ifdef __ROCM_ARCH_GFX942 + using fp8_type = __hip_fp8_e4m3_fnuz; +#else + using fp8_type = __hip_fp8_e4m3; +#endif + +// namespace gfx9Params +// { +// enum kernelParams : uint32_t +// { +// ROCWMMA_M = 16u, +// ROCWMMA_N = 16u, +// ROCWMMA_K = 32u, +// BLOCKS_X = 1u, +// BLOCKS_Y = 4u, +// TBLOCK_X = 256u, +// TBLOCK_Y = 1u, +// WARP_SIZE = Constants::AMDGCN_WAVE_SIZE_64 +// }; +// } + +// namespace gfx11Params +// { +// enum kernelParams : uint32_t +// { +// ROCWMMA_M = 16u, +// ROCWMMA_N = 16u, +// ROCWMMA_K = 32u, +// BLOCKS_X = 1u, +// BLOCKS_Y = 4u, +// TBLOCK_X = 128u, +// TBLOCK_Y = 1u, +// WARP_SIZE = Constants::AMDGCN_WAVE_SIZE_32 +// }; +// } + +constexpr float log2e = 1.44269504088896340736f; +constexpr float log2e_recp = 1.0f / log2e; +#define div_ceil(M, N) (((M) + (N)-1) / (N)) +#define TRAP_IF(cond) do { if (cond) asm volatile("s_trap 0"); } while(0) + +#ifdef __ROCM_ARCH_GFX942 + using namespace gfx9Params; +#else + using namespace gfx11Params; +#endif + +/// +/// Types and Data Layouts +/// + +using InputT = int8_t; +using InputTV = fp8_type; +using OutputT = int32_t; +using ComputeT = int32_t; +using ComputeTV = float32_t; +using LDST = float32_t; +using LDST_new = fp8_type; + +using DataLayoutA = row_major; +using DataLayoutB = col_major; +using DataLayoutC = row_major; +using DataLayoutLds = col_major; +using DataLayoutV = col_major; + +using DataLayoutLds_new = row_major; +/// +/// Fragment types +/// + +// #if (ROCWMMA_ARCH_GFX9 || ROCWMMA_ARCH_GFX11) +// Warp tile: computed by each warp +constexpr uint32_t WARP_TILE_X = BLOCKS_X * ROCWMMA_M; +constexpr uint32_t WARP_TILE_Y = BLOCKS_Y * ROCWMMA_N; + +constexpr uint32_t els_per_thread = ROCWMMA_M * ROCWMMA_N / WARP_SIZE; +constexpr uint32_t threads_per_row = ROCWMMA_N / els_per_thread; + +// Macro Tile: computed by each thread block (workgroup) +// Note: TBLOCK_X must be multiple of WARP_SIZE. +constexpr uint32_t WARPS_X = TBLOCK_X / WARP_SIZE; +constexpr uint32_t WARPS_Y = TBLOCK_Y; +constexpr uint32_t MACRO_TILE_X = WARPS_X * WARP_TILE_X; +constexpr uint32_t MACRO_TILE_Y = WARPS_Y * WARP_TILE_Y; +constexpr uint32_t MACRO_TILE_K = WARPS_Y * BLOCKS_Y * ROCWMMA_K; + +// Mfma frags +using MfmaFragA = fragment; +using MfmaFragB = fragment; +using MfmaFragC = fragment; +using MfmaFragD = fragment; +using MfmaFragAcc = fragment; +using MfmaFragAccf32 = fragment; + +using MfmaFragS = fragment; +using MfmaFragV = fragment; + +// Global read (macro tile) +using GRBuffA = fragment; +using GRBuffB = fragment; + +using GRBuffS = fragment; +using GRBuffV = fragment; + +// Local write of global buffers (macro tile) +// - Must match Lds data layout. +// - Lds has transposed B frags. +using LWBuffA = ApplyDataLayout_t; +using LWBuffB = ApplyDataLayout_t, DataLayoutLds>; + +using LWBuffS = ApplyDataLayout_t; +using LWBuffV = ApplyDataLayout_t, DataLayoutLds_new>; + +// Local read (mfma frags) +// - Must match Lds data layout. +// - Lds has transposed B frags. +using LRFragA = ApplyDataLayout_t; +using LRFragB = ApplyDataLayout_t, DataLayoutLds>; + +using LRFragS = ApplyDataLayout_t; +using LRFragV = ApplyDataLayout_t, DataLayoutLds_new>; + +using LRFragAccf32 = ApplyDataLayout_t; +// #endif // (ROCWMMA_ARCH_GFX9 || ROCWMMA_ARCH_GFX11) + +/// +/// Wrapper functions: repeat mfma tile operations across entire warp tile. +/// + +// Cooperative global read / local write (Macro tile data movement) +// Loads / stores a global data fragment cooperatively across warps. Each participating warp is +// responsible for only a portion of the whole fragment. +// +// The cooperative operation is split into work items (SplitCount). Work items are consumed in +// a round robin fashion by warps in the range of [0, WaveCount). The wave index determines the +// order of the current wave in the collaboration pool. +// +// WaveCount, SplitCount and waveIndex parameters must match successive coop load / store calls +// to ensure the entire fragment remains coherent. + +// Global A reads in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + globalReadCoopA(GRBuffA& grBuffA, InputT const* gAddrA, uint32_t lda, uint32_t waveIndexA) +{ + load_matrix_coop_sync(grBuffA, gAddrA, lda, waveIndexA); +} + +// Global B reads in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + globalReadCoopB(GRBuffB& grBuffB, InputT const* gAddrB, uint32_t ldb, uint32_t waveIndexB) +{ + load_matrix_coop_sync(grBuffB, gAddrB, ldb, waveIndexB); +} + +template +ROCWMMA_DEVICE static inline void + globalReadCoopV(GRBuffV& grBuffV, InputTV const* gAddrV, uint32_t ldV, uint32_t waveIndexV) +{ + load_matrix_coop_sync(grBuffV, gAddrV, ldV, waveIndexV); +} + +// Local A writes in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + localWriteCoopA(InputT* ldsAddr, GRBuffA const& grBuffA, uint32_t ldsld, uint32_t waveIndexA) +{ + // No transpose, but apply the lds data layout + store_matrix_coop_sync( + ldsAddr, applyDataLayout(grBuffA), ldsld, waveIndexA); +} + +// Local B writes in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + localWriteCoopB(InputT* ldsAddr, GRBuffB const& grBuffB, uint32_t ldsld, uint32_t waveIndexB) +{ + // Transpose B and then apply lds data layout + store_matrix_coop_sync( + ldsAddr, applyDataLayout(applyTranspose(grBuffB)), ldsld, waveIndexB); +} + +template +ROCWMMA_DEVICE static inline void + localWriteCoopV(InputTV* ldsAddr, GRBuffV const& grBuffV, uint32_t ldsld, uint32_t waveIndexV) +{ + // Transpose B and then apply lds data layout + store_matrix_coop_sync( + ldsAddr, applyDataLayout(applyTranspose(grBuffV)), ldsld, waveIndexV); +} + +// Local A reads for warp tile gemm, non-cooperative +ROCWMMA_DEVICE static inline void + localReadA(MfmaFragA (&fragsA)[BLOCKS_X], InputT const* ldsAddrA, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each A block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + LRFragA tmp; + load_matrix_sync(tmp, ldsAddrA, ldsld); + fragsA[i] = applyDataLayout(tmp); + + ldsAddrA += blockStep; + } +} + +ROCWMMA_DEVICE static inline void + localReadS(MfmaFragS (&fragsS)[BLOCKS_X], InputTV const* ldsAddrS, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each A block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + LRFragS tmp; + load_matrix_sync(tmp, ldsAddrS, ldsld); + fragsS[i] = applyDataLayout(tmp); + + ldsAddrS += blockStep; + } +} + +// Local B reads for warp tile gemm, non-cooperative +ROCWMMA_DEVICE static inline void + localReadB(MfmaFragB (&fragsB)[BLOCKS_Y], InputT const* ldsAddrB, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each B block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_Y; i++) + { + LRFragB tmp; + load_matrix_sync(tmp, ldsAddrB, ldsld); + + // Transform back to MFMA tile + fragsB[i] = applyDataLayout(applyTranspose(tmp)); + + ldsAddrB += blockStep; + } +} + +ROCWMMA_DEVICE static inline void + localReadV(MfmaFragV (&fragsV)[BLOCKS_Y], InputTV const* ldsAddrV, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each B block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_Y; i++) + { + LRFragV tmp; + load_matrix_sync(tmp, ldsAddrV, ldsld); + + // Transform back to MFMA tile + fragsV[i] = applyDataLayout(applyTranspose(tmp)); + + ldsAddrV += blockStep; + } +} + + +// Global D reads for warp tile gemm, non-cooperative +ROCWMMA_DEVICE static inline void + globalWriteD(float* gAddrD, MfmaFragAccf32 const (&fragsD)[BLOCKS_X][BLOCKS_Y], uint32_t ldd) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Iterative offsets for each D block in the warp tile + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldd); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldd); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + store_matrix_sync(gAddrD + offsetY, fragsD[i][j], ldd); + // printf("offsetY:%u\n",offsetY); + offsetY += blockStepY; + } + gAddrD += blockStepX; + } +} + +// Broadcast value to fragments in warp tile +template +ROCWMMA_DEVICE static inline void fill(FragT (&frags)[BLOCKS_X][BLOCKS_Y], + GetDataType_t value) +{ +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + fill_fragment(frags[i][j], value); + } + } +} + +ROCWMMA_DEVICE static inline void + localWriteAcc(MfmaFragAccf32 (&fragsAcc)[BLOCKS_X][BLOCKS_Y], ComputeTV * ldsAddrAcc, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + store_matrix_sync(ldsAddrAcc + offsetY, fragsAcc[i][j], ldsld); + offsetY += blockStepY; + } + ldsAddrAcc += blockStepX; + } +} + +// ROCWMMA_DEVICE static inline void +// localWriteOut(MfmaFragAccf32 (&fragsAcc)[BLOCKS_Y], ComputeTV * ldsAddrAcc, uint32_t ldsld) +// { +// using FragShape = GetIOShape_t; +// using Mapper1d = GetDataLayout_t; + +// // auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); +// auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + +// auto offsetY = 0u; +// #pragma unroll +// for(int j = 0; j < BLOCKS_Y; j++) +// { +// store_matrix_sync(ldsAddrAcc + offsetY, fragsAcc[j], ldsld); +// offsetY += blockStepY; +// } +// } + + +ROCWMMA_DEVICE static inline void + localReadOut(MfmaFragAccf32 (&fragsAcc)[BLOCKS_Y], ComputeTV const* ldsAddrAcc, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + load_matrix_sync(fragsAcc[j], ldsAddrAcc + offsetY, ldsld); + offsetY += blockStepY; + } +} + +ROCWMMA_DEVICE static inline void + localReadAcc(MfmaFragAccf32 (&fragsAcc)[BLOCKS_X][BLOCKS_Y], float const* ldsAddrAcc, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + load_matrix_sync(fragsAcc[i][j], ldsAddrAcc + offsetY, ldsld); + offsetY += blockStepY; + } + ldsAddrAcc += blockStepX; + } +} + + +ROCWMMA_DEVICE static inline void +convertS32toF32(MfmaFragAcc (&frags_i32)[BLOCKS_X][BLOCKS_Y], + MfmaFragAccf32 (&frags_f32)[BLOCKS_X][BLOCKS_Y]) +{ +#pragma unroll + for (int i = 0; i < BLOCKS_X; ++i) { +#pragma unroll + for (int j = 0; j < BLOCKS_Y; ++j) { +// #pragma unroll + for (int k = 0; k < frags_i32[i][j].num_elements; ++k) { + frags_f32[i][j].x[k] = __int2float_rz(frags_i32[i][j].x[k]); + } + } + } +} + + +ROCWMMA_DEVICE static inline void svgemm(MfmaFragAccf32 (&fragsAccOut)[BLOCKS_X][BLOCKS_Y], + MfmaFragS const (&fragsA)[BLOCKS_X], + MfmaFragV const (&fragsB)[BLOCKS_Y], + MfmaFragAccf32 const (&fragsAccIn)[BLOCKS_X][BLOCKS_Y]) +{ +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + mma_sync(fragsAccOut[i][j], fragsA[i], fragsB[j], fragsAccIn[i][j]); + } + } +} + + +ROCWMMA_DEVICE static inline void qkgemm(MfmaFragAcc (&fragsAccOut)[BLOCKS_X][BLOCKS_Y], + MfmaFragA const (&fragsA)[BLOCKS_X], + MfmaFragB const (&fragsB)[BLOCKS_Y], + MfmaFragAcc const (&fragsAccIn)[BLOCKS_X][BLOCKS_Y]) +{ +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + mma_sync(fragsAccOut[i][j], fragsA[i], fragsB[j], fragsAccIn[i][j]); + } + } +} + +template +ROCWMMA_KERNEL void __launch_bounds__(256) qk_int_sv_f8_attn_kernel(uint32_t lda, + uint32_t ldb, + uint32_t ldv, + uint32_t ldd, + int8_t *__restrict__ Q, int8_t *__restrict__ K, InputTV *__restrict__ V, float *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, float *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t true_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_h_v, const uint32_t stride_d_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + if constexpr(!ROCWMMA_ARCH_HOST) + { + + // Tile Sizes + constexpr auto warpTileSize = make_coord2d(WARP_TILE_X, WARP_TILE_Y); + constexpr auto macroTileSize = make_coord2d(MACRO_TILE_X, MACRO_TILE_Y); + + // Local warp coordinate relative to current threadblock (wg). + constexpr auto warpDims = make_coord2d(WARPS_X, WARPS_Y); + auto localWarpCoord = make_coord2d(threadIdx.x / WARP_SIZE, threadIdx.y); + auto localWarpOffset = localWarpCoord * warpTileSize; + constexpr auto warpCount = get<0>(warpDims) * get<1>(warpDims); + const auto warpIndex = get<0>(localWarpCoord) * get<1>(warpDims) + get<1>(localWarpCoord); + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + using MfmaFragDMap1d = GetDataLayout_t; + + const uint32_t iterations = div_ceil( + kv_len, + MACRO_TILE_Y); + + MfmaFragAccf32 fragsOut[SV_ITERS][BLOCKS_X][BLOCKS_Y]; + for(int i = 0; i < SV_ITERS; i++){ + fill(fragsOut[i], 0.0f); + } + + HIP_DYNAMIC_SHARED(void*, localMemPtr); + + constexpr int NUM = ROCWMMA_M * ROCWMMA_N / WARP_SIZE; + float m[NUM], d[NUM]; + for(int i = 0; i < NUM; i++) + { + m[i] = -5000.f; + d[i] = 1.f; + } + + for(int iter = 0; iter < iterations; iter++) + { + MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; + fill(fragsAcc, 0.0f); + auto macroTileCoord = make_coord2d(blockIdx.x, iter) * macroTileSize; + { + auto warpTileCoord = macroTileCoord + localWarpOffset; + + // Bounds check + auto warpTileBound = warpTileCoord + warpTileSize; + if(get<0>(warpTileBound) > qo_len || get<1>(warpTileBound) > kv_len) + { + continue; + } + + using GRBuffAMap1d = GetDataLayout_t; + using GRBuffBMap1d = GetDataLayout_t; + + // Initial globa read address offsets + auto globalReadOffsetA + = batch_id * stride_bz_q + head_id * stride_h_q + + GRBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(macroTileCoord), 0u), lda); + auto globalReadOffsetB + = batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + + GRBuffBMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), ldb); + + // Incremental global read address offsets + auto kStepOffsetA = GRBuffAMap1d::fromMatrixCoord(make_coord2d(0u, ROCWMMA_K), lda); + auto kStepOffsetB = GRBuffBMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldb); + + GRBuffA grBuffA; + GRBuffB grBuffB; + + globalReadCoopA(grBuffA, Q + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, K + globalReadOffsetB, ldb, warpIndex); + + globalReadOffsetA += kStepOffsetA; + globalReadOffsetB += kStepOffsetB; + + using LWBuffAShape = GetIOShape_t; + using LWBuffBShape = GetIOShape_t; + using LWBuffAMap1d = GetDataLayout_t; + using LWBuffBMap1d = GetDataLayout_t; + + constexpr uint32_t ldsWidth = ROCWMMA_K; + constexpr uint32_t ldsHeight = LWBuffAShape::BlockHeight + LWBuffBShape::BlockHeight; + constexpr uint32_t sizeLds = ldsHeight * ldsWidth; + constexpr uint32_t ldsld = std::is_same_v ? ldsWidth : ldsHeight; + + auto* ldsPtrLo = reinterpret_cast(localMemPtr); + auto* ldsPtrHi = ldsPtrLo + sizeLds; + + auto ldsWriteOffsetA = 0u; + auto ldsWriteOffsetB + = LWBuffAMap1d::fromMatrixCoord(make_coord2d(LWBuffAShape::BlockHeight, 0u), ldsld); + + auto ldsReadOffsetA + = ldsWriteOffsetA + + LWBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(localWarpOffset), 0u), ldsld); + auto ldsReadOffsetB + = ldsWriteOffsetB + + LWBuffBMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld); + + localWriteCoopA(ldsPtrLo + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrLo + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + synchronize_workgroup(); + + for(uint32_t currentK = ROCWMMA_K; currentK < head_dim; currentK += ROCWMMA_K) + { + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; + + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + + globalReadCoopA(grBuffA, Q + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, K + globalReadOffsetB, ldb, warpIndex); + + globalReadOffsetA += kStepOffsetA; + globalReadOffsetB += kStepOffsetB; + + qkgemm(fragsAcc, fragsA, fragsB, fragsAcc); + + localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + synchronize_workgroup(); + + auto* tmp = ldsPtrLo; + ldsPtrLo = ldsPtrHi; + ldsPtrHi = tmp; + } + + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; + + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + qkgemm(fragsAcc, fragsA, fragsB, fragsAcc); + } + + MfmaFragAccf32 fragsTmp[BLOCKS_X][BLOCKS_Y]; + convertS32toF32(fragsAcc,fragsTmp); + + auto* ldsPtr = reinterpret_cast(localMemPtr); + constexpr uint32_t ldsWidth_new = MACRO_TILE_Y; + constexpr uint32_t ldsHeight_new = 2 * MACRO_TILE_X; + constexpr uint32_t ldsld_new = std::is_same_v ? ldsWidth_new : ldsHeight_new; + + auto ldsReadOffsetAcc = get<0>(localWarpOffset) * ldsld_new + get<1>(localWarpOffset); + + float original_sm_scale = sm_scale; + uint32_t baseq_scale_idx, basek_scale_idx; + if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * MACRO_TILE_X / 32; + baseq_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + + bx * (MACRO_TILE_X / 32 * 8) + get<0>(localWarpCoord) / 2 * 8; + + + } + if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + basek_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + + (head_id / num_kv_groups) * (num_warp_block_k * 4) + iter * 4; + } + + //// Here,we do update mdo + for(int i = 0; i < BLOCKS_X; i++) + { + float m_tmp[NUM]; + for(int k = 0; k < els_per_thread; k++) + { + m_tmp[k] = -5000.f; + } + + // dequant and get max + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int k = 0; k < els_per_thread; k++) + { + // dequant here + auto row = (threadIdx.x % 64) / ROCWMMA_N * 4 + k; + auto col = threadIdx.x % ROCWMMA_N; + + if(iter * MACRO_TILE_Y + threadIdx.y * WARP_TILE_Y + BLOCKS_Y * ROCWMMA_N + col < true_len) + { + fragsTmp[i][j].x[k] *= (Q_scale[baseq_scale_idx + row % 8] * log2e * sm_scale); + fragsTmp[i][j].x[k] *= K_scale[basek_scale_idx + (col % 8) / 2]; + } + else + { + fragsTmp[i][j].x[k] = -5000.f; + } + m_tmp[k] = fmaxf(m_tmp[k], fragsTmp[i][j].x[k]); + } + } + + for(int k = 0; k < els_per_thread; k++) + { + float v = m_tmp[k]; + for (int offset = ROCWMMA_N / 2; offset > 0; offset >>= 1) + { + float other = __shfl_xor(v, offset, ROCWMMA_N); + v = fmaxf(v, other); + } + m_tmp[k] = v; + } + + float o_scale[NUM]; + for(int k = 0; k < NUM; k++) + { + float m_prev = m[k]; + m[k] = fmaxf(m[k], m_tmp[k]); + + o_scale[k] = exp2f(m_prev - m[k]); + + d[k] *= o_scale[k]; + } + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int k = 0; k < els_per_thread; k++) + { + float neg_m = -m[k]; + //// accumulate d + fragsTmp[i][j].x[k] = exp2f(fragsTmp[i][j].x[k] + neg_m); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + float d_local = fragsTmp[i][j].x[k]; + for (int offset = ROCWMMA_N / 2; offset > 0; offset >>= 1) { + float other = __shfl_xor(d_local, offset, ROCWMMA_N); + d_local += other; + } + d[k] += d_local; + } + } + } + //here require handle RO + for(int k = 0; k < SV_ITERS; k++) + { + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int z = 0; z < fragsOut[k][i][j].num_elements; z++) + { + float o_true = o_scale[z]; + fragsOut[k][i][j].x[z] *= o_true; + } + } + } + } + + //// cast f32 to fp8 + localWriteAcc(fragsTmp,ldsPtr + ldsReadOffsetAcc,ldsld_new); + synchronize_workgroup(); + + InputTV RS[BLOCKS_X][BLOCKS_Y][els_per_thread]; + for(int i = 0; i < BLOCKS_X; i++) + { + for(int j = 0; j < BLOCKS_Y; j++) + { + auto baseoffset = ldsReadOffsetAcc + i * ROCWMMA_M * ldsld_new + j * ROCWMMA_N; + auto threadoffset = threadIdx.x % threads_per_row * els_per_thread + + (threadIdx.x) % WARP_SIZE / threads_per_row * ldsld_new; + for(int k = 0; k < els_per_thread; k++) + { + RS[i][j][k] = InputTV(ldsPtr[baseoffset + threadoffset + k]); + } + } + } + + // do we need sync here? + synchronize_workgroup(); + + auto* ldsPtrf8 = reinterpret_cast(localMemPtr); + constexpr uint32_t ldsWidth_fp8 = MACRO_TILE_Y; + constexpr uint32_t ldsHeight_fp8 = 2 * MACRO_TILE_X * sizeof(ComputeTV) / sizeof(InputTV); + constexpr uint32_t ldsld_fp8 = std::is_same_v ? ldsWidth_fp8 : ldsHeight_fp8; + + auto ldsReadOffsetsv = get<0>(localWarpOffset) * ldsld_fp8 + get<1>(localWarpOffset); + for(int i = 0; i < BLOCKS_X; i++) + { + for(int j = 0; j < BLOCKS_Y; j++) + { + auto baseoffset = ldsReadOffsetsv + i * ROCWMMA_M * ldsld_fp8 + j * ROCWMMA_N; + auto threadoffset = threadIdx.x % threads_per_row * els_per_thread + + (threadIdx.x) % WARP_SIZE / threads_per_row * ldsld_fp8; + + for(int k = 0; k < els_per_thread; k++) + { + ldsPtrf8[baseoffset + threadoffset + k] = RS[i][j][k]; + } + } + } + + // do we need sync here? + synchronize_workgroup(); + + constexpr size_t SV_CNT = MACRO_TILE_Y / ROCWMMA_K; + auto ldsReadOffsetS = get<0>(localWarpOffset) * ldsld_fp8; + MfmaFragS fragsS[SV_CNT][BLOCKS_X]; + for(int i = 0; i < SV_CNT; i++) + { + localReadS(fragsS[i], ldsPtrf8 + ldsReadOffsetS, ldsld_fp8); + ldsReadOffsetS += ROCWMMA_K; + } + + for(int sv_iter = 0; sv_iter < SV_ITERS; sv_iter++) + { + //// load v and calcualte sv + using GRBuffVMap1d = GetDataLayout_t; + auto globalReadOffsetV + = GRBuffVMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), 1u) + + sv_iter * MACRO_TILE_Y * ldv + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v; + + auto kStepOffsetV = GRBuffVMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldv); + + GRBuffV grBuffV; + + using LWBuffVShape = GetIOShape_t; + using LWBuffVMap1d = GetDataLayout_t; + + auto ldsWriteOffsetV = MACRO_TILE_X * MACRO_TILE_Y; + auto ldsReadOffsetV + = ldsWriteOffsetV + + LWBuffVMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld_fp8); + + for(uint32_t currentK = 0; currentK < SV_CNT; currentK++) + { + globalReadCoopV(grBuffV, V + globalReadOffsetV, ldv, warpIndex); + globalReadOffsetV += kStepOffsetV; + + localWriteCoopV(ldsPtrf8 + ldsWriteOffsetV, grBuffV, ldsld_fp8, warpIndex); + synchronize_workgroup(); + + MfmaFragV fragsV[BLOCKS_Y]; + + localReadV(fragsV, ldsPtrf8 + ldsReadOffsetV, ldsld_fp8); + svgemm(fragsOut[sv_iter], fragsS[currentK], fragsV, fragsOut[sv_iter]); + synchronize_workgroup(); + } + } + } + + //// normalize d + + float d_rcp[BLOCKS_X]; + + auto* ldsPtrOut = reinterpret_cast(localMemPtr); + constexpr uint32_t ldsWidth_Out = MACRO_TILE_Y; + constexpr uint32_t ldsHeight_Out = 2 * MACRO_TILE_X; + constexpr uint32_t ldsld_Out = std::is_same_v ? ldsWidth_Out : ldsHeight_Out; + + auto ldsReadOffsetAcc = get<0>(localWarpOffset) * ldsld_Out + get<1>(localWarpOffset); + +#pragma unroll + for(int sv_iter = 0; sv_iter < SV_ITERS; sv_iter++) + { +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + auto base_offset = batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim; + auto scale_idx = sv_iter * MACRO_TILE_Y + get<1>(localWarpOffset) + j * ROCWMMA_N; +#pragma unroll + for(int k = 0; k < fragsOut[sv_iter][i][j].num_elements; k++) + { + auto col = threadIdx.x % ROCWMMA_N; + fragsOut[sv_iter][i][j].x[k] /= d[k]; + fragsOut[sv_iter][i][j].x[k] *= V_scale[base_offset + scale_idx + col]; + } + } + } + } + + for(int i = 0; i < SV_ITERS; i++){ + auto Out_macroTileCoord = make_coord2d(blockIdx.x, i) * macroTileSize; + auto Out_warpTileCoord = Out_macroTileCoord + localWarpOffset; + + globalWriteD(O + batch_id * stride_bz_o + head_id * stride_h_o + + MfmaFragDMap1d::fromMatrixCoord(Out_warpTileCoord, ldd), fragsOut[i], ldd); + } + } +} + diff --git a/csrc/qattn/rocm/sgattn.hip b/csrc/qattn/rocm/sgattn.hip new file mode 100644 index 00000000..f92e77c5 --- /dev/null +++ b/csrc/qattn/rocm/sgattn.hip @@ -0,0 +1,902 @@ +// !!! This is a file automatically generated by hipify!!! +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include "../attn_utils_hip.cuh" +#include "args.hpp" + +using namespace rocwmma; + +#ifdef __ROCM_ARCH_GFX942 + using fp8_type = __hip_fp8_e4m3_fnuz; +#else + using fp8_type = __hip_fp8_e4m3; +#endif + +// namespace gfx9Params +// { +// enum kernelParams : uint32_t +// { +// ROCWMMA_M = 16u, +// ROCWMMA_N = 16u, +// ROCWMMA_K = 32u, +// BLOCKS_X = 1u, +// BLOCKS_Y = 4u, +// TBLOCK_X = 256u, +// TBLOCK_Y = 1u, +// WARP_SIZE = Constants::AMDGCN_WAVE_SIZE_64 +// }; +// } + +// namespace gfx11Params +// { +// enum kernelParams : uint32_t +// { +// ROCWMMA_M = 16u, +// ROCWMMA_N = 16u, +// ROCWMMA_K = 32u, +// BLOCKS_X = 1u, +// BLOCKS_Y = 4u, +// TBLOCK_X = 128u, +// TBLOCK_Y = 1u, +// WARP_SIZE = Constants::AMDGCN_WAVE_SIZE_32 +// }; +// } + +constexpr float log2e = 1.44269504088896340736f; +constexpr float log2e_recp = 1.0f / log2e; +#define div_ceil(M, N) (((M) + (N)-1) / (N)) +#define TRAP_IF(cond) do { if (cond) asm volatile("s_trap 0"); } while(0) + +#ifdef __ROCM_ARCH_GFX942 + using namespace gfx9Params; +#else + using namespace gfx11Params; +#endif + +/// +/// Types and Data Layouts +/// + +using InputT = int8_t; +using InputTV = fp8_type; +using OutputT = int32_t; +using ComputeT = int32_t; +using ComputeTV = float32_t; +using LDST = float32_t; +using LDST_new = fp8_type; + +using DataLayoutA = row_major; +using DataLayoutB = col_major; +using DataLayoutC = row_major; +using DataLayoutLds = col_major; +using DataLayoutV = col_major; + +using DataLayoutLds_new = row_major; +/// +/// Fragment types +/// + +// #if (ROCWMMA_ARCH_GFX9 || ROCWMMA_ARCH_GFX11) +// Warp tile: computed by each warp +constexpr uint32_t WARP_TILE_X = BLOCKS_X * ROCWMMA_M; +constexpr uint32_t WARP_TILE_Y = BLOCKS_Y * ROCWMMA_N; + +constexpr uint32_t els_per_thread = ROCWMMA_M * ROCWMMA_N / WARP_SIZE; +constexpr uint32_t threads_per_row = ROCWMMA_N / els_per_thread; + +// Macro Tile: computed by each thread block (workgroup) +// Note: TBLOCK_X must be multiple of WARP_SIZE. +constexpr uint32_t WARPS_X = TBLOCK_X / WARP_SIZE; +constexpr uint32_t WARPS_Y = TBLOCK_Y; +constexpr uint32_t MACRO_TILE_X = WARPS_X * WARP_TILE_X; +constexpr uint32_t MACRO_TILE_Y = WARPS_Y * WARP_TILE_Y; +constexpr uint32_t MACRO_TILE_K = WARPS_Y * BLOCKS_Y * ROCWMMA_K; + +// Mfma frags +using MfmaFragA = fragment; +using MfmaFragB = fragment; +using MfmaFragC = fragment; +using MfmaFragD = fragment; +using MfmaFragAcc = fragment; +using MfmaFragAccf32 = fragment; + +using MfmaFragS = fragment; +using MfmaFragV = fragment; + +// Global read (macro tile) +using GRBuffA = fragment; +using GRBuffB = fragment; + +using GRBuffS = fragment; +using GRBuffV = fragment; + +// Local write of global buffers (macro tile) +// - Must match Lds data layout. +// - Lds has transposed B frags. +using LWBuffA = ApplyDataLayout_t; +using LWBuffB = ApplyDataLayout_t, DataLayoutLds>; + +using LWBuffS = ApplyDataLayout_t; +using LWBuffV = ApplyDataLayout_t, DataLayoutLds_new>; + +// Local read (mfma frags) +// - Must match Lds data layout. +// - Lds has transposed B frags. +using LRFragA = ApplyDataLayout_t; +using LRFragB = ApplyDataLayout_t, DataLayoutLds>; + +using LRFragS = ApplyDataLayout_t; +using LRFragV = ApplyDataLayout_t, DataLayoutLds_new>; + +using LRFragAccf32 = ApplyDataLayout_t; +// #endif // (ROCWMMA_ARCH_GFX9 || ROCWMMA_ARCH_GFX11) + +/// +/// Wrapper functions: repeat mfma tile operations across entire warp tile. +/// + +// Cooperative global read / local write (Macro tile data movement) +// Loads / stores a global data fragment cooperatively across warps. Each participating warp is +// responsible for only a portion of the whole fragment. +// +// The cooperative operation is split into work items (SplitCount). Work items are consumed in +// a round robin fashion by warps in the range of [0, WaveCount). The wave index determines the +// order of the current wave in the collaboration pool. +// +// WaveCount, SplitCount and waveIndex parameters must match successive coop load / store calls +// to ensure the entire fragment remains coherent. + +// Global A reads in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + globalReadCoopA(GRBuffA& grBuffA, InputT const* gAddrA, uint32_t lda, uint32_t waveIndexA) +{ + load_matrix_coop_sync(grBuffA, gAddrA, lda, waveIndexA); +} + +// Global B reads in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + globalReadCoopB(GRBuffB& grBuffB, InputT const* gAddrB, uint32_t ldb, uint32_t waveIndexB) +{ + load_matrix_coop_sync(grBuffB, gAddrB, ldb, waveIndexB); +} + +template +ROCWMMA_DEVICE static inline void + globalReadCoopV(GRBuffV& grBuffV, InputTV const* gAddrV, uint32_t ldV, uint32_t waveIndexV) +{ + load_matrix_coop_sync(grBuffV, gAddrV, ldV, waveIndexV); +} + +// Local A writes in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + localWriteCoopA(InputT* ldsAddr, GRBuffA const& grBuffA, uint32_t ldsld, uint32_t waveIndexA) +{ + // No transpose, but apply the lds data layout + store_matrix_coop_sync( + ldsAddr, applyDataLayout(grBuffA), ldsld, waveIndexA); +} + +// Local B writes in cooperative mode (macro tile) +template +ROCWMMA_DEVICE static inline void + localWriteCoopB(InputT* ldsAddr, GRBuffB const& grBuffB, uint32_t ldsld, uint32_t waveIndexB) +{ + // Transpose B and then apply lds data layout + store_matrix_coop_sync( + ldsAddr, applyDataLayout(applyTranspose(grBuffB)), ldsld, waveIndexB); +} + +template +ROCWMMA_DEVICE static inline void + localWriteCoopV(InputTV* ldsAddr, GRBuffV const& grBuffV, uint32_t ldsld, uint32_t waveIndexV) +{ + // Transpose B and then apply lds data layout + store_matrix_coop_sync( + ldsAddr, applyDataLayout(applyTranspose(grBuffV)), ldsld, waveIndexV); +} + +// Local A reads for warp tile gemm, non-cooperative +ROCWMMA_DEVICE static inline void + localReadA(MfmaFragA (&fragsA)[BLOCKS_X], InputT const* ldsAddrA, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each A block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + LRFragA tmp; + load_matrix_sync(tmp, ldsAddrA, ldsld); + fragsA[i] = applyDataLayout(tmp); + + ldsAddrA += blockStep; + } +} + +ROCWMMA_DEVICE static inline void + localReadS(MfmaFragS (&fragsS)[BLOCKS_X], InputTV const* ldsAddrS, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each A block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + LRFragS tmp; + load_matrix_sync(tmp, ldsAddrS, ldsld); + fragsS[i] = applyDataLayout(tmp); + + ldsAddrS += blockStep; + } +} + +// Local B reads for warp tile gemm, non-cooperative +ROCWMMA_DEVICE static inline void + localReadB(MfmaFragB (&fragsB)[BLOCKS_Y], InputT const* ldsAddrB, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each B block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_Y; i++) + { + LRFragB tmp; + load_matrix_sync(tmp, ldsAddrB, ldsld); + + // Transform back to MFMA tile + fragsB[i] = applyDataLayout(applyTranspose(tmp)); + + ldsAddrB += blockStep; + } +} + +ROCWMMA_DEVICE static inline void + localReadV(MfmaFragV (&fragsV)[BLOCKS_Y], InputTV const* ldsAddrV, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Each B block is stacked vertically in LDS + auto blockStep = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_Y; i++) + { + LRFragV tmp; + load_matrix_sync(tmp, ldsAddrV, ldsld); + + // Transform back to MFMA tile + fragsV[i] = applyDataLayout(applyTranspose(tmp)); + + ldsAddrV += blockStep; + } +} + + +// Global D reads for warp tile gemm, non-cooperative +ROCWMMA_DEVICE static inline void + globalWriteD(float* gAddrD, MfmaFragAccf32 const (&fragsD)[BLOCKS_X][BLOCKS_Y], uint32_t ldd) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + // Iterative offsets for each D block in the warp tile + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldd); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldd); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + store_matrix_sync(gAddrD + offsetY, fragsD[i][j], ldd); + // printf("offsetY:%u\n",offsetY); + offsetY += blockStepY; + } + gAddrD += blockStepX; + } +} + +// Broadcast value to fragments in warp tile +template +ROCWMMA_DEVICE static inline void fill(FragT (&frags)[BLOCKS_X][BLOCKS_Y], + GetDataType_t value) +{ +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + fill_fragment(frags[i][j], value); + } + } +} + +ROCWMMA_DEVICE static inline void + localWriteAcc(MfmaFragAccf32 (&fragsAcc)[BLOCKS_X][BLOCKS_Y], ComputeTV * ldsAddrAcc, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + store_matrix_sync(ldsAddrAcc + offsetY, fragsAcc[i][j], ldsld); + offsetY += blockStepY; + } + ldsAddrAcc += blockStepX; + } +} + +// ROCWMMA_DEVICE static inline void +// localWriteOut(MfmaFragAccf32 (&fragsAcc)[BLOCKS_Y], ComputeTV * ldsAddrAcc, uint32_t ldsld) +// { +// using FragShape = GetIOShape_t; +// using Mapper1d = GetDataLayout_t; + +// // auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); +// auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + +// auto offsetY = 0u; +// #pragma unroll +// for(int j = 0; j < BLOCKS_Y; j++) +// { +// store_matrix_sync(ldsAddrAcc + offsetY, fragsAcc[j], ldsld); +// offsetY += blockStepY; +// } +// } + + +ROCWMMA_DEVICE static inline void + localReadOut(MfmaFragAccf32 (&fragsAcc)[BLOCKS_Y], ComputeTV const* ldsAddrAcc, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + load_matrix_sync(fragsAcc[j], ldsAddrAcc + offsetY, ldsld); + offsetY += blockStepY; + } +} + +ROCWMMA_DEVICE static inline void + localReadAcc(MfmaFragAccf32 (&fragsAcc)[BLOCKS_X][BLOCKS_Y], float const* ldsAddrAcc, uint32_t ldsld) +{ + using FragShape = GetIOShape_t; + using Mapper1d = GetDataLayout_t; + + auto blockStepX = Mapper1d::fromMatrixCoord(make_coord2d(FragShape::BlockHeight, 0u), ldsld); + auto blockStepY = Mapper1d::fromMatrixCoord(make_coord2d(0u, FragShape::BlockWidth), ldsld); + +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { + auto offsetY = 0u; +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + load_matrix_sync(fragsAcc[i][j], ldsAddrAcc + offsetY, ldsld); + offsetY += blockStepY; + } + ldsAddrAcc += blockStepX; + } +} + + +ROCWMMA_DEVICE static inline void +convertS32toF32(MfmaFragAcc (&frags_i32)[BLOCKS_X][BLOCKS_Y], + MfmaFragAccf32 (&frags_f32)[BLOCKS_X][BLOCKS_Y]) +{ +#pragma unroll + for (int i = 0; i < BLOCKS_X; ++i) { +#pragma unroll + for (int j = 0; j < BLOCKS_Y; ++j) { +// #pragma unroll + for (int k = 0; k < frags_i32[i][j].num_elements; ++k) { + frags_f32[i][j].x[k] = __int2float_rz(frags_i32[i][j].x[k]); + } + } + } +} + + +ROCWMMA_DEVICE static inline void svgemm(MfmaFragAccf32 (&fragsAccOut)[BLOCKS_X][BLOCKS_Y], + MfmaFragS const (&fragsA)[BLOCKS_X], + MfmaFragV const (&fragsB)[BLOCKS_Y], + MfmaFragAccf32 const (&fragsAccIn)[BLOCKS_X][BLOCKS_Y]) +{ +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + mma_sync(fragsAccOut[i][j], fragsA[i], fragsB[j], fragsAccIn[i][j]); + } + } +} + + +ROCWMMA_DEVICE static inline void qkgemm(MfmaFragAcc (&fragsAccOut)[BLOCKS_X][BLOCKS_Y], + MfmaFragA const (&fragsA)[BLOCKS_X], + MfmaFragB const (&fragsB)[BLOCKS_Y], + MfmaFragAcc const (&fragsAccIn)[BLOCKS_X][BLOCKS_Y]) +{ +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + mma_sync(fragsAccOut[i][j], fragsA[i], fragsB[j], fragsAccIn[i][j]); + } + } +} + +template +ROCWMMA_KERNEL void __launch_bounds__(256) qk_int_sv_f8_attn_kernel(uint32_t lda, + uint32_t ldb, + uint32_t ldv, + uint32_t ldd, + int8_t *__restrict__ Q, int8_t *__restrict__ K, InputTV *__restrict__ V, float *__restrict__ O, float *__restrict__ Lse, + float *__restrict__ Q_scale, float *__restrict__ K_scale, float *__restrict__ V_scale, float *__restrict__ V_mean, + const uint32_t qo_len, const uint32_t kv_len, const uint32_t true_len, const uint32_t num_kv_groups, + const uint32_t stride_bz_q, const uint32_t stride_seq_q, const uint32_t stride_h_q, + const uint32_t stride_bz_k, const uint32_t stride_seq_k, const uint32_t stride_h_k, + const uint32_t stride_bz_v, const uint32_t stride_h_v, const uint32_t stride_d_v, + const uint32_t stride_bz_o, const uint32_t stride_seq_o, const uint32_t stride_h_o, + float sm_scale) +{ + if constexpr(!ROCWMMA_ARCH_HOST) + { + + // Tile Sizes + constexpr auto warpTileSize = make_coord2d(WARP_TILE_X, WARP_TILE_Y); + constexpr auto macroTileSize = make_coord2d(MACRO_TILE_X, MACRO_TILE_Y); + + // Local warp coordinate relative to current threadblock (wg). + constexpr auto warpDims = make_coord2d(WARPS_X, WARPS_Y); + auto localWarpCoord = make_coord2d(threadIdx.x / WARP_SIZE, threadIdx.y); + auto localWarpOffset = localWarpCoord * warpTileSize; + constexpr auto warpCount = get<0>(warpDims) * get<1>(warpDims); + const auto warpIndex = get<0>(localWarpCoord) * get<1>(warpDims) + get<1>(localWarpCoord); + + const uint32_t batch_id = blockIdx.z; + const uint32_t bx = blockIdx.x; + const uint32_t num_qo_heads = gridDim.y; + const uint32_t head_id = blockIdx.y; + + using MfmaFragDMap1d = GetDataLayout_t; + + const uint32_t iterations = div_ceil( + kv_len, + MACRO_TILE_Y); + + MfmaFragAccf32 fragsOut[SV_ITERS][BLOCKS_X][BLOCKS_Y]; + for(int i = 0; i < SV_ITERS; i++){ + fill(fragsOut[i], 0.0f); + } + + HIP_DYNAMIC_SHARED(void*, localMemPtr); + + constexpr int NUM = ROCWMMA_M * ROCWMMA_N / WARP_SIZE; + float m[NUM], d[NUM]; + for(int i = 0; i < NUM; i++) + { + m[i] = -5000.f; + d[i] = 1.f; + } + + for(int iter = 0; iter < iterations; iter++) + { + MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; + fill(fragsAcc, 0.0f); + auto macroTileCoord = make_coord2d(blockIdx.x, iter) * macroTileSize; + { + auto warpTileCoord = macroTileCoord + localWarpOffset; + + // Bounds check + auto warpTileBound = warpTileCoord + warpTileSize; + if(get<0>(warpTileBound) > qo_len || get<1>(warpTileBound) > kv_len) + { + continue; + } + + using GRBuffAMap1d = GetDataLayout_t; + using GRBuffBMap1d = GetDataLayout_t; + + // Initial globa read address offsets + auto globalReadOffsetA + = batch_id * stride_bz_q + head_id * stride_h_q + + GRBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(macroTileCoord), 0u), lda); + auto globalReadOffsetB + = batch_id * stride_bz_k + (head_id / num_kv_groups) * stride_h_k + + GRBuffBMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), ldb); + + // Incremental global read address offsets + auto kStepOffsetA = GRBuffAMap1d::fromMatrixCoord(make_coord2d(0u, ROCWMMA_K), lda); + auto kStepOffsetB = GRBuffBMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldb); + + GRBuffA grBuffA; + GRBuffB grBuffB; + + globalReadCoopA(grBuffA, Q + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, K + globalReadOffsetB, ldb, warpIndex); + + globalReadOffsetA += kStepOffsetA; + globalReadOffsetB += kStepOffsetB; + + using LWBuffAShape = GetIOShape_t; + using LWBuffBShape = GetIOShape_t; + using LWBuffAMap1d = GetDataLayout_t; + using LWBuffBMap1d = GetDataLayout_t; + + constexpr uint32_t ldsWidth = ROCWMMA_K; + constexpr uint32_t ldsHeight = LWBuffAShape::BlockHeight + LWBuffBShape::BlockHeight; + constexpr uint32_t sizeLds = ldsHeight * ldsWidth; + constexpr uint32_t ldsld = std::is_same_v ? ldsWidth : ldsHeight; + + auto* ldsPtrLo = reinterpret_cast(localMemPtr); + auto* ldsPtrHi = ldsPtrLo + sizeLds; + + auto ldsWriteOffsetA = 0u; + auto ldsWriteOffsetB + = LWBuffAMap1d::fromMatrixCoord(make_coord2d(LWBuffAShape::BlockHeight, 0u), ldsld); + + auto ldsReadOffsetA + = ldsWriteOffsetA + + LWBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(localWarpOffset), 0u), ldsld); + auto ldsReadOffsetB + = ldsWriteOffsetB + + LWBuffBMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld); + + localWriteCoopA(ldsPtrLo + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrLo + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + synchronize_workgroup(); + + for(uint32_t currentK = ROCWMMA_K; currentK < head_dim; currentK += ROCWMMA_K) + { + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; + + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + + globalReadCoopA(grBuffA, Q + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, K + globalReadOffsetB, ldb, warpIndex); + + globalReadOffsetA += kStepOffsetA; + globalReadOffsetB += kStepOffsetB; + + qkgemm(fragsAcc, fragsA, fragsB, fragsAcc); + + localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + synchronize_workgroup(); + + auto* tmp = ldsPtrLo; + ldsPtrLo = ldsPtrHi; + ldsPtrHi = tmp; + } + + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; + + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + qkgemm(fragsAcc, fragsA, fragsB, fragsAcc); + } + + MfmaFragAccf32 fragsTmp[BLOCKS_X][BLOCKS_Y]; + convertS32toF32(fragsAcc,fragsTmp); + + auto* ldsPtr = reinterpret_cast(localMemPtr); + constexpr uint32_t ldsWidth_new = MACRO_TILE_Y; + constexpr uint32_t ldsHeight_new = 2 * MACRO_TILE_X; + constexpr uint32_t ldsld_new = std::is_same_v ? ldsWidth_new : ldsHeight_new; + + auto ldsReadOffsetAcc = get<0>(localWarpOffset) * ldsld_new + get<1>(localWarpOffset); + + float original_sm_scale = sm_scale; + uint32_t baseq_scale_idx, basek_scale_idx; + if constexpr (Q_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_q = gridDim.x * MACRO_TILE_X / 32; + baseq_scale_idx = batch_id * num_qo_heads * (num_warp_block_q * 8) + head_id * (num_warp_block_q * 8) + + bx * (MACRO_TILE_X / 32 * 8) + get<0>(localWarpCoord) / 2 * 8; + + + } + if constexpr (K_GRAN == QuantGranularity::kPerThread) + { + const uint32_t num_warp_block_k = div_ceil(kv_len, CTA_K) * (CTA_K / WARP_K); + basek_scale_idx = batch_id * (num_qo_heads / num_kv_groups) * (num_warp_block_k * 4) + + (head_id / num_kv_groups) * (num_warp_block_k * 4) + iter * 4; + } + + //// Here,we do update mdo + for(int i = 0; i < BLOCKS_X; i++) + { + float m_tmp[NUM]; + for(int k = 0; k < els_per_thread; k++) + { + m_tmp[k] = -5000.f; + } + + // dequant and get max + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int k = 0; k < els_per_thread; k++) + { + // dequant here + auto row = (threadIdx.x % 64) / ROCWMMA_N * 4 + k; + auto col = threadIdx.x % ROCWMMA_N; + + if(iter * MACRO_TILE_Y + threadIdx.y * WARP_TILE_Y + BLOCKS_Y * ROCWMMA_N + col < true_len) + { + fragsTmp[i][j].x[k] *= (Q_scale[baseq_scale_idx + row % 8] * log2e * sm_scale); + fragsTmp[i][j].x[k] *= K_scale[basek_scale_idx + (col % 8) / 2]; + } + else + { + fragsTmp[i][j].x[k] = -5000.f; + } + m_tmp[k] = fmaxf(m_tmp[k], fragsTmp[i][j].x[k]); + } + } + + for(int k = 0; k < els_per_thread; k++) + { + float v = m_tmp[k]; + for (int offset = ROCWMMA_N / 2; offset > 0; offset >>= 1) + { + float other = __shfl_xor(v, offset, ROCWMMA_N); + v = fmaxf(v, other); + } + m_tmp[k] = v; + } + + float o_scale[NUM]; + for(int k = 0; k < NUM; k++) + { + float m_prev = m[k]; + m[k] = fmaxf(m[k], m_tmp[k]); + + o_scale[k] = exp2f(m_prev - m[k]); + + d[k] *= o_scale[k]; + } + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int k = 0; k < els_per_thread; k++) + { + float neg_m = -m[k]; + //// accumulate d + fragsTmp[i][j].x[k] = exp2f(fragsTmp[i][j].x[k] + neg_m); + + if constexpr (DenominatorAccumUnit == ComputeUnit::kCudaCore) + { + float d_local = fragsTmp[i][j].x[k]; + for (int offset = ROCWMMA_N / 2; offset > 0; offset >>= 1) { + float other = __shfl_xor(d_local, offset, ROCWMMA_N); + d_local += other; + } + d[k] += d_local; + } + } + } + //here require handle RO + for(int k = 0; k < SV_ITERS; k++) + { + for(int j = 0; j < BLOCKS_Y; j++) + { + for(int z = 0; z < fragsOut[k][i][j].num_elements; z++) + { + float o_true = o_scale[z]; + fragsOut[k][i][j].x[z] *= o_true; + } + } + } + } + + //// cast f32 to fp8 + localWriteAcc(fragsTmp,ldsPtr + ldsReadOffsetAcc,ldsld_new); + synchronize_workgroup(); + + InputTV RS[BLOCKS_X][BLOCKS_Y][els_per_thread]; + for(int i = 0; i < BLOCKS_X; i++) + { + for(int j = 0; j < BLOCKS_Y; j++) + { + auto baseoffset = ldsReadOffsetAcc + i * ROCWMMA_M * ldsld_new + j * ROCWMMA_N; + auto threadoffset = threadIdx.x % threads_per_row * els_per_thread + + (threadIdx.x) % WARP_SIZE / threads_per_row * ldsld_new; + for(int k = 0; k < els_per_thread; k++) + { + RS[i][j][k] = InputTV(ldsPtr[baseoffset + threadoffset + k]); + } + } + } + + // do we need sync here? + synchronize_workgroup(); + + auto* ldsPtrf8 = reinterpret_cast(localMemPtr); + constexpr uint32_t ldsWidth_fp8 = MACRO_TILE_Y; + constexpr uint32_t ldsHeight_fp8 = 2 * MACRO_TILE_X * sizeof(ComputeTV) / sizeof(InputTV); + constexpr uint32_t ldsld_fp8 = std::is_same_v ? ldsWidth_fp8 : ldsHeight_fp8; + + auto ldsReadOffsetsv = get<0>(localWarpOffset) * ldsld_fp8 + get<1>(localWarpOffset); + for(int i = 0; i < BLOCKS_X; i++) + { + for(int j = 0; j < BLOCKS_Y; j++) + { + auto baseoffset = ldsReadOffsetsv + i * ROCWMMA_M * ldsld_fp8 + j * ROCWMMA_N; + auto threadoffset = threadIdx.x % threads_per_row * els_per_thread + + (threadIdx.x) % WARP_SIZE / threads_per_row * ldsld_fp8; + + for(int k = 0; k < els_per_thread; k++) + { + ldsPtrf8[baseoffset + threadoffset + k] = RS[i][j][k]; + } + } + } + + // do we need sync here? + synchronize_workgroup(); + + constexpr size_t SV_CNT = MACRO_TILE_Y / ROCWMMA_K; + auto ldsReadOffsetS = get<0>(localWarpOffset) * ldsld_fp8; + MfmaFragS fragsS[SV_CNT][BLOCKS_X]; + for(int i = 0; i < SV_CNT; i++) + { + localReadS(fragsS[i], ldsPtrf8 + ldsReadOffsetS, ldsld_fp8); + ldsReadOffsetS += ROCWMMA_K; + } + + for(int sv_iter = 0; sv_iter < SV_ITERS; sv_iter++) + { + //// load v and calcualte sv + using GRBuffVMap1d = GetDataLayout_t; + auto globalReadOffsetV + = GRBuffVMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), 1u) + + sv_iter * MACRO_TILE_Y * ldv + batch_id * stride_bz_v + (head_id / num_kv_groups) * stride_h_v; + + auto kStepOffsetV = GRBuffVMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldv); + + GRBuffV grBuffV; + + using LWBuffVShape = GetIOShape_t; + using LWBuffVMap1d = GetDataLayout_t; + + auto ldsWriteOffsetV = MACRO_TILE_X * MACRO_TILE_Y; + auto ldsReadOffsetV + = ldsWriteOffsetV + + LWBuffVMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld_fp8); + + for(uint32_t currentK = 0; currentK < SV_CNT; currentK++) + { + globalReadCoopV(grBuffV, V + globalReadOffsetV, ldv, warpIndex); + globalReadOffsetV += kStepOffsetV; + + localWriteCoopV(ldsPtrf8 + ldsWriteOffsetV, grBuffV, ldsld_fp8, warpIndex); + synchronize_workgroup(); + + MfmaFragV fragsV[BLOCKS_Y]; + + localReadV(fragsV, ldsPtrf8 + ldsReadOffsetV, ldsld_fp8); + svgemm(fragsOut[sv_iter], fragsS[currentK], fragsV, fragsOut[sv_iter]); + synchronize_workgroup(); + } + } + } + + //// normalize d + + float d_rcp[BLOCKS_X]; + + auto* ldsPtrOut = reinterpret_cast(localMemPtr); + constexpr uint32_t ldsWidth_Out = MACRO_TILE_Y; + constexpr uint32_t ldsHeight_Out = 2 * MACRO_TILE_X; + constexpr uint32_t ldsld_Out = std::is_same_v ? ldsWidth_Out : ldsHeight_Out; + + auto ldsReadOffsetAcc = get<0>(localWarpOffset) * ldsld_Out + get<1>(localWarpOffset); + +#pragma unroll + for(int sv_iter = 0; sv_iter < SV_ITERS; sv_iter++) + { +#pragma unroll + for(int i = 0; i < BLOCKS_X; i++) + { +#pragma unroll + for(int j = 0; j < BLOCKS_Y; j++) + { + auto base_offset = batch_id * (num_qo_heads / num_kv_groups) * head_dim + (head_id / num_kv_groups) * head_dim; + auto scale_idx = sv_iter * MACRO_TILE_Y + get<1>(localWarpOffset) + j * ROCWMMA_N; +#pragma unroll + for(int k = 0; k < fragsOut[sv_iter][i][j].num_elements; k++) + { + auto col = threadIdx.x % ROCWMMA_N; + fragsOut[sv_iter][i][j].x[k] /= d[k]; + fragsOut[sv_iter][i][j].x[k] *= V_scale[base_offset + scale_idx + col]; + } + } + } + } + + for(int i = 0; i < SV_ITERS; i++){ + auto Out_macroTileCoord = make_coord2d(blockIdx.x, i) * macroTileSize; + auto Out_warpTileCoord = Out_macroTileCoord + localWarpOffset; + + globalWriteD(O + batch_id * stride_bz_o + head_id * stride_h_o + + MfmaFragDMap1d::fromMatrixCoord(Out_warpTileCoord, ldd), fragsOut[i], ldd); + } + } +} + diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index f6f74c45..d99ed18b 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -19,15 +19,31 @@ */ #pragma once + +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#include +#else +#include +#endif + #define FINAL_MASK 0xffffffff namespace vllm { +template +__device__ __forceinline__ T shfl_xor(T val, int laneMask) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + return __shfl_xor(val, laneMask, warpSize); +#else + return __shfl_xor_sync(FINAL_MASK, val, laneMask, warpSize); +#endif +} + template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) + val += shfl_xor(val, mask); return val; } @@ -38,8 +54,8 @@ __inline__ __device__ T warpReduceSumV2(T* val) for (int i = 0; i < NUM; i++) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) + val[i] += shfl_xor(val[i], mask); } return (T) (0.0f); } @@ -47,9 +63,10 @@ __inline__ __device__ T warpReduceSumV2(T* val) /* Calculate the sum of all elements in a block */ template __inline__ __device__ T blockReduceSum(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); + int wid = threadIdx.x / warpSize; val = warpReduceSum(val); @@ -58,9 +75,9 @@ __inline__ __device__ T blockReduceSum(T val) { __syncthreads(); - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : (T)(0.0f); val = warpReduceSum(val); return val; } @@ -68,9 +85,10 @@ __inline__ __device__ T blockReduceSum(T val) { /* Calculate the sum of all elements in a block */ template __inline__ __device__ T blockAllReduceSum(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); + int wid = threadIdx.x / warpSize; val = warpReduceSum(val); @@ -79,9 +97,8 @@ __inline__ __device__ T blockAllReduceSum(T val) { __syncthreads(); - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (lane < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f); + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : (T)(0.0f); val = warpReduceSum(val); return val; } @@ -89,9 +106,10 @@ __inline__ __device__ T blockAllReduceSum(T val) { template __inline__ __device__ T blockReduceSumV2(T* val) { - static __shared__ T shared[NUM][33]; - int lane = threadIdx.x & 0x1f; - int wid = threadIdx.x >> 5; + // static __shared__ T shared[NUM][warpSize + 1]; + __shared__ T shared[NUM][warpSize + 1]; + int lane = threadIdx.x & (warpSize - 1); + int wid = threadIdx.x / warpSize; warpReduceSumV2(val); @@ -106,11 +124,12 @@ __inline__ __device__ T blockReduceSumV2(T* val) __syncthreads(); - bool is_mask = threadIdx.x < (blockDim.x / 32.f); + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; #pragma unroll for (int i = 0; i < NUM; i++) { - val[i] = is_mask ? shared[i][lane] : (T) (0.0f); + T tmp = (lane < numWarps) ? shared[i][lane] : T(0); + val[i] = tmp; } warpReduceSumV2(val); return (T) 0.0f; @@ -120,24 +139,27 @@ template __inline__ __device__ T warpReduceMax(T val) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = max(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); - return val; + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) { + T other = shfl_xor(val, mask); + val = val > other ? val : other; + } + return val; } /* Calculate the maximum of all elements in a block */ template __inline__ __device__ T blockReduceMax(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); // in-warp idx + int wid = threadIdx.x / warpSize; // warp idx val = warpReduceMax(val); // get maxx in each warp if (lane == 0) // record in-warp maxx by warp Idx shared[wid] = val; __syncthreads(); - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : -1e20f; val = warpReduceMax(val); return val; } @@ -146,9 +168,10 @@ __inline__ __device__ T blockReduceMax(T val) template __inline__ __device__ T blockAllReduceMax(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); // in-warp idx + int wid = threadIdx.x / warpSize; // warp idx val = warpReduceMax(val); // get maxx in each warp @@ -157,9 +180,8 @@ __inline__ __device__ T blockAllReduceMax(T val) __syncthreads(); - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (lane < (blockDim.x / 32.f)) ? shared[lane] : -1e20f; + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : -1e20f; val = warpReduceMax(val); return val; @@ -169,24 +191,27 @@ template __inline__ __device__ T warpReduceMin(T val) { #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = min(val, __shfl_xor_sync(0xffffffff, val, mask, 32)); + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) { + T other = shfl_xor(val, mask); + val = val < other ? val : other; + } return val; } /* Calculate the minimum of all elements in a block */ template __inline__ __device__ T blockReduceMin(T val) { - static __shared__ T shared[32]; - int lane = threadIdx.x & 0x1f; // in-warp idx - int wid = threadIdx.x >> 5; // warp idx + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); // in-warp idx + int wid = threadIdx.x / warpSize; // warp idx val = warpReduceMin(val); // get minx in each warp if (lane == 0) // record in-warp minx by warp Idx shared[wid] = val; __syncthreads(); - // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent - // blockDim.x is not divided by 32 - val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : 1e20f; + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : 1e20f; val = warpReduceMin(val); return val; } diff --git a/csrc/reduction_utils.h b/csrc/reduction_utils.h new file mode 100644 index 00000000..d63d19c4 --- /dev/null +++ b/csrc/reduction_utils.h @@ -0,0 +1,163 @@ +#pragma once +#include +#include +#include + +namespace vllm { + +__device__ __forceinline__ unsigned full_mask() { + return 0xFFFFFFFFu; +} + +template +__device__ __forceinline__ T shfl_xor(T v, int laneMask) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + return __shfl_xor(v, laneMask, warpSize); +#else + return __shfl_xor_sync(full_mask(), v, laneMask, warpSize); +#endif +} + +template +__device__ __forceinline__ T warpReduceSum(T val) { + + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + val += shfl_xor(val, offset); + } + return val; +} + +template +__device__ __forceinline__ void warpReduceSumV2(T* val) { +#pragma unroll + for (int i = 0; i < NUM; ++i) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + val[i] += shfl_xor(val[i], offset); + } + } +} + +template +__device__ __forceinline__ T blockReduceSum(T val) { + __shared__ T shared[64]; // 64 >= max(warpSize) for NV(32)/AMD(64) + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(0); + agg = warpReduceSum(agg); + return agg; +} + +template +__device__ __forceinline__ T blockAllReduceSum(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(0); + agg = warpReduceSum(agg); + return agg; +} + +template +__device__ __forceinline__ void blockReduceSumV2(T* val) { + + __shared__ T shared[NUM][65]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; ++i) shared[i][wid] = val[i]; + } + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; +#pragma unroll + for (int i = 0; i < NUM; ++i) { + T tmp = (lane < numWarps) ? shared[i][lane] : T(0); + val[i] = tmp; + } + warpReduceSumV2(val); +} + +template +__device__ __forceinline__ T warpReduceMax(T val) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + T other = shfl_xor(val, offset); + val = val > other ? val : other; + } + return val; +} + +template +__device__ __forceinline__ T blockReduceMax(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMax(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + + T agg = (lane < numWarps) ? shared[lane] : T(-1e20); + agg = warpReduceMax(agg); + return agg; +} + +template +__device__ __forceinline__ T blockAllReduceMax(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMax(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(-1e20); + agg = warpReduceMax(agg); + return agg; +} + +template +__device__ __forceinline__ T warpReduceMin(T val) { + for (int offset = warpSize >> 1; offset > 0; offset >>= 1) { + T other = shfl_xor(val, offset); + val = val < other ? val : other; + } + return val; +} + +template +__device__ __forceinline__ T blockReduceMin(T val) { + __shared__ T shared[64]; + const int lane = threadIdx.x & (warpSize - 1); + const int wid = threadIdx.x / warpSize; + + val = warpReduceMin(val); + if (lane == 0) shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + T agg = (lane < numWarps) ? shared[lane] : T(1e20); + agg = warpReduceMin(agg); + return agg; +} + +} // namespace vllm diff --git a/csrc/reduction_utils_hip.cuh b/csrc/reduction_utils_hip.cuh new file mode 100644 index 00000000..b4c93936 --- /dev/null +++ b/csrc/reduction_utils_hip.cuh @@ -0,0 +1,220 @@ +// !!! This is a file automatically generated by hipify!!! +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Modifications copyright (c) 2024 by SageAttention team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) +#include +#else +#include +#endif + +#define FINAL_MASK 0xffffffff + +namespace vllm { + +template +__device__ __forceinline__ T shfl_xor(T val, int laneMask) { +#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__) + return __shfl_xor(val, laneMask, warpSize); +#else + return __shfl_xor_sync(FINAL_MASK, val, laneMask, warpSize); +#endif +} + +template +__inline__ __device__ T warpReduceSum(T val) { +#pragma unroll + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) + val += shfl_xor(val, mask); + return val; +} + +template +__inline__ __device__ T warpReduceSumV2(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) + { +#pragma unroll + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) + val[i] += shfl_xor(val[i], mask); + } + return (T) (0.0f); +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockReduceSum(T val) { + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +/* Calculate the sum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceSum(T val) { + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : (T)(0.0f); + val = warpReduceSum(val); + return val; +} + +template +__inline__ __device__ T blockReduceSumV2(T* val) +{ + // static __shared__ T shared[NUM][warpSize + 1]; + __shared__ T shared[NUM][warpSize + 1]; + int lane = threadIdx.x & (warpSize - 1); + int wid = threadIdx.x / warpSize; + + warpReduceSumV2(val); + + if (lane == 0) + { +#pragma unroll + for (int i = 0; i < NUM; i++) + { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; +#pragma unroll + for (int i = 0; i < NUM; i++) + { + T tmp = (lane < numWarps) ? shared[i][lane] : T(0); + val[i] = tmp; + } + warpReduceSumV2(val); + return (T) 0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T val) +{ +#pragma unroll + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) { + T other = shfl_xor(val, mask); + val = val > other ? val : other; + } + return val; +} +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockReduceMax(T val) +{ + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); // in-warp idx + int wid = threadIdx.x / warpSize; // warp idx + val = warpReduceMax(val); // get maxx in each warp + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + return val; +} + +/* Calculate the maximum of all elements in a block */ +template +__inline__ __device__ T blockAllReduceMax(T val) +{ + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); // in-warp idx + int wid = threadIdx.x / warpSize; // warp idx + + val = warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + shared[wid] = val; + + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : -1e20f; + val = warpReduceMax(val); + + return val; +} + +template +__inline__ __device__ T warpReduceMin(T val) +{ +#pragma unroll + for (int mask = warpSize >> 1; mask > 0; mask >>= 1) { + T other = shfl_xor(val, mask); + val = val < other ? val : other; + } + return val; +} +/* Calculate the minimum of all elements in a block */ +template +__inline__ __device__ T blockReduceMin(T val) +{ + // static __shared__ T shared[warpSize]; + __shared__ T shared[warpSize]; + int lane = threadIdx.x & (warpSize - 1); // in-warp idx + int wid = threadIdx.x / warpSize; // warp idx + val = warpReduceMin(val); // get minx in each warp + if (lane == 0) // record in-warp minx by warp Idx + shared[wid] = val; + __syncthreads(); + + const int numWarps = (blockDim.x + warpSize - 1) / warpSize; + val = (lane < numWarps) ? shared[lane] : 1e20f; + val = warpReduceMin(val); + return val; +} + +} // namespace vllm diff --git a/example/cogvideox_infer.py b/example/cogvideox_infer.py index 42c6269c..9aa273ad 100644 --- a/example/cogvideox_infer.py +++ b/example/cogvideox_infer.py @@ -7,15 +7,15 @@ from sageattention import sageattn import torch.nn.functional as F -prompt_path = "videos/testing_prompts.txt" +prompt_path = "/home/PR/SageAttention/example/videos/testing_prompts.txt" def parse_args(): parser = argparse.ArgumentParser(description="CogVideoX Inference") parser.add_argument("--model",choices=["cogvideox-2b", "cogvideox1.5-5b"], default="cogvideox-2b", help="CogVideoX model") parser.add_argument('--compile', action='store_true', help='Compile the model') - parser.add_argument('--attention_type', type=str, default='sdpa', choices=['sdpa', 'sage', 'fa3', 'fa3_fp8'], help='Attention type') + parser.add_argument('--attention_type', type=str, default='sage', choices=['sdpa', 'sage', 'fa3', 'fa3_fp8'], help='Attention type') parser.add_argument("--start", type=int, default=0, help="Starting prompt id of this run.") - parser.add_argument("--end", type=int, default=12, help="Ending prompt id of this run.") + parser.add_argument("--end", type=int, default=1, help="Ending prompt id of this run.") args = parser.parse_args() return args @@ -46,7 +46,9 @@ def parse_args(): with open(prompt_path, "r", encoding="utf-8") as file: prompts = file.readlines() selected_prompts = [p.strip() for p in prompts[args.start:args.end]] - + print("cwd =", os.getcwd()) + print("video_dir =", os.path.abspath(video_dir)) + print("num_prompts =", len(prompts), "selected =", len(selected_prompts), "start/end =", args.start, args.end) pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=torch_dtype) if args.compile: diff --git a/example/videos/testing_prompts.txt b/example/videos/testing_prompts.txt index 6c701f0f..79dd9f78 100644 --- a/example/videos/testing_prompts.txt +++ b/example/videos/testing_prompts.txt @@ -1,12 +1 @@ A bustling city street at night, filled with the glow of car headlights and the ambient light of streetlights. The scene is a blur of motion, with cars speeding by and pedestrians navigating the crosswalks. The cityscape is a mix of towering buildings and illuminated signs, creating a vibrant and dynamic atmosphere. The perspective of the video is from a high angle, providing a bird's eye view of the street and its surroundings. The overall style of the video is dynamic and energetic, capturing the essence of urban life at night. -A detailed wooden toy ship with intricately carved masts and sails is seen gliding smoothly over a plush, blue carpet that mimics the waves of the sea. The ship's hull is painted a rich brown, with tiny windows. The carpet, soft and textured, provides a perfect backdrop, resembling an oceanic expanse. Surrounding the ship are various other toys and children's items, hinting at a playful environment. The scene captures the innocence and imagination of childhood, with the toy ship's journey symbolizing endless adventures in a whimsical, indoor setting. -A majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty. -A serene night scene in a forested area. The first frame shows a tranquil lake reflecting the star-filled sky above. The second frame reveals a beautiful sunset, casting a warm glow over the landscape. The third frame showcases the night sky, filled with stars and a vibrant Milky Way galaxy. The video is a time-lapse, capturing the transition from day to night, with the lake and forest serving as a constant backdrop. The style of the video is naturalistic, emphasizing the beauty of the night sky and the peacefulness of the forest. -A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene. The video is shot from a slightly elevated angle, providing a comprehensive view of the turtle's surroundings. The overall style of the video is calm and peaceful, capturing the beauty and tranquility of the underwater world. -A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road. -A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures. -A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene. In the foreground, a few cars can be seen driving along a winding road that cuts through the mountains. The cars are small compared to the vastness of the landscape, emphasizing the grandeur of the surroundings. The overall style of the video is a mix of adventure and tranquility, with the hot air balloons adding a touch of whimsy to the otherwise serene mountain landscape. The video is likely shot during the day, as the lighting is bright and even, casting soft shadows on the snow-covered mountains. -A vibrant underwater scene. A group of blue fish, with yellow fins, are swimming around a coral reef. The coral reef is a mix of brown and green, providing a natural habitat for the fish. The water is a deep blue, indicating a depth of around 30 feet. The fish are swimming in a circular pattern around the coral reef, indicating a sense of motion and activity. The overall scene is a beautiful representation of marine life. -The camera follows behind a white vintage SUV with a black roof rack as it speeds up a steep dirt road surrounded by pine trees on a steep mountain slope, dust kicks up from its tires, the sunlight shines on the SUV as it speeds along the dirt road, casting a warm glow over the scene. The dirt road curves gently into the distance, with no other cars or vehicles in sight. The trees on either side of the road are redwoods, with patches of greenery scattered throughout. The car is seen from the rear following the curve with ease, making it seem as if it is on a rugged drive through the rugged terrain. The dirt road itself is surrounded by steep hills and mountains, with a clear blue sky above with wispy clouds. -The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements. -The vibrant beauty of a sunflower field. The sunflowers, with their bright yellow petals and dark brown centers, are in full bloom, creating a stunning contrast against the green leaves and stems. The sunflowers are arranged in neat rows, creating a sense of order and symmetry. The sun is shining brightly, casting a warm glow on the flowers and highlighting their intricate details. The video is shot from a low angle, looking up at the sunflowers, which adds a sense of grandeur and awe to the scene. The sunflowers are the main focus of the video, with no other objects or people present. The video is a celebration of nature's beauty and the simple joy of a sunny day in the countryside. diff --git a/sageattention/__init__.py b/sageattention/__init__.py index 73b0256d..f0eb0b67 100644 --- a/sageattention/__init__.py +++ b/sageattention/__init__.py @@ -1,3 +1,13 @@ +import torch + +def is_hip() -> bool: + return torch.version.hip is not None + + +def on_gfx942() -> bool: + GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName + return any(arch in GPU_ARCH for arch in ["gfx942"]) + from .core import sageattn, sageattn_varlen from .core import sageattn_qk_int8_pv_fp16_triton from .core import sageattn_qk_int8_pv_fp16_cuda diff --git a/sageattention/core.py b/sageattention/core.py index 96da4c02..336043ce 100644 --- a/sageattention/core.py +++ b/sageattention/core.py @@ -55,6 +55,7 @@ import subprocess import re +from . import is_hip, on_gfx942 def get_cuda_version(): try: @@ -141,7 +142,9 @@ def sageattn( """ arch = get_cuda_arch_versions()[q.device.index] - if arch == "sm80": + if is_hip(): + return sageattn_qk_int8_pv_fp8_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") + elif arch == "sm80": return sageattn_qk_int8_pv_fp16_cuda(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse, pv_accum_dtype="fp32") elif arch == "sm86": return sageattn_qk_int8_pv_fp16_triton(q, k, v, tensor_layout=tensor_layout, is_causal=is_causal, sm_scale=sm_scale, return_lse=return_lse) @@ -722,7 +725,9 @@ def sageattn_qk_int8_pv_fp8_cuda( """ dtype = q.dtype - assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9." + if not is_hip(): + assert SM89_ENABLED, "SM89 kernel is not available. Make sure you GPUs with compute capability 8.9." + assert q.is_cuda, "Input tensors must be on cuda." assert dtype in [torch.float16, torch.bfloat16], "Input tensors must be in dtype of torch.float16 or torch.bfloat16" assert qk_quant_gran in ["per_warp", "per_thread"], "qk_quant_gran must be either 'per_warp' or 'per_thread'." @@ -802,15 +807,32 @@ def sageattn_qk_int8_pv_fp8_cuda( warnings.warn("pv_accum_dtype is 'fp32+fp16', smooth_v will be ignored.") smooth_v = False - quant_v_scale_max = 448.0 + + if is_hip() and on_gfx942(): + quant_v_scale_max = 224.0 + else: + quant_v_scale_max = 448.0 + if pv_accum_dtype == 'fp32+fp16': quant_v_scale_max = 2.25 - + + if is_hip(): + kv_len = k.size(seq_dim) + v_pad_len = 128 - (kv_len % 128) if kv_len % 128 != 0 else 0 + if v_pad_len > 0: + if tensor_layout == "HND": + v = torch.cat([v, torch.zeros(v.size(0), v.size(1), v_pad_len, v.size(3), dtype=v.dtype, device=v.device)], dim=2) + else: + v = torch.cat([v, torch.zeros(v.size(0), v_pad_len, v.size(2), v.size(3), dtype=v.dtype, device=v.device)], dim=1) + v_fp8, v_scale, vm = per_channel_fp8(v, tensor_layout=tensor_layout, scale_max=quant_v_scale_max, smooth_v=smooth_v) if pv_accum_dtype == "fp32": - if smooth_v: - lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + if is_hip(): + from . import _qattn_rocm + lse = _qattn_rocm.qk_int8_sv_f8_accum_f32_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) + elif smooth_v: + lse = _qattn_sm89.qk_int8_sv_f8_accum_f32_fuse_v_scale_fuse_v_mean_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, vm, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) else: lse = sm89_compile.qk_int8_sv_f8_accum_f32_fuse_v_scale_attn(q_int8, k_int8, v_fp8, o, q_scale, k_scale, v_scale, _tensor_layout, _is_caual, _qk_quant_gran, sm_scale, _return_lse) elif pv_accum_dtype == "fp32+fp32": diff --git a/setup.py b/setup.py index 6b2c5b43..c04adc4a 100644 --- a/setup.py +++ b/setup.py @@ -22,17 +22,102 @@ from packaging.version import parse, Version from setuptools import setup, find_packages - +import torch +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME # Skip CUDA build in CI or when explicitly requested SKIP_CUDA_BUILD = ( os.getenv("SAGEATTN_SKIP_CUDA_BUILD", "0").upper() in {"1", "TRUE", "YES"} or ("sdist" in sys.argv) ) +try: + import torch + IS_ROCM = getattr(torch.version, "hip", None) is not None +except ImportError: + IS_ROCM = False + torch = None + +def get_rocm_arch(): + try: + # get gfx arch + result = subprocess.run(['rocminfo'], capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError("rocminfo command failed") + + for line in result.stdout.splitlines(): + if "gfx" in line: + return line.split()[1] + except Exception as e: + print(f"Error detecting current architecture: {e}") + + return None + + ext_modules = [] cmdclass = {} -if not SKIP_CUDA_BUILD: +if IS_ROCM: + import torch + from torch.utils.cpp_extension import BuildExtension, CUDAExtension + + ROCM_HOME = os.environ.get("ROCM_HOME", "/opt/rocm") + + CXX_FLAGS = ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"] + ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 + CXX_FLAGS.append(f"-D_GLIBCXX_USE_CXX11_ABI={ABI}") + + for var in ("HCC_AMDGPU_TARGET", "AMDGPU_TARGETS", "HIPCC_COMPILE_FLAGS_APPEND", "HIP_TARGETS", "ROCM_TARGET_LST"): + if var in os.environ: + del os.environ[var] + + # ROCm compile flag + rocm_arch = get_rocm_arch() + if rocm_arch: + os.environ['ROCM_ARCH'] = rocm_arch + debug = os.environ.get("SA_DEBUG", "0") == "1" + base_flags = ["-std=c++17", f"-D_GLIBCXX_USE_CXX11_ABI={ABI}", "-DUSE_ROCM=1"] + debug_flags = ["-O0", "-g3", "-ggdb", "-fno-inline", "-fno-omit-frame-pointer"] if debug else ["-O3"] + rocm_hipcc = base_flags + debug_flags + [f"--offload-arch={rocm_arch}"] + [f"-D__ROCM_ARCH_{rocm_arch.upper()}"] + rocm_cxx = base_flags + debug_flags + # rocm_hipcc = base_flags + debug_flags + ["--offload-arch=gfx942"] + ["-D__ROCM_ARCH_GFX942"] + + # ROCm lib path + torch_lib = os.path.join(torch.__path__[0], "lib") + rocm_libs = [os.path.join(ROCM_HOME, d) for d in ["lib", "lib64"]] + + # ROCm extension modules + ext_modules.extend([ + CUDAExtension( + "sageattention._qattn_rocm", + sources=[ + "csrc/qattn/rocm/pybind_rocm.cpp", + "csrc/qattn/rocm/launch_sgattn.cu", + "csrc/qattn/rocm/sgattn.cu" + ], + include_dirs=[os.path.join(ROCM_HOME, "include"), os.path.join(ROCM_HOME, "include", "hip")], + extra_compile_args={"cxx": rocm_cxx, "nvcc": rocm_hipcc}, + libraries=["amdhip64", "hiprtc", "rocblas", "hipblas", "c10", "torch", "torch_python"], + library_dirs=rocm_libs + [torch_lib], + runtime_library_dirs=rocm_libs + [torch_lib] + ), + CUDAExtension( + "sageattention._fused", + # sources=["csrc/fused/rocm/pybind_rocm.cpp", "csrc/fused/rocm/fused.cu"], + sources=["csrc/fused/pybind.cpp", "csrc/fused/fused.cu"], + include_dirs=[os.path.join(ROCM_HOME, "include"), os.path.join(ROCM_HOME, "include", "hip")], + extra_compile_args={"cxx": rocm_cxx, "nvcc": rocm_hipcc}, + libraries=["amdhip64", "hiprtc", "rocblas", "hipblas", "c10", "torch", "torch_python"], + library_dirs=rocm_libs + [torch_lib], + runtime_library_dirs=rocm_libs + [torch_lib] + ) + ]) + + cmdclass = {"build_ext": BuildExtension} if ext_modules else {} + print("Current ROCm architecture detected:", rocm_arch) + else: + print("Unable to detect current ROCm architecture.") + +elif not SKIP_CUDA_BUILD: import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME @@ -269,6 +354,29 @@ def compile_new(*args, **kwargs): cmdclass = {"build_ext": BuildExtensionSeparateDir} if ext_modules else {} + def build_extension(self, ext): + with self.build_extension_patch_lock: + if not getattr(self.compiler, "_compile_separate_output_dir", False): + compile_orig = self.compiler.compile + + def compile_new(*args, **kwargs): + return compile_orig(*args, **{ + **kwargs, + "output_dir": os.path.join( + kwargs["output_dir"], + self.thread_ext_name_map[threading.current_thread().ident]), + }) + self.compiler.compile = compile_new + self.compiler._compile_separate_output_dir = True + self.thread_ext_name_map[threading.current_thread().ident] = ext.name + objects = super().build_extension(ext) + return objects + + cmdclass = {"build_ext": BuildExtensionSeparateDir} if ext_modules else {} + +else: + print("Skipping CUDA/ROCm extension build...") + setup( name='sageattention', version='2.2.0',