Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-514] Add clip_global_norm(row_sparse_grad). Fix row_sparse_param.save(). Fix trainer init_kvstore #11266

Merged
merged 3 commits into from
Jun 18, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,14 +310,16 @@ def _init_grad(self):
self._grad, self.grad_req)

def _reduce(self):
"""Reduce data from multiple context."""
"""Reduce data from multiple context to cpu."""
ctx = context.cpu()
if self._stype == 'default':
block = self.list_data()
data = ndarray.add_n(*(w.copyto(context.cpu()) for w in block)) / len(block)
data = ndarray.add_n(*(w.copyto(ctx) for w in block)) / len(block)
else:
# fetch all rows for 'row_sparse' param
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=context.cpu())
data = self.row_sparse_data(all_row_ids)
all_row_ids = ndarray.arange(0, self.shape[0], dtype='int64', ctx=ctx)
data = ndarray.zeros(self.shape, stype='row_sparse', ctx=ctx)
self._trainer._row_sparse_pull(self, data, all_row_ids)
return data

def initialize(self, init=None, ctx=None, default_init=initializer.Uniform(),
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ def _reset_kvstore(self):

def _init_kvstore(self):
"""Create kvstore."""
arg_arrays = {}
config = self._kvstore_params
if self._contains_sparse:
kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore'])
Expand All @@ -162,6 +161,7 @@ def _init_kvstore(self):
"gradients and/or sparse weights are present for "
"Parameter '%s'."%param.name)
else:
arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params}
kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts),
arg_arrays)
if config['update_on_kvstore'] is not None:
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,14 @@ def split_and_load(data, ctx_list, batch_axis=0, even_split=True):
def clip_global_norm(arrays, max_norm):
"""Rescales NDArrays so that the sum of their 2-norm is smaller than `max_norm`.
"""
def _norm(array):
if array.stype == 'default':
x = array.reshape((-1,))
return ndarray.dot(x, x)
return array.norm().square()
assert len(arrays) > 0
ctx = arrays[0].context
total_norm = ndarray.add_n(*[ndarray.dot(x, x).as_in_context(ctx)
for x in (arr.reshape((-1,)) for arr in arrays)])
total_norm = ndarray.add_n(*[_norm(arr).as_in_context(ctx) for arr in arrays])
total_norm = ndarray.sqrt(total_norm).asscalar()
if not np.isfinite(total_norm):
warnings.warn(UserWarning('nan or inf is detected. Clipping results will be undefined.'),
Expand Down
47 changes: 26 additions & 21 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,16 @@ def test_parameter_invalid_access():

@with_seed()
def test_paramdict():
ctx = mx.cpu(1)
params0 = gluon.ParameterDict('net_')
params0.get('w0', shape=(10, 10))
params0.get('w1', shape=(10, 10), stype='row_sparse')
all_row_ids = mx.nd.arange(0, 10, ctx=mx.cpu())
all_row_ids = mx.nd.arange(0, 10, ctx=ctx)
# check param names
assert list(params0.keys()) == ['net_w0', 'net_w1']
params0.initialize(ctx=mx.cpu())
params0.initialize(ctx=ctx)
trainer0 = mx.gluon.Trainer(params0, 'sgd')
prev_w0 = params0.get('w0').data(mx.cpu())
prev_w0 = params0.get('w0').data(ctx)
prev_w1 = params0.get('w1').row_sparse_data(all_row_ids)
# save params
params0.save('test_paramdict.params')
Expand All @@ -108,11 +109,11 @@ def test_paramdict():
params1 = gluon.ParameterDict('net_')
params1.get('w0', shape=(10, 10))
params1.get('w1', shape=(10, 10), stype='row_sparse')
params1.load('test_paramdict.params', mx.cpu())
params1.load('test_paramdict.params', ctx)
trainer1 = mx.gluon.Trainer(params1, 'sgd')

# compare the values before and after save/load
cur_w0 = params1.get('w0').data(mx.cpu())
cur_w0 = params1.get('w0').data(ctx)
cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
Expand All @@ -122,11 +123,11 @@ def test_paramdict():
params2 = gluon.ParameterDict('net_')
params2.get('w0', shape=(10, 10))
params2.get('w1', shape=(10, 10))
params2.load('test_paramdict.params', mx.cpu())
params2.load('test_paramdict.params', ctx)

# compare the values before and after save/load
cur_w0 = params2.get('w0').data(mx.cpu())
cur_w1 = params2.get('w1').data(mx.cpu())
cur_w0 = params2.get('w0').data(ctx)
cur_w1 = params2.get('w1').data(ctx)
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())

Expand Down Expand Up @@ -728,19 +729,23 @@ def test_sequential_warning():

@with_seed()
def test_global_norm_clip():
x1 = mx.nd.ones((3,3))
x2 = mx.nd.ones((4,4))
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)

x3 = mx.nd.array([1.0, 2.0, float('nan')])
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
gluon.utils.clip_global_norm([x1, x3], 2.0)
assert len(w) == 1

stypes = ['default', 'row_sparse']
def check_global_norm_clip(stype):
x1 = mx.nd.ones((3,3)).tostype(stype)
x2 = mx.nd.ones((4,4)).tostype(stype)
norm = gluon.utils.clip_global_norm([x1, x2], 1.0)
assert norm == 5.0
assert_almost_equal(x1.asnumpy(), np.ones((3,3))/5)
assert_almost_equal(x2.asnumpy(), np.ones((4,4))/5)

x3 = mx.nd.array([1.0, 2.0, float('nan')]).tostype(stype)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
gluon.utils.clip_global_norm([x1, x3], 2.0)
assert len(w) == 1

for stype in stypes:
check_global_norm_clip(stype)

@with_seed()
def test_embedding():
Expand Down
48 changes: 27 additions & 21 deletions tests/python/unittest/test_gluon_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,30 @@ def test_trainer_save_load():

@with_seed()
def test_trainer_reset_kv():
params = gluon.ParameterDict()
x = params.get('x', shape=(10,), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1})
params.save('test_trainer_reset_kv.params')
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
# load would reset kvstore
params.load('test_trainer_reset_kv.params')
assert trainer._kvstore is None
assert trainer._kv_initialized is False
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
# the updated parameter should be based on the loaded checkpoint
assert (x.data(mx.cpu()) == -0.2).asnumpy().all()
def check_trainer_reset_kv(kv):
params = gluon.ParameterDict()
x = params.get('x', shape=(10,), lr_mult=1.0)
params.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros')
trainer = gluon.Trainer(params, 'sgd', {'learning_rate': 0.1}, kvstore=kv)
params.save('test_trainer_reset_kv.params')
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
assert trainer._kvstore.type == kv
# load would reset kvstore
params.load('test_trainer_reset_kv.params')
assert trainer._kvstore is None
assert trainer._kv_initialized is False
with mx.autograd.record():
for w in x.list_data():
y = w + 1
y.backward()
trainer.step(1)
# the updated parameter should be based on the loaded checkpoint
assert (x.data(mx.cpu()) == -0.2).asnumpy().all()

kvs = ['local', 'device']
for kv in kvs:
check_trainer_reset_kv(kv)