diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md index b893d5841254..790f6b496516 100644 --- a/docs/api/python/gluon/contrib.md +++ b/docs/api/python/gluon/contrib.md @@ -54,6 +54,9 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p Identity SparseEmbedding SyncBatchNorm + PixelShuffle1D + PixelShuffle2D + PixelShuffle3D ``` ### Recurrent neural network diff --git a/example/gluon/super_resolution/super_resolution.py b/example/gluon/super_resolution/super_resolution.py index 0f2f21f3c0a7..198f6fe0611b 100644 --- a/example/gluon/super_resolution/super_resolution.py +++ b/example/gluon/super_resolution/super_resolution.py @@ -16,19 +16,27 @@ # under the License. from __future__ import print_function -import argparse, tarfile + +import argparse import math import os +import shutil +import sys +import zipfile +from os import path + import numpy as np import mxnet as mx -import mxnet.ndarray as F -from mxnet import gluon +from mxnet import gluon, autograd as ag from mxnet.gluon import nn -from mxnet import autograd as ag -from mxnet.test_utils import download +from mxnet.gluon.contrib import nn as contrib_nn from mxnet.image import CenterCropAug, ResizeAug from mxnet.io import PrefetchingIter +from mxnet.test_utils import download + +this_dir = path.abspath(path.dirname(__file__)) +sys.path.append(path.join(this_dir, path.pardir)) from data import ImagePairIter @@ -51,19 +59,45 @@ batch_size, test_batch_size = opt.batch_size, opt.test_batch_size color_flag = 0 -# get data -dataset_path = "dataset" -dataset_url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz" -def get_dataset(prefetch=False): - image_path = os.path.join(dataset_path, "BSDS300/images") +# Get data +datasets_dir = path.expanduser(path.join("~", ".mxnet", "datasets")) +datasets_tmpdir = path.join(datasets_dir, "tmp") +dataset_url = "https://github.com/BIDS/BSDS500/archive/master.zip" +data_dir = path.expanduser(path.join(datasets_dir, "BSDS500")) +tmp_dir = path.join(data_dir, "tmp") - if not os.path.exists(image_path): - os.makedirs(dataset_path) - file_name = download(dataset_url) - with tarfile.open(file_name) as tar: - for item in tar: - tar.extract(item, dataset_path) - os.remove(file_name) +def get_dataset(prefetch=False): + """Download the BSDS500 dataset and return train and test iters.""" + + if path.exists(data_dir): + print( + "Directory {} already exists, skipping.\n" + "To force download and extraction, delete the directory and re-run." + "".format(data_dir), + file=sys.stderr, + ) + else: + print("Downloading dataset...", file=sys.stderr) + downloaded_file = download(dataset_url, dirname=datasets_tmpdir) + print("done", file=sys.stderr) + + print("Extracting files...", end="", file=sys.stderr) + os.makedirs(data_dir) + os.makedirs(tmp_dir) + with zipfile.ZipFile(downloaded_file) as archive: + archive.extractall(tmp_dir) + shutil.rmtree(datasets_tmpdir) + + shutil.copytree( + path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"), + path.join(data_dir, "images"), + ) + shutil.copytree( + path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"), + path.join(data_dir, "groundTruth"), + ) + shutil.rmtree(tmp_dir) + print("done", file=sys.stderr) crop_size = 256 crop_size -= crop_size % upscale_factor @@ -72,15 +106,26 @@ def get_dataset(prefetch=False): input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)] target_transform = [CenterCropAug((crop_size, crop_size))] - iters = (ImagePairIter(os.path.join(image_path, "train"), - (input_crop_size, input_crop_size), - (crop_size, crop_size), - batch_size, color_flag, input_transform, target_transform), - ImagePairIter(os.path.join(image_path, "test"), - (input_crop_size, input_crop_size), - (crop_size, crop_size), - test_batch_size, color_flag, - input_transform, target_transform)) + iters = ( + ImagePairIter( + path.join(data_dir, "images", "train"), + (input_crop_size, input_crop_size), + (crop_size, crop_size), + batch_size, + color_flag, + input_transform, + target_transform, + ), + ImagePairIter( + path.join(data_dir, "images", "test"), + (input_crop_size, input_crop_size), + (crop_size, crop_size), + test_batch_size, + color_flag, + input_transform, + target_transform, + ), + ) return [PrefetchingIter(i) for i in iters] if prefetch else iters @@ -90,33 +135,23 @@ def get_dataset(prefetch=False): ctx = [mx.gpu(0)] if opt.use_gpu else [mx.cpu()] -# define model -def _rearrange(raw, F, upscale_factor): - # (N, C * r^2, H, W) -> (N, C, r^2, H, W) - splitted = F.reshape(raw, shape=(0, -4, -1, upscale_factor**2, 0, 0)) - # (N, C, r^2, H, W) -> (N, C, r, r, H, W) - unflatten = F.reshape(splitted, shape=(0, 0, -4, upscale_factor, upscale_factor, 0, 0)) - # (N, C, r, r, H, W) -> (N, C, H, r, W, r) - swapped = F.transpose(unflatten, axes=(0, 1, 4, 2, 5, 3)) - # (N, C, H, r, W, r) -> (N, C, H*r, W*r) - return F.reshape(swapped, shape=(0, 0, -3, -3)) - - -class SuperResolutionNet(gluon.Block): +class SuperResolutionNet(gluon.HybridBlock): def __init__(self, upscale_factor): super(SuperResolutionNet, self).__init__() with self.name_scope(): - self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2)) - self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1)) - self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1)) + self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), activation='relu') + self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu') + self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu') self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1)) - self.upscale_factor = upscale_factor + self.pxshuf = contrib_nn.PixelShuffle2D(upscale_factor) - def forward(self, x): - x = F.Activation(self.conv1(x), act_type='relu') - x = F.Activation(self.conv2(x), act_type='relu') - x = F.Activation(self.conv3(x), act_type='relu') - return _rearrange(self.conv4(x), F, self.upscale_factor) + def hybrid_forward(self, F, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + x = self.pxshuf(x) + return x net = SuperResolutionNet(upscale_factor) metric = mx.metric.MSE() @@ -136,7 +171,7 @@ def test(ctx): avg_psnr += 10 * math.log10(1/metric.get()[1]) metric.reset() avg_psnr /= batches - print('validation avg psnr: %f'%avg_psnr) + print('validation avg psnr: %f' % avg_psnr) def train(epoch, ctx): @@ -168,13 +203,18 @@ def train(epoch, ctx): print('training mse at epoch %d: %s=%f'%(i, name, acc)) test(ctx) - net.save_parameters('superres.params') + net.save_parameters(path.join(this_dir, 'superres.params')) def resolve(ctx): from PIL import Image + if isinstance(ctx, list): ctx = [ctx[0]] - net.load_parameters('superres.params', ctx=ctx) + + img_basename = path.splitext(path.basename(opt.resolve_img))[0] + img_dirname = path.dirname(opt.resolve_img) + + net.load_parameters(path.join(this_dir, 'superres.params'), ctx=ctx) img = Image.open(opt.resolve_img).convert('YCbCr') y, cb, cr = img.split() data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0) @@ -186,7 +226,7 @@ def resolve(ctx): out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC) out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB') - out_img.save('resolved.png') + out_img.save(path.join(img_dirname, '{}-resolved.png'.format(img_basename))) if opt.resolve_img: resolve(ctx) diff --git a/python/mxnet/gluon/contrib/nn/__init__.py b/python/mxnet/gluon/contrib/nn/__init__.py index 62440cda27e2..5eb46f6c08ec 100644 --- a/python/mxnet/gluon/contrib/nn/__init__.py +++ b/python/mxnet/gluon/contrib/nn/__init__.py @@ -17,7 +17,7 @@ # coding: utf-8 # pylint: disable=wildcard-import -"""Contrib recurrent neural network module.""" +"""Contributed neural network modules.""" from . import basic_layers diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 56f0809b345f..ebe136e30208 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -18,8 +18,10 @@ # coding: utf-8 # pylint: disable= arguments-differ """Custom neural network layers in model_zoo.""" + __all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding', - 'SyncBatchNorm'] + 'SyncBatchNorm', 'PixelShuffle1D', 'PixelShuffle2D', + 'PixelShuffle3D'] import warnings from .... import nd, test_utils @@ -238,3 +240,180 @@ def _get_num_devices(self): def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var, name='fwd', **self._kwargs) + +class PixelShuffle1D(HybridBlock): + + r"""Pixel-shuffle layer for upsampling in 1 dimension. + + Pixel-shuffling is the operation of taking groups of values along + the *channel* dimension and regrouping them into blocks of pixels + along the ``W`` dimension, thereby effectively multiplying that dimension + by a constant factor in size. + + For example, a feature map of shape :math:`(fC, W)` is reshaped + into :math:`(C, fW)` by forming little value groups of size :math:`f` + and arranging them in a grid of size :math:`W`. + + Parameters + ---------- + factor : int or 1-tuple of int + Upsampling factor, applied to the ``W`` dimension. + + Inputs: + - **data**: Tensor of shape ``(N, f*C, W)``. + Outputs: + - **out**: Tensor of shape ``(N, C, W*f)``. + + Examples + -------- + >>> pxshuf = PixelShuffle1D(2) + >>> x = mx.nd.zeros((1, 8, 3)) + >>> pxshuf(x).shape + (1, 4, 6) + """ + + def __init__(self, factor): + super(PixelShuffle1D, self).__init__() + self._factor = int(factor) + + def hybrid_forward(self, F, x): + """Perform pixel-shuffling on the input.""" + f = self._factor + # (N, C*f, W) + x = F.reshape(x, (0, -4, -1, f, 0)) # (N, C, f, W) + x = F.transpose(x, (0, 1, 3, 2)) # (N, C, W, f) + x = F.reshape(x, (0, 0, -3)) # (N, C, W*f) + return x + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self._factor) + + +class PixelShuffle2D(HybridBlock): + + r"""Pixel-shuffle layer for upsampling in 2 dimensions. + + Pixel-shuffling is the operation of taking groups of values along + the *channel* dimension and regrouping them into blocks of pixels + along the ``H`` and ``W`` dimensions, thereby effectively multiplying + those dimensions by a constant factor in size. + + For example, a feature map of shape :math:`(f^2 C, H, W)` is reshaped + into :math:`(C, fH, fW)` by forming little :math:`f \times f` blocks + of pixels and arranging them in an :math:`H \times W` grid. + + Pixel-shuffling together with regular convolution is an alternative, + learnable way of upsampling an image by arbitrary factors. It is reported + to help overcome checkerboard artifacts that are common in upsampling with + transposed convolutions (also called deconvolutions). See the paper + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + for further details. + + Parameters + ---------- + factor : int or 2-tuple of int + Upsampling factors, applied to the ``H`` and ``W`` dimensions, + in that order. + + Inputs: + - **data**: Tensor of shape ``(N, f1*f2*C, H, W)``. + Outputs: + - **out**: Tensor of shape ``(N, C, H*f1, W*f2)``. + + Examples + -------- + >>> pxshuf = PixelShuffle2D((2, 3)) + >>> x = mx.nd.zeros((1, 12, 3, 5)) + >>> pxshuf(x).shape + (1, 2, 6, 15) + """ + + def __init__(self, factor): + super(PixelShuffle2D, self).__init__() + try: + self._factors = (int(factor),) * 2 + except TypeError: + self._factors = tuple(int(fac) for fac in factor) + assert len(self._factors) == 2, "wrong length {}".format(len(self._factors)) + + def hybrid_forward(self, F, x): + """Perform pixel-shuffling on the input.""" + f1, f2 = self._factors + # (N, f1*f2*C, H, W) + x = F.reshape(x, (0, -4, -1, f1 * f2, 0, 0)) # (N, C, f1*f2, H, W) + x = F.reshape(x, (0, 0, -4, f1, f2, 0, 0)) # (N, C, f1, f2, H, W) + x = F.transpose(x, (0, 1, 4, 2, 5, 3)) # (N, C, H, f1, W, f2) + x = F.reshape(x, (0, 0, -3, -3)) # (N, C, H*f1, W*f2) + return x + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self._factors) + + +class PixelShuffle3D(HybridBlock): + + r"""Pixel-shuffle layer for upsampling in 3 dimensions. + + Pixel-shuffling (or voxel-shuffling in 3D) is the operation of taking + groups of values along the *channel* dimension and regrouping them into + blocks of voxels along the ``D``, ``H`` and ``W`` dimensions, thereby + effectively multiplying those dimensions by a constant factor in size. + + For example, a feature map of shape :math:`(f^3 C, D, H, W)` is reshaped + into :math:`(C, fD, fH, fW)` by forming little :math:`f \times f \times f` + blocks of voxels and arranging them in a :math:`D \times H \times W` grid. + + Pixel-shuffling together with regular convolution is an alternative, + learnable way of upsampling an image by arbitrary factors. It is reported + to help overcome checkerboard artifacts that are common in upsampling with + transposed convolutions (also called deconvolutions). See the paper + `Real-Time Single Image and Video Super-Resolution Using an Efficient + Sub-Pixel Convolutional Neural Network `_ + for further details. + + Parameters + ---------- + factor : int or 3-tuple of int + Upsampling factors, applied to the ``D``, ``H`` and ``W`` + dimensions, in that order. + + Inputs: + - **data**: Tensor of shape ``(N, f1*f2*f3*C, D, H, W)``. + Outputs: + - **out**: Tensor of shape ``(N, C, D*f1, H*f2, W*f3)``. + + Examples + -------- + >>> pxshuf = PixelShuffle3D((2, 3, 4)) + >>> x = mx.nd.zeros((1, 48, 3, 5, 7)) + >>> pxshuf(x).shape + (1, 2, 6, 15, 28) + """ + + def __init__(self, factor): + super(PixelShuffle3D, self).__init__() + try: + self._factors = (int(factor),) * 3 + except TypeError: + self._factors = tuple(int(fac) for fac in factor) + assert len(self._factors) == 3, "wrong length {}".format(len(self._factors)) + + def hybrid_forward(self, F, x): + """Perform pixel-shuffling on the input.""" + # `transpose` doesn't support 8D, need other implementation + f1, f2, f3 = self._factors + # (N, C*f1*f2*f3, D, H, W) + x = F.reshape(x, (0, -4, -1, f1 * f2 * f3, 0, 0, 0)) # (N, C, f1*f2*f3, D, H, W) + x = F.swapaxes(x, 2, 3) # (N, C, D, f1*f2*f3, H, W) + x = F.reshape(x, (0, 0, 0, -4, f1, f2*f3, 0, 0)) # (N, C, D, f1, f2*f3, H, W) + x = F.reshape(x, (0, 0, -3, 0, 0, 0)) # (N, C, D*f1, f2*f3, H, W) + x = F.swapaxes(x, 3, 4) # (N, C, D*f1, H, f2*f3, W) + x = F.reshape(x, (0, 0, 0, 0, -4, f2, f3, 0)) # (N, C, D*f1, H, f2, f3, W) + x = F.reshape(x, (0, 0, 0, -3, 0, 0)) # (N, C, D*f1, H*f2, f3, W) + x = F.swapaxes(x, 4, 5) # (N, C, D*f1, H*f2, W, f3) + x = F.reshape(x, (0, 0, 0, 0, -3)) # (N, C, D*f1, H*f2, W*f3) + return x + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self._factors) diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index a1cd8ea537d7..6901e8bd12fe 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -19,7 +19,9 @@ import mxnet as mx from mxnet.gluon import contrib from mxnet.gluon import nn -from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding +from mxnet.gluon.contrib.nn import ( + Concurrent, HybridConcurrent, Identity, SparseEmbedding, PixelShuffle1D, + PixelShuffle2D, PixelShuffle3D) from mxnet.test_utils import almost_equal from common import setup_module, with_seed, teardown import numpy as np @@ -204,6 +206,89 @@ def test_sparse_embedding(): assert (layer.weight.grad().asnumpy()[:5] == 1).all() assert (layer.weight.grad().asnumpy()[5:] == 0).all() +def test_pixelshuffle1d(): + nchan = 2 + up_x = 2 + nx = 3 + shape_before = (1, nchan * up_x, nx) + shape_after = (1, nchan, nx * up_x) + layer = PixelShuffle1D(up_x) + x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before) + y = layer(x) + assert y.shape == shape_after + assert_allclose( + y.asnumpy(), + [[[0, 3, 1, 4, 2, 5], + [6, 9, 7, 10, 8, 11]]] + ) + +def test_pixelshuffle2d(): + nchan = 2 + up_x = 2 + up_y = 3 + nx = 2 + ny = 3 + shape_before = (1, nchan * up_x * up_y, nx, ny) + shape_after = (1, nchan, nx * up_x, ny * up_y) + layer = PixelShuffle2D((up_x, up_y)) + x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before) + y = layer(x) + assert y.shape == shape_after + # - Channels are reshaped to form 2x3 blocks + # - Within each block, the increment is `nx * ny` when increasing the column + # index by 1 + # - Increasing the block index adds an offset of 1 + # - Increasing the channel index adds an offset of `nx * up_x * ny * up_y` + assert_allclose( + y.asnumpy(), + [[[[ 0, 6, 12, 1, 7, 13, 2, 8, 14], + [18, 24, 30, 19, 25, 31, 20, 26, 32], + [ 3, 9, 15, 4, 10, 16, 5, 11, 17], + [21, 27, 33, 22, 28, 34, 23, 29, 35]], + + [[36, 42, 48, 37, 43, 49, 38, 44, 50], + [54, 60, 66, 55, 61, 67, 56, 62, 68], + [39, 45, 51, 40, 46, 52, 41, 47, 53], + [57, 63, 69, 58, 64, 70, 59, 65, 71]]]] + ) + +def test_pixelshuffle3d(): + nchan = 1 + up_x = 2 + up_y = 1 + up_z = 2 + nx = 2 + ny = 3 + nz = 4 + shape_before = (1, nchan * up_x * up_y * up_z, nx, ny, nz) + shape_after = (1, nchan, nx * up_x, ny * up_y, nz * up_z) + layer = PixelShuffle3D((up_x, up_y, up_z)) + x = mx.nd.arange(np.prod(shape_before)).reshape(shape_before) + y = layer(x) + assert y.shape == shape_after + # - Channels are reshaped to form 2x1x2 blocks + # - Within each block, the increment is `nx * ny * nz` when increasing the + # column index by 1, e.g. the block [[[ 0, 24]], [[48, 72]]] + # - Increasing the block index adds an offset of 1 + assert_allclose( + y.asnumpy(), + [[[[[ 0, 24, 1, 25, 2, 26, 3, 27], + [ 4, 28, 5, 29, 6, 30, 7, 31], + [ 8, 32, 9, 33, 10, 34, 11, 35]], + + [[48, 72, 49, 73, 50, 74, 51, 75], + [52, 76, 53, 77, 54, 78, 55, 79], + [56, 80, 57, 81, 58, 82, 59, 83]], + + [[12, 36, 13, 37, 14, 38, 15, 39], + [16, 40, 17, 41, 18, 42, 19, 43], + [20, 44, 21, 45, 22, 46, 23, 47]], + + [[60, 84, 61, 85, 62, 86, 63, 87], + [64, 88, 65, 89, 66, 90, 67, 91], + [68, 92, 69, 93, 70, 94, 71, 95]]]]] + ) + def test_datasets(): wikitext2_train = contrib.data.text.WikiText2(root='data/wikitext-2', segment='train') wikitext2_val = contrib.data.text.WikiText2(root='data/wikitext-2', segment='validation',