Skip to content

Commit 897019d

Browse files
authored
[Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. (#10787)
* [Pass][Bugfix] Disable re-use of non-flat buffers in StorageRewrite. As a follow-up from #9727, restricting StorageRewrite to only modify flat memory buffers. When rewriting, the existing algorithm in StorageRewrite flattens N-d allocations into 1-d allocations, preventing them from being exposed to the codegen. * Bugfix, flattening of Allocate/AllocateConst extents Previously, these were ignored entirely. This worked so long as all allocations were 1-d, as `StorageRewrite` erroneously flattened merged arrays into 1-d.
1 parent b5cad84 commit 897019d

File tree

2 files changed

+155
-19
lines changed

2 files changed

+155
-19
lines changed

src/tir/transforms/storage_flatten.cc

Lines changed: 95 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,12 +1405,25 @@ class StorageFlattener : public StmtExprMutator {
14051405
// rather than a buffer_var.
14061406
Stmt VisitStmt_(const AllocateNode* op) final {
14071407
buffer_var_defines_.insert(op->buffer_var.get());
1408-
return StmtExprMutator::VisitStmt_(op);
1408+
auto stmt = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
1409+
return Allocate(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), stmt->condition,
1410+
stmt->body, stmt->annotations, stmt->span);
14091411
}
14101412

14111413
Stmt VisitStmt_(const AllocateConstNode* op) final {
14121414
buffer_var_defines_.insert(op->buffer_var.get());
1413-
return StmtExprMutator::VisitStmt_(op);
1415+
auto stmt = Downcast<AllocateConst>(StmtExprMutator::VisitStmt_(op));
1416+
ObjectRef data_or_idx;
1417+
if (stmt->data) {
1418+
data_or_idx = stmt->data.value();
1419+
} else if (stmt->irmod_storage_idx) {
1420+
data_or_idx = stmt->irmod_storage_idx.value();
1421+
} else {
1422+
LOG(FATAL) << "Neither data array nor data index specified for allocation of const "
1423+
<< op->buffer_var->name_hint;
1424+
}
1425+
return AllocateConst(stmt->buffer_var, stmt->dtype, FlattenExtents(stmt), data_or_idx,
1426+
stmt->body, stmt->span);
14141427
}
14151428

14161429
Stmt VisitStmt_(const LetStmtNode* op) final {
@@ -1598,6 +1611,82 @@ class StorageFlattener : public StmtExprMutator {
15981611
}
15991612

16001613
private:
1614+
// Helper function for visiting Allocate and AllocateConst. If, in
1615+
// the future, these are updated to hold a buffer (Buffer) object
1616+
// rather than a buffer_var (Var), this function can be replaced
1617+
// with a call to GetBufferEntry.
1618+
template <typename Node>
1619+
Array<PrimExpr> FlattenExtents(const Node& node) {
1620+
arith::Analyzer analyzer;
1621+
1622+
// If an allocation has extents that match the buffer
1623+
auto is_compatible_buffer = [&](const Buffer& buffer) {
1624+
if (buffer->shape.size() != node->extents.size()) {
1625+
return false;
1626+
}
1627+
for (size_t i = 0; i < buffer->shape.size(); i++) {
1628+
if (!analyzer.CanProveEqual(buffer->shape[i], node->extents[i])) {
1629+
return false;
1630+
}
1631+
}
1632+
1633+
return true;
1634+
};
1635+
1636+
auto int_array_equal = [](const Array<IntImm>& a, const Array<IntImm>& b) {
1637+
if (a.size() != b.size()) {
1638+
return false;
1639+
}
1640+
1641+
for (size_t i = 0; i < a.size(); i++) {
1642+
if (a[i]->value != b[i]->value) {
1643+
return false;
1644+
}
1645+
}
1646+
1647+
return true;
1648+
};
1649+
1650+
Array<IntImm> axis_separators;
1651+
auto it = buffer_var_map_.find(node->buffer_var.get());
1652+
if (it != buffer_var_map_.end()) {
1653+
const auto& buffers = it->second;
1654+
if (buffers.size() == 0) {
1655+
// No buffers use this allocation, treat as flat and optimize
1656+
// out later.
1657+
} else if (buffers.size() == 1) {
1658+
// Only one buffer uses this allocation, so use its axis
1659+
// separators.
1660+
axis_separators = buffers[0]->axis_separators;
1661+
} else {
1662+
// Try to find a buffer using this allocation with a matching
1663+
// shape.
1664+
Buffer compatible_buffer;
1665+
for (const auto& buffer : buffers) {
1666+
if (is_compatible_buffer(buffer)) {
1667+
ICHECK(!compatible_buffer.defined() ||
1668+
int_array_equal(compatible_buffer->axis_separators, buffer->axis_separators))
1669+
<< "Cannot determine axis separators to use when flattening "
1670+
<< node->buffer_var->name_hint
1671+
<< ", multiple buffer objects found with conflicting axis separators";
1672+
compatible_buffer = buffer;
1673+
}
1674+
}
1675+
ICHECK(compatible_buffer.defined())
1676+
<< "Cannot determine axis separators to use when flattening "
1677+
<< node->buffer_var->name_hint << ", no buffers found with matching shape";
1678+
axis_separators = compatible_buffer->axis_separators;
1679+
}
1680+
}
1681+
1682+
// Use GetFlattenedBuffer to determine the flattened shape of the
1683+
// output. We only need the shape and axis separators defined,
1684+
// everything else can be dummy values.
1685+
Buffer dummy_buffer =
1686+
decl_buffer(node->extents, DataType::Float(32), "buffer", "", axis_separators);
1687+
return dummy_buffer.GetFlattenedBuffer()->shape;
1688+
}
1689+
16011690
// The buffer entry in the flatten map
16021691
struct DimAlignInfo {
16031692
int align_factor{0};
@@ -1665,6 +1754,10 @@ class StorageFlattener : public StmtExprMutator {
16651754
// Set of vars that have occurred in an AllocateNode, but haven't
16661755
// yet occurred in a BufferLoad/BufferStore.
16671756
std::unordered_set<const VarNode*> buffer_var_defines_;
1757+
// Map from an allocation variable to the buffer(s) that it backs.
1758+
// Used to track the determine the axis_separators that should be
1759+
// used for flattening the extents of an AllocateNode.
1760+
std::unordered_map<const VarNode*, std::vector<Buffer>> buffer_var_map_;
16681761
// Buffer map
16691762
std::unordered_map<Buffer, BufferEntry, ObjectPtrHash, ObjectPtrEqual> buf_map_;
16701763
// The extern buffer map, updated to include flattened buffers.

src/tir/transforms/storage_rewrite.cc

Lines changed: 60 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
7676
};
7777
// The scope of each allocation
7878
struct AllocEntry {
79+
// The physical dimension of the allocation.
80+
size_t num_physical_dimensions{0};
7981
// scope level
8082
size_t level{0};
8183
// allocation stmt
@@ -85,8 +87,16 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
8587
void VisitStmt_(const AllocateNode* op) final {
8688
size_t level = scope_.size();
8789
const VarNode* buf = op->buffer_var.get();
88-
alloc_info_[buf].alloc = op;
89-
alloc_info_[buf].level = level;
90+
91+
AllocEntry entry;
92+
entry.alloc = op;
93+
entry.level = level;
94+
// Since StorageRewrite occurs after StorageFlatten/FlattenBuffer,
95+
// all allocations specify the extent of physical dimensions, and
96+
// is 1 for flat memory spaces.
97+
entry.num_physical_dimensions = op->extents.size();
98+
alloc_info_[buf] = entry;
99+
90100
StmtExprVisitor::VisitStmt_(op);
91101
}
92102

@@ -104,6 +114,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
104114
if (it != alloc_info_.end() && it->second.alloc) {
105115
ICHECK_LT(it->second.level, scope_.size());
106116
scope_[it->second.level].touched.push_back(buf);
117+
118+
ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
119+
<< "Buffer " << op->buffer->name << " is allocated with "
120+
<< it->second.num_physical_dimensions
121+
<< " physical dimensions, but is accessed as having "
122+
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
107123
}
108124
StmtEntry e = scope_.back();
109125
scope_.pop_back();
@@ -125,6 +141,12 @@ class LinearAccessPatternFinder final : public StmtExprVisitor {
125141
if (it != alloc_info_.end() && it->second.alloc) {
126142
ICHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store.";
127143
scope_[it->second.level].touched.push_back(buf);
144+
145+
ICHECK_EQ(op->buffer->axis_separators.size() + 1, it->second.num_physical_dimensions)
146+
<< "Buffer " << op->buffer->name << " is allocated with "
147+
<< it->second.num_physical_dimensions
148+
<< " physical dimensions, but is accessed as having "
149+
<< op->buffer->axis_separators.size() + 1 << " physical dimensions" << std::endl;
128150
}
129151
}
130152

@@ -530,6 +552,10 @@ class StoragePlanRewriter : public StmtExprMutator {
530552
uint64_t const_nbits{0};
531553
// The storage scope.
532554
StorageScope scope;
555+
// The physical dimensionality of the allocations. Since
556+
// StorageRewrite is applied after StorageFlatten/FlattenBuffer,
557+
// this is size of `AllocateNode::extents`. If moved
558+
size_t ndim;
533559
// Allocs that shares this entry.
534560
std::vector<const AllocateNode*> allocs;
535561
// The children of this entry, not including itself.
@@ -629,8 +655,8 @@ class StoragePlanRewriter : public StmtExprMutator {
629655
// simply use the original allocation.
630656
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
631657
make_const(DataType::Int(32), 1), e->allocs[0]->extents);
632-
e->new_alloc =
633-
Allocate(e->alloc_var, alloc_type, {sz}, e->allocs[0]->condition, Evaluate(0));
658+
e->new_alloc = Allocate(e->alloc_var, alloc_type, e->allocs[0]->extents,
659+
e->allocs[0]->condition, Evaluate(0));
634660
if (IsSpecialTaggedMemory(e->scope)) {
635661
MemoryInfo info = GetMemoryInfo(e->scope.to_string());
636662
uint64_t total_elem = e->const_nbits / e->elem_type.bits();
@@ -641,8 +667,13 @@ class StoragePlanRewriter : public StmtExprMutator {
641667
// Build a merged allocation
642668
PrimExpr combo_size;
643669
for (const AllocateNode* op : e->allocs) {
644-
PrimExpr sz = foldl([](PrimExpr a, PrimExpr b, Span span) { return mul(a, b, span); },
645-
make_const(DataType::Int(32), 1), op->extents);
670+
ICHECK_EQ(op->extents.size(), 1)
671+
<< "Buffer var " << op->buffer_var->name_hint
672+
<< " was identified as a re-usable allocation, but has " << op->extents.size()
673+
<< " physical dimensions. "
674+
<< "Currently, only flat 1-d memory spaces should be identified as re-usable "
675+
"allocations.";
676+
PrimExpr sz = op->extents[0];
646677
auto nbits = op->dtype.bits() * op->dtype.lanes();
647678
if (const auto* imm = sz.as<IntImmNode>()) {
648679
if (imm->value > std::numeric_limits<int>::max() / nbits) {
@@ -790,7 +821,8 @@ class StoragePlanRewriter : public StmtExprMutator {
790821

791822
for (const VarNode* var : it->second.gen) {
792823
ICHECK(alloc_info.count(var));
793-
const AllocateNode* alloc = alloc_info.at(var).alloc;
824+
const AllocEntry& entry = alloc_info.at(var);
825+
const AllocateNode* alloc = entry.alloc;
794826
auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef<Var>(var)));
795827
StorageEntry* dst_entry = nullptr;
796828
// inplace detection
@@ -818,7 +850,8 @@ class StoragePlanRewriter : public StmtExprMutator {
818850
}
819851
}
820852
if (dst_entry == nullptr) {
821-
dst_entry = FindAlloc(alloc, thread_scope_, storage_scope);
853+
dst_entry =
854+
FindAlloc(alloc, thread_scope_, storage_scope, entry.num_physical_dimensions);
822855
}
823856
dst_entry->allocs.emplace_back(alloc);
824857
alloc_map_[var] = dst_entry;
@@ -871,24 +904,34 @@ class StoragePlanRewriter : public StmtExprMutator {
871904
}
872905

873906
StorageEntry* FindAlloc(const AllocateNode* op, const Object* attach_scope,
874-
const StorageScope& scope) {
907+
const StorageScope& scope, size_t num_physical_dimensions) {
875908
ICHECK(op != nullptr);
876909
// skip plan for local variable,
877910
// compiler can do a better job with register allocation.
878911
const uint64_t match_range = 16;
879912
uint64_t op_elem_bits = op->dtype.bits() * op->dtype.lanes();
880913
uint64_t const_nbits = static_cast<uint64_t>(op->ConstantAllocationSize() * op_elem_bits);
914+
915+
// If the size of the array isn't known at compile-time, it must
916+
// have its own allocation with size determined at runtime.
917+
bool is_known_size = (const_nbits != 0);
918+
919+
// Currently, only flat memory spaces can be re-used. Packing
920+
// into N-d space (e.g. 2-d texture memory on GPUs) will require
921+
// more in-depth algorithms.
922+
bool is_flat_memory_space = (num_physical_dimensions == 1);
923+
881924
// disable reuse of small arrays, they will be lowered to registers in LLVM
882925
// This rules only apply if we are using non special memory
883-
if (scope.tag.length() == 0) {
884-
if (scope.rank >= StorageRank::kWarp || op->dtype.is_handle()) {
885-
return NewAlloc(op, attach_scope, scope, const_nbits);
886-
}
887-
if (const_nbits > 0 && const_nbits <= 32) {
888-
return NewAlloc(op, attach_scope, scope, const_nbits);
889-
}
926+
bool is_small_array =
927+
(scope.tag.length() == 0) && (scope.rank >= StorageRank::kWarp || op->dtype.is_handle() ||
928+
(is_known_size && const_nbits <= 32));
929+
930+
if (is_small_array || !is_flat_memory_space) {
931+
return NewAlloc(op, attach_scope, scope, const_nbits);
890932
}
891-
if (const_nbits != 0) {
933+
934+
if (is_known_size) {
892935
// constant allocation.
893936
auto begin = const_free_map_.lower_bound(const_nbits / match_range);
894937
auto mid = const_free_map_.lower_bound(const_nbits);

0 commit comments

Comments
 (0)