Skip to content

TensorFlow implementations of Wasserstein GAN with Gradient Penalty (WGAN-GP), Least Squares GAN (LSGAN), GANs with the hinge loss.

License

Notifications You must be signed in to change notification settings

mariorioMa/WGAN-GP-TensorFlow

 
 

Repository files navigation

Wasserstein GANs with Gradient Penalty (WGAN-GP) in TensorFlow

Descriptions

This is my TensorFlow implementations of Wasserstein GANs with Gradient Penalty (WGAN-GP) proposed in Improved Training of Wasserstein GANs, Least Squares GANs (LSGAN), and GANs with the hinge loss.

The key insight of WGAN-GP is as follows. To enforce Lipschitz constraint in Wasserstein GAN, the original paper proposes to clip the weights of the discriminator (critic), which can lead to undesired behavior including exploding and vanishing gradients. Instead of weight clipping, this paper proposes to employ a gradient penalty term to constrain the gradient norm of the critic’s output with respect to its input, resulting the learning objective:

This enables stable training of a variety of GAN models on a wide range of datasets. This implementation is tested on several datasets inlcuding LSUN bedroom, CelebA, CityScape(leftImg8bit_sequence_trainvaltest), ImageNet, CIFAR100, CIFAR10, Street View House Number (SVHN), MNIST, and Fashion_MNIST. Randomly sampled results are as follows.

*This code is still being developed and subject to change.

Prerequisites

Usage

Download datasets

python download.py --dataset bedroom celeba CIFAR10 CIFAR100 SVHN MNIST Fashion_MNIST
  • Downloading and extracting the LSUN bedroom dataset require around 94GB disk space.
  • ImageNet can be downloaded from here

Train models with downloaded datasets:

python trainer.py --dataset [bedroom / celeba / CityScape / ImageNet / CIFAR10 / CIFAR100 / SVHN / MNIST / Fashion_MNIST] --batch_size 36 --num_dis_conv 6 --gan_type wgan-gp
  • Selected arguments (see the config.py for more details)
    • --prefix: a nickname for the training.
    • --dataset: choose among bedroom, celeba, ImageNet, CIFAR10, CIFAR100, SVHN, MNIST, and Fashion_MNIST. You can also add your own datasets.
    • --dataset_path: you can specify the path to your dataset (i.e. ImageNet).
    • Checkpoints: specify the path to a pre-trained checkpoint.
      • --checkpoint: load all the parameters including the flow and pixel modules and the discriminator.
    • Logging
      • --log_setp: the frequency of logging ([train step 10] D loss: 1.26449 G loss: 46.01093 (0.057 sec/batch, 558.933 instances/sec)).
      • --ckpt_save_step: the frequency of saving a checkpoint.
      • --write_summary_step: the frequency of writing TensorBoard summaries (default 100).
    • Hyperparameters
      • --batch_size: the mini-batch size (default 8).
      • --max_steps: the max training iterations.
    • GAN
      • --gan_type: the type of GAN: wgan-gp, lsgan, or hinge.
      • --learning_rate_g / learning_rate_d: the learning rates of the generator and the discriminator.
      • --deconv_type: the type of deconv layers.
      • --num_dis_conv: the number of discriminator's conv layers.
      • --norm_type: the type of normalization.

Test models with saved checkpoints:

python evaler.py --dataset [DATASET] [--train_dir /path/to/the/training/dir/ OR --checkpoint /path/to/the/trained/model] --write_summary_image True --output_file output.hdf5
  • Selected arguments (see the config.py for more details)
    • --output_file: dump generated images to a HDF5 file.
    • --write_summary_image: plot an n by n image showing generated images.
    • --summary_image_name: specify the output image name.

Interpret TensorBoard

Launch TensorBoard and go to the specified port, you can see different losses in the scalars tab and plotted images in the images tab. The images could be interpreted as follows.

  • fake_image: a batch of generated images in the current batch
  • img:
    • Top-left: a real image
    • Top-right: a generated image
    • Bottom-left: a spatial map produced by the discrimiantor given the real image shown on the top (D(real image)), reflecting how the discrimiantor thinks about this image. White: real; balck: fake.
    • Bottom-right: a spatial map produced by the discrimiantor given the generated image shown on the top (D(generated image)), reflecting how the discrimiantor thinks about this image. White: real; balck: fake.

Train and test your own datasets:

  • Create a directory
$ mkdir datasets/YOUR_DATASET

Step 1: organize your data

With the HDF5 loader:

  • Store your data as an h5py file datasets/YOUR_DATASET/data.hdf5 and each data point contains
    • 'image': has shape [h, w, c], where c is the number of channels (grayscale images: 1, color images: 3)
  • Maintain a list datasets/YOUR_DATASET/id.txt listing ids of all data points

With the image loader:

  • Put all of images under datasets/YOUR_DATASET
  • Valid image format: .jpg, .jpeg, .JPEG, .webp, and .png.

Step 2: train and test

$ python trainer.py --dataset YOUR_DATASET --dataset_path datasets/YOUR_DATASET
$ python evaler.py --dataset YOUR_DATASET --dataset_path datasets/YOUR_DATASET --train_dir dir

Related Work

CLVR Lab

As part of the implementation series of Cognitive Learning for Vision and Robotics Lab at the University of Southern California, our motivation is to accelerate (or sometimes delay) the research in AI community by promoting open-source projects. To this end, we implement state-of-the-art research papers, and publicly share them with concise reports. Please visit our group GitHub site.

This project is implemented by Shao-Hua Sun and reviewed by Youngwoon Lee.

Author

Shao-Hua Sun / @shaohua0116

About

TensorFlow implementations of Wasserstein GAN with Gradient Penalty (WGAN-GP), Least Squares GAN (LSGAN), GANs with the hinge loss.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%