diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h new file mode 100644 index 000000000000..75022ba24c06 --- /dev/null +++ b/src/operator/nn/group_norm-inl.h @@ -0,0 +1,341 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file group_norm-inl.h + * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494). + * \author Hao Jin +*/ + +#ifndef MXNET_OPERATOR_NN_GROUP_NORM_INL_H_ +#define MXNET_OPERATOR_NN_GROUP_NORM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./moments-inl.h" +#include "../mshadow_op.h" +#include "../operator_common.h" +#include "../mxnet_op.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +namespace groupnorm { +enum GroupNormOpInputs {kData, kGamma, kBeta}; // kGamma: scaling parameters, kBeta: shift biases +enum GroupNormOpOutputs {kOut, kMean, kStd}; // req, out_data +} // namespace groupnorm + +struct GroupNormParam : public dmlc::Parameter { + int num_groups; + float eps; + bool output_mean_var; + DMLC_DECLARE_PARAMETER(GroupNormParam) { + DMLC_DECLARE_FIELD(num_groups).set_default(1) + .describe("Total number of groups."); + DMLC_DECLARE_FIELD(eps).set_default(1e-5f) + .describe("An `epsilon` parameter to prevent division by 0."); + DMLC_DECLARE_FIELD(output_mean_var).set_default(false) + .describe("Output the mean and std calculated along the given axis."); + } +}; + + +template +void GroupNormCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const GroupNormParam& param = nnvm::get(attrs.parsed); + const int num_groups = param.num_groups; + if (req[0] == kNullOp) return; + CHECK_NE(req[0], kAddTo); + + Stream *s = ctx.get_stream(); + const TBlob& data = inputs[groupnorm::kData]; + const TBlob& mean = outputs[groupnorm::kMean]; + const TBlob& std = outputs[groupnorm::kStd]; + const mxnet::TShape& data_shape = data.shape_; + CHECK_GE(data_shape.ndim(), 3U) + << "input should have at least 3 dims and " + << "the first 2 dims should be batch and channel respectively"; + CHECK_EQ(data_shape[1] % num_groups, 0) + << "number of channel should be divisible by num_groups."; + + mxnet::TShape temp_data_shape(data_shape.ndim() + 1, 1); + temp_data_shape[0] = data_shape[0]; + temp_data_shape[1] = num_groups; + temp_data_shape[2] = data_shape[1] / num_groups; + for (int i = 2; i < data_shape.ndim(); ++i) { + temp_data_shape[i+1] = data_shape[i]; + } + + TBlob data_ = data.reshape(temp_data_shape); + const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape); + + mxnet::TShape axes(temp_data_shape.ndim() - 2, 2); + for (int i = 0; i < axes.ndim(); ++i) { + axes[i] = i + 2; + } + MomentsForwardImpl( + ctx, {data_}, {req[1], req[2]}, {mean, std}, dmlc::optional(axes), false); + // Now std actually holds var + + mxnet::TShape moments_shape(temp_data_shape.ndim(), 1); + for (int i = 0; i < data.shape_.ndim(); ++i) { + moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1; + } + const TBlob& mean_ = mean.reshape(moments_shape); + const TBlob& std_ = std.reshape(moments_shape); + // Calculate data = data - mean + BinaryBroadcastCompute( + attrs, ctx, {data_, mean_}, {kWriteTo}, {output}); + + // Calculate std = sqrt(var + eps) + MSHADOW_REAL_TYPE_SWITCH(std.type_flag_, DType, { + Tensor std_tensor = std.FlatTo1D(s); + std_tensor = F(std_tensor + scalar(param.eps)); + }); + + // Calculate data = data / std + BinaryBroadcastCompute(attrs, ctx, + {output, std_}, + {kWriteTo}, {output}); + + mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1); + new_param_shape[1] = num_groups; + + const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape); + const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape); + + // Calculate data = data * gamma + BinaryBroadcastCompute(attrs, ctx, + {output, gamma}, + {kWriteTo}, {output}); + // Calculate data = data + beta + BinaryBroadcastCompute(attrs, ctx, + {output, beta}, + {kWriteTo}, {output}); +} + +/* +Calculate the gradient of layer normalization. +We have the following gradient for gamma, beta and x: + +\bar{x} = (x - mean) / std +w = og * r / std +grad_gamma = sum(\bar{x} og, exclude_axis) +grad_beta = sum(og, exclude_axis) +grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis) +*/ +template +void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 5U); + CHECK_EQ(outputs.size(), 3U); + const GroupNormParam& param = nnvm::get(attrs.parsed); + const int num_groups = param.num_groups; + + const TBlob& data = inputs[1]; + const mxnet::TShape& dshape = data.shape_; + + mxnet::TShape temp_dshape(dshape.ndim() + 1, 1); + temp_dshape[0] = dshape[0]; + temp_dshape[1] = num_groups; + temp_dshape[2] = dshape[1] / num_groups; + for (int i = 2; i < dshape.ndim(); ++i) { + temp_dshape[i+1] = dshape[i]; + } + const TBlob& data_ = data.reshape(temp_dshape); + const TBlob& ograd = inputs[0].reshape(temp_dshape); + + Stream *s = ctx.get_stream(); + // Reshape gamma to be broadcastable + mxnet::TShape new_param_shape(dshape.ndim() + 1, 1); + new_param_shape[1] = num_groups; + + const TBlob& gamma = inputs[2].reshape(new_param_shape); + + const TBlob& mean = inputs[3]; + const TBlob& std = inputs[4]; + + mxnet::TShape moments_shape(temp_dshape.ndim(), 1); + for (int i = 0; i < dshape.ndim(); ++i) { + moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1; + } + const TBlob& mean_ = mean.reshape(moments_shape); + const TBlob& std_ = std.reshape(moments_shape); + + std::cout << ograd.shape_ << std::endl; + std::cout << temp_dshape << std::endl; + std::cout << new_param_shape << std::endl; + std::cout << mean.shape_ << std::endl; + std::cout << mean_.shape_ << std::endl; + + // Prepare the necessary shapes for reduction + mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape; + std::cout << "before" << std::endl; + BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape); + std::cout << "between" << std::endl; + BroadcastReduceShapeCompact(temp_dshape, gamma.shape_, + &red_exclude_src_shape, &red_exclude_dst_shape); + std::cout << red_src_shape << std::endl; + std::cout << red_dst_shape << std::endl; + std::cout << red_exclude_src_shape << std::endl; + std::cout << red_exclude_dst_shape << std::endl; + + int N = red_src_shape.Size() / red_dst_shape.Size(); + + // Initialize the workspace + Construct the temporary TBlobs + Tensor workspace; + size_t reduce_workspace_size = 0; + size_t data_size = 0; + size_t red_out_size = 0; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + data_size = sizeof(DType) * data.Size(); + red_out_size = sizeof(DType) * mean.Size(); + // There are two types of reduction workloads: reduce over axis and reduce exclude axis + // We take the maximum of the workspace sizes required by these workloads. + // Also, we explicitly set the req_type=kAddto in case we want to use it. + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_src_shape, + kAddTo, red_dst_shape)); + }); + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_exclude_src_shape, kAddTo, + red_exclude_dst_shape)); + }); + }); + workspace = ctx.requested[0].get_space_typed( + Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); + const TBlob normalized_data = + TBlob(workspace.dptr_ + reduce_workspace_size, + data_.shape_, data.dev_mask(), data.type_flag_, data.dev_id()); + const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size, + data_.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id()); + const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, + mean_.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastCompute(attrs, ctx, + {data_, mean_}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std_}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + }); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + }); + }); + } + + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + std::cout << req[0] << std::endl; + std::cout << "before mul" << std::endl; + std::cout << "N is " << N << std::endl; + const TBlob output_ = outputs[0].reshape(data_.shape_); + std::cout << "here0" << std::endl; + BinaryBroadcastCompute(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, std_}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(N); + }); + std::cout << "here1" << std::endl; + std::cout << ograd_mult.shape_ << std::endl; + std::cout << red_out.shape_ << std::endl; + std::cout << normalized_data.shape_ << std::endl; + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {output_}); + ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(-N); + }); + std::cout << "here2" << std::endl; + std::cout << normalized_data.shape_ << std::endl; + std::cout << red_out.shape_ << std::endl; + std::cout << output_.shape_ << std::endl; + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {output_}); + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NN_GROUP_NORM_INL_H_ diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc new file mode 100644 index 000000000000..7d4b4fec1cf7 --- /dev/null +++ b/src/operator/nn/group_norm.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file group_norm.cc + * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494). +*/ + +#include "group_norm-inl.h" +#include +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(GroupNormParam); + +static bool GroupNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + const GroupNormParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; + const mxnet::TShape &dshape = in_shape->at(groupnorm::kData); + CHECK_GE(dshape.ndim(), 3U); + const int num_groups = param.num_groups; + CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups"; + + if (!mxnet::ndim_is_known(dshape)) { + return false; + } + + in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups)); + in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups)); + + out_shape->clear(); + out_shape->push_back(dshape); + + mxnet::TShape moments_shape(2, 1); + moments_shape[0] = dshape[0]; + moments_shape[1] = num_groups; + out_shape->push_back(moments_shape); + out_shape->push_back(moments_shape); + return true; +} + +NNVM_REGISTER_OP(GroupNorm) +.describe(R"code(Group normalization. + +Normalizes the channels of the input tensor by mean and variance, and applies a scale ``gamma`` as +well as offset ``beta``. + +Assume the input has more than one dimension and we normalize along axis 1. +We first compute the mean and variance along this axis and then +compute the normalized output, which has the same shape as input, as following: + +.. math:: + + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta + +Both ``gamma`` and ``beta`` are learnable parameters. + +Unlike BatchNorm and InstanceNorm, the *mean* and *var* are computed along the channel dimension. + +Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` +have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and +``data_std``. Note that no gradient will be passed through these two outputs. + +The parameter ``axis`` specifies which axis of the input shape denotes +the 'channel' (separately normalized groups). The default is -1, which sets the channel +axis to be the last item in the input shape. + +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "gamma", "beta"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "mean", "std"}; +}) +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + const GroupNormParam& param = nnvm::get(attrs.parsed); + return param.output_mean_var ? 3 : 1; +}) +.set_attr("FInferShape", GroupNormShape) +.set_attr("FInferType", ElemwiseType<3, 3>) +.set_attr("FCompute", GroupNormCompute) +.set_attr("FGradient", [](const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector heads; + heads.push_back(ograds[0]); // ograd + heads.push_back(n->inputs[0]); // data + heads.push_back(n->inputs[1]); // gamma + heads.emplace_back(nnvm::NodeEntry{n, 1, 0}); // mean + heads.emplace_back(nnvm::NodeEntry{ n, 2, 0 }); // std + return MakeGradNode("_backward_GroupNorm", n, heads, n->attrs.dict); +}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; +}) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input data") +.add_argument("gamma", "NDArray-or-Symbol", "gamma array") +.add_argument("beta", "NDArray-or-Symbol", "beta array") +.add_arguments(GroupNormParam::__FIELDS__()); + + +NNVM_REGISTER_OP(_backward_GroupNorm) +.set_num_inputs(5) +.set_num_outputs(3) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser) +.set_attr("FCompute", GroupNormGradCompute) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/group_norm.cu b/src/operator/nn/group_norm.cu new file mode 100644 index 000000000000..136c3337468c --- /dev/null +++ b/src/operator/nn/group_norm.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file group_norm.cu + * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494). +*/ +#include "./group_norm-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(GroupNorm) +.set_attr("FCompute", GroupNormCompute); + +NNVM_REGISTER_OP(_backward_GroupNorm) +.set_attr("FCompute", GroupNormGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/moments-inl.h b/src/operator/nn/moments-inl.h new file mode 100644 index 000000000000..6a9bdc54b905 --- /dev/null +++ b/src/operator/nn/moments-inl.h @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments-inl.h + * \brief Moments operator + * \author Hao Jin +*/ + +#ifndef MXNET_OPERATOR_NN_MOMENTS_INL_H_ +#define MXNET_OPERATOR_NN_MOMENTS_INL_H_ + +#include +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct MomentsParam : public dmlc::Parameter { + dmlc::optional axes; + bool keepdims; + DMLC_DECLARE_PARAMETER(MomentsParam) { + DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional()) + .describe("Array of ints. Axes along which to compute mean and variance."); + DMLC_DECLARE_FIELD(keepdims).set_default(false) + .describe("produce moments with the same dimensionality as the input."); + } +}; + +inline bool MomentsShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + const MomentsParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + + mxnet::TShape out_shape = + ReduceAxesShapeImpl((*in_attrs)[0], param.axes, param.keepdims, false); + if (!param.axes.has_value() || param.axes.value().ndim() == 0) { + LOG(FATAL) << "Empty axes is not supported, if you would like to do global moments, " + << "please pass all axes to axes argument"; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_shape); + return true; +} + +inline bool MomentsType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 2U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(1)); + return out_attrs->at(0) != -1 && out_attrs->at(1) != -1; +} + +struct VarBroadcastKernel { + template + MSHADOW_XINLINE static void Map(int i, + DType *out, + const DType *data, + const DType *mean, + mshadow::Shape<6> data_shape, + mshadow::Shape<6> mean_shape) { + size_t data_idx = i; + size_t mean_idx = i; + size_t data_stride = 1; + size_t mean_stride = 1; + for (int axis = 5; axis >= 0; --axis) { + size_t axis_idx = data_idx % data_shape[axis]; + mean_idx -= axis_idx * data_stride; + if (mean_shape[axis] != 1) { + mean_idx += axis_idx * mean_stride; + } + data_idx /= data_shape[axis]; + data_stride *= data_shape[axis]; + mean_stride *= mean_shape[axis]; + } + DType res = (data[i] - mean[mean_idx]); + out[i] = res * res; + } +}; + +template +inline void MomentsForwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const dmlc::optional& axes, + const bool keepdims) { + using namespace mshadow; + using namespace mshadow_op; + using namespace mxnet_op; + + Stream *s = ctx.get_stream(); + + const TBlob& data = inputs[0]; + const TBlob& mean = outputs[0]; + const TBlob& var = outputs[1]; + + mxnet::TShape small; + if (keepdims) { + small = outputs[0].shape_; + } else { + small = ReduceAxesShapeImpl(inputs[0].shape_, axes, true, false); + } + + ReduceAxesComputeImpl(ctx, {data}, {req[0]}, {mean}, small); + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Shape<6> data_shape, mean_shape; + for (int i = 0; i < 6; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + mean_shape[i] = (i < small.ndim()) ? small[i] : 1; + } + Tensor temp_data = + ctx.requested[0].get_space_typed(Shape1(data.shape_.Size()), s);; + Kernel::Launch(s, data.shape_.Size(), temp_data.dptr_, + data.dptr(), mean.dptr(), data_shape, mean_shape); + ReduceAxesComputeImpl( + ctx, {TBlob(temp_data).reshape(data.shape_)}, {kWriteTo}, {var}, small); + }); +} + +template +inline void MomentsForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 2U); + + const MomentsParam& param = nnvm::get(attrs.parsed); + + MomentsForwardImpl(ctx, inputs, req, outputs, param.axes, param.keepdims); +} + +template +struct VarBackwardKernel { + template + MSHADOW_XINLINE static void Map(int i, + DType *igrad, + const DType *ograd, + const DType *data, + const DType *mean, + mshadow::Shape<6> data_shape, + mshadow::Shape<6> mean_shape, + const float N, + const float ddof = 0.0f) { + size_t data_idx = i; + size_t mean_idx = i; + size_t data_stride = 1; + size_t mean_stride = 1; + for (int axis = 5; axis >= 0; --axis) { + size_t axis_idx = data_idx % data_shape[axis]; + mean_idx -= axis_idx * data_stride; + if (mean_shape[axis] != 1) { + mean_idx += axis_idx * mean_stride; + } + data_idx /= data_shape[axis]; + data_stride *= data_shape[axis]; + mean_stride *= mean_shape[axis]; + } + KERNEL_ASSIGN(igrad[i], req, ograd[mean_idx] * (data[i] - mean[mean_idx]) * 2 / (N - ddof)); + } +}; + +template +inline void MomentsBackwardImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const dmlc::optional& axes) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mshadow_op; + using namespace mxnet_op; + + Stream *s = ctx.get_stream(); + + const TBlob& mean_grad = inputs[0]; + const TBlob& var_grad = inputs[1]; + const TBlob& data = inputs[2]; + const TBlob& mean = inputs[3]; + const TBlob& var = inputs[4]; + const TBlob& data_grad = outputs[0]; + + mxnet::TShape small = ReduceAxesShapeImpl(data.shape_, axes, true, false); + BroadcastComputeImpl(attrs, ctx, {mean_grad}, req, outputs, small); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Tensor igrad = outputs[0].FlatTo1D(s); + igrad /= scalar(outputs[0].Size()/inputs[0].Size()); + }); + + Shape<6> data_shape, var_shape; + float N = data_grad.Size() / var.Size(); + for (int i = 0; i < 6; ++i) { + data_shape[i] = (i < data.shape_.ndim()) ? data.shape_[i] : 1; + var_shape[i] = (i < small.ndim()) ? small[i] : 1; + } + MSHADOW_TYPE_SWITCH(data_grad.type_flag_, DType, { + Kernel, xpu>::Launch( + s, data_grad.shape_.Size(), data_grad.dptr(), var_grad.dptr(), + data.dptr(), mean.dptr(), data_shape, var_shape, N); + }); +} + +template +inline void MomentsBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow_op; + using namespace mxnet_op; + + CHECK_EQ(inputs.size(), 5U); + CHECK_EQ(outputs.size(), 1U); + + const MomentsParam& param = nnvm::get(attrs.parsed); + + MomentsBackwardImpl(attrs, ctx, inputs, req, outputs, param.axes); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NN_MOMENTS_INL_H_ diff --git a/src/operator/nn/moments.cc b/src/operator/nn/moments.cc new file mode 100644 index 000000000000..37b8cdf18750 --- /dev/null +++ b/src/operator/nn/moments.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments.cc + * \brief Moments operator + * \author Hao Jin +*/ + +#include "./moments-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(MomentsParam); + +NNVM_REGISTER_OP(moments) +.describe(R"code( +Calculate the mean and variance of `data`. + +The mean and variance are calculated by aggregating the contents of data across axes. +If x is 1-D and axes = [0] this is just the mean and variance of a vector. + +Example: + + x = [[1, 2, 3], [4, 5, 6]] + mean, var = moments(data=x, axes=[0]) + mean = [2.5, 3.5, 4.5] + var = [2.25, 2.25, 2.25] + mean, var = moments(data=x, axes=[1]) + mean = [2.0, 5.0] + var = [0.66666667, 0.66666667] + mean, var = moments(data=x, axis=[0, 1]) + mean = [3.5] + var = [2.9166667] + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(2) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) +.set_attr("FInferShape", MomentsShape) +.set_attr("FInferType", MomentsType) +.set_attr("FCompute", MomentsForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", ElemwiseGradUseInOut{"_backward_moments"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(MomentsParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_moments) +.set_attr_parser(ParamParser) +.set_num_inputs(5) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FCompute", MomentsBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/moments.cu b/src/operator/nn/moments.cu new file mode 100644 index 000000000000..a45ae33281be --- /dev/null +++ b/src/operator/nn/moments.cu @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file moments.cu + * \brief Moments operator + * \author Hao Jin +*/ + +#include "./moments-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(moments) +.set_attr("FCompute", MomentsForward); + +NNVM_REGISTER_OP(_backward_moments) +.set_attr("FCompute", MomentsBackward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7db07596d7f8..e04030b710a8 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1786,6 +1786,98 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): cudnn_off, output_mean_var) +@with_seed() +def test_groupnorm(): + acc_types = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64'} + def x_hat_helper(x, num_groups, eps): + dtype = x.dtype + dshape = x.shape + assert len(dshape) == 4 + acc_type = acc_types[str(dtype)] + new_shape = (dshape[0], num_groups, int(dshape[1] / num_groups), dshape[2], dshape[3]) + new_moments_shape = (dshape[0], num_groups, 1, 1, 1) + data = x.reshape(new_shape) + mean = np.mean(data, axis=(2, 3, 4), keepdims=False, dtype=acc_type).astype(dtype) + std = np.sqrt(np.var(data, axis=(2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + eps) + x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape) + return x_hat, mean, std + + def np_groupnorm(data, gamma, beta, num_groups, eps): + new_param_shape = (1, num_groups, 1, 1, 1) + x_hat, mean, std = x_hat_helper(data, num_groups, eps) + out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) + return out.reshape(dshape), mean, std + + def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): + x_hat, mean, std = x_hat_helper(data, num_groups, eps) + new_shape = x_hat.shape + dshape = data.shape + dtype = data.dtype + new_moments_shape = (new_shape[0], num_groups, 1, 1, 1) + new_param_shape = (1, num_groups, 1, 1, 1) + acc_type = acc_types[str(dtype)] + ograd = ograd.reshape(new_shape) + data = data.reshape(new_shape) + gamma = gamma.reshape(new_param_shape) + beta = beta.reshape(new_param_shape) + mean = mean.reshape(new_moments_shape) + std = std.reshape(new_moments_shape) + beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + x_hat_grad = ograd * gamma + var_grad = -x_hat_grad * x_hat / 2. / std / std + mean_grad = -x_hat_grad / std + N = data.size / mean.size + data_grad = x_hat_grad / std + mean_grad / N + var_grad * 2. * (data - mean) / N + return data_grad.reshape(dshape), gamma_grad, beta_grad + + + batch_size = random.randint(1, 8) + num_groups = random.randint(2, 3) + num_channels = random.randint(2, 3) * num_groups + height = random.randint(1, 5) + width = random.randint(1, 5) + dshape = (batch_size, num_channels, height, width) + param_shape = (num_groups,) + temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width) + np_data = np.random.uniform(0.2, 1.0, dshape) + np_gamma = np.random.uniform(-1.0, 1.0, param_shape) + np_beta = np.random.uniform(-1.0, 1.0, param_shape) + data_sym = mx.sym.Variable("data") + gamma_sym = mx.sym.Variable("gamma") + beta_sym = mx.sym.Variable("beta") + # for dtype in [np.float16]: + for dtype in [np.float64]:# np.float32, np.float64]: + eps = 1e-2 if dtype == np.float16 else 1e-5 + # rtol = 1e-2 if dtype == np.float16 else 1e-4 + # atol = 4e-3 if dtype == np.float16 else 1e-6 + rtol = 1e-2 if dtype == np.float16 else 1e-3 + atol = 2e-2 if dtype == np.float16 else 1e-5 + mx_data = mx.nd.array(np_data, dtype=dtype) + mx_gamma = mx.nd.array(np_gamma, dtype=dtype) + mx_beta = mx.nd.array(np_beta, dtype=dtype) + np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype), + np_gamma.astype(dtype), + np_beta.astype(dtype), + num_groups=num_groups, + eps=eps) + mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, + num_groups=num_groups, eps=eps, output_mean_var=True) + check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std], + rtol=rtol, atol=atol, dtype=dtype) + mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, + num_groups=num_groups, eps=eps, output_mean_var=False) + np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np.ones(np_out.shape, dtype=dtype), + np_data.astype(dtype), + np_gamma.astype(dtype), + np_beta.astype(dtype), + np_mean, np_std, + num_groups, eps) + check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.ones(dshape)], + [np_data_grad, np_gamma_grad, np_beta_grad], + rtol=rtol, atol=atol, dtype=dtype) + + @with_seed() def test_convolution_grouping(): for dim in [1, 2, 3]: