diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index a7b31e1d7..7eb5da87a 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -214,7 +214,12 @@ class TLVectorizer : public StmtMutator, // the mutator invocation pattern at call sites. static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) { TLVectorizer vec{var, var_lanes}; + Stmt original_body = body; auto vec_stmt = vec(std::move(body)); + // If scalarization is needed, scalarize the entire original body + if (vec.need_scalarize_) { + return vec.Scalarize(original_body); + } return vec_stmt; } @@ -224,15 +229,12 @@ class TLVectorizer : public StmtMutator, } Stmt VisitStmt(const Stmt &stmt) final { - ICHECK(!need_scalarize_); - Stmt ret = StmtMutator::VisitStmt(stmt); + // If scalarization is already needed, return original stmt unchanged + // to let the top-level Vectorize handle it if (need_scalarize_) { - auto scalarized_stmt = Scalarize(stmt); - need_scalarize_ = false; - return scalarized_stmt; - } else { - return ret; + return stmt; } + return StmtMutator::VisitStmt(stmt); } PrimExpr VisitExpr(const PrimExpr &e) final {