Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added model_builder as directed in PR #6023 on pymc #64

Merged
merged 39 commits into from
Sep 13, 2022
Merged
Changes from 2 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3c3dbda
added model_builder
5hv5hvnk Aug 9, 2022
9a2ca23
added explanation
5hv5hvnk Aug 28, 2022
a510ef4
Merge branch 'pymc-devs:main' into main
5hv5hvnk Sep 3, 2022
410617b
added tests
5hv5hvnk Sep 3, 2022
4e1d200
formatting
5hv5hvnk Sep 3, 2022
50cfc1c
change in save and load method
5hv5hvnk Sep 6, 2022
9dab7cc
updated save and load methods
5hv5hvnk Sep 6, 2022
d757d9c
fixed more errors
5hv5hvnk Sep 7, 2022
f6c0fab
fixed path variable in save and load
5hv5hvnk Sep 7, 2022
2e1e02f
Documentation
5hv5hvnk Sep 7, 2022
890a00f
Documentation
5hv5hvnk Sep 7, 2022
e1cd115
fixed formatting and tests
5hv5hvnk Sep 7, 2022
6a80412
fixed docstring
5hv5hvnk Sep 7, 2022
5912eed
fixed minor issues
5hv5hvnk Sep 8, 2022
dd5b25b
fixed minor issues
5hv5hvnk Sep 8, 2022
c43651b
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
80c73b5
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
8ab4228
fixed spelling errors
5hv5hvnk Sep 8, 2022
e272f1c
Merge branch 'main' of https://github.com/5hv5hvnk/pymc-experimental …
5hv5hvnk Sep 8, 2022
6ccfa1e
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
43a06df
Update pymc_experimental/model_builder.py
twiecki Sep 8, 2022
6ce4838
Update pymc_experimental/model_builder.py
twiecki Sep 8, 2022
34f1aaf
added build method again
5hv5hvnk Sep 8, 2022
ad5fd70
Update pymc_experimental/tests/test_model_builder.py
twiecki Sep 9, 2022
2b42132
removed unecessary imports
5hv5hvnk Sep 9, 2022
8e0fd19
Merge branch 'pymc-devs:main' into main
5hv5hvnk Sep 10, 2022
d1467fe
changed arviz to az
5hv5hvnk Sep 10, 2022
0165fcb
fixed codes to pass build tests
5hv5hvnk Sep 10, 2022
b49199d
linespace -> linspace
twiecki Sep 11, 2022
63c4ed9
updated test_model_builder.py
5hv5hvnk Sep 11, 2022
fb51049
updated model_builder.py
5hv5hvnk Sep 11, 2022
0030a50
fixed overloading of test_fit()
5hv5hvnk Sep 11, 2022
e5f8e72
fixed indentation in docstring
5hv5hvnk Sep 11, 2022
b045f08
added some better examples
5hv5hvnk Sep 12, 2022
37f0bc3
fixed test.yml
5hv5hvnk Sep 12, 2022
63502a4
indetation
5hv5hvnk Sep 12, 2022
b4d1ca1
Apply suggestions from code review
twiecki Sep 13, 2022
70053e1
Update pymc_experimental/model_builder.py
twiecki Sep 13, 2022
c498b20
Update pymc_experimental/model_builder.py
twiecki Sep 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
260 changes: 260 additions & 0 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import pymc as pm
import arviz
import pandas as pd
from pathlib import Path
import numpy as np
from typing import Dict
import cloudpickle
import arviz as az
import hashlib

class ModelBuilder(pm.Model):
'''
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
Extention of pm.Model class to improve workflow.

ModelBuilder class can be used to play around models with ease using direct API calls
for multiple tasks that one need to deploy a model.

Example:

'''
_model_type = 'BaseClass'
version = 'None'

def __init__(self, model_config : Dict, sampler_config : Dict):
super().__init__()
self.model_config = model_config # parameters for priors etc.
self.sample_config = sampler_config # parameters for sampling
self.idata = None # parameters for

def _build(self):
'''
Needs to be implemented by the user in the inherited class.
Builds user model. Requires suitable self.data and self.model_config.

Example:
def _build(self):
# data
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
x = pm.MutableData('x', self.data['input'].values)
y_data = pm.MutableData('y_data', self.data['output'].values)

# prior parameters
a_loc = self.model_config['a_loc']
a_scale = self.model_config['a_scale']
b_loc = self.model_config['b_loc']
b_scale = self.model_config['b_scale']
obs_error = self.model_config['obs_error']

# priors
a = pm.Normal("a", a_loc, sigma=a_scale)
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

# observed data
y_model = pm.Normal('y_model', a + b * x, obs_error, observed=y_data)
'''
raise NotImplementedError


