Skip to content

Commit

Permalink
Python 3.10 support & bugfixing (#95)
Browse files Browse the repository at this point in the history
* Python 3.10

* improve docs

* update req

* update README

* upgrade optuna

* bugfixing

* rework caching flows

* cleanup args

* update args

* cleanup benchmark args

* update tests

* add citation file
  • Loading branch information
bcebere authored Jan 19, 2023
1 parent 3563c8e commit 9a74dd6
Show file tree
Hide file tree
Showing 37 changed files with 585 additions and 102 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9']
python-version: ['3.7', '3.8', '3.9', '3.10']
os: [macos-latest, ubuntu-latest, windows-latest]
steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.7", "3.8", "3.9"]
python-version: ["3.7", "3.8", "3.9", "3.10"]
os: [ubuntu-latest]
steps:
- uses: actions/checkout@v2
Expand Down
28 changes: 28 additions & 0 deletions CITATION.cff
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This CITATION.cff file was generated with cffinit.
# Visit https://bit.ly/cffinit to generate yours today!

cff-version: 1.2.0
title: >-
Synthcity: facilitating innovative use cases of synthetic
data in different data modalities
message: >-
If you use this software, please cite it using the
metadata from this file.
type: software
authors:
- given-names: Zhaozhi
family-names: Qian
- given-names: Bogdan-Constantin
family-names: Cebere
- given-names: Mihaela
family-names: van der Schaar
identifiers:
- type: doi
value: 10.48550/ARXIV.2301.07573
repository-code: 'https://github.com/vanderschaarlab/synthcity'
url: 'https://arxiv.org/abs/2301.07573'
keywords:
- Machine Learning (cs.LG)
- Artificial Intelligence (cs.AI)
- 'FOS: Computer and information sciences'
license: Apache-2.0
21 changes: 18 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<h2 align="center">
synthcity BETA
synthcity
</h2>

<h4 align="center">
Expand All @@ -18,6 +18,8 @@
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/vanderschaarlab/synthcity/blob/main/LICENSE)
[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)

