Skip to content

Commit

Permalink
add a test in gluon.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed May 22, 2018
1 parent 767857f commit 471e6d2
Showing 1 changed file with 41 additions and 21 deletions.
62 changes: 41 additions & 21 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
import mxnet as mx
from mxnet import gluon
import numpy as np
import copy
from numpy.testing import assert_allclose
import unittest
from mxnet.test_utils import almost_equal
from mxnet.test_utils import almost_equal, assert_almost_equal


def test_rnn():
Expand All @@ -36,33 +37,52 @@ def test_rnn():
assert outs == [(10, 100), (10, 100), (10, 100)]


class RNNLayer(gluon.HybridBlock):
def __init__(self, prefix=None, params=None):
super(RNNLayer, self).__init__(prefix=prefix, params=params)
self.cell = gluon.contrib.rnn.RNNCell(100, prefix='rnn_')
class TestRNNLayer(gluon.HybridBlock):
def __init__(self, hidden_size, prefix=None, params=None):
super(TestRNNLayer, self).__init__(prefix=prefix, params=params)
self.cell = gluon.rnn.RNNCell(hidden_size, prefix='rnn_')

def hybrid_forward(self, F, inputs, states=None):
return self.cell.unroll(inputs, states)
def hybrid_forward(self, F, inputs, states):
states = [states]
out, states = F.contrib.foreach(self.cell, inputs, states)
return out

def test_contrib_rnn():
contrib_cell = gluon.contrib.rnn.RNNCell(100, prefix='rnn_')
inputs = mx.sym.Variable('rnn_data')
contrib_outputs, _ = contrib_cell.unroll(inputs)
assert sorted(contrib_cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight',
'rnn_i2h_bias', 'rnn_i2h_weight']

args, outs, auxs = contrib_outputs.infer_shape(rnn_data=(3, 10,50))
assert outs == [(3, 10, 100)]

rnn_data = mx.nd.normal(loc=0, scale=1, shape=(3, 10, 50))
layer = RNNLayer()
batch_size = 10
hidden_size = 100
rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50))
states = mx.nd.normal(loc=0, scale=1, shape=(batch_size, hidden_size))
layer = TestRNNLayer(hidden_size)
layer.initialize(ctx=mx.cpu(0))
res1 = layer(rnn_data)
res1 = layer(rnn_data, states)
params1 = layer.collect_params()
orig_params1 = copy.deepcopy(params1)

trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res1 = layer(rnn_data, states)
res1.backward()
trainer.step(batch_size)

layer = RNNLayer()
layer = TestRNNLayer(hidden_size)
layer.initialize(ctx=mx.cpu(0))
layer.hybridize()
res2 = layer(rnn_data)
res2 = layer(rnn_data, states)
params2 = layer.collect_params()
for key, val in orig_params1.items():
params2[key].set_data(val.data())

trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03})
with mx.autograd.record():
res2 = layer(rnn_data, states)
assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001)
res2.backward()
trainer.step(batch_size)

for key, val in params1.items():
weight1 = val.data()
weight2 = params2[key].data()
assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), rtol=0.001, atol=0.0001)


def test_lstm():
Expand Down

0 comments on commit 471e6d2

Please sign in to comment.