diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 6290c3361..2adb24cd0 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -390,50 +390,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } } } - auto compute_loop_layout_from_buffer = [&](const Buffer &buffer) { - Fragment src_layout = T.layout_map[buffer].as().value(); - DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `" - << buffer << "` of layout " << src_layout->DebugOutput() << '\n'; - - Fragment result; - if (IsCommonAccessIndice(buffer)) { - result = src_layout; - } else { - Var rep; - auto rep_iter = IterVar({0, src_layout->ReplicateExtent()}, rep, - IterVarType::kDataPar); - PrimExpr loop_var_to_thread = - src_layout->ForwardThread(indice_map_[buffer], rep); - loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); - PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { - if (auto opt_var = objref.as(); - opt_var && inner_vars_.count(*opt_var)) { - std::ostringstream oss; - oss << "loop_var_to_thread = " << loop_var_to_thread - << "contains inner var" << *opt_var; - throw LayoutConflictException(oss.str()); - } - }); - - try { - result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) - ->BindThreadRange(T.thread_bounds); - } catch (const tvm::runtime::Error &err) { - std::ostringstream msg; - msg << "Layout inference for buffer `" << buffer->name - << "` failed inside `T.parallel` loop."; - - msg << "\nUnderlying TVM error: " << err.what(); - msg << "\nProblematic loop AST:\n " << root_; - msg << "\nHint: ensure the loop extent divides the thread binding or " - "adjust the fragment mapping."; - LOG(FATAL) << msg.str(); - } - } - DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " - << result->DebugOutput() << '\n'; - return result; - }; + // moved to ComputeLoopLayoutFromBuffer // Try to infer loop layout from buffers in order of preference: // 1. Non-replicated write buffer (most reliable) @@ -442,7 +399,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // 4. Free inference mode (no source buffer) if (source_buffer.defined() && allow_layout_propgate) { - loop_layout_ = compute_loop_layout_from_buffer(source_buffer); + loop_layout_ = ComputeLoopLayoutFromBuffer(source_buffer, T); } else if (level == InferLevel::kFree) { // For free layout inference // If replication exists and buffer has cross-thread shared memory access, @@ -483,122 +440,39 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, } } }); + // In free inference, try two mechanisms and prefer the one that + // minimizes replication while remaining compatible: + // 1) compute_loop_layout_from_buffer (always correct but may + // over-replicate) 2) PlanLoopPartition (often smaller replication) + Fragment candidate_from_buffer; + Fragment candidate_from_plan; + if (read_source_buffer.defined() && allow_layout_propgate) { - loop_layout_ = compute_loop_layout_from_buffer(read_source_buffer); + candidate_from_buffer = + ComputeLoopLayoutFromBuffer(read_source_buffer, T); } - if (!loop_layout_.defined()) { - // No source buffer available, use free mode inference - // Vectorize Size must be aware of the buffer_remap - // As the pass will do post processing to the layout - auto maybe_remapped_root_ = - IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); - int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); - DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; - - PrimExpr loop_total_size = 1; - for (Stmt l = root_; l.as().has_value(); - l = l.as().value()->body) - loop_total_size = loop_total_size * l.as().value()->extent; - DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size - << '\n'; - while (!analyzer_.CanProve( - floormod(loop_total_size, - T.thread_bounds->extent * vector_size) == 0) && - vector_size > 1) - vector_size /= 2; - DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = " - << vector_size << '\n'; - - // Check if coalesced_width is defined - if (auto coalesced_width = - root_->annotations.Get(attr::kCoalescedWidth)) { - if (const auto *imm = coalesced_width->as()) { - int expected = imm->value; - // Verify that vector_size is divisible by expected - if (vector_size % expected != 0) { - LOG(FATAL) << "Vector size " << vector_size - << " is not divisible by coalesced width " << expected; - } - vector_size = expected; - } else { - LOG(FATAL) << "coalesced_width should be an IntImmNode."; - } - } - DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ - << " ############# vector_size = " << vector_size - << ", thread_bounds = " << T.thread_bounds << '\n'; - loop_layout_ = PlanLoopPartition(root_, vector_size, T.thread_bounds); - DLOG(INFO) << "[PlanLoopPartition] loop_layout_ = " - << loop_layout_->DebugOutput() << '\n'; + // try to infer loop layout with two mechanisms and choose the best one + { + candidate_from_plan = ComputePlanCandidate(T); } - // Lambda that guards replicated accesses: - // - When a loop layout replicates a fragment buffer (rep > 1), each thread - // observes the same fragment elements. Blindly storing to shared/global - // memory in that case would add the same value multiple times. - // - We therefore restrict the store so that only the replica with rep == 0 - // performs the update (e.g. global[i] += fragment[i] only fires once). - // Trigger conditions for this guard: - // 1) There are cross-thread stores targeting shared/global memory (no - // fragment stores in this branch; atomic_add and similar remain TODO). - // 2) The loop layout replicate extent is greater than 1, inferred from the - // thread bounds captured in the layout. - - [this, &store_shared_global_buffers, &store_fragment_buffers, - &has_cross_thread_access, &const_index_fragment_buffer, &T]() { - if (is_one(loop_layout_->ReplicateExtent())) - return; - if (!has_cross_thread_access) - return; - - if (!store_fragment_buffers.empty()) { - // Iterate replicated fragment stores: when the fragment index is a - // constant (e.g. fragment[0]), every thread touches the same slot, so - // the rep == 0 predicate is unnecessary. Example: for i in - // T.Parallel(...): - // shared[i] = ... - // fragment[0] = ... - bool replicate_is_from_dynamic_index_fragment = false; - for (const auto &fragment : store_fragment_buffers) { - if (!T.layout_map.count(fragment)) { - continue; - } - - auto fragment_layout = T.layout_map[fragment].as().value(); - if (is_one(fragment_layout->ReplicateExtent())) - continue; - - if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(), - loop_layout_->ReplicateExtent())) - continue; - if (std::find(const_index_fragment_buffer.begin(), - const_index_fragment_buffer.end(), - fragment) == const_index_fragment_buffer.end()) { - replicate_is_from_dynamic_index_fragment = true; - } - } - - if (!replicate_is_from_dynamic_index_fragment) - return; + // Choose the best candidate: + if (candidate_from_buffer.defined() && candidate_from_plan.defined()) { + loop_layout_ = + ChooseBestCandidate(candidate_from_buffer, candidate_from_plan, T); + } else if (candidate_from_plan.defined()) { + loop_layout_ = candidate_from_plan; + DLOG(INFO) << "[FreeInfer] only PlanLoopPartition available, choose it."; + } else if (candidate_from_buffer.defined()) { + loop_layout_ = candidate_from_buffer; + DLOG(INFO) + << "[FreeInfer] only compute_from_buffer available, choose it."; + } - ICHECK(store_shared_global_buffers.empty()) - << "Invalid layout: cannot have both fragment and shared store " - "buffers " - "in replicated loop layout."; - return; - } else { - // Now, store is global or shared - // or T.call_extern or T.call_intrin ... - auto inv = loop_layout_->Inverse(); - Array fwd; - for (size_t i = 0; i < loop_layout_->OutputDim(); i++) - fwd.push_back(0); - fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min); - auto rep = inv->Forward(fwd).back(); - AddPredicate(EQ(rep, 0)); - } - }(); + BuildReplicationGuardsIfNeeded( + T, store_shared_global_buffers, store_fragment_buffers, + has_cross_thread_access, const_index_fragment_buffer); } else { return {}; } @@ -720,5 +594,223 @@ Fragment ParallelOpNode::CompleteBufferFragment(const Buffer &buffer) const { TVM_FFI_STATIC_INIT_BLOCK() { ParallelOpNode::RegisterReflection(); } +bool ParallelOpNode::ValidateCandidateAgainstFragments( + const Fragment &candidate, const LayoutInferArgs &T) const { + auto vars = + loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); + for (const auto &[buffer, _] : indice_map_) { + if (!T.layout_map.count(buffer)) + continue; + auto fragment = T.layout_map[buffer].as().value(); + if (!ProveFragmentContains(candidate, fragment, vars, indice_map_[buffer], + analyzer_)) { + return false; + } + } + return true; +} + +Fragment +ParallelOpNode::ComputeLoopLayoutFromBuffer(const Buffer &buffer, + const LayoutInferArgs &T) const { + Fragment src_layout = T.layout_map[buffer].as().value(); + DLOG(INFO) << "[compute_loop_layout_from_buffer] infer from buffer `" + << buffer << "` of layout " << src_layout->DebugOutput() << '\n'; + + Fragment result; + if (IsCommonAccessIndice(buffer)) { + result = src_layout; + } else { + Var rep; + auto rep_iter = + IterVar({0, src_layout->ReplicateExtent()}, rep, IterVarType::kDataPar); + PrimExpr loop_var_to_thread = + src_layout->ForwardThread(indice_map_[buffer], rep); + loop_var_to_thread = analyzer_.Simplify(loop_var_to_thread); + PostOrderVisit(loop_var_to_thread, [&](const ObjectRef &objref) { + if (auto opt_var = objref.as(); + opt_var && inner_vars_.count(*opt_var)) { + std::ostringstream oss; + oss << "loop_var_to_thread = " << loop_var_to_thread + << "contains inner var" << *opt_var; + throw LayoutConflictException(oss.str()); + } + }); + + try { + result = Fragment(loop_vars_, {}, loop_var_to_thread, rep_iter) + ->BindThreadRange(T.thread_bounds); + } catch (const tvm::runtime::Error &err) { + std::ostringstream msg; + msg << "Layout inference for buffer `" << buffer->name + << "` failed inside `T.parallel` loop."; + + msg << "\nUnderlying TVM error: " << err.what(); + msg << "\nProblematic loop AST:\n " << root_; + msg << "\nHint: ensure the loop extent divides the thread binding or " + "adjust the fragment mapping."; + LOG(FATAL) << msg.str(); + } + } + DLOG(INFO) << "[compute_loop_layout_from_buffer] ... and get " + << result->DebugOutput() << '\n'; + return result; +} + +Fragment ParallelOpNode::ComputePlanCandidate(const LayoutInferArgs &T) const { + // Vectorize Size must be aware of the buffer_remap + // As the pass will do post processing to the layout + auto maybe_remapped_root_ = + IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); + int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); + DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; + + PrimExpr loop_total_size = 1; + for (Stmt l = root_; l.as().has_value(); l = l.as().value()->body) + loop_total_size = loop_total_size * l.as().value()->extent; + DLOG(INFO) << "[PlanLoopPartition] loop_total_size = " << loop_total_size + << '\n'; + while (!analyzer_.CanProve(floormod(loop_total_size, T.thread_bounds->extent * + vector_size) == 0) && + vector_size > 1) + vector_size /= 2; + DLOG(INFO) << "[PlanLoopPartition] after adjust: vector_size = " + << vector_size << '\n'; + + // Check if coalesced_width is defined + if (auto coalesced_width = root_->annotations.Get(attr::kCoalescedWidth)) { + if (const auto *imm = coalesced_width->as()) { + int expected = imm->value; + // Verify that vector_size is divisible by expected + if (vector_size % expected != 0) { + LOG(FATAL) << "Vector size " << vector_size + << " is not divisible by coalesced width " << expected; + } + vector_size = expected; + } else { + LOG(FATAL) << "coalesced_width should be an IntImmNode."; + } + } + DLOG(INFO) << "[PlanLoopPartition] root_ = " << root_ + << " ############# vector_size = " << vector_size + << ", thread_bounds = " << T.thread_bounds << '\n'; + auto plan = PlanLoopPartition(root_, vector_size, T.thread_bounds); + DLOG(INFO) << "[PlanLoopPartition] candidate = " << plan->DebugOutput() + << '\n'; + return plan; +} + +void ParallelOpNode::BuildReplicationGuardsIfNeeded( + const LayoutInferArgs &T, + const std::vector &store_shared_global_buffers, + const std::vector &store_fragment_buffers, + bool has_cross_thread_access, + const std::vector &const_index_fragment_buffer) const { + if (is_one(loop_layout_->ReplicateExtent())) + return; + if (!has_cross_thread_access) + return; + + if (!store_fragment_buffers.empty()) { + bool replicate_is_from_dynamic_index_fragment = false; + for (const auto &fragment : store_fragment_buffers) { + if (!T.layout_map.count(fragment)) { + continue; + } + + auto fragment_layout = T.layout_map[fragment].as().value(); + if (is_one(fragment_layout->ReplicateExtent())) + continue; + + if (analyzer_.CanProveEqual(fragment_layout->ReplicateExtent(), + loop_layout_->ReplicateExtent())) + continue; + if (std::find(const_index_fragment_buffer.begin(), + const_index_fragment_buffer.end(), + fragment) == const_index_fragment_buffer.end()) { + replicate_is_from_dynamic_index_fragment = true; + } + } + + if (!replicate_is_from_dynamic_index_fragment) + return; + + ICHECK(store_shared_global_buffers.empty()) + << "Invalid layout: cannot have both fragment and shared store buffers " + "in replicated loop layout."; + return; + } else { + auto inv = loop_layout_->Inverse(); + Array fwd; + for (size_t i = 0; i < loop_layout_->OutputDim(); i++) + fwd.push_back(0); + fwd.push_back(InputPlaceholder(0) - T.thread_bounds->min); + auto rep = inv->Forward(fwd).back(); + AddPredicate(EQ(rep, 0)); + } +} +Fragment +ParallelOpNode::ChooseBestCandidate(const Fragment &candidate_from_buffer, + const Fragment &candidate_from_plan, + const LayoutInferArgs &T) const { + // Strategy overview: + // 1) Validate each candidate against all known source fragments. If only one + // is compatible, choose it immediately. + // 2) If both are compatible, compare their containment relation: + // - If buffer-based contains plan-based, prefer plan (usually smaller + // rep). + // - If plan-based contains buffer-based, prefer buffer. + // 3) If neither contains the other, prefer the one with provably smaller or + // equal replication extent; otherwise fall back to buffer-based candidate. + // Note: Final global validation happens after selection elsewhere. + auto vars = + loop_vars_.Map([](const IterVar &iv) { return PrimExpr(iv->var); }); + auto contains = [&](const Fragment &big, const Fragment &small) { + // contains(A, B) means: for any loop index, the threads that access + // B's elements are a subset of those that access A's elements. + return ProveFragmentContains(small, big, vars, vars, analyzer_); + }; + + bool buf_ok = ValidateCandidateAgainstFragments(candidate_from_buffer, T); + bool plan_ok = ValidateCandidateAgainstFragments(candidate_from_plan, T); + + if (buf_ok && !plan_ok) { + DLOG(INFO) + << "[FreeInfer] prefer compute_from_buffer (only valid candidate)."; + return candidate_from_buffer; + } + if (plan_ok && !buf_ok) { + DLOG(INFO) + << "[FreeInfer] prefer PlanLoopPartition (only valid candidate)."; + return candidate_from_plan; + } + if (!(buf_ok && plan_ok)) { + // Both invalid here; let the caller continue to final validation/throw. + // Returning buffer-based candidate keeps behavior deterministic. + return candidate_from_buffer; // arbitrary; caller will catch later + } + + bool buf_contains_plan = contains(candidate_from_buffer, candidate_from_plan); + bool plan_contains_buf = contains(candidate_from_plan, candidate_from_buffer); + + auto rep_buf = candidate_from_buffer->ReplicateExtent(); + auto rep_plan = candidate_from_plan->ReplicateExtent(); + + // Prefer the contained candidate (tends to minimize replication while + // respecting access coverage): + if (buf_contains_plan && !plan_contains_buf) { + return candidate_from_plan; + } + if (plan_contains_buf && !buf_contains_plan) { + return candidate_from_buffer; + } + // Neither strictly contains the other; prefer the one with smaller/equal rep. + if (analyzer_.CanProve(rep_plan <= rep_buf)) { + return candidate_from_plan; + } + // Safe fallback: buffer-based candidate is always correct. + return candidate_from_buffer; +} + } // namespace tl } // namespace tvm diff --git a/src/op/parallel.h b/src/op/parallel.h index 88dd1debf..6b132a552 100644 --- a/src/op/parallel.h +++ b/src/op/parallel.h @@ -101,6 +101,29 @@ class ParallelOpNode : public TileOperatorNode { Fragment CompleteBufferFragment(const Buffer &buffer) const; // Check if the buffer is accessed with common indices (i.e., loop variables). bool IsCommonAccessIndice(const Buffer &buffer) const; + // Validate a candidate loop layout against all source fragments in + // T.layout_map. Returns true if compatible with all fragments; otherwise + // false. Does not throw. + bool ValidateCandidateAgainstFragments(const Fragment &candidate, + const LayoutInferArgs &T) const; + // Choose the better loop layout from two candidates using validation, + // containment and replication heuristic. + Fragment ChooseBestCandidate(const Fragment &candidate_from_buffer, + const Fragment &candidate_from_plan, + const LayoutInferArgs &T) const; + // Compute loop layout from a source buffer's fragment mapping. + Fragment ComputeLoopLayoutFromBuffer(const Buffer &buffer, + const LayoutInferArgs &T) const; + // Compute plan-based loop layout candidate using vectorization and thread + // bounds. + Fragment ComputePlanCandidate(const LayoutInferArgs &T) const; + // Add replication guard predicates when needed for cross-thread stores. + void BuildReplicationGuardsIfNeeded( + const LayoutInferArgs &T, + const std::vector &store_shared_global_buffers, + const std::vector &store_fragment_buffers, + bool has_cross_thread_access, + const std::vector &const_index_fragment_buffer) const; // Add a predicate to the current predicate expression. void AddPredicate(const PrimExpr &expr) const { predicate_ = predicate_.defined() ? And(expr, predicate_.value()) : expr; diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index b9ec1e952..daaa7b4cc 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -333,7 +333,6 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { in_queue); // step 3: relax constraints to free and re-run InferInFreeMode(layout_map, strict_layout_map); - // step 4: finalize alias layouts by Var // For each storage var, if any buffer in the group has a layout, // propagate (reshape if needed) to the rest to ensure completeness. @@ -448,17 +447,17 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { return buffer_map; } - // Return true if all buffers that this op (idx) touches already have - // inferred layouts in layout_map. Used to prioritize enqueue order. - bool ShouldPrioritize(int idx, const LayoutMap &layout_map) const { + // Return true if any buffer that this op (idx) touches already has + // an inferred layout in layout_map. Used to prioritize enqueue order. + bool HasKnownLayoutAnchor(int idx, const LayoutMap &layout_map) const { auto it = op_touched_buffers_.find(idx); if (it == op_touched_buffers_.end() || it->second.empty()) return false; for (const auto &buf : it->second) { - if (!layout_map.count(buf)) - return false; + if (layout_map.count(buf)) + return true; } - return true; + return false; } // Enqueue idx to q with priority if all its buffers already @@ -473,7 +472,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { if (in_queue[idx]) return; in_queue[idx] = true; - if (ShouldPrioritize(idx, layout_map)) { + if (HasKnownLayoutAnchor(idx, layout_map)) { q.push_front(idx); } else { q.push_back(idx); @@ -602,15 +601,10 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Track which buffers this op (infer_idx) touches for prioritization. // Avoid duplicates. auto &vec = op_touched_buffers_[infer_idx]; - bool exists = false; - for (const auto &b : vec) { - if (b.same_as(buffer)) { - exists = true; - break; - } - } - if (!exists) + if (std::none_of(vec.begin(), vec.end(), + [&](const Buffer &b) { return b.same_as(buffer); })) { vec.push_back(buffer); + } } void VisitStmt_(const ForNode *op) final { diff --git a/testing/python/issue/test_tilelang_issue_layout.py b/testing/python/issue/test_tilelang_issue_layout.py new file mode 100644 index 000000000..831d9d8ef --- /dev/null +++ b/testing/python/issue/test_tilelang_issue_layout.py @@ -0,0 +1,36 @@ +import tilelang +import tilelang.testing +from tilelang import language as T + + +@tilelang.jit +def _tilelang_issue_layout_free_inference_choose_smallest_replication(): + @T.prim_func + def main(A: T.Tensor((128, 4), T.float), B: T.Tensor((128, 4), T.float)): + with T.Kernel(1, threads=128) as _: + A_frag = T.alloc_fragment((128, 4), T.float) + B_frag = T.alloc_fragment((128, 4), T.float) + S_frag = T.alloc_fragment((4,), T.float) + T.annotate_layout( + { + A_frag: T.Fragment(A_frag.shape, lambda i, j: (i, j)), + } + ) + for i, j in T.Parallel(128, 4): + A_frag[i, j] = S_frag[j] + for i, j in T.Parallel(128, 4): + B_frag[i, j] = S_frag[j] + + return main + + +def test_tilelang_issue_layout_free_inference_choose_smallest_replication(): + kernel = _tilelang_issue_layout_free_inference_choose_smallest_replication() + source = kernel.get_kernel_source() + assert "float S_frag[4];" in source, "S_frag is not in the source" + assert "float B_frag[4];" in source, "B_frag is not in the source" + assert "float A_frag[4];" in source, "A_frag is not in the source" + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/language/annotations.py b/tilelang/language/annotations.py index 43ca9c051..6e95cdafe 100644 --- a/tilelang/language/annotations.py +++ b/tilelang/language/annotations.py @@ -2,7 +2,8 @@ from typing import Callable -from tilelang.layout import Layout +from tilelang.layout import Fragment, Layout +from tilelang.utils.language import is_fragment from tvm.script.parser.tir import attr, block_attr from tvm.tir import FloatImm @@ -27,6 +28,8 @@ def annotate_layout(layout_map: dict): """Annotate the layout of the buffer.""" _layout_map = {} for buffer, layout in layout_map.items(): + if is_fragment(buffer): + assert isinstance(layout, Fragment), f"for Fragment {buffer}, layout must be a Fragment, but got {type(layout)}" if isinstance(layout, Layout): _layout_map[buffer.data] = layout elif isinstance(layout, Callable):