@@ -77,6 +77,15 @@ def assign_poolinfos_to_allocates_in_irmodule(mod, pool_infos):
7777 return ret
7878
7979
80+ def _assign_targets_to_primfuncs_irmodule (mod , target ):
81+ """helper to assign target for PrimFunc in a IRModule"""
82+ ret = tvm .IRModule ()
83+ for global_var , basefunc in mod .functions .items ():
84+ if isinstance (basefunc , tvm .tir .PrimFunc ):
85+ ret [global_var ] = basefunc .with_attr ("target" , target )
86+ return ret
87+
88+
8089def _check_max_workspace_size (buffer_pool_allocations , pool_info , size ):
8190 max_workspace_size = 0
8291 for buffer_info , pool_allocation in buffer_pool_allocations .items ():
@@ -143,7 +152,7 @@ def tvmgen_default_fused_nn_max_pool2d_cast(placeholder_28: T.handle, T_cast_6:
143152 T .store (T_cast_7 .data , (((ax0_ax1_fused_5 * 3584 ) + (ax2_5 * 64 )) + ax3_3 ), T .cast (T .load ("uint8" , tensor_2 , (((ax0_ax1_fused_5 * 3584 ) + (ax2_5 * 64 )) + ax3_3 )), "int16" ), True )
144153
145154 @T .prim_func
146- def tvmgen_default_run_model (input : T .handle , output : T .handle ) -> None :
155+ def run_model (input : T .handle , output : T .handle ) -> None :
147156 # function attr dict
148157 T .func_attr ({"global_symbol" : "tvmgen_default_run_model" , "runner_function" : True })
149158 # body
@@ -159,19 +168,21 @@ def tvmgen_default_run_model(input: T.handle, output: T.handle) -> None:
159168
160169
161170def test_linear ():
171+ target = Target ("c" )
162172 fast_memory_pool = usmp_utils .PoolInfo (
163173 pool_name = "fast_memory" ,
164- target_access = {Target ( "c" ) : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
174+ target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
165175 size_hint_bytes = 200704 ,
166176 )
167177 slow_memory_pool = usmp_utils .PoolInfo (
168- pool_name = "slow_memory" , target_access = {Target ( "c" ) : usmp_utils .PoolInfo .READ_WRITE_ACCESS }
178+ pool_name = "slow_memory" , target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS }
169179 )
170180 tir_mod = LinearStructure
181+ tir_mod = _assign_targets_to_primfuncs_irmodule (tir_mod , target )
171182 tir_mod = assign_poolinfos_to_allocates_in_irmodule (
172183 tir_mod , [fast_memory_pool , slow_memory_pool ]
173184 )
174- main_func = tir_mod ["tvmgen_default_run_model " ]
185+ main_func = tir_mod ["run_model " ]
175186 buffer_info_map = tvm .tir .usmp .analysis .extract_buffer_info (main_func , tir_mod )
176187
177188 fcreate_array_bi = tvm .get_global_func ("tir.usmp.CreateArrayBufferInfo" )
@@ -184,13 +195,15 @@ def test_linear():
184195 buffer_info_map_names [buf_info .name_hint ] = buf_info
185196
186197 # check conflicts
187- _verify_conflicts ("sid_8" , ["Conv2dOutput_7" , "tensor_2" ], buffer_info_map_names )
188- _verify_conflicts ("Conv2dOutput_7" , ["PaddedInput_7" , "sid_8" ], buffer_info_map_names )
189- _verify_conflicts ("PaddedInput_7" , ["sid_9" , "Conv2dOutput_7" ], buffer_info_map_names )
198+ _verify_conflicts ("PaddedInput_7" , ["sid_9" , "sid_8" , "Conv2dOutput_7" ], buffer_info_map_names )
190199 _verify_conflicts ("tensor_2" , ["sid_8" ], buffer_info_map_names )
191200 _verify_conflicts ("sid_9" , ["PaddedInput_7" ], buffer_info_map_names )
201+ _verify_conflicts (
202+ "sid_8" , ["PaddedInput_7" , "Conv2dOutput_7" , "tensor_2" ], buffer_info_map_names
203+ )
204+ _verify_conflicts ("Conv2dOutput_7" , ["sid_8" , "PaddedInput_7" ], buffer_info_map_names )
192205
193- _check_max_workspace_size (buffer_pool_allocations , slow_memory_pool , 802816 )
206+ _check_max_workspace_size (buffer_pool_allocations , slow_memory_pool , 1418528 )
194207 _check_max_workspace_size (buffer_pool_allocations , fast_memory_pool , 200704 )
195208
196209
@@ -316,11 +329,13 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place
316329
317330
318331def test_fanout ():
332+ target = Target ("c" )
319333 global_workspace_pool = usmp_utils .PoolInfo (
320334 pool_name = "global_workspace" ,
321- target_access = {Target ( "c" ) : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
335+ target_access = {target : usmp_utils .PoolInfo .READ_WRITE_ACCESS },
322336 )
323337 tir_mod = ResnetStructure
338+ tir_mod = _assign_targets_to_primfuncs_irmodule (tir_mod , target )
324339 tir_mod = assign_poolinfos_to_allocates_in_irmodule (tir_mod , [global_workspace_pool ])
325340 main_func = tir_mod ["tvmgen_default_run_model" ]
326341 buffer_info_map = tvm .tir .usmp .analysis .extract_buffer_info (main_func , tir_mod )
@@ -336,35 +351,108 @@ def test_fanout():
336351
337352 # check conflicts
338353 _verify_conflicts (
339- "sid_6" ,
340- ["Conv2dOutput_2" , "sid_2" , "PaddedInput_3" , "Conv2dOutput_3" ],
354+ "Conv2dOutput_1" ,
355+ [
356+ "PaddedInput_1" ,
357+ "sid_7" ,
358+ ],
341359 buffer_info_map_names ,
342360 )
343- _verify_conflicts ("PaddedInput_1" , ["sid_8" , "sid_2" , "Conv2dOutput_1" ], buffer_info_map_names )
344- _verify_conflicts ("PaddedInput_2" , ["sid_7" , "sid_2" , "Conv2dOutput_2" ], buffer_info_map_names )
345- _verify_conflicts ("sid_8" , ["Conv2dOutput" , "sid_2" , "PaddedInput_1" ], buffer_info_map_names )
346361 _verify_conflicts (
347- "sid_2 " ,
362+ "sid_8 " ,
348363 [
349364 "PaddedInput" ,
350365 "Conv2dOutput" ,
351- "sid_8" ,
352366 "PaddedInput_1" ,
353- "Conv2dOutput_1" ,
367+ ],
368+ buffer_info_map_names ,
369+ )
370+ _verify_conflicts (
371+ "PaddedInput_2" ,
372+ [
354373 "sid_7" ,
374+ "sid_6" ,
375+ "Conv2dOutput_2" ,
376+ ],
377+ buffer_info_map_names ,
378+ )
379+ _verify_conflicts (
380+ "sid_2" ,
381+ [
382+ "PaddedInput" ,
383+ "PaddedInput_3" ,
384+ ],
385+ buffer_info_map_names ,
386+ )
387+ _verify_conflicts (
388+ "Conv2dOutput" ,
389+ [
390+ "sid_8" ,
391+ "PaddedInput" ,
392+ ],
393+ buffer_info_map_names ,
394+ )
395+ _verify_conflicts (
396+ "sid_7" ,
397+ [
398+ "Conv2dOutput_1" ,
399+ "PaddedInput_1" ,
400+ "PaddedInput_2" ,
401+ ],
402+ buffer_info_map_names ,
403+ )
404+ _verify_conflicts (
405+ "sid_6" ,
406+ [
355407 "PaddedInput_2" ,
356408 "Conv2dOutput_2" ,
409+ "Conv2dOutput_3" ,
410+ "PaddedInput_3" ,
411+ ],
412+ buffer_info_map_names ,
413+ )
414+ _verify_conflicts (
415+ "PaddedInput_3" ,
416+ [
417+ "sid_2" ,
418+ "Conv2dOutput_3" ,
357419 "sid_6" ,
420+ ],
421+ buffer_info_map_names ,
422+ )
423+ _verify_conflicts (
424+ "Conv2dOutput_3" ,
425+ [
358426 "PaddedInput_3" ,
427+ "sid_6" ,
428+ ],
429+ buffer_info_map_names ,
430+ )
431+ _verify_conflicts (
432+ "PaddedInput" ,
433+ [
434+ "sid_2" ,
435+ "sid_8" ,
436+ "Conv2dOutput" ,
437+ ],
438+ buffer_info_map_names ,
439+ )
440+ _verify_conflicts (
441+ "Conv2dOutput_2" ,
442+ [
443+ "sid_6" ,
444+ "PaddedInput_2" ,
359445 ],
360446 buffer_info_map_names ,
361447 )
362- _verify_conflicts ("PaddedInput" , ["sid_2" , "Conv2dOutput" ], buffer_info_map_names )
363- _verify_conflicts ("sid_7" , ["Conv2dOutput_1" , "sid_2" , "PaddedInput_2" ], buffer_info_map_names )
364- _verify_conflicts ("PaddedInput_3" , ["sid_6" , "sid_2" , "Conv2dOutput_3" ], buffer_info_map_names )
365- _verify_conflicts ("Conv2dOutput_3" , ["PaddedInput_3" , "sid_6" ], buffer_info_map_names )
366- _verify_conflicts ("Conv2dOutput" , ["PaddedInput" , "sid_2" , "sid_8" ], buffer_info_map_names )
367- _verify_conflicts ("Conv2dOutput_1" , ["PaddedInput_1" , "sid_2" , "sid_7" ], buffer_info_map_names )
368- _verify_conflicts ("Conv2dOutput_2" , ["PaddedInput_2" , "sid_2" , "sid_6" ], buffer_info_map_names )
369-
370- _check_max_workspace_size (buffer_pool_allocations , global_workspace_pool , 7920256 )
448+ _verify_conflicts (
449+ "PaddedInput_1" ,
450+ [
451+ "sid_8" ,
452+ "Conv2dOutput_1" ,
453+ "sid_7" ,
454+ ],
455+ buffer_info_map_names ,
456+ )
457+
458+ _check_max_workspace_size (buffer_pool_allocations , global_workspace_pool , 7200000 )
0 commit comments