Skip to content

Commit 291c047

Browse files
CharlieFRuantqchen
andauthored
[TIR] Fix Bug in VectorizeLoop (#17039)
* [TIR] Fix Bug in VectorizeLoop This PR fixes a bug in vectorize loop introduced related to recent change. The visit to condition can write need scalarize to true then the followup visit to then case can trigger an ICHECK. The visit to let value can also write need scalarize flag in which case we need to immediately scalarize. * Add unit test --------- Co-authored-by: tqchen <[email protected]>
1 parent 71f7af7 commit 291c047

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

src/tir/transforms/vectorize_loop.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -676,12 +676,16 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
676676
Stmt VisitStmt_(const IfThenElseNode* op) final {
677677
ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
678678
PrimExpr condition = this->VisitExpr(op->condition);
679+
// need scalarize can be marked as true during visit of condition
680+
bool cond_need_scalarize = false;
681+
std::swap(cond_need_scalarize, need_scalarize_);
682+
// temp clear need_scalarize flag, so VisitStmt
683+
// won't trigger an ICHECK eror
679684
Stmt then_case = this->VisitStmt(op->then_case);
680685
Optional<Stmt> else_case = NullOpt;
681686
if (op->else_case) {
682687
else_case = this->VisitStmt(op->else_case.value());
683688
}
684-
685689
// Check if we can rewrite the condition with predicated buffers
686690
if (EnableBufferLevelPredication(target_) &&
687691
condition.dtype().is_scalable_or_fixed_length_vector() && !else_case.defined()) {
@@ -693,7 +697,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
693697
}
694698
}
695699

696-
if (condition.dtype().is_scalable_or_fixed_length_vector()) {
700+
if (cond_need_scalarize || condition.dtype().is_scalable_or_fixed_length_vector()) {
697701
return Scalarize(GetRef<Stmt>(op));
698702
}
699703
if (condition.same_as(op->condition) && then_case.same_as(op->then_case) &&
@@ -710,6 +714,12 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
710714
// LetStmt
711715
Stmt VisitStmt_(const LetStmtNode* op) final {
712716
PrimExpr value = this->VisitExpr(op->value);
717+
// if visit of value triggers need scalarize
718+
// we need to scalarize the let
719+
if (need_scalarize_) {
720+
need_scalarize_ = false;
721+
Scalarize(GetRef<Stmt>(op));
722+
}
713723
ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is binded twice";
714724
let_binding_[op->var] = value;
715725

tests/python/tir-transform/test_tir_transform_vectorize.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
import pytest
18+
1719
import tvm
1820
import tvm.testing
1921
from tvm import te
2022
from tvm.script import ir as I
2123
from tvm.script import tir as T
22-
import pytest
23-
2424

2525
simple_target = tvm.target.Target("llvm -mtriple=x86_64-linux-gnu")
2626
sve_target = tvm.target.Target("llvm -device=arm_cpu -mtriple=aarch64-linux-gnu -mattr=+v8.2a,+sve")
@@ -312,6 +312,29 @@ def main(A: T.Buffer((25,), "float32"), n: T.int32):
312312
tvm.ir.assert_structural_equal(mod, After)
313313

314314

315+
def test_vectorize_let_if_then_else():
316+
@I.ir_module
317+
class Before:
318+
@T.prim_func
319+
def main():
320+
for i in T.vectorized(4):
321+
if i < 2:
322+
result: T.int32 = T.if_then_else(i < 1, 1, 2)
323+
324+
@I.ir_module
325+
class After:
326+
@T.prim_func
327+
def main():
328+
for i_s in range(4):
329+
if i_s < 2:
330+
result: T.int32 = T.if_then_else(i_s < 1, 1, 2)
331+
T.evaluate(0)
332+
333+
with tvm.target.Target(simple_target):
334+
mod = tvm.tir.transform.VectorizeLoop()(Before)
335+
tvm.ir.assert_structural_equal(mod, After)
336+
337+
315338
def test_vectorize_while_fail():
316339
"""A while loop inside a vectorized loop should fail."""
317340

0 commit comments

Comments
 (0)