diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 9a587dfa73c1..c6cc3d1b1f00 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -471,6 +471,7 @@ 'log_softmax', 'InstanceNorm', 'LayerNorm', + 'GroupNorm', 'L2Normalization', 'LRN', 'SoftmaxActivation', diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 3d6976c32740..b1482ce6dd82 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -19,7 +19,8 @@ # pylint: disable= arguments-differ """Basic neural network layers.""" __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding', - 'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 'HybridLambda'] + 'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm', + 'Flatten', 'Lambda', 'HybridLambda'] import warnings import numpy as np @@ -616,6 +617,94 @@ def __repr__(self): for k, v in self._kwargs.items()])) +class GroupNorm(HybridBlock): + r""" + Applies group normalization to the n-dimensional input array. + This operator takes an n-dimensional input array where the leftmost 2 axis are + `batch` and `channel` respectively: + + .. math:: + + x = x.reshape((N, num_groups, C // num_groups, ...)) + axis = (2, ...) + out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma + beta + + Parameters + ---------- + num_groups: int, default 1 + Number of groups to separate the channel axis into. + epsilon: float, default 1e-5 + Small float added to variance to avoid dividing by zero. + center: bool, default True + If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: bool, default True + If True, multiply by `gamma`. If False, `gamma` is not used. + beta_initializer: str or `Initializer`, default 'zeros' + Initializer for the beta weight. + gamma_initializer: str or `Initializer`, default 'ones' + Initializer for the gamma weight. + + + Inputs: + - **data**: input tensor with shape (N, C, ...). + + Outputs: + - **out**: output tensor with the same shape as `data`. + + References + ---------- + `Group Normalization + `_ + + Examples + -------- + >>> # Input of shape (2, 3, 4) + >>> x = mx.nd.array([[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]]) + >>> # Group normalization is calculated with the above formula + >>> layer = GroupNorm() + >>> layer.initialize(ctx=mx.cpu(0)) + >>> layer(x) + [[[-1.5932543 -1.3035717 -1.0138891 -0.7242065] + [-0.4345239 -0.1448413 0.1448413 0.4345239] + [ 0.7242065 1.0138891 1.3035717 1.5932543]] + [[-1.5932543 -1.3035717 -1.0138891 -0.7242065] + [-0.4345239 -0.1448413 0.1448413 0.4345239] + [ 0.7242065 1.0138891 1.3035717 1.5932543]]] + + """ + def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True, + beta_initializer='zeros', gamma_initializer='ones', + prefix=None, params=None): + super(GroupNorm, self).__init__(prefix=prefix, params=params) + self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale} + self._num_groups = num_groups + self._epsilon = epsilon + self._center = center + self._scale = scale + self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', + shape=(num_groups,), init=gamma_initializer, + allow_deferred_init=True) + self.beta = self.params.get('beta', grad_req='write' if center else 'null', + shape=(num_groups,), init=beta_initializer, + allow_deferred_init=True) + + def hybrid_forward(self, F, data, gamma, beta): + norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon) + return norm_data + + def __repr__(self): + s = '{name}({content})' + return s.format(name=self.__class__.__name__, + content=', '.join(['='.join([k, v.__repr__()]) + for k, v in self._kwargs.items()])) + + class Lambda(Block): r"""Wraps an operator or an expression as a Block object. diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h new file mode 100644 index 000000000000..69d5a304dc2c --- /dev/null +++ b/src/operator/nn/group_norm-inl.h @@ -0,0 +1,347 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file group_norm-inl.h + * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494). + * \author Hao Jin +*/ + +#ifndef MXNET_OPERATOR_NN_GROUP_NORM_INL_H_ +#define MXNET_OPERATOR_NN_GROUP_NORM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "./moments-inl.h" +#include "../mshadow_op.h" +#include "../operator_common.h" +#include "../mxnet_op.h" +#include "../tensor/broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +namespace groupnorm { +enum GroupNormOpInputs {kData, kGamma, kBeta}; // kGamma: scaling parameters, kBeta: shift biases +enum GroupNormOpOutputs {kOut, kMean, kStd}; // req, out_data +} // namespace groupnorm + +struct GroupNormParam : public dmlc::Parameter { + int num_groups; + float eps; + bool output_mean_var; + DMLC_DECLARE_PARAMETER(GroupNormParam) { + DMLC_DECLARE_FIELD(num_groups).set_default(1) + .describe("Total number of groups."); + DMLC_DECLARE_FIELD(eps).set_default(1e-5f) + .describe("An `epsilon` parameter to prevent division by 0."); + DMLC_DECLARE_FIELD(output_mean_var).set_default(false) + .describe("Output the mean and std calculated along the given axis."); + } +}; + + +template +void GroupNormCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const GroupNormParam& param = nnvm::get(attrs.parsed); + const int num_groups = param.num_groups; + if (req[0] == kNullOp) return; + CHECK_NE(req[0], kAddTo); + + Stream *s = ctx.get_stream(); + const TBlob& data = inputs[groupnorm::kData]; + const TBlob& mean = outputs[groupnorm::kMean]; + const TBlob& std = outputs[groupnorm::kStd]; + const mxnet::TShape& data_shape = data.shape_; + CHECK_GE(data_shape.ndim(), 3U) + << "input should have at least 3 dims and " + << "the first 2 dims should be batch and channel respectively"; + CHECK_EQ(data_shape[1] % num_groups, 0) + << "number of channel should be divisible by num_groups."; + + mxnet::TShape temp_data_shape(data_shape.ndim() + 1, 1); + temp_data_shape[0] = data_shape[0]; + temp_data_shape[1] = num_groups; + temp_data_shape[2] = data_shape[1] / num_groups; + for (int i = 2; i < data_shape.ndim(); ++i) { + temp_data_shape[i+1] = data_shape[i]; + } + + mxnet::TShape moments_shape(temp_data_shape.ndim(), 1); + for (int i = 0; i < data.shape_.ndim(); ++i) { + moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1; + } + + mxnet::TShape red_src_shape, red_dst_shape; + BroadcastReduceShapeCompact(temp_data_shape, moments_shape, &red_src_shape, &red_dst_shape); + int channel_size = red_src_shape.Size() / red_dst_shape.Size(); + + TBlob data_ = data.reshape(red_src_shape); + const TBlob& mean_ = mean.reshape(red_dst_shape); + const TBlob& std_ = std.reshape(red_dst_shape); + + Tensor workspace; + + size_t workspace_size = 0; + MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + workspace_size = + broadcast::ReduceWorkspaceSize(s, red_dst_shape, req[0], red_src_shape); + }); + }); + + workspace = ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + + // Calculate mean + MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, mean_, req[0], workspace, data_); + Tensor mean_data_tensor = mean_.FlatTo1D(s); + mean_data_tensor /= scalar(channel_size); + }); + }); + + TBlob data_grp = data.reshape(temp_data_shape); + const TBlob& mean_grp = mean.reshape(moments_shape); + const TBlob& std_grp = std.reshape(moments_shape); + const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape); + + // Calculate data = data - mean + BinaryBroadcastCompute(attrs, ctx, + {data_grp, mean_grp}, + {kWriteTo}, {output}); + + // Calculate std + const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape); + MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, std_, req[0], workspace, centered_out); + Tensor std_data_tensor = std_.FlatTo1D(s); + std_data_tensor = F(std_data_tensor / scalar(channel_size) + + scalar(param.eps)); + }); + }); + + // Calculate data = data / std + BinaryBroadcastCompute(attrs, ctx, + {output, std_grp}, + {kWriteTo}, {output}); + + mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1); + new_param_shape[1] = num_groups; + + const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape); + const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape); + + // Calculate data = data * gamma + BinaryBroadcastCompute(attrs, ctx, + {output, gamma}, + {kWriteTo}, {output}); + // Calculate data = data + beta + BinaryBroadcastCompute(attrs, ctx, + {output, beta}, + {kWriteTo}, {output}); +} + +/* +Calculate the gradient of group normalization. +We have the following gradient for gamma, beta and x: + +\bar{x} = (x - mean) / std +w = og * r / std +grad_gamma = sum(\bar{x} og, exclude_axis) +grad_beta = sum(og, exclude_axis) +grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis) +*/ +template +void GroupNormGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 5U); + CHECK_EQ(outputs.size(), 3U); + const GroupNormParam& param = nnvm::get(attrs.parsed); + const int num_groups = param.num_groups; + + const TBlob& data = inputs[1]; + const mxnet::TShape& dshape = data.shape_; + + mxnet::TShape temp_dshape(dshape.ndim() + 1, 1); + temp_dshape[0] = dshape[0]; + temp_dshape[1] = num_groups; + temp_dshape[2] = dshape[1] / num_groups; + for (int i = 2; i < dshape.ndim(); ++i) { + temp_dshape[i+1] = dshape[i]; + } + const TBlob& data_ = data.reshape(temp_dshape); + const TBlob& ograd = inputs[0].reshape(temp_dshape); + + Stream *s = ctx.get_stream(); + // Reshape gamma to be broadcastable + mxnet::TShape new_param_shape(dshape.ndim() + 1, 1); + new_param_shape[1] = num_groups; + + const TBlob& gamma = inputs[2].reshape(new_param_shape); + + const TBlob& mean = inputs[3]; + const TBlob& std = inputs[4]; + + mxnet::TShape moments_shape(temp_dshape.ndim(), 1); + for (int i = 0; i < dshape.ndim(); ++i) { + moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1; + } + const TBlob& mean_ = mean.reshape(moments_shape); + const TBlob& std_ = std.reshape(moments_shape); + + // Prepare the necessary shapes for reduction + mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape; + BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape); + BroadcastReduceShapeCompact(temp_dshape, gamma.shape_, + &red_exclude_src_shape, &red_exclude_dst_shape); + + int N = red_src_shape.Size() / red_dst_shape.Size(); + + // Initialize the workspace + Construct the temporary TBlobs + Tensor workspace; + size_t reduce_workspace_size = 0; + size_t data_size = 0; + size_t red_out_size = 0; + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + data_size = sizeof(DType) * data.Size(); + red_out_size = sizeof(DType) * mean.Size(); + // There are two types of reduction workloads: reduce over axis and reduce exclude axis + // We take the maximum of the workspace sizes required by these workloads. + // Also, we explicitly set the req_type=kAddto in case we want to use it. + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape)); + }); + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + reduce_workspace_size = + std::max(reduce_workspace_size, + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape)); + }); + }); + workspace = ctx.requested[0].get_space_typed( + Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s); + const TBlob normalized_data = + TBlob(workspace.dptr_ + reduce_workspace_size, + data_.shape_, data.dev_mask(), data.type_flag_, data.dev_id()); + const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + data_size, + data_.shape_, ograd.dev_mask(), ograd.type_flag_, ograd.dev_id()); + const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + data_size * 2, + mean_.shape_, mean.dev_mask(), mean.type_flag_, mean.dev_id()); + // Compute normalized_data = (data - mean) / std + BinaryBroadcastCompute(attrs, ctx, + {data_, mean_}, + {kWriteTo}, {normalized_data}); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, std_}, + {kWriteTo}, {normalized_data}); + // Calculate grad_beta + if (req[2] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace, + ograd.reshape(red_exclude_src_shape)); + }); + }); + } + // Calculate grad_gamma, it will be sum(ograd * normalized_data, exclude_axis) + ElemwiseBinaryOp::Compute(attrs, ctx, {normalized_data, ograd}, + {kWriteTo}, {ograd_mult}); + if (req[1] != kNullOp) { + MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace, + ograd_mult.reshape(red_exclude_src_shape)); + }); + }); + } + + // Calculate grad_data: + // ograd_mult = ograd * gamma / std + // grad_data = ograd_mult - mean(ograd_mult, axis) + // + normalized_data * (-mean(normalized_data * ograd_mult, axis)) + if (req[0] != kNullOp) { + const TBlob output_ = outputs[0].reshape(data_.shape_); + BinaryBroadcastCompute(attrs, ctx, + {ograd, gamma}, + {kWriteTo}, {ograd_mult}); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, std_}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(N); + }); + BinaryBroadcastCompute(attrs, ctx, + {ograd_mult, red_out}, + {req[0]}, {output_}); + ElemwiseBinaryOp::Compute(attrs, ctx, {ograd_mult, normalized_data}, + {kWriteTo}, {ograd_mult}); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { + broadcast::Reduce( + s, red_out.reshape(red_dst_shape), kWriteTo, workspace, + ograd_mult.reshape(red_src_shape)); + }); + Tensor red_out_tensor = red_out.FlatTo1D(s); + red_out_tensor /= scalar(-N); + }); + BinaryBroadcastCompute(attrs, ctx, + {normalized_data, red_out}, + {kAddTo}, {output_}); + } +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NN_GROUP_NORM_INL_H_ diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc new file mode 100644 index 000000000000..b4698abeff83 --- /dev/null +++ b/src/operator/nn/group_norm.cc @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file group_norm.cc + * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494). +*/ + +#include "group_norm-inl.h" +#include +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(GroupNormParam); + +static bool GroupNormShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + const GroupNormParam& param = nnvm::get(attrs.parsed); + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; + const mxnet::TShape &dshape = in_shape->at(groupnorm::kData); + CHECK_GE(dshape.ndim(), 3U); + const int num_groups = param.num_groups; + CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # of groups"; + + if (!mxnet::ndim_is_known(dshape)) { + return false; + } + + in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups)); + in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups)); + + out_shape->clear(); + out_shape->push_back(dshape); + + mxnet::TShape moments_shape(2, 1); + moments_shape[0] = dshape[0]; + moments_shape[1] = num_groups; + out_shape->push_back(moments_shape); + out_shape->push_back(moments_shape); + return true; +} + +NNVM_REGISTER_OP(GroupNorm) +.describe(R"code(Group normalization. + +The input channels are separated into ``num_groups`` groups, each containing ``num_channels / num_groups`` channels. +The mean and standard-deviation are calculated separately over the each group. + +.. math:: + + data = data.reshape((N, num_groups, C // num_groups, ...)) + out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * gamma + beta + +Both ``gamma`` and ``beta`` are learnable parameters. + +)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(3) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "gamma", "beta"}; +}) +.set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "mean", "std"}; +}) +.set_attr("FNumVisibleOutputs", + [](const NodeAttrs& attrs) { + const GroupNormParam& param = nnvm::get(attrs.parsed); + return param.output_mean_var ? 3 : 1; +}) +.set_attr("FInferShape", GroupNormShape) +.set_attr("FInferType", ElemwiseType<3, 3>) +.set_attr("FCompute", GroupNormCompute) +.set_attr("FGradient", [](const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector heads; + heads.push_back(ograds[0]); // ograd + heads.push_back(n->inputs[0]); // data + heads.push_back(n->inputs[1]); // gamma + heads.emplace_back(nnvm::NodeEntry{n, 1, 0}); // mean + heads.emplace_back(nnvm::NodeEntry{ n, 2, 0 }); // std + return MakeGradNode("_backward_GroupNorm", n, heads, n->attrs.dict); +}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; +}) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input data") +.add_argument("gamma", "NDArray-or-Symbol", "gamma array") +.add_argument("beta", "NDArray-or-Symbol", "beta array") +.add_arguments(GroupNormParam::__FIELDS__()); + + +NNVM_REGISTER_OP(_backward_GroupNorm) +.set_num_inputs(5) +.set_num_outputs(3) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser) +.set_attr("FCompute", GroupNormGradCompute) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/nn/group_norm.cu b/src/operator/nn/group_norm.cu new file mode 100644 index 000000000000..136c3337468c --- /dev/null +++ b/src/operator/nn/group_norm.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file group_norm.cu + * \brief Implements Group Normalization (https://arxiv.org/abs/1803.08494). +*/ +#include "./group_norm-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(GroupNorm) +.set_attr("FCompute", GroupNormCompute); + +NNVM_REGISTER_OP(_backward_GroupNorm) +.set_attr("FCompute", GroupNormGradCompute); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index d52e7f8bc832..b59ce2d0864c 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -743,6 +743,15 @@ def test_layernorm(): check_layer_forward(layer, (2, 10, 10, 10)) +@with_seed() +def test_groupnorm(): + layer = nn.GroupNorm() + check_layer_forward(layer, (2, 10, 10, 10)) + layer = nn.GroupNorm(num_groups=2) + check_layer_forward(layer, (2, 10, 10, 10)) + layer = nn.GroupNorm(num_groups=5) + check_layer_forward(layer, (2, 10, 10, 10)) + @with_seed() def test_reflectionpad(): layer = nn.ReflectionPad2D(3) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index aeddc7a893df..749f0f2bed23 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1830,6 +1830,97 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var): cudnn_off, output_mean_var) +@with_seed() +def test_groupnorm(): + acc_types = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64'} + def x_hat_helper(x, num_groups, eps): + dtype = x.dtype + dshape = x.shape + assert len(dshape) == 4 + acc_type = acc_types[str(dtype)] + new_shape = (dshape[0], num_groups, int(dshape[1] / num_groups), dshape[2], dshape[3]) + new_moments_shape = (dshape[0], num_groups, 1, 1, 1) + data = x.reshape(new_shape) + mean = np.mean(data, axis=(2, 3, 4), keepdims=False, dtype=acc_type).astype(dtype) + std = np.sqrt(np.var(data, axis=(2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + eps) + x_hat = (data - mean.reshape(new_moments_shape)) / std.reshape(new_moments_shape) + return x_hat, mean, std + + def np_groupnorm(data, gamma, beta, num_groups, eps): + new_param_shape = (1, num_groups, 1, 1, 1) + x_hat, mean, std = x_hat_helper(data, num_groups, eps) + out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape) + return out.reshape(dshape), mean, std + + def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps): + x_hat, mean, std = x_hat_helper(data, num_groups, eps) + new_shape = x_hat.shape + dshape = data.shape + dtype = data.dtype + new_moments_shape = (new_shape[0], num_groups, 1, 1, 1) + new_param_shape = (1, num_groups, 1, 1, 1) + acc_type = acc_types[str(dtype)] + ograd = ograd.reshape(new_shape) + data = data.reshape(new_shape) + gamma = gamma.reshape(new_param_shape) + beta = beta.reshape(new_param_shape) + mean = mean.reshape(new_moments_shape) + std = std.reshape(new_moments_shape) + beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype) + x_hat_grad = ograd * gamma + ograd_mult = x_hat_grad / std + red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype) + data_grad = ograd_mult - red_out + red_out = np.mean(ograd_mult * x_hat, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype) + data_grad = data_grad - x_hat * red_out + return data_grad.reshape(dshape), gamma_grad, beta_grad + + + batch_size = random.randint(1, 8) + num_groups = random.randint(2, 3) + num_channels = random.randint(2, 3) * num_groups + height = random.randint(1, 5) + width = random.randint(1, 5) + dshape = (batch_size, num_channels, height, width) + param_shape = (num_groups,) + temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width) + np_data = np.random.uniform(0.2, 1.0, dshape) + np_gamma = np.random.uniform(-1.0, 1.0, param_shape) + np_beta = np.random.uniform(-1.0, 1.0, param_shape) + data_sym = mx.sym.Variable("data") + gamma_sym = mx.sym.Variable("gamma") + beta_sym = mx.sym.Variable("beta") + for dtype in [np.float16, np.float32, np.float64]: + eps = 1e-2 if dtype == np.float16 else 1e-5 + mx_data = mx.nd.array(np_data, dtype=dtype) + mx_gamma = mx.nd.array(np_gamma, dtype=dtype) + mx_beta = mx.nd.array(np_beta, dtype=dtype) + np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype), + np_gamma.astype(dtype), + np_beta.astype(dtype), + num_groups=num_groups, + eps=eps) + mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, + num_groups=num_groups, eps=eps, output_mean_var=True) + check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, np_mean, np_std], + rtol=1e-2 if dtype == np.float16 else 1e-3, + atol=5e-3 if dtype == np.float16 else 1e-5, dtype=dtype) + mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, beta=beta_sym, + num_groups=num_groups, eps=eps, output_mean_var=False) + np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype) + np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd, + np_data.astype(dtype), + np_gamma.astype(dtype), + np_beta.astype(dtype), + np_mean, np_std, + num_groups, eps) + check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], [mx.nd.array(np_ograd)], + [np_data_grad, np_gamma_grad, np_beta_grad], + rtol=1e-2 if dtype == np.float16 else 1e-3, + atol=5e-2 if dtype == np.float16 else 1e-5, dtype=dtype) + + @with_seed() def test_convolution_grouping(): for dim in [1, 2, 3]: