Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add mkldnn OP for slice #13730

Merged
merged 11 commits into from
Jan 16, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_slice-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_slice-inl.h
* \brief
* \author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_

#if MXNET_USE_MKLDNN == 1

#include <mkldnn.hpp>
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <utility>
#include "../../operator_common.h"
#include "../../tensor/slice-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

class MKLDNNSliceFwd {
public:
MKLDNNSliceFwd(const SliceParam &param,
const NDArray &in,
const NDArray &out);
void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output);
const mkldnn::reorder &GetPd() const;

huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::reorder> fwd_;
};

typedef ParamOpSign<SliceParam> MKLDNNSliceSignature;
MKLDNNSliceFwd &GetSliceForward(const SliceParam &param,
const NDArray &in_data, const NDArray &out_data);

void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out);

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_
103 changes: 103 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_slice.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_slice.cc
* \brief
* \author Zhiyuan Huang
*/

#if MXNET_USE_MKLDNN == 1

#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
#include "./mkldnn_slice-inl.h"

namespace mxnet {
namespace op {

MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam &param,
const NDArray &in,
const NDArray &out) {
const TShape ishape = in.shape();
const TShape oshape = out.shape();
uint32_t N = ishape.ndim();
mkldnn::memory::dims dims(N);
mkldnn::memory::dims offsets(N);
for (uint32_t i = 0; i < N; ++i) {
int s = 0;
if (param.begin[i]) {
s = *param.begin[i];
if (s < 0) s += ishape[i];
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
}
dims[i] = oshape[i];
offsets[i] = s;
}
auto in_mem = in.GetMKLDNNData();
auto in_mem_pd = in_mem->get_primitive_desc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please merge in_mem and in_mem_pd into one line as out_mem_pd in case the in_mem doesn't use in here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for the overall code style, I can also modify it if you insist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the "overall code style"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have merge im_mem and im_mem_pd into one line in the new commit.

auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets);
auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd);
this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr);
fwd_.reset(new mkldnn::reorder(reorder_pd, *data_, *out_));
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
}

void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) {
this->data_->set_data_handle(input.get_data_handle());
this->out_->set_data_handle(output.get_data_handle());
}

const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const {
return *fwd_;
}

MKLDNNSliceFwd &GetSliceForward(const SliceParam &param,
const NDArray &in_data, const NDArray &out_data) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds;
#endif
MKLDNNSliceSignature key(param);
key.AddSign(in_data);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have is_train and out_data into the key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In_data is enough for the key to cache. And I have put is_train and out_data into the key in new commit.


auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNSliceFwd fwd(param, in_data, out_data);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNSlice(const SliceParam &param, const OpContext& ctx,
const NDArray &in, OpReqType req, const NDArray &out) {
MKLDNNSliceFwd &fwd = GetSliceForward(param, in, out);
auto in_mem = in.GetMKLDNNData();
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc();
auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req);
fwd.SetNewMem(*in_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetPd());
CommitOutput(out, out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
69 changes: 37 additions & 32 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
#include "broadcast_reduce_op.h"
#include "./init_op.h"
#include "../../common/static_array.h"
#include "./slice-inl.h"
#include "../nn/mkldnn/mkldnn_slice-inl.h"

#if MXNET_USE_CUDA
#include <thrust/device_vector.h>
Expand Down Expand Up @@ -398,20 +400,6 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
return true;
}

