Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from bc31e7 to cd2b2b
4 changes: 2 additions & 2 deletions src/op/copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
auto par_op = ParallelOp(transformed_loop);

if (is_cpu_target) {
vectorized_thread_loop = VectorizeLoop(transformed_loop);
vectorized_thread_loop = VectorizeLoop(transformed_loop, analyzer);
} else {
std::vector<InferLevel> levels = {InferLevel::kCommon, InferLevel::kStrict,
InferLevel::kFree};
Expand All @@ -865,7 +865,7 @@ Stmt CopyNode::LowerNormalCopy(const LowerArgs &T,
auto thread_var = T.thread_var;
auto thread_loop =
PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer, loop_layout);
vectorized_thread_loop = VectorizeLoop(thread_loop);
vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
}

if (par_op->GetPredicate(T.thread_var).defined()) {
Expand Down
6 changes: 3 additions & 3 deletions src/op/fill.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,15 +207,15 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
}
return vectorized_thread_loop;
} else if (dst.scope() == "local") {
auto init_loop = MakeSIMTLoop(analyzer);
auto vectorized_thread_loop = VectorizeLoop(init_loop);
auto vectorized_thread_loop = VectorizeLoop(init_loop, analyzer);
return vectorized_thread_loop;
} else if (dst.scope() == "shared.dyn" || dst.scope() == "shared" ||
dst.scope() == "global") {
Expand All @@ -225,7 +225,7 @@ Stmt FillNode::Lower(const LowerArgs &T, arith::Analyzer *analyzer) const {
InferLevel::kFree);
auto thread_loop = PartitionLoop(par_op->GetRoot(), T.thread_var, analyzer,
par_op->GetLoopLayout());
auto vectorized_thread_loop = VectorizeLoop(thread_loop);
auto vectorized_thread_loop = VectorizeLoop(thread_loop, analyzer);
if (par_op->GetPredicate(T.thread_var).defined()) {
return IfThenElse(par_op->GetPredicate(T.thread_var).value(),
vectorized_thread_loop);
Expand Down
3 changes: 1 addition & 2 deletions src/op/parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,7 @@ LayoutMap ParallelOpNode::InferLayout(const LayoutInferArgs &T,
// As the pass will do post processing to the layout
auto maybe_remapped_root_ =
IfBufferRemapLoopGenerator::run(root_, T.buffer_remap, T.layout_map);
int vector_size = GetVectorizeSize(maybe_remapped_root_);

int vector_size = GetVectorizeSize(maybe_remapped_root_, T.analyzer);
DLOG(INFO) << "[PlanLoopPartition] vector_size = " << vector_size << '\n';

PrimExpr loop_total_size = 1;
Expand Down
12 changes: 10 additions & 2 deletions src/transform/layout_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <tvm/tir/utils.h>

#include <algorithm>
#include <memory>
#include <queue>

#include "../layout/utils.h"
Expand Down Expand Up @@ -85,6 +86,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
auto &next = infer_list_[cur_infer_id];
auto iter_var = thread_var_vec_[cur_infer_id];
auto thread_bounds = thread_bounds_vec_[cur_infer_id];
arith::Analyzer *cur_analyzer = analyzer_vec_[cur_infer_id].get();
auto buffer_oob = buffer_oob_vec_[cur_infer_id];
// Double-check that 'next' is valid
ICHECK(next.defined()) << "infer_list_[" << cur_infer_id
Expand All @@ -108,7 +110,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
// Run InferLayout
auto updates =
next->InferLayout(LayoutInferArgs{target_, thread_bounds, layout_map,
&analyzer_, buffer_oob},
cur_analyzer, buffer_oob},
level);
// Process the returned updates
for (const auto &[buffer, layout] : updates) {
Expand Down Expand Up @@ -266,6 +268,9 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
ICHECK_EQ(thread_bounds_vec_.size(), infer_list_.size())
<< "Size mismatch: thread_bounds_vec_ and infer_list_ must match in "
"length.";
ICHECK_EQ(analyzer_vec_.size(), infer_list_.size())
<< "Size mismatch: analyzer_vec_ and infer_list_ must match in "
"length.";
ICHECK_EQ(buffer_oob_vec_.size(), infer_list_.size())
<< "Size mismatch: buffer_oob_vec_ and infer_list_ must match in "
"length.";
Expand Down Expand Up @@ -452,6 +457,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
analyzer_vec_.push_back(analyzer_.Clone());

// Compute buffer oob for each buffer in the op
if (const auto *copy = p.as<CopyNode>()) {
Expand Down Expand Up @@ -542,6 +548,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
} else {
thread_bounds_vec_.push_back(Range::FromMinExtent(0, 1));
}
analyzer_vec_.push_back(analyzer_.Clone());
buffer_oob_vec_.push_back(false);
} else {
IRVisitorWithAnalyzer::VisitStmt(op->body);
Expand Down Expand Up @@ -683,6 +690,7 @@ class BufferUseDefCollector : public IRVisitorWithAnalyzer {
IterVarType::kDataPar);
std::vector<IterVar> thread_var_vec_;
std::vector<Range> thread_bounds_vec_;
std::vector<std::unique_ptr<arith::Analyzer>> analyzer_vec_;
std::vector<bool> buffer_oob_vec_;
Target target_;
LayoutMap annotated_layout_map_;
Expand Down Expand Up @@ -1024,7 +1032,7 @@ class LayoutInferencer : public IRMutatorWithAnalyzer {
});

if ((has_non_local || has_cast_operations) && !has_reducer) {
for_node = VectorizeLoop(for_node);
for_node = VectorizeLoop(for_node, analyzer_);
}

if (result_.predicate_map.count(root) && parallel_loop) {
Expand Down
2 changes: 1 addition & 1 deletion src/transform/legalize_vectorized_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class LoopVectorizedLegalizer : IRMutatorWithAnalyzer {
// Change the loop kind from vectorized to serial
for_node.CopyOnWrite()->kind = ForKind::kSerial;
// Apply vectorization transformation to the loop
return VectorizeLoop(for_node);
return VectorizeLoop(for_node, analyzer_);
}
};

Expand Down
99 changes: 64 additions & 35 deletions src/transform/loop_vectorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct VectorizePlanResult {
PrimExpr condition;
};

class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
class VectorizeFindGlobalAccess : public StmtExprVisitor {
public:
VectorizeFindGlobalAccess() = default;

Expand All @@ -60,19 +60,20 @@ class VectorizeFindGlobalAccess : public arith::IRVisitorWithAnalyzer {
void VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "global")
has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
return StmtExprVisitor::VisitStmt_(node);
}

void VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "global")
has_global_access_ = true;
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return StmtExprVisitor::VisitExpr_(node);
}
};

