diff --git a/docs/api/python/contrib/svrg_optimization.md b/docs/api/python/contrib/svrg_optimization.md
new file mode 100644
index 000000000000..e6e1c3e23ee3
--- /dev/null
+++ b/docs/api/python/contrib/svrg_optimization.md
@@ -0,0 +1,86 @@
+# SVRG Optimization in Python Module API
+
+## Overview
+SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that was first introduced in
+paper _Accelerating Stochastic Gradient Descent using Predictive Variance Reduction_ in 2013. It is complement to SGD
+(Stochastic Gradient Descent), which is known for large scale optimization but suffers from slow convergence
+asymptotically due to its inherent variance. SGD approximates the full gradients using a small batch of data or
+a single data sample, which will introduce variance and thus requires to start with a small learning rate in order to
+ensure convergence. SVRG remedies the problem by keeping track of a version of estimated weights that close to the
+optimal parameter values and maintaining an average of full gradients over a full pass of data. The average of full
+gradients is calculated with respect to the weights from the last m-th epochs in the training. SVRG uses a different
+update rule: gradients w.r.t current parameter values minus gradients w.r.t to parameters from the last m-th epochs
+plus the average of full gradients over all data.
+
+Key Characteristics of SVRG:
+* Employs explicit variance reduction by using a different update rule compared to SGD.
+* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
+* Guarantees for fast convergence for smooth and strongly convex functions.
+
+SVRG optimization is implemented as a SVRGModule in `mxnet.contrib.svrg_optimization`, which is an extension of the
+existing `mxnet.module.Module` APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule
+API changes compared to Module API to end users are minimal.
+
+In distributed training, each worker gets the same special weights from the last m-th epoch and calculates the full
+gradients with respect to its own shard of data. The standard SVRG optimization requires building a global full
+gradients, which is calculated by aggregating the full gradients from each worker and averaging over the number of
+workers. The workaround is to keep an additional set of keys in the KVStore that maps to full gradients.
+The `_SVRGOptimizer` is designed to wrap two optimizers, an `_AssignmentOptimizer` which is used for full gradients
+accumulation in the KVStore and a regular optimizer that performs actual update rule to the parameters.
+The `_SVRGOptimizer` and `_AssignmentOptimizer` are designed to be used in `SVRGModule` only.
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the near future.
+```
+
+This document lists the SVRGModule APIs in MXNet/Contrib package:
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ mxnet.contrib.svrg_optimization.svrg_module
+```
+
+### Intermediate Level API for SVRGModule
+
+The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the
+full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate
+level APIs.
+
+```python
+>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
+>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+>>> mod.init_params()
+>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), kvstore='local')
+>>> for epoch in range(num_epochs):
+... if epoch % mod.update_freq == 0:
+... mod.update_full_grads(di)
+... di.reset()
+... for batch in di:
+... mod.forward_backward(data_batch=batch)
+... mod.update()
+```
+
+### High Level API for SVRGModule
+
+The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of
+suggested usage of high level API.
+
+```python
+>>> mod = SVRGModule(symbol=model, update_freq=2, data_names=['data'], label_names=['lin_reg_label'])
+>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
+```
+
+## API reference
+
+
+
+```eval_rst
+
+.. automodule:: mxnet.contrib.svrg_optimization.svrg_module
+.. autoclass:: mxnet.contrib.svrg_optimization.svrg_module.SVRGModule
+ :members: init_optimizer, bind, forward, backward, reshape, update, update_full_grads, fit, prepare
+
+```
+
\ No newline at end of file
diff --git a/docs/api/python/index.md b/docs/api/python/index.md
index 42c4af9e46b5..15d1045a93e4 100644
--- a/docs/api/python/index.md
+++ b/docs/api/python/index.md
@@ -52,6 +52,7 @@ Code examples are placed throughout the API documentation and these can be run a
contrib/contrib.md
contrib/text.md
contrib/onnx.md
+ contrib/svrg_optimization.md
```
## Gluon API
@@ -176,4 +177,4 @@ Code examples are placed throughout the API documentation and these can be run a
:maxdepth: 1
symbol_in_pictures/symbol_in_pictures.md
-```
+```
\ No newline at end of file
diff --git a/docs/api/python/module/module.md b/docs/api/python/module/module.md
index 86ed74db6c19..5a874ac6df02 100644
--- a/docs/api/python/module/module.md
+++ b/docs/api/python/module/module.md
@@ -207,4 +207,4 @@ additional functionality. We summarize them in this section.
:members:
```
-
+
\ No newline at end of file
diff --git a/example/svrg_module/README.md b/example/svrg_module/README.md
new file mode 100644
index 000000000000..7edce14fa103
--- /dev/null
+++ b/example/svrg_module/README.md
@@ -0,0 +1,33 @@
+## SVRGModule Example
+SVRGModule is an extension to the Module API that implements SVRG optimization, which stands for Stochastic
+Variance Reduced Gradient. SVRG is an optimization technique that complements SGD and has several key
+properties:
+* Employs explicit variance reduction by using a different update rule compared to SGD.
+* Ability to use relatively large learning rate, which leads to faster convergence compared to SGD.
+* Guarantees for fast convergence for smooth and strongly convex functions.
+
+#### API Usage Example
+SVRGModule provides both high-level and intermediate-level APIs while minimizing the changes with Module API.
+example_api_train.py: provides suggested usage of SVRGModule high-level and intermediate-level API.
+example_inference.py: provides example usage of SVRGModule inference.
+
+#### Linear Regression
+This example trains a linear regression model using SVRGModule on a real dataset, YearPredictionMSD.
+Logs of the training results can be found in experiments.log which will automatically generated when running the
+training script.
+
+##### Dataset
+YearPredictionMSD: contains predictions of the release year of a song from audio features. It has over
+400,000 samples with 90 features. Please uncomment data downloading script from data_reader.py to download the data.
+
+#### Benchmarks:
+An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model.
+
+* benchmark1.py: A lr_scheduler returns a new learning rate based on the number of updates that have been performed.
+The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs.
+
+* benchmark2.py: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero,
+thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG,
+therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with
+learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges
+much faster than SGD in both cases.
diff --git a/example/svrg_module/api_usage_example/example_api_train.py b/example/svrg_module/api_usage_example/example_api_train.py
new file mode 100644
index 000000000000..f6cd1b2e592c
--- /dev/null
+++ b/example/svrg_module/api_usage_example/example_api_train.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import numpy as np
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def test_svrg_intermediate_level_api(args):
+ """Demonstrates intermediate level SVRGModule API where the training process
+ need to be explicitly defined. KVstore is not explicitly created.
+
+ Parameters
+ ----------
+ args: args
+ Command line arguments
+ """
+ num_epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ di, mod = create_network(batch_size, update_freq)
+
+ mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+ kv = mx.kv.create("local")
+ mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
+ metrics = mx.metric.create("mse")
+ for e in range(num_epoch):
+ metrics.reset()
+ if e % mod.update_freq == 0:
+ mod.update_full_grads(di)
+ di.reset()
+ for batch in di:
+ mod.forward_backward(data_batch=batch)
+ mod.update()
+ mod.update_metric(metrics, batch.label)
+ mod.logger.info('Epoch[%d] Train cost=%f', e, metrics.get()[1])
+
+
+def test_svrg_high_level_api(args):
+ """Demonstrates suggested usage of high level SVRGModule API. KVStore is explicitly created.
+
+ Parameters
+ ----------
+ args: args
+ Command line arguments
+ """
+ num_epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ di, mod = create_network(batch_size, update_freq)
+ mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch,
+ kvstore='local')
+
+
+def create_network(batch_size, update_freq):
+ """Create a linear regression network for performing SVRG optimization.
+ Parameters
+ ----------
+ batch_size: int
+ Size of data split
+ update_freq: int
+ Update Frequency for calculating full gradients
+
+ Returns
+ ----------
+ di: mx.io.NDArrayIter
+ Data iterator
+ update_freq: SVRGModule
+ An instance of SVRGModule for performing SVRG optimization
+ """
+ import logging
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=logging.INFO, format=head)
+
+ train_data = np.random.randint(1, 5, [1000, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=update_freq, logger=logging
+ )
+
+ return di, mod
+
+# run as a script
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-e', dest='epochs', default=100, type=int)
+ parser.add_argument('-bs', dest='batch_size', default=32, type=int)
+ parser.add_argument('-f', dest="update_freq", default=2, type=int)
+ args = parser.parse_args()
+
+ print("========================== Intermediate Level API ==========================")
+ test_svrg_intermediate_level_api(args)
+ print("========================== High Level API ==========================")
+ test_svrg_high_level_api(args)
diff --git a/example/svrg_module/api_usage_example/example_inference.py b/example/svrg_module/api_usage_example/example_inference.py
new file mode 100644
index 000000000000..312f9796074d
--- /dev/null
+++ b/example/svrg_module/api_usage_example/example_inference.py
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import numpy as np
+import logging
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def test_svrg_inference(args):
+ epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ train_iter, val_iter, mod = create_network(batch_size, update_freq)
+ mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd',
+ optimizer_params=(('learning_rate', 0.025),),
+ num_epoch=epoch)
+
+
+def get_validation_score(args):
+ epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ train_iter, val_iter, mod = create_network(batch_size, update_freq)
+ mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+ mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
+ metrics = mx.metric.create("mse")
+ for e in range(epoch):
+ metrics.reset()
+ if e % mod.update_freq == 0:
+ mod.update_full_grads(train_iter)
+ train_iter.reset()
+ for batch in train_iter:
+ mod.forward_backward(data_batch=batch)
+ mod.update()
+ mod.update_metric(metrics, batch.label)
+
+ y = mod.predict(val_iter)
+
+ # test-train data split, 20% test data out of 1000 data samples
+ assert y.shape == (200, 1)
+ score = mod.score(val_iter, ['mse'])
+ print("Training Loss on Validation Set is {}".format(score[0][1]))
+
+
+def create_network(batch_size, update_freq):
+ """Create a linear regression network for performing SVRG optimization.
+ :return: an instance of mx.io.NDArrayIter
+ :return: an instance of mx.mod.svrgmodule for performing SVRG optimization
+ """
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=logging.INFO, format=head)
+ data = np.random.randint(1, 5, [1000, 2])
+
+ #Test_Train data split
+ n_train = int(data.shape[0] * 0.8)
+ weights = np.array([1.0, 2.0])
+ label = data.dot(weights)
+
+ di = mx.io.NDArrayIter(data[:n_train, :], label[:n_train], batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+ val_iter = mx.io.NDArrayIter(data[n_train:, :], label[n_train:], batch_size=batch_size)
+
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=update_freq, logger=logging)
+
+ return di, val_iter, mod
+
+
+# run as a script
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-e', dest='epochs', default=100, type=int)
+ parser.add_argument('-bs', dest='batch_size', default=32, type=int)
+ parser.add_argument('-f', dest="update_freq", default=2, type=int)
+ args = parser.parse_args()
+
+ print("========================== SVRG Module Inference ==========================")
+ test_svrg_inference(args)
+ print("========================SVRG Module Score ============================")
+ get_validation_score(args)
diff --git a/example/svrg_module/benchmarks/benchmark1.png b/example/svrg_module/benchmarks/benchmark1.png
new file mode 100644
index 000000000000..4217db5c93db
Binary files /dev/null and b/example/svrg_module/benchmarks/benchmark1.png differ
diff --git a/example/svrg_module/benchmarks/benchmark2.png b/example/svrg_module/benchmarks/benchmark2.png
new file mode 100644
index 000000000000..cccbf0a54c16
Binary files /dev/null and b/example/svrg_module/benchmarks/benchmark2.png differ
diff --git a/example/svrg_module/linear_regression/common.py b/example/svrg_module/linear_regression/common.py
new file mode 100644
index 000000000000..0b3d19b409c9
--- /dev/null
+++ b/example/svrg_module/linear_regression/common.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import logging
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def create_lin_reg_network(train_features, train_labels, feature_dim, batch_size, update_freq, ctx, logger):
+ # fit a linear regression model with mxnet SVRGModule
+ print("Fitting linear regression with mxnet")
+ train_iter = mx.io.NDArrayIter(train_features, train_labels, batch_size=batch_size, shuffle=True,
+ data_name='data', label_name='label')
+ data = mx.sym.Variable("data")
+ label = mx.sym.Variable("label")
+ weight = mx.sym.Variable("fc_weight", shape=(1, feature_dim))
+ net = mx.sym.dot(data, weight.transpose())
+ bias = mx.sym.Variable("fc_bias", shape=(1,), wd_mult=0.0, lr_mult=10.0)
+ net = mx.sym.broadcast_plus(net, bias)
+ net = mx.sym.LinearRegressionOutput(data=net, label=label)
+
+ mod = SVRGModule(symbol=net, context=ctx, data_names=['data'], label_names=['label'], logger=logger,
+ update_freq=update_freq)
+ return train_iter, mod
+
+
+def create_metrics(metrics):
+ metric = mx.metric.create(metrics)
+ return metric
+
+
+def create_logger():
+ logger = logging.getLogger('sgd_svrg')
+ logger.setLevel(logging.INFO)
+ formatter = logging.Formatter('%(asctime)s - %(message)s')
+ fh = logging.FileHandler('experiments.log')
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+ return logger
+
+
+################################################################################
+# Functions below are for benchmark purpose to calcuate expectation, variance of
+# gradients per epoch for each parameter. These calculations will be helpful when
+# benchmarking SVRG optimization with other optimization techniques, such as SGD.
+# Currently it only calculates the expectation, variance for single context but
+# can be extended to multi-context in later iterations.
+################################################################################
+
+def accumulate_grad(grad_dict, mod):
+ param_names = mod._exec_group.param_names
+
+ for index, name in enumerate(param_names):
+ if name not in grad_dict:
+ grad_dict[name] = mod._exec_group.grad_arrays[index][0].copy()
+ else:
+ grad_dict[name] = mx.ndarray.concat(grad_dict[name], mod._exec_group.grad_arrays[index][0], dim=0)
+
+
+def calc_expectation(grad_dict, num_batches):
+ """Calculates the expectation of the gradients per epoch for each parameter w.r.t number of batches
+
+ Parameters
+ ----------
+ grad_dict: dict
+ dictionary that maps parameter name to gradients in the mod executor group
+ num_batches: int
+ number of batches
+
+ Returns
+ ----------
+ grad_dict: dict
+ dictionary with new keys mapping to gradients expectations
+
+ """
+ for key in grad_dict.keys():
+ grad_dict[str.format(key+"_expectation")] = mx.ndarray.sum(grad_dict[key], axis=0) / num_batches
+
+ return grad_dict
+
+
+def calc_variance(grad_dict, num_batches, param_names):
+ """Calculates the variance of the gradients per epoch for each parameter w.r.t number of batches
+
+ Parameters
+ ----------
+ grad_dict: dict
+ dictionary that maps parameter name to gradients in the mod executor group
+ num_batches: int
+ number of batches
+ param_names: str
+ parameter name in the module
+
+ Returns
+ ----------
+ grad_dict: dict
+ dictionary with new keys mapping to gradients variance
+
+ """
+ for i in range(len(param_names)):
+ diff_sqr = mx.ndarray.square(mx.nd.subtract(grad_dict[param_names[i]],
+ grad_dict[str.format(param_names[i]+"_expectation")]))
+ grad_dict[str.format(param_names[i] + "_variance")] = mx.ndarray.sum(diff_sqr, axis=0) / num_batches
diff --git a/example/svrg_module/linear_regression/data_reader.py b/example/svrg_module/linear_regression/data_reader.py
new file mode 100644
index 000000000000..d1578fc153bf
--- /dev/null
+++ b/example/svrg_module/linear_regression/data_reader.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import numpy as np
+
+
+def read_year_prediction_data(fileName):
+ # Download data file
+ # from subprocess import call
+ # call(['wget', 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2'])
+ # call(['bzip2', '-d', 'YearPredictionMSD.bz2'])
+
+ from sklearn.datasets import load_svmlight_file
+
+ # YearPredictionMSD dataset: https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd
+ feature_dim = 90
+ print("Reading data from disk...")
+ train_features, train_labels = load_svmlight_file(fileName, n_features=feature_dim, dtype=np.float32)
+ train_features = train_features.todense()
+
+ # normalize the data: subtract means and divide by standard deviations
+ label_mean = train_labels.mean()
+ label_std = np.sqrt(np.square(train_labels - label_mean).mean())
+ feature_means = train_features.mean(axis=0)
+ feature_stds = np.sqrt(np.square(train_features - feature_means).mean(axis=0))
+
+ train_features = (train_features - feature_means) / feature_stds
+ train_labels = (train_labels - label_mean) / label_std
+
+ return feature_dim, train_features, train_labels
diff --git a/example/svrg_module/linear_regression/train.py b/example/svrg_module/linear_regression/train.py
new file mode 100644
index 000000000000..b3d942973f19
--- /dev/null
+++ b/example/svrg_module/linear_regression/train.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import argparse
+import mxnet as mx
+from common import create_lin_reg_network, create_logger
+from data_reader import read_year_prediction_data
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-e', dest='epochs', help='number of epochs for training phase', type=int, default=100)
+parser.add_argument('-f', dest="updateFreq", help="update frequency for SVRGModule", type=int, default=2)
+parser.add_argument('-b', dest="batch_size", help="define the batch size for training", type=int,
+ default=100, required=False)
+parser.add_argument('-m', dest='metrics', help="create eval metric", type=str, default='mse')
+parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
+parser.add_argument('--kv-store', type=str, default='local', help='key-value store type')
+
+args = parser.parse_args()
+# devices for training
+ctx = mx.cpu() if args.gpus is None or args.gpus == "" else [mx.gpu(int(i)) for i in args.gpus.split(',')]
+
+logger = create_logger()
+kv = mx.kvstore.create(args.kv_store)
+
+feature_dim, train_features, train_labels = read_year_prediction_data('YearPredictionMSD')
+train_iter, mod = create_lin_reg_network(train_features, train_labels, feature_dim, args.batch_size, args.updateFreq,
+ ctx, logger)
+
+mod.fit(train_iter, eval_metric='mse', optimizer='sgd',
+ optimizer_params=(('learning_rate', 0.025), ), num_epoch=args.epochs, kvstore=kv)
diff --git a/python/mxnet/contrib/svrg_optimization/__init__.py b/python/mxnet/contrib/svrg_optimization/__init__.py
new file mode 100644
index 000000000000..6e70009983c9
--- /dev/null
+++ b/python/mxnet/contrib/svrg_optimization/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""SVRGModule, SVRGOptimization import.
+"""
+
+
+from . import svrg_module
+from . import svrg_optimizer
diff --git a/python/mxnet/contrib/svrg_optimization/svrg_module.py b/python/mxnet/contrib/svrg_optimization/svrg_module.py
new file mode 100644
index 000000000000..84f640e1f487
--- /dev/null
+++ b/python/mxnet/contrib/svrg_optimization/svrg_module.py
@@ -0,0 +1,586 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+"""A `SVRGModule` implements the `Module` API by wrapping an auxiliary module to perform
+SVRG optimization logic.
+"""
+
+import time
+import logging
+import mxnet as mx
+from mxnet.module import Module
+from .svrg_optimizer import _SVRGOptimizer
+
+
+class SVRGModule(Module):
+ """SVRGModule is a module that encapsulates two Modules to accommodate the SVRG optimization technique.
+ It is functionally the same as Module API, except it is implemented using SVRG optimization logic.
+
+ Parameters
+ ----------
+ symbol : Symbol
+ data_names : list of str
+ Defaults to `('data')` for a typical model used in image classification.
+ label_names : list of str
+ Defaults to `('softmax_label')` for a typical model used in image
+ classification.
+ logger : Logger
+ Defaults to `logging`.
+ context : Context or list of Context
+ Defaults to ``mx.cpu()``.
+ work_load_list : list of number
+ Default ``None``, indicating uniform workload.
+ fixed_param_names: list of str
+ Default ``None``, indicating no network parameters are fixed.
+ state_names : list of str
+ states are similar to data and label, but not provided by data iterator.
+ Instead they are initialized to 0 and can be set by `set_states()`.
+ group2ctxs : dict of str to context or list of context, or list of dict of str to context
+ Default is `None`. Mapping the `ctx_group` attribute to the context assignment.
+ compression_params : dict
+ Specifies type of gradient compression and additional arguments depending
+ on the type of compression being used. For example, 2bit compression requires a threshold.
+ Arguments would then be {'type':'2bit', 'threshold':0.5}
+ See mxnet.KVStore.set_gradient_compression method for more details on gradient compression.
+ update_freq: int
+ Specifies the number of times to update the full gradients to be used in the SVRG optimization. For instance,
+ update_freq = 2 will calculates the gradients over all data every two epochs
+ Examples
+ --------
+ >>> # An example of declaring and using SVRGModule.
+ >>> mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2)
+ >>> mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),),
+ >>> num_epoch=num_epoch, kvstore='local')
+ """
+
+ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
+ logger=logging, context=mx.cpu(), work_load_list=None,
+ fixed_param_names=None, state_names=None, group2ctxs=None,
+ compression_params=None, update_freq=None):
+ super(SVRGModule, self).__init__(symbol, data_names=data_names, label_names=label_names, logger=logger,
+ context=context, work_load_list=work_load_list,
+ fixed_param_names=fixed_param_names, state_names=state_names,
+ group2ctxs=group2ctxs, compression_params=compression_params)
+
+ # Type check update_frequency
+ if isinstance(update_freq, int):
+ if update_freq <= 0:
+ raise ValueError("update_freq in SVRGModule must be a positive integer to represent the frequency for "
+ "calculating full gradients")
+ self.update_freq = update_freq
+ else:
+ raise TypeError("update_freq in SVRGModule must be an integer to represent the frequency for "
+ "calculating full gradients")
+
+ self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, context, work_load_list,
+ fixed_param_names, state_names, group2ctxs, compression_params)
+
+ self._param_dict = None
+ self._ctx_len = len(self._context)
+
+ def _reset_bind(self):
+ """Internal function to reset binded state for both modules."""
+ super(SVRGModule, self)._reset_bind()
+ self._mod_aux._reset_bind()
+
+ def reshape(self, data_shapes, label_shapes=None):
+ """Reshapes both modules for new input shapes.
+
+ Parameters
+ ----------
+ data_shapes : list of (str, tuple)
+ Typically is ``data_iter.provide_data``.
+ label_shapes : list of (str, tuple)
+ Typically is ``data_iter.provide_label``.
+ """
+ super(SVRGModule, self).reshape(data_shapes, label_shapes=label_shapes)
+ self._mod_aux.reshape(data_shapes, label_shapes=label_shapes)
+
+ def init_optimizer(self, kvstore='local', optimizer='sgd',
+ optimizer_params=(('learning_rate', 0.01),), force_init=False):
+ """Installs and initializes SVRGOptimizer. The SVRGOptimizer is a wrapper class for a regular optimizer that is
+ passed in and a special AssignmentOptimizer to accumulate the full gradients. If KVStore is 'local' or None,
+ the full gradients will be accumulated locally without pushing to the KVStore. Otherwise, additional keys will
+ be pushed to accumulate the full gradients in the KVStore.
+
+ Parameters
+ ----------
+ kvstore : str or KVStore
+ Default `'local'`.
+ optimizer : str or Optimizer
+ Default `'sgd'`
+ optimizer_params : dict
+ Default `(('learning_rate', 0.01),)`. The default value is not a dictionary,
+ just to avoid pylint warning of dangerous default values.
+ force_init : bool
+ Default ``False``, indicating whether we should force re-initializing the
+ optimizer in the case an optimizer is already installed.
+ """
+ # Init dict for storing average of full gradients for each device
+
+ self._param_dict = [{key: mx.nd.zeros(shape=value.shape, ctx=self._context[i])
+ for key, value in self.get_params()[0].items()} for i in range(self._ctx_len)]
+
+ svrg_optimizer = self._create_optimizer(_SVRGOptimizer.__name__, default_opt=optimizer,
+ kvstore=kvstore, optimizer_params=optimizer_params)
+
+ super(SVRGModule, self).init_optimizer(kvstore=kvstore, optimizer=svrg_optimizer,
+ optimizer_params=optimizer_params, force_init=force_init)
+
+ # Init additional keys for accumulating full grads in KVStore
+ if self._kvstore:
+ for idx, param_on_devs in enumerate(self._exec_group.param_arrays):
+ name = self._exec_group.param_names[idx]
+ self._kvstore.init(name + "_full", mx.nd.zeros(shape=self._arg_params[name].shape))
+ if self._update_on_kvstore:
+ self._kvstore.pull(name + "_full", param_on_devs, priority=-idx)
+
+ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params):
+ """Helper function to create a svrg optimizer. SVRG optimizer encapsulates two optimizers and
+ will redirect update() to the correct optimizer.
+
+ Parameters
+ ----------
+ kvstore : str or KVStore
+ Default `'local'`.
+ optimizer: str
+ Name for SVRGOptimizer
+ default_opt : str or Optimizer that was passed in.
+ optimizer_params : dict
+ optimizer params that was passed in.
+ """
+
+ # code partially copied from mxnet module.init_optimizer() to accomodate svrg_optimizer
+ batch_size = self._exec_group.batch_size
+
+ (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, self._ctx_len, self._arg_params)
+ if kv_store and 'dist' in kv_store.type and '_sync' in kv_store.type:
+ batch_size *= kv_store.num_workers
+ rescale_grad = 1.0 / batch_size
+
+ idx2name = {}
+ if update_on_kvstore:
+ idx2name.update(enumerate(self._exec_group.param_names))
+ else:
+ for k in range(self._ctx_len):
+ idx2name.update({i * self._ctx_len + k: n
+ for i, n in enumerate(self._exec_group.param_names)})
+
+ # update idx2name to include new keys
+ for key in self._param_dict[0].keys():
+ max_key = max(list(idx2name.keys())) + 1
+ idx2name[max_key] = key + "_full"
+
+ optimizer_params = dict(optimizer_params)
+ if 'rescale_grad' not in optimizer_params:
+ optimizer_params['rescale_grad'] = rescale_grad
+ optimizer_params["default_optimizer"] = default_opt
+ optimizer_params["param_idx2name"] = idx2name
+ optimizer = mx.optimizer.create(optimizer, **optimizer_params)
+
+ return optimizer
+
+ def bind(self, data_shapes, label_shapes=None, for_training=True,
+ inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
+ """Binds the symbols to construct executors for both two modules. This is necessary before one
+ can perform computation with the SVRGModule.
+
+ Parameters
+ ----------
+ data_shapes : list of (str, tuple)
+ Typically is ``data_iter.provide_data``.
+ label_shapes : list of (str, tuple)
+ Typically is ``data_iter.provide_label``.
+ for_training : bool
+ Default is ``True``. Whether the executors should be bound for training.
+ inputs_need_grad : bool
+ Default is ``False``. Whether the gradients to the input data need to be computed.
+ Typically this is not needed. But this might be needed when implementing composition
+ of modules.
+ force_rebind : bool
+ Default is ``False``. This function does nothing if the executors are already
+ bound. But with this ``True``, the executors will be forced to rebind.
+ shared_module : Module
+ Default is ``None``. This is used in bucketing. When not ``None``, the shared module
+ essentially corresponds to a different bucket -- a module with different symbol
+ but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
+ """
+ # force rebinding is typically used when one want to switch from
+ # training to prediction phase.
+
+ super(SVRGModule, self).bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind,
+ shared_module, grad_req)
+
+ if for_training:
+ self._mod_aux.bind(data_shapes, label_shapes, for_training, inputs_need_grad, force_rebind, shared_module,
+ grad_req)
+
+ def forward(self, data_batch, is_train=None):
+ """Forward computation for both two modules. It supports data batches with different shapes, such as
+ different batch sizes or different image sizes.
+ If reshaping of data batch relates to modification of symbol or module, such as
+ changing image layout ordering or switching from training to predicting, module
+ rebinding is required.
+
+ See Also
+ ----------
+ :meth:`BaseModule.forward`.
+
+ Parameters
+ ----------
+ data_batch : DataBatch
+ Could be anything with similar API implemented.
+ is_train : bool
+ Default is ``None``, which means ``is_train`` takes the value of ``self.for_training``.
+ """
+
+ super(SVRGModule, self).forward(data_batch, is_train)
+
+ if is_train:
+ self._mod_aux.forward(data_batch, is_train)
+
+ def backward(self, out_grads=None):
+ """Backward computation.
+
+ See Also
+ ----------
+ :meth:`BaseModule.backward`.
+
+ Parameters
+ ----------
+ out_grads : NDArray or list of NDArray, optional
+ Gradient on the outputs to be propagated back.
+ This parameter is only needed when bind is called
+ on outputs that are not a loss function.
+ """
+
+ super(SVRGModule, self).backward(out_grads)
+
+ if self._mod_aux.binded:
+ self._mod_aux.backward(out_grads)
+
+ def update(self):
+ """Updates parameters according to the installed optimizer and the gradients computed
+ in the previous forward-backward batch. The gradients in the _exec_group will be overwritten
+ using the gradients calculated by the SVRG update rule.
+
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters is stored in KVStore. Note that for `row_sparse` parameters,
+ this function does update the copy of parameters in KVStore, but doesn't broadcast the
+ updated parameters to all devices / machines. Please call `prepare` to broadcast
+ `row_sparse` parameters with the next batch of data.
+
+ See Also
+ ----------
+ :meth:`BaseModule.update`.
+ """
+
+ self._update_svrg_gradients()
+ super(SVRGModule, self).update()
+
+ def update_full_grads(self, train_data):
+ """Computes the gradients over all data w.r.t weights of past
+ m epochs. For distributed env, it will accumulate full grads in the kvstore.
+
+ Parameters
+ ----------
+ train_data: DataIter
+ Train data iterator
+
+ """
+ param_names = self._exec_group.param_names
+ arg, aux = self.get_params()
+ self._mod_aux.set_params(arg_params=arg, aux_params=aux)
+ train_data.reset()
+ nbatch = 0
+ padding = 0
+ for batch in train_data:
+ self._mod_aux.forward(batch, is_train=True)
+ self._mod_aux.backward()
+ nbatch += 1
+
+ for ctx in range(self._ctx_len):
+ for index, name in enumerate(param_names):
+ grads = self._mod_aux._exec_group.grad_arrays[index][ctx]
+ self._param_dict[ctx][name] = mx.nd.broadcast_add(self._param_dict[ctx][name], grads, axis=0)
+ padding = batch.pad
+ true_num_batch = nbatch - padding / train_data.batch_size
+ for name in param_names:
+ grad_list = []
+ for i in range(self._ctx_len):
+ self._param_dict[i][name] /= true_num_batch
+ grad_list.append(self._param_dict[i][name])
+ if self._kvstore:
+ # If in distributed mode, push a list of gradients from each worker/device to the KVStore
+ self._accumulate_kvstore(name, grad_list)
+
+ def _accumulate_kvstore(self, key, value):
+ """Accumulate gradients over all data in the KVStore. In distributed setting, each worker sees a portion of
+ data. The full gradients will be aggregated from each worker in the KVStore.
+
+ Parameters
+ ----------
+
+ key: int or str
+ Key in the KVStore.
+ value: NDArray, RowSparseNDArray
+ Average of the full gradients.
+ """
+
+ # Accumulate full gradients for current epochs
+ self._kvstore.push(key + "_full", value)
+ self._kvstore._barrier()
+ self._kvstore.pull(key + "_full", value)
+
+ self._allocate_gradients(key, value)
+
+ def _allocate_gradients(self, key, value):
+ """Allocate average of full gradients accumulated in the KVStore to each device.
+
+ Parameters
+ ----------
+
+ key: int or str
+ Key in the kvstore.
+ value: List of NDArray, List of RowSparseNDArray
+ A list of average of the full gradients in the KVStore.
+ """
+
+ for i in range(self._ctx_len):
+ self._param_dict[i][key] = value[i] / self._ctx_len
+
+ def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special_weight,
+ g_special_weight_all_batch):
+ """Calculates the gradient based on the SVRG update rule.
+ Parameters
+ ----------
+ g_curr_batch_curr_weight : NDArray
+ gradients of current weight of self.mod w.r.t current batch of data
+ g_curr_batch_special_weight: NDArray
+ gradients of the weight of past m epochs of self._mod_special w.r.t current batch of data
+ g_special_weight_all_batch: NDArray
+ average of full gradients over full pass of data
+
+ Returns
+ ----------
+ Gradients calculated using SVRG update rule:
+ grads = g_curr_batch_curr_weight - g_curr_batch_special_weight + g_special_weight_all_batch
+ """
+
+ for index, grad in enumerate(g_curr_batch_curr_weight):
+ grad -= g_curr_batch_special_weight[index]
+ grad += g_special_weight_all_batch[index]
+ return g_curr_batch_curr_weight
+
+ def _update_svrg_gradients(self):
+ """Calculates gradients based on the SVRG update rule.
+ """
+ param_names = self._exec_group.param_names
+ for ctx in range(self._ctx_len):
+ for index, name in enumerate(param_names):
+ g_curr_batch_reg = self._exec_group.grad_arrays[index][ctx]
+ g_curr_batch_special = self._mod_aux._exec_group.grad_arrays[index][ctx]
+ g_special_weight_all_batch = self._param_dict[ctx][name]
+ g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, g_curr_batch_special,
+ g_special_weight_all_batch)
+ self._exec_group.grad_arrays[index][ctx] = g_svrg
+
+ def fit(self, train_data, eval_data=None, eval_metric='acc',
+ epoch_end_callback=None, batch_end_callback=None, kvstore='local',
+ optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+ eval_end_callback=None,
+ eval_batch_end_callback=None, initializer=mx.init.Uniform(0.01),
+ arg_params=None, aux_params=None, allow_missing=False,
+ force_rebind=False, force_init=False, begin_epoch=0, num_epoch=None,
+ validation_metric=None, monitor=None, sparse_row_id_fn=None):
+ """Trains the module parameters.
+ Parameters
+ ----------
+ train_data : DataIter
+ Train DataIter.
+ eval_data : DataIter
+ If not ``None``, will be used as validation set and the performance
+ after each epoch will be evaluated.
+ eval_metric : str or EvalMetric
+ Defaults to 'accuracy'. The performance measure used to display during training.
+ Other possible predefined metrics are:
+ 'ce' (CrossEntropy), 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'.
+ epoch_end_callback : function or list of functions
+ Each callback will be called with the current `epoch`, `symbol`, `arg_params`
+ and `aux_params`.
+ batch_end_callback : function or list of function
+ Each callback will be called with a `BatchEndParam`.
+ kvstore : str or KVStore
+ Defaults to 'local'.
+ optimizer : str or Optimizer
+ Defaults to 'sgd'.
+ optimizer_params : dict
+ Defaults to ``(('learning_rate', 0.01),)``. The parameters for
+ the optimizer constructor.
+ The default value is not a dict, just to avoid pylint warning on dangerous
+ default values.
+ eval_end_callback : function or list of function
+ These will be called at the end of each full evaluation, with the metrics over
+ the entire evaluation set.
+ eval_batch_end_callback : function or list of function
+ These will be called at the end of each mini-batch during evaluation.
+ initializer : Initializer
+ The initializer is called to initialize the module parameters when they are
+ not already initialized.
+ arg_params : dict
+ Defaults to ``None``, if not ``None``, should be existing parameters from a trained
+ model or loaded from a checkpoint (previously saved model). In this case,
+ the value here will be used to initialize the module parameters, unless they
+ are already initialized by the user via a call to `init_params` or `fit`.
+ `arg_params` has a higher priority than `initializer`.
+ aux_params : dict
+ Defaults to ``None``. Similar to `arg_params`, except for auxiliary states.
+ allow_missing : bool
+ Defaults to ``False``. Indicates whether to allow missing parameters when `arg_params`
+ and `aux_params` are not ``None``. If this is ``True``, then the missing parameters
+ will be initialized via the `initializer`.
+ force_rebind : bool
+ Defaults to ``False``. Whether to force rebinding the executors if already bound.
+ force_init : bool
+ Defaults to ``False``. Indicates whether to force initialization even if the
+ parameters are already initialized.
+ begin_epoch : int
+ Defaults to 0. Indicates the starting epoch. Usually, if resumed from a
+ checkpoint saved at a previous training phase at epoch N, then this value should be
+ N+1.
+ num_epoch : int
+ Number of epochs for training.
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
+ validation_metric:
+ """
+
+ assert num_epoch is not None, 'please specify number of epochs'
+
+ self.bind(data_shapes=train_data.provide_data, label_shapes=train_data.provide_label,
+ for_training=True, force_rebind=force_rebind)
+ if monitor is not None:
+ self.install_monitor(monitor)
+ self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
+ allow_missing=allow_missing, force_init=force_init)
+ self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params)
+
+ if validation_metric is None:
+ validation_metric = eval_metric
+ if not isinstance(eval_metric, mx.metric.EvalMetric):
+ eval_metric = mx.metric.create(eval_metric)
+
+ ################################################################################
+ # training loop
+ ################################################################################
+ for epoch in range(begin_epoch, num_epoch):
+ eval_metric.reset()
+ tic = time.time()
+ if epoch % self.update_freq == 0:
+ self.update_full_grads(train_data)
+
+ train_data.reset()
+ data_iter = iter(train_data)
+ end_of_batch = False
+ nbatch = 0
+ next_data_batch = next(data_iter)
+
+ while not end_of_batch:
+ data_batch = next_data_batch
+ if monitor is not None:
+ monitor.tic()
+
+ self.forward_backward(data_batch)
+ self.update()
+
+ if isinstance(data_batch, list):
+ self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True)
+ else:
+ self.update_metric(eval_metric, data_batch.label)
+
+ try:
+ # pre fetch next batch
+ next_data_batch = next(data_iter)
+ self.prepare(next_data_batch, sparse_row_id_fn=sparse_row_id_fn)
+ except StopIteration:
+ end_of_batch = True
+
+ if monitor is not None:
+ monitor.toc_print()
+
+ if end_of_batch:
+ eval_name_vals = eval_metric.get_name_value()
+
+ if batch_end_callback is not None:
+ batch_end_params = mx.model.BatchEndParam(epoch=epoch, nbatch=nbatch,
+ eval_metric=eval_metric, locals=locals())
+ for callback in mx.base._as_list(batch_end_callback):
+ callback(batch_end_params)
+
+ nbatch += 1
+ for name, val in eval_name_vals:
+ self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
+ toc = time.time()
+ self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
+
+ # sync aux params across devices
+ arg_params, aux_params = self.get_params()
+ self.set_params(arg_params, aux_params)
+
+ if epoch_end_callback is not None:
+ for callback in mx.base._as_list(epoch_end_callback):
+ callback(epoch, self.symbol, arg_params, aux_params)
+
+ # ----------------------------------------
+ # evaluation on validation set
+ if eval_data:
+ res = self.score(eval_data, validation_metric,
+ score_end_callback=eval_end_callback,
+ batch_end_callback=eval_batch_end_callback, epoch=epoch)
+ for name, val in res:
+ self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
+
+ def prepare(self, data_batch, sparse_row_id_fn=None):
+ """Prepares two modules for processing a data batch.
+
+ Usually involves switching bucket and reshaping.
+ For modules that contain `row_sparse` parameters in KVStore,
+ it prepares the `row_sparse` parameters based on the sparse_row_id_fn.
+
+ When KVStore is used to update parameters for multi-device or multi-machine training,
+ a copy of the parameters are stored in KVStore. Note that for `row_sparse` parameters,
+ the `update()` updates the copy of parameters in KVStore, but doesn't broadcast
+ the updated parameters to all devices / machines. The `prepare` function is used to
+ broadcast `row_sparse` parameters with the next batch of data.
+
+ Parameters
+ ----------
+ data_batch : DataBatch
+ The current batch of data for forward computation.
+
+ sparse_row_id_fn : A callback function
+ The function takes `data_batch` as an input and returns a dict of
+ str -> NDArray. The resulting dict is used for pulling row_sparse
+ parameters from the kvstore, where the str key is the name of the param,
+ and the value is the row id of the param to pull.
+ """
+ super(SVRGModule, self).prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
+ self._mod_aux.prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
diff --git a/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
new file mode 100644
index 000000000000..0f695a1b2ff0
--- /dev/null
+++ b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
@@ -0,0 +1,171 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""A `_SVRGOptimizer` encapsulates two optimizers to support SVRGModule in single machine and distributed settings.
+Both `_AssignmentOptimizer` and `_SVRGOptimizer` are designed to be used with SVRGModule only.
+"""
+
+
+import mxnet as mx
+
+
+@mx.optimizer.register
+class _AssignmentOptimizer(mx.optimizer.Optimizer):
+ """_AssignmentOptimizer assigns gradients to weights for SVRGModule's full gradients
+ accumulation in the KVStore. It is a helper optimizer that is designed to be used with SVRGModule only.
+ """
+ def update(self, index, weight, grad, state):
+ """Assign the gradients to weight for accumulating full gradients in the KVStore across all devices and workers.
+
+ Parameters
+ ----------
+ index : int
+ The unique index of the parameter into the individual learning
+ rates and weight decays. Learning rates and weight decay
+ may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
+ weight : NDArray
+ The parameter to be updated.
+ grad : NDArray
+ The gradient of the objective with respect to this parameter.
+ state: any obj
+ AssignmentOptimizer will not need to be associated with state.
+ """
+
+ weight[:] = grad
+
+
+@mx.optimizer.register
+class _SVRGOptimizer(mx.optimizer.Optimizer):
+ """_SVRGOptimizer is a wrapper class for two optimizers: _AssignmentOptimizer for accumulating full gradients in the
+ KVStore and a default optimizer that is passed in as a parameter in `mod.init_optimizer()`
+ The _SVRGOptimizer is designed to be used with SVRGModule only.
+
+ This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`.
+
+ Parameters
+ ----------
+ default_optimizer: str or Optimizer
+ Optimizer passed-in when invoke on mx.mod.init_optimizer in SVRGModule
+ """
+
+ def __init__(self, default_optimizer, **kwargs):
+ # Reconstruct kwargs to identify additional params for default optimizer
+ base_param = self._check_params(**kwargs)
+ super(_SVRGOptimizer, self).__init__(**base_param)
+ if isinstance(default_optimizer, str):
+ self.default_opt = mx.optimizer.create(default_optimizer, **kwargs)
+ else:
+ self.default_opt = default_optimizer
+ self.aux_opt = mx.optimizer.create(_AssignmentOptimizer.__name__)
+
+ @staticmethod
+ def _check_params(**kwargs):
+ """ Reassemble kwargs to identify additional optimizer params for default optimizers. base_params contains
+ all the param names in base class Optimizer.
+
+ Parameters
+ ----------
+ kwargs: dict
+ Parameters for the default optimizer
+
+ Returns
+ ----------
+ default_params: dict
+ Optimizer parameters that are defined in base class Optimizer
+ """
+
+ optimizer_param = dict(kwargs)
+ base_params = ['rescale_grad', 'param_idx2name', 'wd', 'clip_gradient', 'learning_rate', 'lr_scheduler', 'sym',
+ 'begin_num_update', 'multi_precision', 'param_dict']
+
+ default_params = {}
+ for key, _ in optimizer_param.items():
+ if key in base_params:
+ default_params[key] = optimizer_param[key]
+
+ return default_params
+
+ def update(self, index, weight, grad, state):
+ """Updates the given parameter using the corresponding gradient and state. If key contains 'full', update with
+ `_AssignmentOptimizer` otherwise will use default optimizer.
+
+ Parameters
+ ----------
+ index : int
+ The unique index of the parameter into the individual learning
+ rates and weight decays. Learning rates and weight decay
+ may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
+ weight : NDArray
+ The parameter to be updated.
+ grad : NDArray
+ The gradient of the objective with respect to this parameter.
+ state : any obj
+ The state returned by `create_state()`.
+ """
+
+ name = self._check_index(index)
+
+ if "full" in name:
+ self.aux_opt.update(index, weight, grad, state)
+ else:
+ # use the default optimizer
+ self.default_opt.update(index, weight, grad, state)
+
+ def create_state(self, index, weight):
+ """Creates auxiliary state for a given weight.
+ Some optimizers require additional states, e.g. as momentum, in addition
+ to gradients in order to update weights. This function creates state
+ for a given weight which will be used in `update`. This function is
+ called only once for each weight.
+
+ Parameters
+ ----------
+ index : int
+ An unique index to identify the weight.
+ weight : NDArray
+ The weight.
+ Returns
+ -------
+ state : any obj
+ The state associated with the weight.
+ """
+
+ name = self._check_index(index)
+ if "full" in name:
+ return self.aux_opt.create_state(index, weight)
+ else:
+ #
+ return self.default_opt.create_state(index, weight)
+
+ def _check_index(self, index):
+ """Check index in idx2name to get corresponding param_name
+ Parameters
+ ----------
+ index : int or str
+ An unique index to identify the weight.
+ Returns
+ -------
+ name : str
+ Name of the Module parameter
+ """
+
+ if index in self.idx2name.values():
+ # index is a str
+ name = index
+ else:
+ # index is an int
+ name = self.idx2name[index]
+ return name
diff --git a/tests/python/unittest/test_contrib_svrg_module.py b/tests/python/unittest/test_contrib_svrg_module.py
new file mode 100644
index 000000000000..c470a1a7c566
--- /dev/null
+++ b/tests/python/unittest/test_contrib_svrg_module.py
@@ -0,0 +1,307 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+from common import with_seed, assertRaises
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+from mxnet.test_utils import *
+
+
+def setup():
+ train_data = np.random.randint(1, 5, [1000, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=2)
+ mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+
+ return di, mod
+
+
+def test_bind_module():
+ _, mod = setup()
+ assert mod.binded == True
+ assert mod._mod_aux.binded == True
+
+
+def test_module_init():
+ _, mod = setup()
+ assert mod._mod_aux is not None
+
+
+def test_module_initializer():
+ def regression_model(m):
+ x = mx.symbol.var("data", stype='csr')
+ v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1),
+ stype='row_sparse')
+ model = mx.symbol.dot(lhs=x, rhs=v)
+ y = mx.symbol.Variable("label")
+ model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
+ return model
+
+ #shape of the data
+ n, m = 128, 100
+ model = regression_model(m)
+
+ data = mx.nd.zeros(shape=(n, m), stype='csr')
+ label = mx.nd.zeros((n, 1))
+ iterator = mx.io.NDArrayIter(data=data, label={'label': label},
+ batch_size=n, last_batch_handle='discard')
+
+ # create module
+ mod = SVRGModule(symbol=model, data_names=['data'], label_names=['label'], update_freq=2)
+ mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label)
+ mod.init_params()
+ v = mod._arg_params['v']
+ assert v.stype == 'row_sparse'
+ assert np.sum(v.asnumpy()) != 0
+
+
+def test_module_bind():
+ x = mx.sym.Variable("data")
+ net = mx.sym.FullyConnected(x, num_hidden=1)
+
+ mod = SVRGModule(symbol=net, data_names=['data'], label_names=None, update_freq=2)
+ assertRaises(TypeError, mod.bind, data_shapes=['data', mx.nd.zeros(shape=(2, 1))])
+
+ mod.bind(data_shapes=[('data', (2, 1))])
+ assert mod.binded == True
+ assert mod._mod_aux.binded == True
+
+
+@with_seed()
+def test_module_save_load():
+ import tempfile
+ import os
+
+ x = mx.sym.Variable("data")
+ y = mx.sym.Variable("softmax_label")
+ net = mx.sym.FullyConnected(x, y, num_hidden=1)
+
+ mod = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=2)
+ mod.bind(data_shapes=[('data', (1, 1))])
+ mod.init_params()
+ mod.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': 0.1})
+ mod.update()
+
+ # Create tempfile
+ tmp = tempfile.mkdtemp()
+ tmp_file = os.path.join(tmp, 'svrg_test_output')
+ mod.save_checkpoint(tmp_file, 0, save_optimizer_states=True)
+
+ mod2 = SVRGModule.load(tmp_file, 0, load_optimizer_states=True, data_names=('data', ))
+ mod2.bind(data_shapes=[('data', (1, 1))])
+ mod2.init_optimizer(optimizer_params={'learning_rate': 0.1})
+ assert mod._symbol.tojson() == mod2._symbol.tojson()
+
+ # Multi-device
+ mod3 = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=3,
+ context=[mx.cpu(0), mx.cpu(1)])
+ mod3.bind(data_shapes=[('data', (10, 10))])
+ mod3.init_params()
+ mod3.init_optimizer(optimizer_params={'learning_rate': 1.0})
+ mod3.update()
+ mod3.save_checkpoint(tmp_file, 0, save_optimizer_states=True)
+
+ mod4 = SVRGModule.load(tmp_file, 0, load_optimizer_states=True, data_names=('data', ))
+ mod4.bind(data_shapes=[('data', (10, 10))])
+ mod4.init_optimizer(optimizer_params={'learning_rate': 1.0})
+ assert mod3._symbol.tojson() == mod4._symbol.tojson()
+
+
+@with_seed()
+def test_svrgmodule_reshape():
+ data = mx.sym.Variable("data")
+ sym = mx.sym.FullyConnected(data=data, num_hidden=4, name='fc')
+
+ dshape=(3, 4)
+ mod = SVRGModule(sym, data_names=["data"], label_names=None, context=[mx.cpu(0), mx.cpu(1)], update_freq=1)
+ mod.bind(data_shapes=[('data', dshape)])
+ mod.init_params()
+ mod._mod_aux.init_params()
+ mod.init_optimizer(optimizer_params={"learning_rate": 1.0})
+
+ data_batch = mx.io.DataBatch(data=[mx.nd.ones(dshape)], label=None)
+ mod.forward(data_batch)
+ mod.backward([mx.nd.ones(dshape)])
+ mod.update()
+ assert mod.get_outputs()[0].shape == dshape
+
+ dshape = (2, 4)
+ mod.reshape(data_shapes=[('data', dshape)])
+ mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)],
+ label=None))
+ mod.backward([mx.nd.ones(dshape)])
+ mod.update()
+ assert mod.get_outputs()[0].shape == dshape
+
+
+@with_seed()
+def test_update_full_grad():
+ def create_network():
+ train_data = np.random.randint(1, 5, [10, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=5, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=2)
+ mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
+ mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+ force_init=False)
+ return di, mod
+
+ di, svrg_mod = create_network()
+
+ # Calculates the average of full gradients over number batches
+ full_grads_weights = mx.nd.zeros(shape=svrg_mod.get_params()[0]['fc1_weight'].shape)
+ arg, aux = svrg_mod.get_params()
+ svrg_mod._mod_aux.set_params(arg_params=arg, aux_params=aux)
+ num_batch = 2
+
+ for batch in di:
+ svrg_mod.forward(batch)
+ svrg_mod.backward()
+ full_grads_weights = mx.nd.broadcast_add(svrg_mod._exec_group.grad_arrays[0][0], full_grads_weights, axis=0)
+ full_grads_weights /= num_batch
+
+ di.reset()
+ svrg_mod.update_full_grads(di)
+ assert same(full_grads_weights, svrg_mod._param_dict[0]['fc1_weight'])
+
+
+@with_seed()
+def test_svrg_with_sgd():
+ def create_module_with_sgd():
+ train_data = np.random.randint(1, 5, [100, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=10, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ reg_mod = mx.mod.Module(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'])
+ reg_mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ reg_mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
+ reg_mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))
+
+ svrg_mod = SVRGModule(symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'],
+ update_freq=2)
+ svrg_mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ svrg_mod.init_params(initializer=mx.init.One(), allow_missing=False, force_init=False, allow_extra=False)
+ svrg_mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))
+
+ return di,reg_mod, svrg_mod
+
+ di, reg_mod, svrg_mod = create_module_with_sgd()
+ num_epoch = 10
+
+ # Use metric MSE
+ metrics = mx.metric.create("mse")
+
+ # Train with SVRGModule
+ for e in range(num_epoch):
+ metrics.reset()
+ if e % svrg_mod.update_freq == 0:
+ svrg_mod.update_full_grads(di)
+ di.reset()
+ for batch in di:
+ svrg_mod.forward_backward(data_batch=batch)
+ svrg_mod.update()
+ svrg_mod.update_metric(metrics, batch.label)
+ svrg_mse = metrics.get()[1]
+
+ # Train with SGD standard Module
+ di.reset()
+ for e in range(num_epoch):
+ metrics.reset()
+ di.reset()
+ for batch in di:
+ reg_mod.forward_backward(data_batch=batch)
+ reg_mod.update()
+ reg_mod.update_metric(metrics, batch.label)
+ sgd_mse = metrics.get()[1]
+
+ assert svrg_mse < sgd_mse
+
+
+@with_seed()
+def test_accumulate_kvstore():
+ # Test KVStore behavior when push a list of values
+ kv = mx.kv.create('local')
+ kv.init("fc1_weight", mx.nd.zeros(shape=(1, 2)))
+ kv.init("fc1_weight_full", mx.nd.zeros(shape=(1, 2)))
+ b = [mx.nd.ones(shape=(1, 2)) for i in range(4)]
+ a = mx.nd.zeros(shape=(1, 2))
+ kv.push("fc1_weight_full", b)
+ kv.pull("fc1_weight_full", out=a)
+ assert same(a, [mx.nd.array([4, 4])])
+ assert kv.num_workers == 1
+
+ # Test accumulate in KVStore and allocate gradients
+ kv_test = mx.kv.create('local')
+ _, svrg_mod = setup()
+ svrg_mod.init_optimizer(kvstore=kv_test, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+ force_init=False)
+ svrg_mod._accumulate_kvstore("fc1_weight", b)
+ assert len(svrg_mod._param_dict) == svrg_mod._ctx_len
+ assert same(svrg_mod._param_dict[0]["fc1_weight"], b[0])
+
+
+@with_seed()
+def test_fit():
+ di, mod = setup()
+ num_epoch = 100
+ metric = mx.metric.create("mse")
+ mod.fit(di, eval_metric=metric, optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch,
+ kvstore='local')
+
+ # Estimated MSE for using SGD optimizer of lr = 0.025, SVRG MSE should be smaller
+ estimated_mse = 1e-5
+ assert metric.get()[1] < estimated_mse
+
+
+if __name__ == "__main__":
+ import nose
+ nose.runmodule()
diff --git a/tests/python/unittest/test_contrib_svrg_optimizer.py b/tests/python/unittest/test_contrib_svrg_optimizer.py
new file mode 100644
index 000000000000..f7d90d12872f
--- /dev/null
+++ b/tests/python/unittest/test_contrib_svrg_optimizer.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+from mxnet.test_utils import same
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+from mxnet.contrib.svrg_optimization.svrg_optimizer import _SVRGOptimizer
+
+
+def create_network():
+
+ train_data = np.random.randint(1, 5, [1000, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ batch_size = 32
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=2
+ )
+
+ mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False,
+ force_init=False, allow_extra=False)
+
+ return di, mod
+
+
+def test_init_svrg_optimizer():
+ _, mod = create_network()
+
+ kv = mx.kv.create('local')
+ mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+ force_init=False)
+
+ assert type(mod._optimizer).__name__ == _SVRGOptimizer.__name__
+
+
+def test_svrg_optimizer_constructor():
+ kv = mx.kv.create('local')
+ svrg_optimizer = _SVRGOptimizer(default_optimizer='sgd', learning_rate=-1.0)
+ kv.set_optimizer(svrg_optimizer)
+
+ assert svrg_optimizer.default_opt.lr == -1.0
+
+
+def test_kvstore_init_aux_keys():
+ param_idx2name = {0: "weight", 1: "weight_full"}
+
+ svrg_optimizer = _SVRGOptimizer(default_optimizer='sgd', param_idx2name= param_idx2name, learning_rate=1.0)
+ kv = mx.kv.create('local')
+ kv.set_optimizer(svrg_optimizer)
+
+ # Use default sgd optimizer
+ param_weight_init = mx.nd.array([0, 0, 0])
+ param_weight_update = mx.nd.array([1, 1, 1])
+
+ kv.init(0, param_weight_init)
+ kv.push(0, param_weight_update)
+ kv.pull(0, param_weight_init)
+
+ param_weight_full_init = mx.nd.array([1, 1, 1])
+ param_weight_full_update = mx.nd.array([2, 2, 2])
+
+ # Use AssignmentOptimizer
+ kv.init(1, param_weight_full_init)
+ kv.push(1, param_weight_full_update)
+ kv.pull(1, param_weight_full_init)
+
+ # updated weights using default sgd optimizer
+ assert same(param_weight_init.asnumpy(), np.array([-1, -1, -1]))
+ # updated with AssignmentOptimizer
+ assert same(param_weight_full_init.asnumpy(), np.array([2, 2, 2]))
+
+
+if __name__ == "__main__":
+ import nose
+ nose.runmodule()