struct SliceParam : public dmlc::Parameter<SliceParam> {
nnvm::Tuple<dmlc::optional<int>> begin, end;
nnvm::Tuple<dmlc::optional<int>> step;
DMLC_DECLARE_PARAMETER(SliceParam) {
DMLC_DECLARE_FIELD(begin)
.describe("starting indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(end)
.describe("ending indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(step)
.set_default(nnvm::Tuple<dmlc::optional<int>>())
.describe("step for the slice operation, supports negative values.");
}
};

inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand All @@ -432,7 +420,16 @@ inline bool SliceForwardInferStorageType(const nnvm::NodeAttrs& attrs,
&& (!param.step[0].has_value() || param.step[0].value() == 1)) {
trivial_step = true;
}
if (!dispatched && in_stype == kDefaultStorage) {
if (!dispatched && in_stype == kDefaultStorage && trivial_step) {
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
#if MXNET_USE_MKLDNN == 1
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, dispatch_ex);
#else
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
#endif
}
else if (!dispatched && in_stype == kDefaultStorage) {
dispatched = storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
Expand Down Expand Up @@ -604,23 +601,6 @@ void SliceCsrImpl(const SliceParam &param, const OpContext& ctx,
}
}

template<typename xpu>
void SliceEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
auto in_stype = inputs[0].storage_type();
if (in_stype == kCSRStorage) {
SliceCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
} else {
LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
}
}

template<int ndim>
inline void GetIndexRange(const TShape& dshape,
const nnvm::Tuple<dmlc::optional<int>>& param_begin,
Expand Down Expand Up @@ -829,6 +809,31 @@ void SliceOpForward(const nnvm::NodeAttrs& attrs,
})
}

template<typename xpu>
void SliceEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 1);
CHECK_EQ(outputs.size(), 1);
const SliceParam& param = nnvm::get<SliceParam>(attrs.parsed);
auto in_stype = inputs[0].storage_type();
if (in_stype == kCSRStorage) {
SliceCsrImpl<xpu>(param, ctx, inputs[0], req[0], outputs[0]);
#if MXNET_USE_MKLDNN == 1
huangzhiyuan marked this conversation as resolved.
Show resolved Hide resolved
} else if(in_stype == kDefaultStorage){ // For default storage, detect whether we are using MKLDNN or not
if (SupportMKLDNN(inputs[0])) {
MKLDNNSlice(param, ctx, inputs[0], req[0], outputs[0]);
} else {
FallBackCompute(SliceOpForward<xpu>, attrs, ctx, inputs, req, outputs);
}
#endif
} else {
LOG(FATAL) << "Slice not implemented for storage type" << in_stype;
}
}

template<int ndim, int req, typename xpu>
struct slice_assign;

Expand Down
3 changes: 3 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,9 @@ Example::
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_slice"})
.set_attr<FCompute>("FCompute<cpu>", SliceOpForward<cpu>)
.set_attr<FComputeEx>("FComputeEx<cpu>", SliceEx<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
#endif
.add_argument("data", "NDArray-or-Symbol", "Source input")
.add_arguments(SliceParam::__FIELDS__());

Expand Down
71 changes: 71 additions & 0 deletions src/operator/tensor/slice-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file slice-inl.h
* \brief
* \author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_TENSOR_SLICE_INL_H_
#define MXNET_OPERATOR_TENSOR_SLICE_INL_H_

#include <utility>
#include <vector>
#include <string>

namespace mxnet {
namespace op {

struct SliceParam : public dmlc::Parameter<SliceParam> {
nnvm::Tuple<dmlc::optional<int>> begin, end;
nnvm::Tuple<dmlc::optional<int>> step;
DMLC_DECLARE_PARAMETER(SliceParam) {
DMLC_DECLARE_FIELD(begin)
.describe("starting indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(end)
.describe("ending indices for the slice operation, supports negative indices.");
DMLC_DECLARE_FIELD(step)
.set_default(nnvm::Tuple<dmlc::optional<int>>())
.describe("step for the slice operation, supports negative values.");
}
bool operator==(const SliceParam& other) const {
return this->begin == other.begin &&
this->end == other.end &&
this->step == other.step;
}
};

} // namespace op
} // namespace mxnet

namespace std {
template<>
struct hash<mxnet::op::SliceParam> {
size_t operator()(const mxnet::op::SliceParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.begin);
ret = dmlc::HashCombine(ret, val.end);
ret = dmlc::HashCombine(ret, val.step);
return ret;
}
};
} // namespace std

#endif // MXNET_OPERATOR_TENSOR_SLICE_INL_H_