forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TVM bridge support to JIT NDArray Function by TVM (apache#9880)
* TVM bridge support. Support wrap TVM compiled function as a NDArray function. * Testcases and CI to include TVM as dependency * address review comments * Add more comments, change to constexpr * change to log warn * update comment on the type code
- Loading branch information
Showing
12 changed files
with
349 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule dlpack
updated
14 files
+1 −1 | .gitignore | |
+46 −0 | .travis.yml | |
+8 −4 | CMakeLists.txt | |
+21 −10 | Makefile | |
+17 −0 | NEWS.md | |
+18 −11 | README.md | |
+9 −9 | contrib/dlpack/dlpackcpp.h | |
+7 −0 | contrib/mock_c.c | |
+1 −0 | contrib/mock_main.cc | |
+42 −12 | include/dlpack/dlpack.h | |
+6 −0 | tests/scripts/task_build.sh | |
+23 −0 | tests/scripts/task_lint.sh | |
+12 −0 | tests/travis/run_test.sh | |
+5 −0 | tests/travis/setup.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule nnvm
updated
73 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm_bridge.cc | ||
* \brief Bridge to run TVM's PackedFunc in MXNet's async engine. | ||
* | ||
* This bridge is mainly used to expose MXNet's async engine push to | ||
* TVM. It only uses TVM runtime in aheader only mode, which means | ||
* there is no link dependencies. | ||
* | ||
* Support for TVM is optional even when this code | ||
* is always compiled and built with the project. | ||
* We choose this strategy because we do not yet want | ||
* llvm as dependency(which TVM uses). So instead we expose hook | ||
* to TVM and let user use this feature when they have TVM installed. | ||
* | ||
* We do require TVM and MXNet to be built with same C++ ABI of std::function | ||
*/ | ||
#define TVM_RUNTIME_HEADER_ONLY 1 | ||
#include <tvm/runtime/packed_func.h> | ||
#include <mxnet/c_api.h> | ||
#include <mxnet/ndarray.h> | ||
#include <mxnet/engine.h> | ||
|
||
#include <memory> | ||
|
||
namespace mxnet { | ||
|
||
using tvm::runtime::PackedFunc; | ||
using tvm::runtime::TVMArgs; | ||
using tvm::runtime::TVMRetValue; | ||
|
||
/*! | ||
* \brief Async functor object | ||
* calling argument of the function. | ||
*/ | ||
class TVMFunctor { | ||
public: | ||
// constructor | ||
explicit TVMFunctor(PackedFunc func, PackedFunc fset_stream) | ||
: func_(func), fset_stream_(fset_stream) {} | ||
|
||
void Init(const TVMArgs& args, | ||
const std::vector<int>& const_loc, | ||
std::vector<Engine::VarHandle>* const_vars, | ||
std::vector<Engine::VarHandle>* mutate_vars) { | ||
values_.clear(); | ||
type_codes_.clear(); | ||
values_.insert(values_.end(), args.values, args.values + args.size()); | ||
type_codes_.insert( | ||
type_codes_.end(), args.type_codes, args.type_codes + args.size()); | ||
|
||
size_t const_loc_ptr = 0; | ||
for (int i = 0; i < args.size(); ++i) { | ||
if (args.type_codes[i] == kTVMNDArrayTypeCode) { | ||
const NDArray& nd = | ||
static_cast<NDArray*>(args.values[i].v_handle)[0]; | ||
// We cannot set the value until | ||
type_codes_[i] = kArrayHandle; | ||
array_data_.push_back(nd); | ||
array_loc_.push_back(i); | ||
// check if there is read or mutate | ||
// by default assume we mutate the array. | ||
if (const_loc_ptr < const_loc.size() && | ||
i == const_loc[const_loc_ptr]) { | ||
const_vars->push_back(nd.var()); | ||
++const_loc_ptr; | ||
} else { | ||
mutate_vars->push_back(nd.var()); | ||
} | ||
} else { | ||
CHECK_LT(args.type_codes[i], kTVMType) | ||
<< "Only allow POD type in mxnet async call"; | ||
} | ||
} | ||
} | ||
|
||
Context ctx() { | ||
return array_data_[0].ctx(); | ||
} | ||
|
||
void Run(const RunContext& rctx) { | ||
// setup DLTensor | ||
for (size_t i = 0; i < array_loc_.size(); ++i) { | ||
values_[array_loc_[i]].v_handle = | ||
const_cast<DLTensor*>(&(array_data_[i].data().dltensor())); | ||
} | ||
// run the packed function | ||
TVMRetValue rv; | ||
TVMArgs args(&values_[0], &type_codes_[0], values_.size()); | ||
if (ctx().dev_type == Context::kGPU) { | ||
#if MXNET_USE_CUDA | ||
// pass stream via last argument. | ||
void* strm = static_cast<void*>(rctx.get_stream<gpu>()->stream_); | ||
int dev_type = kDLGPU; | ||
fset_stream_(dev_type, rctx.ctx.dev_id, strm); | ||
func_.CallPacked(args, &rv); | ||
fset_stream_(dev_type, rctx.ctx.dev_id, nullptr); | ||
#else | ||
LOG(FATAL) << "Please compile with CUDA enabled for cuda features"; | ||
#endif | ||
} else { | ||
func_.CallPacked(args, &rv); | ||
} | ||
} | ||
|
||
private: | ||
/*! \brief The function */ | ||
PackedFunc func_; | ||
/*! \brief Set stream */ | ||
PackedFunc fset_stream_; | ||
/*! \brief Values field */ | ||
std::vector<TVMValue> values_; | ||
/*! \brief type code field */ | ||
std::vector<int> type_codes_; | ||
/*! \brief arrays field */ | ||
std::vector<NDArray> array_data_; | ||
/*! \brief position of array in arguments */ | ||
std::vector<int> array_loc_; | ||
}; | ||
|
||
|
||
// Wrap a TVM function to a function that invokes MXNet's Engine | ||
// It does two things: call the engine properly | ||
// set up the NDArray to DLTensor during invocation. | ||
void WrapAsyncCall(TVMArgs wrap_args, TVMRetValue* wrap_rv) { | ||
PackedFunc f = wrap_args[0]; | ||
PackedFunc fset_stream = wrap_args[1]; | ||
int num_const = wrap_args[2]; | ||
|
||
// sorted position of constant arguments | ||
std::vector<int> const_loc; | ||
for (int i = 0; i < num_const; ++i) { | ||
const_loc.push_back(wrap_args[i + 3].operator int()); | ||
} | ||
std::sort(const_loc.begin(), const_loc.end()); | ||
// wrapped function | ||
// This is the function that called by the user. | ||
auto wrapped = [f, fset_stream, const_loc](TVMArgs args, TVMRetValue* rv) { | ||
std::shared_ptr<TVMFunctor> func = | ||
std::make_shared<TVMFunctor>(f, fset_stream); | ||
std::vector<Engine::VarHandle> const_vars, mutate_vars; | ||
func->Init(args, const_loc, &const_vars, &mutate_vars); | ||
Engine *engine = Engine::Get(); | ||
engine->DeduplicateVarHandle(&const_vars, &mutate_vars); | ||
engine->PushSync([func](RunContext ctx) { | ||
func->Run(ctx); | ||
}, func->ctx(), const_vars, mutate_vars); | ||
}; | ||
*wrap_rv = PackedFunc(wrapped); | ||
} | ||
|
||
} // namespace mxnet | ||
|
||
// C callback that can be used by TVM to extract | ||
// the WrapAsyncCall function. | ||
extern "C" MXNET_DLL int MXTVMBridge(TVMFunctionHandle pregister) { | ||
using tvm::runtime::PackedFunc; | ||
const PackedFunc& fregister = | ||
*static_cast<PackedFunc*>(pregister); | ||
fregister("WrapAsyncCall", PackedFunc(mxnet::WrapAsyncCall)); | ||
return 0; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
|
||
|
||
echo deb http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\ | ||
>> /etc/apt/sources.list.d/llvm.list | ||
echo deb-src http://apt.llvm.org/xenial/ llvm-toolchain-xenial-5.0 main\ | ||
>> /etc/apt/sources.list.d/llvm.list | ||
|
||
wget -O - http://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add - | ||
apt-get update && apt-get install -y --force-yes llvm-5.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#!/usr/bin/env bash | ||
|
||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
# Build and install TVM | ||
cd /tmp | ||
git clone https://github.com/dmlc/tvm/ --recursive | ||
cd tvm | ||
|
||
# This is a stable tag that support MXNet TVM bridge. | ||
# We use this since support for mxnet bridge just checked | ||
# into master and there is yet a version tag | ||
git checkout 30eaf463e34d7c301357c31a010945d11df16537 | ||
|
||
cp make/config.mk | ||
echo USE_CUDA=1 >> config.mk | ||
echo LLVM_CONFIG=llvm-config-5.0 >> config.mk | ||
echo USE_RPC=1 >> config.mk | ||
echo USE_GRAPH_RUNTIME=1 >> config.mk | ||
echo CUDA_PATH=/usr/local/cuda >> config.mk | ||
make -j`nproc` | ||
|
||
cd python | ||
python setup.py install | ||
cd - | ||
|
||
cd topi/python | ||
python setup.py install | ||
cd - |
Oops, something went wrong.