Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Integrate MKL-DNN leakyrelu #16075

Merged
merged 9 commits into from
Sep 24, 2019
192 changes: 38 additions & 154 deletions src/operator/leaky_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,166 +332,50 @@ class LeakyReLUOp : public Operator {
}; // class LeakyReLUOp

template<typename xpu>
Operator* CreateOp(LeakyReLUParam type, int dtype);
void LeakyReLUCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const LeakyReLUParam &param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
CHECK_EQ(inputs.size(), expected);

#if DMLC_USE_CXX11
class LeakyReLUProp : public OperatorProperty {
public:
void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}

std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}

bool InferShape(mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape,
mxnet::ShapeVector *aux_shape) const override {
using namespace mshadow;
if (param_.act_type == leakyrelu::kPReLU) {
CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]";
} else {
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
}
const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
if (!mxnet::ndim_is_known(dshape)) return false;
if (param_.act_type == leakyrelu::kPReLU) {
const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
if (!mxnet::ndim_is_known(gshape)) {
in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
}
if (dshape == gshape) {
SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
}
}
out_shape->clear();
out_shape->push_back(dshape);
if (param_.act_type == leakyrelu::kRReLU) {
out_shape->push_back(dshape);
}
return true;
}

bool InferType(std::vector<int> *in_type,
std::vector<int> *out_type,
std::vector<int> *aux_type) const override {
int dtype = -1;
for (const int& type : *in_type) {
type_assign(&dtype, type);
}
for (const int& type : *out_type) {
type_assign(&dtype, type);
}

for (size_t i = 0; i < in_type->size(); ++i) {
TYPE_ASSIGN_CHECK(*in_type, i, dtype);
}
for (size_t i = 0; i < out_type->size(); ++i) {
TYPE_ASSIGN_CHECK(*out_type, i, dtype);
}
return dtype != -1;
}

OperatorProperty* Copy() const override {
auto ptr = new LeakyReLUProp();
ptr->param_ = param_;
return ptr;
}

std::string TypeString() const override {
return "LeakyReLU";
}

// decalre dependency and inplace optimization options
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data) const override {
if (param_.act_type == leakyrelu::kPReLU) {
return {out_grad[leakyrelu::kOut],
out_data[leakyrelu::kOut],
in_data[leakyrelu::kData],
in_data[leakyrelu::kGamma]};
} else if (param_.act_type == leakyrelu::kRReLU) {
return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kMask], out_data[leakyrelu::kOut]};
} else {
return {out_grad[leakyrelu::kOut], out_data[leakyrelu::kData]};
}
}
MSHADOW_REAL_TYPE_SWITCH(inputs[leakyrelu::kData].type_flag_, DType, {
LeakyReLUOp<xpu, DType> op(param);
op.Forward(ctx, inputs, req, outputs, no_use_but_adapt_origin_api);
});
}

std::vector<std::pair<int, void*> > BackwardInplaceOption(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
const std::vector<int> &out_data,
const std::vector<void*> &in_grad) const override {
return {{out_grad[leakyrelu::kOut], in_grad[leakyrelu::kData]}};
}

std::vector<std::pair<int, void*> > ForwardInplaceOption(
const std::vector<int> &in_data,
const std::vector<void*> &out_data) const override {
if (param_.act_type == leakyrelu::kPReLU) {
return {};
} else {
return {{in_data[leakyrelu::kData], out_data[leakyrelu::kOut]}};
}
}

std::vector<std::string> ListArguments() const override {
if (param_.act_type == leakyrelu::kPReLU) {
return {"data", "gamma"};
} else {
return {"data"};
}
}

std::vector<std::string> ListOutputs() const override {
if (param_.act_type == leakyrelu::kRReLU) {
return {"output", "mask"};
} else {
return {"output"};
}
}

int NumOutputs() const override {
if (param_.act_type == leakyrelu::kRReLU) {
return 2;
} else {
return 1;
}
}

int NumVisibleOutputs() const override {
return 1;
}

std::vector<ResourceRequest> ForwardResource(
const mxnet::ShapeVector &in_shape) const override {
if (param_.act_type == leakyrelu::kRReLU) {
return {ResourceRequest::kRandom};
} else {
return std::vector<ResourceRequest>();
}
}
template<typename xpu>
void LeakyReLUGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
const std::vector<TBlob> no_use_but_adapt_origin_api;
// inputs: out_grad, input_data, input_gamma, output, output_mask
size_t expected_in = param.act_type == leakyrelu::kPReLU ? 2 : 1;
size_t expected_out = param.act_type == leakyrelu::kRReLU ? 2 : 1;

std::vector<ResourceRequest> BackwardResource(
const mxnet::ShapeVector &in_shape) const override {
return {ResourceRequest::kTempSpace};
}
CHECK_GE(inputs.size(), 1 + expected_in + expected_out);
std::vector<TBlob> out_grad{inputs[0]};
xinyu-intel marked this conversation as resolved.
Show resolved Hide resolved
std::vector<TBlob> in_data(inputs.begin() + 1,
inputs.begin() + 1 + expected_in);
std::vector<TBlob> out_data(inputs.begin() + 1 + expected_in,
inputs.begin() + 1 + expected_in + expected_out);

Operator* CreateOperator(Context ctx) const override {
LOG(FATAL) << "Not Implemented.";
return NULL;
}
CHECK_EQ(req.size(), outputs.size());
int dtype = inputs[0].type_flag_;
const std::vector<TBlob> &in_grad = outputs;

Operator* CreateOperatorEx(Context ctx, mxnet::ShapeVector *in_shape,
std::vector<int> *in_type) const override;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
LeakyReLUOp<xpu, DType> op(param);
op.Backward(ctx, out_grad, in_data, out_data, req, in_grad, no_use_but_adapt_origin_api);
});
}

private:
LeakyReLUParam param_;
};
#endif // DMLC_USE_CXX11
} // namespace op
} // namespace mxnet

Expand Down
Loading