Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor L2_normalization (#13059)
Browse files Browse the repository at this point in the history
* Refactor L2_normalization

* Fix windows build

* Fix windows build

* Move cpu optimization into l2_normalization.cc

* Retrigger CI

* Retrigger CI
  • Loading branch information
ZhennanQin authored and szha committed Nov 7, 2018
1 parent 6c4bbd8 commit 36eabfa
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/operator/l2_normalization-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class L2NormalizationOp : public Operator {
}
}

private:
protected:
L2NormalizationParam param_;
}; // class L2NormalizationOp

Expand Down
102 changes: 100 additions & 2 deletions src/operator/l2_normalization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,119 @@
* \brief l2 normalization operator
*/
#include "./l2_normalization-inl.h"

/* VisualStudio only supports openmp 2.0 */
#ifdef _MSC_VER
#define collapse(x)
#endif

namespace mxnet {
namespace op {

template<typename DType>
class L2NormalizationOpCPU : public L2NormalizationOp<cpu, DType> {
public:
explicit L2NormalizationOpCPU(L2NormalizationParam p)
: L2NormalizationOp<cpu, DType>(p) {}
void Forward(const OpContext &ctx, const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) override {
using namespace mshadow;
using namespace mshadow::expr;
if (req[l2_normalization::kOut] == kNullOp) return;
CHECK_EQ(req[l2_normalization::kOut], kWriteTo);
CHECK_EQ(in_data.size(), 1U);
CHECK_EQ(out_data.size(), 2U);
Stream<cpu> *s = ctx.get_stream<cpu>();
TShape orig_shape = in_data[l2_normalization::kData].shape_;
auto omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (this->param_.mode == l2_normalization::kInstance) {
Shape<2> dshape = Shape2(orig_shape[0],
orig_shape.ProdShape(1, orig_shape.ndim()));
Tensor<cpu, 2, DType> data = in_data[l2_normalization::kData]
.get_with_shape<cpu, 2, DType>(dshape, s);
Tensor<cpu, 2, DType> out = out_data[l2_normalization::kOut]
.get_with_shape<cpu, 2, DType>(dshape, s);
Tensor<cpu, 1, DType> norm = out_data[l2_normalization::kNorm].get<cpu, 1, DType>(s);
#pragma omp parallel for num_threads(omp_threads)
for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
norm[shape0] = DType(this->param_.eps);
for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
norm[shape0] += data[shape0][shape1] * data[shape0][shape1];
}
norm[shape0] = std::sqrt(norm[shape0]);
for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
out[shape0][shape1] = data[shape0][shape1] / norm[shape0];
}
}
} else if (this->param_.mode == l2_normalization::kChannel) {
CHECK_GE(orig_shape.ndim(), 3U);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
.get_with_shape<cpu, 3, DType>(dshape, s);
Tensor<cpu, 3, DType> out = out_data[l2_normalization::kOut]
.get_with_shape<cpu, 3, DType>(dshape, s);
Shape<2> norm_shape = Shape2(dshape[0], dshape[2]);
Tensor<cpu, 2, DType> norm = out_data[l2_normalization::kNorm]
.get_with_shape<cpu, 2, DType>(norm_shape, s);
#pragma omp parallel for num_threads(omp_threads) collapse(2)
for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
norm[shape0][shape2] = DType(this->param_.eps);
for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
norm[shape0][shape2] += data[shape0][shape1][shape2] * data[shape0][shape1][shape2];
}
norm[shape0][shape2] = std::sqrt(norm[shape0][shape2]);
for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
out[shape0][shape1][shape2] = data[shape0][shape1][shape2] / norm[shape0][shape2];
}
}
}
} else if (this->param_.mode == l2_normalization::kSpatial) {
CHECK_GE(orig_shape.ndim(), 3U);
Shape<3> dshape = Shape3(orig_shape[0], orig_shape[1],
orig_shape.ProdShape(2, orig_shape.ndim()));
Tensor<cpu, 3, DType> data = in_data[l2_normalization::kData]
.get_with_shape<cpu, 3, DType>(dshape, s);
Tensor<cpu, 3, DType> out = out_data[l2_normalization::kOut]
.get_with_shape<cpu, 3, DType>(dshape, s);
Shape<2> norm_shape = Shape2(dshape[0], dshape[1]);
Tensor<cpu, 2, DType> norm = out_data[l2_normalization::kNorm]
.get_with_shape<cpu, 2, DType>(norm_shape, s);
#pragma omp parallel for num_threads(omp_threads) collapse(2)
for (int shape0 = 0; shape0 < static_cast<int>(dshape[0]); shape0++) {
for (int shape1 = 0; shape1 < static_cast<int>(dshape[1]); shape1++) {
norm[shape0][shape1] = DType(this->param_.eps);
for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
norm[shape0][shape1] += data[shape0][shape1][shape2] * data[shape0][shape1][shape2];
}
norm[shape0][shape1] = std::sqrt(norm[shape0][shape1]);
for (int shape2 = 0; shape2 < static_cast<int>(dshape[2]); shape2++) {
out[shape0][shape1][shape2] = data[shape0][shape1][shape2] / norm[shape0][shape1];
}
}
}
} else {
LOG(FATAL) << "Unexpected mode in l2 normalization";
}
}
};

template<>
Operator* CreateOp<cpu>(L2NormalizationParam param, int dtype) {
Operator* op = nullptr;
MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
op = new L2NormalizationOp<cpu, DType>(param);
op = new L2NormalizationOpCPU<DType>(param);
});
return op;
}

// DO_BIND_DISPATCH comes from static_operator_common.h
Operator* L2NormalizationProp::CreateOperatorEx(Context ctx, std::vector<TShape> *in_shape,
std::vector<int> *in_type) const {
DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
DO_BIND_DISPATCH(CreateOp, this->param_, in_type->at(0));
}

DMLC_REGISTER_PARAMETER(L2NormalizationParam);
Expand Down

0 comments on commit 36eabfa

Please sign in to comment.