Skip to content

Latest commit

 

History

History
127 lines (89 loc) · 5.57 KB

README.md

File metadata and controls

127 lines (89 loc) · 5.57 KB

Repository info

This repository contains a PyTorch implementation of Generative models with kernel distance in data space, proposed by Szymon Knop, Marcin Mazur, Przemysław Spurek, Jacek Tabor, Igor Podolak (2020).

Contents of the repository

|-- src/ - contains an implementation of the models proposed in the paper allowing to reproduce experiments from the original paper
|---- architecture/ - files containing architectures proposed in the paper
|---- lightning_callbacks/ - implementation of evaluators of metrics reported in our experiments
|---- factories/ - factories used to create objects proper objects base on command line arguments. Subfolders contain factories for specific models
|---- lighting_modules/ - implementation of experiments in pytorch lightning
|---- metrics/ - directory containing the implementation of all of the metrics used in paper
|---- modules/ - custom neural network layers used in models
|---- train_autoencoder.py - the main script to run all of the experiments
|-- results/ - directory that will be created to store the results of conducted experiments
|-- data/ - default directory that will be used as a source of data and place to download datasets

Experiments are written in pytorch-lightning to decouple the science code from the engineering. The LightningModule implementation is in src/lightning_modules/{autoencoder|generator|latent_generator}_module.py files. For more details refer to PyTorch-Lightning documentation

Conducting the experiments

To execute experiments described in Table 4 in the paper run scripts located in src/reproduce_table4.sh

The repository supports running CWAE and reuses code provided in SWAE paper. All of the implementations are based on the respective papers and repositories.

Browsing the results

Results are stored in tensorboard format. To browse them run the following command: tensorboard --logdir results

Datasets

The repository uses default datasets provided by PyTorch for MNIST, FashionMNIST, and CELEBA. To convert CELEB-A to 64x64 images we first center crop images to 140x140 and then resize them to 64x64.

Silverman rule of thumb values

As mentioned in paper we are using precalculated values of Silverman rule of thumb. Below is a table that contains precalculated values for used datasets. To compute these values use compute_cw_dataset_statistics.py script.

Dataset
Batch size MNIST F-MNIST KMNIST SVHN CIFAR-10 CELEBA
64 0.02020.02650.02580.00840.01340.0166
128 0.01530.02010.01960.00640.01020.0124
256 0.01160.01520.01480.00490.00770.0094

Stacked MNIST experiment

To perform Stacked MNIST experiment you can use train_mnist_classifier.py to train classifier first.

To train classifier run script: python -m train_mnist_classifier --model classifier --dataset stacked_mnist --dataroot <dataroot_path>

stop it when it reaches expected level of accuracy and later reuse classifier:

CWAE

python -m train_autoencoder --model cwae --dataset stacked_mnist --dataroot <dataroot_path> --latent_dim 24 --classifier_checkpoint <classifier_ckpt_path> --gpus 1

CW2

python -m train_autoencoder --model cw2_dynamic --dataset stacked_mnist --dataroot <dataroot_path> --latent_dim 24 --classifier_checkpoint <classifier_ckpt_path> --gpus 1 --save_checkpoint --verbose

LCW

python -m train_latent_generator --model cwg_dynamic --dataset stacked_mnist --dataroot <dataroot_path> --noise_dim 24 --classifier_checkpoint <classifier_ckpt_path> --gpus 1 --ae_ckpt <cw2_checkpoint_path>

Environment

The code was built and executed on Python 3.9.13 and dependencies specified in requirements.txt.

To create virtual environment please verify with python --version if you are using Python 3.9.13 and execute:

python -m venv .venv

To install please activate your environment using source on Linux or . \.venv\Scripts\Activate.ps1 on Windows.

and install dependencies:

pip install -r .\requirements.txt

Please note that torch dependencies require specifying find links parameter:

https://download.pytorch.org/whl/torch_stable.html

License

This implementation is licensed under the MIT License