diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 0229e9be15a0..7f1a7d09e585 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -398,25 +398,25 @@ def _allreduce_grads(self): return for i, param in enumerate(self._params): if param.grad_req != 'null': - + idx = self._param2idx[param._uuid] grad_list = param.list_grad() # sparse gradients, call push and pull separately if grad_list[0].stype != 'default': - self._kvstore.push(i, grad_list, priority=-i) + self._kvstore.push(idx, grad_list, priority=-i) if param._stype == 'default': if self._update_on_kvstore: pull_list = param.list_data() else: pull_list = param.list_grad() - self._kvstore.pull(i, pull_list, priority=-i, + self._kvstore.pull(idx, pull_list, priority=-i, ignore_sparse=self._distributed) else: # allreduce dense gradients if not update_on_kvstore, # otherwise push dense gradients, pull dense weights if self._update_on_kvstore: - self._kvstore.pushpull(i, grad_list, out=param.list_data(), priority=-i) + self._kvstore.pushpull(idx, grad_list, out=param.list_data(), priority=-i) else: - self._kvstore.pushpull(i, grad_list, priority=-i) + self._kvstore.pushpull(idx, grad_list, priority=-i) def update(self, batch_size, ignore_stale_grad=False): """Makes one step of parameter update. diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index a83d1046af47..5c94fc8d003c 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -359,3 +359,43 @@ def test_trainer_allreduce_hybridsequential(): out = net(mx.nd.ones((1, 1), ctx=ctx)) out.backward() trainer.allreduce_grads() + + +def test_trainer_share_parameters(): + class Net(gluon.Block): + def __init__(self, **kwargs): + super(Net, self).__init__(**kwargs) + self.dense1 = gluon.nn.Dense(5, in_units=2, use_bias=False) + params = self.dense1.collect_params() + self.dense2 = gluon.nn.Dense(5, in_units=2, + use_bias=False).share_parameters(params) + self.dense3 = gluon.nn.Dense(5, in_units=5, use_bias=False) + + def forward(self, x): + hidden = self.dense1(x) + self.dense2(x) + out = self.dense3(hidden) + return out + + net = Net() + ctxes = [mx.cpu(0), mx.cpu(1)] + net.initialize(mx.init.One(), ctx=ctxes) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 1}) + data = mx.nd.array([[1, 1], [1, 1]]) + xs = gluon.utils.split_and_load(data, ctxes) + ys = [] + with mx.autograd.record(): + for x in xs: + y = net(x) + ys.append(y) + for y in ys: + y.backward() + trainer.step(1) + params = net.collect_params() + shared_params = [] + for param in params.values(): + p = param.data(mx.cpu(0)).asnumpy() + if p.shape[1] == 2: + shared_params.append(p) + + assert((shared_params[0] == shared_params[1]).all()) +