diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index 45a44d8eb3e4..a95417cf523b 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -95,10 +95,10 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', if param._grad_stype != 'default': self._contains_sparse_grad = True self._compression_params = compression_params - optimizer_params = optimizer_params if optimizer_params else {} - self._scale = float(optimizer_params.get('rescale_grad', 1.0)) self._contexts = self._check_contexts() + optimizer_params = optimizer_params if optimizer_params else {} self._init_optimizer(optimizer, optimizer_params) + self._scale = self._optimizer.rescale_grad self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore} self._kv_initialized = False self._kvstore = None