![image](https://github.com/vanderschaarlab/synthcity/raw/main/docs/arch.png "Synthcity")

</div>


Expand All @@ -33,8 +35,6 @@
- :book: [Read the docs !](https://synthcity.readthedocs.io/)
- :airplane: [Checkout the tutorials!](https://github.com/vanderschaarlab/synthcity#-tutorials)

:rotating_light: __NOTE__: Python 3.10 is __NOT__ supported yet.

## :rocket: Installation

The library can be installed from PyPI using
Expand Down Expand Up @@ -353,3 +353,18 @@ The tests can be executed using
```bash
pytest -vsx
```
## Citing

If you use this code, please cite the associated paper:

```
@misc{https://doi.org/10.48550/arxiv.2301.07573,
doi = {10.48550/ARXIV.2301.07573},
url = {https://arxiv.org/abs/2301.07573},
author = {Qian, Zhaozhi and Cebere, Bogdan-Constantin and van der Schaar, Mihaela},
keywords = {Machine Learning (cs.LG), Artificial Intelligence (cs.AI), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Synthcity: facilitating innovative use cases of synthetic data in different data modalities},
year = {2023},
copyright = {Creative Commons Attribution 4.0 International}
}
```
Binary file added docs/arch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ install_requires =
lifelines>=0.27
opacus>=1.3
decaf-synthetic-data>=0.1.5
rdt>=1.2.1
rdt>=1.3
optuna>=3.1
diffprivlib
shap
tqdm
Expand All @@ -53,7 +54,6 @@ install_requires =
geomloss
deepecho
pgmpy
optuna
redis
pycox
xgbse
Expand Down
7 changes: 7 additions & 0 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def evaluate(

for testcase, plugin, kwargs in tests:
log.info(f"Testcase : {testcase}")
if not isinstance(kwargs, dict):
raise ValueError(f"'kwargs' must be a dict for {testcase}:{plugin}")

scores = ScoreEvaluator()

kwargs_hash = ""
Expand All @@ -116,6 +119,10 @@ def evaluate(

for repeat in repeats_list:
enable_reproducible_results(repeat)

kwargs["workspace"] = workspace
kwargs["random_state"] = repeat

torch.cuda.empty_cache()

cache_file = (
Expand Down
12 changes: 12 additions & 0 deletions src/synthcity/metrics/eval_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ class SyntheticDetectionXGB(DetectionEvaluator):
1: The datasets are totally distinguishable.
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "detection_xgb"
Expand Down Expand Up @@ -158,6 +161,9 @@ class SyntheticDetectionMLP(DetectionEvaluator):
1: The datasets are totally distinguishable.
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "detection_mlp"
Expand Down Expand Up @@ -193,6 +199,9 @@ class SyntheticDetectionLinear(DetectionEvaluator):
1: The datasets are totally distinguishable.
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "detection_linear"
Expand Down Expand Up @@ -227,6 +236,9 @@ class SyntheticDetectionGMM(DetectionEvaluator):
1: The datasets are totally distinguishable.
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

@staticmethod
def name() -> str:
return "detection_gmm"
Expand Down
6 changes: 1 addition & 5 deletions src/synthcity/plugins/core/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,7 @@ def decode(
decoded = self.dataframe().copy()

for col in encoders:
decoded_local = decoded[[col]].copy()
decoded_local[f"{col}.value"] = decoded_local[col]
decoded[col] = encoders[col].reverse_transform(
decoded_local[[f"{col}.value"]].astype(float)
)
decoded[col] = encoders[col].reverse_transform(decoded[[col]].astype(float))

return self.from_info(decoded, self.info())

Expand Down
34 changes: 9 additions & 25 deletions src/synthcity/plugins/core/models/tabular_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
"""

# stdlib
import platform
from collections import namedtuple
from pathlib import Path
from typing import Any, List, Optional, Sequence, Tuple

# third party
Expand All @@ -19,7 +17,7 @@

# synthcity absolute
import synthcity.logger as log
from synthcity.utils.serialization import dataframe_hash, load_from_file, save_to_file
from synthcity.utils.serialization import dataframe_hash

ColumnTransformInfo = namedtuple(
"ColumnTransformInfo",
Expand Down Expand Up @@ -67,7 +65,7 @@ def _fit_continuous(self, data: pd.Series) -> ColumnTransformInfo:
max_clusters=min(self.max_clusters, len(data)),
enforce_min_max_values=True,
)
gm.fit(data.to_frame(), [column_name])
gm.fit(data.to_frame(), column_name)
num_components = sum(gm.valid_component_indicator)

output_columns = [f"{column_name}.normalized"] + [
Expand Down Expand Up @@ -111,7 +109,7 @@ def _transform_continuous(
) -> pd.Series:
column_name = data.name
gm = column_transform_info.transform
transformed = gm.transform(data.to_frame(), [column_name])
transformed = gm.transform(data.to_frame()[[column_name]])

return transformed[f"{column_name}.component"].to_numpy().astype(int)

Expand Down Expand Up @@ -147,7 +145,6 @@ def __init__(
max_clusters: int = 10,
categorical_limit: int = 10,
whitelist: list = [],
workspace: Path = Path("workspace"),
) -> None:
"""Create a data transformer.
Expand All @@ -158,8 +155,6 @@ def __init__(
self.max_clusters = max_clusters
self.categorical_limit = categorical_limit
self.whitelist = whitelist
self.workspace = workspace
self.workspace.mkdir(parents=True, exist_ok=True)

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _fit_continuous(self, data: pd.Series) -> ColumnTransformInfo:
Expand All @@ -179,7 +174,7 @@ def _fit_continuous(self, data: pd.Series) -> ColumnTransformInfo:
max_clusters=min(len(data), self.max_clusters),
enforce_min_max_values=True,
)
gm.fit(data.to_frame(), [column_name])
gm.fit(data.to_frame(), column_name)
num_components = sum(gm.valid_component_indicator)

output_columns = [f"{column_name}.normalized"] + [
Expand Down Expand Up @@ -243,27 +238,16 @@ def fit(
self._column_raw_dtypes = raw_data.infer_objects().dtypes
self._column_transform_info_list = []

self.workspace.mkdir(parents=True, exist_ok=True)

for column_name in raw_data.columns:
if column_name in self.whitelist:
continue
column_hash = dataframe_hash(raw_data[[column_name]])
bkp_file = (
self.workspace
/ f"encoder_cache_{column_hash}_{column_name[:50]}_{self.max_clusters}_{self.categorical_limit}_{platform.python_version()}.bkp"
)

log.info(f"Encoding {column_name} {column_hash}")

if bkp_file.exists():
column_transform_info = load_from_file(bkp_file)
if column_name in discrete_columns:
column_transform_info = self._fit_discrete(raw_data[column_name])
else:
if column_name in discrete_columns:
column_transform_info = self._fit_discrete(raw_data[column_name])
else:
column_transform_info = self._fit_continuous(raw_data[column_name])
save_to_file(bkp_file, column_transform_info)
column_transform_info = self._fit_continuous(raw_data[column_name])

self.output_dimensions += column_transform_info.output_dimensions
self._column_transform_info_list.append(column_transform_info)
Expand All @@ -275,7 +259,7 @@ def _transform_continuous(
) -> pd.DataFrame:
column_name = data.name
gm = column_transform_info.transform
transformed = gm.transform(data.to_frame(), [column_name])
transformed = gm.transform(data.to_frame()[[column_name]])

# Converts the transformed data to the appropriate output format.
# The first column (ending in '.normalized') stays the same,
Expand Down Expand Up @@ -341,7 +325,7 @@ def _inverse_transform_continuous(
column_data.values[:, :2], columns=list(gm.get_output_sdtypes())
)
data.iloc[:, 1] = np.argmax(column_data.values[:, 1:], axis=1)
return gm.reverse_transform(data, [column_transform_info.column_name])
return gm.reverse_transform(data)

@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _inverse_transform_discrete(
Expand Down
4 changes: 4 additions & 0 deletions src/synthcity/plugins/core/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from synthcity.plugins.core.schema import Schema
from synthcity.plugins.core.serializable import Serializable
from synthcity.utils.constants import DEVICE
from synthcity.utils.reproducibility import enable_reproducible_results
from synthcity.utils.serialization import load_from_file, save_to_file


Expand Down Expand Up @@ -86,6 +87,9 @@ def __init__(
sampling_strategy: str = "marginal", # uniform, marginal
) -> None:
super().__init__()

enable_reproducible_results(random_state)

self._schema: Optional[Schema] = None
self._training_schema: Optional[Schema] = None
self._data_encoders: Optional[Dict] = None
Expand Down
25 changes: 22 additions & 3 deletions src/synthcity/plugins/generic/plugin_bayesian_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

# stdlib
from pathlib import Path
from typing import Any, List

# third party
Expand Down Expand Up @@ -38,9 +39,16 @@ class BayesianNetworkPlugin(Plugin):
The maximum number of parents for each node.
encoder_max_clusters: int = 10
Data encoding clusters.
encoder_noise_scale: float
encoder_noise_scale: float.
Small noise to add to the final data, to prevent data leakage.
workspace: Path.
Optional Path for caching intermediary results.
compress_dataset: bool. Default = False.
Drop redundant features before training the generator.
random_state: int.
Random seed.
sampling_patience: int.
Max inference iterations to wait for the generated data to match the training schema.
Example:
>>> from sklearn.datasets import load_iris
Expand All @@ -64,9 +72,20 @@ def __init__(
struct_max_indegree: int = 4,
encoder_max_clusters: int = 10,
encoder_noise_scale: float = 0.1,
# core plugin
workspace: Path = Path("workspace"),
compress_dataset: bool = False,
random_state: int = 0,
sampling_patience: int = 500,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
super().__init__(
random_state=random_state,
sampling_patience=sampling_patience,
workspace=workspace,
compress_dataset=compress_dataset,
**kwargs,
)

self.struct_learning_n_iter = struct_learning_n_iter
self.struct_learning_search_method = struct_learning_search_method
Expand Down
Loading

0 comments on commit 9a74dd6

Please sign in to comment.