From e486923b8f447107d9dd7c3625874f0ded2de87e Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 1 Jul 2024 16:11:59 +0800 Subject: [PATCH 1/3] fix --- python/paddle/distribution/beta.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 7c6980efe65481..04af6374736359 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import numbers +from typing import TYPE_CHECKING, Sequence import paddle from paddle.distribution import dirichlet, exponential_family +if TYPE_CHECKING: + from paddle import Tensor + class Beta(exponential_family.ExponentialFamily): r""" @@ -86,7 +92,7 @@ class Beta(exponential_family.ExponentialFamily): [-1.91923141, -0.38095081]) """ - def __init__(self, alpha, beta): + def __init__(self, alpha: float | Tensor, beta: float | Tensor) -> None: if isinstance(alpha, numbers.Real): alpha = paddle.full(shape=[], fill_value=alpha) @@ -102,17 +108,17 @@ def __init__(self, alpha, beta): super().__init__(self._dirichlet._batch_shape) @property - def mean(self): + def mean(self) -> float: """Mean of beta distribution.""" return self.alpha / (self.alpha + self.beta) @property - def variance(self): + def variance(self) -> float: """Variance of beat distribution""" sum = self.alpha + self.beta return self.alpha * self.beta / (sum.pow(2) * (sum + 1)) - def prob(self, value): + def prob(self, value: Tensor) -> Tensor: """Probability density function evaluated at value Args: @@ -123,7 +129,7 @@ def prob(self, value): """ return paddle.exp(self.log_prob(value)) - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: """Log probability density function evaluated at value Args: @@ -134,7 +140,7 @@ def log_prob(self, value): """ return self._dirichlet.log_prob(paddle.stack([value, 1.0 - value], -1)) - def sample(self, shape=()): + def sample(self, shape: Sequence[int] | None = None) -> Tensor: """Sample from beta distribution with sample shape. Args: @@ -146,7 +152,7 @@ def sample(self, shape=()): shape = shape if isinstance(shape, tuple) else tuple(shape) return paddle.squeeze(self._dirichlet.sample(shape)[..., 0], axis=-1) - def entropy(self): + def entropy(self) -> Tensor: """Entropy of dirichlet distribution Returns: @@ -155,8 +161,8 @@ def entropy(self): return self._dirichlet.entropy() @property - def _natural_parameters(self): + def _natural_parameters(self) -> tuple[Tensor, Tensor]: return (self.alpha, self.beta) - def _log_normalizer(self, x, y): + def _log_normalizer(self, x: Tensor, y: Tensor) -> Tensor: return paddle.lgamma(x) + paddle.lgamma(y) - paddle.lgamma(x + y) From d135c83489bdefe12c55f47bb82f983484f018db Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 2 Jul 2024 08:59:34 +0800 Subject: [PATCH 2/3] fix --- python/paddle/distribution/beta.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 04af6374736359..7bc8aadf9ad4ea 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -14,7 +14,7 @@ from __future__ import annotations import numbers -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING import paddle from paddle.distribution import dirichlet, exponential_family @@ -91,6 +91,8 @@ class Beta(exponential_family.ExponentialFamily): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [-1.91923141, -0.38095081]) """ + alpha: float | Tensor + beta: float | Tensor def __init__(self, alpha: float | Tensor, beta: float | Tensor) -> None: if isinstance(alpha, numbers.Real): @@ -108,12 +110,12 @@ def __init__(self, alpha: float | Tensor, beta: float | Tensor) -> None: super().__init__(self._dirichlet._batch_shape) @property - def mean(self) -> float: + def mean(self) -> Tensor: """Mean of beta distribution.""" return self.alpha / (self.alpha + self.beta) @property - def variance(self) -> float: + def variance(self) -> Tensor: """Variance of beat distribution""" sum = self.alpha + self.beta return self.alpha * self.beta / (sum.pow(2) * (sum + 1)) @@ -140,7 +142,7 @@ def log_prob(self, value: Tensor) -> Tensor: """ return self._dirichlet.log_prob(paddle.stack([value, 1.0 - value], -1)) - def sample(self, shape: Sequence[int] | None = None) -> Tensor: + def sample(self, shape=()): """Sample from beta distribution with sample shape. Args: From dd41a20d52644dbea7e057c3347eafdee3e915c8 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 4 Jul 2024 09:01:29 +0800 Subject: [PATCH 3/3] fix --- python/paddle/distribution/beta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 7bc8aadf9ad4ea..a32f41cd15078b 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -91,8 +91,8 @@ class Beta(exponential_family.ExponentialFamily): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [-1.91923141, -0.38095081]) """ - alpha: float | Tensor - beta: float | Tensor + alpha: Tensor + beta: Tensor def __init__(self, alpha: float | Tensor, beta: float | Tensor) -> None: if isinstance(alpha, numbers.Real):