Skip to content

Commit 756debc

Browse files
authored
[TIR][USMP] Greedy memory planning algorithm (#9214)
This commit implements a greedy memory planning algorithms using the proposed USMP design.There are two greedy algorithms introduced here which use the size and number of conflicts as the criteria. - Adds few test cases checks for fan-out and linear structures. - Added a test case for ResNet sub-structure - Added a test case for MobileNet sub-structure - This includes a slight fix for buffer info extraction where non-linear network buffers owned by the main function should not show sporadic liveness.
1 parent 0cd6868 commit 756debc

File tree

5 files changed

+1341
-280
lines changed

5 files changed

+1341
-280
lines changed

python/tvm/tir/usmp/utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,6 @@ def __init__(
114114
alignment,
115115
)
116116

117-
def set_pool_candidates(self, pool_candidates: list):
118-
"""Sets the pool candidate names"""
119-
_ffi_api.BufferInfoSetPoolCandidates(self, pool_candidates)
120-
121-
def set_pool_offsets(self, pool_name: str, pool_offset: int):
122-
"""Sets the pool offset by name"""
123-
_ffi_api.BufferInfoSetPoolOffset(self, pool_name, pool_offset)
124-
125117
def set_conflicts(self, conflicts: list):
126118
"""Sets the the conflicting array of buffer info objects"""
127119
_ffi_api.BufferInfoSetConflicts(self, conflicts)

src/tir/usmp/algo/greedy.cc

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tir/analysis/usmp/algo/greedy.cc
22+
* \brief This source contains greedy algorithms for planning
23+
* memory for USMP. There are two algorithms present here :
24+
* 1) greedy_by_size and 2) greedy_by_conflicts.
25+
*
26+
* greedy_by_size : this algorithm prioritizes placing the
27+
* largest size buffer to the given pools. The BufferInfo objects
28+
* are sorted based on the size and placed on each pool adhering
29+
* to size_hint constraint.
30+
*
31+
* greedy_by_conflicts : this algorithm prioritizes placing the
32+
* the most liveness conflicted buffer to the given pools. The
33+
* BufferInfo objects are sorted based on the number of conflicts
34+
* and placed on each pool adhering to size_hint constraint.
35+
*/
36+
37+
#include <tvm/arith/analyzer.h>
38+
#include <tvm/runtime/device_api.h>
39+
#include <tvm/tir/builtin.h>
40+
#include <tvm/tir/function.h>
41+
#include <tvm/tir/stmt_functor.h>
42+
#include <tvm/tir/usmp/utils.h>
43+
44+
namespace tvm {
45+
namespace tir {
46+
namespace usmp {
47+
namespace algo {
48+
49+
/*!
50+
* \brief This is the base class for Greedy Algorithms where the sorting
51+
* is specialized in the extended classes based on the greedy criteria.
52+
*/
53+
class GreedyBase {
54+
public:
55+
GreedyBase() {}
56+
/*!
57+
* \brief This function should be implemented by the extended classes to sort the BufferInfo
58+
* objects based on a criteria and then calling PostSortAllocation.
59+
*/
60+
virtual Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) = 0;
61+
62+
protected:
63+
/*!
64+
* \brief Rounds up the offset to satisfy the alignement requirement
65+
*/
66+
size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
67+
const int& byte_alignment) {
68+
return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
69+
}
70+
71+
/*!
72+
* \brief A helper function check whether a offset is valid given the constraints
73+
*/
74+
bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
75+
const size_t& size_bytes) {
76+
if (candidate_pool->size_hint_bytes == -1) {
77+
// this means pool is not bounded
78+
return true;
79+
}
80+
auto pool_size = static_cast<size_t>(candidate_pool->size_hint_bytes->value);
81+
auto max_address = next_offset + size_bytes;
82+
if (max_address <= pool_size) {
83+
return true;
84+
}
85+
return false;
86+
}
87+
88+
/*!
89+
* \brief Selects a pool for placement in the given set of ordered pool candidates
90+
*/
91+
PoolInfo SelectPlacementPool(
92+
const BufferInfo& buf_info,
93+
const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
94+
// Here the pool candidates are ordered when it is consumed by the algorithm.
95+
// This could be from order the user has specified. However, schedulers are
96+
// welcome to change the order for performance reasons.
97+
for (const auto& pool_info : buf_info->pool_candidates) {
98+
if (pool_offsets.count(pool_info)) {
99+
return pool_info;
100+
}
101+
}
102+
CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when "
103+
"trying to allocate the buffer : "
104+
<< buf_info << "\n. Please increase the size_hints for memory pools.";
105+
return PoolInfo();
106+
}
107+
108+
/*!
109+
* \brief This is the base allocation function that works on sorted BufferInfo objects based
110+
* on the greedy heuristic. The sorting algorithm has to be called before calling this.
111+
*/
112+
Map<BufferInfo, PoolAllocation> PostSortAllocation(
113+
const std::vector<BufferInfo>& buffer_info_vec) {
114+
Map<BufferInfo, PoolAllocation> pool_allocations;
115+
for (const auto& buf_info : buffer_info_vec) {
116+
std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual> pool_offset_candidates;
117+
for (const auto& pool_info : buf_info->pool_candidates) {
118+
// Mark pool candidates that satisfy the size constraints.
119+
if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) {
120+
pool_offset_candidates[pool_info] = 0;
121+
}
122+
}
123+
124+
for (const auto& conflict_buf_info_obj : buf_info->conflicts) {
125+
auto conflict_buf_info = Downcast<BufferInfo>(conflict_buf_info_obj);
126+
size_t next_offset = 0;
127+
// We only look at already allocated BufferInfo in-terms of conflicts.
128+
if (pool_allocations.count(conflict_buf_info)) {
129+
auto pool_allocation = pool_allocations[conflict_buf_info];
130+
next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes;
131+
next_offset =
132+
round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value);
133+
// Checks whether the next offset in the same pool as the conflicting BufferInfo is valid.
134+
if (IsValidPlacement(pool_allocation->pool_info, next_offset,
135+
buf_info->size_bytes->value)) {
136+
// There could be multiple conflicting BufferInfo in the same pool.
137+
// Thus, we need to make sure we pick the largest offset of them all.
138+
if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) {
139+
pool_offset_candidates[pool_allocation->pool_info] = next_offset;
140+
}
141+
} else {
142+
pool_offset_candidates.erase(pool_allocation->pool_info);
143+
}
144+
}
145+
}
146+
auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates);
147+
pool_allocations.Set(
148+
buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
149+
}
150+
return pool_allocations;
151+
}
152+
};
153+
154+
/*!
155+
* \brief This class implements Greedy by the size of BufferInfo
156+
* greedy algorithm. Please refer to main documentation of the file
157+
* for more details.
158+
*/
159+
class GreedySize : public GreedyBase {
160+
public:
161+
GreedySize() {}
162+
Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
163+
std::vector<BufferInfo> buffer_info_vec;
164+
Map<BufferInfo, PoolAllocation> pool_allocations;
165+
for (const auto& buffer_info : buffer_info_arr) {
166+
buffer_info_vec.push_back(std::move(buffer_info));
167+
}
168+
std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
169+
[](const BufferInfo& a, const BufferInfo& b) {
170+
if (a->size_bytes->value == b->size_bytes->value) {
171+
if (a->conflicts.size() == b->conflicts.size()) {
172+
return std::string(a->name_hint->data) > std::string(b->name_hint->data);
173+
} else {
174+
return a->conflicts.size() > b->conflicts.size();
175+
}
176+
}
177+
return a->size_bytes > b->size_bytes;
178+
});
179+
return PostSortAllocation(buffer_info_vec);
180+
}
181+
};
182+
183+
/*!
184+
* \brief This class implements Greedy by the number of conflicts of
185+
* BufferInfo greedy algorithm. Please refer to main documentation
186+
* of the file for more details.
187+
*/
188+
class GreedyConflicts : public GreedyBase {
189+
public:
190+
GreedyConflicts() {}
191+
Map<BufferInfo, PoolAllocation> PlanMemory(const Array<BufferInfo>& buffer_info_arr) {
192+
std::vector<BufferInfo> buffer_info_vec;
193+
Map<BufferInfo, PoolAllocation> pool_allocations;
194+
for (const auto& buffer_info : buffer_info_arr) {
195+
buffer_info_vec.push_back(std::move(buffer_info));
196+
}
197+
std::sort(buffer_info_vec.begin(), buffer_info_vec.end(),
198+
[](const BufferInfo& a, const BufferInfo& b) {
199+
if (a->conflicts.size() == b->conflicts.size()) {
200+
if (a->size_bytes->value == b->size_bytes->value) {
201+
return std::string(a->name_hint->data) > std::string(b->name_hint->data);
202+
} else {
203+
return a->size_bytes->value > b->size_bytes->value;
204+
}
205+
}
206+
return a->conflicts.size() > b->conflicts.size();
207+
});
208+
return PostSortAllocation(buffer_info_vec);
209+
}
210+
};
211+
212+
Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_info_arr) {
213+
return GreedySize().PlanMemory(buffer_info_arr);
214+
}
215+
216+
Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr) {
217+
return GreedyConflicts().PlanMemory(buffer_info_arr);
218+
}
219+
220+
TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_size")
221+
.set_body_typed([](Array<BufferInfo> buffer_info_arr) {
222+
return GreedyBySize(buffer_info_arr);
223+
});
224+
225+
TVM_REGISTER_GLOBAL("tir.usmp.algo.greedy_by_conflicts")
226+
.set_body_typed([](Array<BufferInfo> buffer_info_arr) {
227+
return GreedyByConflicts(buffer_info_arr);
228+
});
229+
230+
} // namespace algo
231+
} // namespace usmp
232+
} // namespace tir
233+
} // namespace tvm

0 commit comments

Comments
 (0)