Skip to content

Commit 313e1ba

Browse files
committed
[TIR][USMP] greedy_by_size usmp algo
* Adding targets to the PrimFuncs in the tests Change-Id: Ic91947e23cbcc4fc0020eb62f0ed9df26cf919f9
1 parent e7d4fee commit 313e1ba

File tree

2 files changed

+119
-31
lines changed

2 files changed

+119
-31
lines changed

src/tir/usmp/algo/greedy_by_size.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,13 @@ namespace tir {
3434
namespace usmp {
3535
namespace algo {
3636

37-
size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
38-
const int& byte_alignment) {
37+
static size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset,
38+
const int& byte_alignment) {
3939
return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment;
4040
}
4141

42-
bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
43-
const size_t& size_bytes) {
42+
static bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
43+
const size_t& size_bytes) {
4444
if (candidate_pool->size_hint_bytes == -1) {
4545
// this means pool is not bounded
4646
return true;
@@ -53,7 +53,7 @@ bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
5353
return false;
5454
}
5555

56-
PoolInfo SelectPlacementPool(
56+
static PoolInfo SelectPlacementPool(
5757
const Array<PoolInfo>& pool_candidates,
5858
const std::unordered_map<PoolInfo, size_t, ObjectPtrHash, ObjectPtrEqual>& pool_offsets) {
5959
for (const auto& pool_info : pool_candidates) {

tests/python/unittest/test_tir_usmp_algo_greedy_by_size.py

Lines changed: 114 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
7777
return ret
7878

7979

80+
def _assign_targets_to_primfuncs_irmodule(mod, target):
81+
"""helper to assign target for PrimFunc in a IRModule"""
82+
ret = tvm.IRModule()
83+
for global_var, basefunc in mod.functions.items():
84+
if isinstance(basefunc, tvm.tir.PrimFunc):
85+
ret[global_var] = basefunc.with_attr("target", target)
86+
return ret
87+
88+
8089
def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
8190
max_workspace_size = 0
8291
for buffer_info, pool_allocation in buffer_pool_allocations.items():
@@ -143,7 +152,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6:
143152
T.store(T_cast_7.data, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3), T.cast(T.load("uint8", tensor_2, (((ax0_ax1_fused_5*3584) + (ax2_5*64)) + ax3_3)), "int16"), True)
144153

145154
@T.prim_func
146-
def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
155+
def run_model(input: T.handle, output: T.handle) -> None:
147156
# function attr dict
148157
T.func_attr({"global_symbol": "tvmgen_default_run_model", "runner_function": True})
149158
# body
@@ -159,19 +168,21 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
159168

160169

161170
def test_linear():
171+
target = Target("c")
162172
fast_memory_pool = usmp_utils.PoolInfo(
163173
pool_name="fast_memory",
164-
target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
174+
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
165175
size_hint_bytes=200704,
166176
)
167177
slow_memory_pool = usmp_utils.PoolInfo(
168-
pool_name="slow_memory", target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS}
178+
pool_name="slow_memory", target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS}
169179
)
170180
tir_mod = LinearStructure
181+
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
171182
tir_mod = assign_poolinfos_to_allocates_in_irmodule(
172183
tir_mod, [fast_memory_pool, slow_memory_pool]
173184
)
174-
main_func = tir_mod["tvmgen_default_run_model"]
185+
main_func = tir_mod["run_model"]
175186
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
176187

177188
fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo")
@@ -184,13 +195,15 @@ def test_linear():
184195
buffer_info_map_names[buf_info.name_hint] = buf_info
185196

186197
# check conflicts
187-
_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map_names)
188-
_verify_conflicts("Conv2dOutput_7", ["PaddedInput_7", "sid_8"], buffer_info_map_names)
189-
_verify_conflicts("PaddedInput_7", ["sid_9", "Conv2dOutput_7"], buffer_info_map_names)
198+
_verify_conflicts("PaddedInput_7", ["sid_9", "sid_8", "Conv2dOutput_7"], buffer_info_map_names)
190199
_verify_conflicts("tensor_2", ["sid_8"], buffer_info_map_names)
191200
_verify_conflicts("sid_9", ["PaddedInput_7"], buffer_info_map_names)
201+
_verify_conflicts(
202+
"sid_8", ["PaddedInput_7", "Conv2dOutput_7", "tensor_2"], buffer_info_map_names
203+
)
204+
_verify_conflicts("Conv2dOutput_7", ["sid_8", "PaddedInput_7"], buffer_info_map_names)
192205

193-
_check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 802816)
206+
_check_max_workspace_size(buffer_pool_allocations, slow_memory_pool, 1418528)
194207
_check_max_workspace_size(buffer_pool_allocations, fast_memory_pool, 200704)
195208

196209

@@ -316,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
316329

317330

