Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add probability distribution transformation APIs #40536

Merged
merged 8 commits into from
Mar 31, 2022
27 changes: 18 additions & 9 deletions python/paddle/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .beta import Beta
from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .kl import kl_divergence, register_kl
from .multinomial import Multinomial
from .normal import Normal
from .uniform import Uniform
from paddle.distribution import transform
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.independent import Independent
from paddle.distribution.kl import kl_divergence, register_kl
from paddle.distribution.multinomial import Multinomial
from paddle.distribution.normal import Normal
from paddle.distribution.transform import * # noqa: F403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要import *的原因是什么呢?

Copy link
Contributor Author

@cxxly cxxly Mar 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distribution/transform.py 文件中定义了需要公开API的__ALL__列表,在distribution/__init__中用import * 全部导出,并添加到__init__.py 的__all__列表中, 可以通过paddle.distribution.xxx访问,访问路径和竞品保持一致

from paddle.distribution.transformed_distribution import \
TransformedDistribution
from paddle.distribution.uniform import Uniform

__all__ = [ # noqa
'Beta',
Expand All @@ -33,4 +38,8 @@
'Uniform',
'kl_divergence',
'register_kl',
'Independent',
'TransformedDistribution'
]

__all__.extend(transform.__all__)
9 changes: 4 additions & 5 deletions python/paddle/distribution/beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
import numbers

import paddle
from paddle.distribution import dirichlet, exponential_family

from .dirichlet import Dirichlet
from .exponential_family import ExponentialFamily


class Beta(ExponentialFamily):
class Beta(exponential_family.ExponentialFamily):
r"""
Beta distribution parameterized by alpha and beta.

Expand Down Expand Up @@ -93,7 +91,8 @@ def __init__(self, alpha, beta):

self.alpha, self.beta = paddle.broadcast_tensors([alpha, beta])

self._dirichlet = Dirichlet(paddle.stack([self.alpha, self.beta], -1))
self._dirichlet = dirichlet.Dirichlet(
paddle.stack([self.alpha, self.beta], -1))

super(Beta, self).__init__(self._dirichlet._batch_shape)

Expand Down
24 changes: 12 additions & 12 deletions python/paddle/distribution/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
import numpy as np
import paddle
from paddle import _C_ops

from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from .distribution import Distribution


class Categorical(Distribution):
from paddle.distribution import distribution
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial


class Categorical(distribution.Distribution):
r"""
Categorical distribution is a discrete probability distribution that
describes the possible results of a random variable that can take on
Expand Down
53 changes: 53 additions & 0 deletions python/paddle/distribution/constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle


class Constraint(object):
"""Constraint condition for random variable.
"""

def __call__(self, value):
raise NotImplementedError


class Real(Constraint):
def __call__(self, value):
return value == value


class Range(Constraint):
def __init__(self, lower, upper):
self._lower = lower
self._upper = upper
super(Range, self).__init__()

def __call__(self, value):
return self._lower <= value <= self._upper


class Positive(Constraint):
def __call__(self, value):
return value >= 0.


class Simplex(Constraint):
def __call__(self, value):
return paddle.all(value >= 0, axis=-1) and (
(value.sum(-1) - 1).abs() < 1e-6)


real = Real()
positive = Positive()
simplex = Simplex()
11 changes: 5 additions & 6 deletions python/paddle/distribution/dirichlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
# limitations under the License.

import paddle
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper

from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import _non_static_mode
from ..fluid.layer_helper import LayerHelper
from .exponential_family import ExponentialFamily


class Dirichlet(ExponentialFamily):
class Dirichlet(exponential_family.ExponentialFamily):
r"""
Dirichlet distribution with parameter "concentration".

Expand Down
32 changes: 23 additions & 9 deletions python/paddle/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,14 @@
import numpy as np
import paddle
from paddle import _C_ops

from ..fluid import core
from ..fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from ..fluid.framework import _non_static_mode
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops, tensor)
from ..tensor import arange, concat, gather_nd, multinomial
from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type,
check_variable_and_dtype, convert_dtype)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
from paddle.tensor import arange, concat, gather_nd, multinomial


class Distribution(object):
Expand Down Expand Up @@ -78,10 +78,24 @@ def event_shape(self):
"""
return self._event_shape

@property
def mean(self):
"""Mean of distribution"""
raise NotImplementedError

