Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-580] Add SN-GAN example #12419

Merged
merged 19 commits into from
Sep 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ If your tutorial depends on specific packages, simply add them to this provision
* [Gluon Examples](gluon) - several examples using the Gluon API
* [Style Transfer](gluon/style_transfer) - a style transfer example using gluon
* [Word Language Model](gluon/word_language_model) - an example that trains a multi-layer RNN on the Penn Treebank language modeling benchmark
* [SN-GAN](gluon/sn-gan) - an example that utilizes spectral normalization to train GAN(Generative adversarial network) using Gluon API
* [Image Classification with R](image-classification) - image classification on MNIST,CIFAR,ImageNet-1k,ImageNet-Full, with multiple GPU and distributed training.
* [Kaggle 1st national data science bowl](kaggle-ndsb1) - a MXnet example for Kaggle Nation Data Science Bowl 1
* [Kaggle 2nd national data science bowl](kaggle-ndsb2) - a tutorial for Kaggle Second Nation Data Science Bowl
Expand Down
44 changes: 44 additions & 0 deletions example/gluon/sn_gan/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Spectral Normalization GAN

This example implements [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.

## Usage

Example runs and the results:

```python
python train.py --use-gpu --data-path=data
```

* Note that the program would download the CIFAR10 for you

`python train.py --help` gives the following arguments:

```bash
optional arguments:
-h, --help show this help message and exit
--data-path DATA_PATH
path of data.
--batch-size BATCH_SIZE
training batch size. default is 64.
--epochs EPOCHS number of training epochs. default is 100.
--lr LR learning rate. default is 0.0001.
--lr-beta LR_BETA learning rate for the beta in margin based loss.
default is 0.5.
--use-gpu use gpu for training.
--clip_gr CLIP_GR Clip the gradient by projecting onto the box. default
is 10.0.
--z-dim Z_DIM dimension of the latent z vector. default is 100.
```

## Result

![SN-GAN](sn_gan_output.png)

## Learned Spectral Normalization

![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png)

## Reference

[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow)
42 changes: 42 additions & 0 deletions example/gluon/sn_gan/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py

import numpy as np

import mxnet as mx
from mxnet import gluon
from mxnet.gluon.data.vision import CIFAR10

IMAGE_SIZE = 64

def transformer(data, label):
""" data preparation """
data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE)
data = mx.nd.transpose(data, (2, 0, 1))
data = data.astype(np.float32) / 128.0 - 1
return data, label


def get_training_data(batch_size):
""" helper function to get dataloader"""
return gluon.data.DataLoader(
CIFAR10(train=True, transform=transformer),
batch_size=batch_size, shuffle=True, last_batch='discard')
138 changes: 138 additions & 0 deletions example/gluon/sn_gan/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet.gluon import Block


EPSILON = 1e-08
POWER_ITERATION = 1

class SNConv2D(Block):
""" Customized Conv2D to feed the conv with the weight that we apply spectral normalization """

def __init__(self, num_filter, kernel_size,
strides, padding, in_channels,
ctx=mx.cpu(), iterations=1):

super(SNConv2D, self).__init__()

self.num_filter = num_filter
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
self.in_channels = in_channels
self.iterations = iterations
self.ctx = ctx

with self.name_scope():
# init the weight
self.weight = self.params.get('weight', shape=(
num_filter, in_channels, kernel_size, kernel_size))
self.u = self.params.get(
'u', init=mx.init.Normal(), shape=(1, num_filter))

def _spectral_norm(self):
""" spectral normalization """
w = self.params.get('weight').data(self.ctx)
w_mat = nd.reshape(w, [w.shape[0], -1])

_u = self.u.data(self.ctx)
_v = None

for _ in range(POWER_ITERATION):
_v = nd.L2Normalization(nd.dot(_u, w_mat))
_u = nd.L2Normalization(nd.dot(_v, w_mat.T))

sigma = nd.sum(nd.dot(_u, w_mat) * _v)
if sigma == 0.:
sigma = EPSILON

self.params.setattr('u', _u)

return w / sigma

def forward(self, x):
# x shape is batch_size x in_channels x height x width
return nd.Convolution(
data=x,
weight=self._spectral_norm(),
kernel=(self.kernel_size, self.kernel_size),
pad=(self.padding, self.padding),
stride=(self.strides, self.strides),
num_filter=self.num_filter,
no_bias=True
)


def get_generator():
""" construct and return generator """
g_net = gluon.nn.Sequential()
with g_net.name_scope():

