diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index 02d93cad1..b90347a3f 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -13,7 +13,6 @@ #include "../layout/layout.h" #include "../target/utils.h" -#include "../transform/atomicadd_vectorize.h" #include "../transform/common/loop_fusion_utils.h" #include "../transform/loop_partition.h" #include "builtin.h" diff --git a/src/tl_templates/cuda/atomic.h b/src/tl_templates/cuda/atomic.h index f6096cc9d..46f5813a7 100644 --- a/src/tl_templates/cuda/atomic.h +++ b/src/tl_templates/cuda/atomic.h @@ -327,8 +327,8 @@ TL_DEVICE T1 AtomicAddRet(T1 *address, T2 val, } } -// TODO add memory_order for vectorized atomic add -TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, +template +TL_DEVICE void AtomicAddx2(half_t *ref, src_type *val, int memory_order = int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) { atomicAdd(reinterpret_cast(ref), @@ -374,8 +374,9 @@ TL_DEVICE void AtomicAddx2(half_t *ref, half_t *val, } } +template TL_DEVICE half2 -AtomicAddx2Ret(half_t *ref, half_t *val, +AtomicAddx2Ret(half_t *ref, src_type *val, int memory_order = int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) { return atomicAdd(reinterpret_cast(ref), @@ -419,7 +420,8 @@ AtomicAddx2Ret(half_t *ref, half_t *val, } #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ > 750)) -TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, +template +TL_DEVICE void AtomicAddx2(bfloat16_t *ref, src_type *val, int memory_order = int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) { atomicAdd( @@ -458,8 +460,9 @@ TL_DEVICE void AtomicAddx2(bfloat16_t *ref, bfloat16_t *val, } } +template TL_DEVICE __nv_bfloat162 -AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, +AtomicAddx2Ret(bfloat16_t *ref, src_type *val, int memory_order = int(cuda::memory_order_relaxed)) { if (memory_order == int(cuda::memory_order_relaxed)) { return atomicAdd( @@ -502,13 +505,19 @@ AtomicAddx2Ret(bfloat16_t *ref, bfloat16_t *val, #endif #if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900)) -TL_DEVICE void AtomicAddx2(float *ref, float *val, +template TL_DEVICE float2 ToFloat2(T *val) { + return *reinterpret_cast(val); +} + +TL_DEVICE float2 ToFloat2(float2 val) { return val; } + +template +TL_DEVICE void AtomicAddx2(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { + float2 add_val = ToFloat2(val); if (memory_order == int(cuda::memory_order_relaxed)) { - atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + atomicAdd(reinterpret_cast(ref), add_val); } else { - float2 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float2 ret_val; if (memory_order == int(cuda::memory_order_release) || @@ -532,14 +541,14 @@ TL_DEVICE void AtomicAddx2(float *ref, float *val, } } +template TL_DEVICE float2 -AtomicAddx2Ret(float *ref, float *val, +AtomicAddx2Ret(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { + float2 add_val = ToFloat2(val); if (memory_order == int(cuda::memory_order_relaxed)) { - return atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + return atomicAdd(reinterpret_cast(ref), add_val); } else { - float2 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float2 ret_val; if (memory_order == int(cuda::memory_order_release) || @@ -564,16 +573,22 @@ AtomicAddx2Ret(float *ref, float *val, } } -TL_DEVICE void AtomicAddx4(float *ref, float *val, +template TL_DEVICE float4 ToFloat4(T *val) { + return *reinterpret_cast(val); +} + +TL_DEVICE float4 ToFloat4(float4 val) { return val; } + +template +TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { + float4 add_val = ToFloat4(val); if (memory_order == int(cuda::memory_order_relaxed)) { - atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + atomicAdd(reinterpret_cast(ref), add_val); } else { // Since atomicAdd does not support memory order, atomic_ref does not // support vectorized atomic operation we can only inline ptx code here // Note: Vectorized atomic operations only support global space - float4 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float4 ret_val; if (memory_order == int(cuda::memory_order_release) || @@ -606,14 +621,14 @@ TL_DEVICE void AtomicAddx4(float *ref, float *val, } } +template TL_DEVICE float4 -AtomicAddx4Ret(float *ref, float *val, +AtomicAddx4Ret(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { + float4 add_val = ToFloat4(val); if (memory_order == int(cuda::memory_order_relaxed)) { - return atomicAdd(reinterpret_cast(ref), - static_cast(*reinterpret_cast(val))); + return atomicAdd(reinterpret_cast(ref), add_val); } else { - float4 add_val = *reinterpret_cast(val); unsigned long long ref_addr = reinterpret_cast(ref); float4 ret_val; if (memory_order == int(cuda::memory_order_release) || @@ -647,40 +662,56 @@ AtomicAddx4Ret(float *ref, float *val, } } #else -TL_DEVICE void AtomicAddx2(float *ref, float *val, +template TL_DEVICE float2 ToFloat2(T *val) { + return *reinterpret_cast(val); +} + +TL_DEVICE float2 ToFloat2(float2 val) { return val; } + +template TL_DEVICE float4 ToFloat4(T *val) { + return *reinterpret_cast(val); +} + +TL_DEVICE float4 ToFloat4(float4 val) { return val; } + +template +TL_DEVICE void AtomicAddx2(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; - float2 add_val = *reinterpret_cast(val); + float2 add_val = ToFloat2(val); atomicAdd(ref + 0, add_val.x); atomicAdd(ref + 1, add_val.y); } +template TL_DEVICE float2 -AtomicAddx2Ret(float *ref, float *val, +AtomicAddx2Ret(float *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; - float2 add_val = *reinterpret_cast(val); + float2 add_val = ToFloat2(val); float2 ret; ret.x = atomicAdd(ref + 0, add_val.x); ret.y = atomicAdd(ref + 1, add_val.y); return ret; } -TL_DEVICE void AtomicAddx4(float *ref, float *val, +template +TL_DEVICE void AtomicAddx4(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; - float4 add_val = *reinterpret_cast(val); + float4 add_val = ToFloat4(val); atomicAdd(ref + 0, add_val.x); atomicAdd(ref + 1, add_val.y); atomicAdd(ref + 2, add_val.z); atomicAdd(ref + 3, add_val.w); } +template TL_DEVICE float4 -AtomicAddx4Ret(float *ref, float *val, +AtomicAddx4Ret(dst_dtype *ref, ValType val, int memory_order = int(cuda::memory_order_relaxed)) { (void)memory_order; - float4 add_val = *reinterpret_cast(val); + float4 add_val = ToFloat4(val); float4 ret; ret.x = atomicAdd(ref + 0, add_val.x); ret.y = atomicAdd(ref + 1, add_val.y); diff --git a/src/transform/atomicadd_vectorize.cc b/src/transform/atomicadd_vectorize.cc deleted file mode 100644 index 8b4826e2f..000000000 --- a/src/transform/atomicadd_vectorize.cc +++ /dev/null @@ -1,162 +0,0 @@ -/*! - * \file atomicadd_vectorize.cc - * \brief Automatic vectorization pass for atomic add operations. - * - * This pass detects atomic_add_elem_op inside vectorized loops and converts - * them to vectorized versions (atomic_addx2_elem_op or atomic_addx4_elem_op). - */ - -#include "atomicadd_vectorize.h" - -namespace tvm { -namespace tl { - -using namespace tir; - -namespace { - -/*! - * \brief Extract BufferLoad from an expression that may be wrapped in - * address_of. - */ -Optional ExtractBufferLoad(const PrimExpr &expr) { - if (const auto *load = expr.as()) { - return tvm::ffi::GetRef(load); - } - if (const auto *call = expr.as()) { - if (call->op.same_as(builtin::address_of()) && !call->args.empty()) { - if (const auto *load = call->args[0].as()) { - return tvm::ffi::GetRef(load); - } - } - } - return Optional(); -} - -/*! - * \brief Get the vectorized atomic add op based on vector size. - */ -Op GetVectorizedAtomicOp(int vector_size) { - switch (vector_size) { - case 4: - return atomic_addx4_elem_op(); - case 2: - return atomic_addx2_elem_op(); - default: - return atomic_add_elem_op(); - } -} - -/*! - * \brief Rewriter that transforms atomic_add_elem_op inside vectorized loops. - * - * Strategy: Detect ForKind::kVectorized loops, use their extent as vector size, - * and convert atomic_add_elem_op to the corresponding vectorized version. - */ -class AtomicAddVectorizeRewriter : public StmtExprMutator { -public: - explicit AtomicAddVectorizeRewriter(Target target) : target_(target) {} - -private: - /*! - * \brief Get the max vector size supported by the given dtype. - */ - int GetMaxVectorSize(DataType dtype) const { - if (dtype.is_float16() || dtype.is_bfloat16()) { - return 2; - } - if (dtype.is_float() && dtype.bits() == 32 && - TargetHasSMVersionGE(target_, 90)) { - return 4; - } - return 1; - } - - Stmt VisitStmt_(const ForNode *node) final { - // Check if this is a vectorized loop - if (node->kind == ForKind::kVectorized) { - auto extent_ptr = as_const_int(node->extent); - if (!extent_ptr) { - return StmtExprMutator::VisitStmt_(node); - } - - int vec_size = static_cast(*extent_ptr); - // Push vectorized context - vectorized_loop_ = node; - vector_size_ = vec_size; - Stmt body = VisitStmt(node->body); - // If we successfully vectorized atomic ops, transform the loop - if (has_vectorized_atomic_) { - has_vectorized_atomic_ = false; - vectorized_loop_ = nullptr; - vector_size_ = 1; - // Change loop extent to 1 since atomic op now handles all elements - return For(node->loop_var, node->min, Integer(1), node->kind, body, - node->thread_binding, node->annotations, node->step, - node->span); - } - - vectorized_loop_ = nullptr; - vector_size_ = 1; - - if (body.same_as(node->body)) { - return tvm::ffi::GetRef(node); - } - return For(node->loop_var, node->min, node->extent, node->kind, body, - node->thread_binding, node->annotations, node->step, - node->span); - } - return StmtExprMutator::VisitStmt_(node); - } - - PrimExpr VisitExpr_(const CallNode *node) final { - if (node->op != atomic_add_elem_op() || node->args.size() < 2) { - return StmtExprMutator::VisitExpr_(node); - } - - // Must be inside a vectorized loop - if (!vectorized_loop_ || vector_size_ <= 1) { - return StmtExprMutator::VisitExpr_(node); - } - - auto dst_load = ExtractBufferLoad(node->args[0]); - auto src_load = ExtractBufferLoad(node->args[1]); - - if (!dst_load.defined() || !src_load.defined()) { - return StmtExprMutator::VisitExpr_(node); - } - - // Check if dtype supports this vector size - DataType dtype = dst_load.value()->buffer->dtype; - if (vector_size_ > GetMaxVectorSize(dtype)) { - return StmtExprMutator::VisitExpr_(node); - } - - // Mark that we have vectorized an atomic op - has_vectorized_atomic_ = true; - - // Create vectorized atomic op - Call addr_dst(DataType::Handle(), builtin::address_of(), - {dst_load.value()}); - Call addr_src(DataType::Handle(), builtin::address_of(), - {src_load.value()}); - - return Call(node->dtype, GetVectorizedAtomicOp(vector_size_), - {addr_dst, addr_src}); - } - - Target target_; - const ForNode *vectorized_loop_ = nullptr; - int vector_size_ = 1; - bool has_vectorized_atomic_ = false; -}; - -} // namespace - -For VectorizeAtomicAdd(const For &for_node) { - Target target = Target::Current(false); - return Downcast(AtomicAddVectorizeRewriter(target)(for_node)); -} - -} // namespace tl -} // namespace tvm diff --git a/src/transform/atomicadd_vectorize.h b/src/transform/atomicadd_vectorize.h deleted file mode 100644 index 470814a92..000000000 --- a/src/transform/atomicadd_vectorize.h +++ /dev/null @@ -1,34 +0,0 @@ -/*! - * \file atomicadd_vectorize.h - * \brief Vectorization pass for atomic add operations. - */ - -#ifndef TVM_TL_ATOMICADD_VECTORIZE_H_ -#define TVM_TL_ATOMICADD_VECTORIZE_H_ - -#include "../op/builtin.h" -#include "../target/utils.h" -#include -#include - -namespace tvm { -namespace tl { - -using namespace tir; - -/*! - * \brief Vectorize atomic add operations inside vectorized loops. - * - * This function detects atomic_add_elem_op inside ForKind::kVectorized loops - * and converts them to vectorized versions (atomic_addx2_elem_op or - * atomic_addx4_elem_op) based on the loop extent and data type. - * - * \param for_node The For loop to process. - * \return The transformed For loop. - */ -For VectorizeAtomicAdd(const For &for_node); - -} // namespace tl -} // namespace tvm - -#endif // TVM_TL_ATOMICADD_VECTORIZE_H_ diff --git a/src/transform/loop_partition.cc b/src/transform/loop_partition.cc index b7b15e4b5..35166e4b4 100644 --- a/src/transform/loop_partition.cc +++ b/src/transform/loop_partition.cc @@ -29,7 +29,6 @@ #include #include "../op/utils.h" -#include "atomicadd_vectorize.h" #include "loop_vectorize.h" namespace tvm { @@ -297,11 +296,7 @@ Stmt LowerParallelLoop(For loop, const Fragment &loop_layout, Var thread_var, if (should_vectorize) { result_loop = VectorizeLoop(result_loop, saved_analyzer.get(), layout_map); } - - // Step 3: Vectorize atomic add operations - result_loop = VectorizeAtomicAdd(result_loop); - - // Step 4: Wrap with predicate if provided and this is a parallel loop + // Step 3: Wrap with predicate if provided and this is a parallel loop if (predicate.defined() && parallel_loop) { return IfThenElse(predicate.value(), result_loop); } diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index 7eb5da87a..e494fc3b1 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -37,8 +37,11 @@ #include #include +#include "../op/builtin.h" +#include "../target/utils.h" #include "arith/scalable_expression.h" #include "tir/analysis/check_contains.h" +#include "tvm/ffi/cast.h" namespace tvm { namespace tl { @@ -118,6 +121,52 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) { return Broadcast(e, CreateNewLanes(is_scalable, lanes)); } +/*! + * \brief Extract BufferLoad from an expression that may be wrapped in + * address_of. + */ +inline Optional ExtractBufferLoadForAtomic(const PrimExpr &expr) { + if (const auto *load = expr.as()) { + return tvm::ffi::GetRef(load); + } + if (const auto *call = expr.as()) { + if (call->op.same_as(builtin::address_of()) && !call->args.empty()) { + if (const auto *load = call->args[0].as()) { + return tvm::ffi::GetRef(load); + } + } + } + return Optional(); +} + +/*! + * \brief Get the vectorized atomic add op based on vector size. + */ +inline Op GetVectorizedAtomicOp(int vector_size) { + switch (vector_size) { + case 4: + return atomic_addx4_elem_op(); + case 2: + return atomic_addx2_elem_op(); + default: + return atomic_add_elem_op(); + } +} + +/*! + * \brief Get the max vector size supported by the given dtype for atomic ops. + */ +inline int GetMaxAtomicVectorSize(DataType dtype, Target target) { + if (dtype.is_float16() || dtype.is_bfloat16()) { + return 2; + } + if (dtype.is_float() && dtype.bits() == 32 && + TargetHasSMVersionGE(target, 90)) { + return 4; + } + return 1; +} + // Rewrite vectorized allocation access // This is necessary for making each vector component containing its own // workspace. Originates from Halide's loop vectorizer @@ -449,6 +498,32 @@ class TLVectorizer : public StmtMutator, } } } + // Address of: remove vectorized var from indices to get base address + // e.g., T.address_of(buf[base + vec]) -> T.address_of(buf[base]) + PrimExpr MutateAddressOfCall_(const CallNode *op) { + ICHECK(op->op.same_as(builtin::address_of())); + ICHECK_EQ(op->args.size(), 1); + + auto buffer_load = op->args[0].as(); + if (!buffer_load) { + return tvm::ffi::GetRef(op); + } + + // Remove the vectorized var from indices by substituting var_ with 0 + Array new_indices; + for (const auto &index : buffer_load->indices) { + PrimExpr new_index = Substitute(index, {{var_, IntImm(var_->dtype, 0)}}); + new_indices.push_back(analyzer_.Simplify(new_index)); + } + + BufferLoad new_load = GetRef(buffer_load); + if (!new_indices.same_as(buffer_load->indices)) { + auto writer = new_load.CopyOnWrite(); + writer->indices = new_indices; + } + + return Call(op->dtype, op->op, {new_load}); + } // Reinterpret expr PrimExpr MutateReinterpretExpr_(const CallNode *op) { ICHECK(op->op.same_as(builtin::reinterpret())); @@ -465,6 +540,37 @@ class TLVectorizer : public StmtMutator, } } } + // Atomic add vectorization + PrimExpr MutateAtomicAddExpr_(const CallNode *op) { + ICHECK(op->op.same_as(atomic_add_elem_op())); + + // Must have at least 2 args (dst_ptr and src) + if (op->args.size() < 2) { + return tvm::ffi::GetRef(op); + } + + // Get the vector size from var_lanes_ + auto lanes_ptr = as_const_int(var_lanes_); + if (!lanes_ptr || *lanes_ptr <= 1) { + // Not in vectorized context or vector size is 1 + return tvm::ffi::GetRef(op); + } + int vector_size = static_cast(*lanes_ptr); + auto dst = VisitExpr(op->args[0]); + auto src = VisitExpr(op->args[1]); + // Check if dtype supports this vector size + auto dst_buffer_load = ExtractBufferLoadForAtomic(dst); + Target target = Target::Current(false); + int max_vec_size = + GetMaxAtomicVectorSize(dst_buffer_load.value()->buffer->dtype, target); + if (vector_size > max_vec_size) { + // Vector size not supported for this dtype, cannot vectorize + return tvm::ffi::GetRef(op); + } + + // Return the vectorized atomic op + return Call(op->dtype, GetVectorizedAtomicOp(vector_size), {dst, src}); + } // Call PrimExpr VisitExpr_(const CallNode *op) final { if (op->op.same_as(builtin::if_then_else())) { @@ -486,6 +592,11 @@ class TLVectorizer : public StmtMutator, return Call(op->dtype.with_lanes(lane), op->op, new_args); } else if (op->op.same_as(builtin::reinterpret())) { return MutateReinterpretExpr_(op); + } else if (op->op.same_as(atomic_add_elem_op())) { + // Handle vectorization of atomic_add_elem_op + return MutateAtomicAddExpr_(op); + } else if (op->op.same_as(builtin::address_of())) { + return MutateAddressOfCall_(op); } auto optional_op = op->op.as(); bool vectorizable = optional_op && diff --git a/testing/python/language/test_tilelang_language_atomic.py b/testing/python/language/test_tilelang_language_atomic.py index a96d0ac6f..44ee97e49 100644 --- a/testing/python/language/test_tilelang_language_atomic.py +++ b/testing/python/language/test_tilelang_language_atomic.py @@ -476,5 +476,43 @@ def test_tma_atomic_add(): assert kernel.get_kernel_source() == kernel_with_explicit_swizzle.get_kernel_source() +def run_atomic_add_auto_vectorized(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_add_program(K, M, N, block_M, block_N, dtype=dtype) + assert "AtomicAddx4" in kernel.get_kernel_source() + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_atomic_add_auto_vectorized(): + run_atomic_add_auto_vectorized(8, 128, 128, 32, 32) + + +@tilelang.jit +def atomic_add_complicated_parallel_program(K, M, N, block_M, block_N, dtype=T.float32): + @T.prim_func + def atomic_add(A: T.Tensor((K, M, N), dtype), B: T.Tensor((M, N), dtype)): + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), K, threads=32) as (bx, by, bz): + A_shared = T.alloc_shared((block_M, block_N), dtype) + + T.copy(A[bz, bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N], A_shared) + + for i, j in T.Parallel(block_M, block_N): + value = A_shared[i, j] + T.atomic_add(B[bx * block_M + i, by * block_N + j], value) + + return atomic_add + + +def run_atomic_add_complicated_parallel(K, M, N, block_M, block_N, dtype=T.float32): + kernel = atomic_add_complicated_parallel_program(K, M, N, block_M, block_N, dtype=dtype) + assert "float4 value" in kernel.get_kernel_source() + assert "AtomicAddx4" in kernel.get_kernel_source() + + +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_atomic_add_complicated_parallel(): + run_atomic_add_complicated_parallel(8, 128, 128, 32, 32) + + if __name__ == "__main__": tilelang.testing.main()