Skip to content

Commit eaf2c64

Browse files
Kathryn-catjunrushao
authored andcommitted
[MetaSchedule] Generate MetaSchedule Dataset
In order to build a dataset for improving the cost model for MetaSchedule, I added several files including importing models to TVM, extracting tuning tasks, and sampling measure candidates. Meanwhile, I exposed some methods in C++ to the Python side to assist the process.
1 parent 6fca5c6 commit eaf2c64

File tree

5 files changed

+366
-0
lines changed

5 files changed

+366
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""
2+
Import models to TVM.
3+
"""
4+
5+
import argparse
6+
import os
7+
from typing import List, Tuple
8+
from tqdm import tqdm # type: ignore
9+
10+
from tvm.meta_schedule.testing.relay_workload import get_network
11+
12+
13+
# pylint: disable=too-many-branches
14+
def _build_dataset() -> List[Tuple[str, List[int]]]:
15+
network_keys = []
16+
for name in [
17+
"resnet_18",
18+
"resnet_50",
19+
"mobilenet_v2",
20+
"mobilenet_v3",
21+
"wide_resnet_50",
22+
"resnext_50",
23+
"densenet_121",
24+
"vgg_16",
25+
]:
26+
for batch_size in [1, 4, 8]:
27+
for image_size in [224, 240, 256]:
28+
network_keys.append((name, [batch_size, 3, image_size, image_size]))
29+
# inception-v3
30+
for name in ["inception_v3"]:
31+
for batch_size in [1, 2, 4]:
32+
for image_size in [299]:
33+
network_keys.append((name, [batch_size, 3, image_size, image_size]))
34+
# resnet3d
35+
for name in ["resnet3d_18"]:
36+
for batch_size in [1, 2, 4]:
37+
for image_size in [112, 128, 144]:
38+
network_keys.append((name, [batch_size, 3, image_size, image_size, 16]))
39+
# bert
40+
for name in ["bert_tiny", "bert_base", "bert_medium", "bert_large"]:
41+
for batch_size in [1, 2, 4]:
42+
for seq_length in [64, 128, 256]:
43+
network_keys.append((name, [batch_size, seq_length]))
44+
# dcgan
45+
for name in ["dcgan"]:
46+
for batch_size in [1, 4, 8]:
47+
for image_size in [64]:
48+
network_keys.append((name, [batch_size, 3, image_size, image_size]))
49+
50+
return network_keys
51+
52+
53+
def cache_models(network_keys, cache_dir):
54+
"""Download the model and cache it in the given directory."""
55+
56+
for name, input_shape in tqdm(network_keys):
57+
get_network(name=name, input_shape=input_shape, cache_dir=cache_dir)
58+
59+
60+
if __name__ == "__main__":
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument(
63+
"--model_cache_dir", type=str, help="Please provide the full path to the model cache dir."
64+
)
65+
args = parser.parse_args()
66+
model_cache_dir = args.model_cache_dir
67+
68+
try:
69+
os.makedirs(model_cache_dir, exist_ok=True)
70+
except OSError as error:
71+
print(f"Directory {model_cache_dir} cannot be created successfully.")
72+
keys = _build_dataset()
73+
cache_models(keys, model_cache_dir)
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
Extract tuning tasks using MetaSchedule, and filter out spatial tasks.
3+
"""
4+
5+
import argparse
6+
import glob
7+
import json
8+
import os
9+
from tqdm import tqdm # type: ignore
10+
11+
import tvm
12+
from tvm import meta_schedule as ms
13+
from tvm.ir import save_json
14+
from tvm.meta_schedule.testing.relay_workload import _load_cache
15+
from tvm.runtime import load_param_dict
16+
17+
18+
def _parse_args():
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument(
21+
"--model_cache_dir", type=str, help="Please provide the full path to the model cache dir."
22+
)
23+
parser.add_argument(
24+
"--task_cache_dir", type=str, help="Please provide the full path to save extracted tasks."
25+
)
26+
parser.add_argument(
27+
"--target", type=str, default="cuda", help="Please specify the target hardware for tuning."
28+
)
29+
return parser.parse_args()
30+
31+
32+
# pylint: disable=too-many-locals
33+
def extract_and_save_tasks(cache_file, is_spatial, args):
34+
"""Extract tuning tasks and cache the nonspatial ones in the given directory.
35+
36+
Parameters
37+
----------
38+
cache_file : str
39+
The filename of the cached model.
40+
is_spatial : PackedFunc
41+
The function for checking whether a task is spatial.
42+
args : argparse.Namespace
43+
The parsed arguments.
44+
45+
Returns
46+
-------
47+
None
48+
49+
"""
50+
51+
mod, params_bytearray, _ = _load_cache(args.model_cache_dir, cache_file)
52+
params = load_param_dict(params_bytearray)
53+
try:
54+
extracted_tasks = ms.extract_task_from_relay(mod, target=args.target, params=params)
55+
except tvm._ffi.base.TVMError: # pylint: disable=protected-access
56+
return
57+
task_cache_path = os.path.join(
58+
args.task_cache_dir, cache_file.split(".")[0] + "_extracted_tasks.json"
59+
)
60+
with open(task_cache_path, "w", encoding="utf8") as file:
61+
for i, task in enumerate(extracted_tasks):
62+
subgraph = task.dispatched[0]
63+
prim_func = subgraph[subgraph.get_global_vars()[0]]
64+
if not is_spatial(prim_func):
65+
subgraph_str = save_json(subgraph)
66+
json_obj = [json.loads(subgraph_str), task.task_name]
67+
json_str = json.dumps(json_obj)
68+
assert "\n" not in json_str, "Failed to generate single line string."
69+
if i == len(extracted_tasks) - 1:
70+
file.write(json_str)
71+
else:
72+
file.write(json_str + "\n")
73+
74+
75+
if __name__ == "__main__":
76+
parsed_args = _parse_args()
77+
if not os.path.isdir(parsed_args.model_cache_dir):
78+
raise Exception("Please provide a correct model cache dir.")
79+
try:
80+
os.makedirs(parsed_args.task_cache_dir, exist_ok=True)
81+
except OSError as error:
82+
print(f"Directory {parsed_args.task_cache_dir} cannot be created successfully.")
83+
84+
check_spatial_fn = tvm.get_global_func("tir.schedule.IsSpatialPrimFunc")
85+
cache_paths = glob.glob(os.path.join(parsed_args.model_cache_dir, "*.json"))
86+
for cache_path in tqdm(cache_paths):
87+
filename = cache_path.split("/")[-1]
88+
extract_and_save_tasks(filename, check_spatial_fn, parsed_args)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""
2+
Sample measure candidates for each tuning task by evolutionary search.
3+
"""
4+
5+
import argparse
6+
import glob
7+
import json
8+
import os
9+
from typing import List
10+
from tqdm import tqdm # type: ignore
11+
12+
import tvm
13+
from tvm import meta_schedule as ms
14+
from tvm.ir import load_json
15+
from tvm.meta_schedule import TuneContext
16+
from tvm.meta_schedule.database import TuningRecord, Workload
17+
from tvm.meta_schedule.search_strategy import EvolutionarySearch
18+
from tvm.meta_schedule.space_generator import PostOrderApply
19+
from tvm.meta_schedule.testing.utils import DummyDatabase
20+
from tvm.meta_schedule.tune import DefaultCUDA, DefaultLLVM
21+
from tvm.target import Target
22+
23+
24+
def _parse_args():
25+
parser = argparse.ArgumentParser()
26+
parser.add_argument(
27+
"--task_cache_dir", type=str, help="Please provide the full path to the extracted tasks."
28+
)
29+
parser.add_argument(
30+
"--candidate_cache_dir",
31+
type=str,
32+
help="Please provide the full path to save the sampled candidates.",
33+
)
34+
parser.add_argument(
35+
"--target",
36+
type=str,
37+
default="nvidia/geforce-rtx-3070",
38+
help="Please specify the target hardware for tuning.\
39+
Note: for generating dataset, the hardware does not need to be present.",
40+
)
41+
parser.add_argument(
42+
"--init_population_size",
43+
type=int,
44+
default=256,
45+
help="The initial population size used in evolutionary search.",
46+
)
47+
parser.add_argument(
48+
"--num_samples_per_task",
49+
type=int,
50+
default=400,
51+
help="The number of samples to gather per tuning task.",
52+
)
53+
parser.add_argument(
54+
"--num_trials_per_iter",
55+
type=int,
56+
default=64,
57+
help="The number of trials per iteration in evolutionary search.",
58+
)
59+
parser.add_argument(
60+
"--max_trials_per_task",
61+
type=int,
62+
default=400,
63+
help="The maximum number of trials per task in evolutionary search.",
64+
)
65+
parser.add_argument(
66+
"--max_retry_per_task",
67+
type=int,
68+
default=10,
69+
help="The maximum number of retry attempts allowed.",
70+
)
71+
parser.add_argument(
72+
"--file_group",
73+
type=int,
74+
default=0,
75+
help="To enable running multiple scripts in parallel, files [idx * 10 : (idx + 1) * 10]\
76+
in the sorted file list from the given directory will be run.",
77+
)
78+
return parser.parse_args()
79+
80+
81+
# pylint: disable=too-many-locals
82+
def sample_candidates(task, task_name, model_name):
83+
"""Randomly sample candidates for a task and save the candidates in the given directory.
84+
85+
Parameters
86+
----------
87+
task : IRModule
88+
The initial ir module used for generating the search space.
89+
task_name : str
90+
The name of the task.
91+
model_name : str
92+
The name of the model.
93+
94+
Returns
95+
-------
96+
None
97+
98+
"""
99+
100+
strategy = EvolutionarySearch(
101+
num_trials_per_iter=args.num_trials_per_iter,
102+
max_trials_per_task=args.max_trials_per_task,
103+
)
104+
default_config = DefaultCUDA if args.target != "llvm" else DefaultLLVM
105+
# pylint: disable=protected-access
106+
context = TuneContext(
107+
mod=task,
108+
target=Target(args.target),
109+
space_generator=PostOrderApply(),
110+
search_strategy=strategy,
111+
sch_rules=default_config._sch_rules(), # type: ignore
112+
postprocs=default_config._postproc(), # type: ignore
113+
mutator_probs=default_config._mutator_probs(), # type: ignore
114+
task_name=task_name,
115+
)
116+
context.initialize()
117+
spaces = context.space_generator.generate_design_space(context.mod)
118+
# type: ignore
119+
strategy.pre_tuning(spaces, database=DummyDatabase(), cost_model=ms.cost_model.RandomModel())
120+
121+
all_states: List[tvm.tir.schedule.schedule.Schedule] = []
122+
num_retry, itr = 0, 0
123+
states = sample_init_population(strategy, args.init_population_size)
124+
while len(all_states) < args.num_samples_per_task and num_retry < args.max_retry_per_task:
125+
states = evolve_with_cost_model(strategy, states, len(states))
126+
all_states += states
127+
if len(states) == 0:
128+
states = sample_init_population(strategy, args.init_population_size)
129+
num_retry += 1
130+
else:
131+
num_retry = 0
132+
print(f"iter: {itr}, number of states sampled: {len(all_states)}")
133+
itr += 1
134+
all_states = all_states[: args.num_samples_per_task]
135+
136+
workload = Workload(context.mod)
137+
file_path = os.path.join(args.candidate_cache_dir, model_name, task_name + ".json")
138+
with open(file_path, "w", encoding="utf8") as file:
139+
for i, state in enumerate(all_states):
140+
tuning_record = TuningRecord(state.trace, workload)
141+
json_str = json.dumps(tuning_record.as_json())
142+
assert "\n" not in json_str, "Failed to generate single line string."
143+
if i == len(all_states) - 1:
144+
file.write(json_str)
145+
else:
146+
file.write(json_str + "\n")
147+
148+
149+
if __name__ == "__main__":
150+
args = _parse_args()
151+
if not os.path.isdir(args.task_cache_dir):
152+
raise Exception("Please provide a correct task cache dir.")
153+
try:
154+
os.makedirs(args.candidate_cache_dir, exist_ok=True)
155+
except OSError as error:
156+
print(f"Directory {args.candidate_cache_dir} cannot be created successfully.")
157+
158+
sample_init_population = tvm.get_global_func(
159+
"meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation"
160+
)
161+
evolve_with_cost_model = tvm.get_global_func(
162+
"meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel"
163+
)
164+
task_paths = sorted(glob.glob(os.path.join(args.task_cache_dir, "*.json")))[
165+
args.file_group * 10 : (args.file_group + 1) * 10
166+
]
167+
print(f"Selected models: {task_paths}")
168+
for num, task_path in enumerate(task_paths):
169+
print(f"Processing model {num} ...")
170+
with open(task_path, "rb") as f:
171+
tasks = f.readlines()
172+
model_n = task_path.split("/")[-1][len("relay-") :][: -len("_extracted_tasks.json")]
173+
os.makedirs(os.path.join(args.candidate_cache_dir, model_n), exist_ok=True)
174+
for task_str in tqdm(tasks):
175+
task_mod, task_n = json.loads(task_str)
176+
task_mod = load_json(json.dumps(task_mod))
177+
sample_candidates(task_mod, task_n, model_n)

src/meta_schedule/search_strategy/evolutionary_search.cc

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,9 +716,36 @@ class EvolutionarySearch : public SearchStrategy {
716716
EvolutionarySearchNode);
717717
};
718718

719+
Array<Schedule> EvolutionarySearchSampleInitPopulation(EvolutionarySearch self, int num) {
720+
std::vector<Schedule> results = self->state_->SampleInitPopulation(num);
721+
return Array<Schedule>(results.begin(), results.end());
722+
}
723+
724+
Array<Schedule> EvolutionarySearchEvolveWithCostModel(ObjectRef _self, Array<Schedule> population,
725+
int num) {
726+
Array<Schedule> result;
727+
const EvolutionarySearchNode* self = _self.as<EvolutionarySearchNode>();
728+
std::vector<Schedule> population_vec =
729+
std::vector<Schedule>(population.begin(), population.end());
730+
std::vector<Schedule> schs = self->state_->EvolveWithCostModel(population_vec, num);
731+
for (Schedule sch : schs) {
732+
IRModule mod = sch->mod();
733+
size_t shash = StructuralHash()(mod);
734+
if (!self->state_->measured_workloads_.Has(mod, shash)) {
735+
self->state_->measured_workloads_.Add(mod, shash);
736+
result.push_back(sch);
737+
}
738+
}
739+
return result;
740+
}
741+
719742
TVM_REGISTER_NODE_TYPE(EvolutionarySearchNode);
720743
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearch")
721744
.set_body_typed(SearchStrategy::EvolutionarySearch);
745+
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation")
746+
.set_body_typed(EvolutionarySearchSampleInitPopulation);
747+
TVM_REGISTER_GLOBAL("meta_schedule.SearchStrategyEvolutionarySearchEvolveWithCostModel")
748+
.set_body_typed(EvolutionarySearchEvolveWithCostModel);
722749

723750
} // namespace meta_schedule
724751
} // namespace tvm

src/tir/schedule/analysis/analysis.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2234,6 +2234,7 @@ Optional<TensorizeInfo> GetTensorizeLoopMapping(const tir::ScheduleState& self,
22342234
return TensorizeInfo(ret);
22352235
}
22362236

2237+
TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc);
22372238
TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping")
22382239
.set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) {
22392240
return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func);

0 commit comments

Comments
 (0)