Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Zheng committed Jan 21, 2020
1 parent c9550a5 commit b7ee7d8
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 17 deletions.
13 changes: 7 additions & 6 deletions python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@ class Trainer(object):
Arguments would then be {'type':'2bit', 'threshold':0.5}
See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
update_on_kvstore : bool, default None
Whether to perform parameter updates on kvstore. If None, then trainer will choose the more
suitable option depending on the type of kvstore. If the `update_on_kvstore` argument is
provided, environment variable `MXNET_UPDATE_ON_KVSTORE` will be ignored.
Whether to perform parameter updates on kvstore. If None and optimizer.aggregate_num <= 1,
then trainer will choose the more suitable option depending on the type of kvstore.
If None and optimizer.aggregate_num > 1, `update_on_kvstore` is set to False.
If the `update_on_kvstore` argument is provided,
environment variable `MXNET_UPDATE_ON_KVSTORE` will be ignored.
Properties
----------
Expand Down Expand Up @@ -107,9 +109,8 @@ def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
if update_on_kvstore:
raise ValueError("Cannot set update_on_kvstore=True "
"when optimizer.aggregate_num > 1.")
if update_on_kvstore is None:
if self._optimizer.aggregate_num > 1:
update_on_kvstore = False
if update_on_kvstore is None and self._optimizer.aggregate_num > 1:
update_on_kvstore = False
self._kvstore_params = {'kvstore': kvstore, 'update_on_kvstore': update_on_kvstore}
self._kv_initialized = False
self._kvstore = None
Expand Down
5 changes: 2 additions & 3 deletions python/mxnet/optimizer/lars.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,17 @@ class LARS(Optimizer):
lazy_update : bool, default False
Default is False. If True, lazy updates are applied \
if the storage types of weight and grad are both ``row_sparse``.
aggregate_num : int, default 4
aggregate_num : int, default 1
Number of weights to be aggregated in a list.
They are passed to the optimizer for a single optimization step.
In default, all the weights are aggregated.
use_fused_step : bool, default True
Whether or not to use fused kernels for optimizer.
When use_fused_step=False, step is called,
otherwise, fused_step is called.
"""
def __init__(self, learning_rate=0.1, momentum=0.0, eta=0.001,
epsilon=1e-8, lazy_update=False, use_fused_step=True,
aggregate_num=4, **kwargs):
aggregate_num=1, **kwargs):
super(LARS, self).__init__(learning_rate=learning_rate,
use_fused_step=use_fused_step,
aggregate_num=aggregate_num,
Expand Down
5 changes: 3 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class Optimizer(object):
aggregate_num : int, optional, default None
Number of weights to be aggregated in a list.
They are passed to the optimizer for a single optimization step.
In default, all the weights are aggregated.
In default, only one weight is aggregated.
When `aggregate_num` is set to numpy.inf, all the weights are aggregated.
use_fused_step : bool, optional, default None
Whether or not to use fused kernels for optimizer.
Expand Down Expand Up @@ -118,7 +119,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0.,
self.multi_precision = multi_precision

if aggregate_num is None:
self.aggregate_num = numpy.inf
self.aggregate_num = 1
else:
self.aggregate_num = aggregate_num

Expand Down
5 changes: 2 additions & 3 deletions python/mxnet/optimizer/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,16 @@ class SGD(Optimizer):
True: makes internal 32-bit copy of the weights and applies gradients
in 32-bit precision even if actual weights used in the model have lower precision.
Turning this on can improve convergence and accuracy when training with float16.
aggregate_num : int, default 4
aggregate_num : int, default 1
Number of weights to be aggregated in a list.
They are passed to the optimizer for a single optimization step.
In default, all the weights are aggregated.
use_fused_step : bool, default True
Whether or not to use fused kernels for optimizer.
When use_fused_step=False, step is called,
otherwise, fused_step is called.
"""
def __init__(self, learning_rate=0.1, momentum=0.0, lazy_update=False,
multi_precision=False, use_fused_step=True, aggregate_num=4, **kwargs):
multi_precision=False, use_fused_step=True, aggregate_num=1, **kwargs):
super(SGD, self).__init__(learning_rate=learning_rate,
multi_precision=multi_precision,
aggregate_num=aggregate_num,
Expand Down
4 changes: 3 additions & 1 deletion tests/python/unittest/test_contrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ def test_group_adagrad():
eps_options = [{}, {'epsilon': 1e-8}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
agg_options = [{}, {'aggregate_num': 0}, {'aggregate_num': 1},
{'aggregate_num': 4}, {'aggregate_num': np.inf}]
for dtype in [np.float32]:
for options in itertools.product(eps_options, cg_options, rg_options):
for options in itertools.product(eps_options, cg_options, rg_options, agg_options):
kwarg = dict(wd=0.0)
for option in options:
kwarg.update(option)
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,8 @@ def test_ftrl():
('multi_precision' not in kwarg or not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(use_fused_step=False, **kwarg),
opt2(use_fused_step=True, **kwarg), shapes, dtype)
opt2(use_fused_step=True, **kwarg), shapes, dtype,
rtol=1e-4, atol=1e-4)


@with_seed()
Expand All @@ -710,7 +711,8 @@ def test_sparse_ftrl():
('multi_precision' not in kwarg or not kwarg['multi_precision'])):
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shapes,
dtype, w_stype='row_sparse', g_stype='row_sparse')
dtype, w_stype='row_sparse', g_stype='row_sparse',
rtol=1e-4, atol=1e-4)


@with_seed()
Expand Down

0 comments on commit b7ee7d8

Please sign in to comment.