From 29f26833641682de1bc0f8b7051bb4c01bb36460 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:19:11 +0800 Subject: [PATCH 01/14] [Layout] Add read write check in layout inference --- src/op/parallel.cc | 123 ++++++++++++++++++++++++++++++--------------- src/op/parallel.h | 26 +++++++--- 2 files changed, 102 insertions(+), 47 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 5fa90ba95..c3f9cc776 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -173,27 +173,14 @@ void ParallelLoopNestVisitor::VisitStmt_(const ForNode *op) { void ParallelLoopNestVisitor::VisitStmt_(const BufferStoreNode *op) { if (IsFragmentBuffer(op->buffer)) { - if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { - ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) - << op->buffer << ": " << op->indices << " and " - << p->indice_map_.at(op->buffer); - } else { - p->indice_map_.Set(op->buffer, op->indices); - } - p->buffer_is_write_.insert(op->buffer); + p->RecordBufferAccess(op->buffer, op->indices, /*is_write=*/true); } StmtExprVisitor::VisitStmt_(op); } void ParallelLoopNestVisitor::VisitExpr_(const BufferLoadNode *op) { if (IsFragmentBuffer(op->buffer)) { - if (p->indice_map_.find(op->buffer) != p->indice_map_.end()) { - ICHECK(StructuralEqual()(p->indice_map_.at(op->buffer), op->indices)) - << op->buffer << ": " << op->indices << " and " - << p->indice_map_.at(op->buffer); - } else { - p->indice_map_.Set(op->buffer, op->indices); - } + p->RecordBufferAccess(op->buffer, op->indices, /*is_write=*/false); } StmtExprVisitor::VisitExpr_(op); } @@ -226,8 +213,8 @@ void ParallelOpNode::ExpandLetBindings( std::function expand = [&](const PrimExpr &expr) { PostOrderVisit(expr, [&](const ObjectRef &node) { if (auto bl = node.as()) { - if (IsFragmentBuffer(bl->buffer) && !indice_map_.count(bl->buffer)) { - indice_map_.Set(bl->buffer, bl->indices); + if (IsFragmentBuffer(bl->buffer)) { + RecordBufferAccess(bl->buffer, bl->indices, /*is_write=*/false); } } else if (auto var_node = node.as()) { auto var = tvm::ffi::GetRef(var_node); @@ -255,6 +242,33 @@ void ParallelOpNode::ExpandLetBindings( } } +void ParallelOpNode::RecordBufferAccess(const Buffer &buffer, + const Array &indices, + bool is_write) { + auto it = indice_map_.find(buffer); + if (it != indice_map_.end()) { + ICHECK(StructuralEqual()(it->second.indices, indices)) + << buffer << ": " << indices << " and " << it->second.indices; + } else { + BufferAccessInfo info; + info.indices = indices; + it = indice_map_.emplace(buffer, std::move(info)).first; + } + if (is_write) { + it->second.is_write = true; + } else { + it->second.is_read = true; + } +} + +const ParallelOpNode::BufferAccessInfo & +ParallelOpNode::GetAccessInfo(const Buffer &buffer) const { + auto it = indice_map_.find(buffer); + ICHECK(it != indice_map_.end()) + << "Missing access info for buffer " << buffer; + return it->second; +} + Stmt ParallelOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return root_; @@ -264,7 +278,7 @@ Stmt ParallelOpNode::Lower(const LowerArgs &T, bool ParallelOpNode::IsCommonAccessIndice(const Buffer &buffer) const { auto common_indice = loop_vars_.Map([](const auto &iv) { return iv->var; }); - return StructuralEqual()(indice_map_[buffer], common_indice); + return StructuralEqual()(GetAccessInfo(buffer).indices, common_indice); } /*! \brief Infer the layout for parallel operations based on different inference @@ -302,7 +316,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // for i in T.Parallel(m): // fragment[0] = x[i] // then fragment[0] must be replicated on all threads. - for (const auto &[buffer, indices] : indice_map_) { + for (const auto &[buffer, access] : indice_map_) { if (T.layout_map.count(buffer)) { continue; } @@ -311,7 +325,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // Check if all indices are zero bool all_indices_zero = true; - for (const auto &index : indices) { + for (const auto &index : access.indices) { if (const auto *imm = index.as()) { if (imm->value != 0) { all_indices_zero = false; @@ -355,7 +369,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, return false; auto frag = T.layout_map[buffer].as().value(); // buffer indices should be IntImm - for (const auto &index : indice_map_[buffer]) { + for (const auto &index : GetAccessInfo(buffer).indices) { if (!index.as()) { return false; } else if (index.as()->value != 0) { @@ -366,13 +380,13 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, }; // Collect fragment buffers with const index and all fragment_buffers std::vector const_index_fragment_buffer, fragment_buffers; - for (const auto &[buffer, indices] : indice_map_) { + for (const auto &[buffer, access] : indice_map_) { if (!IsFragmentBuffer(buffer)) continue; fragment_buffers.push_back(buffer); bool is_const_index = true; - for (const auto &index : indices) { + for (const auto &index : access.indices) { if (!index.as()) { is_const_index = false; break; @@ -400,7 +414,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, Buffer source_buffer, read_source_buffer; Buffer replicated_write_buffer; // Backup: fully replicated write buffer - for (const auto &[buffer, indices] : indice_map_) { + for (const auto &[buffer, access] : indice_map_) { if (T.layout_map.count(buffer)) { // skip reducers with rep=ALL if (auto info = reducer_info_map_.Get(buffer->data); @@ -410,7 +424,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, auto frag = T.layout_map[buffer].as().value(); bool is_fully_replicated = buffer_is_completed_replicated(buffer); - if (buffer_is_write_.count(buffer)) { + if (access.is_write) { source_buffer = buffer; } else { // Keep the buffer with largest number of indices @@ -419,8 +433,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // if the buffer is completed replicated, we don't need to infer the // layout from this buffer. if ((!read_source_buffer.defined() || - indice_map_[buffer].size() > - indice_map_[read_source_buffer].size())) { + access.indices.size() > + GetAccessInfo(read_source_buffer).indices.size())) { read_source_buffer = buffer; } // If the buffer is not replicated and shape is equal to the @@ -554,18 +568,38 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // Step 2: Check that the loop's partition can correctly align with all source // fragment, and infer layout only when it's not yet layout-ed LayoutMap results; - for (const auto &[buffer, _] : indice_map_) { + for (const auto &[buffer, access] : indice_map_) { if (T.layout_map.count(buffer)) { auto fragment = T.layout_map[buffer].as().value(); auto vars = loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); - if (!ProveFragmentContains(loop_layout_, fragment, vars, - indice_map_[buffer], analyzer_)) { - std::ostringstream oss; + std::ostringstream oss; + bool success = true; + if (access.is_read && !ProveFragmentContains(loop_layout_, fragment, vars, + access.indices, analyzer_)) { + oss << "Layout infer conflict between " << buffer << " and " + << source_buffer << " in T.Parallel loop:" << '\n' + << " loop " << loop_layout_->DebugOutput() << '\n' + << " fragment " << fragment->DebugOutput() << '\n'; + success = false; + } + if (access.is_write && + !ProveFragmentContains(fragment, loop_layout_, access.indices, vars, + analyzer_)) { oss << "Layout infer conflict between " << buffer << " and " << source_buffer << " in T.Parallel loop:" << '\n' << " loop " << loop_layout_->DebugOutput() << '\n' << " fragment " << fragment->DebugOutput() << '\n'; + success = false; + } + if (this->loop_vars_[0]->var->name_hint == "i_aaa") { + if (buffer->name == "col_sum2") { + LOG(INFO) << "Debug: Target=" << fragment << " Loop=" << loop_layout_ + << " result: " << success << " has_read: " << access.is_read + << " has_write: " << access.is_write; + } + } + if (!success) { throw LayoutConflictException(oss.str()); } } else { @@ -595,11 +629,12 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { // them directly and avoid introducing a synthetic replicate dimension. { auto res2d = - arith::DetectIterMap(indice_map_[buffer], ToVMap(loop_vars_), 1, - arith::IterMapLevel::Bijective, + arith::DetectIterMap(GetAccessInfo(buffer).indices, ToVMap(loop_vars_), + 1, arith::IterMapLevel::Bijective, const_cast(&analyzer_)); if (res2d->errors.empty()) { - Layout ind_inv2d = Layout(loop_vars_, indice_map_[buffer])->Inverse(); + Layout ind_inv2d = + Layout(loop_vars_, GetAccessInfo(buffer).indices)->Inverse(); PrimExpr indice_rep_extent = 1; PrimExpr loop_rep_extent = loop_layout_->ReplicateExtent(); PrimExpr dest_buffer_rep_extent = indice_rep_extent * loop_rep_extent; @@ -616,9 +651,9 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { } // Otherwise, infer an extra flattened iterator that captures truly-unused // pieces of the loop space (if any), then try inversion with it. - PrimExpr rep_b = MakeFlattenedExpression( - DivideUnusedIterators(indice_map_[buffer], loop_vars_, &analyzer_)); - auto bijective_indice = indice_map_[buffer]; + PrimExpr rep_b = MakeFlattenedExpression(DivideUnusedIterators( + GetAccessInfo(buffer).indices, loop_vars_, &analyzer_)); + auto bijective_indice = GetAccessInfo(buffer).indices; bijective_indice.push_back(rep_b); Layout ind_inv = Layout(loop_vars_, bijective_indice)->Inverse(); @@ -645,14 +680,20 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments( const Fragment &candidate, const LayoutInferArgs &T) const { auto vars = loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); - for (const auto &[buffer, _] : indice_map_) { + for (const auto &[buffer, access] : indice_map_) { if (!T.layout_map.count(buffer)) continue; auto fragment = T.layout_map[buffer].as().value(); // check_forward_index=true: when validating loop layout against buffer // fragment, we need to ensure physical indices match for correct code gen. - if (!ProveFragmentContains(candidate, fragment, vars, indice_map_[buffer], - analyzer_, /*check_forward_index=*/true)) { + if (access.is_write && + !ProveFragmentContains(candidate, fragment, vars, access.indices, + analyzer_, /*check_forward_index=*/false)) { + return false; + } + if (access.is_read && + !ProveFragmentContains(fragment, candidate, access.indices, vars, + analyzer_, /*check_forward_index=*/false)) { return false; } } @@ -674,7 +715,7 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); PrimExpr loop_var_to_thread = - src_layout->ForwardThread(indice_map_[buffer], rep); + src_layout->ForwardThread(GetAccessInfo(buffer).indices, rep); loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { if (auto opt_var = objref.as(); diff --git a/src/op/parallel.h b/src/op/parallel.h index cff0c5f91..3bdc46f81 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -9,6 +9,8 @@ #include #include +#include + #include "../layout/layout.h" #include "../transform/layout_reducer.h" #include "./operator.h" @@ -49,6 +51,15 @@ class ParallelLoopNestVisitor : public StmtExprVisitor { // predicates. class ParallelOpNode : public TileOperatorNode { public: + struct BufferAccessInfo { + Array indices; + bool is_read = false; + bool is_write = false; + }; + + using BufferIndiceMap = std::unordered_map; + // The root For loop node. For root_; // The inferred layout for the loop, mutable to allow lazy inference. @@ -101,8 +112,8 @@ class ParallelOpNode : public TileOperatorNode { Fragment GetLoopLayout() const { return loop_layout_; } // Get the root For loop. For GetRoot() const { return root_; } - // Get the mapping from buffer to access indices. - Map> GetIndiceMap() const { return indice_map_; } + // Get the mapping from buffer to access indices + access type. + const BufferIndiceMap &GetIndiceMap() const { return indice_map_; } // Get the predicate for a given thread variable. Optional GetPredicate(Var thread_var) const; @@ -114,6 +125,11 @@ class ParallelOpNode : public TileOperatorNode { Fragment CompleteBufferFragment(const Buffer &buffer) const; // Check if the buffer is accessed with common indices (i.e., loop variables). bool IsCommonAccessIndice(const Buffer &buffer) const; + // Record buffer access and validate consistent indices. + void RecordBufferAccess(const Buffer &buffer, const Array &indices, + bool is_write); + // Access info lookup with validation. + const BufferAccessInfo &GetAccessInfo(const Buffer &buffer) const; // Validate a candidate loop layout against all source fragments in // T.layout_map. Returns true if compatible with all fragments; otherwise // false. Does not throw. @@ -153,10 +169,8 @@ class ParallelOpNode : public TileOperatorNode { // Visitor for collecting loop nest information. ParallelLoopNestVisitor V; - // Mapping from buffer to their access indices in the loop. - Map> indice_map_; - // Set of buffers that are written to in the loop. - std::unordered_set buffer_is_write_; + // Mapping from buffer to their access indices and access type in the loop. + BufferIndiceMap indice_map_; // The loop variables for the parallel loop nest. Array loop_vars_; // The inner_vars_ From 6bf71fc1696729218ec1fd1389dde85edede919a Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:23:19 +0800 Subject: [PATCH 02/14] remove debug stmts --- src/op/parallel.cc | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index c3f9cc776..730f7723d 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -592,13 +592,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, << " fragment " << fragment->DebugOutput() << '\n'; success = false; } - if (this->loop_vars_[0]->var->name_hint == "i_aaa") { - if (buffer->name == "col_sum2") { - LOG(INFO) << "Debug: Target=" << fragment << " Loop=" << loop_layout_ - << " result: " << success << " has_read: " << access.is_read - << " has_write: " << access.is_write; - } - } if (!success) { throw LayoutConflictException(oss.str()); } From 1525663297bd144b2e7b8898dd9edd93774c19b3 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Thu, 22 Jan 2026 16:50:27 +0800 Subject: [PATCH 03/14] fix typo in validater --- src/op/parallel.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 730f7723d..18f3a9a68 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -570,6 +570,9 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, LayoutMap results; for (const auto &[buffer, access] : indice_map_) { if (T.layout_map.count(buffer)) { + if (auto info = reducer_info_map_.Get(buffer->data); + info && info.value()->rep == ReducerRepType::ALL) + continue; auto fragment = T.layout_map[buffer].as().value(); auto vars = loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); @@ -676,15 +679,18 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments( for (const auto &[buffer, access] : indice_map_) { if (!T.layout_map.count(buffer)) continue; + if (auto info = reducer_info_map_.Get(buffer->data); + info && info.value()->rep == ReducerRepType::ALL) + continue; auto fragment = T.layout_map[buffer].as().value(); // check_forward_index=true: when validating loop layout against buffer // fragment, we need to ensure physical indices match for correct code gen. - if (access.is_write && + if (access.is_read && !ProveFragmentContains(candidate, fragment, vars, access.indices, analyzer_, /*check_forward_index=*/false)) { return false; } - if (access.is_read && + if (access.is_write && !ProveFragmentContains(fragment, candidate, access.indices, vars, analyzer_, /*check_forward_index=*/false)) { return false; From fadf342a06c3e2f412b51fac1e7557e52105e21b Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Fri, 23 Jan 2026 09:57:52 +0800 Subject: [PATCH 04/14] [Layout] Prevent widening of layout in ReduceOpNode::InferLayout --- src/op/reduce.cc | 9 ++++-- .../python/issue/test_tilelang_issue_1719.py | 28 +++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) create mode 100644 testing/python/issue/test_tilelang_issue_1719.py diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 7148cc076..12367b408 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -472,9 +472,12 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, throw LayoutConflictException(oss.str()); } - if (dst_rep > src_rep) { - return {{dst, dst_layout}}; - } + // We shouldn't widen the layout here, + // because is may be written by other parallel for op + // So just keep the original layout + // if (dst_rep > src_rep) { + // return {{dst, dst_layout}}; + // } } } return {}; diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py new file mode 100644 index 000000000..fcd93e066 --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -0,0 +1,28 @@ +import tilelang +import tilelang.testing +import tilelang.language as T + + +def test_tilelang_issue_1719(): + @tilelang.jit() + def _buggy_kernel(M: int, N: int) -> tilelang.JITKernel: + @T.prim_func + def kernel() -> None: + with T.Kernel(): + tmp1 = T.alloc_fragment((N, M), T.float32) + tmp2 = T.alloc_fragment((N, M), T.float32) + tmp3 = T.alloc_fragment((N, M, M), T.float32) + for i, j, k in T.Parallel(N, M, M): + tmp3[i, j, k] = 1 + T.reduce_sum(tmp3, tmp2, dim=1) + for i, k in T.Parallel(N, M): + tmp2[i, k] /= tmp1[i, k] + + return kernel + + kernel = _buggy_kernel(M=4, N=32) + assert "tmp2[(((int)threadIdx.x) & 3)]" not in kernel.get_kernel_source() + + +if __name__ == "__main__": + tilelang.testing.main() From fbc6c5faa6422a3ab7f4a84c60a0b03a3da04b7d Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Fri, 23 Jan 2026 10:34:32 +0800 Subject: [PATCH 05/14] [Layout] Enhance layout conflict handling in ReduceOpNode::InferLayout --- src/op/reduce.cc | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 12367b408..1a52a945e 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -475,9 +475,15 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, // We shouldn't widen the layout here, // because is may be written by other parallel for op // So just keep the original layout - // if (dst_rep > src_rep) { - // return {{dst, dst_layout}}; - // } + if (dst_rep > src_rep) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\nLHS = " << src_layout->DebugOutput() + << "\nRHS = " << orig_dst_layout->DebugOutput() + << "\nYou may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); + } } } return {}; From a081ce4f5055226b2dd0cf8eed23af63ff40cea2 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:07:14 +0800 Subject: [PATCH 06/14] [Layout] Fix layout issue in parallel and reduce --- src/op/reduce.cc | 241 +++++++++--------- .../python/issue/test_tilelang_issue_1719.py | 127 ++++++++- 2 files changed, 253 insertions(+), 115 deletions(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 1a52a945e..1d0dcd9f4 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -16,6 +16,8 @@ #include "../target/utils.h" #include "../transform/loop_partition.h" #include "tir/transforms/ir_utils.h" +#include "tvm/ir/expr.h" +#include "tvm/tir/expr.h" #include "tvm/tir/stmt.h" #include "utils.h" @@ -149,6 +151,40 @@ std::string ReduceOpNode::MakeCodegenReducer() const { } } +static Array InputPlaceholders(size_t n) { + Array result; + result.reserve(n); + for (size_t i = 0; i < n; ++i) { + result.push_back(InputPlaceholder(i)); + } + return result; +} + +static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) { + PrimExpr src_rep_extent = src_layout->ReplicateExtent(); + PrimExpr indice_rep_extent = src_layout->InputShape()[dim]; + PrimExpr reducer_rep_extent = indice_rep_extent * src_rep_extent; + + auto fwd = InputPlaceholders(src_layout->InputDim() - 1); + fwd.insert(fwd.begin() + dim, + FloorMod(ReplicationPlaceholder(), indice_rep_extent)); + + auto thd = src_layout->ForwardThread( + fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); + + auto reducer_shape = src_layout->InputShape(); + reducer_shape.erase(reducer_shape.begin() + dim); + if (reducer_shape.size() == 0) { + reducer_shape.push_back(1); + } + + auto reducer_layout = + Fragment(reducer_shape, {}, thd, reducer_rep_extent, std::nullopt) + ->CondenseReplicateVar() + ->BindThreadRange(src_layout->ThreadRange()); + return reducer_layout; +} + /** * @brief Lower the Reduce operator to a TIR statement. * @@ -204,6 +240,10 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { size_t src_dim = src_layout->InputDim(); size_t dst_dim = dst_layout->InputDim(); + auto red_layout = ComputeReducerLayout(src_layout, dim); + auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); + auto red_rep = *as_const_int(red_layout->ReplicateExtent()); + bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1; if (is_1d_reduce) { @@ -232,6 +272,8 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); Array dst_indices = dst_layout->Forward( dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); + Array red_indices = red_layout->Forward( + dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); Array stmts; @@ -244,25 +286,33 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Buffer clear_buffer = dst_buffer; bool need_duplicate = false; + bool need_update = false; if ((this->type->isSum() || this->type->isAbsSum()) && !this->clear) { need_duplicate = true; + need_update = true; } else if (this->type->isBitAnd() && !this->clear) { need_duplicate = true; + need_update = true; } else if ((this->type->isBitOr() || this->type->isBitXor()) && !this->clear) { need_duplicate = true; + need_update = true; + } + + if (red_rep > dst_rep) { + need_duplicate = true; } if (need_duplicate) { // Create a new buffer with same shape and dtype as dst_buffer - clear_buffer = decl_buffer(dst_buffer->shape, dst_buffer->dtype, + clear_buffer = decl_buffer(red_layout->OutputShape(), dst_buffer->dtype, dst_buffer->name + "_clear", GetPtrStorageScope(dst_buffer->data)); } // make reduce-init stmt if (require_init) { stmts.push_back( - BufferStore(clear_buffer, this->MakeInitValue(), dst_indices)); + BufferStore(clear_buffer, this->MakeInitValue(), red_indices)); } // make thread-local reduce @@ -279,9 +329,9 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Stmt reduce_local = BufferStore( clear_buffer, - this->MakeReduce(BufferLoad(clear_buffer, dst_indices), + this->MakeReduce(BufferLoad(clear_buffer, red_indices), BufferLoad(src_buffer, src_indice_compressed)), - dst_indices); + red_indices); for (int i = static_cast(src_layout->OutputDim()) - 1; i >= 0; --i) { reduce_local = @@ -321,7 +371,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { << ">::run"; } Array thread_reduce_args = { - StringImm(ss.str()), BufferLoad(clear_buffer, dst_indices)}; + StringImm(ss.str()), BufferLoad(clear_buffer, red_indices)}; if (reducing_threads > 32) { PrimExpr workspace = T.AddWorkspace( *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); @@ -329,26 +379,62 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } auto call = Call(clear_buffer->dtype, builtin::call_extern(), thread_reduce_args); - stmts.push_back(BufferStore(clear_buffer, call, dst_indices)); + stmts.push_back(BufferStore(clear_buffer, call, red_indices)); } } + // Layout status in the loop: + // clear_buffer: red_layout + // dst_buffer: dst_layout + // loop_layout: red_layout + // At each step of the loop, we do reduction on + // `clear_buffer[red_layout(loop_idx)]` + // and then transfer it to `dst_buffer[dst_layout(loop_idx)]` + // However, since the red_layout is larger than dst_layout, not all write + // operations are valid We need to add predicate to guard the write + // operations + PrimExpr predicate = Bool(true); + { + // dst_indices is the same as loop_indices + auto dst_th_indices = dst_indices; + dst_th_indices.push_back(T.thread_var); + // 1. compute loop_idx based on thread: [dst_indices, T.thread_var] => + // [loop_indices] + auto inv = dst_layout->Inverse()->Forward(dst_th_indices); + inv.pop_back(); // remove replicate var + // 2. ensure computed loop_idx maps back to the same [loop_indices] + for (int i = 0; i < static_cast(dst_layout->InputDim()); i++) { + predicate = predicate && (inv[i] == dst_vars[i]->var); + } + // 3. simplify predicate + predicate = analyzer->Simplify(predicate); + } if (need_duplicate) { - PrimExpr src_val = BufferLoad(clear_buffer, dst_indices); - PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); PrimExpr update; - if (this->type->isSum() || this->type->isAbsSum()) { - update = dst_val + src_val; - } else if (this->type->isBitAnd()) { - update = this->clear ? src_val : bitwise_and(dst_val, src_val); - } else if (this->type->isBitOr()) { - update = bitwise_or(dst_val, src_val); - } else if (this->type->isBitXor()) { - update = bitwise_xor(dst_val, src_val); + if (need_update) { + PrimExpr src_val = BufferLoad(clear_buffer, red_indices); + PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); + if (this->type->isSum() || this->type->isAbsSum()) { + update = dst_val + src_val; + } else if (this->type->isBitAnd()) { + update = this->clear ? src_val : bitwise_and(dst_val, src_val); + } else if (this->type->isBitOr()) { + update = bitwise_or(dst_val, src_val); + } else if (this->type->isBitXor()) { + update = bitwise_xor(dst_val, src_val); + } else { + LOG(FATAL) << "Unsupported reduce type: " << this->type->type; + } } else { - LOG(FATAL) << "Unsupported reduce type: " << this->type->type; + update = BufferLoad(clear_buffer, red_indices); + } + if (analyzer->CanProve(predicate)) { + stmts.push_back(BufferStore( + dst_buffer, BufferLoad(clear_buffer, red_indices), dst_indices)); + } else { + stmts.push_back(IfThenElse( + predicate, BufferStore(dst_buffer, update, dst_indices))); } - stmts.push_back(BufferStore(dst_buffer, update, dst_indices)); } Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; @@ -359,7 +445,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (dst_layout->InputDim() > 0) { body = PartitionLoop(Downcast(body), T.thread_var, analyzer, - dst_layout); + red_layout); } else { PrimExpr guard = (T.thread_var == T.thread_bounds->min); body = IfThenElse(guard, body); @@ -385,105 +471,32 @@ LayoutMap ReduceOpNode::InferLayout(const LayoutInferArgs &T, if (IsFragmentBuffer(src) && IsFragmentBuffer(dst) && T.layout_map.count(src)) { auto src_layout = T.layout_map[src].as().value(); + auto reducer_layout = ComputeReducerLayout(src_layout, this->dim); - PrimExpr indice_rep_extent = src->shape[dim]; - PrimExpr src_rep_extent = src_layout->ReplicateExtent(); - PrimExpr dest_buffer_rep_extent = indice_rep_extent * src_rep_extent; - - Array fwd; - for (int i = 0; i < static_cast(src->shape.size()); i++) { - if (i == dim) { - fwd.push_back(FloorMod(ReplicationPlaceholder(), indice_rep_extent)); - } else if (i < dim) { - fwd.push_back(InputPlaceholder(i)); - } else if (i > dim) { - fwd.push_back(InputPlaceholder(i - 1)); - } - } - auto thd = src_layout->ForwardThread( - fwd, FloorDiv(ReplicationPlaceholder(), indice_rep_extent)); - - // Ensure the thread count is divisible by the replicate extent. - // Otherwise, we cannot infer a valid fragment<->fragment layout. - { - arith::Analyzer analyzer; - PrimExpr num_threads = T.thread_bounds->extent; - // Though the dest_buffer_rep_extent will be compressed at - // CondenseReplicateVar, we need to check the divisibility here to avoid - // the issue that the thread count is not divisible by the replicate - // extent. - if (!analyzer.CanProve(FloorMod(num_threads, dest_buffer_rep_extent) == - 0) && - !analyzer.CanProve(FloorMod(dest_buffer_rep_extent, num_threads) == - 0)) { - ICHECK(false) << "ReduceOp fragment layout inference failed: " - "num_threads % replicate_extent != 0. " - << "This mapping requires the block's thread count to be " - "divisible by the " - << "replicate extent. " - << "Try one of: (1) choose a thread block size divisible " - "by replicate_extent; " - << "(2) pick a different reduce dimension or adjust the " - "source fragment layout; " - << "Details: num_threads=" << num_threads - << ", replicate_extent=" << indice_rep_extent - << ", src=" << src << ", dst=" << dst; - } + if (!T.layout_map.count(dst)) { + return {{dst, reducer_layout}}; } - Fragment dst_layout = - Fragment(dst->shape, {}, thd, dest_buffer_rep_extent, std::nullopt) - ->CondenseReplicateVar() - ->BindThreadRange(T.thread_bounds); - - if (!T.layout_map.count(dst)) - return {{dst, dst_layout}}; - else { - // Check if computed layout is compatible with existing: the existing one - // must strictly contains the computed layout - auto orig_dst_layout = - T.layout_map.Get(dst).value().as().value(); - ICHECK(dst_layout->InputDim() == orig_dst_layout->InputDim()); - Array indices; - indices.reserve(dst_layout->InputDim()); - arith::Analyzer inner_analyzer; - for (int i = 0; i < dst_layout->InputDim(); ++i) { - auto x = InputPlaceholder(i); - indices.push_back(x); - // should be literal - literal = 0, any analyzer will work - ICHECK(is_zero(inner_analyzer.Simplify( - dst_layout->InputShape()[i] - orig_dst_layout->InputShape()[i]))); - inner_analyzer.Bind(x, Range(0, dst_layout->InputShape()[i])); - } - - ICHECK(as_const_int(dst_layout->ReplicateExtent())); - ICHECK(as_const_int(src_layout->ReplicateExtent())); - auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); - auto src_rep = *as_const_int(src_layout->ReplicateExtent()); - if (dst_rep < src_rep || - !ProveFragmentContains(orig_dst_layout, dst_layout, indices, indices, - inner_analyzer)) { - std::ostringstream oss; - oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " - << src << "\nLHS = " << src_layout->DebugOutput() - << "\nRHS = " << orig_dst_layout->DebugOutput() - << "\nYou may need to use a shared memory to transform the " - "layout"; - throw LayoutConflictException(oss.str()); - } + auto orig_dst_layout = T.layout_map.Get(dst).value().as().value(); + ICHECK(reducer_layout->InputDim() == orig_dst_layout->InputDim()); - // We shouldn't widen the layout here, - // because is may be written by other parallel for op - // So just keep the original layout - if (dst_rep > src_rep) { - std::ostringstream oss; - oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " - << src << "\nLHS = " << src_layout->DebugOutput() - << "\nRHS = " << orig_dst_layout->DebugOutput() - << "\nYou may need to use a shared memory to transform the " - "layout"; - throw LayoutConflictException(oss.str()); - } + auto indices = InputPlaceholders(reducer_layout->InputDim()); + arith::Analyzer analyzer; + for (size_t i = 0; i < indices.size(); i++) { + analyzer.Bind(Downcast(indices[i]), + Range(0, reducer_layout->InputShape()[i])); + } + if (!ProveFragmentContains(orig_dst_layout, reducer_layout, indices, + indices, analyzer)) { + std::ostringstream oss; + oss << "Layout may conflict with ReduceOp for buffer " << dst << " vs. " + << src << "\n" + << "src_layout = " << src_layout << "\n" + << "reducer_layout = " << reducer_layout << "\n" + << "orig_dst_layout = " << orig_dst_layout << "\n" + << "You may need to use a shared memory to transform the " + "layout"; + throw LayoutConflictException(oss.str()); } } return {}; diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py index fcd93e066..e4db4e4e0 100644 --- a/testing/python/issue/test_tilelang_issue_1719.py +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -1,9 +1,31 @@ import tilelang +import torch import tilelang.testing import tilelang.language as T -def test_tilelang_issue_1719(): +def test_issue_1719_layout_1(): + @tilelang.jit() + def _buggy_kernel(): + @T.prim_func + def main(): + with T.Kernel(threads=32): + tmp1 = T.alloc_shared([16, 16], T.float16) + tmp2 = T.alloc_shared([16, 16], T.float16) + tmp3 = T.alloc_fragment([16, 16], T.float32) + tmp4 = T.alloc_fragment([16], T.float32) + T.gemm(tmp1, tmp2, tmp3, transpose_B=True) + T.reduce_max(tmp3, tmp4) + for i in T.Parallel(16): + tmp4[i] = 1 + + return main + + kernel = _buggy_kernel() + print(kernel.get_kernel_source()) + + +def test_issue_1719_layout_2(): @tilelang.jit() def _buggy_kernel(M: int, N: int) -> tilelang.JITKernel: @T.prim_func @@ -21,8 +43,111 @@ def kernel() -> None: return kernel kernel = _buggy_kernel(M=4, N=32) + print(kernel.get_kernel_source()) assert "tmp2[(((int)threadIdx.x) & 3)]" not in kernel.get_kernel_source() +def test_issue_1719_layout_3(): + @tilelang.jit(out_idx=-1) + def buggy_kernel(M, N, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((M, N), dtype), + B: T.Tensor((M,), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.reduce_sum(A_local, B_local, dim=1) + T.copy(B_local, B) + + return main + + M = 2 + N = 128 + kernel = buggy_kernel(M, N) + a = torch.randn(M, N, device="cuda") + b = kernel(a) + print(b, a.sum(dim=1)) + torch.testing.assert_close(b, a.sum(dim=1), atol=1e-2, rtol=1e-2) + + +def test_issue_1719_layout_4(): + @tilelang.jit() + def buggy_kernel(): + @T.prim_func + def main(): + with T.Kernel(threads=128): + Q_tail_shared = T.alloc_shared([32, 32], T.bfloat16) + K_tail_shared = T.alloc_shared([32, 32], T.bfloat16) + acc_s = T.alloc_fragment([32, 32], T.float32) + m_i = T.alloc_fragment([32], T.float32) + T.gemm(Q_tail_shared, K_tail_shared, acc_s, transpose_B=True) + T.reduce_max(acc_s, m_i) + + return main + + buggy_kernel() + + +def test_issue_1719_layout_5(): + @tilelang.jit(out_idx=-1) + def buggy_kernel(N, dtype=T.float32): + @T.prim_func + def main( + A: T.Tensor((1, N), dtype), + ): + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((1, N), dtype) + B_local = T.alloc_fragment((1,), dtype) + + T.copy(A, A_local) + T.reduce_sum(A_local, B_local, dim=1) + + return main + + buggy_kernel(128) + + +def test_issue_1719_layout_6(): + @tilelang.jit() + def buggy_kernel(): + @T.prim_func + def kernel(): + with T.Kernel(): + tmp1 = T.alloc_fragment((1,), dtype=T.float32) + tmp2 = T.alloc_fragment((1,), dtype=T.float32) + tmp1[0] = 1 + T.reduce_sum(tmp1, tmp2) + tmp2[0] + + return kernel + + buggy_kernel() + + +def test_issue_1719_layout_7(): + @tilelang.jit() + def buggy_kernel(): + @T.prim_func + def main(): + with T.Kernel(threads=32): + tmp1 = T.alloc_fragment([1, 32], T.float16) + tmp2 = T.alloc_fragment([32], T.float32) + tmp3 = T.alloc_fragment([32], T.float32) + tmp4 = T.alloc_fragment([32], T.float32) + T.reduce_max(tmp1, tmp4, dim=0) + k = 0 + T.copy(tmp1[k, :], tmp2) + for i in T.Parallel(32): + tmp3[i] += tmp2[i] - tmp4[i] + + return main + + buggy_kernel() + + if __name__ == "__main__": tilelang.testing.main() From a97661586241de7a743631233a4304b7affce776 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:26:58 +0800 Subject: [PATCH 07/14] Fix lint error --- src/op/reduce.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 1d0dcd9f4..0d79f6bf0 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -174,7 +174,7 @@ static Fragment ComputeReducerLayout(const Fragment &src_layout, int dim) { auto reducer_shape = src_layout->InputShape(); reducer_shape.erase(reducer_shape.begin() + dim); - if (reducer_shape.size() == 0) { + if (reducer_shape.empty()) { reducer_shape.push_back(1); } From 29231e8ed906527270c080d270b030397cc3b63a Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:35:11 +0800 Subject: [PATCH 08/14] Fix inconsistent predicate bug --- src/op/reduce.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 0d79f6bf0..c2e056034 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -428,12 +428,11 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } else { update = BufferLoad(clear_buffer, red_indices); } + auto store = BufferStore(dst_buffer, update, dst_indices); if (analyzer->CanProve(predicate)) { - stmts.push_back(BufferStore( - dst_buffer, BufferLoad(clear_buffer, red_indices), dst_indices)); + stmts.push_back(store); } else { - stmts.push_back(IfThenElse( - predicate, BufferStore(dst_buffer, update, dst_indices))); + stmts.push_back(IfThenElse(predicate, store)); } } From e6c66d5ebfba7f81d251d2fd660833442e5cefbb Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 12:54:07 +0800 Subject: [PATCH 09/14] Fix bug in test --- testing/python/issue/test_tilelang_issue_1719.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py index e4db4e4e0..e9a7df20b 100644 --- a/testing/python/issue/test_tilelang_issue_1719.py +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -10,13 +10,13 @@ def _buggy_kernel(): @T.prim_func def main(): with T.Kernel(threads=32): - tmp1 = T.alloc_shared([16, 16], T.float16) - tmp2 = T.alloc_shared([16, 16], T.float16) - tmp3 = T.alloc_fragment([16, 16], T.float32) - tmp4 = T.alloc_fragment([16], T.float32) + tmp1 = T.alloc_shared([32, 32], T.float16) + tmp2 = T.alloc_shared([32, 32], T.float16) + tmp3 = T.alloc_fragment([32, 32], T.float32) + tmp4 = T.alloc_fragment([32], T.float32) T.gemm(tmp1, tmp2, tmp3, transpose_B=True) T.reduce_max(tmp3, tmp4) - for i in T.Parallel(16): + for i in T.Parallel(32): tmp4[i] = 1 return main From c1d41dbd53eebfcd72439eaf32011baa2a2a56c6 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:12:58 +0800 Subject: [PATCH 10/14] Add cuda restriction on test --- testing/python/issue/test_tilelang_issue_1719.py | 1 + 1 file changed, 1 insertion(+) diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py index e9a7df20b..a884a266e 100644 --- a/testing/python/issue/test_tilelang_issue_1719.py +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -4,6 +4,7 @@ import tilelang.language as T +@tilelang.testing.requires_cuda def test_issue_1719_layout_1(): @tilelang.jit() def _buggy_kernel(): From 50462962c6ec828a688831d368f1ba33c6fce809 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:13:46 +0800 Subject: [PATCH 11/14] Add cuda restriction on test --- testing/python/issue/test_tilelang_issue_1719.py | 1 + 1 file changed, 1 insertion(+) diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py index a884a266e..62c361e0e 100644 --- a/testing/python/issue/test_tilelang_issue_1719.py +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -48,6 +48,7 @@ def kernel() -> None: assert "tmp2[(((int)threadIdx.x) & 3)]" not in kernel.get_kernel_source() +@tilelang.testing.requires_cuda def test_issue_1719_layout_3(): @tilelang.jit(out_idx=-1) def buggy_kernel(M, N, dtype=T.float32): From 780b84fbeecb0c6c3d9bd3a237ae4ef126c676c7 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:54:50 +0800 Subject: [PATCH 12/14] Fix comments --- src/op/reduce.cc | 53 +++-- .../python/issue/test_tilelang_issue_1719.py | 191 ++++++++---------- 2 files changed, 107 insertions(+), 137 deletions(-) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index c2e056034..16ef95168 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -233,18 +233,15 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (src_scope == "local.fragment" && dst_scope == "local.fragment") { - Buffer src_buffer = get_buffer(this->src); - Buffer dst_buffer = get_buffer(this->dst); - Fragment src_layout = T.layout_map[this->src].as().value(); - Fragment dst_layout = T.layout_map[this->dst].as().value(); - size_t src_dim = src_layout->InputDim(); - size_t dst_dim = dst_layout->InputDim(); - + auto src_buffer = get_buffer(this->src); + auto dst_buffer = get_buffer(this->dst); + auto src_layout = T.layout_map[this->src].as().value(); + auto dst_layout = T.layout_map[this->dst].as().value(); auto red_layout = ComputeReducerLayout(src_layout, dim); - auto dst_rep = *as_const_int(dst_layout->ReplicateExtent()); - auto red_rep = *as_const_int(red_layout->ReplicateExtent()); + auto src_dim = src_layout->InputDim(); + auto dst_dim = dst_layout->InputDim(); - bool is_1d_reduce = src_dim == dst_dim && dst_dim == 1; + auto is_1d_reduce = src_dim == dst_dim && dst_dim == 1; if (is_1d_reduce) { ICHECK(is_one(dst_layout->OutputShape().back())) @@ -268,25 +265,25 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { IterVar reduce_iv(reduce_dom, Var("rv"), IterVarType::kDataPar); src_vars.insert(src_vars.begin() + this->dim, reduce_iv); - Array src_indices = src_layout->Forward( + auto src_indices = src_layout->Forward( src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); - Array dst_indices = dst_layout->Forward( + auto dst_indices = dst_layout->Forward( dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); - Array red_indices = red_layout->Forward( + auto red_indices = red_layout->Forward( dst_vars.Map([](const auto &iv) { return PrimExpr(iv->var); })); Array stmts; - bool require_init = this->clear; + auto require_init = this->clear; if (this->type->isSum() || this->type->isAbsSum() || this->type->isBitAnd() || this->type->isBitOr() || this->type->isBitXor()) { require_init = true; } - Buffer clear_buffer = dst_buffer; - bool need_duplicate = false; - bool need_update = false; + auto clear_buffer = dst_buffer; + auto need_duplicate = false; + auto need_update = false; if ((this->type->isSum() || this->type->isAbsSum()) && !this->clear) { need_duplicate = true; need_update = true; @@ -299,7 +296,11 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { need_update = true; } - if (red_rep > dst_rep) { + // red_layout should always contain dst_layout + // if we can prove they are the same, no need to duplicate buffer + // otherwise, red_layout contains more replicated dimensions than dst_layout + if (!analyzer->CanProve(dst_layout->ReplicateExtent() == + red_layout->ReplicateExtent())) { need_duplicate = true; } @@ -319,10 +320,8 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { Array src_indice_compressed; Array src_var_compressed; for (size_t i = 0; i < src_layout->OutputDim(); ++i) { - PrimExpr expr; - IterVar var; - std::tie(expr, var) = CompressIterator( - src_indices[i], src_vars, src_vars[this->dim]->var, analyzer); + auto [expr, var] = CompressIterator(src_indices[i], src_vars, + src_vars[this->dim]->var, analyzer); src_indice_compressed.push_back(expr); src_var_compressed.push_back(var); } @@ -341,7 +340,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } stmts.push_back(reduce_local); - PrimExpr src_thread = src_layout->ForwardThread( + auto src_thread = src_layout->ForwardThread( src_vars.Map([](const auto &iv) { return PrimExpr(iv->var); }), {}); auto iter_sum = arith::NormalizeToIterSum(src_thread, ToVMap(src_vars), analyzer); @@ -412,8 +411,8 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { if (need_duplicate) { PrimExpr update; if (need_update) { - PrimExpr src_val = BufferLoad(clear_buffer, red_indices); - PrimExpr dst_val = BufferLoad(dst_buffer, dst_indices); + auto src_val = BufferLoad(clear_buffer, red_indices); + auto dst_val = BufferLoad(dst_buffer, dst_indices); if (this->type->isSum() || this->type->isAbsSum()) { update = dst_val + src_val; } else if (this->type->isBitAnd()) { @@ -436,7 +435,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { } } - Stmt body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; + auto body = stmts.size() > 1 ? SeqStmt(stmts) : stmts[0]; for (int i = static_cast(dst_layout->InputDim()) - 1; i >= 0; --i) { body = For(dst_vars[i]->var, 0, dst_vars[i]->dom->extent, ForKind::kParallel, body); @@ -446,7 +445,7 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { body = PartitionLoop(Downcast(body), T.thread_var, analyzer, red_layout); } else { - PrimExpr guard = (T.thread_var == T.thread_bounds->min); + auto guard = (T.thread_var == T.thread_bounds->min); body = IfThenElse(guard, body); } diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py index 62c361e0e..ee90948db 100644 --- a/testing/python/issue/test_tilelang_issue_1719.py +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -6,70 +6,57 @@ @tilelang.testing.requires_cuda def test_issue_1719_layout_1(): - @tilelang.jit() + @tilelang.jit def _buggy_kernel(): - @T.prim_func - def main(): - with T.Kernel(threads=32): - tmp1 = T.alloc_shared([32, 32], T.float16) - tmp2 = T.alloc_shared([32, 32], T.float16) - tmp3 = T.alloc_fragment([32, 32], T.float32) - tmp4 = T.alloc_fragment([32], T.float32) - T.gemm(tmp1, tmp2, tmp3, transpose_B=True) - T.reduce_max(tmp3, tmp4) - for i in T.Parallel(32): - tmp4[i] = 1 - - return main - - kernel = _buggy_kernel() + with T.Kernel(threads=32): + tmp1 = T.alloc_shared([32, 32], T.float16) + tmp2 = T.alloc_shared([32, 32], T.float16) + tmp3 = T.alloc_fragment([32, 32], T.float32) + tmp4 = T.alloc_fragment([32], T.float32) + T.gemm(tmp1, tmp2, tmp3, transpose_B=True) + T.reduce_max(tmp3, tmp4) + for i in T.Parallel(32): + tmp4[i] = 1 + + kernel = _buggy_kernel.compile() print(kernel.get_kernel_source()) def test_issue_1719_layout_2(): - @tilelang.jit() - def _buggy_kernel(M: int, N: int) -> tilelang.JITKernel: - @T.prim_func - def kernel() -> None: - with T.Kernel(): - tmp1 = T.alloc_fragment((N, M), T.float32) - tmp2 = T.alloc_fragment((N, M), T.float32) - tmp3 = T.alloc_fragment((N, M, M), T.float32) - for i, j, k in T.Parallel(N, M, M): - tmp3[i, j, k] = 1 - T.reduce_sum(tmp3, tmp2, dim=1) - for i, k in T.Parallel(N, M): - tmp2[i, k] /= tmp1[i, k] - - return kernel - - kernel = _buggy_kernel(M=4, N=32) + @tilelang.jit + def _buggy_kernel(M: int, N: int): + with T.Kernel(): + tmp1 = T.alloc_fragment((N, M), T.float32) + tmp2 = T.alloc_fragment((N, M), T.float32) + tmp3 = T.alloc_fragment((N, M, M), T.float32) + for i, j, k in T.Parallel(N, M, M): + tmp3[i, j, k] = 1 + T.reduce_sum(tmp3, tmp2, dim=1) + for i, k in T.Parallel(N, M): + tmp2[i, k] /= tmp1[i, k] + + kernel = _buggy_kernel.compile(M=4, N=32) print(kernel.get_kernel_source()) assert "tmp2[(((int)threadIdx.x) & 3)]" not in kernel.get_kernel_source() @tilelang.testing.requires_cuda def test_issue_1719_layout_3(): - @tilelang.jit(out_idx=-1) - def buggy_kernel(M, N, dtype=T.float32): - @T.prim_func - def main( - A: T.Tensor((M, N), dtype), - B: T.Tensor((M,), dtype), - ): - with T.Kernel(1, threads=32) as _: - A_local = T.alloc_fragment((M, N), dtype) - B_local = T.alloc_fragment((M,), dtype) - - T.copy(A, A_local) - T.reduce_sum(A_local, B_local, dim=1) - T.copy(B_local, B) - - return main - - M = 2 - N = 128 - kernel = buggy_kernel(M, N) + @tilelang.jit + def _buggy_kernel(A, B, dtype=T.float32): + M, N = T.const("M, N") + A: T.Tensor[(M, N), dtype] + B: T.Tensor[(M,), dtype] + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((M, N), dtype) + B_local = T.alloc_fragment((M,), dtype) + + T.copy(A, A_local) + T.reduce_sum(A_local, B_local, dim=1) + T.copy(B_local, B) + + M, N = 2, 128 + kernel = _buggy_kernel.compile(M=M, N=N) a = torch.randn(M, N, device="cuda") b = kernel(a) print(b, a.sum(dim=1)) @@ -77,78 +64,62 @@ def main( def test_issue_1719_layout_4(): - @tilelang.jit() - def buggy_kernel(): - @T.prim_func - def main(): - with T.Kernel(threads=128): - Q_tail_shared = T.alloc_shared([32, 32], T.bfloat16) - K_tail_shared = T.alloc_shared([32, 32], T.bfloat16) - acc_s = T.alloc_fragment([32, 32], T.float32) - m_i = T.alloc_fragment([32], T.float32) - T.gemm(Q_tail_shared, K_tail_shared, acc_s, transpose_B=True) - T.reduce_max(acc_s, m_i) - - return main + @tilelang.jit + def _buggy_kernel(): + with T.Kernel(threads=128): + Q_tail_shared = T.alloc_shared([32, 32], T.bfloat16) + K_tail_shared = T.alloc_shared([32, 32], T.bfloat16) + acc_s = T.alloc_fragment([32, 32], T.float32) + m_i = T.alloc_fragment([32], T.float32) + T.gemm(Q_tail_shared, K_tail_shared, acc_s, transpose_B=True) + T.reduce_max(acc_s, m_i) - buggy_kernel() + _buggy_kernel.compile() def test_issue_1719_layout_5(): - @tilelang.jit(out_idx=-1) - def buggy_kernel(N, dtype=T.float32): - @T.prim_func - def main( - A: T.Tensor((1, N), dtype), - ): - with T.Kernel(1, threads=32) as _: - A_local = T.alloc_fragment((1, N), dtype) - B_local = T.alloc_fragment((1,), dtype) + @tilelang.jit + def buggy_kernel(A, dtype=T.float32): + N = T.const("N") + A: T.Tensor[(1, N), dtype] + with T.Kernel(1, threads=32) as _: + A_local = T.alloc_fragment((1, N), dtype) + B_local = T.alloc_fragment((1,), dtype) - T.copy(A, A_local) - T.reduce_sum(A_local, B_local, dim=1) + T.copy(A, A_local) + T.reduce_sum(A_local, B_local, dim=1) - return main - - buggy_kernel(128) + buggy_kernel.compile(N=128) def test_issue_1719_layout_6(): - @tilelang.jit() + @tilelang.jit def buggy_kernel(): - @T.prim_func - def kernel(): - with T.Kernel(): - tmp1 = T.alloc_fragment((1,), dtype=T.float32) - tmp2 = T.alloc_fragment((1,), dtype=T.float32) - tmp1[0] = 1 - T.reduce_sum(tmp1, tmp2) - tmp2[0] - - return kernel + with T.Kernel(): + tmp1 = T.alloc_fragment((1,), dtype=T.float32) + tmp2 = T.alloc_fragment((1,), dtype=T.float32) + tmp1[0] = 1 + T.reduce_sum(tmp1, tmp2) + tmp2[0] - buggy_kernel() + buggy_kernel.compile() def test_issue_1719_layout_7(): - @tilelang.jit() + @tilelang.jit def buggy_kernel(): - @T.prim_func - def main(): - with T.Kernel(threads=32): - tmp1 = T.alloc_fragment([1, 32], T.float16) - tmp2 = T.alloc_fragment([32], T.float32) - tmp3 = T.alloc_fragment([32], T.float32) - tmp4 = T.alloc_fragment([32], T.float32) - T.reduce_max(tmp1, tmp4, dim=0) - k = 0 - T.copy(tmp1[k, :], tmp2) - for i in T.Parallel(32): - tmp3[i] += tmp2[i] - tmp4[i] - - return main - - buggy_kernel() + with T.Kernel(threads=32): + tmp1 = T.alloc_fragment([1, 32], T.float16) + tmp2 = T.alloc_fragment([32], T.float32) + tmp3 = T.alloc_fragment([32], T.float32) + tmp4 = T.alloc_fragment([32], T.float32) + T.reduce_max(tmp1, tmp4, dim=0) + k = 0 + T.copy(tmp1[k, :], tmp2) + for i in T.Parallel(32): + tmp3[i] += tmp2[i] - tmp4[i] + + buggy_kernel.compile() if __name__ == "__main__": From dde35e2def6f4eb0cc51a6eed01e63846e73b3be Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:59:18 +0800 Subject: [PATCH 13/14] Add assertion --- src/op/reduce.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/op/reduce.cc b/src/op/reduce.cc index 16ef95168..765ec1786 100644 --- a/src/op/reduce.cc +++ b/src/op/reduce.cc @@ -303,6 +303,10 @@ Stmt ReduceOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { red_layout->ReplicateExtent())) { need_duplicate = true; } + ICHECK(!analyzer->CanProve(dst_layout->ReplicateExtent() > + red_layout->ReplicateExtent())) + << "Inconsistent layouts between src and dst in ReduceOp: " + << "dst_layout=" << dst_layout << "red_layout=" << red_layout; if (need_duplicate) { // Create a new buffer with same shape and dtype as dst_buffer From cf633e17572bebff54967dcee83cd70c7e853900 Mon Sep 17 00:00:00 2001 From: kurisu6912 <227995639+kurisu6912@users.noreply.github.com> Date: Tue, 27 Jan 2026 17:24:34 +0800 Subject: [PATCH 14/14] fix bug in tests --- testing/python/issue/test_tilelang_issue_1719.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/testing/python/issue/test_tilelang_issue_1719.py b/testing/python/issue/test_tilelang_issue_1719.py index ee90948db..3d4dbc98f 100644 --- a/testing/python/issue/test_tilelang_issue_1719.py +++ b/testing/python/issue/test_tilelang_issue_1719.py @@ -43,10 +43,10 @@ def _buggy_kernel(M: int, N: int): @tilelang.testing.requires_cuda def test_issue_1719_layout_3(): @tilelang.jit - def _buggy_kernel(A, B, dtype=T.float32): + def _buggy_kernel(A, dtype=T.float32): M, N = T.const("M, N") A: T.Tensor[(M, N), dtype] - B: T.Tensor[(M,), dtype] + B = T.empty((M,), dtype) with T.Kernel(1, threads=32) as _: A_local = T.alloc_fragment((M, N), dtype) B_local = T.alloc_fragment((M,), dtype) @@ -54,6 +54,7 @@ def _buggy_kernel(A, B, dtype=T.float32): T.copy(A, A_local) T.reduce_sum(A_local, B_local, dim=1) T.copy(B_local, B) + return B M, N = 2, 128 kernel = _buggy_kernel.compile(M=M, N=N)