diff --git a/src/op/builtin.cc b/src/op/builtin.cc index e7e86f2f5..ced86cfaa 100644 --- a/src/op/builtin.cc +++ b/src/op/builtin.cc @@ -341,5 +341,30 @@ TIR_DEFINE_TL_BUILTIN(tcgen05_mma_arrive) .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); +TIR_DEFINE_TL_BUILTIN(warp_reduce_sum) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_max) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_min) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_bitand) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + +TIR_DEFINE_TL_BUILTIN(warp_reduce_bitor) + .set_num_inputs(1) + .set_attr("TCallEffectKind", + Integer(CallEffectKind::kOpaque)); + } // namespace tl } // namespace tvm diff --git a/src/op/builtin.h b/src/op/builtin.h index f5c7d9edc..7ae638f1a 100644 --- a/src/op/builtin.h +++ b/src/op/builtin.h @@ -571,6 +571,31 @@ TVM_DLL const Op &device_assert(); */ TVM_DLL const Op &device_assert_with_msg(); +/*! + * \brief tilelang intrinsic for warp reduction sum. + */ +TVM_DLL const Op &warp_reduce_sum(); + +/*! + * \brief tilelang intrinsic for warp reduction max. + */ +TVM_DLL const Op &warp_reduce_max(); + +/*! + * \brief tilelang intrinsic for warp reduction min. + */ +TVM_DLL const Op &warp_reduce_min(); + +/*! + * \brief tilelang intrinsic for warp reduction bitand. + */ +TVM_DLL const Op &warp_reduce_bitand(); + +/*! + * \brief tilelang intrinsic for warp reduction bitor. + */ +TVM_DLL const Op &warp_reduce_bitor(); + } // namespace tl } // namespace tvm diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index dda969253..99512b8be 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -2609,6 +2609,16 @@ void CodeGenTileLangCUDA::VisitExpr_(const CallNode *op, std::ostream &os) { std::string func_name = math_func(op->dtype, "fdiv", rounding_mode); os << func_name << "(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_sum())) { + os << "tl::warp_reduce_sum(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_max())) { + os << "tl::warp_reduce_max(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_min())) { + os << "tl::warp_reduce_min(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_bitand())) { + os << "tl::warp_reduce_bitand(" << PrintExpr(op->args[0]) << ")"; + } else if (op->op.same_as(tl::warp_reduce_bitor())) { + os << "tl::warp_reduce_bitor(" << PrintExpr(op->args[0]) << ")"; } else { CodeGenC::VisitExpr_(op, os); } diff --git a/src/tl_templates/cuda/reduce.h b/src/tl_templates/cuda/reduce.h index a083c7119..458242649 100644 --- a/src/tl_templates/cuda/reduce.h +++ b/src/tl_templates/cuda/reduce.h @@ -250,4 +250,35 @@ template struct CumSum2D { } }; +template +TL_DEVICE T warp_reduce(T value, ReduceOp op) { + constexpr uint32_t mask = 0xffffffff; + value = op(value, __shfl_xor_sync(mask, value, 16)); + value = op(value, __shfl_xor_sync(mask, value, 8)); + value = op(value, __shfl_xor_sync(mask, value, 4)); + value = op(value, __shfl_xor_sync(mask, value, 2)); + value = op(value, __shfl_xor_sync(mask, value, 1)); + return value; +} + +template TL_DEVICE T warp_reduce_sum(T value) { + return warp_reduce(value, SumOp()); +} + +template TL_DEVICE T warp_reduce_max(T value) { + return warp_reduce(value, MaxOp()); +} + +template TL_DEVICE T warp_reduce_min(T value) { + return warp_reduce(value, MinOp()); +} + +template TL_DEVICE T warp_reduce_bitand(T value) { + return warp_reduce(value, BitAndOp()); +} + +template TL_DEVICE T warp_reduce_bitor(T value) { + return warp_reduce(value, BitOrOp()); +} + } // namespace tl diff --git a/testing/python/language/test_tilelang_language_warp_reduce.py b/testing/python/language/test_tilelang_language_warp_reduce.py new file mode 100644 index 000000000..681b23470 --- /dev/null +++ b/testing/python/language/test_tilelang_language_warp_reduce.py @@ -0,0 +1,83 @@ +import torch + +import tilelang +import tilelang.testing +import tilelang.language as T + + +@tilelang.jit +def get_kernel(reduce_op: str, dtype: str): + + assert reduce_op in ["sum", "max", "min", "bitand", "bitor"] + + @T.prim_func + def main(x: T.Tensor((32), dtype)): + with T.Kernel(1, threads=32): + tx = T.get_thread_binding(0) + local_val = T.alloc_local([1], dtype) + local_val[0] = x[tx] + reduced_val = T.alloc_local([1], dtype) + if reduce_op == "sum": + reduced_val[0] = T.warp_reduce_sum(local_val[0]) + elif reduce_op == "max": + reduced_val[0] = T.warp_reduce_max(local_val[0]) + elif reduce_op == "min": + reduced_val[0] = T.warp_reduce_min(local_val[0]) + elif reduce_op == "bitand": + reduced_val[0] = T.warp_reduce_bitand(local_val[0]) + elif reduce_op == "bitor": + reduced_val[0] = T.warp_reduce_bitor(local_val[0]) + x[tx] = reduced_val[0] + + return main + + +def test_warp_reduce_sum(): + a = torch.randn((32,), dtype=torch.float32, device='cuda') + kernel = get_kernel('sum', 'float32') + ref = torch.full_like(a, a.sum()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_max(): + a = torch.randn((32,), dtype=torch.float32, device='cuda') + kernel = get_kernel("max", 'float32') + print(kernel.get_kernel_source()) + ref = torch.full_like(a, a.max()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_min(): + a = torch.randn((32,), dtype=torch.float32, device='cuda') + kernel = get_kernel("min", 'float32') + ref = torch.full_like(a, a.min()) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_bitand(): + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') + kernel = get_kernel("bitand", 'int32') + ref_val = a[0] + for i in range(1, a.shape[0]): + ref_val = ref_val & a[i] + ref = torch.full_like(a, ref_val) + kernel(a) + torch.testing.assert_close(a, ref) + + +def test_warp_reduce_bitor(): + a = torch.randint(0, 100, size=(32,), dtype=torch.int32, device='cuda') + kernel = get_kernel("bitor", 'int32') + ref_val = a[0] + for i in range(1, a.shape[0]): + ref_val = ref_val | a[i] + ref = torch.full_like(a, ref_val) + kernel(a) + torch.testing.assert_close(a, ref) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/__init__.py b/tilelang/language/__init__.py index 95488bdfc..75d8d0b4f 100644 --- a/tilelang/language/__init__.py +++ b/tilelang/language/__init__.py @@ -65,6 +65,11 @@ reduce_bitxor, # noqa: F401 cumsum, # noqa: F401 finalize_reducer, # noqa: F401 + warp_reduce_sum, # noqa: F401 + warp_reduce_max, # noqa: F401 + warp_reduce_min, # noqa: F401 + warp_reduce_bitand, # noqa: F401 + warp_reduce_bitor, # noqa: F401 ) from .print import print, device_assert # noqa: F401 from .customize import ( diff --git a/tilelang/language/reduce.py b/tilelang/language/reduce.py index 09289559d..23bb6d054 100644 --- a/tilelang/language/reduce.py +++ b/tilelang/language/reduce.py @@ -325,3 +325,83 @@ def finalize_reducer(reducer: tir.Buffer): tir.op.Op.get("tl.finalize_reducer"), reducer.access_ptr("w"), ) + + +def warp_reduce_sum(value: tir.PrimExpr): + """Perform warp reduction sum on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the sum of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced sum value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_sum"), value) + + +def warp_reduce_max(value: tir.PrimExpr): + """Perform warp reduction max on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the max of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced max value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_max"), value) + + +def warp_reduce_min(value: tir.PrimExpr): + """Perform warp reduction min on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the min of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced min value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_min"), value) + + +def warp_reduce_bitand(value: tir.PrimExpr): + """Perform warp reduction bitwise-and on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the bitwise-and of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced bitwise-and value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitand"), value) + + +def warp_reduce_bitor(value: tir.PrimExpr): + """Perform warp reduction bitwise-or on a register value. + + This function reduces a value across all threads in a warp using shuffle operations. + Each thread provides a register `value`, and after the reduction, all threads + will have the bitwise-or of all values across the warp. + + Args: + value (tir.PrimExpr): The input register value to reduce + + Returns: + tir.PrimExpr: The reduced bitwise-or value (same on all threads in the warp) + """ + return tir.call_intrin(value.dtype, tir.op.Op.get("tl.warp_reduce_bitor"), value)