Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 50 additions & 2 deletions testing/python/language/test_tilelang_language_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def main(
return main


def run_reduce_sum_clear(M, N, dtype=T.float32):
program = reduce_sum_test_clear(M, N, dtype)
def run_reduce_sum_clear(M, N, dtype=T.float32, tl_func=reduce_sum_test_clear):
program = tl_func(M, N, dtype)
jit_kernel = tl.compile(program, out_idx=-1)

def ref_program(A):
Expand Down Expand Up @@ -219,5 +219,53 @@ def test_reduce_max_clear():
run_reduce_max_clear(256, 256, T.float16)


def reduce_sum_test_clear_B_shared(M, N, dtype=T.float32):
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_local = T.alloc_fragment((M, N), dtype)
B_shared = T.alloc_shared((M,), dtype)

T.copy(A, A_local)
T.fill(B_shared, 1)
T.reduce_sum(A_local, B_shared, dim=1, clear=False)
T.copy(B_shared, B)

return main


def test_reduce_sum_clear_B_shared():
run_reduce_sum_clear(256, 256, T.float32, reduce_sum_test_clear_B_shared)


def reduce_sum_test_clear_AB_shared(M, N, dtype=T.float32):
import tilelang.language as T

@T.prim_func
def main(
A: T.Tensor((M, N), dtype),
B: T.Tensor((M,), dtype),
):
with T.Kernel(1, threads=32) as _:
A_shared = T.alloc_shared((M, N), dtype)
B_shared = T.alloc_shared((M,), dtype)

T.copy(A, A_shared, disable_tma=True)
T.fill(B_shared, 1)
T.reduce_sum(A_shared, B_shared, dim=1, clear=False)
T.copy(B_shared, B)

return main


def test_reduce_sum_clear_AB_shared():
run_reduce_sum_clear(64, 64, T.float32, reduce_sum_test_clear_AB_shared)


if __name__ == "__main__":
tilelang.testing.main()
6 changes: 6 additions & 0 deletions tilelang/language/reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int
IRBuilder.name(buffer.name + "_frag", red_frag_in)
IRBuilder.name(out.name + "_frag", red_frag_out)

if not clear:
copy(out, red_frag_out)

copy(buffer, red_frag_in)
tir.call_intrin(
"handle",
Expand Down Expand Up @@ -78,6 +81,9 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int
red_frag_out = alloc_fragment(out.shape, out.dtype)
IRBuilder.name(out.name + "_frag", red_frag_out)

if not clear:
copy(out, red_frag_out)

tir.call_intrin(
"handle",
tir.op.Op.get(_REDUCE_OP_KEY),
Expand Down
Loading