Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix build
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoLv committed Sep 2, 2019
1 parent 06d98fe commit b27f4b9
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 95 deletions.
155 changes: 80 additions & 75 deletions src/operator/nn/mkldnn/mkldnn_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
* under the License.
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <atomic>
#include "./mkldnn_base-inl.h"
Expand Down Expand Up @@ -54,113 +54,116 @@ void *AlignMem(void *mem, size_t size, size_t alignment, size_t *space) {
return reinterpret_cast<void *>(addr);
}

mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::primitive_desc &pd) {
mkldnn::memory *TmpMemMgr::Alloc(const mkldnn::memory::desc &md) {
// We need to include the size of the memory used for alignment.
this->est_size += pd.get_size() + alignment;
void *mem = AlignMem(this->curr_mem, pd.get_size(), alignment, &this->curr_size);
this->est_size += md.get_size() + alignment;
void *mem = AlignMem(this->curr_mem, md.get_size(), alignment, &this->curr_size);
if (mem) {
// The memory is allocated from the temporary memory space in the
// operator. It'll only become invalid after we exit from the operator.
mkldnn_mem_ptr ret(new mkldnn::memory(pd, mem));
mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine(), mem));
MKLDNNStream::Get()->RegisterMem(ret);
CHECK_EQ(mem, mem);
this->curr_size -= pd.get_size();
this->curr_mem = static_cast<char *>(mem) + pd.get_size();
this->curr_size -= md.get_size();
this->curr_mem = static_cast<char *>(mem) + md.get_size();
return ret.get();
} else {
// If curr_mem has been initialized and we still reach here. It means
// the current allocated memory isn't enough.
if (this->curr_mem && dmlc::GetEnv("MXNET_MKLDNN_DEBUG", false)) {
LOG(WARNING) << "Allocate " << pd.get_size()
LOG(WARNING) << "Allocate " << md.get_size()
<< " bytes with malloc directly";
}
mkldnn_mem_ptr ret(new mkldnn::memory(pd));
mkldnn_mem_ptr ret(new mkldnn::memory(md, CpuEngine::Get()->get_engine()));
MKLDNNStream::Get()->RegisterMem(ret);
return ret.get();
}
}

void MKLDNNCopy(const mkldnn::memory &mem, const mkldnn::memory* this_mem) {
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn::memory::desc from_desc = mem.get_desc();
mkldnn::memory::desc this_desc = this_mem->get_desc();
mkldnn_format_tag_t from_def_format = GetDefaultFormat(from_desc);
mkldnn_format_tag_t this_def_format = GetDefaultFormat(this_desc);

mkldnn::memory::primitive_desc from_pd = mem.get_primitive_desc();
mkldnn::memory::desc from_desc = from_pd.desc();
mkldnn::memory::primitive_desc this_pd = this_mem->get_primitive_desc();
mkldnn::memory::desc this_desc = this_pd.desc();
mkldnn_memory_format_t from_def_format = GetDefaultFormat(from_desc);
mkldnn_memory_format_t this_def_format = GetDefaultFormat(this_desc);
// It's possible that the memory and the NDArray don't have the same shape.
if (!same_shape(this_desc, from_desc)
// If the source memory uses the default layout, we can reshape directly.
&& from_def_format == from_desc.data.format) {
if (!same_shape(this_desc, from_desc) && IsDefaultFormat(from_desc)) {
// In this case, we can simply create a new MKLDNN memory for the required
// shape.
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
mkldnn::memory::desc data_md(dims, this_dtype,
static_cast<mkldnn::memory::format_tag>(this_def_format));

mkldnn_mem_ptr tmp_mem(new mkldnn::memory(data_md, mem.get_engine(), mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, *tmp_mem},
{MKLDNN_ARG_TO, *this_mem}});
stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args);
} else if (!same_shape(this_desc, from_desc)) {
// In this case, the source memory stores data in a customized layout. We
// need to reorganize the data in memory before we can reshape.
mkldnn::memory::primitive_desc def_pd = GetPrimitiveDesc(from_pd, from_def_format);
mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_pd);
stream->RegisterPrim(mkldnn::reorder(mem, *def_mem));
mkldnn::memory::desc def_desc = GetDesc(from_desc, from_def_format);
mkldnn::memory *def_mem = TmpMemMgr::Get()->Alloc(def_desc);
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, mem},
{MKLDNN_ARG_TO, *def_mem}});
stream->RegisterPrimArgs(mkldnn::reorder(mem, *def_mem), args);

