Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,25 @@ def copy6(A):
assert torch.equal(A[:, 0, :, 0], B)


def test_jit2_compile_with_consts():
@tilelang.jit
def transpose(X, Y, block_M, block_N):
M, N = T.const("M N")
X: T.Tensor[[M, N], T.float32]
Y: T.Tensor[[N, M], T.float32]

with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):
X_tile = T.alloc_shared((block_M, block_N), T.float32)
Y_tile = T.alloc_shared((block_N, block_M), T.float32)

T.copy(X[bx * block_M, by * block_N], X_tile)
for i, j in T.Parallel(block_M, block_N):
Y_tile[j, i] = X_tile[i, j]
T.copy(Y_tile, Y[by * block_N, bx * block_M])

transpose.compile(M=1024, N=1024, block_M=64, block_N=64)


if __name__ == "__main__":
# tilelang.testing.main()
test_jit2_return()
21 changes: 8 additions & 13 deletions testing/python/layout/test_tilelang_annotate_loop_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ def loop_layout_fn(i, j):
return forward_thread, forward_local

M, N = 128, 32
A = T.Tensor((M, N), T.float32)
B = T.Tensor((M, N), T.float32)
loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn)
kernel = loop_layout_kernel.compile(A=A, B=B, loop_layout=loop_layout)
kernel = loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout)
code = kernel.get_kernel_source()

# Expect vectorized copy along innermost dimension (float4)
Expand All @@ -42,10 +40,8 @@ def loop_layout_fn(i, j):
return forward_thread, forward_local

M, N = 128, 32
A = T.Tensor((M, N), T.float32)
B = T.Tensor((M, N), T.float32)
loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn)
kernel = loop_layout_kernel.compile(A=A, B=B, loop_layout=loop_layout)
kernel = loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout)
code = kernel.get_kernel_source()

assert "*(float4*)(B + ((((int)threadIdx.x) * 32) + (i * 4))) = *(float4*)(A + ((((int)threadIdx.x) * 32) + (i * 4)));" in code
Expand All @@ -70,13 +66,14 @@ def loop_layout_fn(i, j, rep):
return fth, floc

M, N = 128, 32
A = T.Tensor((M, N), T.float32)
B = T.Tensor((M, N), T.float32)
loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn, replicate=2)
kernel = copy_with_layout_kernel.compile(A=A, B=B, loop_layout=loop_layout)
kernel = copy_with_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout)
code = kernel.get_kernel_source()

assert "*(float4*)(B + ((i * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(A + ((i * 512) + (((int)threadIdx.x) * 4)));" in code
assert (
"*(float4*)(B + ((i * 256) + ((((int)threadIdx.x) & 63) * 4))) = *(float4*)(A + ((i * 256) + ((((int)threadIdx.x) & 63) * 4)));"
in code
)


@tilelang.jit
Expand All @@ -93,8 +90,6 @@ def replicate_loop_layout_kernel(A, B, loop_layout):
@tilelang.testing.requires_cuda
def test_annotate_replicate_loop_layout_vec4():
M, N = 128, 32
A = T.Tensor((M, N), T.float32)
B = T.Tensor((M, N), T.float32)

def loop_layout_fn(i, j, rep):
elems = i * 32 + j
Expand All @@ -104,7 +99,7 @@ def loop_layout_fn(i, j, rep):

loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn, replicate=2)

kernel = replicate_loop_layout_kernel.compile(A, B, loop_layout=loop_layout)
kernel = replicate_loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout)
code = kernel.get_kernel_source()

