From acfa01603178bb199b913c20947927d7b76a8d64 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Thu, 25 Jan 2024 06:27:40 +0800 Subject: [PATCH 1/2] reconstruct done --- .../contrib/msc/core/frontend/translate.py | 28 +- python/tvm/contrib/msc/core/ir/graph.py | 13 + python/tvm/contrib/msc/core/runtime/hook.py | 196 ++++++ python/tvm/contrib/msc/core/runtime/runner.py | 634 ++++++++++++------ .../msc/core/tools/distill/distiller.py | 15 +- .../contrib/msc/core/tools/prune/pruner.py | 82 ++- python/tvm/contrib/msc/core/tools/tool.py | 83 ++- .../contrib/msc/core/tools/track/method.py | 10 +- .../contrib/msc/core/tools/track/tracker.py | 12 +- .../tvm/contrib/msc/core/transform/pattern.py | 97 ++- .../contrib/msc/core/transform/transform.py | 4 +- python/tvm/contrib/msc/core/utils/__init__.py | 1 + .../tvm/contrib/msc/core/utils/arguments.py | 228 +++++++ python/tvm/contrib/msc/core/utils/dataset.py | 6 +- python/tvm/contrib/msc/core/utils/expr.py | 23 +- python/tvm/contrib/msc/core/utils/file.py | 55 ++ python/tvm/contrib/msc/core/utils/info.py | 240 ++----- python/tvm/contrib/msc/core/utils/log.py | 2 +- python/tvm/contrib/msc/core/utils/message.py | 2 +- python/tvm/contrib/msc/core/utils/register.py | 41 +- .../framework/tensorflow/runtime/runner.py | 91 ++- .../msc/framework/tensorrt/codegen/codegen.py | 48 +- .../msc/framework/tensorrt/runtime/runner.py | 47 +- .../framework/tensorrt/transform/pattern.py | 26 +- .../msc/framework/torch/runtime/runner.py | 132 +++- .../msc/framework/tvm/runtime/runner.py | 157 ++++- .../msc/framework/tvm/tools/track/tracker.py | 5 +- python/tvm/contrib/msc/pipeline/manager.py | 569 +++++++--------- src/contrib/msc/core/codegen/codegen_utils.h | 9 +- src/contrib/msc/core/codegen/py_codegen.h | 2 +- src/contrib/msc/core/ir/graph.cc | 38 ++ src/contrib/msc/core/ir/graph.h | 2 + src/contrib/msc/core/ir/graph_builder.cc | 348 ++++++---- src/contrib/msc/core/ir/graph_builder.h | 49 +- src/contrib/msc/core/transform/fuse_tuple.cc | 36 +- .../msc/core/transform/inline_params.cc | 192 ++++++ .../msc/core/transform/layout_utils.cc | 25 +- .../msc/core/transform/set_byoc_attrs.cc | 39 +- .../msc/core/transform/set_expr_name.cc | 105 +-- .../msc/framework/tensorflow/codegen_utils.h | 7 +- src/contrib/msc/framework/tensorrt/codegen.cc | 10 +- .../framework/tensorrt/transform_tensorrt.cc | 28 +- .../msc/framework/torch/codegen_utils.h | 7 +- tests/python/contrib/test_msc/test_manager.py | 18 +- tests/python/contrib/test_msc/test_runner.py | 22 +- tests/python/contrib/test_msc/test_tools.py | 18 +- .../python/contrib/test_msc/test_transform.py | 20 +- .../test_msc/test_translate_tensorrt.py | 10 +- 48 files changed, 2604 insertions(+), 1228 deletions(-) create mode 100644 python/tvm/contrib/msc/core/runtime/hook.py create mode 100644 python/tvm/contrib/msc/core/utils/arguments.py create mode 100644 src/contrib/msc/core/transform/inline_params.cc diff --git a/python/tvm/contrib/msc/core/frontend/translate.py b/python/tvm/contrib/msc/core/frontend/translate.py index 4a7710f382af..2eaae1335855 100644 --- a/python/tvm/contrib/msc/core/frontend/translate.py +++ b/python/tvm/contrib/msc/core/frontend/translate.py @@ -66,7 +66,12 @@ def _to_data(ref_t, data): return data weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)} - return weights + # sort the weights by graph weights + graph_weights = {} + for weight in graph.get_weights(): + assert weight.name in weights, "Missing weight " + str(weight) + graph_weights[weight.name] = weights[weight.name] + return graph_weights def from_relax( @@ -115,13 +120,10 @@ def from_relax( patterns = get_patterns_with_prefix("msc.") passes = [ msc_transform.SetExprName(), + msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), tvm.relax.transform.FuseOpsByPattern( patterns, bind_constants=False, annotate_codegen=False ), - msc_transform.SetExprName(entry_name=entry, target=trans_config.get("target", "")), - msc_transform.SetExprLayout( - trans_config.get("allow_layout_missing", True), entry_name=entry - ), ] mod = tvm.transform.Sequential(passes)(mod) graph = _ffi_api.BuildFromRelax(mod, entry, msc_utils.dump_dict(build_config)) @@ -309,13 +311,12 @@ def _partition_mod(mod, as_msc=True): patterns = get_patterns_with_prefix(target) passes = [ msc_transform.SetExprName(), + msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc), - msc_transform.BindShape(), + msc_transform.InlineParams(), msc_transform.FuseTuple(target), tvm.relax.transform.MergeCompositeFunctions(), msc_transform.SetBYOCAttrs(target), - msc_transform.SetExprName(target=target), - msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)), ] return tvm.transform.Sequential(passes)(mod) @@ -331,9 +332,12 @@ def _is_target_func(func): assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod) BYOCChecker().check(func_names, msc_mod[entry]) - graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry) + ref_weights = _ffi_api.GetRelaxWeights(msc_mod, entry) + graphs, weights = [], {} for name in func_names: - build_config.update({"graph_name": msc_mod[name].attrs["byoc_name"], "byoc_entry": name}) + graph_name = msc_mod[name].attrs[_ffi_api.ToAttrKey("unique")] + build_config.update({"graph_name": graph_name, "byoc_entry": name}) graph = _ffi_api.BuildFromRelax(msc_mod, entry, msc_utils.dump_dict(build_config)) - graphs_info.append((graph, normalize_weights(all_weights, graph))) - return _partition_mod(mod, False), graphs_info + graphs.append(graph) + weights.update(normalize_weights(ref_weights, graph)) + return _partition_mod(mod, False), graphs, weights diff --git a/python/tvm/contrib/msc/core/ir/graph.py b/python/tvm/contrib/msc/core/ir/graph.py index 51727cd08969..5bfe1cec2a6f 100644 --- a/python/tvm/contrib/msc/core/ir/graph.py +++ b/python/tvm/contrib/msc/core/ir/graph.py @@ -660,6 +660,19 @@ def get_nodes(self) -> Iterable[MSCJoint]: for n in self.node_names: yield self.find_node(n) + def get_weights(self) -> Iterable[MSCTensor]: + """Get all the weights in the graph. + + Returns + ------- + weights: generator + The generator of weights. + """ + + for node in self.get_nodes(): + for weight in node.get_weights().values(): + yield weight + def input_at(self, idx: int) -> MSCTensor: """Get input at idx. diff --git a/python/tvm/contrib/msc/core/runtime/hook.py b/python/tvm/contrib/msc/core/runtime/hook.py new file mode 100644 index 000000000000..1229697a63fb --- /dev/null +++ b/python/tvm/contrib/msc/core/runtime/hook.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument, arguments-differ +"""tvm.contrib.msc.core.runtime.hook""" + +from typing import Dict, List, Tuple, Union, Any + +import tvm +from tvm.contrib.msc.core.ir import MSCGraph +from tvm.contrib.msc.core import utils as msc_utils + + +class RunnerHook(object): + """Hook for runner + + Parameters + ---------- + config: dict + The config of the func. + """ + + def __init__(self, config: dict): + self._config = config + + def __str__(self): + return "{}({})".format(self.name(), self._config) + + def apply(self, runner: object, *args, **kwargs) -> Any: + """Apply the hook + + Parameters + ---------- + runner: + The runner context. + args: list + The arguments for run method. + kwargs: dict + The key word arguments for run method. + + Returns + ------- + result: + The results. + """ + + kwargs.update({k: v for k, v in self._config.items() if k not in kwargs}) + return self._apply(runner, *args, **kwargs) + + def _apply(self, runner: object, *args, **kwargs): + """Apply the hook + + Parameters + ---------- + runner: + The runner context. + args: list + The arguments for run method. + kwargs: dict + The key word arguments for run method. + + Returns + ------- + result: + The results. + """ + + raise NotImplementedError("default_func is not supported in " + str(self.__class__)) + + @classmethod + def name(cls): + return "base" + + +class CustomizedHook(RunnerHook): + """Hook for customized func + + Parameters + ---------- + func: callable/str + The function. + config: dict + The config of the func. + """ + + def __init__(self, func: Union[str, callable], config: dict): + super(CustomizedHook, self).__init__(config) + self._func = msc_utils.load_callable(func) + + def __str__(self): + return "{} {}({})".format(self.name(), self._func, self._config) + + def _apply(self, runner: object, *args, **kwargs): + """Apply the hook + + Parameters + ---------- + runner: + The runner context. + args: list + The arguments for run method. + kwargs: dict + The key word arguments for run method. + + Returns + ------- + result: + The results. + """ + + return self._func(runner, *args, **kwargs) + + @classmethod + def name(cls): + return "customized" + + +class UpdateWeightsHook(RunnerHook): + """Hook for update weights""" + + def _apply( + self, + runner: object, + graphs: List[MSCGraph], + weights: Dict[str, tvm.nd.array], + weights_path: str, + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + """Apply the default funcion + + Parameters + ------- + runner: + The runner context. + graphs: list + The translated graphs + weights: dict + The translated weights. + weights_path: str + The weights path. + + Returns + ------- + graphs: list + The updated graphs + weights: dict + The updated weights. + + """ + + with open(weights_path, "rb") as f: + new_weights = tvm.runtime.load_param_dict(f.read()) + weights.update({k: v for k, v in new_weights.items() if k in weights}) + return graphs, weights + + @classmethod + def name(cls): + return "update_weights" + + +def load_runner_hook(config: dict) -> Any: + """Load a registered hook + + Parameters + ---------- + config: dict + The config of the func. + + Returns + ------- + hook: RunnerHook + The hook + """ + + assert "hook" in config, "hook should be given to load hook" + hook_ref = config["hook"] + hook_config = {k: v for k, v in config.items() if k != "hook"} + hook_cls = msc_utils.get_registered_runner_hook(hook_ref) if isinstance(hook_ref, str) else None + if hook_cls: + return hook_cls(hook_config) + return CustomizedHook(hook_ref, hook_config) + + +msc_utils.register_runner_hook(UpdateWeightsHook) diff --git a/python/tvm/contrib/msc/core/runtime/runner.py b/python/tvm/contrib/msc/core/runtime/runner.py index 4b84037994ee..2849eb05ed83 100644 --- a/python/tvm/contrib/msc/core/runtime/runner.py +++ b/python/tvm/contrib/msc/core/runtime/runner.py @@ -31,6 +31,7 @@ from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.core import _ffi_api +from .hook import load_runner_hook class BaseRunner(object): @@ -48,14 +49,16 @@ class BaseRunner(object): The config for translate IRModule to MSCGraph. codegen_config: dict The config for build MSCGraph to runnable model. + build_config: dict + The config for build runnable. + device: str + The device to build runnable. + training: bool + Whether compile model to trainable stage: str The stage of runner. name: str The name of the runner - device: str - The device of the model, cpu| cuda| cuda:0|... - is_training: bool - Whether use model in training debug_level: int The debug level. logger: logging.Logger @@ -68,10 +71,11 @@ def __init__( tools_config: Optional[Dict[str, Any]] = None, translate_config: Optional[Dict[str, str]] = None, generate_config: Optional[Dict[str, str]] = None, + build_config: Optional[Dict[str, str]] = None, + device: str = "cpu", + training: bool = False, stage: str = "default", name: str = "main", - device: str = "cpu", - is_training: bool = False, debug_level: int = 0, logger: logging.Logger = None, ): @@ -79,11 +83,12 @@ def __init__( self._tools_config = msc_utils.copy_dict(tools_config) self._translate_config = msc_utils.copy_dict(translate_config) self._generate_config = msc_utils.copy_dict(generate_config) + self._build_config = msc_utils.copy_dict(build_config) + self._device = device if self._device_enabled(device) else "cpu" self._stage = stage self._name = name - self._device = device if self._device_enabled(device) else "cpu" - self._is_training = is_training self._debug_level = debug_level + self._training, self._trained = training, training self._logger = logger or msc_utils.get_global_logger() self._logger.info( msc_utils.msg_block( @@ -102,7 +107,7 @@ def setup(self) -> dict: if "build_folder" not in self._generate_config: self._generate_config["build_folder"] = msc_utils.get_build_dir() - self._graphs, self._weights = [], [] + self._graphs, self._weights = [], {} self._model, self._model_info = None, {} self._runnable = None # Setup tools @@ -111,15 +116,20 @@ def setup(self) -> dict: self._update_codegen({"use_tools": True, "tools_tag": self._name}) for t_type, config in self._tools_config.items(): self._tools[t_type] = create_tool( - self.framework, t_type, self._name, stage=self._stage, **config + self.framework, + t_type, + self._name, + training=self._training, + stage=self._stage, + **config, ) return { "tools": {k: v.tool_style() for k, v in self._tools.items()}, "translate_config": self._translate_config, "generate_config": self._generate_config, - "name": self._name, + "build_config": self._build_config, "device": self._device, - "is_training": self._is_training, + "name": self._name, "debug_level": self._debug_level, } @@ -154,7 +164,7 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, force_build: bool = Fa """ if force_build: - self._graphs, self._weights = [], [] + self._graphs, self._weights = [], {} self._model, self._model_info = None, {} self._runnable = None if cache_dir and os.path.isfile(cache_dir.relpath("cache_info.json")): @@ -164,13 +174,23 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, force_build: bool = Fa # Load graphs from cache if not self._graphs and cache_info.get("graphs"): - self._graphs, self._weights = self._load_graphs(cache_dir, cache_info["graphs"]) - self._logger.info("Load %d graphs from %s", len(self._graphs), cache_dir) + self._graphs = self._load_graphs(cache_dir, cache_info["graphs"]) + assert "weights" in cache_info, "Missing weights in cache_info" + with open(cache_dir.relpath(cache_info["weights"]), "rb") as f: + self._weights = tvm.runtime.load_param_dict(f.read()) + self._logger.info( + "Load %d graphs %d weights from %s", + len(self._graphs), + len(self._weights), + cache_dir, + ) # Translate graphs from module if not self._graphs: - self._graphs, self._weights = self._translate() - self._logger.info("Translate %d graphs from module", len(self._graphs)) + self._graphs, self._weights = self.translate() + self._logger.info( + "Translate %d graphs %d weights from module", len(self._graphs), len(self._weights) + ) # Load model from cache if not self._model and cache_info.get("model"): @@ -184,31 +204,26 @@ def build(self, cache_dir: msc_utils.MSCDirectory = None, force_build: bool = Fa if distiller and not distiller.distilled: build_root = self._generate_config["build_folder"] - def _build_scope_model(scope: str): + def _build_scope_model(scope: str, apply_hooks: bool): self._update_codegen({"tools_scope": scope}) self._generate_config["build_folder"] = build_root.create_dir(scope) - return self._generate_model() + return self.generate_model(apply_hooks=apply_hooks) # Generate distill model - teacher_model = _build_scope_model(ToolScope.TEACHER) + teacher_model = _build_scope_model(ToolScope.TEACHER, False) self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir) - student_model = _build_scope_model(ToolScope.STUDENT) + student_model = _build_scope_model(ToolScope.STUDENT, True) self._model = distiller.build_model(teacher_model, student_model) else: # Generate normal model self._graphs, self._weights = self.reset_tools(cache_dir=cache_dir) - self._model = self._generate_model() + self._model = self.generate_model() - # Log generate info generate_msg = "Generate model({})".format(self.framework) if self._tools: self._logger.info("%s with tools: %s", generate_msg, ",".join(self._tools.keys())) else: self._logger.info("%s without tools", generate_msg) - if "generator" in self._generate_config: - generator, generate_config = self._generate_config["generator"] - self._model = generator(self._model, **generate_config) - self._logger.info("%s by %s(%s)", generate_msg, generator, generate_config) # Inspect model self._model_info = self._inspect_model() @@ -216,9 +231,8 @@ def _build_scope_model(scope: str): self._logger.debug(msc_utils.msg_block("RUNNER.MODEL_INFO", self._model_info)) runnable_msg = "runnable({}, {}) @ {}".format( - self.framework, "train" if self._is_training else "eval", self._device + self.framework, "train" if self._training else "eval", self._device ) - # Load runnable from cache if not self._runnable and cache_info.get("runnable"): self._runnable = self._load_runnable(cache_dir, cache_info["runnable"]) @@ -226,10 +240,68 @@ def _build_scope_model(scope: str): # Build runnable if not self._runnable: - self._runnable = self._to_runnable(self._model, self._device, self._is_training) + self._runnable = self.build_runnable() self._logger.info("Build %s", runnable_msg) return self._runnable + def run( + self, inputs: Union[List[np.ndarray], Dict[str, np.ndarray]], ret_type="dict" + ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: + """Run the model to get outputs + + Parameters + ------- + inputs: list or dict + The inputs in list or dict. + ret_type: str + The return type list| dict + + Returns + ------- + outputs: dict + The outputs in dict. + """ + + model_inputs = self.get_inputs() + model_outputs = self.get_outputs() + if isinstance(inputs, (list, tuple)): + assert len(inputs) == len( + model_inputs + ), "inputs({}) mismatch with model inputs {}".format(len(inputs), model_inputs) + inputs = {info["name"]: data for info, data in zip(model_inputs, inputs)} + assert isinstance(inputs, dict), "Expect inputs as list or dict, get {}({})".format( + inputs, type(inputs) + ) + assert all( + isinstance(data, np.ndarray) for data in inputs.values() + ), "Expected all inputs as np.ndarray" + inputs = {i["name"]: inputs[i["name"]] for i in model_inputs} + outputs = self._call_runnable(self._runnable, inputs, self._device) + if ret_type == "native": + return outputs + if ret_type == "dict": + if isinstance(outputs, (list, tuple)): + assert len(outputs) == len( + model_outputs + ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) + outputs = {info["name"]: data for info, data in zip(model_outputs, outputs)} + if not isinstance(outputs, dict): + assert len(model_outputs) == 1, "Expect model_outputs with len 1, get " + str( + model_outputs + ) + outputs = {model_outputs[0]["name"]: outputs} + outputs = {name: msc_utils.cast_array(data) for name, data in outputs.items()} + elif ret_type == "list": + if isinstance(outputs, dict): + assert len(outputs) == len( + model_outputs + ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) + outputs = [outputs[o["name"]] for o in model_outputs] + if not isinstance(outputs, (list, tuple)): + outputs = [outputs] + outputs = [msc_utils.cast_array(data) for data in outputs] + return outputs + def save_cache( self, cache_dir: msc_utils.MSCDirectory, @@ -251,10 +323,13 @@ def save_cache( Whether to save tools. """ - cache_info = {"graphs": self._save_graphs(cache_dir)} - if save_model: + cache_info = {"graphs": self._save_graphs(cache_dir), "weights": "graph_weights.bin"} + with cache_dir: + with open(cache_info["weights"], "wb") as f_params: + f_params.write(tvm.runtime.save_param_dict(self._weights)) + if save_model and cache_info.get("graphs"): cache_info["model"] = self._save_model(cache_dir) - if save_runnable: + if save_runnable and cache_info.get("model"): cache_info["runnable"] = self._save_runnable(cache_dir) if save_tools: for t_type, tool in self._tools.items(): @@ -265,6 +340,50 @@ def save_cache( msc_utils.msg_block("RUNNER.SAVE_CACHE", {"folder": cache_dir, "info": cache_info}) ) + def translate(self, apply_hooks: bool = True) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + """Translate IRModule to MSCgraphs + + Parameters + ------- + apply_hooks: bool + Whether to apply hooks. + + Returns + ------- + graphs: list + The translated graphs + weights: dict + The translated weights. + """ + + mod = self._mod + if apply_hooks: + for hook in self._translate_config.get("pre_hooks", []): + mod = self._apply_hook("before translate", hook, mod) + graphs, weights = self._translate(mod) + if apply_hooks: + for hook in self._translate_config.get("post_hooks", []): + graphs, weights = self._apply_hook("after translate", hook, graphs, weights) + return graphs, weights + + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + """Translate IRModule to MSCgraphs + + Parameters + ------- + mod: tvm.IRModule + The module to be translated. + + Returns + ------- + graphs: list + The translated graphs + weights: dict + The translated weights. + """ + + raise NotImplementedError("_translate is not implemented for " + str(self.__class__)) + def reset_tools( self, graphs: List[MSCGraph] = None, @@ -300,63 +419,116 @@ def reset_tools( graphs, weights = tool.reset(graphs, weights, cache_dir) return graphs, weights - def run( - self, inputs: Union[List[np.ndarray], Dict[str, np.ndarray]], ret_type="dict" - ) -> Union[List[np.ndarray], Dict[str, np.ndarray]]: - """Run the model to get outputs + def generate_model(self, apply_hooks: bool = True) -> Any: + """Codegen the model according to framework Parameters ------- - inputs: list or dict - The inputs in list or dict. - ret_type: str - The return type list| dict + apply_hooks: bool + Whether to apply hooks. Returns ------- - outputs: dict - The outputs in dict. + model: Any + The meta model """ - model_inputs = self.get_inputs() - model_outputs = self.get_outputs() - if isinstance(inputs, (list, tuple)): - assert len(inputs) == len( - model_inputs - ), "inputs({}) mismatch with model inputs {}".format(len(inputs), model_inputs) - inputs = {info["name"]: data for info, data in zip(model_inputs, inputs)} - assert isinstance(inputs, dict), "Expect inputs as list or dict, get {}({})".format( - inputs, type(inputs) - ) - assert all( - isinstance(data, np.ndarray) for data in inputs.values() - ), "Expected all inputs as np.ndarray" - inputs = {i["name"]: inputs[i["name"]] for i in model_inputs} - outputs = self._call_runnable(self._runnable, inputs, self._device) - if ret_type == "native": - return outputs - if ret_type == "dict": - if isinstance(outputs, (list, tuple)): - assert len(outputs) == len( - model_outputs - ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) - outputs = {info["name"]: data for info, data in zip(model_outputs, outputs)} - if not isinstance(outputs, dict): - assert len(model_outputs) == 1, "Expect model_outputs with len 1, get " + str( - model_outputs - ) - outputs = {model_outputs[0]["name"]: outputs} - outputs = {name: msc_utils.cast_array(data) for name, data in outputs.items()} - elif ret_type == "list": - if isinstance(outputs, dict): - assert len(outputs) == len( - model_outputs - ), "outputs({}) mismatch with model outputs {}".format(len(outputs), model_outputs) - outputs = [outputs[o["name"]] for o in model_outputs] - if not isinstance(outputs, (list, tuple)): - outputs = [outputs] - outputs = [msc_utils.cast_array(data) for data in outputs] - return outputs + graphs, weights = self._graphs, self._weights + if apply_hooks: + for hook in self._generate_config.get("pre_hooks", []): + graphs, weights = self._apply_hook("before generate", hook, graphs, weights) + model = self._generate_model(graphs, weights) + if apply_hooks: + for hook in self._generate_config.get("post_hooks", []): + model = self._apply_hook("after generate", hook, model) + return model + + def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: + """Codegen the model according to framework + + Parameters + ------- + graphs: list + The msc graphs. + weights: dict + The weights. + + Returns + ------- + model: Any + The meta model + """ + + raise NotImplementedError("_load is not implemented for " + str(self.__class__)) + + def build_runnable(self, apply_hooks: bool = True) -> Any: + """Build runnable object + + Parameters + ------- + apply_hooks: bool + Whether to apply hooks. + + Returns + ------- + runnable: Any + The runnable + """ + + model = self._model + if apply_hooks: + for hook in self._build_config.get("pre_hooks", []): + model = self._apply_hook("before build", hook, model) + runnable = self._build_runnable(model) + if apply_hooks: + for hook in self._build_config.get("post_hooks", []): + runnable = self._apply_hook("after build", hook, runnable) + return runnable + + def _build_runnable(self, model: Any) -> Any: + """Build runnable object + + Parameters + ------- + model: Any + The meta model. + + Returns + ------- + runnable: Any + The runnable + """ + + raise NotImplementedError("_build_runnable is not implemented for " + str(self.__class__)) + + def train(self): + """Change status to train""" + + if not self._training: + self._training = True + for tool in self.get_tools(): + tool.train() + self._train() + + def _train(self): + """Change status to train""" + + self._runnable = self.build_runnable() + + def eval(self): + """Change status to eval""" + + if self._training: + self._trained = True + self._training = False + for tool in self.get_tools(): + tool.eval() + self._eval() + + def _eval(self): + """Change status to eval""" + + self._runnable = self.build_runnable() def get_tool_config(self, tool_type: str) -> dict: """Get tool by type @@ -455,6 +627,30 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> str: self._logger.info("Save %d plan(%s) -> %s", len(plan), tool_type, plan_file) return plan_file + def _apply_hook(self, desc: str, hook_def: dict, *args, **kwargs) -> Any: + """Load a registered hook + + Parameters + ---------- + desc: str + The description of the hook + hook_def: dict + The function and config of the hook. + args: list + The arguments for run method. + kwargs: dict + The key word arguments for run method. + + Returns + ------- + result: + The result + """ + + hook = load_runner_hook(hook_def) + self._logger.info("Apply %s hook:\n %s", desc, hook) + return hook.apply(self, *args, **kwargs) + def _update_codegen(self, config: Dict[str, Any]): """Update the codegen in generate_config @@ -511,6 +707,30 @@ def get_outputs(self) -> List[Dict[str, str]]: return self._model_info["outputs"] + def get_weights(self, framework: str = None, device: str = None) -> Iterable[tvm.nd.array]: + """Get the weights from graphs + + Parameters + ------- + framework: str + The framework for weight. + device: str + The device for weight. + + Returns + ------- + weights: generator + The generator of weight datas. + """ + + device = device or self._device + for graph in self._graphs: + for weight in graph.get_weights(): + data = self._weights[weight.name] + if framework: + data = msc_utils.cast_array(data, framework, device) + yield data + def destory(self): """Destory runner""" @@ -522,23 +742,8 @@ def destory(self): tool.destory() remove_tools(self._name) - def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: - """Translate IRModule to MSCgraphs - - Returns - ------- - graph_list: list - The translated graphs - weights_list: list> - The translated weights - """ - - raise NotImplementedError("_translate is not implemented for " + str(self.__class__)) - - def _load_graphs( - self, cache_dir: msc_utils.MSCDirectory, cache_info: dict - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: - """Load MSCgraphs from cache + def _load_graphs(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> List[MSCGraph]: + """Load MSCGraphs from cache Parameters ------- @@ -549,10 +754,8 @@ def _load_graphs( Returns ------- - graph_list: list + graphs: list The translated graphs - weights_list: list> - The translated weights """ raise NotImplementedError("_load_graphs is not implemented for " + str(self.__class__)) @@ -573,26 +776,6 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: raise NotImplementedError("_save_graphs is not implemented for " + str(self.__class__)) - def _generate_model( - self, graphs: List[MSCGraph] = None, weights: List[Dict[str, tvm.nd.array]] = None - ) -> Any: - """Codegen the model according to framework - - Parameters - ------- - graphs: list - The msc graphs. - weights: list> - The weights - - Returns - ------- - model: Any - The meta model - """ - - raise NotImplementedError("_load is not implemented for " + str(self.__class__)) - def _load_model(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> Any: """Load the model from cache @@ -628,26 +811,6 @@ def _save_model(self, cache_dir: msc_utils.MSCDirectory) -> dict: # disable save model by default return {} - def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: - """Build runnable object - - Parameters - ------- - model: Any - The meta model. - device: str - The device for place model - is_training: bool - Whether to load model for training - - Returns - ------- - runnable: Any - The runnable - """ - - raise NotImplementedError("_to_runnable is not implemented for " + str(self.__class__)) - def _load_runnable(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> Any: """Load the runnable from cache @@ -727,10 +890,6 @@ def _device_enabled(self, device: str) -> bool: return True - @classmethod - def support_tool(cls, tool_type: str) -> bool: - return True - @property def stage(self): return self._stage @@ -763,33 +922,93 @@ def codegen_func(self): def framework(self): return MSCFramework.MSC + @classmethod + def load_native(cls, model: Any) -> Any: + """Load the native model + + Parameters + ------- + model: + The native model. + + Returns + ------- + model: + The loaded native model. + """ + + return model, "cpu" + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + if stage not in config: + return config + if stage in (MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE): + run_config = config[stage].get("run_config", {}) + if "translate_config" not in run_config: + run_config["translate_config"] = {} + if "build" not in run_config["translate_config"]: + run_config["translate_config"]["build"] = {} + if "generate_config" not in run_config: + run_config["generate_config"] = {} + run_config["translate_config"]["build"]["input_aliases"] = [ + i[0] for i in config["inputs"] + ] + run_config["translate_config"]["build"]["output_aliases"] = config["outputs"] + config[stage]["run_config"] = run_config + return config + + @classmethod + def support_tool(cls, tool_type: str) -> bool: + return True + class ModelRunner(BaseRunner): """Model runner of MSC""" - def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Translate IRModule to MSCgraphs + Parameters + ------- + mod: tvm.IRModule + The module to be translated. + Returns ------- - graph_list: list + graphs: list The translated graphs - weights_list: list> - The translated weights + weights: dict + The translated weights. """ graph, weights = from_relax( - self._mod, + mod, trans_config=self._translate_config.get("transform"), build_config=self._translate_config.get("build"), opt_config=self._translate_config.get("optimize"), ) - return [graph], [weights] + return [graph], weights - def _load_graphs( - self, cache_dir: msc_utils.MSCDirectory, cache_info: dict - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: - """Load MSCgraphs from cache + def _load_graphs(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> List[MSCGraph]: + """Load MSCGraphs from cache Parameters ------- @@ -800,17 +1019,13 @@ def _load_graphs( Returns ------- - graph_list: list + graphs: list The translated graphs - weights_list: list> - The translated weights """ assert "main" in cache_info, "main should be given in cache_info, get " + str(cache_info) graph = MSCGraph.from_json(cache_dir.relpath(cache_info["main"]["graph"])) - with open(cache_dir.relpath(cache_info["main"]["weights"]), "rb") as f: - weights = tvm.runtime.load_param_dict(f.read()) - return [graph], [weights] + return [graph] def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: """Save MSCgraphs to cache @@ -826,28 +1041,21 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: The cache info. """ - main_info = { - "graph": self._graphs[0].name + "_graph.json", - "weights": self._graphs[0].name + "_params.bin", - } + main_info = {"graph": self._graphs[0].name + "_graph.json"} with cache_dir: with open(main_info["graph"], "w") as f_graph: f_graph.write(self._graphs[0].to_json()) - with open(main_info["weights"], "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(self._weights[0])) return {"main": main_info} - def _generate_model( - self, graphs: List[MSCGraph] = None, weights: List[Dict[str, tvm.nd.array]] = None - ) -> Any: + def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- @@ -855,11 +1063,9 @@ def _generate_model( The runnable model """ - graph = graphs[0] if graphs else self._graphs[0] - weight = weights[0] if weights else self._weights[0] return self.codegen_func( - graph, - weight, + graphs[0], + weights, codegen_config=self._generate_config.get("codegen"), print_config=self._generate_config.get("print"), build_folder=self._generate_config["build_folder"], @@ -904,34 +1110,33 @@ def visualize(self, visual_dir: msc_utils.MSCDirectory): super().visualize(visual_dir) self._byoc_graph.visualize(visual_dir.relpath(self._byoc_graph.name + ".prototxt")) - def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Translate IRModule to MSCgraphs + Parameters + ------- + mod: tvm.IRModule + The module to be translated. + Returns ------- - graph_list: list + graphs: list The translated graphs - weights_list: list> - The translated weights + weights: dict + The translated weights. """ - self._byoc_mod, graph_infos = self.partition_func( - self._mod, + self._byoc_mod, graphs, weights = self.partition_func( + mod, trans_config=self._translate_config.get("transform"), build_config=self._translate_config.get("build"), ) - graphs, weights = [], [] - for graph, sub_weights in graph_infos: - graphs.append(graph) - weights.append(sub_weights) self._byoc_graph = _ffi_api.BuildFromRelax( self._byoc_mod, "main", msc_utils.dump_dict(self._translate_config.get("build")) ) return graphs, weights - def _load_graphs( - self, cache_dir: msc_utils.MSCDirectory, cache_info: dict - ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _load_graphs(self, cache_dir: msc_utils.MSCDirectory, cache_info: dict) -> List[MSCGraph]: """Load MSCgraphs from cache Parameters @@ -943,10 +1148,8 @@ def _load_graphs( Returns ------- - graph_list: list + graphs: list The translated graphs - weights_list: list> - The translated weights """ assert "byoc_mod" in cache_info, "byoc_mod should be given in cache_info, get " + str( @@ -958,16 +1161,11 @@ def _load_graphs( assert "sub_graphs" in cache_info, "sub_graphs should be given in cache_info, get " + str( cache_info ) - with open(cache_dir.relpath(cache_info["byoc_mod"]), "r") as f: self._byoc_mod = tvm.ir.load_json(f.read()) - graphs, weights = [], [] - for f_graph, f_weights in cache_info["sub_graphs"]: - graphs.append(MSCGraph.from_json(cache_dir.relpath(f_graph))) - with open(cache_dir.relpath(f_weights), "rb") as f: - weights = tvm.runtime.load_param_dict(f.read()) + graphs = [MSCGraph.from_json(cache_dir.relpath(g)) for g in cache_info["sub_graphs"]] self._byoc_graph = MSCGraph.from_json(cache_dir.relpath(cache_info["byoc_graph"])) - return graphs, weights + return graphs def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: """Save MSCgraphs to cache @@ -983,15 +1181,11 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: The cache info. """ - sub_graphs = [ - (graph.name + "_graph.info", graph.name + "_params.bin") for graph in self._graphs - ] + sub_graphs = [g.name + "_graph.info" for g in self._graphs] with cache_dir: - for graph, weights, info in zip(self._graphs, self._weights, sub_graphs): - with open(info[0], "w") as f_graph: + for graph, g_file in zip(self._graphs, sub_graphs): + with open(g_file, "w") as f_graph: f_graph.write(graph.to_json()) - with open(info[1], "wb") as f_params: - f_params.write(tvm.runtime.save_param_dict(weights)) with open("byoc_graph.json", "w") as f: f.write(self._byoc_graph.to_json()) with open("byoc_module.json", "w") as f: @@ -1002,17 +1196,15 @@ def _save_graphs(self, cache_dir: msc_utils.MSCDirectory) -> dict: "byoc_mod": "byoc_module.json", } - def _generate_model( - self, graphs: List[MSCGraph] = None, weights: List[Dict[str, tvm.nd.array]] = None - ) -> Any: + def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- @@ -1020,7 +1212,6 @@ def _generate_model( The relax module """ - graph_infos = list(zip(graphs or self._graphs, weights or self._weights)) extra_option = self._generate_config.get("extra_option", {}) if self._stage == MSCStage.COMPILE and not self.get_tool(ToolType.TRACKER): extra_option["tool_tag"] = "" @@ -1028,7 +1219,8 @@ def _generate_model( extra_option["tool_tag"] = self._name return self.codegen_func( self._byoc_mod, - graph_infos, + graphs, + weights, codegen_configs=self._generate_config.get("codegen"), print_configs=self._generate_config.get("print"), extra_options=extra_option, @@ -1036,17 +1228,13 @@ def _generate_model( output_folder=self._generate_config.get("output_folder", msc_utils.get_output_dir()), ) - def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: + def _build_runnable(self, model: Any) -> Any: """Build runnable object Parameters ------- model: Any - The runnable model on cpu. - device: str - The device for place model - is_training: bool - Whether to load model for training + The meta model. Returns ------- @@ -1055,12 +1243,12 @@ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: """ model = tvm.relax.transform.LegalizeOps()(model) - if device == "cpu": + if self._device == "cpu": target = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): relax_exec = tvm.relax.build(model, target) runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) - elif device.startswith("cuda"): + elif self._device.startswith("cuda"): target = tvm.target.Target("cuda") with target: model = tvm.tir.transform.DefaultGPUSchedule()(model) @@ -1068,7 +1256,7 @@ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: relax_exec = tvm.relax.build(model, target) runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) else: - raise NotImplementedError("Unsupported device " + str(device)) + raise NotImplementedError("Unsupported device " + str(self._device)) return runnable def _call_runnable( diff --git a/python/tvm/contrib/msc/core/tools/distill/distiller.py b/python/tvm/contrib/msc/core/tools/distill/distiller.py index f5c2ca2f8849..58cf3fd2d953 100644 --- a/python/tvm/contrib/msc/core/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/core/tools/distill/distiller.py @@ -45,23 +45,23 @@ def setup(self) -> dict: return super().setup() def _reset( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. """ self._current_iter = 0 @@ -69,8 +69,7 @@ def _reset( if self._distilled: with open(self._weights_path, "rb") as f: distilled_weights = tvm.runtime.load_param_dict(f.read()) - for sub_weights in weights: - sub_weights.update({k: v for k, v in distilled_weights.items() if k in sub_weights}) + weights.update({k: v for k, v in distilled_weights.items() if k in weights}) self._logger.info("Update %d distilled weights", len(distilled_weights)) return super()._reset(graphs, weights) diff --git a/python/tvm/contrib/msc/core/tools/prune/pruner.py b/python/tvm/contrib/msc/core/tools/prune/pruner.py index a946ef1611e7..bb2ff9922073 100644 --- a/python/tvm/contrib/msc/core/tools/prune/pruner.py +++ b/python/tvm/contrib/msc/core/tools/prune/pruner.py @@ -87,28 +87,26 @@ def _update_stages(strategy): return super()._parse_strategys([_update_stages(s) for s in strategy_list]) def _reset( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. """ - self._meta_weights = {} - for sub_weights in weights: - self._meta_weights.update(sub_weights) + self._meta_weights = weights graphs, weights = super()._reset(graphs, weights) if self._plan and self._enabled: return self.prune_graphs(graphs, weights) @@ -302,23 +300,23 @@ def _prunable(w_node: WeightJoint) -> bool: self._plan[w_node.name]["out_indices"] = [] def prune_graphs( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. """ def _prune_by_shape(tensor: MSCTensor, shape: List[int]): @@ -331,17 +329,25 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): shape[channel_axis] = dim return _prune_by_shape(tensor, shape) - new_graphs, new_weights = [], [] - pruned_weights_cnt = 0 - for graph, sub_weights in zip(graphs, weights): - pruned_tensors, pruned_weights = {}, {} + pruned_graphs, pruned_weights = [], {} + pruned_cnt = 0 + for graph in graphs: + pruned_tensors = {} for node in graph.get_nodes(): for weight in node.get_weights().values(): w_node, w_name = self.find_w_node(weight.name), weight.name - if w_name not in self._plan or w_node.get_attr("status", "") == "pruned": - pruned_weights[w_name] = sub_weights[w_name] + if w_name not in self._plan: + pruned_weights[w_name] = weights[w_name] + elif w_node.get_attr("pruned_shape", "") != "": + pruned_weights[w_name] = weights[w_name] + pruned_shape = [int(i) for i in w_node.get_attr("pruned_shape").split(",")] + assert pruned_shape == list( + pruned_weights[w_name].shape + ), "pruned_shape {} mismatch with data shape {}".format( + pruned_shape, pruned_weights[w_name].shape + ) else: - data = msc_utils.cast_array(sub_weights[w_name]) + data = msc_utils.cast_array(weights[w_name]) in_axis, out_axis = self._get_io_axes(self.find_w_node(w_name)) w_config = self._plan[w_name] if w_config["in_indices"]: @@ -350,8 +356,11 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): data = PruneMethod.prune_axis(data, out_axis, w_config["out_indices"]) pruned_tensors[w_name] = _prune_by_shape(weight, data.shape) pruned_weights[w_name] = tvm.nd.array(data) - w_node.set_attr("status", "pruned") - pruned_weights_cnt += 1 + w_node.set_attr( + "pruned_shape", + ",".join([str(i) for i in pruned_tensors[w_name].get_shape()]), + ) + pruned_cnt += 1 if node.optype == "constant": if node.weight_at("const").name not in pruned_tensors: continue @@ -375,12 +384,15 @@ def _prune_by_channel(tensor: MSCTensor, dim, channel_axis: int = None): elif node.optype in self._relation_wtypes: for out in node.get_outputs(): w_node = self.find_w_node(out.name) - if out.name not in self._plan or w_node.get_attr("status", "") == "pruned": + if out.name not in self._plan or w_node.get_attr("pruned_shape", "") != "": continue pruned_tensors[out.name] = _prune_by_channel( out, len(self._plan[out.name]["out_indices"]) ) - w_node.set_attr("status", "pruned") + w_node.set_attr( + "pruned_shape", + ",".join([str(i) for i in pruned_tensors[out.name].get_shape()]), + ) elif node.get_inputs(): ref_input = node.input_at(0) if ref_input.name not in pruned_tensors or ref_input.layout_of("C") < 0: @@ -401,32 +413,28 @@ def _is_pruned(tensor: MSCTensor, graph: MSCGraph) -> bool: if pruned_tensors: pruned_graph = _ffi_api.PruneWeights(graph, pruned_tensors) - new_graphs.append(pruned_graph) + pruned_graphs.append(pruned_graph) else: - new_graphs.append(graph) - new_weights.append(pruned_weights) + pruned_graphs.append(graph) def _flatten_size(weights): - weight_size = 0 - for sub_weights in weights: - for w_data in sub_weights.values(): - weight_size += w_data.asnumpy().size + weight_size = sum([w.asnumpy().size for w in weights.values()]) return weight_size / 2**20 raw_size = _flatten_size(weights) # log compress rate - if pruned_weights_cnt > 0: - new_size = _flatten_size(new_weights) + if pruned_cnt > 0: + new_size = _flatten_size(pruned_weights) self._logger.info( "Prune %d weights, compress to %.2f%% (%.4f M->%.4f M)", - pruned_weights_cnt, + pruned_cnt, new_size * 100 / raw_size, raw_size, new_size, ) else: self._logger.info("No weights pruned, size %.4f M", raw_size) - return new_graphs, new_weights + return pruned_graphs, pruned_weights def get_meta_data(self, name: str) -> np.ndarray: """Get meta weight as np.ndarray diff --git a/python/tvm/contrib/msc/core/tools/tool.py b/python/tvm/contrib/msc/core/tools/tool.py index adeea3b2226a..7253841122ae 100644 --- a/python/tvm/contrib/msc/core/tools/tool.py +++ b/python/tvm/contrib/msc/core/tools/tool.py @@ -178,7 +178,7 @@ def apply(self, *args, **kwargs) -> Any: Returns ------- - plan or tensot: + plan or tensor: The plan generated by method or processed tensor. """ @@ -289,6 +289,8 @@ class BaseTool(object): The plan file path. strategys: list[dict] The strategys of the tool. + training: bool + Whether the tool is training. cache_processed: bool Whether to cache processed tensor. options: dict @@ -306,6 +308,7 @@ def __init__( stage: str, plan_file: str, strategys: List[dict], + training: bool = False, cache_processed: bool = True, options: dict = None, debug_level: int = 0, @@ -318,6 +321,7 @@ def __init__( else: self._plan = {} self._strategys = self._parse_strategys(msc_utils.copy_dict(strategys)) + self._training = training self._cache_processed = cache_processed self._options = options or {} self._debug_level = debug_level @@ -339,7 +343,7 @@ def setup(self) -> dict: """ self._tensor_cache = {} - self._enabled, self._is_training = True, True + self._enabled = True self._graphs, self._weights = [], {} self._graph_id, self._forward_cnt = 0, 0 self._processed_tensor = {} @@ -390,7 +394,11 @@ def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy] if not method: method = msc_utils.get_registered_func(method_name) assert method, "Can not find method with " + str(method_name) - tensor_types = strategy.pop("tensor_types") if "tensor_types" in strategy else ["all"] + tensor_types = ( + strategy.pop("tensor_types") + if "tensor_types" in strategy + else ["input", "output", "weight"] + ) if "op_types" in strategy: op_types = strategy.pop("op_types") marks = [("{}.{}".format(s, t), t) for s, t in product(op_types, tensor_types)] @@ -399,9 +407,9 @@ def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy] marks = [("{}.{}".format(s, t), t) for s, t in product(op_names, tensor_types)] elif "tensor_names" in strategy: tensor_names = strategy.pop("tensor_names") - marks = [(n, "all") for n in tensor_names] + marks = [(n, "tensor") for n in tensor_names] else: - marks = [("default", "all")] + marks = [("default", t) for t in ["input", "output", "weight"]] stages = strategy.pop("stages") if "stages" in strategy else ["default"] for mark, t_type in marks: if mark not in strategys: @@ -415,26 +423,26 @@ def _parse_strategys(self, strategy_list: List[dict]) -> Dict[str, ToolStrategy] def reset( self, graphs: List[MSCGraph], - weights: List[Dict[str, tvm.nd.array]], + weights: Dict[str, tvm.nd.array], cache_dir: msc_utils.MSCDirectory = None, - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Reset the tool with graphs and weights Parameters ---------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. cache_dir: MSCDirectory - cache path for save/load info + cache path for save/load info. Returns ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. """ self._forward_cnt = 0 @@ -445,36 +453,33 @@ def reset( cache_info = {} if self.tool_type() in cache_info: self.load_cache(cache_dir, cache_info[self.tool_type()]) - self._graphs, weights = self._reset(graphs, weights) - self._weights = {} - for sub_weights in weights: - self._weights.update(sub_weights) + self._graphs, self._weights = self._reset(graphs, weights) self._logger.debug( "%s reset %d graphs, %d weights", self.tool_type(), len(self._graphs), len(self._weights), ) - return self._graphs, weights + return self._graphs, self._weights def _reset( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. """ return graphs, weights @@ -889,12 +894,12 @@ def disable(self): def train(self): """Set the tool to train mode""" - self._is_training = True + self._training = True def eval(self): """Set the tool to eval mode""" - self._is_training = False + self._training = False def to_tensor_id(self, name: str, consumer: str) -> str: """Concat name to unique id @@ -1210,28 +1215,18 @@ def _get_tensor_strategys(self, name: str, consumer: str) -> List[ToolStrategy]: if mark not in self._tensor_cache.get(tensor_id, {}): if self.is_weight(name): consumer = self.find_node(consumer) - name_refs = [ - consumer.name + ".weight", - consumer.optype + ".weight", - consumer.optype + ".all", - ] + name_refs = [consumer.name + ".weight", consumer.optype + ".weight"] elif consumer == "exit": producer = self.find_producer(name) - name_refs = [ - producer.name + ".output", - producer.optype + ".output", - producer.optype + ".all", - ] + name_refs = [producer.name + ".output", producer.optype + ".output"] else: consumer = self.find_node(consumer) producer = self.find_producer(name) name_refs = [ producer.name + ".output", producer.optype + ".output", - producer.optype + ".all", consumer.name + ".input", consumer.optype + ".input", - consumer.optype + ".all", ] strategys = [] tensor_strategy = self._strategys.get(tensor_id) @@ -1303,25 +1298,23 @@ def setup(self) -> dict: return super().setup() def _reset( - self, graphs: List[MSCGraph], weights: List[Dict[str, tvm.nd.array]] - ) -> Tuple[List[MSCGraph], List[Dict[str, tvm.nd.array]]]: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + ) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Reset the tool Parameters ---------- graphs: list The msc graphs. - weights: list> - The weights - as_cache: bool - Whether the graphs and weights are loaded from cache + weights: dict + The weights. Returns ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. """ graphs, weights = super()._reset(graphs, weights) diff --git a/python/tvm/contrib/msc/core/tools/track/method.py b/python/tvm/contrib/msc/core/tools/track/method.py index aaa07b381226..a86a6af881f3 100644 --- a/python/tvm/contrib/msc/core/tools/track/method.py +++ b/python/tvm/contrib/msc/core/tools/track/method.py @@ -49,12 +49,8 @@ def save_compared( The name of the tensor. consumer: str The name of the consumer. - stage: str - The current stage of tool. compare_to: dict The compare config - dataset: MSCDirectory - The root dir Returns ------- @@ -76,12 +72,12 @@ def save_compared( continue golden = tracker._loaders[stage].load_data(name, tracker._forward_cnt) report = msc_utils.compare_arrays({name: golden}, {name: data}) - diff_msg = "{}diff to {} -> {}".format( - tracker.msg_mark(), stage, report["info"][name] + diff_msg = "{}{} to {} -> {}".format( + tracker.msg_mark(), name, stage, report["info"][name] ) if report["passed"] == 0: tracker._logger.info(diff_msg) - elif tracker.on_debug(3): + elif tracker.on_debug(): tracker._logger.debug(diff_msg) diffs[stage] = { "pass": report["passed"] == 1, diff --git a/python/tvm/contrib/msc/core/tools/track/tracker.py b/python/tvm/contrib/msc/core/tools/track/tracker.py index 442ac6f508e0..e43a390e850f 100644 --- a/python/tvm/contrib/msc/core/tools/track/tracker.py +++ b/python/tvm/contrib/msc/core/tools/track/tracker.py @@ -33,9 +33,16 @@ def setup(self) -> dict: The setup info. """ + # filter plan + def _filter_info(info: dict) -> dict: + return {k: v for k, v in info.items() if k != self._stage} + + self._plan = {k: _filter_info(v) for k, v in self._plan.items()} data_folder = msc_utils.get_dataset_dir().create_dir("Track") self._loaders = {} for folder in data_folder.listdir(): + if folder == self._stage: + continue if msc_utils.is_simple_dataset(data_folder.relpath(folder)): self._loaders[folder] = msc_utils.SimpleDataLoader(data_folder.relpath(folder)) self._saver = msc_utils.SimpleDataSaver(data_folder.relpath(self._stage)) @@ -163,12 +170,9 @@ def _track_tensor( if self._stage in self._plan.get(name, {}): return tensor - if name not in self._plan: - self._plan[name] = {} - plan = {} + plan = self._plan.setdefault(name, {}).setdefault(self._stage, {}) for strategy in strategys: plan.update(strategy(self, tensor, name, consumer)) - self._plan[name][self._stage] = plan return tensor @classmethod diff --git a/python/tvm/contrib/msc/core/transform/pattern.py b/python/tvm/contrib/msc/core/transform/pattern.py index 6d7ac5ba1311..fdc6a628310d 100644 --- a/python/tvm/contrib/msc/core/transform/pattern.py +++ b/python/tvm/contrib/msc/core/transform/pattern.py @@ -17,7 +17,7 @@ # pylint: disable=unused-argument """tvm.contrib.msc.core.transform.pattern""" -from typing import Mapping, Tuple, Dict +from typing import Mapping, Tuple, Dict, List from functools import partial import tvm @@ -29,10 +29,14 @@ from tvm.relay.op.contrib.register import register_pattern_table from tvm.contrib.msc.core.utils.namespace import MSCMap, MSCKey from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.core import _ffi_api def msc_attrs_getter( - annotated_expr: Dict[str, tvm.relax.Expr], anchor: str = "out" + annotated_expr: Dict[str, tvm.relax.Expr], + anchor: str = "out", + output: str = None, + inputs: List[str] = None, ) -> Dict[str, str]: """Get attributes for fused pattern @@ -49,6 +53,8 @@ def msc_attrs_getter( The extra attributes for msc. """ + attrs = {} + # get name fused_cnt = MSCMap.get(MSCKey.FUSED_CNT, 0) unique_name = "msc_fused_" + str(fused_cnt) if anchor in annotated_expr: @@ -56,7 +62,23 @@ def msc_attrs_getter( if name: unique_name = name MSCMap.set(MSCKey.FUSED_CNT, fused_cnt + 1) - return {"unique_name": unique_name} + attrs[_ffi_api.ToAttrKey("unique")] = unique_name + # get output layout + output = output or anchor + if output in annotated_expr: + attrs[_ffi_api.ToAttrKey("layout")] = msc_utils.get_expr_layout(annotated_expr[output]) + if inputs: + layouts = {} + for i in inputs: + if i not in annotated_expr: + continue + in_name = msc_utils.get_expr_name(annotated_expr[i]) + if not in_name: + continue + layouts[in_name] = msc_utils.get_expr_layout(annotated_expr[i]) + if layouts: + attrs[_ffi_api.ToAttrKey("input_layouts")] = layouts + return attrs def make_relax_conv_bias_pattern( @@ -87,7 +109,14 @@ def make_relax_conv_bias_pattern( shape = relax_pattern.wildcard() reshape = relax_pattern.is_op("relax.reshape")(bias, shape) out = relax_pattern.is_op("relax.add")(conv, reshape) - annotations = {"conv": conv, "bias": bias, "reshape": reshape, "out": out} + annotations = { + "data": data, + "weight": weight, + "conv": conv, + "bias": bias, + "reshape": reshape, + "out": out, + } return out, annotations @@ -126,7 +155,7 @@ def make_relax_linear_pattern() -> ( weight = relax_pattern.is_const() permute = relax_pattern.is_op("relax.permute_dims")(weight) out = relax_pattern.is_op("relax.matmul")(data, permute) - annotations = {"weight": weight, "permute": permute, "linear": out} + annotations = {"data": data, "weight": weight, "permute": permute, "linear": out} return out, annotations @@ -203,7 +232,7 @@ def make_relax_embedding_pattern() -> ( data = relax_pattern.wildcard() astype = relax_pattern.is_op("relax.astype")(data) out = relax_pattern.is_op("relax.take")(weight, astype) - annotations = {"weight": weight, "astype": astype, "take": out} + annotations = {"data": data, "weight": weight, "astype": astype, "take": out} return out, annotations @@ -250,6 +279,7 @@ def make_relax_reshape_embedding_pattern() -> ( expand_shape = relax_pattern.wildcard() out = relax_pattern.is_op("relax.reshape")(take, expand_shape) annotations = { + "data": data, "weight": weight, "astype": astype, "reduce_in": reduce_in, @@ -301,7 +331,15 @@ def make_relax_attention_pattern() -> ( k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) out = relax_pattern.is_op("relax.nn.attention")(q_trans, k_trans, v_trans) - annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans, "attention": out} + annotations = { + "weight_q": weight_q, + "weight_k": weight_k, + "weight_v": weight_v, + "q_trans": q_trans, + "k_trans": k_trans, + "v_trans": v_trans, + "attention": out, + } return out, annotations @@ -341,7 +379,16 @@ def make_relax_mask_attention_pattern() -> ( k_trans = relax_pattern.is_op("relax.permute_dims")(weight_k) v_trans = relax_pattern.is_op("relax.permute_dims")(weight_v) out = relax_pattern.is_op("relax.nn.attention_bias")(q_trans, k_trans, v_trans, mask) - annotations = {"q_trans": q_trans, "k_trans": k_trans, "v_trans": v_trans, "attention": out} + annotations = { + "weight_q": weight_q, + "weight_k": weight_k, + "weight_v": weight_v, + "mask": mask, + "q_trans": q_trans, + "k_trans": k_trans, + "v_trans": v_trans, + "attention": out, + } return out, annotations @@ -421,7 +468,7 @@ def make_opt_relax_linear_pattern() -> ( data = relax_pattern.wildcard() weight = relax_pattern.is_const() out = relax_pattern.is_op("relax.matmul")(data, weight) - annotations = {"weight": weight, "linear": out} + annotations = {"data": data, "weight": weight, "linear": out} return out, annotations @@ -459,7 +506,7 @@ def make_opt_relax_linear_bias_pattern() -> ( linear = relax_pattern.is_op("relax.matmul")(data, weight) bias = relax_pattern.is_const() out = relax_pattern.is_op("relax.add")(linear, bias) - annotations = {"weight": weight, "bias": bias, "linear": linear, "out": out} + annotations = {"data": data, "weight": weight, "bias": bias, "linear": linear, "out": out} return out, annotations @@ -488,7 +535,7 @@ def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: "relax.nn.conv1d", ), _check_opt_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv"), + partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), ), ( "msc.conv2d_bias", @@ -496,19 +543,19 @@ def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: "relax.nn.conv2d", ), _check_opt_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv"), + partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), ), ( "msc.linear", *make_opt_relax_linear_pattern(), _check_opt_relax_linear, - partial(msc_attrs_getter, anchor="linear"), + partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight"]), ), ( "msc.linear_bias", *make_opt_relax_linear_bias_pattern(), _check_opt_relax_linear_bias, - partial(msc_attrs_getter, anchor="linear"), + partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight", "bias"]), ), ( "msc.conv1d_bias", @@ -516,7 +563,7 @@ def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: "relax.nn.conv1d", ), _check_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv"), + partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), ), ( "msc.conv2d_bias", @@ -524,43 +571,49 @@ def _check_opt_relax_linear_bias(context: PatternCheckContext) -> bool: "relax.nn.conv2d", ), _check_relax_conv_bias, - partial(msc_attrs_getter, anchor="conv"), + partial(msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"]), ), ( "msc.linear", *make_relax_linear_pattern(), _check_relax_linear, - partial(msc_attrs_getter, anchor="linear"), + partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight"]), ), ( "msc.linear_bias", *make_relax_linear_bias_pattern(), _check_relax_linear_bias, - partial(msc_attrs_getter, anchor="linear"), + partial(msc_attrs_getter, anchor="linear", inputs=["data", "weight", "bias"]), ), ( "msc.embedding", *make_relax_embedding_pattern(), _check_relax_embedding, - partial(msc_attrs_getter, anchor="take"), + partial(msc_attrs_getter, anchor="take", inputs=["data", "weight"]), ), ( "msc.embedding", *make_relax_reshape_embedding_pattern(), _check_relax_reshape_embedding, - partial(msc_attrs_getter, anchor="take"), + partial(msc_attrs_getter, anchor="take", output="out", inputs=["data", "weight"]), ), ( "msc.attention", *make_relax_attention_pattern(), _check_relax_attention, - partial(msc_attrs_getter, anchor="attention"), + partial( + msc_attrs_getter, anchor="attention", inputs=["weight_q", "weight_k", "weight_v"] + ), ), ( "msc.attention", *make_relax_mask_attention_pattern(), _check_relax_mask_attention, - partial(msc_attrs_getter, anchor="attention"), + partial( + msc_attrs_getter, + anchor="attention", + inputs=["weight_q", "weight_k", "weight_v", "mask"], + ), ), ] ) diff --git a/python/tvm/contrib/msc/core/transform/transform.py b/python/tvm/contrib/msc/core/transform/transform.py index 8bd4ca952177..ddcfffc210fa 100644 --- a/python/tvm/contrib/msc/core/transform/transform.py +++ b/python/tvm/contrib/msc/core/transform/transform.py @@ -86,7 +86,7 @@ def SetExprLayout(allow_missing: bool = True, entry_name: str = "main") -> tvm.i return relax_api.SetExprLayout(allow_missing, entry_name) # type: ignore -def BindShape(entry_name: str = "main") -> tvm.ir.transform.Pass: +def InlineParams(entry_name: str = "main") -> tvm.ir.transform.Pass: """Bind ShapeExpr to reshape Parameters @@ -99,7 +99,7 @@ def BindShape(entry_name: str = "main") -> tvm.ir.transform.Pass: ret: tvm.ir.transform.Pass """ - return relax_api.BindShape(entry_name) # type: ignore + return relax_api.InlineParams(entry_name) # type: ignore def FuseTuple(target, entry_name: str = "main") -> tvm.ir.transform.Pass: diff --git a/python/tvm/contrib/msc/core/utils/__init__.py b/python/tvm/contrib/msc/core/utils/__init__.py index a76659609d5e..d413ce1dbedd 100644 --- a/python/tvm/contrib/msc/core/utils/__init__.py +++ b/python/tvm/contrib/msc/core/utils/__init__.py @@ -24,3 +24,4 @@ from .dataset import * from .log import * from .message import * +from .arguments import * diff --git a/python/tvm/contrib/msc/core/utils/arguments.py b/python/tvm/contrib/msc/core/utils/arguments.py new file mode 100644 index 000000000000..dba54da3a4e8 --- /dev/null +++ b/python/tvm/contrib/msc/core/utils/arguments.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""tvm.contrib.msc.core.utils.arguments""" + +import os +import json +import copy +import numpy as np +from .info import MSCArray + + +def load_dict(str_dict: str, flavor: str = "json") -> dict: + """Load the string/file to dict. + + Parameters + ---------- + str_dict: string + The file_path or string object. + flavor: str + The flavor for load. + + Returns + ------- + dict_obj: dict + The loaded dict. + """ + + if isinstance(str_dict, str) and os.path.isfile(str_dict): + with open(str_dict, "r") as f: + dict_obj = json.load(f) + elif isinstance(str_dict, str): + dict_obj = json.loads(str_dict) + elif isinstance(str_dict, dict): + dict_obj = copy_dict(str_dict) + else: + raise Exception("Unexpected str_dict {}({})".format(str_dict, type(str_dict))) + assert flavor == "json", "Unexpected flavor for load_dict: " + str(flavor) + return dict_obj + + +def update_dict(src_dict: dict, new_dict: dict, soft_update: bool = True) -> dict: + """Update src_dict with new_dict. + + Parameters + ---------- + src_dict: dict + The source dict. + new_dict: dict + The new dict. + soft_update: bool + Whether to update the source dict, False to force update. + + Returns + ------- + dict_obj: dict + The updated dict. + """ + + assert isinstance(src_dict, dict) and isinstance( + new_dict, dict + ), "update_dict only support dict, get src {} and new {}".format(type(src_dict), type(new_dict)) + for k, v in new_dict.items(): + if isinstance(v, dict): + v = update_dict(src_dict.get(k, {}), v, soft_update) + src_dict[k] = v + elif not soft_update or k not in src_dict: + src_dict[k] = v + return src_dict + + +def dump_dict(dict_obj: dict, flavor: str = "dmlc") -> str: + """Dump the config to string. + + Parameters + ---------- + src_dict: dict + The source dict. + flavor: str + The flavor for dumps. + + Returns + ------- + str_dict: string + The dumped string. + """ + + if not dict_obj: + return "" + if flavor == "dmlc": + return json.dumps({k: int(v) if isinstance(v, bool) else v for k, v in dict_obj.items()}) + if flavor.startswith("table:"): + + def _get_lines(value, indent=2): + max_size = int(flavor.split(":")[1]) - indent - 2 + lines = [] + for k, v in value.items(): + if v is None: + continue + if isinstance(v, (dict, tuple, list)) and not v: + continue + if isinstance(v, dict) and len(str(k) + str(v)) > max_size: + lines.append("{}{}:".format(indent * " ", k)) + lines.extend(_get_lines(v, indent + 2)) + elif isinstance(v, (tuple, list)) and len(str(k) + str(v)) > max_size: + if all(isinstance(e, (int, float)) for e in v): + lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) + else: + lines.append("{}{}:".format(indent * " ", k)) + lines.extend( + [ + "{}<{}>{}".format((indent + 2) * " ", idx, ele) + for idx, ele in enumerate(v) + ] + ) + elif isinstance(v, bool): + lines.append("{}{}: {}".format(indent * " ", k, "true" if v else "false")) + elif isinstance(v, np.ndarray): + lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) + else: + lines.append("{}{}: {}".format(indent * " ", k, v)) + return lines + + lines = _get_lines(dict_obj) or [" {}: {}".format(k, v) for k, v in dict_obj.items()] + return "\n".join(lines) + return json.dumps(dict_obj) + + +def dict_equal(dict_a: dict, dict_b: dict) -> bool: + """Check if two dicts are the same. + + Parameters + ---------- + dict_a: dict + The A dict. + dict_b: dict + The B dict. + + Returns + ------- + equal: bool + Whether two dicts are the same. + """ + + if not isinstance(dict_a, dict) or not isinstance(dict_b, dict): + return False + if dict_a.keys() != dict_b.keys(): + return False + for k, v in dict_a.items(): + if not isinstance(v, type(dict_b[k])): + return False + if isinstance(v, dict) and not dict_equal(v, dict_b[k]): + return False + if v != dict_b[k]: + return False + return True + + +def copy_dict(dict_obj: dict) -> dict: + """Deepcopy dict object + + Parameters + ---------- + dict_obj: dict + The source dict. + + Returns + ------- + dict_obj: dict + The copied dict. + """ + + if not dict_obj: + return {} + try: + return copy.deepcopy(dict_obj) + except: # pylint: disable=bare-except + new_dict = {} + for k, v in dict_obj.items(): + if isinstance(v, (list, tuple)): + new_dict[k] = [copy_dict(e) for e in v] + elif isinstance(v, dict): + new_dict[k] = copy_dict(v) + else: + new_dict[k] = v + return new_dict + + +def map_dict(dict_obj: dict, mapper: callable) -> dict: + """Apply mapper to dict object + + Parameters + ---------- + dict_obj: dict + The source dict. + mapper: callable + The mapper function. + + Returns + ------- + new_dict: dict + The mapped dict. + """ + + if not dict_obj: + return {} + new_dict = {} + for k, v in dict_obj.items(): + if isinstance(v, (tuple, list)): + new_dict[k] = [map_dict(e, mapper) if isinstance(e, dict) else e for e in v] + elif isinstance(v, dict): + new_dict[k] = map_dict(v, mapper) + else: + new_dict[k] = mapper(v) + return new_dict diff --git a/python/tvm/contrib/msc/core/utils/dataset.py b/python/tvm/contrib/msc/core/utils/dataset.py index a96369b320f1..8ca8d8ae1a0d 100644 --- a/python/tvm/contrib/msc/core/utils/dataset.py +++ b/python/tvm/contrib/msc/core/utils/dataset.py @@ -23,7 +23,7 @@ from typing import List, Union, Dict, Any import numpy as np -from .info import load_dict +from .arguments import load_dict class BaseDataLoader(object): @@ -433,6 +433,8 @@ def finalize(self): info = self._info["inputs"][name] f.write("{} {} {}\n".format(name, info.get("save_name", name), info["bytes"])) for name in self._output_names: + if name not in self._info["outputs"]: + continue info = self._info["outputs"][name] f.write("{} {} {}\n".format(name, info.get("save_name", name), info["bytes"])) @@ -501,6 +503,8 @@ def save_batch( def is_io_dataset(folder: str) -> bool: """Check if a folder is IO dataset""" + if not isinstance(folder, str): + return False if not os.path.isfile(os.path.join(folder, "datas_info.json")): return False data_info = load_dict(os.path.join(folder, "datas_info.json")) diff --git a/python/tvm/contrib/msc/core/utils/expr.py b/python/tvm/contrib/msc/core/utils/expr.py index 9158381eb9e1..fa9f339a7524 100644 --- a/python/tvm/contrib/msc/core/utils/expr.py +++ b/python/tvm/contrib/msc/core/utils/expr.py @@ -25,7 +25,7 @@ def get_expr_name(expr: relax.Expr) -> str: - """Get name hint ofr expr + """Get name hint for expr Parameters ---------- @@ -38,7 +38,7 @@ def get_expr_name(expr: relax.Expr) -> str: The name_hint of expr """ - name = _ffi_api.SpanGetAttr(expr.span, "name") + name = _ffi_api.SpanGetAttr(expr.span, _ffi_api.ToAttrKey("name")) if not name and isinstance(expr, relax.Var): return expr.name_hint return name @@ -60,10 +60,27 @@ def set_expr_name(expr: relax.Expr, name: str): The expr with name. """ - expr.span = _ffi_api.SpanSetAttr(expr.span, "name", name) + expr.span = _ffi_api.SpanSetAttr(expr.span, _ffi_api.ToAttrKey("name"), name) return expr +def get_expr_layout(expr: relax.Expr) -> str: + """Get layout for expr + + Parameters + ---------- + expr: Expr + The Expr of relax. + + Returns + ------- + layout: str + The layout of expr + """ + + return _ffi_api.SpanGetAttr(expr.span, _ffi_api.ToAttrKey("layout")) + + def get_span_attrs(mod: tvm.IRModule) -> dict: """Extract the span attributes from relax.Function. diff --git a/python/tvm/contrib/msc/core/utils/file.py b/python/tvm/contrib/msc/core/utils/file.py index 5a6f6c0ff649..49912b4d041b 100644 --- a/python/tvm/contrib/msc/core/utils/file.py +++ b/python/tvm/contrib/msc/core/utils/file.py @@ -20,6 +20,7 @@ import shutil import tempfile import types +import subprocess from functools import partial from typing import List, Any, Union from importlib.machinery import SourceFileLoader @@ -28,6 +29,31 @@ from .register import get_registered_func +def is_callable(name: str, framework: str = MSCFramework.MSC) -> bool: + """Check if name is callable. + + Parameters + ---------- + name: string + The name of the registered func or path:f_name str. + framework: string + Should be from MSCFramework. + + Returns + ------- + is_callable: bool + Whether the name is callable + """ + + func = get_registered_func(name, framework) + if func: + return True + if ".py:" in name: + path, _ = name.split(":") + return os.path.isfile(path) + return False + + def load_callable(name: str, framework: str = MSCFramework.MSC) -> callable: """Load a callable object. @@ -358,6 +384,35 @@ def to_abs_path(path: str, root_dir: MSCDirectory = None, keep_history: bool = T return root_dir.relpath(path, keep_history) +def pack_folder(path: str, style="tar"): + """Pack the folder + + Parameters + ---------- + path: str + The path of the folder. + style: str + The pack style. + + Returns + ------- + pack_path: str + The packed path. + """ + + root = os.path.dirname(path) + if style == "tar": + cmd = "tar --exculde={0}.tar.gz -zcvf {0}.tar.gz {0} && rm -rf {0}".format(path) + else: + raise NotImplementedError("Pack style {} is not supported".format(style)) + if root: + with msc_dir(root): + retcode = subprocess.call(cmd, shell=True) + else: + retcode = subprocess.call(cmd, shell=True) + assert retcode == 0, "Failed to pack the folder {}({}): {}".format(path, style, retcode) + + get_build_dir = partial(get_workspace_subdir, name="Build") get_cache_dir = partial(get_workspace_subdir, name="Cache") get_config_dir = partial(get_workspace_subdir, name="Config") diff --git a/python/tvm/contrib/msc/core/utils/info.py b/python/tvm/contrib/msc/core/utils/info.py index d1b5cd1a2644..49d2bdd96a9b 100644 --- a/python/tvm/contrib/msc/core/utils/info.py +++ b/python/tvm/contrib/msc/core/utils/info.py @@ -16,9 +16,6 @@ # under the License. """tvm.contrib.msc.core.utils.info""" -import os -import json -import copy from typing import List, Tuple, Dict, Any, Union from distutils.version import LooseVersion import numpy as np @@ -38,26 +35,34 @@ class MSCArray(object): """ def __init__(self, data: Any): - self._type, self._data = self._analysis(data) + self._type, self._device, self._data = self._analysis(data) def __str__(self): return "<{}>{}".format(self._type, self.abstract()) - def _analysis(self, data: Any) -> Tuple[str, np.ndarray]: + def _analysis(self, data: Any) -> Tuple[str, str, np.ndarray]: if isinstance(data, (list, tuple)) and all(isinstance(d, (int, float)) for d in data): - return "list", np.array(data) + return "list", "cpu", np.array(data) if isinstance(data, np.ndarray): - return "np", data + return "np", "cpu", data if isinstance(data, tvm.runtime.NDArray): - return "tvm", data.asnumpy() + device = tvm.runtime.Device.MASK2STR[data.device.device_type] + if data.device.device_id: + device += ":{}".format(data.device.device_id) + return "tvm", device, data.asnumpy() if isinstance(data, tvm.relax.Var): shape = [int(s) for s in data.struct_info.shape] - return "var", np.zeros(shape, dtype=data.struct_info.dtype) + return "var", "cpu", np.zeros(shape, dtype=data.struct_info.dtype) try: import torch # pylint: disable=import-outside-toplevel if isinstance(data, torch.Tensor): - return "torch", data.detach().cpu().numpy() + ref_dev = data.device + if ref_dev.index: + device = "{}:{}".format(ref_dev.type, ref_dev.index) + else: + device = ref_dev.type + return "torch", device, data.detach().cpu().numpy() except: # pylint: disable=bare-except pass @@ -74,6 +79,38 @@ def abstract(self) -> str: self._data.sum() / self._data.size, ) + def cast(self, framework: str, device: str = None) -> Any: + """Cast np.ndarray to array like object + + Parameters + ---------- + framework: str + The target framework. + device: str + The device for tensor. + + Returns + ------- + output: + The output as framework tensor. + """ + + device = device or self._device + if framework == MSCFramework.TORCH: + import torch # pylint: disable=import-outside-toplevel + + return torch.from_numpy(self._data).to(torch.device(device)) + if framework == MSCFramework.TVM: + if device.startswith("cpu"): + t_device = tvm.cpu() + elif device.startswith("cuda"): + dev_id = int(device.split(":")[1]) if ":" in device else 0 + t_device = tvm.cuda(dev_id) + else: + raise NotImplementedError("device {} is not supported for tvm") + return tvm.nd.array(self._data, device=t_device) + return self._data + @classmethod def is_array(cls, data: Any) -> bool: """Check if the data is array like @@ -108,27 +145,37 @@ def is_array(cls, data: Any) -> bool: def type(self): return self._type + @property + def device(self): + return self._device + @property def data(self): return self._data -def cast_array(data: Any) -> np.ndarray: +def cast_array(data: Any, framework: str = None, device: str = None) -> Any: """Cast array like object to np.ndarray Parameters ---------- data: array_like: np.ndarray| torch.Tensor| tvm.ndarray| ... The data object. + framework: str + The target framework. + device: str + The device for tensor. Returns ------- output: np.ndarray - The output as numpy array. + The output as numpy array or framework tensor(if given). """ assert MSCArray.is_array(data), "{} is not array like".format(data) - return MSCArray(data).data + if not framework: + return MSCArray(data).data + return MSCArray(data).cast(framework, device) def inspect_array(data: Any, as_str: bool = True) -> Union[Dict[str, Any], str]: @@ -213,173 +260,6 @@ def compare_arrays( return report -def load_dict(str_dict: str, flavor: str = "json") -> dict: - """Load the string/file to dict. - - Parameters - ---------- - str_dict: string - The file_path or string object. - flavor: str - The flavor for load. - - Returns - ------- - dict_obj: dict - The loaded dict. - """ - - if isinstance(str_dict, str) and os.path.isfile(str_dict): - with open(str_dict, "r") as f: - dict_obj = json.load(f) - elif isinstance(str_dict, str): - dict_obj = json.loads(str_dict) - elif isinstance(str_dict, dict): - dict_obj = copy_dict(str_dict) - else: - raise Exception("Unexpected str_dict {}({})".format(str_dict, type(str_dict))) - assert flavor == "json", "Unexpected flavor for load_dict: " + str(flavor) - return dict_obj - - -def update_dict( - src_dict: dict, new_dict: dict, recursive: bool = True, soft_update: bool = True -) -> dict: - """Update src_dict with new_dict. - - Parameters - ---------- - src_dict: dict - The source dict. - new_dict: dict - The new dict. - recursive: bool - Whether to update the dict recursive. - soft_update: bool - Whether to update the source dict, False to force update. - - Returns - ------- - dict_obj: dict - The updated dict. - """ - - assert isinstance(src_dict, dict) and isinstance( - new_dict, dict - ), "update_dict only support dict, get src {} and new {}".format(type(src_dict), type(new_dict)) - for k, v in new_dict.items(): - if isinstance(v, dict): - v = update_dict(src_dict.get(k, {}), v, recursive, soft_update) - src_dict[k] = v - elif not soft_update or k not in src_dict: - src_dict[k] = v - return src_dict - - -def dump_dict(dict_obj: dict, flavor: str = "dmlc") -> str: - """Dump the config to string. - - Parameters - ---------- - src_dict: dict - The source dict. - flavor: str - The flavor for dumps. - - Returns - ------- - str_dict: string - The dumped string. - """ - - if not dict_obj: - return "" - if flavor == "dmlc": - return json.dumps({k: int(v) if isinstance(v, bool) else v for k, v in dict_obj.items()}) - if flavor.startswith("table:"): - - def _get_lines(value, indent=2): - max_size = int(flavor.split(":")[1]) - indent - 2 - lines = [] - for k, v in value.items(): - if isinstance(v, (dict, tuple, list)) and not v: - continue - if isinstance(v, dict) and len(str(k) + str(v)) > max_size: - lines.append("{}{}:".format(indent * " ", k)) - lines.extend(_get_lines(v, indent + 2)) - elif isinstance(v, (tuple, list)) and len(str(k) + str(v)) > max_size: - if all(isinstance(e, (int, float)) for e in v): - lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) - else: - lines.append("{}{}:".format(indent * " ", k)) - lines.extend( - [ - "{}<{}>{}".format((indent + 2) * " ", idx, ele) - for idx, ele in enumerate(v) - ] - ) - elif isinstance(v, bool): - lines.append("{}{}: {}".format(indent * " ", k, "true" if v else "false")) - elif isinstance(v, np.ndarray): - lines.append("{}{}: {}".format(indent * " ", k, MSCArray(v).abstract())) - else: - lines.append("{}{}: {}".format(indent * " ", k, v)) - return lines - - lines = _get_lines(dict_obj) or [" {}: {}".format(k, v) for k, v in dict_obj.items()] - return "\n".join(lines) - return json.dumps(dict_obj) - - -def dict_equal(dict_a: dict, dict_b: dict) -> bool: - """Check if two dicts are the same. - - Parameters - ---------- - dict_a: dict - The A dict. - dict_b: dict - The B dict. - - Returns - ------- - equal: bool - Whether two dicts are the same. - """ - - if not isinstance(dict_a, dict) or not isinstance(dict_b, dict): - return False - if dict_a.keys() != dict_b.keys(): - return False - for k, v in dict_a.items(): - if not isinstance(v, type(dict_b[k])): - return False - if isinstance(v, dict) and not dict_equal(v, dict_b[k]): - return False - if v != dict_b[k]: - return False - return True - - -def copy_dict(dict_obj: dict) -> dict: - """Deepcopy dict object - - Parameters - ---------- - dict_obj: dict - The source dict. - - Returns - ------- - dict_obj: dict - The copied dict. - """ - - if not dict_obj: - return {} - return copy.deepcopy(dict_obj) - - def get_version(framework: str) -> List[int]: """Get the version list of framework. diff --git a/python/tvm/contrib/msc/core/utils/log.py b/python/tvm/contrib/msc/core/utils/log.py index f208ecde95db..916eb2468860 100644 --- a/python/tvm/contrib/msc/core/utils/log.py +++ b/python/tvm/contrib/msc/core/utils/log.py @@ -69,7 +69,7 @@ def create_file_logger(level: Union[str, int] = logging.INFO, path: str = None) """ if isinstance(level, str): - if level == "debug": + if level.startswith("debug"): level = logging.DEBUG elif level == "info": level = logging.INFO diff --git a/python/tvm/contrib/msc/core/utils/message.py b/python/tvm/contrib/msc/core/utils/message.py index 4f93d402a004..7ff0e187b05b 100644 --- a/python/tvm/contrib/msc/core/utils/message.py +++ b/python/tvm/contrib/msc/core/utils/message.py @@ -20,7 +20,7 @@ import logging from typing import List -from .info import dump_dict +from .arguments import dump_dict from .log import get_global_logger from .namespace import MSCMap, MSCKey diff --git a/python/tvm/contrib/msc/core/utils/register.py b/python/tvm/contrib/msc/core/utils/register.py index 50f0b8cd17b3..855c28f8b4b2 100644 --- a/python/tvm/contrib/msc/core/utils/register.py +++ b/python/tvm/contrib/msc/core/utils/register.py @@ -32,6 +32,7 @@ class MSCRegistery: GYM_AGENTS = "gym_agents" GYM_ENVS = "gym_envs" GYM_METHODS = "gym_agents_method" + RUNNER_HOOKS = "runner_hooks" @classmethod def register(cls, key: str, value: Any): @@ -157,9 +158,8 @@ def register_tool_method(method_cls: Any, method_style: str = "default"): """ tools_method = MSCRegistery.get(MSCRegistery.MSC_TOOLS_METHOD, {}) - assert hasattr(method_cls, "framework") and hasattr( - method_cls, "tool_type" - ), "framework and tool_type should be given to register tool method" + for key in ["framework", "tool_type"]: + assert hasattr(method_cls, key), "{} should be given to register tool method".format(key) if method_cls.framework() not in tools_method: tools_method[method_cls.framework()] = {} register_name = "{}.{}".format(method_cls.tool_type(), method_style) @@ -342,7 +342,7 @@ def register_gym_method(method: Any): def get_registered_gym_method(method_type: str) -> Any: - """Get the registered agent. + """Get the registered gym method. Parameters ---------- @@ -357,3 +357,36 @@ def get_registered_gym_method(method_type: str) -> Any: methods = MSCRegistery.get(MSCRegistery.GYM_METHODS, {}) return methods.get(method_type) + + +def register_runner_hook(hook: Any): + """Register a runner hook. + + Parameters + ---------- + hook: class + The hook class. + """ + + hooks = MSCRegistery.get(MSCRegistery.RUNNER_HOOKS, {}) + assert hasattr(hook, "name"), "name should be given to register hook" + hooks[hook.name()] = hook + MSCRegistery.register(MSCRegistery.RUNNER_HOOKS, hooks) + + +def get_registered_runner_hook(name: str) -> Any: + """Get the registered runner hook. + + Parameters + ---------- + name: str + The name hook. + + Returns + ------- + method: class + The method class. + """ + + hooks = MSCRegistery.get(MSCRegistery.RUNNER_HOOKS, {}) + return hooks.get(name) diff --git a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py index 4617c5d351b6..c33fc89fa790 100644 --- a/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorflow/runtime/runner.py @@ -18,7 +18,7 @@ """tvm.contrib.msc.framework.tensorflow.runtime.runner""" import time -from typing import Dict, List, Union, Any +from typing import Dict, List, Union, Any, Tuple import numpy as np from tensorflow.python.client import device_lib @@ -27,7 +27,9 @@ import tvm from tvm.contrib.msc.core.ir import MSCGraph from tvm.contrib.msc.core.runtime import ModelRunner +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow from tvm.contrib.msc.framework.tensorflow.codegen import to_tensorflow from tvm.contrib.msc.framework.tensorflow import tf_v1 from tvm.contrib.msc.framework.tensorflow import tools @@ -85,20 +87,20 @@ def destory(self): super().destory() def _generate_model( - self, graphs: List[MSCGraph] = None, weights: List[Dict[str, tvm.nd.array]] = None - ) -> Any: + self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array] + ) -> tf_v1.Graph: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- - model: Any + model: tf_v1.Graph The runnable model """ @@ -109,17 +111,13 @@ def _generate_model( self._tf_outputs = super()._generate_model(graphs, weights) return self._tf_graph - def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: + def _build_runnable(self, model: Any) -> Any: """Build runnable object Parameters ------- model: Any The meta model. - device: str - The device for place model - is_training: bool - Whether to load model for training Returns ------- @@ -183,6 +181,72 @@ def codegen_func(self): def framework(self): return MSCFramework.TENSORFLOW + @classmethod + def load_native(cls, model: Any) -> Tuple[tf_v1.GraphDef, str, bool]: + """Load the native model + + Parameters + ------- + model: + The native model. + + Returns + ------- + model: tf_v1.GraphDef + The loaded native model. + device: str + The device of the model. + training: + Whether the model is for training. + """ + + if isinstance(model, tf_v1.GraphDef): + native_model = model + else: + raise NotImplementedError( + "Load native model {} with type {} is not supported".format(model, type(model)) + ) + device_protos = device_lib.list_local_devices() + if any(dev.device_type == "GPU" for dev in device_protos): + device = "cuda" + else: + device = "cpu" + return native_model, device, False + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = ModelRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage == MSCStage.PARSE: + config["parse"]["parser"] = from_tensorflow + parse_config = config["parse"].get("parse_config", {}) + parse_config.update( + { + "shape_dict": {i[0]: i[1] for i in config["inputs"]}, + "outputs": config["outputs"], + } + ) + config["parse"]["parse_config"] = parse_config + return config + @classmethod def run_native( cls, @@ -192,7 +256,7 @@ def run_native( output_names: List[str], warm_up: int = 10, repeat: int = 0, - ) -> Dict[str, np.ndarray]: + ) -> Tuple[Dict[str, np.ndarray], float]: """Run the datas and get outputs Parameters @@ -210,11 +274,12 @@ def run_native( repeat: int The repeat num for profile. - Returns ------- outputs: dict The outputs in dict. + avg_time: float + The average time. """ feed_dict = {i_name + ":0": inputs[i_name] for i_name in input_names} diff --git a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py index 585e1dc82584..d72b14cfd53e 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py +++ b/python/tvm/contrib/msc/framework/tensorrt/codegen/codegen.py @@ -18,7 +18,7 @@ import os import subprocess -from typing import Dict, Optional, Tuple, List, Union +from typing import Dict, Optional, List, Union import numpy as np import tvm @@ -33,7 +33,7 @@ def to_sub_tensorrt( graph: MSCGraph, - weights: Optional[Dict[str, tvm.nd.array]] = None, + weights: Dict[str, tvm.nd.array], codegen_config: Optional[Dict[str, str]] = None, print_config: Optional[Dict[str, str]] = None, build_folder: msc_utils.MSCDirectory = None, @@ -76,21 +76,20 @@ def to_sub_tensorrt( def _create_depends(folder: msc_utils.MSCDirectory) -> str: if weights: - # fill fake weights - runtime_weights = weights + # gather weights + engine_wts = {} for node in graph.get_nodes(): + for weight in node.get_weights().values(): + engine_wts[weight.name] = weights[weight.name] if node.optype in ("nn.conv2d", "msc.linear"): weight = node.weight_at("weight") bias = np.zeros([weight.dim_at("O")], dtype=weight.dtype_name) - runtime_weights[node.name + ".bias"] = bias + engine_wts[node.name + ".bias"] = bias # write weights file with open(folder.relpath(graph.name + ".wts"), "w") as f: - f.write("{}\n".format(len(runtime_weights))) - for name, data in runtime_weights.items(): - if isinstance(data, np.ndarray): - write_weight(name, data, f) - else: - write_weight(name, data.asnumpy(), f) + f.write("{}\n".format(len(engine_wts))) + for name, data in engine_wts.items(): + write_weight(name, msc_utils.cast_array(data), f) # save utils sources with folder.create_dir("utils") as utils_folder: for name, source in get_trt_sources().items(): @@ -115,12 +114,13 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: return folder.move(engine_name + ".trt", output_folder.relpath(engine_name + ".trt")) with build_folder as folder: + sub_folder = folder.create_dir(graph.name) codegen = CodeGen( graph, _ffi_api.GetTensorRTSources, codegen_config, print_config, - folder.create_dir(graph.name), + sub_folder, code_format="cpp", ) engine_file = codegen.load([], pre_load=_create_depends, post_load=_build_engine) @@ -133,7 +133,8 @@ def _build_engine(engine_name: str, folder: msc_utils.MSCDirectory) -> str: def to_tensorrt( mod: tvm.IRModule, - graph_infos: List[Tuple[str, MSCGraph, Dict[str, tvm.nd.array]]], + graphs: List[MSCGraph], + weights: Dict[str, tvm.nd.array], codegen_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, print_configs: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, extra_options: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None, @@ -146,8 +147,10 @@ def to_tensorrt( ---------- mod: IRModule The IRModule of relax. - graph_infos: list - The translated graph. + graphs: list + The translated graphs. + weights: dict + The weights. codegen_configs: dict or list The config for codegen. print_configs: dict ot list @@ -167,14 +170,19 @@ def to_tensorrt( target_options = {} if not isinstance(codegen_configs, (list, tuple)): - codegen_configs = [codegen_configs] * len(graph_infos) + codegen_configs = [codegen_configs] * len(graphs) if not isinstance(print_configs, (list, tuple)): - print_configs = [print_configs] * len(graph_infos) + print_configs = [print_configs] * len(graphs) if not isinstance(extra_options, (list, tuple)): - extra_options = [extra_options] * len(graph_infos) - for idx, (graph, weights) in enumerate(graph_infos): + extra_options = [extra_options] * len(graphs) + for idx, graph in enumerate(graphs): options = to_sub_tensorrt( - graph, weights, codegen_configs[idx], print_configs[idx], build_folder, output_folder + graph, + weights, + codegen_configs[idx], + print_configs[idx], + build_folder, + output_folder, ) if extra_options[idx]: options.update(extra_options[idx]) diff --git a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py index c66f8d145035..43e85b601579 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tensorrt/runtime/runner.py @@ -23,6 +23,7 @@ from tvm.contrib.msc.core.ir import MSCGraph from tvm.contrib.msc.core.runtime import BYOCRunner from tvm.contrib.msc.core.tools import ToolType +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.framework.tensorrt.frontend import ( @@ -47,8 +48,14 @@ def setup(self) -> dict: if not self._device.startswith("cuda"): self._device = "cuda" + assert not self._training, "TensorRT only support eval" return super().setup() + def train(self): + """Change status to train""" + + raise Exception("TensorRT only support eval") + def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict: """Execute tool and get plan @@ -66,22 +73,20 @@ def apply_tool(self, tool_type: str, data_loader: Any = None) -> dict: assert data_loader, "data_loader should be given to plan prune" for inputs in data_loader(): self.run(inputs) - self._generate_model() + self._generate_model(self._graphs, self._weights) quantizer.calibrate() assert quantizer.calibrated, "Failed to calibrate the tenosrrt quantizer" return super().apply_tool(tool_type, data_loader) - def _generate_model( - self, graphs: List[MSCGraph] = None, weights: List[Dict[str, tvm.nd.array]] = None - ) -> Any: + def _generate_model(self, graphs: List[MSCGraph], weights: Dict[str, tvm.nd.array]) -> Any: """Codegen the model according to framework Parameters ------- graphs: list The msc graphs. - weights: list> - The weights + weights: dict + The weights. Returns ------- @@ -125,3 +130,33 @@ def partition_func(self): @property def framework(self): return MSCFramework.TENSORRT + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = BYOCRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage in (MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE): + run_config = config[stage].get("run_config", {}) + if "extra_option" not in run_config["generate_config"]: + run_config["generate_config"]["extra_option"] = {} + run_config["generate_config"]["extra_option"]["stage"] = stage + config[stage]["run_config"] = run_config + return config diff --git a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py index 9fa1cf2142f4..effa86595dff 100644 --- a/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py +++ b/python/tvm/contrib/msc/framework/tensorrt/transform/pattern.py @@ -308,12 +308,13 @@ def get_patterns(target) -> List[Pattern]: patterns = [] # basic ops for op, in_types in basic_ops.items(): + inputs = ["input_" + str(i) for i in range(len(in_types))] patterns.append( ( target + "." + op, *basic_pattern("relax." + op, in_types), _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=inputs), ) ) # activation ops @@ -323,7 +324,7 @@ def get_patterns(target) -> List[Pattern]: target + "." + op, *basic_pattern("relax." + op, ["input"]), _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), ) ) # reduce ops @@ -333,7 +334,7 @@ def get_patterns(target) -> List[Pattern]: target + "." + op, *basic_pattern("relax." + op, ["input"]), _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), ) ) # unary ops @@ -343,7 +344,7 @@ def get_patterns(target) -> List[Pattern]: target + "." + op, *basic_pattern("relax." + op, ["input"]), _basic_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), ) ) # elemwise ops @@ -353,7 +354,7 @@ def get_patterns(target) -> List[Pattern]: target + "." + op, *elemwise_pattern("relax." + op), _elemwise_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]), ) ) # compare ops @@ -363,7 +364,7 @@ def get_patterns(target) -> List[Pattern]: target + "." + op, *elemwise_pattern("relax." + op), _compare_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]), ) ) @@ -374,25 +375,25 @@ def get_patterns(target) -> List[Pattern]: target + ".take", *basic_pattern("relax.take", ["input", "input"]), _take_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0", "input_1"]), ), ( target + ".argmax", *argmaxmin_pattern("relax.argmax"), _argmaxmin_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]), ), ( target + ".argmin", *argmaxmin_pattern("relax.argmin"), _argmaxmin_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input"]), ), ( target + ".reshape", *basic_pattern("relax.reshape", ["input", "input"]), _reshape_check, - partial(msc_pattern.msc_attrs_getter, anchor="out"), + partial(msc_pattern.msc_attrs_getter, anchor="out", inputs=["input_0"]), ), ] ) @@ -403,10 +404,13 @@ def get_patterns(target) -> List[Pattern]: target + ".msc.conv2d_bias", *msc_pattern.make_opt_relax_conv_bias_pattern("relax.nn.conv2d"), wrap_basic_check(msc_pattern._check_opt_relax_conv_bias), - partial(msc_pattern.msc_attrs_getter, anchor="conv"), + partial( + msc_pattern.msc_attrs_getter, anchor="conv", inputs=["data", "weight", "bias"] + ), ), ] ) + return patterns diff --git a/python/tvm/contrib/msc/framework/torch/runtime/runner.py b/python/tvm/contrib/msc/framework/torch/runtime/runner.py index 9401e6047fa5..97dbdebcb3a9 100644 --- a/python/tvm/contrib/msc/framework/torch/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/torch/runtime/runner.py @@ -17,6 +17,7 @@ # pylint: disable=unused-import """tvm.contrib.msc.framework.torch.runtime.runner""" +import os import time from typing import Dict, List, Union, Tuple, Any import numpy as np @@ -25,19 +26,26 @@ import tvm from tvm.contrib.msc.core.runtime import ModelRunner from tvm.contrib.msc.core.ir import MSCGraph +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils +from tvm.contrib.msc.framework.torch.frontend import from_torch from tvm.contrib.msc.framework.torch.codegen import to_torch from tvm.contrib.msc.framework.torch.frontend import set_weight_alias -from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.framework.torch import tools class TorchRunner(ModelRunner): """Runner of Torch""" - def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: + def _translate(self, mod: tvm.IRModule) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: """Translate IRModule to MSCgraphs + Parameters + ------- + mod: tvm.IRModule + The module to be translated. + Returns ------- graph_list: list @@ -45,20 +53,16 @@ def _translate(self) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]: weights_list: list> The translated weights """ - graphs, weights = super()._translate() + graphs, weights = super()._translate(mod) return [set_weight_alias(graphs[0])], weights - def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: + def _build_runnable(self, model: Any) -> Any: """Build runnable object Parameters ------- model: Any The meta model. - device: str - The device for place model - is_training: bool - Whether to load model for training Returns ------- @@ -66,13 +70,13 @@ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: The runnable """ - if device == "cpu": + if self._device.startswith("cpu"): pass - elif device.startswith("cuda"): - model = model.to(torch.device(device)) + elif self._device.startswith("cuda"): + model = model.to(torch.device(self._device)) else: - raise NotImplementedError("Unsupported device " + str(device)) - if is_training: + raise NotImplementedError("Unsupported device " + str(self._device)) + if self._training: model = model.train() else: model = model.eval() @@ -134,6 +138,80 @@ def codegen_func(self): def framework(self): return MSCFramework.TORCH + @classmethod + def load_native(cls, model: Any) -> Tuple[torch.nn.Module, str, bool]: + """Load the native model + + Parameters + ------- + model: + The native model. + + Returns + ------- + model: torch.nn.Module + The loaded native model. + device: str + The device of the model. + training: + Whether the model is for training. + """ + + if isinstance(model, dict) and "model" in model: + native_model = msc_utils.load_callable(model["model"]) + elif isinstance(model, torch.nn.Module): + native_model = model + else: + raise NotImplementedError( + "Load native model {} with type {} is not supported".format(model, type(model)) + ) + parameters = list(model.parameters()) + if parameters: + ref_device = parameters[0].device + if ref_device.index: + device = "{}:{}".format(ref_device.type, ref_device.index) + else: + device = ref_device.type + else: + device = "cpu" + return native_model, device, model.training + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = ModelRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage == MSCStage.PARSE: + config["parse"]["parser"] = from_torch + parse_config = config["parse"].get("parse_config", {}) + parse_config.update( + { + "input_info": [ + [i[1], "float" if len(i) < 2 else i[2]] for i in config["inputs"] + ], + "input_names": [i[0] for i in config["inputs"]], + } + ) + config["parse"]["parse_config"] = parse_config + return config + @classmethod def run_native( cls, @@ -143,7 +221,7 @@ def run_native( output_names: List[str], warm_up: int = 10, repeat: int = 0, - ) -> Dict[str, np.ndarray]: + ) -> Tuple[Dict[str, np.ndarray], float]: """Run the datas and get outputs Parameters @@ -165,6 +243,8 @@ def run_native( ------- outputs: dict The outputs in dict. + avg_time: float + The average time. """ parameters = list(model.parameters()) @@ -172,9 +252,9 @@ def run_native( device = parameters[0].device else: device = torch.device("cpu") + torch_inputs = [torch.from_numpy(inputs[i_name]).to(device) for i_name in input_names] def _run_once(): - torch_inputs = [torch.from_numpy(inputs[i_name]).to(device) for i_name in input_names] return model(*torch_inputs) if repeat > 0: @@ -197,3 +277,25 @@ def _run_once(): o_name: msc_utils.cast_array(o_data) for o_name, o_data in zip(output_names, outputs) } return outputs, avg_time + + @classmethod + def dump_nativate(cls, model: torch.nn.Module, folder: msc_utils.MSCDirectory) -> str: + """Dump the nativate model + + Parameters + ------- + model: torch.nn.Module + The runnable model. + folder: MSCDirectory + The export folder. + + Returns + ------- + export_path: str + The exported path + """ + + graph_model = torch.fx.symbolic_trace(model) + exp_path = folder.create_dir("model") + graph_model.to_folder(exp_path.path, "native_model") + return {"model": exp_path.relpath("module.py") + ":native_model"} diff --git a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py index 1e2b5257576b..690e146becfd 100644 --- a/python/tvm/contrib/msc/framework/tvm/runtime/runner.py +++ b/python/tvm/contrib/msc/framework/tvm/runtime/runner.py @@ -17,13 +17,16 @@ # pylint: disable=unused-import """tvm.contrib.msc.framework.runtime.tvm.runner""" -from typing import Dict, List, Union, Any +import time +from typing import Dict, List, Union, Any, Tuple import numpy as np import tvm from tvm.contrib.msc.core.runtime import ModelRunner from tvm.contrib.msc.core.tools import execute_step +from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core.utils.namespace import MSCFramework +from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.framework.tvm.codegen import to_relax from tvm.contrib.msc.framework.tvm import tools @@ -33,6 +36,8 @@ class WrapRunnable(object): Parameters ------- + runner: ModelRunner + The runner context runnable: tvm.relax.VirtualMachine The virtual machine. entry: str @@ -52,17 +57,13 @@ def __call__(self, *inputs) -> List[tvm.nd.array]: class TVMRunner(ModelRunner): """Runner of Relax""" - def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: + def _build_runnable(self, model: Any) -> Any: """Build runnable object Parameters ------- model: Any The meta model. - device: str - The device for place model - is_training: bool - Whether to load model for training Returns ------- @@ -70,6 +71,10 @@ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: The runnable """ + if self._training: + model = tvm.relax.transform.DecomposeOpsForTraining()(model) + else: + model = tvm.relax.transform.DecomposeOpsForInference()(model) if "builder" in self._generate_config: builder, build_config = self._generate_config["builder"] runnable = builder(model, **build_config) @@ -80,12 +85,12 @@ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: ) else: model = tvm.relax.transform.LegalizeOps()(model) - if device == "cpu": + if self._device.startswith("cpu"): target = tvm.target.Target("llvm") with tvm.transform.PassContext(opt_level=3): relax_exec = tvm.relax.build(model, target) runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) - elif device.startswith("cuda"): + elif self._device.startswith("cuda"): target = tvm.target.Target("cuda") with target: model = tvm.tir.transform.DefaultGPUSchedule()(model) @@ -93,7 +98,7 @@ def _to_runnable(self, model: Any, device: str, is_training: bool) -> Any: relax_exec = tvm.relax.build(model, target) runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) else: - raise NotImplementedError("Unsupported device " + str(device)) + raise NotImplementedError("Unsupported device " + str(self._device)) return WrapRunnable(runnable) def _call_runnable( @@ -151,3 +156,137 @@ def codegen_func(self): @property def framework(self): return MSCFramework.TVM + + @classmethod + def load_native(cls, model: Any) -> tvm.IRModule: + """Load the native model + + Parameters + ------- + model: + The native model. + + Returns + ------- + model: tvm.IRModule + The loaded native model. + """ + + if isinstance(model, dict) and "model" in model: + with open(model["model"], "r") as f: + native_model = tvm.ir.load_json(f.read()) + elif isinstance(model, tvm.IRModule): + native_model = model + else: + raise NotImplementedError( + "Load native model {} with type {} is not supported".format(model, type(model)) + ) + if tvm.cuda().exist: + device = "cuda" + else: + device = "cpu" + return native_model, device, False + + @classmethod + def update_config(cls, stage: str, config: dict, model: Any = None) -> dict: + """Update the config for parse + + Parameters + ------- + stage: str + The stage to be updated + config: dict + The config for pipeline. + model: + The native model. + + Returns + ------- + config: dict + The updated config. + """ + + config = ModelRunner.update_config(stage, config, model) + if stage not in config: + return config + if stage == MSCStage.PARSE: + # pylint: disable=unused-argument + def passby(mod, *args, **kwargs): + return mod, None + + config["parse"]["parser"] = passby + return config + + @classmethod + def run_native( + cls, + model: tvm.IRModule, + inputs: Dict[str, np.ndarray], + input_names: List[str], + output_names: List[str], + warm_up: int = 10, + repeat: int = 0, + ) -> Tuple[Dict[str, np.ndarray], float]: + """Run the datas and get outputs + + Parameters + ------- + model: tvm.IRModule + The runnable model. + inputs: dict + The inputs in dict. + input_names: list + The input names. + output_names: list + The outut names. + warm_up: int + The warm_up num for profile. + repeat: int + The repeat num for profile. + + Returns + ------- + outputs: dict + The outputs in dict. + avg_time: float + The average time. + """ + + model = tvm.relax.transform.LegalizeOps()(model) + if tvm.cuda().exist: + target = tvm.target.Target("cuda") + with target: + model = tvm.tir.transform.DefaultGPUSchedule()(model) + with tvm.transform.PassContext(opt_level=3): + relax_exec = tvm.relax.build(model, target) + runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cuda()) + tvm_inputs = [tvm.nd.array(inputs[i], device=tvm.cuda()) for i in input_names] + else: + target = tvm.target.Target("llvm") + with tvm.transform.PassContext(opt_level=3): + relax_exec = tvm.relax.build(model, target) + runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu()) + tvm_inputs = [tvm.nd.array(inputs[i]) for i in input_names] + + def _run_once(): + return runnable["main"](*tvm_inputs) + + if repeat > 0: + for _ in range(warm_up): + _run_once() + start = time.time() + for _ in range(repeat): + outputs = _run_once() + avg_time = (time.time() - start) * 1000 / repeat + else: + outputs = _run_once() + avg_time = -1 + if isinstance(outputs, tvm.runtime.NDArray): + outputs = [outputs] + assert len(output_names) == len(outputs), "Outputs mismatch, {} with {}".format( + output_names, len(outputs) + ) + outputs = { + o_name: msc_utils.cast_array(o_data) for o_name, o_data in zip(output_names, outputs) + } + return outputs, avg_time diff --git a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py index cf5ab49e8214..0054b7e77349 100644 --- a/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py +++ b/python/tvm/contrib/msc/framework/tvm/tools/track/tracker.py @@ -101,6 +101,9 @@ def _execute_after_forward( for data, name in zip(outputs[output_num:], self._track_names): consumer = self._track_tensors[name]["consumer"] strategys = self._get_tensor_strategys(name, consumer) + producer = self.find_producer(name) + if producer == "nn.batch_norm": + data = data[0] self._track_tensor(data, name, consumer, strategys) if output_num == 1: return super()._execute_after_forward(outputs[0]) @@ -136,7 +139,7 @@ def _process_tensor( """ if self.is_weight(name): - return self._track_tensor(self.get_data(name), name, consumer, strategys) + self._track_tensor(self.get_data(name), name, consumer, strategys) if name not in self._track_tensors: self._track_tensors[name] = {"consumer": consumer, "tensor": tensor} self._track_names.append(name) diff --git a/python/tvm/contrib/msc/pipeline/manager.py b/python/tvm/contrib/msc/pipeline/manager.py index 337503de5ba5..a8327a08cde3 100644 --- a/python/tvm/contrib/msc/pipeline/manager.py +++ b/python/tvm/contrib/msc/pipeline/manager.py @@ -25,9 +25,10 @@ import numpy as np import tvm +from tvm.contrib.msc.core import transform as msc_transform from tvm.contrib.msc.core.runtime import BaseRunner from tvm.contrib.msc.core.tools import ToolType -from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap +from tvm.contrib.msc.core.utils.namespace import MSCFramework, MSCMap, MSCKey from tvm.contrib.msc.core.utils.message import MSCStage from tvm.contrib.msc.core import utils as msc_utils from tvm.contrib.msc.core.gym.control import create_controller @@ -44,19 +45,28 @@ class BaseManager(object): The config for pipeline. """ - def __init__(self, model, config): - # check config - for stage in ["inputs", "outputs", "dataset", "prepare", "compile"]: + def __init__(self, model: Any, config: dict): + # check stage + for stage in ["inputs", "outputs", "dataset", MSCStage.PREPARE, MSCStage.COMPILE]: assert stage in config, "{} should be given to run the pipeline".format(stage) + MSCMap.reset() - self._model = model - self._workspace = msc_utils.set_workspace(config.get("workspace")) - log_path = config.get("log_path") or self._workspace.relpath("MSC_LOG", keep_history=False) - if config.get("debug_level", 0) > 0 and "verbose" not in config: - self._verbose = "debug" + self._model_type = config["model_type"] + self._model, self._device, self._training = self._get_runner_cls( + self._model_type + ).load_native(model) + use_cache = config.get("use_cache", True) + self._workspace = msc_utils.set_workspace(config.get("workspace"), use_cache) + self._verbose = config.get("verbose", "info") + if "logger" in config: + self._logger = config["logger"] + MSCMap.set(MSCKey.GLOBALE_LOGGER, self._logger) else: - self._verbose = config.get("verbose", "info") - self._logger = msc_utils.set_global_logger(self._verbose, log_path) + log_path = config.get("log_path") or self._workspace.relpath( + "MSC_LOG", keep_history=False + ) + self._logger = msc_utils.set_global_logger(self._verbose, log_path) + self._optimized, self._compiled = False, False msc_utils.time_stamp(MSCStage.SETUP) self._logger.info(msc_utils.msg_block("SETUP", self.setup(config))) @@ -74,18 +84,18 @@ def setup(self, config: dict) -> dict: The setup info. """ + self._meta_config = config + self._optimize_type = config.get(MSCStage.OPTIMIZE, {}).get("run_type", self._model_type) + self._compile_type = config.get(MSCStage.COMPILE, {}).get("run_type", self._model_type) self._config, self._debug_levels = self.update_config(config) self._tools_config = {} self._relax_mod, self._runner = None, None - self._data_loader, self._sample_inputs = None, None - self._model_type = self._config["model_type"] - self._optimize_type = self._config.get("optimize", {}).get("run_type", self._model_type) - self._compile_type = self._config.get("compile", {}).get("run_type", self._model_type) + self._sample_inputs = None self._report = { "success": False, "info": { "workspace": self._workspace.path, - "model_type": self._config["model_type"], + "model_type": "{}({})".format(self._model_type, self._device), }, "duration": {}, "profile": {}, @@ -110,14 +120,22 @@ def update_config(self, config: dict) -> dict: assert "inputs" in config, "inputs should be given to run manager" assert "outputs" in config, "outputs should be given to run manager" config, debug_levels = msc_utils.copy_dict(config), {} - for stage in ["prepare", "parse"]: + for stage in [MSCStage.PREPARE, MSCStage.PARSE]: if stage not in config: config[stage] = {} - config = self._update_prepare_config(config) - config = self._update_parse_config(config) - for stage in ["baseline", "optimize", "compile"]: - config = self._update_runner_config(config, stage) - config = self._update_tool_config(config) + config = self._get_runner_cls(self._model_type).update_config( + MSCStage.PARSE, config, self._model + ) + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: + if stage not in config: + continue + if "run_type" not in config[stage]: + config[stage]["run_type"] = self._model_type + config = self._get_runner_cls(config[stage]["run_type"]).update_config( + stage, config, self._model + ) + if MSCStage.OPTIMIZE in config: + config[MSCStage.OPTIMIZE] = self._update_tool_config(config[MSCStage.OPTIMIZE]) def _set_debug_level(stage: str, stage_config: dict, default: int = None) -> dict: if "debug_level" in stage_config: @@ -127,34 +145,44 @@ def _set_debug_level(stage: str, stage_config: dict, default: int = None) -> dic stage_config["debug_level"] = default return debug_levels - debug_level = config.get("debug_level") - for stage in ["baseline", "optimize", "compile"]: + if self._verbose.startswith("debug:"): + debug_level = int(self._verbose.split(":")[1]) + else: + debug_level = 0 + for stage in [MSCStage.BASELINE, MSCStage.OPTIMIZE, MSCStage.COMPILE]: if stage not in config: continue debug_levels = _set_debug_level(stage, config[stage]["run_config"], debug_level) - if "optimize" in config: + if MSCStage.OPTIMIZE in config: for t_type in ToolType.all_types(): - if t_type not in config["optimize"]: + if t_type not in config[MSCStage.OPTIMIZE]: continue debug_levels = _set_debug_level( - self._get_tool_stage(t_type), config["optimize"][t_type], debug_level + self._get_tool_stage(t_type), config[MSCStage.OPTIMIZE][t_type], debug_level ) ordered_keys = [ "model_type", "inputs", "outputs", "dataset", - "prepare", - "parse", - "baseline", - "optimize", - "compile", + MSCStage.PREPARE, + MSCStage.PARSE, + MSCStage.BASELINE, + MSCStage.OPTIMIZE, + MSCStage.COMPILE, ] return {k: config[k] for k in ordered_keys if k in config}, debug_levels - def run_pipe(self) -> dict: + def run_pipe(self, run_optimize: bool = True, run_compile: bool = True) -> dict: """Run the pipeline and return object. + Parameters + ---------- + run_optimize: bool + Whether to run the optimize. + run_compile: bool + Whether to run the compile. + Returns ------- report: @@ -162,113 +190,67 @@ def run_pipe(self) -> dict: """ err_msg = None - use_cache = self._config.get("use_cache", True) try: - self._data_loader, self._sample_inputs = self.prepare( - self._config["prepare"], use_cache - ) - self._relax_mod = self.parse(self._config["parse"], use_cache) - if "baseline" in self._config: - self._runner = self.baseline(self._config["baseline"], use_cache) - if "optimize" in self._config: - self._runner = self.optimize(self._config["optimize"], use_cache) - self._runner = self.compile(self._config["compile"], use_cache) + self.prepare() + self.parse() + if MSCStage.BASELINE in self._config: + self.baseline() + if run_optimize and MSCStage.OPTIMIZE in self._config: + self.optimize() + if run_compile: + self.compile() except Exception as exc: # pylint: disable=broad-exception-caught err_msg = "Pipeline failed:{}\nTrace: {}".format(exc, traceback.format_exc()) self.summary(err_msg) self._logger.info(msc_utils.msg_block("SUMMARY", self._report, 0)) return self._report - def prepare(self, stage_config: dict, use_cache: bool = False) -> Dict[str, np.ndarray]: + def prepare(self) -> Dict[str, np.ndarray]: """Prepare datas for the pipeline. - Parameters - ---------- - stage_config: dict - The config of this stage. - use_cache: bool - Whether to use cache. - Returns ------- + dataloader: + The dataloader sample_inputs: dict The sample inputs. """ msc_utils.time_stamp(MSCStage.PREPARE) - - # create data loader - source_loader = self._config["dataset"].get("loader") - max_batch = self._config["dataset"].get("max_batch", 5) - assert source_loader, "Dataset loader should be given for msc pipeline" - if source_loader.startswith("from_random"): - - def get_random(): - for _ in range(max_batch): - yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]} - - data_loader, source_type = get_random, "Random" - elif msc_utils.is_io_dataset(source_loader): - - def load_datas(): - for inputs, _ in msc_utils.IODataLoader(data_loader, end=max_batch): - yield inputs - - data_loader, source_type = load_datas, "IOData" - elif callable(source_loader): - - def get_source(): - for idx, inputs in enumerate(source_loader()): - if idx >= max_batch: - break - yield inputs - - data_loader, source_type = get_source, "Custom" - else: - raise TypeError( - "Unexpected source loader {}({})".format(source_loader, type(source_loader)) - ) - self._logger.info("Create data loader(%s) %s", source_type, data_loader) + stage_config = self._config[MSCStage.PREPARE] + use_cache = self._config.get("use_cache", True) + runner_cls = self._get_runner_cls(self._model_type) + run_func = runner_cls.run_native if hasattr(runner_cls, "run_native") else None + input_names = [i[0] for i in self._config["inputs"]] # create golden - golden_folder = msc_utils.get_dataset_dir().relpath("Golden", use_cache) - input_names, sample_inputs = [i[0] for i in self._config["inputs"]], None - report = {"golden_folder": golden_folder} - runner_cls = self._get_runner_cls(self._config["model_type"]) - run_func = runner_cls.run_native if hasattr(runner_cls, "run_native") else None - if use_cache and msc_utils.is_io_dataset(golden_folder): - golden_loader, source_type = msc_utils.IODataLoader(golden_folder), "Cache" - report["datas_info"] = golden_loader.info - sample_inputs = golden_loader[0][0] - self._logger.debug("Load %d cached golden from %s", len(golden_loader), golden_folder) + if "golden" in self._config["dataset"]: + golden_folder = self._config["dataset"]["golden"]["loader"] else: - # save golden - golden_cnt, max_golden = 0, self._config["dataset"].get("max_golden", 5) + golden_folder = msc_utils.get_dataset_dir().relpath("Golden", use_cache) + report = {"golden_folder": golden_folder} + if msc_utils.is_io_dataset(golden_folder): + loader, source_type = msc_utils.IODataLoader(golden_folder), "Cache" + self._sample_inputs = loader[0][0] + report["datas_info"] = loader.info + self._logger.debug("Load %d golden from %s", len(loader), golden_folder) + elif run_func: + loader, source_type = self._get_loader(MSCStage.PREPARE), "Native" saver_options = {"input_names": input_names, "output_names": self._config["outputs"]} - if run_func: - with msc_utils.IODataSaver(golden_folder, saver_options) as saver: - for inputs in data_loader(): - if golden_cnt >= max_golden: - break - if not sample_inputs: - sample_inputs = inputs - outputs, _ = run_func( - self._model, inputs, input_names, self._config["outputs"] - ) - golden_cnt = saver.save_batch(inputs, outputs) - report["datas_info"] = saver.info - elif isinstance(data_loader, msc_utils.IODataLoader): - with msc_utils.IODataSaver(golden_folder, saver_options) as saver: - for inputs, outputs in data_loader(): - if golden_cnt >= max_golden: - break - if not sample_inputs: - sample_inputs = inputs - golden_cnt = saver.save_batch(inputs, outputs) - report["datas_info"] = saver.info - else: - raise Exception("golden or runner should given in prepare to save golden") - self._logger.debug("Saved %d golden to %s", golden_cnt, golden_folder) + cnt, max_golden = 0, self._config["dataset"][MSCStage.PREPARE].get("max_golden", 5) + with msc_utils.IODataSaver(golden_folder, saver_options) as saver: + for inputs in loader(): + if cnt >= max_golden > 0: + break + if not self._sample_inputs: + self._sample_inputs = inputs + outputs, _ = run_func(self._model, inputs, input_names, self._config["outputs"]) + cnt = saver.save_batch(inputs, outputs) + report["datas_info"] = saver.info + self._logger.debug("Saved %d golden to %s", cnt, golden_folder) + else: + raise Exception("golden_folder or runner should given to save golden") + self._config["dataset"]["golden"] = {"loader": golden_folder, "max_batch": -1} def _to_abstract(info: dict) -> dict: def _to_tensor_str(info): @@ -281,30 +263,25 @@ def _to_tensor_str(info): } report["datas_info"] = _to_abstract(report["datas_info"]) - report["sample_inputs"] = sample_inputs + report["sample_inputs"] = self._sample_inputs self._logger.info(msc_utils.msg_block("GOLDEN({})".format(source_type), report)) # profile if "profile" in stage_config and run_func: benchmark = stage_config["profile"].get("benchmark", {}) - repeat = benchmark.get("repeat", 100) + benchmark["repeat"] = self._get_repeat(benchmark) self._logger.debug("Prepare profile with %s(%s)", run_func, benchmark) _, avg_time = run_func( - self._model, sample_inputs, input_names, self._config["outputs"], **benchmark + self._model, self._sample_inputs, input_names, self._config["outputs"], **benchmark ) - self._logger.info("Profile(prepare) {} times -> {:.2f} ms".format(repeat, avg_time)) - self._report["profile"]["prepare"] = {"latency": "{:.2f} ms".format(avg_time)} - return data_loader, sample_inputs + msg = "{:.2f} ms @ {}".format(avg_time, self._device) + self._report["profile"][MSCStage.PREPARE] = {"latency": msg} + self._logger.info("Profile(prepare) %d times -> %s", benchmark["repeat"], msg) - def parse(self, stage_config: dict, use_cache: bool = False) -> tvm.IRModule: - """Parse the model to IRModule. + return self._sample_inputs - Parameters - ---------- - stage_config: dict - The config of this stage. - use_cache: bool - Whether to use cache. + def parse(self) -> tvm.IRModule: + """Parse the model to IRModule. Returns ------- @@ -313,14 +290,17 @@ def parse(self, stage_config: dict, use_cache: bool = False) -> tvm.IRModule: """ msc_utils.time_stamp(MSCStage.PARSE) + stage_config = self._config[MSCStage.PARSE] + use_cache = self._config.get("use_cache", True) + cache_path = msc_utils.get_cache_dir().relpath("parsed_relax.json") if use_cache else None if cache_path and os.path.isfile(cache_path): with open(cache_path, "r") as f: - relax_mod = tvm.ir.load_json(f.read()) + self._relax_mod = tvm.ir.load_json(f.read()) self._logger.info("Load parsed mod from %s", cache_path) else: - parse_config = stage_config.get("parse_config", {}) - runner_cls = self._get_runner_cls(self._config["compile"]["run_type"]) + parse_config = msc_utils.copy_dict(stage_config.get("parse_config", {})) + runner_cls = self._get_runner_cls(self._config[MSCStage.COMPILE]["run_type"]) trans_func = ( runner_cls.target_transform if hasattr(runner_cls, "target_transform") else None ) @@ -330,25 +310,20 @@ def parse(self, stage_config: dict, use_cache: bool = False) -> tvm.IRModule: "trans_func": trans_func, } self._logger.info(msc_utils.msg_block("PARSE", parse_info)) - relax_mod, _ = stage_config["parser"](self._model, as_msc=False, **parse_config) + parse_config["as_msc"] = False + self._relax_mod, _ = stage_config["parser"](self._model, **parse_config) if trans_func: - relax_mod = trans_func(relax_mod) + self._relax_mod = trans_func(self._relax_mod) + self._relax_mod = msc_transform.SetExprName()(self._relax_mod) if cache_path: with open(cache_path, "w") as f: - f.write(tvm.ir.save_json(relax_mod)) + f.write(tvm.ir.save_json(self._relax_mod)) self._logger.debug("Save parsed mod to %s", cache_path) - return relax_mod + return self._relax_mod - def baseline(self, stage_config: dict, use_cache: bool = False) -> BaseRunner: + def baseline(self) -> BaseRunner: """Run the baseline. - Parameters - ---------- - stage_config: dict - The config of this stage. - use_cache: bool - Whether to use cache. - Returns ------- runner: BaseRunner @@ -356,17 +331,36 @@ def baseline(self, stage_config: dict, use_cache: bool = False) -> BaseRunner: """ msc_utils.time_stamp(MSCStage.BASELINE) - return self._create_runner(MSCStage.BASELINE, stage_config, use_cache=use_cache) + self._runner = self._create_runner( + MSCStage.BASELINE, + self._config[MSCStage.BASELINE], + use_cache=self._config.get("use_cache", True), + ) + return self._runner - def optimize(self, stage_config: dict, use_cache: bool = False) -> BaseRunner: + def optimize(self) -> BaseRunner: """Run the optimize and return object. - Parameters - ---------- - stage_config: dict - The config of this stage. - use_cache: bool - Whether to use cache. + Returns + ------- + runner: BaseRunner + The runner. + """ + + stage_config = self._config[MSCStage.OPTIMIZE] + self.apply_tools(stage_config) + msc_utils.time_stamp(MSCStage.OPTIMIZE) + self._runner = self._create_runner( + MSCStage.OPTIMIZE, + stage_config, + tools_config=self._tools_config, + use_cache=self._config.get("use_cache", True), + ) + self._optimized = True + return self._runner + + def compile(self) -> BaseRunner: + """Run the compile and return object. Returns ------- @@ -374,6 +368,27 @@ def optimize(self, stage_config: dict, use_cache: bool = False) -> BaseRunner: The runner. """ + stage_config = self._config[MSCStage.COMPILE] + self.apply_tools(stage_config) + msc_utils.time_stamp(MSCStage.COMPILE) + self._runner = self._create_runner( + MSCStage.COMPILE, + stage_config, + tools_config=self._tools_config, + use_cache=self._config.get("use_cache", True), + ) + self._compiled = True + return self._runner + + def apply_tools(self, stage_config: dict): + """Apply tools for a stage. + + Parameters + ---------- + stage_config: dict + The config of this stage. + """ + runner_cls = self._get_runner_cls(stage_config["run_type"]) def _tool_enabled(tool_type: str) -> bool: @@ -391,35 +406,6 @@ def _tool_enabled(tool_type: str) -> bool: if _tool_enabled(ToolType.DISTILLER): self._apply_tool(ToolType.DISTILLER, stage_config) - # optimize and get the runner - msc_utils.time_stamp(MSCStage.OPTIMIZE) - return self._create_runner( - MSCStage.OPTIMIZE, stage_config, tools_config=self._tools_config, use_cache=use_cache - ) - - def compile(self, stage_config: dict, use_cache: bool = False) -> BaseRunner: - """Run the compile and return object. - - Parameters - ---------- - stage_config: dict - The config of this stage. - use_cache: bool - Whether to use cache. - ret_type: str - The return type runner| model. - - Returns - ------- - runner: BaseRunner - The runner. - """ - - msc_utils.time_stamp(MSCStage.COMPILE) - return self._create_runner( - MSCStage.COMPILE, stage_config, tools_config=self._tools_config, use_cache=use_cache - ) - def summary(self, err_msg=None): """Summary the pipeline. @@ -501,7 +487,11 @@ def _create_runner( run_config["generate_config"]["build_folder"] = msc_utils.get_build_dir().create_dir( stage, cleanup=cleanup ) - opt_config = self._config.get("optimize", {}) + if "device" not in run_config: + run_config["device"] = self._device + if "training" not in run_config: + run_config["training"] = self._training + opt_config = self._config.get(MSCStage.OPTIMIZE, {}) if ToolType.TRACKER in opt_config and runner_cls.support_tool(ToolType.TRACKER): tools_config = {**tools_config, ToolType.TRACKER: opt_config[ToolType.TRACKER]} # Build runner @@ -572,10 +562,9 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) extra_config = { "env": { "runner": runner, - "data_loader": self._data_loader, + "data_loader": self._get_loader(tool_stage), "knowledge": knowledge, }, - "debug_level": runner.debug_level, "verbose": self._verbose, } controller = create_controller(runner.stage, config, extra_config) @@ -586,7 +575,7 @@ def _apply_tool(self, tool_type: str, stage_config: dict, add_tool: bool = True) "Gym save %d knowledge(%s) -> %s", len(knowledge), tool_type, plan_file ) return plan_file - return runner.apply_tool(tool_type, self._data_loader) + return runner.apply_tool(tool_type, self._get_loader(tool_stage)) def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: """Profile the runner. @@ -612,7 +601,7 @@ def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: # check accuracy check_config = profile_config.get("check", {}) if check_config: - loader = msc_utils.IODataLoader(msc_utils.get_dataset_dir().relpath("Golden")) + loader = msc_utils.IODataLoader(self._config["dataset"]["golden"]["loader"]) total, passed = 0, 0 acc_report = {"config": check_config} for idx, (inputs, outputs) in enumerate(loader): @@ -652,7 +641,7 @@ def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: for _ in range(benchmark_config.get("warm_up", 10)): runner.run(self._sample_inputs) start = time.time() - repeat = benchmark_config.get("repeat", 100) + repeat = self._get_repeat(benchmark_config, runner.device) for _ in range(repeat): runner.run(self._sample_inputs) avg_time = (time.time() - start) * 1000 / repeat @@ -661,152 +650,30 @@ def _profile_runner(self, runner: BaseRunner, stage_config: str) -> dict: self._logger.info(msg) return report - def _update_prepare_config(self, config: dict) -> dict: - """Update prepare in stage config. - - Parameters - ---------- - config: dict - The config of a pipeline. - - Returns - ------- - config: dict - The updated config. - """ - - if config["model_type"] == MSCFramework.TORCH: - import torch - - assert isinstance( - self._model, torch.nn.Module - ), "Model for torch should be nn.Module, get {}({})".format( - self._model, type(self._model) - ) - elif config["model_type"] == MSCFramework.TENSORFLOW: - from tvm.contrib.msc.framework.tensorflow import tf_v1 - - assert isinstance( - self._model, tf_v1.GraphDef - ), "Model for tenosrflow should be tf.GraphDef, get {}({})".format( - self._model, type(self._model) - ) - else: - raise Exception("Unexpect model_type " + str(config["model_type"])) - return config - - def _update_parse_config(self, config: dict) -> dict: - """Update parse in stage config. - - Parameters - ---------- - config: dict - The config of a pipeline. - - Returns - ------- - config: dict - The updated config. - """ - - if config["model_type"] == MSCFramework.TORCH: - from tvm.contrib.msc.framework.torch.frontend import from_torch - - config["parse"]["parser"] = from_torch - parse_config = config["parse"].get("parse_config", {}) - parse_config.update( - { - "input_info": [[i[1], i[2]] for i in config["inputs"]], - "input_names": [i[0] for i in config["inputs"]], - } - ) - config["parse"]["parse_config"] = parse_config - elif config["model_type"] == MSCFramework.TENSORFLOW: - from tvm.contrib.msc.framework.tensorflow.frontend import from_tensorflow - - config["parse"]["parser"] = from_tensorflow - parse_config = config["parse"].get("parse_config", {}) - parse_config.update( - { - "shape_dict": {i[0]: i[1] for i in config["inputs"]}, - "outputs": config["outputs"], - } - ) - config["parse"]["parse_config"] = parse_config - else: - raise Exception("Unexpect model_type " + str(config["model_type"])) - return config - - def _update_runner_config(self, config: dict, stage: str) -> dict: - """Update runtime stage in stage config. - - Parameters - ---------- - config: dict - The config of a pipeline. - stage: str - The stage to be updated - """ - - if stage not in config: - return config - model_type = config["model_type"] - if "run_type" not in config[stage]: - config[stage]["run_type"] = model_type - # update run config - run_config = config[stage].get("run_config", {}) - if "translate_config" not in run_config: - run_config["translate_config"] = {} - if "build" not in run_config["translate_config"]: - run_config["translate_config"]["build"] = {} - if "generate_config" not in run_config: - run_config["generate_config"] = {} - run_config["translate_config"]["build"]["input_aliases"] = [i[0] for i in config["inputs"]] - run_config["translate_config"]["build"]["output_aliases"] = config["outputs"] - if model_type == MSCFramework.TORCH: - parameters = list(self._model.parameters()) - if parameters: - ref_device = parameters[0].device - if ref_device.type == "cpu": - device = "cpu" - else: - device = "{}:{}".format(ref_device.type, ref_device.index) - else: - device = "cpu" - run_config.update({"device": device, "is_training": self._model.training}) - if config[stage]["run_type"] == MSCFramework.TENSORRT: - if "extra_option" not in run_config["generate_config"]: - run_config["generate_config"]["extra_option"] = {} - run_config["generate_config"]["extra_option"]["stage"] = stage - config[stage]["run_config"] = run_config - return config - - def _update_tool_config(self, config: dict) -> dict: + def _update_tool_config(self, opt_config: dict) -> dict: """Update tool in stage config. Parameters ---------- - config: dict - The config of a pipeline. + opt_config: dict + The config of optimize. Returns ------- config: dict - The updated config. + The updated config of optimize. """ - if "optimize" not in config: - return config for tool_type in ToolType.all_types(): - if tool_type not in config["optimize"]: + if tool_type not in opt_config: continue - tool_config = config["optimize"][tool_type] + tool_config = opt_config[tool_type] if "plan_file" not in tool_config: tool_config["plan_file"] = "msc_{}.json".format(tool_type) tool_config["plan_file"] = msc_utils.to_abs_path( tool_config["plan_file"], msc_utils.get_config_dir() ) - return config + return opt_config def _get_tool_stage(self, tool_type: str) -> str: """Map the stage according to tool_type @@ -868,6 +735,66 @@ def _get_runner_cls(self, run_type: str) -> BaseRunner: raise NotImplementedError("_get_runner_cls is not implemented for BaseManager") + def _get_loader(self, name: str = MSCStage.PREPARE) -> Any: + """Get the data loader""" + + config = self._config["dataset"].get(name, self._config["dataset"][MSCStage.PREPARE]) + source_loader = config.get("loader") + max_batch = config.get("max_batch", 5) + assert source_loader, "Dataset loader should be given for msc pipeline" + if source_loader == "from_random": + max_batch = max(max_batch, 5) + + def get_random(): + for _ in range(max_batch): + yield {i[0]: np.random.rand(*i[1]).astype(i[2]) for i in self._config["inputs"]} + + loader, source_type = get_random, "Random" + elif msc_utils.is_io_dataset(source_loader): + + def load_datas(): + for inputs, _ in msc_utils.IODataLoader(source_loader, end=max_batch): + yield inputs + + loader, source_type = load_datas, "IOData" + elif callable(source_loader): + + def get_source(): + for idx, inputs in enumerate(source_loader()): + if idx >= max_batch > 0: + break + yield inputs + + loader, source_type = get_source, "Custom" + else: + raise TypeError( + "Unexpected source loader {}({})".format(source_loader, type(source_loader)) + ) + self._logger.debug("Create data loader(%s) %s(%s)", name, loader, source_type) + return loader + + def _get_repeat(self, benchmark: dict, device: str = None) -> int: + """Get the repeat number for benchmark + + Parameters + ---------- + benchmark: dict + The benchmark config. + device: str + The device name + + Returns + ------- + repeat: int + The repeat number. + """ + + device = device or self._device + repeat = benchmark.get("repeat", -1) + if repeat == -1: + repeat = 500 if device.startswith("cuda") else 10 + return repeat + @property def runner(self): return self._runner diff --git a/src/contrib/msc/core/codegen/codegen_utils.h b/src/contrib/msc/core/codegen/codegen_utils.h index bd5d543dc2b1..1af8df5ac1a4 100644 --- a/src/contrib/msc/core/codegen/codegen_utils.h +++ b/src/contrib/msc/core/codegen/codegen_utils.h @@ -40,8 +40,9 @@ namespace msc { using namespace tvm::script::printer; #define CODEGEN_CONFIG_MEMBERS \ - bool is_train{false}; \ + bool training{false}; \ bool use_tools{false}; \ + bool use_plugin{false}; \ bool need_test{true}; \ std::string tools_scope{""}; \ std::string tools_tag{"main"}; \ @@ -51,10 +52,12 @@ using namespace tvm::script::printer; std::vector version{0, 0, 0}; #define CODEGEN_CONFIG_PARSE \ - if (key == "is_train") { \ - reader->Read(&is_train); \ + if (key == "training") { \ + reader->Read(&training); \ } else if (key == "use_tools") { \ reader->Read(&use_tools); \ + } else if (key == "use_plugin") { \ + reader->Read(&use_plugin); \ } else if (key == "need_test") { \ reader->Read(&need_test); \ } else if (key == "tools_scope") { \ diff --git a/src/contrib/msc/core/codegen/py_codegen.h b/src/contrib/msc/core/codegen/py_codegen.h index 6b3120affe32..e1ceb716a278 100644 --- a/src/contrib/msc/core/codegen/py_codegen.h +++ b/src/contrib/msc/core/codegen/py_codegen.h @@ -86,7 +86,7 @@ class PyCodeGen : public BaseCodeGen { virtual void CodeGenHeader() { this->stack_.line("import os") .line("import numpy as np") - .line("from typing import List, Dict") + .line("from typing import List, Dict, Any") .line("import tvm"); if (this->config()->use_tools) { this->stack_.line("from tvm.contrib.msc.core import tools as msc_tools"); diff --git a/src/contrib/msc/core/ir/graph.cc b/src/contrib/msc/core/ir/graph.cc index 2cdb326e77a0..71f3208db94d 100644 --- a/src/contrib/msc/core/ir/graph.cc +++ b/src/contrib/msc/core/ir/graph.cc @@ -199,6 +199,25 @@ bool BaseJointNode::GetAttr(const String& key, bool* val) const { return false; } +bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { + std::string val_str; + if (GetAttr(key, &val_str)) { + int pos = val_str.find(","); + if (pos < 0) { + return false; + } + try { + for (const auto& s : StringUtils::Split(val_str, ",")) { + (*val).push_back(std::string(s)); + } + return true; + } catch (const std::exception&) { + return false; + } + } + return false; +} + bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { std::string val_str; if (GetAttr(key, &val_str)) { @@ -255,6 +274,25 @@ bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { return false; } +bool BaseJointNode::GetAttr(const String& key, std::vector* val) const { + std::string val_str; + if (GetAttr(key, &val_str)) { + int pos = val_str.find(","); + if (pos < 0) { + return false; + } + try { + for (const auto& s : StringUtils::Split(val_str, ",")) { + (*val).push_back(std::stoi(s) != 0); + } + return true; + } catch (const std::exception&) { + return false; + } + } + return false; +} + MSCJoint::MSCJoint(int index, const String& name, const String& shared_ref, const String& optype, const Map& attrs, const Array& scope, const std::vector>& inputs, diff --git a/src/contrib/msc/core/ir/graph.h b/src/contrib/msc/core/ir/graph.h index fbcdeb4d0c3d..7005518f367b 100644 --- a/src/contrib/msc/core/ir/graph.h +++ b/src/contrib/msc/core/ir/graph.h @@ -404,9 +404,11 @@ class BaseJointNode : public Object { bool GetAttr(const String& key, int64_t* val) const; bool GetAttr(const String& key, float* val) const; bool GetAttr(const String& key, bool* val) const; + bool GetAttr(const String& key, std::vector* val) const; bool GetAttr(const String& key, std::vector* val) const; bool GetAttr(const String& key, std::vector* val) const; bool GetAttr(const String& key, std::vector* val) const; + bool GetAttr(const String& key, std::vector* val) const; /*! \brief Check and get the attribute by type. */ template const T GetTypeAttr(const String& key) const { diff --git a/src/contrib/msc/core/ir/graph_builder.cc b/src/contrib/msc/core/ir/graph_builder.cc index dab4ae813ea6..02b5a2ee671a 100644 --- a/src/contrib/msc/core/ir/graph_builder.cc +++ b/src/contrib/msc/core/ir/graph_builder.cc @@ -101,6 +101,31 @@ void RelaxFuncParamsFinder::VisitExpr_(const relax::CallNode* call_node) { } } +void RelaxLayoutsFinder::VisitBinding_(const relax::VarBindingNode* binding, + const relax::FunctionNode* val) { + local_funcs_.Set(binding->var, GetRef(val)); +} + +void RelaxLayoutsFinder::VisitExpr_(const relax::CallNode* call_node) { + RelaxExprVisitor::VisitExpr_(call_node); + relax::Function func; + if (const auto* v_node = call_node->op.as()) { + func = Downcast(ref_module_->Lookup(v_node->name_hint)); + VisitExpr(func); + } else if (call_node->op->IsInstance()) { + ICHECK(local_funcs_.count(call_node->op)) << "Can not find local func " << call_node->op; + func = local_funcs_[call_node->op]; + } + if (func.defined()) { + const auto& layouts_opt = func->GetAttr>(msc_attr::kInputLayouts); + if (layouts_opt.defined()) { + for (const auto& pair : layouts_opt.value()) { + layouts_.Set(pair.first, pair.second); + } + } + } +} + const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { // Add input nodes and record inputs; Array input_names, output_names; @@ -109,6 +134,9 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { if (expr_tensor_map_.count(p)) { continue; } + if (func_params_.count(p) && func_params_[p]->IsInstance()) { + continue; + } if (func_params_.count(p) && func_params_[p]->IsInstance()) { const auto& tuple = Downcast(func_params_[p]); Array tuple_names; @@ -202,18 +230,15 @@ const MSCGraph RelaxGraphBuilder::Build(const relax::Function& func) { const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional& binding_var, const String& name) { - String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, "name"); - const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref"); - - // Get optype and node_name - String optype; - if (expr->IsInstance()) { - if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { - optype = "constant"; - node_name = SpanUtils::GetAttr(func_params_[expr]->span, "name"); - } else { - optype = "input"; - } + // Get optype, node_name and layout + String node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); + String optype = "unknown"; + String layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); + if (func_params_.count(expr) && func_params_[expr]->IsInstance()) { + node_name = SpanUtils::GetAttr(func_params_[expr]->span, msc_attr::kName); + optype = "constant"; + } else if (expr->IsInstance()) { + optype = "input"; } else if (expr->IsInstance()) { optype = "constant"; } else if (expr->IsInstance()) { @@ -224,44 +249,54 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional optype = "tuple"; } else if (const auto* call_node = expr.as()) { if (const auto* op_node = call_node->op.as()) { - optype = StringUtils::Replace(op_node->name, "relax.", ""); - } else if (const auto* v_node = call_node->op.as()) { - const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& byoc_name_opt = func->GetAttr("byoc_name"); - if (byoc_name_opt.defined()) { - node_name = byoc_name_opt.value(); - } - const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); - if (codegen_opt.defined()) { - optype = codegen_opt.value(); + if (op_node->name == "relax.call_dps_packed") { + optype = Downcast(call_node->args[0])->global_symbol; } else { - const auto& name_opt = func->GetAttr(relax::attr::kComposite); - ICHECK(name_opt.defined()) << "Unexpected global func without composite"; - optype = name_opt.value(); + optype = StringUtils::Replace(op_node->name, "relax.", ""); } + } else if (const auto* v_node = call_node->op.as()) { + const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); + std::tie(node_name, optype, layout) = ParseFunc(func); } else if (call_node->op->IsInstance()) { ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; - const auto& func = target_funcs_[call_node->op]; - const auto& name_opt = func->GetAttr(relax::attr::kComposite); - optype = StringUtils::Replace(name_opt.value(), config_.target + ".", ""); - } else if (const auto* f_node = call_node->op.as()) { - const auto& name_opt = f_node->GetAttr(relax::attr::kComposite); - ICHECK(name_opt.defined()) << "Unexpected func without composite"; - optype = name_opt.value(); - } else { - optype = "unknown_op"; + std::tie(node_name, optype, layout) = ParseFunc(target_funcs_[call_node->op]); + } else if (call_node->op->IsInstance()) { + std::tie(node_name, optype, layout) = ParseFunc(Downcast(call_node->op)); } - } else { - optype = "unknown_expr"; + } + if (layouts_.count(node_name)) { + layout = layouts_[node_name]; } - // Extract attributes + // get plugin + const auto& plugin = IsPlugin(optype) ? GetPlugin(optype) : Plugin(); + + // Extract normal attributes Map attrs; - if (const auto* call_node = expr.as()) { + if (plugin.defined()) { + const auto& op = Downcast(expr)->op; + if (target_funcs_.count(op)) { + const auto& opattrs_opt = target_funcs_[op]->GetAttr>(msc_attr::kOpattrs); + if (opattrs_opt.defined()) { + const auto& opattrs = opattrs_opt.value(); + ICHECK_EQ(opattrs.size(), plugin->attrs.size()) + << "opattrs " << opattrs << " size mismatch with " << plugin->attrs.size(); + for (size_t i = 0; i < opattrs.size(); i++) { + attrs.Set(plugin->attrs[i]->name, opattrs[i]); + } + } + } else { + const auto& args = GetPluginInputs(expr); + for (size_t i = 0; i < plugin->attrs.size(); i++) { + const auto& val = args[plugin->inputs.size() + i]; + attrs.Set(plugin->attrs[i]->name, StringUtils::ToString(val)); + } + } + } else if (const auto* call_node = expr.as()) { if (const auto* v_node = call_node->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); - const auto& byoc_name_opt = func->GetAttr("byoc_name"); - if (!byoc_name_opt.defined()) { + const auto& name_opt = func->GetAttr(relax::attr::kComposite); + if (name_opt.defined()) { attrs = RelaxFuncAttrGetter().GetAttrs(func); } } else if (call_node->op->IsInstance()) { @@ -283,40 +318,59 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional attrs.Set("index", std::to_string(get_node->index)); } - // Get scope - Array scope; - if (optype != "input" && optype != "constant") { - scope = StringUtils::Split(scope_name_, "."); - } - // Build inputs and weights - Array input_names; - Map node_weights; - if (const auto* call_node = expr.as()) { - Array prim_values; - if (call_node->op->IsInstance()) { - ICHECK(target_funcs_.count(call_node->op)) << "Can not find target func: " << call_node->op; - prim_values = RelaxFuncValueGetter().GetValues(target_funcs_[call_node->op]); + // Extract attributes from arguments + Array input_types; + if (!plugin.defined() && expr->IsInstance()) { + const auto& call = Downcast(expr); + Array values; + if (call->op->IsInstance()) { + ICHECK(target_funcs_.count(call->op)) << "Can not find target func: " << call->op; + values = RelaxFuncValueGetter().GetValues(target_funcs_[call->op]); } - const auto& input_types = - ExprUtils::GetInputTypes(optype, call_node->args.size() + prim_values.size(), true); - for (size_t i = 0; i < call_node->args.size(); i++) { - const auto& arg = call_node->args[i]; + input_types = ExprUtils::GetInputTypes(optype, call->args.size() + values.size(), true); + for (size_t i = 0; i < call->args.size(); i++) { + const auto& arg = call->args[i]; if (const auto* s_node = arg.as()) { attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); - continue; - } - if (func_params_.count(arg) && func_params_[arg]->IsInstance()) { + } else if (func_params_.count(arg) && func_params_[arg]->IsInstance()) { const auto* s_node = func_params_[arg].as(); attrs.Set(input_types[i], StringUtils::ToString(s_node->values)); ignore_nodes_.insert(Downcast(arg)->name_hint()); - continue; - } - if (const auto* s_node = arg.as()) { + } else if (const auto* s_node = arg.as()) { ICHECK(input_types[i] != "input") << i << " th PrimValue of " << optype << " should has special type, get " << input_types; attrs.Set(input_types[i], StringUtils::ToString(s_node->value)); + } + } + for (size_t i = call->args.size(); i < input_types.size(); i++) { + attrs.Set(input_types[i], values[i - call->args.size()]); + } + } + + // Build inputs and weights + Array input_names; + Map node_weights; + if (plugin.defined()) { + const auto& call = Downcast(expr); + if (call->args.size() == 1) { + ICHECK(expr_tensor_map_.count(call->args[0])) + << "Can not find tuple plugin input " << call->args[0]; + input_names = expr_tensor_map_[call->args[0]]; + } else { + const auto& args = GetPluginInputs(expr); + for (size_t i = 0; i < plugin->inputs.size(); i++) { + ICHECK(expr_tensor_map_.count(args[i])) << "Can not find plugin input " << args[i]; + for (const auto& in_name : expr_tensor_map_[args[i]]) { + input_names.push_back(in_name); + } + } + } + } else if (const auto* call_node = expr.as()) { + for (size_t i = 0; i < call_node->args.size(); i++) { + if (attrs.count(input_types[i])) { continue; } + const auto& arg = call_node->args[i]; Array arg_names; if (expr_tensor_map_.count(arg)) { arg_names = expr_tensor_map_[arg]; @@ -330,10 +384,10 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } String weight_name; if (input_types[i] != "input" && arg->IsInstance()) { - weight_name = SpanUtils::GetAttr(arg->span, "name"); + weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); } else if (input_types[i] != "input" && func_params_.count(arg) && func_params_[arg]->IsInstance()) { - weight_name = SpanUtils::GetAttr(func_params_[arg]->span, "name"); + weight_name = SpanUtils::GetAttr(func_params_[arg]->span, msc_attr::kName); ignore_nodes_.insert(Downcast(arg)->name_hint()); } // set weights or inputs @@ -370,10 +424,6 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } } } - // add prim values to attributes - for (size_t i = call_node->args.size(); i < input_types.size(); i++) { - attrs.Set(input_types[i], prim_values[i - call_node->args.size()]); - } } else if (const auto* tuple_node = expr.as()) { for (const auto& f : tuple_node->fields) { ICHECK(expr_tensor_map_.count(f)) << "Can not find tuple field " << f; @@ -387,11 +437,10 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional input_names = expr_tensor_map_[getitem_node->tuple]; } else if (optype == "constant") { const auto& t_info = Downcast(relax::GetStructInfo(expr)); - const auto& opt_shape = t_info->GetShape(); - ICHECK(opt_shape.defined()) << "Constant shape is not defined"; - const auto& layout = SpanUtils::GetAttr(expr->span, "layout"); + const auto& shape_opt = t_info->GetShape(); + ICHECK(shape_opt.defined()) << "Constant shape is not defined"; const auto& weight = - MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast(opt_shape.value())); + MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast(shape_opt.value())); node_weights.Set("const", weight); } std::vector> inputs; @@ -399,47 +448,69 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional inputs.push_back(tensor_input_map_[i]); } - // Build outputs + // Redefine layout for special ops + if (optype == "tuple") { + layout = ""; + for (size_t i = 0; i < inputs.size(); i++) { + const auto& in_tensor = Downcast(inputs[i].first)->OutputAt(inputs[i].second); + layout = layout + in_tensor->layout.name(); + layout = layout + (i == inputs.size() - 1 ? "" : ","); + } + } else if (optype == "get_item") { + int idx = std::stoi(attrs["index"]); + const auto& in_tensor = Downcast(inputs[idx].first)->OutputAt(inputs[idx].second); + layout = in_tensor->layout.name(); + } + + // Build output tensor + auto build_output = [](const relax::StructInfo& sinfo, const String& node_name, + const String& layout) { + ICHECK(sinfo->IsInstance()) + << "sinfo should be TensorStructInfo, get " << sinfo->GetTypeKey(); + const auto& t_info = Downcast(sinfo); + const auto& shape_opt = t_info->GetShape(); + const auto& shape = + shape_opt.defined() ? ArrayUtils::Cast(shape_opt.value()) : Array(); + return MSCTensor(node_name, t_info->dtype, layout, shape); + }; + + // Gather outputs Array outputs; - const auto& layout = SpanUtils::GetAttr(expr->span, "layout"); const auto& sinfo = relax::GetStructInfo(expr); - if (const auto* t_info = sinfo.as()) { - const auto& opt_shape = t_info->GetShape(); - const auto& shape = - opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : Array(); - const auto& output = - MSCTensor(node_name + ":" + std::to_string(0), t_info->dtype, layout, shape); - outputs.push_back(output); + Array layouts = StringUtils::Split(layout, ","); + size_t num_output = 1; + if (const auto* tuple_sinfo = sinfo.as()) { + num_output = tuple_sinfo->fields.size(); + } + if (layouts.size() == 0) { + layouts = Array(num_output, ""); + } + ICHECK_EQ(layouts.size(), num_output) + << "Layouts " << layouts << " msimatch with output size " << num_output; + if (sinfo->IsInstance()) { + const auto& t_name = node_name + ":" + std::to_string(0); + outputs.push_back(build_output(sinfo, t_name, layouts[0])); } else if (const auto* s_sinfo = sinfo.as()) { Array shape{s_sinfo->ndim}; - const auto& output = MSCTensor(node_name + ":" + std::to_string(0), - DataType(runtime::String2DLDataType("int32")), layout, shape); - outputs.push_back(output); + const auto& t_name = node_name + ":" + std::to_string(0); + const auto& dtype = DataType(runtime::String2DLDataType("int32")); + outputs.push_back(MSCTensor(t_name, dtype, layouts[0], shape)); } else if (const auto* tuple_sinfo = sinfo.as()) { - Array layouts = StringUtils::Split(layout, ","); - if (layouts.size() == 0) { - layouts = Array(tuple_sinfo->fields.size(), ""); - } - ICHECK_EQ(layouts.size(), tuple_sinfo->fields.size()) - << "Layout " << layout << " msimatch with fileds size " << tuple_sinfo->fields.size(); - size_t field_size = tuple_sinfo->fields.size(); - if (optype == "nn.batch_norm") { - field_size = 1; - } + size_t field_size = optype == "nn.batch_norm" ? 1 : num_output; for (size_t i = 0; i < field_size; i++) { - const auto& t_info = Downcast(tuple_sinfo->fields[i]); - const auto& opt_shape = t_info->GetShape(); - const auto& shape = - opt_shape.defined() ? ArrayUtils::Cast(opt_shape.value()) : Array(); - const auto& output = - MSCTensor(node_name + ":" + std::to_string(i), t_info->dtype, layouts[i], shape); - outputs.push_back(output); + const auto& t_name = node_name + ":" + std::to_string(i); + outputs.push_back(build_output(tuple_sinfo->fields[i], t_name, layouts[i])); } } else { LOG(FATAL) << "Unexpected struct info (" << sinfo->GetTypeKey() << ")" << sinfo; } // Build node + Array scope; + if (optype != "input" && optype != "constant") { + scope = StringUtils::Split(scope_name_, "."); + } + const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef); const auto& node = MSCJoint(nodes_.size(), node_name, shared_ref, optype, attrs, scope, inputs, outputs, node_weights); Array output_names; @@ -454,8 +525,23 @@ const MSCJoint RelaxGraphBuilder::AddNode(const Expr& expr, const Optional } void RelaxGraphBuilder::VisitBindingBlock(const relax::BindingBlock& block) { - scope_name_ = SpanUtils::GetAttr(block->span, "name"); + String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); + if (block_name.size() == 0) { + block_name = "block"; + } + const String& prefix = StringUtils::Join(block_stack_, "."); + if (setted_blocks_.count(prefix + "." + block_name)) { + int cnt = 1; + while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { + cnt++; + } + block_name = block_name + "_" + std::to_string(cnt); + } + scope_name_ = prefix + "." + block_name; + setted_blocks_.insert(scope_name_); + block_stack_.push_back(block_name); RelaxExprVisitor::VisitBindingBlock(block); + block_stack_.pop_back(); } void RelaxGraphBuilder::VisitExpr_(const relax::ConstantNode* op) { @@ -526,14 +612,50 @@ void RelaxGraphBuilder::VisitBinding_(const relax::VarBindingNode* binding, target_funcs_.Set(binding->var, GetRef(val)); } +const std::tuple RelaxGraphBuilder::ParseFunc(const relax::Function& func) { + String node_name, optype, layout; + const auto& name_opt = func->GetAttr(msc_attr::kUnique); + // get node_name + if (name_opt.defined()) { + node_name = name_opt.value(); + } + // get optype + const auto& codegen_opt = func->GetAttr(relax::attr::kCodegen); + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + const auto& composite_opt = func->GetAttr(relax::attr::kComposite); + if (codegen_opt.defined()) { + optype = codegen_opt.value(); + } else if (optype_opt.defined()) { + optype = optype_opt.value(); + } else if (composite_opt.defined()) { + optype = composite_opt.value(); + if (config_.target.size() > 0) { + optype = StringUtils::Replace(composite_opt.value(), config_.target + ".", ""); + } + } + // get layout + const auto& layout_opt = func->GetAttr(msc_attr::kLayout); + if (layout_opt.defined()) { + layout = layout_opt.value(); + } + return std::make_tuple(node_name, optype, layout); +} + +Array RelaxGraphBuilder::GetPluginInputs(const relax::Expr& expr) { + ICHECK(expr->IsInstance()) << "plugin expr should be call"; + const auto& call = Downcast(expr); + ICHECK(call->args[1]->IsInstance()) << "plugin argument 1 should be call"; + return Downcast(call->args[1])->fields; +} + Map RelaxWeightsExtractor::GetWeights(const relax::Function& func) { VisitExpr(func); return weights_; } void RelaxWeightsExtractor::VisitExpr_(const relax::ConstantNode* op) { - const auto& name = SpanUtils::GetAttr(op->span, "name"); - const auto& layout = SpanUtils::GetAttr(op->span, "layout"); + const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); + const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); const auto& sinfo = relax::GetStructInfo(GetRef(op)); ICHECK(sinfo->IsInstance()) << "Constant StrcutInfo should be TensorStructInfo"; @@ -616,8 +738,8 @@ MSCGraph RelayGraphBuilder::Build(const relay::Function& func) { } MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) { - const auto& node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, "name"); - const auto& shared_ref = SpanUtils::GetAttr(expr->span, "shared_ref"); + const auto& node_name = name.size() > 0 ? name : SpanUtils::GetAttr(expr->span, msc_attr::kName); + const auto& shared_ref = SpanUtils::GetAttr(expr->span, msc_attr::kSharedRef); // Get optype String optype; @@ -676,7 +798,7 @@ MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) { ICHECK(expr_tensor_map_.count(arg)) << "Missing argument " << arg; if (input_types[i] != "input" && arg->IsInstance()) { const auto& t_name = expr_tensor_map_[arg][0]; - const auto& weight_name = SpanUtils::GetAttr(arg->span, "name"); + const auto& weight_name = SpanUtils::GetAttr(arg->span, msc_attr::kName); const auto& pair = tensor_input_map_[t_name]; const auto& producer = Downcast(pair.first); if (!weights_.count(weight_name)) { @@ -738,7 +860,7 @@ MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) { ICHECK(checked_type.defined() && checked_type->IsInstance()) << "Constant checked_type is not defined"; const auto& t_info = Downcast(checked_type); - const auto& layout = SpanUtils::GetAttr(expr->span, "layout"); + const auto& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); const auto& weight = MSCTensor(node_name, t_info->dtype, layout, ArrayUtils::Cast(t_info->shape)); node_weights.Set("const", weight); @@ -750,7 +872,7 @@ MSCJoint RelayGraphBuilder::AddNode(const Expr& expr, const String& name) { // Build outputs Array outputs; - const auto& layout = SpanUtils::GetAttr(expr->span, "layout"); + const auto& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); Type checked_type = expr->checked_type_; if (checked_type.defined() && checked_type->IsInstance()) { checked_type = Downcast(checked_type)->ret_type; @@ -807,7 +929,7 @@ void RelayGraphBuilder::VisitExpr_(const relay::ConstantNode* op) { void RelayGraphBuilder::VisitExpr_(const relay::FunctionNode* op) { const auto& name_opt = op->GetAttr(relay::attr::kComposite); if (name_opt.defined()) { - StartFuncScope(SpanUtils::GetAttr(op->span, "name")); + StartFuncScope(SpanUtils::GetAttr(op->span, msc_attr::kName)); } RelayExprVisitor::VisitExpr_(op); if (HasFuncScope()) { @@ -872,8 +994,8 @@ Map RelayWeightsExtractor::GetWeights(const relay::Function& } void RelayWeightsExtractor::VisitExpr_(const relay::ConstantNode* op) { - const auto& name = SpanUtils::GetAttr(op->span, "name"); - const auto& layout = SpanUtils::GetAttr(op->span, "layout"); + const auto& name = SpanUtils::GetAttr(op->span, msc_attr::kName); + const auto& layout = SpanUtils::GetAttr(op->span, msc_attr::kLayout); const auto& t_info = op->tensor_type(); const auto& shape = ArrayUtils::Cast(t_info->shape); const auto& weight = MSCTensor(name, t_info->dtype, layout, shape); diff --git a/src/contrib/msc/core/ir/graph_builder.h b/src/contrib/msc/core/ir/graph_builder.h index 6edd77a2e4a9..4b042c5617e4 100644 --- a/src/contrib/msc/core/ir/graph_builder.h +++ b/src/contrib/msc/core/ir/graph_builder.h @@ -35,12 +35,14 @@ #include #include #include +#include #include #include #include #include "../utils.h" #include "graph.h" +#include "plugin.h" namespace tvm { namespace contrib { @@ -150,7 +152,7 @@ class RelaxFuncAttrGetter : public RelaxExprVisitor { public: /*! \brief Get the attributes as Map*/ Map GetAttrs(const Expr& expr) { - RelaxExprVisitor::VisitExpr(expr); + VisitExpr(expr); return attrs_; } @@ -166,7 +168,7 @@ class RelaxFuncValueGetter : public RelaxExprVisitor { public: /*! \brief Get the attributes from prim value as Map*/ Array GetValues(const Expr& expr) { - RelaxExprVisitor::VisitExpr(expr); + VisitExpr(expr); return values_; } @@ -179,7 +181,7 @@ class RelaxFuncValueGetter : public RelaxExprVisitor { class RelaxFuncParamsFinder : public RelaxExprVisitor { public: /*! - * \brief The constructor of RelaxGraphBuilder + * \brief The constructor of RelaxFuncParamsFinder * \param ref_module the reference module. */ explicit RelaxFuncParamsFinder(const IRModule& ref_module) : RelaxExprVisitor() { @@ -188,7 +190,7 @@ class RelaxFuncParamsFinder : public RelaxExprVisitor { /*! \brief Find the func params and bind with arguments*/ Map FindParams(const Expr& expr) { - RelaxExprVisitor::VisitExpr(expr); + VisitExpr(expr); return params_; } @@ -202,6 +204,32 @@ class RelaxFuncParamsFinder : public RelaxExprVisitor { Map local_funcs_; }; +class RelaxLayoutsFinder : public RelaxExprVisitor { + public: + /*! + * \brief The constructor of RelaxLayoutsFinder + * \param ref_module the reference module. + */ + explicit RelaxLayoutsFinder(const IRModule& ref_module) : RelaxExprVisitor() { + ref_module_ = ref_module; + } + + /*! \brief Find the layouts form attrs*/ + Map FindLayouts(const Expr& expr) { + VisitExpr(expr); + return layouts_; + } + + void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; + + void VisitExpr_(const relax::CallNode* op) final; + + private: + IRModule ref_module_; + Map layouts_; + Map local_funcs_; +}; + class RelaxGraphBuilder : public RelaxExprVisitor { public: /*! @@ -223,6 +251,7 @@ class RelaxGraphBuilder : public RelaxExprVisitor { if (config_.byoc_entry.size() > 0) { func_params_ = RelaxFuncParamsFinder(ref_module).FindParams(ref_module->Lookup(name)); } + layouts_ = RelaxLayoutsFinder(ref_module).FindLayouts(ref_module->Lookup(name)); } /*! \brief Build MSCGraph from relax function*/ @@ -257,15 +286,25 @@ class RelaxGraphBuilder : public RelaxExprVisitor { void VisitBinding_(const relax::VarBindingNode* binding, const relax::FunctionNode* val) final; private: + /*! \brief Get the node_name, optype, layout for func*/ + const std::tuple ParseFunc(const relax::Function& func); + + /*! \brief Get the plugin inputs*/ + Array GetPluginInputs(const relax::Expr& expr); + String name_; - String scope_name_; IRModule ref_module_; MSCRBuildConfig config_; + Map layouts_; Array nodes_; Map weights_; Map> expr_tensor_map_; std::unordered_map> tensor_input_map_; std::set ignore_nodes_; + // scope name + String scope_name_; + std::set setted_blocks_; + Array block_stack_; // BYOC maps Map target_funcs_; Map func_params_; diff --git a/src/contrib/msc/core/transform/fuse_tuple.cc b/src/contrib/msc/core/transform/fuse_tuple.cc index 37baff4fa7ba..be1a10718c98 100644 --- a/src/contrib/msc/core/transform/fuse_tuple.cc +++ b/src/contrib/msc/core/transform/fuse_tuple.cc @@ -67,20 +67,29 @@ class TupleFuser : public ExprMutator { } void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { - bool is_tuple_call = false; + bool has_tuple_arg = false; if (target_funcs_.count(val->op)) { - if (val->args.size() == 1 && val->args[0]->IsInstance()) { - const auto& func_call = AddFunc(val->args[0]); - const auto& tuple_out = builder_->Emit(func_call); - ICHECK(target_funcs_.count(func_call->op)) << "Can not find target func " << func_call->op; - target_funcs_.Set(tuple_out, target_funcs_[func_call->op]); - const auto& new_call = Call(val->op, {tuple_out}, val->attrs, val->sinfo_args, val->span); - ReEmitBinding(binding, builder_->Normalize(new_call)); - is_tuple_call = true; + Array new_args; + for (const auto& arg : val->args) { + if (arg->IsInstance()) { + const auto& func_call = AddFunc(arg); + const auto& tuple_out = builder_->Emit(func_call); + ICHECK(target_funcs_.count(func_call->op)) + << "Can not find target func " << func_call->op; + target_funcs_.Set(tuple_out, target_funcs_[func_call->op]); + has_tuple_arg = true; + new_args.push_back(tuple_out); + } else { + new_args.push_back(arg); + } + if (has_tuple_arg) { + const auto& new_call = Call(val->op, new_args, val->attrs, val->sinfo_args, val->span); + ReEmitBinding(binding, builder_->Normalize(new_call)); + } } target_funcs_.Set(binding->var, target_funcs_[val->op]); } - if (!is_tuple_call) { + if (!has_tuple_arg) { ExprMutator::VisitBinding_(binding, val); } } @@ -150,10 +159,11 @@ class TupleFuser : public ExprMutator { BindingBlock new_block = builder_->EndBlock(); Expr body = builder_->Normalize(output); body = builder_->Normalize(SeqExpr({new_block}, body)); + Map func_attrs; - func_attrs.Set(tvm::relax::attr::kComposite, target_ + func_name); - func_attrs.Set(tvm::relax::attr::kPrimitive, Integer(1)); - func_attrs.Set("unique_name", SpanUtils::GetAttr(expr->span, "name")); + func_attrs.Set(attr::kPrimitive, Integer(1)); + func_attrs.Set(attr::kComposite, target_ + func_name); + func_attrs.Set(msc_attr::kUnique, SpanUtils::GetAttr(expr->span, msc_attr::kName)); Function function = Function(/*params=*/params, // /*body=*/body, // diff --git a/src/contrib/msc/core/transform/inline_params.cc b/src/contrib/msc/core/transform/inline_params.cc new file mode 100644 index 000000000000..5e5ac113ef56 --- /dev/null +++ b/src/contrib/msc/core/transform/inline_params.cc @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/contrib/msc/core/transform/inline_params.cc + * \brief Pass for inline Exprs. + */ + +#include +#include +#include + +#include "../../../../relax/transform/utils.h" +#include "../utils.h" + +namespace tvm { +namespace relax { + +using namespace tvm::contrib::msc; + +/*! + * \brief Inline the exprs + */ +class ParamsInliner : public ExprMutator { + public: + explicit ParamsInliner(IRModule ctx_module, const String& entry_name) : ExprMutator(ctx_module) { + mod_ = ctx_module; + entry_name_ = entry_name; + } + + IRModule Bind() { + // update global functions + GlobalVar main_var; + for (const auto& [gv, func] : mod_->functions) { + if (gv->name_hint == entry_name_) { + main_var = gv; + continue; + } + if (func->IsInstance()) { + Array new_params; + Array attrs; + for (const auto& p : Downcast(func)->params) { + auto struct_info = GetStructInfo(p); + if (struct_info->IsInstance()) { + continue; + } + if (struct_info->IsInstance()) { + const auto& optype_opt = func->GetAttr(msc_attr::kOptype); + ICHECK(optype_opt.defined()) + << "Can not find attr " << msc_attr::kOptype << " form extern func"; + extern_types_.Set(p, optype_opt.value()); + continue; + } + if (const auto* tuple_info = struct_info.as()) { + Array new_fields; + for (const auto& i : tuple_info->fields) { + if (i->IsInstance()) { + new_fields.push_back(i); + } else if (const auto& p_info = i.as()) { + ICHECK(p_info->value.defined()) << "PrimStructInfo with undefined prim value " << i; + attrs.push_back(StringUtils::ToString(p_info->value.value())); + } + } + if (new_fields.size() < tuple_info->fields.size()) { + p->struct_info_ = TupleStructInfo(new_fields, tuple_info->span); + } + } + new_params.push_back(p); + } + if (new_params.size() == Downcast(func)->params.size()) { + continue; + } + const auto& new_func = Downcast(VisitExpr(func)); + Map func_attrs = new_func->attrs->dict; + if (attrs.size() > 0) { + func_attrs.Set(msc_attr::kOpattrs, attrs); + } + auto updated_func = Function(new_params, new_func->body, new_func->ret_struct_info, + new_func->is_pure, DictAttrs(func_attrs), new_func->span); + builder_->UpdateFunction(gv, updated_func); + } + } + // update main + ICHECK(main_var.defined()) << "Can not find entry func " << entry_name_; + const auto& new_func = Downcast(VisitExpr(mod_->Lookup(entry_name_))); + builder_->UpdateFunction(main_var, new_func); + return builder_->GetContextIRModule(); + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* call_node) final { + Array new_args; + bool has_inline = false; + for (const auto& a : call_node->args) { + auto struct_info = GetStructInfo(a); + if (a->IsInstance() && struct_info->IsInstance()) { + ICHECK(extern_types_.count(a)) << "Can not find extern type of " << a; + new_args.push_back(ExternFunc(extern_types_[a])); + has_inline = true; + } else if (call_node->op->IsInstance() && a->IsInstance()) { + has_inline = true; + } else if (a->IsInstance() && struct_info->IsInstance()) { + const auto& shape_opt = Downcast(GetStructInfo(a))->values; + ICHECK(shape_opt.defined()) << "Expected shape defined, get " << a; + new_args.push_back(ShapeExpr(shape_opt.value())); + has_inline = true; + } else if (call_node->op->IsInstance() && a->IsInstance()) { + has_inline = true; + } else if (call_node->op->IsInstance() && a->IsInstance()) { + const auto& tuple = Downcast(a); + Array new_fields; + Array new_infos; + + for (const auto& f : tuple->fields) { + if (f->IsInstance()) { + new_fields.push_back(f); + new_infos.push_back(GetStructInfo(f)); + } + } + if (new_fields.size() == tuple->fields.size()) { + new_args.push_back(a); + } else { + const auto& new_tuple = Tuple(new_fields, tuple->span); + new_tuple->struct_info_ = TupleStructInfo(new_infos); + new_args.push_back(new_tuple); + } + } else { + new_args.push_back(a); + } + } + if (!has_inline) { + ExprMutator::VisitBinding_(binding, call_node); + } else if (call_node->op->IsInstance()) { + const auto& new_call = + Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + ReEmitBinding(binding, builder_->Normalize(new_call)); + } else if (const auto* gv_node = call_node->op.as()) { + const auto& func_info = Downcast(gv_node->struct_info_); + Array params_info; + for (const auto& a : new_args) { + ICHECK(a->struct_info_.defined()) + << "Global func argument without defined struct info " << a; + params_info.push_back(Downcast(a->struct_info_.value())); + } + call_node->op->struct_info_ = + FuncStructInfo(params_info, func_info->ret, func_info->purity, func_info->span); + const auto& new_call = + Call(call_node->op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + ReEmitBinding(binding, builder_->Normalize(new_call)); + } else { + LOG_FATAL << "Unexpected shape consumer " << GetRef(call_node); + } + } + + private: + IRModule mod_; + String entry_name_; + Map extern_types_; +}; + +IRModule InlineParams(IRModule mod, const String& entry_name) { + return ParamsInliner(mod, entry_name).Bind(); +} + +namespace transform { + +Pass InlineParams(const String& entry_name) { + runtime::TypedPackedFunc pass_func = + [=](IRModule m, PassContext pc) { return relax::InlineParams(m, entry_name); }; + return CreateModulePass(pass_func, 0, "InlineParams", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.InlineParams").set_body_typed(InlineParams); + +} // namespace transform +} // namespace relax +} // namespace tvm diff --git a/src/contrib/msc/core/transform/layout_utils.cc b/src/contrib/msc/core/transform/layout_utils.cc index 4c9eeb2ec176..317a39ab4e1a 100644 --- a/src/contrib/msc/core/transform/layout_utils.cc +++ b/src/contrib/msc/core/transform/layout_utils.cc @@ -57,15 +57,17 @@ LayoutDecision LayoutUtils::InferLayoutDecisionAt(const Expr& expr, } bool LayoutUtils::LayoutInfered(const Expr& expr) { - const String& layout = SpanUtils::GetAttr(expr->span, "layout"); + const String& layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); return layout.size() > 0; } bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { - const String& saved_layout = SpanUtils::GetAttr(expr->span, "layout"); + const String& saved_layout = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); const auto& sinfo = GetStructInfo(expr); if (sinfo->IsInstance() || sinfo->IsInstance()) { - ICHECK(layout.IsLeaf()) << "Expr has tensor struct, but find nested layout " << expr; + if (!layout.IsLeaf()) { + return false; + } const auto& l_layout = layout.LeafValue()->layout; if (!l_layout.defined()) { return false; @@ -73,14 +75,17 @@ bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { if (saved_layout == l_layout.name()) { return false; } - expr->span = SpanUtils::SetAttr(expr->span, "layout", l_layout.name()); + expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kLayout, l_layout.name()); } else if (sinfo->IsInstance()) { - ICHECK(!layout.IsLeaf()) << "Expr has tuple struct, but find non-nested layout " << expr; + if (layout.IsLeaf()) { + return false; + } String layout_str; Array nested_layouts = layout.NestedArray(); for (size_t i = 0; i < nested_layouts.size(); i++) { - ICHECK(nested_layouts[i].IsLeaf()) - << "Expr input[" << i << "] has tensor struct, but find nested layout " << expr; + if (!nested_layouts[i].IsLeaf()) { + return false; + } const auto& l_layout = nested_layouts[i].LeafValue()->layout; if (!l_layout.defined()) { return false; @@ -90,7 +95,7 @@ bool LayoutUtils::SetLayout(const Expr& expr, const NLayout& layout) { if (saved_layout == layout_str) { return false; } - expr->span = SpanUtils::SetAttr(expr->span, "layout", layout_str); + expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kLayout, layout_str); } return true; } @@ -101,10 +106,10 @@ const NLayout LayoutUtils::GetNLayout(const Expr& expr) { } auto sinfo = GetStructInfo(expr); if (sinfo->IsInstance()) { - return LayoutDecision(SpanUtils::GetAttr(expr->span, "layout")); + return LayoutDecision(SpanUtils::GetAttr(expr->span, msc_attr::kLayout)); } if (sinfo->IsInstance()) { - String layout_str = SpanUtils::GetAttr(expr->span, "layout"); + String layout_str = SpanUtils::GetAttr(expr->span, msc_attr::kLayout); std::vector output_layout; for (const auto& l : StringUtils::Split(layout_str, ",")) { output_layout.push_back(LayoutDecision(l)); diff --git a/src/contrib/msc/core/transform/set_byoc_attrs.cc b/src/contrib/msc/core/transform/set_byoc_attrs.cc index 4fa8ab584e3c..2ed3e35e81de 100644 --- a/src/contrib/msc/core/transform/set_byoc_attrs.cc +++ b/src/contrib/msc/core/transform/set_byoc_attrs.cc @@ -47,34 +47,49 @@ class ByocNameSetter : public ExprMutator { entry_name_ = entry_name; } - IRModule SetAttrs() { - GlobalVar main_var; + IRModule SetNames() { size_t func_cnt = 0; for (const auto& [gv, func] : mod_->functions) { if (gv->name_hint == entry_name_) { - main_var = gv; - } else { - const auto& name_opt = func->GetAttr(attr::kCodegen); - if (name_opt.defined() && name_opt.value() == target_) { - const auto& new_func = WithAttr(Downcast(func), "byoc_name", - target_ + "_" + std::to_string(func_cnt)); - builder_->UpdateFunction(gv, new_func); - func_cnt += 1; - } + continue; + } + const auto& name_opt = func->GetAttr(attr::kCodegen); + if (name_opt.defined() && name_opt.value() == target_) { + const String& func_name = target_ + "_" + std::to_string(func_cnt); + const auto& new_func = Downcast(VisitExpr(func)); + builder_->UpdateFunction(gv, WithAttr(new_func, msc_attr::kUnique, func_name)); + func_cnt += 1; } } return builder_->GetContextIRModule(); } + void VisitBinding_(const VarBindingNode* binding, const FunctionNode* val) final { + local_funcs_.Set(binding->var, GetRef(val)); + ExprMutator::VisitBinding_(binding, val); + } + + void VisitBinding_(const VarBindingNode* binding, const CallNode* val) final { + ExprMutator::VisitBinding_(binding, val); + if (val->op->IsInstance()) { + ICHECK(local_funcs_.count(val->op)) << "Can not find local func " << val->op; + const auto& name_opt = local_funcs_[val->op]->GetAttr(msc_attr::kUnique); + if (name_opt.defined()) { + val->span = SpanUtils::SetAttr(val->span, "name", name_opt.value()); + } + } + } + private: IRModule mod_; String target_; String entry_name_; Map new_funcs_; + Map local_funcs_; }; IRModule SetBYOCAttrs(IRModule mod, const String& target, const String& entry_name) { - return ByocNameSetter(mod, target, entry_name).SetAttrs(); + return ByocNameSetter(mod, target, entry_name).SetNames(); } namespace transform { diff --git a/src/contrib/msc/core/transform/set_expr_name.cc b/src/contrib/msc/core/transform/set_expr_name.cc index 97850c70e8e8..dfed1a242a50 100644 --- a/src/contrib/msc/core/transform/set_expr_name.cc +++ b/src/contrib/msc/core/transform/set_expr_name.cc @@ -49,7 +49,7 @@ class FuncNameGetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const CallNode* val) { if (name_.size() == 0) { - name_ = SpanUtils::GetAttr(val->span, "name"); + name_ = SpanUtils::GetAttr(val->span, msc_attr::kName); } if (name_.size() == 0) { ExprVisitor::VisitBinding_(binding, val); @@ -58,7 +58,7 @@ class FuncNameGetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { if (name_.size() == 0) { - name_ = SpanUtils::GetAttr(val->span, "name"); + name_ = SpanUtils::GetAttr(val->span, msc_attr::kName); } if (name_.size() == 0) { ExprVisitor::VisitBinding_(binding, val); @@ -88,21 +88,22 @@ class RelaxExprNameSetter : public ExprVisitor { : ref_module_(ref_module), target_{target} {} void VisitBindingBlock(const BindingBlock& block) final { - String block_name = SpanUtils::GetAttr(block->span, "name"); + String block_name = SpanUtils::GetAttr(block->span, msc_attr::kName); if (block_name.size() == 0) { block_name = "block"; } - if (setted_blocks_.count(block_name)) { + const String& prefix = StringUtils::Join(block_stack_, "."); + if (setted_blocks_.count(prefix + "." + block_name)) { int cnt = 1; - while (setted_blocks_.count(block_name + "_" + std::to_string(cnt))) { + while (setted_blocks_.count(prefix + "." + block_name + "_" + std::to_string(cnt))) { cnt++; } block_name = block_name + "_" + std::to_string(cnt); } - setted_blocks_.insert(block_name); + setted_blocks_.insert(prefix + "." + block_name); block_stack_.push_back(block_name); const String& unique_name = StringUtils::Join(block_stack_, "."); - block->span = SpanUtils::SetAttr(block->span, "name", unique_name); + block->span = SpanUtils::SetAttr(block->span, msc_attr::kName, unique_name); ExprVisitor::VisitBindingBlock(block); block_stack_.pop_back(); } @@ -110,8 +111,8 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitExpr_(const ConstantNode* val) { ExprVisitor::VisitExpr_(val); const String& unique_name = GetUniqueName(GetRef(val), "const"); - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { + val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } expr_names_.Set(GetRef(val), unique_name); } @@ -119,8 +120,8 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const ConstantNode* val) { ExprVisitor::VisitBinding_(binding, val); const String& unique_name = GetUniqueName(GetRef(val), "const"); - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { + val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } expr_names_.Set(binding->var, unique_name); } @@ -128,8 +129,8 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const ShapeExprNode* val) { ExprVisitor::VisitBinding_(binding, val); const String& unique_name = GetUniqueName(GetRef(val), "shape"); - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { + val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } expr_names_.Set(binding->var, unique_name); } @@ -137,8 +138,8 @@ class RelaxExprNameSetter : public ExprVisitor { void VisitBinding_(const VarBindingNode* binding, const TupleNode* val) { ExprVisitor::VisitBinding_(binding, val); const String& unique_name = GetUniqueName(GetRef(val), "tuple"); - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { + val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } expr_names_.Set(binding->var, unique_name); } @@ -151,8 +152,8 @@ class RelaxExprNameSetter : public ExprVisitor { } else if (const auto* v_node = val->tuple.as()) { unique_name = v_node->name_hint() + "." + std::to_string(val->index); } - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { + val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } expr_names_.Set(binding->var, unique_name); } @@ -171,9 +172,19 @@ class RelaxExprNameSetter : public ExprVisitor { bool use_unique = true; if (const auto* op_node = val->op.as()) { const std::string& op_name = op_node->name; - int rpos = op_name.rfind("."); - name_hint = op_name.substr(rpos + 1); - optype = StringUtils::Replace(op_node->name, "relax.", ""); + if (op_name == "relax.call_dps_packed" && val->args[0]->IsInstance()) { + const auto& func = Downcast(val->args[0]); + name_hint = func->global_symbol; + optype = func->global_symbol; + const String& input_name = GetUniqueName(val->args[1], "plugin_inputs"); + if (input_name != SpanUtils::GetAttr(val->args[1]->span, msc_attr::kName)) { + val->args[1]->span = SpanUtils::SetAttr(val->args[1]->span, msc_attr::kName, input_name); + } + } else { + int rpos = op_name.rfind("."); + name_hint = op_name.substr(rpos + 1); + optype = StringUtils::Replace(op_node->name, "relax.", ""); + } } else if (const auto* v_node = val->op.as()) { const auto& func = Downcast(ref_module_->Lookup(v_node->name_hint)); ExprVisitor::VisitExpr(func); @@ -190,8 +201,8 @@ class RelaxExprNameSetter : public ExprVisitor { // set name const String& unique_name = use_unique ? GetUniqueName(GetRef(val), name_hint) : name_hint; - if (unique_name != SpanUtils::GetAttr(val->span, "name")) { - val->span = SpanUtils::SetAttr(val->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(val->span, msc_attr::kName)) { + val->span = SpanUtils::SetAttr(val->span, msc_attr::kName, unique_name); } // set constant consumer && shared_ref Array input_types; @@ -207,10 +218,10 @@ class RelaxExprNameSetter : public ExprVisitor { continue; } if (const auto* c_node = val->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); + const String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); if (constant_consumers_.count(const_name)) { - val->span = - SpanUtils::SetAttr(val->span, "shared_ref", constant_consumers_[const_name]); + val->span = SpanUtils::SetAttr(val->span, msc_attr::kSharedRef, + constant_consumers_[const_name]); } else { constant_consumers_.Set(const_name, unique_name); } @@ -222,7 +233,7 @@ class RelaxExprNameSetter : public ExprVisitor { private: const String GetUniqueName(const Expr& expr, const String& name_hint) { - String expr_name = SpanUtils::GetAttr(expr->span, "name"); + String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (expr_name.size() == 0) { expr_name = name_hint; } @@ -264,15 +275,8 @@ class RelaxExprNameSetter : public ExprVisitor { const String GetFuncName(const Call& call, const Function& func) { String name; - // get from byoc_name - if (target_.size() > 0) { - const auto& byoc_name_opt = func->GetAttr("byoc_name"); - if (byoc_name_opt.defined()) { - return byoc_name_opt.value(); - } - } - // get from attribute - const auto& name_opt = func->GetAttr("unique_name"); + // get from unique + const auto& name_opt = func->GetAttr(msc_attr::kUnique); if (name_opt.defined()) { return name_opt.value(); } @@ -336,25 +340,25 @@ class RelayExprNameSetter : public ExprVisitor { void VisitExpr_(const ConstantNode* op) final { ExprVisitor::VisitExpr_(op); const String& unique_name = GetUniqueName(GetRef(op), "const"); - if (unique_name != SpanUtils::GetAttr(op->span, "name")) { - op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { + op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); } } void VisitExpr_(const TupleNode* op) final { ExprVisitor::VisitExpr_(op); const String& unique_name = GetUniqueName(GetRef(op), "tuple"); - if (unique_name != SpanUtils::GetAttr(op->span, "name")) { - op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { + op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); } } void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); - const String& tuple_name = SpanUtils::GetAttr(op->tuple->span, "name"); + const String& tuple_name = SpanUtils::GetAttr(op->tuple->span, msc_attr::kName); const String& unique_name = tuple_name + "." + std::to_string(op->index); - if (unique_name != SpanUtils::GetAttr(op->span, "name")) { - op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { + op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); } } @@ -363,8 +367,8 @@ class RelayExprNameSetter : public ExprVisitor { const auto& name_opt = op->GetAttr(attr::kComposite); const String& name_hint = name_opt.defined() ? name_opt.value() : "func"; const String& unique_name = GetUniqueName(GetRef(op), name_hint); - if (unique_name != SpanUtils::GetAttr(op->span, "name")) { - op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { + op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); } } @@ -391,8 +395,8 @@ class RelayExprNameSetter : public ExprVisitor { if (name_hint.size() > 0) { // set name const String& unique_name = GetUniqueName(GetRef(op), name_hint); - if (unique_name != SpanUtils::GetAttr(op->span, "name")) { - op->span = SpanUtils::SetAttr(op->span, "name", unique_name); + if (unique_name != SpanUtils::GetAttr(op->span, msc_attr::kName)) { + op->span = SpanUtils::SetAttr(op->span, msc_attr::kName, unique_name); } // set constant consumer && shared_ref Array input_types; @@ -408,9 +412,10 @@ class RelayExprNameSetter : public ExprVisitor { continue; } if (const auto* c_node = op->args[i].as()) { - const String& const_name = SpanUtils::GetAttr(c_node->span, "name"); + const String& const_name = SpanUtils::GetAttr(c_node->span, msc_attr::kName); if (constant_consumers_.count(const_name)) { - op->span = SpanUtils::SetAttr(op->span, "shared_ref", constant_consumers_[const_name]); + op->span = + SpanUtils::SetAttr(op->span, msc_attr::kSharedRef, constant_consumers_[const_name]); } else { constant_consumers_.Set(const_name, unique_name); } @@ -421,7 +426,7 @@ class RelayExprNameSetter : public ExprVisitor { private: const String GetUniqueName(const Expr& expr, const String& name_hint) { - String expr_name = SpanUtils::GetAttr(expr->span, "name"); + String expr_name = SpanUtils::GetAttr(expr->span, msc_attr::kName); if (expr_name.size() == 0) { expr_name = name_hint; } @@ -503,7 +508,7 @@ class RelayExprNameBinder : public ExprVisitor { valid_name = valid_name + "_" + std::to_string(cnt); } setted_names_.Set(valid_name, expr); - expr->span = SpanUtils::SetAttr(expr->span, "name", valid_name); + expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, valid_name); } } diff --git a/src/contrib/msc/framework/tensorflow/codegen_utils.h b/src/contrib/msc/framework/tensorflow/codegen_utils.h index 85a37de05850..4c250f10609e 100644 --- a/src/contrib/msc/framework/tensorflow/codegen_utils.h +++ b/src/contrib/msc/framework/tensorflow/codegen_utils.h @@ -42,17 +42,12 @@ class TFV1CodeGenHelper : public BaseCodeGenHelper {}; * \brief CodeGen config for tensorflow codegen */ struct TensorflowCodeGenConfig { - bool is_training{false}; CODEGEN_CONFIG_MEMBERS void Load(dmlc::JSONReader* reader) { std::string key; reader->BeginObject(); while (reader->NextObjectItem(&key)) { - if (key == "is_training") { - reader->Read(&is_training); - } else { - CODEGEN_CONFIG_PARSE - } + CODEGEN_CONFIG_PARSE } } }; diff --git a/src/contrib/msc/framework/tensorrt/codegen.cc b/src/contrib/msc/framework/tensorrt/codegen.cc index c3d659015c18..13c231092a8f 100644 --- a/src/contrib/msc/framework/tensorrt/codegen.cc +++ b/src/contrib/msc/framework/tensorrt/codegen.cc @@ -576,12 +576,12 @@ Array MSCTensorRTCompiler(Array functions, Array compiled_functions; for (const auto& func : functions) { VLOG(1) << "MSC.TensorRT partition:" << std::endl << func; - const auto& byoc_name_opt = func->GetAttr("byoc_name"); - ICHECK(byoc_name_opt.defined()) << "Can not find byoc_name from attrs"; - const auto& byoc_name = byoc_name_opt.value(); + const auto& name_opt = func->GetAttr(msc_attr::kUnique); + ICHECK(name_opt.defined()) << "Can not find " << msc_attr::kUnique << " from attrs"; + const auto& name = name_opt.value(); std::string func_name = GetExtSymbol(func); - ICHECK(target_option.count(byoc_name)) << "Can not find target option for " << byoc_name; - const auto& options = Downcast(target_option[byoc_name]); + ICHECK(target_option.count(name)) << "Can not find target option for " << name; + const auto& options = Downcast(target_option[name]); MSCJSONSerializer serializer(constant_names, options); serializer.serialize(func); std::string graph_json = serializer.GetJSON(); diff --git a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc index ca01d5fbea3c..c71cb605013f 100644 --- a/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc +++ b/src/contrib/msc/framework/tensorrt/transform_tensorrt.cc @@ -41,8 +41,8 @@ const Array GetShape(const Expr& var) { } Var EmitCall(BlockBuilder builder, const Expr& expr, const Span& src_span, const String& suffix) { - const auto& name = SpanUtils::GetAttr(src_span, "name") + "_" + suffix; - expr->span = SpanUtils::SetAttr(expr->span, "name", name); + const auto& name = SpanUtils::GetAttr(src_span, msc_attr::kName) + "_" + suffix; + expr->span = SpanUtils::SetAttr(expr->span, msc_attr::kName, name); return builder->Emit(expr, name); } @@ -54,7 +54,7 @@ Var MakeCall(BlockBuilder builder, const Span& src_span, const String& suffix, E Expr MakeConstant(double value, const DataType& dtype, const String& name) { const auto& data = support::FloatImmToNDArray(FloatImm(dtype, value)); - const auto& span = SpanUtils::SetAttr(Span(), "name", name); + const auto& span = SpanUtils::SetAttr(Span(), msc_attr::kName, name); return Constant(data, NullOpt, span); } @@ -228,14 +228,15 @@ Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call Expr p_scale; if (src_attrs->scale.defined()) { const auto& scale = MakeConstant(static_cast(src_attrs->scale.value()->value), in_dtype, - SpanUtils::GetAttr(call->span, "name") + "_scale"); + SpanUtils::GetAttr(call->span, msc_attr::kName) + "_scale"); Array exp_shape(3, Integer(1)); const auto& exp_scale = MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); p_scale = MakeCall(builder, call->span, "p_scale", multiply_op, {qk_prod, exp_scale}); } else { - const auto& scale = MakeConstant(static_cast(Downcast(head_dim)->value), - in_dtype, SpanUtils::GetAttr(call->span, "name") + "_scale"); + const auto& scale = + MakeConstant(static_cast(Downcast(head_dim)->value), in_dtype, + SpanUtils::GetAttr(call->span, msc_attr::kName) + "_scale"); Array exp_shape(3, Integer(1)); const auto& exp_scale = MakeCall(builder, call->span, "exp_scale", reshape_op, {scale, ShapeExpr(exp_shape)}); @@ -305,8 +306,8 @@ Expr RewriteBatchNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(src_attrs->axis, input_shape[src_attrs->axis]); // create eps constant - const auto& eps = - MakeConstant(src_attrs->epsilon, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_eps"); + const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, + SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -419,8 +420,8 @@ Expr RewriteGroupNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(axis, Integer(src_attrs->num_groups)); // create eps constant - const auto& eps = - MakeConstant(src_attrs->epsilon, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_eps"); + const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, + SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -487,8 +488,8 @@ Expr RewriteLayerNorm(BlockBuilder builder, const Var& var, const Call& src_call exp_shape.Set(index, input_shape[index]); } // create eps constant - const auto& eps = - MakeConstant(src_attrs->epsilon, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_eps"); + const auto& eps = MakeConstant(src_attrs->epsilon, in_dtype, + SpanUtils::GetAttr(call->span, msc_attr::kName) + "_eps"); // create ops static const Op& add_op = Op::Get("relax.add"); @@ -578,7 +579,8 @@ Expr RewriteRsqrt(BlockBuilder builder, const Var& var, const Call& src_call, const auto& in_dtype = Downcast(GetStructInfo(call->args[0]))->dtype; Array exp_shape(input_shape.size(), Integer(1)); // create 1 constant - const auto& one = MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span, "name") + "_one"); + const auto& one = + MakeConstant(1, in_dtype, SpanUtils::GetAttr(call->span, msc_attr::kName) + "_one"); // create ops static const Op& reshape_op = Op::Get("relax.reshape"); diff --git a/src/contrib/msc/framework/torch/codegen_utils.h b/src/contrib/msc/framework/torch/codegen_utils.h index b80ea51f1528..c63de27519e0 100644 --- a/src/contrib/msc/framework/torch/codegen_utils.h +++ b/src/contrib/msc/framework/torch/codegen_utils.h @@ -53,17 +53,12 @@ class TorchCodeGenHelper : public BaseCodeGenHelper { * \brief CodeGen config for torch codegen */ struct TorchCodeGenConfig { - bool is_training{false}; CODEGEN_CONFIG_MEMBERS void Load(dmlc::JSONReader* reader) { std::string key; reader->BeginObject(); while (reader->NextObjectItem(&key)) { - if (key == "is_training") { - reader->Read(&is_training); - } else { - CODEGEN_CONFIG_PARSE - } + CODEGEN_CONFIG_PARSE } } }; diff --git a/tests/python/contrib/test_msc/test_manager.py b/tests/python/contrib/test_msc/test_manager.py index 04379af89a20..c07c59784832 100644 --- a/tests/python/contrib/test_msc/test_manager.py +++ b/tests/python/contrib/test_msc/test_manager.py @@ -38,11 +38,11 @@ def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1) path = "test_manager_{}_{}".format(model_type, compile_type) return { "workspace": msc_utils.msc_dir(path), - "verbose": "critical", + "verbose": "debug:1", "model_type": model_type, "inputs": inputs, "outputs": outputs, - "dataset": {"loader": "from_random", "max_iter": 5}, + "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, "prepare": {"profile": {"benchmark": {"repeat": 10}}}, "baseline": { "run_type": model_type, @@ -55,7 +55,7 @@ def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1) } -def _get_torch_model(name, is_training=False): +def _get_torch_model(name, training=False): """Get model from torch vision""" # pylint: disable=import-outside-toplevel @@ -63,7 +63,7 @@ def _get_torch_model(name, is_training=False): import torchvision model = getattr(torchvision.models, name)() - if is_training: + if training: model = model.train() else: model = model.eval() @@ -111,8 +111,8 @@ def _check_manager(manager, expected_info): raise Exception("{}\nReport:{}".format(err, json.dumps(manager.report, indent=2))) -def _test_from_torch(compile_type, expected_info, is_training=False, atol=1e-1, rtol=1e-1): - torch_model = _get_torch_model("resnet50", is_training) +def _test_from_torch(compile_type, expected_info, training=False, atol=1e-1, rtol=1e-1): + torch_model = _get_torch_model("resnet50", training) if torch_model: if torch.cuda.is_available(): torch_model = torch_model.to(torch.device("cuda:0")) @@ -168,7 +168,7 @@ def test_tvm_manager(): "msc.linear_bias": 1, }, } - _test_from_torch(MSCFramework.TVM, model_info, is_training=True) + _test_from_torch(MSCFramework.TVM, model_info, training=False) model_info = { "inputs": [ @@ -222,7 +222,7 @@ def test_torch_manager(): "msc.linear_bias": 1, }, } - _test_from_torch(MSCFramework.TORCH, model_info, is_training=False) + _test_from_torch(MSCFramework.TORCH, model_info, training=False) def test_tensorflow_manager(): @@ -269,7 +269,7 @@ def test_tensorrt_manager(): "outputs": [{"name": "output", "shape": [1, 1000], "dtype": "float32", "layout": ""}], "nodes": {"total": 2, "input": 1, "msc_tensorrt": 1}, } - _test_from_torch(MSCFramework.TENSORRT, model_info, is_training=False) + _test_from_torch(MSCFramework.TENSORRT, model_info, training=False) if __name__ == "__main__": diff --git a/tests/python/contrib/test_msc/test_runner.py b/tests/python/contrib/test_msc/test_runner.py index e3d5bcf24503..3c88c8706a80 100644 --- a/tests/python/contrib/test_msc/test_runner.py +++ b/tests/python/contrib/test_msc/test_runner.py @@ -39,7 +39,7 @@ ) -def _get_torch_model(name, is_training=False): +def _get_torch_model(name, training=False): """Get model from torch vision""" # pylint: disable=import-outside-toplevel @@ -47,7 +47,7 @@ def _get_torch_model(name, is_training=False): import torchvision model = getattr(torchvision.models, name)() - if is_training: + if training: model = model.train() else: model = model.eval() @@ -78,10 +78,10 @@ def _get_tf_graph(): return None, None -def _test_from_torch(runner_cls, device, is_training=False, atol=1e-1, rtol=1e-1): +def _test_from_torch(runner_cls, device, training=False, atol=1e-1, rtol=1e-1): """Test runner from torch model""" - torch_model = _get_torch_model("resnet50", is_training) + torch_model = _get_torch_model("resnet50", training) if torch_model: path = "test_runner_torch_{}_{}".format(runner_cls.__name__, device) workspace = msc_utils.set_workspace(msc_utils.msc_dir(path)) @@ -94,7 +94,7 @@ def _test_from_torch(runner_cls, device, is_training=False, atol=1e-1, rtol=1e-1 with torch.no_grad(): golden = torch_model(*torch_datas) mod = from_fx(graph_model, input_info) - runner = runner_cls(mod, device=device, is_training=is_training) + runner = runner_cls(mod, device=device, training=training) runner.build() outputs = runner.run(datas, ret_type="list") golden = [msc_utils.cast_array(golden)] @@ -106,27 +106,31 @@ def _test_from_torch(runner_cls, device, is_training=False, atol=1e-1, rtol=1e-1 def test_tvm_runner_cpu(): """Test runner for tvm on cpu""" - _test_from_torch(TVMRunner, "cpu", is_training=True) + for training in [True, False]: + _test_from_torch(TVMRunner, "cpu", training=training) @tvm.testing.requires_cuda def test_tvm_runner_cuda(): """Test runner for tvm on cuda""" - _test_from_torch(TVMRunner, "cuda", is_training=True) + for training in [True, False]: + _test_from_torch(TVMRunner, "cuda", training=training) def test_torch_runner_cpu(): """Test runner for torch on cpu""" - _test_from_torch(TorchRunner, "cpu") + for training in [True, False]: + _test_from_torch(TorchRunner, "cpu", training=training) @tvm.testing.requires_cuda def test_torch_runner_cuda(): """Test runner for torch on cuda""" - _test_from_torch(TorchRunner, "cuda", atol=1e-1, rtol=1e-1) + for training in [True, False]: + _test_from_torch(TorchRunner, "cuda", training=training, atol=1e-1, rtol=1e-1) @requires_tensorrt diff --git a/tests/python/contrib/test_msc/test_tools.py b/tests/python/contrib/test_msc/test_tools.py index 8fa9e5cf10cc..7161b4b42f40 100644 --- a/tests/python/contrib/test_msc/test_tools.py +++ b/tests/python/contrib/test_msc/test_tools.py @@ -52,7 +52,7 @@ def _get_config( "model_type": model_type, "inputs": inputs, "outputs": outputs, - "dataset": {"loader": "from_random", "max_iter": 5}, + "dataset": {"prepare": {"loader": "from_random", "max_iter": 5}}, "prepare": {"profile": {"benchmark": {"repeat": 10}}}, "baseline": { "run_type": model_type, @@ -145,7 +145,7 @@ def get_tool_config(tool_type, use_distill=False): return {tool_type: config} -def _get_torch_model(name, is_training=False): +def _get_torch_model(name, training=False): """Get model from torch vision""" # pylint: disable=import-outside-toplevel @@ -153,7 +153,7 @@ def _get_torch_model(name, is_training=False): import torchvision model = getattr(torchvision.models, name)() - if is_training: + if training: model = model.train() else: model = model.eval() @@ -183,12 +183,12 @@ def _test_from_torch( compile_type, tools_config, expected_info, - is_training=False, + training=False, atol=1e-1, rtol=1e-1, optimize_type=None, ): - torch_model = _get_torch_model("resnet50", is_training) + torch_model = _get_torch_model("resnet50", training) if torch_model: if torch.cuda.is_available(): torch_model = torch_model.to(torch.device("cuda:0")) @@ -247,7 +247,7 @@ def test_tvm_tool(tool_type): tool_config = get_tool_config(tool_type) _test_from_torch( - MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), is_training=True + MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), training=False ) @@ -258,7 +258,7 @@ def test_tvm_distill(tool_type): tool_config = get_tool_config(tool_type, use_distill=True) _test_from_torch( - MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), is_training=True + MSCFramework.TVM, tool_config, get_model_info(MSCFramework.TVM), training=False ) @@ -280,7 +280,7 @@ def test_tensorrt_tool(tool_type): MSCFramework.TENSORRT, tool_config, get_model_info(MSCFramework.TENSORRT), - is_training=False, + training=False, atol=1e-1, rtol=1e-1, optimize_type=optimize_type, @@ -294,7 +294,7 @@ def test_tensorrt_distill(tool_type): tool_config = get_tool_config(tool_type, use_distill=True) _test_from_torch( - MSCFramework.TENSORRT, tool_config, get_model_info(MSCFramework.TENSORRT), is_training=False + MSCFramework.TENSORRT, tool_config, get_model_info(MSCFramework.TENSORRT), training=False ) diff --git a/tests/python/contrib/test_msc/test_transform.py b/tests/python/contrib/test_msc/test_transform.py index 37a3ad3fcd7e..ccc2723a24ca 100644 --- a/tests/python/contrib/test_msc/test_transform.py +++ b/tests/python/contrib/test_msc/test_transform.py @@ -25,8 +25,8 @@ from tvm.relay.expr_functor import ExprVisitor from tvm.relay.build_module import bind_params_by_name -from tvm.contrib.msc.core import _ffi_api from tvm.contrib.msc.core import transform as msc_transform +from tvm.contrib.msc.core import utils as msc_utils def test_relax_layout(): @@ -56,14 +56,12 @@ def check(self, expr): def visit_var_binding_(self, binding) -> None: super().visit_var_binding_(binding) - layout = _ffi_api.SpanGetAttr(binding.value.span, "layout") - if not layout: + if not msc_utils.get_expr_layout(binding.value): self._missing_exprs.append(binding.value) def visit_constant_(self, op) -> None: super().visit_constant_(op) - layout = _ffi_api.SpanGetAttr(op.span, "layout") - if not layout: + if not msc_utils.get_expr_layout(op): self._missing_exprs.append(op) torch_model = torchvision.models.resnet50() @@ -90,14 +88,12 @@ def check(self, expr): def visit_constant(self, expr): super().visit_constant(expr) - name = _ffi_api.SpanGetAttr(expr.span, "name") - if not name: + if not msc_utils.get_expr_name(expr): self._missing_exprs.append(expr) def visit_call(self, expr): super().visit_call(expr) - name = _ffi_api.SpanGetAttr(expr.span, "name") - if not name: + if not msc_utils.get_expr_name(expr): self._missing_exprs.append(expr) mod, params = testing.resnet.get_workload(num_layers=50, batch_size=1, dtype="float32") @@ -133,14 +129,12 @@ def check(self, expr): def visit_var_binding_(self, binding) -> None: super().visit_var_binding_(binding) - name = _ffi_api.SpanGetAttr(binding.value.span, "name") - if not name: + if not msc_utils.get_expr_name(binding.value): self._missing_exprs.append(binding.value) def visit_constant_(self, op) -> None: super().visit_constant_(op) - name = _ffi_api.SpanGetAttr(op.span, "name") - if not name: + if not msc_utils.get_expr_name(op): self._missing_exprs.append(op) torch_model = torchvision.models.resnet50() diff --git a/tests/python/contrib/test_msc/test_translate_tensorrt.py b/tests/python/contrib/test_msc/test_translate_tensorrt.py index 262cf40adc7f..81104e6fe0f2 100644 --- a/tests/python/contrib/test_msc/test_translate_tensorrt.py +++ b/tests/python/contrib/test_msc/test_translate_tensorrt.py @@ -69,8 +69,8 @@ def check(self, expr): def visit_function_(self, op: tvm.relax.Function) -> None: if "Composite" in op.attrs: - assert "unique_name" in op.attrs, "Can not find unique_name for func " + str(op) - name = str(op.attrs["unique_name"]) + assert "Unique" in op.attrs, "Can not find unique_name for func " + str(op) + name = str(op.attrs["Unique"]) assert name not in self._recorded_names, "Name {} is already in use".format(name) self._recorded_names.add(name) super().visit_function_(op) @@ -83,7 +83,7 @@ def _is_target_func(func): for _, func in mod.functions.items(): if not _is_target_func(func): continue - assert "byoc_name" in func.attrs, "Can not find byoc_name from function attributes" + assert "Unique" in func.attrs, "Can not find Unique from function attributes" NameChecker().check(func) @@ -100,13 +100,13 @@ def verify_model(torch_model, input_info, allow_incomplete=False): golden = [golden] golden = [g.detach().cpu().numpy() for g in golden] # partition module for tensorrt - mod, graph_infos = translate.partition_for_tensorrt( + mod, graphs, weights = translate.partition_for_tensorrt( mod, trans_config={"allow_incomplete": allow_incomplete} ) check_names(mod) output_folder = msc_utils.msc_dir() # tranalte to tensorrt - mod = codegen.to_tensorrt(mod, graph_infos, output_folder=output_folder) + mod = codegen.to_tensorrt(mod, graphs, weights, output_folder=output_folder) tvm_datas = [tvm.nd.array(i, device=tvm.cuda()) for i in datas] results = build_and_run(mod, tvm_datas) for gol, res in zip(golden, results): From 4da23df380f1b909f2826cf0c1ecb0e7ae9062c6 Mon Sep 17 00:00:00 2001 From: Archermmt Date: Thu, 25 Jan 2024 21:07:48 +0800 Subject: [PATCH 2/2] fix bug in distiller --- .../contrib/msc/framework/torch/tools/distill/distiller.py | 4 ++-- tests/python/contrib/test_msc/test_manager.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py index b2fa414aca63..ee5c895603e4 100644 --- a/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py +++ b/python/tvm/contrib/msc/framework/torch/tools/distill/distiller.py @@ -79,8 +79,8 @@ def build_model(self, teacher: Any, student: Any) -> Any: raise NotImplementedError("optimizer {} is not supported".format(optimizer)) # Get loss function - loss_strategy = self._strategys.get("loss.all") - assert loss_strategy, "Can not find loss.all in strategys" + loss_strategy = self._strategys.get("loss.output") + assert loss_strategy, "Can not find loss.output in strategys" def get_loss(teacher_outputs, student_outputs): return loss_strategy(self, teacher_outputs, student_outputs) diff --git a/tests/python/contrib/test_msc/test_manager.py b/tests/python/contrib/test_msc/test_manager.py index c07c59784832..bcd12b36b5a3 100644 --- a/tests/python/contrib/test_msc/test_manager.py +++ b/tests/python/contrib/test_msc/test_manager.py @@ -38,7 +38,7 @@ def _get_config(model_type, compile_type, inputs, outputs, atol=1e-1, rtol=1e-1) path = "test_manager_{}_{}".format(model_type, compile_type) return { "workspace": msc_utils.msc_dir(path), - "verbose": "debug:1", + "verbose": "critical", "model_type": model_type, "inputs": inputs, "outputs": outputs,