Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added model_builder as directed in PR #6023 on pymc #64

Merged
merged 39 commits into from
Sep 13, 2022
Merged
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3c3dbda
added model_builder
5hv5hvnk Aug 9, 2022
9a2ca23
added explanation
5hv5hvnk Aug 28, 2022
a510ef4
Merge branch 'pymc-devs:main' into main
5hv5hvnk Sep 3, 2022
410617b
added tests
5hv5hvnk Sep 3, 2022
4e1d200
formatting
5hv5hvnk Sep 3, 2022
50cfc1c
change in save and load method
5hv5hvnk Sep 6, 2022
9dab7cc
updated save and load methods
5hv5hvnk Sep 6, 2022
d757d9c
fixed more errors
5hv5hvnk Sep 7, 2022
f6c0fab
fixed path variable in save and load
5hv5hvnk Sep 7, 2022
2e1e02f
Documentation
5hv5hvnk Sep 7, 2022
890a00f
Documentation
5hv5hvnk Sep 7, 2022
e1cd115
fixed formatting and tests
5hv5hvnk Sep 7, 2022
6a80412
fixed docstring
5hv5hvnk Sep 7, 2022
5912eed
fixed minor issues
5hv5hvnk Sep 8, 2022
dd5b25b
fixed minor issues
5hv5hvnk Sep 8, 2022
c43651b
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
80c73b5
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
8ab4228
fixed spelling errors
5hv5hvnk Sep 8, 2022
e272f1c
Merge branch 'main' of https://github.com/5hv5hvnk/pymc-experimental …
5hv5hvnk Sep 8, 2022
6ccfa1e
Update pymc_experimental/model_builder.py
5hv5hvnk Sep 8, 2022
43a06df
Update pymc_experimental/model_builder.py
twiecki Sep 8, 2022
6ce4838
Update pymc_experimental/model_builder.py
twiecki Sep 8, 2022
34f1aaf
added build method again
5hv5hvnk Sep 8, 2022
ad5fd70
Update pymc_experimental/tests/test_model_builder.py
twiecki Sep 9, 2022
2b42132
removed unecessary imports
5hv5hvnk Sep 9, 2022
8e0fd19
Merge branch 'pymc-devs:main' into main
5hv5hvnk Sep 10, 2022
d1467fe
changed arviz to az
5hv5hvnk Sep 10, 2022
0165fcb
fixed codes to pass build tests
5hv5hvnk Sep 10, 2022
b49199d
linespace -> linspace
twiecki Sep 11, 2022
63c4ed9
updated test_model_builder.py
5hv5hvnk Sep 11, 2022
fb51049
updated model_builder.py
5hv5hvnk Sep 11, 2022
0030a50
fixed overloading of test_fit()
5hv5hvnk Sep 11, 2022
e5f8e72
fixed indentation in docstring
5hv5hvnk Sep 11, 2022
b045f08
added some better examples
5hv5hvnk Sep 12, 2022
37f0bc3
fixed test.yml
5hv5hvnk Sep 12, 2022
63502a4
indetation
5hv5hvnk Sep 12, 2022
b4d1ca1
Apply suggestions from code review
twiecki Sep 13, 2022
70053e1
Update pymc_experimental/model_builder.py
twiecki Sep 13, 2022
c498b20
Update pymc_experimental/model_builder.py
twiecki Sep 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 112 additions & 19 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,33 @@


class ModelBuilder(pm.Model):
"""
"""
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
Extention of pm.Model class to improve workflow.

ModelBuilder class can be used to play around models with ease using direct API calls
for multiple tasks that one need to deploy a model.

Example:

"""
twiecki marked this conversation as resolved.
Show resolved Hide resolved

_model_type = "BaseClass"
version = "None"

def __init__(self, model_config: Dict, sampler_config: Dict):
"""
Initialises model configration and sampler configration for the model
Parameters
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
----------
model_confid: Dictionary
dictonary of parameters that initialise model configration. Genrated by the user defiend create_sample_input method.
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
sampler_config: Dictionary
dictonary of parameters that initialise sampler configration. Genrated by the user defiend create_sample_input method.