g_net.add(gluon.nn.Conv2DTranspose(
channels=512, kernel_size=4, strides=1, padding=0, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(
channels=256, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(
channels=128, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(
channels=64, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.BatchNorm())
g_net.add(gluon.nn.LeakyReLU(0.2))

g_net.add(gluon.nn.Conv2DTranspose(channels=3, kernel_size=4, strides=2, padding=1, use_bias=False))
g_net.add(gluon.nn.Activation('tanh'))

return g_net


def get_descriptor(ctx):
""" construct and return descriptor """
d_net = gluon.nn.Sequential()
with d_net.name_scope():

d_net.add(SNConv2D(num_filter=64, kernel_size=4, strides=2, padding=1, in_channels=3, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=128, kernel_size=4, strides=2, padding=1, in_channels=64, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=256, kernel_size=4, strides=2, padding=1, in_channels=128, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=512, kernel_size=4, strides=2, padding=1, in_channels=256, ctx=ctx))
d_net.add(gluon.nn.LeakyReLU(0.2))

d_net.add(SNConv2D(num_filter=1, kernel_size=4, strides=1, padding=0, in_channels=512, ctx=ctx))

return d_net
Binary file added example/gluon/sn_gan/sn_gan_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
149 changes: 149 additions & 0 deletions example/gluon/sn_gan/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This example is inspired by https://github.com/jason71995/Keras-GAN-Library,
# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb
# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py


import os
import random
import logging
import argparse

from data import get_training_data
from model import get_generator, get_descriptor
from utils import save_image

import mxnet as mx
from mxnet import nd, autograd
from mxnet import gluon

# CLI
parser = argparse.ArgumentParser(
description='train a model for Spectral Normalization GAN.')
parser.add_argument('--data-path', type=str, default='./data',
help='path of data.')
parser.add_argument('--batch-size', type=int, default=64,
help='training batch size. default is 64.')
parser.add_argument('--epochs', type=int, default=100,
help='number of training epochs. default is 100.')
parser.add_argument('--lr', type=float, default=0.0001,
help='learning rate. default is 0.0001.')
parser.add_argument('--lr-beta', type=float, default=0.5,
help='learning rate for the beta in margin based loss. default is 0.5.')
parser.add_argument('--use-gpu', action='store_true',
help='use gpu for training.')
parser.add_argument('--clip_gr', type=float, default=10.0,
help='Clip the gradient by projecting onto the box. default is 10.0.')
parser.add_argument('--z-dim', type=int, default=10,
help='dimension of the latent z vector. default is 100.')
opt = parser.parse_args()

BATCH_SIZE = opt.batch_size
Z_DIM = opt.z_dim
NUM_EPOCHS = opt.epochs
LEARNING_RATE = opt.lr
BETA = opt.lr_beta
OUTPUT_DIR = opt.data_path
CTX = mx.gpu() if opt.use_gpu else mx.cpu()
CLIP_GRADIENT = opt.clip_gr
IMAGE_SIZE = 64


def facc(label, pred):
""" evaluate accuracy """
pred = pred.ravel()
label = label.ravel()
return ((pred > 0.5) == label).mean()


# setting
mx.random.seed(random.randint(1, 10000))
logging.basicConfig(level=logging.DEBUG)

# create output dir
try:
os.makedirs(opt.data_path)
except OSError:
pass

# get training data
train_data = get_training_data(opt.batch_size)

# get model
g_net = get_generator()
d_net = get_descriptor(CTX)

# define loss function
loss = gluon.loss.SigmoidBinaryCrossEntropyLoss()

# initialization
g_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX)
d_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX)
g_trainer = gluon.Trainer(
g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
d_trainer = gluon.Trainer(
d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT})
g_net.collect_params().zero_grad()
d_net.collect_params().zero_grad()
# define evaluation metric
metric = mx.metric.CustomMetric(facc)
# initialize labels
real_label = nd.ones(BATCH_SIZE, CTX)
fake_label = nd.zeros(BATCH_SIZE, CTX)

for epoch in range(NUM_EPOCHS):
for i, (d, _) in enumerate(train_data):
# update D
data = d.as_in_context(CTX)
noise = nd.normal(loc=0, scale=1, shape=(
BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX)
with autograd.record():
# train with real image
output = d_net(data).reshape((-1, 1))
errD_real = loss(output, real_label)
metric.update([real_label, ], [output, ])

# train with fake image
fake_image = g_net(noise)
output = d_net(fake_image.detach()).reshape((-1, 1))
errD_fake = loss(output, fake_label)
errD = errD_real + errD_fake
errD.backward()
metric.update([fake_label, ], [output, ])

d_trainer.step(BATCH_SIZE)
# update G
with autograd.record():
fake_image = g_net(noise)
output = d_net(fake_image).reshape(-1, 1)
errG = loss(output, real_label)
errG.backward()

g_trainer.step(BATCH_SIZE)

# print log infomation every 100 batches
if i % 100 == 0:
name, acc = metric.get()
logging.info('discriminator loss = %f, generator loss = %f, \
binary training acc = %f at iter %d epoch %d',
nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch)
if i == 0:
save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR)

metric.reset()
Loading