Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Infra to use tvm write op kernels #15550

Merged
merged 9 commits into from
Jul 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 21935d to afd4b3
23 changes: 23 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ mxnet_option(USE_MXNET_LIB_NAMING "Use MXNet library naming conventions." ON)
mxnet_option(USE_GPROF "Compile with gprof (profiling) flag" OFF)
mxnet_option(USE_CXX14_IF_AVAILABLE "Build with C++14 if the compiler supports it" OFF)
mxnet_option(USE_VTUNE "Enable use of Intel Amplifier XE (VTune)" OFF) # one could set VTUNE_ROOT for search path
mxnet_option(USE_TVM_OP "Enable use of TVM operator build system." OFF)
mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" ON)
mxnet_option(BUILD_CPP_EXAMPLES "Build cpp examples" ON)
mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF)
Expand Down Expand Up @@ -734,6 +735,28 @@ if(USE_DIST_KVSTORE)
list(APPEND mxnet_LINKER_LIBS ${pslite_LINKER_LIBS})
endif()

if(USE_TVM_OP)
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
add_definitions(-DMXNET_USE_TVM_OP=1)
list(APPEND mxnet_LINKER_LIBS ${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/libtvm_runtime.so)
include(cmake/BuildTVM.cmake)
add_subdirectory("3rdparty/tvm")

if(NOT Python3_EXECUTABLE)
find_package(PythonInterp 3 REQUIRED)
set(Python3_EXECUTABLE ${PYTHON_EXECUTABLE} CACHE FILEPATH "Path to the python3 executable")
if(NOT Python3_EXECUTABLE)
message(FATAL_ERROR "No python3 interpreter found to build TVM operators")
endif()
endif()

add_custom_command(TARGET mxnet POST_BUILD
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH="${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/python:${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/topi/python:${CMAKE_CURRENT_SOURCE_DIR}/contrib"
LD_LIBRARY_PATH="${CMAKE_CURRENT_BINARY_DIR}/3rdparty/tvm/build"
${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/compile.py -o${CMAKE_CURRENT_BINARY_DIR}/libtvmop.so
)
endif()

target_link_libraries(mxnet PUBLIC ${mxnet_LINKER_LIBS})

if(USE_PLUGINS_WARPCTC)
Expand Down
38 changes: 38 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ ifndef AMALGAMATION_PATH
AMALGAMATION_PATH = $(ROOTDIR)/amalgamation
endif

ifndef TVM_PATH
TVM_PATH = $(TPARTYDIR)/tvm
endif

ifndef LLVM_PATH
LLVM_PATH = $(TVM_PATH)/build/llvm
endif

ifneq ($(USE_OPENMP), 1)
export NO_OPENMP = 1
endif
Expand Down Expand Up @@ -589,6 +597,35 @@ $(DMLC_CORE)/libdmlc.a: DMLCCORE
DMLCCORE:
+ cd $(DMLC_CORE); $(MAKE) libdmlc.a USE_SSE=$(USE_SSE) config=$(ROOTDIR)/$(config); cd $(ROOTDIR)

ifeq ($(USE_TVM_OP), 1)
LIB_DEP += lib/libtvm_runtime.so lib/libtvmop.so
CFLAGS += -I$(TVM_PATH)/include -DMXNET_USE_TVM_OP=1
LDFLAGS += -L$(TVM_PATH)/build -ltvm_runtime

TVM_USE_CUDA := OFF
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
ifeq ($(USE_CUDA), 1)
TVM_USE_CUDA := ON
ifneq ($(USE_CUDA_PATH), NONE)
TVM_USE_CUDA := $(USE_CUDA_PATH)
endif
endif
lib/libtvm_runtime.so:
echo "Compile TVM"
[ -e $(LLVM_PATH)/bin/llvm-config ] || sh $(ROOTDIR)/contrib/tvmop/prepare_tvm.sh; \
cd $(TVM_PATH)/build; \
cmake -DUSE_LLVM="$(LLVM_PATH)/bin/llvm-config" \
-DUSE_SORT=OFF -DUSE_CUDA=$(TVM_USE_CUDA) -DUSE_CUDNN=OFF ..; \
$(MAKE) VERBOSE=1; \
cp $(TVM_PATH)/build/libtvm_runtime.so $(ROOTDIR)/lib/libtvm_runtime.so; \
cd $(ROOTDIR)

lib/libtvmop.so: lib/libtvm_runtime.so $(wildcard contrib/tvmop/*/*.py contrib/tvmop/*.py)
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
echo "Compile TVM operators"
PYTHONPATH=$(TVM_PATH)/python:$(TVM_PATH)/topi/python:$(ROOTDIR)/contrib:$PYTHONPATH \
LD_LIBRARY_PATH=lib \
python3 $(ROOTDIR)/contrib/tvmop/compile.py -o $(ROOTDIR)/lib/libtvmop.so
endif

NNVM_INC = $(wildcard $(NNVM_PATH)/include/*/*.h)
NNVM_SRC = $(wildcard $(NNVM_PATH)/src/*/*/*.cc $(NNVM_PATH)/src/*/*.cc $(NNVM_PATH)/src/*.cc)
$(NNVM_PATH)/lib/libnnvm.a: $(NNVM_INC) $(NNVM_SRC)
Expand Down Expand Up @@ -726,6 +763,7 @@ clean: rclean cyclean $(EXTRA_PACKAGES_CLEAN)
cd $(DMLC_CORE); $(MAKE) clean; cd -
cd $(PS_PATH); $(MAKE) clean; cd -
cd $(NNVM_PATH); $(MAKE) clean; cd -
cd $(TVM_PATH); $(MAKE) clean; cd -
cd $(AMALGAMATION_PATH); $(MAKE) clean; cd -
$(RM) -r $(patsubst %, %/*.d, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.d, $(EXTRA_OPERATORS))
$(RM) -r $(patsubst %, %/*.o, $(EXTRA_OPERATORS)) $(patsubst %, %/*/*.o, $(EXTRA_OPERATORS))
Expand Down
135 changes: 135 additions & 0 deletions cmake/BuildTVM.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

message(STATUS "Prepare external packages for TVM...")
execute_process(COMMAND "${CMAKE_CURRENT_SOURCE_DIR}/contrib/tvmop/prepare_tvm.sh")

# Whether enable ROCM runtime
#
# Possible values:
# - ON: enable ROCM with cmake's auto search
# - OFF: disable ROCM
# - /path/to/rocm: use specific path to rocm
set(USE_ROCM OFF)

# Whether enable SDAccel runtime
set(USE_SDACCEL OFF)

# Whether enable Intel FPGA SDK for OpenCL (AOCL) runtime
set(USE_AOCL OFF)

# Whether enable OpenCL runtime
set(USE_OPENCL OFF)

# Whether enable Metal runtime
set(USE_METAL OFF)

# Whether enable Vulkan runtime
#
# Possible values:
# - ON: enable Vulkan with cmake's auto search
# - OFF: disable vulkan
# - /path/to/vulkan-sdk: use specific path to vulkan-sdk
set(USE_VULKAN OFF)

# Whether enable OpenGL runtime
set(USE_OPENGL OFF)

# Whether to enable SGX runtime
#
# Possible values for USE_SGX:
# - /path/to/sgxsdk: path to Intel SGX SDK
# - OFF: disable SGX
#
# SGX_MODE := HW|SIM
set(USE_SGX OFF)
set(SGX_MODE "SIM")
set(RUST_SGX_SDK "/path/to/rust-sgx-sdk")

# Whether enable RPC runtime
set(USE_RPC ON)

# Whether embed stackvm into the runtime
set(USE_STACKVM_RUNTIME OFF)

# Whether enable tiny embedded graph runtime.
set(USE_GRAPH_RUNTIME ON)

# Whether enable additional graph debug functions
set(USE_GRAPH_RUNTIME_DEBUG OFF)

# Whether build with LLVM support
# Requires LLVM version >= 4.0
#
# Possible values:
# - ON: enable llvm with cmake's find search
# - OFF: disable llvm
# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available.
set(USE_LLVM "${CMAKE_CURRENT_SOURCE_DIR}/3rdparty/tvm/build/llvm/bin/llvm-config")

#---------------------------------------------
# Contrib libraries
#---------------------------------------------
# Whether use BLAS, choices: openblas, mkl, atlas, apple
set(USE_BLAS none)

# /path/to/mkl: mkl root path when use mkl blas library
# set(USE_MKL_PATH /opt/intel/mkl) for UNIX
# set(USE_MKL_PATH ../IntelSWTools/compilers_and_libraries_2018/windows/mkl) for WIN32
set(USE_MKL_PATH none)

# Whether use contrib.random in runtime
set(USE_RANDOM OFF)

# Whether use NNPack
set(USE_NNPACK OFF)

# Whether use CuDNN
if(USE_CUDNN AND USE_CUDA)
detect_cuDNN()
if(HAVE_CUDNN)
set(USE_CUDNN ON)
else()
set(USE_CUDNN OFF)
endif()
else()
set(USE_CUDNN OFF)
endif()

# Whether use cuBLAS
set(USE_CUBLAS OFF)

# Whether use MIOpen
set(USE_MIOPEN OFF)

# Whether use MPS
set(USE_MPS OFF)

# Whether use rocBlas
set(USE_ROCBLAS OFF)

# Whether use contrib sort
set(USE_SORT OFF)

# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)

# Build TSIM for VTA
set(USE_VTA_TSIM OFF)

# Whether use Relay debug mode
set(USE_RELAY_DEBUG OFF)
22 changes: 22 additions & 0 deletions contrib/tvmop/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
from .opdef import defop
from .utils import AllTypes, RealTypes

from . import basic
19 changes: 19 additions & 0 deletions contrib/tvmop/basic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
from . import ufunc
50 changes: 50 additions & 0 deletions contrib/tvmop/basic/ufunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Licensed to the Apache Software Foundation (ASF) under one
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
import tvm
from .. import defop, AllTypes

def compute_add(dtype, ndim):
A = tvm.placeholder([tvm.var() for _ in range(ndim)], name='A', dtype=dtype)
B = tvm.placeholder([tvm.var() for _ in range(ndim)], name='B', dtype=dtype)
C = tvm.compute([tvm.var() for _ in range(ndim)],
lambda *index: A[index] + B[index], name='C')
s = tvm.create_schedule(C.op)
return s, A, B, C

@defop(name="vadd", target="cpu", auto_broadcast=True,
dtype=AllTypes, ndim=list(range(1, 6)))
def vadd(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
s[C].parallel(fused)

return s, [A, B, C]

@defop(name="cuda_vadd", target="cuda", auto_broadcast=True,
dtype=["float32", "float64"], ndim=list(range(1, 6)))
def vadd_gpu(dtype, ndim):
s, A, B, C = compute_add(dtype, ndim)
s = tvm.create_schedule(C.op)
axes = [axis for axis in C.op.axis]
fused = s[C].fuse(*axes)
bx, tx = s[C].split(fused, factor=64)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
return s, [A, B, C]
59 changes: 59 additions & 0 deletions contrib/tvmop/compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
"""TVM Operator compile entry point"""
import tvm

import os
import argparse
from tvmop.opdef import __OP_DEF__

def get_target(device):
if device == "cpu":
return "llvm"
elif device == "cuda" or device == "gpu":
return "cuda"
assert False, "Unknown device " + device


if __name__ == "__main__":
import sys
sys.path.append(os.path.dirname(sys.path[0]))
parser = argparse.ArgumentParser(description="Generate tvm operators")
parser.add_argument("-o", action="store", required=True, dest="target_path",
help="Target path which stores compiled library")
arguments = parser.parse_args()

func_list_llvm = []
func_list_cuda = []

# TODO: attach instruction features to the library, e.g., avx-512, etc.
for operator_def in __OP_DEF__:
for sch, args in operator_def.invoke_all():
if tvm.module.enabled(get_target(operator_def.target)):
func_list = func_list_llvm if operator_def.target == "cpu" else func_list_cuda
func_lower = tvm.lower(sch, args,
name=operator_def.get_op_name(args),
binds=operator_def.get_binds(args))
func_list.append(func_lower)

lowered_funcs = {get_target("cpu") : func_list_llvm}
if len(func_list_cuda) > 0:
lowered_funcs[get_target("cuda")] = func_list_cuda
func_binary = tvm.build(lowered_funcs, name="tvmop")
yzhliu marked this conversation as resolved.
Show resolved Hide resolved
func_binary.export_library(arguments.target_path)
Loading