class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
class VectorizePlanner : public arith::IRMutatorWithAnalyzer {
public:
VectorizePlanner() = default;
explicit VectorizePlanner(arith::Analyzer *analyzer)
: arith::IRMutatorWithAnalyzer(analyzer) {}

int Plan(const For &node) {
tvm::transform::PassContext ctxt = tvm::transform::PassContext::Current();
Expand All @@ -92,21 +93,31 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
}

private:
void VisitStmt_(const ForNode *node) final {
Stmt VisitStmt_(const ForNode *node) final {
inner_for_ = node;
auto extent_ptr = as_const_int(analyzer_.Simplify(node->extent));
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return;
bool contains_nested_for = false;
// Must analysis vectorization on the innermost loop
PostOrderVisit(Downcast<Stmt>(node->body), [&](const ObjectRef &obj) {
if (obj.as<ForNode>()) {
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for detecting nested loops incorrectly includes the current ForNode itself. PostOrderVisit will visit node and set contains_nested_for = true even when there are no nested loops. This causes the vectorization analysis to be skipped for all loops. Consider checking if the visited ForNode is different from node before setting the flag.

Suggested change
if (obj.as<ForNode>()) {
const ForNode* for_node = obj.as<ForNode>();
if (for_node && for_node != node) {

Copilot uses AI. Check for mistakes.
contains_nested_for = true;
}
});

if (!contains_nested_for) {
auto extent_ptr = as_const_int(analyzer_->Simplify(node->extent));
// Here I disable dynamic shape completely,
// In order to do it, the Planner should accept an analyzer with
// arithmetic info outside to prove the dividiblity of vector size
if (!extent_ptr) {
vector_size_ = 1;
return ffi::GetRef<Stmt>(node);
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
}
vector_size_ = arith::ZeroAwareGCD(vector_size_, *extent_ptr);
arith::IRVisitorWithAnalyzer::VisitStmt_(node);
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
}

void VisitExpr_(const BufferLoadNode *node) final {
PrimExpr VisitExpr_(const BufferLoadNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
Expand All @@ -115,43 +126,44 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
// constant buffer that tl hack to use as local register.
auto boundary_check = node->buffer->shape[0].as<IntImmNode>();
if (boundary_check && boundary_check->value == 1) {
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}
}
UpdateVectorSize(node->indices, node->buffer);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}

void VisitStmt_(const BufferStoreNode *node) final {
Stmt VisitStmt_(const BufferStoreNode *node) final {
if (node->buffer.scope() == "shared" || node->buffer.scope() == "global" ||
node->buffer.scope() == "shared.dyn")
has_nonlocal_memory_access_ = true;
UpdateVectorSize(node->indices, node->buffer);
return arith::IRVisitorWithAnalyzer::VisitExpr(node->value);
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
}

void VisitStmt_(const IfThenElseNode *node) final {
Stmt VisitStmt_(const IfThenElseNode *node) final {
CheckConditionVectorized(node->condition);
return arith::IRVisitorWithAnalyzer::VisitStmt_(node);
return arith::IRMutatorWithAnalyzer::VisitStmt_(node);
}

void VisitExpr_(const CallNode *node) final {
PrimExpr VisitExpr_(const CallNode *node) final {
if (node->op == builtin::if_then_else()) {
CheckConditionVectorized(node->args[0]);
} else if (node->op == builtin::call_extern()) {
// do not vectorize extern calls
vector_size_ = 1;
}
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}

void CheckConditionVectorized(const PrimExpr &cond) {
// TODO: perform some checks here
}

void VisitExpr_(const CastNode *node) final {
PrimExpr VisitExpr_(const CastNode *node) final {
vector_size_ = arith::ZeroAwareGCD(
vector_load_bits_max_ / node->dtype.bits(), vector_size_);
return arith::IRVisitorWithAnalyzer::VisitExpr_(node);
return arith::IRMutatorWithAnalyzer::VisitExpr_(node);
}

void UpdateVectorSize(const Array<PrimExpr> indices, const Buffer &buffer) {
Expand All @@ -171,19 +183,16 @@ class VectorizePlanner : public arith::IRVisitorWithAnalyzer {
for (int i = 0; i < indices.size(); ++i) {
elem_offset += indices[i] * strides[i];
}

// 2. If element offset is independent with loop_var, ignore it
if (CanProveIndependent(elem_offset, inner_for_->loop_var, &analyzer_)) {
if (CanProveIndependent(elem_offset, inner_for_->loop_var, analyzer_)) {
return;
}

// 3. Tight vectorize bound
vector_size_ = arith::ZeroAwareGCD(vector_size_, vector_load_bits_max_ /
buffer->dtype.bits());

// 4. Try to vectorize buffer load
while (!IndiceCanVectorize(elem_offset, inner_for_->loop_var,
inner_for_->extent, vector_size_, &analyzer_)) {
inner_for_->extent, vector_size_, analyzer_)) {
vector_size_ /= 2;
}
}
Expand Down Expand Up @@ -235,7 +244,14 @@ class VectorizeRewriter : public StmtExprMutator {
const int vector_size_;
};

int GetVectorizeSize(const For &loop) { return VectorizePlanner().Plan(loop); }
int GetVectorizeSize(const For &loop) {
arith::Analyzer analyzer;
return VectorizePlanner(&analyzer).Plan(loop);
}

int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer) {
return VectorizePlanner(analyzer).Plan(loop);
}

bool CanProveIndependent(const PrimExpr &expr, Var var,
arith::Analyzer *analyzer) {
Expand Down Expand Up @@ -274,10 +290,10 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,
if (!analyzer->CanProveEqual(FloorMod(iter_var_size, target_size_for_iter),
0))
return false;

auto simplified_expr = analyzer->Simplify(Substitute(expr, {{var, zero}}));
// The base offset must be divisible
if (!analyzer->CanProveEqual(
FloorMod(Substitute(expr, {{var, zero}}), target_size_for_expr), 0)) {
if (!analyzer->CanProveEqual(FloorMod(simplified_expr, target_size_for_expr),
zero)) {
Copy link

Copilot AI Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing FloorMod(simplified_expr, target_size_for_expr) against zero is incorrect. The result of FloorMod has the same dtype as simplified_expr (i.e., expr.dtype()), but zero has type var.dtype(). These types may differ, causing the comparison to fail. Use make_const(expr.dtype(), 0) instead of zero.

Suggested change
zero)) {
make_const(expr.dtype(), 0))) {

Copilot uses AI. Check for mistakes.
return false;
}

Expand Down Expand Up @@ -308,7 +324,20 @@ bool IndiceCanVectorize(const PrimExpr &expr, Var var,

For VectorizeLoop(const For &loop, int vectorize_hint) {
if (vectorize_hint <= 0) {
VectorizePlanner planner;
arith::Analyzer analyzer;
VectorizePlanner planner(&analyzer);
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
return loop;
auto rewriter = VectorizeRewriter(vectorize_hint);
return Downcast<For>(rewriter(loop));
}

For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
int vectorize_hint) {
if (vectorize_hint <= 0) {
VectorizePlanner planner(analyzer);
vectorize_hint = planner.Plan(loop);
}
if (vectorize_hint == 1)
Expand Down
5 changes: 5 additions & 0 deletions src/transform/loop_vectorize.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,13 @@ using namespace tir;

int GetVectorizeSize(const For &loop);

int GetVectorizeSize(const For &loop, arith::Analyzer *analyzer);

For VectorizeLoop(const For &loop, int vectorize_hint = -1);

For VectorizeLoop(const For &loop, arith::Analyzer *analyzer,
int vectorize_hint = -1);

// Can prove expr is independent with var, i.e. the value of expr doesn't change
// when var changes
bool CanProveIndependent(const PrimExpr &expr, Var var,
Expand Down
Loading