Skip to content

Commit

Permalink
Infra for tvm op runtime dispatch (apache#16100)
Browse files Browse the repository at this point in the history
* infra for dispatch tvm op

* fix ci and sanity error

* disable shape with hint and fix coding style

* rename to avoid conflict with original dot

* update tvm and use soft link

* config file moves to lib/ when using Makefile

* add tvmop.conf to ci

* fix rebase

* fix rebase

* use inspect to detect dispatchable func
  • Loading branch information
hzfan authored and yajiedesign committed Nov 6, 2019
1 parent 9d2d1a7 commit 0b1ec85
Show file tree
Hide file tree
Showing 20 changed files with 841 additions and 25 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ if(USE_TVM_OP)
endif()
endif()

set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so")
set(TVM_OP_COMPILE_OPTIONS "-o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so" "--config" "${CMAKE_CURRENT_BINARY_DIR}/tvmop.conf")
if(CUDA_ARCH_BIN)
set(TVM_OP_COMPILE_OPTIONS "${TVM_OP_COMPILE_OPTIONS}" "--cuda-arch" "${CUDA_ARCH_BIN}")
endif()
Expand Down
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ DMLCCORE:

lib/libtvm_runtime.so:
echo "Compile TVM"
@mkdir -p $(@D)
[ -e $(LLVM_PATH)/bin/llvm-config ] || sh $(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \
cd $(TVM_PATH)/build; \
cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
Expand All @@ -632,12 +633,13 @@ lib/libtvm_runtime.so:
ls $(ROOTDIR)/lib; \
cd $(ROOTDIR)

TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so
TVM_OP_COMPILE_OPTIONS = -o $(ROOTDIR)/lib/libtvmop.so --config $(ROOTDIR)/lib/tvmop.conf
ifneq ($(CUDA_ARCH),)
TVM_OP_COMPILE_OPTIONS += --cuda-arch "$(CUDA_ARCH)"
endif
lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
echo "Compile TVM operators"
@mkdir -p $(@D)
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \
LD_LIBRARY_PATH=$(ROOTDIR)/lib \
python3 $(ROOTDIR)/contrib/tvmop/compile.py $(TVM_OP_COMPILE_OPTIONS)
Expand Down
57 changes: 57 additions & 0 deletions benchmark/python/tvmop/benchmark_tvmop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time
import mxnet as mx
import numpy as _np
from mxnet import np, npx

def measure_cost(repeat, func_name, *args, **kwargs):
"""Measure time cost of running a function
"""
mx.nd.waitall()
start = time.time()
for _ in range(repeat):
func_name(*args, **kwargs)
mx.nd.waitall()
end = time.time()
diff = end - start
return diff / repeat


def test_tvm_dot():
# benchmark
for i in list(range(1000, 1100, 4)):
m = i
k = i
n = i
print("{} * {} X {} * {}".format(m, k, k, n))
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
cost = measure_cost(2, mx.nd.contrib.tvm_dot, a, b)
print("dispatch cost: {} ms".format(cost * 1000))
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
cost = measure_cost(2, mx.nd.contrib.tvm_dot_fallback, a, b)
print("fallback cost: {} ms".format(cost * 1000))
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
cost = measure_cost(2, mx.nd.dot, a, b)
print("dot cost: {} ms".format(cost * 1000))

if __name__ == "__main__":
test_tvm_dot()
20 changes: 10 additions & 10 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,24 @@
utils = load('ci/Jenkinsfile_utils.groovy')

// mxnet libraries
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'

// Python wheels
mx_pip = 'build/*.whl'

// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libsample_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, libsample_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy2/*.so, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/cpp-package/example/*'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*'

// Python unittest for CPU
// Python 2
Expand Down
13 changes: 13 additions & 0 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@
# coding: utf-8
"""TVM Operator compile entry point"""
import tvm
from tvm import autotvm

import os
import argparse
import re
import json
import logging
from tvmop.opdef import __OP_DEF__
from tvmop.space import ConfigSpaces, ConfigSpace
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -70,6 +73,8 @@ def get_cuda_arch(arch):
help="Target path which stores compiled library")
parser.add_argument('--cuda-arch', type=str, default=None, dest='cuda_arch',
help='The cuda arch for compiling kernels for')
parser.add_argument("--config", action="store", required=True, dest="config_path",
help="Path which stores the config file")
arguments = parser.parse_args()

func_list_llvm = []
Expand All @@ -78,6 +83,7 @@ def get_cuda_arch(arch):
# TODO: attach instruction features to the library, e.g., avx-512, etc.
for operator_def in __OP_DEF__:
for sch, args, name in operator_def.invoke_all():
name = operator_def.get_op_name(name, args)
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
func_lower = tvm.lower(sch, args,
Expand All @@ -96,3 +102,10 @@ def get_cuda_arch(arch):
set_cuda_target_arch(cuda_arch)
func_binary = tvm.build(lowered_funcs, name="tvmop")
func_binary.export_library(arguments.target_path)

config_spaces = ConfigSpaces()
for operator_def in __OP_DEF__:
for config_space, name in operator_def.get_config_spaces():
config_spaces[name] = ConfigSpace.from_tvm(config_space)
with open(arguments.config_path, "w") as f:
json.dump(config_spaces.to_json_dict(), f)
2 changes: 1 addition & 1 deletion contrib/tvmop/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@
# specific language governing permissions and limitations
# under the License.

from . import umath, fromnumeric
from . import umath, fromnumeric, multiarray
53 changes: 53 additions & 0 deletions contrib/tvmop/core/multiarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
import tvm
from tvm import autotvm
from .. import defop, AllTypes
from .. import assign_by_req, reduce_axes

def compute_dot(A, B):
M = A.shape[0]
K = A.shape[1]
N = B.shape[1]
k = tvm.reduce_axis((0, K), 'k')
C = tvm.compute((M, N),
lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k),
name='C')
return C


@defop(name="dot", target="cpu", dtype=AllTypes)
def dot(dtype, fallback):
cfg = autotvm.get_config()
cfg.define_knob("bn", [64] if fallback else [64, 32])
cfg.define_knob("factor", [4] if fallback else [4])
M = tvm.var("M")
K = tvm.var("K")
N = tvm.var("N")
A = tvm.placeholder((M, K), name='A', dtype=dtype)
B = tvm.placeholder((K, N), name='B', dtype=dtype)
C = compute_dot(A, B)
s = tvm.create_schedule(C.op)
# Blocking by loop tiling
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], cfg["bn"].val, cfg["bn"].val)
k, = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=cfg["factor"].val)
# Hoist reduction domain outside the blocking loop
s[C].reorder(xo, yo, ko, ki, xi, yi)
return s, [A, B, C]
42 changes: 37 additions & 5 deletions contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

