diff --git a/pymc_experimental/linearmodel.py b/pymc_experimental/linearmodel.py index 8f1b7efd..c2efbf33 100644 --- a/pymc_experimental/linearmodel.py +++ b/pymc_experimental/linearmodel.py @@ -20,16 +20,16 @@ def __init__(self, model_config: Dict = None, sampler_config: Dict = None, nsamp _model_type = "LinearModel" version = "0.1" - @property - def default_model_config(self): + @staticmethod + def get_default_model_config(): return { "intercept": {"loc": 0, "scale": 10}, "slope": {"loc": 0, "scale": 10}, "obs_error": 2, } - @property - def default_sampler_config(self): + @staticmethod + def get_default_sampler_config(): return { "draws": 1_000, "tune": 1_000, diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 211ac053..e61e38bc 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -71,9 +71,11 @@ def __init__( >>> ... >>> model = MyModel(model_config, sampler_config) """ - sampler_config = self.default_sampler_config if sampler_config is None else sampler_config + sampler_config = ( + self.get_default_sampler_config() if sampler_config is None else sampler_config + ) self.sampler_config = sampler_config - model_config = self.default_model_config if model_config is None else model_config + model_config = self.get_default_model_config() if model_config is None else model_config self.model_config = model_config # parameters for priors etc. self.model = None # Set by build_model @@ -133,17 +135,17 @@ def output_var(self): """ raise NotImplementedError - @property + @staticmethod @abstractmethod - def default_model_config(self) -> Dict: + def get_default_model_config() -> Dict: """ Returns a class default config dict for model builder if no model_config is provided on class initialization Useful for understanding structure of required model_config to allow its customization by users Examples -------- - >>> @classmethod - >>> def default_model_config(self): - >>> Return { + >>> @staticmethod + >>> def default_model_config(): + >>> return { >>> 'a' : { >>> 'loc': 7, >>> 'scale' : 3 @@ -162,17 +164,17 @@ def default_model_config(self) -> Dict: """ raise NotImplementedError - @property + @staticmethod @abstractmethod - def default_sampler_config(self) -> Dict: + def get_default_sampler_config(self) -> Dict: """ Returns a class default sampler dict for model builder if no sampler_config is provided on class initialization Useful for understanding structure of required sampler_config to allow its customization by users Examples -------- - >>> @classmethod - >>> def default_sampler_config(self): - >>> Return { + >>> @staticmethod + >>> def default_sampler_config(): + >>> return { >>> 'draws': 1_000, >>> 'tune': 1_000, >>> 'chains': 1, diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 27eaf1a8..a0b24346 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -74,7 +74,7 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None): self.generate_and_preprocess_model_data(X, y) with pm.Model(coords=coords) as self.model: if model_config is None: - model_config = self.default_model_config + model_config = self.model_config x = pm.MutableData("x", self.X["input"].values) y_data = pm.MutableData("y_data", self.y) @@ -114,8 +114,8 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series): self.X = X self.y = y - @property - def default_model_config(self) -> Dict: + @staticmethod + def get_default_model_config() -> Dict: return { "a": {"loc": 0, "scale": 10, "dims": ("numbers",)}, "b": {"loc": 0, "scale": 10}, @@ -128,8 +128,8 @@ def _generate_and_preprocess_model_data( self.X = X self.y = y - @property - def default_sampler_config(self) -> Dict: + @staticmethod + def get_default_sampler_config() -> Dict: return { "draws": 1_000, "tune": 1_000,