@@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
7676 };
7777 // The scope of each allocation
7878 struct AllocEntry {
79+ // The physical dimension of the allocation.
80+ size_t num_physical_dimensions{0 };
7981 // scope level
8082 size_t level{0 };
8183 // allocation stmt
@@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
8587 void VisitStmt_ (const AllocateNode* op) final {
8688 size_t level = scope_.size ();
8789 const VarNode* buf = op->buffer_var .get ();
88- alloc_info_[buf].alloc = op;
89- alloc_info_[buf].level = level;
90+
91+ AllocEntry entry;
92+ entry.alloc = op;
93+ entry.level = level;
94+ // Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
95+ // all allocations specify the extent of physical dimensions, and
96+ // is 1 for flat memory spaces.
97+ entry.num_physical_dimensions = op->extents .size ();
98+ alloc_info_[buf] = entry;
99+
90100 StmtExprVisitor::VisitStmt_ (op);
91101 }
92102
@@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
104114 if (it != alloc_info_.end () && it->second .alloc ) {
105115 ICHECK_LT (it->second .level , scope_.size ());
106116 scope_[it->second .level ].touched .push_back (buf);
117+
118+ ICHECK_EQ (op->buffer ->axis_separators .size () + 1 , it->second .num_physical_dimensions )
119+ << " Buffer " << op->buffer ->name << " is allocated with "
120+ << it->second .num_physical_dimensions
121+ << " physical dimensions, but is accessed as having "
122+ << op->buffer ->axis_separators .size () + 1 << " physical dimensions" << std::endl;
107123 }
108124 StmtEntry e = scope_.back ();
109125 scope_.pop_back ();
@@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
125141 if (it != alloc_info_.end () && it->second .alloc ) {
126142 ICHECK_LT (it->second .level , scope_.size ()) << " Load memory in places other than store." ;
127143 scope_[it->second .level ].touched .push_back (buf);
144+
145+ ICHECK_EQ (op->buffer ->axis_separators .size () + 1 , it->second .num_physical_dimensions )
146+ << " Buffer " << op->buffer ->name << " is allocated with "
147+ << it->second .num_physical_dimensions
148+ << " physical dimensions, but is accessed as having "
149+ << op->buffer ->axis_separators .size () + 1 << " physical dimensions" << std::endl;
128150 }
129151 }
130152
@@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator {
530552 uint64_t const_nbits{0 };
531553 // The storage scope.
532554 StorageScope scope;
555+ // The physical dimensionality of the allocations. Since
556+ // StorageRewrite is applied after StorageFlatten/FlattenBuffer,
557+ // this is size of `AllocateNode::extents`. If moved
558+ size_t ndim;
533559 // Allocs that shares this entry.
534560 std::vector<const AllocateNode*> allocs;
535561 // The children of this entry, not including itself.
@@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator {
629655 // simply use the original allocation.
630656 PrimExpr sz = foldl ([](PrimExpr a, PrimExpr b, Span span) { return mul (a, b, span); },
631657 make_const (DataType::Int (32 ), 1 ), e->allocs [0 ]->extents );
632- e->new_alloc =
633- Allocate (e-> alloc_var , alloc_type, {sz}, e->allocs [0 ]->condition , Evaluate (0 ));
658+ e->new_alloc = Allocate (e-> alloc_var , alloc_type, e-> allocs [ 0 ]-> extents ,
659+ e->allocs [0 ]->condition , Evaluate (0 ));
634660 if (IsSpecialTaggedMemory (e->scope )) {
635661 MemoryInfo info = GetMemoryInfo (e->scope .to_string ());
636662 uint64_t total_elem = e->const_nbits / e->elem_type .bits ();
@@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator {
641667 // Build a merged allocation
642668 PrimExpr combo_size;
643669 for (const AllocateNode* op : e->allocs ) {
644- PrimExpr sz = foldl ([](PrimExpr a, PrimExpr b, Span span) { return mul (a, b, span); },
645- make_const (DataType::Int (32 ), 1 ), op->extents );
670+ ICHECK_EQ (op->extents .size (), 1 )
671+ << " Buffer var " << op->buffer_var ->name_hint
672+ << " was identified as a re-usable allocation, but has " << op->extents .size ()
673+ << " physical dimensions. "
674+ << " Currently, only flat 1-d memory spaces should be identified as re-usable "
675+ " allocations." ;
676+ PrimExpr sz = op->extents [0 ];
646677 auto nbits = op->dtype .bits () * op->dtype .lanes ();
647678 if (const auto * imm = sz.as <IntImmNode>()) {
648679 if (imm->value > std::numeric_limits<int >::max () / nbits) {
@@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator {
790821
791822 for (const VarNode* var : it->second .gen ) {
792823 ICHECK (alloc_info.count (var));
793- const AllocateNode* alloc = alloc_info.at (var).alloc ;
824+ const AllocEntry& entry = alloc_info.at (var);
825+ const AllocateNode* alloc = entry.alloc ;
794826 auto storage_scope = StorageScope::Create (GetPtrStorageScope (GetRef<Var>(var)));
795827 StorageEntry* dst_entry = nullptr ;
796828 // inplace detection
@@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator {
818850 }
819851 }
820852 if (dst_entry == nullptr ) {
821- dst_entry = FindAlloc (alloc, thread_scope_, storage_scope);
853+ dst_entry =
854+ FindAlloc (alloc, thread_scope_, storage_scope, entry.num_physical_dimensions );
822855 }
823856 dst_entry->allocs .emplace_back (alloc);
824857 alloc_map_[var] = dst_entry;
@@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator {
871904 }
872905
873906 StorageEntry* FindAlloc (const AllocateNode* op, const Object* attach_scope,
874- const StorageScope& scope) {
907+ const StorageScope& scope, size_t num_physical_dimensions ) {
875908 ICHECK (op != nullptr );
876909 // skip plan for local variable,
877910 // compiler can do a better job with register allocation.
878911 const uint64_t match_range = 16 ;
879912 uint64_t op_elem_bits = op->dtype .bits () * op->dtype .lanes ();
880913 uint64_t const_nbits = static_cast <uint64_t >(op->ConstantAllocationSize () * op_elem_bits);
914+
915+ // If the size of the array isn't known at compile-time, it must
916+ // have its own allocation with size determined at runtime.
917+ bool is_known_size = (const_nbits != 0 );
918+
919+ // Currently, only flat memory spaces can be re-used. Packing
920+ // into N-d space (e.g. 2-d texture memory on GPUs) will require
921+ // more in-depth algorithms.
922+ bool is_flat_memory_space = (num_physical_dimensions == 1 );
923+
881924 // disable reuse of small arrays, they will be lowered to registers in LLVM
882925 // This rules only apply if we are using non special memory
883- if (scope.tag .length () == 0 ) {
884- if (scope.rank >= StorageRank::kWarp || op->dtype .is_handle ()) {
885- return NewAlloc (op, attach_scope, scope, const_nbits);
886- }
887- if (const_nbits > 0 && const_nbits <= 32 ) {
888- return NewAlloc (op, attach_scope, scope, const_nbits);
889- }
926+ bool is_small_array =
927+ (scope.tag .length () == 0 ) && (scope.rank >= StorageRank::kWarp || op->dtype .is_handle () ||
928+ (is_known_size && const_nbits <= 32 ));
929+
930+ if (is_small_array || !is_flat_memory_space) {
931+ return NewAlloc (op, attach_scope, scope, const_nbits);
890932 }
891- if (const_nbits != 0 ) {
933+
934+ if (is_known_size) {
892935 // constant allocation.
893936 auto begin = const_free_map_.lower_bound (const_nbits / match_range);
894937 auto mid = const_free_map_.lower_bound (const_nbits);
0 commit comments