Skip to content

Commit

Permalink
[MXNET-580] Add SN-GAN example (apache#12419)
Browse files Browse the repository at this point in the history
* update sn-gan example

* fix naming

* add more comments

* fix naming and refine comments

* make power iteration as one hyperparameter

* deal with divided by zero problem

* replace 0.00000001 with EPSILON

* refactor the example

* add README

* address the feedback

* refine the composing

* fix the typo, delete the redundant piece of code and update the result image

* update folder name to align with others

* update image name

* add the variable back

* remove the redundant piece of code and fix typo
  • Loading branch information
stu1130 authored and anirudh2290 committed Sep 19, 2018
1 parent b04f802 commit 2a4c865
Show file tree
Hide file tree
Showing 7 changed files with 423 additions and 0 deletions.
1 change: 1 addition & 0 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ If your tutorial depends on specific packages, simply add them to this provision
* [Gluon Examples](gluon) - several examples using the Gluon API
* [Style Transfer](gluon/style_transfer) - a style transfer example using gluon
* [Word Language Model](gluon/word_language_model) - an example that trains a multi-layer RNN on the Penn Treebank language modeling benchmark
* [SN-GAN](gluon/sn-gan) - an example that utilizes spectral normalization to train GAN(Generative adversarial network) using Gluon API
* [Image Classification with R](image-classification) - image classification on MNIST,CIFAR,ImageNet-1k,ImageNet-Full, with multiple GPU and distributed training.
* [Kaggle 1st national data science bowl](kaggle-ndsb1) - a MXnet example for Kaggle Nation Data Science Bowl 1
* [Kaggle 2nd national data science bowl](kaggle-ndsb2) - a tutorial for Kaggle Second Nation Data Science Bowl
Expand Down
44 changes: 44 additions & 0 deletions example/gluon/sn_gan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Spectral Normalization GAN

This example implements [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

## Usage

Example runs and the results:

```python
python train.py --use-gpu --data-path=data
```

* Note that the program would download the CIFAR10 for you

`python train.py --help` gives the following arguments:

```bash
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
path of data.
--batch-size BATCH_SIZE
training batch size. default is 64.
--epochs EPOCHS number of training epochs. default is 100.
--lr LR learning rate. default is 0.0001.
--lr-beta LR_BETA learning rate for the beta in margin based loss.
default is 0.5.
--use-gpu use gpu for training.
--clip_gr CLIP_GR Clip the gradient by projecting onto the box. default
is 10.0.
--z-dim Z_DIM dimension of the latent z vector. default is 100.
```
## Result
![SN-GAN](sn_gan_output.png)
## Learned Spectral Normalization
![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png)
## Reference
[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow)
42 changes: 42 additions & 0 deletions example/gluon/sn_gan/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py

import numpy as np

import mxnet as mx
from mxnet import gluon
from mxnet.gluon.data.vision import CIFAR10

IMAGE_SIZE = 64

def transformer(data, label):
""" data preparation """
data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE)
data = mx.nd.transpose(data, (2, 0, 1))
data = data.astype(np.float32) / 128.0 - 1
return data, label


def get_training_data(batch_size):
""" helper function to get dataloader"""
return gluon.data.DataLoader(
CIFAR10(train=True, transform=transformer),
batch_size=batch_size, shuffle=True, last_batch='discard')
138 changes: 138 additions & 0 deletions example/gluon/sn_gan/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet.gluon import Block


EPSILON = 1e-08
POWER_ITERATION = 1

class SNConv2D(Block):
""" Customized Conv2D to feed the conv with the weight that we apply spectral normalization """

def __init__(self, num_filter, kernel_size,
strides, padding, in_channels,
ctx=mx.cpu(), iterations=1):

super(SNConv2D, self).__init__()

self.num_filter = num_filter
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.in_channels = in_channels
self.iterations = iterations
self.ctx = ctx

with self.name_scope():
# init the weight
self.weight = self.params.get('weight', shape=(
num_filter, in_channels, kernel_size, kernel_size))
self.u = self.params.get(
'u', init=mx.init.Normal(), shape=(1, num_filter))

def _spectral_norm(self):
""" spectral normalization """
w = self.params.get('weight').data(self.ctx)
w_mat = nd.reshape(w, [w.shape[0], -1])

_u = self.u.data(self.ctx)
_v = None

for _ in range(POWER_ITERATION):
_v = nd.L2Normalization(nd.dot(_u, w_mat))
_u = nd.L2Normalization(nd.dot(_v, w_mat.T))

sigma = nd.sum(nd.dot(_u, w_mat) * _v)
if sigma == 0.:
sigma = EPSILON

self.params.setattr('u', _u)

return w / sigma

def forward(self, x):
# x shape is batch_size x in_channels x height x width
return nd.Convolution(
data=x,
weight=self._spectral_norm(),
kernel=(self.kernel_size, self.kernel_size),
pad=(self.padding, self.padding),
stride=(self.strides, self.strides),
num_filter=self.num_filter,
no_bias=True
)


def get_generator():
""" construct and return generator """
g_net = gluon.nn.Sequential()
with g_net.name_scope():

g_net.add(gluon.nn.Conv2DTranspose(
channels=512, kernel_size=4, strides=1, padding=0, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(
channels=256, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(
channels=128, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(
channels=64, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(channels=3, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.Activation('tanh'))

return g_net


def get_descriptor(ctx):
""" construct and return descriptor """
d_net = gluon.nn.Sequential()
with d_net.name_scope():

d_net.add(SNConv2D(num_filter=64, kernel_size=4, strides=2, padding=1, in_channels=3, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=128, kernel_size=4, strides=2, padding=1, in_channels=64, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=256, kernel_size=4, strides=2, padding=1, in_channels=128, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=512, kernel_size=4, strides=2, padding=1, in_channels=256, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=1, kernel_size=4, strides=1, padding=0, in_channels=512, ctx=ctx))

return d_net
Binary file added example/gluon/sn_gan/sn_gan_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
149 changes: 149 additions & 0 deletions example/gluon/sn_gan/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py


import os
import random
import logging
import argparse

from data import get_training_data
from model import get_generator, get_descriptor
from utils import save_image

import mxnet as mx
from mxnet import nd, autograd
from mxnet import gluon

# CLI
parser = argparse.ArgumentParser(
description='train a model for Spectral Normalization GAN.')
parser.add_argument('--data-path', type=str, default='./data',
help='path of data.')
parser.add_argument('--batch-size', type=int, default=64,
help='training batch size. default is 64.')
parser.add_argument('--epochs', type=int, default=100,
help='number of training epochs. default is 100.')
parser.add_argument('--lr', type=float, default=0.0001,
help='learning rate. default is 0.0001.')
parser.add_argument('--lr-beta', type=float, default=0.5,
help='learning rate for the beta in margin based loss. default is 0.5.')
parser.add_argument('--use-gpu', action='store_true',
help='use gpu for training.')
parser.add_argument('--clip_gr', type=float, default=10.0,
help='Clip the gradient by projecting onto the box. default is 10.0.')
parser.add_argument('--z-dim', type=int, default=10,
help='dimension of the latent z vector. default is 100.')
opt = parser.parse_args()

BATCH_SIZE = opt.batch_size
Z_DIM = opt.z_dim
NUM_EPOCHS = opt.epochs
LEARNING_RATE = opt.lr
BETA = opt.lr_beta
OUTPUT_DIR = opt.data_path
CTX = mx.gpu() if opt.use_gpu else mx.cpu()
CLIP_GRADIENT = opt.clip_gr
IMAGE_SIZE = 64


def facc(label, pred):
""" evaluate accuracy """
pred = pred.ravel()
label = label.ravel()
return ((pred > 0.5) == label).mean()


# setting
mx.random.seed(random.randint(1, 10000))
logging.basicConfig(level=logging.DEBUG)

# create output dir
try:
os.makedirs(opt.data_path)
except OSError:
pass

# get training data
train_data = get_training_data(opt.batch_size)

# get model
g_net = get_generator()
d_net = get_descriptor(CTX)

# define loss function
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

# initialization
g_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX)
d_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX)
g_trainer = gluon.Trainer(
g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
d_trainer = gluon.Trainer(
d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
g_net.collect_params().zero_grad()
d_net.collect_params().zero_grad()
# define evaluation metric
metric = mx.metric.CustomMetric(facc)
# initialize labels
real_label = nd.ones(BATCH_SIZE, CTX)
fake_label = nd.zeros(BATCH_SIZE, CTX)

for epoch in range(NUM_EPOCHS):
for i, (d, _) in enumerate(train_data):
# update D
data = d.as_in_context(CTX)
noise = nd.normal(loc=0, scale=1, shape=(
BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX)
with autograd.record():
# train with real image
output = d_net(data).reshape((-1, 1))
errD_real = loss(output, real_label)
metric.update([real_label, ], [output, ])

# train with fake image
fake_image = g_net(noise)
output = d_net(fake_image.detach()).reshape((-1, 1))
errD_fake = loss(output, fake_label)
errD = errD_real + errD_fake
errD.backward()
metric.update([fake_label, ], [output, ])

d_trainer.step(BATCH_SIZE)
# update G
with autograd.record():
fake_image = g_net(noise)
output = d_net(fake_image).reshape(-1, 1)
errG = loss(output, real_label)
errG.backward()

g_trainer.step(BATCH_SIZE)

# print log infomation every 100 batches
if i % 100 == 0:
name, acc = metric.get()
logging.info('discriminator loss = %f, generator loss = %f, \
binary training acc = %f at iter %d epoch %d',
nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch)
if i == 0:
save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR)

metric.reset()
Loading

0 comments on commit 2a4c865

Please sign in to comment.