assert (
Expand Down
22 changes: 14 additions & 8 deletions tilelang/language/eager/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef):
node.body = stmts + node.body
node.decorator_list.clear()
name = node.name
node.args.kwarg = ast.arg(arg="__kwargs")
node = SpanAttacher("__tb_fl", "__tb_fn").visit(node)
return quote1(
f"def make_closure({', '.join(self.nonlocals.keys())}):\n"
Expand Down Expand Up @@ -503,22 +504,27 @@ def _parse_arg_annot(self, stmt: ast.stmt, arg_names: set[str]):
if name not in arg_names:
return
annot = stmt.annotation
# case 1: subscript(attribute(T, Tensor), ...)
# case 2: attribute(T, float32)

# case 1: attribute(T, float32)
if isinstance(annot, ast.Attribute) and annot.attr in dtypes._all_dtypes:
eval_res = self._try_eval(annot)
if isinstance(eval_res, dtypes.dtype):
self.extra_type_hints[name] = eval_res
return

# case 2: subscript(attribute(T, Tensor), ...) or call(attribute(T, Tensor), ...)
inner = None
if isinstance(annot, ast.Call) and isinstance(annot.func, ast.Attribute):
inner = annot.func
if isinstance(annot, ast.Subscript) and isinstance(annot.value, ast.Attribute):
inner = annot.value
if inner.attr in ["Tensor", "StridedTensor", "ptr"]:
eval_res = self._try_eval(inner)
from tilelang.language.proxy import TensorProxy, StridedTensorProxy, ptr
if inner is not None and inner.attr in ["Tensor", "StridedTensor", "ptr"]:
eval_res = self._try_eval(inner)
from tilelang.language.proxy import TensorProxy, StridedTensorProxy, ptr

if isinstance(eval_res, (TensorProxy, StridedTensorProxy)) or eval_res is ptr:
self.extra_type_hints[name] = ptr
return
if isinstance(eval_res, (TensorProxy, StridedTensorProxy)) or eval_res is ptr:
self.extra_type_hints[name] = ptr
return

def visit_BoolOp(self, node: ast.BoolOp):
node = self.generic_visit(node)
Expand Down
35 changes: 22 additions & 13 deletions tilelang/language/eager/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ class TirTemplate(Generic[_P, _T]):
"""

prim_func: PrimFunc[_P, _T]
matcher: dict[Var, tuple[tvm.tir.Var, str, int]] | None = None
matcher: dict[Var, tuple[tvm.tir.Var, str, int, str]] | None = None
is_lazy_style: bool = False # True if from lazy-style (returns PrimFunc directly)

@classmethod
Expand All @@ -899,10 +899,10 @@ def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate
for k, v in prim_func.buffer_map.items():
for i, s in enumerate(v.shape):
if s in constexpr and s not in matcher:
matcher[s] = (k.name, "shape", i)
matcher[s] = (k.name, "shape", i, s.name)
for i, s in enumerate(v.strides):
if s in constexpr and s not in matcher:
matcher[s] = (k.name, "stride", i)
matcher[s] = (k.name, "stride", i, s.name)
for s in constexpr:
if s not in matcher:
shapes = {k: v.shape for k, v in prim_func.buffer_map.items()}
Expand All @@ -926,15 +926,24 @@ def _parse_phase2_key(self, **kwargs):
if self.matcher is None:
return ()
result = []
for k, ty, i in self.matcher.values():
if ty == "shape":
result.append(kwargs[k].shape[i])
if ty == "stride":
v = kwargs[k]
if isinstance(v, Buffer):
result.append(v.strides[i])
else:
result.append(kwargs[k].stride()[i])
for k, ty, i, name in self.matcher.values():
if name in kwargs:
result.append(kwargs.get(name))
elif k in kwargs:
if ty == "shape":
result.append(kwargs[k].shape[i])
elif ty == "stride":
v = kwargs[k]
if isinstance(v, Buffer):
result.append(v.strides[i])
else:
result.append(kwargs[k].stride()[i])
else:
raise ValueError(
f"Cannot find value for constexpr variable `{name}`\n"
f"Please provide it as a keyword argument, e.g. `{name}=<value>`\n"
f"Or provide the corresponding tensor argument `{k}`."
)
return tuple(result)

def get_tir(self, **kwargs):
Expand Down Expand Up @@ -1015,7 +1024,7 @@ def foo(A, B):
self.p1_cache[p1_key] = TirTemplate.from_lazy_style(prim_func)
return True
return False
except (JITNoBuilderError, EagerJITBuildError):
except (JITNoBuilderError, EagerJITBuildError, TypeError):
# In eager mode, we construct AST directly without prim_func,
# so there's no Builder available when the function is called.
# When eager-only features like T.const() or T.Kernel() are used,
Expand Down
Loading