From 93228649340bcacb8056d47d8f6f8a78a9805ae4 Mon Sep 17 00:00:00 2001 From: Haozheng Fan Date: Mon, 28 Oct 2019 14:02:49 +0800 Subject: [PATCH] Infra for tvm op runtime dispatch (#16100) * infra for dispatch tvm op * fix ci and sanity error * disable shape with hint and fix coding style * rename to avoid conflict with original dot * update tvm and use soft link * config file moves to lib/ when using Makefile * add tvmop.conf to ci * fix rebase * fix rebase * use inspect to detect dispatchable func --- CMakeLists.txt | 2 +- Makefile | 4 +- benchmark/python/tvmop/benchmark_tvmop.py | 57 ++++++ ci/jenkins/Jenkins_steps.groovy | 20 +- contrib/tvmop/compile.py | 13 ++ contrib/tvmop/core/__init__.py | 2 +- contrib/tvmop/core/multiarray.py | 53 ++++++ contrib/tvmop/opdef.py | 42 ++++- contrib/tvmop/space.py | 212 ++++++++++++++++++++++ include/mxnet/c_api.h | 26 +++ python/mxnet/__init__.py | 2 + python/mxnet/_ctypes/space.py | 96 ++++++++++ python/mxnet/base.py | 6 - python/mxnet/libinfo.py | 31 ++++ python/mxnet/space.py | 1 + python/mxnet/tvmop.py | 37 ++++ src/c_api/c_api.cc | 23 +++ src/operator/contrib/tvmop/dot.cc | 153 ++++++++++++++++ src/operator/tvmop/op_module.cc | 12 +- src/operator/tvmop/op_module.h | 74 ++++++++ 20 files changed, 841 insertions(+), 25 deletions(-) create mode 100644 benchmark/python/tvmop/benchmark_tvmop.py create mode 100644 contrib/tvmop/core/multiarray.py create mode 100644 contrib/tvmop/space.py create mode 100644 python/mxnet/_ctypes/space.py create mode 120000 python/mxnet/space.py create mode 100644 python/mxnet/tvmop.py create mode 100644 src/operator/contrib/tvmop/dot.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 99ad34bed3dc..a06aa9dba485 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -763,7 +763,7 @@ if(USE_TVM_OP) endif() endif() - set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so") + set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so" "--config" "${CMAKE_CURRENT_BINARY_DIR}/tvmop.conf") if(CUDA_ARCH_BIN) set(TVM_OP_COMPILE_OPTIONS "${TVM_OP_COMPILE_OPTIONS}" "--cuda-arch" "${CUDA_ARCH_BIN}") endif() diff --git a/Makefile b/Makefile index be0d34051873..63a978d01d8a 100644 --- a/Makefile +++ b/Makefile @@ -622,6 +622,7 @@ DMLCCORE: lib/libtvm_runtime.so: echo "Compile TVM" + @mkdir -p $(@D) [ -e $(LLVM_PATH)/bin/llvm-config ] || sh $(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \ cd $(TVM_PATH)/build; \ cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \ @@ -632,12 +633,13 @@ lib/libtvm_runtime.so: ls $(ROOTDIR)/lib; \ cd $(ROOTDIR) -TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so +TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so --config $(ROOTDIR)/lib/tvmop.conf ifneq ($(CUDA_ARCH),) TVM_OP_COMPILE_OPTIONS += --cuda-arch "$(CUDA_ARCH)" endif lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py) echo "Compile TVM operators" + @mkdir -p $(@D) PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \ LD_LIBRARY_PATH=$(ROOTDIR)/lib \ python3 $(ROOTDIR)/contrib/tvmop/compile.py $(TVM_OP_COMPILE_OPTIONS) diff --git a/benchmark/python/tvmop/benchmark_tvmop.py b/benchmark/python/tvmop/benchmark_tvmop.py new file mode 100644 index 000000000000..14ad49bb09ef --- /dev/null +++ b/benchmark/python/tvmop/benchmark_tvmop.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import time +import mxnet as mx +import numpy as _np +from mxnet import np, npx + +def measure_cost(repeat, func_name, *args, **kwargs): + """Measure time cost of running a function + """ + mx.nd.waitall() + start = time.time() + for _ in range(repeat): + func_name(*args, **kwargs) + mx.nd.waitall() + end = time.time() + diff = end - start + return diff / repeat + + +def test_tvm_dot(): + # benchmark + for i in list(range(1000, 1100, 4)): + m = i + k = i + n = i + print("{} * {} X {} * {}".format(m, k, k, n)) + a = mx.nd.random.uniform(shape=(m, k), dtype='float32') + b = mx.nd.random.uniform(shape=(k, n), dtype='float32') + cost = measure_cost(2, mx.nd.contrib.tvm_dot, a, b) + print("dispatch cost: {} ms".format(cost * 1000)) + a = mx.nd.random.uniform(shape=(m, k), dtype='float32') + b = mx.nd.random.uniform(shape=(k, n), dtype='float32') + cost = measure_cost(2, mx.nd.contrib.tvm_dot_fallback, a, b) + print("fallback cost: {} ms".format(cost * 1000)) + a = mx.nd.random.uniform(shape=(m, k), dtype='float32') + b = mx.nd.random.uniform(shape=(k, n), dtype='float32') + cost = measure_cost(2, mx.nd.dot, a, b) + print("dot cost: {} ms".format(cost * 1000)) + +if __name__ == "__main__": + test_tvm_dot() diff --git a/ci/jenkins/Jenkins_steps.groovy b/ci/jenkins/Jenkins_steps.groovy index 7f6200a28016..d7c2b9679ca3 100644 --- a/ci/jenkins/Jenkins_steps.groovy +++ b/ci/jenkins/Jenkins_steps.groovy @@ -23,24 +23,24 @@ utils = load('ci/Jenkinsfile_utils.groovy') // mxnet libraries -mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // Python wheels mx_pip = 'build/*.whl' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. -mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' +mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' // mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default. -mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' -mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' -mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' -mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' +mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests' +mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' +mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' +mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so' -mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/cpp-package/example/*' +mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*' // Python unittest for CPU // Python 2 diff --git a/contrib/tvmop/compile.py b/contrib/tvmop/compile.py index 3c0efdd6b806..b0254218077a 100644 --- a/contrib/tvmop/compile.py +++ b/contrib/tvmop/compile.py @@ -18,12 +18,15 @@ # coding: utf-8 """TVM Operator compile entry point""" import tvm +from tvm import autotvm import os import argparse import re +import json import logging from tvmop.opdef import __OP_DEF__ +from tvmop.space import ConfigSpaces, ConfigSpace from tvm.autotvm.measure.measure_methods import set_cuda_target_arch logging.basicConfig(level=logging.INFO) @@ -70,6 +73,8 @@ def get_cuda_arch(arch): help="Target path which stores compiled library") parser.add_argument('--cuda-arch', type=str, default=None, dest='cuda_arch', help='The cuda arch for compiling kernels for') + parser.add_argument("--config", action="store", required=True, dest="config_path", + help="Path which stores the config file") arguments = parser.parse_args() func_list_llvm = [] @@ -78,6 +83,7 @@ def get_cuda_arch(arch): # TODO: attach instruction features to the library, e.g., avx-512, etc. for operator_def in __OP_DEF__: for sch, args, name in operator_def.invoke_all(): + name = operator_def.get_op_name(name, args) if tvm.module.enabled(get_target(operator_def.target)): func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda func_lower = tvm.lower(sch, args, @@ -96,3 +102,10 @@ def get_cuda_arch(arch): set_cuda_target_arch(cuda_arch) func_binary = tvm.build(lowered_funcs, name="tvmop") func_binary.export_library(arguments.target_path) + + config_spaces = ConfigSpaces() + for operator_def in __OP_DEF__: + for config_space, name in operator_def.get_config_spaces(): + config_spaces[name] = ConfigSpace.from_tvm(config_space) + with open(arguments.config_path, "w") as f: + json.dump(config_spaces.to_json_dict(), f) diff --git a/contrib/tvmop/core/__init__.py b/contrib/tvmop/core/__init__.py index 841d4ad9db27..e309f237df05 100644 --- a/contrib/tvmop/core/__init__.py +++ b/contrib/tvmop/core/__init__.py @@ -15,4 +15,4 @@ # specific language governing permissions and limitations # under the License. -from . import umath, fromnumeric +from . import umath, fromnumeric, multiarray diff --git a/contrib/tvmop/core/multiarray.py b/contrib/tvmop/core/multiarray.py new file mode 100644 index 000000000000..c8eed5b45368 --- /dev/null +++ b/contrib/tvmop/core/multiarray.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +import tvm +from tvm import autotvm +from .. import defop, AllTypes +from .. import assign_by_req, reduce_axes + +def compute_dot(A, B): + M = A.shape[0] + K = A.shape[1] + N = B.shape[1] + k = tvm.reduce_axis((0, K), 'k') + C = tvm.compute((M, N), + lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k), + name='C') + return C + + +@defop(name="dot", target="cpu", dtype=AllTypes) +def dot(dtype, fallback): + cfg = autotvm.get_config() + cfg.define_knob("bn", [64] if fallback else [64, 32]) + cfg.define_knob("factor", [4] if fallback else [4]) + M = tvm.var("M") + K = tvm.var("K") + N = tvm.var("N") + A = tvm.placeholder((M, K), name='A', dtype=dtype) + B = tvm.placeholder((K, N), name='B', dtype=dtype) + C = compute_dot(A, B) + s = tvm.create_schedule(C.op) + # Blocking by loop tiling + xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], cfg["bn"].val, cfg["bn"].val) + k, = s[C].op.reduce_axis + ko, ki = s[C].split(k, factor=cfg["factor"].val) + # Hoist reduction domain outside the blocking loop + s[C].reorder(xo, yo, ko, ki, xi, yi) + return s, [A, B, C] diff --git a/contrib/tvmop/opdef.py b/contrib/tvmop/opdef.py index 39c42f4dd465..1e0f34669b10 100644 --- a/contrib/tvmop/opdef.py +++ b/contrib/tvmop/opdef.py @@ -17,6 +17,8 @@ # coding: utf-8 import tvm +import inspect +from tvm import autotvm from itertools import product __OP_DEF__ = [] @@ -68,6 +70,7 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs): self.name = name self.target = target self.auto_broadcast = auto_broadcast + self.dispatchable = 'fallback' in inspect.signature(self.func).parameters def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) @@ -75,12 +78,41 @@ def __call__(self, *args, **kwargs): def invoke_all(self): for each_kwargs in self.arg_combination: if self.attrs_valid(**each_kwargs): - sch, args = self.func(**each_kwargs) name = self.name \ - + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \ - + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) - for arg in args if hasattr(arg, 'shape')]) - yield sch, args, name + + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) + if self.dispatchable is False: + sch, args = self.func(**each_kwargs) + yield sch, args, name + else: + # register dispatch schedules + config_space = autotvm.ConfigSpace() + with autotvm.task.ApplyConfig(config_space): + sch, args = self.func(fallback=False, **each_kwargs) + for i in range(len(config_space)): + config_entity = config_space.get(i) + with autotvm.task.ApplyConfig(config_entity): + sch, args = self.func(fallback=False, **each_kwargs) + subname = name + "index_" + str(i) + yield sch, args, subname + # register fallback schedule + config_space = autotvm.ConfigSpace() + with autotvm.task.ApplyConfig(config_space): + sch, args = self.func(fallback=True, **each_kwargs) + subname = name + "fallback" + yield sch, args, subname + + def get_op_name(self, name, args): + return name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args if hasattr(arg, 'shape')]) + + def get_config_spaces(self): + for each_kwargs in self.arg_combination: + if self.attrs_valid(**each_kwargs) and self.dispatchable is True: + name = self.name \ + + ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) + config_space = autotvm.ConfigSpace() + with autotvm.task.ApplyConfig(config_space): + self.func(fallback=False, **each_kwargs) + yield config_space, name def get_binds(self, args): if self.auto_broadcast: diff --git a/contrib/tvmop/space.py b/contrib/tvmop/space.py new file mode 100644 index 000000000000..589b931bc37d --- /dev/null +++ b/contrib/tvmop/space.py @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""ConfigSpace API.""" +from collections import OrderedDict +import numpy as _np + +class OtherOptionSpace(object): + """The parameter space for general option""" + def __init__(self, entities): + self.entities = [OtherOptionEntity(e) for e in entities] + + @classmethod + def from_tvm(cls, x): + return cls([e.val for e in x.entities]) + + def __len__(self): + return len(self.entities) + + def __repr__(self): + return "OtherOption(%s) len=%d" % (self.entities, len(self)) + + +class OtherOptionEntity(object): + """The parameter entity for general option, with a detailed value""" + def __init__(self, val): + self.val = val + + @classmethod + def from_tvm(cls, x): + """Build a OtherOptionEntity from autotvm.OtherOptionEntity + + Parameters + ---------- + cls: class + Calling class + x: autotvm.OtherOptionEntity + The source object + + Returns + ------- + ret: OtherOptionEntity + The corresponding OtherOptionEntity object + """ + return cls(x.val) + + def __repr__(self): + return str(self.val) + + +class ConfigSpace(object): + """The configuration space of a schedule.""" + def __init__(self, space_map, _entity_map): + self.space_map = space_map + self._entity_map = _entity_map + self._length = None + + @classmethod + def from_tvm(cls, x): + """Build a ConfigSpace from autotvm.ConfigSpace + + Parameters + ---------- + cls: class + Calling class + x: autotvm.ConfigSpace + The source object + + Returns + ------- + ret: ConfigSpace + The corresponding ConfigSpace object + """ + space_map = OrderedDict([(k, OtherOptionSpace.from_tvm(v)) for k, v in x.space_map.items()]) + _entity_map = OrderedDict([(k, OtherOptionEntity.from_tvm(v)) for k, v in x._entity_map.items()]) + return cls(space_map, _entity_map) + + def __len__(self): + + if self._length is None: + self._length = int(_np.prod([len(x) for x in self.space_map.values()])) + return self._length + + def __repr__(self): + res = "ConfigSpace (len=%d, space_map=\n" % len(self) + for i, (name, space) in enumerate(self.space_map.items()): + res += " %2d %s: %s\n" % (i, name, space) + return res + ")" + + def to_json_dict(self): + """convert to a json serializable dictionary + + Return + ------ + ret: dict + a json serializable dictionary + """ + ret = {} + entity_map = [] + for k, v in self._entity_map.items(): + if isinstance(v, OtherOptionEntity): + entity_map.append((k, 'ot', v.val)) + else: + raise RuntimeError("Invalid entity instance: " + v) + ret['e'] = entity_map + space_map = [] + for k, v in self.space_map.items(): + entities = [e.val for e in v.entities] + space_map.append((k, 'ot', entities)) + ret['s'] = space_map + return ret + + @classmethod + def from_json_dict(cls, json_dict): + """Build a ConfigSpace from json serializable dictionary + + Parameters + ---------- + cls: class + The calling class + json_dict: dict + Json serializable dictionary. + + Returns + ------- + ret: ConfigSpace + The corresponding ConfigSpace object + """ + entity_map = OrderedDict() + for item in json_dict["e"]: + key, knob_type, knob_args = item + if knob_type == 'ot': + entity = OtherOptionEntity(knob_args) + else: + raise RuntimeError("Invalid config knob type: " + knob_type) + entity_map[str(key)] = entity + space_map = OrderedDict() + for item in json_dict["s"]: + key, knob_type, knob_args = item + if knob_type == 'ot': + space = OtherOptionSpace(knob_args) + else: + raise RuntimeError("Invalid config knob type: " + knob_type) + space_map[str(key)] = space + return cls(space_map, entity_map) + + +class ConfigSpaces(object): + """The configuration spaces of all ops.""" + def __init__(self): + self.spaces = {} + + def __setitem__(self, name, space): + self.spaces[name] = space + + def __len__(self): + return len(self.spaces) + + def __repr__(self): + res = "ConfigSpaces (len=%d, config_space=\n" % len(self) + for i, (key, val) in enumerate(self.spaces.items()): + res += " %2d %s:\n %s\n" % (i, key, val) + return res + ")" + + def to_json_dict(self): + """convert to a json serializable dictionary + + Return + ------ + ret: dict + a json serializable dictionary + """ + ret = [] + for k, v in self.spaces.items(): + ret.append((k, v.to_json_dict())) + return ret + + @classmethod + def from_json_dict(cls, json_dict): + """Build a ConfigSpaces from json serializable dictionary + + Parameters + ---------- + cls: class + The calling class + json_dict: dict + Json serializable dictionary. + + Returns + ------- + ret: ConfigSpaces + The corresponding ConfigSpaces object + """ + ret = cls() + for key, val in json_dict: + ret.spaces[key] = ConfigSpace.from_json_dict(val) + return ret diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 2463a5b75cfd..bbd67f059fdf 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -525,6 +525,32 @@ MXNET_DLL int MXGetVersion(int *out); */ #if MXNET_USE_TVM_OP MXNET_DLL int MXLoadTVMOp(const char *libpath); + +struct OtherOptionEntity { + int val; +}; + +struct OtherOptionSpace { + OtherOptionEntity* entities; + int entities_size; +}; + +struct ConfigSpace { + int entity_map_size; + char** entity_map_key; + OtherOptionEntity* entity_map_val; + int space_map_size; + char** space_map_key; + OtherOptionSpace* space_map_val; +}; + +typedef struct ConfigSpaces { + int spaces_size; + char** spaces_key; + ConfigSpace* spaces_val; +} ConfigSpaces; + +MXNET_DLL int MXLoadTVMConfig(ConfigSpaces config); #endif // MXNET_USE_TVM_OP diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 3dc80cd5a0b4..a4ba4f3b6093 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -87,6 +87,8 @@ from . import gluon +from . import tvmop + __version__ = base.__version__ # Dist kvstore module which launches a separate process when role is set to "server". diff --git a/python/mxnet/_ctypes/space.py b/python/mxnet/_ctypes/space.py new file mode 100644 index 000000000000..d1ee6e555e25 --- /dev/null +++ b/python/mxnet/_ctypes/space.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""ConfigSpace ctypes API.""" +import ctypes + +from ..base import _LIB +from ..base import c_str_array, c_array +from ..base import check_call + +class COtherOptionEntity(ctypes.Structure): + """ctypes data structure for OtherOptionEntity""" + _fields_ = [("val", ctypes.c_int)] + + +class COtherOptionSpace(ctypes.Structure): + """ctypes data structure for OtherOptionSpace""" + _fields_ = [("entities", ctypes.POINTER(COtherOptionEntity)), + ("entities_size", ctypes.c_int)] + + +class CConfigSpace(ctypes.Structure): + """ctypes data structure for ConfigSpace""" + _fields_ = [("entity_map_size", ctypes.c_int), + ("entity_map_key", ctypes.POINTER(ctypes.c_char_p)), + ("entity_map_val", ctypes.POINTER(COtherOptionEntity)), + ("space_map_size", ctypes.c_int), + ("space_map_key", ctypes.POINTER(ctypes.c_char_p)), + ("space_map_val", ctypes.POINTER(COtherOptionSpace))] + + +class CConfigSpaces(ctypes.Structure): + """ctypes data structure for ConfigSpaces""" + _fields_ = [("spaces_size", ctypes.c_int), + ("spaces_key", ctypes.POINTER(ctypes.c_char_p)), + ("spaces_val", ctypes.POINTER(CConfigSpace))] + + +def c_other_option_entity(x): + """constructor for OtherOptionEntity""" + ret = COtherOptionEntity() + ret.val = x.val + return ret + + +def c_other_option_space(x): + """constructor for OtherOptionSpace""" + ret = COtherOptionSpace() + ret.entities = c_array(COtherOptionEntity, + [c_other_option_entity(e) for e in x.entities]) + ret.entities_size = len(x.entities) + return ret + + +def c_config_space(x): + """constructor for ConfigSpace""" + ret = CConfigSpace() + ret.entity_map_key = c_str_array(x._entity_map.keys()) + ret.entity_map_val = c_array(COtherOptionEntity, + [c_other_option_entity(e) for e in x._entity_map.values()]) + ret.entity_map_size = len(x._entity_map) + ret.space_map_key = c_str_array(x.space_map.keys()) + ret.space_map_val = c_array(COtherOptionSpace, + [c_other_option_space(v) for v in x.space_map.values()]) + ret.space_map_size = len(x.space_map) + return ret + + +def c_config_spaces(x): + """constructor for ConfigSpaces""" + ret = CConfigSpaces() + ret.spaces_size = len(x.spaces) + ret.spaces_key = c_str_array(x.spaces.keys()) + ret.spaces_val = c_array(CConfigSpace, [c_config_space(c) for c in x.spaces.values()]) + return ret + + +def _set_tvm_op_config(x): + """ctypes implementation of populating the config singleton""" + check_call(_LIB.MXLoadTVMConfig(c_config_spaces(x))) + return x diff --git a/python/mxnet/base.py b/python/mxnet/base.py index db1fa29ab9b4..35acba3dbe53 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -743,12 +743,6 @@ def write_all_str(module_file, module_all_list): ctypes.pythonapi.PyCapsule_GetPointer.restype = ctypes.c_void_p -from .runtime import Features -if Features().is_enabled("TVM_OP"): - _LIB_TVM_OP = libinfo.find_lib_path("libtvmop") - check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0]))) - - _NP_OP_PREFIX = '_np_' _NP_OP_SUBMODULE_LIST = ['_random_', '_linalg_'] diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py index 8e0ae05f3378..589a69ce60cc 100644 --- a/python/mxnet/libinfo.py +++ b/python/mxnet/libinfo.py @@ -110,5 +110,36 @@ def find_include_path(): ' or ' + src_incl_path + '\n') +def find_conf_path(prefix='tvmop'): + """Find TVM op config files. + + Returns + ------- + conf_path : string + Path to the config files. + """ + conf_from_env = os.environ.get('MXNET_CONF_PATH') + if conf_from_env: + if os.path.isfile(conf_from_env): + if not os.path.isabs(conf_from_env): + logging.warning("MXNET_CONF_PATH should be an absolute path, instead of: %s", + conf_from_env) + else: + return conf_from_env + else: + logging.warning("MXNET_CONF_PATH '%s' doesn't exist", conf_from_env) + + curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) + makefile_path = os.path.join(curr_path, '../../lib/') + cmake_build_path = os.path.join(curr_path, '../../build/') + candidates_path = [makefile_path, cmake_build_path] + candidates_path = [p + prefix + '.conf' for p in candidates_path] + conf_path = [p for p in candidates_path if os.path.exists(p) and os.path.isfile(p)] + if len(conf_path) == 0: + raise RuntimeError('Cannot find the TVM op config.\n' + + 'List of candidates:\n' + str('\n'.join(candidates_path))) + return conf_path + + # current version __version__ = "1.6.0" diff --git a/python/mxnet/space.py b/python/mxnet/space.py new file mode 120000 index 000000000000..18f6a91e6055 --- /dev/null +++ b/python/mxnet/space.py @@ -0,0 +1 @@ +../../contrib/tvmop/space.py \ No newline at end of file diff --git a/python/mxnet/tvmop.py b/python/mxnet/tvmop.py new file mode 100644 index 000000000000..9ec278afb7a0 --- /dev/null +++ b/python/mxnet/tvmop.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +"""Init tvm ops.""" +from .runtime import Features + +if Features().is_enabled("TVM_OP"): + import json + + from ._ctypes.space import _set_tvm_op_config + from .base import check_call, _LIB, c_str + from .space import ConfigSpaces + from .libinfo import find_lib_path, find_conf_path + + _LIB_TVM_OP = find_lib_path("libtvmop") + check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0]))) + + # op sch config + _CONF_TVM_OP = find_conf_path("tvmop") + with open(_CONF_TVM_OP[0], "r") as f: + ret = ConfigSpaces.from_json_dict(json.load(f)) + _set_tvm_op_config(ret) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 1696bf572e4c..fb88cae2cbba 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -181,6 +181,29 @@ int MXLoadTVMOp(const char *libpath) { tvm::runtime::TVMOpModule::Get()->Load(libpath); API_END(); } + +int MXLoadTVMConfig(ConfigSpaces config) { + API_BEGIN(); + for (int k = 0; k < config.spaces_size; ++k) { + tvm::runtime::TVMOpConfig& entry = ::dmlc::Registry::Get() + ->__REGISTER_OR_GET__(std::string(config.spaces_key[k])); + const ConfigSpace& c = config.spaces_val[k]; + for (int i = 0; i < c.entity_map_size; ++i) { + entry.add_entity(std::string(c.entity_map_key[i]), c.entity_map_val[i].val); + } + for (int i = 0; i < c.space_map_size; ++i) { + std::string name = std::string(c.space_map_key[i]); + std::vector entities; + for (int j = 0; j < c.space_map_val[i].entities_size; ++j) { + int val = c.space_map_val[i].entities[j].val; + entities.push_back(val); + } + entry.add_space(name, entities); + } + } + API_END(); +} + #endif // MXNET_USE_TVM_OP int MXNDArrayCreateNone(NDArrayHandle *out) { diff --git a/src/operator/contrib/tvmop/dot.cc b/src/operator/contrib/tvmop/dot.cc new file mode 100644 index 000000000000..c6bec570a7e8 --- /dev/null +++ b/src/operator/contrib/tvmop/dot.cc @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file dot.cc + * \brief + * \author Haozheng Fan + */ +#ifdef MXNET_USE_TVM_OP +#include +#include +#include +#include +#include +#include "../../tensor/elemwise_binary_broadcast_op.h" +#include "../../tvmop/op_module.h" +#include "../../tensor/elemwise_binary_op.h" + +namespace mxnet { +namespace op { + +int SplitSch(const ::tvm::runtime::TVMOpConfig& config, + const ::std::string& name, + const std::vector& size) { + const ::tvm::runtime::OtherOptionSpace& space = config.get_space(name); + int weight = config.get_weight(name); + int num_space = space.size(), num_size = size.size(); + for (int i = 0; i < num_space; ++i) { + bool flag = true; + for (int j = 0; j < num_size; ++j) { + if (size[j] % space[i].get_val() != 0) { + flag = false; + break; + } + } + if (flag) { + return i * weight; + } + } + return -1; +} + +std::string DotSch(const std::string name, + const nnvm::NodeAttrs& attrs, + const mxnet::ShapeVector& in_attrs, + const mxnet::ShapeVector& out_attrs) { + const ::tvm::runtime::TVMOpConfig& config = tvm::runtime::GetOpConfig(name); + int m = in_attrs[0][0]; + int k = in_attrs[0][1]; + int n = in_attrs[1][1]; + int idx_bn = SplitSch(config, "bn", {m, n}); + int idx_factor = SplitSch(config, "factor", {k}); + int idx = idx_bn + idx_factor; + if (idx_bn == -1 || idx_factor == -1) { + return "fallback"; + } + return "index_" + std::to_string(idx); +} + +void TVMDotForward(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + std::string funcname = "dot"; + std::string sch = DotSch(funcname, attrs, {inputs[0].shape_, inputs[1].shape_}, + {outputs[0].shape_}); + tvm::runtime::TVMOpModule::Get()->Call(funcname + sch, ctx, {inputs[0], inputs[1], outputs[0]}); +} + +void TVMDotFallbackForward(const nnvm::NodeAttrs& attrs, + const mxnet::OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + std::string funcname = "dot"; + std::string sch = "fallback"; + tvm::runtime::TVMOpModule::Get()->Call(funcname + sch, ctx, {inputs[0], inputs[1], outputs[0]}); +} + +bool TVMDotShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& a_shape = in_attrs->at(0); + const mxnet::TShape& b_shape = in_attrs->at(1); + CHECK_EQ(a_shape.ndim(), 2U); + CHECK_EQ(b_shape.ndim(), 2U); + mxnet::TShape tmp_shape(2, -1); + tmp_shape[1] = b_shape[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, tmp_shape); + + tmp_shape[0] = a_shape[1]; + tmp_shape[1] = -1; + SHAPE_ASSIGN_CHECK(*in_attrs, 1, tmp_shape); + + tmp_shape[0] = a_shape[0]; + tmp_shape[1] = b_shape[1]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, tmp_shape); + return shape_is_known(*in_attrs) && shape_is_known(*out_attrs); +} + +NNVM_REGISTER_OP(_contrib_tvm_dot) + .set_num_inputs(2) + .set_num_outputs(1) + .add_argument("a", "NDArray-or-Symbol", "first input") + .add_argument("b", "NDArray-or-Symbol", "second input") + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) + .set_attr("FInferShape", TVMDotShape) + .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) + .set_attr("FCompute", TVMDotForward); + +NNVM_REGISTER_OP(_contrib_tvm_dot_fallback) + .set_num_inputs(2) + .set_num_outputs(1) + .add_argument("a", "NDArray-or-Symbol", "first input") + .add_argument("b", "NDArray-or-Symbol", "second input") + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a", "b"}; + }) + .set_attr("FInferShape", TVMDotShape) + .set_attr("FInferType", mxnet::op::ElemwiseType<2, 1>) + .set_attr("FCompute", TVMDotFallbackForward); + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_TVM_OP diff --git a/src/operator/tvmop/op_module.cc b/src/operator/tvmop/op_module.cc index ea7f73069698..b45df5dbdd4a 100644 --- a/src/operator/tvmop/op_module.cc +++ b/src/operator/tvmop/op_module.cc @@ -31,7 +31,9 @@ #include #include "op_module.h" -using namespace tvm::runtime; +namespace dmlc { + DMLC_REGISTRY_ENABLE(::tvm::runtime::TVMOpConfig); +} // namespace dmlc namespace tvm { namespace runtime { @@ -137,6 +139,14 @@ void TVMOpModule::CallEx(const std::string &func_name, #endif } +const TVMOpConfig& GetOpConfig(const std::string& name) { + const TVMOpConfig* ret = ::dmlc::Registry::Get()->Find(name); + CHECK(ret != NULL) + << "op " << name << "does not exist."; + return *ret; +} + } // namespace runtime } // namespace tvm + #endif // MXNET_USE_TVM_OP diff --git a/src/operator/tvmop/op_module.h b/src/operator/tvmop/op_module.h index d28dd1b5b0c2..269a0aa50c11 100644 --- a/src/operator/tvmop/op_module.h +++ b/src/operator/tvmop/op_module.h @@ -32,6 +32,7 @@ #include #include #include +#include namespace tvm { namespace runtime { @@ -72,6 +73,79 @@ class TVMOpModule { std::shared_ptr module_ptr_; }; +class OtherOptionEntity { + public: + explicit OtherOptionEntity(int val): val_(val) {} + OtherOptionEntity(): val_(0) {} + inline int get_val() const { + return val_; + } + private: + int val_; +}; + +class OtherOptionSpace { + public: + explicit OtherOptionSpace(const std::vector& entities) { + int size = entities.size(); + for (int i = 0; i < size; ++i) { + this->entities_.push_back(OtherOptionEntity(entities[i])); + } + } + + OtherOptionSpace() {} + + inline OtherOptionEntity &operator[] (int idx) { + return entities_[idx]; + } + + inline const OtherOptionEntity &operator[] (int idx) const { + return entities_[idx]; + } + + inline int size() const { + return entities_.size(); + } + + private: + std::vector entities_; +}; + +class TVMOpConfig { + public: + std::string name; + + inline TVMOpConfig& add_space(const std::string& name, const std::vector& val) { + int size = val.size(); + space_map_[name] = OtherOptionSpace(val); + weight_map_[name] = weight_acc_; + weight_acc_ *= size; + return *this; + } + inline TVMOpConfig& add_entity(const std::string& name, const int val) { + entity_map_[name] = OtherOptionEntity(val); + return *this; + } + + TVMOpConfig(): weight_acc_(1) {} + + inline const OtherOptionSpace& get_space(const std::string& name) const { + return space_map_.at(name); + } + + inline int get_weight(const std::string& name) const { + return weight_map_.at(name); + } + + private: + std::map entity_map_; + std::map space_map_; + std::map weight_map_; + int weight_acc_; +}; + +const TVMOpConfig& GetOpConfig(const std::string& name); + } // namespace runtime } // namespace tvm