Skip to content

Commit 727575f

Browse files
committed
[Unity][Transform] Memory planning for dynamic-shape func return
This PR enhances the static block memory planning pass. Prior to this PR, the memory planning only works on memory allocation that is not externally referenced. In dynamic shape settings, such memory allocation is not fully static and may lead to memory fragmentation. This PR enhances the behavior, so that for such memory allocation, we first allocate a storage with regard to its estimated upper bound (when known), and then allocate the tensor with the actual dynamic shape out from the storage. This will ensure the static memory allocation and avoid memory fragmentation.
1 parent 5c87bfe commit 727575f

File tree

2 files changed

+190
-55
lines changed

2 files changed

+190
-55
lines changed

src/relax/transform/static_plan_block_memory.cc

Lines changed: 128 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,82 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
296296
std::vector<const BindingBlockNode*> block_stack_;
297297
};
298298

299+
/*!
300+
* \brief Set the upper bound of the TIR variables that appear in
301+
* the input function signature in the analyzer.
302+
* \param func The function to be analyzed.
303+
* \param ana The analyzer which contains the TIR var upper bounds.
304+
*/
305+
void SetTIRVarUpperBound(Function func, arith::Analyzer* ana) {
306+
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
307+
// memory planning.
308+
// NOTE: we only apply the annotated upper bounds to the TIR variables that
309+
// appear in the **function signature**.
310+
Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
311+
func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
312+
.value_or(Map<ObjectRef, ObjectRef>());
313+
std::unordered_map<String, IntImm> var_upper_bound_attr;
314+
// We manually check the value type to ensure the values are all positive IntImm.
315+
for (auto it : var_upper_bound_attr_raw) {
316+
const auto* key = it.first.as<StringObj>();
317+
const auto* value = it.second.as<IntImmNode>();
318+
CHECK(key != nullptr)
319+
<< "The entry key of attr `tir_var_upper_bound` should be string. However "
320+
<< it.first->GetTypeKey() << " is got.";
321+
CHECK(value != nullptr)
322+
<< "The entry value of attr `tir_var_upper_bound` should be integer. However "
323+
<< it.second->GetTypeKey() << " is got.";
324+
CHECK_GT(value->value, 0)
325+
<< "The entry value of attr `tir_var_upper_bound` should be a positive integer, while "
326+
<< value->value << " is got.";
327+
var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
328+
}
329+
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(func));
330+
for (const tir::Var& tir_var : var_in_signature) {
331+
auto it = var_upper_bound_attr.find(tir_var->name_hint);
332+
if (it != var_upper_bound_attr.end()) {
333+
ana->Bind(tir_var,
334+
tvm::Range::FromMinExtent(tvm::IntImm(DataType::Int(64), 0),
335+
tvm::IntImm(DataType::Int(64), (*it).second->value + 1)));
336+
}
337+
}
338+
}
339+
340+
/*!
341+
* \brief Use the upper bounds of TIR vars to compute the upper
342+
* bound of a given shape.
343+
* \param shape The input shape to be computed.
344+
* \param ana The arithmetic analyzer that contains the upper bounds
345+
* of TIR variables
346+
* \return The upper-bounded shape. When a dimension's upper bound
347+
* cannot be determined, we keep the dimension unchanged.
348+
*/
349+
Array<PrimExpr> GetUpperBoundShape(Array<PrimExpr> shape, arith::Analyzer* ana) {
350+
// Use the upper bounds of TIR vars as their values.
351+
Array<PrimExpr> upper_bounded_shape;
352+
upper_bounded_shape.reserve(shape.size());
353+
for (const PrimExpr& dim_len : shape) {
354+
int64_t max_bound = ana->const_int_bound(dim_len)->max_value;
355+
if (max_bound == std::numeric_limits<int64_t>::max()) {
356+
upper_bounded_shape.push_back(dim_len);
357+
} else {
358+
upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
359+
}
360+
}
361+
return upper_bounded_shape;
362+
}
363+
364+
/*! \brief Check if a shape is static (a.k.a., has no TIR variable). */
365+
bool IsStaticShape(Array<PrimExpr> shape) {
366+
for (const PrimExpr& dim : shape) {
367+
const auto* int_len = dim.as<IntImmNode>();
368+
if (!int_len) {
369+
return false;
370+
}
371+
}
372+
return true;
373+
}
374+
299375
/*!
300376
* \brief The visitor class for storage token initialization.
301377
* \details It goes through the entire function to get the storage tokens
@@ -330,40 +406,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
330406
explicit StorageAllocatorInit(const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
331407

332408
void VisitExpr_(const FunctionNode* func) final {
333-
// Use the attribute-annotated TIR var upper bounds as the TIR var values for
334-
// memory planning.
335-
// NOTE: we only apply the annotated upper bounds to the TIR variables that
336-
// appear in the **function signature**.
337-
Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
338-
func->GetAttr<Map<ObjectRef, ObjectRef>>("tir_var_upper_bound")
339-
.value_or(Map<ObjectRef, ObjectRef>());
340-
std::unordered_map<String, IntImm> var_upper_bound_attr;
341-
// We manually check the value type to ensure the values are all positive IntImm.
342-
for (auto it : var_upper_bound_attr_raw) {
343-
const auto* key = it.first.as<StringObj>();
344-
const auto* value = it.second.as<IntImmNode>();
345-
CHECK(key != nullptr)
346-
<< "The entry key of attr `tir_var_upper_bound` should be string. However "
347-
<< it.first->GetTypeKey() << " is got.";
348-
CHECK(value != nullptr)
349-
<< "The entry value of attr `tir_var_upper_bound` should be integer. However "
350-
<< it.second->GetTypeKey() << " is got.";
351-
CHECK_GT(value->value, 0)
352-
<< "The entry value of attr `tir_var_upper_bound` should be a positive integer, while "
353-
<< value->value << " is got.";
354-
var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
355-
}
356-
Array<tir::Var> var_in_signature = TIRVarsInStructInfo(GetStructInfo(GetRef<Function>(func)));
357-
var_upper_bound_.clear();
358-
for (const tir::Var& tir_var : var_in_signature) {
359-
auto it = var_upper_bound_attr.find(tir_var->name_hint);
360-
if (it != var_upper_bound_attr.end()) {
361-
ana_.Bind(tir_var, tvm::Range::FromMinExtent(
362-
tvm::IntImm(DataType::Int(64), 0),
363-
tvm::IntImm(DataType::Int(64), (*it).second->value + 1)));
364-
}
365-
}
366-
409+
// Set the upper bound of TIR variables in the analyzer.
410+
SetTIRVarUpperBound(GetRef<Function>(func), &ana_);
367411
// Recurse into the function to get its tokens.
368412
Tokens body_tokens = GetTokens(func->body);
369413
// Discard the tokens used by the function return value, as they are external referenced.
@@ -457,32 +501,20 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
457501
// - the tensor has known dtype;
458502
// - no storage token was created for this call before.
459503
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
460-
const auto* shape = sinfo->shape.as<ShapeExprNode>();
461504
ICHECK_NOTNULL(sinfo);
505+
const auto* shape = sinfo->shape.as<ShapeExprNode>();
462506
ICHECK_NOTNULL(shape);
463507
ICHECK(!sinfo->IsUnknownDtype());
464508
ICHECK(sinfo->dtype == Downcast<DataTypeImm>(call->args[1])->value);
465509
ICHECK(!token_map_.count(call));
466510

467511
// Use the upper bounds of TIR vars as their values.
468-
Array<PrimExpr> upper_bounded_shape;
469-
upper_bounded_shape.reserve(shape->values.size());
470-
for (const PrimExpr& dim_len : shape->values) {
471-
int64_t max_bound = ana_.const_int_bound(dim_len)->max_value;
472-
if (max_bound == std::numeric_limits<int64_t>::max()) {
473-
upper_bounded_shape.push_back(dim_len);
474-
} else {
475-
upper_bounded_shape.push_back(tvm::IntImm(DataType::Int(64), max_bound));
476-
}
477-
}
512+
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
478513

479514
// No support for TIR vars that are not bounded.
480-
for (const PrimExpr& dim_len : upper_bounded_shape) {
481-
const auto* int_len = dim_len.as<IntImmNode>();
482-
if (!int_len) {
483-
token_map_[call] = Tokens();
484-
return Tokens();
485-
}
515+
if (!IsStaticShape(upper_bounded_shape)) {
516+
token_map_[call] = Tokens();
517+
return Tokens();
486518
}
487519

488520
// Create and set token.
@@ -558,8 +590,6 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
558590
* a PrimFunc inside the IRModule.
559591
*/
560592
const IRModule& ctx_mod_;
561-
/*! \brief The mapping from TIR variables to their respective upper bound values. */
562-
std::unordered_map<tir::Var, IntImm, ObjectPtrHash, ObjectPtrEqual> var_upper_bound_;
563593
/*! \brief The mapping from each token to the binding block where it is created. */
564594
std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
565595
/*! \brief The mapping from each token to the Exprs that are using this token. */
@@ -729,8 +759,17 @@ class StorageAllocationRewriter : public ExprMutator {
729759
if (func_ == nullptr) {
730760
continue;
731761
}
762+
constexpr static const char* plan_dyn_attr_ = "relax.memory_plan_dynamic_func_output";
763+
plan_dynamic_output_ = static_cast<bool>(
764+
func_->GetAttr<IntImm>(plan_dyn_attr_).value_or(IntImm(DataType::Int(32), 0))->value);
765+
if (plan_dynamic_output_) {
766+
SetTIRVarUpperBound(GetRef<Function>(func_), &ana_);
767+
}
732768
token2storage_var_.clear();
733769
Function func = Downcast<Function>(this->VisitExpr_(func_));
770+
if (plan_dynamic_output_) {
771+
func = WithoutAttr(func, plan_dyn_attr_);
772+
}
734773
builder_->UpdateFunction(gv, func);
735774
}
736775
return builder_->GetContextIRModule();
@@ -740,8 +779,13 @@ class StorageAllocationRewriter : public ExprMutator {
740779
using ExprMutator::VisitExpr_;
741780

742781
Expr VisitExpr_(const CallNode* call) final {
782+
static const Op& alloc_tensor_op = Op::Get("relax.builtin.alloc_tensor");
783+
static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
784+
static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
743785
auto it = alloc_tensor2token_.find(call);
744786
if (it != alloc_tensor2token_.end()) {
787+
// Case 1. This `alloc_tensor` is planned for memory reuse.
788+
ICHECK_EQ(call->op, alloc_tensor_op);
745789
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
746790
ICHECK_NOTNULL(sinfo);
747791
ICHECK_NOTNULL(sinfo->shape.as<ShapeExprNode>());
@@ -753,7 +797,6 @@ class StorageAllocationRewriter : public ExprMutator {
753797
Var storage_var{nullptr};
754798
auto it_token = token2storage_var_.find(token.get());
755799
if (it_token == token2storage_var_.end()) {
756-
static const Op& mem_alloc_storage = Op::Get("relax.memory.alloc_storage");
757800
ShapeExpr size({tir::make_const(DataType::Int(64), token->bytes)});
758801
PrimValue virtual_device_index = runtime_device_index;
759802
std::string storage_scope = "global";
@@ -769,16 +812,46 @@ class StorageAllocationRewriter : public ExprMutator {
769812
}
770813

771814
// And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`.
772-
static const Op& mem_alloc_tensor = Op::Get("relax.memory.alloc_tensor");
773815
PrimValue offset = PrimValue::Int64(0);
774816
DataType dtype = sinfo->dtype;
775817
return Call(mem_alloc_tensor, {storage_var, offset, sinfo->shape.value(), DataTypeImm(dtype)},
776818
Attrs());
819+
} else if (plan_dynamic_output_ && call->op == alloc_tensor_op) {
820+
// Case 2. For a `alloc_tensor` that is not planned for memory reuse,
821+
// we would still like to allocate **static** memory for the tensor.
822+
// So in case the tensor shape is dynamic but has an upper bound
823+
// estimation, we allocate a storage to its upper bound size, and
824+
// allocate a tensor out from it with the actual symbolic shape.
825+
826+
const auto* sinfo = call->struct_info_.as<TensorStructInfoNode>();
827+
ICHECK_NOTNULL(sinfo);
828+
const auto* shape = sinfo->shape.as<ShapeExprNode>();
829+
ICHECK_NOTNULL(shape);
830+
Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape(shape->values, &ana_);
831+
if (!IsStaticShape(shape->values) && IsStaticShape(upper_bounded_shape)) {
832+
ICHECK(!sinfo->IsUnknownDtype());
833+
ICHECK_EQ(sinfo->dtype, Downcast<DataTypeImm>(call->args[1])->value);
834+
StorageToken token(upper_bounded_shape, sinfo->dtype);
835+
Call alloc_storage(mem_alloc_storage,
836+
{/*size=*/ShapeExpr({tvm::IntImm(DataType::Int(64), token->bytes)}),
837+
/*virtual_device_index=*/Downcast<PrimValue>(call->args[2]),
838+
/*storage_scope=*/StringImm("global"), //
839+
/*dtype=*/DataTypeImm(token->dtype)});
840+
Var storage = builder_->Emit(alloc_storage, "storage");
841+
return Call(mem_alloc_tensor, {storage, //
842+
/*offset=*/PrimValue::Int64(0),
843+
/*shape=*/GetRef<ShapeExpr>(shape), //
844+
/*dtype=*/DataTypeImm(sinfo->dtype)});
845+
}
777846
}
778847

779848
return ExprMutator::VisitExpr_(call);
780849
}
781850

851+
/*! \brief The arithmetic analyzer. */
852+
arith::Analyzer ana_;
853+
/*! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */
854+
bool plan_dynamic_output_;
782855
/*!
783856
* \brief The mapping from each memory-reusable `builtin.alloc_tensor` to
784857
its corresponding underlying storage token that it is using.

tests/python/relax/test_transform_static_plan_block_memory.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,68 @@ def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
11091109
tvm.ir.assert_structural_equal(mod, Expected)
11101110

11111111

1112+
def test_call_tir_dyn_plan_dynamic_func_output():
1113+
# fmt: off
1114+
@I.ir_module
1115+
class Module:
1116+
@T.prim_func
1117+
def tir_full(var_full: T.handle, n: T.int64):
1118+
T.evaluate(0)
1119+
1120+
@T.prim_func
1121+
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
1122+
T.evaluate(0)
1123+
1124+
@R.function
1125+
def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
1126+
n = T.int64()
1127+
R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True, "relax.memory_plan_dynamic_func_output": True})
1128+
cls = Module
1129+
alloc: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
1130+
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
1131+
full: R.Tensor((n,), dtype="float32") = alloc
1132+
alloc1: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
1133+
_1: R.Tuple = cls.tir_exp(full, alloc1)
1134+
lv2: R.Tensor((n,), dtype="float32") = alloc1
1135+
alloc2: R.Tensor((n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([n]), R.dtype("float32"), R.prim_value(0))
1136+
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
1137+
lv3: R.Tensor((n,), dtype="float32") = alloc2
1138+
return lv3
1139+
1140+
@I.ir_module
1141+
class Expected:
1142+
@T.prim_func
1143+
def tir_exp(var_rxplaceholder: T.handle, var_compute: T.handle):
1144+
T.evaluate(0)
1145+
1146+
@T.prim_func
1147+
def tir_full(var_full: T.handle, n: T.int64):
1148+
T.evaluate(0)
1149+
1150+
@R.function
1151+
def main(s: R.Shape(["n"])) -> R.Tensor(("n",), dtype="float32"):
1152+
n = T.int64()
1153+
R.func_attr({"tir_var_upper_bound": {"n": 20}, "relax.force_pure": True})
1154+
cls = Expected
1155+
storage: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32"))
1156+
alloc: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), R.dtype("float32"))
1157+
_: R.Tuple = R.vm.call_tir_dyn(cls.tir_full, (alloc, R.shape([n])))
1158+
full: R.Tensor((n,), dtype="float32") = alloc
1159+
storage1: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32"))
1160+
alloc1: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([n]), R.dtype("float32"))
1161+
_1: R.Tuple = cls.tir_exp(full, alloc1)
1162+
lv2: R.Tensor((n,), dtype="float32") = alloc1
1163+
storage2: R.Object = R.memory.alloc_storage(R.shape([80]), R.prim_value(0), R.str("global"), R.dtype("float32"))
1164+
alloc2: R.Tensor((n,), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([n]), R.dtype("float32"))
1165+
_2: R.Tuple = cls.tir_exp(lv2, alloc2)
1166+
lv3: R.Tensor((n,), dtype="float32") = alloc2
1167+
return lv3
1168+
# fmt: on
1169+
1170+
mod = relax.transform.StaticPlanBlockMemory()(Module)
1171+
tvm.ir.assert_structural_equal(mod, Expected)
1172+
1173+
11121174
def test_function_independence():
11131175
# fmt: off
11141176
@tvm.script.ir_module

0 commit comments

Comments
 (0)