Skip to content

Commit 512ac20

Browse files
committed
Add tuning scripts for tir, te & relay.
Co-authored-by: Junru Shao <[email protected]> Co-authored-by: Bohan Hou <[email protected]> Co-authored-by: Ruihang Lai <[email protected]> Co-authored-by: Hongyi Jin <[email protected]> Co-authored-by: Wuwei Lin <[email protected]> Co-authored-by: Siyuan Feng <[email protected]> Minor fix. Nits. Add back tests.
1 parent 1935341 commit 512ac20

File tree

14 files changed

+1274
-25
lines changed

14 files changed

+1274
-25
lines changed

python/tvm/meta_schedule/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,19 @@
1919
from . import database
2020
from . import builder
2121
from . import runner
22+
from . import mutator
23+
from . import postproc
24+
from . import schedule_rule
2225
from . import space_generator
2326
from . import search_strategy
24-
from . import schedule_rule
2527
from . import integration
2628
from . import feature_extractor
29+
from . import cost_model
30+
from .search_strategy import (
31+
EvolutionarySearchConfig,
32+
MeasureCandidate,
33+
ReplayFuncConfig,
34+
ReplayTraceConfig,
35+
)
36+
from .tune import tune_te, tune_tir, tune_relay
2737
from .tune_context import TuneContext
28-
from .search_strategy import MeasureCandidate

python/tvm/meta_schedule/integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __init__(self, database) -> None:
184184
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member
185185

186186

187-
def extract_task(
187+
def extract_task_from_relay(
188188
mod: Union[IRModule, RelayFunc],
189189
target: Target,
190190
params: Optional[Dict[str, NDArray]] = None,

python/tvm/meta_schedule/testing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616
# under the License.
1717
"""Testing utilities in meta schedule"""
1818
from .local_rpc import LocalRPC
19-
from .relay_workload import get_network
19+
from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model

python/tvm/meta_schedule/testing/relay_workload.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,93 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Workloads in Relay IR"""
18+
from enum import Enum
1819
from typing import Dict, Tuple
1920

2021
import tvm.relay.testing # pylint: disable=unused-import
2122
from tvm import relay
2223
from tvm.ir import IRModule
2324
from tvm.runtime import NDArray
2425

26+
# Model types supported in Torchvision
27+
class MODEL_TYPE(Enum): # pylint: disable=invalid-name
28+
IMAGE_CLASSIFICATION = (1,)
29+
VIDEO_CLASSIFICATION = (2,)
30+
SEGMENTATION = (3,)
31+
OBJECT_DETECTION = (4,)
32+
TEXT_CLASSIFICATION = (5,)
33+
34+
35+
# Specify the type of each model
36+
MODEL_TYPES = {
37+
"resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
38+
"mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
39+
"bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
40+
}
41+
42+
43+
def get_torch_model(
44+
model_name: str,
45+
input_shape: Tuple[int, ...],
46+
output_shape: Tuple[int, int], # pylint: disable=unused-argument
47+
dtype: str = "float32",
48+
) -> Tuple[IRModule, Dict[str, NDArray]]:
49+
"""Load model from torch model zoo
50+
Parameters
51+
----------
52+
model_name : str
53+
The name of the model to load
54+
input_shape: Tuple[int, ...]
55+
Tuple for input shape
56+
output_shape: Tuple[int, int]
57+
Tuple for output shape
58+
dtype: str
59+
Tensor data type
60+
"""
61+
62+
assert dtype == "float32"
63+
64+
import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
65+
from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
66+
import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
67+
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel
68+
69+
def do_trace(model, inp):
70+
model.eval()
71+
model_trace = torch.jit.trace(model, inp)
72+
model_trace.eval()
73+
return model_trace
74+
75+
# Load model from torchvision
76+
if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
77+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
78+
model = transformers.BertModel(
79+
transformers.BertConfig(
80+
num_hidden_layers=12,
81+
hidden_size=768,
82+
intermediate_size=3072,
83+
num_attention_heads=12,
84+
return_dict=False,
85+
)
86+
)
87+
model.eval()
88+
input_data = torch.randint(10000, input_shape)
89+
shape_list = [("input_ids", input_shape)]
90+
scripted_model = torch.jit.trace(model, [input_data], strict=False)
91+
elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
92+
model = getattr(models, model_name)()
93+
# Setup input
94+
input_data = torch.randn(input_shape).type(torch.float32)
95+
shape_list = [("input0", input_shape)]
96+
# Get trace. Depending on the model type, wrapper may be necessary.
97+
scripted_model = do_trace(model, input_data)
98+
else:
99+
raise ValueError("Unsupported model in Torch model zoo.")
100+
101+
# Convert torch model to relay module
102+
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
103+
return mod, params
104+
25105

26106
def get_network(
27107
name: str,

0 commit comments

Comments
 (0)