Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,9 @@ class BuiltinLower : public StmtExprMutator {
}
}
}
PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
PrimExpr total_bytes = make_const(DataType::UInt(64), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
// set total_bytes to uint64 to avoid overflow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

possible for us to also add a guard check here to verify that total_bytes isn't negative? although...i guess it is kinda unlikely to alloc > (1 << 31) worth of space.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think technically it's possible to allocate 2G of memory and also possible to have an intermediate large result between passes. On the other hand each loop extent are supposed to be positive so unless overflow it's unlikely to produce any negative results. I also tried locally to multiply uint64 with negative int32, turns out it will trigger an error. Therefore I think we are good with current change here.

total_bytes = total_bytes * op->extents[i];
}
ICHECK(device_type_.defined()) << "Unknown device type in current IR";
Expand All @@ -250,13 +251,13 @@ class BuiltinLower : public StmtExprMutator {
Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
throw_last_error),
op->body});
Stmt alloca = LetStmt(
op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
cast(DataType::UInt(64), total_bytes), IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);
Stmt alloca =
LetStmt(op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_), cast(DataType::Int(32), device_id_),
total_bytes, IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);

PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_),
Expand Down
19 changes: 18 additions & 1 deletion tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
# under the License.
import tvm
from tvm import te
from tvm.script import tir as T
import numpy as np
from tvm import testing


@tvm.register_func("tvm.test_matmul")
Expand Down Expand Up @@ -172,6 +172,23 @@ def build_tir():
tvm.testing.assert_allclose(a.numpy(), expected_value)


def test_lower_overflow_int32():
@T.prim_func
def variance4(rxplaceholder: T.Buffer((T.int64(1), T.int64(32), T.int64(25690112)), "float32")):
T.func_attr({"global_symbol": "variance4", "tir.noalias": True})
rxplaceholder_red = T.allocate([32], "float32", "global")
T_subtract = T.allocate([822083584], "float32", "global")
rxplaceholder_red_1 = T.Buffer((T.int64(32),), data=rxplaceholder_red)
rxplaceholder_1 = T.Buffer((T.int64(822083584),), data=rxplaceholder.data)
T_subtract_1 = T.Buffer((T.int64(822083584),), data=T_subtract)
for ax1, ax2 in T.grid(32, 25690112):
cse_var_1: T.int32 = ax1 * 25690112 + ax2
T_subtract_1[cse_var_1] = rxplaceholder_1[cse_var_1] - rxplaceholder_red_1[ax1]

func = variance4
tvm.build(func, target="llvm") # should not crash


if __name__ == "__main__":
test_call_packed_return_non_i32()
test_lower_packed_func()