Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/programming_guides/autotuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def matmul_configs(M, N, K):

Tune compile‑time options explicitly:
- `target='auto'|'cuda'|'hip'|'metal'` (normalized to a TVM Target)
- `execution_backend='auto'|'tvm_ffi'|'ctypes'|'cython'|'nvrtc'|'torch'`
- `execution_backend='auto'|'tvm_ffi'|'cython'|'nvrtc'|'torch'`
- `pass_configs={...}` to toggle TileLang/TVM passes for experiments

On CUDA with multiple GPUs, the tuner sets the current device per worker thread
Expand Down
2 changes: 1 addition & 1 deletion testing/python/cpu/test_tilelang_cpu_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def matmul(
block_M, block_N, block_K = M // 4, N // 4, K // 4
cpu_func = matmul_jit_test(M, N, K, block_M, block_N, block_K)
with tvm.target.Target("c"):
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="ctypes")
complied_fun = tilelang.compile(cpu_func, -1, execution_backend="cython")

in_dtype = T.float16
A = torch.randn(M, K, dtype=torch.__getattribute__(in_dtype))
Expand Down
8 changes: 0 additions & 8 deletions testing/python/jit/test_tilelang_jit_gemm_cython.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,24 +256,16 @@ def run_cython_kernel_do_bench(
)

cython_matmul_kernel = tilelang.compile(program, execution_backend="cython")
ctypes_matmul_kernel = tilelang.compile(program, execution_backend="ctypes")

cython_profiler = cython_matmul_kernel.get_profiler()
ctypes_profiler = ctypes_matmul_kernel.get_profiler()

cython_latency = cython_profiler.do_bench(func=cython_matmul_kernel)
print(f"cython Latency: {cython_latency} ms")

# assert ctypes_latency is not None

tvm_latency = cython_profiler.do_bench()
print(f"TVM Latency: {tvm_latency} ms")

assert tvm_latency is not None

ctypes_latency = ctypes_profiler.do_bench(func=ctypes_matmul_kernel)
print(f"ctypes Latency: {ctypes_latency} ms")

assert cython_latency is not None


Expand Down
4 changes: 2 additions & 2 deletions tilelang/autotuner/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class CompileArgs:
"""

out_idx: list[int] | int | None = None
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto"
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] = "auto"
target: Literal["auto", "cuda", "hip"] = "auto"
target_host: str | Target = None
verbose: bool = False
Expand Down Expand Up @@ -265,7 +265,7 @@ def _load_kernel_from_disk(
target: str | Target = "auto",
target_host: str | Target = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "tvm_ffi",
execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch"] = "tvm_ffi",
pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
func: Callable = None,
Expand Down
4 changes: 2 additions & 2 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def set_compile_args(
self,
out_idx: list[int] | int | None = None,
target: Literal["auto", "cuda", "hip", "metal"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] = "auto",
target_host: str | Target = None,
verbose: bool = False,
pass_configs: dict[str, Any] | None = None,
Expand Down Expand Up @@ -725,7 +725,7 @@ def autotune( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"], optional
execution_backend : Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"], optional
Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional
Expand Down
2 changes: 1 addition & 1 deletion tilelang/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def cached(
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch"] | None = "auto",
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch"] | None = "auto",
verbose: bool | None = False,
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
Expand Down
8 changes: 4 additions & 4 deletions tilelang/cache/kernel_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class KernelCache:
_instance = None # For implementing singleton pattern
_lock = threading.Lock() # For thread safety
_memory_cache = {} # In-memory cache dictionary
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi"
execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi"

def __new__(cls):
"""
Expand Down Expand Up @@ -77,7 +77,7 @@ def _generate_key(
self,
func: Callable,
out_idx: list[int],
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
args=None,
target: str | Target = "auto",
target_host: str | Target = None,
Expand Down Expand Up @@ -123,7 +123,7 @@ def cached(
*args,
target: str | Target = "auto",
target_host: str | Target = None,
execution_backend: Literal["auto", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
execution_backend: Literal["auto", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
verbose: bool = False,
pass_configs: dict = None,
compile_flags: list[str] | str | None = None,
Expand Down Expand Up @@ -389,7 +389,7 @@ def _load_kernel_from_disk(
target: str | Target = "auto",
target_host: str | Target | None = None,
out_idx: list[int] | None = None,
execution_backend: Literal["tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
execution_backend: Literal["tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "tvm_ffi",
pass_configs: dict | None = None,
compile_flags: list[str] | str | None = None,
func: Callable | None = None,
Expand Down
16 changes: 8 additions & 8 deletions tilelang/jit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
def compile(
func: PrimFunc[_KP, _T] = None,
out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
Expand All @@ -64,7 +64,7 @@ def compile(
The TileLang TIR function to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional
Expand Down Expand Up @@ -118,7 +118,7 @@ def compile(
def par_compile(
funcs: Iterable[PrimFunc[_KP, _T]],
out_idx: list[int] | int | None = None,
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"] = "auto",
target: str | Target = "auto",
target_host: str | Target | None = None,
verbose: bool = False,
Expand All @@ -135,7 +135,7 @@ def par_compile(
The TileLang TIR functions to compile and wrap.
out_idx : Union[List[int], int], optional
Index(es) of the output tensors to return (default: None).
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional
Execution backend to use for kernel execution. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
target : Union[str, Target], optional
Expand Down Expand Up @@ -202,7 +202,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
out_idx : list[int] | int | None
Which output tensor(s) of the compiled kernel should be returned to the
caller. Accepts a single index, a list of indices, or None to return all.
execution_backend : Literal["dlpack", "ctypes", "cython"]
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"]
Backend used for exchanging arguments and executing the generated kernel.
target : str | tvm.target.Target
TVM compilation target (e.g. "cuda", "llvm", or "auto").
Expand Down Expand Up @@ -256,7 +256,7 @@ class JITImpl(Generic[_P, _KP, _T, _Ret]):
"""

out_idx: list[int] | int | None
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
execution_backend: Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"]
target: str | Target
target_host: str | Target
verbose: bool
Expand Down Expand Up @@ -424,7 +424,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _Ret:
return kernel


ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"]
ExecutionBackend = Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"]


@overload
Expand Down Expand Up @@ -473,7 +473,7 @@ def jit( # This is the new public interface
Compilation target for TVM (e.g., "cuda", "llvm"). Defaults to "auto".
target_host : Union[str, Target], optional
Target host for cross-compilation. Defaults to None.
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "ctypes", "cython", "nvrtc", "torch", "cutedsl"], optional
execution_backend : Literal["auto", "dlpack", "tvm_ffi", "cython", "nvrtc", "torch", "cutedsl"], optional
Backend for kernel execution and argument passing. Use "auto" to pick a sensible
default per target (cuda->tvm_ffi, metal->torch, others->cython).
verbose : bool, optional
Expand Down
1 change: 0 additions & 1 deletion tilelang/jit/adapter/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .base import BaseKernelAdapter # noqa: F401
from .tvm_ffi import TVMFFIKernelAdapter # noqa: F401
from .ctypes import CtypesKernelAdapter # noqa: F401
from .cython import CythonKernelAdapter # noqa: F401
from .nvrtc import NVRTCKernelAdapter # noqa: F401
from .torch import MetalKernelAdapter # noqa: F401
Expand Down
1 change: 0 additions & 1 deletion tilelang/jit/adapter/ctypes/__init__.py

This file was deleted.

Loading
Loading