diff --git a/CMakeLists.txt b/CMakeLists.txt index 16d365355ceb..d229cb0847d0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,6 +234,7 @@ include_directories("include") include_directories("mshadow") include_directories("3rdparty/cub") include_directories("nnvm/include") +include_directories("nnvm/tvm/include") include_directories("dmlc-core/include") include_directories("dlpack/include") @@ -696,4 +697,3 @@ endif() set(LINT_DIRS "include src plugin cpp-package tests") set(EXCLUDE_PATH "src/operator/contrib/ctc_include") add_custom_target(mxnet_lint COMMAND ${CMAKE_COMMAND} -DMSVC=${MSVC} -DPYTHON_EXECUTABLE=${PYTHON_EXECUTABLE} -DLINT_DIRS=${LINT_DIRS} -DPROJECT_SOURCE_DIR=${CMAKE_CURRENT_SOURCE_DIR} -DPROJECT_NAME=mxnet -DEXCLUDE_PATH=${EXCLUDE_PATH} -P ${CMAKE_CURRENT_SOURCE_DIR}/dmlc-core/cmake/lint.cmake) - diff --git a/Jenkinsfile b/Jenkinsfile index 81ddb73f324b..5c96661c983b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -37,12 +37,12 @@ def init_git() { deleteDir() retry(5) { try { - // Make sure wait long enough for api.github.com request quota. Important: Don't increase the amount of + // Make sure wait long enough for api.github.com request quota. Important: Don't increase the amount of // retries as this will increase the amount of requests and worsen the throttling timeout(time: 15, unit: 'MINUTES') { checkout scm - sh 'git submodule update --init' - sh 'git clean -d -f' + sh 'git submodule update --init --recursive' + sh 'git clean -d -f' } } catch (exc) { deleteDir() @@ -60,8 +60,8 @@ def init_git_win() { // retries as this will increase the amount of requests and worsen the throttling timeout(time: 15, unit: 'MINUTES') { checkout scm - bat 'git submodule update --init' - bat 'git clean -d -f' + bat 'git submodule update --init --recursive' + bat 'git clean -d -f' } } catch (exc) { deleteDir() diff --git a/Makefile b/Makefile index cb3e63ba13b0..5d81c7fbb160 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ ifeq ($(DEBUG), 1) else CFLAGS += -O3 -DNDEBUG=1 endif -CFLAGS += -I$(ROOTDIR)/mshadow/ -I$(ROOTDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -Iinclude $(MSHADOW_CFLAGS) +CFLAGS += -I$(ROOTDIR)/mshadow/ -I$(ROOTDIR)/dmlc-core/include -fPIC -I$(NNVM_PATH)/include -I$(DLPACK_PATH)/include -I$(NNVM_PATH)/tvm/include -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) ifeq ($(DEBUG), 1) NVCCFLAGS += -std=c++11 -Xcompiler -D_FORCE_INLINES -g -G -O0 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) @@ -356,7 +356,7 @@ ifeq ($(USE_CUDA), 1) LDFLAGS += -lcuda -lnvrtc CFLAGS += -DMXNET_ENABLE_CUDA_RTC=1 endif - # Make sure to add stubs as fallback in order to be able to build + # Make sure to add stubs as fallback in order to be able to build # without full CUDA install (especially if run without nvidia-docker) LDFLAGS += -L/usr/local/cuda/lib64/stubs SCALA_PKG_PROFILE := $(SCALA_PKG_PROFILE)-gpu diff --git a/dlpack b/dlpack index a6e09b58dc00..10892ac964f1 160000 --- a/dlpack +++ b/dlpack @@ -1 +1 @@ -Subproject commit a6e09b58dc00ee0065f5b7879800e646fbb01d1e +Subproject commit 10892ac964f1af7c81aae145cd3fab78bbccd297 diff --git a/include/mxnet/tensor_blob.h b/include/mxnet/tensor_blob.h index 168ddcca24b7..59c1eacb2c58 100755 --- a/include/mxnet/tensor_blob.h +++ b/include/mxnet/tensor_blob.h @@ -36,8 +36,18 @@ #include #include #include "./base.h" + namespace mxnet { +// redefine DLPack enumeration to be backward compatible. +constexpr const int kCPU = kDLCPU; +constexpr const int kGPU = kDLGPU; +// extension type code under TVM function. +// Currently NNVM reserved 16 to 19 type code from TVM +// 16, 17, 18 is used by NNVM compiler already. +// Pick code 19 for MXNet NDArray +constexpr const int kTVMNDArrayTypeCode = 19; + /* Forward declaration for friend declaration in TBlob */ class NDArray; diff --git a/nnvm b/nnvm index 7a052d678455..c342da72271c 160000 --- a/nnvm +++ b/nnvm @@ -1 +1 @@ -Subproject commit 7a052d678455f1c96538c1cc5a25f11115363558 +Subproject commit c342da72271c85e477480323f1d91997c6101ac0 diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 4c5273fd40d3..d7eb9fa21467 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -174,8 +174,15 @@ class NDArray(NDArrayBase): __slots__ = [] # make numpy functions return NDArray instead of numpy object array __array_priority__ = 1000.0 + # Extension type code for TVM function. + # See C++ side of definition(kTVMNDArrayTypeCode) at include/mxmet/tensor_blob.h + _tvm_tcode = 19 # pylint: disable= no-member, undefined-variable + @property + def _tvm_handle(self): + return self.handle.value + def __repr__(self): """Returns a string representation of the array.""" shape_info = 'x'.join(['%d' % x for x in self.shape]) diff --git a/src/nnvm/tvm_bridge.cc b/src/nnvm/tvm_bridge.cc new file mode 100644 index 000000000000..06929984640d --- /dev/null +++ b/src/nnvm/tvm_bridge.cc @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm_bridge.cc + * \brief Bridge to run TVM's PackedFunc in MXNet's async engine. + * + * This bridge is mainly used to expose MXNet's async engine push to + * TVM. It only uses TVM runtime in aheader only mode, which means + * there is no link dependencies. + * + * Support for TVM is optional even when this code + * is always compiled and built with the project. + * We choose this strategy because we do not yet want + * llvm as dependency(which TVM uses). So instead we expose hook + * to TVM and let user use this feature when they have TVM installed. + * + * We do require TVM and MXNet to be built with same C++ ABI of std::function + */ +#define TVM_RUNTIME_HEADER_ONLY 1 +#include +#include +#include +#include + +#include + +namespace mxnet { + +using tvm::runtime::PackedFunc; +using tvm::runtime::TVMArgs; +using tvm::runtime::TVMRetValue; + +/*! + * \brief Async functor object + * calling argument of the function. + */ +class TVMFunctor { + public: + // constructor + explicit TVMFunctor(PackedFunc func, PackedFunc fset_stream) + : func_(func), fset_stream_(fset_stream) {} + + void Init(const TVMArgs& args, + const std::vector& const_loc, + std::vector* const_vars, + std::vector* mutate_vars) { + values_.clear(); + type_codes_.clear(); + values_.insert(values_.end(), args.values, args.values + args.size()); + type_codes_.insert( + type_codes_.end(), args.type_codes, args.type_codes + args.size()); + + size_t const_loc_ptr = 0; + for (int i = 0; i < args.size(); ++i) { + if (args.type_codes[i] == kTVMNDArrayTypeCode) { + const NDArray& nd = + static_cast(args.values[i].v_handle)[0]; + // We cannot set the value until + type_codes_[i] = kArrayHandle; + array_data_.push_back(nd); + array_loc_.push_back(i); + // check if there is read or mutate + // by default assume we mutate the array. + if (const_loc_ptr < const_loc.size() && + i == const_loc[const_loc_ptr]) { + const_vars->push_back(nd.var()); + ++const_loc_ptr; + } else { + mutate_vars->push_back(nd.var()); + } + } else { + CHECK_LT(args.type_codes[i], kTVMType) + << "Only allow POD type in mxnet async call"; + } + } + } + + Context ctx() { + return array_data_[0].ctx(); + } + + void Run(const RunContext& rctx) { + // setup DLTensor + for (size_t i = 0; i < array_loc_.size(); ++i) { + values_[array_loc_[i]].v_handle = + const_cast(&(array_data_[i].data().dltensor())); + } + // run the packed function + TVMRetValue rv; + TVMArgs args(&values_[0], &type_codes_[0], values_.size()); + if (ctx().dev_type == Context::kGPU) { +#if MXNET_USE_CUDA + // pass stream via last argument. + void* strm = static_cast(rctx.get_stream()->stream_); + int dev_type = kDLGPU; + fset_stream_(dev_type, rctx.ctx.dev_id, strm); + func_.CallPacked(args, &rv); + fset_stream_(dev_type, rctx.ctx.dev_id, nullptr); +#else + LOG(FATAL) << "Please compile with CUDA enabled for cuda features"; +#endif + } else { + func_.CallPacked(args, &rv); + } + } + + private: + /*! \brief The function */ + PackedFunc func_; + /*! \brief Set stream */ + PackedFunc fset_stream_; + /*! \brief Values field */ + std::vector values_; + /*! \brief type code field */ + std::vector type_codes_; + /*! \brief arrays field */ + std::vector array_data_; + /*! \brief position of array in arguments */ + std::vector array_loc_; +}; + + +// Wrap a TVM function to a function that invokes MXNet's Engine +// It does two things: call the engine properly +// set up the NDArray to DLTensor during invocation. +void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) { + PackedFunc f = wrap_args[0]; + PackedFunc fset_stream = wrap_args[1]; + int num_const = wrap_args[2]; + + // sorted position of constant arguments + std::vector const_loc; + for (int i = 0; i < num_const; ++i) { + const_loc.push_back(wrap_args[i + 3].operator int()); + } + std::sort(const_loc.begin(), const_loc.end()); + // wrapped function + // This is the function that called by the user. + auto wrapped = [f, fset_stream, const_loc](TVMArgs args, TVMRetValue* rv) { + std::shared_ptr func = + std::make_shared(f, fset_stream); + std::vector const_vars, mutate_vars; + func->Init(args, const_loc, &const_vars, &mutate_vars); + Engine *engine = Engine::Get(); + engine->DeduplicateVarHandle(&const_vars, &mutate_vars); + engine->PushSync([func](RunContext ctx) { + func->Run(ctx); + }, func->ctx(), const_vars, mutate_vars); + }; + *wrap_rv = PackedFunc(wrapped); +} + +} // namespace mxnet + +// C callback that can be used by TVM to extract +// the WrapAsyncCall function. +extern "C" MXNET_DLL int MXTVMBridge(TVMFunctionHandle pregister) { + using tvm::runtime::PackedFunc; + const PackedFunc& fregister = + *static_cast(pregister); + fregister("WrapAsyncCall", PackedFunc(mxnet::WrapAsyncCall)); + return 0; +} diff --git a/tests/ci_build/Dockerfile.gpu b/tests/ci_build/Dockerfile.gpu index 2483e62b99b5..bd1a00839167 100644 --- a/tests/ci_build/Dockerfile.gpu +++ b/tests/ci_build/Dockerfile.gpu @@ -12,3 +12,9 @@ COPY install/ubuntu_install_r.sh /install/ RUN /install/ubuntu_install_r.sh COPY install/ubuntu_install_perl.sh /install/ RUN /install/ubuntu_install_perl.sh + +COPY install/ubuntu_install_llvm.sh /install/ +RUN /install/ubuntu_install_llvm.sh + +COPY install/ubuntu_install_tvm.sh /install/ +RUN /install/ubuntu_install_tvm.sh diff --git a/tests/ci_build/install/ubuntu_install_llvm.sh b/tests/ci_build/install/ubuntu_install_llvm.sh new file mode 100755 index 000000000000..d3282e7a5fce --- /dev/null +++ b/tests/ci_build/install/ubuntu_install_llvm.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + + +echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\ + >> /etc/apt/sources.list.d/llvm.list +echo deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\ + >> /etc/apt/sources.list.d/llvm.list + +wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - +apt-get update && apt-get install -y --force-yes llvm-5.0 diff --git a/tests/ci_build/install/ubuntu_install_tvm.sh b/tests/ci_build/install/ubuntu_install_tvm.sh new file mode 100755 index 000000000000..2729c7fe3bee --- /dev/null +++ b/tests/ci_build/install/ubuntu_install_tvm.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Build and install TVM +cd /tmp +git clone https://github.com/dmlc/tvm/ --recursive +cd tvm + +# This is a stable tag that support MXNet TVM bridge. +# We use this since support for mxnet bridge just checked +# into master and there is yet a version tag +git checkout 30eaf463e34d7c301357c31a010945d11df16537 + +cp make/config.mk +echo USE_CUDA=1 >> config.mk +echo LLVM_CONFIG=llvm-config-5.0 >> config.mk +echo USE_RPC=1 >> config.mk +echo USE_GRAPH_RUNTIME=1 >> config.mk +echo CUDA_PATH=/usr/local/cuda >> config.mk +make -j`nproc` + +cd python +python setup.py install +cd - + +cd topi/python +python setup.py install +cd - diff --git a/tests/python/gpu/test_tvm_bridge.py b/tests/python/gpu/test_tvm_bridge.py new file mode 100644 index 000000000000..292b9d91e5f7 --- /dev/null +++ b/tests/python/gpu/test_tvm_bridge.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test TVM bridge, only enable this when TVM is available""" +import logging +import mxnet as mx +import numpy as np + +def test_tvm_bridge(): + # only enable test if TVM is available + try: + import tvm + import tvm.contrib.mxnet + import topi + except ImportError: + logging.warn("TVM bridge test skipped because TVM is missing...") + return + + def check(target): + shape = (20,) + scale = tvm.var("scale", dtype="float32") + x = tvm.placeholder(shape) + y = tvm.placeholder(shape) + z = tvm.compute(shape, lambda i: x[i] + y[i]) + zz = tvm.compute(shape, lambda *i: z(*i) * scale) + ctx = mx.gpu(0) if target == "cuda" else mx.cpu(0) + target = tvm.target.create(target) + + # build the function + with target: + s = topi.generic.schedule_injective(zz) + f = tvm.build(s, [x, y, zz, scale]) + + # get a mxnet version + mxf = tvm.contrib.mxnet.to_mxnet_func(f, const_loc=[0, 1]) + xx = mx.nd.uniform(shape=shape, ctx=ctx) + yy = mx.nd.uniform(shape=shape, ctx=ctx) + zz = mx.nd.empty(shape=shape, ctx=ctx) + # invoke myf: this runs in mxnet engine + mxf(xx, yy, zz, 10.0) + np.testing.assert_allclose( + zz.asnumpy(), (xx.asnumpy() + yy.asnumpy()) * 10) + + check("llvm") + check("cuda") + + + +if __name__ == "__main__": + test_tvm_bridge()