Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from .diagnostics import *
from .backends.tracetab import *
from .backends import save_trace, load_trace

from .plots import *
from .tests import test
Expand Down
2 changes: 1 addition & 1 deletion pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@

For specific examples, see pymc3.backends.{ndarray,text,sqlite}.py.
"""
from ..backends.ndarray import NDArray
from ..backends.ndarray import NDArray, save_trace, load_trace
from ..backends.text import Text
from ..backends.sqlite import SQLite
from ..backends.hdf5 import HDF5
Expand Down
116 changes: 116 additions & 0 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,126 @@

Store sampling values in memory as a NumPy array.
"""
import glob
import json
import os

import numpy as np
from ..backends import base


def save_trace(trace, directory='.pymc3.trace'):
"""Save multitrace to file.

TODO: Also save warnings.

This is a custom data format for PyMC3 traces. Each chain goes inside
a directory, and each directory contains a metadata json file, and a
numpy compressed file. See https://docs.scipy.org/doc/numpy/neps/npy-format.html
for more information about this format.

Parameters
----------
trace : pm.MultiTrace
trace to save to disk
directory : str (optional)
path to a directory to save the trace

Returns
-------
str, path to the directory where the trace was saved
"""
if not os.path.exists(directory):
os.makedirs(directory)
for chain, ndarray in trace._straces.items():
SerializeNDArray(os.path.join(directory, str(chain))).save(ndarray)
return directory


def load_trace(directory, model=None):
"""Loads a multitrace that has been written to file.

A the model used for the trace must be passed in, or the command
must be run in a model context.

Parameters
----------
directory : str
Path to a pymc3 serialized trace
model : pm.Model (optional)
Model used to create the trace. Can also be inferred from context

Returns
-------
pm.Multitrace that was saved in the directory
"""
straces = []
for directory in glob.glob(os.path.join(directory, '*')):
if os.path.isdir(directory):
straces.append(SerializeNDArray(directory).load(model))
return base.MultiTrace(straces)


class SerializeNDArray(object):
metadata_file = 'metadata.json'
samples_file = 'samples.npz'

def __init__(self, directory):
"""Helper to save and load NDArray objects"""
self.directory = directory
self.metadata_path = os.path.join(self.directory, self.metadata_file)
self.samples_path = os.path.join(self.directory, self.samples_file)

@staticmethod
def to_metadata(ndarray):
"""Extract ndarray metadata into json-serializable content"""
if ndarray._stats is None:
stats = ndarray._stats
else:
stats = []
for stat in ndarray._stats:
stats.append({key: value.tolist() for key, value in stat.items()})

metadata = {
'draw_idx': ndarray.draw_idx,
'draws': ndarray.draws,
'_stats': stats,
'chain': ndarray.chain,
}
return metadata

def save(self, ndarray):
"""Serialize a ndarray to file

The goal here is to be modestly safer and more portable than a
pickle file. The expense is that the model code must be available
to reload the multitrace.
"""
if not isinstance(ndarray, NDArray):
raise TypeError('Can only save NDArray')

if not os.path.exists(self.directory):
os.mkdir(self.directory)

with open(self.metadata_path, 'w') as buff:
json.dump(SerializeNDArray.to_metadata(ndarray), buff)

np.savez_compressed(self.samples_path, **ndarray.samples)

def load(self, model):
"""Load the saved ndarray from file"""
new_trace = NDArray(model=model)
with open(self.metadata_path, 'r') as buff:
metadata = json.load(buff)

metadata['_stats'] = [{k: np.array(v) for k, v in stat.items()} for stat in metadata['_stats']]

for key, value in metadata.items():
setattr(new_trace, key, value)
new_trace.samples = dict(np.load(self.samples_path))
return new_trace


class NDArray(base.BaseTrace):
"""NDArray trace object

Expand Down
46 changes: 46 additions & 0 deletions pymc3/tests/test_ndarray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy.testing as npt
from pymc3.tests import backend_fixtures as bf
from pymc3.backends import base, ndarray
import pymc3 as pm
import pytest


Expand Down Expand Up @@ -165,3 +166,48 @@ def test_combine_true_squeeze_true(self):
expected = np.concatenate([self.x, self.y])
result = base._squeeze_cat([self.x, self.y], True, True)
npt.assert_equal(result, expected)

class TestSaveLoad(object):
@staticmethod
def model():
with pm.Model() as model:
x = pm.Normal('x', 0, 1)
y = pm.Normal('y', x, 1, observed=2)
z = pm.Normal('z', x + y, 1)
return model

@classmethod
def setup_class(cls):
with TestSaveLoad.model():
cls.trace = pm.sample()

def test_save_and_load(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory)

assert save_dir == directory

trace2 = pm.load_trace(directory, model=TestSaveLoad.model())

for var in ('x', 'z'):
assert (self.trace[var] == trace2[var]).all()

def test_sample_ppc(self, tmpdir_factory):
directory = str(tmpdir_factory.mktemp('data'))
save_dir = pm.save_trace(self.trace, directory)

assert save_dir == directory

seed = 10
np.random.seed(seed)
with TestSaveLoad.model():
ppc = pm.sample_ppc(self.trace)

seed = 10
np.random.seed(seed)
with TestSaveLoad.model():
trace2 = pm.load_trace(directory)
ppc2 = pm.sample_ppc(trace2)

for key, value in ppc.items():
assert (value == ppc2[key]).all()