Skip to content

Commit b98bc66

Browse files
authored
[Relax] Memory planning for "partially dynamic" shapes (#16466)
This PR improves the memory planning to plan for tensors with shape such as `(m, n)` where only `m` has an integer upper bound (say `m <= 20`). Prior to this PR, the tensor is not planned even though `m` has an upper bound. This is not as ideal since it does not fully leverage the upper bound info. This PR enhances the planning, so that a memory chunk of `20 * n` will be allocated first, and the tensor will then be allocated out from this chunk. This enables us to fully leverage the upper bound info. A unit test is provided.
1 parent 593a4bd commit b98bc66

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

src/relax/transform/static_plan_block_memory.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -828,15 +828,19 @@ class StorageAllocationRewriter : public ExprMutator {
828828
const auto* shape = sinfo->shape.as<ShapeExprNode>();
829829
ICHECK_NOTNULL(shape);
830830
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
831-
if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape)) {
831+
if (!IsStaticShape(shape->values)) {
832832
ICHECK(!sinfo->IsUnknownDtype());
833833
ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
834-
StorageToken token(upper_bounded_shape, sinfo->dtype);
834+
PrimExpr bytes = upper_bounded_shape[0];
835+
for (int i = 1; i < static_cast<int>(upper_bounded_shape.size()); ++i) {
836+
bytes *= upper_bounded_shape[i];
837+
}
838+
bytes *= sinfo->dtype.bytes() * sinfo->dtype.lanes();
835839
Call alloc_storage(mem_alloc_storage,
836-
{/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64), token->bytes)}),
840+
{/*size=*/ShapeExpr({bytes}),
837841
/*virtual_device_index=*/Downcast<PrimValue>(call->args[2]),
838842
/*storage_scope=*/StringImm("global"), //
839-
/*dtype=*/DataTypeImm(token->dtype)});
843+
/*dtype=*/DataTypeImm(sinfo->dtype)});
840844
Var storage = builder_->Emit(alloc_storage, "storage");
841845
return Call(mem_alloc_tensor, {storage, //
842846
/*offset=*/PrimValue::Int64(0),

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,70 @@ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
11711171
tvm.ir.assert_structural_equal(mod, Expected)
11721172

11731173

1174+
def test_call_tir_dyn_plan_partially_dynamic():
1175+
# fmt: off
1176+
@I.ir_module
1177+
class Module:
1178+
@T.prim_func
1179+
def tir_full(var_full: T.handle, n: T.int64, m: T.int64):
1180+
T.evaluate(0)
1181+
1182+
@T.prim_func
1183+
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
1184+
T.evaluate(0)
1185+
1186+
@R.function
1187+
def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
1188+
n = T.int64()
1189+
m = T.int64()
1190+
R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True, "relax.memory_plan_dynamic_func_output": True})
1191+
cls = Module
1192+
alloc: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
1193+
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, m])))
1194+
full: R.Tensor((n, m), dtype="float32") = alloc
1195+
alloc1: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
1196+
_1: R.Tuple = cls.tir_exp(full, alloc1)
1197+
lv2: R.Tensor((n, m), dtype="float32") = alloc1
1198+
alloc2: R.Tensor((n, m), dtype="float32") = R.builtin.alloc_tensor(R.shape([n, m]), R.dtype("float32"), R.prim_value(0))
1199+
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
1200+
lv3: R.Tensor((n, m), dtype="float32") = alloc2
1201+
return lv3
1202+
1203+
@I.ir_module
1204+
class Expected:
1205+
@T.prim_func
1206+
def tir_full(var_full: T.handle, n: T.int64, m: T.int64):
1207+
T.evaluate(0)
1208+
1209+
@T.prim_func
1210+
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
1211+
T.evaluate(0)
1212+
1213+
@R.function
1214+
def main(s: R.Shape(["n", "m"])) -> R.Tensor(("n", "m"), dtype="float32"):
1215+
n = T.int64()
1216+
m = T.int64()
1217+
R.func_attr({"relax.force_pure": True, "tir_var_upper_bound": {"n": 20}})
1218+
cls = Expected
1219+
storage: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
1220+
alloc: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
1221+
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n, m])))
1222+
full: R.Tensor((n, m), dtype="float32") = alloc
1223+
storage1: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
1224+
alloc1: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
1225+
_1: R.Tuple = cls.tir_exp(full, alloc1)
1226+
lv2: R.Tensor((n, m), dtype="float32") = alloc1
1227+
storage2: R.Object = R.memory.alloc_storage(R.shape([20 * m * 4]), R.prim_value(0), R.str("global"), R.dtype("float32"))
1228+
alloc2: R.Tensor((n, m), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n, m]), R.dtype("float32"))
1229+
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
1230+
lv3: R.Tensor((n, m), dtype="float32") = alloc2
1231+
return lv3
1232+
# fmt: on
1233+
1234+
mod = relax.transform.StaticPlanBlockMemory()(Module)
1235+
tvm.ir.assert_structural_equal(mod, Expected)
1236+
1237+
11741238
def test_function_independence():
11751239
# fmt: off
11761240
@tvm.script.ir_module

0 commit comments

Comments
 (0)