Skip to content

Commit 7c93c60

Browse files
committed
[TIR][USMP] Greedy algorithms for USMP
This commits removes commented out lines ,few trivial cleanups and few BufferInfo based tests to check the algorithm. Change-Id: I1a12b6a424370e9e4c4a55563dde0ad698b07ea3
1 parent 78e099b commit 7c93c60

File tree

4 files changed

+217
-32
lines changed

4 files changed

+217
-32
lines changed

python/tvm/tir/usmp/utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ def __init__(
114114
alignment,
115115
)
116116

117+
def set_conflicts(self, conflicts: list):
118+
"""Sets the the conflicting array of buffer info objects"""
119+
_ffi_api.BufferInfoSetConflicts(self, conflicts)
120+
117121

118122
@register_object("tir.usmp.PoolAllocation")
119123
class PoolAllocation(Object):

src/tir/usmp/algo/greedy.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
*/
1919

2020
/*!
21-
* \file tir/analysis/usmp/algo/greedy_by_size.cc
21+
* \file tir/analysis/usmp/algo/greedy.cc
2222
* \brief This source contains greedy algorithms for planning
2323
* memory for USMP. There are two algorithms present here :
2424
* 1) greedy_by_size and 2) greedy_by_conflicts.
@@ -89,17 +89,17 @@ class GreedyBase {
8989
* \brief Selects a pool for placement in the given set of ordered pool candidates
9090
*/
9191
PoolInfo SelectPlacementPool(
92-
const Array<PoolInfo>& pool_candidates,
92+
const BufferInfo& buf_info,
9393
const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
9494
// Here the pool candidates are ordered when it is consumed by the algorithm.
9595
// This could be from order the user has specified. However, schedulers are
9696
// welcome to change the order for performance reasons.
97-
for (const auto& pool_info : pool_candidates) {
97+
for (const auto& pool_info : buf_info->pool_candidates) {
9898
if (pool_offsets.count(pool_info)) {
9999
return pool_info;
100100
}
101101
}
102-
ICHECK(false) << "TVM USMP Internal Error: no candidate have been selected!";
102+
CHECK(false) << "TVM USMP Error: no candidate have been selected for " << buf_info;
103103
return PoolInfo();
104104
}
105105

@@ -141,7 +141,7 @@ class GreedyBase {
141141
}
142142
}
143143
}
144-
auto selected_pool = SelectPlacementPool(buf_info->pool_candidates, pool_offset_candidates);
144+
auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates);
145145
pool_allocations.Set(
146146
buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])));
147147
}

src/tir/usmp/analysis/extract_buffer_info.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
454454

455455
// Traverse the liveness events using a open set to track what
456456
// is live while updating the conflicts through out the linear traversal
457-
// std::unordered_set<BufferInfo, ObjectPtrHash, ObjectPtrEqual> open_set;
458457
std::unordered_map<BufferInfo, int, ObjectPtrHash, ObjectPtrEqual> open_set;
459458
for (const auto& le_event : le_events_timeline) {
460459
if (le_event.le_type == START) {
@@ -465,7 +464,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
465464
le_event.buffer_info->conflicts.push_back(open_buffer_info);
466465
}
467466
}
468-
// open_set.insert(le_event.buffer_info);
469467
if (open_set.find(le_event.buffer_info) == open_set.end()) {
470468
open_set[le_event.buffer_info] = 1;
471469
} else {
@@ -477,7 +475,6 @@ Map<BufferInfo, tir::Stmt> BufferInfoExtractor::operator()(const PrimFunc& main_
477475
} else {
478476
open_set[le_event.buffer_info] -= 1;
479477
}
480-
// open_set.erase(le_event.buffer_info);
481478
}
482479
}
483480
return this->buffer_info_map_;

tests/python/unittest/test_tir_usmp_algo.py

Lines changed: 208 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def get_allocate(stmt):
5151
return allocates
5252

5353

