@@ -51,7 +51,7 @@ def get_allocate(stmt):
5151 return allocates
5252
5353
54- def assign_poolinfos_to_allocates_in_primfunc (primfunc , pool_infos ):
54+ def _assign_poolinfos_to_allocates_in_primfunc (primfunc , pool_infos ):
5555 """helper to assing poolinfos to allocate nodes in a tir.PrimFunc"""
5656
5757 def set_poolinfos (stmt ):
@@ -68,12 +68,12 @@ def set_poolinfos(stmt):
6868 return primfunc .with_body (stmt_functor .ir_transform (primfunc .body , None , set_poolinfos ))
6969
7070
71- def assign_poolinfos_to_allocates_in_irmodule (mod , pool_infos ):
71+ def _assign_poolinfos_to_allocates_in_irmodule (mod , pool_infos ):
7272 """helper to assing poolinfos to allocate nodes in a IRModule"""
7373 ret = tvm .IRModule ()
7474 for global_var , basefunc in mod .functions .items ():
7575 if isinstance (basefunc , tvm .tir .PrimFunc ):
76- ret [global_var ] = assign_poolinfos_to_allocates_in_primfunc (basefunc , pool_infos )
76+ ret [global_var ] = _assign_poolinfos_to_allocates_in_primfunc (basefunc , pool_infos )
7777 return ret
7878
7979
@@ -96,9 +96,204 @@ def _check_max_workspace_size(buffer_pool_allocations, pool_info, size):
9696 assert max_workspace_size == size
9797
9898
99+ def test_no_pool_error ():
100+ target = Target ("c" )
101+ tiny_workspace_pool = usmp_utils .PoolInfo (
102+ pool_name = "tiny_workspace" ,
103+ target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
104+ size_hint_bytes = 10 ,
105+ )
106+ bi_a = usmp_utils .BufferInfo (
107+ name_hint = "bi_a" , size_bytes = 10 , pool_candidates = [tiny_workspace_pool ]
108+ )
109+ bi_b = usmp_utils .BufferInfo (
110+ name_hint = "bi_b" , size_bytes = 10 , pool_candidates = [tiny_workspace_pool ]
111+ )
112+ bi_c = usmp_utils .BufferInfo (
113+ name_hint = "bi_c" , size_bytes = 10 , pool_candidates = [tiny_workspace_pool ]
114+ )
115+ bi_a .set_conflicts ([bi_b ])
116+ bi_b .set_conflicts ([bi_c ])
117+ bi_c .set_conflicts ([bi_a ])
118+ buffer_info_arr = [bi_a , bi_b , bi_c ]
119+ fusmp_algo = tvm .get_global_func (f"tir.usmp.algo.greedy_by_size" )
120+ with pytest .raises (
121+ tvm .TVMError , match = "TVM USMP Error: no candidate have been selected for BufferInfoNode"
122+ ):
123+ buffer_pool_allocations = fusmp_algo (buffer_info_arr )
124+
125+
126+ @pytest .mark .parametrize ("algorithm" , ["greedy_by_size" , "greedy_by_conflicts" ])
127+ def test_name_based_ordering (algorithm ):
128+ """ This checks when the size and conlicts are same a stable result is generated"""
129+
130+ def _test ():
131+ target = Target ("c" )
132+ global_workspace_pool = usmp_utils .PoolInfo (
133+ pool_name = "global_workspace" ,
134+ target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
135+ )
136+ bi_a = usmp_utils .BufferInfo (
137+ name_hint = "bi_a" , size_bytes = 10 , pool_candidates = [global_workspace_pool ]
138+ )
139+ bi_b = usmp_utils .BufferInfo (
140+ name_hint = "bi_b" , size_bytes = 10 , pool_candidates = [global_workspace_pool ]
141+ )
142+ bi_c = usmp_utils .BufferInfo (
143+ name_hint = "bi_c" , size_bytes = 10 , pool_candidates = [global_workspace_pool ]
144+ )
145+ bi_a .set_conflicts ([bi_b ])
146+ bi_b .set_conflicts ([bi_c ])
147+ bi_c .set_conflicts ([bi_a ])
148+
149+ buffer_info_arr = [bi_a , bi_b , bi_c ]
150+ fusmp_algo = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
151+ buffer_pool_allocations = fusmp_algo (buffer_info_arr )
152+ assert buffer_pool_allocations [bi_a ].byte_offset == 0
153+ assert buffer_pool_allocations [bi_b ].byte_offset == 20
154+ assert buffer_pool_allocations [bi_c ].byte_offset == 10
155+
156+ # This is tested for several times to check stability
157+ for x in range (0 , 10 ):
158+ _test ()
159+
160+
161+ @pytest .mark .parametrize (
162+ ["algorithm" , "workspace_size" ],
163+ [("greedy_by_size" , 140 ), ("greedy_by_conflicts" , 140 )],
164+ )
165+ def test_linear (algorithm , workspace_size ):
166+ """
167+ The test case here represent BufferInfo objects
168+ that could get generated for a linear sequence
169+ such as :
170+ (Op A)
171+ |
172+ bi_a
173+ |
174+ (Op B)
175+ |
176+ bi_b
177+ |
178+ .
179+ .
180+ .
181+ (Op F)
182+ |
183+ bi_f
184+ """
185+ target = Target ("c" )
186+ global_workspace_pool = usmp_utils .PoolInfo (
187+ pool_name = "global_workspace" ,
188+ target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
189+ )
190+ bi_a = usmp_utils .BufferInfo (
191+ name_hint = "bi_a" , size_bytes = 10 , pool_candidates = [global_workspace_pool ]
192+ )
193+ bi_b = usmp_utils .BufferInfo (
194+ name_hint = "bi_b" , size_bytes = 20 , pool_candidates = [global_workspace_pool ]
195+ )
196+ bi_c = usmp_utils .BufferInfo (
197+ name_hint = "bi_c" , size_bytes = 100 , pool_candidates = [global_workspace_pool ]
198+ )
199+ bi_d = usmp_utils .BufferInfo (
200+ name_hint = "bi_d" , size_bytes = 40 , pool_candidates = [global_workspace_pool ]
201+ )
202+ bi_e = usmp_utils .BufferInfo (
203+ name_hint = "bi_e" , size_bytes = 50 , pool_candidates = [global_workspace_pool ]
204+ )
205+ bi_f = usmp_utils .BufferInfo (
206+ name_hint = "bi_f" , size_bytes = 50 , pool_candidates = [global_workspace_pool ]
207+ )
208+
209+ # Creating conflicts for a linear graph
210+ bi_a .set_conflicts ([bi_b ])
211+ bi_b .set_conflicts ([bi_a , bi_c ])
212+ bi_c .set_conflicts ([bi_b , bi_d ])
213+ bi_d .set_conflicts ([bi_c , bi_e ])
214+ bi_e .set_conflicts ([bi_d , bi_f ])
215+ bi_f .set_conflicts ([bi_e ])
216+
217+ buffer_info_arr = [bi_a , bi_b , bi_c , bi_d , bi_e , bi_f ]
218+ fusmp_algo = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
219+ buffer_pool_allocations = fusmp_algo (buffer_info_arr )
220+ _check_max_workspace_size (buffer_pool_allocations , global_workspace_pool , workspace_size )
221+
222+
223+ @pytest .mark .parametrize (
224+ ["algorithm" , "workspace_size" ],
225+ [("greedy_by_size" , 190 ), ("greedy_by_conflicts" , 320 )],
226+ )
227+ def test_fanout (algorithm , workspace_size ):
228+ """
229+ The test case here represent BufferInfo objects
230+ that could get generated for a fanout topology
231+ such as :
232+ (Op A)
233+ |
234+ bi_a ---------
235+ | |
236+ (Op B) (Op C)
237+ | |
238+ bi_b bi_c
239+ | |
240+ (Op D) (Op E)
241+ | |
242+ bi_d bi_e
243+ | |
244+ (Op F) ------
245+ |
246+ bi_f
247+ |
248+ (Op G)
249+ |
250+ bi_g
251+ """
252+ target = Target ("c" )
253+ global_workspace_pool = usmp_utils .PoolInfo (
254+ pool_name = "global_workspace" ,
255+ target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
256+ )
257+ bi_a = usmp_utils .BufferInfo (
258+ name_hint = "bi_a" , size_bytes = 10 , pool_candidates = [global_workspace_pool ]
259+ )
260+ bi_b = usmp_utils .BufferInfo (
261+ name_hint = "bi_b" , size_bytes = 20 , pool_candidates = [global_workspace_pool ]
262+ )
263+ bi_c = usmp_utils .BufferInfo (
264+ name_hint = "bi_c" , size_bytes = 100 , pool_candidates = [global_workspace_pool ]
265+ )
266+ bi_d = usmp_utils .BufferInfo (
267+ name_hint = "bi_d" , size_bytes = 40 , pool_candidates = [global_workspace_pool ]
268+ )
269+ bi_e = usmp_utils .BufferInfo (
270+ name_hint = "bi_e" , size_bytes = 50 , pool_candidates = [global_workspace_pool ]
271+ )
272+ bi_f = usmp_utils .BufferInfo (
273+ name_hint = "bi_f" , size_bytes = 60 , pool_candidates = [global_workspace_pool ]
274+ )
275+ bi_g = usmp_utils .BufferInfo (
276+ name_hint = "bi_g" , size_bytes = 70 , pool_candidates = [global_workspace_pool ]
277+ )
278+
279+ # Creating conflicts for a linear graph
280+ bi_a .set_conflicts ([bi_b , bi_c ])
281+ bi_b .set_conflicts ([bi_a , bi_c , bi_e ])
282+ bi_c .set_conflicts ([bi_e , bi_a , bi_b , bi_d ])
283+ bi_d .set_conflicts ([bi_b , bi_f , bi_c , bi_e ])
284+ bi_e .set_conflicts ([bi_c , bi_f , bi_b , bi_d ])
285+ bi_f .set_conflicts ([bi_d , bi_e , bi_f ])
286+ bi_g .set_conflicts ([bi_f ])
287+
288+ buffer_info_arr = [bi_a , bi_b , bi_c , bi_d , bi_e , bi_f , bi_g ]
289+ fusmp_algo = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
290+ buffer_pool_allocations = fusmp_algo (buffer_info_arr )
291+ _check_max_workspace_size (buffer_pool_allocations , global_workspace_pool , workspace_size )
292+
293+
99294# fmt: off
100295@tvm .script .ir_module
101- class LinearStructure :
296+ class MobilenetStructure :
102297 @T .prim_func
103298 def tvmgen_default_fused_cast_subtract (placeholder_2 : T .handle , placeholder_3 : T .handle , T_subtract : T .handle ) -> None :
104299 # function attr dict
@@ -167,22 +362,11 @@ def run_model(input: T.handle, output: T.handle) -> None:
167362# fmt: on
168363
169364
170- def print_conflicts (buffer_info_map ):
171- """_verify_conflicts("sid_8", ["Conv2dOutput_7", "tensor_2"], buffer_info_map)"""
172-
173- for buffer_info_name , buf_info in buffer_info_map .items ():
174- conflict_str = "["
175- for conflict in buf_info .conflicts :
176- conflict_str += f'"{ conflict .name_hint } ", '
177- conflict_str += "]"
178- print (f'_verify_conflicts("{ buffer_info_name } ", { conflict_str } , buffer_info_map_names)' )
179-
180-
181365@pytest .mark .parametrize (
182366 ["algorithm" , "fast_memory_size" , "slow_memory_size" ],
183367 [("greedy_by_size" , 200704 , 1418528 ), ("greedy_by_conflicts" , 200704 , 1418528 )],
184368)
185- def test_linear (algorithm , fast_memory_size , slow_memory_size ):
369+ def test_mobilenet_subgraph (algorithm , fast_memory_size , slow_memory_size ):
186370 target = Target ("c" )
187371 fast_memory_pool = usmp_utils .PoolInfo (
188372 pool_name = "fast_memory" ,
@@ -192,18 +376,18 @@ def test_linear(algorithm, fast_memory_size, slow_memory_size):
192376 slow_memory_pool = usmp_utils .PoolInfo (
193377 pool_name = "slow_memory" , target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS }
194378 )
195- tir_mod = LinearStructure
379+ tir_mod = MobilenetStructure
196380 tir_mod = _assign_targets_to_primfuncs_irmodule (tir_mod , target )
197- tir_mod = assign_poolinfos_to_allocates_in_irmodule (
381+ tir_mod = _assign_poolinfos_to_allocates_in_irmodule (
198382 tir_mod , [fast_memory_pool , slow_memory_pool ]
199383 )
200384 main_func = tir_mod ["run_model" ]
201385 buffer_info_map = tvm .tir .usmp .analysis .extract_buffer_info (main_func , tir_mod )
202386
203387 fcreate_array_bi = tvm .get_global_func ("tir.usmp.CreateArrayBufferInfo" )
204388 buffer_info_arr = fcreate_array_bi (buffer_info_map )
205- fusmp_algo_greedy_by_size = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
206- buffer_pool_allocations = fusmp_algo_greedy_by_size (buffer_info_arr )
389+ fusmp_algo = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
390+ buffer_pool_allocations = fusmp_algo (buffer_info_arr )
207391
208392 buffer_info_map_names = dict ()
209393 for buf_info in buffer_info_arr :
@@ -346,22 +530,22 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
346530@pytest .mark .parametrize (
347531 ["algorithm" , "workspace_size" ], [("greedy_by_size" , 7920256 ), ("greedy_by_conflicts" , 7200256 )]
348532)
349- def test_fanout (algorithm , workspace_size ):
533+ def test_resnet_subgraph (algorithm , workspace_size ):
350534 target = Target ("c" )
351535 global_workspace_pool = usmp_utils .PoolInfo (
352536 pool_name = "global_workspace" ,
353537 target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
354538 )
355539 tir_mod = ResnetStructure
356540 tir_mod = _assign_targets_to_primfuncs_irmodule (tir_mod , target )
357- tir_mod = assign_poolinfos_to_allocates_in_irmodule (tir_mod , [global_workspace_pool ])
541+ tir_mod = _assign_poolinfos_to_allocates_in_irmodule (tir_mod , [global_workspace_pool ])
358542 main_func = tir_mod ["tvmgen_default_run_model" ]
359543 buffer_info_map = tvm .tir .usmp .analysis .extract_buffer_info (main_func , tir_mod )
360544
361545 fcreate_array_bi = tvm .get_global_func ("tir.usmp.CreateArrayBufferInfo" )
362546 buffer_info_arr = fcreate_array_bi (buffer_info_map )
363- fusmp_algo_greedy_by_size = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
364- buffer_pool_allocations = fusmp_algo_greedy_by_size (buffer_info_arr )
547+ fusmp_algo = tvm .get_global_func (f"tir.usmp.algo.{ algorithm } " )
548+ buffer_pool_allocations = fusmp_algo (buffer_info_arr )
365549
366550 buffer_info_map_names = dict ()
367551 for buf_info in buffer_info_arr :
0 commit comments