diff --git a/examples/autodd/tilelang_buggy.py b/examples/autodd/tilelang_buggy.py index 47e71fe50..d2c5469bb 100644 --- a/examples/autodd/tilelang_buggy.py +++ b/examples/autodd/tilelang_buggy.py @@ -73,14 +73,10 @@ def get_grid_size(self): return grid_x, grid_y def get_shared_memory_size(self): - return get_memory_requirements( - self.M, self.N, self.K, self.block_M, self.block_N, self.block_K - ) + return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) def validate(self): - return validate_parameters( - self.M, self.N, self.K, self.block_M, self.block_N, self.block_K - ) + return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) def create_reference_output(a, b, activation="relu"): @@ -107,6 +103,7 @@ def benchmark_pytorch(M, N, K, num_iters=10, warmup=5): # Benchmark import time + start = time.time() for _ in range(num_iters): _ = a @ b diff --git a/examples/autodd/tilelang_minimized_expected.py b/examples/autodd/tilelang_minimized_expected.py index 2135f6fce..3dc88f992 100644 --- a/examples/autodd/tilelang_minimized_expected.py +++ b/examples/autodd/tilelang_minimized_expected.py @@ -13,7 +13,6 @@ class MatmulConfig: - def __init__(self, *args, **kwargs): self.M = 1 self.N = 1 @@ -24,7 +23,6 @@ def __init__(self, *args, **kwargs): def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): - @T.prim_func def matmul_kernel(): with T.Kernel(): @@ -45,7 +43,7 @@ def main(*args, **kwargs): try: run_kernel(config) except Exception as e: - print(f'{e}') + print(f"{e}") main() diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index 4a5290c28..058b6c2f4 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -209,14 +209,6 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout( - { - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - } - ) - T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) @@ -387,7 +379,6 @@ def maybe_contiguous(x): block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) - mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do) if ctx.use_atomic: @@ -401,11 +392,11 @@ def maybe_contiguous(x): dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split_novarlen( BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups ) + mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index 1bc8fd1eb..a8d0d153e 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -286,14 +286,6 @@ def flash_bwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout( - { - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - } - ) - T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) @@ -541,7 +533,6 @@ def maybe_contiguous(x): block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) - mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do, cu_seqlens_q) if ctx.use_atomic: @@ -565,7 +556,6 @@ def maybe_contiguous(x): dk = torch.zeros_like(k, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split( BATCH, @@ -583,6 +573,7 @@ def maybe_contiguous(x): num_stages=2, groups=groups, ) + mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) diff --git a/src/op/atomic_add.cc b/src/op/atomic_add.cc index c65641374..538f59fa9 100644 --- a/src/op/atomic_add.cc +++ b/src/op/atomic_add.cc @@ -5,6 +5,7 @@ */ #include "./atomic_add.h" +#include "./copy.h" #include "utils.h" #include #include @@ -303,39 +304,104 @@ For AtomicAddNode::MakeSIMTLoop(arith::Analyzer *analyzer) const { return Downcast(body); } +/** + * @brief Compute linear layout for shared tensor (used in TMA atomic add). + * + * Creates a tiled layout that splits each dimension into blocks of 256 + * elements. The layout maps [i, j, ...] to [i // 256, j // 256, ..., i % 256, j + * % 256, ...]. + * + * @param shared_tensor The shared memory buffer to compute layout for. + * @return Layout A tiled linear layout for the buffer. + */ +Layout AtomicAddNode::ComputeLinearLayout(const Buffer &shared_tensor) const { + Array input_size = shared_tensor->shape; + Array forward_vars; + for (size_t i = 0; i < input_size.size(); i++) { + forward_vars.push_back(InputPlaceholder(i)); + } + // [i, j] -> [i // 256, j // 256, i % 256, j % 256] + Array forward_index; + for (size_t i = 0; i < input_size.size(); i++) { + forward_index.push_back(FloorDiv(forward_vars[i], 256)); + } + for (size_t i = 0; i < input_size.size(); i++) { + forward_index.push_back(FloorMod(forward_vars[i], 256)); + } + return Layout(input_size, forward_index); +} + /** * @brief Infer and return the layout map for the atomic add operator. * - * Constructs a cached ParallelOp (by building the SIMT loop) if not already - * present, validates that local.fragment layouts for src and dst match when - * both are provided, and then delegates layout inference to the underlying - * ParallelOp. + * For TMA atomic add operations (when use_tma=True): + * - src is always shared memory, dst is always global memory + * - Automatically applies swizzle layout to the shared memory buffer when + * the operation is not 1D, improving memory access efficiency + * + * For non-TMA atomic add operations: + * - Returns empty layout map (no layout inference needed) * * @param T Layout inference inputs, including an optional mapping of buffers to * layouts. * @param level Inference strictness level. * @return LayoutMap The inferred layout mapping for buffers used by this * operator. - * - * @note This method mutates the AtomicAddNode by creating and storing a - * ParallelOp on first invocation. - * @throws If both src and dst have layouts in `local.fragment` and their - * fragment layouts differ, an ICHECK failure is raised with diagnostic output. */ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, InferLevel level) const { - if (T.layout_map.count(src) && T.layout_map.count(dst)) { - if (IsFragmentBuffer(src) && IsFragmentBuffer(dst)) { - const FragmentNode *src_layout = T.layout_map[src].as(); - const FragmentNode *dst_layout = T.layout_map[dst].as(); - if (src_layout && dst_layout) { - ICHECK(src_layout->IsEqual(dst_layout, true)) - << "Get different layout for " << src << " and " << dst - << "\nLHS = " << src_layout->DebugOutput() - << "\nRHS = " << dst_layout->DebugOutput() - << "\nYou may need to use a shared memory to transform the layout"; + // Handle TMA atomic add layout inference + if (GetUseTMA()) { + Map result_map; + + // For TMA atomic add: src is shared memory, dst is global memory + Buffer shared_tensor = src; + Array shared_range = src_range; + + // Check if this is 1D TMA + bool is_tma_1d = shared_range.size() == 1; + + if (is_tma_1d) { + // 1D TMA atomic add with single dimension cannot be swizzled + return result_map; + } + + // For non-1D TMA atomic add, apply swizzle layout if possible + if (level == InferLevel::kFree && !T.layout_map.count(shared_tensor)) { + // TMA atomic add is similar to TMA Store - we should perform swizzle if + // possible Use the last two dimensions to analyze swizzling + int dim = shared_tensor->shape.size(); + const int64_t mat_stride = *as_const_int(shared_tensor->shape[dim - 2]); + const int64_t mat_continuous = + *as_const_int(shared_tensor->shape[dim - 1]); + Layout swizzle_layout = + makeGemmABLayoutHopper(mat_stride, mat_continuous, mat_continuous, + shared_tensor->dtype.bits(), /*k_inner=*/true); + // If makeGemmABLayoutHopper returns a linear layout, fallback to + // ComputeLinearLayout which handles arbitrary tensor shapes correctly. + if (StructuralEqual()(swizzle_layout, makeLinearLayout(Array{ + Integer(mat_stride), + Integer(mat_continuous)}))) { + result_map.Set(shared_tensor, ComputeLinearLayout(shared_tensor)); + } else { + result_map.Set(shared_tensor, swizzle_layout); } } + + return result_map; + } + + // For non-TMA atomic add, check that src and dst have the same layout if both + // are fragments + if (IsFragmentBuffer(src) && IsFragmentBuffer(dst)) { + if (T.layout_map.count(src) && T.layout_map.count(dst)) { + Layout src_layout = T.layout_map.at(src); + Layout dst_layout = T.layout_map.at(dst); + ICHECK(StructuralEqual()(src_layout, dst_layout)) + << "AtomicAdd requires src and dst to have the same layout, but got " + << "src layout: " << src_layout << ", dst layout: " << dst_layout + << " for src buffer: " << src->name << ", dst buffer: " << dst->name; + } } return {}; } @@ -378,30 +444,217 @@ LayoutMap AtomicAddNode::InferLayout(const LayoutInferArgs &T, Stmt AtomicAddNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Target target = T.target; if (GetUseTMA()) { - Array src_indices, dst_indices; - PrimExpr src_size, dst_size; - std::tie(src_indices, src_size) = ReturnIndicesAndSize(0); - std::tie(dst_indices, dst_size) = ReturnIndicesAndSize(1); - ICHECK(analyzer->CanProveEqual(src_size, dst_size)) - << "src_size = " << src_size << ", dst_size = " << dst_size; - BufferLoad src_node = BufferLoad(src, src_indices); - BufferLoad dst_node = BufferLoad(dst, dst_indices); - Call address_of_src = - Call(DataType::Handle(), builtin::address_of(), {src_node}); - Call address_of_dst = - Call(DataType::Handle(), builtin::address_of(), {dst_node}); - - int need_reduce = 1; - int eviction_policy = 0; + // For AtomicAdd with TMA: src is shared memory, dst is global memory + // Use cp.reduce.async.bulk.tensor instruction with tensor descriptor + Buffer shared_tensor = src; + Buffer global_tensor = dst; + Array shared_range = src_range; + Array global_range = dst_range; + + // Build TMADesc for the global tensor + TMADesc desc; + desc.rank = global_tensor->shape.size(); + ICHECK(desc.rank >= 1 && desc.rank <= 5) + << "TMA reduce only supports 1-5 dimensions, got " << desc.rank; + + // Data type must match + ICHECK(global_tensor->dtype == shared_tensor->dtype) + << "AtomicAdd between buffer " << shared_tensor->name << " and " + << global_tensor->name << " with different data type " + << shared_tensor->dtype << " and " << global_tensor->dtype; + + desc.data_type = to_CUtensorMapDataType(global_tensor->dtype); + + // Global tensor shape and stride + desc.global_addr = global_tensor->data; + desc.global_shape = ReverseArray(global_tensor->shape); + Array global_coords = + ReverseArray(global_range.Map([](Range r) { return r->min; })); + + if (!global_tensor->strides.empty()) { + desc.global_stride = ReverseArray(global_tensor->strides); + } else { + // Create stride from shape (row-major) + PrimExpr stride = 1; + desc.global_stride.reserve(desc.rank); + for (size_t i = 0; i < desc.rank; i++) { + desc.global_stride.push_back(stride); + stride *= desc.global_shape[i]; + } + } + // Make global stride in bytes + desc.global_stride = desc.global_stride.Map([&](PrimExpr e) { + return cast(DataType::Int(64), e) * global_tensor->dtype.bytes(); + }); + + // Shared memory box (copy extent) + desc.smem_box = + ReverseArray(global_range.Map([](Range r) { return r->extent; })); + desc.smem_stride = Array(desc.rank, PrimExpr(1)); + + // L2 & OOB settings + desc.l2_promotion = static_cast(CU_TENSOR_MAP_L2_PROMOTION_L2_128B); + desc.oob_fill = static_cast(CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE); + + // Detect smem layout for swizzle (similar to copy.cc) + // linear layout must be computed before remapping + auto linear_layout = makeLinearLayout(shared_tensor->shape); + desc.interleave = static_cast(CU_TENSOR_MAP_INTERLEAVE_NONE); + Layout shared_layout; + if (T.layout_map.count(shared_tensor)) { + shared_layout = T.layout_map.at(shared_tensor); + ICHECK(T.buffer_remap.count(shared_tensor)) + << "shared_tensor: " << shared_tensor->name + << " not found in buffer_remap"; + shared_tensor = T.buffer_remap.at(shared_tensor); + } + if (!shared_layout.defined()) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else if (StructuralEqual()(shared_layout, linear_layout)) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else { + ICHECK(shared_layout->InputDim() == 2) << "Cannot detect TMA layout."; + auto stride = as_const_int(shared_layout->InputShape()[0]); + auto continuous = as_const_int(shared_layout->InputShape()[1]); + ICHECK(stride != nullptr && continuous != nullptr); + if (StructuralEqual()(shared_layout, makeQuarterBankSwizzleLayout( + *stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_32B); + } else if (StructuralEqual()( + shared_layout, + makeHalfBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_64B); + } else if (StructuralEqual()( + shared_layout, + makeFullBankSwizzleLayout(*stride, *continuous, + shared_tensor->dtype.bits()))) { + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_128B); + } else if (StructuralEqual()( + shared_layout, + makeGemmABLayoutPadded(*stride, *continuous, + shared_tensor->dtype.bits()))) { + LOG(WARNING) << "AtomicAdd TMA cannot support a padded layout for src: " + << src->name << ", dst: " << dst->name; + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } else { + LOG(WARNING) << "AtomicAdd TMA unsupported swizzle layout for src: " + << src->name << ", dst: " << dst->name; + desc.swizzle = static_cast(CU_TENSOR_MAP_SWIZZLE_NONE); + } + } + + // Adjust instruction_dim based on swizzle type (similar to copy.cc) + auto inner_box_dim = as_const_int(desc.smem_box[0]); + ICHECK(inner_box_dim != nullptr) + << "inner_box_dim must be a constant integer for TMA atomic add"; + int instruction_dim = *inner_box_dim; + if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_64B)) { + instruction_dim = 64 / shared_tensor->dtype.bytes(); + } else if (desc.swizzle == static_cast(CU_TENSOR_MAP_SWIZZLE_128B)) { + instruction_dim = 128 / shared_tensor->dtype.bytes(); + } + if (instruction_dim > 256) { + ICHECK((*inner_box_dim) % 256 == 0) + << "inner_box_dim: " << *inner_box_dim << " is not divisible by 256"; + instruction_dim = 256; + } + ICHECK((*inner_box_dim) % instruction_dim == 0) + << "inner_box_dim: " << *inner_box_dim + << " is not divisible by instruction_dim: " << instruction_dim; + desc.smem_box.Set(0, PrimExpr(instruction_dim)); + + int inner_box_dim_ = instruction_dim * shared_tensor->dtype.bytes(); + // Check inner_box_dim_ for each swizzle type + struct SwizzleCheck { + int swizzle; + int max_dim; + }; + static const std::vector swizzle_checks = { + {static_cast(CU_TENSOR_MAP_SWIZZLE_32B), 32}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_64B), 64}, + {static_cast(CU_TENSOR_MAP_SWIZZLE_128B), 128}, + }; + for (const auto &check : swizzle_checks) { + if (desc.swizzle == check.swizzle && inner_box_dim_ > check.max_dim) { + LOG(WARNING) << "AtomicAdd TMA cannot support swizzled layout with " + "inner_box_dim_ > " + << check.max_dim; + } + } + + // Compute shared memory offset + Array shared_indices; + for (auto r : shared_range) + shared_indices.push_back(r->min); + std::vector shared_strides; + PrimExpr shared_stride = 1; + for (size_t i = 0; i < shared_tensor->shape.size(); i++) { + auto s = shared_tensor->shape[shared_tensor->shape.size() - i - 1]; + shared_strides.insert(shared_strides.begin(), shared_stride); + shared_stride *= s; + } + PrimExpr shared_offset = 0; + for (size_t i = 0; i < shared_indices.size(); i++) { + shared_offset += shared_indices[i] * shared_strides[i]; + } + + // Create TMA descriptor + Call create_descriptor = Call(DataType::Handle(), create_tma_descriptor(), + desc.EncodeCallArgs()); + + // Compute total elements for access_ptr + PrimExpr total_elements = 1; + for (auto e : desc.smem_box) + total_elements *= e; + // erase use_tma from annotations - auto annotations = this->annotations; - annotations.erase("use_tma"); - auto body = Evaluate(Call(DataType::Handle(), tma_store(), - {address_of_src, address_of_dst, - ceildiv(src_size * src->dtype.bits(), 8), - need_reduce, eviction_policy}, - annotations)); - return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), body); + auto op_annotations = this->annotations; + op_annotations.erase("use_tma"); + + Stmt tma_reduce; + if ((*inner_box_dim) != instruction_dim) { + // Need to split the operation into multiple TMA calls + Var loop_var("i"); + int loop_extent = (*inner_box_dim) / instruction_dim; + + Array args; + args.reserve(desc.rank + 4); + args.push_back(create_descriptor); + PrimExpr shared_addr = shared_tensor.access_ptr( + 1, DataType::Handle(), 1, shared_offset + total_elements * loop_var, + total_elements); + args.push_back(shared_addr); + Array loop_global_coords = global_coords; + loop_global_coords.Set(0, global_coords[0] + instruction_dim * loop_var); + for (auto coord : loop_global_coords) + args.push_back(coord); + int need_reduce = 1; + args.push_back(need_reduce); + int eviction_policy = 0; + args.push_back(eviction_policy); + tma_reduce = For(loop_var, 0, loop_extent, ForKind::kUnrolled, + Evaluate(Call(DataType::Handle(), tma_store(), args, + op_annotations))); + } else { + Array args; + args.reserve(desc.rank + 4); + args.push_back(create_descriptor); + PrimExpr shared_addr = shared_tensor.access_ptr( + 1, DataType::Handle(), 1, shared_offset, total_elements); + args.push_back(shared_addr); + for (auto coord : global_coords) + args.push_back(coord); + int need_reduce = 1; + args.push_back(need_reduce); + int eviction_policy = 0; + args.push_back(eviction_policy); + tma_reduce = + Evaluate(Call(DataType::Handle(), tma_store(), args, op_annotations)); + } + + return IfThenElse(EQ(T.thread_var, T.thread_bounds->min), tma_reduce); } auto simt_loop = MakeSIMTLoop(analyzer); auto fused_loop = Downcast(ParallelLoopFuser::Fuse(simt_loop)); diff --git a/src/op/atomic_add.h b/src/op/atomic_add.h index 56f48839f..f13e827a5 100644 --- a/src/op/atomic_add.h +++ b/src/op/atomic_add.h @@ -77,6 +77,8 @@ class AtomicAddNode : public TileOperatorNode { /// Create boundary predicate for memory safety PrimExpr MakePredicate(arith::Analyzer *analyzer, const Array &ivs, Array extents, int src_dst) const; + /// Compute linear layout for shared tensor (used in TMA atomic add) + Layout ComputeLinearLayout(const Buffer &shared_tensor) const; }; /// Wrapper class for atomic addition operations diff --git a/src/op/copy.cc b/src/op/copy.cc index 070df4305..7f91d4c38 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -17,8 +17,6 @@ #include "../transform/loop_vectorize.h" #include "utils.h" -#include "../target/stubs/cuda.h" -#include "../target/utils.h" #include "builtin.h" #include #include @@ -30,75 +28,6 @@ namespace tl { using namespace tir; -// Maps TVM DataType to CUDA's CUtensorMapDataType enum value. -static int to_CUtensorMapDataType(DataType dtype) { - CUtensorMapDataType tp; - if (dtype.is_float()) { - switch (dtype.bits()) { - case 64: - tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; - break; - case 32: - tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - break; - case 16: - tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - break; - case 8: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - default: - ICHECK(0) << dtype; - } - } else if (dtype.is_bfloat16()) { - tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; - } else if (dtype.is_float8()) { - tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; - } else if (dtype.is_int()) { - switch (dtype.bits()) { - case 64: - tp = CU_TENSOR_MAP_DATA_TYPE_INT64; - break; - case 32: - tp = CU_TENSOR_MAP_DATA_TYPE_INT32; - break; - case 16: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 8: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - default: - ICHECK(0) << dtype; - } - } else if (dtype.is_uint()) { - switch (dtype.bits()) { - case 64: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT64; - break; - case 32: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT32; - break; - case 16: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; - break; - case 8: - tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; - break; - default: - ICHECK(0) << dtype; - } - } else { - ICHECK(0) << dtype; - } - return static_cast(tp); -} - -// Reverses an array (used for row-major/column-major layout conversion). -template static Array ReverseArray(Array array) { - return Array{array.rbegin(), array.rend()}; -} - // Constructs a Copy operator node from call arguments and annotations. // args[0]: source region, args[1]: destination region // annotations: Map containing coalesced_width, disable_tma, eviction_policy, diff --git a/src/op/utils.cc b/src/op/utils.cc index 7e56ae8c7..042c38a0c 100644 --- a/src/op/utils.cc +++ b/src/op/utils.cc @@ -92,5 +92,69 @@ PrimExpr MakeAccessPtrFromRegion(const BufferRegion ®ion, int rw_mask, return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); } +// Maps TVM DataType to CUDA's CUtensorMapDataType enum value. +int to_CUtensorMapDataType(DataType dtype) { + CUtensorMapDataType tp; + if (dtype.is_float()) { + switch (dtype.bits()) { + case 64: + tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT64; + break; + case 32: + tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + break; + case 8: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + ICHECK(0) << dtype; + } + } else if (dtype.is_bfloat16()) { + tp = CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (dtype.is_float8()) { + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + } else if (dtype.is_int()) { + switch (dtype.bits()) { + case 64: + tp = CU_TENSOR_MAP_DATA_TYPE_INT64; + break; + case 32: + tp = CU_TENSOR_MAP_DATA_TYPE_INT32; + break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 8: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + ICHECK(0) << dtype; + } + } else if (dtype.is_uint()) { + switch (dtype.bits()) { + case 64: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT64; + break; + case 32: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT32; + break; + case 16: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT16; + break; + case 8: + tp = CU_TENSOR_MAP_DATA_TYPE_UINT8; + break; + default: + ICHECK(0) << dtype; + } + } else { + ICHECK(0) << dtype; + } + return static_cast(tp); +} + } // namespace tl } // namespace tvm diff --git a/src/op/utils.h b/src/op/utils.h index 8ff194cfb..fcbfee9e2 100644 --- a/src/op/utils.h +++ b/src/op/utils.h @@ -6,6 +6,7 @@ #ifndef TVM_TL_OP_UTILS_H_ #define TVM_TL_OP_UTILS_H_ +#include "../target/stubs/cuda.h" #include "./operator.h" #include "region.h" #include @@ -16,6 +17,14 @@ namespace tl { using namespace tir; +// Maps TVM DataType to CUDA's CUtensorMapDataType enum value. +TVM_DLL int to_CUtensorMapDataType(DataType dtype); + +// Reverses an array (used for row-major/column-major layout conversion). +template Array ReverseArray(Array array) { + return Array{array.rbegin(), array.rend()}; +} + // Normalize an argument (BufferRegion/BufferLoad/tl.region) // to BufferRegion so ops can uniformly consume regions. // Note: tvm_access_ptr is no longer supported here. diff --git a/src/tl_templates/cuda/copy_sm90.h b/src/tl_templates/cuda/copy_sm90.h index 0b51450b3..3d5b3f414 100644 --- a/src/tl_templates/cuda/copy_sm90.h +++ b/src/tl_templates/cuda/copy_sm90.h @@ -262,6 +262,74 @@ TL_DEVICE void tma_store_add(float *const smem_ptr, float *gmem_ptr, : "memory"); } +TL_DEVICE void tma_store_add(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group " + "[%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0) + : "memory"); +} + +TL_DEVICE void tma_store_add(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group " + "[%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1) + : "memory"); +} + +TL_DEVICE void tma_store_add(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group " + "[%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +} + +TL_DEVICE void tma_store_add(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group " + "[%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), + "r"(crd3) + : "memory"); +} + +TL_DEVICE void tma_store_add(const CUtensorMap &descriptor, + void const *const smem_ptr, int32_t const &crd0, + int32_t const &crd1, int32_t const &crd2, + int32_t const &crd3, int32_t const &crd4) { + uint64_t gmem_int_desc = reinterpret_cast(&descriptor); + uint32_t smem_int_ptr = smem_ptr_to_uint(smem_ptr); + asm volatile( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group " + "[%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), "r"(crd0), "r"(crd1), "r"(crd2), + "r"(crd3), "r"(crd4) + : "memory"); +} + TL_DEVICE void prefetch_tma_descriptor(const CUtensorMap &descriptor) { uint64_t gmem_int_desc = reinterpret_cast(&descriptor); asm volatile("prefetch.tensormap [%0];" : : "l"(gmem_int_desc) : "memory"); diff --git a/testing/python/language/test_tilelang_language_atomic_add.py b/testing/python/language/test_tilelang_language_atomic_add.py index b3c94a742..8b3253b95 100644 --- a/testing/python/language/test_tilelang_language_atomic_add.py +++ b/testing/python/language/test_tilelang_language_atomic_add.py @@ -1,5 +1,7 @@ import tilelang.testing +import tilelang.layout import tilelang.language as T +import torch @tilelang.jit @@ -350,6 +352,21 @@ def run_atomic_return_prev(M, N, block_M, block_N, dtype=T.float32): torch.testing.assert_close(B, initial_B + A, atol=1e-3, rtol=1e-3) +@tilelang.jit +def tma_atomic_add_program(out, explicit_swizzle=False): + out: T.Tensor[(16, 16), T.float32] + + with T.Kernel( + 1, + ): + out_shared = T.alloc_shared((16, 16), dtype=T.float32) + if explicit_swizzle: + T.annotate_layout({out_shared: tilelang.layout.make_swizzled_layout(out_shared)}) + T.fill(out_shared, 1) + for _ in range(16): + T.atomic_add(out, out_shared, use_tma=True) + + @tilelang.testing.requires_cuda def test_atomic_different_memory_orders(): run_atomic_different_memory_orders(32, 32, 8, 8, dtype=T.float32) @@ -369,5 +386,20 @@ def test_tile_atomic_add(): run_tile_atomic_add(8, 128, 128, 32, 32) +@tilelang.testing.requires_cuda +def test_tma_atomic_add(): + out = torch.zeros((16, 16), dtype=torch.float32, device="cuda") + tma_atomic_add_program(out) + torch.testing.assert_close(out, torch.ones((16, 16), dtype=torch.float32, device="cuda") * 16) + + kernel = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32]) + assert "tma_store_add" in kernel.get_kernel_source() + assert "desc" in kernel.get_kernel_source() # Ensure using cp.reduce.async.bulk.tensor + + kernel_with_explicit_swizzle = tma_atomic_add_program.compile(out=T.Tensor[(16, 16), T.float32], explicit_swizzle=True) + # Ensure auto swizzled layout is applied + assert kernel.get_kernel_source() == kernel_with_explicit_swizzle.get_kernel_source() + + if __name__ == "__main__": tilelang.testing.main()