-
Notifications
You must be signed in to change notification settings - Fork 6.8k
TVM bridge support to JIT NDArray Function by TVM #9880
Changes from 3 commits
46504e5
5aa96c4
385589a
6a69ff5
ed9ed9d
7d4524a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
+1 −1 | .gitignore | |
+46 −0 | .travis.yml | |
+8 −4 | CMakeLists.txt | |
+21 −10 | Makefile | |
+17 −0 | NEWS.md | |
+18 −11 | README.md | |
+9 −9 | contrib/dlpack/dlpackcpp.h | |
+7 −0 | contrib/mock_c.c | |
+1 −0 | contrib/mock_main.cc | |
+42 −12 | include/dlpack/dlpack.h | |
+6 −0 | tests/scripts/task_build.sh | |
+23 −0 | tests/scripts/task_lint.sh | |
+12 −0 | tests/travis/run_test.sh | |
+5 −0 | tests/travis/setup.sh |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,8 +36,15 @@ | |
#include <utility> | ||
#include <algorithm> | ||
#include "./base.h" | ||
|
||
namespace mxnet { | ||
|
||
// redefine DLPack enumeration to be backward compatible. | ||
const int kCPU = kDLCPU; | ||
const int kGPU = kDLGPU; | ||
// extension type code under TVM function. | ||
const int kTVMNDArrayTypeCode = 19; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. would it make sense to make it an enumerator? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mostly because the 19 seems arbitrary? and maybe extensible to other numbers in the future? in that case, an enum could help to manage accidental overlap. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This enumerator is allocated in the TVM side and reserved for MXNet and NNVM project, so it is not arbitrary chosen https://github.com/dmlc/tvm/blob/master/include/tvm/runtime/c_runtime_api.h#L97 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is picked to be the last reserved enumerator for NNVM There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will add a comment about this in the updated code |
||
|
||
/* Forward declaration for friend declaration in TBlob */ | ||
class NDArray; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -174,8 +174,14 @@ class NDArray(NDArrayBase): | |
__slots__ = [] | ||
# make numpy functions return NDArray instead of numpy object array | ||
__array_priority__ = 1000.0 | ||
# used by tvm bridge | ||
_tvm_tcode = 19 | ||
# pylint: disable= no-member, undefined-variable | ||
|
||
@property | ||
def _tvm_handle(self): | ||
return self.handle.value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a handle exposed for PackedFunc convention interface of TVM, to allow arbitrary positional arguments calls without adding new C API. Specifically, the wrapped function is a TVM PackedFunc that will recognize NDArray as an extension object, and pass the address of NDArray handles correctly to the arguments. It is later received in here https://github.com/apache/incubator-mxnet/pull/9880/files#diff-3aa2a3c799e125e086769bc1d5f6490aR74 |
||
|
||
def __repr__(self): | ||
"""Returns a string representation of the array.""" | ||
shape_info = 'x'.join(['%d' % x for x in self.shape]) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm_bridge.cc | ||
* \brief Bridge to run TVM's PackedFunc in MXNet's async engine. | ||
* | ||
* This bridge is mainly used to expose MXNet's async engine push to | ||
* TVM. It only uses TVM runtime in aheader only mode, which means | ||
* there is no link dependencies. | ||
* | ||
* Support for TVM is optional even when this code | ||
* is always compiled and built with the project. | ||
* We choose this strategy because we do not yet want | ||
* llvm as dependency(which TVM uses). So instead we expose hook | ||
* to TVM and let user use this feature when they have TVM installed. | ||
* | ||
* We do require TVM and MXNet to be built with same C++ ABI of std::function | ||
*/ | ||
#define TVM_RUNTIME_HEADER_ONLY 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should you undefine it at the end? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since it is in cc file, this is not necessary |
||
#include <tvm/runtime/packed_func.h> | ||
#include <mxnet/c_api.h> | ||
#include <mxnet/ndarray.h> | ||
#include <mxnet/engine.h> | ||
|
||
#include <memory> | ||
|
||
namespace mxnet { | ||
|
||
using tvm::runtime::PackedFunc; | ||
using tvm::runtime::TVMArgs; | ||
using tvm::runtime::TVMRetValue; | ||
|
||
/*! | ||
* \brief Async functor object | ||
* calling argument of the function. | ||
*/ | ||
class TVMFunctor { | ||
public: | ||
// constructor | ||
explicit TVMFunctor(PackedFunc func, PackedFunc fset_stream) | ||
: func_(func), fset_stream_(fset_stream) {} | ||
|
||
void Init(const TVMArgs& args, | ||
const std::vector<int>& const_loc, | ||
std::vector<Engine::VarHandle>* const_vars, | ||
std::vector<Engine::VarHandle>* mutate_vars) { | ||
values_.clear(); | ||
type_codes_.clear(); | ||
values_.insert(values_.end(), args.values, args.values + args.size()); | ||
type_codes_.insert( | ||
type_codes_.end(), args.type_codes, args.type_codes + args.size()); | ||
|
||
size_t const_loc_ptr = 0; | ||
for (int i = 0; i < args.size(); ++i) { | ||
if (args.type_codes[i] == kTVMNDArrayTypeCode) { | ||
const NDArray& nd = | ||
static_cast<NDArray*>(args.values[i].v_handle)[0]; | ||
// We cannot set the value until | ||
type_codes_[i] = kArrayHandle; | ||
array_data_.push_back(nd); | ||
array_loc_.push_back(i); | ||
// check if there is read or mutate | ||
// by default assume we mutate the array. | ||
if (const_loc_ptr < const_loc.size() && | ||
i == const_loc[const_loc_ptr]) { | ||
const_vars->push_back(nd.var()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this called a lot in performance-sensitive areas? should we do a reserve()? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (for all vectors here) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't know the size of vector before hand There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ik There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
++const_loc_ptr; | ||
} else { | ||
mutate_vars->push_back(nd.var()); | ||
} | ||
} else { | ||
CHECK_LT(args.type_codes[i], kTVMType) | ||
<< "Only allow POD type in mxnet async call"; | ||
} | ||
} | ||
} | ||
|
||
Context ctx() { | ||
return array_data_[0].ctx(); | ||
} | ||
|
||
void Run(const RunContext& rctx) { | ||
// setup DLTensor | ||
for (size_t i = 0; i < array_loc_.size(); ++i) { | ||
values_[array_loc_[i]].v_handle = | ||
const_cast<DLTensor*>(&(array_data_[i].data().dltensor())); | ||
} | ||
// run the packed function | ||
TVMRetValue rv; | ||
TVMArgs args(&values_[0], &type_codes_[0], values_.size()); | ||
if (ctx().dev_type == Context::kGPU) { | ||
#if MXNET_USE_CUDA | ||
// pass stream via last argument. | ||
void* strm = static_cast<void*>(rctx.get_stream<gpu>()->stream_); | ||
int dev_type = kDLGPU; | ||
fset_stream_(dev_type, rctx.ctx.dev_id, strm); | ||
func_.CallPacked(args, &rv); | ||
fset_stream_(dev_type, rctx.ctx.dev_id, nullptr); | ||
#else | ||
LOG(FATAL) << "Please compile with CUDA enabled for cuda features"; | ||
#endif | ||
} else { | ||
func_.CallPacked(args, &rv); | ||
} | ||
} | ||
|
||
private: | ||
/*! \brief The function */ | ||
PackedFunc func_; | ||
/*! \brief Set stream */ | ||
PackedFunc fset_stream_; | ||
/*! \brief Values field */ | ||
std::vector<TVMValue> values_; | ||
/*! \brief type code field */ | ||
std::vector<int> type_codes_; | ||
/*! \brief arrays field */ | ||
std::vector<NDArray> array_data_; | ||
/*! \brief position of array in arguments */ | ||
std::vector<int> array_loc_; | ||
}; | ||
|
||
|
||
// Wrap a TVM function to a function that invokes MXNet's Engine | ||
// It does two things: call the engine properly | ||
// set up the NDArray to DLTensor during invocation. | ||
void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) { | ||
PackedFunc f = wrap_args[0]; | ||
PackedFunc fset_stream = wrap_args[1]; | ||
int num_const = wrap_args[2]; | ||
|
||
// sorted position of constant arguments | ||
std::vector<int> const_loc; | ||
for (int i = 0; i < num_const; ++i) { | ||
const_loc.push_back(wrap_args[i + 3].operator int()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reserve? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not on critical path(function construction instead of running) |
||
} | ||
std::sort(const_loc.begin(), const_loc.end()); | ||
// wrapped function | ||
// This is the function that called by the user. | ||
auto wrapped = [f, fset_stream, const_loc](TVMArgs args, TVMRetValue* rv) { | ||
std::shared_ptr<TVMFunctor> func = | ||
std::make_shared<TVMFunctor>(f, fset_stream); | ||
std::vector<Engine::VarHandle> const_vars, mutate_vars; | ||
func->Init(args, const_loc, &const_vars, &mutate_vars); | ||
Engine *engine = Engine::Get(); | ||
engine->DeduplicateVarHandle(&const_vars, &mutate_vars); | ||
engine->PushSync([func](RunContext ctx) { | ||
func->Run(ctx); | ||
}, func->ctx(), const_vars, mutate_vars); | ||
}; | ||
*wrap_rv = PackedFunc(wrapped); | ||
} | ||
|
||
} // namespace mxnet | ||
|
||
// C callback that can be used by TVM to extract | ||
// the WrapAsyncCall function. | ||
extern "C" MXNET_DLL int MXTVMBridge(TVMFunctionHandle pregister) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this a CAPI? Should it be put in the c api folder? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is queried by TVM, so not publicly facing C API. I feel that it is better to put in here. but we can move to c_api folder |
||
using tvm::runtime::PackedFunc; | ||
const PackedFunc& fregister = | ||
*static_cast<PackedFunc*>(pregister); | ||
fregister("WrapAsyncCall", PackedFunc(mxnet::WrapAsyncCall)); | ||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
|
||
|
||
echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\ | ||
>> /etc/apt/sources.list.d/llvm.list | ||
echo deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\ | ||
>> /etc/apt/sources.list.d/llvm.list | ||
|
||
wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - | ||
apt-get update && apt-get install -y --force-yes llvm-5.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# Build and install TVM | ||
cd /tmp | ||
git clone https://github.com/dmlc/tvm/ --recursive | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you aware that the result of this script is being cached indefinitely? In that case, it would be better to specify a stable version instead of Master as otherwise environments may differ on different slaves There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i am aware of that, change to used a fixed tag |
||
cd tvm | ||
|
||
# This is a stable tag that support MXNet TVM bridge. | ||
# We use this since support for mxnet bridge just checked | ||
# into master and there is yet a version tag | ||
git checkout 30eaf463e34d7c301357c31a010945d11df16537 | ||
|
||
cp make/config.mk | ||
echo USE_CUDA=1 >> config.mk | ||
echo LLVM_CONFIG=llvm-config-5.0 >> config.mk | ||
echo USE_RPC=1 >> config.mk | ||
echo USE_GRAPH_RUNTIME=1 >> config.mk | ||
echo CUDA_PATH=/usr/local/cuda >> config.mk | ||
make -j`nproc` | ||
|
||
cd python | ||
python setup.py install | ||
cd - | ||
|
||
cd topi/python | ||
python setup.py install | ||
cd - |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be constexpr? what keeps it from generating an integer in the data segment for each file compiled?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch, will change to constexpr