Skip to content

Commit

Permalink
[Bug Fix] Fix GroupNorm Implementation (apache#18199)
Browse files Browse the repository at this point in the history
* init

* add in_channels
  • Loading branch information
hgt312 authored and AntiZpvoh committed Jul 6, 2020
1 parent 3d7907e commit 2a48960
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 26 deletions.
11 changes: 7 additions & 4 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,26 +820,29 @@ class GroupNorm(HybridBlock):
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
prefix=None, params=None):
in_channels=0, prefix=None, params=None):
super(GroupNorm, self).__init__(prefix=prefix, params=params)
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null',
shape=(num_groups,), init=gamma_initializer,
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = self.params.get('beta', grad_req='write' if center else 'null',
shape=(num_groups,), init=beta_initializer,
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)

def hybrid_forward(self, F, data, gamma, beta):
norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, num_groups=self._num_groups, eps=self._epsilon)
return norm_data

def __repr__(self):
s = '{name}({content})'
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in self._kwargs.items()]))
Expand Down
25 changes: 13 additions & 12 deletions src/operator/nn/group_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,16 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,
TBlob data_grp = data.reshape(temp_data_shape);
const TBlob& mean_grp = mean.reshape(moments_shape);
const TBlob& std_grp = std.reshape(moments_shape);
const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
const TBlob& output_grp = outputs[groupnorm::kOut].reshape(temp_data_shape);

// Calculate data = data - mean
BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
{data_grp, mean_grp},
{kWriteTo}, {output});
{kWriteTo}, {output_grp});

// Calculate std
const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
MSHADOW_REAL_TYPE_SWITCH(output_grp.type_flag_, DType, {
BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, true>(
s, std_, req[0], workspace, centered_out);
Expand All @@ -157,11 +157,12 @@ void GroupNormCompute(const nnvm::NodeAttrs& attrs,

// Calculate data = data / std
BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
{output, std_grp},
{kWriteTo}, {output});
{output_grp, std_grp},
{kWriteTo}, {output_grp});

mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
new_param_shape[1] = num_groups;
const TBlob& output = outputs[groupnorm::kOut];
mxnet::TShape new_param_shape(data_shape.ndim(), 1);
new_param_shape[1] = data_shape[1];

const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
Expand Down Expand Up @@ -215,8 +216,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,

Stream<xpu> *s = ctx.get_stream<xpu>();
// Reshape gamma to be broadcastable
mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
new_param_shape[1] = num_groups;
mxnet::TShape new_param_shape(dshape.ndim(), 1);
new_param_shape[1] = dshape[1];

const TBlob& gamma = inputs[2].reshape(new_param_shape);

Expand All @@ -233,7 +234,7 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
// Prepare the necessary shapes for reduction
mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, red_exclude_dst_shape;
BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, &red_dst_shape);
BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
BroadcastReduceShapeCompact(dshape, gamma.shape_,
&red_exclude_src_shape, &red_exclude_dst_shape);

int N = red_src_shape.Size() / red_dst_shape.Size();
Expand Down Expand Up @@ -308,8 +309,8 @@ void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
if (req[0] != kNullOp) {
const TBlob output_ = outputs[0].reshape(data_.shape_);
BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
{ograd, gamma},
{kWriteTo}, {ograd_mult});
{inputs[0], gamma},
{kWriteTo}, {ograd_mult.reshape(data.shape_)});
BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
{ograd_mult, std_},
{kWriteTo}, {ograd_mult});
Expand Down
4 changes: 2 additions & 2 deletions src/operator/nn/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
return false;
}

in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(dshape[1]));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(dshape[1]));

out_shape->clear();
out_shape->push_back(dshape);
Expand Down
16 changes: 8 additions & 8 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1960,28 +1960,28 @@ def x_hat_helper(x, num_groups, eps):
return x_hat, mean, std

def np_groupnorm(data, gamma, beta, num_groups, eps):
new_param_shape = (1, num_groups, 1, 1, 1)
new_param_shape = (1, dshape[1], 1, 1)
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
out = x_hat * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out.reshape(dshape), mean, std
out = x_hat.reshape(dshape) * gamma.reshape(new_param_shape) + beta.reshape(new_param_shape)
return out, mean, std

def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
x_hat, mean, std = x_hat_helper(data, num_groups, eps)
new_shape = x_hat.shape
dshape = data.shape
dtype = data.dtype
new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
new_param_shape = (1, num_groups, 1, 1, 1)
new_param_shape = (1, dshape[1], 1, 1)
acc_type = acc_types[str(dtype)]
ograd = ograd.reshape(new_shape)
data = data.reshape(new_shape)
gamma = gamma.reshape(new_param_shape)
beta = beta.reshape(new_param_shape)
mean = mean.reshape(new_moments_shape)
std = std.reshape(new_moments_shape)
beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, keepdims=False).astype(dtype)
x_hat_grad = ograd * gamma
beta_grad = np.sum(ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
gamma_grad = np.sum(x_hat * ograd, axis=(0, 3, 4), dtype=acc_type, keepdims=False).astype(dtype).flatten()
x_hat_grad = ograd * gamma.reshape(1, num_groups, dshape[1] // num_groups, 1, 1)
ograd_mult = x_hat_grad / std
red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, keepdims=True).astype(dtype)
data_grad = ograd_mult - red_out
Expand All @@ -1996,7 +1996,7 @@ def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, eps):
height = random.randint(1, 5)
width = random.randint(1, 5)
dshape = (batch_size, num_channels, height, width)
param_shape = (num_groups,)
param_shape = (num_channels,)
temp_shape = (batch_size, num_groups, int(num_channels / num_groups), height, width)
np_data = np.random.uniform(0.2, 1.0, dshape)
np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
Expand Down

0 comments on commit 2a48960

Please sign in to comment.