# coding: utf-8
import tvm
import inspect
from tvm import autotvm
from itertools import product

__OP_DEF__ = []
Expand Down Expand Up @@ -68,19 +70,49 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs):
self.name = name
self.target = target
self.auto_broadcast = auto_broadcast
self.dispatchable = 'fallback' in inspect.signature(self.func).parameters

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def invoke_all(self):
for each_kwargs in self.arg_combination:
if self.attrs_valid(**each_kwargs):
sch, args = self.func(**each_kwargs)
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape))
for arg in args if hasattr(arg, 'shape')])
yield sch, args, name
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs])
if self.dispatchable is False:
sch, args = self.func(**each_kwargs)
yield sch, args, name
else:
# register dispatch schedules
config_space = autotvm.ConfigSpace()
with autotvm.task.ApplyConfig(config_space):
sch, args = self.func(fallback=False, **each_kwargs)
for i in range(len(config_space)):
config_entity = config_space.get(i)
with autotvm.task.ApplyConfig(config_entity):
sch, args = self.func(fallback=False, **each_kwargs)
subname = name + "index_" + str(i)
yield sch, args, subname
# register fallback schedule
config_space = autotvm.ConfigSpace()
with autotvm.task.ApplyConfig(config_space):
sch, args = self.func(fallback=True, **each_kwargs)
subname = name + "fallback"
yield sch, args, subname

def get_op_name(self, name, args):
return name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args if hasattr(arg, 'shape')])

def get_config_spaces(self):
for each_kwargs in self.arg_combination:
if self.attrs_valid(**each_kwargs) and self.dispatchable is True:
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs])
config_space = autotvm.ConfigSpace()
with autotvm.task.ApplyConfig(config_space):
self.func(fallback=False, **each_kwargs)
yield config_space, name

def get_binds(self, args):
if self.auto_broadcast:
Expand Down
Loading

0 comments on commit 0b1ec85

Please sign in to comment.