This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add mkldnn slice * fix lint * fix lint * mv SliceEx to matrix_op.cc * fix lint * optimize dispatch_mode * retrigger ci * fix indent
- Loading branch information
1 parent
5b011b3
commit 2616275
Showing
5 changed files
with
292 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file mkldnn_slice-inl.h | ||
* \brief | ||
* \author Zhiyuan Huang | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_ | ||
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_ | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
|
||
#include <dmlc/logging.h> | ||
#include <dmlc/parameter.h> | ||
#include <mxnet/operator.h> | ||
#include <utility> | ||
#include "../../operator_common.h" | ||
#include "../../tensor/slice-inl.h" | ||
#include "./mkldnn_base-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
class MKLDNNSliceFwd { | ||
public: | ||
MKLDNNSliceFwd(const SliceParam ¶m, | ||
const NDArray &in, | ||
const NDArray &out); | ||
void SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output); | ||
const mkldnn::reorder &GetPd() const; | ||
|
||
private: | ||
std::shared_ptr<mkldnn::memory> data_; | ||
std::shared_ptr<mkldnn::memory> out_; | ||
std::shared_ptr<mkldnn::reorder> fwd_; | ||
}; | ||
|
||
typedef ParamOpSign<SliceParam> MKLDNNSliceSignature; | ||
MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train, | ||
const NDArray &in_data, const NDArray &out_data); | ||
|
||
void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, | ||
const NDArray &in, OpReqType req, const NDArray &out); | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
#endif // MXNET_USE_MKLDNN == 1 | ||
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_SLICE_INL_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file mkldnn_slice.cc | ||
* \brief | ||
* \author Zhiyuan Huang | ||
*/ | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
|
||
#include "./mkldnn_ops-inl.h" | ||
#include "./mkldnn_base-inl.h" | ||
#include "./mkldnn_slice-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
MKLDNNSliceFwd::MKLDNNSliceFwd(const SliceParam ¶m, | ||
const NDArray &in, | ||
const NDArray &out) { | ||
const TShape ishape = in.shape(); | ||
const TShape oshape = out.shape(); | ||
uint32_t N = ishape.ndim(); | ||
mkldnn::memory::dims dims(N); | ||
mkldnn::memory::dims offsets(N); | ||
for (uint32_t i = 0; i < N; ++i) { | ||
int s = 0; | ||
if (param.begin[i]) { | ||
s = *param.begin[i]; | ||
if (s < 0) s += ishape[i]; | ||
} | ||
dims[i] = oshape[i]; | ||
offsets[i] = s; | ||
} | ||
auto in_mem_pd = in.GetMKLDNNData()->get_primitive_desc(); | ||
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc(); | ||
auto view_pd = mkldnn::view::primitive_desc(in_mem_pd, dims, offsets); | ||
auto reorder_pd = reorder::primitive_desc(view_pd.dst_primitive_desc(), out_mem_pd); | ||
this->data_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr); | ||
this->out_ = std::make_shared<mkldnn::memory>(view_pd.dst_primitive_desc(), nullptr); | ||
this->fwd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *this->data_, *this->out_); | ||
} | ||
|
||
void MKLDNNSliceFwd::SetNewMem(const mkldnn::memory &input, const mkldnn::memory &output) { | ||
this->data_->set_data_handle(input.get_data_handle()); | ||
this->out_->set_data_handle(output.get_data_handle()); | ||
} | ||
|
||
const mkldnn::reorder &MKLDNNSliceFwd::GetPd() const { | ||
return *fwd_; | ||
} | ||
|
||
MKLDNNSliceFwd &GetSliceForward(const SliceParam ¶m, const bool is_train, | ||
const NDArray &in_data, const NDArray &out_data) { | ||
#if DMLC_CXX11_THREAD_LOCAL | ||
static thread_local std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds; | ||
#else | ||
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSliceSignature, MKLDNNSliceFwd, OpHash> fwds; | ||
#endif | ||
MKLDNNSliceSignature key(param); | ||
key.AddSign(is_train); | ||
key.AddSign(in_data); | ||
key.AddSign(out_data); | ||
|
||
auto it = fwds.find(key); | ||
if (it == fwds.end()) { | ||
MKLDNNSliceFwd fwd(param, in_data, out_data); | ||
it = AddToCache(&fwds, key, fwd); | ||
} | ||
return it->second; | ||
} | ||
|
||
void MKLDNNSlice(const SliceParam ¶m, const OpContext& ctx, | ||
const NDArray &in, OpReqType req, const NDArray &out) { | ||
MKLDNNSliceFwd &fwd = GetSliceForward(param, ctx.is_train, in, out); | ||
auto in_mem = in.GetMKLDNNData(); | ||
auto out_mem_pd = out.GetMKLDNNData()->get_primitive_desc(); | ||
auto out_mem = CreateMKLDNNMem(out, out_mem_pd, req); | ||
fwd.SetNewMem(*in_mem, *out_mem.second); | ||
MKLDNNStream::Get()->RegisterPrim(fwd.GetPd()); | ||
CommitOutput(out, out_mem); | ||
MKLDNNStream::Get()->Submit(); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
#endif // MXNET_USE_MKLDNN == 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file slice-inl.h | ||
* \brief | ||
* \author Zhiyuan Huang | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_TENSOR_SLICE_INL_H_ | ||
#define MXNET_OPERATOR_TENSOR_SLICE_INL_H_ | ||
|
||
#include <utility> | ||
#include <vector> | ||
#include <string> | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
struct SliceParam : public dmlc::Parameter<SliceParam> { | ||
nnvm::Tuple<dmlc::optional<int>> begin, end; | ||
nnvm::Tuple<dmlc::optional<int>> step; | ||
DMLC_DECLARE_PARAMETER(SliceParam) { | ||
DMLC_DECLARE_FIELD(begin) | ||
.describe("starting indices for the slice operation, supports negative indices."); | ||
DMLC_DECLARE_FIELD(end) | ||
.describe("ending indices for the slice operation, supports negative indices."); | ||
DMLC_DECLARE_FIELD(step) | ||
.set_default(nnvm::Tuple<dmlc::optional<int>>()) | ||
.describe("step for the slice operation, supports negative values."); | ||
} | ||
bool operator==(const SliceParam& other) const { | ||
return this->begin == other.begin && | ||
this->end == other.end && | ||
this->step == other.step; | ||
} | ||
}; | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
namespace std { | ||
template<> | ||
struct hash<mxnet::op::SliceParam> { | ||
size_t operator()(const mxnet::op::SliceParam& val) { | ||
size_t ret = 0; | ||
ret = dmlc::HashCombine(ret, val.begin); | ||
ret = dmlc::HashCombine(ret, val.end); | ||
ret = dmlc::HashCombine(ret, val.step); | ||
return ret; | ||
} | ||
}; | ||
} // namespace std | ||
|
||
#endif // MXNET_OPERATOR_TENSOR_SLICE_INL_H_ |