Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from e47e76 to 001022
2 changes: 1 addition & 1 deletion examples/attention_sink/example_mha_sink_fwd_bhsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_configs():
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]


@autotune(configs=get_configs(), warmup=500, rep=100)
@autotune(configs=get_configs())
@tilelang.jit(
out_idx=[3],
pass_configs={
Expand Down
2 changes: 0 additions & 2 deletions examples/gemm/example_gemm_intrinsics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func


def make_swizzle_layout(shared_buf):
Expand All @@ -25,7 +24,6 @@ def transform_func(i, j):


@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
Expand Down
2 changes: 0 additions & 2 deletions examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from tilelang.intrinsics.mma_macro_generator import (
TensorCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang.utils.tensor import map_torch_type

tilelang.testing.set_random_seed(0)
Expand All @@ -29,7 +28,6 @@ def transform_func(i, j):


@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
Expand Down
22 changes: 11 additions & 11 deletions examples/lazy_jit/lazyjit.en.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -209,7 +209,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_dyn_K(A, B):\n",
" M, N, K = T.dynamic(\"M, N, K\")\n",
" A: T.Tensor[[M, K], T.float16]\n",
Expand Down Expand Up @@ -248,7 +248,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def as_contingious(A):\n",
" M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n",
" A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n",
Expand Down Expand Up @@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -359,7 +359,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr_dyn(A, B, M, N, K):\n",
" M: T.int32\n",
" N: T.int32\n",
Expand Down Expand Up @@ -421,7 +421,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def example_wrong_kernel(A):\n",
" M = T.const(\"M\")\n",
" A: T.Tensor[[M * 2, M * 3], T.float32]\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dyn_annot(\n",
" A: T.ptr, # 1. T.ptr type annotation\n",
" is_2d=False,\n",
Expand Down Expand Up @@ -515,7 +515,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def add_one(X, data: T.float32 = 1):\n",
" M, N = T.const(\"M, N\")\n",
" X: T.Tensor[[M, N], T.float32]\n",
Expand Down Expand Up @@ -577,7 +577,7 @@
"B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dummy_kernel(A, B):\n",
" M, N = T.const(\"M, N\")\n",
" A: T.Tensor[[M, N], T.float16]\n",
Expand Down Expand Up @@ -797,7 +797,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def element_wise(A, fn):\n",
" N = T.dynamic(\"N\")\n",
" A: T.Tensor[[N], T.float32]\n",
Expand Down Expand Up @@ -857,7 +857,7 @@
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])"
Expand Down
22 changes: 11 additions & 11 deletions examples/lazy_jit/lazyjit.zh.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -209,7 +209,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_dyn_K(A, B):\n",
" M, N, K = T.dynamic(\"M, N, K\")\n",
" A: T.Tensor[[M, K], T.float16]\n",
Expand Down Expand Up @@ -248,7 +248,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def as_contingious(A):\n",
" M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n",
" A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n",
Expand Down Expand Up @@ -307,7 +307,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr(\n",
" A,\n",
" B,\n",
Expand Down Expand Up @@ -359,7 +359,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def gemm_ptr_dyn(A, B, M, N, K):\n",
" M: T.int32\n",
" N: T.int32\n",
Expand Down Expand Up @@ -421,7 +421,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def example_wrong_kernel(A):\n",
" M = T.const(\"M\")\n",
" A: T.Tensor[[M * 2, M * 3], T.float32]\n",
Expand Down Expand Up @@ -470,7 +470,7 @@
}
],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dyn_annot(\n",
" A: T.ptr, # 1. T.ptr type annotation\n",
" is_2d=False,\n",
Expand Down Expand Up @@ -515,7 +515,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def add_one(X, data: T.float32 = 1):\n",
" M, N = T.const(\"M, N\")\n",
" X: T.Tensor[[M, N], T.float32]\n",
Expand Down Expand Up @@ -577,7 +577,7 @@
"B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def dummy_kernel(A, B):\n",
" M, N = T.const(\"M, N\")\n",
" A: T.Tensor[[M, N], T.float16]\n",
Expand Down Expand Up @@ -797,7 +797,7 @@
"metadata": {},
"outputs": [],
"source": [
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def element_wise(A, fn):\n",
" N = T.dynamic(\"N\")\n",
" A: T.Tensor[[N], T.float32]\n",
Expand Down Expand Up @@ -857,7 +857,7 @@
" n31(x * 3 + 1, var)\n",
"\n",
"\n",
"@tilelang.lazy_jit\n",
"@tilelang.jit\n",
"def foo(A: T.Tensor[[1], T.int32], n: int):\n",
" with T.Kernel(1) as _:\n",
" n31(n, A[0])"
Expand Down
Loading
Loading