This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implemented a python SVRGModule for performing SVRG Optimization Logi…
…c. This version supports single machine SVRG with single cpu, gpu and multi-gpus.
- Loading branch information
1 parent
e290623
commit c446dc0
Showing
17 changed files
with
1,752 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# SVRG Optimization in Python Module API | ||
|
||
## Overview | ||
SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that complements SGD. It | ||
employs explicit variance reduction and converges much faster compared to SGD for smooth and strongly convex functions. | ||
|
||
Key Characteristics of SVRG: | ||
* Employs explicit variance reduction by using a different update rule compared to SGD. | ||
* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD. | ||
* Guarantees for fast convergence for smooth and strongly convex functions. | ||
|
||
SVRG optimization is implemented as a SVRGModule in `mxnet.contrib.svrg_optimization`, which is an extension of the | ||
existing `mxnet.module.Module` APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule | ||
API changes compared to Module API to end users are minimal. | ||
|
||
The current `SVRGModule` implements the standard SVRG optimization technique as described in _Accelerating Stochastic | ||
Gradient Descent using Predicative Variance Reduction_ by calculating the gradients of all data | ||
every `update_freq` epochs in the training. The SVRGModule update rule: gradients w.r.t current parameters minus | ||
gradients w.r.t parameters from the last mth epoch, plus the average of gradients over all data. | ||
|
||
In distributed training, each worker gets the same special weights from the last m-th epoch, calculates the full | ||
gradients with respect to its own shard of data. The standard SVRG optimization requires building a global full | ||
gradients, that sum up the full gradients from each worker and average over the number of workers. The solution is to | ||
keep an additional set of keys in the KVStore that maps to full gradients. The `_SVRGOptimizer` is designed to wrap two | ||
optimizers, an `_AssignmentOptimizer` which is used for full gradients accumulation in the KVStore and a regular | ||
optimizer that performs the actual update to the parameters. The `_SVRGOptimizer` and `_AssignmentOptimizer` are | ||
designed to be used in `SVRGModule` only. | ||
|
||
```eval_rst | ||
.. warning:: This package contains experimental APIs and may change in the near future. | ||
``` | ||
|
||
This document lists the svrg_optimization APIs in mxnet: | ||
|
||
```eval_rst | ||
.. autosummary:: | ||
:nosignatures: | ||
mxnet.contrib.svrg_optimization.SVRGModule | ||
mxnet.contrib.svrg_optimization._SVRGOptimizer | ||
mxnet.contrib.svrg_optimization._AssignmentOptimizer | ||
``` | ||
|
||
### Intermediate Level API for SVRGModule | ||
|
||
The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the | ||
full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate | ||
level APIs. | ||
|
||
```python | ||
>>> mod = SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label']) | ||
>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label) | ||
>>> mod.init_params() | ||
>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), )) | ||
>>> for epoch in range(num_epochs): | ||
... if epoch % mod.update_freq == 0: | ||
... mod.update_full_grads(di) | ||
... di.reset() | ||
... for batch in di: | ||
... mod.forward_backward(data_batch=batch) | ||
... mod.update() | ||
``` | ||
|
||
### High Level API for SVRGModule | ||
|
||
The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of | ||
suggested usage of high level API. | ||
|
||
```python | ||
>>> mod = SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label']) | ||
>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), num_epochs=100) | ||
``` | ||
|
||
## API reference | ||
|
||
<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script> | ||
|
||
```eval_rst | ||
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_module.SVRGModule | ||
:members: init_optimizer, _create_optimizer, bind, forward, backward, update, update_full_grads, | ||
_accumulate_kvstore, _allocate_gradients, _svrg_grads_update_rule, update_svrg_gradients, fit, prepare | ||
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_optimizer._SVRGOptimizer | ||
:members: _check_params, update, create_state, _check_index | ||
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_optimizer._AssignmentOptimizer | ||
:members: update | ||
``` | ||
<script>auto_index("api-reference");</script> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
## SVRGModule Example | ||
SVRGModule is an extension to the Module API that implements SVRG optimization, which stands for Stochastic | ||
Variance Reduced Gradient. SVRG is an optimization technique that complements SGD and has several key | ||
properties: | ||
* Employs explicit variance reduction by using a different update rule compared to SGD. | ||
* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD. | ||
* Guarantees for fast convergence for smooth and strongly convex functions. | ||
|
||
#### API Usage Example | ||
SVRGModule provides both high-level and intermediate-level APIs while minimizing the changes with Module API. | ||
example_api_train.py: provides suggested usage of SVRGModule high-level and intermediate-level API. | ||
example_inference.py: provides example usage of SVRGModule inference. | ||
|
||
#### Linear Regression | ||
This example trains a linear regression model using SVRGModule on a real dataset, YearPredictionMSD. | ||
Logs of the training results can be found in experiments.log which will automatically generated when running the | ||
training script. | ||
|
||
##### Dataset | ||
YearPredictionMSD: contains predictions of the release year of a song from audio features. It has over | ||
400,000 samples with 90 features. Please uncomment data downloading script from data_reader.py to download the data. | ||
|
||
#### Benchmarks: | ||
An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model. | ||
|
||
* benchmark1.py: A lr_scheduler returns a new learning rate based on the number of updates that have been performed. | ||
The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs. | ||
|
||
* benchmark2.py: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero, | ||
thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG, | ||
therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with | ||
learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges | ||
much faster than SGD in both cases. |
124 changes: 124 additions & 0 deletions
124
example/svrg_module/api_usage_example/example_api_train.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
|
||
|
||
import mxnet as mx | ||
import numpy as np | ||
from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule | ||
|
||
|
||
def test_svrg_intermediate_level_api(args): | ||
"""Demonstrates intermediate level SVRGModule API where the training process | ||
need to be explicitly defined. KVstore is not explicitly created. | ||
Parameters | ||
---------- | ||
args: args | ||
Command line arguments | ||
""" | ||
num_epoch = args.epochs | ||
batch_size = args.batch_size | ||
update_freq = args.update_freq | ||
|
||
di, mod = create_network(batch_size, update_freq) | ||
|
||
mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label) | ||
mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False) | ||
kv = mx.kv.create("local") | ||
mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.025),)) | ||
metrics = mx.metric.create("mse") | ||
for e in range(num_epoch): | ||
metrics.reset() | ||
if e % mod.update_freq == 0: | ||
mod.update_full_grads(di) | ||
di.reset() | ||
for batch in di: | ||
mod.forward_backward(data_batch=batch) | ||
mod.update() | ||
mod.update_metric(metrics, batch.label) | ||
mod.logger.info('Epoch[%d] Train cost=%f', e, metrics.get()[1]) | ||
|
||
|
||
def test_svrg_high_level_api(args): | ||
"""Demonstrates suggested usage of high level SVRGModule API. KVStore is explicitly created. | ||
Parameters | ||
---------- | ||
args: args | ||
Command line arguments | ||
""" | ||
num_epoch = args.epochs | ||
batch_size = args.batch_size | ||
update_freq = args.update_freq | ||
|
||
di, mod = create_network(batch_size, update_freq) | ||
mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch, | ||
kvstore='local') | ||
|
||
|
||
def create_network(batch_size, update_freq): | ||
"""Create a linear regression network for performing SVRG optimization. | ||
Parameters | ||
---------- | ||
batch_size: int | ||
Size of data split | ||
update_freq: int | ||
Update Frequency for calculating full gradients | ||
Returns | ||
---------- | ||
di: mx.io.NDArrayIter | ||
Data iterator | ||
update_freq: SVRGModule | ||
An instance of SVRGModule for performing SVRG optimization | ||
""" | ||
import logging | ||
head = '%(asctime)-15s %(message)s' | ||
logging.basicConfig(level=logging.INFO, format=head) | ||
|
||
train_data = np.random.randint(1, 5, [1000, 2]) | ||
weights = np.array([1.0, 2.0]) | ||
train_label = train_data.dot(weights) | ||
|
||
di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label') | ||
X = mx.sym.Variable('data') | ||
Y = mx.symbol.Variable('lin_reg_label') | ||
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1) | ||
lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro") | ||
|
||
mod = SVRGModule( | ||
symbol=lro, | ||
data_names=['data'], | ||
label_names=['lin_reg_label'], update_freq=update_freq, logger=logging | ||
) | ||
|
||
return di, mod | ||
|
||
# run as a script | ||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-e', dest='epochs', default=100, type=int) | ||
parser.add_argument('-bs', dest='batch_size', default=32, type=int) | ||
parser.add_argument('-f', dest="update_freq", default=2, type=int) | ||
args = parser.parse_args() | ||
|
||
print("========================== Intermediate Level API ==========================") | ||
test_svrg_intermediate_level_api(args) | ||
print("========================== High Level API ==========================") | ||
test_svrg_high_level_api(args) |
Oops, something went wrong.