-
Notifications
You must be signed in to change notification settings - Fork 10
vanilla_gan
In this article, I want to share some characteristics of GANs and practical tricks that help stabilize the training procedure. Especially, it’s easy to conduct corresponding experiments with the code in this repository. Corresponding code is in gan.py.
Generative adversarial networks(GANs) are one of the most powerful generative models. GANs are able to generate high resolution realistic images.
Typically, GANs consist of one generator that generates fake images from a prior distribution like gaussian and one discriminator that distinguishs real images vs. fake images. These two networks are updated alternatively to optimize the following objective: (Note the discriminator is in the inner loop.)
In practice, the generator is to maximize instead of minimize to prevent vanish gradient and speed up the traning procedure.
While GANs can be used to generate realistic images, they are not stable to convergence. It’s often required to carefully tune the hyperparameters of the networks to make the generator converge. Another common problem of GANs is mode collapse, which means only a part of training images can be generated.
Some drawbacks of GANs can be summarized as follows: 1. Sensitive to learning rate 2. Sensitive to model complexity. 3. Mode collapse. 4. Requires enough data.
-
Use fully convolutional architecture.
-
Normalize input value to [-1, 1] and use tanh as activation in last layer of generator.
-
Use batch normalization in both generator and discrminator.
In this article, I conduct these experiments on MNIST for 20 epochs for efficiency. For each experiment, i set learning rate of generator and discrimnator from 1e-3, 5e-4, 1e-4 with different combinations.
The command to run this experiment is:
python run.py -m model=gan datamodule=mnist networks=mlp networks.encoder.batch_norm=false networks.decoder.batch_norm=false model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='noBN_lrG_${model.lrG}_lrD_${model.lrD}'
The -m
option means running all experiments with combinations of different lrG and lrD at the same time.
You can also conduct experiemnts on other datasets like celeba by simply setting datamodule=celeba
.
The experiments resutls is shown in the following table.
learning rate | lrD=1e-3 | lrD=5e-4 | lrD=1e-4 |
---|---|---|---|
lrG=le-3 |
|||
lrG=5e-4 |
|||
lrG=1e-4 |
As you can see, only the left top 4 experiments generate high quality images. While other experiments do not converge to a meaningul result or suffer the mode collapse problem which only generates digits 1. If your GANs don’t converges, first try to tune the learning rate of generator and discriminator when you are sure there are no bugs in your code.
python run.py -m model=gan datamodule=mnist networks=mlp model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='mlp_lrG_${model.lrG}_lrD_${model.lrD}'
learning rate | lrD=1e-3 | lrD=5e-4 | lrD=1e-4 |
---|---|---|---|
lrG=le-3 |
|||
lrG=5e-4 |
|||
lrG=1e-4 |
python run.py -m model=gan datamodule=mnist networks=conv_mnist networks.encoder.batch_norm=false networks.decoder.batch_norm=false model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='conv_noBN_lrG_${model.lrG}_lrD_${model.lrD}'
learning rate | lrD=1e-3 | lrD=5e-4 | lrD=1e-4 |
---|---|---|---|
lrG=le-3 |
|||
lrG=5e-4 |
|||
lrG=1e-4 |
python run.py -m model=gan datamodule=mnist networks=conv_mnist model.lrG=1e-3,5e-4,1e-4 model.lrD=1e-3,5e-4,1e-4 exp_name='Conv_bn_lrG_{model.lrG}_lrD_{model.lrD}'
learning rate | lrD=1e-3 | lrD=5e-4 | lrD=1e-4 |
---|---|---|---|
lrG=le-3 |
|||
lrG=5e-4 |
|||
lrG=1e-4 |
It can be observed that with batch normalization, GANs much more stable with different learning rates.