diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 966ed2cc9964..a27c951c01b9 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -221,7 +221,7 @@ def _init_kvstore(self): "when sparse gradients are present.") update_on_kvstore = config['update_on_kvstore'] # raise err if a custom kvstore is used for sparse training - if not isinstance(kvstore, KVStore): + if kvstore is not None and not isinstance(kvstore, KVStore): raise ValueError("Cannot use {} for multi-device training with sparse gradients" .format(type(kvstore))) diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index fbd04ee1beec..350700cc129f 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -46,6 +46,21 @@ def test_multi_trainer(): # multiple trainers for a sparse Parameter is not allowed trainer1 = gluon.Trainer([x], 'sgd') +@with_seed() +def test_trainer_with_sparse_grad_on_single_context(): + x = gluon.Parameter('x', shape=(10,), grad_stype='row_sparse') + x.initialize(ctx=[mx.cpu(0)], init='zeros') + trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 1.0, 'momentum': 0.5}) + with mx.autograd.record(): + for w in x.list_data(): + y = w + 1 + y.backward() + trainer.step(1) + + assert trainer._update_on_kvstore is None + assert trainer._kvstore is None # No kvstore created for single-device training + assert (x.data(mx.cpu(0)).asnumpy() == -1).all() + @with_seed() def test_trainer_with_teststore(): x = gluon.Parameter('x', shape=(10,))