Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update language model (#19)
Browse files Browse the repository at this point in the history
* udpate lmexp

* udpate word_language_model.py

* udpate word_language_model.py

* update with lm_decay

* update word_language_model.py with new updated hiddensize, standardrnn exchange tied and dropout; update base.py with rnn_relu config; update lm.py with awd_lstm_lm_1150 pretrained setting, and with new sentiment analysis and lm example

* remove lm test file
  • Loading branch information
cgraywang authored and szha committed Mar 27, 2018
1 parent 14d7499 commit 7b604ab
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
12 changes: 7 additions & 5 deletions example/gluon/word_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import mxnet as mx
from mxnet import gluon, autograd
from mxnet.gluon import data, text
from mxnet.gluon.model_zoo.text.lm import SimpleRNN, AWDRNN
from mxnet.gluon.model_zoo.text.lm import StandardRNN, AWDRNN

parser = argparse.ArgumentParser(description='MXNet Autograd RNN/LSTM Language Model on Wikitext-2.')
parser.add_argument('--model', type=str, default='lstm',
Expand Down Expand Up @@ -68,6 +68,8 @@
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu. (the result of multi-gpu training might be slightly different compared to single-gpu training, still need to be finalized)')
args = parser.parse_args()

print(args)


###############################################################################
# Load data
Expand All @@ -82,7 +84,7 @@
def get_frequencies(dataset):
return collections.Counter(x for tup in dataset for x in tup[0] if x)

vocab = text.vocab.Vocabulary(get_frequencies(train_dataset))
vocab = text.vocab.Vocabulary(get_frequencies(train_dataset), reserved_tokens=['<eos>', '<pad>'])
def index_tokens(data, label):
return vocab[data], vocab[label]

Expand Down Expand Up @@ -124,8 +126,8 @@ def index_tokens(data, label):
model = AWDRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers,
args.tied, args.dropout, args.weight_dropout, args.dropout_h, args.dropout_i)
else:
model = SimpleRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers,
args.tied, args.dropout)
model = StandardRNN(args.model, len(vocab), args.emsize, args.nhid, args.nlayers, args.dropout,
args.tied)

model.initialize(mx.init.Xavier(), ctx=context)

Expand Down Expand Up @@ -169,7 +171,7 @@ def train():
for epoch in range(args.epochs):
total_L = 0.0
start_epoch_time = time.time()
hiddens = [model.begin_state(args.batch_size, func=mx.nd.zeros, ctx=ctx) for ctx in context]
hiddens = [model.begin_state(args.batch_size//len(context), func=mx.nd.zeros, ctx=ctx) for ctx in context]
for i, (data, target) in enumerate(train_data):
start_batch_time = time.time()
data = data.T
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/text/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_rnn_cell(mode, num_layers, input_size, hidden_size,
def get_rnn_layer(mode, num_layers, input_size, hidden_size, dropout, weight_dropout):
"""create rnn layer given specs"""
if mode == 'rnn_relu':
block = rnn.RNN(hidden_size, 'relu', num_layers, dropout=dropout,
block = rnn.RNN(hidden_size, num_layers, 'relu', dropout=dropout,
input_size=input_size)
elif mode == 'rnn_tanh':
block = rnn.RNN(hidden_size, num_layers, dropout=dropout,
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/model_zoo/text/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def awd_lstm_lm_1150(dataset_name=None, vocab=None, pretrained=False, ctx=cpu(),
'tie_weights': True,
'dropout': 0.4,
'weight_drop': 0.5,
'drop_h': 0.3,
'drop_h': 0.2,
'drop_i': 0.65}
assert all(k not in kwargs for k in predefined_args), \
"Cannot override predefined model settings."
Expand Down

0 comments on commit 7b604ab

Please sign in to comment.