@@ -49,16 +49,26 @@ def build(pipe_configs):
4949 Common interface for pipeline executor factory modules.
5050 """
5151 libs = {}
52- mod_n_configs = pipe_configs .get_config ()
52+ config = pipe_configs .get_config ()
53+ if "module_connection" not in config :
54+ raise RuntimeError ('"module_connection" is missing' )
55+ if "input_connection" not in config :
56+ raise RuntimeError ('"input_connection" is missing' )
57+
58+ mod_n_configs = config ["module_connection" ]
5359 config_len = len (mod_n_configs )
54- string_config = [{} for _ in range (config_len )]
60+ module_string_config = [{} for _ in range (config_len )]
61+ # Use hardware configurations to build backend modules for each subgraph.
5562 for ir_mod , mod_config in mod_n_configs .items ():
56- mconf = mod_config ["pipeline" ].copy ()
57- mod_idx = mconf ["mod_idx" ]
63+ pipe_config = mod_config ["pipeline" ].copy ()
64+ mod_idx = pipe_config ["mod_idx" ]
5865 dev = mod_config ["dev" ]
5966 target = mod_config ["target" ]
6067 build_func = relay .build
61- # Check whether there is a customized build function.
68+ # Callers may need to use a customized building function to wrap the pre-building logic
69+ # and the backend building logic. For example, in order to support a backend which only
70+ # can do "int8" computation, the caller may need to merge the "quantization" logic
71+ # into the building logic to creat a customized building function.
6272 if "build" in mod_config and mod_config ["build" ]:
6373 build_func = mod_config ["build" ]
6474
@@ -70,11 +80,20 @@ def build(pipe_configs):
7080 mod_name = mod_config ["mod_name" ],
7181 )
7282
73- mconf ["dev" ] = "{},{}" .format (dev .device_type , dev .device_id )
74- # Create a pipeline configuration.
75- string_config [mod_idx ] = mconf
83+ pipe_config ["dev" ] = "{},{}" .format (dev .device_type , dev .device_id )
84+ # Use "mod_idx" as the key to create a "module_connection" map which is not only
85+ # for the module index but also for the module connection used to build the pipeline.
86+ module_string_config [mod_idx ] = pipe_config
7687 libs [mod_idx ] = {"lib" : lib , "dev" : dev }
7788
89+ # Creating a text form configuration to record the "input_connection" and the
90+ # "module_connection" information. The "input_connection" is used to record the
91+ # map of global input and subgraph input, and the "module_connection" is used to
92+ # record module dependency.
93+ string_config = {}
94+ string_config ["input_connection" ] = config ["input_connection" ]
95+ string_config ["module_connection" ] = module_string_config
96+
7897 return PipelineExecutorFactoryModule (libs , string_config )
7998
8099
@@ -94,6 +113,17 @@ def __init__(self, module):
94113 self .module = module
95114 # Get the packed functions from the pipeline executor.
96115 self ._get_num_outputs = self .module ["get_num_outputs" ]
116+ self ._get_input_pipeline_map = self .module ["get_input_pipeline_map" ]
117+
118+ def get_input_pipeline_map (self , name ):
119+ """Using the "name" to get the corresponding subgraph index and also get the "input name"
120+ of the corresponding subgraph interface.
121+ Returns
122+ -------
123+ input map: Array[str]
124+ Returning the index and "input name" of the subgraph.
125+ """
126+ return self ._get_input_pipeline_map (name )
97127
98128 @property
99129 def num_outputs (self ):
@@ -199,12 +229,48 @@ def is_pipeline_executor_interface(self):
199229 return not isinstance (self .io_owner , PipelineConfig .ModuleWrapper )
200230
201231 def __repr__ (self ):
202- # Get all binding information.
203- ret = " |{}: " .format (self .name )
232+ # Geting the binding information in the form of text .
233+ str_format = " |{}: " .format (self .name )
204234 for binding in self .bindings :
205235 mname , dname = binding .get_name ()
206- ret += "{0}:{1} " .format (mname , dname )
207- return ret
236+ str_format += "{0}:{1} " .format (mname , dname )
237+
238+ return str_format
239+
240+ def check_binding_dict (self , connection_dict ):
241+ """Checking the binding dictionary.
242+ Parameter
243+ ---------
244+ connection_dict : Dict[str, Any]
245+ It is a dictionary of module connections.
246+ """
247+ if "interface_name" not in connection_dict :
248+ raise RuntimeError ('"inteface_name" is missing in global config!"' )
249+ if "connection" not in connection_dict :
250+ raise RuntimeError (f'"connection" is missing!"' )
251+ # The global interface mapping should be one-to-one.
252+ if not connection_dict ["connection" ]:
253+ raise RuntimeError ("The global interface map is empty!" )
254+ if len (connection_dict ["connection" ]) > 1 :
255+ raise RuntimeError ("A global interface maps multiple module interfaces!" )
256+ if "mod_idx" not in connection_dict ["connection" ][0 ]:
257+ raise RuntimeError ('"mod_idx" is missing!' )
258+
259+ def get_binding_dict (self ):
260+ """Returning the binding information in the form of dictionary.
261+ Returns
262+ -------
263+ data : Dict[str, Any]
264+ The binding information is in the form of dictionary.
265+ """
266+ dict_format = {"interface_name" : self .name , "connection" : []}
267+ for binding in self .bindings :
268+ _ , dname = binding .get_name ()
269+ midx = binding .get_owner_idx ()
270+ dict_format ["connection" ].append ({"mod_idx" : midx , "interface_name" : dname })
271+
272+ self .check_binding_dict (dict_format )
273+ return dict_format
208274
209275 def check_dag_acyclic (self , start , inputs ):
210276 """This is to check whether the DAG containing these input interfaces is acyclic.
@@ -243,30 +309,34 @@ def connect(self, binding):
243309
244310 # Check whether the binding setting is correct or not.
245311 if self .io_owner == binding .io_owner :
246- raise RuntimeError (f "Can not bind itself." )
312+ raise RuntimeError ("Can not bind itself." )
247313
248314 if not self .is_pipeline_executor_interface () and self .io_type == "input" :
249- raise RuntimeError (f "Module can only bind from output interface!" )
315+ raise RuntimeError ("Module can only bind from output interface!" )
250316
251317 if (
252318 not self .is_pipeline_executor_interface ()
253319 and not binding .is_pipeline_executor_interface ()
254320 and binding .io_type == "output"
255321 ):
256- raise RuntimeError (f "Can not bind module output with another module output!" )
322+ raise RuntimeError ("Can not bind module output with another module output!" )
257323
258324 if (
259325 not self .is_pipeline_executor_interface ()
260326 and binding .is_pipeline_executor_interface ()
261327 and binding .io_type == "input"
262328 ):
263- raise RuntimeError (f "Can not bind module output with pipeline input!" )
329+ raise RuntimeError ("Can not bind module output with pipeline input!" )
264330
265331 if self .is_pipeline_executor_interface () and self .io_type == "output" :
266- raise RuntimeError (f "Global output can not be used as binding start point." )
332+ raise RuntimeError ("Global output can not be used as binding start point." )
267333
268- if self .is_pipeline_executor_interface () and binding .io_type != "input" :
269- raise RuntimeError (f"Global input can only bind with module input." )
334+ if (
335+ self .is_pipeline_executor_interface ()
336+ and self .io_type == "input"
337+ and binding .io_type != "input"
338+ ):
339+ raise RuntimeError ("Global input can only bind with module input." )
270340
271341 self .bindings .append (binding )
272342 if not self .is_pipeline_executor_interface ():
@@ -288,7 +358,7 @@ def connect(self, binding):
288358 if not self .check_dag_acyclic (
289359 binding .io_owner , self .io_owner .input_bindings .bindings
290360 ):
291- raise RuntimeError (f "Illegal connection: Cause a cycle!" )
361+ raise RuntimeError ("Illegal connection: Cause a cycle!" )
292362
293363 class BindingList :
294364 """Container for bindings(input or output interface).
@@ -357,7 +427,9 @@ def __getitem__(self, key):
357427 if key == "output" :
358428 return self .output_bindings
359429
360- raise RuntimeError (f"{ key } not found!" )
430+ raise RuntimeError (f"{ key } not found!" )
431+
432+ raise RuntimeError ('The data type of "key" is not supported!' )
361433
362434 def get_data_type (self , key , interface_type ):
363435 """Get the module interface data type according to the key value and interface type.
@@ -468,6 +540,8 @@ def get_config(self):
468540 # Use topological sort to get the correct order of modules.
469541 self .dag_topology_sort ()
470542 mconfig = {}
543+ module_connection = {}
544+ input_connection = {}
471545 for mod in self .mod_wrapper :
472546 # Generate pipeline configuration.
473547 mconf = {}
@@ -495,7 +569,7 @@ def get_config(self):
495569 mconf ["mod_idx" ] = module .idx
496570 mconf ["output" ] = output_conf
497571
498- mconfig [mod ] = {
572+ module_connection [mod ] = {
499573 "pipeline" : mconf ,
500574 "target_host" : module .target_host ,
501575 "mod_name" : "default" ,
@@ -505,6 +579,22 @@ def get_config(self):
505579 "dev" : module .dev ,
506580 }
507581
582+ # Create a map of pipeline input and subgraph input.
583+ input_connection = []
584+ for input_name in self .input_bindings .bindings :
585+ input_dict = self .input_bindings .bindings [input_name ].get_binding_dict ()
586+ if "interface_name" not in input_dict ["connection" ][0 ]:
587+ raise RuntimeError ("interface_name is missing in connection config!" )
588+ # Creating the map of global interface and subgraph interface.
589+ input_map = {
590+ "global_interface_name" : input_dict ["interface_name" ],
591+ "mod_idx" : input_dict ["connection" ][0 ]["mod_idx" ],
592+ "module_interface_name" : input_dict ["connection" ][0 ]["interface_name" ],
593+ }
594+ input_connection .append (input_map )
595+
596+ mconfig ["module_connection" ] = module_connection
597+ mconfig ["input_connection" ] = input_connection
508598 return mconfig
509599
510600 def dag_topology_sort (self ):
@@ -601,11 +691,11 @@ def export_library(self, directory_path):
601691 Export the files to this directory.
602692 """
603693 if not self .pipeline_mods :
604- raise RuntimeError (f "The pipeline executor has not been initialized." )
694+ raise RuntimeError ("The pipeline executor has not been initialized." )
605695
606696 # Check if the directory_path exists.
607697 if not os .path .exists (directory_path ):
608- raise RuntimeError (f "The directory { directory_path } does not exist." )
698+ raise RuntimeError ("The directory {directory_path} does not exist." )
609699 # Create an load configuration.
610700 load_config_file_name = "{}/load_config" .format (directory_path )
611701 pipeline_config_file_name = "{}/pipeline_config" .format (directory_path )
0 commit comments