Skip to content

Commit

Permalink
fix fp32 flatten issue (apache#15351) (apache#15802)
Browse files Browse the repository at this point in the history
* Fix flatten issue before slice op

* fix cpplint

* address comments

* retrigger CI

* trigger CI

* retrigger CI

* use SupportMKLDNNReshape and update operator list
  • Loading branch information
juliusshufan authored and TaoLv committed Aug 12, 2019
1 parent 52ce718 commit 386ad26
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 111 deletions.
2 changes: 2 additions & 0 deletions docs/tutorials/mkldnn/operator_list.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ To help users understanding MKL-DNN backend better, the following table summariz
| **elemwise_add** | 1D-4D input | Y | Y | Y |
| **Concat** | 1D-4D input | Y | Y | Y |
| **slice** | 1D-4D input | N | Y | N |
| **Reshape** | 1D-4D input | N | Y | N |
| **Flatten** | 1D-4D input | N | Y | N |
| **Quantization** | 1D-4D input | N | N | Y |
| **Dequantization** | 1D-4D input | N | N | Y |
| **Requantization** | 1D-4D input | N | N | Y |
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ struct DeconvolutionParam;
struct SoftmaxParam;
struct SoftmaxOutputParam;
struct TransposeParam;
struct ReshapeParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportQuantizedMKLDNNAct(const ActivationParam &param);
Expand All @@ -184,6 +185,7 @@ bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input)
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
bool SupportMKLDNNReshape(const ReshapeParam &param, const NDArray &data);
} // namespace op

static int GetTypeSize(int dtype) {
Expand Down
87 changes: 87 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_flatten.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_flatten.cc
* \brief Implement flatten operator by using mkldnn reorder primitive
* \author Wuxun Zhang
*/

#if MXNET_USE_MKLDNN == 1

#include "mkldnn_reshape-inl.h"

namespace mxnet {
namespace op {

class MKLDNNFlattenFwd : public MKLDNNReshapeFwd {
public:
explicit MKLDNNFlattenFwd(const OpReqType &req,
const NDArray &input,
const NDArray &output)
: MKLDNNReshapeFwd(req, input, output) {}
};

static MKLDNNFlattenFwd &GetFlattenForward(const OpReqType &req,
const NDArray &input,
const NDArray &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature,
MKLDNNFlattenFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature,
MKLDNNFlattenFwd, OpHash> fwds;
#endif
OpSignature key;
key.AddSign(req);
key.AddSign(input);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNFlattenFwd fwd(req, input, output);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output) {
if (req == kNullOp) return;
CHECK_NE(req, kAddTo) << "kAddTo is not supported yet";

auto fwd = GetFlattenForward(req, input, output);
auto ws_size = fwd.GetWorkspaceSize();
void* ws_ptr = nullptr;
if (ws_size) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
mshadow::Tensor<cpu, 1, char> ws = ctx.requested[0]
.get_space_typed<cpu, 1, char>(mshadow::Shape1(ws_size), s);
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, ws_ptr);
}

} // namespace op
} // namespace mxnet

#endif
9 changes: 7 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,17 @@ void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpReqType &req,
const NDArray &output);

void MKLDNNReshapeForward(const nnvm::NodeAttrs &attrs,
void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &data,
const NDArray &input,
const OpReqType &req,
const NDArray &output);

void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &input,
const OpReqType &req,
const NDArray &output);
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
Expand Down
68 changes: 68 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_reshape-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_reshape-inl.h
* \brief Function definition of mkldnn reshape operator
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_

#if MXNET_USE_MKLDNN == 1
#include <vector>
#include "mkldnn_base-inl.h"
#include "../../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {

class MKLDNNReshapeFwd {
protected:
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn::primitive> prims_;
bool needInvalidateInput = false;

public:
MKLDNNReshapeFwd(const OpReqType &req,
const NDArray &input,
const NDArray &output);
int GetWorkspaceSize();
void SetNewMem(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
void Execute(const NDArray &input,
const NDArray &output,
void* workspace = nullptr);
};

typedef ParamOpSign<ReshapeParam> MKLDNNReshapeSignature;
MKLDNNReshapeFwd &GetReshapeForward(const ReshapeParam& param,
const OpReqType &req,
const NDArray &input,
const NDArray &output);

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_RESHAPE_INL_H_
Loading

0 comments on commit 386ad26

Please sign in to comment.