Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mkldnn surport for concat #3

Merged
merged 12 commits into from
Dec 8, 2017
4 changes: 2 additions & 2 deletions src/common/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ std::string stype_string(const int x) {
case kRowSparseStorage:
return "row_sparse";
#if MXNET_USE_MKLDNN == 1
case kMKLDNNStorage:
return "mkldnn";
case kMKLDNNStorage:
return "mkldnn";
#endif
}
return "unknown";
Expand Down
6 changes: 5 additions & 1 deletion src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ GraphExecutor::~GraphExecutor() {
}

inline bool SharableStorage(NDArrayStorageType stype) {
return stype == kDefaultStorage || stype == kMKLDNNStorage;
bool ret = stype == kDefaultStorage;
#if MXNET_USE_MKLDNN == 1
ret = ret || stype == kMKLDNNStorage;
#endif
return ret;
}

inline NDArray InitZeros(const NDArrayStorageType stype, const TShape &shape,
Expand Down
98 changes: 49 additions & 49 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
* \file ndarray.cc
* \brief ndarry module of mxnet
*/
#include <mkldnn.hpp>
#include <dmlc/io.h>
#include <dmlc/memory_io.h>
#include <dmlc/logging.h>
Expand All @@ -32,6 +31,7 @@
#include <mxnet/resource.h>
#include <mxnet/imperative.h>
#include <mshadow/tensor.h>
#include <mkldnn.hpp>
#include "./ndarray_function.h"
#include "../common/utils.h"
#include "../operator/tensor/matrix_op-inl.h"
Expand All @@ -48,10 +48,11 @@ DMLC_REGISTRY_ENABLE(::mxnet::NDArrayFunctionReg);

namespace mxnet {

static inline NDArrayStorageType DetermineSType(NDArrayStorageType stype, int dtype, const TShape &shape) {
static inline NDArrayStorageType DetermineSType(NDArrayStorageType stype,
int dtype, const TShape &shape) {
#if MXNET_USE_MKLDNN == 1
// We can't always generate a MKLDNN storage. If MKLDNN can't support the data type,
// we'll have to fall back to the default storage.
// We can't always generate a MKLDNN storage. If MKLDNN can't support
// the data type, we'll have to fall back to the default storage.
if (stype == kMKLDNNStorage && !SupportMKLDNNArray(dtype, shape))
return kDefaultStorage;
else
Expand Down Expand Up @@ -158,15 +159,14 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
#if MXNET_USE_MKLDNN == 1

static inline mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc) {
if (desc.data.ndims == 1)
if (desc.data.ndims == 1) {
return desc.data.format;
else if (desc.data.ndims == 2) {
} else if (desc.data.ndims == 2) {
if (desc.data.format == mkldnn_io)
return mkldnn_oi;
else
return desc.data.format;
}
else if (desc.data.ndims == 4) {
} else if (desc.data.ndims == 4) {
switch (desc.data.format) {
case mkldnn_nchw:
case mkldnn_nhwc:
Expand Down Expand Up @@ -194,8 +194,7 @@ static inline mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc)
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
}
else if (desc.data.ndims == 5) {
} else if (desc.data.ndims == 5) {
switch (desc.data.format) {
case mkldnn_goihw:
case mkldnn_gOIhw8i8o:
Expand All @@ -215,8 +214,7 @@ static inline mkldnn_memory_format_t GetDefaultFormat(mkldnn::memory::desc desc)
LOG(FATAL) << "Unknown MKLDNN format for 4 dimensions: " << desc.data.format;
return mkldnn_format_undef;
}
}
else {
} else {
LOG(FATAL) << "Unsupported dimensions: " << desc.data.ndims;
return mkldnn_format_undef;
}
Expand Down Expand Up @@ -287,9 +285,9 @@ NDArray NDArray::Reshape(const TShape &shape) const {
auto def_format = GetDefaultFormat(this->ptr_->Mkl_mem_->get_primitive_desc().desc());
if (this->ptr_->Mkl_mem_->get_primitive_desc().desc().data.format != def_format) {
ret.ptr_->Mkl_mem_ = Reorder2Default(this->ptr_->Mkl_mem_);
}
else
} else {
ret.ptr_->Mkl_mem_ = this->ptr_->Mkl_mem_;
}
}
}, ctx(), {this->var()}, {ret.var()},
FnProperty::kNormal, 0, PROFILER_MESSAGE("SyncMKLDNN2Default"));
Expand Down Expand Up @@ -340,8 +338,7 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
auto def_format = GetDefaultFormat(this->ptr_->Mkl_mem_->get_primitive_desc().desc());
if (this->ptr_->Mkl_mem_->get_primitive_desc().desc().data.format != def_format) {
ret.ptr_->Mkl_mem_ = Reorder2Default(this->ptr_->Mkl_mem_);
}
else {
} else {
ret.ptr_->Mkl_mem_ = this->ptr_->Mkl_mem_;
}
}, ctx(), {this->var()}, {ret.var()},
Expand Down Expand Up @@ -376,11 +373,13 @@ NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
}

