Skip to content

Commit

Permalink
More tutorials + Colab links (#91)
Browse files Browse the repository at this point in the history
* user conditionals for survival analysis

* update tests

* debug test

* rework conditionals

* cleanup

* bugfixing

* update decaf

* lower condition

* cleanup

* Update README.md

* cleanup

* Update README.md

* Update README.md

* Update README.md

* survival analysis tutorial

* time series tutorial

* Update README.md

* Update README.md

* bugfixing

* bugfixing

* add DP tutorial

* Update README.md

* lint readme
  • Loading branch information
bcebere authored Jan 8, 2023
1 parent e2c24e6 commit 06daf21
Show file tree
Hide file tree
Showing 11 changed files with 1,093 additions and 168 deletions.
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
[![Tutorials](https://github.com/vanderschaarlab/synthcity/actions/workflows/test_tutorials.yml/badge.svg)](https://github.com/vanderschaarlab/synthcity/actions/workflows/test_tutorials.yml)
[![Documentation Status](https://readthedocs.org/projects/synthcity/badge/?version=latest)](https://synthcity.readthedocs.io/en/latest/?badge=latest)


[![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Vr2PJswgfFYBkJCm3hhVkuH-9dXnHeYV?usp=sharing)
[![about](https://img.shields.io/badge/about-The%20van%20der%20Schaar%20Lab-blue)](https://www.vanderschaar-lab.com/)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://github.com/vanderschaarlab/synthcity/blob/main/LICENSE)
[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
Expand Down Expand Up @@ -214,9 +214,12 @@ assert syn_model.name() == reloaded.name()

## 📓 Tutorials

- [Tutorial 0: Basics](tutorials/tutorial0_basic_examples.ipynb)
- [Tutorial 1: Write a new plugin](tutorials/tutorial1_add_a_new_plugin.ipynb)
- [Tutorial 2: Benchmarks](tutorials/tutorial2_benchmarks.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Vr2PJswgfFYBkJCm3hhVkuH-9dXnHeYV?usp=sharing) [ Tutorial 0: Getting started with tabular data](https://github.com/vanderschaarlab/synthcity/blob/use_cases/tutorials/tutorial0_basic_examples.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1rTTvV4FT-Ut-rIHoBPXQimiBlZ7zCv59?usp=sharing) [ Tutorial 1: Writing a new plugin](https://github.com/vanderschaarlab/synthcity/blob/use_cases/tutorials/tutorial1_add_a_new_plugin.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1FXpnQ9bpHzEgJgD-9pf_PPN4D80ENilE?usp=sharing) [ Tutorial 2: Benchmarking models](https://github.com/vanderschaarlab/synthcity/blob/use_cases/tutorials/tutorial2_benchmarks.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Wa2CPsbXzbKMPC5fSBhKl00Gi7QqVkse?usp=sharing) [ Tutorial 3: Generating Survival Analysis data](https://github.com/vanderschaarlab/synthcity/blob/use_cases/tutorials/tutorial3_survival_analysis.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jN36GCAKEkjzDlczmQfR7Wbh3yF3cIz5?usp=sharing) [ Tutorial 4: Generating Time Series](https://github.com/vanderschaarlab/synthcity/blob/use_cases/tutorials/tutorial4_time_series.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Nf8d3Y6sXr1uco8MsJA4wb33iFvReL59?usp=sharing) [ Tutorial 5: Generating Data with Differential Privacy Guarantees](https://github.com/vanderschaarlab/synthcity/blob/use_cases/tutorials/tutorial5_differential_privacy.ipynb)


## 🔑 Methods
Expand Down
44 changes: 44 additions & 0 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Any, Dict, List, Optional, Tuple

# third party
import numpy as np
import pandas as pd
import torch
from IPython.display import display
Expand Down Expand Up @@ -207,3 +208,46 @@ def print(

display(results[plugin].drop(columns=["direction"]))
print()

@staticmethod
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def highlight(
results: Dict,
) -> None:
pd.set_option("display.max_rows", None, "display.max_columns", None)
means = []
for plugin in results:
data = results[plugin]["mean"]
directions = results[plugin]["direction"].to_dict()
means.append(data)

out = pd.concat(means, axis=1)
out.set_axis(results.keys(), axis=1, inplace=True)

bad_highlight = "background-color: lightcoral;"
ok_highlight = "background-color: green;"
default = ""

def highlights(row: pd.Series) -> Any:
metric = row.name
if directions[metric] == "minimize":
best_val = np.min(row.values)
worst_val = np.max(row)
else:
best_val = np.max(row.values)
worst_val = np.min(row)

styles = []
for val in row.values:
if val == best_val:
styles.append(ok_highlight)
elif val == worst_val:
styles.append(bad_highlight)
else:
styles.append(default)

return styles

out.style.apply(highlights, axis=1)

return out
6 changes: 3 additions & 3 deletions src/synthcity/plugins/core/models/ts_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(
output_shape=[n_temporal_window, n_temporal_units],
nonlin_out=generator_temporal_nonlin_out,
**rnn_generator_extra_args,
)
).to(self.device)

# Temporal generator from the latent space: Z_temporal -> E_temporal
self.temporal_generator = TimeSeriesModel(
Expand All @@ -218,7 +218,7 @@ def __init__(
n_temporal_window=n_temporal_window,
output_shape=[n_temporal_window, n_temporal_units_latent],
**rnn_generator_extra_args,
)
).to(self.device)

# Temporal supervisor: Generate the next sequence: E_temporal -> fake_next_temporal_embeddings_temporal
self.temporal_supervisor = TimeSeriesModel(
Expand All @@ -228,7 +228,7 @@ def __init__(
n_temporal_window=n_temporal_window,
output_shape=[n_temporal_window, n_temporal_units_latent],
**rnn_generator_extra_args,
)
).to(self.device)

# Discriminate the original and synthetic time-series data.
self.discriminator = TimeSeriesModel(
Expand Down
8 changes: 4 additions & 4 deletions src/synthcity/plugins/generic/plugin_dpgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def __init__(
dataloader_sampler: Optional[sampler.Sampler] = None,
device: Any = DEVICE,
# privacy settings
dp_epsilon: float = 4,
dp_delta: Optional[float] = None,
epsilon: float = 1,
delta: Optional[float] = None,
dp_max_grad_norm: float = 2,
dp_secure_mode: bool = False,
# early stopping
Expand Down Expand Up @@ -167,9 +167,9 @@ def __init__(
self.n_iter_print = n_iter_print

# privacy
self.dp_epsilon = epsilon
self.dp_delta = delta
self.dp_enabled = True
self.dp_epsilon = dp_epsilon
self.dp_delta = dp_delta
self.dp_max_grad_norm = dp_max_grad_norm
self.dp_secure_mode = dp_secure_mode

Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.2"
__version__ = "0.0.3"

MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
MINOR_VERSION = __version__.split(".")[-1]
149 changes: 62 additions & 87 deletions tutorials/tutorial0_basic_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
"id": "97e2d93c",
"metadata": {},
"source": [
"# Tutorial 0: Basic examples"
"# Tutorial 0: Basic examples\n",
"\n",
"`synthcity` supports generating tabular data for various modalities. In this tutorial we will cover the general-purpose models."
]
},
{
Expand Down Expand Up @@ -120,7 +122,9 @@
"id": "adf672a5",
"metadata": {},
"source": [
"## Generate new data under some constraints"
"## Generate new data using conditionals\n",
"\n",
"We will condition the generated data using the target data(y)."
]
},
{
Expand All @@ -131,16 +135,11 @@
"outputs": [],
"source": [
"# synthcity absolute\n",
"# Constraint: target <= 100\n",
"from synthcity.plugins.core.constraints import Constraints\n",
"\n",
"constraints = Constraints(rules=[(\"target\", \"<=\", 100)])\n",
"\n",
"generated = syn_model.generate(count=10, constraints=constraints)\n",
"from synthcity.plugins import Plugins\n",
"\n",
"assert (generated[\"target\"] <= 100).any()\n",
"syn_model = Plugins().get(\"adsgan\")\n",
"\n",
"generated.dataframe()"
"syn_model.fit(loader, cond=y.to_frame())"
]
},
{
Expand All @@ -150,18 +149,12 @@
"metadata": {},
"outputs": [],
"source": [
"# Constraint: target > 150\n",
"\n",
"# synthcity absolute\n",
"from synthcity.plugins.core.constraints import Constraints\n",
"\n",
"constraints = Constraints(rules=[(\"target\", \">\", 150)])\n",
"\n",
"generated = syn_model.generate(count=10, constraints=constraints)\n",
"\n",
"assert (generated[\"target\"] > 150).any()\n",
"import numpy as np\n",
"\n",
"generated.dataframe()"
"count = 10\n",
"syn_model.generate(\n",
" count=count, cond=np.ones(count)\n",
").dataframe() # generate only patients with the outcome = 1"
]
},
{
Expand Down Expand Up @@ -225,12 +218,54 @@
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35e6a9ae",
"cell_type": "markdown",
"id": "64bacb40",
"metadata": {},
"outputs": [],
"source": []
"source": [
"## Benchmarking metrics\n",
"\n",
"| **Metric** | **Description** |\n",
"|----------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------|\n",
"| sanity.data\\_mismatch.score | Data types mismatch between the real//synthetic features |\n",
"| sanity.common\\_rows\\_proportion.score | Real data copy-paste in the synthetic data |\n",
"| sanity.nearest\\_syn\\_neighbor\\_distance.mean | Computes the \\textless{}reduction\\textgreater{}(distance) from the real data to the closest neighbor in the synthetic data |\n",
"| sanity.close\\_values\\_probability.score | the probability of close values between the real and synthetic data. |\n",
"| sanity.distant\\_values\\_probability.score | the probability of distant values between the real and synthetic data. |\n",
"| stats.jensenshannon\\_dist.marginal | the average Jensen-Shannon distance |\n",
"| stats.chi\\_squared\\_test.marginal | the one-way chi-square test. |\n",
"| stats.feature\\_corr.joint | the correlation/strength-of-association of features in data-set with both categorical and continuous features |\n",
"| stats.inv\\_kl\\_divergence.marginal | the average inverse of the Kullback–Leibler Divergence metric. |\n",
"| stats.ks\\_test.marginal | the Kolmogorov-Smirnov test for goodness of fit. |\n",
"| stats.max\\_mean\\_discrepancy.joint | Empirical maximum mean discrepancy. The lower the result the more evidence that distributions are the same. |\n",
"| stats.prdc.precision | precision between the two manifolds |\n",
"| stats.prdc.recall | recall between the two manifolds |\n",
"| stats.prdc.density | density between the two manifolds |\n",
"| stats.prdc.coverage | coverage between the two manifolds |\n",
"| stats.alpha\\_precision.delta\\_precision\\_alpha\\_OC | Delta precision |\n",
"| stats.alpha\\_precision.delta\\_coverage\\_beta\\_OC | Delta coverage |\n",
"| stats.alpha\\_precision.authenticity\\_OC | Authetnticity |\n",
"| performance.linear\\_model.gt.aucroc | Train on real, test on the test real data using LogisticRegression: AUCROC |\n",
"| performance.linear\\_model.syn\\_id.aucroc | Train on synthetic, test on the train real data using LogisticRegression: AUCROC |\n",
"| performance.linear\\_model.syn\\_ood.aucroc | Train on synthetic, test on the test real data using LogisticRegression: AUCROC |\n",
"| performance.mlp.gt.aucroc | Train on real, test on the test real data using NN: AUCROC |\n",
"| performance.mlp.syn\\_id.aucroc | Train on synthetic, test on the train real data using NN: AUCROC |\n",
"| performance.mlp.syn\\_ood.aucroc | Train on synthetic, test on the test real data using NN: AUCROC |\n",
"| performance.xgb.gt.aucroc | Train on real, test on the test real data using XGB: AUCROC |\n",
"| performance.xgb.syn\\_id.aucroc | Train on synthetic, test on the train real data using XGB: AUCROC |\n",
"| performance.xgb.syn\\_ood.aucroc | Train on synthetic, test on the test real data using XGB: AUCROC |\n",
"| performance.feat\\_rank\\_distance.corr | Correlation for the rank distances between the feature importance on real and synthetic data |\n",
"| performance.feat\\_rank\\_distance.pvalue | p-vale for the rank distances between the feature importance on real and synthetic data |\n",
"| detection.detection\\_xgb.mean | The average AUCROC score for detecting synthetic data using an XGBoost. |\n",
"| detection.detection\\_mlp.mean | The average AUCROC score for detecting synthetic data using a NN. |\n",
"| detection.detection\\_gmm.mean | The average AUCROC score for detecting synthetic data using a GMM. |\n",
"| privacy.delta-presence.score | the maximum re-identification probability on the real dataset from the synthetic dataset. |\n",
"| privacy.k-anonymization.gt | the k-anon for the real data |\n",
"| privacy.k-anonymization.syn | the k-anon for the synthetic data |\n",
"| privacy.k-map.score | the minimum value k that satisfies the k-map rule. |\n",
"| privacy.distinct l-diversity.gt | the l-diversity for the real data |\n",
"| privacy.distinct l-diversity.syn | the l-diversity for the synthetic data |\n",
"| privacy.identifiability\\_score.score | the re-identification score on the real dataset from the synthetic dataset. |"
]
},
{
"cell_type": "markdown",
Expand All @@ -252,16 +287,13 @@
"# synthcity absolute\n",
"from synthcity.benchmark import Benchmarks\n",
"\n",
"constraints = Constraints(rules=[(\"target\", \"ge\", 150)])\n",
"\n",
"score = Benchmarks.evaluate(\n",
" [\n",
" (\"marginal_distributions\", \"marginal_distributions\", {}),\n",
" (\"dummy_sampler\", \"dummy_sampler\", {}),\n",
" ],\n",
" loader,\n",
" synthetic_size=1000,\n",
" synthetic_constraints=constraints,\n",
" repeats=2,\n",
")"
]
Expand All @@ -278,63 +310,6 @@
"Benchmarks.print(score)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7fa5f89a",
"metadata": {},
"outputs": [],
"source": [
"# third party\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"means = []\n",
"for plugin in score:\n",
" data = score[plugin][\"mean\"]\n",
" directions = score[plugin][\"direction\"].to_dict()\n",
" means.append(data)\n",
"\n",
"out = pd.concat(means, axis=1)\n",
"out.set_axis(score.keys(), axis=1, inplace=True)\n",
"\n",
"bad_highlight = \"background-color: lightcoral;\"\n",
"ok_highlight = \"background-color: green;\"\n",
"default = \"\"\n",
"\n",
"\n",
"def highlights(row):\n",
" metric = row.name\n",
" if directions[metric] == \"minimize\":\n",
" best_val = np.min(row.values)\n",
" worst_val = np.max(row)\n",
" else:\n",
" best_val = np.max(row.values)\n",
" worst_val = np.min(row)\n",
"\n",
" styles = []\n",
" for val in row.values:\n",
" if val == best_val:\n",
" styles.append(ok_highlight)\n",
" elif val == worst_val:\n",
" styles.append(bad_highlight)\n",
" else:\n",
" styles.append(default)\n",
"\n",
" return styles\n",
"\n",
"\n",
"out.style.apply(highlights, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3f2912f1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "cbdb2a3e",
Expand All @@ -360,7 +335,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.15"
}
},
"nbformat": 4,
Expand Down
18 changes: 1 addition & 17 deletions tutorials/tutorial1_add_a_new_plugin.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -296,22 +296,6 @@
"\n",
"generated.dataframe()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c09b449",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3fff88a",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -330,7 +314,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.9.15"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 06daf21

Please sign in to comment.