Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Infra to use tvm write op kernels #15550

Merged
merged 9 commits into from
Jul 22, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 21935d to afd4b3
39 changes: 39 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ ifndef AMALGAMATION_PATH
AMALGAMATION_PATH = $(ROOTDIR)/amalgamation
endif

ifndef TVM_PATH
TVM_PATH = $(TPARTYDIR)/tvm
endif

ifndef LLVM_PATH
LLVM_PATH = $(TVM_PATH)/build/llvm
endif

ifneq ($(USE_OPENMP), 1)
export NO_OPENMP = 1
endif
Expand Down Expand Up @@ -101,6 +109,12 @@ endif
CFLAGS += -I$(TPARTYDIR)/mshadow/ -I$(TPARTYDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(TPARTYDIR)/tvm/include -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)

ifeq ($(USE_TVM_OP), 1)
LIB_DEP += lib/libtvm_runtime.so lib/libtvmop.so
CFLAGS += -I$(TVM_PATH)/include -DMXNET_USE_TVM_OP=1
LDFLAGS += -L$(TVM_PATH)/build -ltvm_runtime
endif

ifeq ($(ENABLE_TESTCOVERAGE), 1)
CFLAGS += --coverage
LDFLAGS += --coverage
Expand Down Expand Up @@ -589,6 +603,30 @@ $(DMLC_CORE)/libdmlc.a: DMLCCORE
DMLCCORE:
+ cd $(DMLC_CORE); $(MAKE) libdmlc.a USE_SSE=$(USE_SSE) config=$(ROOTDIR)/$(config); cd $(ROOTDIR)

TVM_USE_CUDA := OFF
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
ifeq ($(USE_CUDA), 1)
TVM_USE_CUDA := ON
ifneq ($(USE_CUDA_PATH), NONE)
TVM_USE_CUDA := $(USE_CUDA_PATH)
endif
endif

lib/libtvm_runtime.so:
echo "Compile TVM"
[ -e $(LLVM_PATH)/bin/llvm-config ] || sh $(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \
cd $(TVM_PATH)/build; \
cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
-DUSE_SORT=OFF -DUSE_CUDA=$(TVM_USE_CUDA) -DUSE_CUDNN=OFF ..; \
$(MAKE) VERBOSE=1; \
cp $(TVM_PATH)/build/libtvm_runtime.so $(ROOTDIR)/lib/libtvm_runtime.so; \
cd $(ROOTDIR)

lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
echo "Compile TVM operators"
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib:$PYTHONPATH \
LD_LIBRARY_PATH=lib \
python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so

NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h)
NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc)
$(NNVM_PATH)/lib/libnnvm.a: $(NNVM_INC) $(NNVM_SRC)
Expand Down Expand Up @@ -726,6 +764,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN)
cd $(DMLC_CORE); $(MAKE) clean; cd -
cd $(PS_PATH); $(MAKE) clean; cd -
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS))
$(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS))
Expand Down
22 changes: 22 additions & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
from .opdef import defop
from .utils import AllTypes, RealTypes

from . import basic
19 changes: 19 additions & 0 deletions contrib/tvmop/basic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
from . import ufunc
50 changes: 50 additions & 0 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
import tvm
from .. import defop, AllTypes

def compute_add(dtype, ndim):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
B = tvm.placeholder([tvm.var() for _ in range(ndim)], name='B', dtype=dtype)
C = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] + B[index], name='C')
s = tvm.create_schedule(C.op)
return s, A, B, C

@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=list(range(1, 6)))
def vadd(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
s[C].parallel(fused)

return s, [A, B, C]

@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=list(range(1, 6)))
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.create_schedule(C.op)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
bx, tx = s[C].split(fused, factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B, C]
58 changes: 58 additions & 0 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""TVM Operator compile entry point"""
import tvm

import os
import argparse
from tvmop.opdef import __OP_DEF__

def get_target(device):
if device == "cpu":
return "llvm"
elif device == "cuda" or device == "gpu":
return "cuda"
assert False, "Unknown device " + device


if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(sys.path[0]))
parser = argparse.ArgumentParser(description="Generate tvm operators")
parser.add_argument("-o", action="store", required=True, dest="target_path",
help="Target path which stores compiled library")
arguments = parser.parse_args()

func_list_llvm = []
func_list_cuda = []

for operator_def in __OP_DEF__:
for sch, args in operator_def.invoke_all():
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
func_lower = tvm.lower(sch, args,
name=operator_def.get_op_name(args),
binds=operator_def.get_binds(args))
func_list.append(func_lower)

lowered_funcs = {get_target("cpu") : func_list_llvm}
if len(func_list_cuda) > 0:
lowered_funcs[get_target("cuda")] = func_list_cuda
func_binary = tvm.build(lowered_funcs, name="tvmop")
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
func_binary.export_library(arguments.target_path)
106 changes: 106 additions & 0 deletions contrib/tvmop/opdef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
import tvm
from itertools import product

__OP_DEF__ = []

class OpDef:
"""Specify the properties of an operator and
construct the value combination of the arguments
e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"],
then the argument combination is
[
{"ldtype": "float32", "rdtype": "float16"},
{"ldtype": "float32", "rdtype": "int16"},
{"ldtype": "int32", "rdtype": "float16"},
{"ldtype": "int32", "rdtype": "int16"},
]
Parameters
----------
func : function
The function to define the operator (in tvm compute and schedule).
It will get the argument combination extracted by this class.
name : str
function name.
target : str
{"cpu", "gpu", "cuda"}
auto_broadcast : bool
auto_broadcast=True allows one to implement broadcast computation
without considering whether dimension size equals to one.
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
"""
def __init__(self, func, name, target, auto_broadcast, **kwargs):
# construct the value combination of the arguments
# e.g., ldtype=["float32", "int32"], rdtype=["float16", "int16"]
# arg_combination = [
# {"ldtype": "float32", "rdtype": "float16"},
# {"ldtype": "float32", "rdtype": "int16"},
# {"ldtype": "int32", "rdtype": "float16"},
# {"ldtype": "int32", "rdtype": "int16"},
# ]
args = [k for k in kwargs]
values = [kwargs[k] if isinstance(kwargs[k], (list, tuple)) else [kwargs[k]]
for k in args]
cart_product = product(*values)
self.arg_combination = [{k: v for k, v in zip(args, comb_values)}
for comb_values in cart_product]
self.func = func
self.name = name
self.target = target
self.auto_broadcast = auto_broadcast

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def invoke_all(self):
for each_kwargs in self.arg_combination:
yield self.func(**each_kwargs)

def get_op_name(self, args):
return self.name + ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) for arg in args])

def get_binds(self, args):
if self.auto_broadcast:
return {arg: tvm.decl_buffer(arg.shape, arg.dtype, buffer_type="auto_broadcast")
for arg in args}
return None


def defop(name, target=None, auto_broadcast=False, **kwargs):
"""Decorator to define a tvm operator.
Parameters
----------
name : str
function name
target : bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bool ->str

{"cpu", "gpu", "cuda"}
auto_broadcast : bool
auto_broadcast=True allows one to implement broadcast computation
without considering whether dimension size equals to one.
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
TVM maps buffer[i][j][k] -> buffer[i][0][k] if dimension i's shape equals 1.
"""
assert name is not None and len(name) > 0
target = "cpu" if target is None else target
def _defop(func):
opdef = OpDef(func, name, target, auto_broadcast, **kwargs)
__OP_DEF__.append(opdef)
return opdef
return _defop

63 changes: 63 additions & 0 deletions contrib/tvmop/prepare_tvm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#!/bin/sh

yzhliu marked this conversation as resolved.
Show resolved Hide resolved
LLVM_VERSION="8.0.0"
LLVM_ROOT="http://releases.llvm.org/${LLVM_VERSION}/"
LLVM_PKG="clang+llvm-${LLVM_VERSION}-x86_64-linux-gnu"

os=`uname`
if [ "$os" = "Linux" ] && [ "$(arch)" = "x86_64" ]; then
DISTRIB_ID=$(cat /etc/*-release | grep DISTRIB_ID | sed 's/DISTRIB_ID=//g' | tr '[:upper:]' '[:lower:]')
DISTRIB_RELEASE=$(cat /etc/*-release | grep DISTRIB_RELEASE | sed 's/DISTRIB_RELEASE=//g' | tr '[:upper:]' '[:lower:]')
if [ "$DISTRIB_ID" = "ubuntu" ]; then
LLVM_PKG=${LLVM_PKG}-${DISTRIB_ID}-${DISTRIB_RELEASE}
else
echo "Downloading LLVM only supports Ubuntu. Please download manually."
exit 1
fi
else
echo "Cannot identify operating system. Try downloading package manually."
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
exit 1
fi

LLVM_URL=${LLVM_ROOT}${LLVM_PKG}.tar.xz

TVM_PATH=`dirname $0`/../../3rdparty/tvm
DST=${TVM_PATH}/build
rm -rf $DST
mkdir -p $DST
DST=`cd $DST; pwd`

if [ -x "$(command -v curl)" ]; then
curl -L -o "${DST}/${LLVM_PKG}.tar.xz" "$LLVM_URL"
elif [ -x "$(command -v wget)" ]; then
wget -O "${DST}/${LLVM_PKG}.tar.xz" "$LLVM_URL"
else
echo "curl or wget not available"
exit 1
fi

if [ \! $? ]; then
echo "Download from $LLVM_URL to $DST failed"
exit 1
fi

tar -xf "$DST/${LLVM_PKG}.tar.xz" -C $DST
mv $DST/${LLVM_PKG} $DST/llvm
echo "Downloaded and unpacked LLVM libraries to $DST"
Loading