Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add probability distribution transformation APIs #40536

Merged
merged 8 commits into from
Mar 31, 2022

Conversation

cxxly
Copy link
Contributor

@cxxly cxxly commented Mar 14, 2022

PR types

New features

PR changes

APIs

Describe

Adds 13 transformation APIs and 2 distribution APIs :

new transformation APIs:

  1. Transform
  2. AbsTransform
  3. AffineTransform
  4. ChainTransform
  5. ExpTransform
  6. IndependentTransform
  7. PowerTransform
  8. ReshapeTransform
  9. SigmoidTransform
  10. SoftmaxTransform
  11. StackTransform
  12. StickBreakingTransform
  13. TanhTransform

new distribution APIs:

  1. Independent
  2. TransformedDistribution

Examples:

import paddle

x = paddle.to_tensor([1., 2.])
affine = paddle.distribution.AffineTransform(paddle.to_tensor(0.), paddle.to_tensor(1.))

print(affine.forward(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [1., 2.])
print(affine.inverse(affine.forward(x)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [1., 2.])
print(affine.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [0.])

@paddle-bot-old
Copy link

paddle-bot-old bot commented Mar 14, 2022

✅ This PR's description meets the template requirements!
Please wait for other CI results.

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@cxxly cxxly changed the title Add probability distribution transformation APIs for Paddle Add probability distribution transformation APIs Mar 29, 2022
"""
BIJECTION = 'bijection' # bijective(injective and surjective)
INJECTION = 'injection' # injective-only
SURJECTION = 'surjection' # surjective-inly

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

Copy link
Contributor Author

@cxxly cxxly Mar 30, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

iclementine
iclementine previously approved these changes Mar 30, 2022
TCChenlong
TCChenlong previously approved these changes Mar 30, 2022
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
TODO:Add Chinese documentation

print(affine.forward(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
# [1., 2.])
print(affine.inverse(power.forward(x)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"power" means paddle.distribution.PowerTransform?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已更新,粘贴错误,粘贴了一个旧的代码示例

@jeff41404
Copy link
Contributor

Examples in Describe above: print(affine.inverse(power.forward(x))), "power" means paddle.distribution.PowerTransform? or should be affine?

@cxxly
Copy link
Contributor Author

cxxly commented Mar 30, 2022

import paddle

x = paddle.to_tensor([1., 2.])
affine = paddle.distribution.AffineTransform(paddle.to_tensor(0.), paddle.to_tensor(1.))

print(affine.forward(x))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [1., 2.])
print(affine.inverse(affine.forward(x)))
# Tensor(shape=[2], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [1., 2.])
print(affine.forward_log_det_jacobian(x))
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True,
#        [0.])

Examples in Describe above: print(affine.inverse(power.forward(x))), "power" means paddle.distribution.PowerTransform? or should be affine?

是 ''affine'',粘贴错误,文档和PR描述均已更新

@@ -96,7 +96,7 @@ def prob(self, value):
Args:
value (Tensor): value which will be evaluated
"""
raise NotImplementedError
return self.log_prob(value).exp()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add some abstract methods in base class, e.g. mean, variance and rsample

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加,这样能保持不同子类持有的方法一致;存量 normal, uniform, categorical目前是NotImplementedError,按照设计文档计划后续统一更新

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

from paddle.distribution.kl import kl_divergence, register_kl
from paddle.distribution.multinomial import Multinomial
from paddle.distribution.normal import Normal
from paddle.distribution.transform import * # noqa: F403
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里需要import *的原因是什么呢?

Copy link
Contributor Author

@cxxly cxxly Mar 31, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

distribution/transform.py 文件中定义了需要公开API的__ALL__列表,在distribution/__init__中用import * 全部导出,并添加到__init__.py 的__all__列表中, 可以通过paddle.distribution.xxx访问,访问路径和竞品保持一致

@iclementine iclementine merged commit 6735a37 into PaddlePaddle:develop Mar 31, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants