diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 72aa4ad121..9aefed3585 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -7,6 +7,8 @@ - Add `logit_p` keyword to `pm.Bernoulli`, so that users can specify the logit of the success probability. This is faster and more stable than using `p=tt.nnet.sigmoid(logit_p)`. +- Add `random` keyword to `pm.DensityDist` thus enabling users to pass custom random method + which in turn makes sampling from a `DensityDist` possible. ### Deprecations diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index c0627254fb..a8a17fb40f 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -178,12 +178,21 @@ def __init__(self, shape=(), dtype=None, defaults=('median', 'mean', 'mode'), class DensityDist(Distribution): """Distribution based on a given log density function.""" - def __init__(self, logp, shape=(), dtype=None, testval=0, *args, **kwargs): + def __init__(self, logp, shape=(), dtype=None, testval=0, random=None, *args, **kwargs): if dtype is None: dtype = theano.config.floatX super(DensityDist, self).__init__( shape, dtype, testval, *args, **kwargs) self.logp = logp + self.rand = random + + def random(self, *args, **kwargs): + if self.rand is not None: + return self.rand(*args, **kwargs) + else: + raise ValueError("Distribution was not passed any random method " + "Define a custom random method and pass it as kwarg random") + def draw_values(params, point=None): diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 829709e01d..b622b9b65d 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -688,3 +688,47 @@ def ref_rand(size, w, mu, sd): 'sd': Domain([[1.5, 2., 3.]], edges=(None, None))}, size=1000, ref_rand=ref_rand) + + def test_density_dist(self): + def ref_rand(size, mu, sd): + return st.norm.rvs(size=size, loc=mu, scale=sd) + + class TestDensityDist(pm.DensityDist): + + def __init__(self, **kwargs): + norm_dist = pm.Normal.dist() + super(TestDensityDist, self).__init__(logp=norm_dist.logp, random=norm_dist.random) + + pymc3_random(TestDensityDist, {},ref_rand=ref_rand) + + def check_model_samplability(self): + model = pm.Model() + with model: + normal_dist = pm.Normal.dist() + density_dist = pm.DensityDist('density_dist', normal_dist.logp, random=normal_dist.random) + step = pm.Metropolis() + trace = pm.sample(100, step, tuning=0) + + try: + ppc = pm.sample_ppc(trace, samples=500, model=model, size=100) + if len(ppc) == 0: + npt.assert_true(len(ppc) == 0, 'length of ppc sample is zero') + except: + assert False + + def check_scipy_distributions(self): + model = pm.Model() + with model: + norm_dist_logp = st.norm.logpdf + norm_dist_random = np.random.normal + density_dist = pm.DensityDist('density_dist', normal_dist_logp, random=normal_dist_random) + step = pm.Metropolis() + trace = pm.sample(100, step, tuning=0) + + try: + ppc = pm.sample_ppc(trace, samples=500, model=model, size=100) + if len(ppc) == 0: + npt.assert_true(len(ppc) == 0, 'length of ppc sample is zero') + except: + assert False + \ No newline at end of file