Example:
>>> class LinearModel(ModelBuilder)
...
>>> model = LinearModel(model_config, sampler_config)
"""

super().__init__()
self.model_config = model_config # parameters for priors etc.
self.sample_config = sampler_config # parameters for sampling
Expand All @@ -59,13 +72,13 @@ def _data_setter(

Example:
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved

def _data_setter(self, data : pd.DataFrame):
with self.model:
pm.set_data({'x': data['input'].values})
try: # if y values in new data
pm.set_data({'y_data': data['output'].values})
except: # dummies otherwise
pm.set_data({'y_data': np.zeros(len(data))})
>>>def _data_setter(self, data : pd.DataFrame):
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
>>> with self.model:
>>> pm.set_data({'x': data['input'].values})
>>> try: # if y values in new data
>>> pm.set_data({'y_data': data['output'].values})
>>> except: # dummies otherwise
>>> pm.set_data({'y_data': np.zeros(len(data))})

"""
raise NotImplementedError
Expand All @@ -74,22 +87,52 @@ def _data_setter(self, data : pd.DataFrame):
def create_sample_input(cls):
"""
Needs to be implemented by the user in the inherited class.
Returns examples for data, model_config, samples_config.
Returns examples for data, model_config, sampler_config.
This is useful for understanding the required
data structures for the user model.

Example:
>>>@classmethod
>>>def create_sample_input(cls):
>>> x = np.linspace(start=1, stop=50, num=100)
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
>>> data = pd.DataFrame({'input': x, 'output': y})

>>> model_config = {
twiecki marked this conversation as resolved.
Show resolved Hide resolved
>>> 'a_loc': 7,
>>> 'a_scale': 3,
>>> 'b_loc': 5,
>>> 'b_scale': 3,
>>> 'obs_error': 2,
>>> }

>>> sampler_config = {
>>> 'draws': 1_000,
>>> 'tune': 1_000,
>>> 'chains': 1,
>>> 'target_accept': 0.95,
>>> }

>>> return data, model_config, sampler_config
"""
raise NotImplementedError

def save(self, file_prefix, filepath):
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
"""
Saves the model as well as inference data of the model.
Saves inference data of the model.

Parameters
----------
file_prefix: string
Passed which denotes the name with which model and idata should be saved.
filepath: string
Used as path at which model and idata should be saved

Example
------
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
>>> name = 'mymodel'
>>> path = '.'
>>> model.save(name,path)

"""
file = Path(filepath + str(file_prefix) + ".nc")
Expand All @@ -99,7 +142,7 @@ def save(self, file_prefix, filepath):
@classmethod
def load(cls, file_prefix, filepath):
"""
Loads model and the idata of used for model.
Loads infernce data for the model.

Parameters
----------
Expand All @@ -108,23 +151,56 @@ def load(cls, file_prefix, filepath):
filepath: string
Used as path at which model and idata should be loaded from.

"""
Return
------
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
Returns the inference data that is loaded from local system.

Example
-------
>>> class LinearModel
...
>>> name = 'mymodel'
>>> path = '.'
>>> imported_model = LinearModel.load(name,path)

"""

filepath = Path(str(filepath) + str(file_prefix) + ".nc")
data = az.from_netcdf(filepath)
self.idata = data
return self
# if model.idata.attrs is not None:
twiecki marked this conversation as resolved.
Show resolved Hide resolved
# if model.idata.attrs['id'] == self.idata.attrs['id']:
# self = model
# self.idata = data
# return self
# else:
# raise ValueError(
# f"The route '{file}' does not contain an inference data of the same model '{self.__name__}'"
# )
return self.idata

# fit and predict methods
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
"""
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
It returns the inference data.
Sets attrs to inference data of the model

Parameter
---------
data: Dictionary of string and either of numpy array, pandas dataframe or pandas Series
It is the data we need to train the model on.

Retruns
--------
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
returns infernece data of the fitted model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
returns infernece data of the fitted model.
returns inference data of the fitted model.


Example
-------
>>>data, model_config, sampler_config = LinearModel.create_sample_input()
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
>>>model = LinearModel(model_config, sampler_config)
>>>idata = model.fit(data)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
"""
if data is not None:
self.data = data
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -151,7 +227,7 @@ def predict(
point_estimate: bool = True,
):
"""
Uses model to predict on unseen data and returns posterioir prediction on the data.
Uses model to predict on unseen data.

Parameters
---------
Expand All @@ -160,6 +236,17 @@ def predict(
point_estimate: bool
Adds point like estimate used as mean passed as

Returns
-------
returns sample's posterior predict

Example
-------
>>> prediction_data = pd.DataFrame({'input':x_pred})
# only point estimate
>>>pred_mean = imported_model.predict(prediction_data)
# samples
>>>pred_samples = imported_model.predict(prediction_data, point_estimate=False)
"""
if data_prediction is not None: # set new input data
self._data_setter(data_prediction)
Expand All @@ -182,11 +269,13 @@ def predict(
@staticmethod
def _extract_samples(post_pred: arviz.data.inference_data.InferenceData) -> Dict[str, np.array]:
"""
Returns dict of numpy arrays from InferenceData object

Parameters
----------
post_pred: arviz InferenceData object

Returns
-------
Dictionary of numpy arrays from InferenceData object
"""
post_pred_dict = dict()
for key in post_pred.posterior_predictive:
Expand All @@ -197,6 +286,10 @@ def _extract_samples(post_pred: arviz.data.inference_data.InferenceData) -> Dict
def id(self):
"""
It creates a hash value to match the model version using last 16 characters of hash encoding.

Return
------
5hv5hvnk marked this conversation as resolved.
Show resolved Hide resolved
Returns string of length 16 characters containg unique hash of the model
"""
hasher = hashlib.sha256()
hasher.update(str(self.model_config.values()).encode())
Expand Down