diff --git a/example/README.md b/example/README.md index 6b9a086ff5e1..2123104a1487 100644 --- a/example/README.md +++ b/example/README.md @@ -95,6 +95,7 @@ If your tutorial depends on specific packages, simply add them to this provision * [Gluon Examples](gluon) - several examples using the Gluon API * [Style Transfer](gluon/style_transfer) - a style transfer example using gluon * [Word Language Model](gluon/word_language_model) - an example that trains a multi-layer RNN on the Penn Treebank language modeling benchmark + * [SN-GAN](gluon/sn-gan) - an example that utilizes spectral normalization to train GAN(Generative adversarial network) using Gluon API * [Image Classification with R](image-classification) - image classification on MNIST,CIFAR,ImageNet-1k,ImageNet-Full, with multiple GPU and distributed training. * [Kaggle 1st national data science bowl](kaggle-ndsb1) - a MXnet example for Kaggle Nation Data Science Bowl 1 * [Kaggle 2nd national data science bowl](kaggle-ndsb2) - a tutorial for Kaggle Second Nation Data Science Bowl diff --git a/example/gluon/sn_gan/README.md b/example/gluon/sn_gan/README.md new file mode 100644 index 000000000000..5b2a750e4efb --- /dev/null +++ b/example/gluon/sn_gan/README.md @@ -0,0 +1,44 @@ +# Spectral Normalization GAN + +This example implements [Spectral Normalization for Generative Adversarial Networks](https://arxiv.org/abs/1802.05957) based on [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset. + +## Usage + +Example runs and the results: + +```python +python train.py --use-gpu --data-path=data +``` + +* Note that the program would download the CIFAR10 for you + +`python train.py --help` gives the following arguments: + +```bash +optional arguments: + -h, --help show this help message and exit + --data-path DATA_PATH + path of data. + --batch-size BATCH_SIZE + training batch size. default is 64. + --epochs EPOCHS number of training epochs. default is 100. + --lr LR learning rate. default is 0.0001. + --lr-beta LR_BETA learning rate for the beta in margin based loss. + default is 0.5. + --use-gpu use gpu for training. + --clip_gr CLIP_GR Clip the gradient by projecting onto the box. default + is 10.0. + --z-dim Z_DIM dimension of the latent z vector. default is 100. +``` + +## Result + +![SN-GAN](sn_gan_output.png) + +## Learned Spectral Normalization + +![alt text](https://github.com/taki0112/Spectral_Normalization-Tensorflow/blob/master/assests/sn.png) + +## Reference + +[Simple Tensorflow Implementation](https://github.com/taki0112/Spectral_Normalization-Tensorflow) \ No newline at end of file diff --git a/example/gluon/sn_gan/data.py b/example/gluon/sn_gan/data.py new file mode 100644 index 000000000000..333125dbe9fe --- /dev/null +++ b/example/gluon/sn_gan/data.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, +# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb +# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py + +import numpy as np + +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.data.vision import CIFAR10 + +IMAGE_SIZE = 64 + +def transformer(data, label): + """ data preparation """ + data = mx.image.imresize(data, IMAGE_SIZE, IMAGE_SIZE) + data = mx.nd.transpose(data, (2, 0, 1)) + data = data.astype(np.float32) / 128.0 - 1 + return data, label + + +def get_training_data(batch_size): + """ helper function to get dataloader""" + return gluon.data.DataLoader( + CIFAR10(train=True, transform=transformer), + batch_size=batch_size, shuffle=True, last_batch='discard') diff --git a/example/gluon/sn_gan/model.py b/example/gluon/sn_gan/model.py new file mode 100644 index 000000000000..38f87ebddc8a --- /dev/null +++ b/example/gluon/sn_gan/model.py @@ -0,0 +1,138 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, +# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb +# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py + +import mxnet as mx +from mxnet import nd +from mxnet import gluon +from mxnet.gluon import Block + + +EPSILON = 1e-08 +POWER_ITERATION = 1 + +class SNConv2D(Block): + """ Customized Conv2D to feed the conv with the weight that we apply spectral normalization """ + + def __init__(self, num_filter, kernel_size, + strides, padding, in_channels, + ctx=mx.cpu(), iterations=1): + + super(SNConv2D, self).__init__() + + self.num_filter = num_filter + self.kernel_size = kernel_size + self.strides = strides + self.padding = padding + self.in_channels = in_channels + self.iterations = iterations + self.ctx = ctx + + with self.name_scope(): + # init the weight + self.weight = self.params.get('weight', shape=( + num_filter, in_channels, kernel_size, kernel_size)) + self.u = self.params.get( + 'u', init=mx.init.Normal(), shape=(1, num_filter)) + + def _spectral_norm(self): + """ spectral normalization """ + w = self.params.get('weight').data(self.ctx) + w_mat = nd.reshape(w, [w.shape[0], -1]) + + _u = self.u.data(self.ctx) + _v = None + + for _ in range(POWER_ITERATION): + _v = nd.L2Normalization(nd.dot(_u, w_mat)) + _u = nd.L2Normalization(nd.dot(_v, w_mat.T)) + + sigma = nd.sum(nd.dot(_u, w_mat) * _v) + if sigma == 0.: + sigma = EPSILON + + self.params.setattr('u', _u) + + return w / sigma + + def forward(self, x): + # x shape is batch_size x in_channels x height x width + return nd.Convolution( + data=x, + weight=self._spectral_norm(), + kernel=(self.kernel_size, self.kernel_size), + pad=(self.padding, self.padding), + stride=(self.strides, self.strides), + num_filter=self.num_filter, + no_bias=True + ) + + +def get_generator(): + """ construct and return generator """ + g_net = gluon.nn.Sequential() + with g_net.name_scope(): + + g_net.add(gluon.nn.Conv2DTranspose( + channels=512, kernel_size=4, strides=1, padding=0, use_bias=False)) + g_net.add(gluon.nn.BatchNorm()) + g_net.add(gluon.nn.LeakyReLU(0.2)) + + g_net.add(gluon.nn.Conv2DTranspose( + channels=256, kernel_size=4, strides=2, padding=1, use_bias=False)) + g_net.add(gluon.nn.BatchNorm()) + g_net.add(gluon.nn.LeakyReLU(0.2)) + + g_net.add(gluon.nn.Conv2DTranspose( + channels=128, kernel_size=4, strides=2, padding=1, use_bias=False)) + g_net.add(gluon.nn.BatchNorm()) + g_net.add(gluon.nn.LeakyReLU(0.2)) + + g_net.add(gluon.nn.Conv2DTranspose( + channels=64, kernel_size=4, strides=2, padding=1, use_bias=False)) + g_net.add(gluon.nn.BatchNorm()) + g_net.add(gluon.nn.LeakyReLU(0.2)) + + g_net.add(gluon.nn.Conv2DTranspose(channels=3, kernel_size=4, strides=2, padding=1, use_bias=False)) + g_net.add(gluon.nn.Activation('tanh')) + + return g_net + + +def get_descriptor(ctx): + """ construct and return descriptor """ + d_net = gluon.nn.Sequential() + with d_net.name_scope(): + + d_net.add(SNConv2D(num_filter=64, kernel_size=4, strides=2, padding=1, in_channels=3, ctx=ctx)) + d_net.add(gluon.nn.LeakyReLU(0.2)) + + d_net.add(SNConv2D(num_filter=128, kernel_size=4, strides=2, padding=1, in_channels=64, ctx=ctx)) + d_net.add(gluon.nn.LeakyReLU(0.2)) + + d_net.add(SNConv2D(num_filter=256, kernel_size=4, strides=2, padding=1, in_channels=128, ctx=ctx)) + d_net.add(gluon.nn.LeakyReLU(0.2)) + + d_net.add(SNConv2D(num_filter=512, kernel_size=4, strides=2, padding=1, in_channels=256, ctx=ctx)) + d_net.add(gluon.nn.LeakyReLU(0.2)) + + d_net.add(SNConv2D(num_filter=1, kernel_size=4, strides=1, padding=0, in_channels=512, ctx=ctx)) + + return d_net diff --git a/example/gluon/sn_gan/sn_gan_output.png b/example/gluon/sn_gan/sn_gan_output.png new file mode 100644 index 000000000000..428c33315023 Binary files /dev/null and b/example/gluon/sn_gan/sn_gan_output.png differ diff --git a/example/gluon/sn_gan/train.py b/example/gluon/sn_gan/train.py new file mode 100644 index 000000000000..1cba1f57d0a0 --- /dev/null +++ b/example/gluon/sn_gan/train.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, +# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb +# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py + + +import os +import random +import logging +import argparse + +from data import get_training_data +from model import get_generator, get_descriptor +from utils import save_image + +import mxnet as mx +from mxnet import nd, autograd +from mxnet import gluon + +# CLI +parser = argparse.ArgumentParser( + description='train a model for Spectral Normalization GAN.') +parser.add_argument('--data-path', type=str, default='./data', + help='path of data.') +parser.add_argument('--batch-size', type=int, default=64, + help='training batch size. default is 64.') +parser.add_argument('--epochs', type=int, default=100, + help='number of training epochs. default is 100.') +parser.add_argument('--lr', type=float, default=0.0001, + help='learning rate. default is 0.0001.') +parser.add_argument('--lr-beta', type=float, default=0.5, + help='learning rate for the beta in margin based loss. default is 0.5.') +parser.add_argument('--use-gpu', action='store_true', + help='use gpu for training.') +parser.add_argument('--clip_gr', type=float, default=10.0, + help='Clip the gradient by projecting onto the box. default is 10.0.') +parser.add_argument('--z-dim', type=int, default=10, + help='dimension of the latent z vector. default is 100.') +opt = parser.parse_args() + +BATCH_SIZE = opt.batch_size +Z_DIM = opt.z_dim +NUM_EPOCHS = opt.epochs +LEARNING_RATE = opt.lr +BETA = opt.lr_beta +OUTPUT_DIR = opt.data_path +CTX = mx.gpu() if opt.use_gpu else mx.cpu() +CLIP_GRADIENT = opt.clip_gr +IMAGE_SIZE = 64 + + +def facc(label, pred): + """ evaluate accuracy """ + pred = pred.ravel() + label = label.ravel() + return ((pred > 0.5) == label).mean() + + +# setting +mx.random.seed(random.randint(1, 10000)) +logging.basicConfig(level=logging.DEBUG) + +# create output dir +try: + os.makedirs(opt.data_path) +except OSError: + pass + +# get training data +train_data = get_training_data(opt.batch_size) + +# get model +g_net = get_generator() +d_net = get_descriptor(CTX) + +# define loss function +loss = gluon.loss.SigmoidBinaryCrossEntropyLoss() + +# initialization +g_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX) +d_net.collect_params().initialize(mx.init.Xavier(), ctx=CTX) +g_trainer = gluon.Trainer( + g_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT}) +d_trainer = gluon.Trainer( + d_net.collect_params(), 'Adam', {'learning_rate': LEARNING_RATE, 'beta1': BETA, 'clip_gradient': CLIP_GRADIENT}) +g_net.collect_params().zero_grad() +d_net.collect_params().zero_grad() +# define evaluation metric +metric = mx.metric.CustomMetric(facc) +# initialize labels +real_label = nd.ones(BATCH_SIZE, CTX) +fake_label = nd.zeros(BATCH_SIZE, CTX) + +for epoch in range(NUM_EPOCHS): + for i, (d, _) in enumerate(train_data): + # update D + data = d.as_in_context(CTX) + noise = nd.normal(loc=0, scale=1, shape=( + BATCH_SIZE, Z_DIM, 1, 1), ctx=CTX) + with autograd.record(): + # train with real image + output = d_net(data).reshape((-1, 1)) + errD_real = loss(output, real_label) + metric.update([real_label, ], [output, ]) + + # train with fake image + fake_image = g_net(noise) + output = d_net(fake_image.detach()).reshape((-1, 1)) + errD_fake = loss(output, fake_label) + errD = errD_real + errD_fake + errD.backward() + metric.update([fake_label, ], [output, ]) + + d_trainer.step(BATCH_SIZE) + # update G + with autograd.record(): + fake_image = g_net(noise) + output = d_net(fake_image).reshape(-1, 1) + errG = loss(output, real_label) + errG.backward() + + g_trainer.step(BATCH_SIZE) + + # print log infomation every 100 batches + if i % 100 == 0: + name, acc = metric.get() + logging.info('discriminator loss = %f, generator loss = %f, \ + binary training acc = %f at iter %d epoch %d', + nd.mean(errD).asscalar(), nd.mean(errG).asscalar(), acc, i, epoch) + if i == 0: + save_image(fake_image, epoch, IMAGE_SIZE, BATCH_SIZE, OUTPUT_DIR) + + metric.reset() diff --git a/example/gluon/sn_gan/utils.py b/example/gluon/sn_gan/utils.py new file mode 100644 index 000000000000..d3f1b8626a1a --- /dev/null +++ b/example/gluon/sn_gan/utils.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This example is inspired by https://github.com/jason71995/Keras-GAN-Library, +# https://github.com/kazizzad/DCGAN-Gluon-MxNet/blob/master/MxnetDCGAN.ipynb +# https://github.com/apache/incubator-mxnet/blob/master/example/gluon/dcgan.py + +import math + +import numpy as np +import imageio + +def save_image(data, epoch, image_size, batch_size, output_dir, padding=2): + """ save image """ + data = data.asnumpy().transpose((0, 2, 3, 1)) + datanp = np.clip( + (data - np.min(data))*(255.0/(np.max(data) - np.min(data))), 0, 255).astype(np.uint8) + x_dim = min(8, batch_size) + y_dim = int(math.ceil(float(batch_size) / x_dim)) + height, width = int(image_size + padding), int(image_size + padding) + grid = np.zeros((height * y_dim + 1 + padding // 2, width * + x_dim + 1 + padding // 2, 3), dtype=np.uint8) + k = 0 + for y in range(y_dim): + for x in range(x_dim): + if k >= batch_size: + break + start_y = y * height + 1 + padding // 2 + end_y = start_y + height - padding + start_x = x * width + 1 + padding // 2 + end_x = start_x + width - padding + np.copyto(grid[start_y:end_y, start_x:end_x, :], datanp[k]) + k += 1 + imageio.imwrite( + '{}/fake_samples_epoch_{}.png'.format(output_dir, epoch), grid)