diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h new file mode 100644 index 000000000000..e2aa3f58e472 --- /dev/null +++ b/src/operator/nn/moments-inl.h @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments-inl.h + * \brief Moments operator + * \author Hao Jin +*/ + +#ifndef MXNET_OPERATOR_NN_MOMENTS_INL_H_ +#define MXNET_OPERATOR_NN_MOMENTS_INL_H_ + +#include +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct MomentsParam : public dmlc::Parameter { + dmlc::optional axes; + bool keepdims; + DMLC_DECLARE_PARAMETER(MomentsParam) { + DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional()) + .describe("Array of ints. Axes along which to compute mean and variance."); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("produce moments with the same dimensionality as the input."); + } +}; + +inline bool MomentsShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + const MomentsParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + + mxnet::TShape out_shape = + ReduceAxesShapeImpl((*in_attrs)[0], param.axes, param.keepdims, false); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_shape); + return true; +} + +inline bool MomentsType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(1)); + return out_attrs->at(0) != -1 && out_attrs->at(1) != -1; +} + +struct VarBroadcastKernel { + template + MSHADOW_XINLINE static void Map(int i, + DType *out, + const DType *data, + const DType *mean, + mshadow::Shape<5> data_shape, + mshadow::Shape<5> mean_shape) { + size_t data_idx = i; + size_t mean_idx = i; + size_t data_stride = 1; + size_t mean_stride = 1; + for (int axis = 4; axis >= 0; --axis) { + size_t axis_idx = data_idx % data_shape[axis]; + mean_idx -= axis_idx * data_stride; + if (mean_shape[axis] != 1) { + mean_idx += axis_idx * mean_stride; + } + data_idx /= data_shape[axis]; + data_stride *= data_shape[axis]; + mean_stride *= mean_shape[axis]; + } + DType res = (data[i] - mean[mean_idx]); + out[i] = res * res; + } +}; + +template +inline void MomentsForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + + const MomentsParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + + const TBlob& data = inputs[0]; + const TBlob& mean = outputs[0]; + const TBlob& var = outputs[1]; + + mxnet::TShape small; + if (param.keepdims) { + small = outputs[0].shape_; + } else { + small = ReduceAxesShapeImpl(inputs[0].shape_, param.axes, true, false); + } + + ReduceAxesComputeImpl(ctx, {data}, {req[0]}, {mean}, small); + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Shape<5> data_shape, mean_shape; + for (int i = 0; i < 5; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + mean_shape[i] = (i < small.ndim()) ? small[i] : 1; + } + Tensor temp_data = + ctx.requested[0].get_space_typed(Shape1(data.shape_.Size()), s);; + Kernel::Launch(s, data.shape_.Size(), temp_data.dptr_, + data.dptr(), mean.dptr(), data_shape, mean_shape); + ReduceAxesComputeImpl( + ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); + }); +} + +template +struct VarBackwardKernel { + template + MSHADOW_XINLINE static void Map(int i, + DType *igrad, + const DType *ograd, + const DType *data, + const DType *mean, + mshadow::Shape<5> data_shape, + mshadow::Shape<5> mean_shape, + const float N, + const float ddof = 0.0f) { + size_t data_idx = i; + size_t mean_idx = i; + size_t data_stride = 1; + size_t mean_stride = 1; + for (int axis = 4; axis >= 0; --axis) { + size_t axis_idx = data_idx % data_shape[axis]; + mean_idx -= axis_idx * data_stride; + if (mean_shape[axis] != 1) { + mean_idx += axis_idx * mean_stride; + } + data_idx /= data_shape[axis]; + data_stride *= data_shape[axis]; + mean_stride *= mean_shape[axis]; + } + KERNEL_ASSIGN(igrad[i], req, ograd[mean_idx] * (data[i] - mean[mean_idx]) * 2 / (N - ddof)); + } +}; + +template +inline void MomentsBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 5U); + CHECK_EQ(outputs.size(), 1U); + + const MomentsParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + + const TBlob& mean_grad = inputs[0]; + const TBlob& var_grad = inputs[1]; + const TBlob& data = inputs[2]; + const TBlob& mean = inputs[3]; + const TBlob& var = inputs[4]; + const TBlob& data_grad = outputs[0]; + + mxnet::TShape small = ReduceAxesShapeImpl(data.shape_, param.axes, true, false); + BroadcastComputeImpl(attrs, ctx, {mean_grad}, req, outputs, small); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor igrad = outputs[0].FlatTo1D(s); + igrad /= scalar(outputs[0].Size()/inputs[0].Size()); + }); + + Shape<5> data_shape, var_shape; + float N = data_grad.Size() / var.Size(); + for (int i = 0; i < 5; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + var_shape[i] = (i < small.ndim()) ? small[i] : 1; + } + MSHADOW_TYPE_SWITCH(data_grad.type_flag_, DType, { + Kernel, xpu>::Launch( + s, data_grad.shape_.Size(), data_grad.dptr(), var_grad.dptr(), + data.dptr(), mean.dptr(), data_shape, var_shape, N); + }); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NN_MOMENTS_INL_H_ diff --git a/src/operator/nn/moments.cc b/src/operator/nn/moments.cc new file mode 100644 index 000000000000..37b8cdf18750 --- /dev/null +++ b/src/operator/nn/moments.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments.cc + * \brief Moments operator + * \author Hao Jin +*/ + +#include "./moments-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MomentsParam); + +NNVM_REGISTER_OP(moments) +.describe(R"code( +Calculate the mean and variance of `data`. + +The mean and variance are calculated by aggregating the contents of data across axes. +If x is 1-D and axes = [0] this is just the mean and variance of a vector. + +Example: + + x = [[1, 2, 3], [4, 5, 6]] + mean, var = moments(data=x, axes=[0]) + mean = [2.5, 3.5, 4.5] + var = [2.25, 2.25, 2.25] + mean, var = moments(data=x, axes=[1]) + mean = [2.0, 5.0] + var = [0.66666667, 0.66666667] + mean, var = moments(data=x, axis=[0, 1]) + mean = [3.5] + var = [2.9166667] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(2) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", MomentsShape) +.set_attr("FInferType", MomentsType) +.set_attr("FCompute", MomentsForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_moments"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(MomentsParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_moments) +.set_attr_parser(ParamParser) +.set_num_inputs(5) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", MomentsBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/moments.cu b/src/operator/nn/moments.cu new file mode 100644 index 000000000000..a45ae33281be --- /dev/null +++ b/src/operator/nn/moments.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments.cu + * \brief Moments operator + * \author Hao Jin +*/ + +#include "./moments-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(moments) +.set_attr("FCompute", MomentsForward); + +NNVM_REGISTER_OP(_backward_moments) +.set_attr("FCompute", MomentsBackward); + +} // namespace op +} // namespace mxnet