Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Don't disable MKLDNN
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Dec 3, 2018
1 parent e74f796 commit 77015a0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 73 deletions.
72 changes: 7 additions & 65 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#if MXNET_USE_MKLDNN == 1
#include "./mkldnn/mkldnn_base-inl.h"
#include "./mkldnn/mkldnn_ops-inl.h"
#endif // MXNET_USE_MKLDNN
#endif // MXNET_USE_MKLDNN == 1
#include "../operator_common.h"
#include "../../common/utils.h"

Expand All @@ -40,49 +40,22 @@ namespace op {
namespace activation {

int GradNumInputs(int act_type) {
#if MXNET_USE_CUDNN == 1
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
case kSoftReLU:
return 2;
case kSoftSign:
case kTanh:
case kSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
}
#elif MXNET_USE_MKLDNN == 1
// \sa ActivationGradComputeExCPU
switch (act_type) {
case kReLU:
return 2;
case kSigmoid:
case kTanh:
case kSoftReLU:
case kSoftSign:
return 3;
default:
CHECK(false) << "missing activation type";
}
#else
// check activation-inl.h \sa ActivationGradComputeImpl
switch (act_type) {
case kReLU:
case kSigmoid:
case kTanh:
case kSoftReLU:
return 2;
case kSoftSign:
case kSigmoid:
return 3;
default:
CHECK(false) << "missing activation type";
}
#endif
// unreachable
return -1;
}

} // namespace activation

DMLC_REGISTER_PARAMETER(ActivationParam);
Expand All @@ -99,52 +72,21 @@ struct ActivationGrad {
const NodeAttrs& attrs = n->attrs;
using namespace activation;
int act_type = dmlc::get<ActivationParam>(attrs.parsed).act_type;
#if MXNET_USE_CUDNN == 1
// for ReLU, no need to pass input data. This enables inplace optimization during the
// forward pass.
// check activation.cu \sa ActivationGradCompute
switch (act_type) {
case kReLU:
case kSoftReLU:
break;
case kSoftSign:
case kTanh:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}
#elif MXNET_USE_MKLDNN == 1
// \sa ActivationGradComputeExCPU
switch (act_type) {
case kReLU:
break;
case kSoftSign:
case kTanh:
case kSoftReLU:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}

#else
// check activation-inl.h \sa ActivationGradComputeImpl
switch (act_type) {
case kSoftSign:
heads.push_back(n->inputs[activation::kData]);
break;
case kReLU:
case kTanh:
case kSoftReLU:
case kSigmoid:
heads.push_back(n->inputs[activation::kData]);
break;
default:
CHECK(false) << "missing activation type";
}
#endif
return MakeGradNode(op_name, n, heads, n->attrs.dict);
}
};
Expand Down Expand Up @@ -177,9 +119,9 @@ void ActivationGradComputeExCPU(const nnvm::NodeAttrs& attrs,
MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
const bool relu = param.act_type == activation::kReLU;
MKLDNNActivationBackward(attrs, ctx, inputs[0], relu ? inputs[1] : inputs[2], req[0],
MKLDNNActivationBackward(attrs, ctx, inputs.at(0), relu ? inputs.at(1) : inputs.at(2), req[0],
outputs[0]);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
MKLDNN_OPCHECK_RUN(ActivationGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(ActivationGradComputeImpl<cpu>, attrs, ctx, inputs, req, outputs);
Expand Down Expand Up @@ -209,7 +151,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
dispatch_mode, in_attrs, out_attrs);
}
#endif
#endif // MXNET_USE_MKLDNN == 1


MXNET_OPERATOR_REGISTER_UNARY(Activation)
Expand Down
16 changes: 8 additions & 8 deletions src/operator/nn/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,20 @@ void ActivationGradCompute<gpu>(const nnvm::NodeAttrs& attrs,
// both SoftReLU and SoftSign not supported by CUDNN yet
if (act_type == activation::kSoftReLU) {
ActivationBackward<gpu, mshadow_op::softrelu, mshadow_op::softrelu_grad>(
ctx, inputs[0], inputs[1], req[0], outputs[0]);
ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]);
} else if (act_type == activation::kSoftSign) {
ActivationBackward<gpu, mshadow_op::softsign, mshadow_op::softsign_grad>(
ctx, inputs[0], inputs[2], req[0], outputs[0]);
ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]);
} else if (act_type == activation::kReLU) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, {
// XXX: for y = relu(x), y is passed as "in_data" to Backward()
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[1],
inputs[1], req[0], outputs[0]);
get_cudnn_op<DType>(param).Backward(ctx, inputs.at(0), inputs.at(1),
inputs.at(1), req[0], outputs[0]);
});
} else {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
get_cudnn_op<DType>(param).Backward(ctx, inputs[0], inputs[2],
inputs[1], req[0], outputs[0]);
MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, {
get_cudnn_op<DType>(param).Backward(ctx, inputs.at(0), inputs.at(2),
inputs.at(1), req[0], outputs[0]);
});
}
}
Expand Down
6 changes: 6 additions & 0 deletions tests/cpp/operator/activation_perf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ TEST(ACTIVATION_PERF, ExecuteBidirectional) {
runner.RunBidirectional(false, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
activation_args, "Activation", "_backward_Activation"), 1);
}
for (const string& activation : activations) {
kwargs_t activation_args = {{"act_type", activation}};
test::op::CoreOperatorRunner<float> runner;
runner.RunBidirectional(true, { shape }, test::op::CoreOpExecutor<float>::ArgsWithOpName(
activation_args, "Activation", "_backward_Activation"), 1);
}
}

/*!
Expand Down

0 comments on commit 77015a0

Please sign in to comment.