From 27310e7365a1cceac67c14b0a6f4c86789bbc59b Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 3 Feb 2026 16:30:34 +0800 Subject: [PATCH 1/5] [Feature] Implement ProveFragmentContains Function for Fragment Thread Validation - Added the ProveFragmentContains function to check if the threads accessing elements of a smaller fragment are a subset of those accessing a larger fragment. - This function ensures valid access when transitioning from a smaller to a larger fragment layout. - Updated layout.cc and utils.cc to incorporate this new functionality, enhancing the layout validation process. - Removed the previous implementation of ProveFragmentContains from parallel.cc to streamline the codebase. --- src/layout/layout.cc | 6 +- src/layout/utils.cc | 82 +++++++++++++ src/layout/utils.h | 23 ++++ src/op/parallel.cc | 266 ++++++++++++++++--------------------------- src/op/parallel.h | 25 ++-- 5 files changed, 220 insertions(+), 182 deletions(-) diff --git a/src/layout/layout.cc b/src/layout/layout.cc index 8b0d37cb8..26e4c7b04 100644 --- a/src/layout/layout.cc +++ b/src/layout/layout.cc @@ -236,7 +236,8 @@ Fragment FragmentNode::DeReplicate() const { PrimExpr new_forward_thread = Substitute(forward_thread_, vmap); Array new_forward_index = {FloorDiv(forward_index_[0], factor)}; return Fragment(input_size_, new_forward_index, new_forward_thread, - int(*rep_size) / factor, std::nullopt); + int(*rep_size) / factor, std::nullopt) + ->BindThreadRange(Range(0, ThreadExtent())); } Fragment FragmentNode::BindThreadRange(Range thread_range) const { @@ -554,7 +555,8 @@ Fragment::Fragment(Array input_size, Array forward_index, Fragment Fragment::FullyReplicated(Array shape, PrimExpr thread_extent) { return Fragment(shape, {}, ReplicationPlaceholder(), thread_extent, - std::nullopt); + std::nullopt) + ->BindThreadRange(Range(0, thread_extent)); } // which means the forward_thread is rep_var -> lambda i, rep: rep diff --git a/src/layout/utils.cc b/src/layout/utils.cc index 860e746a7..733b73e69 100644 --- a/src/layout/utils.cc +++ b/src/layout/utils.cc @@ -377,5 +377,87 @@ Map ToVMap(const Array &ivs) { return result; } +// ProveFragmentContains checks whether the threads that access elements of a +// smaller fragment (small_frag) are a subset of the threads that access +// elements of a larger fragment (large_frag) for any given loop index. This +// function ensures that if the small fragment's layout corresponds to the loop +// itself, accessing the large fragment's elements is valid. Additionally, if +// small is updated to large, the originally valid access remains valid. The +// proof is performed by: +// +// 1. Defining a variable `rep_small` to represent the replicate index of the +// small fragment that is being checked. +// 2. Using the `small_frag_indices` and `rep_small` to derive the thread +// accessing the element in the small fragment. +// 3. Using `large_frag_indices` to derive the physical index of the large +// fragment along with the thread information, and then feeding these into +// the inverse of the large fragment to obtain the logical index and +// replicate index. +// 4. Verifying the mapping by checking whether the computed thread using the +// inverse layout corresponds to the original thread calculated for the small +// fragment. If they don't match, this indicates that the inverse layout's +// domain does not include the thread and thus the access is invalid. +// Thanks @huanqicao for contributing this algorithm. +bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, + Array small_frag_indices, + Array large_frag_indices, + Analyzer &analyzer, bool check_forward_index) { + // When check_forward_index is true, verify that the physical indices + // (forward index) of both fragments are equal. This is required when + // validating loop layout against buffer fragment, as code generation + // needs to correctly derive buffer physical indices from loop layout. + bool large_physical_is_fully_replicated = large_frag->IsCompletedReplicated(); + if (large_physical_is_fully_replicated) { + return true; // fully replicated fragments are always compatible + } + + if (check_forward_index) { + auto small_physical = small_frag->Forward(small_frag_indices); + auto large_physical = large_frag->Forward(large_frag_indices); + // Dimension mismatch means they are not equal. + if (small_physical.size() != large_physical.size()) { + return false; + } + // Check each physical index component for equality. + for (size_t i = 0; i < small_physical.size(); i++) { + auto diff = analyzer.Simplify(small_physical[i] - large_physical[i]); + if (!is_zero(diff)) { + return false; + } + } + } + + Var rep_small("__checking_frag_contains_rep"); + analyzer.Bind(rep_small, + Range(IntImm(small_frag->ReplicateExtent()->dtype, 0), + small_frag->ReplicateExtent()), + true); // Bind the replicate extent of small_frag. + // Derive thread for small_frag. + auto thread = small_frag->ForwardThread(small_frag_indices, rep_small); + + // Get physical index and thread for large_frag. + auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices); + // Add small_frag's thread to the large fragment's thread info. + large_frag_physical_and_thread.push_back(thread); + // Get the inverse of the large fragment. + auto inv_large_frag = large_frag->Inverse(); + // Compute logical index and replicate index using inverse layout. + auto inv_large_frag_logical_and_rep = + inv_large_frag->Forward(large_frag_physical_and_thread); + + // Extract replicate index from the result. + auto inv_large_frag_rep = + inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1]; + + // Calculate thread based on the logical index and replicate index. + auto check_thread = + large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep); + + // Simplify the difference between the threads. + auto diff = analyzer.Simplify(thread - check_thread); + // If the difference is zero, the threads match and the access is valid. + return is_zero(diff); +} + } // namespace tl } // namespace tvm diff --git a/src/layout/utils.h b/src/layout/utils.h index 0f03a8617..cae9ecde5 100644 --- a/src/layout/utils.h +++ b/src/layout/utils.h @@ -10,6 +10,7 @@ #include #include "../support/ffi_aliases.h" +#include "layout.h" namespace tvm { namespace tl { @@ -66,6 +67,28 @@ Map ToVMap(const Array &ivs); */ Array ToIterVars(const Map &vmap); +/*! + * \brief Check whether the threads that access elements of a smaller fragment + * are a subset of the threads that access elements of a larger fragment. + * + * This function ensures that if the small fragment's layout corresponds to the + * loop itself, accessing the large fragment's elements is valid. Additionally, + * if small is updated to large, the originally valid access remains valid. + * + * \param small_frag The smaller fragment to check + * \param large_frag The larger fragment to check against + * \param small_frag_indices The indices used to access small_frag + * \param large_frag_indices The indices used to access large_frag + * \param analyzer The analyzer for simplification + * \param check_forward_index Whether to also check physical index equality + * \return true if small_frag's threads are contained in large_frag's threads + */ +bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, + Array small_frag_indices, + Array large_frag_indices, + arith::Analyzer &analyzer, + bool check_forward_index = false); + } // namespace tl } // namespace tvm diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 35caa464f..0a2f64230 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -22,89 +22,7 @@ namespace tl { using namespace tir; -// ProveFragmentContains checks whether the threads that access elements of a -// smaller fragment (small_frag) are a subset of the threads that access -// elements of a larger fragment (large_frag) for any given loop index. This -// function ensures that if the small fragment's layout corresponds to the loop -// itself, accessing the large fragment's elements is valid. Additionally, if -// small is updated to large, the originally valid access remains valid. The -// proof is performed by: -// -// 1. Defining a variable `rep_small` to represent the replicate index of the -// small fragment that is being checked. -// 2. Using the `small_frag_indices` and `rep_small` to derive the thread -// accessing -// the element in the small fragment. -// 3. Using `large_frag_indices` to derive the physical index of the large -// fragment -// along with the thread information, and then feeding these into the inverse -// of the large fragment to obtain the logical index and replicate index. -// 4. Verifying the mapping by checking whether the computed thread using the -// inverse -// layout corresponds to the original thread calculated for the small -// fragment. If they don't match, this indicates that the inverse layout's -// domain does not include the thread and thus the access is invalid. -bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, - Array small_frag_indices, - Array large_frag_indices, - arith::Analyzer &analyzer_, - bool check_forward_index) { - // When check_forward_index is true, verify that the physical indices - // (forward index) of both fragments are equal. This is required when - // validating loop layout against buffer fragment, as code generation - // needs to correctly derive buffer physical indices from loop layout. - bool large_physical_is_fully_replicated = large_frag->IsCompletedReplicated(); - if (large_physical_is_fully_replicated) { - return true; // fully replicated fragments are always compatible - } - - if (check_forward_index) { - auto small_physical = small_frag->Forward(small_frag_indices); - auto large_physical = large_frag->Forward(large_frag_indices); - // Dimension mismatch means they are not equal. - if (small_physical.size() != large_physical.size()) { - return false; - } - // Check each physical index component for equality. - for (size_t i = 0; i < small_physical.size(); i++) { - auto diff = analyzer_.Simplify(small_physical[i] - large_physical[i]); - if (!is_zero(diff)) { - return false; - } - } - } - - Var rep_small("__checking_frag_contains_rep"); - analyzer_.Bind(rep_small, - Range(IntImm(small_frag->ReplicateExtent()->dtype, 0), - small_frag->ReplicateExtent()), - true); // Bind the replicate extent of small_frag. - // Derive thread for small_frag. - auto thread = small_frag->ForwardThread(small_frag_indices, rep_small); - - // Get physical index and thread for large_frag. - auto large_frag_physical_and_thread = large_frag->Forward(large_frag_indices); - // Add small_frag's thread to the large fragment's thread info. - large_frag_physical_and_thread.push_back(thread); - // Get the inverse of the large fragment. - auto inv_large_frag = large_frag->Inverse(); - // Compute logical index and replicate index using inverse layout. - auto inv_large_frag_logical_and_rep = - inv_large_frag->Forward(large_frag_physical_and_thread); - - // Extract replicate index from the result. - auto inv_large_frag_rep = - inv_large_frag_logical_and_rep[inv_large_frag_logical_and_rep.size() - 1]; - - // Calculate thread based on the logical index and replicate index. - auto check_thread = - large_frag->ForwardThread(large_frag_indices, inv_large_frag_rep); - - // Simplify the difference between the threads. - auto diff = analyzer_.Simplify(thread - check_thread); - // If the difference is zero, the threads match and the access is valid. - return is_zero(diff); -} +namespace { class IfBufferRemapLoopGenerator : public StmtExprMutator { public: @@ -145,6 +63,8 @@ class IfBufferRemapLoopGenerator : public StmtExprMutator { Map layout_map_; }; +} // anonymous namespace + /** * @brief Handle a parallel For node during traversal, collecting loop metadata. * @@ -197,6 +117,22 @@ ParallelOpNode::ParallelOpNode(For root) : root_(root), V(this) { annotated_predicate_ = Downcast( root_->annotations.Get(kParallelLoopPredicate).value()); } + // Collect cross-thread access info and buffer store info. + PostOrderVisit(root_, [&](const ObjectRef &obj) { + if (const auto *store = obj.as()) { + auto buffer = store->buffer; + if (IsSharedBuffer(buffer) || IsGlobalBuffer(buffer)) { + has_cross_thread_access_ = true; + store_shared_global_buffers_.emplace_back(buffer); + } else if (IsFragmentBuffer(buffer)) { + store_fragment_buffers_.emplace_back(buffer); + } + } else if (const auto *load = obj.as()) { + if (IsSharedBuffer(load->buffer) || IsGlobalBuffer(load->buffer)) { + has_cross_thread_access_ = true; + } + } + }); } TileOperator ParallelOpNode::Clone() const { @@ -269,6 +205,22 @@ ParallelOpNode::GetAccessInfo(const Buffer &buffer) const { return it->second; } +bool ParallelOpNode::IsBufferCompletelyReplicated( + const Buffer &buffer, const LayoutMap &layout_map) const { + if (!IsFragmentBuffer(buffer)) + return false; + auto frag = layout_map[buffer].as().value(); + // buffer indices should be IntImm + for (const auto &index : GetAccessInfo(buffer).indices) { + if (!index.as()) { + return false; + } else if (index.as()->value != 0) { + LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; + } + } + return frag->IsCompletedReplicated(); +} + Stmt ParallelOpNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return root_; @@ -364,20 +316,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, return results; } - auto buffer_is_completed_replicated = [&](const Buffer &buffer) { - if (!IsFragmentBuffer(buffer)) - return false; - auto frag = T.layout_map[buffer].as().value(); - // buffer indices should be IntImm - for (const auto &index : GetAccessInfo(buffer).indices) { - if (!index.as()) { - return false; - } else if (index.as()->value != 0) { - LOG(FATAL) << "buffer " << buffer << " is not completed replicated"; - } - } - return frag->IsCompletedReplicated(); - }; // Collect fragment buffers with const index and all fragment_buffers std::vector const_index_fragment_buffer, fragment_buffers; for (const auto &[buffer, access] : indice_map_) { @@ -422,7 +360,8 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, continue; auto frag = T.layout_map[buffer].as().value(); - bool is_fully_replicated = buffer_is_completed_replicated(buffer); + bool is_fully_replicated = + IsBufferCompletelyReplicated(buffer, T.layout_map); if (access.is_write) { source_buffer = buffer; @@ -466,37 +405,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, loop_layout_ = ComputeLoopLayoutFromBuffer(source_buffer, T); } else if (!loop_layout_.defined() && level == InferLevel::kFree) { // For free layout inference - // If replication exists and buffer has cross-thread shared memory access, - // add predicate - bool has_cross_thread_access = false; - PostOrderVisit(root_, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { - if (IsSharedBuffer(store->buffer) || IsGlobalBuffer(store->buffer)) { - has_cross_thread_access = true; - } - } else if (const auto *load = obj.as()) { - if (IsSharedBuffer(load->buffer) || IsGlobalBuffer(load->buffer)) { - has_cross_thread_access = true; - } - } - }); - - // check if loop body contains a "pure" buffer store (i.e., direct - // assignment, not compound update) - std::vector store_shared_global_buffers, store_fragment_buffers; - // Buffers that scope is above fragments. - // global, shared, shared.dyn - // which can be used to analysis replicate case - PostOrderVisit(root_, [&](const ObjectRef &obj) { - if (const auto *store = obj.as()) { - auto buffer = store->buffer; - if (IsSharedBuffer(buffer) || IsGlobalBuffer(buffer)) { - store_shared_global_buffers.emplace_back(buffer); - } else if (IsFragmentBuffer(buffer)) { - store_fragment_buffers.emplace_back(buffer); - } - } - }); // In free inference, try two mechanisms and prefer the one that // minimizes replication while remaining compatible: // 1) compute_loop_layout_from_buffer (always correct but may @@ -507,11 +415,14 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, if (read_source_buffer.defined() && allow_layout_propgate) { candidate_from_buffer = ComputeLoopLayoutFromBuffer(read_source_buffer, T); + LOG(INFO) << "read_source_buffer: " << read_source_buffer; + LOG(INFO) << "candidate_from_buffer: " << candidate_from_buffer; } // try to infer loop layout with two mechanisms and choose the best one { candidate_from_plan = ComputePlanCandidate(T); + LOG(INFO) << "candidate_from_plan: " << candidate_from_plan; } // Choose the best candidate: @@ -526,10 +437,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, DLOG(INFO) << "[FreeInfer] only compute_from_buffer available, choose it."; } - - BuildReplicationGuardsIfNeeded( - T, store_shared_global_buffers, store_fragment_buffers, - has_cross_thread_access, const_index_fragment_buffer); } else if (!loop_layout_.defined()) { // In non-free mode without a source buffer, if we don't have any layout // yet (e.g., no annotation), we have nothing to infer here. @@ -566,39 +473,28 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } // Step 2: Check that the loop's partition can correctly align with all source - // fragment, and infer layout only when it's not yet layout-ed + // fragment, and infer layout only when it's not yet layout-ed. + // Try DeReplicate first to reduce replication if possible. + Fragment dereplicated_layout = loop_layout_->DeReplicate(); + ; + if (ValidateCandidateAgainstFragments( + dereplicated_layout, T, /*throw_on_error=*/false, + /*check_forward_index=*/false, source_buffer)) { + loop_layout_ = dereplicated_layout; + } + ValidateCandidateAgainstFragments(loop_layout_, T, /*throw_on_error=*/true, + /*check_forward_index=*/false, + source_buffer); + + // Step 3: Build replication guards + BuildReplicationGuardsIfNeeded( + T, store_shared_global_buffers_, store_fragment_buffers_, + has_cross_thread_access_, const_index_fragment_buffer); + + // Step 4: Collect buffer fragments LayoutMap results; for (const auto &[buffer, access] : indice_map_) { - if (T.layout_map.count(buffer)) { - if (auto info = reducer_info_map_.Get(buffer->data); - info && info.value()->rep == ReducerRepType::ALL) - continue; - auto fragment = T.layout_map[buffer].as().value(); - auto vars = - loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); - std::ostringstream oss; - bool success = true; - if (access.is_read && !ProveFragmentContains(loop_layout_, fragment, vars, - access.indices, analyzer_)) { - oss << "Layout infer conflict between " << buffer << " and " - << source_buffer << " in T.Parallel loop:" << '\n' - << " loop " << loop_layout_->DebugOutput() << '\n' - << " fragment " << fragment->DebugOutput() << '\n'; - success = false; - } - if (access.is_write && - !ProveFragmentContains(fragment, loop_layout_, access.indices, vars, - analyzer_)) { - oss << "Layout infer conflict between " << buffer << " and " - << source_buffer << " in T.Parallel loop:" << '\n' - << " loop " << loop_layout_->DebugOutput() << '\n' - << " fragment " << fragment->DebugOutput() << '\n'; - success = false; - } - if (!success) { - throw LayoutConflictException(oss.str()); - } - } else { + if (!T.layout_map.count(buffer)) { auto dst_layout = CompleteBufferFragment(buffer)->BindThreadRange(T.thread_bounds); results.Set(buffer, dst_layout); @@ -673,7 +569,8 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); } bool ParallelOpNode::ValidateCandidateAgainstFragments( - const Fragment &candidate, const LayoutInferArgs &T) const { + const Fragment &candidate, const LayoutInferArgs &T, bool throw_on_error, + bool check_forward_index, const Buffer &source_buffer) const { auto vars = loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); for (const auto &[buffer, access] : indice_map_) { @@ -683,16 +580,34 @@ bool ParallelOpNode::ValidateCandidateAgainstFragments( info && info.value()->rep == ReducerRepType::ALL) continue; auto fragment = T.layout_map[buffer].as().value(); - // check_forward_index=true: when validating loop layout against buffer - // fragment, we need to ensure physical indices match for correct code gen. + std::ostringstream oss; + bool success = true; if (access.is_read && !ProveFragmentContains(candidate, fragment, vars, access.indices, - analyzer_, /*check_forward_index=*/false)) { - return false; + analyzer_, check_forward_index)) { + if (throw_on_error) { + oss << "Layout infer conflict between " << buffer << " and " + << source_buffer << " in T.Parallel loop:" << '\n' + << " loop " << candidate->DebugOutput() << '\n' + << " fragment " << fragment->DebugOutput() << '\n'; + } + success = false; } if (access.is_write && !ProveFragmentContains(fragment, candidate, access.indices, vars, - analyzer_, /*check_forward_index=*/false)) { + analyzer_, check_forward_index)) { + if (throw_on_error) { + oss << "Layout infer conflict between " << buffer << " and " + << source_buffer << " in T.Parallel loop:" << '\n' + << " loop " << candidate->DebugOutput() << '\n' + << " fragment " << fragment->DebugOutput() << '\n'; + } + success = false; + } + if (!success) { + if (throw_on_error) { + throw LayoutConflictException(oss.str()); + } return false; } } @@ -711,12 +626,19 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, if (IsCommonAccessIndice(buffer)) { result = src_layout; } else { - Var rep; + Var rep("_rep"); auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); + LOG(INFO) << "rep: " << rep; + LOG(INFO) << "src_layout: " << src_layout->DebugOutput(); + LOG(INFO) << "src_layout->DeReplicate(): " + << src_layout->DeReplicate()->DebugOutput(); + LOG(INFO) << "Create rep_iter: " << rep_iter + << " from rep_extent: " << src_layout->ReplicateExtent(); PrimExpr loop_var_to_thread = src_layout->ForwardThread(GetAccessInfo(buffer).indices, rep); loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); + LOG(INFO) << "loop_var_to_thread after simplify: " << loop_var_to_thread; PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { if (auto opt_var = objref.as(); opt_var && inner_vars_.count(*opt_var)) { @@ -730,6 +652,8 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, try { result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) ->BindThreadRange(T.thread_bounds); + LOG(INFO) << "result: " << result; + LOG(INFO) << "result->DeReplicate(): " << result->DeReplicate(); } catch (const tvm::runtime::Error &err) { std::ostringstream msg; msg << "Layout inference for buffer `" << buffer->name diff --git a/src/op/parallel.h b/src/op/parallel.h index 3bdc46f81..751e14a22 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -12,6 +12,7 @@ #include #include "../layout/layout.h" +#include "../layout/utils.h" #include "../transform/layout_reducer.h" #include "./operator.h" @@ -26,12 +27,6 @@ namespace tl { using namespace tir; -bool ProveFragmentContains(Fragment small_frag, Fragment large_frag, - Array small_frag_indices, - Array large_frag_indices, - arith::Analyzer &analyzer_, - bool check_forward_index = false); - class ParallelOpNode; class ParallelLoopNestVisitor : public StmtExprVisitor { @@ -130,11 +125,17 @@ class ParallelOpNode : public TileOperatorNode { bool is_write); // Access info lookup with validation. const BufferAccessInfo &GetAccessInfo(const Buffer &buffer) const; + // Check if a buffer is completely replicated (all threads hold same data). + bool IsBufferCompletelyReplicated(const Buffer &buffer, + const LayoutMap &layout_map) const; // Validate a candidate loop layout against all source fragments in // T.layout_map. Returns true if compatible with all fragments; otherwise - // false. Does not throw. - bool ValidateCandidateAgainstFragments(const Fragment &candidate, - const LayoutInferArgs &T) const; + // false. When throw_on_error is true, throws LayoutConflictException with + // detailed error message on failure. + bool ValidateCandidateAgainstFragments( + const Fragment &candidate, const LayoutInferArgs &T, + bool throw_on_error = false, bool check_forward_index = false, + const Buffer &source_buffer = Buffer()) const; // Choose the better loop layout from two candidates using validation, // containment and replication heuristic. Fragment ChooseBestCandidate(const Fragment &candidate_from_buffer, @@ -179,6 +180,12 @@ class ParallelOpNode : public TileOperatorNode { mutable arith::Analyzer analyzer_; // Mapping from buffer to reducer info. Map reducer_info_map_; + // Whether the loop body has cross-thread shared/global memory access. + bool has_cross_thread_access_ = false; + // Buffers that are stored to shared/global memory in the loop body. + std::vector store_shared_global_buffers_; + // Fragment buffers that are stored to in the loop body. + std::vector store_fragment_buffers_; }; class ParallelOp : public TileOperator { From dc8005743a446c3fa50e6addb92bf91fa4221e0e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Tue, 3 Feb 2026 16:34:12 +0800 Subject: [PATCH 2/5] fix --- src/op/parallel.cc | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 0a2f64230..ef7722aa0 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -415,14 +415,11 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, if (read_source_buffer.defined() && allow_layout_propgate) { candidate_from_buffer = ComputeLoopLayoutFromBuffer(read_source_buffer, T); - LOG(INFO) << "read_source_buffer: " << read_source_buffer; - LOG(INFO) << "candidate_from_buffer: " << candidate_from_buffer; } // try to infer loop layout with two mechanisms and choose the best one { candidate_from_plan = ComputePlanCandidate(T); - LOG(INFO) << "candidate_from_plan: " << candidate_from_plan; } // Choose the best candidate: @@ -629,16 +626,9 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, Var rep("_rep"); auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); - LOG(INFO) << "rep: " << rep; - LOG(INFO) << "src_layout: " << src_layout->DebugOutput(); - LOG(INFO) << "src_layout->DeReplicate(): " - << src_layout->DeReplicate()->DebugOutput(); - LOG(INFO) << "Create rep_iter: " << rep_iter - << " from rep_extent: " << src_layout->ReplicateExtent(); PrimExpr loop_var_to_thread = src_layout->ForwardThread(GetAccessInfo(buffer).indices, rep); loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); - LOG(INFO) << "loop_var_to_thread after simplify: " << loop_var_to_thread; PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { if (auto opt_var = objref.as(); opt_var && inner_vars_.count(*opt_var)) { @@ -652,8 +642,6 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, try { result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) ->BindThreadRange(T.thread_bounds); - LOG(INFO) << "result: " << result; - LOG(INFO) << "result->DeReplicate(): " << result->DeReplicate(); } catch (const tvm::runtime::Error &err) { std::ostringstream msg; msg << "Layout inference for buffer `" << buffer->name From b473625b3a26fb81f7a9512477632cea71ee1b96 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 4 Feb 2026 10:43:04 +0800 Subject: [PATCH 3/5] Refactor ParallelOpNode Layout Handling - Removed the initial DeReplicate attempt from InferLayout to streamline layout inference. - Added DeReplicate logic to ComputeLoopLayoutFromBuffer to reduce replication when validating layout candidates. - Updated test cases to disable caching and ensure proper functionality of loop layout kernels. --- src/op/parallel.cc | 17 +++++++++-------- .../test_tilelang_annotate_loop_layout.py | 10 +++++----- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index ef7722aa0..e0bddda54 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -471,14 +471,6 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // Step 2: Check that the loop's partition can correctly align with all source // fragment, and infer layout only when it's not yet layout-ed. - // Try DeReplicate first to reduce replication if possible. - Fragment dereplicated_layout = loop_layout_->DeReplicate(); - ; - if (ValidateCandidateAgainstFragments( - dereplicated_layout, T, /*throw_on_error=*/false, - /*check_forward_index=*/false, source_buffer)) { - loop_layout_ = dereplicated_layout; - } ValidateCandidateAgainstFragments(loop_layout_, T, /*throw_on_error=*/true, /*check_forward_index=*/false, source_buffer); @@ -656,6 +648,15 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, } DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " << result->DebugOutput() << '\n'; + // Try DeReplicate first to reduce replication if possible. + Fragment dereplicated_layout = result->DeReplicate(); + if (ValidateCandidateAgainstFragments( + dereplicated_layout, T, /*throw_on_error=*/false, + /*check_forward_index=*/false, /*source_buffer=*/buffer)) { + DLOG(INFO) << "[compute_loop_layout_from_buffer] DeReplicate success, get " + << dereplicated_layout->DebugOutput() << '\n'; + result = dereplicated_layout; + } return result; } diff --git a/testing/python/layout/test_tilelang_annotate_loop_layout.py b/testing/python/layout/test_tilelang_annotate_loop_layout.py index 4f41ebf93..7eed9bdbf 100644 --- a/testing/python/layout/test_tilelang_annotate_loop_layout.py +++ b/testing/python/layout/test_tilelang_annotate_loop_layout.py @@ -3,7 +3,6 @@ import tilelang.language as T -# TODO(lei): replicate loop layout and more complicated layout cases @tilelang.jit def loop_layout_kernel(A, B, loop_layout): M, N = T.const("M, N") @@ -41,9 +40,10 @@ def loop_layout_fn(i, j): M, N = 128, 32 loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn) + tilelang.disable_cache() kernel = loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() - + print(code) assert "*(float4*)(B + ((((int)threadIdx.x) * 32) + (i * 4))) = *(float4*)(A + ((((int)threadIdx.x) * 32) + (i * 4)));" in code @@ -88,6 +88,7 @@ def replicate_loop_layout_kernel(A, B, loop_layout): @tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) def test_annotate_replicate_loop_layout_vec4(): M, N = 128, 32 @@ -98,10 +99,8 @@ def loop_layout_fn(i, j, rep): return forward_thread, forward_local loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn, replicate=2) - kernel = replicate_loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() - assert ( "*(float4*)(B + ((i * 256) + ((((int)threadIdx.x) & 63) * 4))) = *(float4*)(A + ((i * 256) + ((((int)threadIdx.x) & 63) * 4)));" in code @@ -109,4 +108,5 @@ def loop_layout_fn(i, j, rep): if __name__ == "__main__": - tilelang.testing.main() + # tilelang.testing.main() + test_loop_layout_identity() From b001c88bfb738c5907ca1929b87a140fda10a704 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 5 Feb 2026 17:24:50 +0800 Subject: [PATCH 4/5] fix --- src/op/parallel.cc | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/op/parallel.cc b/src/op/parallel.cc index e0bddda54..b93467b54 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -648,15 +648,15 @@ ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, } DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " << result->DebugOutput() << '\n'; - // Try DeReplicate first to reduce replication if possible. - Fragment dereplicated_layout = result->DeReplicate(); - if (ValidateCandidateAgainstFragments( - dereplicated_layout, T, /*throw_on_error=*/false, - /*check_forward_index=*/false, /*source_buffer=*/buffer)) { - DLOG(INFO) << "[compute_loop_layout_from_buffer] DeReplicate success, get " - << dereplicated_layout->DebugOutput() << '\n'; - result = dereplicated_layout; - } + // Lei: This is a tradeoff, disable it for now. + // // Try DeReplicate first to reduce replication if possible. + // Fragment dereplicated_layout = candidate_from_buffer->DeReplicate(); + // if (ValidateCandidateAgainstFragments( + // dereplicated_layout, T, /*throw_on_error=*/false, + // /*check_forward_index=*/false, + // /*source_buffer=*/read_source_buffer)) { + // candidate_from_buffer = dereplicated_layout; + // } return result; } From 29a1bc6a9d6c3665df84f1685747987ad3721000 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Thu, 5 Feb 2026 17:26:44 +0800 Subject: [PATCH 5/5] Refactor Test Cases for Loop Layout - Removed caching disablement and print statements from the loop layout identity test for cleaner output. - Updated the main execution block to directly call the testing framework, enhancing test execution flow. --- testing/python/layout/test_tilelang_annotate_loop_layout.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/testing/python/layout/test_tilelang_annotate_loop_layout.py b/testing/python/layout/test_tilelang_annotate_loop_layout.py index 7eed9bdbf..52653a9d1 100644 --- a/testing/python/layout/test_tilelang_annotate_loop_layout.py +++ b/testing/python/layout/test_tilelang_annotate_loop_layout.py @@ -40,10 +40,8 @@ def loop_layout_fn(i, j): M, N = 128, 32 loop_layout = T.Fragment((M, N), forward_fn=loop_layout_fn) - tilelang.disable_cache() kernel = loop_layout_kernel.compile(M=M, N=N, loop_layout=loop_layout) code = kernel.get_kernel_source() - print(code) assert "*(float4*)(B + ((((int)threadIdx.x) * 32) + (i * 4))) = *(float4*)(A + ((((int)threadIdx.x) * 32) + (i * 4)));" in code @@ -108,5 +106,4 @@ def loop_layout_fn(i, j, rep): if __name__ == "__main__": - # tilelang.testing.main() - test_loop_layout_identity() + tilelang.testing.main()