Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Implemented a python SVRGModule for performing SVRG Optimization Logi…
Browse files Browse the repository at this point in the history
…c. This version supports single machine SVRG with single cpu, gpu and multi-gpus.
  • Loading branch information
StephanieYuan committed Sep 6, 2018
1 parent e290623 commit 3887fe9
Show file tree
Hide file tree
Showing 16 changed files with 1,751 additions and 2 deletions.
86 changes: 86 additions & 0 deletions docs/api/python/contrib/svrg_optimization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# SVRG Optimization in Python Module API

## Overview
SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that was first introduced in
paper _Accelerating Stochastic Gradient Descent using Predictive Variance Reduction_ in 2013. It is complement to SGD
(Stochastic Gradient Descent), which is known for large scale optimization but suffers from slow convergence
asymptotically due to its inherent variance. SGD approximates the full gradients using a small batch of data or
a single data sample, which will introduce variance and thus requires to start with a small learning rate in order to
ensure convergence. SVRG remedies the problem by keeping track of a version of estimated weights that close to the
optimal parameter values and maintaining an average of full gradients over a full pass of data. The average of full
gradients is calculated with respect to the weights from the last m-th epochs in the training. SVRG uses a different
update rule: gradients w.r.t current parameter values minus gradients w.r.t to parameters from the last m-th epochs
plus the average of full gradients over all data.

Key Characteristics of SVRG:
* Employs explicit variance reduction by using a different update rule compared to SGD.
* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
* Guarantees for fast convergence for smooth and strongly convex functions.

SVRG optimization is implemented as a SVRGModule in `mxnet.contrib.svrg_optimization`, which is an extension of the
existing `mxnet.module.Module` APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule
API changes compared to Module API to end users are minimal.

In distributed training, each worker gets the same special weights from the last m-th epoch and calculates the full
gradients with respect to its own shard of data. The standard SVRG optimization requires building a global full
gradients, which is calculated by aggregating the full gradients from each worker and averaging over the number of
workers. The workaround is to keep an additional set of keys in the KVStore that maps to full gradients.
The `_SVRGOptimizer` is designed to wrap two optimizers, an `_AssignmentOptimizer` which is used for full gradients
accumulation in the KVStore and a regular optimizer that performs actual update rule to the parameters.
The `_SVRGOptimizer` and `_AssignmentOptimizer` are designed to be used in `SVRGModule` only.

```eval_rst
.. warning:: This package contains experimental APIs and may change in the near future.
```

This document lists the SVRGModule APIs in MXNet/Contrib package:

```eval_rst
.. autosummary::
:nosignatures:
mxnet.contrib.svrg_optimization.svrg_module
```

### Intermediate Level API for SVRGModule

The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the
full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate
level APIs.

```python
>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
>>> mod.init_params()
>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), kvstore='local')
>>> for epoch in range(num_epochs):
... if epoch % mod.update_freq == 0:
... mod.update_full_grads(di)
... di.reset()
... for batch in di:
... mod.forward_backward(data_batch=batch)
... mod.update()
```

### High Level API for SVRGModule

The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of
suggested usage of high level API.

```python
>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
```

## API reference

<script type="text/javascript" src='../../../_static/js/auto_module_index.js'></script>

```eval_rst
.. automodule:: mxnet.contrib.svrg_optimization.svrg_module
.. autoclass:: mxnet.contrib.svrg_optimization.svrg_module.SVRGModule
:members: init_optimizer, bind, forward, backward, reshape, update, update_full_grads, fit, prepare
```
<script>auto_index("api-reference");</script>
3 changes: 2 additions & 1 deletion docs/api/python/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Code examples are placed throughout the API documentation and these can be run a
contrib/contrib.md
contrib/text.md
contrib/onnx.md
contrib/svrg_optimization.md
```

## Gluon API
Expand Down Expand Up @@ -176,4 +177,4 @@ Code examples are placed throughout the API documentation and these can be run a
:maxdepth: 1
symbol_in_pictures/symbol_in_pictures.md
```
```
2 changes: 1 addition & 1 deletion docs/api/python/module/module.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,4 @@ additional functionality. We summarize them in this section.
:members:
```

<script>auto_index("api-reference");</script>
<script>auto_index("api-reference");</script>
33 changes: 33 additions & 0 deletions example/svrg_module/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
## SVRGModule Example
SVRGModule is an extension to the Module API that implements SVRG optimization, which stands for Stochastic
Variance Reduced Gradient. SVRG is an optimization technique that complements SGD and has several key
properties:
* Employs explicit variance reduction by using a different update rule compared to SGD.
* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
* Guarantees for fast convergence for smooth and strongly convex functions.

#### API Usage Example
SVRGModule provides both high-level and intermediate-level APIs while minimizing the changes with Module API.
example_api_train.py: provides suggested usage of SVRGModule high-level and intermediate-level API.
example_inference.py: provides example usage of SVRGModule inference.

#### Linear Regression
This example trains a linear regression model using SVRGModule on a real dataset, YearPredictionMSD.
Logs of the training results can be found in experiments.log which will automatically generated when running the
training script.

##### Dataset
YearPredictionMSD: contains predictions of the release year of a song from audio features. It has over
400,000 samples with 90 features. Please uncomment data downloading script from data_reader.py to download the data.

#### Benchmarks:
An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model.

