@@ -177,37 +177,26 @@ def _build_function_memory_map(function_metadata):
177177 """
178178 device_max_workspace = dict ()
179179 main_func_metadata = function_metadata [MAIN_FUNC_NAME_STR ]
180- main_targets = dict (main_func_metadata .workspace_sizes ).keys ()
181- from tvm .driver import tvmc # pylint: disable=import-outside-toplevel
182-
183- external_codegens = tvmc .composite_target .get_codegen_names ()
184180 func_entries = []
185181 target_local_entries = dict ()
186- for main_target in main_targets :
187- device_max_workspace [main_target ] = 0
188- for func_name , finfo in function_metadata .items ():
189- if func_name == MAIN_FUNC_NAME_STR :
190- continue
191- target_local_entries [func_name ] = list ()
192182
193- for func_name , finfo in function_metadata .items ():
194- # Skip a few unsupported cases:
195- # 1. The main function metadata is exported elsewhere.
196- # 2. BYOC operator implementations do not currently export useful FunctionInfo.
197- if func_name == MAIN_FUNC_NAME_STR or not finfo .tir_primfuncs :
198- continue
199- if main_target in finfo .workspace_sizes .keys ():
200- workspace_size = finfo .workspace_sizes [main_target ]
201- target_entry = {
202- "device" : int (main_target .kind .device_type ),
203- "workspace_size_bytes" : int (workspace_size ),
204- }
205- target_local_entries [func_name ].append (target_entry )
206- if workspace_size > device_max_workspace .get (main_target , 0 ):
207- device_max_workspace [main_target ] = workspace_size
208- # TODO(Mousius) - Remove this massive hack when Targets are unified
209- if main_target .kind .name in external_codegens :
210- device_max_workspace [main_target ] += int (workspace_size )
183+ for func_name , finfo in function_metadata .items ():
184+ # Skip a few unsupported cases:
185+ # 1. The main function metadata is exported elsewhere.
186+ # 2. BYOC operator implementations do not currently export useful FunctionInfo.
187+ if func_name == MAIN_FUNC_NAME_STR or not finfo .tir_primfuncs :
188+ continue
189+ if func_name not in target_local_entries .keys ():
190+ target_local_entries [func_name ] = list ()
191+ for target in dict (finfo .workspace_sizes ).keys ():
192+ workspace_size = finfo .workspace_sizes [target ]
193+ target_entry = {
194+ "device" : int (target .kind .device_type ),
195+ "workspace_size_bytes" : int (workspace_size ),
196+ }
197+ target_local_entries [func_name ].append (target_entry )
198+ if workspace_size >= device_max_workspace .get (int (target .kind .device_type ), 0 ):
199+ device_max_workspace [int (target .kind .device_type )] = workspace_size
211200
212201 for func_name , target_entries_ in target_local_entries .items ():
213202 func_entry = {
@@ -216,32 +205,46 @@ def _build_function_memory_map(function_metadata):
216205 }
217206 func_entries .append (func_entry )
218207
219- target_main_entries = list ()
220- for main_target in main_targets :
221- main_func_local_workspace = main_func_metadata .workspace_sizes [main_target ]
222- main_func_constants = (
223- main_func_metadata .constant_sizes [main_target ]
224- if main_target in main_func_metadata .constant_sizes .keys ()
225- else 0
208+ target_main_entries = dict ()
209+
210+ def _create_empty_entry (target_device_type ):
211+ return {
212+ "device" : int (target_device_type ),
213+ "workspace_size_bytes" : 0 ,
214+ "constants_size_bytes" : 0 ,
215+ "io_size_bytes" : 0 ,
216+ }
217+
218+ for target in dict (main_func_metadata .workspace_sizes ).keys ():
219+ main_func_local_workspace = main_func_metadata .workspace_sizes [target ]
220+ target_main_entries [int (target .kind .device_type )] = _create_empty_entry (
221+ int (target .kind .device_type )
226222 )
227- main_func_io = (
228- main_func_metadata .io_sizes [main_target ]
229- if main_target in main_func_metadata .io_sizes .keys ()
230- else 0
223+ target_main_entries [int (target .kind .device_type )]["workspace_size_bytes" ] = int (
224+ device_max_workspace .get (int (target .kind .device_type ), 0 )
225+ ) + int (main_func_local_workspace )
226+
227+ for target in dict (main_func_metadata .constant_sizes ).keys ():
228+ if int (target .kind .device_type ) not in target_main_entries .keys ():
229+ target_main_entries [int (target .kind .device_type )] = _create_empty_entry (
230+ int (target .kind .device_type )
231+ )
232+ target_main_entries [int (target .kind .device_type )]["constants_size_bytes" ] = int (
233+ main_func_metadata .constant_sizes [target ]
231234 )
232- target_main_entries . append (
233- {
234- "device" : int (main_target .kind .device_type ),
235- "workspace_size_bytes" : int (device_max_workspace [ main_target ])
236- + int (main_func_local_workspace ),
237- "constants_size_bytes" : int ( main_func_constants ),
238- "io_size_bytes" : int (main_func_io ),
239- }
235+
236+ for target in dict ( main_func_metadata . io_sizes ). keys ():
237+ if int (target .kind .device_type ) not in target_main_entries . keys ():
238+ target_main_entries [ int (target . kind . device_type )] = _create_empty_entry (
239+ int (target . kind . device_type )
240+ )
241+ target_main_entries [ int ( target . kind . device_type )][ "io_size_bytes" ] = int (
242+ main_func_metadata . io_sizes [ target ]
240243 )
241244
242245 ret = {
243246 "operator_functions" : func_entries ,
244- "main" : target_main_entries ,
247+ "main" : list ( target_main_entries . values ()) ,
245248 }
246249 return ret
247250
0 commit comments