Skip to content

Commit 642f1e8

Browse files
committed
enable tvm-ffi for metal
1 parent 0aea438 commit 642f1e8

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

testing/python/metal/test_metal_codegen.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66

77

8-
@tilelang.jit(execution_backend='torch')
8+
@tilelang.jit(execution_backend='tvm_ffi')
99
def matmul(M, N, K, block_M, block_N, block_K, dtype="float32", accum_dtype="float"):
1010

1111
@T.prim_func

tilelang/engine/lower.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ def device_codegen(device_mod: tvm.IRModule, target: Target) -> tvm.IRModule:
161161
device_mod = tvm.ffi.get_global_func("target.build.tilelang_cuda")(device_mod, target)
162162
elif target.kind.name == "hip":
163163
device_mod = tvm.ffi.get_global_func("target.build.tilelang_hip")(device_mod, target)
164+
elif target.kind.name == "metal":
165+
device_mod = tvm.ffi.get_global_func("target.build.metal")(device_mod, target)
164166
else:
165167
raise ValueError(f"Target {target.kind.name} is not supported")
166168

tilelang/jit/execution_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def allowed_backends_for_target(target: Target, *, include_unavailable: bool = T
3535
elif kind == "hip":
3636
allowed = ["tvm_ffi", "cython", "ctypes"]
3737
elif kind == "metal":
38-
allowed = ["torch"]
38+
allowed = ["tvm_ffi", "torch"]
3939
elif kind == "c": # CPU C backend
4040
allowed = ["cython", "ctypes", "tvm_ffi"]
4141
else:

0 commit comments

Comments
 (0)