@property
def variance(self):
"""Variance of distribution"""
raise NotImplementedError

def sample(self, shape=()):
"""Sampling from the distribution."""
raise NotImplementedError

def rsample(self, shape=()):
"""reparameterized sample"""
raise NotImplementedError

def entropy(self):
"""The entropy of the distribution."""
raise NotImplementedError
Expand All @@ -96,7 +110,7 @@ def prob(self, value):
Args:
value (Tensor): value which will be evaluated
"""
raise NotImplementedError
return self.log_prob(value).exp()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add some abstract methods in base class, e.g. mean, variance and rsample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加,这样能保持不同子类持有的方法一致;存量 normal, uniform, categorical目前是NotImplementedError,按照设计文档计划后续统一更新

def log_prob(self, value):
"""Log probability density/mass function."""
Expand Down
7 changes: 3 additions & 4 deletions python/paddle/distribution/exponential_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
# limitations under the License.

import paddle
from paddle.distribution import distribution
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode

from ..fluid.framework import _non_static_mode
from .distribution import Distribution


class ExponentialFamily(Distribution):
class ExponentialFamily(distribution.Distribution):
r"""
ExponentialFamily is the base class for probability distributions belonging
to exponential family, whose probability mass/density function has the
Expand Down
92 changes: 92 additions & 0 deletions python/paddle/distribution/independent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from paddle.distribution import distribution


class Independent(distribution.Distribution):
r"""
Reinterprets some of the batch dimensions of a distribution as event dimensions.

This is mainly useful for changing the shape of the result of
:meth:`log_prob`.

Args:
base (Distribution): The base distribution.
reinterpreted_batch_rank (int): The number of batch dimensions to
reinterpret as event dimensions.

Examples:

.. code-block:: python

import paddle
from paddle.distribution import independent

beta = paddle.distribution.Beta(paddle.to_tensor([0.5, 0.5]), paddle.to_tensor([0.5, 0.5]))
print(beta.batch_shape, beta.event_shape)
# (2,) ()
print(beta.log_prob(paddle.to_tensor(0.2)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.22843921, -0.22843921])
reinterpreted_beta = independent.Independent(beta, 1)
print(reinterpreted_beta.batch_shape, reinterpreted_beta.event_shape)
# () (2,)
print(reinterpreted_beta.log_prob(paddle.to_tensor([0.2, 0.2])))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [-0.45687842])
"""

def __init__(self, base, reinterpreted_batch_rank):
if not isinstance(base, distribution.Distribution):
raise TypeError(
f"Expected type of 'base' is Distribution, but got {type(base)}")
if not (0 < reinterpreted_batch_rank <= len(base.batch_shape)):
raise ValueError(
f"Expected 0 < reinterpreted_batch_rank <= {len(base.batch_shape)}, but got {reinterpreted_batch_rank}"
)
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank

shape = base.batch_shape + base.event_shape
super(Independent, self).__init__(
batch_shape=shape[:len(base.batch_shape) -
reinterpreted_batch_rank],
event_shape=shape[len(base.batch_shape) -
reinterpreted_batch_rank:])

@property
def mean(self):
return self._base.mean

@property
def variance(self):
return self._base.variance

def sample(self, shape=()):
return self._base.sample(shape)

def log_prob(self, value):
return self._sum_rightmost(
self._base.log_prob(value), self._reinterpreted_batch_rank)

def prob(self, value):
return self.log_prob(value).exp()

def entropy(self):
return self._sum_rightmost(self._base.entropy(),
self._reinterpreted_batch_rank)

def _sum_rightmost(self, value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value
18 changes: 8 additions & 10 deletions python/paddle/distribution/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,14 @@
import warnings

import paddle

from ..fluid.framework import _non_static_mode
from .beta import Beta
from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .normal import Normal
from .uniform import Uniform
from paddle.distribution.beta import Beta
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.normal import Normal
from paddle.distribution.uniform import Uniform
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode

__all__ = ["register_kl", "kl_divergence"]

Expand Down Expand Up @@ -207,5 +206,4 @@ def _kl_expfamily_expfamily(p, q):


def _sum_rightmost(value, n):
"""Sum elements along rightmost n dim"""
return value.sum(list(range(-n, 0))) if n > 0 else value
Loading