Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Removes default value for ctx parameter in check_rnn_layer_forward an…
Browse files Browse the repository at this point in the history
…d refactors tests
  • Loading branch information
perdasilva committed Mar 26, 2019
1 parent f8a0dbc commit e947c6f
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
13 changes: 8 additions & 5 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def test_lstmp():
rtol, atol = 1e-2, 1e-2
batch_size, seq_len = 7, 11
input_size = 5
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0))
ctx=mx.gpu(0)
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx)
shapes = {'i2h_weight': (hidden_size*4, input_size),
'h2h_weight': (hidden_size*4, projection_size),
'i2h_bias': (hidden_size*4,),
Expand All @@ -101,8 +102,8 @@ def test_lstmp():
projection_size=projection_size,
input_size=input_size,
prefix='lstm0_l0_')
lstm_layer.initialize(ctx=mx.gpu(0))
lstm_cell.initialize(ctx=mx.gpu(0))
lstm_layer.initialize(ctx=ctx)
lstm_cell.initialize(ctx=ctx)
layer_params = lstm_layer.collect_params()
cell_params = lstm_cell.collect_params()
for k, v in weights.items():
Expand All @@ -121,13 +122,15 @@ def test_lstmp():
print('checking gradient for {}'.format('lstm0_l0_'+k))
assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(),
rtol=rtol, atol=atol)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))])
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), ctx, [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))])

check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)),
ctx,
run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5),
mx.nd.ones((8, 3, 20)),
ctx,
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True)


Expand Down
36 changes: 18 additions & 18 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def hybrid_forward(self, F, seq):
assert_almost_equal(output1.asnumpy(), output2.asnumpy())


def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.cpu()):
def check_rnn_layer_forward(layer, inputs, ctx, states=None, run_only=False):
layer.collect_params().initialize(ctx=ctx)
inputs = inputs.as_in_context(ctx)
inputs.attach_grad()
Expand Down Expand Up @@ -476,27 +476,27 @@ def check_rnn_layer_forward(layer, inputs, states=None, run_only=False, ctx=mx.c

def run_rnn_layers(dtype, dtype2, ctx=mx.cpu()):

check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)],ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, ), mx.nd.ones((8, 3, 20), dtype=dtype),ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype),ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2,dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, [mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)])
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dtype=dtype, bidirectional=True), mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype))


check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, dtype=dtype, dropout=0.5), mx.nd.ones((8, 3, 20), dtype=dtype), ctx,
run_only=True)
check_rnn_layer_forward(gluon.rnn.RNN(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx,
run_only=True)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype),
[mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype),
run_only=True, ctx=ctx)
mx.nd.ones((8, 3, 20), dtype=dtype), ctx,
[mx.nd.ones((4, 3, 10), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype)], run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, dropout=0.5, dtype=dtype), mx.nd.ones((8, 3, 20), dtype=dtype), ctx,
run_only=True)
check_rnn_layer_forward(gluon.rnn.GRU(10, 2, bidirectional=True, dropout=0.5, dtype=dtype),
mx.nd.ones((8, 3, 20), dtype=dtype), mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True, ctx=ctx)
mx.nd.ones((8, 3, 20), dtype=dtype), ctx, mx.nd.ones((4, 3, 10), dtype=dtype), run_only=True)

net = gluon.nn.Sequential()
net.add(gluon.rnn.LSTM(10, bidirectional=True, dtype=dtype2))
Expand Down Expand Up @@ -628,7 +628,7 @@ def test_cell_fill_shape():
def test_layer_fill_shape():
layer = gluon.rnn.LSTM(10)
layer.hybridize()
check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)))
check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7)), mx.cpu())
print(layer)
assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1]

Expand Down

0 comments on commit e947c6f

Please sign in to comment.