From 77c408786ae8aaeacbc35c644b421076a941f1c6 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 19 Sep 2015 14:32:28 -0700 Subject: [PATCH] change def of context, change pinned to special device --- Makefile | 8 +- include/mxnet/base.h | 99 ++++++++++++++++++------- include/mxnet/c_api.h | 20 ++--- include/mxnet/ndarray.h | 9 ++- mshadow | 2 +- python/mxnet/context.py | 10 +-- python/mxnet/ndarray.py | 9 ++- src/c_api.cc | 21 +++--- src/engine/naive_engine.cc | 2 +- src/engine/stream_manager.h | 4 +- src/engine/threaded_engine.cc | 4 +- src/engine/threaded_engine_perdevice.cc | 6 +- src/engine/threaded_engine_pooled.cc | 2 +- src/kvstore/kvstore_local.h | 7 +- src/ndarray/ndarray.cc | 40 +++++----- src/operator/operator_common.h | 4 +- src/resource.cc | 8 +- src/storage/storage.cc | 40 +++++----- src/symbol/graph_executor.cc | 2 +- 19 files changed, 176 insertions(+), 121 deletions(-) diff --git a/Makefile b/Makefile index 0cef664871ee..a8712d526d46 100644 --- a/Makefile +++ b/Makefile @@ -27,12 +27,12 @@ WARNFLAGS= -Wall CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS) # CFLAGS for debug -ifeq ($(DEBUG),0) - CFLAGS += -O3 -else +ifeq ($(DEBUG), 1) CFLAGS += -g -O0 +else + CFLAGS += -O3 endif -CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS) $(DMLC_CFLAGS) +CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) NVCCFLAGS = --use_fast_math -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) ROOTDIR = $(CURDIR) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 2fdcf94ed4e7..b5670070ae1a 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -64,41 +64,41 @@ typedef mshadow::TShape TShape; /*! \brief storage container type */ typedef mshadow::TBlob TBlob; - /*! \brief Context information about the execution enviroment */ struct Context { - /*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */ - int32_t dev_mask; + /*! \brief Type of device */ + enum DeviceType { + kCPU = cpu::kDevMask, + kGPU = gpu::kDevMask, + kCPUPinned = 3 + }; + /*! \brief the device type we run the op on */ + DeviceType dev_type; /*! \brief device id we are going to run it on */ int32_t dev_id; - /*! \brief constructor */ - Context() : dev_mask(cpu::kDevMask), dev_id(0) {} + /*! \brief default constructor */ + Context() : dev_type(kCPU), dev_id(0) {} /*! - * \brief constructor of context - * \param dev_mask the device mask - * \param dev_id the device id + * \brief Get corresponding device mask + * \return cpu::kDevMask or gpu::kDevMask */ - Context(int dev_mask, int dev_id) - : dev_mask(dev_mask), dev_id(dev_id) {} + inline int dev_mask() const { + if (dev_type == kCPUPinned) return cpu::kDevMask; + return dev_type; + } /*! * \brief Comparator, used to enable Context as std::map key. * \param b another context to compare * \return compared result */ - inline bool operator<(const Context &b) const { - if (dev_mask == b.dev_mask) { - return dev_id < b.dev_id; - } else { - return dev_mask < b.dev_mask; - } - } + inline bool operator<(const Context &b) const; /*! * \brief check if current context equals another one * \param b another context to compare * \return whether dev mask and id are same */ inline bool operator==(const Context &b) const { - return dev_mask == b.dev_mask && dev_id == b.dev_id; + return dev_type == b.dev_type && dev_id == b.dev_id; } /*! * \brief check if current context not equals another one @@ -112,8 +112,8 @@ struct Context { * \brief save the content into binary stream * \param strm the output stream */ - void Save(dmlc::Stream *strm) const { - strm->Write(&dev_mask, sizeof(dev_mask)); + inline void Save(dmlc::Stream *strm) const { + strm->Write(&dev_type, sizeof(dev_type)); strm->Write(&dev_id, sizeof(dev_id)); } /*! @@ -121,18 +121,35 @@ struct Context { * \param strm the output stream * \return whether the load is successful */ - bool Load(dmlc::Stream *strm) { - if (strm->Read(&dev_mask, sizeof(int32_t)) != sizeof(int32_t)) return false; + inline bool Load(dmlc::Stream *strm) { + if (strm->Read(&dev_type, sizeof(dev_type)) != sizeof(dev_type)) return false; if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; return true; } - /*! \brief the maximal device mask, cpu = 1, gpu = 2 */ - static const int32_t kMaxDevMask = 2; + /*! \brief the maximal device type */ + static const int32_t kMaxDevType = 4; + /*! \brief the maximal device index */ + static const int32_t kMaxDevID = 16; + /*! + * \brief Create a new context. + * \param dev_type device type. + * \param dev_id device id. + */ + inline static Context Create(DeviceType dev_type, int32_t dev_id); + /*! \return CPU Context */ + inline static Context CPU(); + /*! + * Create a GPU context. + * \param dev_id the device id. + * \return GPU Context. + */ + inline static Context GPU(int32_t dev_id); /*! - * \brief A dedicate ID for pinned cpu memory. - * Any normal CPU ID should be less than this number. + * Create a pinned CPU context. + * \param dev_id the device id for corresponding GPU. + * \return Pinned CPU context. */ - static const int32_t kPinnedMemoryID = 16; + inline static Context CPUPinned(int32_t dev_id); }; /*! @@ -157,6 +174,34 @@ struct RunContext { } // namespace mxnet //! \cond Doxygen_Suppress +namespace mxnet { +// implementing Context +inline bool Context::operator<(const Context &b) const { + if (dev_type == b.dev_type) { + return dev_id < b.dev_id; + } else { + return dev_type < b.dev_type; + } +} +inline Context Context::Create(DeviceType dev_type, int32_t dev_id) { + Context ctx; + ctx.dev_type = dev_type; + ctx.dev_id = dev_id; + return ctx; +} +inline Context Context::CPU() { + return Create(kCPU, 0); +} + +inline Context Context::CPUPinned(int32_t dev_id) { + return Create(kCPUPinned, dev_id); +} + +inline Context Context::GPU(int32_t dev_id) { + return Create(kGPU, dev_id); +} +} // namespace mxnet + namespace dmlc { // Add a few patches to support TShape in dmlc/parameter. DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)"); diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index c421df8fab5f..17d2cf9ed086 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -77,7 +77,7 @@ MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out); * \brief create a NDArray with specified shape * \param shape the pointer to the shape * \param ndim the dimension of the shape - * \param dev_mask device mask, specify device we want to take + * \param dev_type device type, specify device we want to take * \param dev_id the device id of the specific device * \param delay_alloc whether to delay allocation until * the narray is first mutated @@ -86,7 +86,7 @@ MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out); */ MXNET_DLL int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, - int dev_mask, + int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out); @@ -198,8 +198,8 @@ MXNET_DLL int MXNDArrayFree(NDArrayHandle handle); * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, - mx_uint *out_dim, - const mx_uint **out_pdata); + mx_uint *out_dim, + const mx_uint **out_pdata); /*! * \brief get the content of the data in NDArray * \param handle the handle to the narray @@ -207,17 +207,17 @@ MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle, - mx_float **out_pdata); + mx_float **out_pdata); /*! * \brief get the context of the NDArray * \param handle the handle to the narray - * \param out_dev_mask the output device mask + * \param out_dev_type the output device type * \param out_dev_id the output device id * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle, - int *out_dev_mask, - int *out_dev_id); + int *out_dev_type, + int *out_dev_id); //-------------------------------- // Part 2: functions on NDArray @@ -547,7 +547,7 @@ MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, * \brief Generate Executor from symbol * * \param symbol_handle symbol handle - * \param dev_mask device mask + * \param dev_type device type * \param dev_id device id * \param len length * \param in_args in args array @@ -559,7 +559,7 @@ MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle, - int dev_mask, + int dev_type, int dev_id, mx_uint len, NDArrayHandle *in_args, diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 26f9306ddc97..d56db097867e 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -62,7 +62,7 @@ class NDArray { */ inline TBlob data() const { return TBlob(static_cast(ptr_->shandle.dptr) + offset_, \ - shape_, ptr_->shandle.ctx.dev_mask); + shape_, ptr_->shandle.ctx.dev_mask()); } /*! * \return the context of NDArray, this function is only valid when the NDArray is not empty @@ -288,7 +288,12 @@ class NDArray { : static_data(true), delay_alloc(false) { var = Engine::Get()->NewVariable(); - shandle.ctx = Context(data.dev_mask_, dev_id); + if (data.dev_mask_ == cpu::kDevMask) { + shandle.ctx = Context::CPU(); + } else { + CHECK_EQ(data.dev_mask_, gpu::kDevMask); + shandle.ctx = Context::GPU(dev_id); + } shandle.dptr = data.dptr_; shandle.size = data.shape_.Size() * sizeof(real_t); } diff --git a/mshadow b/mshadow index 2b6c218f6f6f..c7e3c5eb1d0e 160000 --- a/mshadow +++ b/mshadow @@ -1 +1 @@ -Subproject commit 2b6c218f6f6fd677186eee9eb0a9ff64a57ead70 +Subproject commit c7e3c5eb1d0e6e6a754bb5e87dec46c3dac9b6f9 diff --git a/python/mxnet/context.py b/python/mxnet/context.py index 1e5e4652aced..fff45dc7b895 100644 --- a/python/mxnet/context.py +++ b/python/mxnet/context.py @@ -6,8 +6,8 @@ class Context(object): """Context representing device and device id in mxnet""" # static class variable default_ctx = None - devmask2type = {1: 'cpu', 2: 'gpu'} - devtype2mask = {'cpu': 1, 'gpu': 2} + devtype2str = {1: 'cpu', 2: 'gpu'} + devstr2type = {'cpu': 1, 'gpu': 2} def __init__(self, device_type, device_id=0): """Constructing a context. @@ -21,10 +21,10 @@ def __init__(self, device_type, device_id=0): the device id of the device, needed for GPU """ if isinstance(device_type, Context): - self.device_mask = device_type.device_mask + self.device_typeid = device_type.device_typeid self.device_id = device_type.device_id else: - self.device_mask = Context.devtype2mask[device_type] + self.device_typeid = Context.devstr2type[device_type] self.device_id = device_id self._old_ctx = None @@ -36,7 +36,7 @@ def device_type(self): ------- device_type : str """ - return Context.devmask2type[self.device_mask] + return Context.devtype2str[self.device_typeid] def __str__(self): return 'Context(device_type=%s, device_id=%d)' % ( diff --git a/python/mxnet/ndarray.py b/python/mxnet/ndarray.py index 1bb9f465ac59..2cb77be8bcb2 100644 --- a/python/mxnet/ndarray.py +++ b/python/mxnet/ndarray.py @@ -36,10 +36,11 @@ def _new_alloc_handle(shape, ctx, delay_alloc): a new empty ndarray handle """ hdl = NDArrayHandle() + print ctx.device_typeid check_call(_LIB.MXNDArrayCreate( c_array(mx_uint, shape), len(shape), - ctx.device_mask, + ctx.device_typeid, ctx.device_id, int(delay_alloc), ctypes.byref(hdl))) @@ -262,11 +263,11 @@ def context(self): context : mxnet.Context The context of current NDArray. """ - dev_mask = ctypes.c_int() + dev_typeid = ctypes.c_int() dev_id = ctypes.c_int() check_call(_LIB.MXNDArrayGetContext( - self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id))) - return Context(Context.devmask2type[dev_mask.value], dev_id.value) + self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id))) + return Context(Context.devtype2str[dev_typeid.value], dev_id.value) def asnumpy(self): """Return a copied numpy array of current array. diff --git a/src/c_api.cc b/src/c_api.cc index 5f35be62bc1b..2dc48ce3d1ff 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -195,14 +195,15 @@ int MXNDArrayCreateNone(NDArrayHandle *out) { int MXNDArrayCreate(const mx_uint *shape, mx_uint ndim, - int dev_mask, + int dev_type, int dev_id, int delay_alloc, NDArrayHandle *out) { API_BEGIN(); - *out = new NDArray(TShape(shape, shape + ndim), - Context(dev_mask, dev_id), - delay_alloc != 0); + *out = new NDArray( + TShape(shape, shape + ndim), + Context::Create(static_cast(dev_type), dev_id), + delay_alloc != 0); API_END(); } @@ -336,7 +337,7 @@ int MXNDArrayGetData(NDArrayHandle handle, API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { - CHECK(arr->ctx().dev_mask == cpu::kDevMask) + CHECK(arr->ctx().dev_mask() == cpu::kDevMask) << "MXNDArrayGetData can only be called for NDArray on CPU"; const TBlob &b = arr->data(); CHECK(b.CheckContiguous()); @@ -348,16 +349,16 @@ int MXNDArrayGetData(NDArrayHandle handle, } int MXNDArrayGetContext(NDArrayHandle handle, - int *out_dev_mask, + int *out_dev_type, int *out_dev_id) { API_BEGIN(); NDArray *arr = static_cast(handle); if (!arr->is_none()) { const Context &ctx = arr->ctx(); - *out_dev_mask = ctx.dev_mask; + *out_dev_type = ctx.dev_type; *out_dev_id = ctx.dev_id; } else { - *out_dev_mask = 0; + *out_dev_type = 0; *out_dev_id = 0; } API_END(); @@ -764,7 +765,7 @@ int MXExecutorOutputs(ExecutorHandle handle, } int MXExecutorBind(SymbolHandle symbol_handle, - int dev_mask, + int dev_type, int dev_id, mx_uint len, NDArrayHandle *in_args, @@ -775,7 +776,7 @@ int MXExecutorBind(SymbolHandle symbol_handle, ExecutorHandle *out) { API_BEGIN(); Symbol *symb = static_cast(symbol_handle); - Context ctx = Context(dev_mask, dev_id); + Context ctx = Context::Create(static_cast(dev_type), dev_id); NDArray **in_args_ptr = reinterpret_cast(in_args); NDArray **arg_grad_ptr = reinterpret_cast(arg_grad_store); NDArray **aux_states_ptr = reinterpret_cast(aux_states); diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 9351e8a05163..e9d558013c1b 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -59,7 +59,7 @@ class NaiveEngine final : public Engine { NaiveEngine::OnComplete, nullptr); this->req_completed_ = false; - if (exec_ctx.dev_mask == gpu::kDevMask) { + if (exec_ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA size_t dev_id = static_cast(exec_ctx.dev_id); mshadow::SetDevice(exec_ctx.dev_id); diff --git a/src/engine/stream_manager.h b/src/engine/stream_manager.h index c392e7d6ce3c..dbaf9bb8dce6 100644 --- a/src/engine/stream_manager.h +++ b/src/engine/stream_manager.h @@ -47,7 +47,7 @@ RunContext StreamManager::GetRunContext( Context const& ctx) { RunContext ret; ret.stream = nullptr; - switch (ctx.dev_mask) { + switch (ctx.dev_mask()) { case cpu::kDevMask: break; case gpu::kDevMask: { #if MXNET_USE_CUDA @@ -80,7 +80,7 @@ RunContext StreamManager::GetIORunContext( Context const& ctx) { RunContext ret; ret.stream = nullptr; - switch (ctx.dev_mask) { + switch (ctx.dev_mask()) { case cpu::kDevMask: break; case gpu::kDevMask: { #if MXNET_USE_CUDA diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 1a3144e783ec..1b2ccdec796b 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -226,7 +226,7 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { threaded_opr->mutable_vars.end()); this->PushSync([threaded_opr](RunContext) { ThreadedOpr::Delete(threaded_opr); - }, Context(), {}, deps, FnProperty::kAsync); + }, Context::CPU(), {}, deps, FnProperty::kAsync); } void ThreadedEngine::Push(OprHandle op, Context exec_ctx) { @@ -281,7 +281,7 @@ void ThreadedEngine::WaitForVar(VarHandle var) { std::unique_lock lock{finished_m_}; done.store(true); finished_cv_.notify_all(); - }, Context{}, {var}, {}, FnProperty::kNormal); + }, Context::CPU(), {var}, {}, FnProperty::kNormal); finished_cv_.wait(lock, [&done]() { return done.load(); }); } } diff --git a/src/engine/threaded_engine_perdevice.cc b/src/engine/threaded_engine_perdevice.cc index 35d333b3f85b..3bb72606c341 100644 --- a/src/engine/threaded_engine_perdevice.cc +++ b/src/engine/threaded_engine_perdevice.cc @@ -46,7 +46,7 @@ class ThreadedEnginePerDevice : public ThreadedEngine { void PushToExecute(OprBlock *opr_block, bool pusher_thread) override { const Context& ctx = opr_block->ctx; if (opr_block->opr->prop == FnProperty::kAsync && pusher_thread) { - if (ctx.dev_mask == gpu::kDevMask) { + if (ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA mshadow::SetDevice(ctx.dev_id); #endif @@ -55,10 +55,10 @@ class ThreadedEnginePerDevice : public ThreadedEngine { run_ctx.stream = nullptr; this->ExecuteOprBlock(run_ctx, opr_block); } else { - if (ctx.dev_mask == cpu::kDevMask) { + if (ctx.dev_mask() == cpu::kDevMask) { cpu_worker_->task_queue.Push(opr_block); } else { - CHECK_EQ(ctx.dev_mask, gpu::kDevMask); + CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); ThreadWorkerBlock* block = this->GetGPUWorkerBlock( ctx.dev_id, opr_block->opr->prop); block->task_queue.Push(opr_block); diff --git a/src/engine/threaded_engine_pooled.cc b/src/engine/threaded_engine_pooled.cc index 3a32623776b5..a7d44ea018a3 100644 --- a/src/engine/threaded_engine_pooled.cc +++ b/src/engine/threaded_engine_pooled.cc @@ -84,7 +84,7 @@ class ThreadedEnginePooled : public ThreadedEngine { */ void DoExecute(OprBlock* opr_block) { assert(opr_block->wait.load() == 0); - if (opr_block->ctx.dev_mask == gpu::kDevMask) { + if (opr_block->ctx.dev_mask() == gpu::kDevMask) { #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(opr_block->ctx.dev_id)); #else // MXNET_USE_CUDA diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 678539c44b9e..aca976f9aaa1 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -20,10 +20,11 @@ namespace mxnet { class KVStoreLocal : public KVStore { public: KVStoreLocal() { + // TODO(tianqi, mu) allocate pinned per GPU #if MXNET_USE_CUDA - pinned_ctx_ = Context(cpu::kDevMask, Context::kPinnedMemoryID); + pinned_ctx_ = Context::CPUPinned(0); #else - pinned_ctx_ = Context(cpu::kDevMask, 0); + pinned_ctx_ = Context::CPU(); #endif Clear(); } @@ -128,7 +129,7 @@ class KVStoreLocal : public KVStore { for (size_t i = 1; i < val.size(); ++i) { const auto& v = val[i]; - if (v.ctx().dev_mask == cpu::kDevMask) { + if (v.ctx().dev_mask() == cpu::kDevMask) { buf.merged += v; } else { int id = v.ctx().dev_id; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 21edc44580f7..210d4b7926f3 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -29,15 +29,15 @@ inline void BinaryOp(const NDArray &lhs, const NDArray &rhs, NDArray *out) { // no check if both of them are on cpu - if (lhs.ctx().dev_mask != cpu::kDevMask || rhs.ctx().dev_mask != cpu::kDevMask) + if (lhs.ctx().dev_mask() != cpu::kDevMask || rhs.ctx().dev_mask() != cpu::kDevMask) CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; // if out is none, allocate space if (out->is_none()) { *out = NDArray(OP::GetShape(lhs.shape(), rhs.shape()), lhs.ctx(), true); } else { // no check if both of them are on cpu - if (lhs.ctx().dev_mask != cpu::kDevMask || - out->ctx().dev_mask != cpu::kDevMask) { + if (lhs.ctx().dev_mask() != cpu::kDevMask || + out->ctx().dev_mask() != cpu::kDevMask) { CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; } CHECK(out->shape() == OP::GetShape(lhs.shape(), rhs.shape())) @@ -51,7 +51,7 @@ inline void BinaryOp(const NDArray &lhs, if (rhs.var() != ret.var()) const_vars.push_back(rhs.var()); // redirect everything to mshadow operations - switch (lhs.ctx().dev_mask) { + switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -80,7 +80,7 @@ inline void SetValueOp(const real_t &rhs, NDArray *out) { CHECK_NE(out->is_none(), true) << "Set value target must not be empty"; // important: callback must always capture by value NDArray ret = *out; - switch (ret.ctx().dev_mask) { + switch (ret.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -128,7 +128,7 @@ inline void ScalarOp(const NDArray &lhs, if (lhs.var() != ret.var()) const_vars.push_back(lhs.var()); // redirect everything to mshadow operations - switch (lhs.ctx().dev_mask) { + switch (lhs.ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([lhs, rhs, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -160,8 +160,8 @@ void CopyFromTo(const NDArray &from, NDArray *to) { << "source operands have zero dimension shape"; // important: callback must always capture by value NDArray ret = *to; - int a = from.ctx().dev_mask; - int b = to->ctx().dev_mask; + int a = from.ctx().dev_mask(); + int b = to->ctx().dev_mask(); std::vector const_vars; if (from.var() != ret.var()) const_vars.push_back(from.var()); @@ -222,7 +222,7 @@ inline void SampleOP(const real_t &a, // important: callback must always capture by value NDArray ret = *out; // redirect everything to mshadow operations - switch (out->ctx().dev_mask) { + switch (out->ctx().dev_mask()) { case cpu::kDevMask: { Engine::Get()->PushSync([a, b, resource, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -356,8 +356,8 @@ void NDArray::Save(dmlc::Stream *strm) const { ctx.Save(strm); TBlob save_data; NDArray temp; - if (ctx.dev_mask != cpu::kDevMask) { - temp = this->Copy(Context(cpu::kDevMask, 0)); + if (ctx.dev_mask() != cpu::kDevMask) { + temp = this->Copy(Context::CPU()); temp.WaitToRead(); save_data = temp.data(); } else { @@ -390,14 +390,14 @@ bool NDArray::Load(dmlc::Stream *strm) { if (strm->Read(&type_flag, sizeof(type_flag)) != sizeof(type_flag)) return false; CHECK(type_flag == mshadow::DataType::kFlag) << "Only support float NDArray so far"; - // load data into CPUbu - NDArray temp(shape, Context(cpu::kDevMask, ctx.dev_id)); + // load data into CPU + NDArray temp(shape, Context::CPU()); TBlob load_data = temp.data(); size_t type_size = sizeof(real_t); size_t nread = type_size * shape.Size(); if (strm->Read(load_data.dptr_, nread) != nread) return false; - if (ctx.dev_mask == cpu::kDevMask) { + if (ctx.dev_mask() == cpu::kDevMask) { *this = std::move(temp); return true; } else { *this = temp.Copy(ctx); return true; @@ -453,8 +453,8 @@ void NDArray::SyncCopyFromCPU(const real_t *data, size_t size) const { TBlob src((real_t*)data, dshape, cpu::kDevMask); // NOLINT(*) RunContext run_ctx; - if (ctx.dev_mask == cpu::kDevMask) { - ndarray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); + if (ctx.dev_mask() == cpu::kDevMask) { + ndarray::Copy(src, &dst, Context::CPU(), ctx, run_ctx); } else { #if MXNET_USE_CUDA // use empty stream to do sync copy @@ -462,7 +462,7 @@ void NDArray::SyncCopyFromCPU(const real_t *data, size_t size) const { // Maybe move to engine part mshadow::Stream zero_stream; run_ctx.stream = &zero_stream; - ndarray::Copy(src, &dst, Context(cpu::kDevMask, 0), ctx, run_ctx); + ndarray::Copy(src, &dst, Context::CPU(), ctx, run_ctx); #else LOG(FATAL) << "GPU is not enabled"; #endif @@ -479,8 +479,8 @@ void NDArray::SyncCopyToCPU(real_t *data, size_t size) const { TBlob dst(data, dshape, cpu::kDevMask); // NOLINT(*) RunContext run_ctx; - if (ctx.dev_mask == cpu::kDevMask) { - ndarray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); + if (ctx.dev_mask() == cpu::kDevMask) { + ndarray::Copy(src, &dst, ctx, Context::CPU(), run_ctx); } else { #if MXNET_USE_CUDA // use empty stream to do sync copy @@ -488,7 +488,7 @@ void NDArray::SyncCopyToCPU(real_t *data, size_t size) const { // Maybe move to engine part mshadow::Stream zero_stream; run_ctx.stream = &zero_stream; - ndarray::Copy(src, &dst, ctx, Context(cpu::kDevMask, 0), run_ctx); + ndarray::Copy(src, &dst, ctx, Context::CPU(), run_ctx); #else LOG(FATAL) << "GPU is not enabled"; #endif diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h index 64714f1e7633..3e45510418b6 100644 --- a/src/operator/operator_common.h +++ b/src/operator/operator_common.h @@ -81,14 +81,14 @@ struct InferShapeError { // helper macro to implement bind dispatch #if MXNET_USE_CUDA #define DO_BIND_DISPATCH(Method, ...) \ - if (ctx.dev_mask == cpu::kDevMask) { \ + if (ctx.dev_mask() == cpu::kDevMask) { \ return Method(__VA_ARGS__); \ } else { \ return Method(__VA_ARGS__); \ } #else #define DO_BIND_DISPATCH(Method, ...) \ - if (ctx.dev_mask == cpu::kDevMask) { \ + if (ctx.dev_mask() == cpu::kDevMask) { \ return Method(__VA_ARGS__); \ } else { \ LOG(FATAL) << "GPU is not enabled"; \ diff --git a/src/resource.cc b/src/resource.cc index a2d51c5516a0..2259c642d32e 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -24,9 +24,9 @@ class ResourceManagerImpl : public ResourceManager { gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 4); engine_ref_ = Engine::_GetSharedRef(); cpu_rand_.reset(new ResourceRandom( - Context(cpu::kDevMask, 0), global_seed_)); + Context::CPU(), global_seed_)); cpu_space_.reset(new ResourceTempSpace( - Context(cpu::kDevMask, 0), cpu_temp_space_copy_)); + Context::CPU(), cpu_temp_space_copy_)); } ~ResourceManagerImpl() { // need explicit delete, before engine get killed @@ -45,14 +45,14 @@ class ResourceManagerImpl : public ResourceManager { // request resources Resource Request(Context ctx, const ResourceRequest &req) override { - if (ctx.dev_mask == cpu::kDevMask) { + if (ctx.dev_mask() == cpu::kDevMask) { switch (req.type) { case ResourceRequest::kRandom: return cpu_rand_->resource; case ResourceRequest::kTempSpace: return cpu_space_->GetNext(); default: LOG(FATAL) << "Unknown supported type " << req.type; } } else { - CHECK_EQ(ctx.dev_mask, gpu::kDevMask); + CHECK_EQ(ctx.dev_mask(), gpu::kDevMask); #if MSHADOW_USE_CUDA switch (req.type) { case ResourceRequest::kRandom: { diff --git a/src/storage/storage.cc b/src/storage/storage.cc index e41f4e701f80..08af99621b40 100644 --- a/src/storage/storage.cc +++ b/src/storage/storage.cc @@ -21,18 +21,18 @@ namespace mxnet { // consider change storage as a pure abstract class struct Storage::Impl { static constexpr size_t kPoolThreshold = 4096 * 1024 * 1024ul; - static constexpr size_t kMaxNumberOfDevices = Context::kMaxDevMask + 1; - static constexpr size_t kMaxNumberOfDeviceIDs = Context::kPinnedMemoryID + 1; + static constexpr size_t kMaxNumberOfDevices = Context::kMaxDevType + 1; + static constexpr size_t kMaxNumberOfDeviceIDs = Context::kMaxDevID + 1; template using CurrentStorageManager = storage::PooledStorageManager; static void ActivateDevice(Context ctx) { - switch (ctx.dev_mask) { - case cpu::kDevMask: - break; - case gpu::kDevMask: + switch (ctx.dev_type) { + case Context::kCPU: break; + case Context::kGPU: + case Context::kCPUPinned: #if MXNET_USE_CUDA CUDA_CALL(cudaSetDevice(ctx.dev_id)); #else // MXNET_USE_CUDA @@ -57,26 +57,28 @@ Storage::Handle Storage::Alloc(size_t size, Context ctx) { hd.size = size; { std::lock_guard lock{impl_->m}; - auto&& device = impl_->storage_managers.at(ctx.dev_mask); + auto&& device = impl_->storage_managers.at(ctx.dev_type); auto&& device_id_it = device.at(ctx.dev_id); // Allocate device if necessary. if (!device_id_it) { - switch (ctx.dev_mask) { - case cpu::kDevMask: - if (ctx.dev_id == Context::kPinnedMemoryID) { - device_id_it = common::MakeUnique< + switch (ctx.dev_type) { + case Context::kCPU: { + device_id_it = common::MakeUnique< Storage::Impl::CurrentStorageManager< - storage::PinnedMemoryStorage>>(); - } else { - device_id_it = common::MakeUnique< + storage::CPUDeviceStorage>>(); + break; + } + case Context::kCPUPinned: { + device_id_it = common::MakeUnique< Storage::Impl::CurrentStorageManager< - storage::CPUDeviceStorage>>(); - } + storage::PinnedMemoryStorage>>(); break; - case gpu::kDevMask: + } + case Context::kGPU: { device_id_it = common::MakeUnique>(); + storage::GPUDeviceStorage>>(); break; + } default: LOG(FATAL) << "Unimplemented device"; } @@ -90,7 +92,7 @@ Storage::Handle Storage::Alloc(size_t size, Context ctx) { void Storage::Free(Storage::Handle handle) { std::lock_guard lock{impl_->m}; Impl::ActivateDevice(handle.ctx); - impl_->storage_managers.at(handle.ctx.dev_mask) + impl_->storage_managers.at(handle.ctx.dev_type) .at(handle.ctx.dev_id) ->Free(handle.dptr, handle.size); } diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 3ac19fa4469e..cf70826919fa 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -216,7 +216,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { Operator* op = op_node.op.get(); OpContext* op_ctx_ptr = &op_node.op_ctx; - bool is_gpu = op_node.ctx.dev_mask == gpu::kDevMask; + bool is_gpu = op_node.ctx.dev_mask() == gpu::kDevMask; exec.exec_fun = [op, is_gpu, op_ctx_ptr, in_data, req, out_data, aux_states] (RunContext ctx, Engine::CallbackOnComplete on_complete) { op_ctx_ptr->run_ctx = ctx;