Skip to content

Commit 7f6da09

Browse files
zxybazhSunghyun Park
andauthored
[TIR] Fix Datatype in Lower TVM Builtin (#14347)
Fix data type and add minimal reproducible test. Co-authored-by: Sunghyun Park <[email protected]>
1 parent 36b3097 commit 7f6da09

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

src/tir/transforms/lower_tvm_builtin.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,9 @@ class BuiltinLower : public StmtExprMutator {
239239
}
240240
}
241241
}
242-
PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
242+
PrimExpr total_bytes = make_const(DataType::UInt(64), nbytes);
243243
for (size_t i = 0; i < op->extents.size(); ++i) {
244+
// set total_bytes to uint64 to avoid overflow
244245
total_bytes = total_bytes * op->extents[i];
245246
}
246247
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
@@ -250,13 +251,13 @@ class BuiltinLower : public StmtExprMutator {
250251
Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
251252
throw_last_error),
252253
op->body});
253-
Stmt alloca = LetStmt(
254-
op->buffer_var,
255-
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
256-
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
257-
cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
258-
IntImm(DataType::Int(32), op->dtype.bits())}),
259-
body);
254+
Stmt alloca =
255+
LetStmt(op->buffer_var,
256+
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
257+
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
258+
total_bytes, IntImm(DataType::Int(32), op->dtype.code()),
259+
IntImm(DataType::Int(32), op->dtype.bits())}),
260+
body);
260261

261262
PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
262263
{cast(DataType::Int(32), device_type_),

tests/python/unittest/test_tir_transform_lower_tvm_builtin.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
# under the License.
1717
import tvm
1818
from tvm import te
19+
from tvm.script import tir as T
1920
import numpy as np
20-
from tvm import testing
2121

2222

2323
@tvm.register_func("tvm.test_matmul")
@@ -172,6 +172,23 @@ def build_tir():
172172
tvm.testing.assert_allclose(a.numpy(), expected_value)
173173

174174

175+
def test_lower_overflow_int32():
176+
@T.prim_func
177+
def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32")):
178+
T.func_attr({"global_symbol": "variance4", "tir.noalias": True})
179+
rxplaceholder_red = T.allocate([32], "float32", "global")
180+
T_subtract = T.allocate([822083584], "float32", "global")
181+
rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
182+
rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data)
183+
T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
184+
for ax1, ax2 in T.grid(32, 25690112):
185+
cse_var_1: T.int32 = ax1 * 25690112 + ax2
186+
T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1]
187+
188+
func = variance4
189+
tvm.build(func, target="llvm") # should not crash
190+
191+
175192
if __name__ == "__main__":
176193
test_call_packed_return_non_i32()
177194
test_lower_packed_func()

0 commit comments

Comments
 (0)