Skip to content

Commit b4d194d

Browse files
committed
[FlashInfer] Update include path and interface
This PR updates the include path for FlashInfer JIT compilation, and also updates the plan function interface for attention prefill computation, to align with recent interface change in flashinfer-ai/flashinfer#1661.
1 parent 70e9164 commit b4d194d

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

python/tvm/relax/backend/cuda/flashinfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def get_object_file_path(src: Path) -> Path:
141141
)
142142
include_paths += [
143143
Path(tvm_home).resolve() / "include",
144-
Path(tvm_home).resolve() / "ffi" / "include",
145-
Path(tvm_home).resolve() / "ffi" / "3rdparty" / "dlpack" / "include",
144+
Path(tvm_home).resolve() / "3rdparty"/ "tvm-ffi" / "include",
145+
Path(tvm_home).resolve() / "3rdparty"/ "tvm-ffi" / "3rdparty" / "dlpack" / "include",
146146
Path(tvm_home).resolve() / "3rdparty" / "dmlc-core" / "include",
147147
]
148148
else:
@@ -160,8 +160,8 @@ def get_object_file_path(src: Path) -> Path:
160160
# The package is installed from source.
161161
include_paths += [
162162
tvm_package_path.parent.parent / "include",
163-
tvm_package_path.parent.parent / "ffi" / "include",
164-
tvm_package_path.parent.parent / "ffi" / "3rdparty" / "dlpack" / "include",
163+
tvm_package_path.parent.parent / "3rdparty" / "tvm-ffi" / "include",
164+
tvm_package_path.parent.parent / "3rdparty" / "tvm-ffi" / "3rdparty" / "dlpack" / "include",
165165
tvm_package_path.parent.parent / "3rdparty" / "dmlc-core" / "include",
166166
]
167167
else:

src/runtime/vm/attn_backend.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ class FlashInferPagedPrefillFunc : public PagedPrefillFunc {
176176
plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer,
177177
qo_indptr->as_tensor(), page_indptr->as_tensor(), IntTuple(std::move(kv_len)),
178178
total_qo_len, batch_size, num_qo_heads, num_kv_heads, page_size,
179-
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream)
179+
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal,
180+
/*window_left=*/-1, copy_stream)
180181
.cast<IntTuple>();
181182
} else if (attn_kind == AttnKind::kMLA) {
182183
plan_info_vec =
@@ -280,7 +281,8 @@ class FlashInferRaggedPrefillFunc : public RaggedPrefillFunc {
280281
plan_func_(float_workspace_buffer, int_workspace_buffer, page_locked_int_workspace_buffer,
281282
qo_indptr->as_tensor(), kv_indptr->as_tensor(), IntTuple(std::move(kv_len)),
282283
total_qo_len, batch_size, num_qo_heads, num_kv_heads, /*page_size=*/1,
283-
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal, copy_stream)
284+
/*enable_cuda_graph=*/false, qk_head_dim, v_head_dim, causal,
285+
/*window_left=*/-1, copy_stream)
284286
.cast<IntTuple>();
285287
}
286288

0 commit comments

Comments
 (0)