Skip to content

Commit

Permalink
Upgrade the PolyaGamma class to v4 design
Browse files Browse the repository at this point in the history
  • Loading branch information
zoj613 committed Mar 25, 2021
1 parent c8bc440 commit 041246a
Showing 1 changed file with 48 additions and 43 deletions.
91 changes: 48 additions & 43 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
beta,
cauchy,
exponential,
RandomVariable,
gamma,
halfcauchy,
halfnormal,
Expand Down Expand Up @@ -4242,6 +4243,26 @@ def logcdf(self, value):
)


class PolyaGammaRV(RandomVariable):
"""Polya-Gamma random variable."""

name = "polyagamma"
ndim_supp = 0
ndims_params = [0, 0]
dtype = "floatX"
# inplace = True
_print_name = ("PG", "\\operatorname{PG}")

@classmethod
def rng_fn(cls, rng, h, z, size=None):
# handle the kind of rng passed to the sampler
bg = rng._bit_generator if isinstance(rng, np.random.RandomState) else rng
return random_polyagamma(h, z, size=size, random_state=bg)


polyagamma = PolyaGammaRV()


class PolyaGamma(PositiveContinuous):
r"""
The Polya-Gamma distribution.
Expand All @@ -4266,23 +4287,22 @@ class PolyaGamma(PositiveContinuous):
import matplotlib.pyplot as plt
import numpy as np
from polyagamma import random_polyagamma
import scipy.stats as st
import seaborn as sns
plt.style.use('seaborn-darkgrid')
hs = [1., 5., 10., 15.]
zs = [0., 1., 5., 10.]
density = st.kde.gaussian_kde(data)
for h, z in zip(mus, zs):
x = np.arange(np.min(data), np.max(data), .1)
plt.plot(x, density(x), label=r'h = {}, z = {}'.format(h, z))
data = {}
for h, z in zip(hs, zs):
data[f'h = {h}, z = {z}'] = random_polyagamma(h, z, size=25000)
sns.kdeplot(data=data)
plt.xlabel('x', fontsize=12)
plt.ylabel('f(x)', fontsize=12)
plt.legend(loc=1)
plt.show()
======== =============================
Support :math:`x \in (0, \infty)`
Mean :math:`\dfrac{tanh(z/2)h}{2z}`
Variance :math:`\dfrac{h(sinh(z) - z)(1 - tanh^2(z/2))}{4z^3}`
Mean :math:`dfrac{h}{4} if :math:`z=0`, :math:`\dfrac{tanh(z/2)h}{2z}` otherwise.
Variance :math:`0.041666688h` if :math:`z=0`, :math:`\dfrac{h(sinh(z) - z)(1 - tanh^2(z/2))}{4z^3}` otherwise.
======== =============================
Parameters
Expand All @@ -4292,6 +4312,18 @@ class PolyaGamma(PositiveContinuous):
z: float, optional
The exponential tilting parameter of the distribution.
Examples
--------
.. code-block:: python
rng = np.random.default_rng()
with pm.Model():
x = pm.PolyaGamma('x', h=1, z=5.5)
with pm.Model():
x = pm.PolyaGamma('x', h=25, z=-2.3, rng=rng, size=(100, 5))
References
----------
.. [1] Polson, Nicholas G., James G. Scott, and Jesse Windle.
Expand All @@ -4306,42 +4338,15 @@ class PolyaGamma(PositiveContinuous):
Volume 79, Issue 21, (2009): 2251-2259.
"""
rv_op = polyagamma

def __init__(self, h=1.0, z=0.0, *args, **kwargs):
self.h = aet.as_tensor_variable(floatX(h))
self.z = aet.as_tensor_variable(floatX(z))
@classmethod
def dist(cls, h=1.0, z=0.0, rng=None, size=None, **kwargs):
hh = aet.as_tensor_variable(floatX(h))
zz = aet.as_tensor_variable(floatX(z))

msg = f"The variable {h} specified for PolyaGamma has non-positive "
msg = f"The variable {hh} specified for PolyaGamma has non-positive "
msg += "values, making it unsuitable for this parameter."
Assert(msg)(h, aet.all(aet.gt(h, 0.00001)))

z_zero = aet.eq(self.z, 0)
x = aet.tanh(0.5 * self.z)
self.mean = aet.switch(z_zero, 0.25 * self.h, 0.5 * self.h * x / self.z)
self.variance = aet.switch(
z_zero,
0.041666688 * self.h,
0.25 * h * (aet.sinh(self.z) - self.z) * (1 - x * x) / (self.z * self.z * self.z),
)
super().__init__(*args, **kwargs)

def _random(self, h, z, size=None):
return random_polyagamma(h, z, size=size, disable_checks=True, random_state=self.rng)

def random(self, point=None, size=None):
"""
Draw random values from PolyaGamma distribution.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
Assert(msg)(hh, aet.all(aet.gt(hh, 0.00001)))

Returns
-------
array
"""
return super().dist([h, z], size=size, rng=rng, **kwargs)

0 comments on commit 041246a

Please sign in to comment.