Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 71 additions & 2 deletions python/paddle/nn/functional/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@

from __future__ import annotations

import inspect
import numbers
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from typing_extensions import overload

import paddle
from paddle import _C_ops, in_dynamic_mode
Expand Down Expand Up @@ -681,6 +684,7 @@ def local_response_norm(
return res


@overload
def group_norm(
x: Tensor,
num_groups: int,
Expand All @@ -689,16 +693,40 @@ def group_norm(
bias: Tensor | None = None,
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
name: str | None = None,
) -> Tensor:
) -> Tensor: ...


@overload
def group_norm(
input: Tensor,
num_groups: int,
weight: Tensor | None = None,
bias: Tensor | None = None,
eps: float = 1e-05,
) -> Tensor: ...


def group_norm(*args: Any, **kwargs: Any) -> Tensor:
"""
nn.GroupNorm is recommended.
For more information, please refer to :ref:`api_paddle_nn_GroupNorm` .

This function has two functionalities, depending on the parameters passed:

1. ``group_norm(Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)``:
PyTorch compatible group_norm.

2. ``group_norm(Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None,
DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)``:
The original paddle.nn.functional.group_norm, see the following docs.

Parameters:
x(Tensor): Input Tensor with shape: attr:`(batch, num_features, *)`.
alias: ``input``.
num_groups(int): The number of groups that divided from channels.
epsilon(float, optional): The small value added to the variance to prevent
division by zero. Default: 1e-05.
alias: ``eps``.
weight(Tensor, optional): The weight Tensor of group_norm, with shape: attr:`[num_channels]`.
Default: None.
bias(Tensor, optional): The bias Tensor of group_norm, with shape: attr:`[num_channels]`.
Expand Down Expand Up @@ -744,6 +772,44 @@ def group_norm(
[[-1.34163547, -0.44721183],
[ 0.44721183, 1.34163547]]]])
"""

len_args = len(args)
if len_args + len(kwargs) < 2:
raise TypeError(
f"Too few arguments in the function call: {len_args}, {len(kwargs)}. Expect one of: \n"
" - (Tensor input, int num_groups, Tensor weight = None, Tensor bias = None, float eps = 1e-05)\n"
" - (Tensor x, int num_groups, float epsilon = 1e-05, Tensor weight = None, Tensor bias = None, "
"DataLayout1D | DataLayout2D | DataLayout3D data_format = 'NCHW', str | None name = None)"
)

def safe_set_param(key: str, value: Any):
if key in kwargs:
raise TypeError(f"got multiple values for argument '{key}'")
kwargs[key] = value

if 'input' in kwargs:
safe_set_param('x', kwargs.pop('input'))

if 'eps' in kwargs:
safe_set_param('epsilon', kwargs.pop('eps'))

if len_args >= 3 and not isinstance(args[2], float):
param_keys = ["weight", "bias", "epsilon"]
for idx in range(min(len_args - 2, len(param_keys))):
safe_set_param(param_keys[idx], args[idx + 2])
args = args[:2]
return _group_norm_wrapper(*args, **kwargs)


def _group_norm_wrapper(
x: Tensor,
num_groups: int,
epsilon: float = 1e-05,
weight: Tensor | None = None,
bias: Tensor | None = None,
data_format: DataLayout1D | DataLayout2D | DataLayout3D = 'NCHW',
name: str | None = None,
) -> Tensor:
if data_format not in ['NCL', 'NCHW', 'NCDHW', 'NLC', 'NHWC', 'NDHWC']:
raise ValueError("unsupported data layout:" + data_format)

Expand Down Expand Up @@ -794,3 +860,6 @@ def group_norm(
)

return helper.append_activation(group_norm_out)


group_norm.__signature__ = inspect.signature(_group_norm_wrapper)
125 changes: 125 additions & 0 deletions test/legacy_test/test_group_norm_op_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,5 +618,130 @@ def test_group_norm_cpu_with_optional_grad_nhwc(self):
np.testing.assert_equal(dx.numpy(), dx_ref.numpy())


class TestGroupNormParam(unittest.TestCase):
def setUp(self):
self.x_tensor = paddle.randn([2, 6, 4, 4], dtype='float32')
self.weight_tensor = paddle.randn([6], dtype='float32')
self.bias_tensor = paddle.randn([6], dtype='float32')

def test_alias_input_for_x(self):
"""test parameter alias input/x"""
out_with_input = paddle.nn.functional.group_norm(
input=self.x_tensor,
num_groups=3,
weight=self.weight_tensor,
bias=self.bias_tensor,
eps=1e-5,
)
out_with_x = paddle.nn.functional.group_norm(
x=self.x_tensor,
num_groups=3,
weight=self.weight_tensor,
bias=self.bias_tensor,
eps=1e-5,
)

np.testing.assert_array_equal(
out_with_input.numpy(), out_with_x.numpy()
)

def test_params_consistency(self):
"""test both paddle and torch formats works."""
out_old = paddle.nn.functional.group_norm(
self.x_tensor,
3,
1e-5,
weight=self.weight_tensor,
bias=self.bias_tensor,
)

out_new = paddle.nn.functional.group_norm(
x=self.x_tensor,
num_groups=3,
weight=self.weight_tensor,
bias=self.bias_tensor,
eps=1e-5,
)

np.testing.assert_array_equal(out_old.numpy(), out_new.numpy())

def test_params_1(self):
"""test all args with torch format"""
try:
out = paddle.nn.functional.group_norm(
self.x_tensor,
3,
self.weight_tensor,
self.bias_tensor,
1e-5,
)
self.assertTrue(True, "Function call succeeded without error")
except Exception as e:
self.fail(f"Function raised an unexpected exception: {e}")

def test_params_2(self):
"""test all kwargs with torch format"""
try:
out = paddle.nn.functional.group_norm(
input=self.x_tensor,
num_groups=3,
weight=self.weight_tensor,
bias=self.bias_tensor,
epsilon=1e-5,
)
self.assertTrue(True, "Function call succeeded without error")
except Exception as e:
self.fail(f"Function raised an unexpected exception: {e}")

def test_params_3(self):
"""test of passing both args and kwargs parameters"""
try:
out1 = paddle.nn.functional.group_norm(
self.x_tensor,
3,
weight=self.weight_tensor,
bias=self.bias_tensor,
epsilon=1e-5,
)
out2 = paddle.nn.functional.group_norm(
self.x_tensor,
3,
1e-5,
weight=self.weight_tensor,
bias=self.bias_tensor,
)
self.assertTrue(True, "Function call succeeded without error")
except Exception as e:
self.fail(f"Function raised an unexpected exception: {e}")

def test_params_4(self):
"""test default parameters"""
try:
out1 = paddle.nn.functional.group_norm(
self.x_tensor,
3,
self.weight_tensor,
)
out2 = paddle.nn.functional.group_norm(self.x_tensor, 3, 1e-5)
self.assertTrue(True, "Function call succeeded without error")
except Exception as e:
self.fail(f"Function raised an unexpected exception: {e}")

def test_params_5(self):
"""test duplicate parameters"""
with self.assertRaises(TypeError):
out_1 = paddle.nn.functional.group_norm(
x=self.x_tensor,
input=self.x_tensor,
num_groups=3,
)
with self.assertRaises(TypeError):
out_2 = paddle.nn.functional.group_norm(
self.x_tensor,
input=self.x_tensor,
num_groups=3,
)


if __name__ == '__main__':
unittest.main()
Loading