Skip to content

Commit b9f47dd

Browse files
committed
[MetaSchedule][Minor] Organize Testing Scripts
1 parent d0650ba commit b9f47dd

File tree

8 files changed

+175
-35
lines changed

8 files changed

+175
-35
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-import, redefined-builtin
18+
"""Testing utilities in auto scheduler."""
19+
20+
# NOTE: Do not import any module here by default

python/tvm/meta_schedule/testing/tune_onnx_auto_scheduler.py renamed to python/tvm/auto_scheduler/testing/tune_onnx.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@
2222
import numpy as np # type: ignore
2323
import onnx # type: ignore
2424
import tvm
25-
from tvm.relay.frontend import from_onnx
2625
from tvm import auto_scheduler
2726
from tvm import meta_schedule as ms
2827
from tvm import relay
2928
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
29+
from tvm.relay.frontend import from_onnx
3030

3131

3232
def _parse_args():
@@ -82,6 +82,26 @@ def _parse_args():
8282
type=str,
8383
required=True,
8484
)
85+
args.add_argument(
86+
"--number",
87+
type=int,
88+
default=3,
89+
)
90+
args.add_argument(
91+
"--repeat",
92+
type=int,
93+
default=1,
94+
)
95+
args.add_argument(
96+
"--min-repeat-ms",
97+
type=int,
98+
default=100,
99+
)
100+
args.add_argument(
101+
"--cpu-flush",
102+
type=bool,
103+
required=True,
104+
)
85105
parsed = args.parse_args()
86106
parsed.target = tvm.target.Target(parsed.target)
87107
parsed.input_shape = json.loads(parsed.input_shape)
@@ -105,10 +125,10 @@ def main():
105125
host=ARGS.rpc_host,
106126
port=ARGS.rpc_port,
107127
n_parallel=ARGS.rpc_workers,
108-
number=3,
109-
repeat=1,
110-
min_repeat_ms=100, # TODO
111-
enable_cpu_cache_flush=False, # TODO
128+
number=ARGS.number,
129+
repeat=ARGS.repeat,
130+
min_repeat_ms=ARGS.min_repeat_ms,
131+
enable_cpu_cache_flush=ARGS.cpu_flush,
112132
)
113133

114134
if ARGS.target.kind.name == "llvm":

python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py renamed to python/tvm/auto_scheduler/testing/tune_relay.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,26 @@ def _parse_args():
8080
type=str,
8181
default=None,
8282
)
83+
args.add_argument(
84+
"--number",
85+
type=int,
86+
default=3,
87+
)
88+
args.add_argument(
89+
"--repeat",
90+
type=int,
91+
default=1,
92+
)
93+
args.add_argument(
94+
"--min-repeat-ms",
95+
type=int,
96+
default=100,
97+
)
98+
args.add_argument(
99+
"--cpu-flush",
100+
type=bool,
101+
required=True,
102+
)
83103
parsed = args.parse_args()
84104
parsed.target = tvm.target.Target(parsed.target)
85105
parsed.input_shape = json.loads(parsed.input_shape)
@@ -103,10 +123,10 @@ def main():
103123
host=ARGS.rpc_host,
104124
port=ARGS.rpc_port,
105125
n_parallel=ARGS.rpc_workers,
106-
number=3,
107-
repeat=1,
108-
min_repeat_ms=100, # TODO
109-
enable_cpu_cache_flush=False, # TODO
126+
number=ARGS.number,
127+
repeat=ARGS.repeat,
128+
min_repeat_ms=ARGS.min_repeat_ms,
129+
enable_cpu_cache_flush=ARGS.cpu_flush,
110130
)
111131

112132
if ARGS.target.kind.name == "llvm":

