diff --git a/flashinfer/csrc/pytorch_extension_utils.h b/flashinfer/csrc/pytorch_extension_utils.h index 129c5b8210..74e896b496 100644 --- a/flashinfer/csrc/pytorch_extension_utils.h +++ b/flashinfer/csrc/pytorch_extension_utils.h @@ -30,7 +30,7 @@ #include #endif -#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) +#if defined(FLASHINFER_ENABLE_FP8) #include #endif @@ -46,7 +46,7 @@ #include #endif -#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) +#if defined(FLASHINFER_ENABLE_FP8) #include #endif @@ -95,7 +95,7 @@ using dtype_half = __half; #ifdef FLASHINFER_ENABLE_BF16 using dtype_bfloat16 = __hip_bfloat16; #endif -#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) +#if defined(FLASHINFER_ENABLE_FP8) using dtype_fp8_e4m3 = __hip_fp8_e4m3_fnuz; using dtype_fp8_e5m2 = __hip_fp8_e5m2_fnuz; #endif @@ -106,7 +106,7 @@ using dtype_half = nv_half; #ifdef FLASHINFER_ENABLE_BF16 using dtype_bfloat16 = nv_bfloat16; #endif -#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) +#if defined(FLASHINFER_ENABLE_FP8) using dtype_fp8_e4m3 = nv_fp8_e4m3; using dtype_fp8_e5m2 = nv_fp8_e5m2; #endif @@ -134,7 +134,7 @@ using dtype_fp8_e5m2 = nv_fp8_e5m2; #ifdef FLASHINFER_ENABLE_FP8_E4M3 #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ - case at::ScalarType::Float8_e4m3fn: { \ + case at::ScalarType::Float8_e4m3fnuz: { \ using c_type = dtype_fp8_e4m3; \ return __VA_ARGS__(); \ } @@ -144,7 +144,7 @@ using dtype_fp8_e5m2 = nv_fp8_e5m2; #ifdef FLASHINFER_ENABLE_FP8_E5M2 #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ - case at::ScalarType::Float8_e5m2: { \ + case at::ScalarType::Float8_e5m2fnuz: { \ using c_type = dtype_fp8_e5m2; \ return __VA_ARGS__(); \ } @@ -281,6 +281,6 @@ inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { #define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) inline bool is_float8_tensor(const at::Tensor& tensor) { - return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || - tensor.scalar_type() == at::ScalarType::Float8_e5m2; + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fnuz || + tensor.scalar_type() == at::ScalarType::Float8_e5m2fnuz; } diff --git a/flashinfer/jit/core.py b/flashinfer/jit/core.py index 0147fda2dc..c00a5e5b54 100644 --- a/flashinfer/jit/core.py +++ b/flashinfer/jit/core.py @@ -148,6 +148,9 @@ def gen_jit_spec( cflags += [ "--offload-arch=gfx942", "-DFLASHINFER_ENABLE_HIP", + "-DFLASHINFER_ENABLE_FP8", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", "-DHIP_ENABLE_WARP_SYNC_BUILTINS=1", ] cuda_cflags = [ @@ -156,6 +159,7 @@ def gen_jit_spec( "-use_fast_math", "-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_FP8", "-DFLASHINFER_ENABLE_FP8_E4M3", "-DFLASHINFER_ENABLE_FP8_E5M2", ] diff --git a/flashinfer/jit/cpp_ext.py b/flashinfer/jit/cpp_ext.py index 2b450ac0ce..7bac1f3adb 100644 --- a/flashinfer/jit/cpp_ext.py +++ b/flashinfer/jit/cpp_ext.py @@ -71,6 +71,7 @@ def generate_ninja_build_for_op( "-DHIP_ENABLE_WARP_SYNC_BUILTINS=1", "-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_FP8", "-DFLASHINFER_ENABLE_FP8_E4M3", "-DFLASHINFER_ENABLE_FP8_E5M2", ] diff --git a/flashinfer/jit/utils.py b/flashinfer/jit/utils.py index 0d64e49492..ea95bbd375 100644 --- a/flashinfer/jit/utils.py +++ b/flashinfer/jit/utils.py @@ -74,8 +74,8 @@ def wrapper(func, args): dtype_map_hip = { torch.float16: "__half", torch.bfloat16: "__hip_bfloat16", - torch.float8_e4m3fn: "__hip_fp8_e4m3_fnuz", - torch.float8_e5m2: "__hip_fp8_e5m2_fnuz", + torch.float8_e4m3fnuz: "__hip_fp8_e4m3_fnuz", + torch.float8_e5m2fnuz: "__hip_fp8_e5m2_fnuz", torch.int8: "int8_t", torch.uint8: "uint8_t", torch.int32: "int32_t", @@ -87,8 +87,8 @@ def wrapper(func, args): filename_safe_dtype_map = { torch.float16: "f16", torch.bfloat16: "bf16", - torch.float8_e4m3fn: "e4m3", - torch.float8_e5m2: "e5m2", + torch.float8_e4m3fnuz: "e4m3fnuz", + torch.float8_e5m2fnuz: "e5m2fnuz", torch.int8: "i8", torch.uint8: "u8", torch.int32: "i32", diff --git a/libflashinfer/include/flashinfer/attention/generic/decode.cuh b/libflashinfer/include/flashinfer/attention/generic/decode.cuh index 4244355e4d..db933cd78d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/decode.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/decode.cuh @@ -598,11 +598,7 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const Params params) { */ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeof_dtype) { if (group_size == 8U) { - if (sizeof_dtype == 1U) { - return 256U; // not enough registers for 512 threads - } else { - return 512U; - } + return 512U; } else { // At 128 threads and 32 threads per warp, the CUDA implementation deploys 4 warps per block. // We have 64 threads per wavefront so we use 256 threads @@ -661,7 +657,7 @@ gpuError_t SingleDecodeWithKVCacheDispatched(Params params, typename Params::DTy constexpr uint32_t bdz = num_threads / (bdx * bdy); // AMD CDNA3 Reduce tile size to accomodate for CDNA3 architecture's hardware threshold. - constexpr uint32_t tile_size_per_bdx = (GROUP_SIZE == 1U) ? 2U : 1U; + constexpr uint32_t tile_size_per_bdx = (sizeof(DTypeKV) == 1 || GROUP_SIZE == 1) ? 2U : 1U; // This has been hard coded to 2U. Previous implementation involved a macro redirection that // always resulted in 2U for H100 or CDNA3 architecture. Please take a look at diff --git a/libflashinfer/utils/conversion_utils.h b/libflashinfer/utils/conversion_utils.h index 0263722cb3..05828344df 100644 --- a/libflashinfer/utils/conversion_utils.h +++ b/libflashinfer/utils/conversion_utils.h @@ -8,6 +8,25 @@ #include #include +namespace { + +__host__ __device__ __inline__ __hip_fp8_e5m2fnuz convert_float_to_fp8( + float in, __hip_fp8_interpretation_t interpret, __hip_saturation_t sat) { + return __hip_cvt_float_to_fp8(in, sat, interpret); +} + +__host__ __device__ __inline__ __hip_fp8_e4m3fnuz convert_float_to_fp8( + float in, __hip_fp8_interpretation_t interpret, __hip_saturation_t sat) { + return __hip_cvt_float_to_fp8(in, sat, interpret); +} + +__host__ __device__ __inline__ float convert_fp8_to_float(float in, + __hip_fp8_interpretation_t interpret) { + float hf = __hip_cvt_fp8_to_float(in, interpret); + return hf; +} + +} // namespace namespace fi::con { template __host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) { @@ -50,4 +69,56 @@ __host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__hip_bfloat16, _ __hip_bfloat16 value) { return value; } + +template <> +__host__ __device__ __inline__ __hip_fp8_e4m3fnuz explicit_casting( + float value) { + return convert_float_to_fp8(value, __HIP_E4M3_FNUZ, __HIP_SATURATE); +} + +template <> +__host__ __device__ __inline__ float explicit_casting<__hip_fp8_e4m3fnuz, float>( + __hip_fp8_e4m3fnuz value) { + return convert_fp8_to_float(value, __HIP_E4M3_FNUZ); +} + +template <> +__host__ __device__ __inline__ __hip_fp8_e4m3fnuz explicit_casting<__half, __hip_fp8_e4m3fnuz>( + __half value) { + float temp = __half2float(value); + return convert_float_to_fp8(temp, __HIP_E4M3_FNUZ, __HIP_SATURATE); +} + +template <> +__host__ __device__ __inline__ __half explicit_casting<__hip_fp8_e4m3fnuz, __half>( + __hip_fp8_e4m3fnuz value) { + float temp = convert_fp8_to_float(value, __HIP_E4M3_FNUZ); + return __float2half(temp); +} + +template <> +__host__ __device__ __inline__ __hip_fp8_e5m2fnuz explicit_casting( + float value) { + return convert_float_to_fp8(value, __HIP_E5M2_FNUZ, __HIP_SATURATE); +} + +template <> +__host__ __device__ __inline__ float explicit_casting<__hip_fp8_e5m2fnuz, float>( + __hip_fp8_e5m2fnuz value) { + return convert_fp8_to_float(value, __HIP_E5M2_FNUZ); +} + +template <> +__host__ __device__ __inline__ __hip_fp8_e5m2fnuz explicit_casting<__half, __hip_fp8_e5m2fnuz>( + __half value) { + float temp = __half2float(value); + return convert_float_to_fp8(temp, __HIP_E5M2_FNUZ, __HIP_SATURATE); +} + +template <> +__host__ __device__ __inline__ __half explicit_casting<__hip_fp8_e5m2fnuz, __half>( + __hip_fp8_e5m2fnuz value) { + float temp = convert_fp8_to_float(value, __HIP_E5M2_FNUZ); + return __float2half(temp); +} } // namespace fi::con diff --git a/pyproject.toml b/pyproject.toml index e7044680f3..53f249b885 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,9 +88,9 @@ FLASHINFER_BUILD_WHEELS="ON" FLASHINFER_AOT_TORCH_EXTS = {env="FLASHINFER_AOT_TORCH_EXTS", default="OFF"} FLASHINFER_ENABLE_F16="ON" FLASHINFER_ENABLE_BF16="ON" -FLASHINFER_ENABLE_FP8="OFF" -FLASHINFER_ENABLE_FP8_E4M3="OFF" -FLASHINFER_ENABLE_FP8_E5M2="OFF" +FLASHINFER_ENABLE_FP8="ON" +FLASHINFER_ENABLE_FP8_E4M3="ON" +FLASHINFER_ENABLE_FP8_E5M2="ON" FLASHINFER_ENABLE_CUDA = {env="FLASHINFER_ENABLE_CUDA", default="OFF"} FLASHINFER_ENABLE_HIP = {env="FLASHINFER_ENABLE_HIP", default="ON"} diff --git a/scripts/run_hip_tests.sh b/scripts/run_hip_tests.sh index 7061277da8..547fc9b1cf 100755 --- a/scripts/run_hip_tests.sh +++ b/scripts/run_hip_tests.sh @@ -4,6 +4,7 @@ python -m pytest ../tests/test_sliding_window_hip.py \ ../tests/test_batch_decode_kernels_hip.py \ +../tests/test_batch_decode_kernels_hip_fp8.py \ ../tests/test_batch_decode_vllm.py \ ../tests/test_rope.py \ ../tests/test_page.py \ diff --git a/tests/test_batch_decode_kernels_hip_fp8.py b/tests/test_batch_decode_kernels_hip_fp8.py new file mode 100644 index 0000000000..4963fe191c --- /dev/null +++ b/tests/test_batch_decode_kernels_hip_fp8.py @@ -0,0 +1,386 @@ +# SPDX-FileCopyrightText : 2023-2055 FlashInfer team. +# SPDX-FileCopyrightText : 2025 Advanced Micro Devices, Inc. +# SPDX-License-Identifier : Apache-2.0 + +import pytest +import torch +from jit_utils import jit_decode_attention_func_args + +import flashinfer + + +@pytest.fixture(autouse=True, scope="module") +def warmup_jit(): + if flashinfer.jit.has_prebuilt_ops: + yield + else: + try: + flashinfer.jit.parallel_load_modules( + jit_decode_attention_func_args( + [torch.float16], # q_dtypes + [ + torch.float16, + torch.float8_e4m3fnuz, + ], # kv_dtypes + [64, 128, 256], # head_dims + [0], # pos_encoding_modes + [False], # use_sliding_windows + [False], # use_logits_soft_caps + ) + ) + except Exception as e: + # abort the test session if warmup fails + pytest.exit(str(e)) + finally: + yield + + +@pytest.mark.parametrize("batch_size", [12, 17, 64, 87]) +@pytest.mark.parametrize("kv_len", [54, 97, 512, 1024, 2048, 4096, 8192, 16384, 32768]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) +@pytest.mark.parametrize("logits_soft_cap", [0.0]) +@pytest.mark.parametrize("return_lse", [True]) +@pytest.mark.parametrize("q_dtype", [torch.float16]) +@pytest.mark.parametrize("kv_dtype", [torch.float8_e4m3fnuz]) +@pytest.mark.parametrize("contiguous_kv", [True]) +def test_batch_decode_with_paged_kv_cache( + batch_size, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + pos_encoding_mode, + logits_soft_cap, + return_lse, + q_dtype, + kv_dtype, + contiguous_kv, +): + q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + if kv_layout == "HND": + kv_shape = [total_num_pages, 2, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, 2, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + kv_data = kv_data_fp32.to(kv_dtype) + kv_data = kv_data[:, 1, :, 1, :, 1, :, 1, :] + kv_data_fp32 = kv_data_fp32[:, 1, :, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data.stride(-4) + != kv_data.shape[-3] * kv_data.shape[-2] * kv_data.shape[-1] + ) + else: + kv_data_fp32 = torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + kv_data = kv_data_fp32.to(kv_dtype) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + ) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + logits_soft_cap=logits_soft_cap, + pos_encoding_mode=pos_encoding_mode, + data_type=kv_dtype, + q_data_type=q_dtype, + ) + if return_lse: + o, _ = wrapper.run(q, kv_data, return_lse=True) + else: + o = wrapper.run(q, kv_data) + + for i in range(batch_size): + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] + qi = q[i] + ki = torch.cat( + [ + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 0] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + kv_data_fp32[kv_indptr[i + 1] - 1, 0, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else kv_data_fp32[kv_indptr[i + 1] - 1, 0, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(kv_dtype) + vi = torch.cat( + [ + kv_data_fp32[kv_indptr[i] : kv_indptr[i + 1] - 1, 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + kv_data_fp32[kv_indptr[i + 1] - 1, 1, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else kv_data_fp32[kv_indptr[i + 1] - 1, 1, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(kv_dtype) + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( + qi, + ki, + vi, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + + # test user-allocated output + o_buffer = torch.empty_like(o) + wrapper.run(q, kv_data, out=o_buffer) + torch.testing.assert_close(o, o_buffer, rtol=1e-3, atol=1e-3) + + +@pytest.mark.parametrize("batch_size", [12, 17, 64, 87]) +@pytest.mark.parametrize("kv_len", [54, 97, 512, 1024, 2048, 4096, 8192, 16384, 32768]) +@pytest.mark.parametrize("page_size", [1, 8, 16]) +@pytest.mark.parametrize("num_kv_heads", [4]) +@pytest.mark.parametrize("num_qo_heads", [4, 32]) +@pytest.mark.parametrize("head_dim", [128, 256]) +@pytest.mark.parametrize("kv_layout", ["NHD"]) +@pytest.mark.parametrize("pos_encoding_mode", ["NONE"]) +@pytest.mark.parametrize("logits_soft_cap", [0.0]) +@pytest.mark.parametrize("return_lse", [True]) +@pytest.mark.parametrize("q_dtype", [torch.float16]) +@pytest.mark.parametrize("kv_dtype", [torch.float8_e4m3fnuz]) +@pytest.mark.parametrize("contiguous_kv", [True]) +def test_batch_decode_with_tuple_paged_kv_cache( + batch_size, + kv_len, + page_size, + num_kv_heads, + num_qo_heads, + head_dim, + kv_layout, + pos_encoding_mode, + logits_soft_cap, + return_lse, + q_dtype, + kv_dtype, + contiguous_kv, +): + q = torch.randn(batch_size, num_qo_heads, head_dim, device="cuda:0", dtype=q_dtype) + num_pages_per_seq = (kv_len + page_size - 1) // page_size + total_num_pages = num_pages_per_seq * batch_size + if kv_layout == "HND": + kv_shape = [total_num_pages, num_kv_heads, page_size, head_dim] + else: + kv_shape = [total_num_pages, page_size, num_kv_heads, head_dim] + if not contiguous_kv: + tmp = [kv_shape[0]] + for v in kv_shape[1:]: + tmp.append(2) + tmp.append(v) + kv_shape = tmp + kv_data_fp32 = [ + torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + for _ in range(2) + ] + kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] + for i in range(2): + kv_data_fp32[i] = kv_data_fp32[i][:, 1, :, 1, :, 1, :] + kv_data[i] = kv_data[i][:, 1, :, 1, :, 1, :] + # actual data is stored in non-contiguous memory + assert ( + kv_data[i].stride(-4) + != kv_data[i].shape[-3] * kv_data[i].shape[-2] * kv_data[i].shape[-1] + ) + else: + kv_data_fp32 = [ + torch.randn(*kv_shape, dtype=torch.float32, device="cuda:0") + for _ in range(2) + ] + kv_data = [kv_data_fp32[i].to(kv_dtype) for i in range(2)] + kv_data = tuple(kv_data) + kv_indptr = ( + torch.arange(0, batch_size + 1, device="cuda:0", dtype=torch.int32) + * num_pages_per_seq + ) + kv_indices = torch.arange(0, total_num_pages, device="cuda:0", dtype=torch.int32) + kv_last_page_len = torch.full( + (batch_size,), (kv_len - 1) % page_size + 1, dtype=torch.int32, device="cuda:0" + ) + + workspace_buffer = torch.empty(32 * 1024 * 1024, dtype=torch.int8, device="cuda:0") + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, kv_layout + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + logits_soft_cap=logits_soft_cap, + pos_encoding_mode=pos_encoding_mode, + data_type=kv_dtype, + q_data_type=q_dtype, + ) + if return_lse: + o, _ = wrapper.run(q, kv_data, return_lse=True) + else: + o = wrapper.run(q, kv_data) + + k_cache, v_cache = kv_data_fp32 + for i in range(batch_size): + perm_dims = [0, 2, 1, 3] if kv_layout == "HND" else [0, 1, 2, 3] + perm_dims_last = [1, 0, 2] if kv_layout == "HND" else [0, 1, 2] + qi = q[i] + ki = torch.cat( + [ + k_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + k_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else k_cache[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(kv_dtype) + vi = torch.cat( + [ + v_cache[kv_indptr[i] : kv_indptr[i + 1] - 1] + .to(torch.float32) # torch.cat does not support some fp8 types + .permute(*perm_dims) + .reshape(-1, num_kv_heads, head_dim), + ( + v_cache[kv_indptr[i + 1] - 1, :, : kv_last_page_len[i]] + if kv_layout == "HND" + else v_cache[kv_indptr[i + 1] - 1, : kv_last_page_len[i], :] + ) + .permute(*perm_dims_last) + .reshape(-1, num_kv_heads, head_dim), + ], + dim=0, + ).to(kv_dtype) + o_ref_i = flashinfer.decode.single_decode_with_kv_cache( + qi, + ki, + vi, + pos_encoding_mode=pos_encoding_mode, + logits_soft_cap=logits_soft_cap, + ) + torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + + +if __name__ == "__main__": + test_batch_decode_with_paged_kv_cache( + 256, + 54, + 8, + 8, + 8, + 128, + "NHD", + "NONE", + 0.0, + False, + torch.float16, + torch.float8_e4m3fnuz, + True, + ) + test_batch_decode_with_tuple_paged_kv_cache( + 256, + 54, + 8, + 8, + 8, + 128, + "NHD", + "NONE", + 0.0, + False, + torch.float16, + torch.float8_e4m3fnuz, + True, + ) + test_batch_decode_with_paged_kv_cache( + 12, + 2048, + 8, + 8, + 8, + 128, + "NHD", + "NONE", + 0.0, + False, + torch.float16, + torch.float8_e4m3fnuz, + True, + ) + + test_batch_decode_with_paged_kv_cache( + 12, + 54, + 1, + 8, + 8, + 128, + "HND", + "NONE", + 0.0, + True, + torch.float16, + torch.float8_e5m2fnuz, + True, + ) + + test_batch_decode_with_paged_kv_cache( + 12, + 54, + 1, + 8, + 8, + 128, + "HND", + "NONE", + 0.0, + True, + torch.float16, + torch.float8_e5m2fnuz, + True, + )