diff --git a/.travis.yml b/.travis.yml index 5c7a5d2562a6..1d9e5bad4ed3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,7 @@ env: - TASK=python CXX=g++ - TASK=python3 CXX=g++ - TASK=python_naive CXX=g++ + - TASK=python_perdev CXX=g++ - TASK=cpp_unittest CXX=g++ # dependent apt packages diff --git a/dmlc-core b/dmlc-core index 75f1950d386d..2e2d187efc43 160000 --- a/dmlc-core +++ b/dmlc-core @@ -1 +1 @@ -Subproject commit 75f1950d386d033b0b64919017515d27e698962a +Subproject commit 2e2d187efc43ee2df1d132c3690169575e830442 diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 7f2fc7c07a0b..4c63b35bade5 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -1,12 +1,13 @@ /*! * Copyright (c) 2015 by Contributors * \file base.h - * \brief configuation of mxnet + * \brief configuation of mxnet as well as basic data structure. */ #ifndef MXNET_BASE_H_ #define MXNET_BASE_H_ #include +#include #include #include #include @@ -62,6 +63,84 @@ typedef mshadow::default_real_t real_t; typedef mshadow::TShape TShape; /*! \brief storage container type */ typedef mshadow::TBlob TBlob; + +/*! \brief Context information about the execution enviroment */ +struct Context { + /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ + int32_t dev_mask; + /*! \brief device id we are going to run it on */ + int32_t dev_id; + /*! \brief constructor */ + Context() : dev_mask(cpu::kDevMask), dev_id(0) {} + /*! + * \brief constructor of context + * \param dev_mask the device mask + * \param dev_id the device id + */ + Context(int dev_mask, int dev_id) + : dev_mask(dev_mask), dev_id(dev_id) {} + /*! + * \brief check if current context equals another one + * \param b another context to compare + * \return whether dev mask and id are same + */ + inline bool operator==(const Context &b) const { + return dev_mask == b.dev_mask && dev_id == b.dev_id; + } + /*! + * \brief check if current context not equals another one + * \param b another context to compare + * \return whether they are not the same + */ + inline bool operator!=(const Context &b) const { + return !(*this == b); + } + /*! + * \brief save the content into binary stream + * \param strm the output stream + */ + void Save(dmlc::Stream *strm) const { + strm->Write(&dev_mask, sizeof(dev_mask)); + strm->Write(&dev_id, sizeof(dev_id)); + } + /*! + * \brief load the content from binary stream + * \param strm the output stream + * \return whether the load is successful + */ + bool Load(dmlc::Stream *strm) { + if (strm->Read(&dev_mask, sizeof(int32_t)) != sizeof(int32_t)) return false; + if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; + return true; + } + /*! \brief the maximal device mask, cpu = 1, gpu = 2 */ + static const int32_t kMaxDevMask = 2; + /*! + * \brief A dedicate ID for pinned cpu memory. + * Any normal CPU ID should be less than this number. + */ + static const int32_t kPinnedMemoryID = 16; +}; + +/*! + * \brief execution time context. + * The information needed in runtime for actual execution. + */ +struct RunContext { + /*! + * \brief the stream of the device, can be NULL or Stream* in GPU mode + */ + void *stream; + /*! + * \brief get mshadow stream from Context + * \return the mshadow stream + * \tparam xpu the device type of the stream + */ + template + inline mshadow::Stream* get_stream() const { + return static_cast*>(stream); + } +}; } // namespace mxnet //! \cond Doxygen_Suppress diff --git a/include/mxnet/context.h b/include/mxnet/context.h deleted file mode 100644 index a7ed35d21263..000000000000 --- a/include/mxnet/context.h +++ /dev/null @@ -1,131 +0,0 @@ -/*! - * Copyright (c) 2015 by Contributors - * \file context.h - * \brief Context information and resources in mxnet. - */ -#ifndef MXNET_CONTEXT_H_ -#define MXNET_CONTEXT_H_ - -#include -#include -#include -#include -#include "./base.h" - -namespace mxnet { - -/*! \brief Context information about the execution enviroment */ -struct Context { - /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ - int32_t dev_mask; - /*! \brief device id we are going to run it on */ - int32_t dev_id; - /*! \brief constructor */ - Context() : dev_mask(cpu::kDevMask), dev_id(0) {} - /*! - * \brief constructor of context - * \param dev_mask the device mask - * \param dev_id the device id - */ - Context(int dev_mask, int dev_id) - : dev_mask(dev_mask), dev_id(dev_id) {} - /*! - * \brief check if current context equals another one - * \param b another context to compare - * \return whether dev mask and id are same - */ - inline bool operator==(const Context &b) const { - return dev_mask == b.dev_mask && dev_id == b.dev_id; - } - /*! - * \brief check if current context not equals another one - * \param b another context to compare - * \return whether they are not the same - */ - inline bool operator!=(const Context &b) const { - return !(*this == b); - } - /*! - * \brief save the content into binary stream - * \param strm the output stream - */ - void Save(dmlc::Stream *strm) const { - strm->Write(&dev_mask, sizeof(dev_mask)); - strm->Write(&dev_id, sizeof(dev_id)); - } - /*! - * \brief load the content from binary stream - * \param strm the output stream - * \return whether the load is successful - */ - bool Load(dmlc::Stream *strm) { - if (strm->Read(&dev_mask, sizeof(int32_t)) != sizeof(int32_t)) return false; - if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; - return true; - } - - /*! \brief the maximal device mask, cpu = 1, gpu = 2 */ - static const int32_t kMaxDevMask = 2; - - /*! - * \brief A dedicate ID for pinned cpu memory. - * - * Any normal CPU ID should be less than this number. - */ - static const int32_t kPinnedMemoryID = 16; -}; - -/*! - * \brief execution time context. - * The information needed in runtime for actual execution. - */ -struct RunContext { - /*! - * \brief the stream of the device, can be NULL or Stream* in GPU mode - */ - void *stream; - /*! - * \brief get mshadow stream from Context - * \return the mshadow stream - * \tparam xpu the device type of the stream - */ - template - inline mshadow::Stream* get_stream() const { - return static_cast*>(stream); - } -}; - -/*! - * \brief Additional resources - */ -struct Resource { - /*! \brief Resource type, indicating what the pointer type is */ - enum Type { - /*! \brief mshadow::Random object */ - kRandom, - /*! \brief Temporal space */ - kTempSpace - }; - /*! \brief pointer to the resource */ - void *ptr; -}; - -/*! - * \brief The resources that can be requested by Operator - */ -struct ResourceRequest { - /*! \brief type of resources */ - Resource::Type type; - /*! \brief size requirment if it is an temp space request */ - size_t space_size; - /*! \brief default constructor */ - ResourceRequest() {} - /*! - * \brief default constructor, allow implicit conversion - * \param type type of resources - */ - ResourceRequest(Resource::Type type) : type(type) {} // NOLINT(*) -}; - -} // namespace mxnet -#endif // MXNET_CONTEXT_H_ diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index f185da8215c3..0db270fbb958 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -12,7 +12,6 @@ #endif #include #include "./base.h" -#include "./context.h" namespace mxnet { /*! \brief namespace of engine internal types. */ @@ -28,13 +27,14 @@ typedef Opr* OprHandle; } // namespace engine #if DMLC_USE_CXX11 - /*! \brief Function property, used to hint what action is pushed to engine. */ enum class FnProperty { /*! \brief Normal operation */ kNormal, - /*! \brief Copy operation between CPU and GPU */ - kCopy, + /*! \brief Copy operation from GPU to other devices */ + kCopyFromGPU, + /*! \brief Copy operation from CPU to other devices */ + kCopyToGPU, /*! \brief Asynchronous function call */ kAsync }; // enum class FnProperty diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 5f6b07680b92..dfdb6d874102 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -10,7 +10,7 @@ #if DMLC_USE_CXX11 #include #endif // DMLC_USE_CXX11 -#include "ndarray.h" +#include "./ndarray.h" namespace mxnet { diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index c8ec4528202e..15747a9bda02 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -13,10 +13,9 @@ #include #include #include "./base.h" -#include "./context.h" #include "./storage.h" -#include "./context.h" #include "./engine.h" + // check c++11 #if DMLC_USE_CXX11 == 0 #error "cxx11 was required for ndarray module" diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 92b5f034a3c9..a62d97425da0 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -14,7 +14,7 @@ #include #include #include "./base.h" -#include "./context.h" +#include "./resource.h" namespace mxnet { /*! \brief operation request type to Forward and Backward */ @@ -230,18 +230,22 @@ class OperatorProperty { * \brief Declare additional resource required in forward pass. * These additional resources will be presented in OpContext.requested * in the same order of the returned Resource. + * \param in_shape The input shape to the operator, corresponds to shapes of in_data. * \return Additional resource request */ - virtual std::vector ForwardResource() const { + virtual std::vector ForwardResource( + const std::vector &in_shape) const { return std::vector(); } /*! * \brief Decalre additional resource required in backward pass. * These additional resources will be presented in OpContext.requested * in the same order of the returned Resource. + * \param in_shape The input shape to the operator, corresponds to shapes of in_data. * \return Additional resource request */ - virtual std::vector BackwardResource() const { + virtual std::vector BackwardResource( + const std::vector &in_shape) const { return std::vector(); } /*! diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h new file mode 100644 index 000000000000..8d03b08ad44a --- /dev/null +++ b/include/mxnet/resource.h @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file resource.h + * \brief Global resource allocation handling. + */ +#ifndef MXNET_RESOURCE_H_ +#define MXNET_RESOURCE_H_ + +#include +#include "./base.h" +#include "./engine.h" + +namespace mxnet { + +/*! + * \brief The resources that can be requested by Operator + */ +struct ResourceRequest { + /*! \brief Resource type, indicating what the pointer type is */ + enum Type { + /*! \brief mshadow::Random object */ + kRandom, + /*! \brief Temporal space */ + kTempSpace + }; + /*! \brief type of resources */ + Type type; + /*! \brief size of space requested, in terms of number of reals */ + size_t space_num_reals; + /*! \brief default constructor */ + ResourceRequest() {} + /*! + * \brief constructor, allow implicit conversion + * \param type type of resources + */ + ResourceRequest(Type type, size_t space_num_reals = 0) // NOLINT(*) + : type(type), space_num_reals(space_num_reals) {} +}; + + +/*! + * \brief Resources used by mxnet operations. + * A resource is something special other than NDArray, + * but will still participate + */ +struct Resource { + /*! \brief The original request */ + ResourceRequest req; + /*! \brief engine variable */ + engine::VarHandle var; + /*! + * \brief pointer to the resource, do not use directly, + * access using member functions + */ + void *ptr_; + /*! + * \brief Get random number generator. + * \return the mshadow random number generator requested. + * \tparam xpu the device type of random number generator. + */ + template + inline mshadow::Random* get_random() const { + CHECK_EQ(req.type, ResourceRequest::kRandom); + return static_cast*>(ptr_); + } + /*! + * \brief Get space requested as mshadow Tensor. + * The resulting tensor must fit in space requsted. + * \param shape the Shape of returning tensor. + * \param stream the stream of retruning tensor. + * \return the mshadow tensor requested. + * \tparam xpu the device type of random number generator. + * \tparam ndim the number of dimension of the tensor requested. + */ + template + inline mshadow::Tensor get_space( + mshadow::Shape shape, mshadow::Stream *stream) const { + CHECK_EQ(req.type, ResourceRequest::kTempSpace); + CHECK_GE(req.space_num_reals, shape.Size()); + return mshadow::Tensor( + static_cast(ptr_), shape, shape[ndim - 1], stream); + } +}; +} // namespace mxnet +#endif // MXNET_RESOURCE_H_ diff --git a/include/mxnet/storage.h b/include/mxnet/storage.h index 575dc4cde1a2..5590c9f1cdad 100644 --- a/include/mxnet/storage.h +++ b/include/mxnet/storage.h @@ -7,8 +7,7 @@ #define MXNET_STORAGE_H_ #include -#include "base.h" -#include "context.h" +#include "./base.h" namespace mxnet { @@ -64,7 +63,5 @@ class Storage { std::unique_ptr impl_; DISALLOW_COPY_AND_ASSIGN(Storage); }; // class Storage - } // namespace mxnet - #endif // MXNET_STORAGE_H_ diff --git a/scripts/travis_script.sh b/scripts/travis_script.sh index 1b250afdf70b..07abd4881dce 100755 --- a/scripts/travis_script.sh +++ b/scripts/travis_script.sh @@ -40,7 +40,7 @@ if [ ${TASK} == "python3" ]; then make all || exit -1 export MXNET_ENGINE_TYPE=ThreadedEngine nosetests tests/python/unittest || exit -1 - nosetests tests/python/train || exit -1 + nosetests tests/python/train || exit -1 fi if [ ${TASK} == "python_naive" ]; then @@ -48,7 +48,15 @@ if [ ${TASK} == "python_naive" ]; then make all || exit -1 export MXNET_ENGINE_TYPE=NaiveEngine nosetests tests/python/unittest || exit -1 - nosetests tests/python/train || exit -1 + nosetests tests/python/train || exit -1 +fi + +if [ ${TASK} == "python_perdev" ]; then + echo "USE_CUDA=0" >> config.mk + make all || exit -1 + export MXNET_ENGINE_TYPE=ThreadedEnginePerDevice + nosetests tests/python/unittest || exit -1 + nosetests tests/python/train || exit -1 fi if [ ${TASK} == "cpp_unittest" ]; then diff --git a/src/c_api.cc b/src/c_api.cc index 4a96d946d34f..6427a6357c90 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -323,12 +323,6 @@ int MXNDArrayListLoad(const char* fname, API_END(); } -int MXNDArrayWaitAll() { - API_BEGIN(); - Engine::Get()->WaitForAll(); - API_END(); -} - int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast(handle); @@ -817,6 +811,9 @@ int MXDataIterBeforeFirst(DataIterHandle handle) { int MXDataIterNext(DataIterHandle handle, int *out) { API_BEGIN(); + // TODO(tianjun): remove this after having prefetcher by default. + // and call NArray.WaitForWrite instead. + Engine::Get()->WaitForAll(); *out = static_cast* >(handle)->Next(); API_END(); } diff --git a/src/engine/engine.cc b/src/engine/engine.cc index 75bb58c6f56a..6cd82f96050b 100644 --- a/src/engine/engine.cc +++ b/src/engine/engine.cc @@ -13,16 +13,20 @@ namespace engine { inline Engine* CreateEngine() { const char *type = getenv("MXNET_ENGINE_TYPE"); const bool default_engine = (type == nullptr); - if (type == nullptr) type = "ThreadedEngine"; + if (type == nullptr) type = "ThreadedEnginePerDevice"; std::string stype = type; + Engine *ret = nullptr; - if (stype == "ThreadedEngine") { - ret = CreateThreadedEngine(); - } else if (stype == "NaiveEngine") { + if (stype == "NaiveEngine") { ret = CreateNaiveEngine(); + } else if (stype == "ThreadedEngine") { + ret = CreateThreadedEnginePooled(); + } else if (stype == "ThreadedEnginePerDevice") { + ret = CreateThreadedEnginePerDevice(); } + CHECK_NE(ret, nullptr) - << "Cannot find Eine " << type << " in registry"; + << "Cannot find Engine " << type; if (!default_engine) { LOG(INFO) << "MXNet start using engine: " << type; } diff --git a/src/engine/engine_impl.h b/src/engine/engine_impl.h index e4c350656097..44452df7b9c5 100644 --- a/src/engine/engine_impl.h +++ b/src/engine/engine_impl.h @@ -65,11 +65,16 @@ inline T* Opr::Cast() { #endif } +/*! \brief Maximum number of GPUs */ +static constexpr std::size_t kMaxNumGPUs = 16; + // predeclare factory function for each type of engine /*! \return NaiveEngine instance */ Engine *CreateNaiveEngine(); -/*! \return ThreadedEngine instance */ -Engine *CreateThreadedEngine(); +/*! \return ThreadedEnginePooled instance */ +Engine *CreateThreadedEnginePooled(); +/*! \return ThreadedEnginePerDevie instance */ +Engine *CreateThreadedEnginePerDevice(); } // namespace engine } // namespace mxnet #endif // MXNET_ENGINE_ENGINE_IMPL_H_ diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index 7b2382d60df7..adc84cdf7e9e 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -5,15 +5,13 @@ #define MXNET_ENGINE_STREAM_MANAGER_H_ #include +#include #include #include #include -#include "mxnet/base.h" -#include "mxnet/context.h" #include "../common/cuda_utils.h" namespace mxnet { - namespace engine { /*! @@ -44,9 +42,9 @@ class StreamManager { template RunContext StreamManager::GetRunContext( Context const& ctx) { + RunContext ret; switch (ctx.dev_mask) { - case cpu::kDevMask: - return {nullptr}; + case cpu::kDevMask: ret.stream = nullptr; break; case gpu::kDevMask: { #if MXNET_USE_CUDA std::size_t use_counter; @@ -63,21 +61,22 @@ RunContext StreamManager::GetRunContext( use_counter = counter; counter = (counter + 1) % kStreams; } - return {gpu_streams_.at(ctx.dev_id).at(use_counter)}; -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; + ret.stream = gpu_streams_.at(ctx.dev_id).at(use_counter); + break; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA } } - return {nullptr}; + return ret; } template RunContext StreamManager::GetIORunContext( Context const& ctx) { + RunContext ret; switch (ctx.dev_mask) { - case cpu::kDevMask: - return {nullptr}; + case cpu::kDevMask: ret.stream = nullptr; break; case gpu::kDevMask: { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(ctx.dev_id)); @@ -87,13 +86,14 @@ RunContext StreamManager::GetIORunContext( gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream(false, false); } } - return {gpu_io_streams_.at(ctx.dev_id)}; -#else // MXNET_USE_CUDA - LOG(FATAL) << "Please compile with CUDA enabled"; + ret.stream = gpu_io_streams_.at(ctx.dev_id); + break; +#else + LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; #endif // MXNET_USE_CUDA } } - return {nullptr}; + return ret; } template diff --git a/src/engine/thread_pool.h b/src/engine/thread_pool.h index ef99a93e58d1..b88cddaa29c5 100644 --- a/src/engine/thread_pool.h +++ b/src/engine/thread_pool.h @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include "mxnet/base.h" @@ -17,24 +17,30 @@ namespace engine { /*! * \brief Thread pool. */ -template class ThreadPool { public: /*! - * \brief Constructor takes function to run and its arguments. + * \brief Constructor takes function to run. + * \param size size of the thread pool. + * \param func the function to run on the thread pool. */ - template - explicit ThreadPool(Function&& func, Args&&... args); - /*! - * \brief Destructor. - */ - ~ThreadPool() noexcept(false); + explicit ThreadPool(size_t size, std::function func) + : worker_threads_(size) { + for (auto& i : worker_threads_) { + i = std::thread(func); + } + } + ~ThreadPool() noexcept(false) { + for (auto&& i : worker_threads_) { + i.join(); + } + } private: /*! * \brief Worker threads. */ - std::array worker_threads_; + std::vector worker_threads_; /*! * \brief Disallow default construction. */ @@ -44,23 +50,6 @@ class ThreadPool { */ DISALLOW_COPY_AND_ASSIGN(ThreadPool); }; - -template -template -ThreadPool::ThreadPool(Function&& func, Args&&... args) { - for (auto&& i : worker_threads_) { - i = std::thread{std::forward(func), std::forward(args)...}; - } -} - -template -ThreadPool::~ThreadPool() noexcept(false) { - for (auto&& i : worker_threads_) { - i.join(); - } -} - } // namespace engine } // namespace mxnet - #endif // MXNET_ENGINE_THREAD_POOL_H_ diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index dd7662095097..8e59c59ab30e 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -83,7 +83,8 @@ void ThreadedVar::CompleteReadDependency(Dispatcher dispatcher) { template bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { - VersionedVarBlock *old_pending_write, *end_of_dispatch_chain; + VersionedVarBlock *old_pending_write, *end_of_read_chain; + bool trigger_write = false; { // this is lock scope std::lock_guard lock{m_}; @@ -91,32 +92,31 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { // detach pending write old_pending_write = pending_write_; // search for chains to trigger - VersionedVarBlock *p = old_pending_write->next; + end_of_read_chain = old_pending_write->next; assert(num_pending_reads_ == 0); - while (p->next != nullptr && p->write == false) { + while (end_of_read_chain->next != nullptr && + end_of_read_chain->write == false) { ++num_pending_reads_; - p = p->next; + end_of_read_chain = end_of_read_chain->next; } - // mark end of dispatch chain - end_of_dispatch_chain = p; - - if (p->next == nullptr) { + // check the states + if (end_of_read_chain->next == nullptr) { ready_to_read_ = true; pending_write_ = nullptr; - assert(p->trigger == nullptr); - assert(p->write ==false); } else { - assert(p->write == true); - pending_write_ = p; + assert(end_of_read_chain->write == true); + pending_write_ = end_of_read_chain; if (num_pending_reads_ == 0) { - if (--pending_write_->trigger->wait == 0) { - dispatcher(pending_write_->trigger); - } + trigger_write = true; } } } - // this is outside of lock scope - // the linked list is detached from variable + // This is outside of lock scope + // Be very carful, pending_write_ and num_pending_reads_ + // can change now, do not reply ont the two variables. + // The linked list \in [old_pending_write, end_of_read_chain) + // is already detached from this Var. + // So it is safe to modify these VersionedVarBlock *cur_head = old_pending_write->next; VersionedVarBlock::Delete(old_pending_write); if (to_delete_) { @@ -125,7 +125,7 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { return true; } // dispatch all the events - while (cur_head != end_of_dispatch_chain) { + while (cur_head != end_of_read_chain) { if (--cur_head->trigger->wait == 0) { dispatcher(cur_head->trigger); } @@ -134,6 +134,13 @@ bool ThreadedVar::CompleteWriteDependency(Dispatcher dispatcher) { assert(cur_head != nullptr); VersionedVarBlock::Delete(prev); } + // Be careful, do not use pending_write_ or num_pending_reads_ here. + // As they can change, use end_of_read_chain + if (trigger_write) { + if (--end_of_read_chain->trigger->wait == 0) { + dispatcher(end_of_read_chain->trigger); + } + } return false; } diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc new file mode 100644 index 000000000000..0a3da50e69be --- /dev/null +++ b/src/engine/threaded_engine_perdevice.cc @@ -0,0 +1,176 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file threaded_engine_perdevice.cc + * \brief ThreadedEngine that uses fix amount of thread for each device. + */ +#include +#include +#include +#include +#include +#include "./threaded_engine.h" +#include "./thread_pool.h" +#include "./stream_manager.h" + +namespace mxnet { +namespace engine { +/*! + * \brief ThreadedEngine uses per device threads. + * The policy of this Engine: + * - Execute Async operation immediately if pushed from Pusher. + * - Use fixed amount of threads for each device. + * - Use special threads for copy operations. + * - Each stream is allocated and binded to each of the thread. + */ +class ThreadedEnginePerDevice : public ThreadedEngine { + public: + ThreadedEnginePerDevice() noexcept(false) { + cpu_worker_nthreads_ = dmlc::GetEnv("MXNET_CPU_WORKER_NTHREADS", 2); + gpu_worker_nthreads_ = dmlc::GetEnv("MXNET_GPU_WORKER_NTHREADS", 2); + gpu_copy_nthreads_ = dmlc::GetEnv("MXNET_GPU_COPY_NTHREADS", 1); + + // create CPU task + auto *cpu_queue = &(cpu_worker_.task_queue); + cpu_worker_.pool.reset(new ThreadPool( + cpu_worker_nthreads_, [this, cpu_queue] { + this->CPUWorker(cpu_queue); + })); + // GPU tasks will be created lazily + } + ~ThreadedEnginePerDevice() noexcept(false) { + } + + protected: + void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { + const Context& ctx = opr_block->ctx; + if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { + if (ctx.dev_mask == gpu::kDevMask) { + #if MXNET_USE_CUDA + mshadow::SetDevice(ctx.dev_id); + #endif + } + RunContext run_ctx; + run_ctx.stream = nullptr; + this->ExecuteOprBlock(run_ctx, opr_block); + } else { + if (ctx.dev_mask == cpu::kDevMask) { + cpu_worker_.task_queue.Push(opr_block); + } else { + CHECK_EQ(ctx.dev_mask, gpu::kDevMask); + ThreadWorkerBlock* block = this->GetGPUWorkerBlock( + ctx.dev_id, opr_block->opr->prop); + block->task_queue.Push(opr_block); + } + } + } + + private: + // working unit for each of the task. + struct ThreadWorkerBlock { + // task queue on this task + dmlc::ConcurrentBlockingQueue task_queue; + // thread pool that works on this task + std::unique_ptr pool; + // destructor + ~ThreadWorkerBlock() noexcept(false) { + task_queue.SignalForKill(); + } + }; + /*! \brief number of concurrent thread cpu worker uses */ + int cpu_worker_nthreads_; + /*! \brief number of concurrent thread each gpu worker uses */ + int gpu_worker_nthreads_; + /*! \brief number of concurrent thread each gpu copy worker uses */ + int gpu_copy_nthreads_; + // mutex used when creating a ThreadWorkerBlock + std::mutex create_mutex_; + // cpu worker + ThreadWorkerBlock cpu_worker_; + // workers doing normal works on GPU + std::array, kMaxNumGPUs> gpu_normal_workers_; + // workers doing copy works from/to GPU + std::array, kMaxNumGPUs> gpu_copy_workers_; + /*! + * \brief get GPU Task Worker + * \param dev_id the device id + * \param prop The property of the function. + */ + inline ThreadWorkerBlock *GetGPUWorkerBlock(size_t dev_id, + FnProperty prop) { + bool is_copy = (prop == FnProperty::kCopyFromGPU || + prop == FnProperty::kCopyToGPU); + CHECK_LT(dev_id, kMaxNumGPUs) + << "GPU Device index " << dev_id + << " exceed bound " << kMaxNumGPUs; + std::array, kMaxNumGPUs> *workers; + if (is_copy) { + workers = &gpu_copy_workers_; + } else { + workers = &gpu_normal_workers_; + } + ThreadWorkerBlock *block = workers->at(dev_id).get(); + if (block != nullptr) return block; + { + // only lock when block is not available. + std::lock_guard lock(create_mutex_); + // need to double check, because state can change + ThreadWorkerBlock *block = workers->at(dev_id).get(); + if (block != nullptr) return block; + int nthread = is_copy ? gpu_copy_nthreads_ : gpu_worker_nthreads_; + workers->at(dev_id).reset(new ThreadWorkerBlock()); + block = workers->at(dev_id).get(); + block->pool.reset(new ThreadPool(nthread, [this, dev_id, is_copy, block] () { + this->GPUWorker(dev_id, is_copy, &(block->task_queue)); + })); + return block; + } + } + /*! + * \brief GPU worker that performs operations on a certain device. + * \param dev_id The device id of the worker. + * \param is_copy_worker whether the worker only do copy job + * \param task_queue the device id of the worker. + */ + inline void GPUWorker(int dev_id, + bool is_copy_worker, + dmlc::ConcurrentBlockingQueue* task_queue) { + #if MXNET_USE_CUDA + // allocate stream + mshadow::SetDevice(dev_id); + RunContext run_ctx; + mshadow::Stream *stream; + if (is_copy_worker) { + stream = mshadow::NewStream(false, false); + } else { + stream = mshadow::NewStream(true, MXNET_USE_CUDNN != 0); + } + run_ctx.stream = stream; + // execute task + OprBlock* opr_block; + while (task_queue->Pop(&opr_block)) { + this->ExecuteOprBlock(run_ctx, opr_block); + } + mshadow::DeleteStream(stream); + #endif + } + /*! + * \brief CPU worker that performs operations on CPU. + * \param task_queue the device id of the worker. + */ + inline void CPUWorker(dmlc::ConcurrentBlockingQueue* task_queue) { + RunContext run_ctx; + run_ctx.stream = nullptr; + // execute task + OprBlock* opr_block; + while (task_queue->Pop(&opr_block)) { + this->ExecuteOprBlock(run_ctx, opr_block); + } + } +}; + +Engine *CreateThreadedEnginePerDevice() { + return new ThreadedEnginePerDevice(); +} +} // namespace engine +} // namespace mxnet + diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 2dd2f27487eb..0978b32ea8d6 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -22,9 +22,9 @@ namespace engine { */ class ThreadedEnginePooled : public ThreadedEngine { public: - ThreadedEnginePooled() - : thread_pool_{[this]() { ThreadWorker(&task_queue_); }}, - io_thread_pool_{[this]() { ThreadWorker(&io_task_queue_); }} {} + ThreadedEnginePooled() : + thread_pool_(kNumWorkingThreads, [this]() { ThreadWorker(&task_queue_); }), + io_thread_pool_(1, [this]() { ThreadWorker(&io_task_queue_); }) {} ~ThreadedEnginePooled() noexcept(false) { task_queue_.SignalForKill(); @@ -59,8 +59,8 @@ class ThreadedEnginePooled : public ThreadedEngine { /*! * \brief Thread pools. */ - ThreadPool thread_pool_; - ThreadPool<1> io_thread_pool_; + ThreadPool thread_pool_; + ThreadPool io_thread_pool_; /*! * \brief Worker. * \param task_queue Queue to work on. @@ -86,7 +86,9 @@ class ThreadedEnginePooled : public ThreadedEngine { LOG(FATAL) << "Please compile with CUDA enabled"; #endif // MXNET_USE_CUDA } - auto&& rctx = opr_block->opr->prop == FnProperty::kCopy + bool is_copy = (opr_block->opr->prop == FnProperty::kCopyFromGPU || + opr_block->opr->prop == FnProperty::kCopyToGPU); + auto&& rctx = is_copy ? streams_.GetIORunContext(opr_block->ctx) : streams_.GetRunContext(opr_block->ctx); this->ExecuteOprBlock(rctx, opr_block); @@ -97,7 +99,8 @@ class ThreadedEnginePooled : public ThreadedEngine { */ void DoPushToQueue(OprBlock* opr_block) { switch (opr_block->opr->prop) { - case FnProperty::kCopy: { + case FnProperty::kCopyFromGPU: + case FnProperty::kCopyToGPU: { io_task_queue_.Push(opr_block); break; } @@ -109,7 +112,7 @@ class ThreadedEnginePooled : public ThreadedEngine { } }; -Engine *CreateThreadedEngine() { +Engine *CreateThreadedEnginePooled() { return new ThreadedEnginePooled(); } } // namespace engine diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index e9be7e445da6..feb3de61be2d 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -169,7 +169,7 @@ void CopyFromTo(const NDArray &from, NDArray *to) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); }, from.ctx(), const_vars, {ret.ptr_->var}); } else { #if MXNET_USE_CUDA @@ -178,28 +178,28 @@ void CopyFromTo(const NDArray &from, NDArray *to) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, ret.ctx(), const_vars, {ret.ptr_->var}); + }, ret.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyToGPU); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.ptr_->var}); + }, from.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyFromGPU); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Copy(from.data(), &tmp, - from.ctx(), ret.ctx(), ctx); + from.ctx(), ret.ctx(), ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, from.ctx(), const_vars, {ret.ptr_->var}); + }, from.ctx(), const_vars, {ret.ptr_->var}, FnProperty::kCopyFromGPU); } else { LOG(FATAL) << "unknown device mask"; } diff --git a/src/ndarray/ndarray_function.h b/src/ndarray/ndarray_function.h index 94b03ab05a6f..0a0dc89ccdde 100644 --- a/src/ndarray/ndarray_function.h +++ b/src/ndarray/ndarray_function.h @@ -8,7 +8,6 @@ #include #include #include -#include namespace mxnet { /*! \brief namespace to support all possible Ndarray operator */ diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 613913eb8284..3c70810fd52e 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -22,6 +22,7 @@ namespace op { enum BatchNormOpInputs {kData, kGamma, kBeta}; enum BatchNormOpOutputs {kOut, kOutNoAffine, kMean, kVar}; enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; +enum BatchNormBackResource {kTempSpace}; struct BatchNormParam : public dmlc::Parameter { float eps; @@ -37,7 +38,7 @@ struct BatchNormParam : public dmlc::Parameter { template class BatchNormOp : public Operator { public: - explicit BatchNormOp(BatchNormParam param) : is_init(false) { + explicit BatchNormOp(BatchNormParam param) { this->param_ = param; } @@ -137,16 +138,19 @@ class BatchNormOp : public Operator { out = out_data[kOut].get(s); out_no_affine = out_data[kOutNoAffine].get(s); } - this->Init(ctx, out.shape_); + Tensor mean = out_data[kMean].get(s); Tensor var = out_data[kVar].get(s); Tensor slope = in_data[kGamma].get(s); // Tensor bias = in_data[kBeta].get(s); Tensor gslope = in_grad[kGamma].get(s); Tensor gbias = in_grad[kBeta].get(s); - Tensor gmean = tmp_[0]; - Tensor gvar = tmp_[1]; - Tensor tmp = tmp_[2]; + // get requested temp space + Tensor workspace = ctx.requested[kTempSpace].get_space( + mshadow::Shape2(3, out.shape_[1]), s); + Tensor gmean = workspace[0]; + Tensor gvar = workspace[1]; + Tensor tmp = workspace[2]; // cal gvar = sumall_except_dim<1>((grad * broadcast<1>(slope, data.shape_)) * (data - broadcast<1>(mean, data.shape_)) * @@ -167,18 +171,7 @@ class BatchNormOp : public Operator { } private: - // TODO(bing): use global memory allocator - inline void Init(const OpContext &ctx, - const mshadow::Shape<4> &dshape) { - if (is_init) return; - is_init = true; - mshadow::Stream *s = ctx.get_stream(); - tmp_.set_stream(s); - tmp_.Resize(mshadow::Shape2(3, dshape[1])); - } - mshadow::TensorContainer tmp_; BatchNormParam param_; - bool is_init; }; // class BatchNormOp template @@ -239,6 +232,13 @@ class BatchNormProp : public OperatorProperty { return {{out_grad[kOut], in_grad[kData]}}; } + std::vector BackwardResource( + const std::vector &in_shape) const override { + const TShape &dshape = in_shape[0]; + size_t nspace = dshape[1] * 3; + return {{ResourceRequest::kTempSpace, nspace}}; + } + int NumVisibleOutputs() const override { return 1; } @@ -261,10 +261,6 @@ class BatchNormProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - std::vector BackwardResource() const override { - return {Resource::kTempSpace}; - } - private: BatchNormParam param_; }; // class BatchNormProp diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h index f8a29b204d60..b69c412c80fa 100644 --- a/src/operator/convolution-inl.h +++ b/src/operator/convolution-inl.h @@ -340,14 +340,6 @@ class ConvolutionProp : public OperatorProperty { Operator* CreateOperator(Context ctx) const; - std::vector ForwardResource() const override { - return {Resource::kTempSpace}; - } - - std::vector BackwardResource() const override { - return {Resource::kTempSpace}; - } - private: ConvolutionParam param_; }; // class ConvolutionProp diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 8b81818304e1..38397a931096 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -149,9 +149,9 @@ class CuDNNConvolutionOp : public Operator { size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); + temp_.set_stream(s); if (!init_cudnn_) { init_cudnn_ = true; - temp_.set_stream(s); size_t workspace = static_cast(param_.workspace); size_t back_size = 0; size_t back_size_w = 0; diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 10318e39beb1..c1587a3657b4 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -4,6 +4,7 @@ * \brief Executor to execute the Graph. */ #include +#include #include #include #include "./graph_executor.h" @@ -77,11 +78,18 @@ class GraphExecutor::BackwardOpWrapper : public Operator { inline std::vector GraphExecutor::GetResource(uint32_t node_id) const { const StaticGraph::Node &node = graph_.nodes[node_id]; + // use input shape + std::vector in_shapes; + for (StaticGraph::DataEntry e : node.inputs) { + in_shapes.push_back(op_nodes_[e.source_id].outputs[e.index].shape); + } + if (node.is_forward()) { - return node.op->ForwardResource(); + return node.op->ForwardResource(in_shapes); } else { CHECK(node.is_backward()); - return graph_.nodes[node.backward_source_id].op->BackwardResource(); + return graph_.nodes[node.backward_source_id] + .op->BackwardResource(in_shapes); } } @@ -199,6 +207,10 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { } // start setup exec function. + for (const Resource& r : op_node.op_ctx.requested) { + exec.mutate_vars.push_back(static_cast(r.var)); + } + Operator* op = op_node.op.get(); OpContext* op_ctx_ptr = &op_node.op_ctx; bool is_gpu = op_node.ctx.dev_mask == gpu::kDevMask; @@ -374,7 +386,8 @@ void GraphExecutor::InitDataEntryMemory() { for (std::pair kv : inplace) { DataEntryInfo* in = kv.first; DataEntryInfo* out = kv.second; - if (in->temp_ref_count == 1 && + if (enable_inplace_allocation_ && + in->temp_ref_count == 1 && in->type == kInternalAllocated && out->type == kNotInitialized) { // we can only do inplace if we are last user of in @@ -398,6 +411,21 @@ void GraphExecutor::InitDataEntryMemory() { out->type = kInternalAllocated; } } + // resource + const std::vector& reqs = GetResource(nid); + op_nodes_[nid].resources.resize(reqs.size()); + for (uint32_t i = 0; i < reqs.size(); ++i) { + op_nodes_[nid].resources[i].resource.req = reqs[i]; + } + // allocate resource + for (ResourceEntry& entry : op_nodes_[nid].resources) { + if (entry.resource.req.type == ResourceRequest::kTempSpace) { + entry.storage_id = + allocator.Request(op_nodes_[nid].ctx, + mshadow::Shape1(entry.resource.req.space_num_reals), + nid); + } + } // then free inputs for (DataEntryInfo *in : in_data) { // temp_ref_count == 0 means it is taken by inplace op @@ -417,9 +445,15 @@ void GraphExecutor::InitDataEntryMemory() { allocator.Release(out->storage_id, nid); } } + // release the resource, as soon as the forward is finished we can release it. + for (ResourceEntry& res : op_nodes_[nid].resources) { + if (res.resource.req.type == ResourceRequest::kTempSpace) { + allocator.Release(res.storage_id, nid); + } + } } // one pass complete, allocate real memory - allocator.InitStorages(); + this->total_allocated_reals_ = allocator.InitStorages(); // get the real data NDArray into the DataEntryInfo for (size_t i = 0; i < topo_order_.size(); ++i) { uint32_t nid = topo_order_[i]; @@ -430,6 +464,21 @@ void GraphExecutor::InitDataEntryMemory() { out.data = allocator.Get(out.storage_id, out.shape); } } + // Get the resource of temporal space. + for (ResourceEntry& entry : op_nodes_[nid].resources) { + if (entry.resource.req.type == ResourceRequest::kTempSpace) { + entry.data = allocator.Get(entry.storage_id, + mshadow::Shape1(entry.resource.req.space_num_reals)); + entry.resource.ptr_ = entry.data.data().dptr_; + entry.resource.var = entry.data.var(); + } else { + LOG(FATAL) << "resource type not yet supported"; + } + op_nodes_[nid].op_ctx.requested.resize(op_nodes_[nid].resources.size()); + for (size_t i = 0; i < op_nodes_[nid].resources.size(); ++i) { + op_nodes_[nid].op_ctx.requested[i] = op_nodes_[nid].resources[i].resource; + } + } } for (StaticGraph::DataEntry e : graph_.heads) { DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; @@ -510,7 +559,19 @@ std::string GraphExecutor::DebugStr() const { } os << '\n'; } + for (size_t j = 0; j < op_nodes_[nid].resources.size(); ++j) { + const ResourceEntry &entry = op_nodes_[nid].resources[j]; + os << "\tresource[" << j << "]: "; + if (entry.resource.req.type == ResourceRequest::kTempSpace) { + os << "type=TempSpace, size=" << entry.resource.req.space_num_reals + << ", storage_id=" << entry.storage_id; + } else if (entry.resource.req.type == ResourceRequest::kRandom) { + os << "type=RandomNumber"; + } + os << '\n'; + } } + os << "Total " << (total_allocated_reals_ >> 18UL) <<" MB allocated\n"; return os.str(); } diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index a97b34510b2e..a2aafa798669 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -32,6 +32,8 @@ class GraphExecutor : public Executor { const std::vector &arg_grad_store, const std::vector &grad_req_type, const std::vector &aux_states) { + enable_inplace_allocation_ = dmlc::GetEnv("MXNET_EXEC_ENABLE_INPLACE", true); + CHECK_EQ(grad_req_type.size(), arg_grad_store.size()); bool need_backward = false; for (auto req : grad_req_type) { @@ -90,6 +92,15 @@ class GraphExecutor : public Executor { storage_id(GraphStorageAllocator::kBadStorageID), temp_ref_count(0), ref_count(0) {} }; + // information of the resource + struct ResourceEntry { + /*! \brief the actual resource */ + Resource resource; + /*! \brief actual data for the entry if it is a temp space */ + NDArray data; + // storage id from allocator if it is a temp space + GraphStorageAllocator::StorageID storage_id; + }; // all the information needed to push the op to engine struct OpExecEntry { // execution function for @@ -111,6 +122,8 @@ class GraphExecutor : public Executor { std::vector outputs; // auxiliary data information of op std::vector aux_states; + // resource entry + std::vector resources; // The following parts are constructed in InitOpNodes // the real operator std::shared_ptr op; @@ -178,6 +191,10 @@ class GraphExecutor : public Executor { // topological order of nodes in computation graph // backward nodes always follow forward nodes std::vector topo_order_; + // whether to enable inplace space + bool enable_inplace_allocation_; + // total allocated space in #reals + size_t total_allocated_reals_; // number of forward nodes in the graph size_t num_forward_nodes_; // head gradient node in the graph, if there is backward pass diff --git a/src/symbol/graph_memory_allocator.h b/src/symbol/graph_memory_allocator.h index cd6dc0648cb4..5812e5c94b86 100644 --- a/src/symbol/graph_memory_allocator.h +++ b/src/symbol/graph_memory_allocator.h @@ -10,6 +10,7 @@ #include #include #include +#include namespace mxnet { /*! @@ -47,8 +48,11 @@ class GraphStorageAllocator { * \param node_id the node id in the graph that is releasing the memory. */ void Release(StorageID id, uint32_t node_id); - /*! \brief Initialize all the memories requested */ - void InitStorages(); + /*! + * \brief Initialize all the memories requested + * \return size of memory allocated. + */ + size_t InitStorages(); /*! * \brief Get the the memory allocated in planning phase. * \param id the storage id allocated in planning phase. @@ -81,6 +85,8 @@ class GraphStorageAllocator { StaticGraph *graph_; /*! \brief all the resources available */ std::vector > data_; + /*! \brief scale used for rough match */ + size_t match_range_; /*! * \brief free list of storage entries, maps size to free list */ @@ -89,7 +95,9 @@ class GraphStorageAllocator { // put implementation in header files for now GraphStorageAllocator::GraphStorageAllocator(StaticGraph *graph) - : graph_(graph) {} + : graph_(graph) { + match_range_ = dmlc::GetEnv("MXNET_EXEC_MATCH_RANGE", 16); +} GraphStorageAllocator::StorageID GraphStorageAllocator::Alloc(Context ctx, size_t size) { @@ -104,16 +112,29 @@ GraphStorageAllocator::Alloc(Context ctx, size_t size) { GraphStorageAllocator::StorageID GraphStorageAllocator::Request(Context ctx, TShape shape, uint32_t node_id) { + // search memory block in [size / match_range_, size * match_range_) size_t size = shape.Size(); - auto begin = free_.lower_bound(size); - auto end = free_.upper_bound(size); - // vector of possible candidates - for (auto it = begin; it != end; ++it) { + auto begin = free_.lower_bound(size / match_range_); + auto mid = free_.lower_bound(size); + auto end = free_.upper_bound(size * match_range_); + // TODO(bing, min) consider better strategy + // search for memory blocks larger than requested + for (auto it = mid; it != end; ++it) { + StorageEntry *e = it->second; + if (e->ctx != ctx) continue; + // Use exect matching strategy + e->max_size = std::max(size, e->max_size); + // find a exact match, erase from map and return + free_.erase(it); + return e->id; + } + // then search for memory blocks smaller than requested space + for (auto it = mid; it != begin;) { + --it; StorageEntry *e = it->second; if (e->ctx != ctx) continue; // Use exect matching strategy - // TODO(bing): think of other strategies, for example, rough match. - if (e->max_size != size) continue; + e->max_size = std::max(size, e->max_size); // find a exact match, erase from map and return free_.erase(it); return e->id; @@ -128,12 +149,15 @@ void GraphStorageAllocator::Release(StorageID id, uint32_t node_id) { free_.insert({e->max_size, e}); } -void GraphStorageAllocator::InitStorages() { +size_t GraphStorageAllocator::InitStorages() { + size_t total = 0; for (size_t i = 0; i < data_.size(); ++i) { StorageEntry *e = data_[i].get(); TShape shape = mshadow::Shape1(e->max_size); e->data = NDArray(shape, e->ctx); + total += e->max_size; } + return total; } NDArray GraphStorageAllocator::Get(StorageID id, TShape shape) { diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index a93a45b05aba..53df58dd96f1 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -12,6 +12,11 @@ namespace mxnet { std::vector StaticGraph::TopoSort() const { + std::vector > stack; + std::unordered_set visited; + std::vector ret(nodes.size()); + std::vector head_node; + // out degree std::vector out_degree(nodes.size(), 0); for (const Node& n : nodes) { for (const DataEntry& e : n.inputs) { @@ -21,28 +26,33 @@ std::vector StaticGraph::TopoSort() const { ++out_degree[n.backward_source_id]; } } - std::vector ret(nodes.size()); - auto result = ret.rbegin(); - std::queue queue; for (size_t i = 0; i < nodes.size(); ++i) { if (out_degree[i] == 0) { - queue.push(static_cast(i)); + stack.push_back(std::make_pair(static_cast(i), 0)); } } - while (!queue.empty()) { - uint32_t node_id = queue.front(); - queue.pop(); - *result = node_id; - ++result; - const Node& n = nodes[node_id]; - for (const DataEntry& e : n.inputs) { - if (--out_degree[e.source_id] == 0) { - queue.push(e.source_id); + // heads + for (auto &head : head_node) { + stack.push_back(std::make_pair(head, 0)); + } + int count = 0; + while (!stack.empty()) { + std::pair& back = stack.back(); + const Node& n = nodes[back.first]; + if (back.second == n.inputs.size() + (n.is_backward() ? 1 : 0)) { + ret[count++] = back.first; + visited.insert(back.first); + stack.pop_back(); + } else { + uint32_t input; + if (back.second == n.inputs.size() && n.is_backward()) { + input = n.backward_source_id; + back.second++; + } else { + input = n.inputs[back.second++].source_id; } - } - if (n.is_backward()) { - if (--out_degree[n.backward_source_id] == 0) { - queue.push(n.backward_source_id); + if (visited.count(input) == 0) { + stack.push_back(std::make_pair(input, 0)); } } } diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 769cc5361f54..9a0b1e0e997d 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -54,25 +54,26 @@ inline bool Symbol::is_atomic() const { // implementation of template functions template inline void Symbol::DFSVisit(FVisit fvisit) const { - std::vector*> stack; + std::vector*, uint32_t> > stack; std::unordered_set visited; // put the head into the graph for (auto &head : heads_) { Node *ptr = head.source.get(); if (visited.count(ptr) == 0) { - stack.push_back(&head.source); - visited.insert(ptr); + stack.push_back(std::make_pair(&head.source, 0)); } } while (!stack.empty()) { - const std::shared_ptr *back = stack.back(); - stack.pop_back(); - fvisit(*back); - for (auto it = back->get()->inputs.rbegin(); it != back->get()->inputs.rend(); ++it) { - Node *ptr = it->source.get(); - if (visited.count(ptr) == 0) { - stack.push_back(&it->source); - visited.insert(ptr); + std::pair *, uint32_t>& back = stack.back(); + if (back.second == back.first->get()->inputs.size()) { + fvisit(*(back.first)); + visited.insert(back.first->get()); + stack.pop_back(); + } else { + std::vector& inputs = back.first->get()->inputs; + Symbol::DataEntry& input = inputs.at(back.second++); + if (visited.count(input.source.get()) == 0) { + stack.push_back(std::make_pair(&input.source, 0)); } } } diff --git a/tests/cpp/threaded_engine_unittest.cc b/tests/cpp/threaded_engine_unittest.cc index 35e0ca3124b0..ffe3ee4ad3da 100644 --- a/tests/cpp/threaded_engine_unittest.cc +++ b/tests/cpp/threaded_engine_unittest.cc @@ -72,7 +72,7 @@ TEST(Engine, basics) { Foo(ctx, 42); cb(); }, - {}, {var}, mxnet::FnProperty::kCopy)); + {}, {var}, mxnet::FnProperty::kCopyFromGPU)); engine->Push(oprs.at(0), mxnet::Context{}); LOG(INFO) << "IO operator pushed, should wait for 2 seconds."; engine->WaitForVar(var);