Skip to content

Commit

Permalink
Add pixelshuffle layers (apache#13571)
Browse files Browse the repository at this point in the history
* Add pixelshuffle layers, closes apache#13548

* Remove fmt comments

* Use explicit class in super()

* Add axis swapping to pixel shuffling, add tests

* Add documentation to pixel shuffle layers

* Use pixelshuffle layer and fix download in superres example

* Add pixelshuffle layers to API doc page
  • Loading branch information
kohr-h authored and stephenrawls committed Feb 16, 2019
1 parent ebb8599 commit 3fe2a72
Show file tree
Hide file tree
Showing 5 changed files with 362 additions and 55 deletions.
3 changes: 3 additions & 0 deletions docs/api/python/gluon/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
Identity
SparseEmbedding
SyncBatchNorm
PixelShuffle1D
PixelShuffle2D
PixelShuffle3D
```

### Recurrent neural network
Expand Down
144 changes: 92 additions & 52 deletions example/gluon/super_resolution/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,27 @@
# under the License.

from __future__ import print_function
import argparse, tarfile

import argparse
import math
import os
import shutil
import sys
import zipfile
from os import path

import numpy as np

import mxnet as mx
import mxnet.ndarray as F
from mxnet import gluon
from mxnet import gluon, autograd as ag
from mxnet.gluon import nn
from mxnet import autograd as ag
from mxnet.test_utils import download
from mxnet.gluon.contrib import nn as contrib_nn
from mxnet.image import CenterCropAug, ResizeAug
from mxnet.io import PrefetchingIter
from mxnet.test_utils import download

this_dir = path.abspath(path.dirname(__file__))
sys.path.append(path.join(this_dir, path.pardir))

from data import ImagePairIter

Expand All @@ -51,19 +59,45 @@
batch_size, test_batch_size = opt.batch_size, opt.test_batch_size
color_flag = 0

# get data
dataset_path = "dataset"
dataset_url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
def get_dataset(prefetch=False):
image_path = os.path.join(dataset_path, "BSDS300/images")
# Get data
datasets_dir = path.expanduser(path.join("~", ".mxnet", "datasets"))
datasets_tmpdir = path.join(datasets_dir, "tmp")
dataset_url = "https://github.com/BIDS/BSDS500/archive/master.zip"
data_dir = path.expanduser(path.join(datasets_dir, "BSDS500"))
tmp_dir = path.join(data_dir, "tmp")

if not os.path.exists(image_path):
os.makedirs(dataset_path)
file_name = download(dataset_url)
with tarfile.open(file_name) as tar:
for item in tar:
tar.extract(item, dataset_path)
os.remove(file_name)
def get_dataset(prefetch=False):
"""Download the BSDS500 dataset and return train and test iters."""

if path.exists(data_dir):
print(
"Directory {} already exists, skipping.\n"
"To force download and extraction, delete the directory and re-run."
"".format(data_dir),
file=sys.stderr,
)
else:
print("Downloading dataset...", file=sys.stderr)
downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
print("done", file=sys.stderr)

print("Extracting files...", end="", file=sys.stderr)
os.makedirs(data_dir)
os.makedirs(tmp_dir)
with zipfile.ZipFile(downloaded_file) as archive:
archive.extractall(tmp_dir)
shutil.rmtree(datasets_tmpdir)

shutil.copytree(
path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
path.join(data_dir, "images"),
)
shutil.copytree(
path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"),
path.join(data_dir, "groundTruth"),
)
shutil.rmtree(tmp_dir)
print("done", file=sys.stderr)

crop_size = 256
crop_size -= crop_size % upscale_factor
Expand All @@ -72,15 +106,26 @@ def get_dataset(prefetch=False):
input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
target_transform = [CenterCropAug((crop_size, crop_size))]

iters = (ImagePairIter(os.path.join(image_path, "train"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
batch_size, color_flag, input_transform, target_transform),
ImagePairIter(os.path.join(image_path, "test"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
test_batch_size, color_flag,
input_transform, target_transform))
iters = (
ImagePairIter(
path.join(data_dir, "images", "train"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
batch_size,
color_flag,
input_transform,
target_transform,
),
ImagePairIter(
path.join(data_dir, "images", "test"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
test_batch_size,
color_flag,
input_transform,
target_transform,
),
)

return [PrefetchingIter(i) for i in iters] if prefetch else iters

Expand All @@ -90,33 +135,23 @@ def get_dataset(prefetch=False):
ctx = [mx.gpu(0)] if opt.use_gpu else [mx.cpu()]


# define model
def _rearrange(raw, F, upscale_factor):
# (N, C * r^2, H, W) -> (N, C, r^2, H, W)
splitted = F.reshape(raw, shape=(0, -4, -1, upscale_factor**2, 0, 0))
# (N, C, r^2, H, W) -> (N, C, r, r, H, W)
unflatten = F.reshape(splitted, shape=(0, 0, -4, upscale_factor, upscale_factor, 0, 0))
# (N, C, r, r, H, W) -> (N, C, H, r, W, r)
swapped = F.transpose(unflatten, axes=(0, 1, 4, 2, 5, 3))
# (N, C, H, r, W, r) -> (N, C, H*r, W*r)
return F.reshape(swapped, shape=(0, 0, -3, -3))


class SuperResolutionNet(gluon.Block):
class SuperResolutionNet(gluon.HybridBlock):
def __init__(self, upscale_factor):
super(SuperResolutionNet, self).__init__()
with self.name_scope():
self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2))
self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1))
self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1))
self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), activation='relu')
self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1))
self.upscale_factor = upscale_factor
self.pxshuf = contrib_nn.PixelShuffle2D(upscale_factor)

def forward(self, x):
x = F.Activation(self.conv1(x), act_type='relu')
x = F.Activation(self.conv2(x), act_type='relu')
x = F.Activation(self.conv3(x), act_type='relu')
return _rearrange(self.conv4(x), F, self.upscale_factor)
def hybrid_forward(self, F, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pxshuf(x)
return x

net = SuperResolutionNet(upscale_factor)
metric = mx.metric.MSE()
Expand All @@ -136,7 +171,7 @@ def test(ctx):
avg_psnr += 10 * math.log10(1/metric.get()[1])
metric.reset()
avg_psnr /= batches
print('validation avg psnr: %f'%avg_psnr)
print('validation avg psnr: %f' % avg_psnr)


def train(epoch, ctx):
Expand Down Expand Up @@ -168,13 +203,18 @@ def train(epoch, ctx):
print('training mse at epoch %d: %s=%f'%(i, name, acc))
test(ctx)

net.save_parameters('superres.params')
net.save_parameters(path.join(this_dir, 'superres.params'))

def resolve(ctx):
from PIL import Image

if isinstance(ctx, list):
ctx = [ctx[0]]
net.load_parameters('superres.params', ctx=ctx)

img_basename = path.splitext(path.basename(opt.resolve_img))[0]
img_dirname = path.dirname(opt.resolve_img)

net.load_parameters(path.join(this_dir, 'superres.params'), ctx=ctx)
img = Image.open(opt.resolve_img).convert('YCbCr')
y, cb, cr = img.split()
data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
Expand All @@ -186,7 +226,7 @@ def resolve(ctx):
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

out_img.save('resolved.png')
out_img.save(path.join(img_dirname, '{}-resolved.png'.format(img_basename)))

if opt.resolve_img:
resolve(ctx)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/contrib/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# coding: utf-8
# pylint: disable=wildcard-import
"""Contrib recurrent neural network module."""
"""Contributed neural network modules."""

from . import basic_layers

Expand Down
Loading

0 comments on commit 3fe2a72

Please sign in to comment.