Skip to content

Commit

Permalink
Use dill to serialize logp functions in DensityDist (#4053)
Browse files Browse the repository at this point in the history
* Use dill to serialize logp functions in DensityDist

* Update testenv based on yml file on travis

* Explicitly test pickling and unpickling of DensityDist

* Improve release notes

* Use conda activate in create testenv
  • Loading branch information
aseyboldt authored Aug 16, 2020
1 parent aaafa8d commit e08ad0c
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 3 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Maintenance
- Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)).
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)).

### Documentation

Expand Down
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies:
- dataclasses # python_version < 3.7
- contextvars # python_version < 3.7
- mkl-service
- dill
- libblas=*=*mkl
- pip:
- black_nbconvert
- dill
14 changes: 14 additions & 0 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numbers
import contextvars
import dill
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Callable
Expand Down Expand Up @@ -419,6 +420,19 @@ def __init__(
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
self.check_shape_in_random = check_shape_in_random

def __getstate__(self):
# We use dill to serialize the logp function, as this is almost
# always defined in the notebook and won't be pickled correctly.
# Fix https://github.com/pymc-devs/pymc3/issues/3844
logp = dill.dumps(self.logp)
vals = self.__dict__.copy()
vals['logp'] = logp
return vals

def __setstate__(self, vals):
vals['logp'] = dill.loads(vals['logp'])
self.__dict__ = vals

def random(self, point=None, size=None, **kwargs):
if self.rand is not None:
not_broadcast_kwargs = dict(point=point)
Expand Down
14 changes: 14 additions & 0 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@

from ..distributions import continuous
from pymc3.theanof import floatX
import pymc3 as pm
from numpy import array, inf, log, exp
from numpy.testing import assert_almost_equal, assert_allclose, assert_equal
import numpy.random as nr
Expand Down Expand Up @@ -1872,3 +1873,16 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
assert isinstance(actual_a, np.ndarray)
assert actual_a.shape == (X.shape[0],)
pass


def test_serialize_density_dist():
def func(x):
return -2 * (x ** 2).sum()

with pm.Model():
pm.Normal('x')
y = pm.DensityDist('y', func)
pm.sample(draws=5, tune=1, mp_ctx="spawn")

import pickle
pickle.loads(pickle.dumps(y))
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ sphinx-autobuild==0.7.1
sphinx>=1.5.5
watermark
parameterized
dill
dill
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ h5py>=2.7.0
typing-extensions>=3.7.4
dataclasses; python_version < '3.7'
contextvars; python_version < '3.7'
dill
5 changes: 4 additions & 1 deletion scripts/create_testenv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,18 @@ command -v conda >/dev/null 2>&1 || {
ENVNAME="${ENVNAME:-testenv}" # if no ENVNAME is specified, use testenv

if [ -z ${GLOBAL} ]; then
source $(dirname $(dirname $(which conda)))/etc/profile.d/conda.sh
if conda env list | grep -q ${ENVNAME}; then
echo "Environment ${ENVNAME} already exists, keeping up to date"
conda activate ${ENVNAME}
mamba env update -f environment-dev.yml
else
conda config --add channels conda-forge
conda config --set channel_priority strict
conda install -c conda-forge mamba --yes
mamba env create -f environment-dev.yml
conda activate ${ENVNAME}
fi
source activate ${ENVNAME}
fi

# Install editable using the setup.py
Expand Down

0 comments on commit e08ad0c

Please sign in to comment.