diff --git a/README.md b/README.md index 8f9f3a6279..07e4fc6178 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ Run the following command to install micromamba and set it up inside `bash` source ~/.bashrc ``` After installing micromamba a custom environment for FlashInfer+ROCm -development should be setup. The micromamba environment is used to only manage +development should be setup. The micromamba environment is used to only manage the Python version, rest of the dependencies are installed using `pip`. ```bash @@ -100,14 +100,14 @@ Clone the latest trunk from https://github.com/ROCm/flashinfer. git clone https://github.com/ROCm/flashinfer cd flashinfer/ ``` -The Flashinfer+ROCm library can be built in two ways: with ahead-of-time (AOT) +The Flashinfer+ROCm library can be built in two ways: with ahead-of-time (AOT) compiled kernels and without any AOT kernels. -Building the library with AOT kernels will take more time and local disk space -as several common configurations of the core Flashinfer kernels are built +Building the library with AOT kernels will take more time and local disk space +as several common configurations of the core Flashinfer kernels are built during installation. -When building without AOT compilation, every kernel will be just-in-time (JIT) +When building without AOT compilation, every kernel will be just-in-time (JIT) compiled at the time of first use. * Instructions to build with AOT are as follows: diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index 29ce92f56a..431ba48081 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -252,11 +252,14 @@ def write_if_different(path: Path, content: str) -> None: f"use_logits_cap_{logits_soft_cap}_" f"f16qk_{bool(use_fp16_qk_reduction)}" ) - final_list = single_decode_uris + batch_decode_uris + single_prefill_uris + batch_prefill_uris - print(final_list) - return ( - final_list + final_list = ( + single_decode_uris + + batch_decode_uris + + single_prefill_uris + + batch_prefill_uris ) + print(final_list) + return final_list if __name__ == "__main__": diff --git a/aot_build_utils/generate_dispatch_inc.py b/aot_build_utils/generate_dispatch_inc.py index 8a0cfa1432..e73a345c5a 100644 --- a/aot_build_utils/generate_dispatch_inc.py +++ b/aot_build_utils/generate_dispatch_inc.py @@ -17,7 +17,11 @@ import argparse from pathlib import Path -from .literal_map import bool_literal, mask_mode_literal, pos_encoding_mode_literal +from .literal_map import ( + bool_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) def get_dispatch_inc_str(args: argparse.Namespace) -> str: diff --git a/aot_build_utils/generate_single_prefill_inst.py b/aot_build_utils/generate_single_prefill_inst.py index 14535c04e6..35da3a1818 100644 --- a/aot_build_utils/generate_single_prefill_inst.py +++ b/aot_build_utils/generate_single_prefill_inst.py @@ -18,7 +18,11 @@ import sys from pathlib import Path -from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal +from .literal_map import ( + dtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) def get_cu_file_str( diff --git a/aot_build_utils/generate_single_prefill_sm90_inst.py b/aot_build_utils/generate_single_prefill_sm90_inst.py index 291aad8edd..7a531059ca 100644 --- a/aot_build_utils/generate_single_prefill_sm90_inst.py +++ b/aot_build_utils/generate_single_prefill_sm90_inst.py @@ -18,7 +18,11 @@ import sys from pathlib import Path -from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal +from .literal_map import ( + dtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) def get_cu_file_str( diff --git a/cmake/utils/ConfigurePrebuitUris.cmake b/cmake/utils/ConfigurePrebuitUris.cmake index b6789b2cdb..4b31476356 100644 --- a/cmake/utils/ConfigurePrebuitUris.cmake +++ b/cmake/utils/ConfigurePrebuitUris.cmake @@ -1,19 +1,19 @@ function(flashinfer_configure_prebuilt_uris) - message(STATUS "Configuring prebuilt URIs") - get_property(PREBUILT_URI_LIST GLOBAL PROPERTY FLASHINFER_PREBUILT_URIS) - set(PYTHON_URI "") - message(STATUS "PREBUILT_URI_LIST: ${PREBUILT_URI_LIST}") + message(STATUS "Configuring prebuilt URIs") + get_property(PREBUILT_URI_LIST GLOBAL PROPERTY FLASHINFER_PREBUILT_URIS) + set(PYTHON_URI "") + message(STATUS "PREBUILT_URI_LIST: ${PREBUILT_URI_LIST}") - string(REPLACE ";" "\", \"" list_items "${PREBUILT_URI_LIST}") - set(PYTHON_URI "${list_items}") + string(REPLACE ";" "\", \"" list_items "${PREBUILT_URI_LIST}") + set(PYTHON_URI "${list_items}") - message(STATUS "PYTHON_URI: ${PYTHON_URI}") - - set(TEMPLATE_FILE "${CMAKE_SOURCE_DIR}/templates/__aot_prebuilt_uris__.py.in") - set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/flashinfer/__aot_prebuilt_uris__.py") - set(INSTALL_DIR "flashinfer") + message(STATUS "PYTHON_URI: ${PYTHON_URI}") - configure_file("${TEMPLATE_FILE}" "${OUTPUT_FILE}" @ONLY) + set(TEMPLATE_FILE "${CMAKE_SOURCE_DIR}/templates/__aot_prebuilt_uris__.py.in") + set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/flashinfer/__aot_prebuilt_uris__.py") + set(INSTALL_DIR "flashinfer") - install(FILES "${OUTPUT_FILE}" DESTINATION "${INSTALL_DIR}") + configure_file("${TEMPLATE_FILE}" "${OUTPUT_FILE}" @ONLY) + + install(FILES "${OUTPUT_FILE}" DESTINATION "${INSTALL_DIR}") endfunction() diff --git a/examples/test_batch_decode_example.py b/examples/test_batch_decode_example.py index 9c17356f97..8da21b4c9e 100644 --- a/examples/test_batch_decode_example.py +++ b/examples/test_batch_decode_example.py @@ -1,17 +1,22 @@ import torch + import flashinfer + def verify_tensors(tensor1, tensor2, rtol=1e-3, atol=1e-3): - + for i in range(tensor1.shape[0]): for j in range(tensor1.shape[1]): - if torch.abs(tensor1[i][j] - tensor2[i][j]) > atol + rtol * torch.abs(tensor2[i][j]): + if torch.abs(tensor1[i][j] - tensor2[i][j]) > atol + rtol * torch.abs( + tensor2[i][j] + ): print(f"Error at {i}, {j}") print(f"Expected: {tensor2[i][j]}") print(f"Got: {tensor1[i][j]}") return False return True + def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -119,10 +124,7 @@ def test_batch_decode_with_paged_kv_cache( dim=0, ).to(kv_dtype) # print(qi.shape, ki.shape, vi.shape) - o_ref_i = flashinfer.single_decode_with_kv_cache( - qi, - ki, - vi) + o_ref_i = flashinfer.single_decode_with_kv_cache(qi, ki, vi) # torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) result += verify_tensors(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @@ -136,14 +138,15 @@ def test_batch_decode_with_paged_kv_cache( else: print("FAIL") + if __name__ == "__main__": batch_size = 256 - page_size = 8 + page_size = 8 # # This configuration works - num_qo_heads = 32 - num_kv_heads = 4 + num_qo_heads = 32 + num_kv_heads = 4 head_dim = 256 kv_len = 512 @@ -152,7 +155,7 @@ def test_batch_decode_with_paged_kv_cache( # num_kv_heads = 8 # head_dim = 128 # kv_len = 54 - + kv_layout = "NHD" pos_encoding_mode = "NONE" logits_soft_cap = 0.0 @@ -160,9 +163,9 @@ def test_batch_decode_with_paged_kv_cache( q_dtype = torch.float16 kv_dtype = torch.float16 contiguous_kv = True - - num_qo_heads = 32 - num_kv_heads = 4 + + num_qo_heads = 32 + num_kv_heads = 4 head_dim = 256 kv_len = 512 test_batch_decode_with_paged_kv_cache( @@ -178,5 +181,5 @@ def test_batch_decode_with_paged_kv_cache( return_lse, q_dtype, kv_dtype, - contiguous_kv) - + contiguous_kv, + ) diff --git a/scripts/run_hip_tests.sh b/scripts/run_hip_tests.sh index 73a17db3e1..7061277da8 100755 --- a/scripts/run_hip_tests.sh +++ b/scripts/run_hip_tests.sh @@ -10,4 +10,3 @@ python -m pytest ../tests/test_sliding_window_hip.py \ ../tests/test_norm_hip.py \ ../tests/test_logits_cap_hip.py \ ../tests/test_non_contiguous_decode_hip.py \ - diff --git a/tests/test_batch_decode_kernels_hip.py b/tests/test_batch_decode_kernels_hip.py index d70b3b47ab..e0db60a548 100644 --- a/tests/test_batch_decode_kernels_hip.py +++ b/tests/test_batch_decode_kernels_hip.py @@ -16,12 +16,11 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer + @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: @@ -318,6 +317,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( 256, @@ -395,4 +395,3 @@ def test_batch_decode_with_tuple_paged_kv_cache( torch.float16, True, ) - diff --git a/tests/test_logits_cap_hip.py b/tests/test_logits_cap_hip.py index c3a1f475c9..7c43284b29 100644 --- a/tests/test_logits_cap_hip.py +++ b/tests/test_logits_cap_hip.py @@ -18,9 +18,7 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer @@ -75,5 +73,6 @@ def test_single_decode_logits_soft_cap( o_ref = attention_logits_soft_cap_torch(q.unsqueeze(0), k, v, soft_cap).squeeze(0) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + if __name__ == "__main__": test_single_decode_logits_soft_cap(9, 32, 128, 30.0) diff --git a/tests/test_non_contiguous_decode_hip.py b/tests/test_non_contiguous_decode_hip.py index a526eb48d5..1836c054eb 100644 --- a/tests/test_non_contiguous_decode_hip.py +++ b/tests/test_non_contiguous_decode_hip.py @@ -1,8 +1,6 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer diff --git a/tests/test_norm_hip.py b/tests/test_norm_hip.py index 3c5fbf891d..41f4faca8e 100644 --- a/tests/test_norm_hip.py +++ b/tests/test_norm_hip.py @@ -103,7 +103,6 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, enable_pdl, contiguou x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) x = x[:, :hidden_size] - residual = torch.randn_like(x) weight = torch.randn(hidden_size, dtype=dtype, device="cuda") diff --git a/tests/test_rope.py b/tests/test_rope.py index dbca4670b4..4e0c40b1cc 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -301,7 +301,6 @@ def forward_cuda( (256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2), ], ) - def test_rope_cos_sin_cache( head_size: int, rotary_dim: int, diff --git a/tests/test_sliding_window_hip.py b/tests/test_sliding_window_hip.py index 0699fe98b7..39fa15d5cd 100644 --- a/tests/test_sliding_window_hip.py +++ b/tests/test_sliding_window_hip.py @@ -16,9 +16,7 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer