Skip to content

Commit

Permalink
Update MXNet for MKLDNN.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Dec 17, 2017
1 parent 3f75f52 commit edf6842
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 147 deletions.
99 changes: 17 additions & 82 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,44 +103,8 @@ class NDArray {
NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx,
bool delay_alloc = true, int dtype = mshadow::default_type_flag,
std::vector<int> aux_types = {}, std::vector<TShape> aux_shapes = {},
TShape storage_shape = TShape(mshadow::Shape1(0)))
: shape_(shape), dtype_(dtype), storage_type_(stype),
entry_({nullptr, 0, 0}) {
// Assign default aux types if not given
if (aux_types.size() == 0) {
if (stype == kRowSparseStorage) {
aux_types = {mshadow::kInt64};
} else if (stype == kCSRStorage) {
aux_types = {mshadow::kInt64, mshadow::kInt64};
} else {
LOG(FATAL) << "Unknown storage type " << stype;
}
}
// Assign default shapes if not given
// unknown shapes are intialized as {0} such that Size() would return 0
if (aux_shapes.size() == 0) {
if (stype == kRowSparseStorage) {
aux_shapes = {TShape(mshadow::Shape1(0))};
} else if (stype == kCSRStorage) {
// aux shapes for indptr and indices
aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))};
} else {
LOG(FATAL) << "Unknown storage type " << stype;
}
}
if (storage_shape.Size() == 0) {
if (stype == kRowSparseStorage) {
storage_shape = shape;
storage_shape[0] = aux_shapes[rowsparse::kIdx][0];
} else if (stype == kCSRStorage) {
storage_shape = aux_shapes[csr::kIdx];
} else {
LOG(FATAL) << "Unknown storage type " << stype;
}
}
ptr_ = std::make_shared<Chunk>(stype, storage_shape, ctx, delay_alloc,
dtype, aux_types, aux_shapes);
}
TShape storage_shape = TShape(mshadow::Shape1(0)));

/*!
* \brief constructing a static NDArray that shares data with TBlob
* Use with caution: allocate ONLY ONE NDArray for each TBlob,
Expand Down Expand Up @@ -591,10 +555,6 @@ class NDArray {
std::shared_ptr<const mkldnn::memory> GetMKLDNNData(
const mkldnn::memory::primitive_desc &desc,
std::vector<mkldnn::primitive> &net) const;
std::shared_ptr<mkldnn::memory> GetMKLDNNData();
std::shared_ptr<mkldnn::memory> GetMKLDNNData(
const mkldnn::memory::primitive_desc &desc,
std::vector<mkldnn::primitive> &net);

std::shared_ptr<mkldnn::memory> CreateMKLDNNData(
const mkldnn::memory::primitive_desc &desc);
Expand Down Expand Up @@ -634,6 +594,12 @@ class NDArray {
for csr, aux_handles[0] = indptr, aux_handles[1] = indices
*/
std::vector<Storage::Handle> aux_handles;

#if MXNET_USE_MKLDNN == 1
/*! This is created when data is stored in MKLDNN format.
*/
std::shared_ptr<mkldnn::memory> Mkl_mem_;
#endif
/*! \brief variable from engine */
Engine::VarHandle var;
/*!
Expand Down Expand Up @@ -812,20 +778,14 @@ class NDArray {
// storage shape is also updated
// if data is already allocated, try reuse the storage. Otherwise, free the current one
// and allocate new storage
inline void CheckAndAllocData(const TShape &shape, int dtype) {
CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data";
auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype);
if (shandle.size < dbytes) {
// free storage if necessary and alloc again
if (shandle.size > 0) Storage::Get()->Free(shandle);
// init storage
shandle = Storage::Get()->Alloc(dbytes, ctx);
}
// init shape
storage_shape = shape;
// delay_alloc is only set when data storage handle is present
delay_alloc = false;
}
void CheckAndAllocData(const TShape &shape, int dtype);

#if MXNET_USE_MKLDNN == 1
// Have MKL memory reference to the data in the default storage
// or create memory for MKLDNN.
void SetMKLMem(const TShape &shape, int dtype);
#endif

// create storage handle for aux data based on shape
// this function assumes ctx, aux shapes and aux types are set
// aux shape is also updated
Expand Down Expand Up @@ -866,33 +826,8 @@ class NDArray {
}
}; // struct Chunk

#if MXNET_USE_MKLDNN == 1
// Have MKL memory reference to the data in TBlob.
void SetMKLMem();
#endif
void SetTBlob() const;

void SetTBlob() const {
CHECK(ptr_ != nullptr);
TShape shape = shape_;
char *dptr = static_cast<char*>(ptr_->shandle.dptr);
auto stype = storage_type();
if (stype == kDefaultStorage) {
dptr += byte_offset_;
} else if (stype == kCSRStorage || stype == kRowSparseStorage) {
CHECK_NE(byte_offset_, 0);
shape = storage_shape();
} else {
LOG(FATAL) << "unknown storage type " << stype;
}
tblob_.dptr_ = dptr;
tblob_.shape_ = shape;
tblob_.type_flag_ = dtype_;
tblob_.SetDLTensor(ptr_->shandle.ctx.dev_mask(), ptr_->shandle.ctx.dev_id);
}

#if MXNET_USE_MKLDNN == 1
std::shared_ptr<mkldnn::memory> Mkl_mem_;
#endif
/*! \brief internal data of NDArray */
std::shared_ptr<Chunk> ptr_{nullptr};
/*! \brief shape of current NDArray */
Expand Down
16 changes: 16 additions & 0 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,21 @@ void CastStorageDispatch<cpu>(const OpContext& ctx,
mxnet::op::CastStorageComputeImpl<cpu>(ctx, input, output);
}

std::string stype_string(const int x) {
switch (x) {
case kDefaultStorage:
return "default";
case kCSRStorage:
return "csr";
case kRowSparseStorage:
return "row_sparse";
#if MXNET_USE_MKLDNN == 1
case kMKLDNNStorage:
return "mkldnn";
#endif
}
return "unknown";
}

} // namespace common
} // namespace mxnet
12 changes: 1 addition & 11 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,17 +327,7 @@ inline std::string dispatch_mode_string(const DispatchMode x) {


/*! \brief get string representation of storage_type */
inline std::string stype_string(const int x) {
switch (x) {
case kDefaultStorage:
return "default";
case kCSRStorage:
return "csr";
case kRowSparseStorage:
return "row_sparse";
}
return "unknown";
}
std::string stype_string(const int x);

// heuristic to dermine number of threads per GPU
inline int GetNumThreadPerGPU() {
Expand Down
Loading

0 comments on commit edf6842

Please sign in to comment.