Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Add `check_test_point` method to `pm.Model`
- Add `Ordered` Transformation and `OrderedLogistic` distribution
- Better warning message for `Mass matrix contains zeros on the diagonal. Some derivatives might always be zero`
- Save and load traces without `pickle` using `pm.save_trace` and `pm.load_trace`

### Fixes

Expand All @@ -33,7 +34,7 @@
- Plots of discrete distributions in the docstrings
- Add logitnormal distribution
- Densityplot: add support for discrete variables
- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`.
- Fix the Binomial likelihood in `.glm.families.Binomial`, with the flexibility of specifying the `n`.
- Add `offset` kwarg to `.glm`.
- Changed the `compare` function to accept a dictionary of model-trace pairs instead of two separate lists of models and traces.
- add test and support for creating multivariate mixture and mixture of mixtures
Expand Down Expand Up @@ -71,15 +72,15 @@
- Forestplot supports multiple traces (#2736)
- Add new plot, densityplot (#2741)
- DIC and BPIC calculations have been deprecated
- Refactor HMC and implemented new warning system (#2677, #2808)
- Refactor HMC and implemented new warning system (#2677, #2808)

### Fixes

- Fixed `compareplot` to use `loo` output.
- Improved `posteriorplot` to scale fonts
- `sample_ppc_w` now broadcasts
- `df_summary` function renamed to `summary`
- Add test for `model.logp_array` and `model.bijection` (#2724)
- Add test for `model.logp_array` and `model.bijection` (#2724)
- Fixed `sample_ppc` and `sample_ppc_w` to iterate all chains(#2633, #2748)
- Add Bayesian R2 score (for GLMs) `stats.r2_score` (#2696) and test (#2729).
- SMC works with transformed variables (#2755)
Expand Down
28 changes: 23 additions & 5 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
import glob
import json
import os
import shutil

import numpy as np
from ..backends import base


def save_trace(trace, directory='.pymc3.trace'):
def save_trace(trace, directory=None, overwrite=False):
"""Save multitrace to file.

TODO: Also save warnings.
Expand All @@ -26,13 +27,28 @@ def save_trace(trace, directory='.pymc3.trace'):
trace to save to disk
directory : str (optional)
path to a directory to save the trace
overwrite : bool (default False)
whether to overwrite an existing directory.

Returns
-------
str, path to the directory where the trace was saved
"""
if not os.path.exists(directory):
os.makedirs(directory)
if directory is None:
directory = '.pymc_{}.trace'
idx = 1
while os.path.exists(directory.format(idx)):
idx += 1
directory = directory.format(idx)

if os.path.isdir(directory):
if overwrite:
shutil.rmtree(directory)
else:
raise OSError('Cautiously refusing to overwrite the already existing {}! Please supply '
'a different directory, or set `overwrite=True`'.format(directory))
os.makedirs(directory)

for chain, ndarray in trace._straces.items():
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
return directory
Expand Down Expand Up @@ -100,8 +116,10 @@ def save(self, ndarray):
if not isinstance(ndarray, NDArray):
raise TypeError('Can only save NDArray')

if not os.path.exists(self.directory):
os.mkdir(self.directory)
if os.path.isdir(self.directory):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we ask or alert before removing directories?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This scares me too, but my intention here is to treat self.directory as a single filesystem object that pymc3 owns. In this context, it is not unusual to overwrite without asking/alerting. I think this is fine, but I could see something ugly happening if you call pm.save_trace(trace, './'), intending to save a file to the current directory.

My best "safe" alternative that is not overly onerous is to write .pymc.manifest to the top level of the directory, with a list of all the files that pymc is allowed to delete in the future. Then this deletion will be just running over all the files in the existing manifest and deleting those. In that case, I would throw an exception then if any path already exists after deleting files in the manifest.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even more explicitly, by default, there would be a file ./.pymc3.trace/.pymc.manifest, whose contents are something like

0/samples.npz
0/metadata.json
1/samples.npz
1/metadata.json
...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could do add overwrite=False flag that if False and the dir exists, we raise an exception that you have to set overwrite=True.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's reasonable, though I always appreciate programs where save "just works". what if i followed jupyter's lead and did .pymc3_1.trace, .pymc3_2.trace, etc. by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would imagine that this saving is done multiple times for a single model, so overwriting is probably a useful feature.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, i agree with that! I am suggesting that the signature be

def save_trace(trace, filename=None, overwrite=False):

if filename is None:
    # find a unique filename to write to 
else:
    if filename exists:
        if overwrite is False:
            raise Exception
        else:
            # clear out directory

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, that sounds great!

shutil.rmtree(self.directory)

os.mkdir(self.directory)

with open(self.metadata_path, 'w') as buff:
json.dump(SerializeNDArray.to_metadata(ndarray), buff)
Expand Down
22 changes: 20 additions & 2 deletions pymc3/tests/test_ndarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,27 @@ def setup_class(cls):
with TestSaveLoad.model():
cls.trace = pm.sample()

def test_save_new_model(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory, overwrite=True)

assert save_dir == directory
with pm.Model() as model:
w = pm.Normal('w', 0, 1)
new_trace = pm.sample()

with pytest.raises(OSError):
_ = pm.save_trace(new_trace, directory)

_ = pm.save_trace(new_trace, directory, overwrite=True)
with model:
new_trace_copy = pm.load_trace(directory)

assert (new_trace['w'] == new_trace_copy['w']).all()

def test_save_and_load(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory)
save_dir = pm.save_trace(self.trace, directory, overwrite=True)

assert save_dir == directory

Expand All @@ -194,7 +212,7 @@ def test_save_and_load(self, tmpdir_factory):

def test_sample_ppc(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory)
save_dir = pm.save_trace(self.trace, directory, overwrite=True)

assert save_dir == directory

Expand Down