diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index 88b436a0deb2..9eeeec749211 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -88,7 +88,8 @@ def test_lstmp(): rtol, atol = 1e-2, 1e-2 batch_size, seq_len = 7, 11 input_size = 5 - lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=mx.gpu(0)) + ctx = mx.gpu(0) + lstm_input = mx.nd.uniform(shape=(seq_len, batch_size, input_size), ctx=ctx) shapes = {'i2h_weight': (hidden_size*4, input_size), 'h2h_weight': (hidden_size*4, projection_size), 'i2h_bias': (hidden_size*4,), @@ -101,8 +102,8 @@ def test_lstmp(): projection_size=projection_size, input_size=input_size, prefix='lstm0_l0_') - lstm_layer.initialize(ctx=mx.gpu(0)) - lstm_cell.initialize(ctx=mx.gpu(0)) + lstm_layer.initialize(ctx=ctx) + lstm_cell.initialize(ctx=ctx) layer_params = lstm_layer.collect_params() cell_params = lstm_cell.collect_params() for k, v in weights.items(): @@ -121,14 +122,14 @@ def test_lstmp(): print('checking gradient for {}'.format('lstm0_l0_'+k)) assert_almost_equal(layer_grad.asnumpy(), cell_grad.asnumpy(), rtol=rtol, atol=atol) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20))) - check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))]) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5), mx.nd.ones((8, 3, 20)), ctx=ctx) + check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, projection_size=5, bidirectional=True), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], ctx=ctx) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), - run_only=True) + run_only=True, ctx=ctx) check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), - [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True) + [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True, ctx=ctx) @with_seed()