Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
infra for dispatch tvm op
Browse files Browse the repository at this point in the history
  • Loading branch information
Fan committed Sep 5, 2019
1 parent d0fa8c0 commit c7d1920
Show file tree
Hide file tree
Showing 19 changed files with 854 additions and 29 deletions.
7 changes: 6 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,12 @@ if(USE_TVM_OP)
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH="${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python:${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/topi/python:${CMAKE_CURRENT_SOURCE_DIR}/contrib"
LD_LIBRARY_PATH=${CMAKE_CURRENT_BINARY_DIR}:${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm:$ENV{LD_LIBRARY_PATH}
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py -o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py -o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so --config ${CMAKE_CURRENT_BINARY_DIR}/tvmop.conf
)

add_custom_command(TARGET mxnet POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy
${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/space.py ${CMAKE_CURRENT_SOURCE_DIR}/python/mxnet/space.py
)
endif()

Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,7 @@ DMLCCORE:

lib/libtvm_runtime.so:
echo "Compile TVM"
@mkdir -p $(@D)
[ -e $(LLVM_PATH)/bin/llvm-config ] || sh $(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \
cd $(TVM_PATH)/build; \
cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
Expand All @@ -628,9 +629,11 @@ lib/libtvm_runtime.so:

lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
echo "Compile TVM operators"
@mkdir -p $(@D)
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib \
LD_LIBRARY_PATH=$(ROOTDIR)/lib \
python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so
python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so --config $(ROOTDIR)/tvmop.conf
cp $(ROOTDIR)/contrib/tvmop/space.py $(ROOTDIR)/python/mxnet/space.py

NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h)
NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc)
Expand Down
57 changes: 57 additions & 0 deletions benchmark/python/tvmop/benchmark_tvmop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import time
import mxnet as mx
import numpy as _np
from mxnet import np, npx

def measure_cost(repeat, func_name, *args, **kwargs):
"""Measure time cost of running a function
"""
mx.nd.waitall()
start = time.time()
for _ in range(repeat):
func_name(*args, **kwargs)
mx.nd.waitall()
end = time.time()
diff = end - start
return diff / repeat


def test_tvm_dot():
# benchmark
for i in list(range(1000, 1100, 4)):
m = i
k = i
n = i
print("{} * {} X {} * {}".format(m, k, k, n))
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
cost = measure_cost(2, mx.nd.contrib.tvm_dot, a, b)
print("dispatch cost: {} ms".format(cost * 1000))
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
cost = measure_cost(2, mx.nd.contrib.tvm_dot_fallback, a, b)
print("fallback cost: {} ms".format(cost * 1000))
a = mx.nd.random.uniform(shape=(m, k), dtype='float32')
b = mx.nd.random.uniform(shape=(k, n), dtype='float32')
cost = measure_cost(2, mx.nd.dot, a, b)
print("dot cost: {} ms".format(cost * 1000))

if __name__ == "__main__":
test_tvm_dot()
1 change: 1 addition & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from .utils import assign_by_req, reduce_axes

from . import basic
from . import core
53 changes: 45 additions & 8 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
# coding: utf-8
"""TVM Operator compile entry point"""
import tvm
from tvm import autotvm

import os
import argparse
import json
from tvmop.opdef import __OP_DEF__
from tvmop.space import ConfigSpaces, ConfigSpace

def get_target(device):
if device == "cpu":
Expand All @@ -37,23 +40,57 @@ def get_target(device):
parser = argparse.ArgumentParser(description="Generate tvm operators")
parser.add_argument("-o", action="store", required=True, dest="target_path",
help="Target path which stores compiled library")
parser.add_argument("--config", action="store", required=True, dest="config_path",
help="Path which stores the config file")
arguments = parser.parse_args()

func_list_llvm = []
func_list_cuda = []
config_spaces = ConfigSpaces()

# TODO: attach instruction features to the library, e.g., avx-512, etc.
for operator_def in __OP_DEF__:
for sch, args, name in operator_def.invoke_all():
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
func_lower = tvm.lower(sch, args,
name=name,
binds=operator_def.get_binds(args))
func_list.append(func_lower)
for op in __OP_DEF__:
if tvm.module.enabled(get_target(op.target)):
func_list = func_list_llvm if op.target == "cpu" else func_list_cuda
for each_kwargs in op.arg_combination:
if (op.attrs_valid(**each_kwargs)):
name = op.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in op.attrs])
if op.dispatch is True:
config_space = autotvm.ConfigSpace()
with autotvm.task.ApplyConfig(config_space):
sch, args = op.func(fallback=False, **each_kwargs)
# register dispatch schedules
for i in range(len(config_space)):
config_entity = config_space.get(i)
with autotvm.task.ApplyConfig(config_entity):
sch, args = op.func(fallback=False, **each_kwargs)
subname = name + "index_" + str(i) + \
''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
func_lower = tvm.lower(sch, args,
name=subname,
binds=op.get_binds(args))
func_list.append(func_lower)
# register config space
config_spaces[name] = ConfigSpace.from_tvm(config_space)
# register fallback schedule
config_space = autotvm.ConfigSpace()
with autotvm.task.ApplyConfig(config_space):
sch, args = op.func(fallback=True, **each_kwargs)
subname = name + "fallback" + \
''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
func_lower = tvm.lower(sch, args, name=subname, binds=op.get_binds(args))
func_list.append(func_lower)
else:
sch, args = op.func(**each_kwargs)
subname = name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
func_lower = tvm.lower(sch, args, name=subname, binds=op.get_binds(args))
func_list.append(func_lower)

