Skip to content

vanilla_gan

Zhenhuan Liu edited this page Oct 27, 2021 · 3 revisions

Analysis on Vanilla GAN

In this article, I want to share some characteristics of GANs and practical tricks that help stabilize the training procedure. Especially, it’s easy to conduct corresponding experiments with the code in this repository. Corresponding code is in gan.py.

Brief Introduction

Generative adversarial networks(GANs) are one of the most powerful generative models. GANs are able to generate high resolution realistic images.

Typically, GANs consist of one generator that generates fake images from a prior distribution like gaussian and one discriminator that distinguishs real images vs. fake images. These two networks are updated alternatively to optimize the following objective: f1 (Note the discriminator is in the inner loop.)

In practice, the generator is to maximize f2 instead of minimize f3 to prevent vanish gradient and speed up the traning procedure.

Drawbacks of GANs

While GANs can be used to generate realistic images, they are not stable to convergence. It’s often required to carefully tune the hyperparameters of the networks to make the generator converge. Another common problem of GANs is mode collapse, which means only a part of training images can be generated.

Some drawbacks of GANs can be summarized as follows: 1. Sensitive to learning rate 2. Sensitive to model complexity. 3. Mode collapse. 4. Requires enough data.

Training tricks that help stabilize the training procedure

  1. Use fully convolutional architecture.

  2. Normalize input value to [-1, 1] and use tanh as activation in last layer of generator.

  3. Use batch normalization in both generator and discrminator.

Experiments

In this article, I conduct these experiments on MNIST for 20 epochs for efficiency. For each experiment, i set learning rate of generator and discrimnator from 1e-3, 5e-4, 1e-4 with different combinations.

MLP network without batchnorm

The command to run this experiment is:

python run.py -m model=gan datamodule=mnist networks=mlp networks.encoder.batch_norm=false networks.decoder.batch_norm=false model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='noBN_lrG_${model.lrG}_lrD_${model.lrD}'

The -m option means running all experiments with combinations of different lrG and lrD at the same time. You can also conduct experiemnts on other datasets like celeba by simply setting datamodule=celeba.

The experiments resutls is shown in the following table.

learning rate lrD=1e-3 lrD=5e-4 lrD=1e-4

lrG=le-3

mlp+noBN lrG 0.001 lrD 0.001

mlp+noBN lrG 0.001 lrD 0.0005

mlp+noBN lrG 0.001 lrD 0.0001

lrG=5e-4

mlp+noBN lrG 0.0005 lrD 0.001

mlp+noBN lrG 0.0005 lrD 0.0005

mlp+noBN lrG 0.0005 lrD 0.0001

lrG=1e-4

mlp+noBN lrG 0.0001 lrD 0.001

mlp+noBN lrG 0.0001 lrD 0.0005

mlp+noBN lrG 0.0001 lrD 0.0001

As you can see, only the left top 4 experiments generate high quality images. While other experiments do not converge to a meaningul result or suffer the mode collapse problem which only generates digits 1. If your GANs don’t converges, first try to tune the learning rate of generator and discriminator when you are sure there are no bugs in your code.

MLP network with batchnorm

python run.py -m model=gan datamodule=mnist networks=mlp model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='mlp_lrG_${model.lrG}_lrD_${model.lrD}'
learning rate lrD=1e-3 lrD=5e-4 lrD=1e-4

lrG=le-3

lrG 0.001 lrD 0.001

lrG 0.001 lrD 0.0005

lrG 0.001 lrD 0.0001

lrG=5e-4

lrG 0.0005 lrD 0.001

lrG 0.0005 lrD 0.0005

lrG 0.0005 lrD 0.0001

lrG=1e-4

lrG 0.0001 lrD 0.001

lrG 0.0001 lrD 0.0005

lrG 0.0001 lrD 0.0001

Convolution Network without batchnorm

python run.py -m model=gan datamodule=mnist networks=conv_mnist networks.encoder.batch_norm=false networks.decoder.batch_norm=false model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='conv_noBN_lrG_${model.lrG}_lrD_${model.lrD}'
learning rate lrD=1e-3 lrD=5e-4 lrD=1e-4

lrG=le-3

noBN lrG 0.001 lrD 0.001

noBN lrG 0.001 lrD 0.0005

noBN lrG 0.001 lrD 0.0001

lrG=5e-4

noBN lrG 0.0005 lrD 0.001

noBN lrG 0.0005 lrD 0.0005

noBN lrG 0.0005 lrD 0.0001

lrG=1e-4

noBN lrG 0.0001 lrD 0.001

noBN lrG 0.0001 lrD 0.0005

noBN lrG 0.0001 lrD 0.0001

Convolution Network with batchnorm

python run.py -m model=gan datamodule=mnist networks=conv_mnist model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='Conv_bn_lrG_{model.lrG}_lrD_{model.lrD}'
learning rate lrD=1e-3 lrD=5e-4 lrD=1e-4

lrG=le-3

lrG 0.001 lrD 0.001

lrG 0.001 lrD 0.0005

lrG 0.001 lrD 0.0001

lrG=5e-4

lrG 0.0005 lrD 0.001

lrG 0.0005 lrD 0.0005

lrG 0.0005 lrD 0.0001

lrG=1e-4

lrG 0.0001 lrD 0.001

lrG 0.0001 lrD 0.0005

lrG 0.0001 lrD 0.0001

It can be observed that with batch normalization, GANs much more stable with different learning rates.

References