Skip to content

Commit ebce953

Browse files
committed
Small fixes
1 parent 861713c commit ebce953

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

python/tvm/relax/backend/cuda/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def gen_sampling_module(target: Target, num_threads: int = 8):
426426
"FlashInfer is not installed. Please follow instructions "
427427
"in https://docs.flashinfer.ai to install FlashInfer."
428428
)
429-
uri, source_paths = gen_sampling_tvm_binding()
429+
uri, source_paths = gen_sampling_tvm_binding(uri="sampling")
430430
object_files = _compile_flashinfer_kernels(uri, source_paths, target, num_threads)
431431
modules = _load_flashinfer_modules(object_files)
432432
return modules

0 commit comments

Comments
 (0)