Skip to content

Commit

Permalink
Add build_model abstractmethod to ModelBuilder (#142)
Browse files Browse the repository at this point in the history
* adaptations to integrate with mmm

* adapted model_config and descriptions

* fixed ModuleNotFoundError from build

* small tweaks to make mmm tests work smoother

* new test for save, fit allows custom configs

* updating create_sample_input example
  • Loading branch information
michaelraczycki authored Apr 19, 2023
1 parent 5f1c2bb commit e38da06
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 37 deletions.
89 changes: 70 additions & 19 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import json
from abc import abstractmethod
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Union

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
from pymc.util import RandomState


class ModelBuilder:
Expand Down Expand Up @@ -100,7 +101,7 @@ def _data_setter(
@abstractmethod
def create_sample_input():
"""
Needs to be implemented by the user in the inherited class.
Needs to be implemented by the user in the child class.
Returns examples for data, model_config, sampler_config.
This is useful for understanding the required
data structures for the user model.
Expand All @@ -114,12 +115,15 @@ def create_sample_input():
>>> data = pd.DataFrame({'input': x, 'output': y})
>>> model_config = {
>>> 'a_loc': 7,
>>> 'a_scale': 3,
>>> 'b_loc': 5,
>>> 'b_scale': 3,
>>> 'obs_error': 2,
>>> }
>>> 'a' : {
>>> 'loc': 7,
>>> 'scale' : 3
>>> },
>>> 'b' : {
>>> 'loc': 3,
>>> 'scale': 5
>>> }
>>> 'obs_error': 2
>>> sampler_config = {
>>> 'draws': 1_000,
Expand All @@ -132,6 +136,31 @@ def create_sample_input():

raise NotImplementedError

@abstractmethod
def build_model(
model_data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]],
model_config: Dict[str, Union[int, float, Dict]],
) -> None:
"""
Needs to be implemented by the user in the child class.
Creates an instance of pm.Model based on provided model_data and model_config, and
attaches it to self.
Required Parameters
----------
model_data - preformated data that is going to be used in the model.
For efficiency reasons it should contain only the necesary data columns, not entire available
dataset since it's going to be encoded into data used to recreate the model.
model_config - dictionary where keys are strings representing names of parameters of the model, values are
dictionaries of parameters needed for creating model parameters (see example in create_model_input)
Returns:
----------
None
"""
raise NotImplementedError

def save(self, fname: str) -> None:
"""
Saves inference data of the model.
Expand All @@ -151,9 +180,11 @@ def save(self, fname: str) -> None:
>>> name = './mymodel.nc'
>>> model.save(name)
"""

file = Path(str(fname))
self.idata.to_netcdf(file)
if self.idata is not None and "fit_data" in self.idata:
file = Path(str(fname))
self.idata.to_netcdf(file)
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

@classmethod
def load(cls, fname: str):
Expand Down Expand Up @@ -191,7 +222,7 @@ def load(cls, fname: str):
data=idata.fit_data.to_dataframe(),
)
model_builder.idata = idata
model_builder.build()
model_builder.build_model(model_builder.data, model_builder.model_config)
if model_builder.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
Expand All @@ -200,7 +231,12 @@ def load(cls, fname: str):
return model_builder

def fit(
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
self,
progressbar: bool = True,
random_seed: RandomState = None,
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
*args: Any,
**kwargs: Any,
) -> az.InferenceData:
"""
Fit a model using the data passed as a parameter.
Expand All @@ -227,20 +263,22 @@ def fit(
# If a new data was provided, assign it to the model
if data is not None:
self.data = data

self.build()
self._data_setter(data)

self.model_data, model_config, sampler_config = self.create_sample_input(data=self.data)
if self.model_config is None:
self.model_config = model_config
if self.sampler_config is None:
self.sampler_config = sampler_config
self.build_model(self.model_data, self.model_config)
with self.model:
self.idata = pm.sample(**self.sampler_config)
self.idata = pm.sample(**self.sampler_config, **kwargs)
self.idata.extend(pm.sample_prior_predictive())
self.idata.extend(pm.sample_posterior_predictive(self.idata))

self.idata.attrs["id"] = self.id
self.idata.attrs["model_type"] = self._model_type
self.idata.attrs["version"] = self.version
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
self.idata.attrs["model_config"] = json.dumps(self.model_config)
self.idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
self.idata.add_groups(fit_data=self.data.to_xarray())
return self.idata

Expand Down Expand Up @@ -351,6 +389,19 @@ def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[st

return post_pred_dict

@property
@abstractmethod
def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
"""
Converts non-serializable values from model_config to their serializable reversable equivalent.
Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
so in order to save the model they need to be formatted.
Returns
-------
model_config: dict
"""

@property
def id(self) -> str:
"""
Expand Down
46 changes: 28 additions & 18 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,26 @@ class test_ModelBuilder(ModelBuilder):
_model_type = "LinearModel"
version = "0.1"

def build(self):

def build_model(self, model_data, model_config):
with pm.Model() as self.model:
if self.data is not None:
x = pm.MutableData("x", self.data["input"].values)
y_data = pm.MutableData("y_data", self.data["output"].values)
if model_data is not None:
x = pm.MutableData("x", model_data["input"].values)
y_data = pm.MutableData("y_data", model_data["output"].values)

# prior parameters
a_loc = self.model_config["a_loc"]
a_scale = self.model_config["a_scale"]
b_loc = self.model_config["b_loc"]
b_scale = self.model_config["b_scale"]
obs_error = self.model_config["obs_error"]
a_loc = model_config["a"]["loc"]
a_scale = model_config["a"]["scale"]
b_loc = model_config["b"]["loc"]
b_scale = model_config["b"]["scale"]
obs_error = model_config["obs_error"]

# priors
a = pm.Normal("a", a_loc, sigma=a_scale)
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

# observed data
if self.data is not None:
if model_data is not None:
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)

def _data_setter(self, data: pd.DataFrame):
Expand All @@ -57,18 +56,20 @@ def _data_setter(self, data: pd.DataFrame):
if "output" in data.columns:
pm.set_data({"y_data": data["output"].values})

@property
def _serializable_model_config(self):
return self.model_config

@classmethod
def create_sample_input(self):
def create_sample_input(self, data=None):
x = np.linspace(start=0, stop=1, num=100)
y = 5 * x + 3
y = y + np.random.normal(0, 1, len(x))
data = pd.DataFrame({"input": x, "output": y})

model_config = {
"a_loc": 0,
"a_scale": 10,
"b_loc": 0,
"b_scale": 10,
"a": {"loc": 0, "scale": 10},
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}

Expand All @@ -94,7 +95,16 @@ def initial_build_and_fit(check_idata=True) -> ModelBuilder:
return model_builder


def test_empty_model_config():
def test_save_without_fit_raises_runtime_error():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_builder = test_ModelBuilder(
model_config=model_config, sampler_config=sampler_config, data=data
)
with pytest.raises(RuntimeError):
model_builder.save("saved_model")


def test_empty_sampler_config_fit():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
sampler_config = {}
model_builder = test_ModelBuilder(
Expand All @@ -105,7 +115,7 @@ def test_empty_model_config():
assert "posterior" in model_builder.idata.groups()


def test_empty_model_config():
def test_empty_model_config_fit():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_config = {}
model_builder = test_ModelBuilder(
Expand Down

0 comments on commit e38da06

Please sign in to comment.