diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 1a49b7706..4ae19bafc 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -539,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } -TIR_REGISTER_TL_OP(AtomicAdd, atomicadd) +TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd) .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/copy.cc b/src/op/copy.cc index 1bd548bc5..93a0ff0e0 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -2037,7 +2037,7 @@ Array TMAIm2ColDesc::EncodeCallArgs() const { // - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma, // eviction_policy // - Marked as opaque since it has side effects (memory writes) -TIR_REGISTER_TL_OP(Copy, copy) +TIR_REGISTER_TL_TILE_OP(Copy, copy) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -2062,7 +2062,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T, // - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride, // dilation, padding, eviction_policy // - Marked as opaque since it has side effects (memory writes) -TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col) +TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col) .set_num_inputs(9) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/fill.cc b/src/op/fill.cc index 5a773768a..714e97ad2 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -209,7 +209,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T, return {}; } -TIR_REGISTER_TL_OP(Fill, fill) +TIR_REGISTER_TL_TILE_OP(Fill, fill) .set_num_inputs(2) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/finalize_reducer.cc b/src/op/finalize_reducer.cc index effc4baf0..f542b2d91 100644 --- a/src/op/finalize_reducer.cc +++ b/src/op/finalize_reducer.cc @@ -159,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const { return TileOperator(node); } -TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer) +TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer) .set_num_inputs(1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/gemm.cc b/src/op/gemm.cc index 5a98cba69..dd14eb746 100644 --- a/src/op/gemm.cc +++ b/src/op/gemm.cc @@ -826,7 +826,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T, return results; } -TIR_REGISTER_TL_OP(Gemm, gemm) +TIR_REGISTER_TL_TILE_OP(Gemm, gemm) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/gemm_py.cc b/src/op/gemm_py.cc index aa6c02823..f12a2de5e 100644 --- a/src/op/gemm_py.cc +++ b/src/op/gemm_py.cc @@ -318,7 +318,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T, return results; } -TIR_REGISTER_TL_OP(GemmPy, gemm_py) +TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/gemm_sp.cc b/src/op/gemm_sp.cc index df923d0e9..bdabefaf2 100644 --- a/src/op/gemm_sp.cc +++ b/src/op/gemm_sp.cc @@ -302,7 +302,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T, return results; } -TIR_REGISTER_TL_OP(GemmSP, gemm_sp) +TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp) .set_num_inputs(5) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/operator.h b/src/op/operator.h index 0d9f859a7..1453f9c1e 100644 --- a/src/op/operator.h +++ b/src/op/operator.h @@ -77,12 +77,12 @@ TileOperator ParseOperator(Stmt stmt); using OpBuilderFunc = ffi::TypedFunction)>; -#define TIR_REGISTER_TL_OP(Entry, OpName) \ +#define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \ const Op &Entry::Get() { \ - static const Op &op = Op::Get("tl." #OpName); \ + static const Op &op = Op::Get("tl.tileop." #OpName); \ return op; \ } \ - TVM_REGISTER_OP("tl." #OpName) \ + TVM_REGISTER_OP("tl.tileop." #OpName) \ .set_attr("TScriptPrinterName", #OpName) \ .set_attr( \ "TLOpBuilder", [](Array args) { return Entry(args); }) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index caf9198a7..40c9b83cd 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -478,7 +478,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, return {}; } -TIR_REGISTER_TL_OP(ReduceOp, reduce) +TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); @@ -563,7 +563,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T, return {}; } -TIR_REGISTER_TL_OP(CumSumOp, cumsum) +TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum) .set_num_inputs(4) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); diff --git a/src/op/region.cc b/src/op/region.cc index 2a1f27456..25e78eba8 100644 --- a/src/op/region.cc +++ b/src/op/region.cc @@ -76,17 +76,7 @@ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T, return {}; } -const Op &RegionOp::Get() { - static const Op &op = Op::Get("tl.region"); - return op; -} - -TVM_REGISTER_OP("tl.region") - .set_attr("TScriptPrinterName", "region") - .set_attr("TLOpBuilder", - [](Array args) { - return RegionOp(args); - }) +TIR_REGISTER_TL_TILE_OP(RegionOp, region) .set_num_inputs(-1) .set_attr("TCallEffectKind", Integer(CallEffectKind::kPure)); diff --git a/testing/python/analysis/test_tilelang_nested_loop_checker.py b/testing/python/analysis/test_tilelang_nested_loop_checker.py index b572a707a..d3c2ec20e 100644 --- a/testing/python/analysis/test_tilelang_nested_loop_checker.py +++ b/testing/python/analysis/test_tilelang_nested_loop_checker.py @@ -550,5 +550,178 @@ def test_mixed_pp(): run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1]) +""" +TiledOp in a T.Parallel is also not permitted. +""" + + +def matmul_with_parallel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + threads, + order, + stage, +): + A_shape = (M, K) + B_shape = (K, N) + A_shared_shape = (block_M, block_K) + B_shared_shape = (block_K, block_N) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype) + B_shared = T.alloc_shared(B_shared_shape, in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage): + for i, j in T.Parallel(block_M, block_K): + A_shared[i, j] = A[by * block_M + i, k * block_K + j] + for i, j in T.Parallel(block_K, block_N): + B_shared[i, j] = B[k * block_K + i, bx * block_N + j] + + # T.copy(A[by * block_M, k * block_K], A_shared) + # T.copy(B[k * block_K, bx * block_N], B_shared) + + for _ in T.Parallel(1): + T.gemm(A_shared, B_shared, C_local, False, False) + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def run_gemm_tiled_op_with_parallel( + order, + stage, +): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + in_dtype = "float16" + out_dtype = "float16" + dtypeAccum = "float32" + num_threads = 128 + + program = matmul_nested_pipa( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + + kernel = tilelang.compile( + program, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + profiler = kernel.get_profiler() + + def ref_program(A, B): + import torch + + if in_dtype == "float32": + # Convert float32 to tfloat32 because tfloat32 mma cannot truncate + # float32 automatically, -0x1000 meas + A = ((A.view(torch.int32) - 0x1000)).view(torch.float32) + B = ((B.view(torch.int32) - 0x1000)).view(torch.float32) + C = torch.matmul(A.to(torch.float), B.to(torch.float)) + C = C.to(torch.__getattribute__(out_dtype)) + return C + + profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) + + program1 = matmul_with_parallel( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + dtypeAccum, + num_threads, + order, + stage, + ) + with pytest.raises(ValueError): + tilelang.compile( + program1, + out_idx=[2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }) + + +@tilelang.jit(out_idx=[1]) +def tir_op_with_parallel(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = T.max(A[i * block + j], 0.0) + + return main + + +@tilelang.jit(out_idx=[1]) +def customize_op_with_parallel(length=256, block=16, dtype="float32"): + + @T.prim_func + def main( + A: T.Tensor((length,), dtype), + B: T.Tensor((length,), dtype), + ): + with T.Kernel(1, threads=length) as _: + for i in T.Parallel(length // block): + for j in T.Parallel(block): + B[i * block + j] = A[i * block + j] + T.atomic_add(B[i * block + j], 1.0) + + return main + + +def test_tiled_op_with_parallel(): + run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1]) + + kernel1 = tir_op_with_parallel(length=256, block=16) + data = _require_cuda_tensor((256,), torch.float32) + result1 = kernel1(data) + torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5) + kernel2 = customize_op_with_parallel(length=256, block=16) + result2 = kernel2(data) + torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/analysis/ast_printer.py b/tilelang/analysis/ast_printer.py index c54ec5cf9..e634e0271 100644 --- a/tilelang/analysis/ast_printer.py +++ b/tilelang/analysis/ast_printer.py @@ -14,7 +14,7 @@ def pre_visit(statement: tir.Stmt) -> None: Pre-order visitor to print all visited statements. """ - print(f"Visiting statement: {type(statement)}") + print(f"Visiting statement: {type(statement)}, {statement}") def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc: new_body = ir_transform(func.body, pre_visit, None) diff --git a/tilelang/analysis/nested_loop_checker.py b/tilelang/analysis/nested_loop_checker.py index 7a0d94daa..eff0fc2db 100644 --- a/tilelang/analysis/nested_loop_checker.py +++ b/tilelang/analysis/nested_loop_checker.py @@ -1,6 +1,7 @@ from tvm import tir from tvm.tir import ( For, + Call, PrimFunc, PyStmtExprVisitor, ) @@ -17,6 +18,12 @@ def is_pipelined_for(op: For) -> bool: return any(key in op.annotations for key in anno_keys) +def is_tile_op(op: Call) -> bool: + """Check if a call is a tile-op""" + + return op.op.get_attr("TLOpBuilder") is not None + + @tir.functor.visitor class _NestedLoopCheckVisitor(PyStmtExprVisitor): @@ -39,7 +46,7 @@ def visit_for_(self, op: For) -> None: "Nested parallel loops are not allowed. " "Please check your loop structure.") self.in_parallel_context = True - self.visit_stmt(child) + super().visit_for_(op) self.in_parallel_context = False return elif is_pipelined_for(op): @@ -48,7 +55,14 @@ def visit_for_(self, op: For) -> None: "Pipelined loop cannot be nested inside a parallel loop. " "Please check your loop structure.") - self.visit_stmt(op.body) + super().visit_for_(op) + + def visit_call_(self, op: Call) -> None: + if self.in_parallel_context and is_tile_op(op): + raise ValueError("[Tilelang Semantic Check] " + "Only elementwise operations are allowed inside a parallel loop. " \ + f"Got a tile-op \"{op.op}\"." + ) def NestedLoopChecker(): diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index dfa8050a3..1a98c8937 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -76,10 +76,8 @@ def PreLowerSemanticCheck(mod: IRModule) -> None: # Debug # tilelang.analysis.ASTPrinter()(mod) - # Check if there are any invalid nested loops. tilelang.analysis.NestedLoopChecker()(mod) - # Check if there are any invalid symbolic T.Parallel + fragment access. tilelang.analysis.FragmentLoopChecker()(mod) diff --git a/tilelang/language/atomic.py b/tilelang/language/atomic.py index 56f87473f..07e45bbc8 100644 --- a/tilelang/language/atomic.py +++ b/tilelang/language/atomic.py @@ -212,9 +212,9 @@ def get_extent(data): "return_prev is not supported for tile-region-based atomic operations") if memory_order is None: - return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0) + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0) else: - return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, + return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, _MEMORY_ORDER_ID_MAP[memory_order]) diff --git a/tilelang/language/copy.py b/tilelang/language/copy.py index 965919fd4..cabc4a3e4 100644 --- a/tilelang/language/copy.py +++ b/tilelang/language/copy.py @@ -90,7 +90,7 @@ def get_extent(data): eviction_policy = 0 else: eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] - return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width, + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width, disable_tma, eviction_policy) @@ -124,5 +124,5 @@ def c2d_im2col(img: tir.Buffer, eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy] img_region = to_buffer_region(img, access_type="r") col_region = to_buffer_region(col, access_type="w") - return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region, + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region, nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy) diff --git a/tilelang/language/experimental/gemm_sp.py b/tilelang/language/experimental/gemm_sp.py index 7cc3d736d..b391d2d00 100644 --- a/tilelang/language/experimental/gemm_sp.py +++ b/tilelang/language/experimental/gemm_sp.py @@ -70,7 +70,7 @@ def legalize_arguments(arg: tir.Buffer | tir.Var): C_arg = to_buffer_region(C, access_type="rw") return tir.call_intrin( "handle", - tir.op.Op.get("tl.gemm_sp"), + tir.op.Op.get("tl.tileop.gemm_sp"), A_arg, E_arg, B_arg, diff --git a/tilelang/language/fill.py b/tilelang/language/fill.py index fbbcf1b63..b23733377 100644 --- a/tilelang/language/fill.py +++ b/tilelang/language/fill.py @@ -32,7 +32,7 @@ def fill(buffer: tir.Buffer | tir.BufferRegion | tir.BufferLoad, value: tir.Prim extents = [tir.IntImm("int32", 1) for _ in buffer.indices] else: extents = [] - return tir.call_intrin("handle", tir.op.Op.get("tl.fill"), + return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.fill"), to_buffer_region(buffer, access_type="w", extents=extents), value) diff --git a/tilelang/language/gemm.py b/tilelang/language/gemm.py index 2bfd3a0cf..db8e04aba 100644 --- a/tilelang/language/gemm.py +++ b/tilelang/language/gemm.py @@ -116,7 +116,7 @@ def gemm_v1( ): """GEMM v1: use op tl.gemm.""" return _gemm_impl( - "tl.gemm", + "tl.tileop.gemm", A, B, C, @@ -145,7 +145,7 @@ def gemm_v2( ): """GEMM v2: use op tl.gemm_py.""" return _gemm_impl( - "tl.gemm_py", + "tl.tileop.gemm_py", A, B, C, diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 3c4d8187b..fb84b6d78 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -13,6 +13,9 @@ def _legalize_dim(buffer: tir.Buffer, dim: int): return dim +_REDUCE_OP_KEY = "tl.tileop.reduce" + + def reduce(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool): """Perform a reduction operation on a buffer along a specified dimension. @@ -50,7 +53,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int copy(buffer, red_frag_in) tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, @@ -65,7 +68,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int copy(buffer, red_frag_in) tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(out, access_type="w"), reduce_type, @@ -78,7 +81,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(buffer, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, @@ -89,7 +92,7 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int elif is_fragment(buffer) and is_fragment(out): tir.call_intrin( "handle", - tir.op.Op.get("tl.reduce"), + tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(buffer, access_type="r"), to_buffer_region(out, access_type="w"), reduce_type, @@ -245,7 +248,7 @@ def cumsum_fragment(src: tir.Buffer, dst: tir.Buffer, dim: int, reverse: bool) - copy(src, cumsum_smem) tir.call_intrin( "handle", - tir.op.Op.get("tl.cumsum"), + tir.op.Op.get("tl.tileop.cumsum"), to_buffer_region(cumsum_smem, access_type="r"), to_buffer_region(cumsum_smem, access_type="w"), dim, @@ -299,7 +302,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse return cumsum_fragment(src, dst, dim, reverse) return tir.call_intrin( "handle", - tir.op.Op.get("tl.cumsum"), + tir.op.Op.get("tl.tileop.cumsum"), to_buffer_region(src, access_type="r"), to_buffer_region(dst, access_type="w"), dim, @@ -309,7 +312,7 @@ def cumsum(src: tir.Buffer, dst: tir.Buffer | None = None, dim: int = 0, reverse def finalize_reducer(reducer: tir.Buffer): """ - Finalize a reducer buffer by emitting the `tl.finalize_reducer` intrinsic. + Finalize a reducer buffer by emitting the `tl.tileop.finalize_reducer` intrinsic. This returns a TVM `tir.Call` handle that finalizes the given reducer using its writable pointer. The call does not modify Python objects directly; it produces the low-level intrinsic call used by the IR. @@ -322,7 +325,7 @@ def finalize_reducer(reducer: tir.Buffer): """ return tir.call_intrin( "handle", - tir.op.Op.get("tl.finalize_reducer"), + tir.op.Op.get("tl.tileop.finalize_reducer"), to_buffer_region(reducer, access_type="w"), ) diff --git a/tilelang/language/utils.py b/tilelang/language/utils.py index 75fea4c09..136bc0bac 100644 --- a/tilelang/language/utils.py +++ b/tilelang/language/utils.py @@ -7,7 +7,7 @@ def region(buffer: BufferLoad, access_type: str, *args: PrimExpr): """Create a tl.region call for a BufferLoad and extents.""" access_type = {"r": 1, "w": 2, "rw": 3}[access_type] - return T.call_intrin("handle", op.Op.get("tl.region"), buffer, access_type, *args) + return T.call_intrin("handle", op.Op.get("tl.tileop.region"), buffer, access_type, *args) def buffer_load_to_tile_region(load: BufferLoad, access_type: str, extents: list[PrimExpr]):