Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions python/tvm/contrib/msc/core/frontend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,12 @@ def _to_data(ref_t, data):
return data

weights = {t.name: _to_data(t, d) for t, d in t_weights.items() if graph.has_tensor(t.name)}
return weights
# sort the weights by graph weights
graph_weights = {}
for weight in graph.get_weights():
assert weight.name in weights, "Missing weight " + str(weight)
graph_weights[weight.name] = weights[weight.name]
return graph_weights


def from_relax(
Expand Down Expand Up @@ -115,13 +120,10 @@ def from_relax(
patterns = get_patterns_with_prefix("msc.")
passes = [
msc_transform.SetExprName(),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
tvm.relax.transform.FuseOpsByPattern(
patterns, bind_constants=False, annotate_codegen=False
),
msc_transform.SetExprName(entry_name=entry, target=trans_config.get("target", "")),
msc_transform.SetExprLayout(
trans_config.get("allow_layout_missing", True), entry_name=entry
),
]
mod = tvm.transform.Sequential(passes)(mod)
graph = _ffi_api.BuildFromRelax(mod, entry, msc_utils.dump_dict(build_config))
Expand Down Expand Up @@ -309,13 +311,12 @@ def _partition_mod(mod, as_msc=True):
patterns = get_patterns_with_prefix(target)
passes = [
msc_transform.SetExprName(),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
tvm.relax.transform.FuseOpsByPattern(patterns, bind_constants=not as_msc),
msc_transform.BindShape(),
msc_transform.InlineParams(),
msc_transform.FuseTuple(target),
tvm.relax.transform.MergeCompositeFunctions(),
msc_transform.SetBYOCAttrs(target),
msc_transform.SetExprName(target=target),
msc_transform.SetExprLayout(trans_config.get("allow_layout_missing", True)),
]
return tvm.transform.Sequential(passes)(mod)

Expand All @@ -331,9 +332,12 @@ def _is_target_func(func):
assert len(func_names) == 1, "More than 1 target func is found: " + str(msc_mod)
BYOCChecker().check(func_names, msc_mod[entry])

graphs_info, all_weights = [], _ffi_api.GetRelaxWeights(msc_mod, entry)
ref_weights = _ffi_api.GetRelaxWeights(msc_mod, entry)
graphs, weights = [], {}
for name in func_names:
build_config.update({"graph_name": msc_mod[name].attrs["byoc_name"], "byoc_entry": name})
graph_name = msc_mod[name].attrs[_ffi_api.ToAttrKey("unique")]
build_config.update({"graph_name": graph_name, "byoc_entry": name})
graph = _ffi_api.BuildFromRelax(msc_mod, entry, msc_utils.dump_dict(build_config))
graphs_info.append((graph, normalize_weights(all_weights, graph)))
return _partition_mod(mod, False), graphs_info
graphs.append(graph)
weights.update(normalize_weights(ref_weights, graph))
return _partition_mod(mod, False), graphs, weights
13 changes: 13 additions & 0 deletions python/tvm/contrib/msc/core/ir/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,19 @@ def get_nodes(self) -> Iterable[MSCJoint]:
for n in self.node_names:
yield self.find_node(n)

def get_weights(self) -> Iterable[MSCTensor]:
"""Get all the weights in the graph.

Returns
-------
weights: generator<MSCTensor>
The generator of weights.
"""

for node in self.get_nodes():
for weight in node.get_weights().values():
yield weight

def input_at(self, idx: int) -> MSCTensor:
"""Get input at idx.

Expand Down
196 changes: 196 additions & 0 deletions python/tvm/contrib/msc/core/runtime/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, arguments-differ
"""tvm.contrib.msc.core.runtime.hook"""

from typing import Dict, List, Tuple, Union, Any

import tvm
from tvm.contrib.msc.core.ir import MSCGraph
from tvm.contrib.msc.core import utils as msc_utils


class RunnerHook(object):
"""Hook for runner

Parameters
----------
config: dict
The config of the func.
"""

def __init__(self, config: dict):
self._config = config

def __str__(self):
return "{}({})".format(self.name(), self._config)

def apply(self, runner: object, *args, **kwargs) -> Any:
"""Apply the hook

Parameters
----------
runner:
The runner context.
args: list<Any>
The arguments for run method.
kwargs: dict<Any>
The key word arguments for run method.

Returns
-------
result:
The results.
"""

kwargs.update({k: v for k, v in self._config.items() if k not in kwargs})
return self._apply(runner, *args, **kwargs)

def _apply(self, runner: object, *args, **kwargs):
"""Apply the hook

Parameters
----------
runner:
The runner context.
args: list<Any>
The arguments for run method.
kwargs: dict<Any>
The key word arguments for run method.

Returns
-------
result:
The results.
"""

raise NotImplementedError("default_func is not supported in " + str(self.__class__))

@classmethod
def name(cls):
return "base"


class CustomizedHook(RunnerHook):
"""Hook for customized func

Parameters
----------
func: callable/str
The function.
config: dict
The config of the func.
"""

def __init__(self, func: Union[str, callable], config: dict):
super(CustomizedHook, self).__init__(config)
self._func = msc_utils.load_callable(func)

def __str__(self):
return "{} {}({})".format(self.name(), self._func, self._config)

def _apply(self, runner: object, *args, **kwargs):
"""Apply the hook

Parameters
----------
runner:
The runner context.
args: list<Any>
The arguments for run method.
kwargs: dict<Any>
The key word arguments for run method.

Returns
-------
result:
The results.
"""

return self._func(runner, *args, **kwargs)

@classmethod
def name(cls):
return "customized"


class UpdateWeightsHook(RunnerHook):
"""Hook for update weights"""

def _apply(
self,
runner: object,
graphs: List[MSCGraph],
weights: Dict[str, tvm.nd.array],
weights_path: str,
) -> Tuple[List[MSCGraph], Dict[str, tvm.nd.array]]:
"""Apply the default funcion

Parameters
-------
runner:
The runner context.
graphs: list<MSCGraph>
The translated graphs
weights: dict<str, tvm.nd.array>
The translated weights.
weights_path: str
The weights path.

Returns
-------
graphs: list<MSCGraph>
The updated graphs
weights: dict<str, tvm.nd.array>
The updated weights.

"""

with open(weights_path, "rb") as f:
new_weights = tvm.runtime.load_param_dict(f.read())
weights.update({k: v for k, v in new_weights.items() if k in weights})
return graphs, weights

@classmethod
def name(cls):
return "update_weights"


def load_runner_hook(config: dict) -> Any:
"""Load a registered hook

Parameters
----------
config: dict
The config of the func.

Returns
-------
hook: RunnerHook
The hook
"""

assert "hook" in config, "hook should be given to load hook"
hook_ref = config["hook"]
hook_config = {k: v for k, v in config.items() if k != "hook"}
hook_cls = msc_utils.get_registered_runner_hook(hook_ref) if isinstance(hook_ref, str) else None
if hook_cls:
return hook_cls(hook_config)
return CustomizedHook(hook_ref, hook_config)


msc_utils.register_runner_hook(UpdateWeightsHook)
Loading