diff --git a/contrib/svrg_optimization_python/README.md b/contrib/svrg_optimization_python/README.md
deleted file mode 100644
index cc3b6bc41357..000000000000
--- a/contrib/svrg_optimization_python/README.md
+++ /dev/null
@@ -1,39 +0,0 @@
-## SVRG Optimization in Python Module
-
-### Problem Description
-SVRG stands for Stochastic Variance Reduced Gradient, which was first introduced in the paper _Accelerating Stochastic
-Gradient Descent using Predicative Variance Reduction_ in 2013. It is an optimization technique that complements SGD.
-SGD is known for large scale optimization but it suffers from slow convergence asymptotically due to the inherent
-variance. SGD approximates the full gradient using a small batch of samples which introduces variance.
-In order to converge faster, SGD often needs to start with a smaller learning rate. SVRG remedies the problem by keeping
-a version of the estimated weights that is close to the optimal parameters and maintain average of full gradient over
-full pass of data. The average of full gradients of all data is calculated w.r.t to parameters of last mth epochs.
-It has provable guarantees for strongly convex smooth functions, and a more detailed proof can be found in section 3 of
-the paper. SVRG uses a different update rule: gradients w.r.t current parameters minus gradients w.r.t parameters
-from the last mth epoch, plus the average of gradients over all data.
-
-#### Key Characteristics of SVRG
-* Explicit Variance Reduction
-* Ability to use relatively large learning rate compared to SGD, which leads to faster convergence compared to SGD.
-
-#### Testing:
-Functional Tests:
-* test_svrg_train.py: test script that tests both high-level and intermediate-level api for using SVRG
-* test_svrg_inferency.py: test script for testing SVRGModule inference
-
-Unit Tests:
-* test_svrg_module.py: unittests for SVRGModule API
-* test_svrg_optimizer.py: unittests for SVRGOptimizer API
-
-#### Benchmarks:
-An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model.
-
-* benchmark1.py: A lr_scheduler returns a new learning rate based on the number of updates that have been performed.
-The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs.
-
-* benchmark2.py: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero,
-thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG,
-therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with
-learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges
-much faster than SGD in both cases.
-This particular experiment result aligns with what was stated in the SVRG paper section 5.
\ No newline at end of file
diff --git a/contrib/svrg_optimization_python/src/__init__.py b/contrib/svrg_optimization_python/src/__init__.py
deleted file mode 100644
index e69de29bb2d1..000000000000
diff --git a/contrib/svrg_optimization_python/tests/__init__.py b/contrib/svrg_optimization_python/tests/__init__.py
deleted file mode 100644
index b7a3e645e0d5..000000000000
--- a/contrib/svrg_optimization_python/tests/__init__.py
+++ /dev/null
@@ -1,21 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from __future__ import absolute_import
-from ..src.svrg_module import SVRGModule
-from ..src.svrg_optimizer import SVRGOptimizer
-
diff --git a/contrib/svrg_optimization_python/tests/test_svrg_module.py b/contrib/svrg_optimization_python/tests/test_svrg_module.py
deleted file mode 100644
index 5118ae1656fb..000000000000
--- a/contrib/svrg_optimization_python/tests/test_svrg_module.py
+++ /dev/null
@@ -1,116 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import unittest
-from ..src.svrg_module import SVRGModule
-import mxnet as mx
-import numpy as np
-
-
-class TestSVRGModule(unittest.TestCase):
- def setUp(self):
- mx.random.seed(42)
- train_data = np.random.randint(1, 5, [1000, 2])
- weights = np.array([1.0, 2.0])
- train_label = train_data.dot(weights)
-
- self.di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label')
- X = mx.sym.Variable('data')
- Y = mx.symbol.Variable('lin_reg_label')
- fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
- lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
-
- self.mod = SVRGModule(
- symbol=lro,
- data_names=['data'],
- label_names=['lin_reg_label'], update_freq=2)
- self.mod.bind(data_shapes=self.di.provide_data, label_shapes=self.di.provide_label)
- self.mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False,
- force_init=False, allow_extra=False)
-
- def test_create_module(self):
- self.assertTrue(self.mod._mod_aux is not None)
-
- def test_bind_module(self):
- self.assertTrue(self.mod.binded)
- self.assertTrue(self.mod._mod_aux.binded)
-
- def test_module_initializer(self):
- def regression_model(m):
- x = mx.symbol.var("data", stype='csr')
- v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1),
- stype='row_sparse')
- model = mx.symbol.dot(lhs=x, rhs=v)
- y = mx.symbol.Variable("label")
- model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
- return model
-
- n, m = 128, 100
- model = regression_model(m)
-
- data = mx.nd.zeros(shape=(n, m), stype='csr')
- label = mx.nd.zeros((n, 1))
- iterator = mx.io.NDArrayIter(data=data, label={'label': label},
- batch_size=n, last_batch_handle='discard')
-
- # create module
- mod = SVRGModule(symbol=model, data_names=['data'], label_names=['label'], update_freq=2)
- mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label)
- mod.init_params()
- v = mod._arg_params['v']
- self.assertEqual(v.stype, 'row_sparse')
- self.assertTrue(np.sum(v.asnumpy()) != 0)
-
- @unittest.skip("SVRGModule with Pure SGD will not be a release feature")
- def test_svrg_calculations(self):
- def calc_svrg_optimization(update_freq):
- mx.random.seed(42)
- train_data = np.random.randint(1, 5, [1000, 2])
- weights = np.array([1.0, 2.0])
- train_label = train_data.dot(weights)
-
- di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label')
- X = mx.sym.Variable('data')
- Y = mx.symbol.Variable('lin_reg_label')
- fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
- lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
-
- mod = SVRGModule(
- symbol=lro,
- data_names=['data'],
- label_names=['lin_reg_label'], update_freq=update_freq)
- mod.bind(data_shapes=self.di.provide_data, label_shapes=self.di.provide_label)
- mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
- mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01),))
- num_epoch = 100
-
- metrics = mx.metric.create("mse")
- for e in range(1, num_epoch + 1):
- if e % (mod.update_freq + 1) == 0:
- mod.update_full_grads(di)
- di.reset()
- metrics.reset()
- for batch in di:
- mod.forward_backward(data_batch=batch)
- mod.update()
- mod.update_metric(metrics, batch.label)
- return metrics.get()[1]
-
- svrg_mse = calc_svrg_optimization(update_freq=2)
- sgd_mse = calc_svrg_optimization(update_freq=101)
-
- self.assertTrue(svrg_mse - sgd_mse < 0)
diff --git a/contrib/svrg_optimization_python/tests/test_svrg_optimizer.py b/contrib/svrg_optimization_python/tests/test_svrg_optimizer.py
deleted file mode 100644
index 36d44f64b758..000000000000
--- a/contrib/svrg_optimization_python/tests/test_svrg_optimizer.py
+++ /dev/null
@@ -1,96 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-import unittest
-from ..src.svrg_optimizer import SVRGOptimizer
-from ..src.svrg_module import SVRGModule
-import mxnet as mx
-import numpy as np
-from numpy.testing import assert_array_equal
-
-
-class TestSVRGOPtimizer(unittest.TestCase):
- @staticmethod
- def create_network():
- mx.random.seed(42)
- train_data = np.random.randint(1, 5, [1000, 2])
- weights = np.array([1.0, 2.0])
- train_label = train_data.dot(weights)
-
- batch_size = 32
-
- di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
- X = mx.sym.Variable('data')
- Y = mx.symbol.Variable('lin_reg_label')
- fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
- lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
-
- mod = SVRGModule(
- symbol=lro,
- data_names=['data'],
- label_names=['lin_reg_label'], update_freq=2
- )
-
- mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
- mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False,
- force_init=False, allow_extra=False)
-
- return di, mod
-
- def test_init_svrg_optimizer(self):
- di, mod = self.create_network()
-
- kv = mx.kv.create('local')
- mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
- force_init=False)
-
- self.assertEqual(type(mod._optimizer).__name__, SVRGOptimizer.__name__)
-
- def test_svrg_optimizer_constructor(self):
- _, mod = self.create_network()
-
- kv = mx.kv.create('local')
- svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', learning_rate=1.0)
- kv.set_optimizer(svrg_optimizer)
-
- self.assertEqual(svrg_optimizer.default_opt.lr, 1.0)
-
- def test_kvstore_init_aux_keys(self):
- param_idx2name= {0: "weight", 1: "weight_full"}
-
- svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', param_idx2name= param_idx2name, learning_rate=1.0)
- kv = mx.kv.create('local')
- kv.set_optimizer(svrg_optimizer)
-
- param_weight_init = mx.nd.array([0, 0, 0])
- param_weight_update = mx.nd.array([1, 1, 1])
-
- kv.init(0, param_weight_init)
- kv.push(0, param_weight_update)
- kv.pull(0, param_weight_init)
-
- param_weight_full_init = mx.nd.array([1, 1, 1])
- param_weight_full_update = mx.nd.array([2, 2, 2])
-
- # Use AssignmentOptimizer
- kv.init(1, param_weight_full_init)
- kv.push(1, param_weight_full_update)
- kv.pull(1, param_weight_full_init)
-
- assert_array_equal(param_weight_init.asnumpy(), np.array([-1, -1, -1]))
- assert_array_equal(param_weight_full_init.asnumpy(), np.array([2, 2, 2]))
diff --git a/docs/api/python/contrib/svrg_optimization.md b/docs/api/python/contrib/svrg_optimization.md
new file mode 100644
index 000000000000..2b2c6812b864
--- /dev/null
+++ b/docs/api/python/contrib/svrg_optimization.md
@@ -0,0 +1,80 @@
+# SVRG Optimization in Python Module API
+
+## Overview
+SVRG which stands for Stochastic Variance Reduced Gradients, is an optimization technique that complements SGD. It
+employs explicit variance reduction and converges much faster compared to SGD for smooth and strongly convex functions.
+
+SVRG optimization is implemented as a SVRGModule in `mxnet.contrib.svrg_optimization`, which is an extension of the
+existing `mxnet.module.Module` APIs and encapsulates SVRG optimization logic within several new functions. SVRGModule
+API changes compared to Module API to end users are minimal.
+
+The current `SVRGModule` implements the standard SVRG optimization technique as described in _Accelerating Stochastic
+Gradient Descent using Predicative Variance Reduction_ by calculating the gradients of all data
+every `update_freq` epochs in the training. The SVRGModule update rule: gradients w.r.t current parameters minus gradients w.r.t parameters
+from the last mth epoch, plus the average of gradients over all data.
+
+`SVRGOptimizer` wraps two optimizers, an AssignmentOptimizer which is used for full gradients accumulation in the KVStore and
+a regular optimizer which is specified as a parameter to the `mod.init_optimizer()`.
+
+```eval_rst
+.. warning:: This package contains experimental APIs and may change in the near future.
+```
+
+This document lists the svrg_optimization APIs in mxnet:
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ mxnet.contrib.svrg_optimization.SVRGModule
+ mxnet.contrib.svrg_optimization.SVRGOptimizer
+```
+
+### Intermediate Level API for SVRGModule
+
+The only extra step to use a SVRGModule compared to use a Module is to check if the current epoch should update the
+full gradients over all data. Code snippets below demonstrate the suggested usage of SVRGModule using intermediate
+level APIs.
+
+```python
+>>> mod = SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])
+>>> mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+>>> mod.init_params()
+>>> mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ))
+>>> for epoch in range(num_epochs):
+... if epoch % mod.update_freq == 0:
+... mod.update_full_grads(di)
+... di.reset()
+... for batch in di:
+... mod.forward_backward(data_batch=batch)
+... mod.update()
+```
+
+### High Level API for SVRGModule
+
+The high level API usage of SVRGModule remains exactly the same as Module API. Code snippets below gives an example of
+suggested usage of high level API.
+
+```python
+>>> mod = SVRGModule(symbol=model, update_frequency=2, data_names=['data'], label_names=['lin_reg_label'])
+>>> mod.fit(di, num_epochs=100, optimizer='sgd', optimizer_params=(('learning_rate', 0.01), ), num_epochs=100)
+```
+
+## API reference
+
+
+
+```eval_rst
+
+.. automodule:: mxnet.contrib.svrg_optimization.svrg_module
+ :members: init_optimizer, _create_optimizer, bind, forward, backward, update, update_full_grads,
+ _accumulate_kvstore, _allocate_gradients, _svrg_grads_update_rule, update_svrg_gradients, fit, prepare
+
+.. automodule:: mxnet.contrib.svrg_optimization.svrg_optimizer.SVRGOptimizer
+ :members: _check_params, update, create_state, _check_index
+
+.. automodule:: mxnet.contrib.svrg_optimization.svrg_optimizer.AssignmentOptimizer
+ :members: update
+
+```
+
\ No newline at end of file
diff --git a/docs/api/python/module/module.md b/docs/api/python/module/module.md
index 86ed74db6c19..662caa78ff93 100644
--- a/docs/api/python/module/module.md
+++ b/docs/api/python/module/module.md
@@ -58,6 +58,7 @@ The `module` package provides several modules:
BucketingModule
PythonModule
PythonLossModule
+ SVRGModule
```
We summarize the interface for each class in the following sections.
@@ -188,6 +189,23 @@ additional functionality. We summarize them in this section.
SequentialModule.add
```
+### Class `SVRGModule`
+SVRGModule is an extension to the Module API that implements SVRG (Stochastic Variance Reduced Gradients) optimization
+logic. A few extra functions are defined to assist SVRG optimization update; however these functions are encapsulated in
+Module's existing function calls and should not require explicit invocations by end users using high-level API.
+
+```eval_rst
+.. autosummary::
+ :nosignatures:
+
+ SVRGModule.update_full_grads
+ SVRGModule.update_svrg_gradients
+ SVRGModule._svrg_grads_update_rule
+ SVRGModule._accumulate_kvstore
+ SVRGModule._allocate_gradients
+ SVRGModule._create_optimizer
+```
+
## API Reference
@@ -205,6 +223,8 @@ additional functionality. We summarize them in this section.
:members:
.. autoclass:: mxnet.module.PythonLossModule
:members:
+.. autoclass:: mxnet.contrib.svrg_optimization.SVRGModule
+ :members:
```
diff --git a/example/svrg_module/README.md b/example/svrg_module/README.md
new file mode 100644
index 000000000000..23515c30633e
--- /dev/null
+++ b/example/svrg_module/README.md
@@ -0,0 +1,26 @@
+## SVRG Module Example
+
+#### API Usage Example
+SVRGModule provides both high-level and intermediate-level APIs while minimizing the changes with Module API.
+example_api_train.py: provides suggested usage of SVRGModule high-level and intermediate-level API.
+example_inference.py: provides example usage of SVRGModule inference.
+
+#### Linear Regression
+This example trains a linear regression model using SVRGModule on a real dataset. Logs of the training results can be
+found in experiments.log that will automatically generated when running the training script.
+
+YearPredictionMSD: dataset contains over 400,000 samples with 90 features. Please uncomment data downloading script from
+data_reader.py to download the data.
+
+
+#### Benchmarks:
+An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model.
+
+* benchmark1.py: A lr_scheduler returns a new learning rate based on the number of updates that have been performed.
+The training loss of SVRG is less than SGD with lr_scheduler over all of the 100 epochs.
+
+* benchmark2.py: One drawback for SGD is that in order to converge faster, the learning rate has to decay to zero,
+thus SGD needs to start with a small learning rate. The learning rate does not need to decay to zero for SVRG,
+therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with
+learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges
+much faster than SGD in both cases.
\ No newline at end of file
diff --git a/contrib/svrg_optimization_python/test_svrg_train.py b/example/svrg_module/api_usage_example/example_api_train.py
similarity index 63%
rename from contrib/svrg_optimization_python/test_svrg_train.py
rename to example/svrg_module/api_usage_example/example_api_train.py
index 36e8ce448731..2dbb8c710a1c 100644
--- a/contrib/svrg_optimization_python/test_svrg_train.py
+++ b/example/svrg_module/api_usage_example/example_api_train.py
@@ -18,14 +18,19 @@
import mxnet as mx
import numpy as np
-from src.svrg_module import SVRGModule
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
-def test_svrg_intermediate_level_api(num_epoch):
- """Test intermediate level svrgmodule API where the training process
+def test_svrg_intermediate_level_api(args):
+ """Demonstrates intermediate level SVRGModule API where the training process
need to be explicitly defined. KVstore is not explicitly created.
"""
- di, mod = create_network()
+ num_epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ di, mod = create_network(batch_size, update_freq)
+
mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
kv = mx.kv.create("local")
@@ -33,35 +38,43 @@ def test_svrg_intermediate_level_api(num_epoch):
metrics = mx.metric.create("mse")
for e in range(num_epoch):
metrics.reset()
- if e % (mod.update_freq) == 0:
+ if e % mod.update_freq == 0:
+ print("Update")
mod.update_full_grads(di)
di.reset()
for batch in di:
mod.forward_backward(data_batch=batch)
mod.update()
mod.update_metric(metrics, batch.label)
- print('Epoch[%d] Time cost=%.3f', e, metrics.get())
+ mod.logger.info('Epoch[%d] Train cost=%f', e, metrics.get()[1])
-def test_svrg_high_level_api(num_epoch):
- """Test high level svrgmodule API. KVStore is explicitly created.
+def test_svrg_high_level_api(args):
+ """Demonstrates suggested usage of high level SVRGModule API. KVStore is explicitly created.
"""
- di, mod = create_network()
+ num_epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ di, mod = create_network(batch_size, update_freq)
mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch,
kvstore='local')
-def create_network():
+def create_network(batch_size, update_freq):
"""Create a linear regression network for performing SVRG optimization.
:return: an instance of mx.io.NDArrayIter
:return: an instance of mx.mod.svrgmodule for performing SVRG optimization
"""
- mx.random.seed(42)
+ import logging
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=logging.INFO, format=head)
+
train_data = np.random.randint(1, 5, [1000, 2])
weights = np.array([1.0, 2.0])
train_label = train_data.dot(weights)
- di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label')
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
X = mx.sym.Variable('data')
Y = mx.symbol.Variable('lin_reg_label')
fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
@@ -70,15 +83,23 @@ def create_network():
mod = SVRGModule(
symbol=lro,
data_names=['data'],
- label_names=['lin_reg_label'], update_freq=2
+ label_names=['lin_reg_label'], update_freq=update_freq, logger=logging
)
return di, mod
# run as a script
if __name__ == "__main__":
- num_epoch = 100
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-e', dest='epochs', default=100, type=int)
+ parser.add_argument('-bs', dest='batch_size', default=32, type=int)
+ parser.add_argument('-f', dest="update_freq", default=2, type=int)
+ args = parser.parse_args()
+
+
print("========================== Intermediate Level API ==========================")
- test_svrg_intermediate_level_api(num_epoch)
+ test_svrg_intermediate_level_api(args)
print("========================== High Level API ==========================")
- test_svrg_high_level_api(num_epoch)
+ # test_svrg_high_level_api(args)
diff --git a/contrib/svrg_optimization_python/test_svrg_inference.py b/example/svrg_module/api_usage_example/example_inference.py
similarity index 61%
rename from contrib/svrg_optimization_python/test_svrg_inference.py
rename to example/svrg_module/api_usage_example/example_inference.py
index 0250cdec5899..312f9796074d 100644
--- a/contrib/svrg_optimization_python/test_svrg_inference.py
+++ b/example/svrg_module/api_usage_example/example_inference.py
@@ -18,23 +18,34 @@
import mxnet as mx
import numpy as np
-from src.svrg_module import SVRGModule
+import logging
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
-def test_svrg_inference(num_epoch):
- train_iter, val_iter, mod = create_network()
- mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),),
- num_epoch=num_epoch)
+def test_svrg_inference(args):
+ epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
-def test_score(num_epoch):
- train_iter, val_iter, mod = create_network()
+ train_iter, val_iter, mod = create_network(batch_size, update_freq)
+ mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd',
+ optimizer_params=(('learning_rate', 0.025),),
+ num_epoch=epoch)
+
+
+def get_validation_score(args):
+ epoch = args.epochs
+ batch_size = args.batch_size
+ update_freq = args.update_freq
+
+ train_iter, val_iter, mod = create_network(batch_size, update_freq)
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
- mod.init_optimizer(kvstore='local', optimizer='nag', optimizer_params=(('momentum', 0.9),))
+ mod.init_optimizer(kvstore='local', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),))
metrics = mx.metric.create("mse")
- for e in range(num_epoch):
+ for e in range(epoch):
metrics.reset()
- if e % (mod.update_freq + 1) == 0:
+ if e % mod.update_freq == 0:
mod.update_full_grads(train_iter)
train_iter.reset()
for batch in train_iter:
@@ -43,25 +54,29 @@ def test_score(num_epoch):
mod.update_metric(metrics, batch.label)
y = mod.predict(val_iter)
+
+ # test-train data split, 20% test data out of 1000 data samples
assert y.shape == (200, 1)
score = mod.score(val_iter, ['mse'])
- print("Training Loss is %f", score[0][1])
+ print("Training Loss on Validation Set is {}".format(score[0][1]))
-def create_network():
+def create_network(batch_size, update_freq):
"""Create a linear regression network for performing SVRG optimization.
:return: an instance of mx.io.NDArrayIter
:return: an instance of mx.mod.svrgmodule for performing SVRG optimization
"""
- mx.random.seed(42)
+ head = '%(asctime)-15s %(message)s'
+ logging.basicConfig(level=logging.INFO, format=head)
data = np.random.randint(1, 5, [1000, 2])
+
+ #Test_Train data split
n_train = int(data.shape[0] * 0.8)
weights = np.array([1.0, 2.0])
label = data.dot(weights)
-
- di = mx.io.NDArrayIter(data[:n_train, :], label[:n_train], batch_size=32, shuffle=True, label_name='lin_reg_label')
- val_iter = mx.io.NDArrayIter(data[n_train:, :], label[n_train:], batch_size=32)
+ di = mx.io.NDArrayIter(data[:n_train, :], label[:n_train], batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+ val_iter = mx.io.NDArrayIter(data[n_train:, :], label[n_train:], batch_size=batch_size)
X = mx.sym.Variable('data')
Y = mx.symbol.Variable('lin_reg_label')
@@ -71,16 +86,21 @@ def create_network():
mod = SVRGModule(
symbol=lro,
data_names=['data'],
- label_names=['lin_reg_label'], update_freq=2
- )
+ label_names=['lin_reg_label'], update_freq=update_freq, logger=logging)
return di, val_iter, mod
# run as a script
if __name__ == "__main__":
- num_epoch = 100
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-e', dest='epochs', default=100, type=int)
+ parser.add_argument('-bs', dest='batch_size', default=32, type=int)
+ parser.add_argument('-f', dest="update_freq", default=2, type=int)
+ args = parser.parse_args()
+
print("========================== SVRG Module Inference ==========================")
- test_svrg_inference(num_epoch)
+ test_svrg_inference(args)
print("========================SVRG Module Score ============================")
- test_score(num_epoch)
+ get_validation_score(args)
diff --git a/contrib/svrg_optimization_python/benchmarks/benchmark1.png b/example/svrg_module/benchmarks/benchmark1.png
similarity index 100%
rename from contrib/svrg_optimization_python/benchmarks/benchmark1.png
rename to example/svrg_module/benchmarks/benchmark1.png
diff --git a/contrib/svrg_optimization_python/benchmarks/benchmark2.png b/example/svrg_module/benchmarks/benchmark2.png
similarity index 100%
rename from contrib/svrg_optimization_python/benchmarks/benchmark2.png
rename to example/svrg_module/benchmarks/benchmark2.png
diff --git a/example/svrg_module/linear_regression/common.py b/example/svrg_module/linear_regression/common.py
new file mode 100644
index 000000000000..0b3d19b409c9
--- /dev/null
+++ b/example/svrg_module/linear_regression/common.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import mxnet as mx
+import logging
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def create_lin_reg_network(train_features, train_labels, feature_dim, batch_size, update_freq, ctx, logger):
+ # fit a linear regression model with mxnet SVRGModule
+ print("Fitting linear regression with mxnet")
+ train_iter = mx.io.NDArrayIter(train_features, train_labels, batch_size=batch_size, shuffle=True,
+ data_name='data', label_name='label')
+ data = mx.sym.Variable("data")
+ label = mx.sym.Variable("label")
+ weight = mx.sym.Variable("fc_weight", shape=(1, feature_dim))
+ net = mx.sym.dot(data, weight.transpose())
+ bias = mx.sym.Variable("fc_bias", shape=(1,), wd_mult=0.0, lr_mult=10.0)
+ net = mx.sym.broadcast_plus(net, bias)
+ net = mx.sym.LinearRegressionOutput(data=net, label=label)
+
+ mod = SVRGModule(symbol=net, context=ctx, data_names=['data'], label_names=['label'], logger=logger,
+ update_freq=update_freq)
+ return train_iter, mod
+
+
+def create_metrics(metrics):
+ metric = mx.metric.create(metrics)
+ return metric
+
+
+def create_logger():
+ logger = logging.getLogger('sgd_svrg')
+ logger.setLevel(logging.INFO)
+ formatter = logging.Formatter('%(asctime)s - %(message)s')
+ fh = logging.FileHandler('experiments.log')
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+ return logger
+
+
+################################################################################
+# Functions below are for benchmark purpose to calcuate expectation, variance of
+# gradients per epoch for each parameter. These calculations will be helpful when
+# benchmarking SVRG optimization with other optimization techniques, such as SGD.
+# Currently it only calculates the expectation, variance for single context but
+# can be extended to multi-context in later iterations.
+################################################################################
+
+def accumulate_grad(grad_dict, mod):
+ param_names = mod._exec_group.param_names
+
+ for index, name in enumerate(param_names):
+ if name not in grad_dict:
+ grad_dict[name] = mod._exec_group.grad_arrays[index][0].copy()
+ else:
+ grad_dict[name] = mx.ndarray.concat(grad_dict[name], mod._exec_group.grad_arrays[index][0], dim=0)
+
+
+def calc_expectation(grad_dict, num_batches):
+ """Calculates the expectation of the gradients per epoch for each parameter w.r.t number of batches
+
+ Parameters
+ ----------
+ grad_dict: dict
+ dictionary that maps parameter name to gradients in the mod executor group
+ num_batches: int
+ number of batches
+
+ Returns
+ ----------
+ grad_dict: dict
+ dictionary with new keys mapping to gradients expectations
+
+ """
+ for key in grad_dict.keys():
+ grad_dict[str.format(key+"_expectation")] = mx.ndarray.sum(grad_dict[key], axis=0) / num_batches
+
+ return grad_dict
+
+
+def calc_variance(grad_dict, num_batches, param_names):
+ """Calculates the variance of the gradients per epoch for each parameter w.r.t number of batches
+
+ Parameters
+ ----------
+ grad_dict: dict
+ dictionary that maps parameter name to gradients in the mod executor group
+ num_batches: int
+ number of batches
+ param_names: str
+ parameter name in the module
+
+ Returns
+ ----------
+ grad_dict: dict
+ dictionary with new keys mapping to gradients variance
+
+ """
+ for i in range(len(param_names)):
+ diff_sqr = mx.ndarray.square(mx.nd.subtract(grad_dict[param_names[i]],
+ grad_dict[str.format(param_names[i]+"_expectation")]))
+ grad_dict[str.format(param_names[i] + "_variance")] = mx.ndarray.sum(diff_sqr, axis=0) / num_batches
diff --git a/example/svrg_module/linear_regression/data_reader.py b/example/svrg_module/linear_regression/data_reader.py
new file mode 100644
index 000000000000..d1578fc153bf
--- /dev/null
+++ b/example/svrg_module/linear_regression/data_reader.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import numpy as np
+
+
+def read_year_prediction_data(fileName):
+ # Download data file
+ # from subprocess import call
+ # call(['wget', 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2'])
+ # call(['bzip2', '-d', 'YearPredictionMSD.bz2'])
+
+ from sklearn.datasets import load_svmlight_file
+
+ # YearPredictionMSD dataset: https://archive.ics.uci.edu/ml/datasets/yearpredictionmsd
+ feature_dim = 90
+ print("Reading data from disk...")
+ train_features, train_labels = load_svmlight_file(fileName, n_features=feature_dim, dtype=np.float32)
+ train_features = train_features.todense()
+
+ # normalize the data: subtract means and divide by standard deviations
+ label_mean = train_labels.mean()
+ label_std = np.sqrt(np.square(train_labels - label_mean).mean())
+ feature_means = train_features.mean(axis=0)
+ feature_stds = np.sqrt(np.square(train_features - feature_means).mean(axis=0))
+
+ train_features = (train_features - feature_means) / feature_stds
+ train_labels = (train_labels - label_mean) / label_std
+
+ return feature_dim, train_features, train_labels
diff --git a/example/svrg_module/linear_regression/train.py b/example/svrg_module/linear_regression/train.py
new file mode 100644
index 000000000000..b3d942973f19
--- /dev/null
+++ b/example/svrg_module/linear_regression/train.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import argparse
+import mxnet as mx
+from common import create_lin_reg_network, create_logger
+from data_reader import read_year_prediction_data
+
+parser = argparse.ArgumentParser()
+parser.add_argument('-e', dest='epochs', help='number of epochs for training phase', type=int, default=100)
+parser.add_argument('-f', dest="updateFreq", help="update frequency for SVRGModule", type=int, default=2)
+parser.add_argument('-b', dest="batch_size", help="define the batch size for training", type=int,
+ default=100, required=False)
+parser.add_argument('-m', dest='metrics', help="create eval metric", type=str, default='mse')
+parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
+parser.add_argument('--kv-store', type=str, default='local', help='key-value store type')
+
+args = parser.parse_args()
+# devices for training
+ctx = mx.cpu() if args.gpus is None or args.gpus == "" else [mx.gpu(int(i)) for i in args.gpus.split(',')]
+
+logger = create_logger()
+kv = mx.kvstore.create(args.kv_store)
+
+feature_dim, train_features, train_labels = read_year_prediction_data('YearPredictionMSD')
+train_iter, mod = create_lin_reg_network(train_features, train_labels, feature_dim, args.batch_size, args.updateFreq,
+ ctx, logger)
+
+mod.fit(train_iter, eval_metric='mse', optimizer='sgd',
+ optimizer_params=(('learning_rate', 0.025), ), num_epoch=args.epochs, kvstore=kv)
diff --git a/contrib/svrg_optimization_python/__init__.py b/python/mxnet/contrib/svrg_optimization/__init__.py
similarity index 86%
rename from contrib/svrg_optimization_python/__init__.py
rename to python/mxnet/contrib/svrg_optimization/__init__.py
index 4acf63ef7a13..6e70009983c9 100644
--- a/contrib/svrg_optimization_python/__init__.py
+++ b/python/mxnet/contrib/svrg_optimization/__init__.py
@@ -14,7 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+"""SVRGModule, SVRGOptimization import.
+"""
-from __future__ import absolute_import
-from .src.svrg_optimizer import SVRGOptimizer
-from .src.svrg_module import SVRGModule
+
+from . import svrg_module
+from . import svrg_optimizer
diff --git a/contrib/svrg_optimization_python/src/svrg_module.py b/python/mxnet/contrib/svrg_optimization/svrg_module.py
similarity index 87%
rename from contrib/svrg_optimization_python/src/svrg_module.py
rename to python/mxnet/contrib/svrg_optimization/svrg_module.py
index e587da00eb25..f78c9f363964 100644
--- a/contrib/svrg_optimization_python/src/svrg_module.py
+++ b/python/mxnet/contrib/svrg_optimization/svrg_module.py
@@ -18,11 +18,11 @@
SVRG optimization logic.
"""
-import mxnet as mx
import time
import logging
-from svrg_optimizer import SVRGOptimizer
+import mxnet as mx
from mxnet.module import Module
+from .svrg_optimizer import SVRGOptimizer
class SVRGModule(Module):
@@ -62,7 +62,7 @@ class SVRGModule(Module):
Examples
--------
>>> # An example of declaring and using SVRGModule.
- >>> mod = mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2)
+ >>> mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2)
>>> mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),),
>>> num_epoch=num_epoch, kvstore='local')
"""
@@ -78,30 +78,35 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',),
# Type check update_frequency
if isinstance(update_freq, int):
+ if update_freq <= 0:
+ raise ValueError("update_freq in SVRGModule must be a positive integer to represent the frequency for "
+ "calculating full gradients")
self.update_freq = update_freq
else:
- raise TypeError("update_freq must be an integer")
+ raise TypeError("update_freq in SVRGModule must be a positive integer to represent the frequency for "
+ "calculating full gradients")
self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, context, work_load_list,
fixed_param_names, state_names, group2ctxs, compression_params)
- self._param_dict = [{} for ctx in self._context]
+ self._param_dict = None
+ self._ctx_len = len(self._context)
def _reset_bind(self):
"""Internal function to reset binded state."""
super(SVRGModule, self)._reset_bind()
self._mod_aux._reset_bind()
-
def reshape(self, data_shapes, label_shapes=None):
super(SVRGModule, self).reshape(data_shapes, label_shapes=label_shapes)
self._mod_aux.reshape(data_shapes, label_shapes=label_shapes)
def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer_params=(('learning_rate', 0.01),), force_init=False):
- """Installs and initializes SVRG optimizers. The SVRGOptimizer is a wrapper for normal SGD optimizer
- and special AssignmentOptimizer to accumulate gradients. If KVStore exists, additional keys will be
- pushed to the kvstore for accumulating full grads.
+ """Installs and initializes SVRGOptimizer. The SVRGOptimizer is a wrapper class for a regular optimizer that is
+ passed in and a special AssignmentOptimizer to accumulate the full gradients. If KVStore is 'local' or None,
+ the full gradients will be accumulated locally without pushing to the KVStore. Otherwise, additional keys will
+ be pushed to accumulate the full gradients in the KVStore.
Parameters
----------
@@ -117,15 +122,15 @@ def init_optimizer(self, kvstore='local', optimizer='sgd',
optimizer in the case an optimizer is already installed.
"""
# Init dict for storing average of full gradients for each device
- for i in range(len(self._context)):
- self._param_dict[i] = {key: mx.nd.zeros(shape=value.shape, ctx=self._context[i])
- for key, value in self.get_params()[0].items()}
+
+ self._param_dict = [{key: mx.nd.zeros(shape=value.shape, ctx=self._context[i])
+ for key, value in self.get_params()[0].items()} for i in range(self._ctx_len)]
svrg_optimizer = self._create_optimizer(SVRGOptimizer.__name__, default_opt=optimizer,
kvstore=kvstore, optimizer_params=optimizer_params)
super(SVRGModule, self).init_optimizer(kvstore=kvstore, optimizer=svrg_optimizer,
- optimizer_params=optimizer_params, force_init=force_init)
+ optimizer_params=optimizer_params, force_init=force_init)
# Init additional keys for accumulating full grads in KVStore
if self._kvstore:
@@ -153,7 +158,7 @@ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params):
# code partially copied from mxnet module.init_optimizer() to accomodate svrg_optimizer
batch_size = self._exec_group.batch_size
- (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, len(self._context), self._arg_params)
+ (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, self._ctx_len, self._arg_params)
if kv_store and 'dist' in kv_store.type and '_sync' in kv_store.type:
batch_size *= kv_store.num_workers
rescale_grad = 1.0 / batch_size
@@ -162,8 +167,8 @@ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params):
if update_on_kvstore:
idx2name.update(enumerate(self._exec_group.param_names))
else:
- for k in range(len(self._context)):
- idx2name.update({i * len(self._context) + k: n
+ for k in range(self._ctx_len):
+ idx2name.update({i * self._ctx_len + k: n
for i, n in enumerate(self._exec_group.param_names)})
# update idx2name to include new keys
@@ -182,8 +187,7 @@ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params):
return optimizer
def bind(self, data_shapes, label_shapes=None, for_training=True,
- inputs_need_grad=False, force_rebind=False, shared_module=None,
- grad_req='write'):
+ inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'):
"""Binds the symbols to construct executors for both two modules. This is necessary before one
can perform computation with the SVRGModule.
@@ -259,7 +263,7 @@ def backward(self, out_grads=None):
super(SVRGModule, self).backward(out_grads)
if self._mod_aux.binded:
- self._mod_aux.backward()
+ self._mod_aux.backward(out_grads)
def update(self):
"""Updates parameters according to the installed optimizer and the gradients computed
@@ -301,49 +305,45 @@ def update_full_grads(self, train_data):
self._mod_aux.backward()
nbatch += 1
- for i in range(len(self._context)):
- for j in range(len(param_names)):
- grads = self._mod_aux._exec_group.grad_arrays[j][i]
- self._param_dict[i][param_names[j]] = mx.nd.broadcast_add(self._param_dict[i][param_names[j]],
- grads, axis=0)
+ for ctx in range(self._ctx_len):
+ for index, name in enumerate(param_names):
+ grads = self._mod_aux._exec_group.grad_arrays[index][ctx]
+ self._param_dict[ctx][name] = mx.nd.broadcast_add(self._param_dict[ctx][name], grads, axis=0)
padding = batch.pad
# Average full gradients over number of batches, accumulate in the kvstore if kvstore is set
- for i in range(len(self._context)):
- for key in self._param_dict[i].keys():
- self._param_dict[i][key] /= (nbatch - padding / train_data.batch_size)
+ for i in range(self._ctx_len):
+ for name in param_names:
+ self._param_dict[i][name] /= (nbatch - padding / train_data.batch_size)
if self._kvstore:
# Push a list of gradients from each device in the KVStore
- for key in self._param_dict[0].keys():
- grad_list = []
- for i in range(len(self._param_dict)):
- grad_list.append(self._param_dict[i][key])
-
- self._accumulate_kvstore(key, grad_list)
+ for name in param_names:
+ grad_list = list(self._param_dict[i][name] for i in range(self._ctx_len))
+ self._accumulate_kvstore(name, grad_list)
def _accumulate_kvstore(self, key, value):
- """Accumulate gradients over all data in the KVstore. In distributed setting, each worker sees a portion of
+ """Accumulate gradients over all data in the KVStore. In distributed setting, each worker sees a portion of
data. The full gradients will be aggregated from each worker in the KVStore.
Parameters
----------
key: int or str
- Key in the kvstore.
+ Key in the KVStore.
value: NDArray, RowSparseNDArray
Average of the full gradients.
"""
# Accumulate full gradients for current epochs
self._kvstore.push(key + "_full", value)
-
self._kvstore._barrier()
self._kvstore.pull(key + "_full", value)
+
self._allocate_gradients(key, value)
def _allocate_gradients(self, key, value):
- """Allocate averge of full gradients accumualted in the KVStore to each device.
+ """Allocate average of full gradients accumulated in the KVStore to each device.
Parameters
----------
@@ -354,11 +354,11 @@ def _allocate_gradients(self, key, value):
A list of average of the full gradients in the KVStore.
"""
- num_device = len(self._context)
- for i in range(len(self._param_dict)):
- self._param_dict[i][key] = value[i] / num_device
+ for i in range(self._ctx_len):
+ self._param_dict[i][key] = value[i] / self._ctx_len
- def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special_weight, g_special_weight_all_batch):
+ def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special_weight,
+ g_special_weight_all_batch):
"""Calculates the gradient based on the SVRG update rule.
Parameters
----------
@@ -374,22 +374,24 @@ def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special
Gradients calculated using SVRG update rule:
grads = g_curr_batch_curr_weight - g_curr_batch_special_weight + g_special_weight_all_batch
"""
- for i in range(len(g_curr_batch_curr_weight)):
- g_curr_batch_curr_weight[i] -= g_curr_batch_special_weight[i]
- g_curr_batch_curr_weight[i] += g_special_weight_all_batch[i]
+
+ for index, grad in enumerate(g_curr_batch_curr_weight):
+ grad -= g_curr_batch_special_weight[index]
+ grad += g_special_weight_all_batch[index]
return g_curr_batch_curr_weight
def update_svrg_gradients(self):
"""Calculates gradients based on the SVRG update rule.
"""
param_names = self._exec_group.param_names
- for i in range(len(self._context)):
- for j in range(len(param_names)):
- g_curr_batch_reg = self._exec_group.grad_arrays[j][i]
- g_curr_batch_special = self._mod_aux._exec_group.grad_arrays[j][i]
- g_special_weight_all_batch = self._param_dict[i][param_names[j]]
- g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, g_curr_batch_special, g_special_weight_all_batch)
- self._exec_group.grad_arrays[j][i] = g_svrg
+ for ctx in range(self._ctx_len):
+ for index, name in enumerate(param_names):
+ g_curr_batch_reg = self._exec_group.grad_arrays[index][ctx]
+ g_curr_batch_special = self._mod_aux._exec_group.grad_arrays[index][ctx]
+ g_special_weight_all_batch = self._param_dict[ctx][name]
+ g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, g_curr_batch_special,
+ g_special_weight_all_batch)
+ self._exec_group.grad_arrays[index][ctx] = g_svrg
def fit(self, train_data, eval_data=None, eval_metric='acc',
epoch_end_callback=None, batch_end_callback=None, kvstore='local',
@@ -461,6 +463,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
str -> NDArray. The resulting dict is used for pulling row_sparse
parameters from the kvstore, where the str key is the name of the param,
and the value is the row id of the param to pull.
+ validation_metric:
"""
assert num_epoch is not None, 'please specify number of epochs'
@@ -471,8 +474,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
self.install_monitor(monitor)
self.init_params(initializer=initializer, arg_params=arg_params, aux_params=aux_params,
allow_missing=allow_missing, force_init=force_init)
- self.init_optimizer(kvstore=kvstore, optimizer=optimizer,
- optimizer_params=optimizer_params)
+ self.init_optimizer(kvstore=kvstore, optimizer=optimizer, optimizer_params=optimizer_params)
if validation_metric is None:
validation_metric = eval_metric
@@ -503,9 +505,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
self.update()
if isinstance(data_batch, list):
- self.update_metric(eval_metric,
- [db.label for db in data_batch],
- pre_sliced=True)
+ self.update_metric(eval_metric, [db.label for db in data_batch], pre_sliced=True)
else:
self.update_metric(eval_metric, data_batch.label)
@@ -533,7 +533,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val)
toc = time.time()
self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
- print('Epoch[%d] Time cost=%.3f', epoch, eval_metric.get())
# sync aux params across devices
arg_params, aux_params = self.get_params()
@@ -551,7 +550,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc',
batch_end_callback=eval_batch_end_callback, epoch=epoch)
for name, val in res:
self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val)
- print('Epoch[%d] Validation-%s=%f', epoch, name, val)
def prepare(self, data_batch, sparse_row_id_fn=None):
"""Prepares two modules for processing a data batch.
@@ -577,5 +575,5 @@ def prepare(self, data_batch, sparse_row_id_fn=None):
parameters from the kvstore, where the str key is the name of the param,
and the value is the row id of the param to pull.
"""
- super(SVRGModule, self).prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn)
- self._mod_aux.prepare(data_batch=sparse_row_id_fn)
+ super(SVRGModule, self).prepare(data_batch, sparse_row_id_fn)
+ self._mod_aux.prepare(data_batch, sparse_row_id_fn)
diff --git a/contrib/svrg_optimization_python/src/svrg_optimizer.py b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
similarity index 63%
rename from contrib/svrg_optimization_python/src/svrg_optimizer.py
rename to python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
index bf9cca975cce..dd4b363c38d7 100644
--- a/contrib/svrg_optimization_python/src/svrg_optimizer.py
+++ b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py
@@ -23,46 +23,82 @@
@mx.optimizer.register
class AssignmentOptimizer(mx.optimizer.Optimizer):
+ """AssignmentOptimizer assigns gradients to weights for SVRGModule's full gradients
+ accumulation in the KVStore
+ """
def update(self, index, weight, grad, state):
+ """Assign the gradients to weight for accumulating full gradients in the KVStore across all devices and workers.
+
+ Parameters
+ ----------
+ index : int
+ The unique index of the parameter into the individual learning
+ rates and weight decays. Learning rates and weight decay
+ may be set via `set_lr_mult()` and `set_wd_mult()`, respectively.
+ weight : NDArray
+ The parameter to be updated.
+ grad : NDArray
+ The gradient of the objective with respect to this parameter.
+ state: any obj
+ AssignmentOptimizer will not need to be associated with state.
+ """
+
weight[:] = grad
+
@mx.optimizer.register
class SVRGOptimizer(mx.optimizer.Optimizer):
- """SVRGOptimizer is a wrapper class for two optimizers: one for accumulating full gradients and the other
- one is the passed-in optimizer.
+ """SVRGOptimizer is a wrapper class for two optimizers: AssignmentOptimizer for accumulating full gradients in the
+ KVStore and a default optimizer that is passed in as a parameter in `mod.init_optimizer()`
+
+ This optimizer accepts the following parameters in addition to those accepted by :class:`.Optimizer`.
Parameters
----------
- default_optimizer: optimizer passed-in when invoke on mx.mod.init_optimizer
+ default_optimizer: str or Optimizer
+ Optimizer passed-in when invoke on mx.mod.init_optimizer in SVRGModule
"""
- def __init__(self, default_optimizer, **kwargs):
+ def __init__(self, default_optimizer, **kwargs):
# Reconstruct kwargs to identify additional params for default optimizer
- extra_param, default_param = self._check_params(**kwargs)
- super(SVRGOptimizer, self).__init__(**default_param)
+ base_param = self._check_params(**kwargs)
+ super(SVRGOptimizer, self).__init__(**base_param)
if isinstance(default_optimizer, str):
self.default_opt = mx.optimizer.create(default_optimizer, **kwargs)
else:
self.default_opt = default_optimizer
self.aux_opt = mx.optimizer.create(AssignmentOptimizer.__name__)
+ @staticmethod
+ def _check_params(**kwargs):
+ """ Reassemble kwargs to identify additional optimizer params for default optimizers. base_params contains
+ all the param names in base class Optimizer.
+
+ Parameters
+ ----------
+ kwargs: dict
+ Parameters for the default optimizer
+
+ Returns
+ ----------
+ default_params: dict
+ Optimizer parameters that are defined in base class Optimizer
+ """
- def _check_params(self, **kwargs):
optimizer_param = dict(kwargs)
base_params = ['rescale_grad', 'param_idx2name', 'wd', 'clip_gradient', 'learning_rate', 'lr_scheduler', 'sym',
'begin_num_update', 'multi_precision', 'param_dict']
- extra_param = {}
+
default_params = {}
- for key in optimizer_param.keys():
+ for key, _ in optimizer_param.items():
if key in base_params:
default_params[key] = optimizer_param[key]
- else:
- extra_param[key] = optimizer_param[key]
- return extra_param, default_params
+
+ return default_params
def update(self, index, weight, grad, state):
"""Updates the given parameter using the corresponding gradient and state. If key contains 'full', update with
- lr = -1 otherwise will use default optimizer.
+ AssignmentOptimizer otherwise will use default optimizer.
Parameters
----------
@@ -80,7 +116,7 @@ def update(self, index, weight, grad, state):
name = self._check_index(index)
- if "full".lower() in name:
+ if "full" in name:
self.aux_opt.update(index, weight, grad, state)
else:
# use the default optimizer
@@ -106,26 +142,28 @@ def create_state(self, index, weight):
"""
name = self._check_index(index)
- if "full".lower() in name:
- return
+ if "full" in name:
+ return self.aux_opt.create_state(index, weight)
else:
- # use the default optimizer
+ #
return self.default_opt.create_state(index, weight)
def _check_index(self, index):
"""Check index in idx2name to get corresponding param_name
Parameters
----------
- index : int
+ index : int or str
An unique index to identify the weight.
Returns
-------
name : str
- Name of the Module parameter
+ Name of the Module parameter
"""
if index in self.idx2name.values():
+ # index is a str
name = index
else:
+ # index is an int
name = self.idx2name[index]
return name
diff --git a/tests/python/unittest/legacy_ndarray.v0 b/tests/python/unittest/legacy_ndarray.v0
deleted file mode 100644
index f4306d837202..000000000000
Binary files a/tests/python/unittest/legacy_ndarray.v0 and /dev/null differ
diff --git a/tests/python/unittest/save_000800.json b/tests/python/unittest/save_000800.json
deleted file mode 100644
index 7b385e2983d8..000000000000
--- a/tests/python/unittest/save_000800.json
+++ /dev/null
@@ -1,188 +0,0 @@
-{
- "nodes": [
- {
- "op": "null",
- "param": {},
- "name": "data",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage1",
- "lr_mult": "0.2"
- }
- },
- {
- "op": "null",
- "param": {},
- "name": "fc1_weight",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage1",
- "wd_mult": "0.3",
- "weight_lr_mult": "1.2"
- }
- },
- {
- "op": "null",
- "param": {},
- "name": "fc1_bias",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage1",
- "wd_mult": "0.3",
- "weight_lr_mult": "1.2"
- }
- },
- {
- "op": "FullyConnected",
- "param": {
- "no_bias": "False",
- "num_hidden": "128"
- },
- "name": "fc1",
- "inputs": [[0, 0], [1, 0], [2, 0]],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage1",
- "wd_mult": "0.3",
- "weight_lr_mult": "1.2"
- }
- },
- {
- "op": "Activation",
- "param": {"act_type": "relu"},
- "name": "relu1",
- "inputs": [[3, 0]],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage1"}
- },
- {
- "op": "null",
- "param": {},
- "name": "fc2_weight",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage2",
- "lr_mult": "0.01"
- }
- },
- {
- "op": "null",
- "param": {},
- "name": "fc2_bias",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage2",
- "lr_mult": "0.01"
- }
- },
- {
- "op": "FullyConnected",
- "param": {
- "no_bias": "False",
- "num_hidden": "64"
- },
- "name": "fc2",
- "inputs": [[4, 0], [5, 0], [6, 0]],
- "backward_source_id": -1,
- "attr": {
- "ctx_group": "stage2",
- "lr_mult": "0.01"
- }
- },
- {
- "op": "Activation",
- "param": {"act_type": "relu"},
- "name": "relu2",
- "inputs": [[7, 0]],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "null",
- "param": {},
- "name": "fc3_weight",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "null",
- "param": {},
- "name": "fc3_bias",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "FullyConnected",
- "param": {
- "no_bias": "False",
- "num_hidden": "10"
- },
- "name": "fc3",
- "inputs": [[8, 0], [9, 0], [10, 0]],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "null",
- "param": {},
- "name": "batchnorm0_gamma",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "null",
- "param": {},
- "name": "batchnorm0_beta",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "BatchNorm",
- "param": {
- "eps": "0.001",
- "fix_gamma": "True",
- "momentum": "0.9",
- "use_global_stats": "False"
- },
- "name": "batchnorm0",
- "inputs": [[11, 0], [12, 0], [13, 0]],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "null",
- "param": {},
- "name": "softmax_label",
- "inputs": [],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- },
- {
- "op": "SoftmaxOutput",
- "param": {
- "grad_scale": "1",
- "ignore_label": "-1",
- "multi_output": "False",
- "normalization": "null",
- "out_grad": "False",
- "preserve_shape": "False",
- "use_ignore": "False"
- },
- "name": "softmax",
- "inputs": [[14, 0], [15, 0]],
- "backward_source_id": -1,
- "attr": {"ctx_group": "stage2"}
- }
- ],
- "arg_nodes": [0, 1, 2, 5, 6, 9, 10, 12, 13, 15],
- "heads": [[16, 0]]
-}
\ No newline at end of file
diff --git a/tests/python/unittest/test_contrib_svrg_module.py b/tests/python/unittest/test_contrib_svrg_module.py
new file mode 100644
index 000000000000..70c718824aee
--- /dev/null
+++ b/tests/python/unittest/test_contrib_svrg_module.py
@@ -0,0 +1,158 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import mxnet as mx
+import numpy as np
+from common import with_seed, assertRaises
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+
+
+def setup():
+ train_data = np.random.randint(1, 5, [1000, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=2)
+ mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False)
+
+ return mod
+
+
+def test_bind_module():
+ mod = setup()
+ assert mod.binded == True
+ assert mod._mod_aux.binded == True
+
+
+def test_module_init():
+ mod = setup()
+ assert mod._mod_aux != None
+
+
+def test_module_initializer():
+ def regression_model(m):
+ x = mx.symbol.var("data", stype='csr')
+ v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1),
+ stype='row_sparse')
+ model = mx.symbol.dot(lhs=x, rhs=v)
+ y = mx.symbol.Variable("label")
+ model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out")
+ return model
+
+ #shape of the data
+ n, m = 128, 100
+ model = regression_model(m)
+
+ data = mx.nd.zeros(shape=(n, m), stype='csr')
+ label = mx.nd.zeros((n, 1))
+ iterator = mx.io.NDArrayIter(data=data, label={'label': label},
+ batch_size=n, last_batch_handle='discard')
+
+ # create module
+ mod = SVRGModule(symbol=model, data_names=['data'], label_names=['label'], update_freq=2)
+ mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label)
+ mod.init_params()
+ v = mod._arg_params['v']
+ assert v.stype == 'row_sparse'
+ assert np.sum(v.asnumpy()) != 0
+
+
+def test_module_bind():
+ x = mx.sym.Variable("data")
+ net = mx.sym.FullyConnected(x, num_hidden=1)
+
+ mod = SVRGModule(symbol=net, data_names=['data'], label_names=None, update_freq=2)
+ assertRaises(TypeError, mod.bind, data_shapes=['data', mx.nd.zeros(shape=(2, 1))])
+
+ mod.bind(data_shapes=[('data', (2, 1))])
+ assert mod.binded == True
+ assert mod._mod_aux.binded == True
+
+@with_seed()
+def test_module_save_load():
+ x = mx.sym.Variable("data")
+ y = mx.sym.Variable("softmax_label")
+ net = mx.sym.FullyConnected(x, y, num_hidden=1)
+
+ mod = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=2)
+ mod.bind(data_shapes=[('data', (1, 1))])
+ mod.init_params()
+ mod.init_optimizer(optimizer='sgd', optimizer_params={'learning_rate': 0.1})
+ mod.update()
+ mod.save_checkpoint('test', 0, save_optimizer_states=True)
+
+ mod2 = SVRGModule.load('test', 0, load_optimizer_states=True, data_names=('data', ))
+ mod2.bind(data_shapes=[('data', (1, 1))])
+ mod2.init_optimizer(optimizer_params={'learning_rate': 0.1})
+ assert mod._symbol.tojson() == mod2._symbol.tojson()
+
+
+ # Multi-device
+ mod3 = SVRGModule(symbol=net, data_names=['data'], label_names=['softmax_label'], update_freq=3,
+ context=[mx.cpu(0), mx.cpu(1)])
+ mod3.bind(data_shapes=[('data', (10, 10))])
+ mod3.init_params()
+ mod3.init_optimizer(optimizer_params={'learning_rate': 1.0})
+ mod3.update()
+ mod3.save_checkpoint('test', 0, save_optimizer_states=True)
+
+ mod4 = SVRGModule.load('test', 0, load_optimizer_states=True, data_names=('data', ))
+ mod4.bind(data_shapes=[('data', (10, 10))])
+ mod4.init_optimizer(optimizer_params={'learning_rate': 1.0})
+ assert mod3._symbol.tojson() == mod4._symbol.tojson()
+
+@with_seed()
+def test_svrgmodule_reshape():
+ data = mx.sym.Variable("data")
+ sym = mx.sym.FullyConnected(data=data, num_hidden=4, name='fc')
+
+ dshape=(3, 4)
+ mod = SVRGModule(sym, data_names=["data"], label_names=None, context=[mx.cpu(0), mx.cpu(1)], update_freq=1)
+ mod.bind(data_shapes=[('data', dshape)])
+ mod.init_params()
+ mod._mod_aux.init_params()
+ mod.init_optimizer(optimizer_params={"learning_rate": 1.0})
+
+ data_batch = mx.io.DataBatch(data=[mx.nd.ones(dshape)], label=None)
+ mod.forward(data_batch)
+ mod.backward([mx.nd.ones(dshape)])
+ mod.update()
+ assert mod.get_outputs()[0].shape == dshape
+
+
+ dshape = (2, 4)
+ mod.reshape(data_shapes=[('data', dshape)])
+ mod.forward(mx.io.DataBatch(data=[mx.nd.ones(dshape)],
+ label=None))
+ mod.backward([mx.nd.ones(dshape)])
+ mod.update()
+ assert mod.get_outputs()[0].shape == dshape
+
+
+if __name__ == "__main__":
+ import nose
+ nose.runmodule()
diff --git a/tests/python/unittest/test_contrib_svrg_optimizer.py b/tests/python/unittest/test_contrib_svrg_optimizer.py
new file mode 100644
index 000000000000..fc89b2933a29
--- /dev/null
+++ b/tests/python/unittest/test_contrib_svrg_optimizer.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import mxnet as mx
+from mxnet.test_utils import same
+from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule
+from mxnet.contrib.svrg_optimization.svrg_optimizer import SVRGOptimizer
+
+
+def create_network():
+
+ train_data = np.random.randint(1, 5, [1000, 2])
+ weights = np.array([1.0, 2.0])
+ train_label = train_data.dot(weights)
+
+ batch_size = 32
+
+ di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label')
+ X = mx.sym.Variable('data')
+ Y = mx.symbol.Variable('lin_reg_label')
+ fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1)
+ lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro")
+
+ mod = SVRGModule(
+ symbol=lro,
+ data_names=['data'],
+ label_names=['lin_reg_label'], update_freq=2
+ )
+
+ mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label)
+ mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False,
+ force_init=False, allow_extra=False)
+
+ return di, mod
+
+
+def test_init_svrg_optimizer():
+ _, mod = create_network()
+
+ kv = mx.kv.create('local')
+ mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),),
+ force_init=False)
+
+ assert type(mod._optimizer).__name__ == SVRGOptimizer.__name__
+
+
+def test_svrg_optimizer_constructor():
+ kv = mx.kv.create('local')
+ svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', learning_rate=-1.0)
+ kv.set_optimizer(svrg_optimizer)
+
+ assert svrg_optimizer.default_opt.lr == -1.0
+
+
+def test_kvstore_init_aux_keys():
+ param_idx2name = {0: "weight", 1: "weight_full"}
+
+ svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', param_idx2name= param_idx2name, learning_rate=1.0)
+ kv = mx.kv.create('local')
+ kv.set_optimizer(svrg_optimizer)
+
+ # Use default sgd optimizer
+ param_weight_init = mx.nd.array([0, 0, 0])
+ param_weight_update = mx.nd.array([1, 1, 1])
+
+ kv.init(0, param_weight_init)
+ kv.push(0, param_weight_update)
+ kv.pull(0, param_weight_init)
+
+ param_weight_full_init = mx.nd.array([1, 1, 1])
+ param_weight_full_update = mx.nd.array([2, 2, 2])
+
+ # Use AssignmentOptimizer
+ kv.init(1, param_weight_full_init)
+ kv.push(1, param_weight_full_update)
+ kv.pull(1, param_weight_full_init)
+
+ # updated weights using default sgd optimizer
+ same(param_weight_init.asnumpy(), np.array([-1, -1, -1]))
+ # updated with AssignmentOptimizer
+ same(param_weight_full_init.asnumpy(), np.array([2, 2, 2]))
+
+
+if __name__ == "__main__":
+ import nose
+ nose.runmodule()
\ No newline at end of file
diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py
index 39fcd81642d3..ef9beebe4b80 100644
--- a/tests/python/unittest/test_module.py
+++ b/tests/python/unittest/test_module.py
@@ -168,6 +168,7 @@ def test_module_layout():
assert mod.get_outputs()[0].shape == dshape
hdshape = (3, 4, 7)
+
for x in mod.get_outputs(merge_multi_context=False)[0]:
assert x.shape == hdshape