Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/op/atomic_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
return vectorized_thread_loop;
}

TIR_REGISTER_TL_OP(AtomicAdd, atomicadd)
TIR_REGISTER_TL_TILE_OP(AtomicAdd, atomicadd)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
4 changes: 2 additions & 2 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2037,7 +2037,7 @@ Array<PrimExpr> TMAIm2ColDesc::EncodeCallArgs() const {
// - Takes 5 inputs: src_buffer, dst_buffer, coalesced_width, disable_tma,
// eviction_policy
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Copy, copy)
TIR_REGISTER_TL_TILE_OP(Copy, copy)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand All @@ -2062,7 +2062,7 @@ LayoutMap Conv2DIm2ColOpNode::InferLayout(const LayoutInferArgs &T,
// - Takes 9 inputs: src_buffer, dst_buffer, nhw_step, c_step, kernel, stride,
// dilation, padding, eviction_policy
// - Marked as opaque since it has side effects (memory writes)
TIR_REGISTER_TL_OP(Conv2DIm2ColOp, c2d_im2col)
TIR_REGISTER_TL_TILE_OP(Conv2DIm2ColOp, c2d_im2col)
.set_num_inputs(9)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
2 changes: 1 addition & 1 deletion src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ LayoutMap FillNode::InferLayout(const LayoutInferArgs &T,
return {};
}

TIR_REGISTER_TL_OP(Fill, fill)
TIR_REGISTER_TL_TILE_OP(Fill, fill)
.set_num_inputs(2)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
2 changes: 1 addition & 1 deletion src/op/finalize_reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ TileOperator FinalizeReducerOpNode::Clone() const {
return TileOperator(node);
}

TIR_REGISTER_TL_OP(FinalizeReducerOp, finalize_reducer)
TIR_REGISTER_TL_TILE_OP(FinalizeReducerOp, finalize_reducer)
.set_num_inputs(1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
2 changes: 1 addition & 1 deletion src/op/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ LayoutMap GemmNode::InferLayout(const LayoutInferArgs &T,
return results;
}

TIR_REGISTER_TL_OP(Gemm, gemm)
TIR_REGISTER_TL_TILE_OP(Gemm, gemm)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
2 changes: 1 addition & 1 deletion src/op/gemm_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ LayoutMap GemmPyNode::InferLayout(const LayoutInferArgs &T,
return results;
}

TIR_REGISTER_TL_OP(GemmPy, gemm_py)
TIR_REGISTER_TL_TILE_OP(GemmPy, gemm_py)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
2 changes: 1 addition & 1 deletion src/op/gemm_sp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ LayoutMap GemmSPNode::InferLayout(const LayoutInferArgs &T,
return results;
}

TIR_REGISTER_TL_OP(GemmSP, gemm_sp)
TIR_REGISTER_TL_TILE_OP(GemmSP, gemm_sp)
.set_num_inputs(5)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
6 changes: 3 additions & 3 deletions src/op/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ TileOperator ParseOperator(Stmt stmt);

using OpBuilderFunc = ffi::TypedFunction<TileOperator(Array<PrimExpr>)>;

#define TIR_REGISTER_TL_OP(Entry, OpName) \
#define TIR_REGISTER_TL_TILE_OP(Entry, OpName) \
const Op &Entry::Get() { \
static const Op &op = Op::Get("tl." #OpName); \
static const Op &op = Op::Get("tl.tileop." #OpName); \
return op; \
} \
TVM_REGISTER_OP("tl." #OpName) \
TVM_REGISTER_OP("tl.tileop." #OpName) \
.set_attr<TScriptPrinterName>("TScriptPrinterName", #OpName) \
.set_attr<OpBuilderFunc>( \
"TLOpBuilder", [](Array<PrimExpr> args) { return Entry(args); })
Expand Down
4 changes: 2 additions & 2 deletions src/op/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}

TIR_REGISTER_TL_OP(ReduceOp, reduce)
TIR_REGISTER_TL_TILE_OP(ReduceOp, reduce)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down Expand Up @@ -563,7 +563,7 @@ LayoutMap CumSumOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}

TIR_REGISTER_TL_OP(CumSumOp, cumsum)
TIR_REGISTER_TL_TILE_OP(CumSumOp, cumsum)
.set_num_inputs(4)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));
Expand Down
12 changes: 1 addition & 11 deletions src/op/region.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,7 @@ LayoutMap RegionOpNode::InferLayout(const LayoutInferArgs &T,
return {};
}

const Op &RegionOp::Get() {
static const Op &op = Op::Get("tl.region");
return op;
}

TVM_REGISTER_OP("tl.region")
.set_attr<TScriptPrinterName>("TScriptPrinterName", "region")
.set_attr<OpBuilderFunc>("TLOpBuilder",
[](Array<PrimExpr> args) {
return RegionOp(args);
})
TIR_REGISTER_TL_TILE_OP(RegionOp, region)
.set_num_inputs(-1)
.set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kPure));
Expand Down
173 changes: 173 additions & 0 deletions testing/python/analysis/test_tilelang_nested_loop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,5 +550,178 @@ def test_mixed_pp():
run_gemm_mixed_pp(order=[0, 1, 2], stage=[0, 0, 1])


"""
TiledOp in a T.Parallel is also not permitted.
"""


def matmul_with_parallel(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
accum_dtype,
threads,
order,
stage,
):
A_shape = (M, K)
B_shape = (K, N)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_K, block_N)

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, in_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), order=order, stage=stage):
for i, j in T.Parallel(block_M, block_K):
A_shared[i, j] = A[by * block_M + i, k * block_K + j]
for i, j in T.Parallel(block_K, block_N):
B_shared[i, j] = B[k * block_K + i, bx * block_N + j]

# T.copy(A[by * block_M, k * block_K], A_shared)
# T.copy(B[k * block_K, bx * block_N], B_shared)

for _ in T.Parallel(1):
T.gemm(A_shared, B_shared, C_local, False, False)
T.copy(C_local, C[by * block_M, bx * block_N])

return main


def run_gemm_tiled_op_with_parallel(
order,
stage,
):
M = 1024
N = 1024
K = 1024
block_M = 128
block_N = 128
block_K = 32
in_dtype = "float16"
out_dtype = "float16"
dtypeAccum = "float32"
num_threads = 128

program = matmul_nested_pipa(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)

kernel = tilelang.compile(
program,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})
profiler = kernel.get_profiler()

def ref_program(A, B):
import torch

if in_dtype == "float32":
# Convert float32 to tfloat32 because tfloat32 mma cannot truncate
# float32 automatically, -0x1000 meas
A = ((A.view(torch.int32) - 0x1000)).view(torch.float32)
B = ((B.view(torch.int32) - 0x1000)).view(torch.float32)
C = torch.matmul(A.to(torch.float), B.to(torch.float))
C = C.to(torch.__getattribute__(out_dtype))
return C

profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)

program1 = matmul_with_parallel(
M,
N,
K,
block_M,
block_N,
block_K,
in_dtype,
out_dtype,
dtypeAccum,
num_threads,
order,
stage,
)
with pytest.raises(ValueError):
tilelang.compile(
program1,
out_idx=[2],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
})


@tilelang.jit(out_idx=[1])
def tir_op_with_parallel(length=256, block=16, dtype="float32"):

@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
for j in T.Parallel(block):
B[i * block + j] = T.max(A[i * block + j], 0.0)

return main


@tilelang.jit(out_idx=[1])
def customize_op_with_parallel(length=256, block=16, dtype="float32"):

