Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #102 from tqchen/master
Browse files Browse the repository at this point in the history
change def of context, change pinned to special device
  • Loading branch information
tqchen committed Sep 19, 2015
2 parents b7627d4 + 5b96c3d commit 5bc10be
Show file tree
Hide file tree
Showing 19 changed files with 177 additions and 122 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ WARNFLAGS= -Wall
CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS)

# CFLAGS for debug
ifeq ($(DEBUG),0)
CFLAGS += -O3
else
ifeq ($(DEBUG), 1)
CFLAGS += -g -O0
else
CFLAGS += -O3
endif
CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS) $(DMLC_CFLAGS)
CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
NVCCFLAGS = --use_fast_math -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
ROOTDIR = $(CURDIR)
Expand Down
99 changes: 72 additions & 27 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,41 +64,41 @@ typedef mshadow::TShape TShape;
/*! \brief storage container type */
typedef mshadow::TBlob TBlob;


/*! \brief Context information about the execution enviroment */
struct Context {
/*! \brief the device type we run the op can be cpu::kDevMask or gpu::kDevMask */
int32_t dev_mask;
/*! \brief Type of device */
enum DeviceType {
kCPU = cpu::kDevMask,
kGPU = gpu::kDevMask,
kCPUPinned = 3
};
/*! \brief the device type we run the op on */
DeviceType dev_type;
/*! \brief device id we are going to run it on */
int32_t dev_id;
/*! \brief constructor */
Context() : dev_mask(cpu::kDevMask), dev_id(0) {}
/*! \brief default constructor */
Context() : dev_type(kCPU), dev_id(0) {}
/*!
* \brief constructor of context
* \param dev_mask the device mask
* \param dev_id the device id
* \brief Get corresponding device mask
* \return cpu::kDevMask or gpu::kDevMask
*/
Context(int dev_mask, int dev_id)
: dev_mask(dev_mask), dev_id(dev_id) {}
inline int dev_mask() const {
if (dev_type == kCPUPinned) return cpu::kDevMask;
return dev_type;
}
/*!
* \brief Comparator, used to enable Context as std::map key.
* \param b another context to compare
* \return compared result
*/
inline bool operator<(const Context &b) const {
if (dev_mask == b.dev_mask) {
return dev_id < b.dev_id;
} else {
return dev_mask < b.dev_mask;
}
}
inline bool operator<(const Context &b) const;
/*!
* \brief check if current context equals another one
* \param b another context to compare
* \return whether dev mask and id are same
*/
inline bool operator==(const Context &b) const {
return dev_mask == b.dev_mask && dev_id == b.dev_id;
return dev_type == b.dev_type && dev_id == b.dev_id;
}
/*!
* \brief check if current context not equals another one
Expand All @@ -112,27 +112,44 @@ struct Context {
* \brief save the content into binary stream
* \param strm the output stream
*/
void Save(dmlc::Stream *strm) const {
strm->Write(&dev_mask, sizeof(dev_mask));
inline void Save(dmlc::Stream *strm) const {
strm->Write(&dev_type, sizeof(dev_type));
strm->Write(&dev_id, sizeof(dev_id));
}
/*!
* \brief load the content from binary stream
* \param strm the output stream
* \return whether the load is successful
*/
bool Load(dmlc::Stream *strm) {
if (strm->Read(&dev_mask, sizeof(int32_t)) != sizeof(int32_t)) return false;
inline bool Load(dmlc::Stream *strm) {
if (strm->Read(&dev_type, sizeof(dev_type)) != sizeof(dev_type)) return false;
if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false;
return true;
}
/*! \brief the maximal device mask, cpu = 1, gpu = 2 */
static const int32_t kMaxDevMask = 2;
/*! \brief the maximal device type */
static const int32_t kMaxDevType = 4;
/*! \brief the maximal device index */
static const int32_t kMaxDevID = 16;
/*!
* \brief Create a new context.
* \param dev_type device type.
* \param dev_id device id.
*/
inline static Context Create(DeviceType dev_type, int32_t dev_id);
/*! \return CPU Context */
inline static Context CPU();
/*!
* Create a GPU context.
* \param dev_id the device id.
* \return GPU Context.
*/
inline static Context GPU(int32_t dev_id);
/*!
* \brief A dedicate ID for pinned cpu memory.
* Any normal CPU ID should be less than this number.
* Create a pinned CPU context.
* \param dev_id the device id for corresponding GPU.
* \return Pinned CPU context.
*/
static const int32_t kPinnedMemoryID = 16;
inline static Context CPUPinned(int32_t dev_id);
};

/*!
Expand All @@ -157,6 +174,34 @@ struct RunContext {
} // namespace mxnet

//! \cond Doxygen_Suppress
namespace mxnet {
// implementing Context
inline bool Context::operator<(const Context &b) const {
if (dev_type == b.dev_type) {
return dev_id < b.dev_id;
} else {
return dev_type < b.dev_type;
}
}
inline Context Context::Create(DeviceType dev_type, int32_t dev_id) {
Context ctx;
ctx.dev_type = dev_type;
ctx.dev_id = dev_id;
return ctx;
}
inline Context Context::CPU() {
return Create(kCPU, 0);
}

inline Context Context::CPUPinned(int32_t dev_id) {
return Create(kCPUPinned, dev_id);
}

inline Context Context::GPU(int32_t dev_id) {
return Create(kGPU, dev_id);
}
} // namespace mxnet

namespace dmlc {
// Add a few patches to support TShape in dmlc/parameter.
DMLC_DECLARE_TYPE_NAME(mxnet::TShape, "Shape(tuple)");
Expand Down
20 changes: 10 additions & 10 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out);
* \brief create a NDArray with specified shape
* \param shape the pointer to the shape
* \param ndim the dimension of the shape
* \param dev_mask device mask, specify device we want to take
* \param dev_type device type, specify device we want to take
* \param dev_id the device id of the specific device
* \param delay_alloc whether to delay allocation until
* the narray is first mutated
Expand All @@ -86,7 +86,7 @@ MXNET_DLL int MXNDArrayCreateNone(NDArrayHandle *out);
*/
MXNET_DLL int MXNDArrayCreate(const mx_uint *shape,
mx_uint ndim,
int dev_mask,
int dev_type,
int dev_id,
int delay_alloc,
NDArrayHandle *out);
Expand Down Expand Up @@ -198,26 +198,26 @@ MXNET_DLL int MXNDArrayFree(NDArrayHandle handle);
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetShape(NDArrayHandle handle,
mx_uint *out_dim,
const mx_uint **out_pdata);
mx_uint *out_dim,
const mx_uint **out_pdata);
/*!
* \brief get the content of the data in NDArray
* \param handle the handle to the narray
* \param out_pdata pointer holder to get pointer of data
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetData(NDArrayHandle handle,
mx_float **out_pdata);
mx_float **out_pdata);
/*!
* \brief get the context of the NDArray
* \param handle the handle to the narray
* \param out_dev_mask the output device mask
* \param out_dev_type the output device type
* \param out_dev_id the output device id
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_mask,
int *out_dev_id);
int *out_dev_type,
int *out_dev_id);

//--------------------------------
// Part 2: functions on NDArray
Expand Down Expand Up @@ -547,7 +547,7 @@ MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle,
* \brief Generate Executor from symbol
*
* \param symbol_handle symbol handle
* \param dev_mask device mask
* \param dev_type device type
* \param dev_id device id
* \param len length
* \param in_args in args array
Expand All @@ -559,7 +559,7 @@ MXNET_DLL int MXExecutorOutputs(ExecutorHandle handle,
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXExecutorBind(SymbolHandle symbol_handle,
int dev_mask,
int dev_type,
int dev_id,
mx_uint len,
NDArrayHandle *in_args,
Expand Down
9 changes: 7 additions & 2 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class NDArray {
*/
inline TBlob data() const {
return TBlob(static_cast<real_t*>(ptr_->shandle.dptr) + offset_, \
shape_, ptr_->shandle.ctx.dev_mask);
shape_, ptr_->shandle.ctx.dev_mask());
}
/*!
* \return the context of NDArray, this function is only valid when the NDArray is not empty
Expand Down Expand Up @@ -288,7 +288,12 @@ class NDArray {
: static_data(true),
delay_alloc(false) {
var = Engine::Get()->NewVariable();
shandle.ctx = Context(data.dev_mask_, dev_id);
if (data.dev_mask_ == cpu::kDevMask) {
shandle.ctx = Context::CPU();
} else {
CHECK_EQ(data.dev_mask_, gpu::kDevMask);
shandle.ctx = Context::GPU(dev_id);
}
shandle.dptr = data.dptr_;
shandle.size = data.shape_.Size() * sizeof(real_t);
}
Expand Down
10 changes: 5 additions & 5 deletions python/mxnet/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ class Context(object):
"""Context representing device and device id in mxnet"""
# static class variable
default_ctx = None
devmask2type = {1: 'cpu', 2: 'gpu'}
devtype2mask = {'cpu': 1, 'gpu': 2}
devtype2str = {1: 'cpu', 2: 'gpu'}
devstr2type = {'cpu': 1, 'gpu': 2}

def __init__(self, device_type, device_id=0):
"""Constructing a context.
Expand All @@ -21,10 +21,10 @@ def __init__(self, device_type, device_id=0):
the device id of the device, needed for GPU
"""
if isinstance(device_type, Context):
self.device_mask = device_type.device_mask
self.device_typeid = device_type.device_typeid
self.device_id = device_type.device_id
else:
self.device_mask = Context.devtype2mask[device_type]
self.device_typeid = Context.devstr2type[device_type]
self.device_id = device_id
self._old_ctx = None

Expand All @@ -36,7 +36,7 @@ def device_type(self):
-------
device_type : str
"""
return Context.devmask2type[self.device_mask]
return Context.devtype2str[self.device_typeid]

def __str__(self):
return 'Context(device_type=%s, device_id=%d)' % (
Expand Down
9 changes: 5 additions & 4 deletions python/mxnet/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,11 @@ def _new_alloc_handle(shape, ctx, delay_alloc):
a new empty ndarray handle
"""
hdl = NDArrayHandle()
print ctx.device_typeid
check_call(_LIB.MXNDArrayCreate(
c_array(mx_uint, shape),
len(shape),
ctx.device_mask,
ctx.device_typeid,
ctx.device_id,
int(delay_alloc),
ctypes.byref(hdl)))
Expand Down Expand Up @@ -262,11 +263,11 @@ def context(self):
context : mxnet.Context
The context of current NDArray.
"""
dev_mask = ctypes.c_int()
dev_typeid = ctypes.c_int()
dev_id = ctypes.c_int()
check_call(_LIB.MXNDArrayGetContext(
self.handle, ctypes.byref(dev_mask), ctypes.byref(dev_id)))
return Context(Context.devmask2type[dev_mask.value], dev_id.value)
self.handle, ctypes.byref(dev_typeid), ctypes.byref(dev_id)))
return Context(Context.devtype2str[dev_typeid.value], dev_id.value)

def asnumpy(self):
"""Return a copied numpy array of current array.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None):

handle = ExecutorHandle()
check_call(_LIB.MXExecutorBind(self.handle,
mx_uint(ctx.device_mask),
mx_uint(ctx.device_id),
ctx.device_typeid,
ctx.device_id,
len(args),
args_handle,
args_grad_handle,
Expand Down
21 changes: 11 additions & 10 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,15 @@ int MXNDArrayCreateNone(NDArrayHandle *out) {

int MXNDArrayCreate(const mx_uint *shape,
mx_uint ndim,
int dev_mask,
int dev_type,
int dev_id,
int delay_alloc,
NDArrayHandle *out) {
API_BEGIN();
*out = new NDArray(TShape(shape, shape + ndim),
Context(dev_mask, dev_id),
delay_alloc != 0);
*out = new NDArray(
TShape(shape, shape + ndim),
Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id),
delay_alloc != 0);
API_END();
}

Expand Down Expand Up @@ -336,7 +337,7 @@ int MXNDArrayGetData(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
CHECK(arr->ctx().dev_mask == cpu::kDevMask)
CHECK(arr->ctx().dev_mask() == cpu::kDevMask)
<< "MXNDArrayGetData can only be called for NDArray on CPU";
const TBlob &b = arr->data();
CHECK(b.CheckContiguous());
Expand All @@ -348,16 +349,16 @@ int MXNDArrayGetData(NDArrayHandle handle,
}

int MXNDArrayGetContext(NDArrayHandle handle,
int *out_dev_mask,
int *out_dev_type,
int *out_dev_id) {
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
if (!arr->is_none()) {
const Context &ctx = arr->ctx();
*out_dev_mask = ctx.dev_mask;
*out_dev_type = ctx.dev_type;
*out_dev_id = ctx.dev_id;
} else {
*out_dev_mask = 0;
*out_dev_type = 0;
*out_dev_id = 0;
}
API_END();
Expand Down Expand Up @@ -764,7 +765,7 @@ int MXExecutorOutputs(ExecutorHandle handle,
}

int MXExecutorBind(SymbolHandle symbol_handle,
int dev_mask,
int dev_type,
int dev_id,
mx_uint len,
NDArrayHandle *in_args,
Expand All @@ -775,7 +776,7 @@ int MXExecutorBind(SymbolHandle symbol_handle,
ExecutorHandle *out) {
API_BEGIN();
Symbol *symb = static_cast<Symbol*>(symbol_handle);
Context ctx = Context(dev_mask, dev_id);
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);
NDArray **in_args_ptr = reinterpret_cast<NDArray**>(in_args);
NDArray **arg_grad_ptr = reinterpret_cast<NDArray**>(arg_grad_store);
NDArray **aux_states_ptr = reinterpret_cast<NDArray**>(aux_states);
Expand Down
Loading

0 comments on commit 5bc10be

Please sign in to comment.