Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions flashinfer/csrc/pytorch_extension_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <hip/hip_fp16.h>
#endif

#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2)
#if defined(FLASHINFER_ENABLE_FP8)
#include <hip/hip_fp8.h>
#endif

Expand All @@ -46,7 +46,7 @@
#include <cuda_fp16.h>
#endif

#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2)
#if defined(FLASHINFER_ENABLE_FP8)
#include <cuda_fp8.h>
#endif

Expand Down Expand Up @@ -95,7 +95,7 @@ using dtype_half = __half;
#ifdef FLASHINFER_ENABLE_BF16
using dtype_bfloat16 = __hip_bfloat16;
#endif
#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2)
#if defined(FLASHINFER_ENABLE_FP8)
using dtype_fp8_e4m3 = __hip_fp8_e4m3_fnuz;
using dtype_fp8_e5m2 = __hip_fp8_e5m2_fnuz;
#endif
Expand All @@ -106,7 +106,7 @@ using dtype_half = nv_half;
#ifdef FLASHINFER_ENABLE_BF16
using dtype_bfloat16 = nv_bfloat16;
#endif
#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2)
#if defined(FLASHINFER_ENABLE_FP8)
using dtype_fp8_e4m3 = nv_fp8_e4m3;
using dtype_fp8_e5m2 = nv_fp8_e5m2;
#endif
Expand Down Expand Up @@ -134,7 +134,7 @@ using dtype_fp8_e5m2 = nv_fp8_e5m2;

