Skip to content
56 changes: 46 additions & 10 deletions python/paddle/tensor/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING, Any, Literal

from typing_extensions import TypeAlias, overload

Expand All @@ -29,6 +29,7 @@
ParamAliasDecorator,
param_two_alias,
param_two_alias_one_default,
use_first_signature,
)

from ..base.data_feeder import check_type, check_variable_and_dtype
Expand Down Expand Up @@ -285,16 +286,40 @@ def _replace_nan(out):
return result


@overload
def std(
x: Tensor,
axis: int | Sequence[int] | None = None,
unbiased: bool = True,
unbiased: bool | None = None,
keepdim: bool = False,
name: str | None = None,
) -> Tensor:
*,
correction: float = 1,
out: Tensor | None = None,
) -> Tensor: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

只有一种签名?这里用 overload 的意义是什么?



@overload
def std(
input: Tensor,
dim: int | Sequence[int] | None = None,
*,
correction: float = 1,
keepdim: bool = False,
out: Tensor | None = None,
) -> Tensor: ...


@use_first_signature
def std(*args: Any, **kwargs: Any) -> Tensor:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image

这样的话,签名会丢掉

"""
Computes the standard-deviation of ``x`` along ``axis`` .

.. note::
Alias Support:
1. The parameter name ``input`` can be used as an alias for ``x``.
2. The parameter name ``dim`` can be used as an alias for ``axis``.

Args:
x (Tensor): The input Tensor with data type float16, float32, float64.
axis (int|list|tuple|None, optional): The axis along which to perform
Expand All @@ -319,34 +344,45 @@ def std(
the output Tensor is squeezed in ``axis`` . Default is False.
name (str|None, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
correction (int|float, optional): Difference between the sample size and sample degrees of freedom.
Defaults to 1 (Bessel's correction). If unbiased is specified, this parameter is ignored.
out (Tensor|None, optional): Output tensor. Default is None.

Returns:
Tensor, results of standard-deviation along ``axis`` of ``x``, with the
same data type as ``x``.

Examples:
.. code-block:: python
.. code-block:: pycon

>>> import paddle

>>> x = paddle.to_tensor([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]])
>>> out1 = paddle.std(x)
>>> print(out1.numpy())
1.6329932

>>> out2 = paddle.std(x, unbiased=False)
>>> print(out2.numpy())
1.490712

>>> out3 = paddle.std(x, axis=1)
>>> print(out3.numpy())
[1. 2.081666]

>>> out4 = paddle.std(x=x, keepdim=True, correction=1.5)
>>> print(out4.numpy())
[[1.721326]]

>>> out5 = paddle.std(input=x, dim=[0, 1])
>>> print(out5.numpy())
1.6329932

"""
if not in_dynamic_or_pir_mode():
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'std'
)
out = var(**locals())
return paddle.sqrt(out)
variance = var(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看来std、var这两个API不是简单的别名替换。

需要一套参数分发的逻辑,参考group_norm、gather的写法,这里需要处理args/kwargs来判断overload的两种签名。

var是底层的接口,先处理好var,std只需要这么写就行:

return sqrt(var(*args, **kwargs), out=out)

out=None的默认缺省逻辑在各API底层以及下沉后都会有,不需要重复判断。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那个if逻辑可能删不了,out需要从kwargs里面取,取不到的话会报KeyError。目前的实现也不涉及像gather一样新定义一个wrapper。

if 'out' in kwargs:
return paddle.sqrt(variance, out=kwargs['out'])
Comment on lines +382 to +384
Copy link

Copilot AI Dec 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The handling of the out parameter needs careful consideration. The current implementation extracts out from kwargs and passes it to paddle.sqrt. However, the var function is also called with all kwargs including out, which means out would first receive the variance result from var, and then be overwritten with the sqrt result. This could lead to incorrect behavior if the var function tries to assign to out. The implementation should either remove out from kwargs before calling var, or handle this more explicitly in the proper signature-based implementation.

Suggested change
variance = var(*args, **kwargs)
if 'out' in kwargs:
return paddle.sqrt(variance, out=kwargs['out'])
out = kwargs.pop('out', None)
variance = var(*args, **kwargs)
if out is not None:
return paddle.sqrt(variance, out=out)

Copilot uses AI. Check for mistakes.
return paddle.sqrt(variance)


def numel(x: Tensor, name: str | None = None) -> Tensor:
Expand Down
14 changes: 13 additions & 1 deletion python/paddle/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, get_overloads

import paddle

Expand Down Expand Up @@ -943,3 +943,15 @@ def wrapper(*args, **kwargs) -> _RetT:
return wrapper

return decorator


def use_first_signature(
func: Callable[_InputT, _RetT],
) -> Callable[_InputT, _RetT]:
overloads = get_overloads(func)
if not overloads:
return func
first_overload = overloads[0]
sig = inspect.signature(first_overload)
func.__signature__ = sig
return func
114 changes: 108 additions & 6 deletions test/legacy_test/test_std_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,108 @@ def test_alias(self):
paddle.enable_static()


class TestStdAPI_Compatibility(unittest.TestCase):
def setUp(self):
np.random.seed(2026)
self.dtype = 'float32'
self.shape = [1, 3, 4, 10]
self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
self.place = get_device_place()

def test_dygraph_compatibility(self):
paddle.disable_static()
x = paddle.tensor(self.x)
# input arg
out1_1 = paddle.std(x=x)
out1_2 = paddle.std(input=x)
np.testing.assert_allclose(out1_1.numpy(), out1_2.numpy(), rtol=1e-05)
# dim arg
out2_1 = paddle.std(x, axis=3)
out2_2 = paddle.std(x, dim=3)
np.testing.assert_allclose(out2_1.numpy(), out2_2.numpy(), rtol=1e-05)
# out arg
out3_1 = paddle.empty([])
out3_2 = paddle.std(x, out=out3_1)
np.testing.assert_allclose(out3_1.numpy(), out3_2.numpy(), rtol=1e-05)
paddle.enable_static()

def test_static_compatibility(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('x', self.shape, self.dtype)
# input arg
out1_1 = paddle.std(x=x)
out1_2 = paddle.std(input=x)
# dim arg
out2_1 = paddle.std(x, axis=3)
out2_2 = paddle.std(x, dim=3)
exe = paddle.static.Executor(self.place)
res = exe.run(
feed={'x': self.x}, fetch_list=[out1_1, out1_2, out2_1, out2_2]
)
np.testing.assert_allclose(res[0], res[1], rtol=1e-05)
np.testing.assert_allclose(res[2], res[3], rtol=1e-05)


class TestStdAPI_Correction(unittest.TestCase):
def setUp(self):
np.random.seed(2026)
self.dtype = 'float32'
self.shape = [1, 3, 4, 10]
self.set_attrs()
self.x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
if self.axis:
axis = tuple(self.axis)
self.ref_out = np.std(self.x, axis, ddof=self.correction)
else:
self.ref_out = np.std(self.x, ddof=self.correction)
self.place = get_device_place()

def set_attrs(self):
self.correction = 1
self.axis = None

def test_dygraph_correction(self):
paddle.disable_static()
x = paddle.tensor(self.x)
if self.axis:
out = paddle.std(x, self.axis, correction=self.correction)
else:
out = paddle.std(x, correction=self.correction)
np.testing.assert_allclose(out.numpy(), self.ref_out, rtol=1e-05)
paddle.enable_static()

def test_static_correction(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('x', self.shape, self.dtype)
if self.axis:
out = paddle.std(x, self.axis, correction=self.correction)
else:
out = paddle.std(x, correction=self.correction)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x}, fetch_list=[out])
np.testing.assert_allclose(res[0], self.ref_out, rtol=1e-05)


class TestStdAPI_Correction2(TestStdAPI_Correction):
def set_attrs(self):
self.correction = 2
self.axis = None


class TestStdAPI_CorrectionFloat(TestStdAPI_Correction):
def set_attrs(self):
self.correction = 1.5
self.axis = None


class TestStdAPI_CorrectionWithAxis(TestStdAPI_Correction):
def set_attrs(self):
self.correction = 0
self.axis = [1, 2]


class TestStdError(unittest.TestCase):
def test_error(self):
paddle.enable_static()
Expand Down Expand Up @@ -147,14 +249,14 @@ def init_data(self):
self.x_shape = []
# x = torch.tensor([])
# res= torch.std(x) Here, res is nan
self.expact_out = np.nan
self.expect_out = np.nan

def test_zerosize(self):
self.init_data()
paddle.disable_static()
x = paddle.to_tensor(np.random.random(self.x_shape))
out1 = paddle.std(x).numpy()
np.testing.assert_allclose(out1, self.expact_out, equal_nan=True)
np.testing.assert_allclose(out1, self.expect_out, equal_nan=True)
paddle.enable_static()


Expand All @@ -163,14 +265,14 @@ def init_data(self):
self.x_shape = [1]
# x = torch.randn([1])
# res= torch.std(x,correction=0) Here, res is 0.
self.expact_out = 0.0
self.expect_out = 0.0

def test_api(self):
self.init_data()
paddle.disable_static()
x = paddle.to_tensor(np.random.random(self.x_shape))
out1 = paddle.std(x, unbiased=False).numpy()
np.testing.assert_allclose(out1, self.expact_out, equal_nan=True)
np.testing.assert_allclose(out1, self.expect_out, equal_nan=True)
paddle.enable_static()


Expand All @@ -179,14 +281,14 @@ def init_data(self):
self.x_shape = [1]
# x = torch.randn([1])
# res= torch.std(x,correction=1) Here, res is 0.
self.expact_out = np.nan
self.expect_out = np.nan

def test_api(self):
self.init_data()
paddle.disable_static()
x = paddle.to_tensor(np.random.random(self.x_shape))
out1 = paddle.std(x, unbiased=True).numpy()
np.testing.assert_allclose(out1, self.expact_out, equal_nan=True)
np.testing.assert_allclose(out1, self.expect_out, equal_nan=True)
paddle.enable_static()


Expand Down
Loading