Skip to content

Commit

Permalink
Improvements&bugfixing (#118)
Browse files Browse the repository at this point in the history
* create serde folder if it is missing

* add generate random seed

* cleaunp
  • Loading branch information
bcebere authored Jan 25, 2023
1 parent d3b1014 commit 227bd54
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) ->
if "cond" in kwargs and kwargs["cond"] is not None:
self.expecting_conditional = True

enable_reproducible_results(self.random_state)

self.data_info = X.info()

self._schema = Schema(
Expand Down Expand Up @@ -262,6 +264,7 @@ def generate(
self,
count: Optional[int] = None,
constraints: Optional[Constraints] = None,
random_state: Optional[int] = None,
**kwargs: Any,
) -> DataLoader:
"""Synthetic data generation method.
Expand Down Expand Up @@ -301,6 +304,9 @@ def generate(
>>>
>>> assert (syn_data["InterestingFeature"] == 0).all()
random_state: optional int.
Optional random seed to use.
Returns:
<count> synthetic samples
"""
Expand All @@ -310,6 +316,9 @@ def generate(
if self._schema is None:
raise RuntimeError("Fit the model first")

if random_state is not None:
enable_reproducible_results(random_state)

has_gen_cond = "cond" in kwargs and kwargs["cond"] is not None
if has_gen_cond and not self.expecting_conditional:
raise RuntimeError(
Expand Down
6 changes: 6 additions & 0 deletions src/synthcity/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ def load(buff: bytes) -> Any:


def save_to_file(path: Union[str, Path], model: Any) -> Any:
path = Path(path)
ppath = path.absolute().parent

if not ppath.exists():
ppath.mkdir(parents=True, exist_ok=True)

with open(path, "wb") as f:
return cloudpickle.dump(model, f)

Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.8"
__version__ = "0.1.9"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
MINOR_VERSION = __version__.split(".")[-1]
7 changes: 7 additions & 0 deletions tests/plugins/generic/test_ctgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def test_plugin_generate(test_plugin: Plugin, serialize: bool) -> None:
assert len(X_gen) == 50
assert test_plugin.schema_includes(X_gen)

# generate with random seed
X_gen1 = test_plugin.generate(50, random_state=0)
X_gen2 = test_plugin.generate(50, random_state=0)
X_gen3 = test_plugin.generate(50)
assert (X_gen1.numpy() == X_gen2.numpy()).all()
assert (X_gen1.numpy() != X_gen3.numpy()).any()


@pytest.mark.parametrize(
"test_plugin", generate_fixtures(plugin_name, plugin, plugin_args)
Expand Down

0 comments on commit 227bd54

Please sign in to comment.