Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[OP] Accelerate GPU version of LayerNorm(axis=-1) #14935

Merged
merged 1 commit into from
May 21, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 17 additions & 6 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,17 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
}
};


template<typename xpu>
void LayerNormCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const std::vector<TBlob>& outputs);

template<typename xpu>
void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
Expand Down Expand Up @@ -146,6 +151,12 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs,
{kWriteTo}, {outputs[0]});
}

template<typename xpu>
void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs);

/*
Calculate the gradient of layer normalization.
We have the following gradient for gamma, beta and x:
Expand All @@ -157,10 +168,10 @@ grad_beta = sum(og, exclude_axis)
grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis)
*/
template<typename xpu>
void LayerNormGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
void LayerNormGradComputeGeneral(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 5U);
Expand Down
16 changes: 16 additions & 0 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,22 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
}


template<>
void LayerNormCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
return LayerNormComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}

template<>
void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
return LayerNormGradComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}

NNVM_REGISTER_OP(LayerNorm)
.describe(R"code(Layer normalization.

Expand Down
Loading