Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add pixelshuffle layers #13571

Merged
merged 7 commits into from
Feb 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api/python/gluon/contrib.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
Identity
SparseEmbedding
SyncBatchNorm
PixelShuffle1D
PixelShuffle2D
PixelShuffle3D
```

### Recurrent neural network
Expand Down
144 changes: 92 additions & 52 deletions example/gluon/super_resolution/super_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,27 @@
# under the License.

from __future__ import print_function
import argparse, tarfile

import argparse
import math
import os
import shutil
import sys
import zipfile
from os import path

import numpy as np

import mxnet as mx
import mxnet.ndarray as F
from mxnet import gluon
from mxnet import gluon, autograd as ag
from mxnet.gluon import nn
from mxnet import autograd as ag
from mxnet.test_utils import download
from mxnet.gluon.contrib import nn as contrib_nn
from mxnet.image import CenterCropAug, ResizeAug
from mxnet.io import PrefetchingIter
from mxnet.test_utils import download

this_dir = path.abspath(path.dirname(__file__))
sys.path.append(path.join(this_dir, path.pardir))

from data import ImagePairIter

Expand All @@ -51,19 +59,45 @@
batch_size, test_batch_size = opt.batch_size, opt.test_batch_size
color_flag = 0

# get data
dataset_path = "dataset"
dataset_url = "http://www2.eecs.berkeley.edu/Research/Projects/CS/vision/bsds/BSDS300-images.tgz"
def get_dataset(prefetch=False):
image_path = os.path.join(dataset_path, "BSDS300/images")
# Get data
datasets_dir = path.expanduser(path.join("~", ".mxnet", "datasets"))
datasets_tmpdir = path.join(datasets_dir, "tmp")
dataset_url = "https://github.com/BIDS/BSDS500/archive/master.zip"
data_dir = path.expanduser(path.join(datasets_dir, "BSDS500"))
tmp_dir = path.join(data_dir, "tmp")

if not os.path.exists(image_path):
os.makedirs(dataset_path)
file_name = download(dataset_url)
with tarfile.open(file_name) as tar:
for item in tar:
tar.extract(item, dataset_path)
os.remove(file_name)
def get_dataset(prefetch=False):
"""Download the BSDS500 dataset and return train and test iters."""

if path.exists(data_dir):
print(
"Directory {} already exists, skipping.\n"
"To force download and extraction, delete the directory and re-run."
"".format(data_dir),
file=sys.stderr,
)
else:
print("Downloading dataset...", file=sys.stderr)
downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
print("done", file=sys.stderr)

print("Extracting files...", end="", file=sys.stderr)
os.makedirs(data_dir)
os.makedirs(tmp_dir)
with zipfile.ZipFile(downloaded_file) as archive:
archive.extractall(tmp_dir)
shutil.rmtree(datasets_tmpdir)

shutil.copytree(
path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
path.join(data_dir, "images"),
)
shutil.copytree(
path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"),
path.join(data_dir, "groundTruth"),
)
shutil.rmtree(tmp_dir)
print("done", file=sys.stderr)

crop_size = 256
crop_size -= crop_size % upscale_factor
Expand All @@ -72,15 +106,26 @@ def get_dataset(prefetch=False):
input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
target_transform = [CenterCropAug((crop_size, crop_size))]

iters = (ImagePairIter(os.path.join(image_path, "train"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
batch_size, color_flag, input_transform, target_transform),
ImagePairIter(os.path.join(image_path, "test"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
test_batch_size, color_flag,
input_transform, target_transform))
iters = (
ImagePairIter(
path.join(data_dir, "images", "train"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
batch_size,
color_flag,
input_transform,
target_transform,
),
ImagePairIter(
path.join(data_dir, "images", "test"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
test_batch_size,
color_flag,
input_transform,
target_transform,
),
)

return [PrefetchingIter(i) for i in iters] if prefetch else iters

Expand All @@ -90,33 +135,23 @@ def get_dataset(prefetch=False):
ctx = [mx.gpu(0)] if opt.use_gpu else [mx.cpu()]


# define model
def _rearrange(raw, F, upscale_factor):
# (N, C * r^2, H, W) -> (N, C, r^2, H, W)
splitted = F.reshape(raw, shape=(0, -4, -1, upscale_factor**2, 0, 0))
# (N, C, r^2, H, W) -> (N, C, r, r, H, W)
unflatten = F.reshape(splitted, shape=(0, 0, -4, upscale_factor, upscale_factor, 0, 0))
# (N, C, r, r, H, W) -> (N, C, H, r, W, r)
swapped = F.transpose(unflatten, axes=(0, 1, 4, 2, 5, 3))
# (N, C, H, r, W, r) -> (N, C, H*r, W*r)
return F.reshape(swapped, shape=(0, 0, -3, -3))


class SuperResolutionNet(gluon.Block):
class SuperResolutionNet(gluon.HybridBlock):
def __init__(self, upscale_factor):
super(SuperResolutionNet, self).__init__()
with self.name_scope():
self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2))
self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1))
self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1))
self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), activation='relu')
self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1))
self.upscale_factor = upscale_factor
self.pxshuf = contrib_nn.PixelShuffle2D(upscale_factor)

def forward(self, x):
x = F.Activation(self.conv1(x), act_type='relu')
x = F.Activation(self.conv2(x), act_type='relu')
x = F.Activation(self.conv3(x), act_type='relu')
return _rearrange(self.conv4(x), F, self.upscale_factor)
def hybrid_forward(self, F, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pxshuf(x)
return x

net = SuperResolutionNet(upscale_factor)
metric = mx.metric.MSE()
Expand All @@ -136,7 +171,7 @@ def test(ctx):
avg_psnr += 10 * math.log10(1/metric.get()[1])
metric.reset()
avg_psnr /= batches
print('validation avg psnr: %f'%avg_psnr)
print('validation avg psnr: %f' % avg_psnr)


def train(epoch, ctx):
Expand Down Expand Up @@ -168,13 +203,18 @@ def train(epoch, ctx):
print('training mse at epoch %d: %s=%f'%(i, name, acc))
test(ctx)

net.save_parameters('superres.params')
net.save_parameters(path.join(this_dir, 'superres.params'))

def resolve(ctx):
from PIL import Image

if isinstance(ctx, list):
ctx = [ctx[0]]
net.load_parameters('superres.params', ctx=ctx)

img_basename = path.splitext(path.basename(opt.resolve_img))[0]
img_dirname = path.dirname(opt.resolve_img)

net.load_parameters(path.join(this_dir, 'superres.params'), ctx=ctx)
img = Image.open(opt.resolve_img).convert('YCbCr')
y, cb, cr = img.split()
data = mx.nd.expand_dims(mx.nd.expand_dims(mx.nd.array(y), axis=0), axis=0)
Expand All @@ -186,7 +226,7 @@ def resolve(ctx):
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')

out_img.save('resolved.png')
out_img.save(path.join(img_dirname, '{}-resolved.png'.format(img_basename)))

if opt.resolve_img:
resolve(ctx)
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/contrib/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# coding: utf-8
# pylint: disable=wildcard-import
"""Contrib recurrent neural network module."""
"""Contributed neural network modules."""

from . import basic_layers

Expand Down
Loading