-
Notifications
You must be signed in to change notification settings - Fork 5.9k
[API Compatibility No.352、377] Add torch-style arg and alias for std -part
#77006
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
96560a1
712c4cf
b20b4c6
6e92cb7
c594a1e
df53d56
e6136e4
345ac53
e0d2265
04a19fb
bc34d71
1bcf808
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |||||||||||||||
| from __future__ import annotations | ||||||||||||||||
|
|
||||||||||||||||
| import warnings | ||||||||||||||||
| from typing import TYPE_CHECKING, Literal | ||||||||||||||||
| from typing import TYPE_CHECKING, Any, Literal | ||||||||||||||||
|
|
||||||||||||||||
| from typing_extensions import TypeAlias, overload | ||||||||||||||||
|
|
||||||||||||||||
|
|
@@ -29,6 +29,7 @@ | |||||||||||||||
| ParamAliasDecorator, | ||||||||||||||||
| param_two_alias, | ||||||||||||||||
| param_two_alias_one_default, | ||||||||||||||||
| use_first_signature, | ||||||||||||||||
| ) | ||||||||||||||||
|
|
||||||||||||||||
| from ..base.data_feeder import check_type, check_variable_and_dtype | ||||||||||||||||
|
|
@@ -285,16 +286,40 @@ def _replace_nan(out): | |||||||||||||||
| return result | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @overload | ||||||||||||||||
| def std( | ||||||||||||||||
| x: Tensor, | ||||||||||||||||
| axis: int | Sequence[int] | None = None, | ||||||||||||||||
| unbiased: bool = True, | ||||||||||||||||
| unbiased: bool | None = None, | ||||||||||||||||
| keepdim: bool = False, | ||||||||||||||||
| name: str | None = None, | ||||||||||||||||
| ) -> Tensor: | ||||||||||||||||
| *, | ||||||||||||||||
| correction: float = 1, | ||||||||||||||||
| out: Tensor | None = None, | ||||||||||||||||
| ) -> Tensor: ... | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @overload | ||||||||||||||||
| def std( | ||||||||||||||||
| input: Tensor, | ||||||||||||||||
| dim: int | Sequence[int] | None = None, | ||||||||||||||||
| *, | ||||||||||||||||
| correction: float = 1, | ||||||||||||||||
| keepdim: bool = False, | ||||||||||||||||
| out: Tensor | None = None, | ||||||||||||||||
| ) -> Tensor: ... | ||||||||||||||||
|
|
||||||||||||||||
|
|
||||||||||||||||
| @use_first_signature | ||||||||||||||||
| def std(*args: Any, **kwargs: Any) -> Tensor: | ||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||||||||||||||
| """ | ||||||||||||||||
| Computes the standard-deviation of ``x`` along ``axis`` . | ||||||||||||||||
|
|
||||||||||||||||
| .. note:: | ||||||||||||||||
| Alias Support: | ||||||||||||||||
| 1. The parameter name ``input`` can be used as an alias for ``x``. | ||||||||||||||||
| 2. The parameter name ``dim`` can be used as an alias for ``axis``. | ||||||||||||||||
|
|
||||||||||||||||
| Args: | ||||||||||||||||
| x (Tensor): The input Tensor with data type float16, float32, float64. | ||||||||||||||||
| axis (int|list|tuple|None, optional): The axis along which to perform | ||||||||||||||||
|
|
@@ -319,34 +344,45 @@ def std( | |||||||||||||||
| the output Tensor is squeezed in ``axis`` . Default is False. | ||||||||||||||||
| name (str|None, optional): Name for the operation (optional, default is None). | ||||||||||||||||
| For more information, please refer to :ref:`api_guide_Name`. | ||||||||||||||||
| correction (int|float, optional): Difference between the sample size and sample degrees of freedom. | ||||||||||||||||
| Defaults to 1 (Bessel's correction). If unbiased is specified, this parameter is ignored. | ||||||||||||||||
| out (Tensor|None, optional): Output tensor. Default is None. | ||||||||||||||||
|
|
||||||||||||||||
| Returns: | ||||||||||||||||
| Tensor, results of standard-deviation along ``axis`` of ``x``, with the | ||||||||||||||||
| same data type as ``x``. | ||||||||||||||||
|
|
||||||||||||||||
| Examples: | ||||||||||||||||
| .. code-block:: python | ||||||||||||||||
| .. code-block:: pycon | ||||||||||||||||
|
|
||||||||||||||||
| >>> import paddle | ||||||||||||||||
|
|
||||||||||||||||
| >>> x = paddle.to_tensor([[1.0, 2.0, 3.0], [1.0, 4.0, 5.0]]) | ||||||||||||||||
| >>> out1 = paddle.std(x) | ||||||||||||||||
| >>> print(out1.numpy()) | ||||||||||||||||
| 1.6329932 | ||||||||||||||||
|
|
||||||||||||||||
| >>> out2 = paddle.std(x, unbiased=False) | ||||||||||||||||
| >>> print(out2.numpy()) | ||||||||||||||||
| 1.490712 | ||||||||||||||||
|
|
||||||||||||||||
| >>> out3 = paddle.std(x, axis=1) | ||||||||||||||||
| >>> print(out3.numpy()) | ||||||||||||||||
| [1. 2.081666] | ||||||||||||||||
|
|
||||||||||||||||
| >>> out4 = paddle.std(x=x, keepdim=True, correction=1.5) | ||||||||||||||||
| >>> print(out4.numpy()) | ||||||||||||||||
| [[1.721326]] | ||||||||||||||||
|
|
||||||||||||||||
| >>> out5 = paddle.std(input=x, dim=[0, 1]) | ||||||||||||||||
| >>> print(out5.numpy()) | ||||||||||||||||
| 1.6329932 | ||||||||||||||||
|
|
||||||||||||||||
| """ | ||||||||||||||||
| if not in_dynamic_or_pir_mode(): | ||||||||||||||||
| check_variable_and_dtype( | ||||||||||||||||
| x, 'x', ['float16', 'float32', 'float64'], 'std' | ||||||||||||||||
| ) | ||||||||||||||||
| out = var(**locals()) | ||||||||||||||||
| return paddle.sqrt(out) | ||||||||||||||||
| variance = var(*args, **kwargs) | ||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 看来std、var这两个API不是简单的别名替换。 需要一套参数分发的逻辑,参考group_norm、gather的写法,这里需要处理args/kwargs来判断overload的两种签名。 var是底层的接口,先处理好var,std只需要这么写就行: out=None的默认缺省逻辑在各API底层以及下沉后都会有,不需要重复判断。
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那个if逻辑可能删不了,out需要从kwargs里面取,取不到的话会报KeyError。目前的实现也不涉及像gather一样新定义一个wrapper。 |
||||||||||||||||
| if 'out' in kwargs: | ||||||||||||||||
| return paddle.sqrt(variance, out=kwargs['out']) | ||||||||||||||||
|
Comment on lines
+382
to
+384
|
||||||||||||||||
| variance = var(*args, **kwargs) | |
| if 'out' in kwargs: | |
| return paddle.sqrt(variance, out=kwargs['out']) | |
| out = kwargs.pop('out', None) | |
| variance = var(*args, **kwargs) | |
| if out is not None: | |
| return paddle.sqrt(variance, out=out) |

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
只有一种签名?这里用
overload的意义是什么?