318331
def test_fanout():
332+
target = Target("c")
319333
global_workspace_pool = usmp_utils.PoolInfo(
320334
pool_name="global_workspace",
321-
target_access={Target("c"): usmp_utils.PoolInfo.READ_WRITE_ACCESS},
335+
target_access={target: usmp_utils.PoolInfo.READ_WRITE_ACCESS},
322336
)
323337
tir_mod = ResnetStructure
338+
tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target)
324339
tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool])
325340
main_func = tir_mod["tvmgen_default_run_model"]
326341
buffer_info_map = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod)
@@ -336,35 +351,108 @@ def test_fanout():
336351

337352
# check conflicts
338353
_verify_conflicts(
339-
"sid_6",
340-
["Conv2dOutput_2", "sid_2", "PaddedInput_3", "Conv2dOutput_3"],
354+
"Conv2dOutput_1",
355+
[
356+
"PaddedInput_1",
357+
"sid_7",
358+
],
341359
buffer_info_map_names,
342360
)
343-
_verify_conflicts("PaddedInput_1", ["sid_8", "sid_2", "Conv2dOutput_1"], buffer_info_map_names)
344-
_verify_conflicts("PaddedInput_2", ["sid_7", "sid_2", "Conv2dOutput_2"], buffer_info_map_names)
345-
_verify_conflicts("sid_8", ["Conv2dOutput", "sid_2", "PaddedInput_1"], buffer_info_map_names)
346361
_verify_conflicts(
347-
"sid_2",
362+
"sid_8",
348363
[
349364
"PaddedInput",
350365
"Conv2dOutput",
351-
"sid_8",
352366
"PaddedInput_1",
353-
"Conv2dOutput_1",
367+
],
368+
buffer_info_map_names,
369+
)
370+
_verify_conflicts(
371+
"PaddedInput_2",
372+
[
354373
"sid_7",
374+
"sid_6",
375+
"Conv2dOutput_2",
376+
],
377+
buffer_info_map_names,
378+
)
379+
_verify_conflicts(
380+
"sid_2",
381+
[
382+
"PaddedInput",
383+
"PaddedInput_3",
384+
],
385+
buffer_info_map_names,
386+
)
387+
_verify_conflicts(
388+
"Conv2dOutput",
389+
[
390+
"sid_8",
391+
"PaddedInput",
392+
],
393+
buffer_info_map_names,
394+
)
395+
_verify_conflicts(
396+
"sid_7",
397+
[
398+
"Conv2dOutput_1",
399+
"PaddedInput_1",
400+
"PaddedInput_2",
401+
],
402+
buffer_info_map_names,
403+
)
404+
_verify_conflicts(
405+
"sid_6",
406+
[
355407
"PaddedInput_2",
356408
"Conv2dOutput_2",
409+
"Conv2dOutput_3",
410+
"PaddedInput_3",
411+
],
412+
buffer_info_map_names,
413+
)
414+
_verify_conflicts(
415+
"PaddedInput_3",
416+
[
417+
"sid_2",
418+
"Conv2dOutput_3",
357419
"sid_6",
420+
],
421+
buffer_info_map_names,
422+
)
423+
_verify_conflicts(
424+
"Conv2dOutput_3",
425+
[
358426
"PaddedInput_3",
427+
"sid_6",
428+
],
429+
buffer_info_map_names,
430+
)
431+
_verify_conflicts(
432+
"PaddedInput",
433+
[
434+
"sid_2",
435+
"sid_8",
436+
"Conv2dOutput",
437+
],
438+
buffer_info_map_names,
439+
)
440+
_verify_conflicts(
441+
"Conv2dOutput_2",
442+
[
443+
"sid_6",
444+
"PaddedInput_2",
359445
],
360446
buffer_info_map_names,
361447
)
362-
_verify_conflicts("PaddedInput", ["sid_2", "Conv2dOutput"], buffer_info_map_names)
363-
_verify_conflicts("sid_7", ["Conv2dOutput_1", "sid_2", "PaddedInput_2"], buffer_info_map_names)
364-
_verify_conflicts("PaddedInput_3", ["sid_6", "sid_2", "Conv2dOutput_3"], buffer_info_map_names)
365-
_verify_conflicts("Conv2dOutput_3", ["PaddedInput_3", "sid_6"], buffer_info_map_names)
366-
_verify_conflicts("Conv2dOutput", ["PaddedInput", "sid_2", "sid_8"], buffer_info_map_names)
367-
_verify_conflicts("Conv2dOutput_1", ["PaddedInput_1", "sid_2", "sid_7"], buffer_info_map_names)
368-
_verify_conflicts("Conv2dOutput_2", ["PaddedInput_2", "sid_2", "sid_6"], buffer_info_map_names)
369-
370-
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7920256)
448+
_verify_conflicts(
449+
"PaddedInput_1",
450+
[
451+
"sid_8",
452+
"Conv2dOutput_1",
453+
"sid_7",
454+
],
455+
buffer_info_map_names,
456+
)
457+
458+
_check_max_workspace_size(buffer_pool_allocations, global_workspace_pool, 7200000)

0 commit comments

Comments
 (0)