Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/tvm/auto_scheduler/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.relay.frontend import from_onnx
from tvm.support import describe


def _parse_args():
Expand Down Expand Up @@ -152,6 +153,7 @@ def main():
else:
raise NotImplementedError(f"Unsupported target {ARGS.target}")

describe()
print(f"Workload: {ARGS.model_name}")
onnx_model = onnx.load(ARGS.onnx_path)
shape_dict = {}
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/auto_scheduler/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.support import describe


def _parse_args():
Expand Down Expand Up @@ -149,14 +150,16 @@ def main():
)
else:
raise NotImplementedError(f"Unsupported target {ARGS.target}")

describe()
print(f"Workload: {ARGS.workload}")
mod, params, (input_name, input_shape, input_dtype) = get_network(
ARGS.workload,
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_data = {}
print(f"Workload: {ARGS.workload}")
for input_name, input_shape in input_info.items():
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/auto_scheduler/testing/tune_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from tvm import auto_scheduler
from tvm.meta_schedule.testing.te_workload import CONFIGS
from tvm.support import describe


def _parse_args():
Expand Down Expand Up @@ -94,6 +95,8 @@ def _parse_args():


def main():
describe()
print(f"Workload: {ARGS.workload}")
log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")
workload_func, params = CONFIGS[ARGS.workload]
params = params[0] # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/meta_schedule/testing/tune_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.relay.frontend import from_onnx
from tvm.support import describe


def _parse_args():
Expand Down Expand Up @@ -120,6 +121,7 @@ def _parse_args():


def main():
describe()
print(f"Workload: {ARGS.model_name}")
onnx_model = onnx.load(ARGS.onnx_path)
shape_dict = {}
Expand Down
4 changes: 3 additions & 1 deletion python/tvm/meta_schedule/testing/tune_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network
from tvm.support import describe


def _parse_args():
Expand Down Expand Up @@ -118,14 +119,15 @@ def _parse_args():


def main():
describe()
print(f"Workload: {ARGS.workload}")
mod, params, (input_name, input_shape, input_dtype) = get_network(
ARGS.workload,
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
input_info = {input_name: input_shape}
input_data = {}
print(f"Workload: {ARGS.workload}")
for input_name, input_shape in input_info.items():
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/meta_schedule/testing/tune_te.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm import meta_schedule as ms
from tvm import tir
from tvm.meta_schedule.testing.te_workload import create_te_workload
from tvm.support import describe


def _parse_args():
Expand Down Expand Up @@ -107,6 +108,8 @@ def _parse_args():


def main():
describe()
print(f"Workload: {ARGS.workload}")
runner = ms.runner.RPCRunner(
rpc_config=ARGS.rpc_config,
evaluator_config=ms.runner.EvaluatorConfig(
Expand Down