54-
def assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
54+
def _assign_poolinfos_to_allocates_in_primfunc(primfunc, pool_infos):
5555
"""helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""
5656

5757
def set_poolinfos(stmt):
@@ -68,12 +68,12 @@ def set_poolinfos(stmt):
6868
return primfunc.with_body(stmt_functor.ir_transform(primfunc.body, None, set_poolinfos))
6969

7070

71-
def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
71+
def _assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
7272
"""helper to assing poolinfos to allocate nodes in a IRModule"""
7373
ret = tvm.IRModule()
7474
for global_var, basefunc in mod.functions.items():
7575
if isinstance(basefunc, tvm.tir.PrimFunc):
76-
ret[global_var] = assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
76+
ret[global_var] = _assign_poolinfos_to_allocates_in_primfunc(basefunc, pool_infos)
7777
return ret
7878

7979

@@ -96,9 +96,204 @@ def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
9696
assert max_workspace_size == size
9797

9898

99+
def test_no_pool_error():
100+
target = Target("c")
101+
tiny_workspace_pool = usmp_utils.PoolInfo(
102+
pool_name="tiny_workspace",
103+
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
104+
size_hint_bytes=10,
105+
)
106+
bi_a = usmp_utils.BufferInfo(
107+
name_hint="bi_a", size_bytes=10, pool_candidates=[tiny_workspace_pool]
108+
)
109+
bi_b = usmp_utils.BufferInfo(
110+
name_hint="bi_b", size_bytes=10, pool_candidates=[tiny_workspace_pool]
111+
)
112+
bi_c = usmp_utils.BufferInfo(
113+
name_hint="bi_c", size_bytes=10, pool_candidates=[tiny_workspace_pool]
114+
)
115+
bi_a.set_conflicts([bi_b])
116+
bi_b.set_conflicts([bi_c])
117+
bi_c.set_conflicts([bi_a])
118+
buffer_info_arr = [bi_a, bi_b, bi_c]
119+
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.greedy_by_size")
120+
with pytest.raises(
121+
tvm.TVMError, match="TVM USMP Error: no candidate have been selected for BufferInfoNode"
122+
):
123+
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
124+
125+
126+
@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"])
127+
def test_name_based_ordering(algorithm):
128+
""" This checks when the size and conlicts are same a stable result is generated"""
129+
130+
def _test():
131+
target = Target("c")
132+
global_workspace_pool = usmp_utils.PoolInfo(
133+
pool_name="global_workspace",
134+
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
135+
)
136+
bi_a = usmp_utils.BufferInfo(
137+
name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]
138+
)
139+
bi_b = usmp_utils.BufferInfo(
140+
name_hint="bi_b", size_bytes=10, pool_candidates=[global_workspace_pool]
141+
)
142+
bi_c = usmp_utils.BufferInfo(
143+
name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool]
144+
)
145+
bi_a.set_conflicts([bi_b])
146+
bi_b.set_conflicts([bi_c])
147+
bi_c.set_conflicts([bi_a])
148+
149+
buffer_info_arr = [bi_a, bi_b, bi_c]
150+
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
151+
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
152+
assert buffer_pool_allocations[bi_a].byte_offset == 0
153+
assert buffer_pool_allocations[bi_b].byte_offset == 20
154+
assert buffer_pool_allocations[bi_c].byte_offset == 10
155+
156+
# This is tested for several times to check stability
157+
for x in range(0, 10):
158+
_test()
159+
160+
161+
@pytest.mark.parametrize(
162+
["algorithm", "workspace_size"],
163+
[("greedy_by_size", 140), ("greedy_by_conflicts", 140)],
164+
)
165+
def test_linear(algorithm, workspace_size):
166+
"""
167+
The test case here represent BufferInfo objects
168+
that could get generated for a linear sequence
169+
such as :
170+
(Op A)
171+
|
172+
bi_a
173+
|
174+
(Op B)
175+
|
176+
bi_b
177+
|
178+
.
179+
.
180+
.
181+
(Op F)
182+
|
183+
bi_f
184+
"""
185+
target = Target("c")
186+
global_workspace_pool = usmp_utils.PoolInfo(
187+
pool_name="global_workspace",
188+
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
189+
)
190+
bi_a = usmp_utils.BufferInfo(
191+
name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]
192+
)
193+
bi_b = usmp_utils.BufferInfo(
194+
name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool]
195+
)
196+
bi_c = usmp_utils.BufferInfo(
197+
name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool]
198+
)
199+
bi_d = usmp_utils.BufferInfo(
200+
name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool]
201+
)
202+
bi_e = usmp_utils.BufferInfo(
203+
name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool]
204+
)
205+
bi_f = usmp_utils.BufferInfo(
206+
name_hint="bi_f", size_bytes=50, pool_candidates=[global_workspace_pool]
207+
)
208+
209+
# Creating conflicts for a linear graph
210+
bi_a.set_conflicts([bi_b])
211+
bi_b.set_conflicts([bi_a, bi_c])
212+
bi_c.set_conflicts([bi_b, bi_d])
213+
bi_d.set_conflicts([bi_c, bi_e])
214+
bi_e.set_conflicts([bi_d, bi_f])
215+
bi_f.set_conflicts([bi_e])
216+
217+
buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f]
218+
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
219+
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
220+
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
221+
222+
223+
@pytest.mark.parametrize(
224+
["algorithm", "workspace_size"],
225+
[("greedy_by_size", 190), ("greedy_by_conflicts", 320)],
226+
)
227+
def test_fanout(algorithm, workspace_size):
228+
"""
229+
The test case here represent BufferInfo objects
230+
that could get generated for a fanout topology
231+
such as :
232+
(Op A)
233+
|
234+
bi_a ---------
235+
| |
236+
(Op B) (Op C)
237+
| |
238+
bi_b bi_c
239+
| |
240+
(Op D) (Op E)
241+
| |
242+
bi_d bi_e
243+
| |
244+
(Op F) ------
245+
|
246+
bi_f
247+
|
248+
(Op G)
249+
|
250+
bi_g
251+
"""
252+
target = Target("c")
253+
global_workspace_pool = usmp_utils.PoolInfo(
254+
pool_name="global_workspace",
255+
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
256+
)
257+
bi_a = usmp_utils.BufferInfo(
258+
name_hint="bi_a", size_bytes=10, pool_candidates=[global_workspace_pool]
259+
)
260+
bi_b = usmp_utils.BufferInfo(
261+
name_hint="bi_b", size_bytes=20, pool_candidates=[global_workspace_pool]
262+
)
263+
bi_c = usmp_utils.BufferInfo(
264+
name_hint="bi_c", size_bytes=100, pool_candidates=[global_workspace_pool]
265+
)
266+
bi_d = usmp_utils.BufferInfo(
267+
name_hint="bi_d", size_bytes=40, pool_candidates=[global_workspace_pool]
268+
)
269+
bi_e = usmp_utils.BufferInfo(
270+
name_hint="bi_e", size_bytes=50, pool_candidates=[global_workspace_pool]
271+
)
272+
bi_f = usmp_utils.BufferInfo(
273+
name_hint="bi_f", size_bytes=60, pool_candidates=[global_workspace_pool]
274+
)
275+
bi_g = usmp_utils.BufferInfo(
276+
name_hint="bi_g", size_bytes=70, pool_candidates=[global_workspace_pool]
277+
)
278+
279+
# Creating conflicts for a linear graph
280+
bi_a.set_conflicts([bi_b, bi_c])
281+
bi_b.set_conflicts([bi_a, bi_c, bi_e])
282+
bi_c.set_conflicts([bi_e, bi_a, bi_b, bi_d])
283+
bi_d.set_conflicts([bi_b, bi_f, bi_c, bi_e])
284+
bi_e.set_conflicts([bi_c, bi_f, bi_b, bi_d])
285+
bi_f.set_conflicts([bi_d, bi_e, bi_f])
286+
bi_g.set_conflicts([bi_f])
287+
288+
buffer_info_arr = [bi_a, bi_b, bi_c, bi_d, bi_e, bi_f, bi_g]
289+
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
290+
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
291+
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, workspace_size)
292+
293+
99294
# fmt: off
100295
@tvm.script.ir_module
101-
class LinearStructure:
296+
class MobilenetStructure:
102297
@T.prim_func
103298
def tvmgen_default_fused_cast_subtract(placeholder_2: T.handle, placeholder_3: T.handle, T_subtract: T.handle) -> None:
104299
# function attr dict
@@ -167,22 +362,11 @@ def run_model(input: T.handle, output: T.handle) -> None:
167362
# fmt: on
168363

169364

170-
def print_conflicts(buffer_info_map):
171-
"""_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map)"""
172-
173-
for buffer_info_name, buf_info in buffer_info_map.items():
174-
conflict_str = "["
175-
for conflict in buf_info.conflicts:
176-
conflict_str += f'"{conflict.name_hint}", '
177-
conflict_str += "]"
178-
print(f'_verify_conflicts("{buffer_info_name}", {conflict_str}, buffer_info_map_names)')
179-
180-
181365
@pytest.mark.parametrize(
182366
["algorithm", "fast_memory_size", "slow_memory_size"],
183367
[("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)],
184368
)
185-
def test_linear(algorithm, fast_memory_size, slow_memory_size):
369+
def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size):
186370
target = Target("c")
187371
fast_memory_pool = usmp_utils.PoolInfo(
188372
pool_name="fast_memory",
@@ -192,18 +376,18 @@ def test_linear(algorithm, fast_memory_size, slow_memory_size):
192376
slow_memory_pool = usmp_utils.PoolInfo(
193377
pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
194378
)
195-
tir_mod = LinearStructure
379+
tir_mod = MobilenetStructure
196380
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
197-
tir_mod = assign_poolinfos_to_allocates_in_irmodule(
381+
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(
198382
tir_mod, [fast_memory_pool, slow_memory_pool]
199383
)
200384
main_func = tir_mod["run_model"]
201385
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
202386

203387
fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
204388
buffer_info_arr = fcreate_array_bi(buffer_info_map)
205-
fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
206-
buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
389+
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
390+
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
207391

208392
buffer_info_map_names = dict()
209393
for buf_info in buffer_info_arr:
@@ -346,22 +530,22 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
346530
@pytest.mark.parametrize(
347531
["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)]
348532
)
349-
def test_fanout(algorithm, workspace_size):
533+
def test_resnet_subgraph(algorithm, workspace_size):
350534
target = Target("c")
351535
global_workspace_pool = usmp_utils.PoolInfo(
352536
pool_name="global_workspace",
353537
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
354538
)
355539
tir_mod = ResnetStructure
356540
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
357-
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
541+
tir_mod = _assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
358542
main_func = tir_mod["tvmgen_default_run_model"]
359543
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
360544

361545
fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
362546
buffer_info_arr = fcreate_array_bi(buffer_info_map)
363-
fusmp_algo_greedy_by_size = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
364-
buffer_pool_allocations = fusmp_algo_greedy_by_size(buffer_info_arr)
547+
fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}")
548+
buffer_pool_allocations = fusmp_algo(buffer_info_arr)
365549

366550
buffer_info_map_names = dict()
367551
for buf_info in buffer_info_arr:

0 commit comments

Comments
 (0)