From 6cf15606aabfa186ee5d5b6118730e69d7e522da Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 20 Mar 2018 11:36:52 -0700 Subject: [PATCH] word language model end-to-end example (AWD/RNNModel, with fix) (#11) * word language model zoo * update names * fix * fix weight drop * fix one-off error in dataset --- .../train.py => word_language_model.py} | 134 ++++++---- example/gluon/word_language_model/README.md | 67 ----- .../gluon/word_language_model/get_ptb_data.sh | 43 ---- example/gluon/word_language_model/model.py | 64 ----- python/mxnet/gluon/contrib/nn/basic_layers.py | 4 +- python/mxnet/gluon/data/text/base.py | 8 +- python/mxnet/gluon/data/text/utils.py | 25 +- python/mxnet/gluon/model_zoo/__init__.py | 2 + python/mxnet/gluon/model_zoo/text/__init__.py | 76 ++++++ python/mxnet/gluon/model_zoo/text/base.py | 238 ++++++++++++++++++ python/mxnet/gluon/model_zoo/text/lm.py | 124 +++++++++ tests/python/unittest/test_gluon_data_text.py | 38 ++- tests/python/unittest/test_gluon_model_zoo.py | 24 +- 13 files changed, 592 insertions(+), 255 deletions(-) rename example/gluon/{word_language_model/train.py => word_language_model.py} (59%) delete mode 100644 example/gluon/word_language_model/README.md delete mode 100755 example/gluon/word_language_model/get_ptb_data.sh delete mode 100644 example/gluon/word_language_model/model.py create mode 100644 python/mxnet/gluon/model_zoo/text/__init__.py create mode 100644 python/mxnet/gluon/model_zoo/text/base.py create mode 100644 python/mxnet/gluon/model_zoo/text/lm.py diff --git a/example/gluon/word_language_model/train.py b/example/gluon/word_language_model.py similarity index 59% rename from example/gluon/word_language_model/train.py rename to example/gluon/word_language_model.py index c7323934a130..eb33f2abf61b 100644 --- a/example/gluon/word_language_model/train.py +++ b/example/gluon/word_language_model.py @@ -20,35 +20,41 @@ import time import math import mxnet as mx -from mxnet import gluon, autograd, contrib -from mxnet.gluon import data -import model +from mxnet import gluon, autograd +from mxnet.gluon import data, text +from mxnet.gluon.model_zoo.text.lm import RNNModel, AWDLSTM parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.') parser.add_argument('--model', type=str, default='lstm', help='type of recurrent net (rnn_tanh, rnn_relu, lstm, gru)') -parser.add_argument('--emsize', type=int, default=200, +parser.add_argument('--emsize', type=int, default=400, help='size of word embeddings') -parser.add_argument('--nhid', type=int, default=200, +parser.add_argument('--nhid', type=int, default=1150, help='number of hidden units per layer') -parser.add_argument('--nlayers', type=int, default=2, +parser.add_argument('--nlayers', type=int, default=3, help='number of layers') -parser.add_argument('--lr', type=float, default=1.0, +parser.add_argument('--lr', type=float, default=30, help='initial learning rate') -parser.add_argument('--clip', type=float, default=0.2, +parser.add_argument('--clip', type=float, default=0.25, help='gradient clipping') -parser.add_argument('--epochs', type=int, default=40, +parser.add_argument('--epochs', type=int, default=750, help='upper epoch limit') -parser.add_argument('--batch_size', type=int, default=32, metavar='N', +parser.add_argument('--batch_size', type=int, default=80, metavar='N', help='batch size') parser.add_argument('--bptt', type=int, default=35, help='sequence length') -parser.add_argument('--dropout', type=float, default=0.2, +parser.add_argument('--dropout', type=float, default=0.4, help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--dropout_h', type=float, default=0.3, + help='dropout applied to hidden layer (0 = no dropout)') +parser.add_argument('--dropout_i', type=float, default=0.4, + help='dropout applied to input layer (0 = no dropout)') +parser.add_argument('--dropout_e', type=float, default=0.1, + help='dropout applied to embedding layer (0 = no dropout)') +parser.add_argument('--weight_dropout', type=float, default=0.65, + help='weight dropout applied to h2h weight matrix (0 = no weight dropout)') parser.add_argument('--tied', action='store_true', help='tie the word embedding and softmax weights') -parser.add_argument('--cuda', action='store_true', - help='Whether to use gpu') parser.add_argument('--log-interval', type=int, default=200, metavar='N', help='report interval') parser.add_argument('--save', type=str, default='model.params', @@ -58,6 +64,10 @@ takes `2bit` or `none` for now.') parser.add_argument('--gcthreshold', type=float, default=0.5, help='threshold for 2bit gradient compression') +parser.add_argument('--eval_only', action='store_true', + help='Whether to only evaluate the trained model') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. (the result of multi-gpu training might be slightly different compared to single-gpu training, still need to be finalized)') args = parser.parse_args() @@ -65,23 +75,20 @@ # Load data ############################################################################### +context = [mx.cpu()] if args.gpus is None or args.gpus == "" else \ + [mx.gpu(int(i)) for i in args.gpus.split(',')] -if args.cuda: - context = mx.gpu(0) -else: - context = mx.cpu(0) - -train_dataset = data.text.lm.WikiText2('./data', 'train', seq_len=args.bptt, +train_dataset = data.text.lm.WikiText2(segment='train', seq_len=args.bptt, eos='') def get_frequencies(dataset): return collections.Counter(x for tup in dataset for x in tup[0] if x) -vocab = contrib.text.vocab.Vocabulary(get_frequencies(train_dataset)) +vocab = text.vocab.Vocabulary(get_frequencies(train_dataset)) def index_tokens(data, label): - return vocab.to_indices(data), vocab.to_indices(label) + return vocab[data], vocab[label] -val_dataset, test_dataset = [data.text.lm.WikiText2('./data', segment, +val_dataset, test_dataset = [data.text.lm.WikiText2(segment=segment, seq_len=args.bptt, eos='') for segment in ['val', 'test']] @@ -114,9 +121,17 @@ def index_tokens(data, label): ntokens = len(vocab) -model = model.RNNModel(args.model, ntokens, args.emsize, args.nhid, - args.nlayers, args.dropout, args.tied) -model.collect_params().initialize(mx.init.Xavier(), ctx=context) + +if args.weight_dropout: + model = AWDLSTM(args.model, vocab, args.emsize, args.nhid, args.nlayers, + args.dropout, args.dropout_h, args.dropout_i, args.dropout_e, args.weight_dropout, + args.tied) +else: + model = RNNModel(args.model, vocab, args.emsize, args.nhid, + args.nlayers, args.dropout, args.tied) + +model.initialize(mx.init.Xavier(), ctx=context) + compression_params = None if args.gctype == 'none' else {'type': args.gctype, 'threshold': args.gcthreshold} trainer = gluon.Trainer(model.collect_params(), 'sgd', @@ -140,38 +155,46 @@ def detach(hidden): def eval(data_source): total_L = 0.0 ntotal = 0 - hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context) + hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context[0]) for i, (data, target) in enumerate(data_source): - data = data.as_in_context(context).T - target = target.as_in_context(context).T.reshape((-1, 1)) + data = data.as_in_context(context[0]).T + target= target.as_in_context(context[0]).T output, hidden = model(data, hidden) - L = loss(output, target) + L = loss(mx.nd.reshape(output, (-3, -1)), + mx.nd.reshape(target, (-1,))) total_L += mx.nd.sum(L).asscalar() ntotal += L.size return total_L / ntotal def train(): best_val = float("Inf") + start_train_time = time.time() for epoch in range(args.epochs): total_L = 0.0 - start_time = time.time() - hidden = model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=context) + start_epoch_time = time.time() + hiddens = [model.begin_state(func=mx.nd.zeros, batch_size=args.batch_size, ctx=ctx) for ctx in context] for i, (data, target) in enumerate(train_data): - data = data.as_in_context(context).T - target = target.as_in_context(context).T.reshape((-1, 1)) - hidden = detach(hidden) + start_batch_time = time.time() + data = data.T + target= target.T + data_list = gluon.utils.split_and_load(data, context, even_split=False) + target_list = gluon.utils.split_and_load(target, context, even_split=False) + hiddens = [detach(hidden) for hidden in hiddens] + Ls = [] with autograd.record(): - output, hidden = model(data, hidden) - L = loss(output, target) + for j, (X, y, h) in enumerate(zip(data_list, target_list, hiddens)): + output, h = model(X, h) + Ls.append(loss(mx.nd.reshape(output, (-3, -1)), mx.nd.reshape(y, (-1,)))) + hiddens[j] = h + for L in Ls: L.backward() - - grads = [p.grad(context) for p in model.collect_params().values()] - # Here gradient is for the whole batch. - # So we multiply max_norm by batch_size and bptt size to balance it. - gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size) + for ctx in context: + grads = [p.grad(ctx) for p in model.collect_params().values()] + gluon.utils.clip_global_norm(grads, args.clip * args.bptt * args.batch_size) trainer.step(args.batch_size) - total_L += mx.nd.sum(L).asscalar() + + total_L += sum([mx.nd.sum(L).asscalar() for L in Ls]) if i % args.log_interval == 0 and i > 0: cur_L = total_L / args.bptt / args.batch_size / args.log_interval @@ -179,26 +202,33 @@ def train(): epoch, i, cur_L, math.exp(cur_L))) total_L = 0.0 - val_L = eval(val_data) + print('[Epoch %d Batch %d] throughput %.2f samples/s'%( + epoch, i, args.batch_size / (time.time() - start_batch_time))) + mx.nd.waitall() + + print('[Epoch %d] throughput %.2f samples/s'%( + epoch, (args.batch_size * nbatch_train) / (time.time() - start_epoch_time))) + val_L = eval(val_data) print('[Epoch %d] time cost %.2fs, valid loss %.2f, valid ppl %.2f'%( - epoch, time.time()-start_time, val_L, math.exp(val_L))) + epoch, time.time()-start_epoch_time, val_L, math.exp(val_L))) if val_L < best_val: best_val = val_L test_L = eval(test_data) model.collect_params().save(args.save) print('test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L))) - else: - args.lr = args.lr*0.25 - trainer._init_optimizer('sgd', - {'learning_rate': args.lr, - 'momentum': 0, - 'wd': 0}) - model.collect_params().load(args.save, context) + + print('Total training throughput %.2f samples/s'%( + (args.batch_size * nbatch_train * args.epochs) / (time.time() - start_train_time))) if __name__ == '__main__': - train() + start_pipeline_time = time.time() + if not args.eval_only: + train() model.collect_params().load(args.save, context) + val_L = eval(val_data) test_L = eval(test_data) + print('Best validation loss %.2f, test ppl %.2f'%(val_L, math.exp(val_L))) print('Best test loss %.2f, test ppl %.2f'%(test_L, math.exp(test_L))) + print('Total time cost %.2fs'%(time.time()-start_pipeline_time)) diff --git a/example/gluon/word_language_model/README.md b/example/gluon/word_language_model/README.md deleted file mode 100644 index ff8ea56b2063..000000000000 --- a/example/gluon/word_language_model/README.md +++ /dev/null @@ -1,67 +0,0 @@ -# Word-level language modeling RNN - -This example trains a multi-layer RNN (Elman, GRU, or LSTM) on Penn Treebank (PTB) language modeling benchmark. - -The model obtains the state-of-the-art result on PTB using LSTM, getting a test perplexity of ~72. -And ~97 ppl in WikiText-2, outperform than basic LSTM(99.3) and reach Variational LSTM(96.3). - -The following techniques have been adopted for SOTA results: -- [LSTM for LM](https://arxiv.org/pdf/1409.2329.pdf) -- [Weight tying](https://arxiv.org/abs/1608.05859) between word vectors and softmax output embeddings - -## Data - -### PTB - -The PTB data is the processed version from [(Mikolov et al, 2010)](http://www.fit.vutbr.cz/research/groups/speech/publi/2010/mikolov_interspeech2010_IS100722.pdf): - -```bash -bash get_ptb_data.sh -python data.py -``` - -### Wiki Text - -The wikitext-2 data is downloaded from [(The wikitext long term dependency language modeling dataset)](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/): - -```bash -bash get_wikitext2_data.sh -``` - - -## Usage - -Example runs and the results: - -``` -python train.py -data ./data/ptb. --cuda --tied --nhid 650 --emsize 650 --dropout 0.5 # Test ppl of 75.3 in ptb -python train.py -data ./data/ptb. --cuda --tied --nhid 1500 --emsize 1500 --dropout 0.65 # Test ppl of 72.0 in ptb -``` - -``` -python train.py -data ./data/wikitext-2/wiki. --cuda --tied --nhid 256 --emsize 256 # Test ppl of 97.07 in wikitext-2 -``` - - -
- -`python train.py --help` gives the following arguments: -``` -Optional arguments: - -h, --help show this help message and exit - --data DATA location of the data corpus - --model MODEL type of recurrent net (rnn_tanh, rnn_relu, lstm, gru) - --emsize EMSIZE size of word embeddings - --nhid NHID number of hidden units per layer - --nlayers NLAYERS number of layers - --lr LR initial learning rate - --clip CLIP gradient clipping - --epochs EPOCHS upper epoch limit - --batch_size N batch size - --bptt BPTT sequence length - --dropout DROPOUT dropout applied to layers (0 = no dropout) - --tied tie the word embedding and softmax weights - --cuda Whether to use gpu - --log-interval N report interval - --save SAVE path to save the final model -``` diff --git a/example/gluon/word_language_model/get_ptb_data.sh b/example/gluon/word_language_model/get_ptb_data.sh deleted file mode 100755 index 2dc4034a938c..000000000000 --- a/example/gluon/word_language_model/get_ptb_data.sh +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env bash - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -echo -echo "NOTE: To continue, you need to review the licensing of the data sets used by this script" -echo "See https://catalog.ldc.upenn.edu/ldc99t42 for the licensing" -read -p "Please confirm you have reviewed the licensing [Y/n]:" -n 1 -r -echo - -if [ $REPLY != "Y" ] -then - echo "License was not reviewed, aborting script." - exit 1 -fi - -RNN_DIR=$(cd `dirname $0`; pwd) -DATA_DIR="${RNN_DIR}/data/" - -if [[ ! -d "${DATA_DIR}" ]]; then - echo "${DATA_DIR} doesn't exist, will create one"; - mkdir -p ${DATA_DIR} -fi - -wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.train.txt; -wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.valid.txt; -wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/ptb/ptb.test.txt; -wget -P ${DATA_DIR} https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/tinyshakespeare/input.txt; diff --git a/example/gluon/word_language_model/model.py b/example/gluon/word_language_model/model.py deleted file mode 100644 index 40e7926ef8d6..000000000000 --- a/example/gluon/word_language_model/model.py +++ /dev/null @@ -1,64 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import mxnet as mx -from mxnet import gluon -from mxnet.gluon import nn, rnn - -class RNNModel(gluon.Block): - """A model with an encoder, recurrent layer, and a decoder.""" - - def __init__(self, mode, vocab_size, num_embed, num_hidden, - num_layers, dropout=0.5, tie_weights=False, **kwargs): - super(RNNModel, self).__init__(**kwargs) - with self.name_scope(): - self.drop = nn.Dropout(dropout) - self.encoder = nn.Embedding(vocab_size, num_embed, - weight_initializer=mx.init.Uniform(0.1)) - if mode == 'rnn_relu': - self.rnn = rnn.RNN(num_hidden, 'relu', num_layers, dropout=dropout, - input_size=num_embed) - elif mode == 'rnn_tanh': - self.rnn = rnn.RNN(num_hidden, num_layers, dropout=dropout, - input_size=num_embed) - elif mode == 'lstm': - self.rnn = rnn.LSTM(num_hidden, num_layers, dropout=dropout, - input_size=num_embed) - elif mode == 'gru': - self.rnn = rnn.GRU(num_hidden, num_layers, dropout=dropout, - input_size=num_embed) - else: - raise ValueError("Invalid mode %s. Options are rnn_relu, " - "rnn_tanh, lstm, and gru"%mode) - - if tie_weights: - self.decoder = nn.Dense(vocab_size, in_units=num_hidden, - params=self.encoder.params) - else: - self.decoder = nn.Dense(vocab_size, in_units=num_hidden) - - self.num_hidden = num_hidden - - def forward(self, inputs, hidden): - emb = self.drop(self.encoder(inputs)) - output, hidden = self.rnn(emb, hidden) - output = self.drop(output) - decoded = self.decoder(output.reshape((-1, self.num_hidden))) - return decoded, hidden - - def begin_state(self, *args, **kwargs): - return self.rnn.begin_state(*args, **kwargs) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 88708884c511..082a148012c5 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -108,5 +108,7 @@ class Identity(HybridBlock): def __init__(self, prefix=None, params=None): super(Identity, self).__init__(prefix=prefix, params=params) - def hybrid_forward(self, F, x): + def hybrid_forward(self, F, *x): + if x and len(x) == 1: + return x[0] return x diff --git a/python/mxnet/gluon/data/text/base.py b/python/mxnet/gluon/data/text/base.py index 3c024245023f..f67a18c810b1 100644 --- a/python/mxnet/gluon/data/text/base.py +++ b/python/mxnet/gluon/data/text/base.py @@ -27,7 +27,7 @@ from ..dataset import SimpleDataset from ..datareader import DataReader -from .utils import flatten_samples, collate +from .utils import flatten_samples, collate, collate_pad_length class CorpusReader(DataReader): """Text reader that reads a whole corpus and produces a dataset based on provided @@ -124,9 +124,9 @@ def read(self): samples = [self._process(s) for s in samples] if self._seq_len: samples = flatten_samples(samples) - if self._pad and len(samples) % self._seq_len: - pad_len = self._seq_len - len(samples) % self._seq_len + pad_len = collate_pad_length(len(samples), self._seq_len, 1) + if self._pad: samples.extend([self._pad] * pad_len) - samples = collate(samples, self._seq_len, 1) + samples = collate(samples, self._seq_len+1, 1) return SimpleDataset(samples).transform(lambda x: (x[:-1], x[1:])) diff --git a/python/mxnet/gluon/data/text/utils.py b/python/mxnet/gluon/data/text/utils.py index 057e8431303e..b923f748be73 100644 --- a/python/mxnet/gluon/data/text/utils.py +++ b/python/mxnet/gluon/data/text/utils.py @@ -52,5 +52,26 @@ def collate(flat_sample, seq_len, overlap=0): ------- List of samples, each of which has length equal to `seq_len`. """ - num_samples = len(flat_sample) // seq_len - return [flat_sample[i*seq_len:((i+1)*seq_len+overlap)] for i in range(num_samples)] + num_samples = (len(flat_sample)-seq_len) // (seq_len-overlap) + 1 + return [flat_sample[i*(seq_len-overlap):((i+1)*seq_len-i*overlap)] for i in range(num_samples)] + +def collate_pad_length(num_items, seq_len, overlap=0): + """Calculate the padding length needed for collated samples in order not to discard data. + + Parameters + ---------- + num_items : int + Number of items in dataset before collating. + seq_len : int + The length of each of the samples. + overlap : int, default 0 + The extra number of items in current sample that should overlap with the + next sample. + + Returns + ------- + Length of paddings. + """ + step = seq_len-overlap + span = num_items-seq_len + return (span // step + 1) * step - span diff --git a/python/mxnet/gluon/model_zoo/__init__.py b/python/mxnet/gluon/model_zoo/__init__.py index b8c32af38561..bde69ad3b3ab 100644 --- a/python/mxnet/gluon/model_zoo/__init__.py +++ b/python/mxnet/gluon/model_zoo/__init__.py @@ -21,3 +21,5 @@ from . import model_store from . import vision + +from . import text diff --git a/python/mxnet/gluon/model_zoo/text/__init__.py b/python/mxnet/gluon/model_zoo/text/__init__.py new file mode 100644 index 000000000000..9aabd8a6821e --- /dev/null +++ b/python/mxnet/gluon/model_zoo/text/__init__.py @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=wildcard-import, arguments-differ +r"""Module for pre-defined NLP models. + +This module contains definitions for the following model architectures: +- `AWD`_ + +You can construct a model with random weights by calling its constructor: + +.. code:: + + from mxnet.gluon.model_zoo import text + # TODO + awd = text.awd_variant() + +We provide pre-trained models for all the listed models. +These models can constructed by passing ``pretrained=True``: + +.. code:: + + from mxnet.gluon.model_zoo import text + # TODO + awd = text.awd_variant(pretrained=True) + +.. _AWD: https://arxiv.org/abs/1404.5997 +""" + +from .base import * + +from . import lm + +def get_model(name, **kwargs): + """Returns a pre-defined model by name + + Parameters + ---------- + name : str + Name of the model. + pretrained : bool + Whether to load the pretrained weights for model. + classes : int + Number of classes for the output layer. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default '~/.mxnet/models' + Location for keeping the model parameters. + + Returns + ------- + HybridBlock + The model. + """ + #models = {'awd_variant': awd_variant} + name = name.lower() + if name not in models: + raise ValueError( + 'Model %s is not supported. Available options are\n\t%s'%( + name, '\n\t'.join(sorted(models.keys())))) + return models[name](**kwargs) diff --git a/python/mxnet/gluon/model_zoo/text/base.py b/python/mxnet/gluon/model_zoo/text/base.py new file mode 100644 index 000000000000..7967eb9c282b --- /dev/null +++ b/python/mxnet/gluon/model_zoo/text/base.py @@ -0,0 +1,238 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Building blocks and utility for models.""" + +from ... import Block, HybridBlock, Parameter, contrib, nn, rnn +from .... import nd + + +class _TextSeq2SeqModel(Block): + def __init__(self, src_vocab, tgt_vocab, **kwargs): + super(_TextSeq2SeqModel, self).__init__(**kwargs) + self._src_vocab = src_vocab + self._tgt_vocab = tgt_vocab + + def begin_state(self, *args, **kwargs): + return self.encoder.begin_state(*args, **kwargs) + + def forward(self, inputs, begin_state=None): # pylint: disable=arguments-differ + embedded_inputs = self.embedding(inputs) + if not begin_state: + begin_state = self.begin_state() + encoded, state = self.encoder(embedded_inputs, begin_state) + out = self.decoder(encoded) + return out, state + + +def apply_weight_drop(block, local_param_name, rate, axes=(), + weight_dropout_mode='training'): + if not rate: + return + + params = block.collect_params('.*{}'.format(local_param_name)) + for full_param_name, param in params.items(): + dropped_param = WeightDropParameter(param, rate, weight_dropout_mode, axes) + param_dicts, reg_param_dicts = _find_param(block, full_param_name, local_param_name) + for param_dict in param_dicts: + param_dict[full_param_name] = dropped_param + for reg_param_dict in reg_param_dicts: + reg_param_dict[local_param_name] = dropped_param + local_attr = getattr(block, local_param_name) + if local_attr == param: + super(Block, block).__setattr__(local_param_name, dropped_param) + else: + if isinstance(local_attr, (list, tuple)): + if isinstance(local_attr, tuple): + local_attr = list(local_attr) + for i, v in enumerate(local_attr): + if v == param: + local_attr[i] = dropped_param + elif isinstance(local_attr, dict): + for k, v in local_attr: + if v == param: + local_attr[k] = dropped_param + else: + continue + super(Block, block).__setattr__(local_param_name, local_attr) + + +def _find_param(block, full_param_name, local_param_name): + param_dict_results = [] + reg_dict_results = [] + params = block.params + + if full_param_name in block.params._params: + if isinstance(block, HybridBlock) and local_param_name in block._reg_params: + reg_dict_results.append(block._reg_params) + while params: + if full_param_name in params._params: + param_dict_results.append(params._params) + if params._shared: + params = params._shared + else: + break + + if block._children: + for c in block._children: + pd, rd = _find_param(c, full_param_name, local_param_name) + param_dict_results.extend(pd) + reg_dict_results.extend(rd) + + return param_dict_results, reg_dict_results + +def get_rnn_cell(mode, num_layers, num_embed, num_hidden, + dropout, weight_dropout, + var_drop_in, var_drop_state, var_drop_out): + """create rnn cell given specs""" + rnn_cell = rnn.SequentialRNNCell() + with rnn_cell.name_scope(): + for i in range(num_layers): + if mode == 'rnn_relu': + cell = rnn.RNNCell(num_hidden, 'relu', input_size=num_embed) + elif mode == 'rnn_tanh': + cell = rnn.RNNCell(num_hidden, 'tanh', input_size=num_embed) + elif mode == 'lstm': + cell = rnn.LSTMCell(num_hidden, input_size=num_embed) + elif mode == 'gru': + cell = rnn.GRUCell(num_hidden, input_size=num_embed) + if var_drop_in + var_drop_state + var_drop_out != 0: + cell = contrib.rnn.VariationalDropoutCell(cell, + var_drop_in, + var_drop_state, + var_drop_out) + + rnn_cell.add(cell) + if i != num_layers - 1 and dropout != 0: + rnn_cell.add(rnn.DropoutCell(dropout)) + + if weight_dropout: + apply_weight_drop(rnn_cell, 'h2h_weight', rate=weight_dropout) + + return rnn_cell + + +def get_rnn_layer(mode, num_layers, num_embed, num_hidden, dropout, weight_dropout): + """create rnn layer given specs""" + if mode == 'rnn_relu': + block = rnn.RNN(num_hidden, 'relu', num_layers, dropout=dropout, + input_size=num_embed) + elif mode == 'rnn_tanh': + block = rnn.RNN(num_hidden, num_layers, dropout=dropout, + input_size=num_embed) + elif mode == 'lstm': + block = rnn.LSTM(num_hidden, num_layers, dropout=dropout, + input_size=num_embed) + elif mode == 'gru': + block = rnn.GRU(num_hidden, num_layers, dropout=dropout, + input_size=num_embed) + if weight_dropout: + apply_weight_drop(block, 'h2h_weight', rate=weight_dropout) + + return block + + +class RNNCellLayer(Block): + """A block that takes an rnn cell and makes it act like rnn layer.""" + def __init__(self, rnn_cell, layout='TNC', **kwargs): + super(RNNCellBlock, self).__init__(**kwargs) + self.cell = rnn_cell + assert layout == 'TNC' or layout == 'NTC', \ + "Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout + self._layout = layout + self._axis = layout.find('T') + self._batch_axis = layout.find('N') + + def forward(self, inputs, states=None): # pylint: disable=arguments-differ + batch_size = inputs.shape[self._batch_axis] + skip_states = states is None + if skip_states: + states = self.cell.begin_state(batch_size, ctx=inputs.context) + if isinstance(states, ndarray.NDArray): + states = [states] + for state, info in zip(states, self.cell.state_info(batch_size)): + if state.shape != info['shape']: + raise ValueError( + "Invalid recurrent state shape. Expecting %s, got %s."%( + str(info['shape']), str(state.shape))) + states = sum(zip(*((j for j in i) for i in states)), ()) + outputs, states = self.cell.unroll( + inputs.shape[self._axis], inputs, states, + layout=self._layout, merge_outputs=True) + + if skip_states: + return outputs + return outputs, states + +class ExtendedSequential(nn.Sequential): + def forward(self, *x): # pylint: disable=arguments-differ + for block in self._children: + x = block(*x) + return x + +class TransformerBlock(Block): + def __init__(self, *blocks, **kwargs): + super(TransformerBlock, self).__init__(**kwargs) + self._blocks = blocks + + def forward(self, *inputs): + return [block(data) if block else data for block, data in zip(self._blocks, inputs)] + + +class WeightDropParameter(Parameter): + """A Container holding parameters (weights) of Blocks and performs dropout. + parameter : Parameter + The parameter which drops out. + rate : float, default 0.0 + Fraction of the input units to drop. Must be a number between 0 and 1. + Dropout is not applied if dropout_rate is 0. + mode : str, default 'training' + Whether to only turn on dropout during training or to also turn on for inference. + Options are 'training' and 'always'. + axes : tuple of int, default () + Axes on which dropout mask is shared. + """ + def __init__(self, parameter, rate=0.0, mode='training', axes=()): + p = parameter + super(WeightDropParameter, self).__init__( + name=p.name, grad_req=p.grad_req, shape=p._shape, dtype=p.dtype, + lr_mult=p.lr_mult, wd_mult=p.wd_mult, init=p.init, + allow_deferred_init=p._allow_deferred_init, + differentiable=p._differentiable) + self._rate = rate + self._mode = mode + self._axes = axes + + def data(self, ctx=None): + """Returns a copy of this parameter on one context. Must have been + initialized on this context before. + Parameters + ---------- + ctx : Context + Desired context. + Returns + ------- + NDArray on ctx + """ + d = self._check_and_get(self._data, ctx) + if self._rate: + d = nd.Dropout(d, self._rate, self._mode, self._axes) + return d + + def __repr__(self): + s = 'WeightDropParameter {name} (shape={shape}, dtype={dtype}, rate={rate}, mode={mode})' + return s.format(name=self.name, shape=self.shape, dtype=self.dtype, + rate=self._rate, mode=self._mode) diff --git a/python/mxnet/gluon/model_zoo/text/lm.py b/python/mxnet/gluon/model_zoo/text/lm.py new file mode 100644 index 000000000000..3ada7d0f735d --- /dev/null +++ b/python/mxnet/gluon/model_zoo/text/lm.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Language models.""" + +from .base import _TextSeq2SeqModel, ExtendedSequential, TransformerBlock +from .base import get_rnn_layer, apply_weight_drop +from ... import nn +from .... import init + + +class AWDLSTM(_TextSeq2SeqModel): + """AWD language model.""" + def __init__(self, mode, vocab, embed_dim, hidden_dim, num_layers, + dropout=0.5, drop_h=0.5, drop_i=0.5, drop_e=0.1, weight_drop=0, + tie_weights=False, **kwargs): + super(AWDLSTM, self).__init__(vocab, vocab, **kwargs) + self._mode = mode + self._embed_dim = embed_dim + self._hidden_dim = hidden_dim + self._num_layers = num_layers + self._dropout = dropout + self._drop_h = drop_h + self._drop_i = drop_i + self._drop_e = drop_e + self._weight_drop = weight_drop + self._tie_weights = tie_weights + self.embedding = self._get_embedding() + self.encoder = self._get_encoder() + self.decoder = self._get_decoder() + + def _get_embedding(self): + embedding = nn.HybridSequential() + with embedding.name_scope(): + embedding_block = nn.Embedding(len(self._src_vocab), self._embed_dim, + weight_initializer=init.Uniform(0.1)) + if self._drop_e: + apply_weight_drop(embedding_block, 'weight', self._drop_e, axes=(1,)) + embedding.add(embedding_block) + if self._drop_i: + embedding.add(nn.Dropout(self._drop_i, axes=(0,))) + return embedding + + def _get_encoder(self): + encoder = ExtendedSequential() + with encoder.name_scope(): + for l in range(self._num_layers): + encoder.add(get_rnn_layer(self._mode, 1, self._embed_dim if l == 0 else + self._hidden_dim, self._hidden_dim if + l != self._num_layers - 1 or not self._tie_weights + else self._embed_dim, 0, self._weight_drop)) + if self._drop_h: + encoder.add(TransformerBlock(nn.Dropout(self._drop_h, axes=(0,)), None)) + return encoder + + def _get_decoder(self): + vocab_size = len(self._tgt_vocab) + if self._tie_weights: + output = nn.Dense(vocab_size, flatten=False, params=self.embedding.params) + else: + output = nn.Dense(vocab_size, flatten=False) + return output + + def begin_state(self, *args, **kwargs): + return self.encoder[0].begin_state(*args, **kwargs) + +class RNNModel(_TextSeq2SeqModel): + """Simple RNN language model.""" + def __init__(self, mode, vocab, embed_dim, hidden_dim, + num_layers, dropout=0.5, tie_weights=False, **kwargs): + super(RNNModel, self).__init__(vocab, vocab, **kwargs) + self._mode = mode + self._embed_dim = embed_dim + self._hidden_dim = hidden_dim + self._num_layers = num_layers + self._dropout = dropout + self._tie_weights = tie_weights + self.embedding = self._get_embedding() + self.encoder = self._get_encoder() + self.decoder = self._get_decoder() + + def _get_embedding(self): + embedding = nn.HybridSequential() + with embedding.name_scope(): + embedding.add(nn.Embedding(len(self._src_vocab), self._embed_dim, + weight_initializer=init.Uniform(0.1))) + if self._dropout: + embedding.add(nn.Dropout(self._dropout)) + return embedding + + def _get_encoder(self): + encoder = ExtendedSequential() + with encoder.name_scope(): + for l in range(self._num_layers): + encoder.add(get_rnn_layer(self._mode, 1, self._embed_dim if l == 0 else + self._hidden_dim, self._hidden_dim if + l != self._num_layers - 1 or not self._tie_weights + else self._embed_dim, 0, 0)) + + return encoder + + def _get_decoder(self): + vocab_size = len(self._tgt_vocab) + if self._tie_weights: + output = nn.Dense(vocab_size, flatten=False, params=self.embedding[0].params) + else: + output = nn.Dense(vocab_size, flatten=False) + return output + + def begin_state(self, *args, **kwargs): + return self.encoder[0].begin_state(*args, **kwargs) diff --git a/tests/python/unittest/test_gluon_data_text.py b/tests/python/unittest/test_gluon_data_text.py index 49a598875ea9..8888b75edc4e 100644 --- a/tests/python/unittest/test_gluon_data_text.py +++ b/tests/python/unittest/test_gluon_data_text.py @@ -18,33 +18,51 @@ from __future__ import print_function import collections import mxnet as mx -from mxnet.gluon import nn, data +from mxnet.gluon import text, contrib, nn +from mxnet.gluon import data as d from common import setup_module, with_seed def get_frequencies(dataset): return collections.Counter(x for tup in dataset for x in tup[0]+tup[1][-1:]) + def test_wikitext2(): - train = data.text.lm.WikiText2(root='data/wikitext-2', segment='train') - val = data.text.lm.WikiText2(root='data/wikitext-2', segment='val') - test = data.text.lm.WikiText2(root='data/wikitext-2', segment='test') + train = d.text.lm.WikiText2(root='data/wikitext-2', segment='train') + val = d.text.lm.WikiText2(root='data/wikitext-2', segment='val') + test = d.text.lm.WikiText2(root='data/wikitext-2', segment='test') train_freq, val_freq, test_freq = [get_frequencies(x) for x in [train, val, test]] - assert len(train) == 59306, len(train) - assert len(train_freq) == 33279, len(train_freq) + assert len(train) == 59305, len(train) + assert len(train_freq) == 33278, len(train_freq) assert len(val) == 6182, len(val) assert len(val_freq) == 13778, len(val_freq) - assert len(test) == 6975, len(test) - assert len(test_freq) == 14144, len(test_freq) + assert len(test) == 6974, len(test) + assert len(test_freq) == 14143, len(test_freq) assert test_freq['English'] == 33, test_freq['English'] assert len(train[0][0]) == 35, len(train[0][0]) - test_no_pad = data.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None) + test_no_pad = d.text.lm.WikiText2(root='data/wikitext-2', segment='test', pad=None) assert len(test_no_pad) == 6974, len(test_no_pad) - train_paragraphs = data.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None) + train_paragraphs = d.text.lm.WikiText2(root='data/wikitext-2', segment='train', seq_len=None) assert len(train_paragraphs) == 23767, len(train_paragraphs) assert len(train_paragraphs[0][0]) != 35, len(train_paragraphs[0][0]) + vocab = text.vocab.Vocabulary(get_frequencies(train)) + def index_tokens(data, label): + return vocab[data], vocab[label] + nbatch_train = len(train) // 80 + train_data = d.DataLoader(train.transform(index_tokens), + batch_size=80, + sampler=contrib.data.IntervalSampler(len(train), + nbatch_train), + last_batch='discard') + sampler = contrib.data.IntervalSampler(len(train), nbatch_train) + + for i, (data, target) in enumerate(train_data): + pass + + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon_model_zoo.py b/tests/python/unittest/test_gluon_model_zoo.py index f89a8f701820..e97b3b51b5f0 100644 --- a/tests/python/unittest/test_gluon_model_zoo.py +++ b/tests/python/unittest/test_gluon_model_zoo.py @@ -17,7 +17,7 @@ from __future__ import print_function import mxnet as mx -from mxnet.gluon.model_zoo.vision import get_model +from mxnet.gluon.model_zoo.vision import get_model as get_vision_model import sys from common import setup_module, with_seed @@ -28,20 +28,20 @@ def eprint(*args, **kwargs): @with_seed() def test_models(): - all_models = ['resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1', - 'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2', - 'vgg11', 'vgg13', 'vgg16', 'vgg19', - 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn', - 'alexnet', 'inceptionv3', - 'densenet121', 'densenet161', 'densenet169', 'densenet201', - 'squeezenet1.0', 'squeezenet1.1', - 'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25', - 'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25'] + vision_models = ['resnet18_v1', 'resnet34_v1', 'resnet50_v1', 'resnet101_v1', 'resnet152_v1', + 'resnet18_v2', 'resnet34_v2', 'resnet50_v2', 'resnet101_v2', 'resnet152_v2', + 'vgg11', 'vgg13', 'vgg16', 'vgg19', + 'vgg11_bn', 'vgg13_bn', 'vgg16_bn', 'vgg19_bn', + 'alexnet', 'inceptionv3', + 'densenet121', 'densenet161', 'densenet169', 'densenet201', + 'squeezenet1.0', 'squeezenet1.1', + 'mobilenet1.0', 'mobilenet0.75', 'mobilenet0.5', 'mobilenet0.25', + 'mobilenetv2_1.0', 'mobilenetv2_0.75', 'mobilenetv2_0.5', 'mobilenetv2_0.25'] pretrained_to_test = set(['squeezenet1.1']) - for model_name in all_models: + for model_name in vision_models: test_pretrain = model_name in pretrained_to_test - model = get_model(model_name, pretrained=test_pretrain, root='model/') + model = get_vision_model(model_name, pretrained=test_pretrain, root='model/') data_shape = (2, 3, 224, 224) if 'inception' not in model_name else (2, 3, 299, 299) eprint('testing forward for %s' % model_name) print(model)