python/tvm/meta_schedule/testing/tune_te_auto_scheduler.py renamed to python/tvm/auto_scheduler/testing/tune_te.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# software distributed under the License is distributed on an
1313
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
1414
# KIND, either express or implied. See the License for the
15-
# specific language governing permissions and limitatios
15+
# specific language governing permissions and limitations
1616
# under the License.
1717
# pylint: disable=missing-docstring
1818
import argparse
@@ -61,10 +61,30 @@ def _parse_args():
6161
required=True,
6262
)
6363
args.add_argument(
64-
"--log-dir",
64+
"--work-dir",
6565
type=str,
6666
required=True,
6767
)
68+
args.add_argument(
69+
"--number",
70+
type=int,
71+
default=3,
72+
)
73+
args.add_argument(
74+
"--repeat",
75+
type=int,
76+
default=1,
77+
)
78+
args.add_argument(
79+
"--min-repeat-ms",
80+
type=int,
81+
default=100,
82+
)
83+
args.add_argument(
84+
"--cpu-flush",
85+
type=bool,
86+
required=True,
87+
)
6888
parsed = args.parse_args()
6989
parsed.target = tvm.target.Target(parsed.target)
7090
return parsed
@@ -74,7 +94,7 @@ def _parse_args():
7494

7595

7696
def main():
77-
log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json")
97+
log_file = os.path.join(ARGS.work_dir, f"{ARGS.workload}.json")
7898
workload_func, params = CONFIGS[ARGS.workload]
7999
params = params[0] # type: ignore
80100
workload_func = auto_scheduler.register_workload(workload_func)
@@ -110,10 +130,10 @@ def main():
110130
host=ARGS.rpc_host,
111131
port=ARGS.rpc_port,
112132
n_parallel=ARGS.rpc_workers,
113-
number=3,
114-
repeat=1,
115-
min_repeat_ms=100,
116-
enable_cpu_cache_flush=False,
133+
number=ARGS.number,
134+
repeat=ARGS.repeat,
135+
min_repeat_ms=ARGS.min_repeat_ms,
136+
enable_cpu_cache_flush=ARGS.cpu_flush,
117137
)
118138

119139
# Inspect the computational graph

