Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/transform/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ class TLVectorizer : public StmtMutator,
// the mutator invocation pattern at call sites.
static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) {
TLVectorizer vec{var, var_lanes};
Stmt original_body = body;
auto vec_stmt = vec(std::move(body));
// If scalarization is needed, scalarize the entire original body
if (vec.need_scalarize_) {
return vec.Scalarize(original_body);
}
return vec_stmt;
}

Expand All @@ -224,15 +229,12 @@ class TLVectorizer : public StmtMutator,
}

Stmt VisitStmt(const Stmt &stmt) final {
ICHECK(!need_scalarize_);
Stmt ret = StmtMutator::VisitStmt(stmt);
// If scalarization is already needed, return original stmt unchanged
// to let the top-level Vectorize handle it
if (need_scalarize_) {
auto scalarized_stmt = Scalarize(stmt);
need_scalarize_ = false;
return scalarized_stmt;
} else {
return ret;
return stmt;
}
return StmtMutator::VisitStmt(stmt);
}

PrimExpr VisitExpr(const PrimExpr &e) final {
Expand Down
Loading