From cc1dde66e6d9af58b313b25129eb629acd526099 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Dec 2025 20:07:28 +0800 Subject: [PATCH] [Refactor] Improve scalarization handling in vectorization logic * Enhanced the Vectorize function to retain the original body for scalarization if needed. * Updated VisitStmt and VisitStmt_ methods to return the original statement when scalarization is required, ensuring proper handling of statements during vectorization. * Added checks after visiting expressions and statements to maintain the integrity of the original structure when scalarization is triggered. This refactor aims to streamline the vectorization process and improve the handling of scalarization scenarios. --- src/transform/vectorize_loop.cc | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/transform/vectorize_loop.cc b/src/transform/vectorize_loop.cc index a7b31e1d7..7eb5da87a 100644 --- a/src/transform/vectorize_loop.cc +++ b/src/transform/vectorize_loop.cc @@ -214,7 +214,12 @@ class TLVectorizer : public StmtMutator, // the mutator invocation pattern at call sites. static Stmt Vectorize(const Var &var, const PrimExpr &var_lanes, Stmt body) { TLVectorizer vec{var, var_lanes}; + Stmt original_body = body; auto vec_stmt = vec(std::move(body)); + // If scalarization is needed, scalarize the entire original body + if (vec.need_scalarize_) { + return vec.Scalarize(original_body); + } return vec_stmt; } @@ -224,15 +229,12 @@ class TLVectorizer : public StmtMutator, } Stmt VisitStmt(const Stmt &stmt) final { - ICHECK(!need_scalarize_); - Stmt ret = StmtMutator::VisitStmt(stmt); + // If scalarization is already needed, return original stmt unchanged + // to let the top-level Vectorize handle it if (need_scalarize_) { - auto scalarized_stmt = Scalarize(stmt); - need_scalarize_ = false; - return scalarized_stmt; - } else { - return ret; + return stmt; } + return StmtMutator::VisitStmt(stmt); } PrimExpr VisitExpr(const PrimExpr &e) final {