NDArray NDArray::At(index_t idx) const {
CHECK(storage_type() == kDefaultStorage
#if MXNET_USE_MKLDNN == 1
|| storage_type() == kMKLDNNStorage
CHECK(storage_type() == kDefaultStorage
|| storage_type() == kMKLDNNStorage)
#else
CHECK(storage_type() == kDefaultStorage)
#endif
) << "Storage type " << storage_type() << " doesn't support At()";
<< "Storage type " << storage_type() << " doesn't support At()";
NDArray ret = this->Slice(idx, idx+1);
if (shape_.ndim() > 1) {
return ret.Reshape(TShape(shape_.data()+1, shape_.data()+shape_.ndim()));
Expand All @@ -390,11 +389,13 @@ NDArray NDArray::At(index_t idx) const {
}

NDArray NDArray::AtWithRecord(index_t idx) {
CHECK(storage_type() == kDefaultStorage
#if MXNET_USE_MKLDNN == 1
|| storage_type() == kMKLDNNStorage
CHECK(storage_type() == kDefaultStorage
|| storage_type() == kMKLDNNStorage)
#else
CHECK(storage_type() == kDefaultStorage)
#endif
) << "Storage type " << storage_type() << " doesn't support At()";
<< "Storage type " << storage_type() << " doesn't support At()";
NDArray ret = this->SliceWithRecord(idx, idx+1);
if (shape_.ndim() > 1) {
return ret.ReshapeWithRecord(TShape(shape_.data()+1, shape_.data()+shape_.ndim()));
Expand Down Expand Up @@ -450,7 +451,7 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
// The shape of the array and the one of the MKL memory may mismatch.
// For example, if the array stores parameters, the MKL memory may store data
// in 5 dimensions while the NDArray stores data in 4 dimensions.
// TODO is it possible that the MKL memory is out-of-date?
// TODO(zhengda) is it possible that the MKL memory is out-of-date?
if (Mkl_mem_ && storage_type == kMKLDNNStorage) {
return;
}
Expand All @@ -462,22 +463,21 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
dims.resize(shape.ndim());
for (size_t i = 0; i < dims.size(); i++)
dims[i] = shape[i];
}
// If there are 3 dimensions, we'll force it to 4 dimensions.
else if (shape.ndim() == 3) {
} else if (shape.ndim() == 3) {
// If there are 3 dimensions, we'll force it to 4 dimensions.
dims.resize(shape.ndim() + 1);
dims[0] = 1;
for (size_t i = 0; i < shape.ndim(); i++)
dims[i + 1] = shape[i];
}
else
} else {
LOG(FATAL) << "MKLDNN doesn't support " << shape.ndim() << " dimensions";
}
mkldnn::memory::format layout = mkldnn::memory::format::format_undef;
switch (dims.size()) {
case 1: layout = mkldnn::memory::format::x; break;
case 2: layout = mkldnn::memory::format::nc; break;
case 4: layout = mkldnn::memory::format::nchw; break;
// TODO This isn't the right layout when the data has 5 dimensions in MXNet.
// This isn't the right layout when the data has 5 dimensions in MXNet.
// MXNet interprets 5 dimensions as ncdhw, but MKLDNN doesn't have
// a corresponding format.
case 5: layout = mkldnn::memory::format::goihw; break;
Expand All @@ -491,9 +491,8 @@ void NDArray::Chunk::SetMKLMem(const TShape &shape, int dtype) {
CheckAndAlloc();
Mkl_mem_.reset(new mkldnn::memory(mkldnn::memory::primitive_desc(data_md,
cpu_engine), shandle.dptr));
}
// If the array uses MKLDNN storage, we need to allocate memory here.
else if (storage_type == kMKLDNNStorage) {
} else if (storage_type == kMKLDNNStorage) {
// If the array uses MKLDNN storage, we need to allocate memory here.
Mkl_mem_.reset(new mkldnn::memory(mkldnn::memory::primitive_desc(data_md,
cpu_engine)));
}
Expand Down Expand Up @@ -528,9 +527,9 @@ std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNData(
mkldnn_mem_ptr ret(new mkldnn::memory(desc, ptr_->Mkl_mem_->get_data_handle()));
MKLDNNStream::Instance().RegisterMem(ret);
return ret;
}
else
} else {
return nullptr;
}
}

std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNDataReorder(
Expand All @@ -557,17 +556,15 @@ std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNDataReorder(
mkldnn::memory::primitive_desc _desc = desc;
// Now we need to determine if we should reorder the memory.
// If both use the default formats, we think we don't need to reshape.
// TODO if the memory format isn't the default one, it may not work.
auto desc1 = ptr_->Mkl_mem_->get_primitive_desc().desc();
auto desc2 = _desc.desc();
if (desc1.data.format == GetDefaultFormat(desc1) &&
if (desc1.data.format == GetDefaultFormat(desc1) &&
desc2.data.format == GetDefaultFormat(desc2)) {
mkldnn_mem_ptr ret(new mkldnn::memory(desc, ptr_->Mkl_mem_->get_data_handle()));
stream.RegisterMem(ret);
return ret;
}
else {
// TODO we should manage the memory allocation here.
} else {
// TODO(zhengda) we should manage the memory allocation here.
mkldnn_mem_ptr ret(new mkldnn::memory(desc));
stream.RegisterMem(ret);
stream.RegisterPrim(mkldnn::reorder(*ptr_->Mkl_mem_, *ret));
Expand All @@ -576,14 +573,15 @@ std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNDataReorder(
}

std::shared_ptr<const mkldnn::memory> NDArray::GetMKLDNNData() const {
CHECK(storage_type() == kMKLDNNStorage || storage_type() == kDefaultStorage);
ptr_->SetMKLMem(shape_, dtype_);
if (ptr_->Mkl_mem_) {
MKLDNNStream::Instance().RegisterMem(ptr_->Mkl_mem_);
return ptr_->Mkl_mem_;
}
else
// TODO We don't support converting sparse format.
} else {
// We don't support converting sparse format.
return nullptr;
}
}

void NDArray::CopyFrom(const mkldnn::memory &mem) {
Expand All @@ -607,18 +605,20 @@ void NDArray::CopyFrom(const mkldnn::memory &mem) {
if (!same_shape(shape_, from_desc.data.dims, from_desc.data.ndims)) {
// In this case, we can simply create a new MKLDNN memory for the required
// shape.
// TODO let's just hope it's the default format for now.
// TODO(zhengda) let's just hope it's the default format for now.
CHECK_EQ(GetDefaultFormat(from_desc), from_desc.data.format);
mkldnn::memory::dims dims(this_desc.data.dims, this_desc.data.dims + this_desc.data.ndims);
mkldnn::memory::desc data_md(dims, static_cast<mkldnn::memory::data_type>(this_desc.data.data_type),
static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc)));
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, mem.get_primitive_desc().get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
stream.RegisterMem(tmp_mem);
stream.RegisterPrim(mkldnn::reorder(*tmp_mem, *ptr_->Mkl_mem_));
}
else
} else {
stream.RegisterPrim(mkldnn::reorder(mem, *ptr_->Mkl_mem_));
}
}

std::shared_ptr<mkldnn::memory> NDArray::CreateMKLDNNData(
Expand Down Expand Up @@ -668,7 +668,7 @@ void NDArray::SetTBlob() const {
ptr_->Mkl_mem_ = Reorder2Default(ptr_->Mkl_mem_);
else
ptr_->SetMKLMem(shape_, dtype_);
dptr = (char *) ptr_->Mkl_mem_->get_data_handle();
dptr = static_cast<char *>(ptr_->Mkl_mem_->get_data_handle());
#endif
} else {
LOG(FATAL) << "unknown storage type " << stype;
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/concat-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
* \brief
* \author Bing Xu
*/
#ifndef MXNET_OPERATOR_CONCAT_INL_H_
#define MXNET_OPERATOR_CONCAT_INL_H_
#ifndef MXNET_OPERATOR_NN_CONCAT_INL_H_
#define MXNET_OPERATOR_NN_CONCAT_INL_H_
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
Expand Down Expand Up @@ -156,4 +156,4 @@ void ConcatGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx,
} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONCAT_INL_H_
#endif // MXNET_OPERATOR_NN_CONCAT_INL_H_
Loading