Skip to content

Commit 6bf8b60

Browse files
author
YJ Shi
committed
fix auto schedule
1 parent 1c08a3e commit 6bf8b60

File tree

1 file changed

+19
-13
lines changed

1 file changed

+19
-13
lines changed

python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,13 @@ def main():
134134
ARGS.input_shape,
135135
cache_dir=ARGS.cache_dir,
136136
)
137+
input_info = {input_name: input_shape}
138+
inputs = []
137139
print(f"Workload: {ARGS.workload}")
138-
print(f" input_name: {input_name}")
139-
print(f" input_shape: {input_shape}")
140-
print(f" input_dtype: {input_dtype}")
140+
for input_name, input_shape in input_info.items():
141+
print(f" input_name: {input_name}")
142+
print(f" input_shape: {input_shape}")
143+
print(f" input_dtype: {input_dtype}")
141144
tasks, task_weights = auto_scheduler.extract_tasks(
142145
mod["main"],
143146
params,
@@ -170,19 +173,21 @@ def main():
170173
params=params,
171174
)
172175
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
173-
if input_dtype.startswith("float"):
174-
input_data = np.random.uniform(size=input_shape).astype(input_dtype)
175-
else:
176-
input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)
176+
for input_name, input_shape in input_info.items():
177+
if input_dtype.startswith("float"):
178+
inputs.append(np.random.uniform(size=input_shape).astype(input_dtype))
179+
else:
180+
inputs.append(np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype))
177181

178-
def f_timer(rt_mod, dev, input_data):
182+
def f_timer(rt_mod, dev, inputs):
179183
# pylint: disable=import-outside-toplevel
180184
from tvm.contrib.graph_executor import GraphModule
181185

182186
# pylint: enable=import-outside-toplevel
183187

184188
mod = GraphModule(rt_mod["default"](dev))
185-
mod.set_input(input_name, input_data)
189+
for index, (input_name, _) in enumerate(input_info.items()):
190+
mod.set_input(input_name, inputs[index])
186191
ftimer = mod.module.time_evaluator(
187192
"run",
188193
dev,
@@ -196,17 +201,18 @@ def f_timer(rt_mod, dev, input_data):
196201
rpc_config=ARGS.rpc_config,
197202
lib=lib,
198203
dev_type=ARGS.target.kind.name,
199-
args=[input_data],
204+
args=inputs,
200205
continuation=f_timer,
201206
)
202207

203-
def f_per_layer(rt_mod, dev, input_data):
208+
def f_per_layer(rt_mod, dev, inputs):
204209
# pylint: disable=import-outside-toplevel
205210
from tvm.contrib.debugger.debug_executor import create
206211

207212
# pylint: enable=import-outside-toplevel
208213
mod = create(graph, rt_mod, dev)
209-
mod.set_input(input_name, input_data)
214+
for index, (input_name, _) in enumerate(input_info.items()):
215+
mod.set_input(input_name, inputs[index])
210216
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
211217
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
212218
print("|graph_nodes| = ", len(graph_nodes))
@@ -219,7 +225,7 @@ def f_per_layer(rt_mod, dev, input_data):
219225
rpc_config=ARGS.rpc_config,
220226
lib=rt_mod,
221227
dev_type=ARGS.target.kind.name,
222-
args=[input_data],
228+
args=inputs,
223229
continuation=f_per_layer,
224230
)
225231

0 commit comments

Comments
 (0)