From 696c54738f2ceca89c9eb7eac18060b942e6dfab Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sat, 14 Dec 2019 08:28:57 -0800 Subject: [PATCH] [BUGFIX] Fix trainer param order (#17068) * fix trainer param order * Update trainer.py * Update trainer.py * Update trainer.py --- python/mxnet/gluon/trainer.py | 5 ++++- tests/python/unittest/test_gluon_trainer.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 01f76d637a97..1ab86af2b93f 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -71,8 +71,11 @@ class Trainer(object): """ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', compression_params=None, update_on_kvstore=None): + param_list = [] if isinstance(params, (dict, ParameterDict)): - params = list(params.values()) + for key in sorted(list(params.keys())): + param_list.append(params[key]) + params = param_list if not isinstance(params, (list, tuple)): raise ValueError( "First argument must be a list or dict of Parameters, " \ diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 2d5874a8b97b..9f02733d0a25 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -291,3 +291,19 @@ def test_trainer_lr_sched(): assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i) lr *= factor mx.nd.waitall() + +@with_seed() +def test_gluon_trainer_param_order(): + net = mx.gluon.nn.Sequential() + # layers may be added in a random order for all workers + layers = {'ones_': 1, 'zeros_': 0} + for name, init in layers.items(): + net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init), + use_bias=False, prefix=name)) + params = net.collect_params() + net.initialize() + trainer = gluon.Trainer(params, 'sgd') + for name, init in layers.items(): + expected_idx = 0 if name == 'ones_' else 1 + expected_name = name + 'weight' + assert trainer._params[expected_idx].name == expected_name