diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 89c457eff4..634e5e9ef5 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -9,6 +9,7 @@ - Add `check_test_point` method to `pm.Model` - Add `Ordered` Transformation and `OrderedLogistic` distribution - Better warning message for `Mass matrix contains zeros on the diagonal. Some derivatives might always be zero` +- Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace` ### Fixes @@ -33,7 +34,7 @@ - Plots of discrete distributions in the docstrings - Add logitnormal distribution - Densityplot: add support for discrete variables -- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`. +- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`. - Add `offset` kwarg to `.glm`. - Changed the `compare` function to accept a dictionary of model-trace pairs instead of two separate lists of models and traces. - add test and support for creating multivariate mixture and mixture of mixtures @@ -71,7 +72,7 @@ - Forestplot supports multiple traces (#2736) - Add new plot, densityplot (#2741) - DIC and BPIC calculations have been deprecated -- Refactor HMC and implemented new warning system (#2677, #2808) +- Refactor HMC and implemented new warning system (#2677, #2808) ### Fixes @@ -79,7 +80,7 @@ - Improved `posteriorplot` to scale fonts - `sample_ppc_w` now broadcasts - `df_summary` function renamed to `summary` -- Add test for `model.logp_array` and `model.bijection` (#2724) +- Add test for `model.logp_array` and `model.bijection` (#2724) - Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633, #2748) - Add Bayesian R2 score (for GLMs) `stats.r2_score` (#2696) and test (#2729). - SMC works with transformed variables (#2755) diff --git a/pymc3/backends/ndarray.py b/pymc3/backends/ndarray.py index c6621ef8fe..1c57bb02dc 100644 --- a/pymc3/backends/ndarray.py +++ b/pymc3/backends/ndarray.py @@ -5,12 +5,13 @@ import glob import json import os +import shutil import numpy as np from ..backends import base -def save_trace(trace, directory='.pymc3.trace'): +def save_trace(trace, directory=None, overwrite=False): """Save multitrace to file. TODO: Also save warnings. @@ -26,13 +27,28 @@ def save_trace(trace, directory='.pymc3.trace'): trace to save to disk directory : str (optional) path to a directory to save the trace + overwrite : bool (default False) + whether to overwrite an existing directory. Returns ------- str, path to the directory where the trace was saved """ - if not os.path.exists(directory): - os.makedirs(directory) + if directory is None: + directory = '.pymc_{}.trace' + idx = 1 + while os.path.exists(directory.format(idx)): + idx += 1 + directory = directory.format(idx) + + if os.path.isdir(directory): + if overwrite: + shutil.rmtree(directory) + else: + raise OSError('Cautiously refusing to overwrite the already existing {}! Please supply ' + 'a different directory, or set `overwrite=True`'.format(directory)) + os.makedirs(directory) + for chain, ndarray in trace._straces.items(): SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray) return directory @@ -100,8 +116,10 @@ def save(self, ndarray): if not isinstance(ndarray, NDArray): raise TypeError('Can only save NDArray') - if not os.path.exists(self.directory): - os.mkdir(self.directory) + if os.path.isdir(self.directory): + shutil.rmtree(self.directory) + + os.mkdir(self.directory) with open(self.metadata_path, 'w') as buff: json.dump(SerializeNDArray.to_metadata(ndarray), buff) diff --git a/pymc3/tests/test_ndarray_backend.py b/pymc3/tests/test_ndarray_backend.py index 8059959904..455f503642 100644 --- a/pymc3/tests/test_ndarray_backend.py +++ b/pymc3/tests/test_ndarray_backend.py @@ -181,9 +181,27 @@ def setup_class(cls): with TestSaveLoad.model(): cls.trace = pm.sample() + def test_save_new_model(self, tmpdir_factory): + directory = str(tmpdir_factory.mktemp('data')) + save_dir = pm.save_trace(self.trace, directory, overwrite=True) + + assert save_dir == directory + with pm.Model() as model: + w = pm.Normal('w', 0, 1) + new_trace = pm.sample() + + with pytest.raises(OSError): + _ = pm.save_trace(new_trace, directory) + + _ = pm.save_trace(new_trace, directory, overwrite=True) + with model: + new_trace_copy = pm.load_trace(directory) + + assert (new_trace['w'] == new_trace_copy['w']).all() + def test_save_and_load(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp('data')) - save_dir = pm.save_trace(self.trace, directory) + save_dir = pm.save_trace(self.trace, directory, overwrite=True) assert save_dir == directory @@ -194,7 +212,7 @@ def test_save_and_load(self, tmpdir_factory): def test_sample_ppc(self, tmpdir_factory): directory = str(tmpdir_factory.mktemp('data')) - save_dir = pm.save_trace(self.trace, directory) + save_dir = pm.save_trace(self.trace, directory, overwrite=True) assert save_dir == directory