Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ Run the following command to install micromamba and set it up inside `bash`
source ~/.bashrc
```
After installing micromamba a custom environment for FlashInfer+ROCm
development should be setup. The micromamba environment is used to only manage
development should be setup. The micromamba environment is used to only manage
the Python version, rest of the dependencies are installed using `pip`.

```bash
Expand All @@ -100,14 +100,14 @@ Clone the latest trunk from https://github.com/ROCm/flashinfer.
git clone https://github.com/ROCm/flashinfer
cd flashinfer/
```
The Flashinfer+ROCm library can be built in two ways: with ahead-of-time (AOT)
The Flashinfer+ROCm library can be built in two ways: with ahead-of-time (AOT)
compiled kernels and without any AOT kernels.

Building the library with AOT kernels will take more time and local disk space
as several common configurations of the core Flashinfer kernels are built
Building the library with AOT kernels will take more time and local disk space
as several common configurations of the core Flashinfer kernels are built
during installation.

When building without AOT compilation, every kernel will be just-in-time (JIT)
When building without AOT compilation, every kernel will be just-in-time (JIT)
compiled at the time of first use.

* Instructions to build with AOT are as follows:
Expand Down
11 changes: 7 additions & 4 deletions aot_build_utils/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,14 @@ def write_if_different(path: Path, content: str) -> None:
f"use_logits_cap_{logits_soft_cap}_"
f"f16qk_{bool(use_fp16_qk_reduction)}"
)
final_list = single_decode_uris + batch_decode_uris + single_prefill_uris + batch_prefill_uris
print(final_list)
return (
final_list
final_list = (
single_decode_uris
+ batch_decode_uris
+ single_prefill_uris
+ batch_prefill_uris
)
print(final_list)
return final_list


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion aot_build_utils/generate_dispatch_inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import argparse
from pathlib import Path

from .literal_map import bool_literal, mask_mode_literal, pos_encoding_mode_literal
from .literal_map import (
bool_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_dispatch_inc_str(args: argparse.Namespace) -> str:
Expand Down
6 changes: 5 additions & 1 deletion aot_build_utils/generate_single_prefill_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import sys
from pathlib import Path

from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal
from .literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
Expand Down
6 changes: 5 additions & 1 deletion aot_build_utils/generate_single_prefill_sm90_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
import sys
from pathlib import Path

from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal
from .literal_map import (
dtype_literal,
mask_mode_literal,
pos_encoding_mode_literal,
)


def get_cu_file_str(
Expand Down
26 changes: 13 additions & 13 deletions cmake/utils/ConfigurePrebuitUris.cmake
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
function(flashinfer_configure_prebuilt_uris)
message(STATUS "Configuring prebuilt URIs")
get_property(PREBUILT_URI_LIST GLOBAL PROPERTY FLASHINFER_PREBUILT_URIS)
set(PYTHON_URI "")
message(STATUS "PREBUILT_URI_LIST: ${PREBUILT_URI_LIST}")
message(STATUS "Configuring prebuilt URIs")
get_property(PREBUILT_URI_LIST GLOBAL PROPERTY FLASHINFER_PREBUILT_URIS)
set(PYTHON_URI "")
message(STATUS "PREBUILT_URI_LIST: ${PREBUILT_URI_LIST}")

string(REPLACE ";" "\", \"" list_items "${PREBUILT_URI_LIST}")
set(PYTHON_URI "${list_items}")
string(REPLACE ";" "\", \"" list_items "${PREBUILT_URI_LIST}")
set(PYTHON_URI "${list_items}")

message(STATUS "PYTHON_URI: ${PYTHON_URI}")

set(TEMPLATE_FILE "${CMAKE_SOURCE_DIR}/templates/__aot_prebuilt_uris__.py.in")
set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/flashinfer/__aot_prebuilt_uris__.py")
set(INSTALL_DIR "flashinfer")
message(STATUS "PYTHON_URI: ${PYTHON_URI}")

configure_file("${TEMPLATE_FILE}" "${OUTPUT_FILE}" @ONLY)
set(TEMPLATE_FILE "${CMAKE_SOURCE_DIR}/templates/__aot_prebuilt_uris__.py.in")
set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/flashinfer/__aot_prebuilt_uris__.py")
set(INSTALL_DIR "flashinfer")

install(FILES "${OUTPUT_FILE}" DESTINATION "${INSTALL_DIR}")
configure_file("${TEMPLATE_FILE}" "${OUTPUT_FILE}" @ONLY)

install(FILES "${OUTPUT_FILE}" DESTINATION "${INSTALL_DIR}")
endfunction()
33 changes: 18 additions & 15 deletions examples/test_batch_decode_example.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import torch

import flashinfer


def verify_tensors(tensor1, tensor2, rtol=1e-3, atol=1e-3):

for i in range(tensor1.shape[0]):
for j in range(tensor1.shape[1]):
if torch.abs(tensor1[i][j] - tensor2[i][j]) > atol + rtol * torch.abs(tensor2[i][j]):
if torch.abs(tensor1[i][j] - tensor2[i][j]) > atol + rtol * torch.abs(
tensor2[i][j]
):
print(f"Error at {i}, {j}")
print(f"Expected: {tensor2[i][j]}")
print(f"Got: {tensor1[i][j]}")
return False
return True


def test_batch_decode_with_paged_kv_cache(
batch_size,
kv_len,
Expand Down Expand Up @@ -119,10 +124,7 @@ def test_batch_decode_with_paged_kv_cache(
dim=0,
).to(kv_dtype)
# print(qi.shape, ki.shape, vi.shape)
o_ref_i = flashinfer.single_decode_with_kv_cache(
qi,
ki,
vi)
o_ref_i = flashinfer.single_decode_with_kv_cache(qi, ki, vi)
# torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3)
result += verify_tensors(o[i], o_ref_i, rtol=1e-3, atol=1e-3)

Expand All @@ -136,14 +138,15 @@ def test_batch_decode_with_paged_kv_cache(
else:
print("FAIL")


if __name__ == "__main__":

batch_size = 256
page_size = 8
page_size = 8

# # This configuration works
num_qo_heads = 32
num_kv_heads = 4
num_qo_heads = 32
num_kv_heads = 4
head_dim = 256
kv_len = 512

Expand All @@ -152,17 +155,17 @@ def test_batch_decode_with_paged_kv_cache(
# num_kv_heads = 8
# head_dim = 128
# kv_len = 54

kv_layout = "NHD"
pos_encoding_mode = "NONE"
logits_soft_cap = 0.0
return_lse = False
q_dtype = torch.float16
kv_dtype = torch.float16
contiguous_kv = True
num_qo_heads = 32
num_kv_heads = 4

num_qo_heads = 32
num_kv_heads = 4
head_dim = 256
kv_len = 512
test_batch_decode_with_paged_kv_cache(
Expand All @@ -178,5 +181,5 @@ def test_batch_decode_with_paged_kv_cache(
return_lse,
q_dtype,
kv_dtype,
contiguous_kv)

contiguous_kv,
)
1 change: 0 additions & 1 deletion scripts/run_hip_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ python -m pytest ../tests/test_sliding_window_hip.py \
../tests/test_norm_hip.py \
../tests/test_logits_cap_hip.py \
../tests/test_non_contiguous_decode_hip.py \

7 changes: 3 additions & 4 deletions tests/test_batch_decode_kernels_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@

import pytest
import torch
from jit_utils import (
jit_decode_attention_func_args
)
from jit_utils import jit_decode_attention_func_args

import flashinfer


@pytest.fixture(autouse=True, scope="module")
def warmup_jit():
if flashinfer.jit.has_prebuilt_ops:
Expand Down Expand Up @@ -318,6 +317,7 @@ def test_batch_decode_with_tuple_paged_kv_cache(
)
torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
test_batch_decode_with_paged_kv_cache(
256,
Expand Down Expand Up @@ -395,4 +395,3 @@ def test_batch_decode_with_tuple_paged_kv_cache(
torch.float16,
True,
)

5 changes: 2 additions & 3 deletions tests/test_logits_cap_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@

import pytest
import torch
from jit_utils import (
jit_decode_attention_func_args
)
from jit_utils import jit_decode_attention_func_args

import flashinfer

Expand Down Expand Up @@ -75,5 +73,6 @@ def test_single_decode_logits_soft_cap(
o_ref = attention_logits_soft_cap_torch(q.unsqueeze(0), k, v, soft_cap).squeeze(0)
torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3)


if __name__ == "__main__":
test_single_decode_logits_soft_cap(9, 32, 128, 30.0)
4 changes: 1 addition & 3 deletions tests/test_non_contiguous_decode_hip.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import pytest
import torch
from jit_utils import (
jit_decode_attention_func_args
)
from jit_utils import jit_decode_attention_func_args

import flashinfer

Expand Down
1 change: 0 additions & 1 deletion tests/test_norm_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, enable_pdl, contiguou
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]


residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")

Expand Down
1 change: 0 additions & 1 deletion tests/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,6 @@ def forward_cuda(
(256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2),
],
)

def test_rope_cos_sin_cache(
head_size: int,
rotary_dim: int,
Expand Down
4 changes: 1 addition & 3 deletions tests/test_sliding_window_hip.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

import pytest
import torch
from jit_utils import (
jit_decode_attention_func_args
)
from jit_utils import jit_decode_attention_func_args

import flashinfer

Expand Down