Skip to content

Commit

Permalink
ModelBuilder.load versatility improvements (#210)
Browse files Browse the repository at this point in the history
* fixing dims format, enabling input param preservation

* adding additional test for implemented features

* fixing typehinting in linearmodel

* introducing dims to model_builder tests to check for dim format preservation
  • Loading branch information
michaelraczycki authored Jul 11, 2023
1 parent dd3c44d commit 5f7b185
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 42 deletions.
4 changes: 4 additions & 0 deletions pymc_experimental/linearmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def default_sampler_config(self):
"target_accept": 0.95,
}

@property
def _serializable_model_config(self) -> Dict:
return self.model_config

@property
def output_var(self):
return "y_hat"
Expand Down
22 changes: 20 additions & 2 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def sample_model(self, **kwargs):
idata.extend(pm.sample_prior_predictive())
idata.extend(pm.sample_posterior_predictive(idata))

self.set_idata_attrs(idata)
idata = self.set_idata_attrs(idata)
return idata

def set_idata_attrs(self, idata=None):
Expand Down Expand Up @@ -338,6 +338,10 @@ def set_idata_attrs(self, idata=None):
idata.attrs["version"] = self.version
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
# Only classes with non-dataset parameters will implement save_input_params
if hasattr(self, "_save_input_params"):
self._save_input_params(idata)
return idata

def save(self, fname: str) -> None:
"""
Expand Down Expand Up @@ -375,6 +379,17 @@ def save(self, fname: str) -> None:
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

@classmethod
def _convert_dims_to_tuple(cls, model_config: Dict) -> Dict:
for key in model_config:
if (
isinstance(model_config[key], dict)
and "dims" in model_config[key]
and isinstance(model_config[key]["dims"], list)
):
model_config[key]["dims"] = tuple(model_config[key]["dims"])
return model_config

@classmethod
def load(cls, fname: str):
"""
Expand Down Expand Up @@ -403,8 +418,10 @@ def load(cls, fname: str):
"""
filepath = Path(str(fname))
idata = az.from_netcdf(filepath)
# needs to be converted, because json.loads was changing tuple to list
model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"]))
model = cls(
model_config=json.loads(idata.attrs["model_config"]),
model_config=model_config,
sampler_config=json.loads(idata.attrs["sampler_config"]),
)
model.idata = idata
Expand Down Expand Up @@ -480,6 +497,7 @@ def fit(
combined_data = pd.concat([X_df, y], axis=1)
assert all(combined_data.columns), "All columns must have non-empty names"
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore

return self.idata # type: ignore

def predict(
Expand Down
36 changes: 18 additions & 18 deletions pymc_experimental/tests/test_linearmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,24 @@ def fitted_linear_model_instance(toy_X, toy_y):
return model


@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_linear_model_instance):
model = fitted_linear_model_instance
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
model.save(temp.name)
model2 = LinearModel.load(temp.name)
assert model.idata.groups() == model2.idata.groups()

X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
pred1 = model.predict(X_pred, random_seed=423)
pred2 = model2.predict(X_pred, random_seed=423)
# Predictions should be identical
np.testing.assert_array_equal(pred1, pred2)
temp.close()


def test_save_without_fit_raises_runtime_error(toy_X, toy_y):
test_model = LinearModel()
with pytest.raises(RuntimeError):
Expand All @@ -83,24 +101,6 @@ def test_fit(fitted_linear_model_instance):
assert isinstance(post_pred, xr.DataArray)


@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_linear_model_instance):
model = fitted_linear_model_instance
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
model.save(temp.name)
model2 = LinearModel.load(temp.name)
assert model.idata.groups() == model2.idata.groups()

X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
pred1 = model.predict(X_pred, random_seed=423)
pred2 = model2.predict(X_pred, random_seed=423)
# Predictions should be identical
np.testing.assert_array_equal(pred1, pred2)
temp.close()


def test_predict(fitted_linear_model_instance):
model = fitted_linear_model_instance
X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
Expand Down
72 changes: 50 additions & 22 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import hashlib
import json
import sys
import tempfile
from typing import Dict
Expand Down Expand Up @@ -43,29 +44,35 @@ def toy_y(toy_X):
@pytest.fixture(scope="module")
def fitted_model_instance(toy_X, toy_y):
sampler_config = {
"draws": 500,
"tune": 300,
"draws": 100,
"tune": 100,
"chains": 2,
"target_accept": 0.95,
}
model_config = {
"a": {"loc": 0, "scale": 10},
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
model = test_ModelBuilder(
model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
)
model.fit(toy_X)
return model


class test_ModelBuilder(ModelBuilder):
def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
self.test_parameter = test_parameter
super().__init__(model_config=model_config, sampler_config=sampler_config)

_model_type = "LinearModel"
_model_type = "test_model"
version = "0.1"

def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
coords = {"numbers": np.arange(len(X))}
self.generate_and_preprocess_model_data(X, y)
with pm.Model() as self.model:
with pm.Model(coords=coords) as self.model:
if model_config is None:
model_config = self.default_model_config
x = pm.MutableData("x", self.X["input"].values)
Expand All @@ -79,13 +86,16 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
obs_error = model_config["obs_error"]

# priors
a = pm.Normal("a", a_loc, sigma=a_scale)
a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"])
b = pm.Normal("b", b_loc, sigma=b_scale)
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)

# observed data
output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)

def _save_input_params(self, idata):
idata.attrs["test_paramter"] = json.dumps(self.test_parameter)

@property
def output_var(self):
return "output"
Expand All @@ -107,7 +117,7 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
@property
def default_model_config(self) -> Dict:
return {
"a": {"loc": 0, "scale": 10},
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
"b": {"loc": 0, "scale": 10},
"obs_error": 2,
}
Expand All @@ -122,6 +132,38 @@ def default_sampler_config(self) -> Dict:
}


def test_save_input_params(fitted_model_instance):
assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'


def test_save_load(fitted_model_instance):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
assert fitted_model_instance.id == test_builder2.id
x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
assert pred1.shape == pred2.shape
temp.close()


def test_convert_dims_to_tuple(fitted_model_instance):
model_config = {
"a": {
"loc": 0,
"scale": 10,
"dims": [
"x",
],
},
}
converted_model_config = fitted_model_instance._convert_dims_to_tuple(model_config)
assert converted_model_config["a"]["dims"] == ("x",)


def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
if check_idata:
assert fitted_model_instance.idata is not None
Expand Down Expand Up @@ -162,20 +204,6 @@ def test_fit_no_y(toy_X):
@pytest.mark.skipif(
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
)
def test_save_load(fitted_model_instance):
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
fitted_model_instance.save(temp.name)
test_builder2 = test_ModelBuilder.load(temp.name)
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()

x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
pred1 = fitted_model_instance.predict(prediction_data["input"])
pred2 = test_builder2.predict(prediction_data["input"])
assert pred1.shape == pred2.shape
temp.close()


def test_predict(fitted_model_instance):
x_pred = np.random.uniform(low=0, high=1, size=100)
prediction_data = pd.DataFrame({"input": x_pred})
Expand Down

0 comments on commit 5f7b185

Please sign in to comment.