Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adds context parameter to check_rnn_layer_forward calls in test_lstmp (
Browse files Browse the repository at this point in the history
  • Loading branch information
perdasilva authored and lanking520 committed Mar 27, 2019
1 parent 3d20f2a commit 67c10f9
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def test_lstmp():
rtol, atol = 1e-2, 1e-2
batch_size, seq_len = 7, 11
input_size = 5
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0))
ctx = mx.gpu(0)
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx)
shapes = {'i2h_weight': (hidden_size*4, input_size),
'h2h_weight': (hidden_size*4, projection_size),
'i2h_bias': (hidden_size*4,),
Expand All @@ -101,8 +102,8 @@ def test_lstmp():
projection_size=projection_size,
input_size=input_size,
prefix='lstm0_l0_')
lstm_layer.initialize(ctx=mx.gpu(0))
lstm_cell.initialize(ctx=mx.gpu(0))
lstm_layer.initialize(ctx=ctx)
lstm_cell.initialize(ctx=ctx)
layer_params = lstm_layer.collect_params()
cell_params = lstm_cell.collect_params()
for k, v in weights.items():
Expand All @@ -121,14 +122,14 @@ def test_lstmp():
print('checking gradient for {}'.format('lstm0_l0_'+k))
assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(),
rtol=rtol, atol=atol)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))])
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx)

check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)),
run_only=True)
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5),
mx.nd.ones((8, 3, 20)),
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True)
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True, ctx=ctx)


@with_seed()
Expand Down

0 comments on commit 67c10f9

Please sign in to comment.