@@ -296,6 +296,82 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
296296 std::vector<const BindingBlockNode*> block_stack_;
297297};
298298
299+ /* !
300+ * \brief Set the upper bound of the TIR variables that appear in
301+ * the input function signature in the analyzer.
302+ * \param func The function to be analyzed.
303+ * \param ana The analyzer which contains the TIR var upper bounds.
304+ */
305+ void SetTIRVarUpperBound (Function func, arith::Analyzer* ana) {
306+ // Use the attribute-annotated TIR var upper bounds as the TIR var values for
307+ // memory planning.
308+ // NOTE: we only apply the annotated upper bounds to the TIR variables that
309+ // appear in the **function signature**.
310+ Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
311+ func->GetAttr <Map<ObjectRef, ObjectRef>>(" tir_var_upper_bound" )
312+ .value_or (Map<ObjectRef, ObjectRef>());
313+ std::unordered_map<String, IntImm> var_upper_bound_attr;
314+ // We manually check the value type to ensure the values are all positive IntImm.
315+ for (auto it : var_upper_bound_attr_raw) {
316+ const auto * key = it.first .as <StringObj>();
317+ const auto * value = it.second .as <IntImmNode>();
318+ CHECK (key != nullptr )
319+ << " The entry key of attr `tir_var_upper_bound` should be string. However "
320+ << it.first ->GetTypeKey () << " is got." ;
321+ CHECK (value != nullptr )
322+ << " The entry value of attr `tir_var_upper_bound` should be integer. However "
323+ << it.second ->GetTypeKey () << " is got." ;
324+ CHECK_GT (value->value , 0 )
325+ << " The entry value of attr `tir_var_upper_bound` should be a positive integer, while "
326+ << value->value << " is got." ;
327+ var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
328+ }
329+ Array<tir::Var> var_in_signature = TIRVarsInStructInfo (GetStructInfo (func));
330+ for (const tir::Var& tir_var : var_in_signature) {
331+ auto it = var_upper_bound_attr.find (tir_var->name_hint );
332+ if (it != var_upper_bound_attr.end ()) {
333+ ana->Bind (tir_var,
334+ tvm::Range::FromMinExtent (tvm::IntImm (DataType::Int (64 ), 0 ),
335+ tvm::IntImm (DataType::Int (64 ), (*it).second ->value + 1 )));
336+ }
337+ }
338+ }
339+
340+ /* !
341+ * \brief Use the upper bounds of TIR vars to compute the upper
342+ * bound of a given shape.
343+ * \param shape The input shape to be computed.
344+ * \param ana The arithmetic analyzer that contains the upper bounds
345+ * of TIR variables
346+ * \return The upper-bounded shape. When a dimension's upper bound
347+ * cannot be determined, we keep the dimension unchanged.
348+ */
349+ Array<PrimExpr> GetUpperBoundShape (Array<PrimExpr> shape, arith::Analyzer* ana) {
350+ // Use the upper bounds of TIR vars as their values.
351+ Array<PrimExpr> upper_bounded_shape;
352+ upper_bounded_shape.reserve (shape.size ());
353+ for (const PrimExpr& dim_len : shape) {
354+ int64_t max_bound = ana->const_int_bound (dim_len)->max_value ;
355+ if (max_bound == std::numeric_limits<int64_t >::max ()) {
356+ upper_bounded_shape.push_back (dim_len);
357+ } else {
358+ upper_bounded_shape.push_back (tvm::IntImm (DataType::Int (64 ), max_bound));
359+ }
360+ }
361+ return upper_bounded_shape;
362+ }
363+
364+ /* ! \brief Check if a shape is static (a.k.a., has no TIR variable). */
365+ bool IsStaticShape (Array<PrimExpr> shape) {
366+ for (const PrimExpr& dim : shape) {
367+ const auto * int_len = dim.as <IntImmNode>();
368+ if (!int_len) {
369+ return false ;
370+ }
371+ }
372+ return true ;
373+ }
374+
299375/* !
300376 * \brief The visitor class for storage token initialization.
301377 * \details It goes through the entire function to get the storage tokens
@@ -330,40 +406,8 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
330406 explicit StorageAllocatorInit (const IRModule& ctx_mod) : ctx_mod_(ctx_mod) {}
331407
332408 void VisitExpr_ (const FunctionNode* func) final {
333- // Use the attribute-annotated TIR var upper bounds as the TIR var values for
334- // memory planning.
335- // NOTE: we only apply the annotated upper bounds to the TIR variables that
336- // appear in the **function signature**.
337- Map<ObjectRef, ObjectRef> var_upper_bound_attr_raw =
338- func->GetAttr <Map<ObjectRef, ObjectRef>>(" tir_var_upper_bound" )
339- .value_or (Map<ObjectRef, ObjectRef>());
340- std::unordered_map<String, IntImm> var_upper_bound_attr;
341- // We manually check the value type to ensure the values are all positive IntImm.
342- for (auto it : var_upper_bound_attr_raw) {
343- const auto * key = it.first .as <StringObj>();
344- const auto * value = it.second .as <IntImmNode>();
345- CHECK (key != nullptr )
346- << " The entry key of attr `tir_var_upper_bound` should be string. However "
347- << it.first ->GetTypeKey () << " is got." ;
348- CHECK (value != nullptr )
349- << " The entry value of attr `tir_var_upper_bound` should be integer. However "
350- << it.second ->GetTypeKey () << " is got." ;
351- CHECK_GT (value->value , 0 )
352- << " The entry value of attr `tir_var_upper_bound` should be a positive integer, while "
353- << value->value << " is got." ;
354- var_upper_bound_attr[GetRef<String>(key)] = GetRef<IntImm>(value);
355- }
356- Array<tir::Var> var_in_signature = TIRVarsInStructInfo (GetStructInfo (GetRef<Function>(func)));
357- var_upper_bound_.clear ();
358- for (const tir::Var& tir_var : var_in_signature) {
359- auto it = var_upper_bound_attr.find (tir_var->name_hint );
360- if (it != var_upper_bound_attr.end ()) {
361- ana_.Bind (tir_var, tvm::Range::FromMinExtent (
362- tvm::IntImm (DataType::Int (64 ), 0 ),
363- tvm::IntImm (DataType::Int (64 ), (*it).second ->value + 1 )));
364- }
365- }
366-
409+ // Set the upper bound of TIR variables in the analyzer.
410+ SetTIRVarUpperBound (GetRef<Function>(func), &ana_);
367411 // Recurse into the function to get its tokens.
368412 Tokens body_tokens = GetTokens (func->body );
369413 // Discard the tokens used by the function return value, as they are external referenced.
@@ -457,32 +501,20 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
457501 // - the tensor has known dtype;
458502 // - no storage token was created for this call before.
459503 const auto * sinfo = call->struct_info_ .as <TensorStructInfoNode>();
460- const auto * shape = sinfo->shape .as <ShapeExprNode>();
461504 ICHECK_NOTNULL (sinfo);
505+ const auto * shape = sinfo->shape .as <ShapeExprNode>();
462506 ICHECK_NOTNULL (shape);
463507 ICHECK (!sinfo->IsUnknownDtype ());
464508 ICHECK (sinfo->dtype == Downcast<DataTypeImm>(call->args [1 ])->value );
465509 ICHECK (!token_map_.count (call));
466510
467511 // Use the upper bounds of TIR vars as their values.
468- Array<PrimExpr> upper_bounded_shape;
469- upper_bounded_shape.reserve (shape->values .size ());
470- for (const PrimExpr& dim_len : shape->values ) {
471- int64_t max_bound = ana_.const_int_bound (dim_len)->max_value ;
472- if (max_bound == std::numeric_limits<int64_t >::max ()) {
473- upper_bounded_shape.push_back (dim_len);
474- } else {
475- upper_bounded_shape.push_back (tvm::IntImm (DataType::Int (64 ), max_bound));
476- }
477- }
512+ Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape (shape->values , &ana_);
478513
479514 // No support for TIR vars that are not bounded.
480- for (const PrimExpr& dim_len : upper_bounded_shape) {
481- const auto * int_len = dim_len.as <IntImmNode>();
482- if (!int_len) {
483- token_map_[call] = Tokens ();
484- return Tokens ();
485- }
515+ if (!IsStaticShape (upper_bounded_shape)) {
516+ token_map_[call] = Tokens ();
517+ return Tokens ();
486518 }
487519
488520 // Create and set token.
@@ -558,8 +590,6 @@ class StorageAllocatorInit : public StorageAllocatorBaseVisitor {
558590 * a PrimFunc inside the IRModule.
559591 */
560592 const IRModule& ctx_mod_;
561- /* ! \brief The mapping from TIR variables to their respective upper bound values. */
562- std::unordered_map<tir::Var, IntImm, ObjectPtrHash, ObjectPtrEqual> var_upper_bound_;
563593 /* ! \brief The mapping from each token to the binding block where it is created. */
564594 std::unordered_map<const StorageTokenNode*, const BindingBlockNode*> token2block_;
565595 /* ! \brief The mapping from each token to the Exprs that are using this token. */
@@ -729,8 +759,17 @@ class StorageAllocationRewriter : public ExprMutator {
729759 if (func_ == nullptr ) {
730760 continue ;
731761 }
762+ constexpr static const char * plan_dyn_attr_ = " relax.memory_plan_dynamic_func_output" ;
763+ plan_dynamic_output_ = static_cast <bool >(
764+ func_->GetAttr <IntImm>(plan_dyn_attr_).value_or (IntImm (DataType::Int (32 ), 0 ))->value );
765+ if (plan_dynamic_output_) {
766+ SetTIRVarUpperBound (GetRef<Function>(func_), &ana_);
767+ }
732768 token2storage_var_.clear ();
733769 Function func = Downcast<Function>(this ->VisitExpr_ (func_));
770+ if (plan_dynamic_output_) {
771+ func = WithoutAttr (func, plan_dyn_attr_);
772+ }
734773 builder_->UpdateFunction (gv, func);
735774 }
736775 return builder_->GetContextIRModule ();
@@ -740,8 +779,13 @@ class StorageAllocationRewriter : public ExprMutator {
740779 using ExprMutator::VisitExpr_;
741780
742781 Expr VisitExpr_ (const CallNode* call) final {
782+ static const Op& alloc_tensor_op = Op::Get (" relax.builtin.alloc_tensor" );
783+ static const Op& mem_alloc_storage = Op::Get (" relax.memory.alloc_storage" );
784+ static const Op& mem_alloc_tensor = Op::Get (" relax.memory.alloc_tensor" );
743785 auto it = alloc_tensor2token_.find (call);
744786 if (it != alloc_tensor2token_.end ()) {
787+ // Case 1. This `alloc_tensor` is planned for memory reuse.
788+ ICHECK_EQ (call->op , alloc_tensor_op);
745789 const auto * sinfo = call->struct_info_ .as <TensorStructInfoNode>();
746790 ICHECK_NOTNULL (sinfo);
747791 ICHECK_NOTNULL (sinfo->shape .as <ShapeExprNode>());
@@ -753,7 +797,6 @@ class StorageAllocationRewriter : public ExprMutator {
753797 Var storage_var{nullptr };
754798 auto it_token = token2storage_var_.find (token.get ());
755799 if (it_token == token2storage_var_.end ()) {
756- static const Op& mem_alloc_storage = Op::Get (" relax.memory.alloc_storage" );
757800 ShapeExpr size ({tir::make_const (DataType::Int (64 ), token->bytes )});
758801 PrimValue virtual_device_index = runtime_device_index;
759802 std::string storage_scope = " global" ;
@@ -769,16 +812,46 @@ class StorageAllocationRewriter : public ExprMutator {
769812 }
770813
771814 // And always create a `memory.alloc_tensor` for the old `builtin.alloc_tensor`.
772- static const Op& mem_alloc_tensor = Op::Get (" relax.memory.alloc_tensor" );
773815 PrimValue offset = PrimValue::Int64 (0 );
774816 DataType dtype = sinfo->dtype ;
775817 return Call (mem_alloc_tensor, {storage_var, offset, sinfo->shape .value (), DataTypeImm (dtype)},
776818 Attrs ());
819+ } else if (plan_dynamic_output_ && call->op == alloc_tensor_op) {
820+ // Case 2. For a `alloc_tensor` that is not planned for memory reuse,
821+ // we would still like to allocate **static** memory for the tensor.
822+ // So in case the tensor shape is dynamic but has an upper bound
823+ // estimation, we allocate a storage to its upper bound size, and
824+ // allocate a tensor out from it with the actual symbolic shape.
825+
826+ const auto * sinfo = call->struct_info_ .as <TensorStructInfoNode>();
827+ ICHECK_NOTNULL (sinfo);
828+ const auto * shape = sinfo->shape .as <ShapeExprNode>();
829+ ICHECK_NOTNULL (shape);
830+ Array<PrimExpr> upper_bounded_shape = GetUpperBoundShape (shape->values , &ana_);
831+ if (!IsStaticShape (shape->values ) && IsStaticShape (upper_bounded_shape)) {
832+ ICHECK (!sinfo->IsUnknownDtype ());
833+ ICHECK_EQ (sinfo->dtype , Downcast<DataTypeImm>(call->args [1 ])->value );
834+ StorageToken token (upper_bounded_shape, sinfo->dtype );
835+ Call alloc_storage (mem_alloc_storage,
836+ {/* size=*/ ShapeExpr ({tvm::IntImm (DataType::Int (64 ), token->bytes )}),
837+ /* virtual_device_index=*/ Downcast<PrimValue>(call->args [2 ]),
838+ /* storage_scope=*/ StringImm (" global" ), //
839+ /* dtype=*/ DataTypeImm (token->dtype )});
840+ Var storage = builder_->Emit (alloc_storage, " storage" );
841+ return Call (mem_alloc_tensor, {storage, //
842+ /* offset=*/ PrimValue::Int64 (0 ),
843+ /* shape=*/ GetRef<ShapeExpr>(shape), //
844+ /* dtype=*/ DataTypeImm (sinfo->dtype )});
845+ }
777846 }
778847
779848 return ExprMutator::VisitExpr_ (call);
780849 }
781850
851+ /* ! \brief The arithmetic analyzer. */
852+ arith::Analyzer ana_;
853+ /* ! \brief A boolean indicating whether to plan dynamic-shape function output tensors. */
854+ bool plan_dynamic_output_;
782855 /* !
783856 * \brief The mapping from each memory-reusable `builtin.alloc_tensor` to
784857 its corresponding underlying storage token that it is using.
0 commit comments