From cfc429da47e19afb3149823403f4616b5bbbc44a Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Fri, 21 Nov 2025 21:07:44 +0800 Subject: [PATCH 1/5] [Refactor] Update Vectorization Functions to Accept Analyzer Parameter - Modified `VectorizeLoop` and related functions to accept an `arith::Analyzer` parameter, enhancing their capability to perform analysis during vectorization. - Updated multiple instances in `copy.cc`, `fill.cc`, `parallel.cc`, and layout inference files to utilize the new analyzer parameter for improved performance and correctness. - Ensured consistency across vectorization logic by integrating the analyzer into existing workflows, facilitating better optimization opportunities. --- 3rdparty/tvm | 2 +- src/op/copy.cc | 4 +- src/op/fill.cc | 6 +- src/op/parallel.cc | 3 +- src/transform/layout_inference.cc | 12 ++- src/transform/legalize_vectorized_loop.cc | 2 +- src/transform/loop_vectorize.cc | 92 +++++++++++++++-------- src/transform/loop_vectorize.h | 5 ++ 8 files changed, 82 insertions(+), 44 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 713e6ade5..cd2b2b601 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 713e6ade56eaa72cc85d58d9228dd9f34cc2d03e +Subproject commit cd2b2b6013d155b5822300b0a0740fa65320dd9e diff --git a/src/op/copy.cc b/src/op/copy.cc index 5d3529044..2584abced 100644 --- a/src/op/copy.cc +++ b/src/op/copy.cc @@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto par_op = ParallelOp(transformed_loop); if (is_cpu_target) { - vectorized_thread_loop = VectorizeLoop(transformed_loop); + vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer); } else { std::vector levels = {InferLevel::kCommon, InferLevel::kStrict, InferLevel::kFree}; @@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T, auto thread_var = T.thread_var; auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout); - vectorized_thread_loop = VectorizeLoop(thread_loop); + vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); } if (par_op->GetPredicate(T.thread_var).defined()) { diff --git a/src/op/fill.cc b/src/op/fill.cc index 83b0842dc..93b3bca07 100644 --- a/src/op/fill.cc +++ b/src/op/fill.cc @@ -207,7 +207,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); @@ -215,7 +215,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { return vectorized_thread_loop; } else if (dst.scope() == "local") { auto init_loop = MakeSIMTLoop(analyzer); - auto vectorized_thread_loop = VectorizeLoop(init_loop); + auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer); return vectorized_thread_loop; } else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" || dst.scope() == "global") { @@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const { InferLevel::kFree); auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, par_op->GetLoopLayout()); - auto vectorized_thread_loop = VectorizeLoop(thread_loop); + auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer); if (par_op->GetPredicate(T.thread_var).defined()) { return IfThenElse(par_op->GetPredicate(T.thread_var).value(), vectorized_thread_loop); diff --git a/src/op/parallel.cc b/src/op/parallel.cc index 81777aa53..0d09cc129 100644 --- a/src/op/parallel.cc +++ b/src/op/parallel.cc @@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T, // As the pass will do post processing to the layout auto maybe_remapped_root_ = IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map); - int vector_size = GetVectorizeSize(maybe_remapped_root_); - + int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer); DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n'; PrimExpr loop_total_size = 1; diff --git a/src/transform/layout_inference.cc b/src/transform/layout_inference.cc index bd726b3db..be98b284d 100644 --- a/src/transform/layout_inference.cc +++ b/src/transform/layout_inference.cc @@ -12,6 +12,7 @@ #include #include +#include #include #include "../layout/utils.h" @@ -85,6 +86,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { auto &next = infer_list_[cur_infer_id]; auto iter_var = thread_var_vec_[cur_infer_id]; auto thread_bounds = thread_bounds_vec_[cur_infer_id]; + arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get(); auto buffer_oob = buffer_oob_vec_[cur_infer_id]; // Double-check that 'next' is valid ICHECK(next.defined()) << "infer_list_[" << cur_infer_id @@ -108,7 +110,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { // Run InferLayout auto updates = next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map, - &analyzer_, buffer_oob}, + cur_analyzer, buffer_oob}, level); // Process the returned updates for (const auto &[buffer, layout] : updates) { @@ -266,6 +268,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size()) << "Size mismatch: thread_bounds_vec_ and infer_list_ must match in " "length."; + ICHECK_EQ(analyzer_vec_.size(), infer_list_.size()) + << "Size mismatch: analyzer_vec_ and infer_list_ must match in " + "length."; ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size()) << "Size mismatch: buffer_oob_vec_ and infer_list_ must match in " "length."; @@ -452,6 +457,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } else { thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); } + analyzer_vec_.push_back(analyzer_.Clone()); // Compute buffer oob for each buffer in the op if (const auto *copy = p.as()) { @@ -542,6 +548,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { } else { thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1)); } + analyzer_vec_.push_back(analyzer_.Clone()); buffer_oob_vec_.push_back(false); } else { IRVisitorWithAnalyzer::VisitStmt(op->body); @@ -683,6 +690,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer { IterVarType::kDataPar); std::vector thread_var_vec_; std::vector thread_bounds_vec_; + std::vector> analyzer_vec_; std::vector buffer_oob_vec_; Target target_; LayoutMap annotated_layout_map_; @@ -1024,7 +1032,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer { }); if ((has_non_local || has_cast_operations) && !has_reducer) { - for_node = VectorizeLoop(for_node); + for_node = VectorizeLoop(for_node, analyzer_); } if (result_.predicate_map.count(root) && parallel_loop) { diff --git a/src/transform/legalize_vectorized_loop.cc b/src/transform/legalize_vectorized_loop.cc index aa461784a..4fd4ab91f 100644 --- a/src/transform/legalize_vectorized_loop.cc +++ b/src/transform/legalize_vectorized_loop.cc @@ -73,7 +73,7 @@ class LoopVectorizedLegalizer : IRMutatorWithAnalyzer { // Change the loop kind from vectorized to serial for_node.CopyOnWrite()->kind = ForKind::kSerial; // Apply vectorization transformation to the loop - return VectorizeLoop(for_node); + return VectorizeLoop(for_node, analyzer_); } }; diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 45283d905..627341dfe 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -45,7 +45,7 @@ struct VectorizePlanResult { PrimExpr condition; }; -class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { +class VectorizeFindGlobalAccess : public StmtExprVisitor { public: VectorizeFindGlobalAccess() = default; @@ -60,19 +60,21 @@ class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer { void VisitStmt_(const BufferStoreNode *node) final { if (node->buffer.scope() == "global") has_global_access_ = true; - return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return StmtExprVisitor::VisitStmt_(node); } void VisitExpr_(const BufferLoadNode *node) final { if (node->buffer.scope() == "global") has_global_access_ = true; - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return StmtExprVisitor::VisitExpr_(node); } }; -class VectorizePlanner : public arith::IRVisitorWithAnalyzer { +class VectorizePlanner : public arith::IRMutatorWithAnalyzer { public: - VectorizePlanner() = default; + VectorizePlanner() : arith::IRMutatorWithAnalyzer(new arith::Analyzer()) {} + explicit VectorizePlanner(arith::Analyzer *analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} int Plan(const For &node) { tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current(); @@ -92,21 +94,31 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { } private: - void VisitStmt_(const ForNode *node) final { + Stmt VisitStmt_(const ForNode *node) final { inner_for_ = node; - auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent)); - // Here I disable dynamic shape completely, - // In order to do it, the Planner should accept an analyzer with - // arithmetic info outside to prove the dividiblity of vector size - if (!extent_ptr) { - vector_size_ = 1; - return; + bool contains_nested_for = false; + // Must analysis vectorization on the innermost loop + PostOrderVisit(Downcast(node), [&](const ObjectRef &obj) { + if (obj.as()) { + contains_nested_for = true; + } + }); + + if (!contains_nested_for) { + auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent)); + // Here I disable dynamic shape completely, + // In order to do it, the Planner should accept an analyzer with + // arithmetic info outside to prove the dividiblity of vector size + if (!extent_ptr) { + vector_size_ = 1; + return ffi::GetRef(node); + } + vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); } - vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr); - arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitExpr_(const BufferLoadNode *node) final { + PrimExpr VisitExpr_(const BufferLoadNode *node) final { if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || node->buffer.scope() == "shared.dyn") has_nonlocal_memory_access_ = true; @@ -115,43 +127,44 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { // constant buffer that tl hack to use as local register. auto boundary_check = node->buffer->shape[0].as(); if (boundary_check && boundary_check->value == 1) { - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } } UpdateVectorSize(node->indices, node->buffer); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } - void VisitStmt_(const BufferStoreNode *node) final { + Stmt VisitStmt_(const BufferStoreNode *node) final { if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" || node->buffer.scope() == "shared.dyn") has_nonlocal_memory_access_ = true; UpdateVectorSize(node->indices, node->buffer); - return arith::IRVisitorWithAnalyzer::VisitExpr(node->value); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitStmt_(const IfThenElseNode *node) final { + Stmt VisitStmt_(const IfThenElseNode *node) final { CheckConditionVectorized(node->condition); - return arith::IRVisitorWithAnalyzer::VisitStmt_(node); + return arith::IRMutatorWithAnalyzer::VisitStmt_(node); } - void VisitExpr_(const CallNode *node) final { + PrimExpr VisitExpr_(const CallNode *node) final { if (node->op == builtin::if_then_else()) { CheckConditionVectorized(node->args[0]); } else if (node->op == builtin::call_extern()) { // do not vectorize extern calls vector_size_ = 1; } - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } void CheckConditionVectorized(const PrimExpr &cond) { // TODO: perform some checks here } - void VisitExpr_(const CastNode *node) final { + PrimExpr VisitExpr_(const CastNode *node) final { vector_size_ = arith::ZeroAwareGCD( vector_load_bits_max_ / node->dtype.bits(), vector_size_); - return arith::IRVisitorWithAnalyzer::VisitExpr_(node); + return arith::IRMutatorWithAnalyzer::VisitExpr_(node); } void UpdateVectorSize(const Array indices, const Buffer &buffer) { @@ -171,19 +184,16 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer { for (int i = 0; i < indices.size(); ++i) { elem_offset += indices[i] * strides[i]; } - // 2. If element offset is independent with loop_var, ignore it - if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) { + if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) { return; } - // 3. Tight vectorize bound vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ / buffer->dtype.bits()); - // 4. Try to vectorize buffer load while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var, - inner_for_->extent, vector_size_, &analyzer_)) { + inner_for_->extent, vector_size_, analyzer_)) { vector_size_ /= 2; } } @@ -237,6 +247,10 @@ class VectorizeRewriter : public StmtExprMutator { int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) { + return VectorizePlanner(analyzer).Plan(loop); +} + bool CanProveIndependent(const PrimExpr &expr, Var var, arith::Analyzer *analyzer) { // 1. if var doesn't exist, it is independent @@ -274,10 +288,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter), 0)) return false; - + auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}})); // The base offset must be divisible - if (!analyzer->CanProveEqual( - FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) { + if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr), + zero)) { return false; } @@ -317,5 +331,17 @@ For VectorizeLoop(const For &loop, int vectorize_hint) { return Downcast(rewriter(loop)); } +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint) { + if (vectorize_hint <= 0) { + VectorizePlanner planner(analyzer); + vectorize_hint = planner.Plan(loop); + } + if (vectorize_hint == 1) + return loop; + auto rewriter = VectorizeRewriter(vectorize_hint); + return Downcast(rewriter(loop)); +} + } // namespace tl } // namespace tvm diff --git a/src/transform/loop_vectorize.h b/src/transform/loop_vectorize.h index 4ab20c668..a63c4b450 100644 --- a/src/transform/loop_vectorize.h +++ b/src/transform/loop_vectorize.h @@ -35,8 +35,13 @@ using namespace tir; int GetVectorizeSize(const For &loop); +int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer); + For VectorizeLoop(const For &loop, int vectorize_hint = -1); +For VectorizeLoop(const For &loop, arith::Analyzer *analyzer, + int vectorize_hint = -1); + // Can prove expr is independent with var, i.e. the value of expr doesn't change // when var changes bool CanProveIndependent(const PrimExpr &expr, Var var, From 7484f05f8430d1b180c9a1f88fd98df57953aee2 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Fri, 21 Nov 2025 21:59:11 +0800 Subject: [PATCH 2/5] [Fix] Corrected PostOrderVisit call in loop_vectorize.cc - Updated the PostOrderVisit function to analyze the body of the loop node instead of the node itself, ensuring proper handling of nested loops during vectorization analysis. --- src/transform/loop_vectorize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 627341dfe..3f314b1d8 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -98,7 +98,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { inner_for_ = node; bool contains_nested_for = false; // Must analysis vectorization on the innermost loop - PostOrderVisit(Downcast(node), [&](const ObjectRef &obj) { + PostOrderVisit(Downcast(node->body), [&](const ObjectRef &obj) { if (obj.as()) { contains_nested_for = true; } From be0d3192739c50191e0489be0948e8c50c49d316 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Fri, 21 Nov 2025 23:13:35 +0800 Subject: [PATCH 3/5] fix --- src/transform/loop_vectorize.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 3f314b1d8..b9180bc75 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -72,7 +72,8 @@ class VectorizeFindGlobalAccess : public StmtExprVisitor { class VectorizePlanner : public arith::IRMutatorWithAnalyzer { public: - VectorizePlanner() : arith::IRMutatorWithAnalyzer(new arith::Analyzer()) {} + VectorizePlanner() + : arith::IRMutatorWithAnalyzer(&owned_analyzer_), owned_analyzer_() {} explicit VectorizePlanner(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} @@ -203,6 +204,9 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { const ForNode *inner_for_{}; bool has_nonlocal_memory_access_ = false; int vector_size_ = 128; + +private: + arith::Analyzer owned_analyzer_; }; class VectorizeRewriter : public StmtExprMutator { From ab44b7433efc5f6c3de9bd60f23c2e8440a5b49b Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Fri, 21 Nov 2025 23:29:42 +0800 Subject: [PATCH 4/5] lint fix --- src/transform/loop_vectorize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index b9180bc75..52b153ed3 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -204,7 +204,7 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { const ForNode *inner_for_{}; bool has_nonlocal_memory_access_ = false; int vector_size_ = 128; - + private: arith::Analyzer owned_analyzer_; }; From 6221326278a69f2d845ea6b37d37ed2d19433117 Mon Sep 17 00:00:00 2001 From: Lei Wang Date: Sat, 22 Nov 2025 00:13:57 +0800 Subject: [PATCH 5/5] fix --- src/transform/loop_vectorize.cc | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/transform/loop_vectorize.cc b/src/transform/loop_vectorize.cc index 52b153ed3..e8a18b004 100644 --- a/src/transform/loop_vectorize.cc +++ b/src/transform/loop_vectorize.cc @@ -72,8 +72,6 @@ class VectorizeFindGlobalAccess : public StmtExprVisitor { class VectorizePlanner : public arith::IRMutatorWithAnalyzer { public: - VectorizePlanner() - : arith::IRMutatorWithAnalyzer(&owned_analyzer_), owned_analyzer_() {} explicit VectorizePlanner(arith::Analyzer *analyzer) : arith::IRMutatorWithAnalyzer(analyzer) {} @@ -204,9 +202,6 @@ class VectorizePlanner : public arith::IRMutatorWithAnalyzer { const ForNode *inner_for_{}; bool has_nonlocal_memory_access_ = false; int vector_size_ = 128; - -private: - arith::Analyzer owned_analyzer_; }; class VectorizeRewriter : public StmtExprMutator { @@ -249,7 +244,10 @@ class VectorizeRewriter : public StmtExprMutator { const int vector_size_; }; -int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); } +int GetVectorizeSize(const For &loop) { + arith::Analyzer analyzer; + return VectorizePlanner(&analyzer).Plan(loop); +} int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) { return VectorizePlanner(analyzer).Plan(loop); @@ -326,7 +324,8 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var, For VectorizeLoop(const For &loop, int vectorize_hint) { if (vectorize_hint <= 0) { - VectorizePlanner planner; + arith::Analyzer analyzer; + VectorizePlanner planner(&analyzer); vectorize_hint = planner.Plan(loop); } if (vectorize_hint == 1)