Skip to content

Commit

Permalink
new test for save, fit allows custom configs
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelraczycki committed Apr 19, 2023
1 parent 4c99877 commit e89724c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 12 deletions.
34 changes: 25 additions & 9 deletions pymc_experimental/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,11 @@ def save(self, fname: str) -> None:
>>> name = './mymodel.nc'
>>> model.save(name)
"""

file = Path(str(fname))
self.idata.to_netcdf(file)
if self.idata is not None and "fit_data" in self.idata:
file = Path(str(fname))
self.idata.to_netcdf(file)
else:
raise RuntimeError("The model hasn't been fit yet, call .fit() first")

@classmethod
def load(cls, fname: str):
Expand Down Expand Up @@ -220,7 +222,7 @@ def load(cls, fname: str):
data=idata.fit_data.to_dataframe(),
)
model_builder.idata = idata
model_builder.idata = model_builder.fit()
model_builder.build_model(model_builder.data, model_builder.model_config)
if model_builder.id != idata.attrs["id"]:
raise ValueError(
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
Expand Down Expand Up @@ -261,11 +263,12 @@ def fit(
# If a new data was provided, assign it to the model
if data is not None:
self.data = data
self.model_data, self.model_config, self.sampler_config = self.create_sample_input(
data=self.data
)
self.model_data, model_config, sampler_config = self.create_sample_input(data=self.data)
if self.model_config is None:
self.model_config = model_config
if self.sampler_config is None:
self.sampler_config = sampler_config
self.build_model(self.model_data, self.model_config)
self._data_setter(self.model_data)
with self.model:
self.idata = pm.sample(**self.sampler_config, **kwargs)
self.idata.extend(pm.sample_prior_predictive())
Expand All @@ -275,7 +278,7 @@ def fit(
self.idata.attrs["model_type"] = self._model_type
self.idata.attrs["version"] = self.version
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
self.idata.attrs["model_config"] = json.dumps(self.serializable_model_config)
self.idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
self.idata.add_groups(fit_data=self.data.to_xarray())
return self.idata

Expand Down Expand Up @@ -386,6 +389,19 @@ def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[st

return post_pred_dict

@property
@abstractmethod
def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
"""
Converts non-serializable values from model_config to their serializable reversable equivalent.
Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
so in order to save the model they need to be formatted.
Returns
-------
model_config: dict
"""

@property
def id(self) -> str:
"""
Expand Down
15 changes: 12 additions & 3 deletions pymc_experimental/tests/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _data_setter(self, data: pd.DataFrame):
pm.set_data({"y_data": data["output"].values})

@property
def serializable_model_config(self):
def _serializable_model_config(self):
return self.model_config

@classmethod
Expand Down Expand Up @@ -95,7 +95,16 @@ def initial_build_and_fit(check_idata=True) -> ModelBuilder:
return model_builder


def test_empty_model_config():
def test_save_without_fit_raises_runtime_error():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_builder = test_ModelBuilder(
model_config=model_config, sampler_config=sampler_config, data=data
)
with pytest.raises(RuntimeError):
model_builder.save("saved_model")


def test_empty_sampler_config_fit():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
sampler_config = {}
model_builder = test_ModelBuilder(
Expand All @@ -106,7 +115,7 @@ def test_empty_model_config():
assert "posterior" in model_builder.idata.groups()


def test_empty_model_config():
def test_empty_model_config_fit():
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
model_config = {}
model_builder = test_ModelBuilder(
Expand Down

0 comments on commit e89724c

Please sign in to comment.