diff --git a/testing/python/language/test_tilelang_language_lazy_jit.py b/testing/python/language/test_tilelang_language_eager_jit.py similarity index 91% rename from testing/python/language/test_tilelang_language_lazy_jit.py rename to testing/python/language/test_tilelang_language_eager_jit.py index d7eba6c1b..0d58f310d 100644 --- a/testing/python/language/test_tilelang_language_lazy_jit.py +++ b/testing/python/language/test_tilelang_language_eager_jit.py @@ -225,6 +225,25 @@ def copy6(A): assert torch.equal(A[:, 0, :, 0], B) +def test_jit2_compile_with_consts(): + @tilelang.jit + def transpose(X, Y, block_M, block_N): + M, N = T.const("M N") + X: T.Tensor[[M, N], T.float32] + Y: T.Tensor[[N, M], T.float32] + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): + X_tile = T.alloc_shared((block_M, block_N), T.float32) + Y_tile = T.alloc_shared((block_N, block_M), T.float32) + + T.copy(X[bx * block_M, by * block_N], X_tile) + for i, j in T.Parallel(block_M, block_N): + Y_tile[j, i] = X_tile[i, j] + T.copy(Y_tile, Y[by * block_N, bx * block_M]) + + transpose.compile(M=1024, N=1024, block_M=64, block_N=64) + + if __name__ == "__main__": # tilelang.testing.main() test_jit2_return() diff --git a/testing/python/layout/test_tilelang_annotate_loop_layout.py b/testing/python/layout/test_tilelang_annotate_loop_layout.py index fbe9bb8a8..4f41ebf93 100644 --- a/testing/python/layout/test_tilelang_annotate_loop_layout.py +++ b/testing/python/layout/test_tilelang_annotate_loop_layout.py @@ -24,10 +24,8 @@ def loop_layout_fn(i, j): return forward_thread, forward_local M, N = 128, 32 - A = T.Tensor((M, N), T.float32) - B = T.Tensor((M, N), T.float32) loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn) - kernel = loop_layout_kernel.compile(A=A, B=B, loop_layout=loop_layout) + kernel = loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() # Expect vectorized copy along innermost dimension (float4) @@ -42,10 +40,8 @@ def loop_layout_fn(i, j): return forward_thread, forward_local M, N = 128, 32 - A = T.Tensor((M, N), T.float32) - B = T.Tensor((M, N), T.float32) loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn) - kernel = loop_layout_kernel.compile(A=A, B=B, loop_layout=loop_layout) + kernel = loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() assert "*(float4*)(B + ((((int)threadIdx.x) * 32) + (i * 4))) = *(float4*)(A + ((((int)threadIdx.x) * 32) + (i * 4)));" in code @@ -70,13 +66,14 @@ def loop_layout_fn(i, j, rep): return fth, floc M, N = 128, 32 - A = T.Tensor((M, N), T.float32) - B = T.Tensor((M, N), T.float32) loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn, replicate=2) - kernel = copy_with_layout_kernel.compile(A=A, B=B, loop_layout=loop_layout) + kernel = copy_with_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() - assert "*(float4*)(B + ((i * 512) + (((int)threadIdx.x) * 4))) = *(float4*)(A + ((i * 512) + (((int)threadIdx.x) * 4)));" in code + assert ( + "*(float4*)(B + ((i * 256) + ((((int)threadIdx.x) & 63) * 4))) = *(float4*)(A + ((i * 256) + ((((int)threadIdx.x) & 63) * 4)));" + in code + ) @tilelang.jit @@ -93,8 +90,6 @@ def replicate_loop_layout_kernel(A, B, loop_layout): @tilelang.testing.requires_cuda def test_annotate_replicate_loop_layout_vec4(): M, N = 128, 32 - A = T.Tensor((M, N), T.float32) - B = T.Tensor((M, N), T.float32) def loop_layout_fn(i, j, rep): elems = i * 32 + j @@ -104,7 +99,7 @@ def loop_layout_fn(i, j, rep): loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn, replicate=2) - kernel = replicate_loop_layout_kernel.compile(A, B, loop_layout=loop_layout) + kernel = replicate_loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() assert ( diff --git a/tilelang/language/eager/ast.py b/tilelang/language/eager/ast.py index 230925a1d..378306128 100644 --- a/tilelang/language/eager/ast.py +++ b/tilelang/language/eager/ast.py @@ -472,6 +472,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef): node.body = stmts + node.body node.decorator_list.clear() name = node.name + node.args.kwarg = ast.arg(arg="__kwargs") node = SpanAttacher("__tb_fl", "__tb_fn").visit(node) return quote1( f"def make_closure({', '.join(self.nonlocals.keys())}):\n" @@ -503,22 +504,27 @@ def _parse_arg_annot(self, stmt: ast.stmt, arg_names: set[str]): if name not in arg_names: return annot = stmt.annotation - # case 1: subscript(attribute(T, Tensor), ...) - # case 2: attribute(T, float32) + + # case 1: attribute(T, float32) if isinstance(annot, ast.Attribute) and annot.attr in dtypes._all_dtypes: eval_res = self._try_eval(annot) if isinstance(eval_res, dtypes.dtype): self.extra_type_hints[name] = eval_res return + + # case 2: subscript(attribute(T, Tensor), ...) or call(attribute(T, Tensor), ...) + inner = None + if isinstance(annot, ast.Call) and isinstance(annot.func, ast.Attribute): + inner = annot.func if isinstance(annot, ast.Subscript) and isinstance(annot.value, ast.Attribute): inner = annot.value - if inner.attr in ["Tensor", "StridedTensor", "ptr"]: - eval_res = self._try_eval(inner) - from tilelang.language.proxy import TensorProxy, StridedTensorProxy, ptr + if inner is not None and inner.attr in ["Tensor", "StridedTensor", "ptr"]: + eval_res = self._try_eval(inner) + from tilelang.language.proxy import TensorProxy, StridedTensorProxy, ptr - if isinstance(eval_res, (TensorProxy, StridedTensorProxy)) or eval_res is ptr: - self.extra_type_hints[name] = ptr - return + if isinstance(eval_res, (TensorProxy, StridedTensorProxy)) or eval_res is ptr: + self.extra_type_hints[name] = ptr + return def visit_BoolOp(self, node: ast.BoolOp): node = self.generic_visit(node) diff --git a/tilelang/language/eager/builder.py b/tilelang/language/eager/builder.py index f7377a222..fbd950262 100644 --- a/tilelang/language/eager/builder.py +++ b/tilelang/language/eager/builder.py @@ -890,7 +890,7 @@ class TirTemplate(Generic[_P, _T]): """ prim_func: PrimFunc[_P, _T] - matcher: dict[Var, tuple[tvm.tir.Var, str, int]] | None = None + matcher: dict[Var, tuple[tvm.tir.Var, str, int, str]] | None = None is_lazy_style: bool = False # True if from lazy-style (returns PrimFunc directly) @classmethod @@ -899,10 +899,10 @@ def create(cls, prim_func: PrimFunc[_P, _T], constexpr: set[Var]) -> TirTemplate for k, v in prim_func.buffer_map.items(): for i, s in enumerate(v.shape): if s in constexpr and s not in matcher: - matcher[s] = (k.name, "shape", i) + matcher[s] = (k.name, "shape", i, s.name) for i, s in enumerate(v.strides): if s in constexpr and s not in matcher: - matcher[s] = (k.name, "stride", i) + matcher[s] = (k.name, "stride", i, s.name) for s in constexpr: if s not in matcher: shapes = {k: v.shape for k, v in prim_func.buffer_map.items()} @@ -926,15 +926,24 @@ def _parse_phase2_key(self, **kwargs): if self.matcher is None: return () result = [] - for k, ty, i in self.matcher.values(): - if ty == "shape": - result.append(kwargs[k].shape[i]) - if ty == "stride": - v = kwargs[k] - if isinstance(v, Buffer): - result.append(v.strides[i]) - else: - result.append(kwargs[k].stride()[i]) + for k, ty, i, name in self.matcher.values(): + if name in kwargs: + result.append(kwargs.get(name)) + elif k in kwargs: + if ty == "shape": + result.append(kwargs[k].shape[i]) + elif ty == "stride": + v = kwargs[k] + if isinstance(v, Buffer): + result.append(v.strides[i]) + else: + result.append(kwargs[k].stride()[i]) + else: + raise ValueError( + f"Cannot find value for constexpr variable `{name}`\n" + f"Please provide it as a keyword argument, e.g. `{name}=`\n" + f"Or provide the corresponding tensor argument `{k}`." + ) return tuple(result) def get_tir(self, **kwargs): @@ -1015,7 +1024,7 @@ def foo(A, B): self.p1_cache[p1_key] = TirTemplate.from_lazy_style(prim_func) return True return False - except (JITNoBuilderError, EagerJITBuildError): + except (JITNoBuilderError, EagerJITBuildError, TypeError): # In eager mode, we construct AST directly without prim_func, # so there's no Builder available when the function is called. # When eager-only features like T.const() or T.Kernel() are used,