Skip to content

General Adversarial Networks (GAN) training application using both PyTorch

Notifications You must be signed in to change notification settings

mcandemir/wgan-gp-trainer

Repository files navigation

WGAN-GP Trainer

Wasserstein GAN (WGAN), is a type of generative adversarial network that minimizes an approximation of the Earth-Mover's distance (EM) rather than the Jensen-Shannon divergence as in the original GAN formulation.

WGAN-GP is a generative adversarial network that uses the Wasserstein loss formulation plus a gradient norm penalty to achieve Lipschitz continuity.

Cat faces

created

Human faces

created

File Structure:

Creating your data folder

The root folder is the hierarchically first one: CatFaces, it contains the class folders such as cats. Same logic applies for celeba dataset.

.
├── data
│   ├── CatFaces
│   │   └── cats
│   └── celeba_gan
│       └── celeb

Your models

architectures folder contains the Deep Convolutional GAN architectures for images which are 64x64 and 128x128.

├── models
│   └── architectures
│       ├── DeepConv_GAN_64.py
│       └── DeepConv_GAN_128.py

Auto-generated training folders

Each time a new train is started, a new file structure as shown below is created to save created images for once in each 100th batch. These created images can be tracked in Tensorboard, which also includes the losses of generator and critic (discriminator). Tensorboard logs are stored in logs folder, and at the end of the training, the model is saved in model. And layout folder contains the hyperparameters that are used to train the desired model.

Each time a new training is started, a folder named train{ID} will be created. First training always called training0, if there is a folder already named training0, then the ID will be incremented and training1 will be created, and it will keep incrementing as the trainings go. Note that if there is no trains folder, it will create it automatically.



Name of the training file can be given as a parameter with --name. As an example: python train.py --name test_catfaces_64.

└── trains/
    ├── catfaces_64/
    │   ├── generated_grid_images/
    │   ├── generated_grid_images_fixed/
    │   ├── generated_images_fixed/
    │   ├── layout/
    │   ├── logs/
    │   │   ├── criticLoss/
    │   │   ├── fake_different/
    │   │   ├── fake_fixed/
    │   │   ├── fake_singular/
    │   │   ├── genLoss/
    │   │   └── real/
    │   └── model/
    │       └── catfaces_checkpoint.pth.tar
    └── celeba_128/
        ├── generated_grid_images/
        ├── generated_grid_images_fixed/
        .
        .

To keep track of the losses and images while training, you need to specify the path where logs is kept, and then go to the terminal and type: tensorboard --logdir logs.

Example:

Create your data folder -> add your data -> tune hyperparameters -> start the training by:

$ python train.py --name test_catfaces_64

Keep track of the logs:

$ cd trains/test_catfaces_64
$ tensorboard --logdir logs

Evaluate the model:

$ python run.py --path (path to model)

You can specify the figure's size with the parameter --size (height and width by pixels)

$ python run.py --path (path to model) --size 220

Installing Dependencies

Anaconda:

Packages required to training and running

$ conda install python=3.8
$ conda install pytorch torchvision cudatoolkit=10.2 -c pytorch
$ conda install tensorboard
$ conda install pyyaml
$ conda install matplotlib

Optional

$ conda install -c conda-forge torchinfo # not a must

About

General Adversarial Networks (GAN) training application using both PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages