Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added model_builder as directed in PR #6023 on pymc #64

Merged
merged 39 commits into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3c3dbda
added model_builder
5hv5hvnk Aug 9, 2022
9a2ca23
added explanation
5hv5hvnk Aug 28, 2022
a510ef4
Merge branch 'pymc-devs:main' into main
5hv5hvnk Sep 3, 2022
410617b
added tests
5hv5hvnk Sep 3, 2022
4e1d200
formatting
5hv5hvnk Sep 3, 2022
50cfc1c
change in save and load method
5hv5hvnk Sep 6, 2022
9dab7cc
updated save and load methods
5hv5hvnk Sep 6, 2022
d757d9c
fixed more errors
5hv5hvnk Sep 7, 2022
f6c0fab
fixed path variable in save and load
5hv5hvnk Sep 7, 2022
2e1e02f
Documentation
5hv5hvnk Sep 7, 2022
890a00f
Documentation
5hv5hvnk Sep 7, 2022
e1cd115
fixed formatting and tests
5hv5hvnk Sep 7, 2022
6a80412
fixed docstring
5hv5hvnk Sep 7, 2022
5912eed
fixed minor issues
5hv5hvnk Sep 8, 2022
dd5b25b
fixed minor issues
5hv5hvnk Sep 8, 2022
c43651b
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
80c73b5
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
8ab4228
fixed spelling errors
5hv5hvnk Sep 8, 2022
e272f1c
Merge branch 'main' of https://github.com/5hv5hvnk/pymc-experimental …
5hv5hvnk Sep 8, 2022
6ccfa1e
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
43a06df
Update pymc_experimental/model_builder.py
twiecki Sep 8, 2022
6ce4838
Update pymc_experimental/model_builder.py
twiecki Sep 8, 2022
34f1aaf
added build method again
5hv5hvnk Sep 8, 2022
ad5fd70
Update pymc_experimental/tests/test_model_builder.py
twiecki Sep 9, 2022
2b42132
removed unecessary imports
5hv5hvnk Sep 9, 2022
8e0fd19
Merge branch 'pymc-devs:main' into main
5hv5hvnk Sep 10, 2022
d1467fe
changed arviz to az
5hv5hvnk Sep 10, 2022
0165fcb
fixed codes to pass build tests
5hv5hvnk Sep 10, 2022
b49199d
linespace -> linspace
twiecki Sep 11, 2022
63c4ed9
updated test_model_builder.py
5hv5hvnk Sep 11, 2022
fb51049
updated model_builder.py
5hv5hvnk Sep 11, 2022
0030a50
fixed overloading of test_fit()
5hv5hvnk Sep 11, 2022
e5f8e72
fixed indentation in docstring
5hv5hvnk Sep 11, 2022
b045f08
added some better examples
5hv5hvnk Sep 12, 2022
37f0bc3
fixed test.yml
5hv5hvnk Sep 12, 2022
63502a4
indetation
5hv5hvnk Sep 12, 2022
b4d1ca1
Apply suggestions from code review
twiecki Sep 13, 2022
70053e1
Update pymc_experimental/model_builder.py
twiecki Sep 13, 2022
c498b20
Update pymc_experimental/model_builder.py
twiecki Sep 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ jobs:
# The ">-" in the next line replaces newlines with spaces (see https://stackoverflow.com/a/66809682).
run: >-
conda activate pymc-test-py37 &&
python -m pytest -vv --cov=pymc_experimental --doctest-modules pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v2
with:
Expand Down
309 changes: 309 additions & 0 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
import hashlib
from pathlib import Path
from typing import Dict, Union

import arviz as az
import numpy as np
import pandas as pd
import pymc as pm


class ModelBuilder(pm.Model):
"""
Extension of pm.Model class to improve workflow.
twiecki marked this conversation as resolved.
Show resolved Hide resolved
ModelBuilder class can be used to play around models with ease using direct API calls
for multiple tasks that one need to deploy a model.
twiecki marked this conversation as resolved.
Show resolved Hide resolved
"""
twiecki marked this conversation as resolved.
Show resolved Hide resolved

_model_type = "BaseClass"
version = "None"

def __init__(
self,
model_config: Dict,
sampler_config: Dict,
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
):
"""
Initializes model configuration and sampler configuration for the model

Parameters
----------
model_config : Dictionary
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method.
sampler_config : Dictionary
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
data : Dictionary
It is the data we need to train the model on.
Examples
--------
>>> class LinearModel(ModelBuilder):
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
>>> ...
>>> model = LinearModel(model_config, sampler_config)
"""

super().__init__()
self.model_config = model_config # parameters for priors etc.
self.sample_config = sampler_config # parameters for sampling
self.idata = None # inference data object
self.data = data
self.build()

def build(self):
"""
Builds the defined model.
"""

with self:
self.build_model(self.model_config, self.data)

def _data_setter(
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True
):
"""
Sets new data in the model.

Parameters
----------
data : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to set as idata for the model
x_only : bool
if data only contains values of x and y is not present in the data

Examples
--------
>>> def _data_setter(self, data : pd.DataFrame):
>>> with self.model:
>>> pm.set_data({'x': data['input'].values})
>>> try: # if y values in new data
>>> pm.set_data({'y_data': data['output'].values})
>>> except: # dummies otherwise
>>> pm.set_data({'y_data': np.zeros(len(data))})
"""

raise NotImplementedError

@classmethod
def create_sample_input(cls):
"""
Needs to be implemented by the user in the inherited class.
Returns examples for data, model_config, sampler_config.
This is useful for understanding the required
data structures for the user model.

Examples
--------
>>> @classmethod
>>> def create_sample_input(cls):
>>> x = np.linspace(start=1, stop=50, num=100)
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
>>> data = pd.DataFrame({'input': x, 'output': y})

>>> model_config = {
twiecki marked this conversation as resolved.
Show resolved Hide resolved
>>> 'a_loc': 7,
>>> 'a_scale': 3,
>>> 'b_loc': 5,
>>> 'b_scale': 3,
>>> 'obs_error': 2,
>>> }

>>> sampler_config = {
>>> 'draws': 1_000,
>>> 'tune': 1_000,
>>> 'chains': 1,
>>> 'target_accept': 0.95,
>>> }
>>> return data, model_config, sampler_config
"""

raise NotImplementedError

def save(self, fname):
"""
Saves inference data of the model.

Parameters
----------
fname : string
This denotes the name with path from where idata should be saved.

Examples
--------
>>> class LinearModel():
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
>>> ...
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
>>> model = LinearModel(model_config, sampler_config)
>>> idata = model.fit(data)
>>> name = './mymodel.nc'
>>> model.save(name)
"""

file = Path(str(fname))
self.idata.to_netcdf(file)

@classmethod
def load(cls, fname):
"""
Loads inference data for the model.

Parameters
----------
fname : string
This denotes the name with path from where idata should be loaded from.

Returns
-------
Returns the inference data that is loaded from local system.

Raises
------
ValueError
If the inference data that is loaded doesn't match with the model.

Examples
--------
>>> class LinearModel():
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
>>> ...
>>> name = './mymodel.nc'
>>> imported_model = LinearModel.load(name)
"""

filepath = Path(str(fname))
data = az.from_netcdf(filepath)
idata = data
# Since there is an issue with attrs getting saved in netcdf format which will be fixed in future the following part of code is commented
# Link of issue -> https://github.com/arviz-devs/arviz/issues/2109
# if model.idata.attrs is not None:
twiecki marked this conversation as resolved.
Show resolved Hide resolved
# if model.idata.attrs['id'] == self.idata.attrs['id']:
# self = model
# self.idata = data
# return self
# else:
# raise ValueError(
# f"The route '{file}' does not contain an inference data of the same model '{self.__name__}'"
# )
return idata

def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
"""
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
Sets attrs to inference data of the model.

Parameter
---------
data : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to train the model on.

Returns
-------
returns inference data of the fitted model.

Examples
--------
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
>>> model = LinearModel(model_config, sampler_config)
>>> idata = model.fit(data)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
"""

if data is not None:
self.data = data
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
self._data_setter(data)

if self.basic_RVs == []:
self.build()

with self:
self.idata = pm.sample(**self.sample_config)
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
self.idata.extend(pm.sample_prior_predictive())
self.idata.extend(pm.sample_posterior_predictive(self.idata))

self.idata.attrs["id"] = self.id()
self.idata.attrs["model_type"] = self._model_type
self.idata.attrs["version"] = self.version
self.idata.attrs["sample_conifg"] = self.sample_config
self.idata.attrs["model_config"] = self.model_config
return self.idata

def predict(
self,
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
point_estimate: bool = True,
):
"""
Uses model to predict on unseen data.

Parameters
---------
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to make prediction on using the model.
point_estimate : bool
Adds point like estimate used as mean passed as

Returns
-------
returns dictionary of sample's posterior predict.

Examples
--------
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
>>> model = LinearModel(model_config, sampler_config)
>>> idata = model.fit(data)
>>> x_pred = []
>>> prediction_data = pd.DataFrame({'input':x_pred})
# only point estimate
>>> pred_mean = model.predict(prediction_data)
# samples
>>> pred_samples = model.predict(prediction_data, point_estimate=False)
"""

if data_prediction is not None: # set new input data
self._data_setter(data_prediction)

with self.model: # sample with new input data
post_pred = pm.sample_posterior_predictive(self.idata)

# reshape output
post_pred = self._extract_samples(post_pred)
if point_estimate: # average, if point-like estimate desired
for key in post_pred:
post_pred[key] = post_pred[key].mean(axis=0)
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved

return post_pred

@staticmethod
def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[str, np.array]:
"""
This method can be used to extract samples from posterior predict.

Parameters
----------
post_pred: arviz InferenceData object

Returns
-------
Dictionary of numpy arrays from InferenceData object
"""

post_pred_dict = dict()
for key in post_pred.posterior_predictive:
post_pred_dict[key] = post_pred.posterior_predictive[key].to_numpy()[0]

return post_pred_dict

def id(self):
"""
It creates a hash value to match the model version using last 16 characters of hash encoding.

Returns
-------
Returns string of length 16 characters contains unique hash of the model
"""

hasher = hashlib.sha256()
hasher.update(str(self.model_config.values()).encode())
hasher.update(self.version.encode())
hasher.update(self._model_type.encode())
hasher.update(str(self.sample_config.values()).encode())
return hasher.hexdigest()[:16]
Loading