* benchmark1.py: A lr_scheduler returns a new learning rate based on the number of updates that have been performed.
The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs.

* benchmark2.py: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero,
thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG,
therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with
learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges
much faster than SGD in both cases.
124 changes: 124 additions & 0 deletions example/svrg_module/api_usage_example/example_api_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import mxnet as mx
import numpy as np
from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule


def test_svrg_intermediate_level_api(args):
"""Demonstrates intermediate level SVRGModule API where the training process
need to be explicitly defined. KVstore is not explicitly created.
Parameters
----------
args: args
Command line arguments
"""
num_epoch = args.epochs
batch_size = args.batch_size
update_freq = args.update_freq

di, mod = create_network(batch_size, update_freq)

mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
kv = mx.kv.create("local")
mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
metrics = mx.metric.create("mse")
for e in range(num_epoch):
metrics.reset()
if e % mod.update_freq == 0:
mod.update_full_grads(di)
di.reset()
for batch in di:
mod.forward_backward(data_batch=batch)
mod.update()
mod.update_metric(metrics, batch.label)
mod.logger.info('Epoch[%d] Train cost=%f', e, metrics.get()[1])


def test_svrg_high_level_api(args):
"""Demonstrates suggested usage of high level SVRGModule API. KVStore is explicitly created.
Parameters
----------
args: args
Command line arguments
"""
num_epoch = args.epochs
batch_size = args.batch_size
update_freq = args.update_freq

di, mod = create_network(batch_size, update_freq)
mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch,
kvstore='local')


def create_network(batch_size, update_freq):
"""Create a linear regression network for performing SVRG optimization.
Parameters
----------
batch_size: int
Size of data split
update_freq: int
Update Frequency for calculating full gradients
Returns
----------
di: mx.io.NDArrayIter
Data iterator
update_freq: SVRGModule
An instance of SVRGModule for performing SVRG optimization
"""
import logging
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)

train_data = np.random.randint(1, 5, [1000, 2])
weights = np.array([1.0, 2.0])
train_label = train_data.dot(weights)

di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
X = mx.sym.Variable('data')
Y = mx.symbol.Variable('lin_reg_label')
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")

mod = SVRGModule(
symbol=lro,
data_names=['data'],
label_names=['lin_reg_label'], update_freq=update_freq, logger=logging
)

return di, mod

# run as a script
if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('-e', dest='epochs', default=100, type=int)
parser.add_argument('-bs', dest='batch_size', default=32, type=int)
parser.add_argument('-f', dest="update_freq", default=2, type=int)
args = parser.parse_args()

print("========================== Intermediate Level API ==========================")
test_svrg_intermediate_level_api(args)
print("========================== High Level API ==========================")
test_svrg_high_level_api(args)
106 changes: 106 additions & 0 deletions example/svrg_module/api_usage_example/example_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import mxnet as mx
import numpy as np
import logging
from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule


def test_svrg_inference(args):
epoch = args.epochs
batch_size = args.batch_size
update_freq = args.update_freq

train_iter, val_iter, mod = create_network(batch_size, update_freq)
mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd',
optimizer_params=(('learning_rate', 0.025),),
num_epoch=epoch)


def get_validation_score(args):
epoch = args.epochs
batch_size = args.batch_size
update_freq = args.update_freq

train_iter, val_iter, mod = create_network(batch_size, update_freq)
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
metrics = mx.metric.create("mse")
for e in range(epoch):
metrics.reset()
if e % mod.update_freq == 0:
mod.update_full_grads(train_iter)
train_iter.reset()
for batch in train_iter:
mod.forward_backward(data_batch=batch)
mod.update()
mod.update_metric(metrics, batch.label)

y = mod.predict(val_iter)

# test-train data split, 20% test data out of 1000 data samples
assert y.shape == (200, 1)
score = mod.score(val_iter, ['mse'])
print("Training Loss on Validation Set is {}".format(score[0][1]))


def create_network(batch_size, update_freq):
"""Create a linear regression network for performing SVRG optimization.
:return: an instance of mx.io.NDArrayIter
:return: an instance of mx.mod.svrgmodule for performing SVRG optimization
"""
head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=logging.INFO, format=head)
data = np.random.randint(1, 5, [1000, 2])

#Test_Train data split
n_train = int(data.shape[0] * 0.8)
weights = np.array([1.0, 2.0])
label = data.dot(weights)

di = mx.io.NDArrayIter(data[:n_train, :], label[:n_train], batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
val_iter = mx.io.NDArrayIter(data[n_train:, :], label[n_train:], batch_size=batch_size)

X = mx.sym.Variable('data')
Y = mx.symbol.Variable('lin_reg_label')
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")

mod = SVRGModule(
symbol=lro,
data_names=['data'],
label_names=['lin_reg_label'], update_freq=update_freq, logger=logging)

return di, val_iter, mod


# run as a script
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('-e', dest='epochs', default=100, type=int)
parser.add_argument('-bs', dest='batch_size', default=32, type=int)
parser.add_argument('-f', dest="update_freq", default=2, type=int)
args = parser.parse_args()

print("========================== SVRG Module Inference ==========================")
test_svrg_inference(args)
print("========================SVRG Module Score ============================")
get_validation_score(args)
Binary file added example/svrg_module/benchmarks/benchmark1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added example/svrg_module/benchmarks/benchmark2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 3887fe9

Please sign in to comment.