From 495be3b20b5a3f34f5b82ab231ae713ce58f5c12 Mon Sep 17 00:00:00 2001 From: Stephanie Jingyi Yuan Date: Tue, 28 Aug 2018 21:47:25 -0400 Subject: [PATCH] SVRG optimization in python/contrib package, this version supports single machine single cpu, single gpu and multi-gpus --- .../svrg_optimization_python/src/__init__.py | 0 .../tests/__init__.py | 21 ---- .../tests/test_svrg_module.py | 116 ------------------ .../tests/test_svrg_optimizer.py | 96 --------------- .../svrg_module}/benchmarks/benchmark1.png | Bin .../svrg_module}/benchmarks/benchmark2.png | Bin example/svrg_module/common.py | 78 ++++++++++++ example/svrg_module/data_reader.py | 44 +++++++ .../svrg_module/example_api_train.py | 10 +- .../svrg_module/example_inference.py | 9 +- example/svrg_module/train.py | 44 +++++++ .../contrib/svrg_optimization}/README.md | 10 +- .../contrib/svrg_optimization}/__init__.py | 8 +- .../contrib/svrg_optimization}/svrg_module.py | 94 +++++++------- .../svrg_optimization}/svrg_optimizer.py | 20 +-- .../unittest/test_contrib_svrg_module.py | 86 +++++++++++++ .../unittest/test_contrib_svrg_optimizer.py | 101 +++++++++++++++ 17 files changed, 433 insertions(+), 304 deletions(-) delete mode 100644 contrib/svrg_optimization_python/src/__init__.py delete mode 100644 contrib/svrg_optimization_python/tests/__init__.py delete mode 100644 contrib/svrg_optimization_python/tests/test_svrg_module.py delete mode 100644 contrib/svrg_optimization_python/tests/test_svrg_optimizer.py rename {contrib/svrg_optimization_python => example/svrg_module}/benchmarks/benchmark1.png (100%) rename {contrib/svrg_optimization_python => example/svrg_module}/benchmarks/benchmark2.png (100%) create mode 100644 example/svrg_module/common.py create mode 100644 example/svrg_module/data_reader.py rename contrib/svrg_optimization_python/test_svrg_train.py => example/svrg_module/example_api_train.py (90%) rename contrib/svrg_optimization_python/test_svrg_inference.py => example/svrg_module/example_inference.py (92%) create mode 100644 example/svrg_module/train.py rename {contrib/svrg_optimization_python => python/mxnet/contrib/svrg_optimization}/README.md (85%) rename {contrib/svrg_optimization_python => python/mxnet/contrib/svrg_optimization}/__init__.py (86%) rename {contrib/svrg_optimization_python/src => python/mxnet/contrib/svrg_optimization}/svrg_module.py (91%) rename {contrib/svrg_optimization_python/src => python/mxnet/contrib/svrg_optimization}/svrg_optimizer.py (92%) create mode 100644 tests/python/unittest/test_contrib_svrg_module.py create mode 100644 tests/python/unittest/test_contrib_svrg_optimizer.py diff --git a/contrib/svrg_optimization_python/src/__init__.py b/contrib/svrg_optimization_python/src/__init__.py deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/contrib/svrg_optimization_python/tests/__init__.py b/contrib/svrg_optimization_python/tests/__init__.py deleted file mode 100644 index b7a3e645e0d5..000000000000 --- a/contrib/svrg_optimization_python/tests/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import absolute_import -from ..src.svrg_module import SVRGModule -from ..src.svrg_optimizer import SVRGOptimizer - diff --git a/contrib/svrg_optimization_python/tests/test_svrg_module.py b/contrib/svrg_optimization_python/tests/test_svrg_module.py deleted file mode 100644 index 5118ae1656fb..000000000000 --- a/contrib/svrg_optimization_python/tests/test_svrg_module.py +++ /dev/null @@ -1,116 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import unittest -from ..src.svrg_module import SVRGModule -import mxnet as mx -import numpy as np - - -class TestSVRGModule(unittest.TestCase): - def setUp(self): - mx.random.seed(42) - train_data = np.random.randint(1, 5, [1000, 2]) - weights = np.array([1.0, 2.0]) - train_label = train_data.dot(weights) - - self.di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label') - X = mx.sym.Variable('data') - Y = mx.symbol.Variable('lin_reg_label') - fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1) - lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro") - - self.mod = SVRGModule( - symbol=lro, - data_names=['data'], - label_names=['lin_reg_label'], update_freq=2) - self.mod.bind(data_shapes=self.di.provide_data, label_shapes=self.di.provide_label) - self.mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, - force_init=False, allow_extra=False) - - def test_create_module(self): - self.assertTrue(self.mod._mod_aux is not None) - - def test_bind_module(self): - self.assertTrue(self.mod.binded) - self.assertTrue(self.mod._mod_aux.binded) - - def test_module_initializer(self): - def regression_model(m): - x = mx.symbol.var("data", stype='csr') - v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1), - stype='row_sparse') - model = mx.symbol.dot(lhs=x, rhs=v) - y = mx.symbol.Variable("label") - model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") - return model - - n, m = 128, 100 - model = regression_model(m) - - data = mx.nd.zeros(shape=(n, m), stype='csr') - label = mx.nd.zeros((n, 1)) - iterator = mx.io.NDArrayIter(data=data, label={'label': label}, - batch_size=n, last_batch_handle='discard') - - # create module - mod = SVRGModule(symbol=model, data_names=['data'], label_names=['label'], update_freq=2) - mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label) - mod.init_params() - v = mod._arg_params['v'] - self.assertEqual(v.stype, 'row_sparse') - self.assertTrue(np.sum(v.asnumpy()) != 0) - - @unittest.skip("SVRGModule with Pure SGD will not be a release feature") - def test_svrg_calculations(self): - def calc_svrg_optimization(update_freq): - mx.random.seed(42) - train_data = np.random.randint(1, 5, [1000, 2]) - weights = np.array([1.0, 2.0]) - train_label = train_data.dot(weights) - - di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label') - X = mx.sym.Variable('data') - Y = mx.symbol.Variable('lin_reg_label') - fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1) - lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro") - - mod = SVRGModule( - symbol=lro, - data_names=['data'], - label_names=['lin_reg_label'], update_freq=update_freq) - mod.bind(data_shapes=self.di.provide_data, label_shapes=self.di.provide_label) - mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, force_init=False, allow_extra=False) - mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.01),)) - num_epoch = 100 - - metrics = mx.metric.create("mse") - for e in range(1, num_epoch + 1): - if e % (mod.update_freq + 1) == 0: - mod.update_full_grads(di) - di.reset() - metrics.reset() - for batch in di: - mod.forward_backward(data_batch=batch) - mod.update() - mod.update_metric(metrics, batch.label) - return metrics.get()[1] - - svrg_mse = calc_svrg_optimization(update_freq=2) - sgd_mse = calc_svrg_optimization(update_freq=101) - - self.assertTrue(svrg_mse - sgd_mse < 0) diff --git a/contrib/svrg_optimization_python/tests/test_svrg_optimizer.py b/contrib/svrg_optimization_python/tests/test_svrg_optimizer.py deleted file mode 100644 index 36d44f64b758..000000000000 --- a/contrib/svrg_optimization_python/tests/test_svrg_optimizer.py +++ /dev/null @@ -1,96 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -import unittest -from ..src.svrg_optimizer import SVRGOptimizer -from ..src.svrg_module import SVRGModule -import mxnet as mx -import numpy as np -from numpy.testing import assert_array_equal - - -class TestSVRGOPtimizer(unittest.TestCase): - @staticmethod - def create_network(): - mx.random.seed(42) - train_data = np.random.randint(1, 5, [1000, 2]) - weights = np.array([1.0, 2.0]) - train_label = train_data.dot(weights) - - batch_size = 32 - - di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label') - X = mx.sym.Variable('data') - Y = mx.symbol.Variable('lin_reg_label') - fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1) - lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro") - - mod = SVRGModule( - symbol=lro, - data_names=['data'], - label_names=['lin_reg_label'], update_freq=2 - ) - - mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label) - mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, - force_init=False, allow_extra=False) - - return di, mod - - def test_init_svrg_optimizer(self): - di, mod = self.create_network() - - kv = mx.kv.create('local') - mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), - force_init=False) - - self.assertEqual(type(mod._optimizer).__name__, SVRGOptimizer.__name__) - - def test_svrg_optimizer_constructor(self): - _, mod = self.create_network() - - kv = mx.kv.create('local') - svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', learning_rate=1.0) - kv.set_optimizer(svrg_optimizer) - - self.assertEqual(svrg_optimizer.default_opt.lr, 1.0) - - def test_kvstore_init_aux_keys(self): - param_idx2name= {0: "weight", 1: "weight_full"} - - svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', param_idx2name= param_idx2name, learning_rate=1.0) - kv = mx.kv.create('local') - kv.set_optimizer(svrg_optimizer) - - param_weight_init = mx.nd.array([0, 0, 0]) - param_weight_update = mx.nd.array([1, 1, 1]) - - kv.init(0, param_weight_init) - kv.push(0, param_weight_update) - kv.pull(0, param_weight_init) - - param_weight_full_init = mx.nd.array([1, 1, 1]) - param_weight_full_update = mx.nd.array([2, 2, 2]) - - # Use AssignmentOptimizer - kv.init(1, param_weight_full_init) - kv.push(1, param_weight_full_update) - kv.pull(1, param_weight_full_init) - - assert_array_equal(param_weight_init.asnumpy(), np.array([-1, -1, -1])) - assert_array_equal(param_weight_full_init.asnumpy(), np.array([2, 2, 2])) diff --git a/contrib/svrg_optimization_python/benchmarks/benchmark1.png b/example/svrg_module/benchmarks/benchmark1.png similarity index 100% rename from contrib/svrg_optimization_python/benchmarks/benchmark1.png rename to example/svrg_module/benchmarks/benchmark1.png diff --git a/contrib/svrg_optimization_python/benchmarks/benchmark2.png b/example/svrg_module/benchmarks/benchmark2.png similarity index 100% rename from contrib/svrg_optimization_python/benchmarks/benchmark2.png rename to example/svrg_module/benchmarks/benchmark2.png diff --git a/example/svrg_module/common.py b/example/svrg_module/common.py new file mode 100644 index 000000000000..ac630fe6a684 --- /dev/null +++ b/example/svrg_module/common.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import mxnet as mx +import logging +from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule + + +def create_lin_reg_network(train_features, train_labels, feature_dim, batch_size, update_freq, ctx, logger): + # fit a linear regression model with mxnet SVRG + print("Fitting linear regression with mxnet") + train_iter = mx.io.NDArrayIter(train_features, train_labels, batch_size=batch_size, shuffle=True, + data_name='data', label_name='label') + data = mx.sym.Variable("data") + label = mx.sym.Variable("label") + weight = mx.sym.Variable("fc_weight", shape=(1, feature_dim)) + net = mx.sym.dot(data, weight.transpose()) + bias = mx.sym.Variable("fc_bias", shape=(1,), wd_mult=0.0, lr_mult=10.0) + net = mx.sym.broadcast_plus(net, bias) + net = mx.sym.LinearRegressionOutput(data=net, label=label) + + mod = SVRGModule(symbol=net, context=ctx, data_names=['data'], label_names=['label'], logger=logger, + update_freq=update_freq) + return train_iter, mod + + +def create_metrics(metrics): + metric = mx.metric.create(metrics) + return metric + + +def create_logger(): + logger = logging.getLogger('sgd_svrg') + logger.setLevel(logging.INFO) + formatter = logging.Formatter('%(asctime)s - %(message)s') + fh = logging.FileHandler('experiments_lr.log') + fh.setFormatter(formatter) + logger.addHandler(fh) + return logger + + +def accumulate_grad(grad_dict, mod): + param_names = mod._exec_group.param_names + for i in range(len(param_names)): + if param_names[i] not in grad_dict: + grad_dict[param_names[i]] = mod._exec_group.grad_arrays[i][0].copy() + else: + grad_dict[param_names[i]] = mx.ndarray.concat(grad_dict[param_names[i]], mod._exec_group.grad_arrays[i][0], + dim=0) + + +def calc_expectation(grad_dict, count): + for key in grad_dict.keys(): + grad_dict[str.format(key+"_expectation")] = mx.ndarray.sum(grad_dict[key], axis=0)/count + + return grad_dict + + +def calc_variance(grad_dict, count, param_names): + for i in range(len(param_names)): + diff_sqr = mx.ndarray.square(mx.nd.subtract(grad_dict[param_names[i]], + grad_dict[str.format(param_names[i]+"_expectation")])) + grad_dict[str.format(param_names[i] + "_variance")] = mx.ndarray.sum(diff_sqr, axis=0) / count diff --git a/example/svrg_module/data_reader.py b/example/svrg_module/data_reader.py new file mode 100644 index 000000000000..c4edca9b10ad --- /dev/null +++ b/example/svrg_module/data_reader.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import numpy as np + + +def read_year_prediction_data(fileName): + # Download data file + # from subprocess import call + # call(['wget', 'https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression/YearPredictionMSD.bz2']) + # call(['bzip2', '-d', 'YearPredictionMSD.bz2']) + + from sklearn.datasets import load_svmlight_file + + feature_dim = 90 + print("Reading data from disk...") + train_features, train_labels = load_svmlight_file(fileName, n_features=feature_dim, dtype=np.float32) + train_features = train_features.todense() + + # normalize the data: subtract means and divide by standard deviations + label_mean = train_labels.mean() + label_std = np.sqrt(np.square(train_labels - label_mean).mean()) + feature_means = train_features.mean(axis=0) + feature_stds = np.sqrt(np.square(train_features - feature_means).mean(axis=0)) + + train_features = (train_features - feature_means) / feature_stds + train_labels = (train_labels - label_mean) / label_std + + return feature_dim, train_features, train_labels diff --git a/contrib/svrg_optimization_python/test_svrg_train.py b/example/svrg_module/example_api_train.py similarity index 90% rename from contrib/svrg_optimization_python/test_svrg_train.py rename to example/svrg_module/example_api_train.py index 36e8ce448731..bffc1534b3b9 100644 --- a/contrib/svrg_optimization_python/test_svrg_train.py +++ b/example/svrg_module/example_api_train.py @@ -18,7 +18,7 @@ import mxnet as mx import numpy as np -from src.svrg_module import SVRGModule +from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule def test_svrg_intermediate_level_api(num_epoch): @@ -40,7 +40,7 @@ def test_svrg_intermediate_level_api(num_epoch): mod.forward_backward(data_batch=batch) mod.update() mod.update_metric(metrics, batch.label) - print('Epoch[%d] Time cost=%.3f', e, metrics.get()) + mod.logger.info('Epoch[%d] Train cost=%f', e, metrics.get()[1]) def test_svrg_high_level_api(num_epoch): @@ -52,11 +52,13 @@ def test_svrg_high_level_api(num_epoch): def create_network(): + import logging """Create a linear regression network for performing SVRG optimization. :return: an instance of mx.io.NDArrayIter :return: an instance of mx.mod.svrgmodule for performing SVRG optimization """ - mx.random.seed(42) + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.INFO, format=head) train_data = np.random.randint(1, 5, [1000, 2]) weights = np.array([1.0, 2.0]) train_label = train_data.dot(weights) @@ -70,7 +72,7 @@ def create_network(): mod = SVRGModule( symbol=lro, data_names=['data'], - label_names=['lin_reg_label'], update_freq=2 + label_names=['lin_reg_label'], update_freq=2, logger=logging ) return di, mod diff --git a/contrib/svrg_optimization_python/test_svrg_inference.py b/example/svrg_module/example_inference.py similarity index 92% rename from contrib/svrg_optimization_python/test_svrg_inference.py rename to example/svrg_module/example_inference.py index 0250cdec5899..994b95fd3f86 100644 --- a/contrib/svrg_optimization_python/test_svrg_inference.py +++ b/example/svrg_module/example_inference.py @@ -18,12 +18,14 @@ import mxnet as mx import numpy as np -from src.svrg_module import SVRGModule +import logging +from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule def test_svrg_inference(num_epoch): train_iter, val_iter, mod = create_network() - mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), + mod.fit(train_iter, eval_data=val_iter, eval_metric='mse', optimizer='sgd', + optimizer_params=(('learning_rate', 0.025),), num_epoch=num_epoch) def test_score(num_epoch): @@ -53,7 +55,8 @@ def create_network(): :return: an instance of mx.io.NDArrayIter :return: an instance of mx.mod.svrgmodule for performing SVRG optimization """ - mx.random.seed(42) + head = '%(asctime)-15s %(message)s' + logging.basicConfig(level=logging.INFO, format=head) data = np.random.randint(1, 5, [1000, 2]) n_train = int(data.shape[0] * 0.8) weights = np.array([1.0, 2.0]) diff --git a/example/svrg_module/train.py b/example/svrg_module/train.py new file mode 100644 index 000000000000..6d5a6b71a16a --- /dev/null +++ b/example/svrg_module/train.py @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import argparse +import mxnet as mx +from common import create_lin_reg_network, create_logger +from data_reader import read_year_prediction_data + +parser = argparse.ArgumentParser() +parser.add_argument('-e', dest='epochs', help='number of epochs for training phase', type=int, required=True) +parser.add_argument('-f', dest="updateFreq", help="update frequency for SVRGModule", type=int, default=2, required=True) +parser.add_argument('-b', dest="batch_size", help="define the batch size for training", type=int, + default=100, required=False) +parser.add_argument('-m', dest='metrics', help="create eval metric", type=str, required=False) +parser.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') +parser.add_argument('--kv-store', type=str, default='local', help='key-value store type') + +args = parser.parse_args() +# devices for training +ctx = mx.cpu() if args.gpus is None or args.gpus == "" else [mx.gpu(int(i)) for i in args.gpus.split(',')] +logger = create_logger() +kv = mx.kvstore.create(args.kv_store) + +feature_dim, train_features, train_labels = read_year_prediction_data('YearPredictionMSD') +train_iter, mod = create_lin_reg_network(train_features, train_labels, feature_dim, args.batch_size, 2, + ctx, logger) + +mod.fit(train_iter, eval_metric='mse', optimizer='sgd', + optimizer_params=(('learning_rate', 0.025), ), num_epoch=args.epochs, kvstore=kv) diff --git a/contrib/svrg_optimization_python/README.md b/python/mxnet/contrib/svrg_optimization/README.md similarity index 85% rename from contrib/svrg_optimization_python/README.md rename to python/mxnet/contrib/svrg_optimization/README.md index cc3b6bc41357..534183368ca2 100644 --- a/contrib/svrg_optimization_python/README.md +++ b/python/mxnet/contrib/svrg_optimization/README.md @@ -17,14 +17,16 @@ from the last mth epoch, plus the average of gradients over all data. * Ability to use relatively large learning rate compared to SGD, which leads to faster convergence compared to SGD. #### Testing: -Functional Tests: -* test_svrg_train.py: test script that tests both high-level and intermediate-level api for using SVRG -* test_svrg_inferency.py: test script for testing SVRGModule inference Unit Tests: * test_svrg_module.py: unittests for SVRGModule API * test_svrg_optimizer.py: unittests for SVRGOptimizer API +### Examples: +* examples/svrg_module/train.py: test script for training a linear regression model with SVRG optimization on YearPrediction MSD dataset +* examples/svrg_module/example_api_train.py: a demo script for training a SVRGModule using intermediate level api and high level api +* examples/svrg_module/example_inference.py: a demo script for SVRGModule inference + #### Benchmarks: An initial set of benchmarks has been performed on YearPredictionDatasetMSD with linear regression model. @@ -36,4 +38,4 @@ thus SGD needs to start with a small learning rate. The learning rate does not n therefore we can use a relatively larger learning rate. SGD with learning rate of (0.001, 0.0025) and SVRG with learning rate of (0.025) are benchmarked. Even though SVRG starts with a relatively large learning rate, it converges much faster than SGD in both cases. -This particular experiment result aligns with what was stated in the SVRG paper section 5. \ No newline at end of file +This particular experiment result aligns with what was stated in the SVRG paper section 5. \ No newline at end of file diff --git a/contrib/svrg_optimization_python/__init__.py b/python/mxnet/contrib/svrg_optimization/__init__.py similarity index 86% rename from contrib/svrg_optimization_python/__init__.py rename to python/mxnet/contrib/svrg_optimization/__init__.py index 4acf63ef7a13..6e70009983c9 100644 --- a/contrib/svrg_optimization_python/__init__.py +++ b/python/mxnet/contrib/svrg_optimization/__init__.py @@ -14,7 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +"""SVRGModule, SVRGOptimization import. +""" -from __future__ import absolute_import -from .src.svrg_optimizer import SVRGOptimizer -from .src.svrg_module import SVRGModule + +from . import svrg_module +from . import svrg_optimizer diff --git a/contrib/svrg_optimization_python/src/svrg_module.py b/python/mxnet/contrib/svrg_optimization/svrg_module.py similarity index 91% rename from contrib/svrg_optimization_python/src/svrg_module.py rename to python/mxnet/contrib/svrg_optimization/svrg_module.py index e587da00eb25..6d6149fd7fe4 100644 --- a/contrib/svrg_optimization_python/src/svrg_module.py +++ b/python/mxnet/contrib/svrg_optimization/svrg_module.py @@ -18,11 +18,11 @@ SVRG optimization logic. """ -import mxnet as mx import time import logging -from svrg_optimizer import SVRGOptimizer +import mxnet as mx from mxnet.module import Module +from .svrg_optimizer import SVRGOptimizer class SVRGModule(Module): @@ -62,7 +62,7 @@ class SVRGModule(Module): Examples -------- >>> # An example of declaring and using SVRGModule. - >>> mod = mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2) + >>> mod = SVRGModule(symbol=lro, data_names=['data'], label_names=['lin_reg_label'], update_freq=2) >>> mod.fit(di, eval_metric='mse', optimizer='sgd', optimizer_params=(('learning_rate', 0.025),), >>> num_epoch=num_epoch, kvstore='local') """ @@ -80,19 +80,20 @@ def __init__(self, symbol, data_names=('data',), label_names=('softmax_label',), if isinstance(update_freq, int): self.update_freq = update_freq else: - raise TypeError("update_freq must be an integer") + raise TypeError("update_freq in a SVRGModule must be an integer to represent the frequency for calculating " + "full gradients") self._mod_aux = mx.mod.Module(symbol, data_names, label_names, logger, context, work_load_list, fixed_param_names, state_names, group2ctxs, compression_params) - self._param_dict = [{} for ctx in self._context] + self._param_dict = None + self._ctx_len = len(self._context) def _reset_bind(self): """Internal function to reset binded state.""" super(SVRGModule, self)._reset_bind() self._mod_aux._reset_bind() - def reshape(self, data_shapes, label_shapes=None): super(SVRGModule, self).reshape(data_shapes, label_shapes=label_shapes) self._mod_aux.reshape(data_shapes, label_shapes=label_shapes) @@ -117,15 +118,15 @@ def init_optimizer(self, kvstore='local', optimizer='sgd', optimizer in the case an optimizer is already installed. """ # Init dict for storing average of full gradients for each device - for i in range(len(self._context)): - self._param_dict[i] = {key: mx.nd.zeros(shape=value.shape, ctx=self._context[i]) - for key, value in self.get_params()[0].items()} + + self._param_dict = [{key: mx.nd.zeros(shape=value.shape, ctx=self._context[i]) + for key, value in self.get_params()[0].items()} for i in range(self._ctx_len)] svrg_optimizer = self._create_optimizer(SVRGOptimizer.__name__, default_opt=optimizer, kvstore=kvstore, optimizer_params=optimizer_params) super(SVRGModule, self).init_optimizer(kvstore=kvstore, optimizer=svrg_optimizer, - optimizer_params=optimizer_params, force_init=force_init) + optimizer_params=optimizer_params, force_init=force_init) # Init additional keys for accumulating full grads in KVStore if self._kvstore: @@ -153,7 +154,7 @@ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params): # code partially copied from mxnet module.init_optimizer() to accomodate svrg_optimizer batch_size = self._exec_group.batch_size - (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, len(self._context), self._arg_params) + (kv_store, update_on_kvstore) = mx.model._create_kvstore(kvstore, self._ctx_len, self._arg_params) if kv_store and 'dist' in kv_store.type and '_sync' in kv_store.type: batch_size *= kv_store.num_workers rescale_grad = 1.0 / batch_size @@ -162,8 +163,8 @@ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params): if update_on_kvstore: idx2name.update(enumerate(self._exec_group.param_names)) else: - for k in range(len(self._context)): - idx2name.update({i * len(self._context) + k: n + for k in range(self._ctx_len): + idx2name.update({i * self._ctx_len + k: n for i, n in enumerate(self._exec_group.param_names)}) # update idx2name to include new keys @@ -182,8 +183,7 @@ def _create_optimizer(self, optimizer, default_opt, kvstore, optimizer_params): return optimizer def bind(self, data_shapes, label_shapes=None, for_training=True, - inputs_need_grad=False, force_rebind=False, shared_module=None, - grad_req='write'): + inputs_need_grad=False, force_rebind=False, shared_module=None, grad_req='write'): """Binds the symbols to construct executors for both two modules. This is necessary before one can perform computation with the SVRGModule. @@ -301,26 +301,23 @@ def update_full_grads(self, train_data): self._mod_aux.backward() nbatch += 1 - for i in range(len(self._context)): - for j in range(len(param_names)): - grads = self._mod_aux._exec_group.grad_arrays[j][i] - self._param_dict[i][param_names[j]] = mx.nd.broadcast_add(self._param_dict[i][param_names[j]], - grads, axis=0) + for ctx_index, _ in enumerate(self._context): + for index, name in enumerate(param_names): + grads = self._mod_aux._exec_group.grad_arrays[index][ctx_index] + self._param_dict[ctx_index][name] = mx.nd.broadcast_add(self._param_dict[ctx_index][name], + grads, axis=0) padding = batch.pad # Average full gradients over number of batches, accumulate in the kvstore if kvstore is set - for i in range(len(self._context)): - for key in self._param_dict[i].keys(): - self._param_dict[i][key] /= (nbatch - padding / train_data.batch_size) + for i in range(self._ctx_len): + for name in param_names: + self._param_dict[i][name] /= (nbatch - padding / train_data.batch_size) if self._kvstore: # Push a list of gradients from each device in the KVStore - for key in self._param_dict[0].keys(): - grad_list = [] - for i in range(len(self._param_dict)): - grad_list.append(self._param_dict[i][key]) - - self._accumulate_kvstore(key, grad_list) + for name in param_names: + grad_list = list(self._param_dict[i][name] for i in range(self._ctx_len)) + self._accumulate_kvstore(name, grad_list) def _accumulate_kvstore(self, key, value): """Accumulate gradients over all data in the KVstore. In distributed setting, each worker sees a portion of @@ -337,9 +334,9 @@ def _accumulate_kvstore(self, key, value): # Accumulate full gradients for current epochs self._kvstore.push(key + "_full", value) - self._kvstore._barrier() self._kvstore.pull(key + "_full", value) + self._allocate_gradients(key, value) def _allocate_gradients(self, key, value): @@ -354,11 +351,11 @@ def _allocate_gradients(self, key, value): A list of average of the full gradients in the KVStore. """ - num_device = len(self._context) - for i in range(len(self._param_dict)): - self._param_dict[i][key] = value[i] / num_device + for i in range(self._ctx_len): + self._param_dict[i][key] = value[i] / self._ctx_len - def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special_weight, g_special_weight_all_batch): + def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special_weight, + g_special_weight_all_batch): """Calculates the gradient based on the SVRG update rule. Parameters ---------- @@ -374,22 +371,24 @@ def _svrg_grads_update_rule(self, g_curr_batch_curr_weight, g_curr_batch_special Gradients calculated using SVRG update rule: grads = g_curr_batch_curr_weight - g_curr_batch_special_weight + g_special_weight_all_batch """ - for i in range(len(g_curr_batch_curr_weight)): - g_curr_batch_curr_weight[i] -= g_curr_batch_special_weight[i] - g_curr_batch_curr_weight[i] += g_special_weight_all_batch[i] + + for index, grad in enumerate(g_curr_batch_curr_weight): + grad -= g_curr_batch_special_weight[index] + grad += g_special_weight_all_batch[index] return g_curr_batch_curr_weight def update_svrg_gradients(self): """Calculates gradients based on the SVRG update rule. """ param_names = self._exec_group.param_names - for i in range(len(self._context)): - for j in range(len(param_names)): - g_curr_batch_reg = self._exec_group.grad_arrays[j][i] - g_curr_batch_special = self._mod_aux._exec_group.grad_arrays[j][i] - g_special_weight_all_batch = self._param_dict[i][param_names[j]] - g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, g_curr_batch_special, g_special_weight_all_batch) - self._exec_group.grad_arrays[j][i] = g_svrg + for ctx in range(self._ctx_len): + for index, name in enumerate(param_names): + g_curr_batch_reg = self._exec_group.grad_arrays[index][ctx] + g_curr_batch_special = self._mod_aux._exec_group.grad_arrays[index][ctx] + g_special_weight_all_batch = self._param_dict[ctx][name] + g_svrg = self._svrg_grads_update_rule(g_curr_batch_reg, g_curr_batch_special, + g_special_weight_all_batch) + self._exec_group.grad_arrays[index][ctx] = g_svrg def fit(self, train_data, eval_data=None, eval_metric='acc', epoch_end_callback=None, batch_end_callback=None, kvstore='local', @@ -461,6 +460,7 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', str -> NDArray. The resulting dict is used for pulling row_sparse parameters from the kvstore, where the str key is the name of the param, and the value is the row id of the param to pull. + validation_metric: """ assert num_epoch is not None, 'please specify number of epochs' @@ -533,7 +533,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', self.logger.info('Epoch[%d] Train-%s=%f', epoch, name, val) toc = time.time() self.logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic)) - print('Epoch[%d] Time cost=%.3f', epoch, eval_metric.get()) # sync aux params across devices arg_params, aux_params = self.get_params() @@ -551,7 +550,6 @@ def fit(self, train_data, eval_data=None, eval_metric='acc', batch_end_callback=eval_batch_end_callback, epoch=epoch) for name, val in res: self.logger.info('Epoch[%d] Validation-%s=%f', epoch, name, val) - print('Epoch[%d] Validation-%s=%f', epoch, name, val) def prepare(self, data_batch, sparse_row_id_fn=None): """Prepares two modules for processing a data batch. @@ -577,5 +575,5 @@ def prepare(self, data_batch, sparse_row_id_fn=None): parameters from the kvstore, where the str key is the name of the param, and the value is the row id of the param to pull. """ - super(SVRGModule, self).prepare(data_batch, sparse_row_id_fn=sparse_row_id_fn) - self._mod_aux.prepare(data_batch=sparse_row_id_fn) + super(SVRGModule, self).prepare(data_batch, sparse_row_id_fn) + self._mod_aux.prepare(data_batch, sparse_row_id_fn) diff --git a/contrib/svrg_optimization_python/src/svrg_optimizer.py b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py similarity index 92% rename from contrib/svrg_optimization_python/src/svrg_optimizer.py rename to python/mxnet/contrib/svrg_optimization/svrg_optimizer.py index bf9cca975cce..d19804b4dc5b 100644 --- a/contrib/svrg_optimization_python/src/svrg_optimizer.py +++ b/python/mxnet/contrib/svrg_optimization/svrg_optimizer.py @@ -23,9 +23,13 @@ @mx.optimizer.register class AssignmentOptimizer(mx.optimizer.Optimizer): + """AssignmentOptimizer assigns gradients to be weights for SVRGModule full gradients + accumulation in the KVStore + """ def update(self, index, weight, grad, state): weight[:] = grad + @mx.optimizer.register class SVRGOptimizer(mx.optimizer.Optimizer): """SVRGOptimizer is a wrapper class for two optimizers: one for accumulating full gradients and the other @@ -36,9 +40,9 @@ class SVRGOptimizer(mx.optimizer.Optimizer): default_optimizer: optimizer passed-in when invoke on mx.mod.init_optimizer """ - def __init__(self, default_optimizer, **kwargs): + def __init__(self, default_optimizer, **kwargs): # Reconstruct kwargs to identify additional params for default optimizer - extra_param, default_param = self._check_params(**kwargs) + default_param = self._check_params(**kwargs) super(SVRGOptimizer, self).__init__(**default_param) if isinstance(default_optimizer, str): self.default_opt = mx.optimizer.create(default_optimizer, **kwargs) @@ -51,14 +55,13 @@ def _check_params(self, **kwargs): optimizer_param = dict(kwargs) base_params = ['rescale_grad', 'param_idx2name', 'wd', 'clip_gradient', 'learning_rate', 'lr_scheduler', 'sym', 'begin_num_update', 'multi_precision', 'param_dict'] - extra_param = {} + default_params = {} - for key in optimizer_param.keys(): + for key, _ in optimizer_param.items(): if key in base_params: default_params[key] = optimizer_param[key] - else: - extra_param[key] = optimizer_param[key] - return extra_param, default_params + + return default_params def update(self, index, weight, grad, state): """Updates the given parameter using the corresponding gradient and state. If key contains 'full', update with @@ -107,9 +110,8 @@ def create_state(self, index, weight): name = self._check_index(index) if "full".lower() in name: - return + return self.aux_opt.create_state(index, weight) else: - # use the default optimizer return self.default_opt.create_state(index, weight) def _check_index(self, index): diff --git a/tests/python/unittest/test_contrib_svrg_module.py b/tests/python/unittest/test_contrib_svrg_module.py new file mode 100644 index 000000000000..2a71f8a590a3 --- /dev/null +++ b/tests/python/unittest/test_contrib_svrg_module.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule +import mxnet as mx +import numpy as np + + +def set_up(): + train_data = np.random.randint(1, 5, [1000, 2]) + weights = np.array([1.0, 2.0]) + train_label = train_data.dot(weights) + + di = mx.io.NDArrayIter(train_data, train_label, batch_size=32, shuffle=True, label_name='lin_reg_label') + X = mx.sym.Variable('data') + Y = mx.symbol.Variable('lin_reg_label') + fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1) + lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro") + + mod = SVRGModule( + symbol=lro, + data_names=['data'], + label_names=['lin_reg_label'], update_freq=2) + mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label) + mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, + force_init=False, allow_extra=False) + + return mod + + +def test_bind_module(): + mod = set_up() + assert mod.binded == True + assert mod._mod_aux.binded == True + + +def test_module_init(): + mod = set_up() + assert mod._mod_aux != None + + +def test_module_initializer(): + def regression_model(m): + x = mx.symbol.var("data", stype='csr') + v = mx.symbol.var("v", shape=(m, 1), init=mx.init.Uniform(scale=.1), + stype='row_sparse') + model = mx.symbol.dot(lhs=x, rhs=v) + y = mx.symbol.Variable("label") + model = mx.symbol.LinearRegressionOutput(data=model, label=y, name="out") + return model + + n, m = 128, 100 + model = regression_model(m) + + data = mx.nd.zeros(shape=(n, m), stype='csr') + label = mx.nd.zeros((n, 1)) + iterator = mx.io.NDArrayIter(data=data, label={'label': label}, + batch_size=n, last_batch_handle='discard') + + # create module + mod = SVRGModule(symbol=model, data_names=['data'], label_names=['label'], update_freq=2) + mod.bind(data_shapes=iterator.provide_data, label_shapes=iterator.provide_label) + mod.init_params() + v = mod._arg_params['v'] + assert v.stype == 'row_sparse' + assert np.sum(v.asnumpy()) != 0 + + +if __name__ == "__main__": + import nose + nose.runmodule() diff --git a/tests/python/unittest/test_contrib_svrg_optimizer.py b/tests/python/unittest/test_contrib_svrg_optimizer.py new file mode 100644 index 000000000000..bc42ed5991a8 --- /dev/null +++ b/tests/python/unittest/test_contrib_svrg_optimizer.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +from mxnet.contrib.svrg_optimization.svrg_module import SVRGModule +from mxnet.contrib.svrg_optimization.svrg_optimizer import SVRGOptimizer +import mxnet as mx +import numpy as np +from mxnet.test_utils import same + + +def create_network(): + mx.random.seed(42) + train_data = np.random.randint(1, 5, [1000, 2]) + weights = np.array([1.0, 2.0]) + train_label = train_data.dot(weights) + + batch_size = 32 + + di = mx.io.NDArrayIter(train_data, train_label, batch_size=batch_size, shuffle=True, label_name='lin_reg_label') + X = mx.sym.Variable('data') + Y = mx.symbol.Variable('lin_reg_label') + fully_connected_layer = mx.sym.FullyConnected(data=X, name='fc1', num_hidden=1) + lro = mx.sym.LinearRegressionOutput(data=fully_connected_layer, label=Y, name="lro") + + mod = SVRGModule( + symbol=lro, + data_names=['data'], + label_names=['lin_reg_label'], update_freq=2 + ) + + mod.bind(data_shapes=di.provide_data, label_shapes=di.provide_label) + mod.init_params(initializer=mx.init.Uniform(0.01), allow_missing=False, + force_init=False, allow_extra=False) + + return di, mod + + +def test_init_svrg_optimizer(): + di, mod = create_network() + + kv = mx.kv.create('local') + mod.init_optimizer(kvstore=kv, optimizer='sgd', optimizer_params=(('learning_rate', 0.01),), + force_init=False) + + assert type(mod._optimizer).__name__ == SVRGOptimizer.__name__ + + +def test_svrg_optimizer_constructor(): + _, mod = create_network() + + kv = mx.kv.create('local') + svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', learning_rate=-1.0) + kv.set_optimizer(svrg_optimizer) + + assert svrg_optimizer.default_opt.lr == -1.0 + + +def test_kvstore_init_aux_keys(): + param_idx2name= {0: "weight", 1: "weight_full"} + + svrg_optimizer = SVRGOptimizer(default_optimizer='sgd', param_idx2name= param_idx2name, learning_rate=1.0) + kv = mx.kv.create('local') + kv.set_optimizer(svrg_optimizer) + + param_weight_init = mx.nd.array([0, 0, 0]) + param_weight_update = mx.nd.array([1, 1, 1]) + + kv.init(0, param_weight_init) + kv.push(0, param_weight_update) + kv.pull(0, param_weight_init) + + param_weight_full_init = mx.nd.array([1, 1, 1]) + param_weight_full_update = mx.nd.array([2, 2, 2]) + + # Use AssignmentOptimizer + kv.init(1, param_weight_full_init) + kv.push(1, param_weight_full_update) + kv.pull(1, param_weight_full_init) + + same(param_weight_init.asnumpy(), np.array([-1, -1, -1])) + same(param_weight_full_init.asnumpy(), np.array([2, 2, 2])) + + +if __name__ == "__main__": + import nose + nose.runmodule() \ No newline at end of file