Skip to content

Commit

Permalink
Image generation support (#135)
Browse files Browse the repository at this point in the history
## Description
- [x] Add image dataloader, with the ability to resize the images internally to predefined shapes.
- [x] Extend the core plugin for non-tabular inputs
- [x] Add ConvNet classifier support, with some predefined architectures.
- [x] Add Generator and Discriminators with predefined architectures for various image input sizes.
- [x] Add Image GAN support on top of the Generators and Discriminators.
- [x] Add Conditional GAN support for the Generators and Discriminators
- [x] New metrics: FID score
- [x] Adapt the detection and performance metrics for images
- [x] Add Image Conditional GAN plugin
- [x] Add Image AdsGAN plugin
- [x] Add tutorial on top of MedMNIST
- [x] Update docs and README
  • Loading branch information
bcebere authored Feb 27, 2023
1 parent 0e39fe2 commit 9a2ac98
Show file tree
Hide file tree
Showing 59 changed files with 4,481 additions and 127 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
- uses: actions/checkout@v2
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test_tutorials.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jobs:
- uses: actions/checkout@v2
with:
submodules: true
- uses: gautamkrishnar/keepalive-workflow@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@ data
checkpoints
lightning_logs
generated
MNIST
cifar-10*
49 changes: 45 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
- :cyclone: Several evaluation metrics for correctness and privacy.
- :fire: Several reference models, by type:
- General purpose: GAN-based (AdsGAN, CTGAN, PATEGAN, DP-GAN),VAE-based(TVAE, RTVAE), Normalizing flows, Bayesian Networks(PrivBayes, BN).
- Time Series generators: TimeGAN, FourierFlows, Probabilistic autoregressive.
- Survival Analysis: SurvivalGAN, SurVAE.
- Time Series & Time-Series Survival generators: TimeGAN, FourierFlows, TimeVAE.
- Static Survival Analysis: SurvivalGAN, SurVAE.
- Privacy-focused: DECAF, DP-GAN, AdsGAN, PATEGAN, PrivBayes.
- Domain adaptation: RadialGAN.
- Images: Image ConditionalGAN, Image AdsGAN.
- :book: [Read the docs !](https://synthcity.readthedocs.io/)
- :airplane: [Checkout the tutorials!](https://github.com/vanderschaarlab/synthcity#-tutorials)

Expand Down Expand Up @@ -110,7 +111,7 @@ score = Benchmarks.evaluate(
Benchmarks.print(score)
```

### Survival analysis
### Static Survival analysis

* List the available generators dedicated to survival analysis

Expand Down Expand Up @@ -172,6 +173,37 @@ syn_model = Plugins().get("timegan")
syn_model.fit(data)

syn_model.generate(count=10)
```
### Images

__Note__ : The architectures used for generators are not state-of-the-art. For other architectures, consider extending the `suggest_image_generator_discriminator_arch` method from the `convnet.py` module.

* List the available generators

```python
from synthcity.plugins import Plugins

Plugins(categories=["images"]).list()
```

* Generate new data
```python
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import ImageDataLoader
from torchvision import datasets


dataset = datasets.MNIST(".", download=True)
loader = ImageDataLoader(dataset).sample(100)

syn_model = Plugins().get("image_cgan")

syn_model.fit(loader)

syn_img, syn_labels = syn_model.generate(count=10).unpack().numpy()

print(syn_img.shape)

```

### Serialization
Expand Down Expand Up @@ -225,7 +257,8 @@ assert syn_model.name() == reloaded.name()
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Wa2CPsbXzbKMPC5fSBhKl00Gi7QqVkse?usp=sharing) [ Tutorial 3: Generating Survival Analysis data](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial3_survival_analysis.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jN36GCAKEkjzDlczmQfR7Wbh3yF3cIz5?usp=sharing) [ Tutorial 4: Generating Time Series](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial4_time_series.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1Nf8d3Y6sXr1uco8MsJA4wb33iFvReL59?usp=sharing) [ Tutorial 5: Generating Data with Differential Privacy Guarantees](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial5_differential_privacy.ipynb)

- [ Tutorial 6 for using Custom Time series data](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial6_time_series_data_preparation.ipynb)
- [![Test In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1gEUSGxVmts9C0cDBKbq7Ees-wvQpUBoo?usp=sharing) [ Tutorial 7: Image generation](https://github.com/vanderschaarlab/synthcity/blob/main/tutorials/tutorial7_image_generation_using_mednist.ipynb)

## 🔑 Methods

Expand Down Expand Up @@ -291,6 +324,13 @@ assert syn_model.name() == reloaded.name()
|--- | --- | --- |
|**radialgan** | Training complex machine learning models for prediction often requires a large amount of data that is not always readily available. Leveraging these external datasets from related but different sources is, therefore, an essential task if good predictive models are to be built for deployment in settings where data can be rare. RadialGAN is an approach to the problem in which multiple GAN architectures are used to learn to translate from one dataset to another, thereby allowing to augment the target dataset effectively and learning better predictive models than just the target dataset. | [RadialGAN: Leveraging multiple datasets to improve target-specific predictive models using Generative Adversarial Networks](https://arxiv.org/abs/1802.06403) |

### Images

| Method | Description | Reference |
|--- | --- | --- |
|**image_cgan**| Conditional GAN for generating images| --- |
|**image_adsgan**| The AdsGAN method adapted for image generation| --- |


### Debug methods

Expand Down Expand Up @@ -328,6 +368,7 @@ The following table contains the available evaluation metrics:
|**prdc**| Computes precision, recall, density, and coverage given two manifolds. | --- |
|**alpha_precision**|Evaluate the alpha-precision, beta-recall, and authenticity scores. | --- |
|**survival_km_distance**|The distance between two Kaplan-Meier plots(survival analysis). | --- |
|**fid**|The Frechet Inception Distance (FID) calculates the distance between two distributions of images. | --- |



Expand Down
60 changes: 57 additions & 3 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
- |:cyclone:| Several evaluation metrics for correctness and privacy.
- |:fire:| Several reference models, by type:
- General purpose: GAN-based (AdsGAN, CTGAN, PATEGAN, DP-GAN),VAE-based(TVAE, RTVAE), Normalizing flows, Bayesian Networks(PrivBayes, BN).
- Time Series generators: TimeGAN, FourierFlows, Probabilistic autoregressive.
- Survival Analysis: SurvivalGAN, SurVAE.
- Time Series & Time-Series Survival generators: TimeGAN, FourierFlows, TimeVAE.
- Static Survival Analysis: SurvivalGAN, SurVAE.
- Privacy-focused: DECAF, DP-GAN, AdsGAN, PATEGAN, PrivBayes.
- Domain adaptation: RadialGAN.
- Images: Image ConditionalGAN, Image AdsGAN.
- |:book:| [Read the docs !](https://synthcity.readthedocs.io/)
- |:airplane:| [Checkout the tutorials!](https://github.com/vanderschaarlab/synthcity#-tutorials)

Expand Down Expand Up @@ -152,6 +153,37 @@ syn_model = Plugins().get("timegan")
syn_model.fit(data)

syn_model.generate(count=10)
```
### Images

__Note__ : The architectures used for generators are not state-of-the-art. For other architectures, consider extending the `suggest_image_generator_discriminator_arch` method from the `convnet.py` module.

* List the available generators

```python
from synthcity.plugins import Plugins

Plugins(categories=["images"]).list()
```

* Generate new data
```python
from synthcity.plugins import Plugins
from synthcity.plugins.core.dataloader import ImageDataLoader
from torchvision import datasets


dataset = datasets.MNIST(".", download=True)
loader = ImageDataLoader(dataset).sample(100)

syn_model = Plugins().get("image_cgan")

syn_model.fit(loader)

syn_img, syn_labels = syn_model.generate(count=10).unpack().numpy()

print(syn_img.shape)

```

### Serialization
Expand All @@ -170,6 +202,20 @@ reloaded = load(buff)
assert syn_model.name() == reloaded.name()
```

* Saving and loading models from disk

```python
from synthcity.utils.serialization import save_to_file, load_from_file
from synthcity.plugins import Plugins

syn_model = Plugins().get("adsgan")

save_to_file('./adsgan_10_epochs.pkl', syn_model)
reloaded = load_from_file('./adsgan_10_epochs.pkl')

assert syn_model.name() == reloaded.name()
```

* Using the Serializable interface

```python
Expand Down Expand Up @@ -247,6 +293,13 @@ assert syn_model.name() == reloaded.name()
|--- | --- | --- |
|**radialgan** | Training complex machine learning models for prediction often requires a large amount of data that is not always readily available. Leveraging these external datasets from related but different sources is, therefore, an essential task if good predictive models are to be built for deployment in settings where data can be rare. RadialGAN is an approach to the problem in which multiple GAN architectures are used to learn to translate from one dataset to another, thereby allowing to augment the target dataset effectively and learning better predictive models than just the target dataset. | [RadialGAN: Leveraging multiple datasets to improve target-specific predictive models using Generative Adversarial Networks](https://arxiv.org/abs/1802.06403) |

### Images

| Method | Description | Reference |
|--- | --- | --- |
|**image_cgan**| Conditional GAN for generating images| --- |
|**image_adsgan**| The AdsGAN method adapted for image generation| --- |


### Debug methods

Expand Down Expand Up @@ -284,6 +337,7 @@ The following table contains the available evaluation metrics:
|**prdc**| Computes precision, recall, density, and coverage given two manifolds. | --- |
|**alpha_precision**|Evaluate the alpha-precision, beta-recall, and authenticity scores. | --- |
|**survival_km_distance**|The distance between two Kaplan-Meier plots(survival analysis). | --- |
|**fid**|The Frechet Inception Distance (FID) calculates the distance between two distributions of images. | --- |



Expand All @@ -293,7 +347,7 @@ The following table contains the available evaluation metrics:
|--- | --- | --- |
|**performance.xgb**|Train an XGBoost classifier/regressor/survival model on real data(gt) and synthetic data(syn), and evaluate the performance on the test set. | 1 for ideal performance, 0 for worst performance |
|**performance.linear**|Train a Linear classifier/regressor/survival model on real data(gt) and the synthetic data and evaluate the performance on test data.| 1 for ideal performance, 0 for worst performance |
|**performance.mlp**|Train a Neural Net classifier/regressor/survival model on the read data and the synthetic data and evaluate the performance on test data.| 1 for ideal performance, 0 for worst performance |
|**performance.mlp**|Train a Neural Net classifier/regressor/survival model on the real data and the synthetic data and evaluate the performance on test data.| 1 for ideal performance, 0 for worst performance |
|**performance.feat_rank_distance**| Train a model on the synthetic data and a model on the real data. Compute the feature importance of the models on the same test data, and compute the rank distance between the importance(kendalltau or spearman)| 1: similar ranks in the feature importance. 0: uncorrelated feature importance |
|**detection_gmm**|Train a GaussianMixture model to differentiate the synthetic data from the real data.|0: The datasets are indistinguishable. <br/>1: The datasets are totally distinguishable.|
|**detection_xgb**|Train an XGBoost model to differentiate the synthetic data from the real data.|0: The datasets are indistinguishable. <br/>1: The datasets are totally distinguishable.|
Expand Down
10 changes: 9 additions & 1 deletion docs/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ Time-series survival models
:maxdepth: 2

Time-series CoxPH <generated/synthcity.plugins.core.models.time_series_survival.ts_surv_coxph.rst>
DeepCoxPH <generated/synthcity.plugins.core.models.time_series_survival.ts_surv_deep_coxph.rst>
Dynamic DeepHit <generated/synthcity.plugins.core.models.time_series_survival.ts_surv_dynamic_deephit.rst>
Time-Series XGBoost <generated/synthcity.plugins.core.models.time_series_survival.ts_surv_xgb.rst>

Expand All @@ -63,3 +62,12 @@ Time-to-event models

DATE <generated/synthcity.plugins.core.models.time_to_event.tte_date.rst>
Survival function regression <generated/synthcity.plugins.core.models.time_to_event.tte_survival_function_regression.rst>

Images
-------------------------------
.. toctree::
:glob:
:maxdepth: 2

ConvNets <generated/synthcity.plugins.core.models.convnet.rst>
ImageGAN <generated/synthcity.plugins.core.models.image_gan.rst>
14 changes: 0 additions & 14 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,3 @@
# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
import subprocess
Expand Down Expand Up @@ -152,7 +140,6 @@
"pycox",
"pykeops",
"pyod",
"pydantic",
"rdt",
"redis",
"scikit-learn",
Expand All @@ -165,7 +152,6 @@
"joblib",
"sdv",
"shap",
"torch",
"tsai",
"xgboost",
"xgbse",
Expand Down
11 changes: 11 additions & 0 deletions docs/dataloaders.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
Datasets and DataLoaders
=========================

Dataloaders and Datasets
-------------------------
.. toctree::
:glob:
:maxdepth: 2

Dataloaders <generated/synthcity.plugins.core.dataloader.rst>
Datasets <generated/synthcity.plugins.core.dataset.rst>
11 changes: 11 additions & 0 deletions docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Getting started
Generating Survival Analysis Data <tutorials/tutorial3_survival_analysis.ipynb>
Generating Time Series <tutorials/tutorial4_time_series.ipynb>
Generating Data with Differential Privacy Guarantees <tutorials/tutorial5_differential_privacy.ipynb>
Using custom Time-series Datasets <tutorials/tutorial6_time_series_data_preparation.ipynb>
Generating Images <tutorials/tutorial7_image_generation_using_mednist.ipynb>

General-purpose generators
---------------------------
Expand Down Expand Up @@ -55,3 +57,12 @@ Domain adaptation generators
:maxdepth: 2

RadialGAN <tutorials/plugins/domain_adaptation/plugin_radialgan.ipynb>

Images
------------------------------
.. toctree::
:glob:
:maxdepth: 2

Image CGAN <tutorials/plugins/images/plugin_image_cgan.ipynb>
Image AdsGAN <tutorials/plugins/images/plugin_image_adsgan.ipynb>
10 changes: 9 additions & 1 deletion docs/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,13 @@ Time-series & Time-Series Survival Analysis

TimeGAN <generated/synthcity.plugins.time_series.plugin_timegan.rst>
FourierFlows <generated/synthcity.plugins.time_series.plugin_fflows.rst>
Probabilistic AutoRegressive <generated/synthcity.plugins.time_series.plugin_probabilistic_ar.rst>
TimeVAE <generated/synthcity.plugins.time_series.plugin_timevae.rst>

Images
----------------------------------------------
.. toctree::
:glob:
:maxdepth: 2

ImageCGAN <generated/synthcity.plugins.images.plugin_image_cgan.rst>
Image AdsGAN <generated/synthcity.plugins.images.plugin_image_adsgan.rst>
6 changes: 3 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ Examples

examples.rst

Dataloaders
============
Dataloaders and Datasets
==========================
.. toctree::
:glob:
:maxdepth: 3

Dataloaders <generated/synthcity.plugins.core.dataloader.rst>
dataloaders.rst

Generators
==========
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ platforms = any
classifiers =
Programming Language :: Python :: 3
Topic :: Scientific/Engineering :: Artificial Intelligence
Intended Audience :: Healthcare Industry
Intended Audience :: Science/Research
Operating System :: OS Independent

[options]
Expand Down Expand Up @@ -55,6 +55,7 @@ install_requires =
xgbse
pykeops
fflows
monai
tsai; python_version>"3.7"
importlib-metadata; python_version<"3.8"

Expand Down
2 changes: 1 addition & 1 deletion src/synthcity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@

warnings.simplefilter(action="ignore")

logger.add(sink=sys.stderr, level="ERROR")
logger.add(sink=sys.stderr, level="CRITICAL")
9 changes: 5 additions & 4 deletions src/synthcity/benchmark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# third party
import numpy as np
import pandas as pd
import torch
from IPython.display import display
from pydantic import validate_arguments

Expand All @@ -20,7 +19,7 @@
from synthcity.plugins import Plugins
from synthcity.plugins.core.constraints import Constraints
from synthcity.plugins.core.dataloader import DataLoader
from synthcity.utils.reproducibility import enable_reproducible_results
from synthcity.utils.reproducibility import clear_cache, enable_reproducible_results
from synthcity.utils.serialization import load_from_file, save_to_file


Expand Down Expand Up @@ -96,7 +95,9 @@ def evaluate(
workspace.mkdir(parents=True, exist_ok=True)

plugin_cats = ["generic", "privacy"]
if task_type == "survival_analysis":
if X.type() == "images":
plugin_cats.append("images")
elif task_type == "survival_analysis":
plugin_cats.append("survival_analysis")
elif task_type == "time_series" or task_type == "time_series_survival":
plugin_cats.append("time_series")
Expand All @@ -123,7 +124,7 @@ def evaluate(
kwargs["workspace"] = workspace
kwargs["random_state"] = repeat

torch.cuda.empty_cache()
clear_cache()

cache_file = (
workspace
Expand Down
Loading

0 comments on commit 9a2ac98

Please sign in to comment.