From 227bd54df694bd78ea256ce1aa33337dbb59d1f0 Mon Sep 17 00:00:00 2001 From: Bogdan Cebere Date: Wed, 25 Jan 2023 19:05:10 +0200 Subject: [PATCH] Improvements&bugfixing (#118) * create serde folder if it is missing * add generate random seed * cleaunp --- src/synthcity/plugins/core/plugin.py | 9 +++++++++ src/synthcity/utils/serialization.py | 6 ++++++ src/synthcity/version.py | 2 +- tests/plugins/generic/test_ctgan.py | 7 +++++++ 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py index 637064b3..a670631d 100644 --- a/src/synthcity/plugins/core/plugin.py +++ b/src/synthcity/plugins/core/plugin.py @@ -211,6 +211,8 @@ def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) -> if "cond" in kwargs and kwargs["cond"] is not None: self.expecting_conditional = True + enable_reproducible_results(self.random_state) + self.data_info = X.info() self._schema = Schema( @@ -262,6 +264,7 @@ def generate( self, count: Optional[int] = None, constraints: Optional[Constraints] = None, + random_state: Optional[int] = None, **kwargs: Any, ) -> DataLoader: """Synthetic data generation method. @@ -301,6 +304,9 @@ def generate( >>> >>> assert (syn_data["InterestingFeature"] == 0).all() + random_state: optional int. + Optional random seed to use. + Returns: synthetic samples """ @@ -310,6 +316,9 @@ def generate( if self._schema is None: raise RuntimeError("Fit the model first") + if random_state is not None: + enable_reproducible_results(random_state) + has_gen_cond = "cond" in kwargs and kwargs["cond"] is not None if has_gen_cond and not self.expecting_conditional: raise RuntimeError( diff --git a/src/synthcity/utils/serialization.py b/src/synthcity/utils/serialization.py index 9f77f5d3..06c72d4f 100644 --- a/src/synthcity/utils/serialization.py +++ b/src/synthcity/utils/serialization.py @@ -17,6 +17,12 @@ def load(buff: bytes) -> Any: def save_to_file(path: Union[str, Path], model: Any) -> Any: + path = Path(path) + ppath = path.absolute().parent + + if not ppath.exists(): + ppath.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as f: return cloudpickle.dump(model, f) diff --git a/src/synthcity/version.py b/src/synthcity/version.py index 03510e0b..1d2cac53 100644 --- a/src/synthcity/version.py +++ b/src/synthcity/version.py @@ -1,4 +1,4 @@ -__version__ = "0.1.8" +__version__ = "0.1.9" MAJOR_VERSION = ".".join(__version__.split(".")[:-1]) MINOR_VERSION = __version__.split(".")[-1] diff --git a/tests/plugins/generic/test_ctgan.py b/tests/plugins/generic/test_ctgan.py index 4d1e2079..8be4e479 100644 --- a/tests/plugins/generic/test_ctgan.py +++ b/tests/plugins/generic/test_ctgan.py @@ -75,6 +75,13 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None: assert len(X_gen) == 50 assert test_plugin.schema_includes(X_gen) + # generate with random seed + X_gen1 = test_plugin.generate(50, random_state=0) + X_gen2 = test_plugin.generate(50, random_state=0) + X_gen3 = test_plugin.generate(50) + assert (X_gen1.numpy() == X_gen2.numpy()).all() + assert (X_gen1.numpy() != X_gen3.numpy()).any() + @pytest.mark.parametrize( "test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)