Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions python/tvm/meta_schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,19 @@
from . import database
from . import builder
from . import runner
from . import mutator
from . import postproc
from . import schedule_rule
from . import space_generator
from . import search_strategy
from . import schedule_rule
from . import integration
from . import feature_extractor
from . import cost_model
from .search_strategy import (
EvolutionarySearchConfig,
MeasureCandidate,
ReplayFuncConfig,
ReplayTraceConfig,
)
from .tune import tune_te, tune_tir, tune_relay
from .tune_context import TuneContext
from .search_strategy import MeasureCandidate
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self, database) -> None:
self.__init_handle_by_constructor__(_ffi_api.ApplyHistoryBest, database) # type: ignore # pylint: disable=no-member


def extract_task(
def extract_task_from_relay(
mod: Union[IRModule, RelayFunc],
target: Target,
params: Optional[Dict[str, NDArray]] = None,
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities in meta schedule"""
from .local_rpc import LocalRPC
from .relay_workload import get_network
from .byoc_trt import relay_build_with_tensorrt
from .local_rpc import LocalRPC
from .relay_workload import MODEL_TYPE, MODEL_TYPES, get_network, get_torch_model
80 changes: 80 additions & 0 deletions python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,93 @@
# specific language governing permissions and limitations
# under the License.
"""Workloads in Relay IR"""
from enum import Enum
from typing import Dict, Tuple

import tvm.relay.testing # pylint: disable=unused-import
from tvm import relay
from tvm.ir import IRModule
from tvm.runtime import NDArray

# Model types supported in Torchvision
class MODEL_TYPE(Enum): # pylint: disable=invalid-name
IMAGE_CLASSIFICATION = (1,)
VIDEO_CLASSIFICATION = (2,)
SEGMENTATION = (3,)
OBJECT_DETECTION = (4,)
TEXT_CLASSIFICATION = (5,)


# Specify the type of each model
MODEL_TYPES = {
"resnet18": MODEL_TYPE.IMAGE_CLASSIFICATION,
"mobilenet_v2": MODEL_TYPE.IMAGE_CLASSIFICATION,
"bert_base": MODEL_TYPE.TEXT_CLASSIFICATION,
}


def get_torch_model(
model_name: str,
input_shape: Tuple[int, ...],
output_shape: Tuple[int, int], # pylint: disable=unused-argument
dtype: str = "float32",
) -> Tuple[IRModule, Dict[str, NDArray]]:
"""Load model from torch model zoo
Parameters
----------
model_name : str
The name of the model to load
input_shape: Tuple[int, ...]
Tuple for input shape
output_shape: Tuple[int, int]
Tuple for output shape
dtype: str
Tensor data type
"""

assert dtype == "float32"

import torch # type: ignore # pylint: disable=import-error,import-outside-toplevel
from torchvision import models # type: ignore # pylint: disable=import-error,import-outside-toplevel
import transformers # type: ignore # pylint: disable=import-error,import-outside-toplevel
import os # type: ignore # pylint: disable=import-error,import-outside-toplevel

def do_trace(model, inp):
model.eval()
model_trace = torch.jit.trace(model, inp)
model_trace.eval()
return model_trace

# Load model from torchvision
if MODEL_TYPES[model_name] == MODEL_TYPE.TEXT_CLASSIFICATION:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model = transformers.BertModel(
transformers.BertConfig(
num_hidden_layers=12,
hidden_size=768,
intermediate_size=3072,
num_attention_heads=12,
return_dict=False,
)
)
model.eval()
input_data = torch.randint(10000, input_shape)
shape_list = [("input_ids", input_shape)]
scripted_model = torch.jit.trace(model, [input_data], strict=False)
elif MODEL_TYPES[model_name] == MODEL_TYPE.IMAGE_CLASSIFICATION:
model = getattr(models, model_name)()
# Setup input
input_data = torch.randn(input_shape).type(torch.float32)
shape_list = [("input0", input_shape)]
# Get trace. Depending on the model type, wrapper may be necessary.
scripted_model = do_trace(model, input_data)
else:
raise ValueError("Unsupported model in Torch model zoo.")

# Convert torch model to relay module
mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
return mod, params


def get_network(
name: str,
Expand Down
Loading