From 9cd5bc64262589bc12f341c80a22e91a6a92ff23 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sat, 11 Jan 2020 15:08:57 +0800 Subject: [PATCH] fix lstm layer with projection save params (#17266) --- python/mxnet/gluon/rnn/rnn_layer.py | 2 +- tests/python/gpu/test_gluon_gpu.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index b3cc596282a7..11d45815e37f 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -124,7 +124,7 @@ def __repr__(self): def _collect_params_with_prefix(self, prefix=''): if prefix: prefix += '.' - pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z') + pattern = re.compile(r'(l|r)(\d)_(i2h|h2h|h2r)_(weight|bias)\Z') def convert_key(m, bidirectional): # for compatibility with old parameter format d, l, g, t = [m.group(i) for i in range(1, 5)] if bidirectional: diff --git a/tests/python/gpu/test_gluon_gpu.py b/tests/python/gpu/test_gluon_gpu.py index fc650294a538..d6070d656046 100644 --- a/tests/python/gpu/test_gluon_gpu.py +++ b/tests/python/gpu/test_gluon_gpu.py @@ -137,6 +137,8 @@ def test_lstmp(): check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, dropout=0.5, projection_size=5), mx.nd.ones((8, 3, 20)), [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], run_only=True, ctx=ctx) + lstm_layer.save_parameters('gpu_tmp.params') + lstm_layer.load_parameters('gpu_tmp.params') @with_seed()