Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -828,15 +828,19 @@ class StorageAllocationRewriter : public ExprMutator {
const auto* shape = sinfo->shape.as<ShapeExprNode>();
ICHECK_NOTNULL(shape);
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape)) {
if (!IsStaticShape(shape->values)) {
ICHECK(!sinfo->IsUnknownDtype());
ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
StorageToken token(upper_bounded_shape, sinfo->dtype);
PrimExpr bytes = upper_bounded_shape[0];
for (int i = 1; i < static_cast<int>(upper_bounded_shape.size()); ++i) {
bytes *= upper_bounded_shape[i];
}
bytes *= sinfo->dtype.bytes() * sinfo->dtype.lanes();
Call alloc_storage(mem_alloc_storage,
{/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64), token->bytes)}),
{/*size=*/ShapeExpr({bytes}),
/*virtual_device_index=*/Downcast<PrimValue>(call->args[2]),
/*storage_scope=*/StringImm("global"), //
/*dtype=*/DataTypeImm(token->dtype)});
/*dtype=*/DataTypeImm(sinfo->dtype)});
Var storage = builder_->Emit(alloc_storage, "storage");
return Call(mem_alloc_tensor, {storage, //
/*offset=*/PrimValue::Int64(0),
Expand Down
64 changes: 64 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,70 @@ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
tvm.ir.assert_structural_equal(mod, Expected)


def test_call_tir_dyn_plan_partially_dynamic():
# fmt: off
@I.ir_module
class Module:
@T.prim_func
def tir_full(var_full: T.handle, n: T.int64, m: T.int64):
T.evaluate(0)

@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
n = T.int64()
m = T.int64()
R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True, "relax.memory_plan_dynamic_func_output": True})
cls = Module
alloc: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, m])))
full: R.Tensor((n, m), dtype="float32") = alloc
alloc1: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_1: R.Tuple = cls.tir_exp(full, alloc1)
lv2: R.Tensor((n, m), dtype="float32") = alloc1
alloc2: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n, m), dtype="float32") = alloc2
return lv3

@I.ir_module
class Expected:
@T.prim_func
def tir_full(var_full: T.handle, n: T.int64, m: T.int64):
T.evaluate(0)

@T.prim_func
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
T.evaluate(0)

@R.function
def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
n = T.int64()
m = T.int64()
R.func_attr({"relax.force_pure": True, "tir_var_upper_bound": {"n": 20}})
cls = Expected
storage: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, m])))
full: R.Tensor((n, m), dtype="float32") = alloc
storage1: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_1: R.Tuple = cls.tir_exp(full, alloc1)
lv2: R.Tensor((n, m), dtype="float32") = alloc1
storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc2: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
lv3: R.Tensor((n, m), dtype="float32") = alloc2
return lv3
# fmt: on

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)


def test_function_independence():
# fmt: off
@tvm.script.ir_module
Expand Down