def _data_setter(self, data : Dict[Str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only : bool = True):
'''
Sets new data in the model.

Parameter
--------
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved

data: Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to set as idata for the model
x_only: bool
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
if data only contains values of x and y is not present in the data

Example:
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved

def _data_setter(self, data : pd.DataFrame):
with self.model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing indent

pm.set_data({'x': data['input'].values})
try: # if y values in new data
pm.set_data({'y_data': data['output'].values})
except: # dummies otherwise
pm.set_data({'y_data': np.zeros(len(data))})

'''
raise NotImplementedError


@classmethod
def create_sample_input(cls):
'''
Needs to be implemented by the user in the inherited class.
Returns examples for data, model_config, samples_config.
This is useful for understanding the required
data structures for the user model.
'''
raise NotImplementedError


def build(self):
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
with self:
self._build()


def save(self,file_prefix,filepath,save_model=True,save_idata=True):
"""
Saves the model as well as inference data of the model.

Parameters
----------
file_prefix: string
Passed which denotes the name with which model and idata should be saved.
filepath: string
Used as path at which model and idata should be saved
save_model: bool
Saves the model at given filepath with given file_prefix.
Does not save the model if passed as False
save_idata: bool
Saves the idata at given filepath with given file_prefix.
Does not save the idata if passed as False

"""
if save_idata:
file = Path(filepath+str(file_prefix)+'.nc')
self.idata.to_netcdf(file)
if save_model:
filepath = Path(str(filepath)+str(file_prefix)+'.pickle')
Model = cloudpickle.dumps(self)
file = open(filepath, 'wb')
file.write(Model)
self.saved = True

def load_model(self,filename):
'''
Loads the saved model from local system.
Return pymc model

Parameters
----------
filename: string
File name of saved model with it's path if not present in current working directory.
'''
with open(filename, "rb") as pickle_file:
model = pickle.load(pickle_file)
if isinstance(model, self):
return model
else:
raise ValueError(
f"The route '{filename}' does not contain an object of the class '{self.__name__}'"
)
@classmethod
def load(cls,file_prefix,filepath,load_model=True,load_idata=True):
'''
Loads model and the idata of used for model.

Parameters
----------
file_prefix: string
Passed which denotes the name with which model and idata should be loaded from.
filepath: string
Used as path at which model and idata should be loaded from.
save_model: bool
Loads the model at given filepath with given file_prefix.
Does not load the model if passed as False
save_idata: bool
Loads the idata at given filepath with given file_prefix.
Does not load the idata if passed as False

'''

file = Path(str(filepath)+str(file_prefix)+'.pickle')
self = load_model(self,file)
filepath = Path(str(filepath)+str(file_prefix)+'.nc')
data = az.from_netcdf(filepath)
self.idata = data
return self

# fit and predict methods
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
def fit(self, data : Dict[Str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
'''
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
It returns the inference data.

Parameter
---------
data: Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to train the model on.
'''
if data is not None:
self.data = data
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved

if self.basic_RVs == []:
print('No model found, building model...')
self.build()

with self:
self.idata = pm.sample(**self.sample_config)
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
self.idata.extend(pm.sample_prior_predictive())
self.idata.extend(pm.sample_posterior_predictive(self.idata))

self.idata.attrs['id']=self.id()
self.idata.attrs['model_type']=self._model_type
self.idata.attrs['version']=self.version
self.idata.attrs['sample_conifg']=self.sample_config
self.idata.attrs['model_config']=self.model_config
return self.idata


def predict(self, data_prediction : Dict[Str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, point_estimate : bool = True):
'''
Uses model to predict on unseen data and returns posterioir prediction on the data.

Parameters
---------
data_prediction: Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to make prediction on using the model.
point_estimate: bool
Adds point like estimate used as mean passed as

'''
if data_prediction is not None: # set new input data
self._data_setter(data_prediction)

with self.model: # sample with new input data
post_pred = pm.sample_posterior_predictive(self.idata.posterior)

# reshape output
post_pred = self._extract_samples(post_pred)

if point_estimate: # average, if point-like estimate desired
for key in post_pred:
post_pred[key] = post_pred[key].mean(axis=0)
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved

if data_prediction is not None: # set back original data in model
self._data_setter(self.data)

return post_pred


@staticmethod
def _extract_samples(post_pred : arviz.data.inference_data.InferenceData) -> Dict[str, np.array]:
'''
Returns dict of numpy arrays from InferenceData object

Parameters
----------
post_pred: arviz InferenceData object
'''
post_pred_dict = dict()
for key in post_pred.posterior_predictive:
post_pred_dict[key] = post_pred.posterior_predictive[key].to_numpy()[0]

return post_pred_dict

def id(self):
'''
It creates a hash value to match the model version using last 16 characters of hash encoding.
'''
hasher = hashlib.sha256()
hasher.update(str(self.model_config.values()).encode())
hasher.update(self.version.encode())
hasher.update(self._model_type.encode())
hasher.update(str(self.sample_config.values()).encode())
return hasher.hexdigest()[:16]