python/tvm/meta_schedule/testing/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,5 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Testing utilities in meta schedule"""
18+
19+
# NOTE: Do not import any module here by default

python/tvm/meta_schedule/testing/tune_onnx_meta_schedule.py renamed to python/tvm/meta_schedule/testing/tune_onnx.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import argparse
1919
import json
2020
import logging
21+
2122
import numpy as np # type: ignore
2223
import onnx # type: ignore
2324
import tvm
24-
from tvm.relay.frontend import from_onnx
2525
from tvm import meta_schedule as ms
2626
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
27+
from tvm.relay.frontend import from_onnx
2728

2829

2930
def _parse_args():
@@ -79,6 +80,26 @@ def _parse_args():
7980
type=str,
8081
required=True,
8182
)
83+
args.add_argument(
84+
"--number",
85+
type=int,
86+
default=3,
87+
)
88+
args.add_argument(
89+
"--repeat",
90+
type=int,
91+
default=1,
92+
)
93+
args.add_argument(
94+
"--min-repeat-ms",
95+
type=int,
96+
default=100,
97+
)
98+
args.add_argument(
99+
"--cpu-flush",
100+
type=bool,
101+
required=True,
102+
)
82103
parsed = args.parse_args()
83104
parsed.target = tvm.target.Target(parsed.target)
84105
parsed.input_shape = json.loads(parsed.input_shape)
@@ -108,16 +129,15 @@ def main():
108129
print(f" input_dtype: {item['dtype']}")
109130
shape_dict[item["name"]] = item["shape"]
110131
mod, params = from_onnx(onnx_model, shape_dict, freeze_params=True)
111-
alloc_repeat = 1
112132
runner = ms.runner.RPCRunner(
113133
rpc_config=ARGS.rpc_config,
114134
evaluator_config=ms.runner.EvaluatorConfig(
115-
number=3,
116-
repeat=1,
117-
min_repeat_ms=100,
118-
enable_cpu_cache_flush=False,
135+
number=ARGS.number,
136+
repeat=ARGS.repeat,
137+
min_repeat_ms=ARGS.min_repeat_ms,
138+
enable_cpu_cache_flush=ARGS.cpu_flush,
119139
),
120-
alloc_repeat=alloc_repeat,
140+
alloc_repeat=1,
121141
max_workers=ARGS.rpc_workers,
122142
)
123143
lib = ms.tune_relay(

python/tvm/meta_schedule/testing/tune_relay_meta_schedule.py renamed to python/tvm/meta_schedule/testing/tune_relay.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,26 @@ def _parse_args():
7878
type=str,
7979
default=None,
8080
)
81+
args.add_argument(
82+
"--number",
83+
type=int,
84+
default=3,
85+
)
86+
args.add_argument(
87+
"--repeat",
88+
type=int,
89+
default=1,
90+
)
91+
args.add_argument(
92+
"--min-repeat-ms",
93+
type=int,
94+
default=100,
95+
)
96+
args.add_argument(
97+
"--cpu-flush",
98+
type=bool,
99+
required=True,
100+
)
81101
parsed = args.parse_args()
82102
parsed.target = tvm.target.Target(parsed.target)
83103
parsed.input_shape = json.loads(parsed.input_shape)
@@ -110,16 +130,15 @@ def main():
110130
print(f" input_name: {input_name}")
111131
print(f" input_shape: {input_shape}")
112132
print(f" input_dtype: {input_dtype}")
113-
alloc_repeat = 1
114133
runner = ms.runner.RPCRunner(
115134
rpc_config=ARGS.rpc_config,
116135
evaluator_config=ms.runner.EvaluatorConfig(
117-
number=3,
118-
repeat=1,
119-
min_repeat_ms=100,
120-
enable_cpu_cache_flush=False,
136+
number=ARGS.number,
137+
repeat=ARGS.repeat,
138+
min_repeat_ms=ARGS.min_repeat_ms,
139+
enable_cpu_cache_flush=ARGS.cpu_flush,
121140
),
122-
alloc_repeat=alloc_repeat,
141+
alloc_repeat=1,
123142
max_workers=ARGS.rpc_workers,
124143
)
125144
with ms.Profiler() as profiler:

python/tvm/meta_schedule/testing/tune_te_meta_schedule.py renamed to python/tvm/meta_schedule/testing/tune_te.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,26 @@ def _parse_args():
6868
type=str,
6969
required=True,
7070
)
71+
args.add_argument(
72+
"--number",
73+
type=int,
74+
default=3,
75+
)
76+
args.add_argument(
77+
"--repeat",
78+
type=int,
79+
default=1,
80+
)
81+
args.add_argument(
82+
"--min-repeat-ms",
83+
type=int,
84+
default=100,
85+
)
86+
args.add_argument(
87+
"--cpu-flush",
88+
type=bool,
89+
required=True,
90+
)
7191
parsed = args.parse_args()
7292
parsed.target = tvm.target.Target(parsed.target)
7393
parsed.rpc_config = ms.runner.RPCConfig(
@@ -87,16 +107,15 @@ def _parse_args():
87107

88108

89109
def main():
90-
alloc_repeat = 1
91110
runner = ms.runner.RPCRunner(
92111
rpc_config=ARGS.rpc_config,
93112
evaluator_config=ms.runner.EvaluatorConfig(
94-
number=3,
95-
repeat=1,
96-
min_repeat_ms=100,
97-
enable_cpu_cache_flush=False,
113+
number=ARGS.number,
114+
repeat=ARGS.repeat,
115+
min_repeat_ms=ARGS.min_repeat_ms,
116+
enable_cpu_cache_flush=ARGS.cpu_flush,
98117
),
99-
alloc_repeat=alloc_repeat,
118+
alloc_repeat=1,
100119
max_workers=ARGS.rpc_workers,
101120
)
102121
sch: Optional[tir.Schedule] = ms.tune_tir(

0 commit comments

Comments
 (0)