@T.prim_func
def main(
A: T.Tensor((length,), dtype),
B: T.Tensor((length,), dtype),
):
with T.Kernel(1, threads=length) as _:
for i in T.Parallel(length // block):
for j in T.Parallel(block):
B[i * block + j] = A[i * block + j]
T.atomic_add(B[i * block + j], 1.0)

return main


def test_tiled_op_with_parallel():
run_gemm_tiled_op_with_parallel(order=[0, 1, 2], stage=[0, 0, 1])

kernel1 = tir_op_with_parallel(length=256, block=16)
data = _require_cuda_tensor((256,), torch.float32)
result1 = kernel1(data)
torch.testing.assert_close(result1, torch.relu(data), atol=1e-5, rtol=1e-5)
kernel2 = customize_op_with_parallel(length=256, block=16)
result2 = kernel2(data)
torch.testing.assert_close(result2, data + 1, atol=1e-5, rtol=1e-5)


if __name__ == "__main__":
tilelang.testing.main()
2 changes: 1 addition & 1 deletion tilelang/analysis/ast_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def pre_visit(statement: tir.Stmt) -> None:
Pre-order visitor to print all visited statements.
"""

print(f"Visiting statement: {type(statement)}")
print(f"Visiting statement: {type(statement)}, {statement}")

def pass_fn(func: PrimFunc, mod, ctx) -> PrimFunc:
new_body = ir_transform(func.body, pre_visit, None)
Expand Down
18 changes: 16 additions & 2 deletions tilelang/analysis/nested_loop_checker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from tvm import tir
from tvm.tir import (
For,
Call,
PrimFunc,
PyStmtExprVisitor,
)
Expand All @@ -17,6 +18,12 @@ def is_pipelined_for(op: For) -> bool:
return any(key in op.annotations for key in anno_keys)


def is_tile_op(op: Call) -> bool:
"""Check if a call is a tile-op"""

return op.op.get_attr("TLOpBuilder") is not None


@tir.functor.visitor
class _NestedLoopCheckVisitor(PyStmtExprVisitor):

Expand All @@ -39,7 +46,7 @@ def visit_for_(self, op: For) -> None:
"Nested parallel loops are not allowed. "
"Please check your loop structure.")
self.in_parallel_context = True
self.visit_stmt(child)
super().visit_for_(op)
self.in_parallel_context = False
return
elif is_pipelined_for(op):
Expand All @@ -48,7 +55,14 @@ def visit_for_(self, op: For) -> None:
"Pipelined loop cannot be nested inside a parallel loop. "
"Please check your loop structure.")

self.visit_stmt(op.body)
super().visit_for_(op)

def visit_call_(self, op: Call) -> None:
if self.in_parallel_context and is_tile_op(op):
raise ValueError("[Tilelang Semantic Check] "
"Only elementwise operations are allowed inside a parallel loop. " \
f"Got a tile-op \"{op.op}\"."
)


def NestedLoopChecker():
Expand Down
2 changes: 0 additions & 2 deletions tilelang/engine/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,8 @@ def PreLowerSemanticCheck(mod: IRModule) -> None:

# Debug
# tilelang.analysis.ASTPrinter()(mod)

# Check if there are any invalid nested loops.
tilelang.analysis.NestedLoopChecker()(mod)

# Check if there are any invalid symbolic T.Parallel + fragment access.
tilelang.analysis.FragmentLoopChecker()(mod)

Expand Down
4 changes: 2 additions & 2 deletions tilelang/language/atomic.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ def get_extent(data):
"return_prev is not supported for tile-region-based atomic operations")

if memory_order is None:
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma, 0)
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma, 0)
else:
return T.call_intrin("handle", op.Op.get("tl.atomicadd"), value, dst, use_tma,
return T.call_intrin("handle", op.Op.get("tl.tileop.atomicadd"), value, dst, use_tma,
_MEMORY_ORDER_ID_MAP[memory_order])


Expand Down
4 changes: 2 additions & 2 deletions tilelang/language/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_extent(data):
eviction_policy = 0
else:
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
return tir.call_intrin("handle", tir.op.Op.get("tl.copy"), src, dst, coalesced_width,
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.copy"), src, dst, coalesced_width,
disable_tma, eviction_policy)


Expand Down Expand Up @@ -124,5 +124,5 @@ def c2d_im2col(img: tir.Buffer,
eviction_policy = {"evict_normal": 0, "evict_first": 1, "evict_last": 2}[eviction_policy]
img_region = to_buffer_region(img, access_type="r")
col_region = to_buffer_region(col, access_type="w")
return tir.call_intrin("handle", tir.op.Op.get("tl.c2d_im2col"), img_region, col_region,
return tir.call_intrin("handle", tir.op.Op.get("tl.tileop.c2d_im2col"), img_region, col_region,
nhw_step, c_step, kernel, stride, dilation, pad, eviction_policy)
Loading
Loading