#ifdef FLASHINFER_ENABLE_FP8_E4M3
#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \
case at::ScalarType::Float8_e4m3fn: { \
case at::ScalarType::Float8_e4m3fnuz: { \
using c_type = dtype_fp8_e4m3; \
return __VA_ARGS__(); \
}
Expand All @@ -144,7 +144,7 @@ using dtype_fp8_e5m2 = nv_fp8_e5m2;

#ifdef FLASHINFER_ENABLE_FP8_E5M2
#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \
case at::ScalarType::Float8_e5m2: { \
case at::ScalarType::Float8_e5m2fnuz: { \
using c_type = dtype_fp8_e5m2; \
return __VA_ARGS__(); \
}
Expand Down Expand Up @@ -281,6 +281,6 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)

inline bool is_float8_tensor(const at::Tensor& tensor) {
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn ||
tensor.scalar_type() == at::ScalarType::Float8_e5m2;
return tensor.scalar_type() == at::ScalarType::Float8_e4m3fnuz ||
tensor.scalar_type() == at::ScalarType::Float8_e5m2fnuz;
}
4 changes: 4 additions & 0 deletions flashinfer/jit/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def gen_jit_spec(
cflags += [
"--offload-arch=gfx942",
"-DFLASHINFER_ENABLE_HIP",
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
"-DHIP_ENABLE_WARP_SYNC_BUILTINS=1",
]
cuda_cflags = [
Expand All @@ -156,6 +159,7 @@ def gen_jit_spec(
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
Expand Down
1 change: 1 addition & 0 deletions flashinfer/jit/cpp_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def generate_ninja_build_for_op(
"-DHIP_ENABLE_WARP_SYNC_BUILTINS=1",
"-DFLASHINFER_ENABLE_F16",
"-DFLASHINFER_ENABLE_BF16",
"-DFLASHINFER_ENABLE_FP8",
"-DFLASHINFER_ENABLE_FP8_E4M3",
"-DFLASHINFER_ENABLE_FP8_E5M2",
]
Expand Down
8 changes: 4 additions & 4 deletions flashinfer/jit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def wrapper(func, args):
dtype_map_hip = {
torch.float16: "__half",
torch.bfloat16: "__hip_bfloat16",
torch.float8_e4m3fn: "__hip_fp8_e4m3_fnuz",
torch.float8_e5m2: "__hip_fp8_e5m2_fnuz",
torch.float8_e4m3fnuz: "__hip_fp8_e4m3_fnuz",
torch.float8_e5m2fnuz: "__hip_fp8_e5m2_fnuz",
torch.int8: "int8_t",
torch.uint8: "uint8_t",
torch.int32: "int32_t",
Expand All @@ -87,8 +87,8 @@ def wrapper(func, args):
filename_safe_dtype_map = {
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float8_e4m3fn: "e4m3",
torch.float8_e5m2: "e5m2",
torch.float8_e4m3fnuz: "e4m3fnuz",
torch.float8_e5m2fnuz: "e5m2fnuz",
torch.int8: "i8",
torch.uint8: "u8",
torch.int32: "i32",
Expand Down
8 changes: 2 additions & 6 deletions libflashinfer/include/flashinfer/attention/generic/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -598,11 +598,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const Params params) {
*/
constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeof_dtype) {
if (group_size == 8U) {
if (sizeof_dtype == 1U) {
return 256U; // not enough registers for 512 threads
} else {
return 512U;
}
return 512U;
} else {
// At 128 threads and 32 threads per warp, the CUDA implementation deploys 4 warps per block.
// We have 64 threads per wavefront so we use 256 threads
Expand Down Expand Up @@ -661,7 +657,7 @@ gpuError_t SingleDecodeWithKVCacheDispatched(Params params, typename Params::DTy
constexpr uint32_t bdz = num_threads / (bdx * bdy);

// AMD CDNA3 Reduce tile size to accomodate for CDNA3 architecture's hardware threshold.
constexpr uint32_t tile_size_per_bdx = (GROUP_SIZE == 1U) ? 2U : 1U;
constexpr uint32_t tile_size_per_bdx = (sizeof(DTypeKV) == 1 || GROUP_SIZE == 1) ? 2U : 1U;

// This has been hard coded to 2U. Previous implementation involved a macro redirection that
// always resulted in 2U for H100 or CDNA3 architecture. Please take a look at
Expand Down
71 changes: 71 additions & 0 deletions libflashinfer/utils/conversion_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,25 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>

namespace {

__host__ __device__ __inline__ __hip_fp8_e5m2fnuz convert_float_to_fp8(
float in, __hip_fp8_interpretation_t interpret, __hip_saturation_t sat) {
return __hip_cvt_float_to_fp8(in, sat, interpret);
}

__host__ __device__ __inline__ __hip_fp8_e4m3fnuz convert_float_to_fp8(
float in, __hip_fp8_interpretation_t interpret, __hip_saturation_t sat) {
return __hip_cvt_float_to_fp8(in, sat, interpret);
}

__host__ __device__ __inline__ float convert_fp8_to_float(float in,
__hip_fp8_interpretation_t interpret) {
float hf = __hip_cvt_fp8_to_float(in, interpret);
return hf;
}

} // namespace
namespace fi::con {
template <typename DTypeIn, typename DTypeOut>
__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) {
Expand Down Expand Up @@ -50,4 +69,56 @@ __host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__hip_bfloat16, _
__hip_bfloat16 value) {
return value;
}

template <>
__host__ __device__ __inline__ __hip_fp8_e4m3fnuz explicit_casting<float, __hip_fp8_e4m3fnuz>(
float value) {
return convert_float_to_fp8(value, __HIP_E4M3_FNUZ, __HIP_SATURATE);
}

template <>
__host__ __device__ __inline__ float explicit_casting<__hip_fp8_e4m3fnuz, float>(
__hip_fp8_e4m3fnuz value) {
return convert_fp8_to_float(value, __HIP_E4M3_FNUZ);
}

template <>
__host__ __device__ __inline__ __hip_fp8_e4m3fnuz explicit_casting<__half, __hip_fp8_e4m3fnuz>(
__half value) {
float temp = __half2float(value);
return convert_float_to_fp8(temp, __HIP_E4M3_FNUZ, __HIP_SATURATE);
}

template <>
__host__ __device__ __inline__ __half explicit_casting<__hip_fp8_e4m3fnuz, __half>(
__hip_fp8_e4m3fnuz value) {
float temp = convert_fp8_to_float(value, __HIP_E4M3_FNUZ);
return __float2half(temp);
}

template <>
__host__ __device__ __inline__ __hip_fp8_e5m2fnuz explicit_casting<float, __hip_fp8_e5m2fnuz>(
float value) {
return convert_float_to_fp8(value, __HIP_E5M2_FNUZ, __HIP_SATURATE);
}

template <>
__host__ __device__ __inline__ float explicit_casting<__hip_fp8_e5m2fnuz, float>(
__hip_fp8_e5m2fnuz value) {
return convert_fp8_to_float(value, __HIP_E5M2_FNUZ);
}

template <>
__host__ __device__ __inline__ __hip_fp8_e5m2fnuz explicit_casting<__half, __hip_fp8_e5m2fnuz>(
__half value) {
float temp = __half2float(value);
return convert_float_to_fp8(temp, __HIP_E5M2_FNUZ, __HIP_SATURATE);
}

template <>
__host__ __device__ __inline__ __half explicit_casting<__hip_fp8_e5m2fnuz, __half>(
__hip_fp8_e5m2fnuz value) {
float temp = convert_fp8_to_float(value, __HIP_E5M2_FNUZ);
return __float2half(temp);
}
} // namespace fi::con
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ FLASHINFER_BUILD_WHEELS="ON"
FLASHINFER_AOT_TORCH_EXTS = {env="FLASHINFER_AOT_TORCH_EXTS", default="OFF"}
FLASHINFER_ENABLE_F16="ON"
FLASHINFER_ENABLE_BF16="ON"
FLASHINFER_ENABLE_FP8="OFF"
FLASHINFER_ENABLE_FP8_E4M3="OFF"
FLASHINFER_ENABLE_FP8_E5M2="OFF"
FLASHINFER_ENABLE_FP8="ON"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at Options.cmake setting FLASHINFER_ENABLE_FP8 sets the FLASHINFER_ENABLE_FP8_E4M3 and FLASHINFER_ENABLE_FP8_E5M2 to true. So, we should only use the FLASHINFER_ENABLE_FP8 flag.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be fixed later.

FLASHINFER_ENABLE_FP8_E4M3="ON"
FLASHINFER_ENABLE_FP8_E5M2="ON"
FLASHINFER_ENABLE_CUDA = {env="FLASHINFER_ENABLE_CUDA", default="OFF"}
FLASHINFER_ENABLE_HIP = {env="FLASHINFER_ENABLE_HIP", default="ON"}

Expand Down
1 change: 1 addition & 0 deletions scripts/run_hip_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

python -m pytest ../tests/test_sliding_window_hip.py \
../tests/test_batch_decode_kernels_hip.py \
../tests/test_batch_decode_kernels_hip_fp8.py \
../tests/test_batch_decode_vllm.py \
../tests/test_rope.py \
../tests/test_page.py \
Expand Down
Loading