lowered_funcs = {get_target("cpu") : func_list_llvm}
if len(func_list_cuda) > 0:
lowered_funcs[get_target("cuda")] = func_list_cuda
func_binary = tvm.build(lowered_funcs, name="tvmop")
func_binary.export_library(arguments.target_path)
with open(arguments.config_path, "w") as f:
json.dump(config_spaces.to_json_dict(), f)
19 changes: 19 additions & 0 deletions contrib/tvmop/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
from . import multiarray
53 changes: 53 additions & 0 deletions contrib/tvmop/core/multiarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
import tvm
from tvm import autotvm
from .. import defop, AllTypes
from .. import assign_by_req, reduce_axes

def compute_dot(A, B):
M = A.shape[0]
K = A.shape[1]
N = B.shape[1]
k = tvm.reduce_axis((0, K), 'k')
C = tvm.compute((M, N),
lambda x, y: tvm.sum(A[x, k] * B[k, y], axis=k),
name='C')
return C


@defop(name="dot", target="cpu", dispatch=True, dtype=AllTypes)
def dot(dtype, fallback):
cfg = autotvm.get_config()
cfg.define_knob("bn", [64] if fallback else [64, 32])
cfg.define_knob("factor", [4] if fallback else [4])
M = tvm.var("M") if fallback else cfg["bn"].val * tvm.var("M")
K = tvm.var("K") if fallback else cfg["factor"].val * tvm.var("K")
N = tvm.var("N") if fallback else cfg["bn"].val * tvm.var("N")
A = tvm.placeholder((M, K), name='A', dtype=dtype)
B = tvm.placeholder((K, N), name='B', dtype=dtype)
C = compute_dot(A, B)
s = tvm.create_schedule(C.op)
# Blocking by loop tiling
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], cfg["bn"].val, cfg["bn"].val)
k, = s[C].op.reduce_axis
ko, ki = s[C].split(k, factor=cfg["factor"].val)
# Hoist reduction domain outside the blocking loop
s[C].reorder(xo, yo, ko, ki, xi, yi)
return s, [A, B, C]
17 changes: 5 additions & 12 deletions contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

# coding: utf-8
import tvm
from tvm import autotvm
from itertools import product

__OP_DEF__ = []
Expand Down Expand Up @@ -47,7 +48,7 @@ class OpDef:
without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
"""
def __init__(self, func, name, target, auto_broadcast, **kwargs):
def __init__(self, func, name, target, auto_broadcast, dispatch, **kwargs):
# construct the value combination of the arguments
# e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"]
# arg_combination = [
Expand All @@ -68,27 +69,19 @@ def __init__(self, func, name, target, auto_broadcast, **kwargs):
self.name = name
self.target = target
self.auto_broadcast = auto_broadcast
self.dispatch = dispatch

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def invoke_all(self):
for each_kwargs in self.arg_combination:
if (self.attrs_valid(**each_kwargs)):
sch, args = self.func(**each_kwargs)
name = self.name \
+ ''.join(["{}_{}".format(key, each_kwargs[key]) for key in self.attrs]) \
+ ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])
yield sch, args, name

def get_binds(self, args):
if self.auto_broadcast:
return {arg: tvm.decl_buffer(arg.shape, arg.dtype, buffer_type="auto_broadcast")
for arg in args}
return None


def defop(name, target=None, auto_broadcast=False, **kwargs):
def defop(name, target=None, auto_broadcast=False, dispatch=False, **kwargs):
"""Decorator to define a tvm operator.
Parameters
----------
Expand All @@ -108,7 +101,7 @@ def defop(name, target=None, auto_broadcast=False, **kwargs):
assert name is not None and len(name) > 0
target = "cpu" if target is None else target
def _defop(func):
opdef = OpDef(func, name, target, auto_broadcast, **kwargs)
opdef = OpDef(func, name, target, auto_broadcast, dispatch, **kwargs)
__OP_DEF__.append(opdef)
return opdef
return _defop
Expand Down
Loading

0 comments on commit c7d1920

Please sign in to comment.