diff --git a/python/paddle/distribution/beta.py b/python/paddle/distribution/beta.py index 7c6980efe65481..a32f41cd15078b 100644 --- a/python/paddle/distribution/beta.py +++ b/python/paddle/distribution/beta.py @@ -11,11 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import numbers +from typing import TYPE_CHECKING import paddle from paddle.distribution import dirichlet, exponential_family +if TYPE_CHECKING: + from paddle import Tensor + class Beta(exponential_family.ExponentialFamily): r""" @@ -85,8 +91,10 @@ class Beta(exponential_family.ExponentialFamily): Tensor(shape=[2], dtype=float32, place=Place(cpu), stop_gradient=True, [-1.91923141, -0.38095081]) """ + alpha: Tensor + beta: Tensor - def __init__(self, alpha, beta): + def __init__(self, alpha: float | Tensor, beta: float | Tensor) -> None: if isinstance(alpha, numbers.Real): alpha = paddle.full(shape=[], fill_value=alpha) @@ -102,17 +110,17 @@ def __init__(self, alpha, beta): super().__init__(self._dirichlet._batch_shape) @property - def mean(self): + def mean(self) -> Tensor: """Mean of beta distribution.""" return self.alpha / (self.alpha + self.beta) @property - def variance(self): + def variance(self) -> Tensor: """Variance of beat distribution""" sum = self.alpha + self.beta return self.alpha * self.beta / (sum.pow(2) * (sum + 1)) - def prob(self, value): + def prob(self, value: Tensor) -> Tensor: """Probability density function evaluated at value Args: @@ -123,7 +131,7 @@ def prob(self, value): """ return paddle.exp(self.log_prob(value)) - def log_prob(self, value): + def log_prob(self, value: Tensor) -> Tensor: """Log probability density function evaluated at value Args: @@ -146,7 +154,7 @@ def sample(self, shape=()): shape = shape if isinstance(shape, tuple) else tuple(shape) return paddle.squeeze(self._dirichlet.sample(shape)[..., 0], axis=-1) - def entropy(self): + def entropy(self) -> Tensor: """Entropy of dirichlet distribution Returns: @@ -155,8 +163,8 @@ def entropy(self): return self._dirichlet.entropy() @property - def _natural_parameters(self): + def _natural_parameters(self) -> tuple[Tensor, Tensor]: return (self.alpha, self.beta) - def _log_normalizer(self, x, y): + def _log_normalizer(self, x: Tensor, y: Tensor) -> Tensor: return paddle.lgamma(x) + paddle.lgamma(y) - paddle.lgamma(x + y)