Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixes test_gluon_gpu:test_lstmp
Browse files Browse the repository at this point in the history
  • Loading branch information
perdasilva committed Mar 26, 2019
1 parent 1aba52f commit e4e31c1
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tests/python/gpu/test_gluon_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ def test_lstmp():
rtol, atol = 1e-2, 1e-2
batch_size, seq_len = 7, 11
input_size = 5
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0))
ctx=mx.gpu(0)
lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx)
shapes = {'i2h_weight': (hidden_size*4, input_size),
'h2h_weight': (hidden_size*4, projection_size),
'i2h_bias': (hidden_size*4,),
Expand Down Expand Up @@ -121,14 +122,14 @@ def test_lstmp():
print('checking gradient for {}'.format('lstm0_l0_'+k))
assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(),
rtol=rtol, atol=atol)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)))
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))])
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx)

check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)),
run_only=True)
run_only=True, ctx=ctx)
check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5),
mx.nd.ones((8, 3, 20)),
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True)
[mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True, ctx=ctx)


@with_seed()
Expand Down

0 comments on commit e4e31c1

Please sign in to comment.