Skip to content

Commit

Permalink
Add logging to test_bucketing.py
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 3, 2017
1 parent be02b28 commit 1236d7e
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions tests/python/train/test_bucketing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@
import random
from random import randint


def test_bucket_module():
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.DEBUG, format=head)
console = logging.StreamHandler()
console.setLevel(logging.DEBUG)
logging.getLogger('').addHandler(console)

class DummySentenceIter(mx.rnn.BucketSentenceIter):
"""Dummy sentence iterator to output sentences the same as input.
"""

def __init__(self, sentences, batch_size, buckets=None, invalid_label=-1,
data_name='data', label_name='l2_label', dtype='float32',
layout='NTC'):
Expand Down Expand Up @@ -40,7 +45,7 @@ def reset(self):
num_layers = 2
len_vocab = 100
buckets = [10, 20, 30, 40, 50, 60]

invalid_label = 0
num_sentence = 2500

Expand All @@ -57,14 +62,14 @@ def reset(self):
train_sent.append(train_sentence)
val_sent.append(val_sentence)

data_train = DummySentenceIter(train_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)
data_val = DummySentenceIter(val_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)
data_train = DummySentenceIter(train_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)
data_val = DummySentenceIter(val_sent, batch_size, buckets=buckets,
invalid_label=invalid_label)

stack = mx.rnn.SequentialRNNCell()
for i in range(num_layers):
stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_'%i))
stack.add(mx.rnn.LSTMCell(num_hidden=num_hidden, prefix='lstm_l%d_' % i))

def sym_gen(seq_len):
data = mx.sym.Variable('data')
Expand All @@ -77,31 +82,34 @@ def sym_gen(seq_len):

pred = mx.sym.Reshape(outputs, shape=(-1, num_hidden))
pred = mx.sym.FullyConnected(data=pred, num_hidden=1, name='pred')
pred = mx.sym.reshape(pred, shape= (batch_size, -1))
pred = mx.sym.reshape(pred, shape=(batch_size, -1))
loss = mx.sym.LinearRegressionOutput(pred, label, name='l2_loss')

return loss, ('data',), ('l2_label',)

contexts = mx.cpu(0)

model = mx.mod.BucketingModule(
sym_gen = sym_gen,
default_bucket_key = data_train.default_bucket_key,
context = contexts)
sym_gen=sym_gen,
default_bucket_key=data_train.default_bucket_key,
context=contexts)

logging.info('Begin fit...')
model.fit(
train_data = data_train,
eval_data = data_val,
eval_metric = mx.metric.MSE(),
kvstore = 'device',
optimizer = 'sgd',
optimizer_params = { 'learning_rate': 0.01,
'momentum': 0,
'wd': 0.00001 },
initializer = mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch = num_epochs,
batch_end_callback = mx.callback.Speedometer(batch_size, 50))
train_data=data_train,
eval_data=data_val,
eval_metric=mx.metric.MSE(),
kvstore='device',
optimizer='sgd',
optimizer_params={'learning_rate': 0.01,
'momentum': 0,
'wd': 0.00001},
initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
num_epoch=num_epochs,
batch_end_callback=mx.callback.Speedometer(batch_size, 50))
logging.info('Finished fit...')
assert model.score(data_val, mx.metric.MSE())[0][1] < 15, "High mean square error."


if __name__ == "__main__":
test_bucket_module()

0 comments on commit 1236d7e

Please sign in to comment.