// Now we can reshape it
mkldnn::memory::dims dims(this_desc.data.dims,
this_desc.data.dims + this_desc.data.ndims);
auto this_dtype = static_cast<mkldnn::memory::data_type>(this_desc.data.data_type);
auto this_format = static_cast<mkldnn::memory::format>(GetDefaultFormat(this_desc));
mkldnn::memory::desc data_md(dims, this_dtype, this_format);
mkldnn::memory::primitive_desc pd(data_md, from_pd.get_engine());
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, def_mem->get_data_handle()));
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(this_desc,
mem.get_engine(), def_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
} else if (from_pd == this_pd) {
args = {{MKLDNN_ARG_FROM, *tmp_mem}, {MKLDNN_ARG_TO, *this_mem}};
stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args);
} else if (this_desc == from_desc) {
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, mem},
{MKLDNN_ARG_TO, *this_mem}});
// If the layout is the same, we can just copy data.
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
} else {
stream->RegisterPrimArgs(mkldnn::reorder(mem, *this_mem), args);
} else {
// If both are not using the default layouts. There isn't much we can do,
// other than reorder data layout directly.
if (this_def_format != this_desc.data.format
&& from_def_format != from_desc.data.format) {
stream->RegisterPrim(mkldnn::reorder(mem, *this_mem));
} else if (this_def_format == this_desc.data.format) {
if (!IsDefaultFormat(this_desc) && !IsDefaultFormat(from_desc)) {
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, mem},
{MKLDNN_ARG_TO, *this_mem}});
stream->RegisterPrimArgs(mkldnn::reorder(mem, *this_mem), args);
} else if (IsDefaultFormat(this_desc)) {
// If the dest mem uses the default memory layout, we can simply use
// the default format of the source memory to improve perf of reorder.
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(from_pd,
from_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, this_mem->get_data_handle()));
mkldnn::memory::desc desc = GetDesc(from_desc, from_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(desc,
mem.get_engine(), this_mem->get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(mem, *tmp_mem));
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, mem},
{MKLDNN_ARG_TO, *tmp_mem}});
stream->RegisterPrimArgs(mkldnn::reorder(mem, *tmp_mem), args);
} else {
// If the src mem uses the default memory layout, we can use
// the default format of the source memory to improve perf.
mkldnn::memory::primitive_desc pd = GetPrimitiveDesc(this_pd,
this_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(pd, mem.get_data_handle()));
mkldnn::memory::desc desc = GetDesc(this_desc, this_def_format);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(desc,
this_mem->get_engine(), mem.get_data_handle()));
stream->RegisterMem(tmp_mem);
stream->RegisterPrim(mkldnn::reorder(*tmp_mem, *this_mem));
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, *tmp_mem},
{MKLDNN_ARG_TO, *this_mem}});
stream->RegisterPrimArgs(mkldnn::reorder(*tmp_mem, *this_mem), args);
}
}
}

bool CanWriteTo(const NDArray &out_arr,
const NDArray &in_arr,
const mkldnn::memory::primitive_desc &desc) {
const mkldnn::memory::desc &desc) {
auto in_mem = in_arr.GetMKLDNNData();
bool add_same = in_mem->get_data_handle() == out_arr.GetMKLDNNData()->get_data_handle();
bool pdesc_same = out_arr.GetMKLDNNData()->get_primitive_desc() == desc &&
in_mem->get_primitive_desc() == desc;
bool pdesc_same = out_arr.GetMKLDNNData()->get_desc() == desc &&
in_mem->get_desc() == desc;
return add_same && pdesc_same;
}

mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr,
const mkldnn::memory::primitive_desc &desc,
const mkldnn::memory::desc &desc,
OpReqType req,
const NDArray* in_arr) {
if (kAddTo == req) {
Expand Down Expand Up @@ -188,7 +191,7 @@ mkldnn_output_t CreateMKLDNNMem(const NDArray &out_arr,
}

mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr,
const mkldnn::memory::primitive_desc &desc,
const mkldnn::memory::desc &desc,
OpReqType req) {
if (kAddTo == req) {
auto tmp = TmpMemMgr::Get()->Alloc(desc);
Expand All @@ -197,10 +200,8 @@ mkldnn_output_t CreateMKLDNNWeightGrad(const NDArray &out_arr,
auto tmp = TmpMemMgr::Get()->Alloc(desc);
return mkldnn_output_t(OutDataOp::CopyBack, tmp);
} else {
auto _desc = desc;
auto def_format = GetDefaultFormat(_desc.desc());
mkldnn::memory *mem = nullptr;
if (def_format == _desc.desc().data.format) {
if (IsDefaultFormat(desc)) {
mem = const_cast<NDArray &>(out_arr).CreateMKLDNNData(desc);
}
if (mem == nullptr) {
Expand All @@ -217,8 +218,8 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) {
const_cast<NDArray &>(arr).CopyFrom(*res.second);
} else if (res.first == AddBack) {
auto res_memory = res.second;
auto target_pd = arr.GetMKLDNNData()->get_primitive_desc();
auto mem = arr.GetMKLDNNData(res.second->get_primitive_desc());
auto target_pd = arr.GetMKLDNNData()->get_desc();
auto mem = arr.GetMKLDNNData(res.second->get_desc());
if (mem == nullptr) {
auto tmp_memory = TmpMemMgr::Get()->Alloc(target_pd);
MKLDNNCopy(*res_memory, tmp_memory);
Expand All @@ -232,12 +233,12 @@ void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) {
const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
const auto type = get_mkldnn_type(arr.dtype());
auto tz = mkldnn::memory::dims{0};
auto format = mkldnn::memory::format::format_undef;
auto format_tag = mkldnn::memory::format_tag::undef;
auto engine = CpuEngine::Get()->get_engine();
const int O = 0, I = 1, H = 2, W = 3;
if (arr.shape().ndim() == 2) {
tz = mkldnn::memory::dims{static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I])};
format = mkldnn::memory::format::oi;
format_tag = mkldnn::memory::format_tag::oi;
} else if (arr.shape().ndim() == 3) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
Expand All @@ -246,7 +247,8 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
: mkldnn::memory::dims{static_cast<int>(arr.shape()[O]),
static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H])};
format = num_groups > 1 ? mkldnn::memory::format::goiw : mkldnn::memory::format::oiw;
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goiw
: mkldnn::memory::format_tag::oiw;
} else if (arr.shape().ndim() == 4) {
tz = num_groups > 1
? mkldnn::memory::dims{num_groups, static_cast<int>(arr.shape()[O] / num_groups),
Expand All @@ -256,26 +258,28 @@ const mkldnn::memory *GetWeights(const NDArray &arr, int num_groups) {
: mkldnn::memory::dims{
static_cast<int>(arr.shape()[O]), static_cast<int>(arr.shape()[I]),
static_cast<int>(arr.shape()[H]), static_cast<int>(arr.shape()[W])};
format = num_groups > 1 ? mkldnn::memory::format::goihw : mkldnn::memory::format::oihw;
format_tag = num_groups > 1 ? mkldnn::memory::format_tag::goihw
: mkldnn::memory::format_tag::oihw;
} else {
LOG(FATAL) << "The weight array has an unsupported number of dimensions";
}
const auto md = mkldnn::memory::desc{tz, type, format};
const auto pd = mkldnn::memory::primitive_desc{md, engine};
return arr.GetMKLDNNData(pd);
const auto md = mkldnn::memory::desc{tz, type, format_tag};
return arr.GetMKLDNNData(md);
}

const mkldnn::memory *GetWeights(const NDArray &arr,
const mkldnn::memory::primitive_desc &target_pd, int num_groups) {
const mkldnn::memory *mem = arr.GetMKLDNNData(target_pd);
const mkldnn::memory::desc &target_desc, int num_groups) {
const mkldnn::memory *mem = arr.GetMKLDNNData(target_desc);
// If the weight array already uses the target layout, simply return it directly.
if (mem) return mem;
mem = GetWeights(arr, num_groups);
if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_pd);
if (mem->get_primitive_desc() == target_pd) return mem;
if (mem == nullptr) mem = arr.GetMKLDNNDataReorder(target_desc);
if (mem->get_desc() == target_desc) return mem;

auto ret = TmpMemMgr::Get()->Alloc(target_pd);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(*mem, *ret));
auto ret = TmpMemMgr::Get()->Alloc(target_desc);
std::unordered_map<int, mkldnn::memory> args({{MKLDNN_ARG_FROM, *mem},
{MKLDNN_ARG_TO, *ret}});
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(*mem, *ret), args);
return ret;
}

Expand Down Expand Up @@ -437,10 +441,11 @@ static bool SimilarArray(const mxnet::NDArray &arr1, const mxnet::NDArray &arr2,
std::atomic<bool> success(true);
#pragma omp parallel for
#ifdef _MSC_VER
for (int64_t i = 0; i < arr1.shape().Size(); i++) {
for (int64_t i = 0; i < arr1.shape().Size(); i++)
#else
for (size_t i = 0; i < arr1.shape().Size(); i++) {
for (size_t i = 0; i < arr1.shape().Size(); i++)
#endif
{
if (std::abs(data1[i] - data2[i]) > atol + rtol * std::abs(data2[i]))
success.store(false);
}
Expand Down
16 changes: 11 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_

#if MXNET_USE_MKLDNN == 1

#include <mxnet/io.h>
#include <mxnet/base.h>
Expand All @@ -36,11 +35,15 @@
#include <dmlc/logging.h>
#include <dmlc/optional.h>
#include <vector>

#if MXNET_USE_MKLDNN == 100
#include <mkldnn.hpp>
#endif

namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
/* For fully connected. */
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down Expand Up @@ -110,9 +113,6 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &data,
Expand All @@ -130,8 +130,14 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
#endif

#if MXNET_USE_MKLDNN == 100
void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);
#endif

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1

#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_
Loading

0 comments on commit b27f4b9

Please sign in to comment.