Skip to content

Commit d8e39fd

Browse files
authored
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes (#10172)
[TIR] Fix Ramp int32~64 mismatch in VectorizeLoop and NarrowDataType passes
1 parent d8d28bf commit d8e39fd

File tree

4 files changed

+51
-2
lines changed

4 files changed

+51
-2
lines changed

src/tir/transforms/narrow_datatype.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,23 @@ class DataTypeRewriter : public StmtExprMutator {
253253
return StmtExprMutator::VisitExpr_(op);
254254
}
255255

256+
PrimExpr VisitExpr_(const RampNode* op) final {
257+
PrimExpr base = VisitExpr(op->base);
258+
PrimExpr stride = VisitExpr(op->stride);
259+
if (base.same_as(op->base) && stride.same_as(op->stride)) {
260+
return GetRef<PrimExpr>(op);
261+
} else {
262+
if (base.dtype().is_int()) {
263+
ICHECK(stride.dtype().is_int()) << "Ramp base is int but stride is " << stride.dtype();
264+
int bits = std::max(base.dtype().bits(), stride.dtype().bits());
265+
DataType dtype = base.dtype().with_bits(bits);
266+
if (base.dtype() != dtype) base = cast(dtype, base);
267+
if (stride.dtype() != dtype) stride = cast(dtype, stride);
268+
}
269+
return Ramp(base, stride, op->lanes);
270+
}
271+
}
272+
256273
PrimExpr VisitExpr_(const SizeVarNode* op) final {
257274
if (visitor_.vmap.find(op) != visitor_.vmap.end()) {
258275
if (vmap_.find(op) == vmap_.end()) {

src/tir/transforms/vectorize_loop.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
101101
using StmtMutator::operator();
102102

103103
Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
104-
ramp_ = Ramp(0, 1, var_lanes);
104+
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
105105
}
106106

107107
Stmt VisitStmt(const Stmt& stmt) final {

tests/python/unittest/test_tir_transform_narrow_datatype.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def lower_stmt(params, stmt, target_bits):
2727
return stmt
2828

2929

30-
def lower_sch(sch, args, target_bits):
30+
def lower_sch(sch, args, target_bits, extra_passes=None):
3131
binds = {}
3232
arg_list = []
3333
for x in args:
@@ -42,6 +42,9 @@ def lower_sch(sch, args, target_bits):
4242

4343
mod = schedule_to_module(sch, args)
4444
mod = tvm.tir.transform.StorageFlatten(64)(mod)
45+
if extra_passes:
46+
for p in extra_passes:
47+
mod = p(mod)
4548
return tvm.tir.transform.NarrowDataType(target_bits)(mod)["main"].body
4649

4750

@@ -255,6 +258,25 @@ def check(shape, index, target_bits, target_dtype):
255258
)
256259

257260

261+
def test_ramp_dtype_consistency():
262+
"""
263+
for (i :int64, (int64)0, (int64)4) {
264+
A[ramp(i*(int64)2, (int64)1, 2)] = cast(int64, 2 ** 31 - 1) * i;
265+
}
266+
The infer result:
267+
base: int64 -> int64 (since i is involved in another int64 expr)
268+
stride: int64 -> int32
269+
270+
Thus ramp should still use int64 for both stride and base after rewrite.
271+
"""
272+
n = tvm.tir.IntImm("int64", 4)
273+
m = tvm.tir.IntImm("int64", 2)
274+
A = te.compute((n, m), lambda i, j: tvm.tir.Cast("int64", 2 ** 31 - 1) * i, name="A")
275+
s = te.create_schedule(A.op)
276+
s[A].vectorize(A.op.axis[1])
277+
lower_sch(s, [A], 32, extra_passes=[tvm.tir.transform.VectorizeLoop()])
278+
279+
258280
if __name__ == "__main__":
259281
test_basic()
260282
test_thread_axis()
@@ -263,3 +285,4 @@ def check(shape, index, target_bits, target_dtype):
263285
test_slice()
264286
test_relay_basic()
265287
test_relay_take()
288+
test_ramp_dtype_consistency()

tests/python/unittest/test_tir_transform_vectorize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,14 @@ def test_ir(A, B, C):
205205
assert expected in error_msg
206206

207207

208+
def test_vectorize_dtype_mismatch():
209+
n = tvm.tir.IntImm("int64", 4)
210+
A = te.compute((n,), lambda i: tvm.tir.IntImm("int64", 2 ** 31 - 1) + i, name="A")
211+
s = te.create_schedule(A.op)
212+
s[A].vectorize(A.op.axis[0])
213+
tvm.lower(s, [A], "llvm", simple_mode=True)
214+
215+
208216
if __name__ == "__main__":
209217
test_vectorize_vector()
210218
test_vectorize_with_if()
@@ -214,3 +222,4 @@ def test_ir(A, B, C):
214222
test_vectorize_with_ge_cond()
215223
test_vectorize_let()
216224
test_vectorize_while_fail()
225+
test_vectorize_dtype_mismatch()

0 commit comments

Comments
 (0)