diff --git a/testing/python/language/test_tilelang_language_reduce.py b/testing/python/language/test_tilelang_language_reduce.py index 1d9bf6130..f12c5bc4a 100644 --- a/testing/python/language/test_tilelang_language_reduce.py +++ b/testing/python/language/test_tilelang_language_reduce.py @@ -159,8 +159,8 @@ def main( return main -def run_reduce_sum_clear(M, N, dtype=T.float32): - program = reduce_sum_test_clear(M, N, dtype) +def run_reduce_sum_clear(M, N, dtype=T.float32, tl_func=reduce_sum_test_clear): + program = tl_func(M, N, dtype) jit_kernel = tl.compile(program, out_idx=-1) def ref_program(A): @@ -219,5 +219,53 @@ def test_reduce_max_clear(): run_reduce_max_clear(256, 256, T.float16) +def reduce_sum_test_clear_B_shared(M, N, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_shared = T.alloc_shared((M,), dtype) + + T.copy(A, A_local) + T.fill(B_shared, 1) + T.reduce_sum(A_local, B_shared, dim=1, clear=False) + T.copy(B_shared, B) + + return main + + +def test_reduce_sum_clear_B_shared(): + run_reduce_sum_clear(256, 256, T.float32, reduce_sum_test_clear_B_shared) + + +def reduce_sum_test_clear_AB_shared(M, N, dtype=T.float32): + import tilelang.language as T + + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_shared = T.alloc_shared((M, N), dtype) + B_shared = T.alloc_shared((M,), dtype) + + T.copy(A, A_shared, disable_tma=True) + T.fill(B_shared, 1) + T.reduce_sum(A_shared, B_shared, dim=1, clear=False) + T.copy(B_shared, B) + + return main + + +def test_reduce_sum_clear_AB_shared(): + run_reduce_sum_clear(64, 64, T.float32, reduce_sum_test_clear_AB_shared) + + if __name__ == "__main__": tilelang.testing.main() diff --git a/tilelang/language/reduce_op.py b/tilelang/language/reduce_op.py index 9db56df0d..e6e69594e 100644 --- a/tilelang/language/reduce_op.py +++ b/tilelang/language/reduce_op.py @@ -49,6 +49,9 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int IRBuilder.name(buffer.name + "_frag", red_frag_in) IRBuilder.name(out.name + "_frag", red_frag_out) + if not clear: + copy(out, red_frag_out) + copy(buffer, red_frag_in) tir.call_intrin( "handle", @@ -78,6 +81,9 @@ def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int red_frag_out = alloc_fragment(out.shape, out.dtype) IRBuilder.name(out.name + "_frag", red_frag_out) + if not clear: + copy(out, red_frag_out) + tir.call_intrin( "handle", tir.op.Op.get(_REDUCE_OP_KEY),