diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 73fca6050acf..0c6aae921352 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -310,14 +310,16 @@ def _init_grad(self): self._grad, self.grad_req) def _reduce(self): - """Reduce data from multiple context.""" + """Reduce data from multiple context to cpu.""" + ctx = context.cpu() if self._stype == 'default': block = self.list_data() - data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block) + data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block) else: # fetch all rows for 'row_sparse' param - all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu()) - data = self.row_sparse_data(all_row_ids) + all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx) + data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx) + self._trainer._row_sparse_pull(self, data, all_row_ids) return data def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(), diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index ef20109021aa..02d68f0c39cb 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -152,7 +152,6 @@ def _reset_kvstore(self): def _init_kvstore(self): """Create kvstore.""" - arg_arrays = {} config = self._kvstore_params if self._contains_sparse: kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) @@ -162,6 +161,7 @@ def _init_kvstore(self): "gradients and/or sparse weights are present for " "Parameter '%s'."%param.name) else: + arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) if config['update_on_kvstore'] is not None: diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 06b91fadcee4..fcb7c97b9809 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -118,10 +118,14 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True): def clip_global_norm(arrays, max_norm): """Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`. """ + def _norm(array): + if array.stype == 'default': + x = array.reshape((-1,)) + return ndarray.dot(x, x) + return array.norm().square() assert len(arrays) > 0 ctx = arrays[0].context - total_norm = ndarray.add_n(*[ndarray.dot(x, x).as_in_context(ctx) - for x in (arr.reshape((-1,)) for arr in arrays)]) + total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays]) total_norm = ndarray.sqrt(total_norm).asscalar() if not np.isfinite(total_norm): warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'), diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index e9259fde4b38..e540657ed8f3 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -90,15 +90,16 @@ def test_parameter_invalid_access(): @with_seed() def test_paramdict(): + ctx = mx.cpu(1) params0 = gluon.ParameterDict('net_') params0.get('w0', shape=(10, 10)) params0.get('w1', shape=(10, 10), stype='row_sparse') - all_row_ids = mx.nd.arange(0, 10, ctx=mx.cpu()) + all_row_ids = mx.nd.arange(0, 10, ctx=ctx) # check param names assert list(params0.keys()) == ['net_w0', 'net_w1'] - params0.initialize(ctx=mx.cpu()) + params0.initialize(ctx=ctx) trainer0 = mx.gluon.Trainer(params0, 'sgd') - prev_w0 = params0.get('w0').data(mx.cpu()) + prev_w0 = params0.get('w0').data(ctx) prev_w1 = params0.get('w1').row_sparse_data(all_row_ids) # save params params0.save('test_paramdict.params') @@ -107,11 +108,11 @@ def test_paramdict(): params1 = gluon.ParameterDict('net_') params1.get('w0', shape=(10, 10)) params1.get('w1', shape=(10, 10), stype='row_sparse') - params1.load('test_paramdict.params', mx.cpu()) + params1.load('test_paramdict.params', ctx) trainer1 = mx.gluon.Trainer(params1, 'sgd') # compare the values before and after save/load - cur_w0 = params1.get('w0').data(mx.cpu()) + cur_w0 = params1.get('w0').data(ctx) cur_w1 = params1.get('w1').row_sparse_data(all_row_ids) mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) @@ -121,11 +122,11 @@ def test_paramdict(): params2 = gluon.ParameterDict('net_') params2.get('w0', shape=(10, 10)) params2.get('w1', shape=(10, 10)) - params2.load('test_paramdict.params', mx.cpu()) + params2.load('test_paramdict.params', ctx) # compare the values before and after save/load - cur_w0 = params2.get('w0').data(mx.cpu()) - cur_w1 = params2.get('w1').data(mx.cpu()) + cur_w0 = params2.get('w0').data(ctx) + cur_w1 = params2.get('w1').data(ctx) mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy()) mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy()) @@ -731,19 +732,23 @@ def test_sequential_warning(): @with_seed() def test_global_norm_clip(): - x1 = mx.nd.ones((3,3)) - x2 = mx.nd.ones((4,4)) - norm = gluon.utils.clip_global_norm([x1, x2], 1.0) - assert norm == 5.0 - assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) - assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) - - x3 = mx.nd.array([1.0, 2.0, float('nan')]) - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - gluon.utils.clip_global_norm([x1, x3], 2.0) - assert len(w) == 1 - + stypes = ['default', 'row_sparse'] + def check_global_norm_clip(stype): + x1 = mx.nd.ones((3,3)).tostype(stype) + x2 = mx.nd.ones((4,4)).tostype(stype) + norm = gluon.utils.clip_global_norm([x1, x2], 1.0) + assert norm == 5.0 + assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5) + assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5) + + x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype) + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + gluon.utils.clip_global_norm([x1, x3], 2.0) + assert len(w) == 1 + + for stype in stypes: + check_global_norm_clip(stype) @with_seed() def test_embedding(): diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index c2e11ebb18ee..1c59ceaa093a 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -177,24 +177,30 @@ def test_trainer_save_load(): @with_seed() def test_trainer_reset_kv(): - params = gluon.ParameterDict() - x = params.get('x', shape=(10,), lr_mult=1.0) - params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') - trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}) - params.save('test_trainer_reset_kv.params') - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - trainer.step(1) - # load would reset kvstore - params.load('test_trainer_reset_kv.params') - assert trainer._kvstore is None - assert trainer._kv_initialized is False - with mx.autograd.record(): - for w in x.list_data(): - y = w + 1 - y.backward() - trainer.step(1) - # the updated parameter should be based on the loaded checkpoint - assert (x.data(mx.cpu()) == -0.2).asnumpy().all() + def check_trainer_reset_kv(kv): + params = gluon.ParameterDict() + x = params.get('x', shape=(10,), lr_mult=1.0) + params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') + trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv) + params.save('test_trainer_reset_kv.params') + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + assert trainer._kvstore.type == kv + # load would reset kvstore + params.load('test_trainer_reset_kv.params') + assert trainer._kvstore is None + assert trainer._kv_initialized is False + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + # the updated parameter should be based on the loaded checkpoint + assert (x.data(mx.cpu()) == -0.2).asnumpy().all() + + kvs = ['local', 'device'] + for kv in kvs: + check_trainer_reset_kv(kv)