diff --git a/pymc3/__init__.py b/pymc3/__init__.py index 846ac91b26..79c2180ac0 100644 --- a/pymc3/__init__.py +++ b/pymc3/__init__.py @@ -20,6 +20,7 @@ from .diagnostics import * from .backends.tracetab import * +from .backends import save_trace, load_trace from .plots import * from .tests import test diff --git a/pymc3/backends/__init__.py b/pymc3/backends/__init__.py index 95256bbf99..e52049c10f 100644 --- a/pymc3/backends/__init__.py +++ b/pymc3/backends/__init__.py @@ -113,7 +113,7 @@ For specific examples, see pymc3.backends.{ndarray,text,sqlite}.py. """ -from ..backends.ndarray import NDArray +from ..backends.ndarray import NDArray, save_trace, load_trace from ..backends.text import Text from ..backends.sqlite import SQLite from ..backends.hdf5 import HDF5 diff --git a/pymc3/backends/ndarray.py b/pymc3/backends/ndarray.py index 700722a787..1b0b075389 100644 --- a/pymc3/backends/ndarray.py +++ b/pymc3/backends/ndarray.py @@ -2,10 +2,126 @@ Store sampling values in memory as a NumPy array. """ +import glob +import json +import os + import numpy as np from ..backends import base +def save_trace(trace, directory='.pymc3.trace'): + """Save multitrace to file. + + TODO: Also save warnings. + + This is a custom data format for PyMC3 traces. Each chain goes inside + a directory, and each directory contains a metadata json file, and a + numpy compressed file. See https://docs.scipy.org/doc/numpy/neps/npy-format.html + for more information about this format. + + Parameters + ---------- + trace : pm.MultiTrace + trace to save to disk + directory : str (optional) + path to a directory to save the trace + + Returns + ------- + str, path to the directory where the trace was saved + """ + if not os.path.exists(directory): + os.makedirs(directory) + for chain, ndarray in trace._straces.items(): + SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray) + return directory + + +def load_trace(directory, model=None): + """Loads a multitrace that has been written to file. + + A the model used for the trace must be passed in, or the command + must be run in a model context. + + Parameters + ---------- + directory : str + Path to a pymc3 serialized trace + model : pm.Model (optional) + Model used to create the trace. Can also be inferred from context + + Returns + ------- + pm.Multitrace that was saved in the directory + """ + straces = [] + for directory in glob.glob(os.path.join(directory, '*')): + if os.path.isdir(directory): + straces.append(SerializeNDArray(directory).load(model)) + return base.MultiTrace(straces) + + +class SerializeNDArray(object): + metadata_file = 'metadata.json' + samples_file = 'samples.npz' + + def __init__(self, directory): + """Helper to save and load NDArray objects""" + self.directory = directory + self.metadata_path = os.path.join(self.directory, self.metadata_file) + self.samples_path = os.path.join(self.directory, self.samples_file) + + @staticmethod + def to_metadata(ndarray): + """Extract ndarray metadata into json-serializable content""" + if ndarray._stats is None: + stats = ndarray._stats + else: + stats = [] + for stat in ndarray._stats: + stats.append({key: value.tolist() for key, value in stat.items()}) + + metadata = { + 'draw_idx': ndarray.draw_idx, + 'draws': ndarray.draws, + '_stats': stats, + 'chain': ndarray.chain, + } + return metadata + + def save(self, ndarray): + """Serialize a ndarray to file + + The goal here is to be modestly safer and more portable than a + pickle file. The expense is that the model code must be available + to reload the multitrace. + """ + if not isinstance(ndarray, NDArray): + raise TypeError('Can only save NDArray') + + if not os.path.exists(self.directory): + os.mkdir(self.directory) + + with open(self.metadata_path, 'w') as buff: + json.dump(SerializeNDArray.to_metadata(ndarray), buff) + + np.savez_compressed(self.samples_path, **ndarray.samples) + + def load(self, model): + """Load the saved ndarray from file""" + new_trace = NDArray(model=model) + with open(self.metadata_path, 'r') as buff: + metadata = json.load(buff) + + metadata['_stats'] = [{k: np.array(v) for k, v in stat.items()} for stat in metadata['_stats']] + + for key, value in metadata.items(): + setattr(new_trace, key, value) + new_trace.samples = dict(np.load(self.samples_path)) + return new_trace + + class NDArray(base.BaseTrace): """NDArray trace object diff --git a/pymc3/tests/test_ndarray_backend.py b/pymc3/tests/test_ndarray_backend.py index ba6a80982e..8059959904 100644 --- a/pymc3/tests/test_ndarray_backend.py +++ b/pymc3/tests/test_ndarray_backend.py @@ -2,6 +2,7 @@ import numpy.testing as npt from pymc3.tests import backend_fixtures as bf from pymc3.backends import base, ndarray +import pymc3 as pm import pytest @@ -165,3 +166,48 @@ def test_combine_true_squeeze_true(self): expected = np.concatenate([self.x, self.y]) result = base._squeeze_cat([self.x, self.y], True, True) npt.assert_equal(result, expected) + +class TestSaveLoad(object): + @staticmethod + def model(): + with pm.Model() as model: + x = pm.Normal('x', 0, 1) + y = pm.Normal('y', x, 1, observed=2) + z = pm.Normal('z', x + y, 1) + return model + + @classmethod + def setup_class(cls): + with TestSaveLoad.model(): + cls.trace = pm.sample() + + def test_save_and_load(self, tmpdir_factory): + directory = str(tmpdir_factory.mktemp('data')) + save_dir = pm.save_trace(self.trace, directory) + + assert save_dir == directory + + trace2 = pm.load_trace(directory, model=TestSaveLoad.model()) + + for var in ('x', 'z'): + assert (self.trace[var] == trace2[var]).all() + + def test_sample_ppc(self, tmpdir_factory): + directory = str(tmpdir_factory.mktemp('data')) + save_dir = pm.save_trace(self.trace, directory) + + assert save_dir == directory + + seed = 10 + np.random.seed(seed) + with TestSaveLoad.model(): + ppc = pm.sample_ppc(self.trace) + + seed = 10 + np.random.seed(seed) + with TestSaveLoad.model(): + trace2 = pm.load_trace(directory) + ppc2 = pm.sample_ppc(trace2) + + for key, value in ppc.items(): + assert (value == ppc2[key]).all()