Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add build_model abstractmethod to ModelBuilder #142

Merged
merged 6 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
import json
from abc import abstractmethod
from pathlib import Path
from typing import Dict, Union
from typing import Any, Dict, Union

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm
from pymc.util import RandomState


class ModelBuilder:
Expand Down Expand Up @@ -100,7 +101,7 @@ def _data_setter(
@abstractmethod
def create_sample_input():
"""
Needs to be implemented by the user in the inherited class.
Needs to be implemented by the user in the child class.
Returns examples for data, model_config, sampler_config.
This is useful for understanding the required
data structures for the user model.
Expand All @@ -114,12 +115,15 @@ def create_sample_input():
>>> data = pd.DataFrame({'input': x, 'output': y})

>>> model_config = {
>>> 'a_loc': 7,
>>> 'a_scale': 3,
>>> 'b_loc': 5,
>>> 'b_scale': 3,
>>> 'obs_error': 2,
>>> }
>>> 'a' : {
>>> 'a_loc': 7,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not just be loc, the a seems redundant.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, good catch

>>> 'a_scale' : 3
>>> },
>>> 'b' : {
>>> 'b_loc': 3,
>>> 'b_scale': 5
>>> }
>>> 'obs_error': 2

>>> sampler_config = {
>>> 'draws': 1_000,
Expand All @@ -132,6 +136,31 @@ def create_sample_input():

raise NotImplementedError

@abstractmethod
def build_model(
model_data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]],
model_config: Dict[str, Union[int, float, Dict]],
) -> None:
"""
Needs to be implemented by the user in the child class.
Creates an instance of pm.Model based on provided model_data and model_config, and
attaches it to self.

Required Parameters
----------
model_data - preformated data that is going to be used in the model.
For efficiency reasons it should contain only the necesary data columns, not entire available
dataset since it's going to be encoded into data used to recreate the model.
model_config - dictionary where keys are strings representing names of parameters of the model, values are
dictionaries of parameters needed for creating model parameters (see example in create_model_input)

Returns:
----------
None

"""
raise NotImplementedError

def save(self, fname: str) -> None:
"""
Saves inference data of the model.
Expand Down Expand Up @@ -191,7 +220,7 @@ def load(cls, fname: str):
data=idata.fit_data.to_dataframe(),
)
model_builder.idata = idata
model_builder.build()
model_builder.idata = model_builder.fit()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the model was already fit?

if model_builder.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
Expand All @@ -200,7 +229,12 @@ def load(cls, fname: str):
return model_builder

def fit(
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
self,
progressbar: bool = True,
random_seed: RandomState = None,
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
*args: Any,
**kwargs: Any,
) -> az.InferenceData:
"""
Fit a model using the data passed as a parameter.
Expand All @@ -227,20 +261,21 @@ def fit(
# If a new data was provided, assign it to the model
if data is not None:
self.data = data

self.build()
self._data_setter(data)

self.model_data, self.model_config, self.sampler_config = self.create_sample_input(
data=self.data
)
self.build_model(self.model_data, self.model_config)
self._data_setter(self.model_data)
with self.model:
self.idata = pm.sample(**self.sampler_config)
self.idata = pm.sample(**self.sampler_config, **kwargs)
self.idata.extend(pm.sample_prior_predictive())
self.idata.extend(pm.sample_posterior_predictive(self.idata))

self.idata.attrs["id"] = self.id
self.idata.attrs["model_type"] = self._model_type
self.idata.attrs["version"] = self.version
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
self.idata.attrs["model_config"] = json.dumps(self.model_config)
self.idata.attrs["model_config"] = json.dumps(self.serializable_model_config)
self.idata.add_groups(fit_data=self.data.to_xarray())
return self.idata

Expand Down
33 changes: 17 additions & 16 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,26 @@ class test_ModelBuilder(ModelBuilder):
_model_type = "LinearModel"
version = "0.1"

def build(self):

def build_model(self, model_data, model_config):
with pm.Model() as self.model:
if self.data is not None:
x = pm.MutableData("x", self.data["input"].values)
y_data = pm.MutableData("y_data", self.data["output"].values)
if model_data is not None:
x = pm.MutableData("x", model_data["input"].values)
y_data = pm.MutableData("y_data", model_data["output"].values)

# prior parameters
a_loc = self.model_config["a_loc"]
a_scale = self.model_config["a_scale"]
b_loc = self.model_config["b_loc"]
b_scale = self.model_config["b_scale"]
obs_error = self.model_config["obs_error"]
a_loc = model_config["a"]["loc"]
a_scale = model_config["a"]["scale"]
b_loc = model_config["b"]["loc"]
b_scale = model_config["b"]["scale"]
obs_error = model_config["obs_error"]

# priors
a = pm.Normal("a", a_loc, sigma=a_scale)
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

# observed data
if self.data is not None:
if model_data is not None:
y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data)

def _data_setter(self, data: pd.DataFrame):
Expand All @@ -57,18 +56,20 @@ def _data_setter(self, data: pd.DataFrame):
if "output" in data.columns:
pm.set_data({"y_data": data["output"].values})

@property
def serializable_model_config(self):
return self.model_config

@classmethod
def create_sample_input(self):
def create_sample_input(self, data=None):
x = np.linspace(start=0, stop=1, num=100)
y = 5 * x + 3
y = y + np.random.normal(0, 1, len(x))
data = pd.DataFrame({"input": x, "output": y})

model_config = {
"a_loc": 0,
"a_scale": 10,
"b_loc": 0,
"b_scale": 10,
"a": {"loc": 0, "scale": 10},
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}

Expand Down