Skip to content

Commit

Permalink
[1.x] Backport of Fix LeakyRelu behaviour on empty input (apache#18934)…
Browse files Browse the repository at this point in the history
… (apache#19009)

* Fix LeakyRelu behaviour on empty input

* Remove duplicated declarations
  • Loading branch information
bgawrych committed Aug 26, 2020
1 parent dfefe87 commit bce4cc6
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 8 deletions.
2 changes: 2 additions & 0 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ void LeakyReLUCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (inputs[0].Size() == 0U) return;
const LeakyReLUParam &param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
Expand All @@ -352,6 +353,7 @@ void LeakyReLUGradCompute(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
if (inputs[0].Size() == 0U) return;
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
// inputs: out_grad, input_data, input_gamma, output, output_mask
Expand Down
2 changes: 2 additions & 0 deletions src/operator/leaky_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(inputs.size(), expected);
Expand All @@ -107,6 +108,7 @@ void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
std::vector<NDArray> in_data{inputs[0], inputs[1]};
Expand Down
7 changes: 0 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_act-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,6 @@ MKLDNNActForward &GetActForward(const MKLDNNActParam& param,
const OpContext &ctx, const NDArray &in_data,
const mkldnn::memory &in_mem);

void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const MKLDNNActParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
*/
#if MXNET_USE_MKLDNN == 1

#include "../../nn/mkldnn/mkldnn_act-inl.h"
#include "../../nn/mkldnn/mkldnn_ops-inl.h"
#include "../quantization_utils.h"

namespace mxnet {
Expand Down

0 comments on commit bce4cc6

Please sign in to comment.