Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
SVRG optimization in python/contrib package, this version supports si…
Browse files Browse the repository at this point in the history
…ngle machine single cpu, single gpu and multi-gpus
  • Loading branch information
StephanieYuan committed Sep 5, 2018
1 parent 6fdfd89 commit fa1753c
Show file tree
Hide file tree
Showing 22 changed files with 976 additions and 400 deletions.
39 changes: 0 additions & 39 deletions contrib/svrg_optimization_python/README.md

This file was deleted.

Empty file.
21 changes: 0 additions & 21 deletions contrib/svrg_optimization_python/tests/__init__.py

This file was deleted.

116 changes: 0 additions & 116 deletions contrib/svrg_optimization_python/tests/test_svrg_module.py

This file was deleted.

96 changes: 0 additions & 96 deletions contrib/svrg_optimization_python/tests/test_svrg_optimizer.py

This file was deleted.

90 changes: 90 additions & 0 deletions docs/api/python/contrib/svrg_optimization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SVRG Optimization in Python Module API

## Overview
SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that complements SGD. It
employs explicit variance reduction and converges much faster compared to SGD for smooth and strongly convex functions.

Key Characteristics of SVRG:
* Employs explicit variance reduction by using a different update rule compared to SGD.
* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
* Guarantees for fast convergence for smooth and strongly convex functions.

SVRG optimization is implemented as a SVRGModule in `mxnet.contrib.svrg_optimization`, which is an extension of the
existing `mxnet.module.Module` APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule
API changes compared to Module API to end users are minimal.

The current `SVRGModule` implements the standard SVRG optimization technique as described in _Accelerating Stochastic
Gradient Descent using Predicative Variance Reduction_ by calculating the gradients of all data
every `update_freq` epochs in the training. The SVRGModule update rule: gradients w.r.t current parameters minus
gradients w.r.t parameters from the last mth epoch, plus the average of gradients over all data.

In distributed training, each worker gets the same special weights from the last m-th epoch, calculates the full
gradients with respect to its own shard of data. The standard SVRG optimization requires building a global full
gradients, that sum up the full gradients from each worker and average over the number of workers. The solution is to
keep an additional set of keys in the KVStore that maps to full gradients. The `_SVRGOptimizer` is designed to wrap two
optimizers, an `_AssignmentOptimizer` which is used for full gradients accumulation in the KVStore and a regular
optimizer that performs the actual update to the parameters. The `_SVRGOptimizer` and `_AssignmentOptimizer` are
designed to be used in `SVRGModule` only.

```eval_rst
.. warning:: This package contains experimental APIs and may change in the near future.
```

This document lists the svrg_optimization APIs in mxnet:

```eval_rst
.. autosummary::
:nosignatures:
mxnet.contrib.svrg_optimization.SVRGModule
mxnet.contrib.svrg_optimization._SVRGOptimizer
mxnet.contrib.svrg_optimization._AssignmentOptimizer
```

### Intermediate Level API for SVRGModule

The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the
full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate
level APIs.

```python
>>> mod = SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])
>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
>>> mod.init_params()
>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
>>> for epoch in range(num_epochs):
... if epoch % mod.update_freq == 0:
... mod.update_full_grads(di)
... di.reset()
... for batch in di:
... mod.forward_backward(data_batch=batch)
... mod.update()
```

### High Level API for SVRGModule

The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of
suggested usage of high level API.

```python
>>> mod = SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])
>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), num_epochs=100)
```

## API reference

<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>

```eval_rst
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_module.SVRGModule
:members: init_optimizer, _create_optimizer, bind, forward, backward, update, update_full_grads,
_accumulate_kvstore, _allocate_gradients, _svrg_grads_update_rule, update_svrg_gradients, fit, prepare
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_optimizer._SVRGOptimizer
:members: _check_params, update, create_state, _check_index
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_optimizer._AssignmentOptimizer
:members: update
```
<script>auto_index("api-reference");</script>
1 change: 1 addition & 0 deletions docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -158,4 +158,5 @@ Code examples are placed throughout the API documentation and these can be run a
contrib/contrib.md
contrib/text.md
contrib/onnx.md
contrib/svrg_optimization.md
```
Loading

0 comments on commit fa1753c

Please sign in to comment.