Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] Fix trainer param order #17068

Merged
merged 4 commits into from
Dec 14, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ class Trainer(object):
"""
def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
compression_params=None, update_on_kvstore=None):
if isinstance(params, (dict, ParameterDict)):
params = list(params.values())
param_list = []
if isinstance(params, mx.gluon.ParameterDict):
for key in sorted(list(params.keys())):
param_list.append(params[key])
params = param_list
if not isinstance(params, (list, tuple)):
raise ValueError(
"First argument must be a list or dict of Parameters, " \
Expand Down
16 changes: 16 additions & 0 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,19 @@ def test_trainer_lr_sched():
assert trainer.learning_rate == lr, (lr, trainer.learning_rate, i)
lr *= factor
mx.nd.waitall()

@with_seed()
def test_gluon_trainer_param_order():
net = mx.gluon.nn.Sequential()
# layers may be added in a random order for all workers
layers = {'ones_': 1, 'zeros_': 0}
for name, init in layers.items():
net.add(mx.gluon.nn.Dense(10, in_units=10, weight_initializer=mx.init.Constant(init),
use_bias=False, prefix=name))
params = net.collect_params()
net.initialize()
trainer = gluon.Trainer(params, 'sgd')
for name, init in layers.items():
expected_idx = 0 if name == 'ones_' else 1
expected_name = name + 'weight'
assert trainer._params[expected_idx].name == expected_name