@@ -216,6 +216,28 @@ def _create_type_metadata(input_type):
216216 "dtype" : str (input_type .dtype ),
217217 }
218218
219+ def _flatten_tuple_outputs (ret_type , predefined_names , offset = 0 ):
220+ if isinstance (ret_type , tvm .ir .tensor_type .TensorType ):
221+ name = predefined_names [offset ] if predefined_names else f"output{ offset } "
222+ return {
223+ name : ret_type
224+ }
225+
226+ added_fields = len (ret_type .fields )
227+ outputs = {}
228+ for output_index in range (added_fields ):
229+ next_output = offset + len (outputs )
230+ outputs .update (_flatten_tuple_outputs (ret_type .fields [output_index ], predefined_names , next_output ))
231+
232+ return outputs
233+
234+ def _get_outputs_from_ret_type (ret_type , predefined_names ):
235+ if isinstance (ret_type , tvm .ir .tensor_type .TensorType ):
236+ name = predefined_names [0 ] if predefined_names else "output"
237+ return {
238+ name : ret_type
239+ }
240+ return _flatten_tuple_outputs (ret_type , predefined_names )
219241
220242def _build_function_memory_map (function_metadata ):
221243 """Build a simple map that shows how much workspace is required to execute
@@ -297,29 +319,21 @@ def _create_empty_entry(target_device_type):
297319 target_main_entries [int (target .get_target_device_type ())] = _create_empty_entry (
298320 int (target .get_target_device_type ())
299321 )
300- target_main_entries [int (target .get_target_device_type ())]["io_size_bytes" ] = int (
322+ target_main_on_device = target_main_entries [int (target .get_target_device_type ())]
323+ target_main_on_device ["io_size_bytes" ] = int (
301324 main_func_metadata .io_sizes [target ]
302325 )
303326
304- # Now, we also add the information about the size of each input and output of the main
305- # function (in bytes)
306- input_dict = {}
307- for input_param in main_func_metadata .relay_primfuncs [target ].params :
308- input_dict [input_param .name_hint ] = _create_type_metadata (input_param .checked_type )
309- target_main_entries [int (target .get_target_device_type ())]["inputs" ] = input_dict
310-
311- output_dict = {}
312- # For output, we dont have the name of the output, so we enumerate them
313- if isinstance (main_func_metadata .relay_primfuncs [target ].ret_type , tvm .ir .type .TupleType ):
314- output_list = _convert_tuple_to_outputs (
315- main_func_metadata .relay_primfuncs [target ].ret_type
316- )
317- for i , output_type in enumerate (output_list ):
318- output_dict [f"output{ i } " ] = _create_type_metadata (output_type )
319- else :
320- output_type = main_func_metadata .relay_primfuncs [target ].ret_type
321- output_dict ["output" ] = _create_type_metadata (output_type )
322- target_main_entries [int (target .get_target_device_type ())]["outputs" ] = output_dict
327+ main_relay_func = main_func_metadata .relay_primfuncs [target ]
328+ target_main_on_device ["inputs" ] = {
329+ input_param .name_hint : _create_type_metadata (input_param .checked_type )
330+ for input_param in main_relay_func .params
331+ }
332+ predefined_names = main_relay_func .attrs ["output_tensor_names" ] if "output_tensor_names" in main_relay_func .attrs else None
333+ target_main_on_device ["outputs" ] = {
334+ name : _create_type_metadata (output_type )
335+ for name , output_type in _get_outputs_from_ret_type (main_relay_func .ret_type , predefined_names ).items ()
336+ }
323337
324338 ret = {
325339 "operator_functions" : func_entries ,
@@ -328,30 +342,6 @@ def _create_empty_entry(target_device_type):
328342 return ret
329343
330344
331- def _get_main_relay_func (mod : executor_factory .ExecutorFactoryModule ):
332- main_func = mod .function_metadata [MAIN_FUNC_NAME_STR ]
333- target = list (main_func .relay_primfuncs .keys ())[0 ]
334- return main_func .relay_primfuncs [target ]
335-
336-
337- def _convert_tuple_to_outputs (ret_type , offset = 0 ):
338- outputs = []
339- added_fields = len (ret_type .fields )
340- for output_index in range (added_fields ):
341- next_output = offset + len (outputs )
342- if isinstance (ret_type .fields [output_index ], TupleType ):
343- outputs .extend (_convert_tuple_to_outputs (ret_type .fields [output_index ], next_output ))
344- else :
345- outputs .append (ret_type .fields [output_index ])
346- return outputs
347-
348-
349- def _get_inputs_and_outputs_from_module (mod ):
350- inputs = [str (input_var .name ) for input_var in mod .executor_codegen_metadata .inputs ]
351- outputs = list (mod .executor_codegen_metadata .outputs )
352- return inputs , outputs
353-
354-
355345def _get_pools_from_module (mod ):
356346 return list (dict (mod .executor_codegen_metadata .pool_inputs ).values ())
357347
@@ -462,33 +452,30 @@ def _export_graph_model_library_format(
462452 if not include_path .exists ():
463453 include_path .mkdir ()
464454
465- inputs , outputs = _get_inputs_and_outputs_from_module (mod )
466455 devices = mod .get_devices ()
467456 pools = _get_pools_from_module (mod )
468457 io_pool_allocations = _get_io_pool_allocation_from_module (mod )
458+ main_func = metadata ["modules" ][mod .libmod_name ]["memory" ]["functions" ]["main" ][0 ]
469459 workspace_size = int (
470- metadata [ "modules" ][ mod . libmod_name ][ "memory" ][ "functions" ][ "main" ][ 0 ] [
460+ main_func [
471461 "workspace_size_bytes"
472462 ]
473463 )
474- inputs_sizes = metadata ["modules" ][mod .libmod_name ]["memory" ]["functions" ]["main" ][0 ][
475- "inputs"
476- ]
477- # Here, we merge the output sizes with the actual output names
478- output_sizes = {}
479- for i , key in enumerate (
480- metadata ["modules" ][mod .libmod_name ]["memory" ]["functions" ]["main" ][0 ][
481- "outputs"
482- ].keys ()
483- ):
484- output_sizes [outputs [i ]] = metadata ["modules" ][mod .libmod_name ]["memory" ][
485- "functions"
486- ]["main" ][0 ]["outputs" ][key ]
464+ inputs = main_func ["inputs" ]
465+ outputs = main_func ["outputs" ]
466+ inputs_sizes = {
467+ name : property_map ["size" ] for name , property_map in inputs .items ()
468+ }
469+ output_sizes = {
470+ name : property_map ["size" ] for name , property_map in outputs .items ()
471+ }
472+ input_names = list (inputs .keys ())
473+ output_names = list (outputs .keys ())
487474
488475 generate_c_interface_header (
489476 mod .libmod_name ,
490- inputs ,
491- outputs ,
477+ input_names ,
478+ output_names ,
492479 pools ,
493480 io_pool_allocations ,
494481 devices ,
0 commit comments