-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add probability distribution transformation APIs #40536
Changes from all commits
567b9ea
bdee8c6
09354bf
1e3dd2c
6e86fc5
ffefb43
8a71e25
c8ef422
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import paddle | ||
|
||
|
||
class Constraint(object): | ||
"""Constraint condition for random variable. | ||
""" | ||
|
||
def __call__(self, value): | ||
raise NotImplementedError | ||
|
||
|
||
class Real(Constraint): | ||
def __call__(self, value): | ||
return value == value | ||
|
||
|
||
class Range(Constraint): | ||
def __init__(self, lower, upper): | ||
self._lower = lower | ||
self._upper = upper | ||
super(Range, self).__init__() | ||
|
||
def __call__(self, value): | ||
return self._lower <= value <= self._upper | ||
|
||
|
||
class Positive(Constraint): | ||
def __call__(self, value): | ||
return value >= 0. | ||
|
||
|
||
class Simplex(Constraint): | ||
def __call__(self, value): | ||
return paddle.all(value >= 0, axis=-1) and ( | ||
(value.sum(-1) - 1).abs() < 1e-6) | ||
|
||
|
||
real = Real() | ||
positive = Positive() | ||
simplex = Simplex() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,14 +27,14 @@ | |
import numpy as np | ||
import paddle | ||
from paddle import _C_ops | ||
|
||
from ..fluid import core | ||
from ..fluid.data_feeder import (check_dtype, check_type, | ||
check_variable_and_dtype, convert_dtype) | ||
from ..fluid.framework import _non_static_mode | ||
from ..fluid.layers import (control_flow, elementwise_add, elementwise_div, | ||
elementwise_mul, elementwise_sub, nn, ops, tensor) | ||
from ..tensor import arange, concat, gather_nd, multinomial | ||
from paddle.fluid import core | ||
from paddle.fluid.data_feeder import (check_dtype, check_type, | ||
check_variable_and_dtype, convert_dtype) | ||
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode | ||
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, | ||
elementwise_mul, elementwise_sub, nn, ops, | ||
tensor) | ||
from paddle.tensor import arange, concat, gather_nd, multinomial | ||
|
||
|
||
class Distribution(object): | ||
|
@@ -78,10 +78,24 @@ def event_shape(self): | |
""" | ||
return self._event_shape | ||
|
||
@property | ||
def mean(self): | ||
"""Mean of distribution""" | ||
raise NotImplementedError | ||
|
||
@property | ||
def variance(self): | ||
"""Variance of distribution""" | ||
raise NotImplementedError | ||
|
||
def sample(self, shape=()): | ||
"""Sampling from the distribution.""" | ||
raise NotImplementedError | ||
|
||
def rsample(self, shape=()): | ||
"""reparameterized sample""" | ||
raise NotImplementedError | ||
|
||
def entropy(self): | ||
"""The entropy of the distribution.""" | ||
raise NotImplementedError | ||
|
@@ -96,7 +110,7 @@ def prob(self, value): | |
Args: | ||
value (Tensor): value which will be evaluated | ||
""" | ||
raise NotImplementedError | ||
return self.log_prob(value).exp() | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we add some abstract methods in base class, e.g. mean, variance and rsample There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已添加,这样能保持不同子类持有的方法一致;存量 normal, uniform, categorical目前是 |
||
def log_prob(self, value): | ||
"""Log probability density/mass function.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from paddle.distribution import distribution | ||
|
||
|
||
class Independent(distribution.Distribution): | ||
r""" | ||
Reinterprets some of the batch dimensions of a distribution as event dimensions. | ||
|
||
This is mainly useful for changing the shape of the result of | ||
:meth:`log_prob`. | ||
|
||
Args: | ||
base (Distribution): The base distribution. | ||
reinterpreted_batch_rank (int): The number of batch dimensions to | ||
reinterpret as event dimensions. | ||
|
||
Examples: | ||
|
||
.. code-block:: python | ||
|
||
import paddle | ||
from paddle.distribution import independent | ||
|
||
beta = paddle.distribution.Beta(paddle.to_tensor([0.5, 0.5]), paddle.to_tensor([0.5, 0.5])) | ||
print(beta.batch_shape, beta.event_shape) | ||
# (2,) () | ||
print(beta.log_prob(paddle.to_tensor(0.2))) | ||
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
# [-0.22843921, -0.22843921]) | ||
reinterpreted_beta = independent.Independent(beta, 1) | ||
print(reinterpreted_beta.batch_shape, reinterpreted_beta.event_shape) | ||
# () (2,) | ||
print(reinterpreted_beta.log_prob(paddle.to_tensor([0.2, 0.2]))) | ||
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, | ||
# [-0.45687842]) | ||
""" | ||
|
||
def __init__(self, base, reinterpreted_batch_rank): | ||
if not isinstance(base, distribution.Distribution): | ||
raise TypeError( | ||
f"Expected type of 'base' is Distribution, but got {type(base)}") | ||
if not (0 < reinterpreted_batch_rank <= len(base.batch_shape)): | ||
raise ValueError( | ||
f"Expected 0 < reinterpreted_batch_rank <= {len(base.batch_shape)}, but got {reinterpreted_batch_rank}" | ||
) | ||
self._base = base | ||
self._reinterpreted_batch_rank = reinterpreted_batch_rank | ||
|
||
shape = base.batch_shape + base.event_shape | ||
super(Independent, self).__init__( | ||
batch_shape=shape[:len(base.batch_shape) - | ||
reinterpreted_batch_rank], | ||
event_shape=shape[len(base.batch_shape) - | ||
reinterpreted_batch_rank:]) | ||
|
||
@property | ||
def mean(self): | ||
return self._base.mean | ||
|
||
@property | ||
def variance(self): | ||
return self._base.variance | ||
|
||
def sample(self, shape=()): | ||
return self._base.sample(shape) | ||
|
||
def log_prob(self, value): | ||
return self._sum_rightmost( | ||
self._base.log_prob(value), self._reinterpreted_batch_rank) | ||
|
||
def prob(self, value): | ||
return self.log_prob(value).exp() | ||
|
||
def entropy(self): | ||
return self._sum_rightmost(self._base.entropy(), | ||
self._reinterpreted_batch_rank) | ||
|
||
def _sum_rightmost(self, value, n): | ||
return value.sum(list(range(-n, 0))) if n > 0 else value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里需要import *的原因是什么呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
distribution/transform.py
文件中定义了需要公开API的__ALL__列表,在distribution/__init__
中用import *
全部导出,并添加到__init__.py
的__all__列表中, 可以通过paddle.distribution.xxx
访问,访问路径和竞品保持一致