Skip to content

Commit c62cc34

Browse files
Yuanjing Shipfk-beta
authored andcommitted
[Meta Schedule] Refactor meta schedule testing utils (apache#10648)
This PR moves some utility testing classes into `meta_schedule/testing/utils` and updated the following tests involved: - test_meta_schedule_integration.py - test_meta_schedule_measure_callback.py - test_meta_schedule_search_strategy.py - test_meta_schedule_task_scheduler.py
1 parent 9d527d9 commit c62cc34

File tree

6 files changed

+119
-228
lines changed

6 files changed

+119
-228
lines changed

python/tvm/meta_schedule/testing/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Testing utilities in meta schedule"""
18+
from .utils import DummyDatabase, DummyBuilder, DummyRunner, DummyRunnerFuture, DummyMutator
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Testing utilitiy functions in meta schedule"""
18+
from typing import List, Optional
19+
import random
20+
21+
import tvm
22+
23+
from tvm.meta_schedule import TuneContext # pylint: disable=unused-import
24+
from tvm.meta_schedule.utils import derived_object
25+
from tvm.meta_schedule.mutator.mutator import PyMutator
26+
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
27+
from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
28+
from tvm.meta_schedule.runner import (
29+
RunnerInput,
30+
RunnerResult,
31+
RunnerFuture,
32+
PyRunnerFuture,
33+
PyRunner,
34+
)
35+
from tvm.ir import IRModule
36+
from tvm.tir.schedule import Trace
37+
38+
39+
@derived_object
40+
class DummyDatabase(PyDatabase):
41+
"""
42+
An in-memory database based on python list for testing.
43+
"""
44+
45+
def __init__(self):
46+
super().__init__()
47+
self.records = []
48+
self.workload_reg = []
49+
50+
def has_workload(self, mod: IRModule) -> bool:
51+
for workload in self.workload_reg:
52+
if tvm.ir.structural_equal(workload.mod, mod):
53+
return True
54+
return False
55+
56+
def commit_tuning_record(self, record: TuningRecord) -> None:
57+
self.records.append(record)
58+
59+
def commit_workload(self, mod: IRModule) -> Workload:
60+
for workload in self.workload_reg:
61+
if tvm.ir.structural_equal(workload.mod, mod):
62+
return workload
63+
workload = Workload(mod)
64+
self.workload_reg.append(workload)
65+
return workload
66+
67+
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
68+
return list(
69+
filter(
70+
lambda x: x.workload == workload,
71+
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
72+
)
73+
)[: int(top_k)]
74+
75+
def __len__(self) -> int:
76+
return len(self.records)
77+
78+
def print_results(self) -> None:
79+
print("\n".join([str(r) for r in self.records]))
80+
81+
82+
@derived_object
83+
class DummyRunnerFuture(PyRunnerFuture):
84+
def done(self) -> bool:
85+
return True
86+
87+
def result(self) -> RunnerResult:
88+
run_secs = [random.uniform(5, 30) for _ in range(random.randint(1, 10))]
89+
return RunnerResult(run_secs, None)
90+
91+
92+
@derived_object
93+
class DummyBuilder(PyBuilder):
94+
def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
95+
return [BuilderResult("test_path", None) for _ in build_inputs]
96+
97+
98+
@derived_object
99+
class DummyRunner(PyRunner):
100+
def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
101+
return [DummyRunnerFuture() for _ in runner_inputs] # type: ignore
102+
103+
104+
@derived_object
105+
class DummyMutator(PyMutator):
106+
"""Dummy Mutator for testing"""
107+
108+
def initialize_with_tune_context(self, context: "TuneContext") -> None:
109+
pass
110+
111+
def apply(self, trace: Trace, _) -> Optional[Trace]:
112+
return Trace(trace.insts, {})

tests/python/unittest/test_meta_schedule_integration.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from tvm.script import tir as T
3333
from tvm.target import Target
3434
from tvm.tir import Schedule
35+
from tvm.meta_schedule.testing import DummyDatabase
3536
from tvm.meta_schedule.testing.tlcbench import load_quantized_bert_base
3637
from tvm.meta_schedule.tune import extract_task_from_relay, Parse
3738

@@ -106,44 +107,6 @@ def test_meta_schedule_integration_extract_from_resnet():
106107

107108
@requires_torch
108109
def test_meta_schedule_integration_apply_history_best():
109-
@derived_object
110-
class DummyDatabase(PyDatabase):
111-
def __init__(self):
112-
super().__init__()
113-
self.records = []
114-
self.workload_reg = []
115-
116-
def has_workload(self, mod: IRModule) -> Workload:
117-
for workload in self.workload_reg:
118-
if tvm.ir.structural_equal(workload.mod, mod):
119-
return True
120-
return False
121-
122-
def commit_tuning_record(self, record: TuningRecord) -> None:
123-
self.records.append(record)
124-
125-
def commit_workload(self, mod: IRModule) -> Workload:
126-
for workload in self.workload_reg:
127-
if tvm.ir.structural_equal(workload.mod, mod):
128-
return workload
129-
workload = Workload(mod)
130-
self.workload_reg.append(workload)
131-
return workload
132-
133-
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
134-
return list(
135-
filter(
136-
lambda x: x.workload == workload,
137-
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
138-
)
139-
)[: int(top_k)]
140-
141-
def __len__(self) -> int:
142-
return len(self.records)
143-
144-
def print_results(self) -> None:
145-
print("\n".join([str(r) for r in self.records]))
146-
147110
mod, _, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
148111
database = DummyDatabase()
149112
env = ApplyHistoryBest(database)

tests/python/unittest/test_meta_schedule_measure_callback.py

Lines changed: 3 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,9 @@
2424
from tvm.ir import IRModule, assert_structural_equal
2525
from tvm.meta_schedule.builder import BuilderResult
2626
from tvm.meta_schedule.measure_callback import PyMeasureCallback
27-
from tvm.meta_schedule.builder import PyBuilder, BuilderInput, BuilderResult
28-
from tvm.meta_schedule.runner import (
29-
RunnerInput,
30-
RunnerResult,
31-
RunnerFuture,
32-
PyRunnerFuture,
33-
PyRunner,
34-
)
35-
from tvm.meta_schedule.database import PyDatabase, Workload, TuningRecord
27+
from tvm.meta_schedule.builder import BuilderResult
28+
from tvm.meta_schedule.runner import RunnerResult
29+
from tvm.meta_schedule.testing import DummyDatabase, DummyRunner, DummyBuilder
3630
from tvm.meta_schedule.search_strategy import MeasureCandidate
3731
from tvm.meta_schedule.task_scheduler import RoundRobin, TaskScheduler
3832
from tvm.meta_schedule.utils import derived_object
@@ -61,66 +55,6 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None:
6155
# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument
6256

6357

64-
@derived_object
65-
class DummyRunnerFuture(PyRunnerFuture):
66-
def done(self) -> bool:
67-
return True
68-
69-
def result(self) -> RunnerResult:
70-
return RunnerResult([random.uniform(5, 30) for _ in range(random.randint(1, 10))], None)
71-
72-
73-
@derived_object
74-
class DummyBuilder(PyBuilder):
75-
def build(self, build_inputs: List[BuilderInput]) -> List[BuilderResult]:
76-
return [BuilderResult("test_path", None) for _ in build_inputs]
77-
78-
79-
@derived_object
80-
class DummyRunner(PyRunner):
81-
def run(self, runner_inputs: List[RunnerInput]) -> List[RunnerFuture]:
82-
return [DummyRunnerFuture() for _ in runner_inputs]
83-
84-
85-
@derived_object
86-
class DummyDatabase(PyDatabase):
87-
def __init__(self):
88-
super().__init__()
89-
self.records = []
90-
self.workload_reg = []
91-
92-
def has_workload(self, mod: IRModule) -> Workload:
93-
for workload in self.workload_reg:
94-
if tvm.ir.structural_equal(workload.mod, mod):
95-
return True
96-
return False
97-
98-
def commit_tuning_record(self, record: TuningRecord) -> None:
99-
self.records.append(record)
100-
101-
def commit_workload(self, mod: IRModule) -> Workload:
102-
for workload in self.workload_reg:
103-
if tvm.ir.structural_equal(workload.mod, mod):
104-
return workload
105-
workload = Workload(mod)
106-
self.workload_reg.append(workload)
107-
return workload
108-
109-
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
110-
return list(
111-
filter(
112-
lambda x: x.workload == workload,
113-
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
114-
)
115-
)[: int(top_k)]
116-
117-
def __len__(self) -> int:
118-
return len(self.records)
119-
120-
def print_results(self) -> None:
121-
print("\n".join([str(r) for r in self.records]))
122-
123-
12458
def test_meta_schedule_measure_callback():
12559
@derived_object
12660
class FancyMeasureCallback(PyMeasureCallback):

tests/python/unittest/test_meta_schedule_search_strategy.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from tvm.meta_schedule import TuneContext
2727
from tvm.meta_schedule.builder import LocalBuilder
2828
from tvm.meta_schedule.cost_model import RandomModel
29-
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
30-
from tvm.meta_schedule.mutator.mutator import PyMutator
3129
from tvm.meta_schedule.runner import LocalRunner, RunnerResult
3230
from tvm.meta_schedule.search_strategy import (
3331
EvolutionarySearch,
@@ -38,6 +36,7 @@
3836
from tvm.meta_schedule.space_generator import ScheduleFn
3937
from tvm.meta_schedule.task_scheduler import RoundRobin
4038
from tvm.meta_schedule.utils import derived_object
39+
from tvm.meta_schedule.testing import DummyDatabase, DummyMutator
4140
from tvm.script import tir as T
4241
from tvm.tir.schedule import Schedule, Trace
4342

@@ -117,56 +116,6 @@ def test_meta_schedule_replay_func(TestClass: SearchStrategy): # pylint: disabl
117116

118117

119118
def test_meta_schedule_evolutionary_search(): # pylint: disable = invalid-name]
120-
@derived_object
121-
class DummyMutator(PyMutator):
122-
"""Dummy Mutator for testing"""
123-
124-
def initialize_with_tune_context(self, context: "TuneContext") -> None:
125-
pass
126-
127-
def apply(self, trace: Trace, _) -> Optional[Trace]:
128-
return Trace(trace.insts, {})
129-
130-
@derived_object
131-
class DummyDatabase(PyDatabase):
132-
"""Dummy Database for testing"""
133-
134-
def __init__(self):
135-
super().__init__()
136-
self.records = []
137-
self.workload_reg = []
138-
139-
def has_workload(self, mod: IRModule) -> bool:
140-
for workload in self.workload_reg:
141-
if tvm.ir.structural_equal(workload.mod, mod):
142-
return True
143-
return False
144-
145-
def commit_tuning_record(self, record: TuningRecord) -> None:
146-
self.records.append(record)
147-
148-
def commit_workload(self, mod: IRModule) -> Workload:
149-
for workload in self.workload_reg:
150-
if tvm.ir.structural_equal(workload.mod, mod):
151-
return workload
152-
workload = Workload(mod)
153-
self.workload_reg.append(workload)
154-
return workload
155-
156-
def get_top_k(self, workload: Workload, top_k: int) -> List[TuningRecord]:
157-
return list(
158-
filter(
159-
lambda x: x.workload == workload,
160-
sorted(self.records, key=lambda x: sum(x.run_secs) / len(x.run_secs)),
161-
)
162-
)[: int(top_k)]
163-
164-
def __len__(self) -> int:
165-
return len(self.records)
166-
167-
def print_results(self) -> None:
168-
print("\n".join([str(r) for r in self.records]))
169-
170119
num_trials_per_iter = 10
171120
num_trials_total = 100
172121

0 commit comments

Comments
 (0)