Skip to content

Commit 1c08a3e

Browse files
author
YJ Shi
committed
fix issues for models with more than one inputs
1 parent 53fe596 commit 1c08a3e

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed

python/tvm/meta_schedule/testing/custom_builder_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,4 @@ def run_module_via_rpc(
167167
rt_mod = session.load_module(filename)
168168
dev = session.device(dev_type=dev_type, dev_id=0)
169169
args = [ndarray.array(arg, dev) for arg in args]
170-
return continuation(rt_mod, dev, *args)
170+
return continuation(rt_mod, dev, args)

python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,13 @@ def main():
103103
ARGS.input_shape,
104104
cache_dir=ARGS.cache_dir,
105105
)
106+
input_info = {input_name: input_shape}
107+
inputs = []
106108
print(f"Workload: {ARGS.workload}")
107-
print(f" input_name: {input_name}")
108-
print(f" input_shape: {input_shape}")
109-
print(f" input_dtype: {input_dtype}")
109+
for input_name, input_shape in input_info.items():
110+
print(f" input_name: {input_name}")
111+
print(f" input_shape: {input_shape}")
112+
print(f" input_dtype: {input_dtype}")
110113
alloc_repeat = 1
111114
runner = ms.runner.RPCRunner(
112115
rpc_config=ARGS.rpc_config,
@@ -133,19 +136,21 @@ def main():
133136
params=params,
134137
)
135138
graph, rt_mod, params = lib.graph_json, lib.lib, lib.params
136-
if input_dtype.startswith("float"):
137-
input_data = np.random.uniform(size=input_shape).astype(input_dtype)
138-
else:
139-
input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)
139+
for input_name, input_shape in input_info.items():
140+
if input_dtype.startswith("float"):
141+
inputs.append(np.random.uniform(size=input_shape).astype(input_dtype))
142+
else:
143+
inputs.append(np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype))
140144

141-
def f_timer(rt_mod, dev, input_data):
145+
def f_timer(rt_mod, dev, inputs):
142146
# pylint: disable=import-outside-toplevel
143147
from tvm.contrib.graph_executor import GraphModule
144148

145149
# pylint: enable=import-outside-toplevel
146150

147151
mod = GraphModule(rt_mod["default"](dev))
148-
mod.set_input(input_name, input_data)
152+
for index, (input_name, _) in enumerate(input_info.items()):
153+
mod.set_input(input_name, inputs[index])
149154
ftimer = mod.module.time_evaluator(
150155
"run",
151156
dev,
@@ -159,17 +164,18 @@ def f_timer(rt_mod, dev, input_data):
159164
rpc_config=ARGS.rpc_config,
160165
lib=lib,
161166
dev_type=ARGS.target.kind.name,
162-
args=[input_data],
167+
args=inputs,
163168
continuation=f_timer,
164169
)
165170

166-
def f_per_layer(rt_mod, dev, input_data):
171+
def f_per_layer(rt_mod, dev, inputs):
167172
# pylint: disable=import-outside-toplevel
168173
from tvm.contrib.debugger.debug_executor import create
169174

170175
# pylint: enable=import-outside-toplevel
171176
mod = create(graph, rt_mod, dev)
172-
mod.set_input(input_name, input_data)
177+
for index, (input_name, _) in enumerate(input_info.items()):
178+
mod.set_input(input_name, inputs[index])
173179
graph_nodes = [n["name"] for n in json.loads(graph)["nodes"]]
174180
graph_time = mod.run_individual(number=10, repeat=1, min_repeat_ms=5000)
175181
print("|graph_nodes| = ", len(graph_nodes))
@@ -182,7 +188,7 @@ def f_per_layer(rt_mod, dev, input_data):
182188
rpc_config=ARGS.rpc_config,
183189
lib=rt_mod,
184190
dev_type=ARGS.target.kind.name,
185-
args=[input_data],
191+
args=inputs,
186192
continuation=f_per_layer,
187193
)
188194

0 commit comments

Comments
 (0)