@@ -134,10 +134,13 @@ def main():
134134 ARGS .input_shape ,
135135 cache_dir = ARGS .cache_dir ,
136136 )
137+ input_info = {input_name : input_shape }
138+ inputs = []
137139 print (f"Workload: { ARGS .workload } " )
138- print (f" input_name: { input_name } " )
139- print (f" input_shape: { input_shape } " )
140- print (f" input_dtype: { input_dtype } " )
140+ for input_name , input_shape in input_info .items ():
141+ print (f" input_name: { input_name } " )
142+ print (f" input_shape: { input_shape } " )
143+ print (f" input_dtype: { input_dtype } " )
141144 tasks , task_weights = auto_scheduler .extract_tasks (
142145 mod ["main" ],
143146 params ,
@@ -170,19 +173,21 @@ def main():
170173 params = params ,
171174 )
172175 graph , rt_mod , params = lib .graph_json , lib .lib , lib .params
173- if input_dtype .startswith ("float" ):
174- input_data = np .random .uniform (size = input_shape ).astype (input_dtype )
175- else :
176- input_data = np .random .randint (low = 0 , high = 10000 , size = input_shape , dtype = input_dtype )
176+ for input_name , input_shape in input_info .items ():
177+ if input_dtype .startswith ("float" ):
178+ inputs .append (np .random .uniform (size = input_shape ).astype (input_dtype ))
179+ else :
180+ inputs .append (np .random .randint (low = 0 , high = 10000 , size = input_shape , dtype = input_dtype ))
177181
178- def f_timer (rt_mod , dev , input_data ):
182+ def f_timer (rt_mod , dev , inputs ):
179183 # pylint: disable=import-outside-toplevel
180184 from tvm .contrib .graph_executor import GraphModule
181185
182186 # pylint: enable=import-outside-toplevel
183187
184188 mod = GraphModule (rt_mod ["default" ](dev ))
185- mod .set_input (input_name , input_data )
189+ for index , (input_name , _ ) in enumerate (input_info .items ()):
190+ mod .set_input (input_name , inputs [index ])
186191 ftimer = mod .module .time_evaluator (
187192 "run" ,
188193 dev ,
@@ -196,17 +201,18 @@ def f_timer(rt_mod, dev, input_data):
196201 rpc_config = ARGS .rpc_config ,
197202 lib = lib ,
198203 dev_type = ARGS .target .kind .name ,
199- args = [ input_data ] ,
204+ args = inputs ,
200205 continuation = f_timer ,
201206 )
202207
203- def f_per_layer (rt_mod , dev , input_data ):
208+ def f_per_layer (rt_mod , dev , inputs ):
204209 # pylint: disable=import-outside-toplevel
205210 from tvm .contrib .debugger .debug_executor import create
206211
207212 # pylint: enable=import-outside-toplevel
208213 mod = create (graph , rt_mod , dev )
209- mod .set_input (input_name , input_data )
214+ for index , (input_name , _ ) in enumerate (input_info .items ()):
215+ mod .set_input (input_name , inputs [index ])
210216 graph_nodes = [n ["name" ] for n in json .loads (graph )["nodes" ]]
211217 graph_time = mod .run_individual (number = 10 , repeat = 1 , min_repeat_ms = 5000 )
212218 print ("|graph_nodes| = " , len (graph_nodes ))
@@ -219,7 +225,7 @@ def f_per_layer(rt_mod, dev, input_data):
219225 rpc_config = ARGS .rpc_config ,
220226 lib = rt_mod ,
221227 dev_type = ARGS .target .kind .name ,
222- args = [ input_data ] ,
228+ args = inputs ,
223229 continuation = f_per_layer ,
224230 )
225231
0 commit comments