Skip to content

Commit

Permalink
Optimize transpose operator with MKL-DNN (apache#14545)
Browse files Browse the repository at this point in the history
* add mkldnn transpose

* general transpose

* support mkldnn format

* fix lint

* address comments

* add unit test

* add comments

* retrigger CI
  • Loading branch information
TaoLv authored and larroy committed Apr 15, 2019
1 parent 69bc3a6 commit 2c17660
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,14 @@ struct ConvolutionParam;
struct DeconvolutionParam;
struct SoftmaxParam;
struct SoftmaxOutputParam;
struct TransposeParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
6 changes: 6 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &data,
const OpReqType &req,
const NDArray &output);

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
Expand Down
161 changes: 161 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_transpose.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_transpose.cc
* \brief Implement transpose operator via MKL-DNN reorder primitive
* \author Tao Lv
*/

#if MXNET_USE_MKLDNN == 1

#include <mkldnn.hpp>
#include "../../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {

bool SupportMKLDNNTranspose(const TransposeParam& param,
const NDArray &data) {
auto data_ndim = data.shape().ndim();

if (data_ndim > 4 || data.dtype() != mshadow::kFloat32)
return false;

return true;
}

typedef ParamOpSign<TransposeParam> MKLDNNTransposeSignature;

class MKLDNNTransposeForward {
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory::primitive_desc> dst_pd_;
std::shared_ptr<mkldnn::reorder> transpose_;

public:
MKLDNNTransposeForward(const TransposeParam& param,
const NDArray &data) {
auto shape = data.shape();
auto data_ndim = shape.ndim();
auto axes_ndim = param.axes.ndim();
auto axes = mxnet::TShape(data_ndim);
if (axes_ndim == 0) {
for (size_t i = 0; i < data_ndim; i++) {
axes[i] = data_ndim - i - 1;
}
} else {
axes = param.axes;
}

auto engine = CpuEngine::Get()->get_engine();
auto in_mem = data.GetMKLDNNData();
auto src_pd = in_mem->get_primitive_desc();
data_ = std::make_shared<mkldnn::memory>(src_pd, nullptr);

// destination
// Not all formats are well defined with a certain name in MKL-DNN.
// For example, transpose(NCHW, (0, 2, 1, 3)) -> NHCW, which is not explicitly defined in
// MKL-DNN. To support general transposing, we need create destination format from scratch.
mkldnn_memory_desc_t dst_fmt;
dst_fmt.primitive_kind = mkldnn_memory;
dst_fmt.ndims = data_ndim;
dst_fmt.data_type = mkldnn_f32;
dst_fmt.format = mkldnn_blocked;

for (size_t i = 0; i < data_ndim; i++)
dst_fmt.dims[i] = shape[i];

unsigned int total_stride = 1;
for (int i = data_ndim - 1; i >= 0; i--) {
dst_fmt.layout_desc.blocking.padding_dims[i] = shape[i];
dst_fmt.layout_desc.blocking.block_dims[i] = 1;
dst_fmt.layout_desc.blocking.offset_padding_to_data[i]= 0;
// strides[0]: stride between the first elements of adjacent blocks.
dst_fmt.layout_desc.blocking.strides[0][axes[i]] = total_stride;
// strides[1]: strides between elements in the same block.
dst_fmt.layout_desc.blocking.strides[1][axes[i]] = 1;

total_stride *= shape[axes[i]];
}

dst_fmt.layout_desc.blocking.offset_padding = 0;
dst_pd_ = std::make_shared<mkldnn::memory::primitive_desc>(dst_fmt, engine);
out_ = std::make_shared<mkldnn::memory>(*dst_pd_, nullptr);

transpose_ = std::make_shared<mkldnn::reorder>(*data_, *out_);
}

void SetNewMem(const NDArray &data, const NDArray &output) {
if (data.IsMKLDNNData()) {
this->data_->set_data_handle(data.GetMKLDNNData()->get_data_handle());
} else {
MSHADOW_TYPE_SWITCH(data.dtype(), DTYPE, {
this->data_->set_data_handle(data.data().dptr<DTYPE>());
});
}

CHECK(!output.IsMKLDNNData());
MSHADOW_TYPE_SWITCH(output.dtype(), DTYPE, {
this->out_->set_data_handle(output.data().dptr<DTYPE>());
});
}

const mkldnn::reorder &GetFwd() const {
return *transpose_;
}
};

static MKLDNNTransposeForward &GetTransposeForward(const TransposeParam& param,
const NDArray &data) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNTransposeSignature,
MKLDNNTransposeForward, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNTransposeSignature,
MKLDNNTransposeForward, OpHash> fwds;
#endif
MKLDNNTransposeSignature key(param);
key.AddSign(data);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNTransposeForward fwd(param, data);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &data,
const OpReqType &req,
const NDArray &output) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);

auto stream = MKLDNNStream::Get();
auto fwd = GetTransposeForward(param, data);

fwd.SetNewMem(data, output);
stream->RegisterPrim(fwd.GetFwd());
stream->Submit();
}
} // namespace op
} // namespace mxnet
#endif
15 changes: 15 additions & 0 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,10 @@ struct TransposeParam : public dmlc::Parameter<TransposeParam> {
DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape())
.describe("Target axis order. By default the axes will be inverted.");
}

bool operator==(const TransposeParam &other) const {
return this->axes == other.axes;
}
};

template<typename xpu>
Expand Down Expand Up @@ -2841,4 +2845,15 @@ inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) {
} // namespace op
} // namespace mxnet

namespace std {
template<>
struct hash<mxnet::op::TransposeParam> {
size_t operator()(const mxnet::op::TransposeParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.axes);
return ret;
}
};
} // namespace std

#endif // MXNET_OPERATOR_TENSOR_MATRIX_OP_INL_H_
34 changes: 34 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,35 @@ Example::
})
.add_argument("data", "NDArray-or-Symbol", "Input array.");

#if MXNET_USE_MKLDNN == 1
static void TransposeComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);

if (SupportMKLDNNTranspose(param, inputs[0])) {
MKLDNNTransposeForward(attrs, ctx, inputs[0], req[0], outputs[0]);
return;
}
FallBackCompute(Transpose<cpu>, attrs, ctx, inputs, req, outputs);
}

inline static bool TransposeStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}
#endif

NNVM_REGISTER_OP(transpose)
.describe(R"code(Permutes the dimensions of an array.
Expand Down Expand Up @@ -393,6 +422,11 @@ Examples::
}
})
.set_attr<FCompute>("FCompute<cpu>", Transpose<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", TransposeComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", TransposeStorageType)
#endif
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(TransposeParam::__FIELDS__());

Expand Down
15 changes: 15 additions & 0 deletions tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,21 @@ def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 'conv_weight': mx.nd.ones([8,3,5,5])})
exec1.forward()[0].wait_to_read()

@with_seed()
def test_conv_transpose():
axes = [(0,2,1,3), (0,2,3,1), (1,2,3,0), (3,2,1,0)]
a = np.random.rand(10, 16, 50, 50)
b = np.random.rand(32, 16, 3, 3)
x = mx.nd.array(a)
w = mx.nd.array(b)
y = mx.nd.Convolution(data=x, weight=w, kernel=(3, 3), num_group=1, num_filter=32, no_bias=True)
for axis in axes:
t = mx.nd.transpose(y, axis)
t.wait_to_read()
s = y.asnumpy()
n = np.transpose(s, axis)
np.allclose(t.asnumpy(), n)


if __name__ == '__main__':
install.test_mkldnn_install()

0 comments on commit 2c17660

Please sign in to comment.