@@ -495,86 +495,29 @@ class BufferAccessRegionCollector : public StmtExprVisitor {
495495 access_annotations_;
496496};
497497
498+ /* ! \brief The storage alignment for a dimension */
499+ struct DimAlignInfo {
500+ /* ! \brief The factor of the alignment */
501+ int align_factor{0 };
502+ /* ! \brief The offset of the alignment */
503+ int align_offset{0 };
504+ };
505+
506+ struct BufferAllocInfo {
507+ /* ! \brief The buffer access region. */
508+ Region region;
509+ /* ! \brief The storage alignment information. */
510+ std::vector<DimAlignInfo> dim_aligns;
511+ /* !
512+ * \brief The reallocated buffer with minimal size.
513+ * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer).
514+ */
515+ Buffer new_buffer;
516+ };
517+
498518/* ! \brief Reallocate the buffers with minimal region. */
499519class BufferCompactor : public StmtExprMutator {
500520 public:
501- static Stmt Compact (
502- const PrimFunc& f,
503- const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
504- const std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>&
505- storage_align) {
506- // collect buffer allocation info for no-alias buffers
507- std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;
508- for (const auto & kv : regions) {
509- const Buffer& buffer = kv.first ;
510-
511- // set dim alignment info
512- Region region = kv.second ;
513- BufferAllocInfo alloc_info;
514- auto it = storage_align.find (buffer->data );
515- if (it != storage_align.end ()) {
516- std::vector<DimAlignInfo> dim_aligns (buffer->shape .size ());
517- for (const StorageAlignTuple& dim_align : (*it).second ) {
518- ICHECK (dim_align.size () == 4 );
519- int dim = dim_align[1 ]->value ;
520- int factor = dim_align[2 ]->value ;
521- int offset = dim_align[3 ]->value ;
522- dim_aligns.at (dim) = {factor, offset};
523- }
524- alloc_info.dim_aligns = std::move (dim_aligns);
525- }
526-
527- // prepare new buffer
528- Array<PrimExpr> shape = region.Map ([](const Range& range) { return range->extent ; });
529- Array<PrimExpr> strides;
530- if (alloc_info.dim_aligns .size ()) {
531- ICHECK (alloc_info.dim_aligns .size () == shape.size ());
532- strides.resize (shape.size ());
533- PrimExpr stride = make_const (shape[0 ].dtype (), 1 );
534- for (size_t i = shape.size (); i != 0 ; --i) {
535- size_t dim = i - 1 ;
536- if (alloc_info.dim_aligns [dim].align_factor != 0 ) {
537- PrimExpr factor = make_const (stride.dtype (), alloc_info.dim_aligns [dim].align_factor );
538- PrimExpr offset = make_const (stride.dtype (), alloc_info.dim_aligns [dim].align_offset );
539- stride = stride + indexmod (factor + offset - indexmod (stride, factor), factor);
540- }
541- strides.Set (dim, stride);
542- stride = stride * shape[dim];
543- }
544- }
545- ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get ());
546- n->shape = std::move (shape);
547- n->strides = std::move (strides);
548- alloc_info.new_buffer = Buffer (std::move (n));
549- alloc_info.region = region;
550- buffer_info.emplace (buffer->data , std::move (alloc_info));
551- }
552- BufferCompactor compactor (std::move (buffer_info));
553- Stmt stmt = compactor (f->body );
554- return stmt;
555- }
556-
557- private:
558- /* ! \brief The storage alignment for a dimension */
559- struct DimAlignInfo {
560- /* ! \brief The factor of the alignment */
561- int align_factor{0 };
562- /* ! \brief The offset of the alignment */
563- int align_offset{0 };
564- };
565-
566- struct BufferAllocInfo {
567- /* ! \brief The buffer access region. */
568- Region region;
569- /* ! \brief The storage alignment information. */
570- std::vector<DimAlignInfo> dim_aligns;
571- /* !
572- * \brief The reallocated buffer with minimal size.
573- * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer).
574- */
575- Buffer new_buffer;
576- };
577-
578521 explicit BufferCompactor (
579522 std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info)
580523 : buffer_info_(std::move(buffer_info)) {}
@@ -709,13 +652,76 @@ class BufferCompactor : public StmtExprMutator {
709652 std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info_;
710653};
711654
655+ Array<PrimExpr> CalcStrides (const BufferAllocInfo& alloc_info, const Array<PrimExpr>& shape) {
656+ std::vector<PrimExpr> strides;
657+ if (alloc_info.dim_aligns .size ()) {
658+ ICHECK (alloc_info.dim_aligns .size () == shape.size ());
659+ strides.resize (shape.size ());
660+ PrimExpr stride = make_const (shape[0 ].dtype (), 1 );
661+ for (size_t i = shape.size (); i != 0 ; --i) {
662+ size_t dim = i - 1 ;
663+ DimAlignInfo info = alloc_info.dim_aligns [dim];
664+ int align_factor = info.align_factor ;
665+ int align_offset = info.align_offset ;
666+ if (align_factor != 0 ) {
667+ PrimExpr factor = make_const (stride.dtype (), align_factor);
668+ PrimExpr offset = make_const (stride.dtype (), align_offset);
669+ stride = stride + indexmod (factor + offset - indexmod (stride, factor), factor);
670+ }
671+ strides[dim] = stride;
672+ stride = stride * shape[dim];
673+ }
674+ }
675+ return strides;
676+ }
677+
678+ Stmt BufferCompactorCompact (
679+ const PrimFunc& f,
680+ const std::unordered_map<Buffer, Region, ObjectPtrHash, ObjectPtrEqual>& regions,
681+ const std::unordered_map<Var, StorageAlignAnnotation, ObjectPtrHash, ObjectPtrEqual>&
682+ storage_align) {
683+ // collect buffer allocation info for no-alias buffers
684+ std::unordered_map<Var, BufferAllocInfo, ObjectPtrHash, ObjectPtrEqual> buffer_info;
685+ for (const auto & kv : regions) {
686+ const Buffer& buffer = kv.first ;
687+ // set dim alignment info
688+ Region region = kv.second ;
689+ BufferAllocInfo alloc_info;
690+ auto it = storage_align.find (buffer->data );
691+ if (it != storage_align.end ()) {
692+ std::vector<DimAlignInfo> dim_aligns (buffer->shape .size ());
693+ for (const StorageAlignTuple& dim_align : (*it).second ) {
694+ ICHECK (dim_align.size () == 4 );
695+ int dim = dim_align[1 ]->value ;
696+ int factor = dim_align[2 ]->value ;
697+ int offset = dim_align[3 ]->value ;
698+ dim_aligns.at (dim) = {factor, offset};
699+ }
700+ alloc_info.dim_aligns = std::move (dim_aligns);
701+ }
702+
703+ // prepare new buffer
704+ Array<PrimExpr> shape = region.Map ([](const Range& range) { return range->extent ; });
705+ Array<PrimExpr> strides = CalcStrides (alloc_info, shape);
706+ ObjectPtr<BufferNode> n = make_object<BufferNode>(*buffer.get ());
707+ n->shape = std::move (shape);
708+ n->strides = std::move (strides);
709+ alloc_info.new_buffer = Buffer (std::move (n));
710+ alloc_info.region = region;
711+ buffer_info.emplace (buffer->data , std::move (alloc_info));
712+ }
713+ BufferCompactor compactor (std::move (buffer_info));
714+ Stmt stmt = compactor (f->body );
715+ return stmt;
716+ }
717+
712718PrimFunc CompactBufferAllocation (PrimFunc f, bool is_strict) {
713719 // Only apply this pass to TIR that is not from TE schedules
714720 if (!IsFromLegacyTESchedule (f)) {
715721 PrimFuncNode* fptr = f.CopyOnWrite ();
716722 auto region = BufferAccessRegionCollector::Collect (f, /* collect_inbound=*/ is_strict);
717723 auto storage_align = CollectStorageAlignAnnotation (f->body );
718- fptr->body = BufferCompactor::Compact (f, region, storage_align);
724+ fptr->body = BufferCompactorCompact (f, region, storage_align);
719725 return f;
720726 } else {
721727 return f;
0 commit comments