diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 72c01acb2652..b4bfe4c47f00 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -74,12 +74,13 @@ def dict_equ(a, b): if trainer._update_on_kvstore: dict_equ(trainer._kvstore._updater.states, states) assert trainer._optimizer == trainer._kvstore._updater.optimizer + # invalid usage of update and allreduce_grads if update_on_kvstore + assert_raises(AssertionError, trainer.update, 1) + assert_raises(AssertionError, trainer.allreduce_grads) else: for updater in trainer._updaters: dict_equ(updater.states, states) assert trainer._optimizer == trainer._updaters[0].optimizer - assert_raises(AssertionError, trainer.update, 1) - assert_raises(AssertionError, trainer.allreduce_grads) x = gluon.Parameter('x', shape=(10,)) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') @@ -193,8 +194,10 @@ def check_trainer_reset_kv(kv): # load would reset kvstore mx.nd.waitall() params.load('test_trainer_reset_kv.params') - assert trainer._kvstore is None - assert trainer._kv_initialized is False + if trainer._update_on_kvstore: + # drop kvstore state if new parameters are loaded + assert trainer._kvstore is None + assert trainer._kv_initialized is False with mx.autograd.record(): for w in x.list_data(): y = w + 1