Skip to content

Commit

Permalink
Add mkldnn OP for slice (apache#13730)
Browse files Browse the repository at this point in the history
* add mkldnn slice

* fix lint

* fix lint

* mv SliceEx to matrix_op.cc

* fix lint

* optimize dispatch_mode

* retrigger ci

* fix indent
  • Loading branch information
huangzhiyuan authored and haohuw committed Jun 23, 2019
1 parent 4e8ae72 commit 8c14c36
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 16 deletions.
66 changes: 66 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_slice-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_slice-inl.h
* \brief
* \author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_

#if MXNET_USE_MKLDNN == 1

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <utility>
#include "../../operator_common.h"
#include "../../tensor/slice-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

class MKLDNNSliceFwd {
public:
MKLDNNSliceFwd(const SliceParam &param,
const NDArray &in,
const NDArray &out);
void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output);
const mkldnn::reorder &GetPd() const;

private:
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::reorder> fwd_;
};

typedef ParamOpSign<SliceParam> MKLDNNSliceSignature;
MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
const NDArray &in_data, const NDArray &out_data);

void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out);

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
104 changes: 104 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_slice.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_slice.cc
* \brief
* \author Zhiyuan Huang
*/

#if MXNET_USE_MKLDNN == 1

#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
#include "./mkldnn_slice-inl.h"

namespace mxnet {
namespace op {

MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam &param,
const NDArray &in,
const NDArray &out) {
const TShape ishape = in.shape();
const TShape oshape = out.shape();
uint32_t N = ishape.ndim();
mkldnn::memory::dims dims(N);
mkldnn::memory::dims offsets(N);
for (uint32_t i = 0; i < N; ++i) {
int s = 0;
if (param.begin[i]) {
s = *param.begin[i];
if (s < 0) s += ishape[i];
}
dims[i] = oshape[i];
offsets[i] = s;
}
auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc();
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets);
auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd);
this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
this->fwd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *this->data_, *this->out_);
}

void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) {
this->data_->set_data_handle(input.get_data_handle());
this->out_->set_data_handle(output.get_data_handle());
}

const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const {
return *fwd_;
}

MKLDNNSliceFwd &GetSliceForward(const SliceParam &param, const bool is_train,
const NDArray &in_data, const NDArray &out_data) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
#endif
MKLDNNSliceSignature key(param);
key.AddSign(is_train);
key.AddSign(in_data);
key.AddSign(out_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNSliceFwd fwd(param, in_data, out_data);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out);
auto in_mem = in.GetMKLDNNData();
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req);
fwd.SetNewMem(*in_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetPd());
CommitOutput(out, out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
37 changes: 22 additions & 15 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "broadcast_reduce_op.h"
#include "./init_op.h"
#include "../../common/static_array.h"
#include "./slice-inl.h"

#if MXNET_USE_CUDA
#include <thrust/device_vector.h>
Expand Down Expand Up @@ -398,19 +399,15 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
return true;
}

struct SliceParam : public dmlc::Parameter<SliceParam> {
nnvm::Tuple<dmlc::optional<int>> begin, end;
nnvm::Tuple<dmlc::optional<int>> step;
DMLC_DECLARE_PARAMETER(SliceParam) {
DMLC_DECLARE_FIELD(begin)
.describe("starting indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(end)
.describe("ending indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(step)
.set_default(nnvm::Tuple<dmlc::optional<int>>())
.describe("step for the slice operation, supports negative values.");
// Currently MKLDNN only supports step = 1 or step has no value
inline bool SupportMKLDNNSlice(const SliceParam& param) {
if (param.step.ndim() == 0U) return true;
for (uint32_t i = 0; i < param.step.ndim(); ++i) {
if (param.step[i].has_value() && param.step[i].value() != 1)
return false;
}
};
return true;
}

inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
Expand All @@ -432,9 +429,19 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
&& (!param.step[0].has_value() || param.step[0].value() == 1)) {
trivial_step = true;
}
if (!dispatched && in_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);

if (in_stype == kDefaultStorage) {
#if MXNET_USE_MKLDNN == 1
if (dev_mask == Context::kCPU && MKLDNNEnvSet()
&& SupportMKLDNNSlice(param)) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, dispatch_ex);
}
#endif
if (!dispatched) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
}

if (!dispatched && in_stype == kCSRStorage && trivial_step) {
Expand Down
30 changes: 29 additions & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "./elemwise_unary_op.h"
#include "../nn/mkldnn/mkldnn_ops-inl.h"
#include "../nn/mkldnn/mkldnn_base-inl.h"
#include "../nn/mkldnn/mkldnn_slice-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -420,6 +421,30 @@ will return a new array with shape ``(2,1,3,4)``.
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(ExpandDimParam::__FIELDS__());

void SliceExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
auto in_stype = inputs[0].storage_type();
if (in_stype == kCSRStorage) {
SliceCsrImpl<cpu>(param, ctx, inputs[0], req[0], outputs[0]);
#if MXNET_USE_MKLDNN == 1
} else if (in_stype == kDefaultStorage) {
if (SupportMKLDNN(inputs[0])) {
MKLDNNSlice(param, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(SliceOpForward<cpu>, attrs, ctx, inputs, req, outputs);
}
#endif
} else {
LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
}
}

NNVM_REGISTER_OP(slice)
MXNET_ADD_SPARSE_OP_ALIAS(slice)
.add_alias("crop")
Expand Down Expand Up @@ -478,7 +503,10 @@ Example::
.set_attr<FInferStorageType>("FInferStorageType", SliceForwardInferStorageType)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
.set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceEx<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceExCPU)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
#endif
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(SliceParam::__FIELDS__());

Expand Down
71 changes: 71 additions & 0 deletions src/operator/tensor/slice-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file slice-inl.h
* \brief
* \author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_TENSOR_SLICE_INL_H_
#define MXNET_OPERATOR_TENSOR_SLICE_INL_H_

#include <utility>
#include <vector>
#include <string>

namespace mxnet {
namespace op {

struct SliceParam : public dmlc::Parameter<SliceParam> {
nnvm::Tuple<dmlc::optional<int>> begin, end;
nnvm::Tuple<dmlc::optional<int>> step;
DMLC_DECLARE_PARAMETER(SliceParam) {
DMLC_DECLARE_FIELD(begin)
.describe("starting indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(end)
.describe("ending indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(step)
.set_default(nnvm::Tuple<dmlc::optional<int>>())
.describe("step for the slice operation, supports negative values.");
}
bool operator==(const SliceParam& other) const {
return this->begin == other.begin &&
this->end == other.end &&
this->step == other.step;
}
};

} // namespace op
} // namespace mxnet

namespace std {
template<>
struct hash<mxnet::op::SliceParam> {
size_t operator()(const mxnet::op::SliceParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.begin);
ret = dmlc::HashCombine(ret, val.end);
ret = dmlc::HashCombine(ret, val.step);
return ret;
}
};
} // namespace std

#endif // MXNET_OPERATOR_TENSOR_SLICE_INL_H_

0 comments on commit 8c14c36

Please sign in to comment.