Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions python/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,28 +681,34 @@ def local_response_norm(
return res


@param_two_alias(["x", "input"], ["epsilon", "eps"])
def group_norm(
x: Tensor,
num_groups: int,
epsilon: float = 1e-05,
weight: Tensor | None = None,
bias: Tensor | None = None,
epsilon: float = 1e-05,
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
name: str | None = None,
) -> Tensor:
"""
nn.GroupNorm is recommended.
For more information, please refer to :ref:`api_paddle_nn_GroupNorm` .

.. note::
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
For example, ``group_norm(input=tensor_x, ...)`` is equivalent to ``group_norm(x=tensor_x, ...)``.

Parameters:
x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`.
alias: ``input``.
num_groups(int): The number of groups that divided from channels.
epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`.
Default: None.
bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`.
Default: None.
epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
data_format(str, optional): Specify the input data format. Support "NCL", "NCHW", "NCDHW", "NLC", "NHWC" or "NDHWC". Default: "NCHW".
name(str|None, optional): Name for the GroupNorm, default is None. For more information, please refer to :ref:`api_guide_Name`..

Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,9 @@ def forward(self, input: Tensor) -> Tensor:
return group_norm(
input,
self._num_groups,
self._epsilon,
self.weight,
self.bias,
self._epsilon,
self._data_format,
)

Expand Down
51 changes: 51 additions & 0 deletions test/legacy_test/test_group_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,5 +618,56 @@ def test_group_norm_cpu_with_optional_grad_nhwc(self):
np.testing.assert_equal(dx.numpy(), dx_ref.numpy())


class TestGroupNormParam(unittest.TestCase):
def setUp(self):
np.random.seed(42)
self.x_np = np.random.randn(2, 6, 4, 4).astype('float32')
self.weight_np = np.random.randn(6).astype('float32')
self.bias_np = np.random.randn(6).astype('float32')

def test_alias_input_for_x(self):
"""test parameter alias input/x"""
x_tensor = paddle.to_tensor(self.x_np)
weight_tensor = paddle.to_tensor(self.weight_np)
bias_tensor = paddle.to_tensor(self.bias_np)

out_with_input = paddle.nn.functional.group_norm(
input=x_tensor,
num_groups=3,
weight=weight_tensor,
bias=bias_tensor,
eps=1e-5,
)
out_with_x = paddle.nn.functional.group_norm(
x=x_tensor,
num_groups=3,
weight=weight_tensor,
bias=bias_tensor,
eps=1e-5,
)

np.testing.assert_array_equal(
out_with_input.numpy(), out_with_x.numpy()
)

def test_param_order(self):
"""test order of parameters"""
x_tensor = paddle.to_tensor(self.x_np)
weight_tensor = paddle.to_tensor(self.weight_np)
bias_tensor = paddle.to_tensor(self.bias_np)

try:
out = paddle.nn.functional.group_norm(
x_tensor, # x
3, # num_groups
weight_tensor, # weight
bias_tensor, # bias
1e-5, # epsilon
)
self.assertTrue(True, "Function call succeeded without error")
except Exception as e:
self.fail(f"Function raised an unexpected exception: {e}")


if __name__ == '__main__':
unittest.main()
Loading