From 640f3fc864187ecffeb6c183104b8ea63cbe111f Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 10 Jul 2024 15:35:48 +0800 Subject: [PATCH 1/5] fix --- python/paddle/distribution/transform.py | 201 +++++++++++++----------- 1 file changed, 110 insertions(+), 91 deletions(-) diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index 9bfa53a89fe653..bda0e6f25c92d3 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import enum import math import typing +from typing import TYPE_CHECKING, Sequence import paddle import paddle.nn.functional as F @@ -25,6 +27,10 @@ variable, ) +if TYPE_CHECKING: + from paddle import Tensor + from paddle.distribution import Distribution, TransformedDistribution + __all__ = [ 'Transform', 'AbsTransform', @@ -129,7 +135,9 @@ def _is_injective(cls): """ return Type.is_injective(cls._type) - def __call__(self, input): + def __call__( + self, input: Tensor | Distribution | Transform + ) -> Tensor | TransformedDistribution | ChainTransform: """Make this instance as a callable object. The return value is depending on the input type. @@ -154,7 +162,7 @@ def __call__(self, input): return ChainTransform([self, input]) return self.forward(input) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: """Forward transformation with mapping :math:`y = f(x)`. Useful for turning one random outcome into another. @@ -179,7 +187,7 @@ def forward(self, x): ) return self._forward(x) - def inverse(self, y): + def inverse(self, y: Tensor) -> Tensor: """Inverse transformation :math:`x = f^{-1}(y)`. It's useful for "reversing" a transformation to compute one probability in terms of another. @@ -202,7 +210,7 @@ def inverse(self, y): ) return self._inverse(y) - def forward_log_det_jacobian(self, x): + def forward_log_det_jacobian(self, x: Tensor) -> Tensor: """The log of the absolute value of the determinant of the matrix of all first-order partial derivatives of the inverse function. @@ -235,7 +243,7 @@ def forward_log_det_jacobian(self, x): return self._call_forward_log_det_jacobian(x) - def inverse_log_det_jacobian(self, y): + def inverse_log_det_jacobian(self, y: Tensor) -> Tensor: """Compute :math:`log|det J_{f^{-1}}(y)|`. Note that ``forward_log_det_jacobian`` is the negative of this function, evaluated at :math:`f^{-1}(y)`. @@ -258,7 +266,7 @@ def inverse_log_det_jacobian(self, y): ) return self._call_inverse_log_det_jacobian(y) - def forward_shape(self, shape): + def forward_shape(self, shape: Sequence[int]) -> Sequence[int]: """Infer the shape of forward transformation. Args: @@ -273,7 +281,7 @@ def forward_shape(self, shape): ) return self._forward_shape(shape) - def inverse_shape(self, shape): + def inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: """Infer the shape of inverse transformation. Args: @@ -298,19 +306,19 @@ def _codomain(self): """The codomain of this transformation""" return variable.real - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: """Inner method for public API ``forward``, subclass should overwrite this method for supporting forward transformation. """ raise NotImplementedError('Forward not implemented') - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: """Inner method of public API ``inverse``, subclass should overwrite this method for supporting inverse transformation. """ raise NotImplementedError('Inverse not implemented') - def _call_forward_log_det_jacobian(self, x): + def _call_forward_log_det_jacobian(self, x: Tensor) -> Tensor: """Inner method called by ``forward_log_det_jacobian``.""" if hasattr(self, '_forward_log_det_jacobian'): return self._forward_log_det_jacobian(x) @@ -321,7 +329,7 @@ def _call_forward_log_det_jacobian(self, x): 'is implemented. One of them is required.' ) - def _call_inverse_log_det_jacobian(self, y): + def _call_inverse_log_det_jacobian(self, y: Tensor) -> Tensor: """Inner method called by ``inverse_log_det_jacobian``""" if hasattr(self, '_inverse_log_det_jacobian'): return self._inverse_log_det_jacobian(y) @@ -332,14 +340,14 @@ def _call_inverse_log_det_jacobian(self, y): 'is implemented. One of them is required' ) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: """Inner method called by ``forward_shape``, which is used to infer the forward shape. Subclass should overwrite this method for supporting ``forward_shape``. """ return shape - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: """Inner method called by ``inverse_shape``, which is used to infer the inverse shape. Subclass should overwrite this method for supporting ``inverse_shape``. @@ -400,22 +408,22 @@ class AbsTransform(Transform): """ _type = Type.SURJECTION - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.abs() - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return -y, y - def _inverse_log_det_jacobian(self, y): + def _inverse_log_det_jacobian(self, y: Tensor) -> Tensor: zero = paddle.zeros([], dtype=y.dtype) return zero, zero @property - def _domain(self): + def _domain(self) -> Tensor: return variable.real @property - def _codomain(self): + def _codomain(self) -> Tensor: return variable.positive @@ -447,8 +455,10 @@ class AffineTransform(Transform): 0.) """ _type = Type.BIJECTION + loc: Tensor + scale: Tensor - def __init__(self, loc, scale): + def __init__(self, loc: Tensor, scale: Tensor) -> None: if not isinstance( loc, (paddle.base.framework.Variable, paddle.pir.Value) ): @@ -464,23 +474,23 @@ def __init__(self, loc, scale): super().__init__() @property - def loc(self): + def loc(self) -> Tensor: return self._loc @property - def scale(self): + def scale(self) -> Tensor: return self._scale - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return self._loc + self._scale * x - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return (y - self._loc) / self._scale - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return paddle.abs(self._scale).log() - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: return tuple( paddle.broadcast_shape( paddle.broadcast_shape(shape, self._loc.shape), @@ -488,7 +498,7 @@ def _forward_shape(self, shape): ) ) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: return tuple( paddle.broadcast_shape( paddle.broadcast_shape(shape, self._loc.shape), @@ -497,11 +507,11 @@ def _inverse_shape(self, shape): ) @property - def _domain(self): + def _domain(self) -> Tensor: return variable.real @property - def _codomain(self): + def _codomain(self) -> Tensor: return variable.real @@ -538,8 +548,9 @@ class ChainTransform(Transform): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [ 0., -1., -2., -3.]) """ + transforms: Sequence[Transform] - def __init__(self, transforms): + def __init__(self, transforms: Sequence[Transform]) -> None: if not isinstance(transforms, typing.Sequence): raise TypeError( f"Expected type of 'transforms' is Sequence, but got {type(transforms)}" @@ -552,20 +563,20 @@ def __init__(self, transforms): self.transforms = transforms super().__init__() - def _is_injective(self): + def _is_injective(self) -> bool: return all(t._is_injective() for t in self.transforms) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: for transform in self.transforms: x = transform.forward(x) return x - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: for transform in reversed(self.transforms): y = transform.inverse(y) return y - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> float: value = 0.0 event_rank = self._domain.event_rank for t in self.transforms: @@ -576,17 +587,17 @@ def _forward_log_det_jacobian(self, x): event_rank += t._codomain.event_rank - t._domain.event_rank return value - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: for transform in self.transforms: shape = transform.forward_shape(shape) return shape - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: for transform in self.transforms: shape = transform.inverse_shape(shape) return shape - def _sum_rightmost(self, value, n): + def _sum_rightmost(self, value: Tensor, n: int) -> Tensor: """sum value along rightmost n dim""" return value.sum(list(range(-n, 0))) if n > 0 else value @@ -658,24 +669,24 @@ class ExpTransform(Transform): """ _type = Type.BIJECTION - def __init__(self): + def __init__(self) -> None: super().__init__() @property - def _domain(self): + def _domain(self) -> Tensor: return variable.real @property - def _codomain(self): + def _codomain(self) -> Tensor: return variable.positive - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.exp() - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.log() - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return x @@ -722,8 +733,10 @@ class IndependentTransform(Transform): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [6. , 15.]) """ + base: Transform + _reinterpreted_batch_rank: int - def __init__(self, base, reinterpreted_batch_rank): + def __init__(self, base: Transform, reinterpreted_batch_rank: int) -> None: if not isinstance(base, Transform): raise TypeError( f"Expected 'base' is Transform type, but get {type(base)}" @@ -737,28 +750,28 @@ def __init__(self, base, reinterpreted_batch_rank): self._reinterpreted_batch_rank = reinterpreted_batch_rank super().__init__() - def _is_injective(self): + def _is_injective(self) -> bool: return self._base._is_injective() - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: if x.dim() < self._domain.event_rank: raise ValueError("Input dimensions is less than event dimensions.") return self._base.forward(x) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: if y.dim() < self._codomain.event_rank: raise ValueError("Input dimensions is less than event dimensions.") return self._base.inverse(y) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return self._base.forward_log_det_jacobian(x).sum( list(range(-self._reinterpreted_batch_rank, 0)) ) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: return self._base.forward_shape(shape) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return self._base.inverse_shape(shape) @property @@ -801,8 +814,9 @@ class PowerTransform(Transform): [0.69314718, 1.38629436]) """ _type = Type.BIJECTION + power: Tensor - def __init__(self, power): + def __init__(self, power: Tensor) -> None: if not isinstance( power, (paddle.base.framework.Variable, paddle.pir.Value) ): @@ -813,30 +827,30 @@ def __init__(self, power): super().__init__() @property - def power(self): + def power(self) -> Tensor: return self._power @property - def _domain(self): + def _domain(self) -> Tensor: return variable.real @property - def _codomain(self): + def _codomain(self) -> Tensor: return variable.positive - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.pow(self._power) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.pow(1 / self._power) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return (self._power * x.pow(self._power - 1)).abs().log() - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: return tuple(paddle.broadcast_shape(shape, self._power.shape)) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: return tuple(paddle.broadcast_shape(shape, self._power.shape)) @@ -874,8 +888,12 @@ class ReshapeTransform(Transform): [0.]) """ _type = Type.BIJECTION + in_event_shape: Sequence[int] + out_event_shape: Sequence[int] - def __init__(self, in_event_shape, out_event_shape): + def __init__( + self, in_event_shape: Sequence[int], out_event_shape: Sequence[int] + ) -> None: if not isinstance(in_event_shape, typing.Sequence) or not isinstance( out_event_shape, typing.Sequence ): @@ -901,11 +919,11 @@ def __init__(self, in_event_shape, out_event_shape): super().__init__() @property - def in_event_shape(self): + def in_event_shape(self) -> tuple[Sequence[int]]: return self._in_event_shape @property - def out_event_shape(self): + def out_event_shape(self) -> tuple[Sequence[int]]: return self._out_event_shape @property @@ -916,19 +934,19 @@ def _domain(self): def _codomain(self): return variable.Independent(variable.real, len(self._out_event_shape)) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.reshape( tuple(x.shape)[: x.dim() - len(self._in_event_shape)] + self._out_event_shape ) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.reshape( tuple(y.shape)[: y.dim() - len(self._out_event_shape)] + self._in_event_shape ) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: if len(shape) < len(self._in_event_shape): raise ValueError( f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" @@ -943,7 +961,7 @@ def _forward_shape(self, shape): tuple(shape[: -len(self._in_event_shape)]) + self._out_event_shape ) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: if len(shape) < len(self._out_event_shape): raise ValueError( f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" @@ -958,7 +976,7 @@ def _inverse_shape(self, shape): tuple(shape[: -len(self._out_event_shape)]) + self._in_event_shape ) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: shape = x.shape[: x.dim() - len(self._in_event_shape)] return paddle.zeros(shape, dtype=x.dtype) @@ -989,20 +1007,20 @@ class SigmoidTransform(Transform): """ @property - def _domain(self): + def _domain(self) -> Tensor: return variable.real @property def _codomain(self): return variable.Variable(False, 0, constraint.Range(0.0, 1.0)) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return F.sigmoid(x) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.log() - (-y).log1p() - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return -F.softplus(-x) - F.softplus(x) @@ -1040,21 +1058,21 @@ def _domain(self): def _codomain(self): return variable.Variable(False, 1, constraint.simplex) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: x = (x - x.max(-1, keepdim=True)[0]).exp() return x / x.sum(-1, keepdim=True) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.log() - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < 1: raise ValueError( f"Expected length of shape is grater than 1, but got {len(shape)}" ) return shape - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < 1: raise ValueError( f"Expected length of shape is grater than 1, but got {len(shape)}" @@ -1102,8 +1120,9 @@ class StackTransform(Transform): [2. , 1.38629436], [3. , 1.79175949]]) """ + transforms: Sequence[Transform] - def __init__(self, transforms, axis=0): + def __init__(self, transforms: Sequence[Transform], axis: int = 0): if not transforms or not isinstance(transforms, typing.Sequence): raise TypeError( f"Expected 'transforms' is Sequence[Transform], but got {type(transforms)}." @@ -1118,18 +1137,18 @@ def __init__(self, transforms, axis=0): self._transforms = transforms self._axis = axis - def _is_injective(self): + def _is_injective(self) -> bool: return all(t._is_injective() for t in self._transforms) @property - def transforms(self): + def transforms(self) -> Sequence[Transform]: return self._transforms @property - def axis(self): + def axis(self) -> int: return self._axis - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: self._check_size(x) return paddle.stack( [ @@ -1139,7 +1158,7 @@ def _forward(self, x): self._axis, ) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: self._check_size(y) return paddle.stack( [ @@ -1149,7 +1168,7 @@ def _inverse(self, y): self._axis, ) - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: self._check_size(x) return paddle.stack( [ @@ -1159,7 +1178,7 @@ def _forward_log_det_jacobian(self, x): self._axis, ) - def _check_size(self, v): + def _check_size(self, v: Tensor) -> Tensor: if not (-v.dim() <= self._axis < v.dim()): raise ValueError( f'Input dimensions {v.dim()} should be grater than stack ' @@ -1208,7 +1227,7 @@ class StickBreakingTransform(Transform): _type = Type.BIJECTION - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) z = F.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) @@ -1216,25 +1235,25 @@ def _forward(self, x): z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1 ) - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: y_crop = y[..., :-1] offset = y.shape[-1] - paddle.ones([y_crop.shape[-1]]).cumsum(-1) sf = 1 - y_crop.cumsum(-1) x = y_crop.log() - sf.log() + offset.log() return x - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: y = self.forward(x) offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) x = x - offset.log() return (-x + F.log_sigmoid(x) + y[..., :-1].log()).sum(-1) - def _forward_shape(self, shape): + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if not shape: raise ValueError(f"Expected 'shape' is not empty, but got {shape}") return shape[:-1] + (shape[-1] + 1,) - def _inverse_shape(self, shape): + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if not shape: raise ValueError(f"Expected 'shape' is not empty, but got {shape}") return shape[:-1] + (shape[-1] - 1,) @@ -1283,20 +1302,20 @@ class TanhTransform(Transform): _type = Type.BIJECTION @property - def _domain(self): + def _domain(self) -> Tensor: return variable.real @property def _codomain(self): return variable.Variable(False, 0, constraint.Range(-1.0, 1.0)) - def _forward(self, x): + def _forward(self, x: Tensor) -> Tensor: return x.tanh() - def _inverse(self, y): + def _inverse(self, y: Tensor) -> Tensor: return y.atanh() - def _forward_log_det_jacobian(self, x): + def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: """We implicitly rely on _forward_log_det_jacobian rather than explicitly implement ``_inverse_log_det_jacobian`` since directly using ``-tf.math.log1p(-tf.square(y))`` has lower numerical precision. From 9c5331b3d045a07ee6486fa64a518c5f5fc1b95b Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 10 Jul 2024 15:49:21 +0800 Subject: [PATCH 2/5] fix --- .../distribution/transformed_distribution.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index c36f6af60f5a21..32a8e027258050 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -11,11 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import typing +from typing import TYPE_CHECKING, Sequence from paddle.distribution import distribution, independent, transform +if TYPE_CHECKING: + from paddle import Tensor + from paddle.distribution.distribution import Distribution + from paddle.distribution.transform import Transform + class TransformedDistribution(distribution.Distribution): r""" @@ -48,8 +55,12 @@ class TransformedDistribution(distribution.Distribution): -1.64333570) >>> # doctest: -SKIP """ + base: Distribution + transforms: Sequence[Transform] - def __init__(self, base, transforms): + def __init__( + self, base: Distribution, transforms: Sequence[Transform] + ) -> None: if not isinstance(base, distribution.Distribution): raise TypeError( f"Expected type of 'base' is Distribution, but got {type(base)}." @@ -92,7 +103,7 @@ def __init__(self, base, transforms): ], ) - def sample(self, shape=()): + def sample(self, shape: Sequence[int] = ()) -> Tensor: """Sample from ``TransformedDistribution``. Args: @@ -106,7 +117,7 @@ def sample(self, shape=()): x = t.forward(x) return x - def rsample(self, shape=()): + def rsample(self, shape: Sequence[int] = ()) -> Tensor: """Reparameterized sample from ``TransformedDistribution``. Args: @@ -120,7 +131,7 @@ def rsample(self, shape=()): x = t.forward(x) return x - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: """The log probability evaluated at value. Args: @@ -145,5 +156,5 @@ def log_prob(self, value): return log_prob -def _sum_rightmost(value, n): +def _sum_rightmost(value: Tensor, n: int) -> Tensor: return value.sum(list(range(-n, 0))) if n > 0 else value From e3317fa8e4916364dd2105b6f89d36c268cefa1e Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 12 Jul 2024 09:13:43 +0800 Subject: [PATCH 3/5] fix --- python/paddle/distribution/transform.py | 34 +++++++++++++++---------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index bda0e6f25c92d3..28486be8e96909 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -16,7 +16,12 @@ import enum import math import typing -from typing import TYPE_CHECKING, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Sequence, + overload, +) import paddle import paddle.nn.functional as F @@ -123,7 +128,7 @@ class Transform: """ _type = Type.INJECTION - def __init__(self): + def __init__(self) -> None: super().__init__() @classmethod @@ -135,9 +140,19 @@ def _is_injective(cls): """ return Type.is_injective(cls._type) - def __call__( - self, input: Tensor | Distribution | Transform - ) -> Tensor | TransformedDistribution | ChainTransform: + @overload + def __call__(self, input: Tensor) -> Tensor: + ... + + @overload + def __call__(self, input: Distribution) -> TransformedDistribution: + ... + + @overload + def __call__(self, input: Transform) -> ChainTransform: + ... + + def __call__(self, input) -> Any: """Make this instance as a callable object. The return value is depending on the input type. @@ -455,8 +470,6 @@ class AffineTransform(Transform): 0.) """ _type = Type.BIJECTION - loc: Tensor - scale: Tensor def __init__(self, loc: Tensor, scale: Tensor) -> None: if not isinstance( @@ -548,7 +561,6 @@ class ChainTransform(Transform): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [ 0., -1., -2., -3.]) """ - transforms: Sequence[Transform] def __init__(self, transforms: Sequence[Transform]) -> None: if not isinstance(transforms, typing.Sequence): @@ -733,8 +745,6 @@ class IndependentTransform(Transform): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [6. , 15.]) """ - base: Transform - _reinterpreted_batch_rank: int def __init__(self, base: Transform, reinterpreted_batch_rank: int) -> None: if not isinstance(base, Transform): @@ -814,7 +824,6 @@ class PowerTransform(Transform): [0.69314718, 1.38629436]) """ _type = Type.BIJECTION - power: Tensor def __init__(self, power: Tensor) -> None: if not isinstance( @@ -888,8 +897,6 @@ class ReshapeTransform(Transform): [0.]) """ _type = Type.BIJECTION - in_event_shape: Sequence[int] - out_event_shape: Sequence[int] def __init__( self, in_event_shape: Sequence[int], out_event_shape: Sequence[int] @@ -1120,7 +1127,6 @@ class StackTransform(Transform): [2. , 1.38629436], [3. , 1.79175949]]) """ - transforms: Sequence[Transform] def __init__(self, transforms: Sequence[Transform], axis: int = 0): if not transforms or not isinstance(transforms, typing.Sequence): From 1bc08a7abf54b6037c047a524f73f9faf5428b75 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sat, 20 Jul 2024 22:24:55 +0800 Subject: [PATCH 4/5] fix all `_domain` and `_codomain` typing --- python/paddle/distribution/constraint.py | 2 +- python/paddle/distribution/transform.py | 60 ++++++++++++++---------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/python/paddle/distribution/constraint.py b/python/paddle/distribution/constraint.py index a339d47c9d164f..e59163dfd3e918 100644 --- a/python/paddle/distribution/constraint.py +++ b/python/paddle/distribution/constraint.py @@ -34,7 +34,7 @@ def __call__(self, value: Tensor) -> Tensor: class Range(Constraint): - def __init__(self, lower: Tensor, upper: Tensor) -> None: + def __init__(self, lower: float | Tensor, upper: float | Tensor) -> None: self._lower = lower self._upper = upper super().__init__() diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index 28486be8e96909..e8e04d40105c0f 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -126,6 +126,7 @@ class Transform: * _inverse_shape """ + _type = Type.INJECTION def __init__(self) -> None: @@ -312,12 +313,12 @@ def inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return self._inverse_shape(shape) @property - def _domain(self): + def _domain(self) -> variable.Variable: """The domain of this transformation""" return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Variable: """The codomain of this transformation""" return variable.real @@ -421,6 +422,7 @@ class AbsTransform(Transform): 0.)) """ + _type = Type.SURJECTION def _forward(self, x: Tensor) -> Tensor: @@ -434,11 +436,11 @@ def _inverse_log_det_jacobian(self, y: Tensor) -> Tensor: return zero, zero @property - def _domain(self) -> Tensor: + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self) -> Tensor: + def _codomain(self) -> variable.Positive: return variable.positive @@ -469,6 +471,7 @@ class AffineTransform(Transform): Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True, 0.) """ + _type = Type.BIJECTION def __init__(self, loc: Tensor, scale: Tensor) -> None: @@ -520,11 +523,11 @@ def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: ) @property - def _domain(self) -> Tensor: + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self) -> Tensor: + def _codomain(self) -> variable.Real: return variable.real @@ -614,7 +617,7 @@ def _sum_rightmost(self, value: Tensor, n: int) -> Tensor: return value.sum(list(range(-n, 0))) if n > 0 else value @property - def _domain(self): + def _domain(self) -> variable.Independent: domain = self.transforms[0]._domain # Compute the lower bound of input dimensions for chain transform. @@ -642,7 +645,7 @@ def _domain(self): return variable.Independent(domain, event_rank - domain.event_rank) @property - def _codomain(self): + def _codomain(self) -> variable.Independent: codomain = self.transforms[-1]._codomain event_rank = self.transforms[0]._domain.event_rank @@ -679,17 +682,18 @@ class ExpTransform(Transform): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [ 0. , -0.69314718, -1.09861231]) """ + _type = Type.BIJECTION def __init__(self) -> None: super().__init__() @property - def _domain(self) -> Tensor: + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self) -> Tensor: + def _codomain(self) -> variable.Positive: return variable.positive def _forward(self, x: Tensor) -> Tensor: @@ -785,13 +789,13 @@ def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return self._base.inverse_shape(shape) @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent( self._base._domain, self._reinterpreted_batch_rank ) @property - def _codomain(self): + def _codomain(self) -> variable.Independent: return variable.Independent( self._base._codomain, self._reinterpreted_batch_rank ) @@ -823,6 +827,7 @@ class PowerTransform(Transform): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [0.69314718, 1.38629436]) """ + _type = Type.BIJECTION def __init__(self, power: Tensor) -> None: @@ -840,11 +845,11 @@ def power(self) -> Tensor: return self._power @property - def _domain(self) -> Tensor: + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self) -> Tensor: + def _codomain(self) -> variable.Positive: return variable.positive def _forward(self, x: Tensor) -> Tensor: @@ -896,6 +901,7 @@ class ReshapeTransform(Transform): Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True, [0.]) """ + _type = Type.BIJECTION def __init__( @@ -934,11 +940,11 @@ def out_event_shape(self) -> tuple[Sequence[int]]: return self._out_event_shape @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent(variable.real, len(self._in_event_shape)) @property - def _codomain(self): + def _codomain(self) -> variable.Independent: return variable.Independent(variable.real, len(self._out_event_shape)) def _forward(self, x: Tensor) -> Tensor: @@ -1014,11 +1020,11 @@ class SigmoidTransform(Transform): """ @property - def _domain(self) -> Tensor: + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 0, constraint.Range(0.0, 1.0)) def _forward(self, x: Tensor) -> Tensor: @@ -1055,14 +1061,15 @@ class SoftmaxTransform(Transform): [[-1.09861231, -1.09861231, -1.09861231], [-1.09861231, -1.09861231, -1.09861231]]) """ + _type = Type.OTHER @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent(variable.real, 1) @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 1, constraint.simplex) def _forward(self, x: Tensor) -> Tensor: @@ -1197,11 +1204,11 @@ def _check_size(self, v: Tensor) -> Tensor: ) @property - def _domain(self): + def _domain(self) -> variable.Stack: return variable.Stack([t._domain for t in self._transforms], self._axis) @property - def _codomain(self): + def _codomain(self) -> variable.Stack: return variable.Stack( [t._codomain for t in self._transforms], self._axis ) @@ -1265,11 +1272,11 @@ def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return shape[:-1] + (shape[-1] - 1,) @property - def _domain(self): + def _domain(self) -> variable.Independent: return variable.Independent(variable.real, 1) @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 1, constraint.simplex) @@ -1305,14 +1312,15 @@ class TanhTransform(Transform): [6.61441946 , 8.61399269 , 10.61451530]]) >>> # doctest: -SKIP """ + _type = Type.BIJECTION @property - def _domain(self) -> Tensor: + def _domain(self) -> variable.Real: return variable.real @property - def _codomain(self): + def _codomain(self) -> variable.Variable: return variable.Variable(False, 0, constraint.Range(-1.0, 1.0)) def _forward(self, x: Tensor) -> Tensor: From 8335d5143f573f8dcd98a0f06dc58cc35b94e21b Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sat, 20 Jul 2024 22:44:15 +0800 Subject: [PATCH 5/5] fix all `_forward_shape` and `_inverse_shape` --- python/paddle/distribution/transform.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index e8e04d40105c0f..1fbd71af5ee758 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -428,10 +428,10 @@ class AbsTransform(Transform): def _forward(self, x: Tensor) -> Tensor: return x.abs() - def _inverse(self, y: Tensor) -> Tensor: + def _inverse(self, y: Tensor) -> tuple[Tensor, Tensor]: return -y, y - def _inverse_log_det_jacobian(self, y: Tensor) -> Tensor: + def _inverse_log_det_jacobian(self, y: Tensor) -> tuple[Tensor, Tensor]: zero = paddle.zeros([], dtype=y.dtype) return zero, zero @@ -506,7 +506,7 @@ def _inverse(self, y: Tensor) -> Tensor: def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return paddle.abs(self._scale).log() - def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple( paddle.broadcast_shape( paddle.broadcast_shape(shape, self._loc.shape), @@ -514,7 +514,7 @@ def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: ) ) - def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple( paddle.broadcast_shape( paddle.broadcast_shape(shape, self._loc.shape), @@ -861,10 +861,10 @@ def _inverse(self, y: Tensor) -> Tensor: def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: return (self._power * x.pow(self._power - 1)).abs().log() - def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple(paddle.broadcast_shape(shape, self._power.shape)) - def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: return tuple(paddle.broadcast_shape(shape, self._power.shape)) @@ -959,7 +959,7 @@ def _inverse(self, y: Tensor) -> Tensor: + self._in_event_shape ) - def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: + def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < len(self._in_event_shape): raise ValueError( f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" @@ -974,7 +974,7 @@ def _forward_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: tuple(shape[: -len(self._in_event_shape)]) + self._out_event_shape ) - def _inverse_shape(self, shape: Sequence[int]) -> tuple[Sequence[int]]: + def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if len(shape) < len(self._out_event_shape): raise ValueError( f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" @@ -1191,7 +1191,7 @@ def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: self._axis, ) - def _check_size(self, v: Tensor) -> Tensor: + def _check_size(self, v: Tensor) -> None: if not (-v.dim() <= self._axis < v.dim()): raise ValueError( f'Input dimensions {v.dim()} should be grater than stack ' @@ -1264,12 +1264,12 @@ def _forward_log_det_jacobian(self, x: Tensor) -> Tensor: def _forward_shape(self, shape: Sequence[int]) -> Sequence[int]: if not shape: raise ValueError(f"Expected 'shape' is not empty, but got {shape}") - return shape[:-1] + (shape[-1] + 1,) + return (*shape[:-1], shape[-1] + 1) def _inverse_shape(self, shape: Sequence[int]) -> Sequence[int]: if not shape: raise ValueError(f"Expected 'shape' is not empty, but got {shape}") - return shape[:-1] + (shape[-1] - 1,) + return (*shape[:-1], shape[-1] - 1) @property def _domain(self) -> variable.Independent: