Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Time major data layout support for RNN (#3497)
Browse files Browse the repository at this point in the history
* rnn-cell demo (push to server for testing)

* a running example with cuDNN RNN cell

* ndarray concatenate

* fix lint errors

* allow batch_axis in executor_group

* add batch_axis parameter for all modules

* fix bug in copy slice implementation

* fix module examples

* use batch_axis if data iterator provided such information

* rnn cell example in time major

* fix init state names in rnn cell bucketing example

* sanity check stochastic depth mnist

* a cifar10 example (not tested)

* add description for sd cifar

* add doc for sd module

* add a simple random number queue

* add final numbers

* fix typo

* default layout mapper

* fix other modules for layout mapper

* fix typo

* softmax output mode that preserves the shape

* comments on run-time speed of time-major

* extend layout mapper to include other information

* fix data layout API change

* fix lint errors

* fix Travis CI numpy error on unit test
  • Loading branch information
pluskid authored Oct 14, 2016
1 parent d97e360 commit 23bf60c
Show file tree
Hide file tree
Showing 20 changed files with 879 additions and 131 deletions.
11 changes: 7 additions & 4 deletions example/module/mnist_mlp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pylint: skip-file
import os
import mxnet as mx
import numpy as np
import logging
Expand All @@ -13,14 +14,16 @@

n_epoch = 2
batch_size = 100

basedir = os.path.dirname(__file__)
train_dataiter = mx.io.MNISTIter(
image="../image-classification/mnist/train-images-idx3-ubyte",
label="../image-classification/mnist/train-labels-idx1-ubyte",
image=os.path.join(basedir, "../image-classification/mnist/train-images-idx3-ubyte"),
label=os.path.join(basedir, "../image-classification/mnist/train-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="../image-classification/mnist/t10k-images-idx3-ubyte",
label="../image-classification/mnist/t10k-labels-idx1-ubyte",
image=os.path.join(basedir, "../image-classification/mnist/t10k-images-idx3-ubyte"),
label=os.path.join(basedir, "../image-classification/mnist/t10k-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)

Expand Down
12 changes: 7 additions & 5 deletions example/module/sequential_module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# pylint: skip-file
import os
import mxnet as mx
import numpy as np
import logging

# whether to demo model-parallelism + data parallelism
demo_data_model_parallelism = False
demo_data_model_parallelism = True

if demo_data_model_parallelism:
contexts = [[mx.context.gpu(0), mx.context.gpu(1)], [mx.context.gpu(2), mx.context.gpu(3)]]
Expand Down Expand Up @@ -43,14 +44,15 @@
#--------------------------------------------------------------------------------
n_epoch = 2
batch_size = 100
basedir = os.path.dirname(__file__)
train_dataiter = mx.io.MNISTIter(
image="../image-classification/mnist/train-images-idx3-ubyte",
label="../image-classification/mnist/train-labels-idx1-ubyte",
image=os.path.join(basedir, "../image-classification/mnist/train-images-idx3-ubyte"),
label=os.path.join(basedir, "../image-classification/mnist/train-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)
val_dataiter = mx.io.MNISTIter(
image="../image-classification/mnist/t10k-images-idx3-ubyte",
label="../image-classification/mnist/t10k-labels-idx1-ubyte",
image=os.path.join(basedir, "../image-classification/mnist/t10k-images-idx3-ubyte"),
label=os.path.join(basedir, "../image-classification/mnist/t10k-labels-idx1-ubyte"),
data_shape=(784,),
batch_size=batch_size, shuffle=True, flat=True, silent=False)

Expand Down
6 changes: 3 additions & 3 deletions example/module/train_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def _download(data_dir):
os.chdir(data_dir)
if (not os.path.exists('train.rec')) or \
(not os.path.exists('test.rec')) :
os.system("wget http://data.dmlc.ml/mxnet/data/cifar10.zip")
os.system("unzip -u cifar10.zip")
os.system("mv cifar/* .; rm -rf cifar; rm cifar10.zip")
os.system("wget http://data.dmlc.ml/mxnet/data/cifar10.zip")
os.system("unzip -u cifar10.zip")
os.system("mv cifar/* .; rm -rf cifar; rm cifar10.zip")
os.chdir(cwd)

# network
Expand Down
245 changes: 245 additions & 0 deletions example/rnn-time-major/bucket_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# pylint: disable=C0111,too-many-arguments,too-many-instance-attributes,too-many-locals,redefined-outer-name,fixme
# pylint: disable=superfluous-parens, no-member, invalid-name
import sys
sys.path.insert(0, "../../python")
import numpy as np
import mxnet as mx

# The interface of a data iter that works for bucketing
#
# DataIter
# - default_bucket_key: the bucket key for the default symbol.
#
# DataBatch
# - provide_data: same as DataIter, but specific to this batch
# - provide_label: same as DataIter, but specific to this batch
# - bucket_key: the key for the bucket that should be used for this batch

def default_read_content(path):
with open(path) as ins:
content = ins.read()
content = content.replace('\n', ' <eos> ').replace('. ', ' <eos> ')
return content

def default_build_vocab(path):
content = default_read_content(path)
content = content.split(' ')
idx = 1 # 0 is left for zero-padding
the_vocab = {}
the_vocab[' '] = 0 # put a dummy element here so that len(vocab) is correct
for word in content:
if len(word) == 0:
continue
if not word in the_vocab:
the_vocab[word] = idx
idx += 1
return the_vocab

def default_text2id(sentence, the_vocab):
words = sentence.split(' ')
words = [the_vocab[w] for w in words if len(w) > 0]
return words

def default_gen_buckets(sentences, batch_size, the_vocab):
len_dict = {}
max_len = -1
for sentence in sentences:
words = default_text2id(sentence, the_vocab)
if len(words) == 0:
continue
if len(words) > max_len:
max_len = len(words)
if len(words) in len_dict:
len_dict[len(words)] += 1
else:
len_dict[len(words)] = 1
print(len_dict)

tl = 0
buckets = []
for l, n in len_dict.items(): # TODO: There are better heuristic ways to do this
if n + tl >= batch_size:
buckets.append(l)
tl = 0
else:
tl += n
if tl > 0:
buckets.append(max_len)
return buckets

class SimpleBatch(object):
def __init__(self, data_names, data, label_names, label, bucket_key):
self.data = data
self.label = label
self.data_names = data_names
self.label_names = label_names
self.bucket_key = bucket_key

self.pad = 0
self.index = None # TODO: what is index?

@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self.data_names, self.data)]

@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self.label_names, self.label)]

class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
def __init__(self, real_iter):
super(DummyIter, self).__init__()
self.real_iter = real_iter
self.provide_data = real_iter.provide_data
self.provide_label = real_iter.provide_label
self.batch_size = real_iter.batch_size

for batch in real_iter:
self.the_batch = batch
break

def __iter__(self):
return self

def next(self):
return self.the_batch

class BucketSentenceIter(mx.io.DataIter):
def __init__(self, path, vocab, buckets, batch_size,
init_states, data_name='data', label_name='label',
seperate_char=' <eos> ', text2id=None, read_content=None,
time_major=True):
super(BucketSentenceIter, self).__init__()

if text2id == None:
self.text2id = default_text2id
else:
self.text2id = text2id
if read_content == None:
self.read_content = default_read_content
else:
self.read_content = read_content
content = self.read_content(path)
sentences = content.split(seperate_char)

if len(buckets) == 0:
buckets = default_gen_buckets(sentences, batch_size, vocab)

self.vocab_size = len(vocab)
self.data_name = data_name
self.label_name = label_name
self.time_major = time_major
self.layout_mapper = mx.io.DefaultLayoutMapper(1 if time_major else 0)

buckets.sort()
self.buckets = buckets
self.data = [[] for _ in buckets]

# pre-allocate with the largest bucket for better memory sharing
self.default_bucket_key = max(buckets)

for sentence in sentences:
sentence = self.text2id(sentence, vocab)
if len(sentence) == 0:
continue
for i, bkt in enumerate(buckets):
if bkt >= len(sentence):
self.data[i].append(sentence)
break
# we just ignore the sentence it is longer than the maximum
# bucket size here

# convert data into ndarrays for better speed during training
data = [np.zeros((len(x), buckets[i])) for i, x in enumerate(self.data)]
for i_bucket in range(len(self.buckets)):
for j in range(len(self.data[i_bucket])):
sentence = self.data[i_bucket][j]
data[i_bucket][j, :len(sentence)] = sentence
self.data = data

# Get the size of each bucket, so that we could sample
# uniformly from the bucket
bucket_sizes = [len(x) for x in self.data]

print("Summary of dataset ==================")
for bkt, size in zip(buckets, bucket_sizes):
print("bucket of len %3d : %d samples" % (bkt, size))

self.batch_size = batch_size
self.make_data_iter_plan()

self.init_states = init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in init_states]

if self.time_major:
self.provide_data = [('data', (self.default_bucket_key, batch_size))] + init_states
self.provide_label = [('softmax_label', (self.default_bucket_key, batch_size))]
else:
self.provide_data = [('data', (batch_size, self.default_bucket_key))] + init_states
self.provide_label = [('softmax_label', (self.batch_size, self.default_bucket_key))]

def make_data_iter_plan(self):
"make a random data iteration plan"
# truncate each bucket into multiple of batch-size
bucket_n_batches = []
for i in range(len(self.data)):
bucket_n_batches.append(len(self.data[i]) / self.batch_size)
self.data[i] = self.data[i][:int(bucket_n_batches[i]*self.batch_size)]

bucket_plan = np.hstack([np.zeros(n, int)+i for i, n in enumerate(bucket_n_batches)])
np.random.shuffle(bucket_plan)

bucket_idx_all = [np.random.permutation(len(x)) for x in self.data]

self.bucket_plan = bucket_plan
self.bucket_idx_all = bucket_idx_all
self.bucket_curr_idx = [0 for x in self.data]

self.data_buffer = []
self.label_buffer = []
for i_bucket in range(len(self.data)):
if self.time_major:
data = np.zeros((self.buckets[i_bucket], self.batch_size))
label = np.zeros((self.buckets[i_bucket], self.batch_size))
else:
data = np.zeros((self.batch_size, self.buckets[i_bucket]))
label = np.zeros((self.batch_size, self.buckets[i_bucket]))

self.data_buffer.append(data)
self.label_buffer.append(label)

def __iter__(self):
for i_bucket in self.bucket_plan:
data = self.data_buffer[i_bucket]
i_idx = self.bucket_curr_idx[i_bucket]
idx = self.bucket_idx_all[i_bucket][i_idx:i_idx+self.batch_size]
self.bucket_curr_idx[i_bucket] += self.batch_size

init_state_names = [x[0] for x in self.init_states]

if self.time_major:
data[:] = self.data[i_bucket][idx].T
else:
data[:] = self.data[i_bucket][idx]

label = self.label_buffer[i_bucket]
if self.time_major:
label[:-1, :] = data[1:, :]
label[-1, :] = 0
else:
label[:, :-1] = data[:, 1:]
label[:, -1] = 0

data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['softmax_label']

data_batch = SimpleBatch(data_names, data_all, label_names, label_all,
self.buckets[i_bucket])
yield data_batch


def reset(self):
self.bucket_curr_idx = [0 for x in self.data]
Loading

0 comments on commit 23bf60c

Please sign in to comment.