diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 9f304f7d..9f9362f5 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -44,6 +44,7 @@ jobs:
- uses: actions/checkout@v2
with:
submodules: true
+ - uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
diff --git a/.github/workflows/test_tutorials.yml b/.github/workflows/test_tutorials.yml
index 08285c14..4be2288b 100644
--- a/.github/workflows/test_tutorials.yml
+++ b/.github/workflows/test_tutorials.yml
@@ -20,6 +20,7 @@ jobs:
- uses: actions/checkout@v2
with:
submodules: true
+ - uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
diff --git a/.gitignore b/.gitignore
index 062f9112..41f36b84 100644
--- a/.gitignore
+++ b/.gitignore
@@ -65,3 +65,5 @@ data
checkpoints
lightning_logs
generated
+MNIST
+cifar-10*
diff --git a/README.md b/README.md
index 52c4794b..8c09c72a 100644
--- a/README.md
+++ b/README.md
@@ -32,10 +32,11 @@
- :cyclone: Several evaluation metrics for correctness and privacy.
- :fire: Several reference models, by type:
- General purpose: GAN-based (AdsGAN, CTGAN, PATEGAN, DP-GAN),VAE-based(TVAE, RTVAE), Normalizing flows, Bayesian Networks(PrivBayes, BN).
- - Time Series generators: TimeGAN, FourierFlows, Probabilistic autoregressive.
- - Survival Analysis: SurvivalGAN, SurVAE.
+ - Time Series & Time-Series Survival generators: TimeGAN, FourierFlows, TimeVAE.
+ - Static Survival Analysis: SurvivalGAN, SurVAE.
- Privacy-focused: DECAF, DP-GAN, AdsGAN, PATEGAN, PrivBayes.
- Domain adaptation: RadialGAN.
+ - Images: Image ConditionalGAN, Image AdsGAN.
- :book: [Read the docs !](https://synthcity.readthedocs.io/)
- :airplane: [Checkout the tutorials!](https://github.com/vanderschaarlab/synthcity#-tutorials)
@@ -110,7 +111,7 @@ score = Benchmarks.evaluate(
Benchmarks.print(score)
```
-### Survival analysis
+### Static Survival analysis
* List the available generators dedicated to survival analysis
@@ -172,6 +173,37 @@ syn_model = Plugins().get("timegan")
syn_model.fit(data)
syn_model.generate(count=10)
+```
+### Images
+
+__Note__ : The architectures used for generators are not state-of-the-art. For other architectures, consider extending the `suggest_image_generator_discriminator_arch` method from the `convnet.py` module.
+
+* List the available generators
+
+```python
+from synthcity.plugins import Plugins
+
+Plugins(categories=["images"]).list()
+```
+
+* Generate new data
+```python
+from synthcity.plugins import Plugins
+from synthcity.plugins.core.dataloader import ImageDataLoader
+from torchvision import datasets
+
+
+dataset = datasets.MNIST(".", download=True)
+loader = ImageDataLoader(dataset).sample(100)
+
+syn_model = Plugins().get("image_cgan")
+
+syn_model.fit(loader)
+
+syn_img, syn_labels = syn_model.generate(count=10).unpack().numpy()
+
+print(syn_img.shape)
+
```
### Serialization
@@ -225,7 +257,8 @@ assert syn_model.name() == reloaded.name()
- [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Test In Colab"](https://colab.research.google.com/drive/1Wa2CPsbXzbKMPC5fSBhKl00Gi7QqVkse?usp=sharing) [ Tutorial 3: Generating Survival Analysis data](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial3_survival_analysis.ipynb)
- [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Test In Colab"](https://colab.research.google.com/drive/1jN36GCAKEkjzDlczmQfR7Wbh3yF3cIz5?usp=sharing) [ Tutorial 4: Generating Time Series](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial4_time_series.ipynb)
- [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Test In Colab"](https://colab.research.google.com/drive/1Nf8d3Y6sXr1uco8MsJA4wb33iFvReL59?usp=sharing) [ Tutorial 5: Generating Data with Differential Privacy Guarantees](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial5_differential_privacy.ipynb)
-
+ - [ Tutorial 6 for using Custom Time series data](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial6_time_series_data_preparation.ipynb)
+ - [data:image/s3,"s3://crabby-images/e7985/e79852128a5f83c92496b9d734ca52d01e009a39" alt="Test In Colab"](https://colab.research.google.com/drive/1gEUSGxVmts9C0cDBKbq7Ees-wvQpUBoo?usp=sharing) [ Tutorial 7: Image generation](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial7_image_generation_using_mednist.ipynb)
## 🔑 Methods
@@ -291,6 +324,13 @@ assert syn_model.name() == reloaded.name()
|--- | --- | --- |
|**radialgan** | Training complex machine learning models for prediction often requires a large amount of data that is not always readily available. Leveraging these external datasets from related but different sources is, therefore, an essential task if good predictive models are to be built for deployment in settings where data can be rare. RadialGAN is an approach to the problem in which multiple GAN architectures are used to learn to translate from one dataset to another, thereby allowing to augment the target dataset effectively and learning better predictive models than just the target dataset. | [RadialGAN: Leveraging multiple datasets to improve target-specific predictive models using Generative Adversarial Networks](https://arxiv.org/abs/1802.06403) |
+### Images
+
+| Method | Description | Reference |
+|--- | --- | --- |
+|**image_cgan**| Conditional GAN for generating images| --- |
+|**image_adsgan**| The AdsGAN method adapted for image generation| --- |
+
### Debug methods
@@ -328,6 +368,7 @@ The following table contains the available evaluation metrics:
|**prdc**| Computes precision, recall, density, and coverage given two manifolds. | --- |
|**alpha_precision**|Evaluate the alpha-precision, beta-recall, and authenticity scores. | --- |
|**survival_km_distance**|The distance between two Kaplan-Meier plots(survival analysis). | --- |
+|**fid**|The Frechet Inception Distance (FID) calculates the distance between two distributions of images. | --- |
diff --git a/docs/README.md b/docs/README.md
index 7bf01012..267feb14 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -12,10 +12,11 @@
- |:cyclone:| Several evaluation metrics for correctness and privacy.
- |:fire:| Several reference models, by type:
- General purpose: GAN-based (AdsGAN, CTGAN, PATEGAN, DP-GAN),VAE-based(TVAE, RTVAE), Normalizing flows, Bayesian Networks(PrivBayes, BN).
- - Time Series generators: TimeGAN, FourierFlows, Probabilistic autoregressive.
- - Survival Analysis: SurvivalGAN, SurVAE.
+ - Time Series & Time-Series Survival generators: TimeGAN, FourierFlows, TimeVAE.
+ - Static Survival Analysis: SurvivalGAN, SurVAE.
- Privacy-focused: DECAF, DP-GAN, AdsGAN, PATEGAN, PrivBayes.
- Domain adaptation: RadialGAN.
+ - Images: Image ConditionalGAN, Image AdsGAN.
- |:book:| [Read the docs !](https://synthcity.readthedocs.io/)
- |:airplane:| [Checkout the tutorials!](https://github.com/vanderschaarlab/synthcity#-tutorials)
@@ -152,6 +153,37 @@ syn_model = Plugins().get("timegan")
syn_model.fit(data)
syn_model.generate(count=10)
+```
+### Images
+
+__Note__ : The architectures used for generators are not state-of-the-art. For other architectures, consider extending the `suggest_image_generator_discriminator_arch` method from the `convnet.py` module.
+
+* List the available generators
+
+```python
+from synthcity.plugins import Plugins
+
+Plugins(categories=["images"]).list()
+```
+
+* Generate new data
+```python
+from synthcity.plugins import Plugins
+from synthcity.plugins.core.dataloader import ImageDataLoader
+from torchvision import datasets
+
+
+dataset = datasets.MNIST(".", download=True)
+loader = ImageDataLoader(dataset).sample(100)
+
+syn_model = Plugins().get("image_cgan")
+
+syn_model.fit(loader)
+
+syn_img, syn_labels = syn_model.generate(count=10).unpack().numpy()
+
+print(syn_img.shape)
+
```
### Serialization
@@ -170,6 +202,20 @@ reloaded = load(buff)
assert syn_model.name() == reloaded.name()
```
+* Saving and loading models from disk
+
+```python
+from synthcity.utils.serialization import save_to_file, load_from_file
+from synthcity.plugins import Plugins
+
+syn_model = Plugins().get("adsgan")
+
+save_to_file('./adsgan_10_epochs.pkl', syn_model)
+reloaded = load_from_file('./adsgan_10_epochs.pkl')
+
+assert syn_model.name() == reloaded.name()
+```
+
* Using the Serializable interface
```python
@@ -247,6 +293,13 @@ assert syn_model.name() == reloaded.name()
|--- | --- | --- |
|**radialgan** | Training complex machine learning models for prediction often requires a large amount of data that is not always readily available. Leveraging these external datasets from related but different sources is, therefore, an essential task if good predictive models are to be built for deployment in settings where data can be rare. RadialGAN is an approach to the problem in which multiple GAN architectures are used to learn to translate from one dataset to another, thereby allowing to augment the target dataset effectively and learning better predictive models than just the target dataset. | [RadialGAN: Leveraging multiple datasets to improve target-specific predictive models using Generative Adversarial Networks](https://arxiv.org/abs/1802.06403) |
+### Images
+
+| Method | Description | Reference |
+|--- | --- | --- |
+|**image_cgan**| Conditional GAN for generating images| --- |
+|**image_adsgan**| The AdsGAN method adapted for image generation| --- |
+
### Debug methods
@@ -284,6 +337,7 @@ The following table contains the available evaluation metrics:
|**prdc**| Computes precision, recall, density, and coverage given two manifolds. | --- |
|**alpha_precision**|Evaluate the alpha-precision, beta-recall, and authenticity scores. | --- |
|**survival_km_distance**|The distance between two Kaplan-Meier plots(survival analysis). | --- |
+|**fid**|The Frechet Inception Distance (FID) calculates the distance between two distributions of images. | --- |
@@ -293,7 +347,7 @@ The following table contains the available evaluation metrics:
|--- | --- | --- |
|**performance.xgb**|Train an XGBoost classifier/regressor/survival model on real data(gt) and synthetic data(syn), and evaluate the performance on the test set. | 1 for ideal performance, 0 for worst performance |
|**performance.linear**|Train a Linear classifier/regressor/survival model on real data(gt) and the synthetic data and evaluate the performance on test data.| 1 for ideal performance, 0 for worst performance |
-|**performance.mlp**|Train a Neural Net classifier/regressor/survival model on the read data and the synthetic data and evaluate the performance on test data.| 1 for ideal performance, 0 for worst performance |
+|**performance.mlp**|Train a Neural Net classifier/regressor/survival model on the real data and the synthetic data and evaluate the performance on test data.| 1 for ideal performance, 0 for worst performance |
|**performance.feat_rank_distance**| Train a model on the synthetic data and a model on the real data. Compute the feature importance of the models on the same test data, and compute the rank distance between the importance(kendalltau or spearman)| 1: similar ranks in the feature importance. 0: uncorrelated feature importance |
|**detection_gmm**|Train a GaussianMixture model to differentiate the synthetic data from the real data.|0: The datasets are indistinguishable.
1: The datasets are totally distinguishable.|
|**detection_xgb**|Train an XGBoost model to differentiate the synthetic data from the real data.|0: The datasets are indistinguishable.
1: The datasets are totally distinguishable.|
diff --git a/docs/advanced.rst b/docs/advanced.rst
index bed1e89f..886fae58 100644
--- a/docs/advanced.rst
+++ b/docs/advanced.rst
@@ -51,7 +51,6 @@ Time-series survival models
:maxdepth: 2
Time-series CoxPH
- DeepCoxPH
Dynamic DeepHit
Time-Series XGBoost
@@ -63,3 +62,12 @@ Time-to-event models
DATE
Survival function regression
+
+Images
+-------------------------------
+.. toctree::
+ :glob:
+ :maxdepth: 2
+
+ ConvNets
+ ImageGAN
diff --git a/docs/conf.py b/docs/conf.py
index 35ab9b8b..25e964ac 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -1,15 +1,3 @@
-# Configuration file for the Sphinx documentation builder.
-#
-# This file only contains a selection of the most common options. For a full
-# list see the documentation:
-# https://www.sphinx-doc.org/en/master/usage/configuration.html
-
-# -- Path setup --------------------------------------------------------------
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-#
import os
import sys
import subprocess
@@ -152,7 +140,6 @@
"pycox",
"pykeops",
"pyod",
- "pydantic",
"rdt",
"redis",
"scikit-learn",
@@ -165,7 +152,6 @@
"joblib",
"sdv",
"shap",
- "torch",
"tsai",
"xgboost",
"xgbse",
diff --git a/docs/dataloaders.rst b/docs/dataloaders.rst
new file mode 100644
index 00000000..a55a3b64
--- /dev/null
+++ b/docs/dataloaders.rst
@@ -0,0 +1,11 @@
+Datasets and DataLoaders
+=========================
+
+Dataloaders and Datasets
+-------------------------
+.. toctree::
+ :glob:
+ :maxdepth: 2
+
+ Dataloaders
+ Datasets
diff --git a/docs/examples.rst b/docs/examples.rst
index 4116eecc..034281c2 100644
--- a/docs/examples.rst
+++ b/docs/examples.rst
@@ -13,6 +13,8 @@ Getting started
Generating Survival Analysis Data
Generating Time Series
Generating Data with Differential Privacy Guarantees
+ Using custom Time-series Datasets
+ Generating Images
General-purpose generators
---------------------------
@@ -55,3 +57,12 @@ Domain adaptation generators
:maxdepth: 2
RadialGAN
+
+Images
+------------------------------
+.. toctree::
+ :glob:
+ :maxdepth: 2
+
+ Image CGAN
+ Image AdsGAN
diff --git a/docs/generators.rst b/docs/generators.rst
index 0c4572c9..d78a30fd 100644
--- a/docs/generators.rst
+++ b/docs/generators.rst
@@ -52,5 +52,13 @@ Time-series & Time-Series Survival Analysis
TimeGAN
FourierFlows
- Probabilistic AutoRegressive
TimeVAE
+
+Images
+----------------------------------------------
+.. toctree::
+ :glob:
+ :maxdepth: 2
+
+ ImageCGAN
+ Image AdsGAN
diff --git a/docs/index.rst b/docs/index.rst
index d8215099..e2768326 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -20,13 +20,13 @@ Examples
examples.rst
-Dataloaders
-============
+Dataloaders and Datasets
+==========================
.. toctree::
:glob:
:maxdepth: 3
- Dataloaders
+ dataloaders.rst
Generators
==========
diff --git a/setup.cfg b/setup.cfg
index 9b25e1d1..929aafa6 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -19,7 +19,7 @@ platforms = any
classifiers =
Programming Language :: Python :: 3
Topic :: Scientific/Engineering :: Artificial Intelligence
- Intended Audience :: Healthcare Industry
+ Intended Audience :: Science/Research
Operating System :: OS Independent
[options]
@@ -55,6 +55,7 @@ install_requires =
xgbse
pykeops
fflows
+ monai
tsai; python_version>"3.7"
importlib-metadata; python_version<"3.8"
diff --git a/src/synthcity/__init__.py b/src/synthcity/__init__.py
index 24672e5e..3d8d7988 100644
--- a/src/synthcity/__init__.py
+++ b/src/synthcity/__init__.py
@@ -23,4 +23,4 @@
warnings.simplefilter(action="ignore")
-logger.add(sink=sys.stderr, level="ERROR")
+logger.add(sink=sys.stderr, level="CRITICAL")
diff --git a/src/synthcity/benchmark/__init__.py b/src/synthcity/benchmark/__init__.py
index 9efaabd8..08bb024a 100644
--- a/src/synthcity/benchmark/__init__.py
+++ b/src/synthcity/benchmark/__init__.py
@@ -9,7 +9,6 @@
# third party
import numpy as np
import pandas as pd
-import torch
from IPython.display import display
from pydantic import validate_arguments
@@ -20,7 +19,7 @@
from synthcity.plugins import Plugins
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import DataLoader
-from synthcity.utils.reproducibility import enable_reproducible_results
+from synthcity.utils.reproducibility import clear_cache, enable_reproducible_results
from synthcity.utils.serialization import load_from_file, save_to_file
@@ -96,7 +95,9 @@ def evaluate(
workspace.mkdir(parents=True, exist_ok=True)
plugin_cats = ["generic", "privacy"]
- if task_type == "survival_analysis":
+ if X.type() == "images":
+ plugin_cats.append("images")
+ elif task_type == "survival_analysis":
plugin_cats.append("survival_analysis")
elif task_type == "time_series" or task_type == "time_series_survival":
plugin_cats.append("time_series")
@@ -123,7 +124,7 @@ def evaluate(
kwargs["workspace"] = workspace
kwargs["random_state"] = repeat
- torch.cuda.empty_cache()
+ clear_cache()
cache_file = (
workspace
diff --git a/src/synthcity/metrics/eval.py b/src/synthcity/metrics/eval.py
index fb22c9d8..89c03142 100644
--- a/src/synthcity/metrics/eval.py
+++ b/src/synthcity/metrics/eval.py
@@ -43,6 +43,7 @@
from .eval_statistical import (
AlphaPrecision,
ChiSquaredTest,
+ FrechetInceptionDistance,
InverseKLDivergence,
JensenShannonDistance,
KolmogorovSmirnovTest,
@@ -70,6 +71,7 @@
PRDCScore,
AlphaPrecision,
SurvivalKMDistance,
+ FrechetInceptionDistance,
# performance tests
PerformanceEvaluatorLinear,
PerformanceEvaluatorMLP,
diff --git a/src/synthcity/metrics/eval_detection.py b/src/synthcity/metrics/eval_detection.py
index 7629fa05..b48a848f 100644
--- a/src/synthcity/metrics/eval_detection.py
+++ b/src/synthcity/metrics/eval_detection.py
@@ -4,6 +4,7 @@
# third party
import numpy as np
+import torch
from pydantic import validate_arguments
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
@@ -15,7 +16,10 @@
import synthcity.logger as log
from synthcity.metrics.core import MetricEvaluator
from synthcity.plugins.core.dataloader import DataLoader
+from synthcity.plugins.core.dataset import NumpyDataset
+from synthcity.plugins.core.models.convnet import suggest_image_classifier_arch
from synthcity.plugins.core.models.mlp import MLP
+from synthcity.utils.reproducibility import clear_cache
from synthcity.utils.serialization import load_from_file, save_to_file
@@ -49,8 +53,16 @@ def type() -> str:
def direction() -> str:
return "minimize"
+ @staticmethod
+ def name() -> str:
+ raise NotImplementedError()
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ raise NotImplementedError()
+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
- def _evaluate_detection(
+ def _evaluate_detection_generic(
self,
model_template: Any,
X_gt: DataLoader,
@@ -63,13 +75,15 @@ def _evaluate_detection(
)
if self.use_cache(cache_file):
results = load_from_file(cache_file)
- log.info(f" Detection eval for {self.name()} : {results}")
+ log.info(
+ f" Synthetic-real data discrimination using {self.name()}. AUCROC : {results}"
+ )
return results
- arr_gt = X_gt.numpy()
+ arr_gt = X_gt.numpy().reshape(len(X_gt), -1)
labels_gt = np.asarray([0] * len(X_gt))
- arr_syn = X_syn.numpy()
+ arr_syn = X_syn.numpy().reshape(len(X_syn), -1)
labels_syn = np.asarray([1] * len(X_syn))
data = np.concatenate([arr_gt, arr_syn])
@@ -96,7 +110,9 @@ def _evaluate_detection(
res.append(score)
results = {self._reduction: float(self.reduction()(res))}
- log.info(f" Detection eval for {self.name()} : {results}")
+ log.info(
+ f" Synthetic-real data discrimination using {self.name()}. AUCROC : {results}"
+ )
save_to_file(cache_file, results)
@@ -143,7 +159,9 @@ def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
"random_state": self._random_state,
}
- return self._evaluate_detection(model_template, X_gt, X_syn, **model_args)
+ return self._evaluate_detection_generic(
+ model_template, X_gt, X_syn, **model_args
+ )
class SyntheticDetectionMLP(DetectionEvaluator):
@@ -168,15 +186,73 @@ def __init__(self, **kwargs: Any) -> None:
def name() -> str:
return "detection_mlp"
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def _evaluate_image_detection(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ clear_cache()
+
+ cache_file = (
+ self._workspace
+ / f"sc_metric_cache_{self.type()}_{self.name()}_{X_gt.hash()}_{X_syn.hash()}_{self._reduction}_{platform.python_version()}.bkp"
+ )
+ if self.use_cache(cache_file):
+ results = load_from_file(cache_file)
+ log.info(
+ f" Synthetic-real data discrimination using {self.name()}. AUCROC : {results}"
+ )
+ return results
+
+ data_gt = X_gt.numpy()
+ data_syn = X_syn.numpy()
+ data = np.concatenate([data_gt, data_syn], axis=0)
+
+ labels_gt = np.asarray([0] * len(X_gt))
+ labels_syn = np.asarray([1] * len(X_syn))
+ labels = np.concatenate([labels_gt, labels_syn])
+
+ skf = StratifiedKFold(
+ n_splits=self._n_folds, shuffle=True, random_state=self._random_state
+ )
+ res = []
+ for train_idx, test_idx in skf.split(data, labels):
+ train_X = data[train_idx]
+ train_y = labels[train_idx]
+ test_X = data[test_idx]
+ test_y = labels[test_idx]
+
+ clf = suggest_image_classifier_arch(
+ n_channels=X_gt.info()["channels"],
+ height=X_gt.info()["height"],
+ width=X_gt.info()["width"],
+ classes=2,
+ )
+ train_dataset = NumpyDataset(train_X, train_y)
+
+ clf.fit(train_dataset)
+ test_pred = clf.predict_proba(torch.from_numpy(test_X))[:, 1].cpu().numpy()
+
+ score = roc_auc_score(test_y, test_pred)
+ res.append(score)
+
+ results = {self._reduction: float(self.reduction()(res))}
+ log.info(
+ f" Synthetic-real data discrimination using {self.name()}. AUCROC : {results}"
+ )
+
+ save_to_file(cache_file, results)
+ return results
+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ if X_gt.type() == "images":
+ return self._evaluate_image_detection(X_gt, X_syn)
+
model_args = {
"task_type": "classification",
"n_units_in": X_gt.shape[1],
"n_units_out": 2,
"random_state": self._random_state,
}
- return self._evaluate_detection(
+ return self._evaluate_detection_generic(
MLP,
X_gt,
X_syn,
@@ -213,7 +289,7 @@ def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
"n_jobs": -1,
"max_iter": 10000,
}
- return self._evaluate_detection(
+ return self._evaluate_detection_generic(
LogisticRegression,
X_gt,
X_syn,
@@ -253,7 +329,7 @@ def evaluate(
"n_components": min(10, len(X_gt)),
"random_state": self._random_state,
}
- return self._evaluate_detection(
+ return self._evaluate_detection_generic(
GaussianMixture,
X_gt,
X_syn,
diff --git a/src/synthcity/metrics/eval_performance.py b/src/synthcity/metrics/eval_performance.py
index 2c4923f0..89803984 100644
--- a/src/synthcity/metrics/eval_performance.py
+++ b/src/synthcity/metrics/eval_performance.py
@@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import shap
+import torch
from pydantic import validate_arguments
from scipy.stats import kendalltau, spearmanr
from sklearn.linear_model import LinearRegression, LogisticRegression
@@ -20,6 +21,7 @@
from synthcity.metrics._utils import evaluate_auc
from synthcity.metrics.core import MetricEvaluator
from synthcity.plugins.core.dataloader import DataLoader
+from synthcity.plugins.core.models.convnet import suggest_image_classifier_arch
from synthcity.plugins.core.models.mlp import MLP
from synthcity.plugins.core.models.survival_analysis import (
CoxPHSurvivalAnalysis,
@@ -159,6 +161,9 @@ def _evaluate_standard_performance(
Returns:
gt and syn performance scores
"""
+ if X_gt.type() == "images":
+ raise ValueError("Standard evaluation not supported for images")
+
cache_file = (
self._workspace
/ f"sc_metric_cache_{self.type()}_{self.name()}_{X_gt.hash()}_{X_syn.hash()}_{self._reduction}_{platform.python_version()}_{platform.python_version()}.bkp"
@@ -232,6 +237,9 @@ def _evaluate_survival_model(
Returns:
gt and syn performance scores
"""
+ if X_gt.type() == "images":
+ raise ValueError("Survival analysis evaluation not supported for images")
+
if X_gt.type() != "survival_analysis" or X_syn.type() != "survival_analysis":
raise ValueError(
f"Invalid data types. gt = {X_gt.type()} syn = {X_syn.type()}"
@@ -339,6 +347,9 @@ def _evaluate_time_series_performance(
Returns:
gt and syn performance scores
"""
+ if X_gt.type() == "images":
+ raise ValueError("Time series evaluation not supported for images")
+
if X_gt.type() != "time_series" or X_syn.type() != "time_series":
raise ValueError(
f"Invalid data type gt = {X_gt.type()} syn = {X_syn.type()}"
@@ -757,6 +768,92 @@ class PerformanceEvaluatorMLP(PerformanceEvaluator):
def name() -> str:
return "mlp"
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def _evaluate_image_clf(
+ self,
+ train_data: torch.utils.data.Dataset,
+ test_data: torch.utils.data.Dataset,
+ input_info: Dict,
+ n_classes: int,
+ ) -> float:
+ _, train_Y = train_data.numpy()
+ test_X, test_Y = test_data.numpy()
+
+ clf = suggest_image_classifier_arch(
+ n_channels=input_info["channels"],
+ height=input_info["height"],
+ width=input_info["width"],
+ classes=n_classes,
+ )
+
+ clf.fit(train_data)
+ test_pred = clf.predict_proba(torch.from_numpy(test_X)).cpu().numpy()
+
+ score, _ = evaluate_auc(test_Y, test_pred)
+ return score
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def _evaluate_images(
+ self,
+ X_gt: DataLoader,
+ X_syn: DataLoader,
+ ) -> Dict:
+ cache_file = (
+ self._workspace
+ / f"sc_metric_cache_{self.type()}_{self.name()}_{X_gt.hash()}_{X_syn.hash()}_{self._reduction}_{platform.python_version()}_{platform.python_version()}.bkp"
+ )
+ if self.use_cache(cache_file):
+ return load_from_file(cache_file)
+
+ id_gt = X_gt.train().unpack()
+ id_X_gt, id_y_gt = id_gt.numpy()
+
+ n_classes = len(np.unique(id_y_gt))
+
+ ood_gt = X_gt.test().unpack()
+ iter_syn = X_syn.unpack()
+ iter_X_syn, iter_y_syn = iter_syn.numpy()
+
+ skf = StratifiedKFold(
+ n_splits=self._n_folds, shuffle=True, random_state=self._random_state
+ )
+
+ real_scores = []
+ syn_scores_id = []
+ syn_scores_ood = []
+
+ for train_idx, test_idx in skf.split(id_X_gt, id_y_gt):
+ train_data = id_gt.filter_indices(train_idx)
+ test_data = id_gt.filter_indices(test_idx)
+
+ real_score = self._evaluate_image_clf(
+ train_data, test_data, X_gt.info(), n_classes=n_classes
+ )
+ synth_score_id = self._evaluate_image_clf(
+ iter_syn,
+ test_data,
+ X_syn.info(),
+ n_classes=n_classes,
+ ) # data seen by the generator
+
+ synth_score_ood = self._evaluate_image_clf(
+ iter_syn, ood_gt, X_syn.info(), n_classes=n_classes
+ ) # data not seen by the generator
+
+ real_scores.append(real_score)
+ syn_scores_id.append(synth_score_id)
+ syn_scores_ood.append(synth_score_ood)
+
+ results = {
+ "gt": float(self.reduction()(real_scores)),
+ "syn_id": float(self.reduction()(syn_scores_id)),
+ "syn_ood": float(self.reduction()(syn_scores_ood)),
+ }
+
+ save_to_file(cache_file, results)
+
+ return results
+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def evaluate(
self,
@@ -769,6 +866,9 @@ def evaluate(
)
elif self._task_type == "classification" or self._task_type == "regression":
+ if X_gt.type() == "images":
+ return self._evaluate_images(X_gt, X_syn)
+
mlp_args = {
"n_units_in": X_gt.shape[1] - 1,
"n_units_out": 1,
diff --git a/src/synthcity/metrics/eval_privacy.py b/src/synthcity/metrics/eval_privacy.py
index d2ad143f..4531321d 100644
--- a/src/synthcity/metrics/eval_privacy.py
+++ b/src/synthcity/metrics/eval_privacy.py
@@ -99,6 +99,9 @@ def evaluate_data(self, X: DataLoader) -> int:
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ if X_gt.type() == "images":
+ raise ValueError(f"Metric {self.name()} doesn't support images")
+
return {
"gt": self.evaluate_data(X_gt),
"syn": (self.evaluate_data(X_syn) + 1e-8),
@@ -149,6 +152,9 @@ def evaluate_data(self, X: DataLoader) -> int:
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ if X_gt.type() == "images":
+ raise ValueError(f"Metric {self.name()} doesn't support images")
+
return {
"gt": self.evaluate_data(X_gt),
"syn": (self.evaluate_data(X_syn) + 1e-8),
@@ -178,6 +184,9 @@ def direction() -> str:
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ if X_gt.type() == "images":
+ raise ValueError(f"Metric {self.name()} doesn't support images")
+
features = get_features(X_gt, X_gt.sensitive_features)
values = []
@@ -220,6 +229,9 @@ def direction() -> str:
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ if X_gt.type() == "images":
+ raise ValueError(f"Metric {self.name()} doesn't support images")
+
features = get_features(X_gt, X_gt.sensitive_features)
values = []
@@ -303,8 +315,8 @@ def _compute_scores(
Returns:
WD_value: Wasserstein distance
"""
- X_gt_ = X_gt.numpy()
- X_syn_ = X_syn.numpy()
+ X_gt_ = X_gt.numpy().reshape(len(X_gt), -1)
+ X_syn_ = X_syn.numpy().reshape(len(X_syn), -1)
if emb == "OC":
emb = f"_{emb}"
@@ -321,7 +333,7 @@ def compute_entropy(labels: np.ndarray) -> np.ndarray:
return entropy(counts)
# Parameters
- no, x_dim = X_gt.shape
+ no, x_dim = X_gt_.shape
# Weights
W = np.zeros(
@@ -331,7 +343,7 @@ def compute_entropy(labels: np.ndarray) -> np.ndarray:
)
for i in range(x_dim):
- W[i] = compute_entropy(X_gt.numpy()[:, i])
+ W[i] = compute_entropy(X_gt_[:, i])
# Normalization
X_hat = X_gt_
diff --git a/src/synthcity/metrics/eval_sanity.py b/src/synthcity/metrics/eval_sanity.py
index ff7ca6de..ffaaaf31 100644
--- a/src/synthcity/metrics/eval_sanity.py
+++ b/src/synthcity/metrics/eval_sanity.py
@@ -8,6 +8,7 @@
from sklearn.neighbors import NearestNeighbors
# synthcity absolute
+import synthcity.logger as log
from synthcity.metrics.core import MetricEvaluator
from synthcity.plugins.core.dataloader import DataLoader
@@ -24,10 +25,15 @@ def __init__(self, **kwargs: Any) -> None:
@staticmethod
def _helper_nearest_neighbor(X_gt: DataLoader, X_syn: DataLoader) -> np.ndarray:
try:
- estimator = NearestNeighbors(n_neighbors=5).fit(X_syn.numpy())
- dist, _ = estimator.kneighbors(X_gt.numpy(), 1, return_distance=True)
+ estimator = NearestNeighbors(n_neighbors=5).fit(
+ X_syn.numpy().reshape(len(X_syn), -1)
+ )
+ dist, _ = estimator.kneighbors(
+ X_gt.numpy().reshape(len(X_gt), -1), 1, return_distance=True
+ )
return dist.squeeze()
except BaseException:
+ log.error("NearestNeighbors failed")
return np.asarray([999])
@staticmethod
@@ -42,6 +48,22 @@ def evaluate_default(
) -> float:
return self.evaluate(X_gt, X_syn)[self._default_metric]
+ @staticmethod
+ def name() -> str:
+ raise NotImplementedError()
+
+ @staticmethod
+ def direction() -> str:
+ raise NotImplementedError()
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def evaluate(
+ self,
+ X_gt: DataLoader,
+ X_syn: DataLoader,
+ ) -> float:
+ raise NotImplementedError()
+
class DataMismatchScore(BasicMetricEvaluator):
"""
@@ -68,6 +90,12 @@ def direction() -> str:
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
+ if X_gt.type() != X_syn.type():
+ raise ValueError("Incompatible dataloader")
+
+ if X_gt.type() == "images":
+ return {"score": 0}
+
if len(X_gt.columns) != len(X_syn.columns):
raise ValueError(f"Incompatible dataframe {X_gt.shape} and {X_syn.shape}")
diff --git a/src/synthcity/metrics/eval_statistical.py b/src/synthcity/metrics/eval_statistical.py
index 9f2ae193..2caf5651 100644
--- a/src/synthcity/metrics/eval_statistical.py
+++ b/src/synthcity/metrics/eval_statistical.py
@@ -9,6 +9,7 @@
import torch
from geomloss import SamplesLoss
from pydantic import validate_arguments
+from scipy import linalg
from scipy.spatial.distance import jensenshannon
from scipy.special import kl_div
from scipy.stats import chisquare, ks_2samp
@@ -17,12 +18,14 @@
from sklearn.preprocessing import MinMaxScaler
# synthcity absolute
+import synthcity.logger as log
from synthcity.metrics._utils import get_frequency
from synthcity.metrics.core import MetricEvaluator
from synthcity.plugins.core.dataloader import DataLoader
from synthcity.plugins.core.models.survival_analysis.metrics import (
nonparametric_distance,
)
+from synthcity.utils.reproducibility import clear_cache
from synthcity.utils.serialization import load_from_file, save_to_file
@@ -53,6 +56,7 @@ def evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
if self.use_cache(cache_file):
return load_from_file(cache_file)
+ clear_cache()
results = self._evaluate(X_gt, X_syn)
save_to_file(cache_file, results)
return results
@@ -173,7 +177,10 @@ def _evaluate(self, X_gt: DataLoader, X_syn: DataLoader) -> Dict:
gt_freq, synth_freq = freqs[col]
try:
_, pvalue = chisquare(gt_freq, synth_freq)
+ if np.isnan(pvalue):
+ pvalue = 0
except BaseException:
+ log.error("chisquare failed")
pvalue = 0
res.append(pvalue)
@@ -229,9 +236,21 @@ def _evaluate(
MMD using rbf (gaussian) kernel (i.e., k(x,y) = exp(-gamma * ||x-y||^2 / 2))
"""
gamma = 1.0
- XX = metrics.pairwise.rbf_kernel(X_gt.numpy(), X_gt.numpy(), gamma)
- YY = metrics.pairwise.rbf_kernel(X_syn.numpy(), X_syn.numpy(), gamma)
- XY = metrics.pairwise.rbf_kernel(X_gt.numpy(), X_syn.numpy(), gamma)
+ XX = metrics.pairwise.rbf_kernel(
+ X_gt.numpy().reshape(len(X_gt), -1),
+ X_gt.numpy().reshape(len(X_gt), -1),
+ gamma,
+ )
+ YY = metrics.pairwise.rbf_kernel(
+ X_syn.numpy().reshape(len(X_syn), -1),
+ X_syn.numpy().reshape(len(X_syn), -1),
+ gamma,
+ )
+ XY = metrics.pairwise.rbf_kernel(
+ X_gt.numpy().reshape(len(X_gt), -1),
+ X_syn.numpy().reshape(len(X_syn), -1),
+ gamma,
+ )
score = XX.mean() + YY.mean() - 2 * XY.mean()
elif self.kernel == "polynomial":
"""
@@ -241,13 +260,25 @@ def _evaluate(
gamma = 1
coef0 = 0
XX = metrics.pairwise.polynomial_kernel(
- X_gt.numpy(), X_gt.numpy(), degree, gamma, coef0
+ X_gt.numpy().reshape(len(X_gt), -1),
+ X_gt.numpy().reshape(len(X_gt), -1),
+ degree,
+ gamma,
+ coef0,
)
YY = metrics.pairwise.polynomial_kernel(
- X_syn.numpy(), X_syn.numpy(), degree, gamma, coef0
+ X_syn.numpy().reshape(len(X_syn), -1),
+ X_syn.numpy().reshape(len(X_syn), -1),
+ degree,
+ gamma,
+ coef0,
)
XY = metrics.pairwise.polynomial_kernel(
- X_gt.numpy(), X_syn.numpy(), degree, gamma, coef0
+ X_gt.numpy().reshape(len(X_gt), -1),
+ X_syn.numpy().reshape(len(X_syn), -1),
+ degree,
+ gamma,
+ coef0,
)
score = XX.mean() + YY.mean() - 2 * XY.mean()
else:
@@ -348,8 +379,9 @@ def _evaluate(
X: DataLoader,
X_syn: DataLoader,
) -> Dict:
- X_ = X.numpy()
- X_syn_ = X_syn.numpy()
+ X_ = X.numpy().reshape(len(X), -1)
+ X_syn_ = X_syn.numpy().reshape(len(X_syn), -1)
+
if len(X_) > len(X_syn_):
X_syn_ = np.concatenate(
[X_syn_, np.zeros((len(X_) - len(X_syn_), X_.shape[1]))]
@@ -398,8 +430,8 @@ def _evaluate(
X: DataLoader,
X_syn: DataLoader,
) -> Dict:
- X_ = X.numpy()
- X_syn_ = X_syn.numpy()
+ X_ = X.numpy().reshape(len(X), -1)
+ X_syn_ = X_syn.numpy().reshape(len(X_syn), -1)
# Default representation
results = self._compute_prdc(X_, X_syn_)
@@ -628,8 +660,8 @@ def _evaluate(
results = {}
- X_ = X.numpy()
- X_syn_ = X_syn.numpy()
+ X_ = X.numpy().reshape(len(X), -1)
+ X_syn_ = X_syn.numpy().reshape(len(X_syn), -1)
# OneClass representation
emb = "_OC"
@@ -699,3 +731,125 @@ def _evaluate(
"abs_optimism": abs_optimism,
"sightedness": sightedness,
}
+
+
+class FrechetInceptionDistance(StatisticalEvaluator):
+ """
+ .. inheritance-diagram:: synthcity.metrics.eval_statistical.FrechetInceptionDistance
+ :parts: 1
+
+ Calculates the Frechet Inception Distance (FID) to evalulate GANs.
+
+ Paper: GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium.
+
+ The FID metric calculates the distance between two distributions of images.
+ Typically, we have summary statistics (mean & covariance matrix) of one of these distributions, while the 2nd distribution is given by a GAN.
+
+ Adapted by Boris van Breugel(bv292@cam.ac.uk)
+ """
+
+ def __init__(self, **kwargs: Any) -> None:
+ super().__init__(**kwargs)
+
+ @staticmethod
+ def name() -> str:
+ return "fid"
+
+ @staticmethod
+ def direction() -> str:
+ return "minimize"
+
+ def _fit_gaussian(self, act: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ """Calculation of the statistics used by the FID.
+ Params:
+ -- act : activations
+ Returns:
+ -- mu : The mean over samples of the activations
+ -- sigma : The covariance matrix of the activations
+ """
+ mu = np.mean(act, axis=0)
+ sigma = np.cov(act.T)
+ return mu, sigma
+
+ def _calculate_frechet_distance(
+ self,
+ mu1: np.ndarray,
+ sigma1: np.ndarray,
+ mu2: np.ndarray,
+ sigma2: np.ndarray,
+ eps: float = 1e-6,
+ ) -> float:
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+
+ Stable version by Dougal J. Sutherland.
+ Params:
+ -- mu1 : Numpy array containing the activations of the pool_3 layer of the
+ inception net ( like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
+ on an representive data set.
+ -- sigma1: The covariance matrix over activations of the pool_3 layer for
+ generated samples.
+ -- sigma2: The covariance matrix over activations of the pool_3 layer,
+ precalcualted on an representive data set.
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ if mu1.shape != mu2.shape:
+ raise RuntimeError("Training and test mean vectors have different lengths")
+
+ if sigma1.shape != sigma2.shape:
+ raise RuntimeError(
+ "Training and test covariances have different dimensions"
+ )
+
+ diff = mu1 - mu2
+
+ # product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=2e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError("Imaginary component {}".format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def _evaluate(
+ self,
+ X: DataLoader,
+ X_syn: DataLoader,
+ ) -> Dict:
+ if X.type() != "images":
+ raise RuntimeError(
+ f"The metric is valid only for image tasks, but got datasets {X.type()} and {X_syn.type()}"
+ )
+
+ X1 = X.numpy().reshape(len(X), -1)
+ X2 = X_syn.numpy().reshape(len(X_syn), -1)
+
+ mu1, cov1 = self._fit_gaussian(X1)
+ mu2, cov2 = self._fit_gaussian(X2)
+ score = self._calculate_frechet_distance(mu1, cov1, mu2, cov2)
+
+ return {
+ "score": score,
+ }
diff --git a/src/synthcity/plugins/__init__.py b/src/synthcity/plugins/__init__.py
index a4d5e934..f1c64a27 100644
--- a/src/synthcity/plugins/__init__.py
+++ b/src/synthcity/plugins/__init__.py
@@ -14,6 +14,7 @@
"survival_analysis",
"time_series",
"domain_adaptation",
+ "images",
]
plugins = {}
diff --git a/src/synthcity/plugins/core/dataloader.py b/src/synthcity/plugins/core/dataloader.py
index 8619b866..6ccae8a8 100644
--- a/src/synthcity/plugins/core/dataloader.py
+++ b/src/synthcity/plugins/core/dataloader.py
@@ -6,12 +6,16 @@
# third party
import numpy as np
import pandas as pd
+import PIL
+import torch
from pydantic import validate_arguments
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
+from torchvision import transforms
# synthcity absolute
from synthcity.plugins.core.constraints import Constraints
+from synthcity.plugins.core.dataset import FlexibleDataset, TensorDataset
from synthcity.plugins.core.models.data_encoder import DatetimeEncoder
from synthcity.utils.compression import compress_dataset, decompress_dataset
from synthcity.utils.serialization import dataframe_hash
@@ -66,7 +70,7 @@ def __init__(
self,
data_type: str,
data: Any,
- static_features: List[str],
+ static_features: List[str] = [],
temporal_features: List[str] = [],
sensitive_features: List[str] = [],
important_features: List[str] = [],
@@ -255,6 +259,10 @@ def decode(
return self.from_info(decoded, self.info())
+ @abstractmethod
+ def is_tabular(self) -> bool:
+ ...
+
class GenericDataLoader(DataLoader):
"""
@@ -409,7 +417,7 @@ def from_info(data: pd.DataFrame, info: dict) -> "GenericDataLoader":
train_size=info["train_size"],
)
- def __getitem__(self, feature: Union[str, list]) -> Any:
+ def __getitem__(self, feature: Union[str, list, int]) -> Any:
return self.data[feature]
def __setitem__(self, feature: str, val: Any) -> None:
@@ -441,12 +449,39 @@ def fillna(self, value: Any) -> "DataLoader":
self.data = self.data.fillna(value)
return self
+ def is_tabular(self) -> bool:
+ return True
+
class SurvivalAnalysisDataLoader(DataLoader):
"""
.. inheritance-diagram:: synthcity.plugins.core.dataloader.SurvivalAnalysisDataLoader
:parts: 1
+ Data Loader for Survival Analysis Data
+
+ Constructor Args:
+ data: Union[pd.DataFrame, list, np.ndarray]
+ The dataset. Either a Pandas DataFrame or a Numpy Array.
+ time_to_event_column: str
+ Survival Analysis specific time-to-event feature
+ target_column: str
+ The outcome: event or censoring.
+ sensitive_features: List[str]
+ Name of sensitive features.
+ important_features: List[str]
+ Default: None. Only relevant for SurvivalGAN method.
+ target_column: str
+ The feature name that provides labels for downstream tasks.
+ domain_column: Optional[str]
+ Optional domain label, used for domain adaptation algorithms.
+ random_state: int
+ Defaults to zero.
+ train_size: float
+ The ratio to use for train splits.
+
+ Example:
+ >>> TODO
"""
@validate_arguments(config=dict(arbitrary_types_allowed=True))
@@ -588,7 +623,7 @@ def from_info(data: pd.DataFrame, info: dict) -> "DataLoader":
time_horizons=info["time_horizons"],
)
- def __getitem__(self, feature: Union[str, list]) -> Any:
+ def __getitem__(self, feature: Union[str, list, int]) -> Any:
return self.data[feature]
def __setitem__(self, feature: str, val: Any) -> None:
@@ -619,6 +654,9 @@ def fillna(self, value: Any) -> "DataLoader":
self.data = self.data.fillna(value)
return self
+ def is_tabular(self) -> bool:
+ return True
+
class TimeSeriesDataLoader(DataLoader):
"""
@@ -643,6 +681,8 @@ class TimeSeriesDataLoader(DataLoader):
random_state: int
Defaults to zero.
+ Example:
+ >>> TODO
"""
@validate_arguments(config=dict(arbitrary_types_allowed=True))
@@ -861,7 +901,7 @@ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
outcome,
)
- def __getitem__(self, feature: Union[str, list]) -> Any:
+ def __getitem__(self, feature: Union[str, list, int]) -> Any:
return self.data["seq_data"][feature]
def __setitem__(self, feature: str, val: Any) -> None:
@@ -1276,6 +1316,9 @@ def unpack_raw_data(
return static_df, temporal_data, observation_times, outcome_df
+ def is_tabular(self) -> bool:
+ return True
+
class TimeSeriesSurvivalDataLoader(TimeSeriesDataLoader):
"""
@@ -1302,6 +1345,9 @@ class TimeSeriesSurvivalDataLoader(TimeSeriesDataLoader):
random_state. int
Defaults to zero.
+ Example:
+ >>> TODO
+
"""
@validate_arguments(config=dict(arbitrary_types_allowed=True))
@@ -1468,8 +1514,222 @@ def test(self) -> "DataLoader":
return self.unpack_and_decorate(self.filter_ids(test_ids))
+class ImageDataLoader(DataLoader):
+ """
+ .. inheritance-diagram:: synthcity.plugins.core.dataloader.ImageDataLoader
+ :parts: 1
+
+ Data loader for generic image data.
+
+ Constructor Args:
+ data: torch.utils.data.Dataset or torch.Tensor
+ The image dataset or a tuple of (tensor images, tensor labels)
+ random_state: int
+ Defaults to zero.
+ height: int. Default = 32
+ Height to use internally
+ width: Optional[int]
+ Optional width to use internally. If None, it is used the same value as height.
+ train_size: float = 0.8
+ Train dataset ratio.
+ Example:
+ >>> dataset = datasets.MNIST(".", download=True)
+ >>>
+ >>> loader = ImageDataLoader(
+ >>> data=dataset,
+ >>> train_size=0.8,
+ >>> height=32,
+ >>> width=w32,
+ >>> )
+
+ """
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ data: Union[torch.utils.data.Dataset, Tuple[torch.Tensor, torch.Tensor]],
+ height: int = 32,
+ width: Optional[int] = None,
+ random_state: int = 0,
+ train_size: float = 0.8,
+ **kwargs: Any,
+ ) -> None:
+ if width is None:
+ width = height
+
+ if isinstance(data, tuple):
+ X, y = data
+ data = TensorDataset(images=X, targets=y)
+
+ self.data_transform = None
+
+ dummy, _ = data[0]
+ img_transform = []
+ if not isinstance(dummy, PIL.Image.Image):
+ img_transform = [transforms.ToPILImage()]
+
+ img_transform.extend(
+ [
+ transforms.Resize((height, width)),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5,), std=(0.5,)),
+ ]
+ )
+
+ self.data_transform = transforms.Compose(img_transform)
+ data = FlexibleDataset(data, transform=self.data_transform)
+
+ self.height = height
+ self.width = width
+ self.channels = data.shape()[1]
+
+ super().__init__(
+ data_type="images",
+ data=data,
+ random_state=random_state,
+ train_size=train_size,
+ **kwargs,
+ )
+
+ @property
+ def shape(self) -> tuple:
+ return self.data.shape()
+
+ def unpack(self, as_numpy: bool = False, pad: bool = False) -> Any:
+ return self.data
+
+ def numpy(self) -> np.ndarray:
+ x, _ = self.data.numpy()
+
+ return x
+
+ def dataframe(self) -> pd.DataFrame:
+ x = self.numpy().reshape(len(self), -1)
+
+ x = pd.DataFrame(x)
+ x.columns = x.columns.astype(str)
+
+ return x
+
+ def info(self) -> dict:
+ return {
+ "data_type": self.data_type,
+ "len": len(self),
+ "train_size": self.train_size,
+ "height": self.height,
+ "width": self.width,
+ "channels": self.channels,
+ "random_state": self.random_state,
+ }
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def decorate(self, data: Any) -> "DataLoader":
+ return ImageDataLoader(
+ data,
+ random_state=self.random_state,
+ train_size=self.train_size,
+ height=self.height,
+ width=self.width,
+ )
+
+ def sample(self, count: int, random_state: int = 0) -> "DataLoader":
+ idxs = np.random.choice(len(self), count, replace=False)
+ subset = FlexibleDataset(self.data.data, indices=idxs)
+ return self.decorate(subset)
+
+ @staticmethod
+ def from_info(data: torch.utils.data.Dataset, info: dict) -> "ImageDataLoader":
+ if not isinstance(data, torch.utils.data.Dataset):
+ raise ValueError(f"Invalid data type {type(data)}")
+
+ return ImageDataLoader(
+ data,
+ train_size=info["train_size"],
+ height=info["height"],
+ width=info["width"],
+ random_state=info["random_state"],
+ )
+
+ def __getitem__(self, index: Union[list, int, str]) -> Any:
+ if isinstance(index, str):
+ return self.dataframe()[index]
+
+ return self.numpy()[index]
+
+ def _train_test_split(self) -> Tuple:
+ indices = np.arange(len(self.data))
+ _, stratify = self.data.numpy()
+
+ return train_test_split(
+ indices,
+ train_size=self.train_size,
+ random_state=self.random_state,
+ stratify=stratify,
+ )
+
+ def train(self) -> "DataLoader":
+ train_idx, _ = self._train_test_split()
+ subset = FlexibleDataset(self.data.data, indices=train_idx)
+ return self.decorate(subset)
+
+ def test(self) -> "DataLoader":
+ _, test_idx = self._train_test_split()
+ subset = FlexibleDataset(self.data.data, indices=test_idx)
+ return self.decorate(subset)
+
+ def compress(
+ self,
+ ) -> Tuple["DataLoader", Dict]:
+ return self, {}
+
+ def decompress(self, context: Dict) -> "DataLoader":
+ return self
+
+ def encode(
+ self,
+ encoders: Optional[Dict[str, Any]] = None,
+ ) -> Tuple["DataLoader", Dict]:
+ return self, {}
+
+ def decode(
+ self,
+ encoders: Dict[str, Any],
+ ) -> "DataLoader":
+ return self
+
+ def is_tabular(self) -> bool:
+ return False
+
+ @property
+ def columns(self) -> list:
+ return list(self.dataframe().columns)
+
+ def satisfies(self, constraints: Constraints) -> bool:
+ return True
+
+ def match(self, constraints: Constraints) -> "DataLoader":
+ return self
+
+ def compression_protected_features(self) -> list:
+ raise NotImplementedError("Images do not support the compression call")
+
+ def drop(self, columns: list = []) -> "DataLoader":
+ raise NotImplementedError()
+
+ def __setitem__(self, feature: str, val: Any) -> None:
+ raise NotImplementedError()
+
+ def fillna(self, value: Any) -> "DataLoader":
+ raise NotImplementedError()
+
+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
-def create_from_info(data: pd.DataFrame, info: dict) -> "DataLoader":
+def create_from_info(
+ data: Union[pd.DataFrame, torch.utils.data.Dataset], info: dict
+) -> "DataLoader":
+ """Helper for creating a DataLoader from existing information."""
if info["data_type"] == "generic":
return GenericDataLoader.from_info(data, info)
elif info["data_type"] == "survival_analysis":
@@ -1478,5 +1738,7 @@ def create_from_info(data: pd.DataFrame, info: dict) -> "DataLoader":
return TimeSeriesDataLoader.from_info(data, info)
elif info["data_type"] == "time_series_survival":
return TimeSeriesSurvivalDataLoader.from_info(data, info)
+ elif info["data_type"] == "images":
+ return ImageDataLoader.from_info(data, info)
else:
raise RuntimeError(f"invalid datatype {info}")
diff --git a/src/synthcity/plugins/core/dataset.py b/src/synthcity/plugins/core/dataset.py
new file mode 100644
index 00000000..9b177839
--- /dev/null
+++ b/src/synthcity/plugins/core/dataset.py
@@ -0,0 +1,188 @@
+# stdlib
+from typing import List, Optional, Tuple
+
+# third party
+import numpy as np
+import torch
+
+# synthcity absolute
+from synthcity.utils.constants import DEVICE
+
+
+class FlexibleDataset(torch.utils.data.Dataset):
+ """Helper dataset wrapper for post-processing or transforming another dataset. Used for controlling the image sizes for the synthcity models.
+
+ The class supports adding custom transforms to existing datasets, and to subsample a set of indices.
+
+ Args:
+ data: torch.Dataset
+ transform: An optional list of transforms
+ indices: An optional list of indices to subsample
+ """
+
+ def __init__(
+ self,
+ data: torch.utils.data.Dataset,
+ transform: Optional[torch.nn.Module] = None,
+ indices: Optional[list] = None,
+ ) -> None:
+ super().__init__()
+
+ if indices is None:
+ indices = np.arange(len(data))
+
+ self.indices = np.asarray(indices)
+ self.data = data
+ self.transform = transform
+ self.ndarrays: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ x, y = self.data[self.indices[index]]
+ if self.transform:
+ x = self.transform(x)
+ return x, y
+
+ def __len__(self) -> int:
+ return len(self.indices)
+
+ def shape(self) -> Tuple:
+ x, _ = self[self.indices[0]]
+
+ return (len(self), *x.shape)
+
+ def numpy(self) -> Tuple[np.ndarray, np.ndarray]:
+ if self.ndarrays is not None:
+ return self.ndarrays
+
+ x_buff = []
+ y_buff = []
+ for idx in range(len(self)):
+ x_local, y_local = self[idx]
+ x_buff.append(x_local.unsqueeze(0).cpu().numpy())
+ y_buff.append(y_local)
+
+ x = np.concatenate(x_buff, axis=0)
+ y = np.asarray(y_buff)
+
+ self.ndarrays = (x, y)
+ return x, y
+
+ def tensors(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ x, y = self.numpy()
+
+ return torch.from_numpy(x), torch.from_numpy(y)
+
+ def labels(self) -> np.ndarray:
+ labels = []
+ for idx in self.indices:
+ _, y = self.data[idx]
+ labels.append(y)
+
+ return np.asarray(labels)
+
+ def filter_indices(self, indices: List[int]) -> "FlexibleDataset":
+ for idx in indices:
+ if idx >= len(self.indices):
+ raise ValueError(
+ "Invalid filtering list. {idx} not found in the current list of indices"
+ )
+ return FlexibleDataset(
+ data=self.data, transform=self.transform, indices=self.indices[indices]
+ )
+
+
+class TensorDataset(torch.utils.data.Dataset):
+ """Helper dataset for wrapping existing tensors
+
+ Args:
+ images: Tensor
+ targets: Tensor
+ """
+
+ def __init__(
+ self,
+ images: torch.Tensor,
+ targets: Optional[torch.Tensor],
+ ) -> None:
+ super().__init__()
+
+ if targets is not None and len(targets) != len(images):
+ raise ValueError("Invalid input")
+
+ self.images = images
+ self.targets = targets
+
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ y: Optional[torch.Tensor] = None
+ x = self.images[index]
+
+ if self.targets is not None:
+ y = self.targets[index]
+
+ return x, y
+
+ def __len__(self) -> int:
+ return len(self.images)
+
+ def labels(self) -> Optional[np.ndarray]:
+ if self.targets is None:
+ return None
+
+ return self.targets.cpu().numpy()
+
+
+class ConditionalDataset(torch.utils.data.Dataset):
+ """Helper dataset for wrapping existing datasets with custom tensors
+
+ Args:
+ data: torch.Dataset
+ cond: Optional Tensor
+ """
+
+ def __init__(
+ self,
+ data: torch.utils.data.Dataset,
+ cond: Optional[torch.Tensor] = None,
+ ) -> None:
+ super().__init__()
+
+ if cond is not None and len(cond) != len(data):
+ raise ValueError("Invalid input")
+
+ self.data = data
+ self.cond = cond
+
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ cond: Optional[torch.Tensor] = None
+ x = self.data[index][0]
+
+ if self.cond is not None:
+ cond = self.cond[index]
+
+ return x, cond
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+
+class NumpyDataset(torch.utils.data.Dataset):
+ """Helper class for wrapping Numpy arrays in torch Datasets
+ Args:
+ X: np.ndarray
+ y: np.ndarray
+ """
+
+ def __init__(self, X: np.ndarray, y: np.ndarray) -> None:
+ super().__init__()
+
+ self.X = X
+ self.y = y
+
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = self.X[index]
+ y = self.y[index]
+
+ return torch.from_numpy(x).to(DEVICE), y
+
+ def __len__(self) -> int:
+ return len(self.X)
diff --git a/src/synthcity/plugins/core/models/convnet.py b/src/synthcity/plugins/core/models/convnet.py
new file mode 100644
index 00000000..48066fbb
--- /dev/null
+++ b/src/synthcity/plugins/core/models/convnet.py
@@ -0,0 +1,583 @@
+# stdlib
+from typing import Any, Optional, Tuple
+
+# third party
+import numpy as np
+import torch
+from monai.networks.layers.factories import Act
+from monai.networks.nets import Classifier, Discriminator, Generator
+from pydantic import validate_arguments
+from torch import nn
+
+# synthcity absolute
+import synthcity.logger as log
+from synthcity.utils.constants import DEVICE
+from synthcity.utils.reproducibility import enable_reproducible_results
+
+
+def map_nonlin(nonlin: str) -> Act:
+ if nonlin == "relu":
+ return Act.RELU
+ elif nonlin == "elu":
+ return Act.ELU
+ elif nonlin == "prelu":
+ return Act.PRELU
+ elif nonlin == "leaky_relu":
+ return Act.LEAKYRELU
+ elif nonlin == "sigmoid":
+ return Act.SIGMOID
+ elif nonlin == "softmax":
+ return Act.SOFTMAX
+ elif nonlin == "tanh":
+ return Act.TANH
+
+ raise ValueError(f"Unknown activation {nonlin}")
+
+
+class ConvNet(nn.Module):
+ """
+ Wrapper for convolutional nets for classification and regression.
+
+ Parameters
+ ----------
+ task_type: str
+ classifier or regression
+ model: nn.Module
+ classification or regression model implementation
+ lr: float
+ learning rate for optimizer.
+ weight_decay: float
+ l2 (ridge) penalty for the weights.
+ n_iter: int
+ Maximum number of iterations.
+ batch_size: int
+ Batch size
+ n_iter_print: int
+ Number of iterations after which to print updates and check the validation loss.
+ random_state: int
+ random_state used
+ patience: int
+ Number of iterations to wait before early stopping after decrease in validation loss
+ n_iter_min: int
+ Minimum number of iterations to go through before starting early stopping
+ clipping_value: int, default 1
+ Gradients clipping value
+ early_stopping: bool
+ Enable/disable early stopping
+ """
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ task_type: str,
+ model: nn.Module, # classification/regression
+ lr: float = 1e-3,
+ weight_decay: float = 1e-3,
+ opt_betas: tuple = (0.9, 0.999),
+ n_iter: int = 1000,
+ batch_size: int = 500,
+ n_iter_print: int = 100,
+ random_state: int = 0,
+ patience: int = 10,
+ n_iter_min: int = 100,
+ clipping_value: int = 1,
+ early_stopping: bool = True,
+ device: Any = DEVICE,
+ ) -> None:
+ super(ConvNet, self).__init__()
+
+ if task_type not in ["classification", "regression"]:
+ raise ValueError(f"Invalid task type {task_type}")
+
+ enable_reproducible_results(random_state)
+
+ self.task_type = task_type
+ self.device = device
+ self.model = model
+ self.random_state = random_state
+
+ # optimizer
+ self.lr = lr
+ self.weight_decay = weight_decay
+ self.opt_betas = opt_betas
+ self.optimizer = torch.optim.Adam(
+ self.parameters(),
+ lr=self.lr,
+ weight_decay=self.weight_decay,
+ betas=self.opt_betas,
+ )
+
+ # training
+ self.n_iter = n_iter
+ self.n_iter_print = n_iter_print
+ self.n_iter_min = n_iter_min
+ self.batch_size = batch_size
+ self.patience = patience
+ self.clipping_value = clipping_value
+ self.early_stopping = early_stopping
+ if task_type == "classification":
+ self.loss = nn.CrossEntropyLoss()
+ else:
+ self.loss = nn.MSELoss()
+
+ def fit(self, X: torch.utils.data.Dataset) -> "ConvNet":
+ train_size = int(0.8 * len(X))
+ test_size = len(X) - train_size
+ train_dataset, test_dataset = torch.utils.data.random_split(
+ X, [train_size, test_size]
+ )
+ train_loader = torch.utils.data.DataLoader(
+ train_dataset, batch_size=self.batch_size, pin_memory=False
+ )
+ test_loader = torch.utils.data.DataLoader(
+ test_dataset, batch_size=len(test_dataset)
+ )
+
+ # Setup the network and optimizer
+
+ val_loss_best = 999999
+ patience = 0
+
+ # do training
+ for i in range(self.n_iter):
+ train_loss = self._train_epoch(train_loader)
+
+ if self.early_stopping or i % self.n_iter_print == 0:
+ with torch.no_grad():
+ X_val, y_val = next(iter(test_loader))
+ X_val = self._check_tensor(X_val)
+ y_val = self._check_tensor(y_val).long()
+
+ preds = self.forward(X_val).squeeze()
+ val_loss = self.loss(preds, y_val)
+
+ if self.early_stopping:
+ if val_loss_best > val_loss:
+ val_loss_best = val_loss
+ patience = 0
+ else:
+ patience += 1
+
+ if patience > self.patience and i > self.n_iter_min:
+ break
+
+ if i % self.n_iter_print == 0:
+ log.debug(
+ f"Epoch: {i}, val loss: {val_loss}, train_loss: {train_loss}"
+ )
+
+ return self
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def predict_proba(self, X: torch.Tensor) -> torch.Tensor:
+ if self.task_type != "classification":
+ raise ValueError(f"Invalid task type for predict_proba {self.task_type}")
+
+ with torch.no_grad():
+ Xt = self._check_tensor(X)
+
+ yt = self.forward(Xt)
+
+ return yt.cpu()
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def predict(self, X: torch.Tensor) -> torch.Tensor:
+ with torch.no_grad():
+ Xt = self._check_tensor(X)
+
+ yt = self.forward(Xt)
+
+ if self.task_type == "classification":
+ return torch.argmax(yt.cpu(), -1).squeeze()
+ else:
+ return yt.cpu()
+
+ def score(self, X: torch.Tensor, y: torch.Tensor) -> float:
+ y_pred = self.predict(X)
+ if self.task_type == "classification":
+ return torch.mean(y_pred == y)
+ else:
+ return torch.mean(torch.inner(y - y_pred, y - y_pred) / 2.0)
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
+ X = self._check_tensor(X)
+
+ return self.model(X.float())
+
+ def _train_epoch(self, loader: torch.utils.data.DataLoader) -> float:
+ train_loss = []
+
+ for batch_ndx, sample in enumerate(loader):
+ self.optimizer.zero_grad()
+
+ X_next, y_next = sample
+
+ X_next = self._check_tensor(X_next)
+ y_next = self._check_tensor(y_next).long()
+
+ if len(X_next) < 2:
+ continue
+
+ preds = self.forward(X_next).squeeze()
+
+ batch_loss = self.loss(preds, y_next)
+
+ batch_loss.backward()
+
+ if self.clipping_value > 0:
+ torch.nn.utils.clip_grad_norm_(self.parameters(), self.clipping_value)
+
+ self.optimizer.step()
+
+ train_loss.append(batch_loss.detach())
+
+ return torch.mean(torch.Tensor(train_loss))
+
+ def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
+ if isinstance(X, torch.Tensor):
+ return X.to(self.device)
+ else:
+ return torch.from_numpy(np.asarray(X)).to(self.device)
+
+ def __len__(self) -> int:
+ return len(self.model)
+
+
+class ConditionalGenerator(nn.Module):
+ """Wrapper for making existing CNN generator conditional. Useful for Conditional GANs
+
+ Args:
+ model: nn.Module
+ Core model.
+ n_channels: int
+ Number of channels in images
+ n_units_latent: int
+ Noise size for the input
+ cond: torch.Tensor
+ The reference conditional
+ cond_embedding_n_units_hidden: int
+ Size of the conditional embedding layer
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ n_channels: int,
+ n_units_latent: int,
+ cond: Optional[torch.Tensor] = None,
+ cond_embedding_n_units_hidden: int = 100,
+ device: Any = DEVICE,
+ ) -> None:
+ super(ConditionalGenerator, self).__init__()
+
+ self.model = model
+ self.cond = cond
+ self.n_channels = n_channels
+ self.n_units_latent = n_units_latent
+ self.device = device
+
+ self.label_conditioned_generator: Optional[nn.Module] = None
+ if cond is not None:
+ classes = torch.unique(self.cond)
+ self.label_conditioned_generator = nn.Sequential(
+ nn.Embedding(len(classes), cond_embedding_n_units_hidden),
+ nn.Linear(cond_embedding_n_units_hidden, n_channels * n_units_latent),
+ ).to(device)
+
+ def forward(
+ self, noise: torch.Tensor, cond: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if cond is None and self.cond is not None:
+ perm = torch.randperm(self.cond.size(0))
+ cond = self.cond[perm[: len(noise)]]
+
+ if self.label_conditioned_generator is not None and cond is not None:
+ cond_emb = self.label_conditioned_generator(cond.long()).view(
+ -1, self.n_units_latent, self.n_channels, 1
+ )
+ noise = torch.cat(
+ (noise, cond_emb), dim=2
+ ) # add another channel with the conditional
+ return self.model(noise)
+
+
+class ConditionalDiscriminator(nn.Module):
+ """
+
+ Args:
+ model: nn.Module
+ Core model.
+ n_channels: int
+ Number of channels in images
+ height: int
+ Image height
+ width: int
+ Image width
+ cond: torch.Tensor
+ The reference conditional
+ cond_embedding_n_units_hidden: int
+ Size of the conditional embedding layer
+ """
+
+ def __init__(
+ self,
+ model: nn.Module,
+ n_channels: int,
+ height: int,
+ width: int,
+ cond: Optional[torch.Tensor] = None,
+ cond_embedding_n_units_hidden: int = 100,
+ device: Any = DEVICE,
+ ) -> None:
+ super(ConditionalDiscriminator, self).__init__()
+ self.model = model
+ self.cond = cond
+
+ self.n_channels = n_channels
+ self.height = height
+ self.width = width
+
+ self.device = device
+
+ self.label_conditioned_generator: Optional[nn.Module] = None
+ if cond is not None:
+ classes = torch.unique(self.cond)
+ self.label_conditioned_generator = nn.Sequential(
+ nn.Embedding(len(classes), cond_embedding_n_units_hidden),
+ nn.Linear(cond_embedding_n_units_hidden, n_channels * height * width),
+ ).to(device)
+
+ def forward(
+ self, X: torch.Tensor, cond: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if cond is None and self.cond is not None:
+ perm = torch.randperm(self.cond.size(0))
+ cond = self.cond[perm[: len(X)]]
+
+ if self.label_conditioned_generator is not None and cond is not None:
+ cond_emb = self.label_conditioned_generator(cond.long()).view(
+ -1, self.n_channels, self.height, self.width
+ )
+ X = torch.cat(
+ (X, cond_emb), dim=1
+ ) # add another channel with the conditional
+
+ return self.model(X)
+
+
+def suggest_image_generator_discriminator_arch(
+ n_units_latent: int,
+ n_channels: int,
+ height: int,
+ width: int,
+ generator_dropout: float = 0.2,
+ generator_nonlin: str = "prelu",
+ generator_n_residual_units: int = 2,
+ discriminator_dropout: float = 0.2,
+ discriminator_nonlin: str = "prelu",
+ discriminator_n_residual_units: int = 2,
+ device: Any = DEVICE,
+ strategy: str = "predefined",
+ cond: Optional[torch.Tensor] = None,
+ cond_embedding_n_units_hidden: int = 100,
+) -> Tuple[ConditionalGenerator, ConditionalDiscriminator]:
+ """Helper for selecting compatible architecture for image generators and discriminators.
+
+ Args:
+ n_units_latent: int,
+ Input size for the generator
+ n_channels: int
+ Number of channels in the image
+ height: int
+ Image height
+ width: int
+ Image width
+ generator_dropout: float = 0.2
+ Dropout value for the generator
+ generator_nonlin: str
+ name of the activation activation layers in the generator. Can be relu, elu, prelu or leaky_relu
+ generator_n_residual_units: int
+ integer stating number of convolutions in residual units for the generator, 0 means no residual units
+ discriminator_dropout: float = 0.2
+ Dropout value for the discriminator
+ discriminator_nonlin: str
+ name of the activation activation layers in the discriminator. Can be relu, elu, prelu or leaky_relu
+ discriminator_n_residual_units: int
+ integer stating number of convolutions in residual units for the discriminator, 0 means no residual units
+ device: str
+ PyTorch device. cpu, cuda
+ strategy: str
+ Which suggestion to use. Options:
+ - predefined: a few hardcoded architectures for certain image shapes.
+ - ...
+ """
+ cond_weight = 1 if cond is None else 2
+ if strategy == "predefined":
+ if height == 32 and width == 32:
+ start_shape_gen = 4
+ start_stride_disc = 2
+ elif height == 64 and width == 64:
+ start_shape_gen = 8
+ start_stride_disc = 4
+ elif height == 128 and width == 128:
+ start_shape_gen = 16
+ start_stride_disc = 8
+ else:
+ raise ValueError(
+ f"Unsupported predefined arch : ({n_channels}, {height}, {width})"
+ )
+
+ generator = nn.Sequential(
+ Generator(
+ latent_shape=(n_units_latent, cond_weight * n_channels),
+ start_shape=(64, start_shape_gen, start_shape_gen),
+ channels=[64, 32, 16, n_channels],
+ strides=[2, 2, 2, 1],
+ kernel_size=3,
+ dropout=generator_dropout,
+ act=map_nonlin(generator_nonlin),
+ num_res_units=generator_n_residual_units,
+ ),
+ nn.Tanh(),
+ ).to(device)
+ discriminator = Discriminator(
+ in_shape=(cond_weight * n_channels, height, width),
+ channels=[16, 32, 64, 1],
+ strides=[start_stride_disc, 2, 2, 2],
+ kernel_size=3,
+ last_act=None,
+ dropout=discriminator_dropout,
+ act=map_nonlin(generator_nonlin),
+ num_res_units=discriminator_n_residual_units,
+ ).to(device)
+
+ return ConditionalGenerator(
+ model=generator,
+ n_channels=n_channels,
+ n_units_latent=n_units_latent,
+ cond=cond,
+ cond_embedding_n_units_hidden=cond_embedding_n_units_hidden,
+ device=device,
+ ), ConditionalDiscriminator(
+ discriminator,
+ n_channels=n_channels,
+ height=height,
+ width=width,
+ cond=cond,
+ cond_embedding_n_units_hidden=cond_embedding_n_units_hidden,
+ device=device,
+ )
+
+ raise ValueError(f"unsupported image arch : ({n_channels}, {height}, {width})")
+
+
+def suggest_image_classifier_arch(
+ n_channels: int,
+ height: int,
+ width: int,
+ classes: int,
+ n_residual_units: int = 2,
+ nonlin: str = "prelu",
+ dropout: float = 0.2,
+ last_nonlin: str = "softmax",
+ device: Any = DEVICE,
+ strategy: str = "predefined",
+ # training
+ lr: float = 1e-3,
+ weight_decay: float = 1e-3,
+ opt_betas: tuple = (0.9, 0.999),
+ n_iter: int = 1000,
+ batch_size: int = 500,
+ n_iter_print: int = 100,
+ random_state: int = 0,
+ patience: int = 10,
+ n_iter_min: int = 100,
+ clipping_value: int = 1,
+ early_stopping: bool = True,
+) -> ConvNet:
+ """Helper for selecting compatible architecture for image classifiers.
+
+ Args:
+ n_channels: int
+ Number of channels in the image
+ height: int
+ Image height
+ width: int
+ Image width
+ classes: int
+ Number of output classes
+ nonlin: str
+ name of the activation activation layers. Can be relu, elu, prelu or leaky_relu
+ last_act: str
+ output activation
+ dropout: float = 0.2
+ Dropout value
+ n_residual_units: int
+ integer stating number of convolutions in residual units, 0 means no residual units
+ device: str
+ PyTorch device. cpu, cuda
+ # Training
+ lr: float
+ learning rate for optimizer.
+ weight_decay: float
+ l2 (ridge) penalty for the weights.
+ n_iter: int
+ Maximum number of iterations.
+ batch_size: int
+ Batch size
+ n_iter_print: int
+ Number of iterations after which to print updates and check the validation loss.
+ random_state: int
+ random_state used
+ patience: int
+ Number of iterations to wait before early stopping after decrease in validation loss
+ n_iter_min: int
+ Minimum number of iterations to go through before starting early stopping
+ clipping_value: int, default 1
+ Gradients clipping value
+ early_stopping: bool
+ Enable/disable early stopping
+
+
+ """
+ if strategy == "predefined":
+ if height == 32 and width == 32:
+ start_stride = 2
+ elif height == 64 and width == 64:
+ start_stride = 4
+ elif height == 128 and width == 128:
+ start_stride = 8
+ else:
+ raise ValueError(
+ f"Unsupported predefined arch : ({n_channels}, {height}, {width})"
+ )
+
+ clf = Classifier(
+ in_shape=(n_channels, height, width),
+ classes=classes,
+ channels=[16, 32, 64, 1],
+ strides=[start_stride, 2, 2, 2],
+ act=map_nonlin(nonlin),
+ last_act=map_nonlin(last_nonlin),
+ dropout=dropout,
+ num_res_units=n_residual_units,
+ ).to(device)
+ return ConvNet(
+ task_type="classification",
+ model=clf,
+ device=device,
+ lr=lr,
+ weight_decay=weight_decay,
+ opt_betas=opt_betas,
+ n_iter=n_iter,
+ batch_size=batch_size,
+ n_iter_print=n_iter_print,
+ random_state=random_state,
+ patience=patience,
+ n_iter_min=n_iter_min,
+ clipping_value=clipping_value,
+ early_stopping=early_stopping,
+ )
+
+ raise ValueError(f"unsupported image arch : ({n_channels}, {height}, {width})")
diff --git a/src/synthcity/plugins/core/models/image_gan.py b/src/synthcity/plugins/core/models/image_gan.py
new file mode 100644
index 00000000..24252097
--- /dev/null
+++ b/src/synthcity/plugins/core/models/image_gan.py
@@ -0,0 +1,686 @@
+# stdlib
+from typing import Any, Callable, List, Optional, Tuple
+
+# third party
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from opacus import PrivacyEngine
+from pydantic import validate_arguments
+from torch import nn
+from torch.utils.data import DataLoader, sampler
+from tqdm import tqdm
+
+# synthcity absolute
+import synthcity.logger as log
+from synthcity.metrics.weighted_metrics import WeightedMetrics
+from synthcity.plugins.core.dataloader import ImageDataLoader
+from synthcity.plugins.core.dataset import ConditionalDataset, FlexibleDataset
+from synthcity.utils.constants import DEVICE
+from synthcity.utils.reproducibility import clear_cache, enable_reproducible_results
+
+
+def display_imgs(imgs: List[np.ndarray], title: Optional[str] = None) -> None:
+ for i in range(len(imgs)):
+ plt.subplot(1, len(imgs), i + 1)
+ plt.tight_layout()
+ imgs[i] = imgs[i] / 2 + 0.5 # unnormalize
+ plt.imshow(np.transpose(imgs[i].cpu().numpy(), (1, 2, 0)))
+
+ if title is not None:
+ plt.title(title)
+ plt.show()
+
+
+def weights_init(m: nn.Module) -> None:
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
+ if isinstance(m, nn.BatchNorm2d):
+ torch.nn.init.normal_(m.weight, mean=0.0, std=0.02)
+ torch.nn.init.constant_(m.bias, val=0)
+
+
+class ImageGAN(nn.Module):
+ """
+ .. inheritance-diagram:: synthcity.plugins.core.models.image_gan.ImageGAN
+ :parts: 1
+
+ Basic GAN implementation.
+
+ Args:
+ image_generator: nn.Module
+ Generator model
+ image_discriminator: nn.Module
+ Discriminator model
+ n_units_latent: int
+ Number of latent units
+ n_channels: int
+ Number of channels in the image
+ generator_n_iter: int
+ Maximum number of iterations in the Generator.
+ generator_lr: float = 2e-4
+ Generator learning rate, used by the Adam optimizer
+ generator_weight_decay: float = 1e-3
+ Generator weight decay, used by the Adam optimizer
+ generator_opt_betas: tuple = (0.9, 0.999)
+ Generator initial decay rates, used by the Adam Optimizer
+ generator_extra_penalties: list
+ Additional penalties for the generator. Values: "identifiability_penalty"
+ generator_extra_penalty_cbks: List[Callable]
+ Additional loss callabacks for the generator. Used by the TabularGAN for the conditional loss
+ discriminator_n_iter: int
+ Maximum number of iterations in the discriminator.
+ discriminator_lr: float
+ Discriminator learning rate, used by the Adam optimizer
+ discriminator_weight_decay: float
+ Discriminator weight decay, used by the Adam optimizer
+ discriminator_opt_betas: tuple
+ Initial weight decays for the Adam optimizer
+ batch_size: int
+ Batch size
+ random_state: int
+ random_state used
+ clipping_value: int, default 0
+ Gradients clipping value. Zero disables the feature
+ lambda_gradient_penalty: float = 10
+ Weight for the gradient penalty
+ lambda_identifiability_penalty: float = 0.1
+ Weight for the identifiability penalty, if enabled
+ dataloader_sampler: Optional[sampler.Sampler]
+ Optional sampler for the dataloader, useful for conditional sampling
+ device: Any = DEVICE
+ CUDA/CPU
+ # early stopping
+ n_iter_print: int
+ Number of iterations after which to print updates and check the validation loss.
+ n_iter_min: int
+ Minimum number of iterations to go through before starting early stopping
+ patience: int
+ Max number of iterations without any improvement before early stopping is trigged.
+ patience_metric: Optional[WeightedMetrics]
+ If not None, the metric is used for evaluation the criterion for early stopping.
+ # privacy settings
+ dp_enabled: bool
+ Train the discriminator with Differential Privacy guarantees
+ dp_delta: Optional[float]
+ Optional DP delta: the probability of information accidentally being leaked. Usually 1 / len(dataset)
+ dp_epsilon: float = 3
+ DP epsilon: privacy budget, which is a measure of the amount of privacy that is preserved by a given algorithm. Epsilon is a number that represents the maximum amount of information that an adversary can learn about an individual from the output of a differentially private algorithm. The smaller the value of epsilon, the more private the algorithm is. For example, an algorithm with an epsilon of 0.1 preserves more privacy than an algorithm with an epsilon of 1.0.
+ dp_max_grad_norm: float
+ max grad norm used for gradient clipping
+ dp_secure_mode: bool = False,
+ if True uses noise generation approach robust to floating point arithmetic attacks.
+ """
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ image_generator: nn.Module,
+ image_discriminator: nn.Module,
+ n_units_latent: int,
+ n_channels: int,
+ # generator
+ generator_n_iter: int = 500,
+ generator_lr: float = 2e-4,
+ generator_weight_decay: float = 1e-3,
+ generator_opt_betas: tuple = (0.9, 0.999),
+ generator_extra_penalties: list = [], # "identifiability_penalty"
+ generator_extra_penalty_cbks: List[Callable] = [],
+ # discriminator
+ discriminator_n_iter: int = 1,
+ discriminator_lr: float = 2e-4,
+ discriminator_weight_decay: float = 1e-3,
+ discriminator_opt_betas: tuple = (0.9, 0.999),
+ # training
+ batch_size: int = 100,
+ random_state: int = 0,
+ clipping_value: int = 1,
+ lambda_gradient_penalty: float = 10,
+ lambda_identifiability_penalty: float = 0.1,
+ device: Any = DEVICE,
+ n_iter_min: int = 100,
+ n_iter_print: int = 1,
+ plot_progress: int = False,
+ patience: int = 20,
+ patience_metric: Optional[WeightedMetrics] = None,
+ dataloader_sampler: Optional[sampler.Sampler] = None,
+ # privacy settings
+ dp_enabled: bool = False,
+ dp_delta: Optional[float] = None,
+ dp_epsilon: float = 3,
+ dp_max_grad_norm: float = 2,
+ dp_secure_mode: bool = False,
+ ) -> None:
+ super(ImageGAN, self).__init__()
+
+ extra_penalty_list = ["identifiability_penalty"]
+ for penalty in generator_extra_penalties:
+ if penalty not in extra_penalty_list:
+ raise ValueError(f"Unsupported generator penalty {penalty}")
+
+ log.info(f"Training ImageGAN on device {device}.")
+ self.device = device
+ self.generator_extra_penalties = generator_extra_penalties
+ self.generator_extra_penalty_cbks = generator_extra_penalty_cbks
+
+ self.generator = image_generator.apply(weights_init)
+ self.discriminator = image_discriminator.apply(weights_init)
+
+ self.n_units_latent = n_units_latent
+ self.n_channels = n_channels
+ self.plot_progress = plot_progress
+
+ # training
+ self.generator_n_iter = generator_n_iter
+ self.generator_lr = generator_lr
+ self.generator_weight_decay = generator_weight_decay
+ self.generator_opt_betas = generator_opt_betas
+
+ self.discriminator_n_iter = discriminator_n_iter
+ self.discriminator_lr = discriminator_lr
+ self.discriminator_weight_decay = discriminator_weight_decay
+ self.discriminator_opt_betas = discriminator_opt_betas
+
+ self.n_iter_print = n_iter_print
+ self.n_iter_min = n_iter_min
+ self.patience = patience
+ self.patience_metric = patience_metric
+ self.batch_size = batch_size
+ self.clipping_value = clipping_value
+
+ self.lambda_gradient_penalty = lambda_gradient_penalty
+ self.lambda_identifiability_penalty = lambda_identifiability_penalty
+
+ self.random_state = random_state
+ enable_reproducible_results(random_state)
+
+ def gen_fake_labels(X: torch.Tensor) -> torch.Tensor:
+ return torch.zeros((len(X),), device=self.device)
+
+ def gen_true_labels(X: torch.Tensor) -> torch.Tensor:
+ return torch.ones((len(X),), device=self.device)
+
+ self.fake_labels_generator = gen_fake_labels
+ self.true_labels_generator = gen_true_labels
+ self.dataloader_sampler = dataloader_sampler
+
+ # privacy
+ self.dp_enabled = dp_enabled
+ self.dp_delta = dp_delta
+ self.dp_epsilon = dp_epsilon
+ self.dp_max_grad_norm = dp_max_grad_norm
+ self.dp_secure_mode = dp_secure_mode
+
+ def _get_noise(self, n_samples: int) -> torch.Tensor:
+ """
+ Generate noise vectors from the random normal distribution with dimensions (n_samples, noise_dim),
+ where
+ n_samples: the number of samples to generate based on batch_size
+ """
+
+ return torch.randn(
+ n_samples, self.n_units_latent, self.n_channels, 1, device=self.device
+ )
+
+ def fit(
+ self,
+ X: FlexibleDataset,
+ cond: Optional[torch.Tensor] = None,
+ fake_labels_generator: Optional[Callable] = None,
+ true_labels_generator: Optional[Callable] = None,
+ ) -> "ImageGAN":
+ clear_cache()
+
+ self.with_conditional = False
+ if cond is not None:
+ cond = self._check_tensor(cond)
+ self.with_conditional = True
+
+ self._train(
+ X,
+ cond=cond,
+ fake_labels_generator=fake_labels_generator,
+ true_labels_generator=true_labels_generator,
+ )
+
+ return self
+
+ def generate(
+ self,
+ count: int,
+ cond: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ clear_cache()
+ self.generator.eval()
+ with torch.no_grad():
+ return self(count, cond=cond).detach()
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def forward(
+ self,
+ count: int,
+ cond: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ fixed_noise = self._get_noise(count)
+
+ if cond is not None:
+ cond = self._check_tensor(cond)
+
+ return self.generator(fixed_noise, cond=cond)
+
+ def dataloader(
+ self, dataset: torch.utils.data.Dataset, cond: Optional[torch.Tensor] = None
+ ) -> DataLoader:
+ if cond is None:
+ return DataLoader(
+ dataset,
+ batch_size=self.batch_size,
+ sampler=self.dataloader_sampler,
+ pin_memory=False,
+ shuffle=True,
+ )
+
+ cond_dataset = ConditionalDataset(dataset, cond)
+ return DataLoader(
+ cond_dataset,
+ batch_size=self.batch_size,
+ sampler=self.dataloader_sampler,
+ pin_memory=False,
+ shuffle=True,
+ )
+
+ def _train_epoch_generator(
+ self,
+ X: torch.Tensor,
+ fake_labels_generator: Callable,
+ true_labels_generator: Callable,
+ cond: Optional[torch.Tensor] = None,
+ ) -> float:
+ # Update the G network
+ self.generator.train()
+
+ real_X = X.to(self.device)
+ batch_size = len(real_X)
+
+ noise = self._get_noise(batch_size)
+
+ fake = self.generator(noise, cond=cond)
+
+ output = self.discriminator(fake, cond=cond).squeeze().float()
+ # Calculate G's loss based on this output
+ errG = -torch.mean(output)
+ for extra_loss in self.generator_extra_penalty_cbks:
+ errG += extra_loss(
+ real_X,
+ fake,
+ )
+
+ errG += self._extra_penalties(
+ self.generator_extra_penalties,
+ real_samples=real_X,
+ fake_samples=fake,
+ batch_size=batch_size,
+ )
+
+ # Calculate gradients for G
+ self.generator_optimizer.zero_grad()
+ errG.backward()
+
+ # Update G
+ if self.clipping_value > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.generator.parameters(), self.clipping_value
+ )
+ self.generator_optimizer.step()
+
+ if torch.isnan(errG):
+ raise RuntimeError("NaNs detected in the generator loss")
+
+ # Return loss
+ return errG.item()
+
+ def _train_epoch_discriminator(
+ self,
+ X: torch.Tensor,
+ fake_labels_generator: Callable,
+ true_labels_generator: Callable,
+ cond: Optional[torch.Tensor] = None,
+ ) -> float:
+ # Update the D network
+ self.discriminator.train()
+
+ errors = []
+
+ batch_size = min(self.batch_size, len(X))
+
+ for epoch in range(self.discriminator_n_iter):
+ # Train with all-real batch
+ real_X = X.to(self.device)
+
+ real_labels = true_labels_generator(X).to(self.device).squeeze()
+ real_output = self.discriminator(real_X, cond=cond).squeeze().float()
+
+ # Train with all-fake batch
+ noise = self._get_noise(batch_size)
+
+ fake = self.generator(noise, cond=cond)
+
+ fake_labels = fake_labels_generator(fake).to(self.device).squeeze().float()
+ fake_output = self.discriminator(fake.detach(), cond=cond).squeeze()
+
+ # Compute errors. Some fake inputs might be marked as real for privacy guarantees.
+
+ real_real_output = real_output[(real_labels * real_output) != 0]
+ real_fake_output = fake_output[(fake_labels * fake_output) != 0]
+ errD_real = torch.mean(torch.concat((real_real_output, real_fake_output)))
+
+ fake_real_output = real_output[((1 - real_labels) * real_output) != 0]
+ fake_fake_output = fake_output[((1 - fake_labels) * fake_output) != 0]
+ errD_fake = torch.mean(torch.concat((fake_real_output, fake_fake_output)))
+
+ penalty = self._loss_gradient_penalty(
+ real_samples=real_X,
+ fake_samples=fake,
+ cond=cond,
+ )
+ errD = -errD_real + errD_fake
+
+ self.discriminator_optimizer.zero_grad()
+ if self.dp_enabled:
+ # Adversarial loss
+ # 1. split fwd-bkwd on fake and real images into two explicit blocks.
+ # 2. no need to compute per_sample_gardients on fake data, disable hooks.
+ # 3. re-enable hooks to obtain per_sample_gardients for real data.
+ # fake fwd-bkwd
+ self.discriminator.disable_hooks()
+ penalty.backward(retain_graph=True)
+ errD_fake.backward(retain_graph=True)
+
+ self.discriminator.enable_hooks()
+ errD_real.backward() # HACK: calling bkwd without zero_grad() accumulates param gradients
+ else:
+ errD += penalty
+ errD.backward()
+
+ # Update D
+ if self.clipping_value > 0:
+ torch.nn.utils.clip_grad_norm_(
+ self.discriminator.parameters(), self.clipping_value
+ )
+ self.discriminator_optimizer.step()
+
+ errors.append(errD.item())
+
+ if np.isnan(np.mean(errors)):
+ raise RuntimeError("NaNs detected in the discriminator loss")
+
+ return np.mean(errors)
+
+ def _train_epoch(
+ self,
+ loader: DataLoader,
+ fake_labels_generator: Optional[Callable] = None,
+ true_labels_generator: Optional[Callable] = None,
+ ) -> Tuple[float, float]:
+ if fake_labels_generator is None:
+ fake_labels_generator = self.fake_labels_generator
+ if true_labels_generator is None:
+ true_labels_generator = self.true_labels_generator
+
+ G_losses = []
+ D_losses = []
+
+ for i, data in enumerate(loader):
+ cond: Optional[torch.Tensor] = None
+ if self.with_conditional:
+ X, cond = data
+ else:
+ X = data[0]
+
+ D_losses.append(
+ self._train_epoch_discriminator(
+ X,
+ fake_labels_generator=fake_labels_generator,
+ true_labels_generator=true_labels_generator,
+ cond=cond,
+ )
+ )
+ G_losses.append(
+ self._train_epoch_generator(
+ X,
+ fake_labels_generator=fake_labels_generator,
+ true_labels_generator=true_labels_generator,
+ cond=cond,
+ )
+ )
+
+ return np.mean(G_losses), np.mean(D_losses)
+
+ def _init_patience_score(self) -> float:
+ if self.patience_metric is None:
+ return 0
+
+ if self.patience_metric.direction() == "minimize":
+ return np.inf
+ else:
+ return -np.inf
+
+ def _evaluate_patience_metric(
+ self,
+ X: torch.Tensor,
+ cond: Optional[torch.Tensor],
+ prev_score: float,
+ patience: int,
+ ) -> Tuple[float, int, bool]:
+ save = False
+ if self.patience_metric is None:
+ return prev_score, patience, save
+
+ X_syn = self.generate(len(X), cond=cond)
+ new_score = self.patience_metric.evaluate(
+ ImageDataLoader(ConditionalDataset(X)),
+ ImageDataLoader(ConditionalDataset(X_syn)),
+ )
+ score = prev_score
+ if self.patience_metric.direction() == "minimize":
+ if new_score >= prev_score:
+ patience += 1
+ else:
+ patience = 0
+ score = new_score
+ save = True
+ else:
+ if new_score <= prev_score:
+ patience += 1
+ else:
+ patience = 0
+ score = new_score
+ save = True
+
+ return score, patience, save
+
+ def _train_test_split(
+ self, X: FlexibleDataset, cond: Optional[torch.Tensor] = None
+ ) -> Tuple:
+ if self.patience_metric is None:
+ return X, cond, None, None
+
+ if self.dataloader_sampler is not None:
+ train_idx, test_idx = self.dataloader_sampler.train_test()
+ else:
+ total = np.arange(0, len(X))
+ np.random.shuffle(total)
+ split = int(len(total) * 0.8)
+ train_idx, test_idx = total[:split], total[split:]
+
+ X_train, X_val = X.filter_indices(train_idx), X.filter_indices(test_idx)
+ cond_train, cond_val = None, None
+ if cond is not None:
+ cond_train, cond_val = cond[train_idx], cond[test_idx]
+ return X_train, cond_train, X_val, cond_val
+
+ def _train(
+ self,
+ X: FlexibleDataset,
+ cond: Optional[torch.Tensor] = None,
+ fake_labels_generator: Optional[Callable] = None,
+ true_labels_generator: Optional[Callable] = None,
+ ) -> "ImageGAN":
+ self.train()
+
+ X, cond, X_val, cond_val = self._train_test_split(X, cond)
+
+ # Load Dataset
+ loader = self.dataloader(X, cond)
+
+ # Create the optimizers
+ self.generator_optimizer = torch.optim.Adam(
+ self.generator.parameters(),
+ lr=self.generator_lr,
+ weight_decay=self.generator_weight_decay,
+ betas=self.generator_opt_betas,
+ )
+ self.discriminator_optimizer = torch.optim.Adam(
+ self.discriminator.parameters(),
+ lr=self.discriminator_lr,
+ weight_decay=self.discriminator_weight_decay,
+ betas=self.discriminator_opt_betas,
+ )
+
+ # Privacy
+ if self.dp_enabled:
+ if self.dp_delta is None:
+ self.dp_delta = 1 / len(X)
+
+ privacy_engine = PrivacyEngine(secure_mode=self.dp_secure_mode)
+
+ (
+ self.discriminator,
+ self.discriminator_optimizer,
+ loader,
+ ) = privacy_engine.make_private_with_epsilon(
+ module=self.discriminator,
+ optimizer=self.discriminator_optimizer,
+ data_loader=loader,
+ epochs=self.generator_n_iter,
+ target_epsilon=self.dp_epsilon,
+ target_delta=self.dp_delta,
+ max_grad_norm=self.dp_max_grad_norm,
+ poisson_sampling=False,
+ )
+
+ # Train loop
+ patience_score = self._init_patience_score()
+ patience = 0
+ best_state_dict = None
+
+ for i in tqdm(range(self.generator_n_iter)):
+ g_loss, d_loss = self._train_epoch(
+ loader,
+ fake_labels_generator=fake_labels_generator,
+ true_labels_generator=true_labels_generator,
+ )
+ # Check how the generator is doing by saving G's output on fixed_noise
+ if (i + 1) % self.n_iter_print == 0:
+ if self.plot_progress:
+ display_imgs(self.generate(5), title="synthetic samples")
+ display_imgs(
+ torch.from_numpy(X_val.numpy()[0])[:5], title="real samples"
+ )
+
+ log.debug(
+ f"[{i}/{self.generator_n_iter}]\tLoss_D: {d_loss}\tLoss_G: {g_loss} Patience score: {patience_score} Patience : {patience}"
+ )
+ if self.dp_enabled:
+ log.debug(
+ f"[{i}/{self.generator_n_iter}] Privacy budget: epsilon = {privacy_engine.get_epsilon(self.dp_delta)} delta = {self.dp_delta}"
+ )
+
+ if self.patience_metric is not None:
+ patience_score, patience, save = self._evaluate_patience_metric(
+ torch.from_numpy(X_val.numpy()[0]),
+ cond_val,
+ patience_score,
+ patience,
+ )
+ if save:
+ best_state_dict = self.state_dict()
+
+ if patience >= self.patience and i >= self.n_iter_min:
+ log.debug(f"[{i}/{self.generator_n_iter}] Early stopping")
+ break
+
+ if best_state_dict is not None:
+ self.load_state_dict(best_state_dict)
+
+ return self
+
+ def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
+ if isinstance(X, torch.Tensor):
+ return X.to(self.device)
+ else:
+ return torch.from_numpy(np.asarray(X)).to(self.device)
+
+ def _extra_penalties(
+ self,
+ penalties: list,
+ real_samples: torch.tensor,
+ fake_samples: torch.Tensor,
+ batch_size: int,
+ ) -> torch.Tensor:
+ """Calculates additional penalties for the training"""
+ err: torch.Tensor = 0
+ for penalty in penalties:
+ if penalty == "identifiability_penalty":
+ err += self._loss_identifiability_penalty(
+ real_samples=real_samples,
+ fake_samples=fake_samples,
+ )
+ else:
+ raise RuntimeError(f"unknown penalty {penalty}")
+ return err
+
+ def _loss_gradient_penalty(
+ self,
+ real_samples: torch.Tensor,
+ fake_samples: torch.Tensor,
+ cond: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ batch_size, channel, height, width = real_samples.shape
+ # alpha is selected randomly between 0 and 1
+ alpha = (
+ torch.rand(batch_size, 1, 1, 1)
+ .repeat(1, channel, height, width)
+ .to(self.device)
+ )
+ # interpolated image=randomly weighted average between a real and fake image
+ interpolatted_samples = (alpha * real_samples) + (1 - alpha) * fake_samples
+
+ # calculate the critic score on the interpolated image
+ interpolated_score = self.discriminator(interpolatted_samples, cond=cond)
+
+ # take the gradient of the score wrt to the interpolated image
+ gradient = torch.autograd.grad(
+ inputs=interpolatted_samples,
+ outputs=interpolated_score,
+ grad_outputs=torch.ones_like(interpolated_score),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ allow_unused=True,
+ )[0]
+ gradient = gradient.view(gradient.shape[0], -1)
+ gradient_norm = gradient.norm(2, dim=1)
+ gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
+ return self.lambda_gradient_penalty * gradient_penalty
+
+ def _loss_identifiability_penalty(
+ self,
+ real_samples: torch.tensor,
+ fake_samples: torch.Tensor,
+ ) -> torch.Tensor:
+ """Calculates the identifiability penalty. Section C in the paper"""
+ return (
+ -self.lambda_identifiability_penalty
+ * (real_samples - fake_samples).square().sum(dim=-1).sqrt().mean()
+ )
diff --git a/src/synthcity/plugins/core/models/tabular_vae.py b/src/synthcity/plugins/core/models/tabular_vae.py
index fcb88ccc..892433e2 100644
--- a/src/synthcity/plugins/core/models/tabular_vae.py
+++ b/src/synthcity/plugins/core/models/tabular_vae.py
@@ -60,7 +60,7 @@ class TabularVAE(nn.Module):
encoder_dropout: float
Dropout value for the encoder. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/core/models/ts_gan.py b/src/synthcity/plugins/core/models/ts_gan.py
index 6bdd372b..2ca677a2 100644
--- a/src/synthcity/plugins/core/models/ts_gan.py
+++ b/src/synthcity/plugins/core/models/ts_gan.py
@@ -71,7 +71,7 @@ class TimeSeriesGAN(nn.Module):
discriminator_dropout: float. Default = 0.1
Dropout value for the discriminator. If 0, the dropout is not used.
discriminator_lr: float. Default = 2e-4
- learning rate for discriminator optimizer. step_size equivalent in the JAX version.
+ learning rate for discriminator optimizer.
discriminator_weight_decay: float. Default = 1e-3
l2 (ridge) penalty for the discriminator weights.
batch_size: int. Default = 64
diff --git a/src/synthcity/plugins/core/models/ts_tabular_gan.py b/src/synthcity/plugins/core/models/ts_tabular_gan.py
index c20d54b4..dc05bfa7 100644
--- a/src/synthcity/plugins/core/models/ts_tabular_gan.py
+++ b/src/synthcity/plugins/core/models/ts_tabular_gan.py
@@ -66,7 +66,7 @@ class TimeSeriesTabularGAN(torch.nn.Module):
discriminator_dropout: float. Default = 0.1
Dropout value for the discriminator. If 0, the dropout is not used.
discriminator_lr: float. Default = 2e-4
- learning rate for discriminator optimizer. step_size equivalent in the JAX version.
+ learning rate for discriminator optimizer.
discriminator_weight_decay: float. Default = 1e-3
l2 (ridge) penalty for the discriminator weights.
batch_size: int. Default = 64
diff --git a/src/synthcity/plugins/core/models/ts_tabular_vae.py b/src/synthcity/plugins/core/models/ts_tabular_vae.py
index f596445e..ce5b778d 100644
--- a/src/synthcity/plugins/core/models/ts_tabular_vae.py
+++ b/src/synthcity/plugins/core/models/ts_tabular_vae.py
@@ -54,7 +54,7 @@ class TimeSeriesTabularVAE(torch.nn.Module):
encoder_dropout: float
Dropout value for the encoder. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/core/plugin.py b/src/synthcity/plugins/core/plugin.py
index a8eb0199..725974af 100644
--- a/src/synthcity/plugins/core/plugin.py
+++ b/src/synthcity/plugins/core/plugin.py
@@ -221,18 +221,19 @@ def fit(self, X: Union[DataLoader, pd.DataFrame], *args: Any, **kwargs: Any) ->
random_state=self.random_state,
)
- X, self._data_encoders = X.encode()
- if self.compress_dataset:
- X_hash = X.hash()
- bkp_file = (
- self.workspace
- / f"compressed_df_{X_hash}_{platform.python_version()}.bkp"
- )
- if not bkp_file.exists():
- X_compressed_context = X.compress()
- save_to_file(bkp_file, X_compressed_context)
+ if X.is_tabular():
+ X, self._data_encoders = X.encode()
+ if self.compress_dataset:
+ X_hash = X.hash()
+ bkp_file = (
+ self.workspace
+ / f"compressed_df_{X_hash}_{platform.python_version()}.bkp"
+ )
+ if not bkp_file.exists():
+ X_compressed_context = X.compress()
+ save_to_file(bkp_file, X_compressed_context)
- X, self.compress_context = load_from_file(bkp_file)
+ X, self.compress_context = load_from_file(bkp_file)
self._training_schema = Schema(
data=X,
@@ -336,10 +337,12 @@ def generate(
syn_schema = Schema.from_constraints(gen_constraints)
X_syn = self._generate(count=count, syn_schema=syn_schema, **kwargs)
- if self.compress_dataset:
- X_syn = X_syn.decompress(self.compress_context)
- if self._data_encoders is not None:
- X_syn = X_syn.decode(self._data_encoders)
+
+ if X_syn.is_tabular():
+ if self.compress_dataset:
+ X_syn = X_syn.decompress(self.compress_context)
+ if self._data_encoders is not None:
+ X_syn = X_syn.decode(self._data_encoders)
# The dataset is decompressed here, we can use the public schema
gen_constraints = self.schema().as_constraints()
@@ -470,6 +473,14 @@ def _safe_generate_time_series(
data_synth = self.training_schema().adapt_dtypes(data_synth)
return create_from_info(data_synth, data_info)
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def _safe_generate_images(
+ self, gen_cbk: Callable, count: int, syn_schema: Schema, **kwargs: Any
+ ) -> DataLoader:
+ data_synth = gen_cbk(count, **kwargs)
+
+ return create_from_info(data_synth, self.data_info)
+
@validate_arguments(config=dict(arbitrary_types_allowed=True))
def schema_includes(self, other: Union[DataLoader, pd.DataFrame]) -> bool:
"""Helper method to test if the reference schema includes a Dataset
diff --git a/src/synthcity/plugins/domain_adaptation/plugin_radialgan.py b/src/synthcity/plugins/domain_adaptation/plugin_radialgan.py
index 60a7c5f5..9287b84e 100644
--- a/src/synthcity/plugins/domain_adaptation/plugin_radialgan.py
+++ b/src/synthcity/plugins/domain_adaptation/plugin_radialgan.py
@@ -70,7 +70,7 @@ class RadialGAN(nn.Module):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
@@ -594,7 +594,7 @@ class TabularRadialGAN(torch.nn.Module):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer..
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
@@ -754,7 +754,7 @@ class RadialGANPlugin(Plugin):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/generic/plugin_ctgan.py b/src/synthcity/plugins/generic/plugin_ctgan.py
index d284491f..f2816280 100644
--- a/src/synthcity/plugins/generic/plugin_ctgan.py
+++ b/src/synthcity/plugins/generic/plugin_ctgan.py
@@ -59,7 +59,7 @@ class CTGANPlugin(Plugin):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/generic/plugin_rtvae.py b/src/synthcity/plugins/generic/plugin_rtvae.py
index 731bcd65..edc6d9de 100644
--- a/src/synthcity/plugins/generic/plugin_rtvae.py
+++ b/src/synthcity/plugins/generic/plugin_rtvae.py
@@ -56,7 +56,7 @@ class RTVAEPlugin(Plugin):
n_iter: int
Maximum number of iterations in the encoder.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/generic/plugin_tvae.py b/src/synthcity/plugins/generic/plugin_tvae.py
index 3eb92e63..34ad5271 100644
--- a/src/synthcity/plugins/generic/plugin_tvae.py
+++ b/src/synthcity/plugins/generic/plugin_tvae.py
@@ -51,7 +51,7 @@ class TVAEPlugin(Plugin):
n_iter: int
Maximum number of iterations in the encoder.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/images/__init__.py b/src/synthcity/plugins/images/__init__.py
new file mode 100644
index 00000000..b494cf09
--- /dev/null
+++ b/src/synthcity/plugins/images/__init__.py
@@ -0,0 +1,18 @@
+# stdlib
+import glob
+from os.path import basename, dirname, isfile, join
+
+# synthcity absolute
+from synthcity.plugins.core.plugin import Plugin, PluginLoader # noqa: F401,E402
+
+plugins = glob.glob(join(dirname(__file__), "plugin*.py"))
+
+
+class ImagePlugins(PluginLoader):
+ def __init__(self) -> None:
+ super().__init__(plugins, Plugin, ["privacy"])
+
+
+__all__ = [basename(f)[:-3] for f in plugins if isfile(f)] + [
+ "ImagePlugins",
+]
diff --git a/src/synthcity/plugins/images/plugin_image_adsgan.py b/src/synthcity/plugins/images/plugin_image_adsgan.py
new file mode 100644
index 00000000..a47758ee
--- /dev/null
+++ b/src/synthcity/plugins/images/plugin_image_adsgan.py
@@ -0,0 +1,338 @@
+# stdlib
+from pathlib import Path
+from typing import Any, List, Optional
+
+# third party
+import numpy as np
+import torch
+
+# Necessary packages
+from pydantic import validate_arguments
+from torch import nn
+
+# synthcity absolute
+import synthcity.logger as log
+from synthcity.metrics.weighted_metrics import WeightedMetrics
+from synthcity.plugins.core.dataloader import DataLoader
+from synthcity.plugins.core.dataset import TensorDataset
+from synthcity.plugins.core.distribution import (
+ CategoricalDistribution,
+ Distribution,
+ FloatDistribution,
+)
+from synthcity.plugins.core.models.convnet import (
+ suggest_image_classifier_arch,
+ suggest_image_generator_discriminator_arch,
+)
+from synthcity.plugins.core.models.image_gan import ImageGAN
+from synthcity.plugins.core.plugin import Plugin
+from synthcity.plugins.core.schema import Schema
+from synthcity.utils.constants import DEVICE
+
+
+class ImageAdsGANPlugin(Plugin):
+ """
+ .. inheritance-diagram:: synthcity.plugins.images.plugin_image_adsgan.ImageAdsGANPlugin
+ :parts: 1
+
+ Image AdsGAN - Anonymization through Data Synthesis using Generative Adversarial Networks.
+
+ Args:
+ n_units_latent: int
+ The noise units size used by the generator.
+ n_iter: int
+ Maximum number of iterations in the Generator.
+ generator_nonlin: string, default 'leaky_relu'
+ Nonlinearity to use in the generator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.
+ generator_dropout: float
+ Dropout value. If 0, the dropout is not used.
+ generator_n_residual_units: int
+ The number of convolutions in residual units for the generator, 0 means no residual units
+ discriminator_nonlin: string, default 'leaky_relu'
+ Nonlinearity to use in the discriminator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.
+ discriminator_n_iter: int
+ Maximum number of iterations in the discriminator.
+ discriminator_dropout: float
+ Dropout value for the discriminator. If 0, the dropout is not used.
+ discriminator_n_residual_units: int
+ The number of convolutions in residual units for the discriminator, 0 means no residual units
+ # training parameters
+ lr: float
+ learning rate for optimizer
+ weight_decay: float
+ l2 (ridge) penalty for the weights.
+ batch_size: int
+ Batch size
+ random_state: int
+ random seed to use
+ clipping_value: int, default 0
+ Gradients clipping value. Zero disables the feature
+ lambda_gradient_penalty: float = 10
+ Weight for the gradient penalty
+ lambda_identifiability_penalty: float = 0.1
+ Weight for the identifiability penalty
+ device: torch device
+ Device: cpu or cuda
+ plot_progress: bool
+ Plot some synthetic samples every `n_iter_print`
+ # early stopping
+ n_iter_print: int
+ Number of iterations after which to print updates and check the validation loss.
+ n_iter_min: int
+ Minimum number of iterations to go through before starting early stopping
+ early_stopping: bool
+ Evaluate the quality of the synthetic data using `patience_metric`, and stop after `patience` iteration with no improvement.
+ patience: int
+ Max number of iterations without any improvement before training early stopping is trigged.
+ patience_metric: Optional[WeightedMetrics]
+ If not None, the metric is used for evaluation the criterion for training early stopping.
+ # Core Plugin arguments
+ workspace: Path.
+ Optional Path for caching intermediary results.
+
+ Example:
+ >>> from torchvision import datasets
+ >>> from synthcity.plugins import Plugins
+ >>> from synthcity.plugins.core.dataloader import ImageDataLoader
+ >>>
+ >>> model = Plugins().get("image_adsgan", n_iter = 10)
+ >>>
+ >>> dataset = datasets.MNIST(".", download=True)
+ >>> dataloader = ImageDataLoader(dataset).sample(100)
+ >>>
+ >>> model.fit(dataloader)
+ >>>
+ >>> X_gen = model.generate(50)
+ >>> assert len(X_gen) == 50
+ """
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ n_units_latent: int = 100,
+ n_iter: int = 1000,
+ generator_nonlin: str = "relu",
+ generator_dropout: float = 0.1,
+ generator_n_residual_units: int = 2,
+ discriminator_nonlin: str = "leaky_relu",
+ discriminator_n_iter: int = 5,
+ discriminator_dropout: float = 0.1,
+ discriminator_n_residual_units: int = 2,
+ # training
+ lr: float = 2e-4,
+ weight_decay: float = 1e-3,
+ opt_betas: tuple = (0.5, 0.999),
+ batch_size: int = 200,
+ random_state: int = 0,
+ clipping_value: int = 1,
+ lambda_gradient_penalty: float = 10,
+ lambda_identifiability_penalty: float = 0.1,
+ device: Any = DEVICE,
+ # early stopping
+ patience: int = 5,
+ patience_metric: Optional[WeightedMetrics] = None,
+ n_iter_print: int = 50,
+ n_iter_min: int = 100,
+ plot_progress: int = False,
+ early_stopping: bool = True,
+ # core plugin arguments
+ workspace: Path = Path("workspace"),
+ sampling_patience: int = 500,
+ **kwargs: Any
+ ) -> None:
+ super().__init__(
+ device=device,
+ random_state=random_state,
+ sampling_patience=sampling_patience,
+ workspace=workspace,
+ compress_dataset=False,
+ **kwargs
+ )
+ if patience_metric is None:
+ patience_metric = WeightedMetrics(
+ metrics=[("detection", "detection_mlp")],
+ weights=[1],
+ workspace=workspace,
+ )
+
+ self.n_units_latent = n_units_latent
+ self.n_iter = n_iter
+ self.generator_nonlin = generator_nonlin
+ self.generator_dropout = generator_dropout
+ self.generator_n_residual_units = generator_n_residual_units
+ self.discriminator_nonlin = discriminator_nonlin
+ self.discriminator_n_iter = discriminator_n_iter
+ self.discriminator_dropout = discriminator_dropout
+ self.discriminator_n_residual_units = discriminator_n_residual_units
+
+ self.lr = lr
+ self.weight_decay = weight_decay
+ self.opt_betas = opt_betas
+
+ self.batch_size = batch_size
+ self.random_state = random_state
+ self.clipping_value = clipping_value
+ self.lambda_gradient_penalty = lambda_gradient_penalty
+ self.lambda_identifiability_penalty = lambda_identifiability_penalty
+
+ self.device = device
+ self.patience = patience
+ self.patience_metric = patience_metric
+ self.early_stopping = early_stopping
+ self.n_iter_min = n_iter_min
+ self.n_iter_print = n_iter_print
+ self.plot_progress = plot_progress
+
+ @staticmethod
+ def name() -> str:
+ return "image_adsgan"
+
+ @staticmethod
+ def type() -> str:
+ return "images"
+
+ @staticmethod
+ def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
+ return [
+ CategoricalDistribution(
+ name="generator_nonlin", choices=["relu", "leaky_relu", "tanh", "elu"]
+ ),
+ FloatDistribution(name="generator_dropout", low=0, high=0.2),
+ CategoricalDistribution(
+ name="discriminator_nonlin",
+ choices=["relu", "leaky_relu", "tanh", "elu"],
+ ),
+ FloatDistribution(name="discriminator_dropout", low=0, high=0.2),
+ CategoricalDistribution(name="lr", choices=[1e-3, 2e-4, 1e-4]),
+ CategoricalDistribution(name="weight_decay", choices=[1e-3, 1e-4]),
+ ]
+
+ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "ImageAdsGANPlugin":
+ if X.type() != "images":
+ raise RuntimeError("Invalid dataloader type for image generators")
+
+ labels = X.unpack().labels()
+ self.classes = np.unique(labels)
+
+ cond = labels
+ if "cond" in kwargs:
+ cond = kwargs["cond"]
+
+ cond = self._prepare_cond(cond)
+
+ # synthetic images
+ (
+ image_generator,
+ image_discriminator,
+ ) = suggest_image_generator_discriminator_arch(
+ n_units_latent=self.n_units_latent,
+ n_channels=X.info()["channels"],
+ height=X.info()["height"],
+ width=X.info()["width"],
+ generator_dropout=self.generator_dropout,
+ generator_nonlin=self.generator_nonlin,
+ generator_n_residual_units=self.generator_n_residual_units,
+ discriminator_dropout=self.discriminator_dropout,
+ discriminator_nonlin=self.discriminator_nonlin,
+ discriminator_n_residual_units=self.discriminator_n_residual_units,
+ device=self.device,
+ strategy="predefined",
+ cond=cond,
+ cond_embedding_n_units_hidden=self.n_units_latent,
+ )
+
+ log.debug("Training the image generator")
+ self.image_generator = ImageGAN(
+ image_generator=image_generator,
+ image_discriminator=image_discriminator,
+ n_units_latent=self.n_units_latent,
+ n_channels=X.info()["channels"],
+ # generator
+ generator_n_iter=self.n_iter,
+ generator_lr=self.lr,
+ generator_weight_decay=self.weight_decay,
+ generator_opt_betas=self.opt_betas,
+ generator_extra_penalties=["identifiability_penalty"],
+ # discriminator
+ discriminator_n_iter=self.discriminator_n_iter,
+ discriminator_lr=self.lr,
+ discriminator_weight_decay=self.weight_decay,
+ discriminator_opt_betas=self.opt_betas,
+ # training
+ batch_size=self.batch_size,
+ random_state=self.random_state,
+ clipping_value=self.clipping_value,
+ lambda_gradient_penalty=self.lambda_gradient_penalty,
+ lambda_identifiability_penalty=self.lambda_identifiability_penalty,
+ device=self.device,
+ n_iter_min=self.n_iter_min,
+ n_iter_print=self.n_iter_print,
+ plot_progress=self.plot_progress,
+ patience=self.patience,
+ patience_metric=self.patience_metric,
+ )
+ self.image_generator.fit(X.unpack(), cond=cond)
+
+ # synthetic labels
+ self.label_generator: Optional[nn.Module] = None
+
+ if labels is not None: # TODO: handle regression
+ log.debug("Training the labels generator")
+ self.label_generator = suggest_image_classifier_arch(
+ n_channels=X.info()["channels"],
+ height=X.info()["height"],
+ width=X.info()["width"],
+ classes=len(np.unique(labels)),
+ n_residual_units=self.generator_n_residual_units,
+ nonlin=self.generator_nonlin,
+ dropout=self.generator_dropout,
+ last_nonlin="softmax",
+ device=self.device,
+ strategy="predefined",
+ # training
+ lr=self.lr,
+ weight_decay=self.weight_decay,
+ opt_betas=self.opt_betas,
+ n_iter=self.n_iter,
+ batch_size=self.batch_size,
+ n_iter_print=self.n_iter_print,
+ random_state=self.random_state,
+ patience=self.patience,
+ n_iter_min=self.n_iter_min,
+ clipping_value=self.clipping_value,
+ early_stopping=self.early_stopping,
+ )
+ self.label_generator.fit(X.unpack())
+
+ return self
+
+ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader:
+ def _sample(count: int) -> TensorDataset:
+ cond: Optional[torch.Tensor] = None
+ if "cond" in kwargs:
+ cond = self._prepare_cond(kwargs["cond"])
+ elif self.classes is not None:
+ cond = np.random.choice(self.classes, count)
+ cond = torch.from_numpy(cond).to(self.device)
+
+ sampled_images = self.image_generator.generate(count, cond=cond)
+ sampled_labels: Optional[torch.Tensor] = None
+ if self.label_generator is not None:
+ sampled_labels = self.label_generator.predict(sampled_images)
+
+ return TensorDataset(images=sampled_images, targets=sampled_labels)
+
+ return self._safe_generate_images(_sample, count, syn_schema)
+
+ def _prepare_cond(self, cond: Any) -> Optional[torch.Tensor]:
+ if cond is None:
+ return None
+
+ cond = np.asarray(cond)
+ if len(cond.shape) == 1:
+ cond = cond.reshape(-1, 1)
+
+ return torch.from_numpy(cond).to(self.device)
+
+
+plugin = ImageAdsGANPlugin
diff --git a/src/synthcity/plugins/images/plugin_image_cgan.py b/src/synthcity/plugins/images/plugin_image_cgan.py
new file mode 100644
index 00000000..8b818d49
--- /dev/null
+++ b/src/synthcity/plugins/images/plugin_image_cgan.py
@@ -0,0 +1,331 @@
+# stdlib
+from pathlib import Path
+from typing import Any, List, Optional
+
+# third party
+import numpy as np
+import torch
+
+# Necessary packages
+from pydantic import validate_arguments
+from torch import nn
+
+# synthcity absolute
+import synthcity.logger as log
+from synthcity.metrics.weighted_metrics import WeightedMetrics
+from synthcity.plugins.core.dataloader import DataLoader
+from synthcity.plugins.core.dataset import TensorDataset
+from synthcity.plugins.core.distribution import (
+ CategoricalDistribution,
+ Distribution,
+ FloatDistribution,
+)
+from synthcity.plugins.core.models.convnet import (
+ suggest_image_classifier_arch,
+ suggest_image_generator_discriminator_arch,
+)
+from synthcity.plugins.core.models.image_gan import ImageGAN
+from synthcity.plugins.core.plugin import Plugin
+from synthcity.plugins.core.schema import Schema
+from synthcity.utils.constants import DEVICE
+
+
+class ImageCGANPlugin(Plugin):
+ """
+ .. inheritance-diagram:: synthcity.plugins.images.plugin_image_cgan.ImageCGANPlugin
+ :parts: 1
+
+ Image (Conditional) GAN
+
+ Args:
+ n_units_latent: int
+ The noise units size used by the generator.
+ n_iter: int
+ Maximum number of iterations in the Generator.
+ generator_nonlin: string, default 'leaky_relu'
+ Nonlinearity to use in the generator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.
+ generator_dropout: float
+ Dropout value. If 0, the dropout is not used.
+ generator_n_residual_units: int
+ The number of convolutions in residual units for the generator, 0 means no residual units
+ discriminator_nonlin: string, default 'leaky_relu'
+ Nonlinearity to use in the discriminator. Can be 'elu', 'relu', 'selu' or 'leaky_relu'.
+ discriminator_n_iter: int
+ Maximum number of iterations in the discriminator.
+ discriminator_dropout: float
+ Dropout value for the discriminator. If 0, the dropout is not used.
+ discriminator_n_residual_units: int
+ The number of convolutions in residual units for the discriminator, 0 means no residual units
+ # training parameters
+ lr: float
+ learning rate for optimizer..
+ weight_decay: float
+ l2 (ridge) penalty for the weights.
+ batch_size: int
+ Batch size
+ random_state: int
+ random seed to use
+ clipping_value: int, default 0
+ Gradients clipping value. Zero disables the feature
+ device: torch device
+ Device: cpu or cuda
+ plot_progress: bool
+ Plot some synthetic samples every `n_iter_print`
+ # early stopping
+ n_iter_print: int
+ Number of iterations after which to print updates and check the validation loss.
+ n_iter_min: int
+ Minimum number of iterations to go through before starting early stopping
+ patience: int
+ Max number of iterations without any improvement before training early stopping is trigged.
+ patience_metric: Optional[WeightedMetrics]
+ If not None, the metric is used for evaluation the criterion for training early stopping.
+ early_stopping: bool
+ Evaluate the quality of the synthetic data using `patience_metric`, and stop after `patience` iteration with no improvement.
+ # Core Plugin arguments
+ workspace: Path.
+ Optional Path for caching intermediary results.
+
+ Example:
+ >>> from torchvision import datasets
+ >>> from synthcity.plugins import Plugins
+ >>> from synthcity.plugins.core.dataloader import ImageDataLoader
+ >>>
+ >>> model = Plugins().get("image_cgan", n_iter = 10)
+ >>>
+ >>> dataset = datasets.MNIST(".", download=True)
+ >>> dataloader = ImageDataLoader(dataset).sample(100)
+ >>>
+ >>> model.fit(dataloader)
+ >>>
+ >>> X_gen = model.generate(50)
+ >>> assert len(X_gen) == 50
+ """
+
+ @validate_arguments(config=dict(arbitrary_types_allowed=True))
+ def __init__(
+ self,
+ n_units_latent: int = 100,
+ n_iter: int = 1000,
+ generator_nonlin: str = "relu",
+ generator_dropout: float = 0.1,
+ generator_n_residual_units: int = 2,
+ discriminator_nonlin: str = "leaky_relu",
+ discriminator_n_iter: int = 5,
+ discriminator_dropout: float = 0.1,
+ discriminator_n_residual_units: int = 2,
+ # training
+ lr: float = 2e-4,
+ weight_decay: float = 1e-3,
+ opt_betas: tuple = (0.5, 0.999),
+ batch_size: int = 200,
+ random_state: int = 0,
+ clipping_value: int = 1,
+ lambda_gradient_penalty: float = 10,
+ device: Any = DEVICE,
+ # early stopping
+ patience: int = 5,
+ patience_metric: Optional[WeightedMetrics] = None,
+ n_iter_print: int = 50,
+ n_iter_min: int = 100,
+ plot_progress: int = False,
+ early_stopping: bool = True,
+ # core plugin arguments
+ workspace: Path = Path("workspace"),
+ sampling_patience: int = 500,
+ **kwargs: Any
+ ) -> None:
+ super().__init__(
+ device=device,
+ random_state=random_state,
+ sampling_patience=sampling_patience,
+ workspace=workspace,
+ compress_dataset=False,
+ **kwargs
+ )
+ if patience_metric is None:
+ patience_metric = WeightedMetrics(
+ metrics=[("detection", "detection_mlp")],
+ weights=[1],
+ workspace=workspace,
+ )
+
+ self.n_units_latent = n_units_latent
+ self.n_iter = n_iter
+ self.generator_nonlin = generator_nonlin
+ self.generator_dropout = generator_dropout
+ self.generator_n_residual_units = generator_n_residual_units
+ self.discriminator_nonlin = discriminator_nonlin
+ self.discriminator_n_iter = discriminator_n_iter
+ self.discriminator_dropout = discriminator_dropout
+ self.discriminator_n_residual_units = discriminator_n_residual_units
+
+ self.lr = lr
+ self.weight_decay = weight_decay
+ self.opt_betas = opt_betas
+
+ self.batch_size = batch_size
+ self.random_state = random_state
+ self.clipping_value = clipping_value
+ self.lambda_gradient_penalty = lambda_gradient_penalty
+
+ self.device = device
+ self.patience = patience
+ self.patience_metric = patience_metric
+ self.early_stopping = early_stopping
+ self.n_iter_min = n_iter_min
+ self.n_iter_print = n_iter_print
+ self.plot_progress = plot_progress
+
+ @staticmethod
+ def name() -> str:
+ return "image_cgan"
+
+ @staticmethod
+ def type() -> str:
+ return "images"
+
+ @staticmethod
+ def hyperparameter_space(**kwargs: Any) -> List[Distribution]:
+ return [
+ CategoricalDistribution(
+ name="generator_nonlin", choices=["relu", "leaky_relu", "tanh", "elu"]
+ ),
+ FloatDistribution(name="generator_dropout", low=0, high=0.2),
+ CategoricalDistribution(
+ name="discriminator_nonlin",
+ choices=["relu", "leaky_relu", "tanh", "elu"],
+ ),
+ FloatDistribution(name="discriminator_dropout", low=0, high=0.2),
+ CategoricalDistribution(name="lr", choices=[1e-3, 2e-4, 1e-4]),
+ CategoricalDistribution(name="weight_decay", choices=[1e-3, 1e-4]),
+ ]
+
+ def _fit(self, X: DataLoader, *args: Any, **kwargs: Any) -> "ImageCGANPlugin":
+ if X.type() != "images":
+ raise RuntimeError("Invalid dataloader type for image generators")
+
+ labels = X.unpack().labels()
+ self.classes = np.unique(labels)
+
+ cond = labels
+ if "cond" in kwargs:
+ cond = kwargs["cond"]
+
+ cond = self._prepare_cond(cond)
+
+ # synthetic images
+ (
+ image_generator,
+ image_discriminator,
+ ) = suggest_image_generator_discriminator_arch(
+ n_units_latent=self.n_units_latent,
+ n_channels=X.info()["channels"],
+ height=X.info()["height"],
+ width=X.info()["width"],
+ generator_dropout=self.generator_dropout,
+ generator_nonlin=self.generator_nonlin,
+ generator_n_residual_units=self.generator_n_residual_units,
+ discriminator_dropout=self.discriminator_dropout,
+ discriminator_nonlin=self.discriminator_nonlin,
+ discriminator_n_residual_units=self.discriminator_n_residual_units,
+ device=self.device,
+ strategy="predefined",
+ cond=cond,
+ cond_embedding_n_units_hidden=self.n_units_latent,
+ )
+
+ log.debug("Training the image generator")
+ self.image_generator = ImageGAN(
+ image_generator=image_generator,
+ image_discriminator=image_discriminator,
+ n_units_latent=self.n_units_latent,
+ n_channels=X.info()["channels"],
+ # generator
+ generator_n_iter=self.n_iter,
+ generator_lr=self.lr,
+ generator_weight_decay=self.weight_decay,
+ generator_opt_betas=self.opt_betas,
+ generator_extra_penalties=[],
+ # discriminator
+ discriminator_n_iter=self.discriminator_n_iter,
+ discriminator_lr=self.lr,
+ discriminator_weight_decay=self.weight_decay,
+ discriminator_opt_betas=self.opt_betas,
+ # training
+ batch_size=self.batch_size,
+ random_state=self.random_state,
+ clipping_value=self.clipping_value,
+ lambda_gradient_penalty=self.lambda_gradient_penalty,
+ device=self.device,
+ n_iter_min=self.n_iter_min,
+ n_iter_print=self.n_iter_print,
+ plot_progress=self.plot_progress,
+ patience=self.patience,
+ patience_metric=self.patience_metric,
+ )
+ self.image_generator.fit(X.unpack(), cond=cond)
+
+ # synthetic labels
+ self.label_generator: Optional[nn.Module] = None
+
+ if labels is not None: # TODO: handle regression
+ log.debug("Training the labels generator")
+ self.label_generator = suggest_image_classifier_arch(
+ n_channels=X.info()["channels"],
+ height=X.info()["height"],
+ width=X.info()["width"],
+ classes=len(np.unique(labels)),
+ n_residual_units=self.generator_n_residual_units,
+ nonlin=self.generator_nonlin,
+ dropout=self.generator_dropout,
+ last_nonlin="softmax",
+ device=self.device,
+ strategy="predefined",
+ # training
+ lr=self.lr,
+ weight_decay=self.weight_decay,
+ opt_betas=self.opt_betas,
+ n_iter=self.n_iter,
+ batch_size=self.batch_size,
+ n_iter_print=self.n_iter_print,
+ random_state=self.random_state,
+ patience=self.patience,
+ n_iter_min=self.n_iter_min,
+ clipping_value=self.clipping_value,
+ early_stopping=self.early_stopping,
+ )
+ self.label_generator.fit(X.unpack())
+
+ return self
+
+ def _generate(self, count: int, syn_schema: Schema, **kwargs: Any) -> DataLoader:
+ def _sample(count: int) -> TensorDataset:
+ cond: Optional[torch.Tensor] = None
+ if "cond" in kwargs:
+ cond = self._prepare_cond(kwargs["cond"])
+ elif self.classes is not None:
+ cond = np.random.choice(self.classes, count)
+ cond = torch.from_numpy(cond).to(self.device)
+
+ sampled_images = self.image_generator.generate(count, cond=cond)
+ sampled_labels: Optional[torch.Tensor] = None
+ if self.label_generator is not None:
+ sampled_labels = self.label_generator.predict(sampled_images)
+
+ return TensorDataset(images=sampled_images, targets=sampled_labels)
+
+ return self._safe_generate_images(_sample, count, syn_schema)
+
+ def _prepare_cond(self, cond: Any) -> Optional[torch.Tensor]:
+ if cond is None:
+ return None
+
+ cond = np.asarray(cond)
+ if len(cond.shape) == 1:
+ cond = cond.reshape(-1, 1)
+
+ return torch.from_numpy(cond).to(self.device)
+
+
+plugin = ImageCGANPlugin
diff --git a/src/synthcity/plugins/privacy/plugin_adsgan.py b/src/synthcity/plugins/privacy/plugin_adsgan.py
index 2cb423db..a9fbf275 100644
--- a/src/synthcity/plugins/privacy/plugin_adsgan.py
+++ b/src/synthcity/plugins/privacy/plugin_adsgan.py
@@ -61,7 +61,7 @@ class AdsGANPlugin(Plugin):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
@@ -74,6 +74,10 @@ class AdsGANPlugin(Plugin):
The max number of clusters to create for continuous columns when encoding
adjust_inference_sampling: bool
Adjust the marginal probabilities in the synthetic data to closer match the training set. Active only with the ConditionalSampler
+ lambda_gradient_penalty: float = 10
+ Weight for the gradient penalty
+ lambda_identifiability_penalty: float = 0.1
+ Weight for the identifiability penalty, if enabled
# early stopping
n_iter_print: int
Number of iterations after which to print updates and check the validation loss.
diff --git a/src/synthcity/plugins/privacy/plugin_dpgan.py b/src/synthcity/plugins/privacy/plugin_dpgan.py
index e39fcab1..50395449 100644
--- a/src/synthcity/plugins/privacy/plugin_dpgan.py
+++ b/src/synthcity/plugins/privacy/plugin_dpgan.py
@@ -57,7 +57,7 @@ class DPGANPlugin(Plugin):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/privacy/plugin_pategan.py b/src/synthcity/plugins/privacy/plugin_pategan.py
index 0535e874..e8178052 100644
--- a/src/synthcity/plugins/privacy/plugin_pategan.py
+++ b/src/synthcity/plugins/privacy/plugin_pategan.py
@@ -364,7 +364,7 @@ class PATEGANPlugin(Plugin):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/time_series/plugin_timegan.py b/src/synthcity/plugins/time_series/plugin_timegan.py
index 667ea2ac..a529a912 100644
--- a/src/synthcity/plugins/time_series/plugin_timegan.py
+++ b/src/synthcity/plugins/time_series/plugin_timegan.py
@@ -65,7 +65,7 @@ class TimeGANPlugin(Plugin):
discriminator_dropout: float
Dropout value for the discriminator. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/plugins/time_series/plugin_timevae.py b/src/synthcity/plugins/time_series/plugin_timevae.py
index 19181efd..3d852d5d 100644
--- a/src/synthcity/plugins/time_series/plugin_timevae.py
+++ b/src/synthcity/plugins/time_series/plugin_timevae.py
@@ -60,7 +60,7 @@ class TimeVAEPlugin(Plugin):
encoder_dropout: float
Dropout value for the encoder. If 0, the dropout is not used.
lr: float
- learning rate for optimizer. step_size equivalent in the JAX version.
+ learning rate for optimizer.
weight_decay: float
l2 (ridge) penalty for the weights.
batch_size: int
diff --git a/src/synthcity/utils/reproducibility.py b/src/synthcity/utils/reproducibility.py
index 9912cdcf..f9069870 100644
--- a/src/synthcity/utils/reproducibility.py
+++ b/src/synthcity/utils/reproducibility.py
@@ -8,9 +8,15 @@
def enable_reproducible_results(random_state: int = 0) -> None:
np.random.seed(random_state)
- torch.manual_seed(random_state)
+ try:
+ torch.manual_seed(random_state)
+ except BaseException:
+ pass
random.seed(random_state)
def clear_cache() -> None:
- torch.cuda.empty_cache()
+ try:
+ torch.cuda.empty_cache()
+ except BaseException:
+ pass
diff --git a/src/synthcity/version.py b/src/synthcity/version.py
index 1d2cac53..0d89953c 100644
--- a/src/synthcity/version.py
+++ b/src/synthcity/version.py
@@ -1,4 +1,4 @@
-__version__ = "0.1.9"
+__version__ = "0.2.1"
MAJOR_VERSION = ".".join(__version__.split(".")[:-1])
MINOR_VERSION = __version__.split(".")[-1]
diff --git a/tests/metrics/test_detection.py b/tests/metrics/test_detection.py
index 30588d83..f4a8341d 100644
--- a/tests/metrics/test_detection.py
+++ b/tests/metrics/test_detection.py
@@ -6,6 +6,7 @@
import pandas as pd
import pytest
from sklearn.datasets import load_iris
+from torchvision import datasets
# synthcity absolute
from synthcity.metrics.eval_detection import (
@@ -15,7 +16,11 @@
SyntheticDetectionXGB,
)
from synthcity.plugins import Plugin, Plugins
-from synthcity.plugins.core.dataloader import GenericDataLoader, TimeSeriesDataLoader
+from synthcity.plugins.core.dataloader import (
+ GenericDataLoader,
+ ImageDataLoader,
+ TimeSeriesDataLoader,
+)
from synthcity.utils.datasets.time_series.google_stocks import GoogleStocksDataloader
@@ -147,3 +152,22 @@ def test_detect_synth_timeseries(test_plugin: Plugin, evaluator_t: Type) -> None
assert evaluator.type() == "detection"
assert evaluator.direction() == "minimize"
+
+
+def test_image_support_detection() -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ X1 = ImageDataLoader(dataset).sample(100)
+ X2 = ImageDataLoader(dataset).sample(100)
+
+ for evaluator in [
+ SyntheticDetectionGMM,
+ SyntheticDetectionLinear,
+ SyntheticDetectionXGB,
+ SyntheticDetectionMLP,
+ ]:
+ score = evaluator().evaluate(X1, X2)
+ assert isinstance(score, dict)
+ for k in score:
+ assert score[k] >= 0
+ assert not np.isnan(score[k])
diff --git a/tests/metrics/test_performance.py b/tests/metrics/test_performance.py
index 45edeac6..f7bd91db 100644
--- a/tests/metrics/test_performance.py
+++ b/tests/metrics/test_performance.py
@@ -8,6 +8,7 @@
import pytest
from lifelines.datasets import load_rossi
from sklearn.datasets import load_diabetes, load_iris
+from torchvision import datasets
# synthcity absolute
from synthcity.metrics.eval_performance import (
@@ -19,6 +20,7 @@
from synthcity.plugins import Plugin, Plugins
from synthcity.plugins.core.dataloader import (
GenericDataLoader,
+ ImageDataLoader,
SurvivalAnalysisDataLoader,
TimeSeriesDataLoader,
TimeSeriesSurvivalDataLoader,
@@ -462,3 +464,19 @@ def test_evaluate_performance_time_series_survival(
def_score = evaluator.evaluate_default(data, data_gen)
assert def_score == good_score["syn_id.c_index"] - good_score["syn_id.brier_score"]
+
+
+def test_image_support_perf() -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ X1 = ImageDataLoader(dataset).sample(100)
+ X2 = ImageDataLoader(dataset).sample(100)
+
+ for evaluator in [
+ PerformanceEvaluatorMLP,
+ ]:
+ score = evaluator().evaluate(X1, X2)
+ assert isinstance(score, dict)
+ for k in score:
+ assert score[k] >= 0
+ assert not np.isnan(score[k])
diff --git a/tests/metrics/test_privacy.py b/tests/metrics/test_privacy.py
index e6f9449f..30391dc0 100644
--- a/tests/metrics/test_privacy.py
+++ b/tests/metrics/test_privacy.py
@@ -2,8 +2,10 @@
from typing import Type
# third party
+import numpy as np
import pytest
from sklearn.datasets import load_diabetes
+from torchvision import datasets
# synthcity absolute
from synthcity.metrics.eval_privacy import (
@@ -14,7 +16,7 @@
lDiversityDistinct,
)
from synthcity.plugins import Plugin, Plugins
-from synthcity.plugins.core.dataloader import GenericDataLoader
+from synthcity.plugins.core.dataloader import GenericDataLoader, ImageDataLoader
@pytest.mark.parametrize(
@@ -49,3 +51,19 @@ def test_evaluator(evaluator_t: Type, test_plugin: Plugin) -> None:
def_score = evaluator.evaluate_default(Xloader, X_gen)
assert isinstance(def_score, (float, int))
+
+
+def test_image_support() -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ X1 = ImageDataLoader(dataset).sample(100)
+ X2 = ImageDataLoader(dataset).sample(100)
+
+ for evaluator in [
+ IdentifiabilityScore,
+ ]:
+ score = evaluator().evaluate(X1, X2)
+ assert isinstance(score, dict)
+ for k in score:
+ assert score[k] >= 0
+ assert not np.isnan(score[k])
diff --git a/tests/metrics/test_sanity.py b/tests/metrics/test_sanity.py
index a2a225b7..b75c0ca6 100644
--- a/tests/metrics/test_sanity.py
+++ b/tests/metrics/test_sanity.py
@@ -6,6 +6,7 @@
import pandas as pd
import pytest
from sklearn.datasets import load_iris
+from torchvision import datasets
# synthcity absolute
from synthcity.metrics.eval_sanity import (
@@ -16,7 +17,11 @@
NearestSyntheticNeighborDistance,
)
from synthcity.plugins import Plugin, Plugins
-from synthcity.plugins.core.dataloader import DataLoader, GenericDataLoader
+from synthcity.plugins.core.dataloader import (
+ DataLoader,
+ GenericDataLoader,
+ ImageDataLoader,
+)
def _eval_plugin(cbk: Callable, X: DataLoader, X_syn: DataLoader) -> Tuple:
@@ -187,3 +192,23 @@ def test_evaluate_distant_values(test_plugin: Plugin) -> None:
def_score = evaluator.evaluate_default(Xloader, X_gen)
assert isinstance(def_score, float)
+
+
+def test_image_support() -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ X1 = ImageDataLoader(dataset).sample(100)
+ X2 = ImageDataLoader(dataset).sample(100)
+
+ for evaluator in [
+ CloseValuesProbability,
+ CommonRowsProportion,
+ DataMismatchScore,
+ DistantValuesProbability,
+ NearestSyntheticNeighborDistance,
+ ]:
+ score = evaluator().evaluate(X1, X2)
+ assert isinstance(score, dict)
+ for k in score:
+ assert score[k] >= 0
+ assert not np.isnan(score[k])
diff --git a/tests/metrics/test_statistical.py b/tests/metrics/test_statistical.py
index 34fba695..d86489a2 100644
--- a/tests/metrics/test_statistical.py
+++ b/tests/metrics/test_statistical.py
@@ -7,11 +7,13 @@
import pytest
from lifelines.datasets import load_rossi
from sklearn.datasets import load_iris
+from torchvision import datasets
# synthcity absolute
from synthcity.metrics.eval_statistical import (
AlphaPrecision,
ChiSquaredTest,
+ FrechetInceptionDistance,
InverseKLDivergence,
JensenShannonDistance,
KolmogorovSmirnovTest,
@@ -24,6 +26,7 @@
from synthcity.plugins.core.dataloader import (
DataLoader,
GenericDataLoader,
+ ImageDataLoader,
SurvivalAnalysisDataLoader,
create_from_info,
)
@@ -263,3 +266,28 @@ def test_evaluate_survival_km_distance(test_plugin: Plugin) -> None:
assert SurvivalKMDistance.name() == "survival_km_distance"
assert SurvivalKMDistance.type() == "stats"
assert SurvivalKMDistance.direction() == "minimize"
+
+
+def test_image_support() -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ X1 = ImageDataLoader(dataset).sample(100)
+ X2 = ImageDataLoader(dataset).sample(100)
+
+ for evaluator in [
+ AlphaPrecision,
+ ChiSquaredTest,
+ InverseKLDivergence,
+ JensenShannonDistance,
+ KolmogorovSmirnovTest,
+ MaximumMeanDiscrepancy,
+ PRDCScore,
+ WassersteinDistance,
+ FrechetInceptionDistance,
+ ]:
+ score = evaluator().evaluate(X1, X2)
+ print(score)
+ assert isinstance(score, dict), evaluator
+ for k in score:
+ assert score[k] >= 0, evaluator
+ assert not np.isnan(score[k]), evaluator
diff --git a/tests/plugins/core/models/test_convnet.py b/tests/plugins/core/models/test_convnet.py
new file mode 100644
index 00000000..578c1c82
--- /dev/null
+++ b/tests/plugins/core/models/test_convnet.py
@@ -0,0 +1,84 @@
+# third party
+import numpy as np
+import pytest
+import torch
+from torch.utils.data import Subset
+from torchvision import datasets, transforms
+
+# synthcity absolute
+from synthcity.plugins.core.models.convnet import (
+ map_nonlin,
+ suggest_image_classifier_arch,
+ suggest_image_generator_discriminator_arch,
+)
+
+
+@pytest.mark.parametrize("nonlin", ["relu", "elu", "prelu", "leaky_relu"])
+def test_get_nonlin(nonlin: str) -> None:
+ assert map_nonlin(nonlin) is not None
+
+
+@pytest.mark.parametrize("n_channels", [1, 3])
+@pytest.mark.parametrize("height", [32, 64, 128])
+def test_suggest_gan(n_channels: int, height: int) -> None:
+ n_units_latent = 100
+ gen, disc = suggest_image_generator_discriminator_arch(
+ n_units_latent=n_units_latent,
+ n_channels=n_channels,
+ height=height,
+ width=height,
+ )
+
+ dummy_noise = torch.rand((10, n_units_latent, n_channels, 1))
+ gen(dummy_noise)
+
+ dummy_in = torch.rand((10, n_channels, height, height))
+ disc(dummy_in)
+
+
+@pytest.mark.parametrize("n_channels", [1, 3])
+@pytest.mark.parametrize("height", [32, 64, 128])
+def test_suggest_clf(n_channels: int, height: int) -> None:
+ classes = 13
+ clf = suggest_image_classifier_arch(
+ n_channels=n_channels,
+ height=height,
+ width=height,
+ classes=classes,
+ )
+
+ dummy_input = torch.rand((10, n_channels, height, height))
+ out = clf(dummy_input)
+
+ assert out.shape == (10, classes)
+
+
+def test_train_clf() -> None:
+ IMG_SIZE = 32
+ data_transform = transforms.Compose(
+ [
+ transforms.Resize(IMG_SIZE),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5,), std=(0.5,)),
+ ]
+ )
+ dataset = datasets.MNIST(".", download=True, transform=data_transform)
+ dataset = Subset(dataset, np.arange(len(dataset))[:100])
+
+ classes = 10
+
+ clf = suggest_image_classifier_arch(
+ n_channels=1,
+ height=IMG_SIZE,
+ width=IMG_SIZE,
+ classes=classes,
+ n_iter=100,
+ n_iter_print=10,
+ batch_size=40,
+ )
+
+ clf.fit(dataset)
+
+ test_X, test_y = next(iter(dataset))
+
+ print(clf.predict(test_X), test_y)
diff --git a/tests/plugins/core/models/test_image_gan.py b/tests/plugins/core/models/test_image_gan.py
new file mode 100644
index 00000000..0c735c22
--- /dev/null
+++ b/tests/plugins/core/models/test_image_gan.py
@@ -0,0 +1,169 @@
+# third party
+import numpy as np
+import pytest
+import torch
+from torch.utils.data import Subset
+from torchvision import datasets, transforms
+
+# synthcity absolute
+from synthcity.plugins.core.dataloader import FlexibleDataset
+from synthcity.plugins.core.models.convnet import (
+ ConditionalDiscriminator,
+ ConditionalGenerator,
+ suggest_image_generator_discriminator_arch,
+)
+from synthcity.plugins.core.models.image_gan import ImageGAN
+from synthcity.utils.constants import DEVICE
+
+IMG_SIZE = 32
+data_transform = transforms.Compose(
+ [
+ transforms.Resize(IMG_SIZE),
+ transforms.ToTensor(),
+ transforms.Normalize(mean=(0.5,), std=(0.5,)),
+ ]
+)
+# Load MNIST dataset as tensors
+batch_size = 128
+dataset = datasets.MNIST(".", download=True, transform=data_transform)
+dataset = Subset(dataset, np.arange(len(dataset))[:100])
+dataset = FlexibleDataset(dataset)
+
+
+def test_network_config() -> None:
+ noise_dim = 123
+ (
+ image_generator,
+ image_discriminator,
+ ) = suggest_image_generator_discriminator_arch(
+ n_units_latent=noise_dim,
+ n_channels=1,
+ height=IMG_SIZE,
+ width=IMG_SIZE,
+ device=DEVICE,
+ generator_dropout=0.2,
+ )
+
+ net = ImageGAN(
+ image_generator=image_generator,
+ image_discriminator=image_discriminator,
+ n_units_latent=noise_dim,
+ n_channels=1,
+ # Generator
+ generator_n_iter=1001,
+ generator_lr=1e-3,
+ generator_weight_decay=1e-3,
+ # Discriminator
+ discriminator_n_iter=1002,
+ discriminator_lr=1e-3,
+ discriminator_weight_decay=1e-3,
+ # Training
+ batch_size=64,
+ n_iter_print=100,
+ random_state=77,
+ n_iter_min=100,
+ clipping_value=1,
+ lambda_gradient_penalty=2,
+ lambda_identifiability_penalty=3,
+ )
+
+ assert isinstance(net.generator, ConditionalGenerator)
+ assert isinstance(net.discriminator, ConditionalDiscriminator)
+ assert net.batch_size == 64
+ assert net.generator_n_iter == 1001
+ assert net.discriminator_n_iter == 1002
+ assert net.random_state == 77
+ assert net.lambda_gradient_penalty == 2
+ assert net.lambda_identifiability_penalty == 3
+
+
+@pytest.mark.parametrize("n_iter", [10])
+@pytest.mark.parametrize("lr", [1e-3, 3e-4])
+def test_basic_network(
+ n_iter: int,
+ lr: float,
+) -> None:
+ noise_dim = 123
+ (
+ image_generator,
+ image_discriminator,
+ ) = suggest_image_generator_discriminator_arch(
+ n_units_latent=noise_dim,
+ n_channels=1,
+ height=IMG_SIZE,
+ width=IMG_SIZE,
+ device=DEVICE,
+ generator_dropout=0.2,
+ )
+ net = ImageGAN(
+ image_generator=image_generator,
+ image_discriminator=image_discriminator,
+ n_units_latent=noise_dim,
+ n_channels=1,
+ generator_n_iter=n_iter,
+ discriminator_n_iter=n_iter,
+ generator_lr=lr,
+ discriminator_lr=lr,
+ )
+
+ assert net.generator_n_iter == n_iter
+ assert net.discriminator_n_iter == n_iter
+
+
+@pytest.mark.parametrize("generator_extra_penalties", [[], ["identifiability_penalty"]])
+def test_image_gan_generation(generator_extra_penalties: list) -> None:
+ noise_dim = 123
+ (
+ image_generator,
+ image_discriminator,
+ ) = suggest_image_generator_discriminator_arch(
+ n_units_latent=noise_dim,
+ n_channels=1,
+ height=IMG_SIZE,
+ width=IMG_SIZE,
+ device=DEVICE,
+ generator_dropout=0.2,
+ )
+ model = ImageGAN(
+ image_generator=image_generator,
+ image_discriminator=image_discriminator,
+ n_units_latent=noise_dim,
+ n_channels=1,
+ generator_n_iter=10,
+ generator_extra_penalties=generator_extra_penalties,
+ )
+ model.fit(dataset)
+
+ generated = model.generate(10)
+
+ assert generated.shape == (10, 1, IMG_SIZE, IMG_SIZE)
+
+
+def test_image_gan_conditional_generation() -> None:
+ noise_dim = 123
+ cond = dataset.labels()
+ (
+ image_generator,
+ image_discriminator,
+ ) = suggest_image_generator_discriminator_arch(
+ n_units_latent=noise_dim,
+ n_channels=1,
+ height=IMG_SIZE,
+ width=IMG_SIZE,
+ device=DEVICE,
+ generator_dropout=0.2,
+ )
+
+ model = ImageGAN(
+ image_generator=image_generator,
+ image_discriminator=image_discriminator,
+ n_units_latent=noise_dim,
+ n_channels=1,
+ generator_n_iter=10,
+ )
+ model.fit(dataset, cond=cond)
+
+ cnt = 10
+ generated = model.generate(cnt, cond=torch.ones(cnt).to(DEVICE))
+
+ assert generated.shape == (10, 1, IMG_SIZE, IMG_SIZE)
diff --git a/tests/plugins/core/test_dataloader.py b/tests/plugins/core/test_dataloader.py
index 0e0db277..658287f2 100644
--- a/tests/plugins/core/test_dataloader.py
+++ b/tests/plugins/core/test_dataloader.py
@@ -6,17 +6,21 @@
import numpy as np
import pandas as pd
import pytest
+import torch
from lifelines.datasets import load_rossi
from sklearn.datasets import load_breast_cancer
+from torchvision import datasets, transforms
# synthcity absolute
from synthcity.plugins.core.dataloader import (
GenericDataLoader,
+ ImageDataLoader,
SurvivalAnalysisDataLoader,
TimeSeriesDataLoader,
TimeSeriesSurvivalDataLoader,
create_from_info,
)
+from synthcity.plugins.core.dataset import FlexibleDataset, TensorDataset
from synthcity.utils.datasets.time_series.google_stocks import GoogleStocksDataloader
from synthcity.utils.datasets.time_series.pbc import PBCDataloader
from synthcity.utils.datasets.time_series.sine import SineDataloader
@@ -29,6 +33,7 @@ def test_generic_dataloader_sanity() -> None:
loader = GenericDataLoader(X, target_column="target")
+ assert loader.is_tabular()
assert loader.raw().shape == X.shape
assert loader.type() == "generic"
assert loader.shape == X.shape
@@ -160,6 +165,7 @@ def test_survival_dataloader_sanity() -> None:
assert (loader.dataframe().values == df.values).all()
assert loader.raw().shape == df.shape
+ assert loader.is_tabular()
assert loader.type() == "survival_analysis"
assert loader.shape == df.shape
assert sorted(list(loader.dataframe().columns)) == sorted(list(df.columns))
@@ -299,6 +305,8 @@ def test_time_series_dataloader_sanity(source: Any) -> None:
)
assert len(loader.raw()) == 5
+ assert loader.is_tabular()
+
feat_cnt = temporal_data[0].shape[1] + 2 # id, time_id
if static_data is not None:
feat_cnt += static_data.shape[1]
@@ -625,3 +633,120 @@ def test_time_series_survival_pack_unpack_padding(as_numpy: bool) -> None:
for idx, item in enumerate(unp_temporal):
assert len(unp_temporal[idx]) == max_window_len
assert len(unp_observation_times[idx]) == max_window_len
+
+
+@pytest.mark.parametrize("height", [55, 64])
+@pytest.mark.parametrize("width", [32, 22])
+def test_image_dataloader_sanity(height: int, width: int) -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ loader = ImageDataLoader(
+ data=dataset,
+ train_size=0.8,
+ height=height,
+ width=width,
+ )
+ channels = 1
+
+ assert loader.shape == (len(dataset), channels, height, width)
+ assert loader.info()["height"] == height
+ assert loader.info()["width"] == width
+ assert loader.info()["channels"] == channels
+ assert loader.info()["len"] == len(dataset)
+ assert not loader.is_tabular()
+
+ assert isinstance(loader.unpack(), torch.utils.data.Dataset)
+
+ assert loader.sample(5).shape == (5, channels, height, width)
+
+ assert loader[0].shape == (channels, height, width)
+
+ assert loader.hash() != ""
+
+ assert loader.train().shape == (0.8 * len(dataset), channels, height, width)
+ assert loader.test().shape == (0.2 * len(dataset), channels, height, width)
+
+ x_np = loader.numpy()
+ assert x_np.shape == (len(dataset), channels, height, width)
+ assert isinstance(x_np, np.ndarray)
+
+ df = loader.dataframe()
+ assert df.shape == (len(dataset), channels * height * width)
+ assert isinstance(df, pd.DataFrame)
+
+ assert loader.unpack().labels().shape == (len(loader),)
+
+
+def test_image_dataloader_create_from_info() -> None:
+ dataset = datasets.MNIST(".", download=True)
+
+ loader = ImageDataLoader(
+ data=dataset,
+ train_size=0.8,
+ height=32,
+ )
+
+ data = loader.unpack()
+
+ reloaded = create_from_info(data, loader.info())
+
+ for key in loader.info():
+ assert reloaded.info()[key] == loader.info()[key]
+
+
+def test_image_dataloader_create_from_tensor() -> None:
+ X = torch.randn((100, 10, 10))
+ y = torch.randn((100,))
+
+ loader = ImageDataLoader(
+ data=(X, y),
+ train_size=0.8,
+ height=32,
+ )
+
+ assert len(loader) == len(X)
+ assert loader.shape == (100, 1, 32, 32)
+
+
+def test_image_datasets() -> None:
+ size = 100
+ X = torch.rand(size, 10, 10)
+ y = torch.rand(size)
+
+ gen_dataset = TensorDataset(images=X, targets=y)
+ assert (gen_dataset[0][0] == X[0]).all()
+ assert (gen_dataset[0][1] == y[0]).all()
+
+ img_transform = transforms.Compose(
+ [
+ transforms.ToPILImage(),
+ transforms.Resize((20, 20)),
+ transforms.ToTensor(),
+ ]
+ )
+
+ transform_dataset = FlexibleDataset(gen_dataset, transform=img_transform)
+ assert transform_dataset.shape() == (size, 1, 20, 20)
+ assert transform_dataset[0][0].shape == (1, 20, 20)
+ assert transform_dataset[0][1] == y[0]
+
+ gen_dataset = TensorDataset(images=X, targets=None)
+ assert (gen_dataset[0][0] == X[0]).all()
+ assert gen_dataset[0][1] is None
+
+ transform_dataset = FlexibleDataset(gen_dataset, transform=img_transform)
+ assert transform_dataset.shape() == (size, 1, 20, 20)
+ assert transform_dataset[0][0].shape == (1, 20, 20)
+ assert transform_dataset[0][1] is None
+
+ transform_dataset = transform_dataset.filter_indices([0, 1, 2])
+ assert len(transform_dataset) == 3
+ assert (transform_dataset.indices == [0, 1, 2]).all()
+
+ transform_dataset = transform_dataset.filter_indices([1, 2])
+ assert len(transform_dataset) == 2
+ assert (transform_dataset.indices == [1, 2]).all()
+
+ transform_dataset = transform_dataset.filter_indices([0])
+ assert len(transform_dataset) == 1
+ assert (transform_dataset.indices == [1]).all()
diff --git a/tests/plugins/images/img_helpers.py b/tests/plugins/images/img_helpers.py
new file mode 100644
index 00000000..21fc5625
--- /dev/null
+++ b/tests/plugins/images/img_helpers.py
@@ -0,0 +1,20 @@
+# stdlib
+from typing import Dict, List, Type
+
+# synthcity absolute
+from synthcity.plugins import Plugin, Plugins
+from synthcity.utils.serialization import load, save
+
+
+def generate_fixtures(name: str, plugin: Type, plugin_args: Dict = {}) -> List:
+ def from_api() -> Plugin:
+ return Plugins().get(name, **plugin_args)
+
+ def from_module() -> Plugin:
+ return plugin(**plugin_args)
+
+ def from_serde() -> Plugin:
+ buff = save(plugin(**plugin_args))
+ return load(buff)
+
+ return [from_api(), from_module(), from_serde()]
diff --git a/tests/plugins/images/test_image_adsgan.py b/tests/plugins/images/test_image_adsgan.py
new file mode 100644
index 00000000..a7102258
--- /dev/null
+++ b/tests/plugins/images/test_image_adsgan.py
@@ -0,0 +1,90 @@
+# third party
+import numpy as np
+import pytest
+from img_helpers import generate_fixtures
+from torchvision import datasets
+
+# synthcity absolute
+from synthcity.plugins import Plugin
+from synthcity.plugins.core.dataloader import ImageDataLoader
+from synthcity.plugins.images.plugin_image_adsgan import plugin
+
+plugin_name = "image_adsgan"
+
+dataset = datasets.MNIST(".", download=True)
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_sanity(test_plugin: Plugin) -> None:
+ assert test_plugin is not None
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_name(test_plugin: Plugin) -> None:
+ assert test_plugin.name() == plugin_name
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_type(test_plugin: Plugin) -> None:
+ assert test_plugin.type() == "images"
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_hyperparams(test_plugin: Plugin) -> None:
+ assert len(test_plugin.hyperparameter_space()) == 6
+
+
+def test_plugin_fit() -> None:
+ test_plugin = plugin(n_iter=5)
+
+ X = ImageDataLoader(dataset).sample(100)
+
+ test_plugin.fit(X)
+
+
+def test_plugin_generate() -> None:
+ test_plugin = plugin(n_iter=10, n_units_latent=13)
+
+ X = ImageDataLoader(dataset).sample(100)
+
+ test_plugin.fit(X)
+
+ X_gen = test_plugin.generate()
+ assert len(X_gen) == len(X)
+ assert X_gen.shape == X.shape
+
+ X_gen = test_plugin.generate(50)
+ assert len(X_gen) == 50
+
+
+def test_plugin_generate_with_conditional() -> None:
+ test_plugin = plugin(n_iter=10, n_units_latent=13)
+
+ X = ImageDataLoader(dataset).sample(100)
+ cond = X.unpack().labels()
+
+ test_plugin.fit(X, cond=cond)
+
+ cnt = 50
+ X_gen = test_plugin.generate(cnt, cond=np.ones(cnt))
+ assert len(X_gen) == 50
+
+
+def test_plugin_generate_with_stop_conditional() -> None:
+ test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2)
+
+ X = ImageDataLoader(dataset).sample(100)
+ cond = X.unpack().labels()
+
+ test_plugin.fit(X, cond=cond)
+
+ cnt = 50
+ X_gen = test_plugin.generate(cnt, cond=np.ones(cnt))
+ assert len(X_gen) == 50
+
+
+def test_sample_hyperparams() -> None:
+ for i in range(100):
+ args = plugin.sample_hyperparameters()
+
+ assert plugin(**args) is not None
diff --git a/tests/plugins/images/test_image_cgan.py b/tests/plugins/images/test_image_cgan.py
new file mode 100644
index 00000000..c9cccec9
--- /dev/null
+++ b/tests/plugins/images/test_image_cgan.py
@@ -0,0 +1,91 @@
+# third party
+import numpy as np
+import pytest
+from img_helpers import generate_fixtures
+from torchvision import datasets
+
+# synthcity absolute
+from synthcity.plugins import Plugin
+from synthcity.plugins.core.dataloader import ImageDataLoader
+from synthcity.plugins.images.plugin_image_cgan import plugin
+
+plugin_name = "image_cgan"
+
+dataset = datasets.MNIST(".", download=True)
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_sanity(test_plugin: Plugin) -> None:
+ assert test_plugin is not None
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_name(test_plugin: Plugin) -> None:
+ assert test_plugin.name() == plugin_name
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_type(test_plugin: Plugin) -> None:
+ assert test_plugin.type() == "images"
+
+
+@pytest.mark.parametrize("test_plugin", generate_fixtures(plugin_name, plugin))
+def test_plugin_hyperparams(test_plugin: Plugin) -> None:
+ assert len(test_plugin.hyperparameter_space()) == 6
+
+
+@pytest.mark.parametrize("height", [32, 64, 128])
+def test_plugin_fit(height: int) -> None:
+ test_plugin = plugin(n_iter=5)
+
+ X = ImageDataLoader(dataset, height=height).sample(100)
+
+ test_plugin.fit(X)
+
+
+def test_plugin_generate() -> None:
+ test_plugin = plugin(n_iter=10, n_units_latent=13)
+
+ X = ImageDataLoader(dataset).sample(100)
+
+ test_plugin.fit(X)
+
+ X_gen = test_plugin.generate()
+ assert len(X_gen) == len(X)
+ assert X_gen.shape == X.shape
+
+ X_gen = test_plugin.generate(50)
+ assert len(X_gen) == 50
+
+
+def test_plugin_generate_with_conditional() -> None:
+ test_plugin = plugin(n_iter=10, n_units_latent=13)
+
+ X = ImageDataLoader(dataset).sample(100)
+ cond = X.unpack().labels()
+
+ test_plugin.fit(X, cond=cond)
+
+ cnt = 50
+ X_gen = test_plugin.generate(cnt, cond=np.ones(cnt))
+ assert len(X_gen) == 50
+
+
+def test_plugin_generate_with_stop_conditional() -> None:
+ test_plugin = plugin(n_iter=10, n_units_latent=13, n_iter_print=2)
+
+ X = ImageDataLoader(dataset).sample(100)
+ cond = X.unpack().labels()
+
+ test_plugin.fit(X, cond=cond)
+
+ cnt = 50
+ X_gen = test_plugin.generate(cnt, cond=np.ones(cnt))
+ assert len(X_gen) == 50
+
+
+def test_sample_hyperparams() -> None:
+ for i in range(100):
+ args = plugin.sample_hyperparameters()
+
+ assert plugin(**args) is not None
diff --git a/tutorials/plugins/images/plugin_image_adsgan.ipynb b/tutorials/plugins/images/plugin_image_adsgan.ipynb
new file mode 100644
index 00000000..d6cb3204
--- /dev/null
+++ b/tutorials/plugins/images/plugin_image_adsgan.ipynb
@@ -0,0 +1,179 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d30e7633",
+ "metadata": {},
+ "source": [
+ "# Image AdsGAN Example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fda150d4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# stdlib\n",
+ "import warnings\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "\n",
+ "eval_plugin = \"image_adsgan\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28dcf470",
+ "metadata": {},
+ "source": [
+ "### Load dataset\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "82bb1afd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "from torchvision import datasets, transforms\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.dataloader import ImageDataLoader\n",
+ "\n",
+ "IMG_SIZE = 32\n",
+ "\n",
+ "dataset = datasets.MNIST(\".\", download=True)\n",
+ "loader = ImageDataLoader(\n",
+ " dataset,\n",
+ " height=IMG_SIZE,\n",
+ ").sample(1000)\n",
+ "\n",
+ "loader.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e597a396",
+ "metadata": {},
+ "source": [
+ "### Train the generator\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8d846eee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "\n",
+ "syn_model = Plugins().get(eval_plugin, batch_size=100, plot_progress=True, n_iter=100)\n",
+ "\n",
+ "syn_model.fit(loader)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4cc54562",
+ "metadata": {},
+ "source": [
+ "### Generate new samples\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "821cff43",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "import torch\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.models.image_gan import display_imgs\n",
+ "\n",
+ "syn_samples, syn_labels = syn_model.generate(count=5).unpack().tensors()\n",
+ "\n",
+ "display_imgs(syn_samples)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2be8a8da",
+ "metadata": {},
+ "source": [
+ "### Benchmarks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e7e19494",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.benchmark import Benchmarks\n",
+ "\n",
+ "score = Benchmarks.evaluate(\n",
+ " [\n",
+ " (eval_plugin, eval_plugin, {\"n_iter\": 50})\n",
+ " ], # (testname, plugin, plugin_args) REPLACE {\"n_iter\" : 50} with {} for better performance\n",
+ " loader,\n",
+ " repeats=2,\n",
+ " metrics={\"detection\": [\"detection_mlp\"]}, # DELETE THIS LINE FOR ALL METRICS\n",
+ " task_type=\"classification\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b25c1964",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Benchmarks.print(score)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "81a0507e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/plugins/images/plugin_image_cgan.ipynb b/tutorials/plugins/images/plugin_image_cgan.ipynb
new file mode 100644
index 00000000..65984ea7
--- /dev/null
+++ b/tutorials/plugins/images/plugin_image_cgan.ipynb
@@ -0,0 +1,179 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "d30e7633",
+ "metadata": {},
+ "source": [
+ "# Conditional Image GAN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fda150d4",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# stdlib\n",
+ "import warnings\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "\n",
+ "eval_plugin = \"image_cgan\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28dcf470",
+ "metadata": {},
+ "source": [
+ "### Load dataset\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "82bb1afd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "from torchvision import datasets, transforms\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.dataloader import ImageDataLoader\n",
+ "\n",
+ "IMG_SIZE = 32\n",
+ "\n",
+ "dataset = datasets.MNIST(\".\", download=True)\n",
+ "loader = ImageDataLoader(\n",
+ " dataset,\n",
+ " height=IMG_SIZE,\n",
+ ").sample(1000)\n",
+ "\n",
+ "loader.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e597a396",
+ "metadata": {},
+ "source": [
+ "### Train the generator\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8d846eee",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "\n",
+ "syn_model = Plugins().get(eval_plugin, batch_size=100, plot_progress=True, n_iter=100)\n",
+ "\n",
+ "syn_model.fit(loader)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4cc54562",
+ "metadata": {},
+ "source": [
+ "### Generate new samples\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "821cff43",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "import torch\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.models.image_gan import display_imgs\n",
+ "\n",
+ "syn_samples, syn_labels = syn_model.generate(count=5).unpack().tensors()\n",
+ "\n",
+ "display_imgs(syn_samples)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "2be8a8da",
+ "metadata": {},
+ "source": [
+ "### Benchmarks"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e7e19494",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.benchmark import Benchmarks\n",
+ "\n",
+ "score = Benchmarks.evaluate(\n",
+ " [\n",
+ " (eval_plugin, eval_plugin, {\"n_iter\": 50})\n",
+ " ], # (testname, plugin, plugin_args) REPLACE {\"n_iter\" : 50} with {} for better performance\n",
+ " loader,\n",
+ " repeats=2,\n",
+ " metrics={\"detection\": [\"detection_mlp\"]}, # DELETE THIS LINE FOR ALL METRICS\n",
+ " task_type=\"classification\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b25c1964",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Benchmarks.print(score)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "81a0507e",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/tutorials/tutorial6_time_series_data_preparation.ipynb b/tutorials/tutorial6_time_series_data_preparation.ipynb
index 5caf311b..fdefbbba 100644
--- a/tutorials/tutorial6_time_series_data_preparation.ipynb
+++ b/tutorials/tutorial6_time_series_data_preparation.ipynb
@@ -35,7 +35,9 @@
{
"cell_type": "markdown",
"metadata": {},
- "source": ["We simulate our data as two dataframes: a dataframe containing data (eg, age, sex) and a dataframe containing temporal data (eg, body temperature over time)"]
+ "source": [
+ "We simulate our data as two dataframes: a dataframe containing data (eg, age, sex) and a dataframe containing temporal data (eg, body temperature over time)"
+ ]
},
{
"cell_type": "code",
@@ -43,10 +45,13 @@
"metadata": {},
"outputs": [],
"source": [
+ "# stdlib\n",
+ "import datetime\n",
+ "import uuid\n",
+ "\n",
+ "# third party\n",
"# import libraries for generating simulated data\n",
"import numpy as np\n",
- "import uuid\n",
- "import datetime\n",
"import pandas as pd\n",
"\n",
"# set the number of individuals and observations per individual you want to generate\n",
@@ -55,21 +60,29 @@
"\n",
"# generate static data\n",
"ids = [uuid.uuid4().hex[:6].upper() for i in range(num_subj)]\n",
- "static_data = pd.DataFrame({'id': ids,\n",
- " 'var_a': np.random.randint(2, size=(num_subj)),\n",
- " 'var_b': np.random.normal(loc=2, scale=0.5, size=(num_subj)),\n",
- " 'outcome': np.random.binomial(1, 0.7, size=(num_subj))})\n",
+ "static_data = pd.DataFrame(\n",
+ " {\n",
+ " \"id\": ids,\n",
+ " \"var_a\": np.random.randint(2, size=(num_subj)),\n",
+ " \"var_b\": np.random.normal(loc=2, scale=0.5, size=(num_subj)),\n",
+ " \"outcome\": np.random.binomial(1, 0.7, size=(num_subj)),\n",
+ " }\n",
+ ")\n",
"\n",
"# generate temporal data\n",
"temp_len = num_obs * len(ids)\n",
"temp_ids = ids * num_obs\n",
"timepoints = (pd.date_range(datetime.date.today(), periods=num_obs).tolist()) * num_subj\n",
"\n",
- "temporal_data = pd.DataFrame({'id': temp_ids,\n",
- " 'temp_a': np.random.normal(loc=0, scale=0.2, size=(temp_len)),\n",
- " 'temp_b': np.random.normal(loc=5, scale=1, size=(temp_len)),\n",
- " 'temp_c': np.random.binomial(1, 0.5, size=(temp_len)),\n",
- " 'timepoint': timepoints})"
+ "temporal_data = pd.DataFrame(\n",
+ " {\n",
+ " \"id\": temp_ids,\n",
+ " \"temp_a\": np.random.normal(loc=0, scale=0.2, size=(temp_len)),\n",
+ " \"temp_b\": np.random.normal(loc=5, scale=1, size=(temp_len)),\n",
+ " \"temp_c\": np.random.binomial(1, 0.5, size=(temp_len)),\n",
+ " \"timepoint\": timepoints,\n",
+ " }\n",
+ ")"
]
},
{
@@ -95,16 +108,16 @@
"outputs": [],
"source": [
"# rearrange static data\n",
- "outcome_data = static_data[['outcome']]\n",
- "static_data = static_data.drop(columns=['outcome'])\n",
+ "outcome_data = static_data[[\"outcome\"]]\n",
+ "static_data = static_data.drop(columns=[\"outcome\"])\n",
"\n",
"# rearrange temporal data\n",
"observation_data, temporal_dataframes = ([] for i in range(2))\n",
- "for id in static_data['id'].unique():\n",
- " temp_df = temporal_data[temporal_data['id'] == id]\n",
- " observations = temp_df['timepoint'].tolist()\n",
- " temp_df.set_index('timepoint', inplace=True)\n",
- " temp_df = temp_df.drop(columns=['id'])\n",
+ "for id in static_data[\"id\"].unique():\n",
+ " temp_df = temporal_data[temporal_data[\"id\"] == id]\n",
+ " observations = temp_df[\"timepoint\"].tolist()\n",
+ " temp_df.set_index(\"timepoint\", inplace=True)\n",
+ " temp_df = temp_df.drop(columns=[\"id\"])\n",
" # add each to list\n",
" observation_data.append(observations)\n",
" temporal_dataframes.append(temp_df)\n",
@@ -114,7 +127,8 @@
" temporal_data=temporal_dataframes,\n",
" observation_times=observation_data,\n",
" static_data=static_data,\n",
- " outcome=outcome_data)"
+ " outcome=outcome_data,\n",
+ ")"
]
},
{
@@ -132,7 +146,7 @@
"source": [
"syn_model = Plugins().get(\"timegan\")\n",
"\n",
- "syn_model.fit(loader)\n"
+ "syn_model.fit(loader)"
]
},
{
@@ -206,23 +220,23 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 2",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
- "name": "python2"
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
- "version": 2
+ "version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.6"
+ "pygments_lexer": "ipython3",
+ "version": "3.9.15"
}
},
"nbformat": 4,
- "nbformat_minor": 0
+ "nbformat_minor": 1
}
diff --git a/tutorials/tutorial7_image_generation_using_mednist.ipynb b/tutorials/tutorial7_image_generation_using_mednist.ipynb
new file mode 100644
index 00000000..d0f01b32
--- /dev/null
+++ b/tutorials/tutorial7_image_generation_using_mednist.ipynb
@@ -0,0 +1,365 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "41a16e01",
+ "metadata": {},
+ "source": [
+ "# Tutorial 7: Image generation\n",
+ "\n",
+ "Synthcity supports generating synthetic images. In this tutorial, we will train a generator based on the [MedNIST dataset](https://medmnist.com/). The Tutorial is adapted from a [MONAI example](https://github.com/Project-MONAI/tutorials/blob/main/2d_classification/mednist_tutorial.ipynb).\n",
+ "\n",
+ "The main components are:\n",
+ " - Creating an `ImageDataloader` on top of the MedNIST dataset.\n",
+ " - Training a Conditional GAN on the resulted dataloader.\n",
+ " - Benchmarking the quality of the synthetic images.\n",
+ " \n",
+ "__Disclaimer__: The models used for the Generators and the Discriminators are not state of the art. For adding better architectures, please update the `suggest_image_generator_discriminator_arch` options from the `convnet.py` module."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "889c3c45",
+ "metadata": {},
+ "source": [
+ "## Load MedNIST\n",
+ "\n",
+ "The dataset is downloaded using [MONAI](https://monai.io/)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b7c34bea",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Download MedNIST\n",
+ "# stdlib\n",
+ "import os\n",
+ "from pathlib import Path\n",
+ "\n",
+ "# third party\n",
+ "import PIL\n",
+ "from monai.apps import download_and_extract\n",
+ "\n",
+ "workspace = Path(\"workspace\")\n",
+ "workspace.mkdir(parents=True, exist_ok=True)\n",
+ "\n",
+ "resource = \"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gz\"\n",
+ "md5 = \"0bc7306e7427e00ad1c5526a6677552d\"\n",
+ "\n",
+ "compressed_file = workspace / \"MedNIST.tar.gz\"\n",
+ "data_dir = workspace / \"MedNIST\"\n",
+ "\n",
+ "if not data_dir.exists():\n",
+ " download_and_extract(resource, compressed_file, workspace, md5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a0608986",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "LIMIT = 1000 # samples per class\n",
+ "\n",
+ "class_names = sorted(\n",
+ " x for x in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, x))\n",
+ ")\n",
+ "num_class = len(class_names)\n",
+ "image_files = [\n",
+ " [\n",
+ " os.path.join(data_dir, class_names[i], x)\n",
+ " for x in os.listdir(os.path.join(data_dir, class_names[i]))\n",
+ " ]\n",
+ " for i in range(num_class)\n",
+ "]\n",
+ "num_each = [len(image_files[i]) for i in range(num_class)]\n",
+ "image_files_list = []\n",
+ "image_class = []\n",
+ "for i in range(num_class):\n",
+ " image_files_list.extend(image_files[i][:LIMIT])\n",
+ " image_class.extend([i] * min(num_each[i], LIMIT))\n",
+ "num_total = len(image_class)\n",
+ "image_width, image_height = PIL.Image.open(image_files_list[0]).size\n",
+ "\n",
+ "print(f\"Total image count: {num_total}\")\n",
+ "print(f\"Image dimensions: {image_width} x {image_height}\")\n",
+ "print(f\"Label names: {class_names}\")\n",
+ "print(f\"Label counts: {num_each}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1453db8c",
+ "metadata": {},
+ "source": [
+ "## Visualize random samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8a7b2717",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "\n",
+ "plt.subplots(3, 3, figsize=(8, 8))\n",
+ "for i, k in enumerate(np.random.randint(num_total, size=9)):\n",
+ " im = PIL.Image.open(image_files_list[k])\n",
+ " arr = np.array(im)\n",
+ " plt.subplot(3, 3, i + 1)\n",
+ " plt.xlabel(class_names[image_class[k]])\n",
+ " plt.imshow(arr, cmap=\"gray\", vmin=0, vmax=255)\n",
+ "plt.tight_layout()\n",
+ "plt.show()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5424a2c5",
+ "metadata": {},
+ "source": [
+ "## Create the ImageDataLoader\n",
+ "\n",
+ "The ImageDataLoader prepares the image dataset for the `synthcity` generators.\n",
+ "\n",
+ "Internally, the dataloader will resize the data to the `(height, width)` parameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "75f1f078",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "import torch\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.dataloader import ImageDataLoader\n",
+ "\n",
+ "IMG_SIZE = 64\n",
+ "\n",
+ "\n",
+ "class MedNISTDataset(torch.utils.data.Dataset):\n",
+ " def __init__(self, image_files, labels):\n",
+ " self.image_files = image_files\n",
+ " self.image_cache = {}\n",
+ " self.labels = labels\n",
+ "\n",
+ " def __len__(self):\n",
+ " return len(self.image_files)\n",
+ "\n",
+ " def __getitem__(self, index):\n",
+ " if index in self.image_cache:\n",
+ " img = self.image_cache[index]\n",
+ " else:\n",
+ " img = PIL.Image.open(self.image_files[index])\n",
+ " img = np.asarray(img)\n",
+ " self.image_cache[index] = img\n",
+ "\n",
+ " return img, self.labels[index]\n",
+ "\n",
+ "\n",
+ "dataset = MedNISTDataset(image_files_list, labels=image_class)\n",
+ "\n",
+ "dataloader = ImageDataLoader(\n",
+ " dataset,\n",
+ " height=IMG_SIZE,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5ef1a423",
+ "metadata": {},
+ "source": [
+ "## Load a generator - Conditional GAN\n",
+ "\n",
+ "For this experiment, we are using the `image_cgan` plugin."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "a95785a1",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.plugins import Plugins\n",
+ "\n",
+ "generator = Plugins().get(\"image_cgan\", batch_size=100, plot_progress=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ca77190b",
+ "metadata": {},
+ "source": [
+ "## Train the generator\n",
+ "\n",
+ "For the training, we are using the `ImageDataLoader` object previously created.\n",
+ "\n",
+ "At the same time, we are using a conditional(`cond`) with the labels of the images. This way, at inference time, we can request from the generator only samples from a specific class."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "bc6794cf",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "generator.fit(dataloader, cond=image_class)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f6b6d5dc",
+ "metadata": {},
+ "source": [
+ "## Generate new samples"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0cec4b56",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# third party\n",
+ "import torch\n",
+ "\n",
+ "# synthcity absolute\n",
+ "from synthcity.plugins.core.models.image_gan import display_imgs\n",
+ "\n",
+ "syn_samples, syn_labels = generator.generate(count=5).unpack().tensors()\n",
+ "\n",
+ "display_imgs(syn_samples)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "3e2428fa",
+ "metadata": {},
+ "source": [
+ "## Generate new samples using a conditional\n",
+ "\n",
+ "We can also generate instances from a specific class, using the conditional we used at training time(`cond`).\n",
+ "\n",
+ "__Disclaimer__ : Other architectures for the Generator and the Discriminator could improve the results. These architectures can be tweaked in the `suggest_image_generator_discriminator_arch` function."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8b538951",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for cls_idx, cls in enumerate(class_names):\n",
+ " print(\"Class\", cls)\n",
+ " syn_samples, syn_labels = (\n",
+ " generator.generate(count=5, cond=np.ones(5) * cls_idx).unpack().tensors()\n",
+ " )\n",
+ "\n",
+ " display_imgs(syn_samples)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "62eb522b",
+ "metadata": {},
+ "source": [
+ "## Benchmarks\n",
+ "\n",
+ "`synthcity` allows us to compare multiple generators on the same dataset, with a wide range of metrics."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "7d895cdd",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# synthcity absolute\n",
+ "from synthcity.benchmark import Benchmarks\n",
+ "\n",
+ "score = Benchmarks.evaluate(\n",
+ " [\n",
+ " (f\"test_{model}\", model, {}) for model in [\"image_cgan\", \"image_adsgan\"]\n",
+ " ], # (testname, plugin, plugin_args) REPLACE {\"n_iter\" : 50} with {} for better performance\n",
+ " dataloader,\n",
+ " repeats=3,\n",
+ " metrics={\n",
+ " \"detection\": [\"detection_mlp\"],\n",
+ " \"performance\": [\"mlp\"],\n",
+ " \"stats\": [\"fid\"],\n",
+ " },\n",
+ " task_type=\"classification\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "243b9890",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "Benchmarks.print(score)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "211a549e",
+ "metadata": {},
+ "source": [
+ "## Congratulations!\n",
+ "\n",
+ "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement towards Machine learning and AI for medicine, you can do so in the following ways!\n",
+ "\n",
+ "### Star [Synthcity](https://github.com/vanderschaarlab/synthcity) on GitHub\n",
+ "\n",
+ "- The easiest way to help our community is just by starring the Repos! This helps raise awareness of the tools we're building.\n",
+ "\n",
+ "\n",
+ "### Checkout other projects from vanderschaarlab\n",
+ "- [HyperImpute](https://github.com/vanderschaarlab/hyperimpute)\n",
+ "- [AutoPrognosis](https://github.com/vanderschaarlab/autoprognosis)\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.15"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}