@@ -103,10 +103,13 @@ def main():
103103 ARGS .input_shape ,
104104 cache_dir = ARGS .cache_dir ,
105105 )
106+ input_info = {input_name : input_shape }
107+ inputs = []
106108 print (f"Workload: { ARGS .workload } " )
107- print (f" input_name: { input_name } " )
108- print (f" input_shape: { input_shape } " )
109- print (f" input_dtype: { input_dtype } " )
109+ for input_name , input_shape in input_info .items ():
110+ print (f" input_name: { input_name } " )
111+ print (f" input_shape: { input_shape } " )
112+ print (f" input_dtype: { input_dtype } " )
110113 alloc_repeat = 1
111114 runner = ms .runner .RPCRunner (
112115 rpc_config = ARGS .rpc_config ,
@@ -133,19 +136,21 @@ def main():
133136 params = params ,
134137 )
135138 graph , rt_mod , params = lib .graph_json , lib .lib , lib .params
136- if input_dtype .startswith ("float" ):
137- input_data = np .random .uniform (size = input_shape ).astype (input_dtype )
138- else :
139- input_data = np .random .randint (low = 0 , high = 10000 , size = input_shape , dtype = input_dtype )
139+ for input_name , input_shape in input_info .items ():
140+ if input_dtype .startswith ("float" ):
141+ inputs .append (np .random .uniform (size = input_shape ).astype (input_dtype ))
142+ else :
143+ inputs .append (np .random .randint (low = 0 , high = 10000 , size = input_shape , dtype = input_dtype ))
140144
141- def f_timer (rt_mod , dev , input_data ):
145+ def f_timer (rt_mod , dev , inputs ):
142146 # pylint: disable=import-outside-toplevel
143147 from tvm .contrib .graph_executor import GraphModule
144148
145149 # pylint: enable=import-outside-toplevel
146150
147151 mod = GraphModule (rt_mod ["default" ](dev ))
148- mod .set_input (input_name , input_data )
152+ for index , (input_name , _ ) in enumerate (input_info .items ()):
153+ mod .set_input (input_name , inputs [index ])
149154 ftimer = mod .module .time_evaluator (
150155 "run" ,
151156 dev ,
@@ -159,17 +164,18 @@ def f_timer(rt_mod, dev, input_data):
159164 rpc_config = ARGS .rpc_config ,
160165 lib = lib ,
161166 dev_type = ARGS .target .kind .name ,
162- args = [ input_data ] ,
167+ args = inputs ,
163168 continuation = f_timer ,
164169 )
165170
166- def f_per_layer (rt_mod , dev , input_data ):
171+ def f_per_layer (rt_mod , dev , inputs ):
167172 # pylint: disable=import-outside-toplevel
168173 from tvm .contrib .debugger .debug_executor import create
169174
170175 # pylint: enable=import-outside-toplevel
171176 mod = create (graph , rt_mod , dev )
172- mod .set_input (input_name , input_data )
177+ for index , (input_name , _ ) in enumerate (input_info .items ()):
178+ mod .set_input (input_name , inputs [index ])
173179 graph_nodes = [n ["name" ] for n in json .loads (graph )["nodes" ]]
174180 graph_time = mod .run_individual (number = 10 , repeat = 1 , min_repeat_ms = 5000 )
175181 print ("|graph_nodes| = " , len (graph_nodes ))
@@ -182,7 +188,7 @@ def f_per_layer(rt_mod, dev, input_data):
182188 rpc_config = ARGS .rpc_config ,
183189 lib = rt_mod ,
184190 dev_type = ARGS .target .kind .name ,
185- args = [ input_data ] ,
191+ args = inputs ,
186192 continuation = f_per_layer ,
187193 )
188194
0 commit comments