From 8aa725584b635838ca9136787ce24b76662fc37f Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 11:43:02 -0400 Subject: [PATCH 01/18] init ps.h --- include/mxnet/ps.h | 103 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 include/mxnet/ps.h diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h new file mode 100644 index 000000000000..2066af77d78f --- /dev/null +++ b/include/mxnet/ps.h @@ -0,0 +1,103 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file ps.h + * \brief parameter server interface for mxnet + */ +#ifndef MXNET_PS_H_ +#define MXNET_PS_H_ +#include "dmlc/io.h" +#include "narray.h" + +#if DMLC_USE_CXX11 == 0 +#error "C++11 was required for ps module." +#endif + +namespace mxnet { +namespace ps { + +/*! + * \brief A PS worker node + * + * a worker node can push data (gradient) to the servers and also pull data + * (weight) back + */ +class Worker { + public: + /*! + * \brief push \a data to the server nodes + * + * This function returns after adding a push operator to the engine. Any + * following operator requiring writing \a data will be blocked until the + * actual push is finished. + * + * \param data data for pushing + */ + void Push(const NArray& data); + + /*! + * \brief pull data from the server nodes + * + * This function returns after adding a pull operator to the engine. Any + * following operator requiring reading \a data will be blocked until the + * actual pull is finished. + * + * \param data data for pulling, should be pre-allocated + */ + void Pull(NArray& data); + + /** + * \brief wait until a push/pull finished + * + * Wait until data has already pushed to the servers or pulled back from the + * servers + * + * \param data data for waiting + */ + void Wait(const NArray& data); +}; + + + +/** + * \brief A PS server node + * + * a server node maintains data (weight), and allows user-defined handle to + * modify the data + */ +class Server { + public: + /** + * \brief constructor + * + * The server node triggers the user-defined handle in two ways: + * - online: the handle is called every time when data received from a + * worker. often used for asynchronous optimization + * - batch: the handle is called after data have been aggregated over all + * workers. often used for synchronous optimization + * + * \param batch true for batch, false for online + */ + explicit Server(bool batch = true); + + /** + * \brief Load from disk + */ + void Load(dmlc::Stream *fi); + + /** + * \brief Save to disk + */ + void Save(dmlc::Stream *fo); +}; + +/** + * \brief user-defined handle + * \param recv_data data (gradient) received from users + * \param my_data data (weight) maintained on the server + */ +void ServerHandle(const NArray& recv_data, NArray my_data); + + +} // namespace ps +} // namespace mxnet +#endif // MXNET_PS_H_ From 8b59d337bff10e29040906c5bc163751a2f39170 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 15:10:00 -0400 Subject: [PATCH 02/18] refactor ps::worker --- include/mxnet/ps.h | 44 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h index 2066af77d78f..9776a728e1e9 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/ps.h @@ -18,46 +18,44 @@ namespace ps { /*! * \brief A PS worker node * - * a worker node can push data (gradient) to the servers and also pull data - * (weight) back + * Worker node can push data (gradient) to the servers and pull data (aggregated + * gradient or weight) back. A worker is bind to a particular device, namely a + * worker can only push and pull data with the same \a device_id */ class Worker { public: /*! - * \brief push \a data to the server nodes + * \brief push data to the server nodes * - * This function returns after adding a push operator to the engine. Any - * following operator requiring writing \a data will be blocked until the - * actual push is finished. + * Push the key-value pair (\a key, \a value) to the server nodes. This + * function returns after adding a push operator to the engine. Any following + * operator requiring writing \a value will be blocked until the actual push is + * finished. * - * \param data data for pushing + * One can wait the push is finished via `data.Wait()` + * + * \param key the key for pushing + * \param value the value for pushing */ - void Push(const NArray& data); + void Push(int key, const NArray& value); /*! * \brief pull data from the server nodes * - * This function returns after adding a pull operator to the engine. Any - * following operator requiring reading \a data will be blocked until the - * actual pull is finished. - * - * \param data data for pulling, should be pre-allocated - */ - void Pull(NArray& data); - - /** - * \brief wait until a push/pull finished + * Pull the \a value associated with the \a key from the servers. This + * function returns after adding a pull operator to the engine. Any following + * operator requiring reading \a data will be blocked until the actual pull is + * finished. * - * Wait until data has already pushed to the servers or pulled back from the - * servers + * One can wait the pull is finished via `data.Wait()` * - * \param data data for waiting + * \param key the key for pulling + * \param value data for pulling, should be pre-allocated */ - void Wait(const NArray& data); + void Pull(int key, NArray* value); }; - /** * \brief A PS server node * From 75ee58a4f3c0311991dca03c5e392e4da29d5cef Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 16:43:29 -0400 Subject: [PATCH 03/18] update ps::worekr/server --- include/mxnet/ps.h | 92 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 73 insertions(+), 19 deletions(-) diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h index 9776a728e1e9..d694c661b390 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/ps.h @@ -8,9 +8,9 @@ #include "dmlc/io.h" #include "narray.h" -#if DMLC_USE_CXX11 == 0 -#error "C++11 was required for ps module." -#endif +#if DMLC_USE_CXX11 +#include +#endif // DMLC_USE_CXX11 namespace mxnet { namespace ps { @@ -21,6 +21,43 @@ namespace ps { * Worker node can push data (gradient) to the servers and pull data (aggregated * gradient or weight) back. A worker is bind to a particular device, namely a * worker can only push and pull data with the same \a device_id + * + * Example to implement allreduce + * \code + * // on worker node: + * NArray data; + * // init data... + * Worker comm; + * comm.Push(0, data); + * comm.Pull(0, &data); + * data.Wait(); + * + * // on server node: + * Server store; + * \endcode + * + * Example to implement asynchronous SGD + * \code + * // on worker node: + * NArray weight, grad; + * if (NodeInfo::Root()) { + * // init weight ... + * comm.Push(0, weight); + * } + * comm.Pull(0, &weight); + * // compute grad + * comm.Push(0, grad); + * + * // on server node: + * auto updater = [](const NArray& recv, NArray* weight) { + * if (weight->Empty()) { + * *weight = recv; // recv is the init weight + * } else { + * *weight += 0.1 * recv; // recv is grad + * } + * } + * Server store(false, updater); + * \endcode */ class Worker { public: @@ -34,6 +71,10 @@ class Worker { * * One can wait the push is finished via `data.Wait()` * + * For each push, each server node will apply a user-defined server handle to merge + * the value sent to the one maintained by itself. See \ref Server for more + * details. + * * \param key the key for pushing * \param value the value for pushing */ @@ -49,33 +90,53 @@ class Worker { * * One can wait the pull is finished via `data.Wait()` * + * System will guarantee that the all pushes issued by this worker have been + * applied, namely the server handle has been triggered. + * * \param key the key for pulling * \param value data for pulling, should be pre-allocated */ void Pull(int key, NArray* value); + + private: }; +#if DMLC_USE_CXX11 /** * \brief A PS server node * - * a server node maintains data (weight), and allows user-defined handle to - * modify the data + * A server node maintains data (weight or aggregated gradient), and allows + * user-defined handle to modify the data */ class Server { public: + /** + * \brief user-defined handle + */ + using Handle = std::function; + /** * \brief constructor * - * The server node triggers the user-defined handle in two ways: - * - online: the handle is called every time when data received from a - * worker. often used for asynchronous optimization - * - batch: the handle is called after data have been aggregated over all - * workers. often used for synchronous optimization + * Given a key, assume \a x is the received value and \a y is the value stored + * on the server node. The server updates \a y by `h(x, &y)`. The default \a h + * is ASSIGN, namely `*y = x`. + * + * The handle is triggered in two ways: + * + * - online: \a h is called every time when \a x is received from a worker. It + * is often used for asynchronous optimization. + * + * - batch: \a h is called after data have been aggregated over all + * workers. Assume \f$ x_i \f$ is received from worker i. Then the server + * first computes \f$\sum_{i=0}^n x = x_i\f$, and then applies \a h. It is often + * used for synchronous optimization * * \param batch true for batch, false for online + * \param h user-defined handle, default is assign */ - explicit Server(bool batch = true); + explicit Server(bool batch = true, const Handle& h = Handle()); /** * \brief Load from disk @@ -87,14 +148,7 @@ class Server { */ void Save(dmlc::Stream *fo); }; - -/** - * \brief user-defined handle - * \param recv_data data (gradient) received from users - * \param my_data data (weight) maintained on the server - */ -void ServerHandle(const NArray& recv_data, NArray my_data); - +#endif // DMLC_USE_CXX11 } // namespace ps } // namespace mxnet From 022e6a4728aa8781cedd751b1e2f611db6c3ea78 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 17:10:06 -0400 Subject: [PATCH 04/18] add ps::node --- include/mxnet/ps.h | 36 ++++++++++++++++++++++++++++++++++-- 1 file changed, 34 insertions(+), 2 deletions(-) diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h index d694c661b390..b3f35bbd8741 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/ps.h @@ -40,7 +40,8 @@ namespace ps { * \code * // on worker node: * NArray weight, grad; - * if (NodeInfo::Root()) { + * Worker comm; + * if (comm.Rank() == 0) { * // init weight ... * comm.Push(0, weight); * } @@ -109,7 +110,7 @@ class Worker { * A server node maintains data (weight or aggregated gradient), and allows * user-defined handle to modify the data */ -class Server { +class Server : public Node { public: /** * \brief user-defined handle @@ -150,6 +151,37 @@ class Server { }; #endif // DMLC_USE_CXX11 +/** + * \brief A PS node + */ +class Node { + public: + Node() {} + virtual ~Node() {} + + /** + * \brief Gets rank of this node in its group + * + * The rank is an integer in [0, \ref WorldlSize). + */ + int Rank(); + + /*! \brief Get the size of the node group. */ + int GroupSize() { return IsWorker() ? NumWorkers() : NumServer(); } + + /*! \brief Returns the number of worker nodes */ + static int NumWorkers(); + + /*! \brief Returns the number of server nodes */ + static int NumServers(); + + /*! \brief Returns true if this process runs workers */ + static bool IsWorker(); + + /*!\brief Returns true if this process only run servers */ + static bool IsServer(); +}; + } // namespace ps } // namespace mxnet #endif // MXNET_PS_H_ From 78ac41bd2021054c09dd00700932ebcb98296f27 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 17:27:58 -0400 Subject: [PATCH 05/18] more doc --- include/mxnet/ps.h | 62 ++++++++++++++++++++++------------------------ 1 file changed, 29 insertions(+), 33 deletions(-) diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h index b3f35bbd8741..63b83176c14b 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/ps.h @@ -15,6 +15,31 @@ namespace mxnet { namespace ps { +/*! \brief A PS node */ +class Node { + public: + Node() {} + virtual ~Node() {} + + /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ + int Rank(); + + /*! \brief Get the size of this node group. */ + int GroupSize() { return IsWorker() ? NumWorkers() : NumServer(); } + + /*! \brief Returns the number of worker nodes */ + static int NumWorkers(); + + /*! \brief Returns the number of server nodes */ + static int NumServers(); + + /*! \brief Returns true if this process runs workers */ + static bool IsWorker(); + + /*!\brief Returns true if this process only run servers */ + static bool IsServer(); +}; + /*! * \brief A PS worker node * @@ -60,7 +85,7 @@ namespace ps { * Server store(false, updater); * \endcode */ -class Worker { +class Worker : public Node { public: /*! * \brief push data to the server nodes @@ -76,6 +101,9 @@ class Worker { * the value sent to the one maintained by itself. See \ref Server for more * details. * + * For a given \a key, the \a value should be always has the same size over + * all workers. + * * \param key the key for pushing * \param value the value for pushing */ @@ -98,8 +126,6 @@ class Worker { * \param value data for pulling, should be pre-allocated */ void Pull(int key, NArray* value); - - private: }; @@ -151,36 +177,6 @@ class Server : public Node { }; #endif // DMLC_USE_CXX11 -/** - * \brief A PS node - */ -class Node { - public: - Node() {} - virtual ~Node() {} - - /** - * \brief Gets rank of this node in its group - * - * The rank is an integer in [0, \ref WorldlSize). - */ - int Rank(); - - /*! \brief Get the size of the node group. */ - int GroupSize() { return IsWorker() ? NumWorkers() : NumServer(); } - - /*! \brief Returns the number of worker nodes */ - static int NumWorkers(); - - /*! \brief Returns the number of server nodes */ - static int NumServers(); - - /*! \brief Returns true if this process runs workers */ - static bool IsWorker(); - - /*!\brief Returns true if this process only run servers */ - static bool IsServer(); -}; } // namespace ps } // namespace mxnet From dd97baae66ce031e242de996de29c447c4a69a82 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 20:20:08 -0400 Subject: [PATCH 06/18] simplify ps.h --- include/mxnet/ps.h | 147 +++++++++++++++------------------------------ 1 file changed, 50 insertions(+), 97 deletions(-) diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h index 63b83176c14b..d34c0ffd91d5 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/ps.h @@ -13,96 +13,69 @@ #endif // DMLC_USE_CXX11 namespace mxnet { -namespace ps { -/*! \brief A PS node */ -class Node { - public: - Node() {} - virtual ~Node() {} - - /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ - int Rank(); - - /*! \brief Get the size of this node group. */ - int GroupSize() { return IsWorker() ? NumWorkers() : NumServer(); } - - /*! \brief Returns the number of worker nodes */ - static int NumWorkers(); - - /*! \brief Returns the number of server nodes */ - static int NumServers(); - - /*! \brief Returns true if this process runs workers */ - static bool IsWorker(); - - /*!\brief Returns true if this process only run servers */ - static bool IsServer(); -}; - -/*! - * \brief A PS worker node +/** + * \brief distributed key-value store * - * Worker node can push data (gradient) to the servers and pull data (aggregated - * gradient or weight) back. A worker is bind to a particular device, namely a - * worker can only push and pull data with the same \a device_id + * A distributed key-value store which synchronous data over multiple devices + * and multiple machines. It supports user-defined updaters * * Example to implement allreduce * \code - * // on worker node: * NArray data; * // init data... - * Worker comm; - * comm.Push(0, data); - * comm.Pull(0, &data); + * KVStore store; + * store.Push(0, data); + * store.Pull(0, &data); * data.Wait(); - * - * // on server node: - * Server store; * \endcode * * Example to implement asynchronous SGD * \code - * // on worker node: - * NArray weight, grad; - * Worker comm; - * if (comm.Rank() == 0) { - * // init weight ... - * comm.Push(0, weight); + * Worker store; + * auto updater = [](const NArray& recv, NArray* weight) { + * *weight += 0.1 * recv; // recv is grad * } - * comm.Pull(0, &weight); - * // compute grad - * comm.Push(0, grad); + * store.Register(false, updater); * - * // on server node: - * auto updater = [](const NArray& recv, NArray* weight) { - * if (weight->Empty()) { - * *weight = recv; // recv is the init weight - * } else { - * *weight += 0.1 * recv; // recv is grad - * } + * NArray weight, grad; + * if (store.GetRank() == 0) { + * store.Init(0, weight); * } - * Server store(false, updater); - * \endcode + * store.Pull(0, &weight); + * // compute grad + * store.Push(0, grad); */ -class Worker : public Node { +class KVStore { public: + + /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ + int GetRank(); + + /*! \brief Get the size of this node group. */ + int GetGroupSize(); + + /** + * \brief Init data + * + * Init \a key with \a value. + */ + void Init(int key, const NArray& value); + /*! - * \brief push data to the server nodes + * \brief push data to the store * - * Push the key-value pair (\a key, \a value) to the server nodes. This + * Push the key-value pair (\a key, \a value) to the store. This * function returns after adding a push operator to the engine. Any following * operator requiring writing \a value will be blocked until the actual push is * finished. * * One can wait the push is finished via `data.Wait()` * - * For each push, each server node will apply a user-defined server handle to merge - * the value sent to the one maintained by itself. See \ref Server for more - * details. + * For each push, a user-defined updater is called to merge + * the value sent to the one maintained by itself. * * For a given \a key, the \a value should be always has the same size over - * all workers. * * \param key the key for pushing * \param value the value for pushing @@ -112,43 +85,35 @@ class Worker : public Node { /*! * \brief pull data from the server nodes * - * Pull the \a value associated with the \a key from the servers. This + * Pull the \a value associated with the \a key from the store. This * function returns after adding a pull operator to the engine. Any following * operator requiring reading \a data will be blocked until the actual pull is * finished. * * One can wait the pull is finished via `data.Wait()` * - * System will guarantee that the all pushes issued by this worker have been - * applied, namely the server handle has been triggered. + * Before sending back the value, the store will wait all pushed issued by + * this worker on \a key have been applied (updater has been triggered) + * and \a value is initialized * * \param key the key for pulling * \param value data for pulling, should be pre-allocated */ void Pull(int key, NArray* value); -}; - #if DMLC_USE_CXX11 -/** - * \brief A PS server node - * - * A server node maintains data (weight or aggregated gradient), and allows - * user-defined handle to modify the data - */ -class Server : public Node { - public: /** - * \brief user-defined handle + * \brief user-defined updater */ - using Handle = std::function; + using Updater = std::function; /** - * \brief constructor + * \brief register updater * - * Given a key, assume \a x is the received value and \a y is the value stored - * on the server node. The server updates \a y by `h(x, &y)`. The default \a h - * is ASSIGN, namely `*y = x`. + * The server allows user-defined handle to modify the data. Given a key, + * assume \a x is the received value and \a y is the value stored on the server + * node. The server updates \a y by `h(x, &y)`. The default \a h is ASSIGN, + * namely `*y = x`. * * The handle is triggered in two ways: * @@ -161,23 +126,11 @@ class Server : public Node { * used for synchronous optimization * * \param batch true for batch, false for online - * \param h user-defined handle, default is assign - */ - explicit Server(bool batch = true, const Handle& h = Handle()); - - /** - * \brief Load from disk + * \param updt user-defined updater, default is assign */ - void Load(dmlc::Stream *fi); - - /** - * \brief Save to disk - */ - void Save(dmlc::Stream *fo); -}; + void Register(bool batch = true, const Updater& updt = Updater()); #endif // DMLC_USE_CXX11 +}; - -} // namespace ps } // namespace mxnet #endif // MXNET_PS_H_ From da480241a627e4d980d99c074acedb1e2b41a747 Mon Sep 17 00:00:00 2001 From: muli Date: Sun, 6 Sep 2015 21:28:06 -0400 Subject: [PATCH 07/18] tiny --- include/mxnet/ps.h | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/include/mxnet/ps.h b/include/mxnet/ps.h index d34c0ffd91d5..b78e81c72248 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/ps.h @@ -49,18 +49,22 @@ namespace mxnet { class KVStore { public: - /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ - int GetRank(); + /** + * \brief get singleton instance + */ + static KVStore* Get(); - /*! \brief Get the size of this node group. */ - int GetGroupSize(); + /** + * \brief Init with the local devices + */ + void Init(const std::vector& devices); /** - * \brief Init data + * \brief data * - * Init \a key with \a value. + * insert a key-value pair. One must insert before push and pull */ - void Init(int key, const NArray& value); + void Insert(int key, const NArray& value); /*! * \brief push data to the store @@ -130,6 +134,13 @@ class KVStore { */ void Register(bool batch = true, const Updater& updt = Updater()); #endif // DMLC_USE_CXX11 + + /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ + int GetRank(); + + /*! \brief Get the number of nodes in this group. */ + int GetGroupSize(); + }; } // namespace mxnet From 3f199e50ab8164820e1bce92d6e9f1524bef7847 Mon Sep 17 00:00:00 2001 From: muli Date: Mon, 7 Sep 2015 14:45:44 -0400 Subject: [PATCH 08/18] add python test --- include/mxnet/{ps.h => kvstore.h} | 4 +- python/mxnet/kvstore.py | 58 ++++++++++++ tests/python/test_mlp_multi_devices.py | 124 +++++++++++++++++++++++++ 3 files changed, 184 insertions(+), 2 deletions(-) rename include/mxnet/{ps.h => kvstore.h} (96%) create mode 100644 python/mxnet/kvstore.py create mode 100644 tests/python/test_mlp_multi_devices.py diff --git a/include/mxnet/ps.h b/include/mxnet/kvstore.h similarity index 96% rename from include/mxnet/ps.h rename to include/mxnet/kvstore.h index b78e81c72248..c10dc1382c08 100644 --- a/include/mxnet/ps.h +++ b/include/mxnet/kvstore.h @@ -17,8 +17,8 @@ namespace mxnet { /** * \brief distributed key-value store * - * A distributed key-value store which synchronous data over multiple devices - * and multiple machines. It supports user-defined updaters + * A distributed key-value store for data synchronization over multiple + * devices/machines. It supports user-defined updater * * Example to implement allreduce * \code diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py new file mode 100644 index 000000000000..cf45569cb1ad --- /dev/null +++ b/python/mxnet/kvstore.py @@ -0,0 +1,58 @@ +# coding: utf-8 +""" KVStore in mxnet """ + +from __future__ import absolute_import + +def init(contexts): + """ Init key-value store with a list of context + + Parameters + ---------- + contexts : list of Context + The list of local devices used by this process + """ + +def insert(key, value): + """ Insert a key-value pair into the store + + Parameters + ---------- + key : int + The key + value : NArray + The value + """ + +def push(key, value): + """ Push a value into the store + + Parameters + ---------- + key : int + The key + value : NArray + The value + """ + +def pull(key, value): + """ Pull the value from the store + + Parameters + ---------- + key : int + The key + value : NArray + The value + """ + +def register(updater): + """ Register a updater into the store + + Example: + def Update(grad, weight): + weight[:] -= lr * grad / batch_size + + Parameters + ---------- + + """ diff --git a/tests/python/test_mlp_multi_devices.py b/tests/python/test_mlp_multi_devices.py new file mode 100644 index 000000000000..5859476dece5 --- /dev/null +++ b/tests/python/test_mlp_multi_devices.py @@ -0,0 +1,124 @@ +# pylint: skip-file +import sys +sys.path.append('../../python/') + +import mxnet as mx +import numpy as np +import os, gzip +import pickle as pickle +import get_data + +# symbol net +data = mx.symbol.Variable('data') +fc1 = mx.symbol.FullyConnected(data = data, name='fc1', nb_hidden=128) +act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', nb_hidden = 64) +act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") +fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', nb_hidden=10) +mlp = mx.symbol.Softmax(data = fc3, name = 'mlp') + +# use multiple devices +num_devs = 2 +devs = [mx.Context('cpu', i) for i in range(num_devs)] + +# infer shape +batch_size = 100 +input_shape = (batch_size / num_devs, 784) +param_shapes, out_shapes, aux_shapes = mlp.infer_shape(data=input_shape) +param_names = mlp.list_arguments() + +# allocate memory +params = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; +grads = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; + +# only need to init param on device 0 +mx.kvstore.init(devs) + +np.random.seed(0) +for i, v in enumerate(params[0]): + if "weight" in param_names[i]: + v.numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) + mx.kvstore.insert(i, v) + if "bias" in param_names[i]: + v.numpy[:] = 0.0 + mx.kvstore.insert(i, v) + +# register param updater +def make_updater(env): + def updater(grad, weight): + eta = env['lr'] / sqrt(env['iter']) / env['batch_size'] + env['iter'] += 1 + weight[:] -= eta * grad + return updater + +mx.kvstore.register(make_updater( + {'lr' : 0.1, 'batch_size' : batch_size, 'wd' : .00004})) + +# create exector for each device + +req = ['write_to' for i in range(len(param_names))] +executors = [mlp.bind(devs[i], params[i], grads[i], req) for i in range(num_devs)] +forward_out = [mx.narray.create(e.heads()[0].shape) for e in executors] + +# data reader +get_data.GetMNIST_ubyte() +train_dataiter = mx.io.MNISTIter( + image="data/train-images-idx3-ubyte", + label="data/train-labels-idx1-ubyte", + batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) +val_dataiter = mx.io.MNISTIter( + image="data/t10k-images-idx3-ubyte", + label="data/t10k-labels-idx1-ubyte", + batch_size=batch_size, shuffle=True, flat=True, silent=False) + +def cal_acc(out, label): + pred = np.argmax(out, axis=1) + return np.sum(pred == label) * 1.0 / out.shape[0] + +def test_mlp(): + epoch = 9 + acc_train = 0. + acc_val = 0. + for i in range(epoch): + # train + print("Epoch %d" % i) + train_acc = 0.0 + for data, label in train_dataiter: + data = data.numpy + label = label.numpy.flatten() + + for d in range(num_devs): + # feed input + k = batch_size / num_devs + idx = range(d*k, (d+1)*k) + params[d][param_names.index('data')].numpy[:] = data[idx,:] + params[d][param_names.index('mlp_label')].numpy[:] = label[idx] + + # pull weight + for j, m in enumerate(param_names): + if 'weight' in m or 'bias' in m: + mx.kvstore.pull(i, params[d][j]) + + # forward and backward + executors[d].forward() + # TODO copyto should not block execution? + executors[d].heads()[0].copyto(forward_out[d]) + executors[d].backward([forward_out[d]]) + + # push gradient + for j, m in enumerate(param_names): + if 'weight' in m or 'bias' in m: + mx.kvstore.pull(i, grads[d][j]) + + # TODO should evalute accuray here? otherwise the above forloop will not + # be paralleled? + + train_acc /= train_nbatch + train_nbatch += 1 + print("Train Acc: ", train_acc) + train_dataiter.reset() + + assert(acc_train > 0.98) + +if __name__ == "__main__": + test_mlp() From 13b4aba4615da33cbe3d90ebad6d0c2f86f86083 Mon Sep 17 00:00:00 2001 From: muli Date: Mon, 7 Sep 2015 15:18:11 -0400 Subject: [PATCH 09/18] update --- tests/python/test_mlp_multi_devices.py | 34 ++++++++++++-------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/tests/python/test_mlp_multi_devices.py b/tests/python/test_mlp_multi_devices.py index 5859476dece5..7c321decc6cf 100644 --- a/tests/python/test_mlp_multi_devices.py +++ b/tests/python/test_mlp_multi_devices.py @@ -32,16 +32,15 @@ grads = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; # only need to init param on device 0 -mx.kvstore.init(devs) - +mx.kvstore.init_devices(devs) +sync_keys = [i for i,m in enumerate(param_names) if "weight" in m or "bias" in m] np.random.seed(0) -for i, v in enumerate(params[0]): - if "weight" in param_names[i]: - v.numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) - mx.kvstore.insert(i, v) - if "bias" in param_names[i]: - v.numpy[:] = 0.0 - mx.kvstore.insert(i, v) +for k in sync_keys: + if "weight" in param_names[k]: + params[0][k].numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) + else: + params[0][k].numpy[:] = 0 +mx.kvstore.init([(k,params[0][k]) for k in sync_keys]) # register param updater def make_updater(env): @@ -86,32 +85,29 @@ def test_mlp(): for data, label in train_dataiter: data = data.numpy label = label.numpy.flatten() + k = batch_size / num_devs for d in range(num_devs): # feed input - k = batch_size / num_devs idx = range(d*k, (d+1)*k) params[d][param_names.index('data')].numpy[:] = data[idx,:] params[d][param_names.index('mlp_label')].numpy[:] = label[idx] # pull weight - for j, m in enumerate(param_names): - if 'weight' in m or 'bias' in m: - mx.kvstore.pull(i, params[d][j]) + mx.kvstore.pull([(k,params[d][k]) for k in sync_keys]) # forward and backward executors[d].forward() - # TODO copyto should not block execution? executors[d].heads()[0].copyto(forward_out[d]) executors[d].backward([forward_out[d]]) # push gradient - for j, m in enumerate(param_names): - if 'weight' in m or 'bias' in m: - mx.kvstore.pull(i, grads[d][j]) + mx.kvstore.push([(k, grads[d][k]) for k in sync_keys]) - # TODO should evalute accuray here? otherwise the above forloop will not - # be paralleled? + # evaluate. cannot put into the above for loop since it is blocked + # until all forwards are finished + for d in range(num_devs): + train_acc += cal_acc(forward_out[d].numpy, label[range(d*k, (d+1)*k)]) train_acc /= train_nbatch train_nbatch += 1 From 69576b8bb9f32ec282c90ce70869e5176f78237a Mon Sep 17 00:00:00 2001 From: muli Date: Mon, 7 Sep 2015 16:10:24 -0400 Subject: [PATCH 10/18] minior --- python/mxnet/kvstore.py | 22 ++++++++-------------- tests/python/test_mlp_multi_devices.py | 6 +++--- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index cf45569cb1ad..6db2a7a4e07f 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -3,8 +3,8 @@ from __future__ import absolute_import -def init(contexts): - """ Init key-value store with a list of context +def init_devices(contexts): + """ Init key-value store with a list of device contexts Parameters ---------- @@ -12,29 +12,23 @@ def init(contexts): The list of local devices used by this process """ -def insert(key, value): - """ Insert a key-value pair into the store +def init(kv_list): + """ Initialize a list of key-value pairs Parameters ---------- - key : int - The key - value : NArray - The value + kv_list : tuple or list/generator of tuples + a key-value tuple or a list of key-value tuples """ -def push(key, value): +def push(kv_list): """ Push a value into the store Parameters ---------- - key : int - The key - value : NArray - The value """ -def pull(key, value): +def pull(kv_list): """ Pull the value from the store Parameters diff --git a/tests/python/test_mlp_multi_devices.py b/tests/python/test_mlp_multi_devices.py index 7c321decc6cf..7a2e7ce1938a 100644 --- a/tests/python/test_mlp_multi_devices.py +++ b/tests/python/test_mlp_multi_devices.py @@ -40,7 +40,7 @@ params[0][k].numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) else: params[0][k].numpy[:] = 0 -mx.kvstore.init([(k,params[0][k]) for k in sync_keys]) +mx.kvstore.init((k,params[0][k]) for k in sync_keys) # register param updater def make_updater(env): @@ -94,7 +94,7 @@ def test_mlp(): params[d][param_names.index('mlp_label')].numpy[:] = label[idx] # pull weight - mx.kvstore.pull([(k,params[d][k]) for k in sync_keys]) + mx.kvstore.pull((k,params[d][k]) for k in sync_keys) # forward and backward executors[d].forward() @@ -102,7 +102,7 @@ def test_mlp(): executors[d].backward([forward_out[d]]) # push gradient - mx.kvstore.push([(k, grads[d][k]) for k in sync_keys]) + mx.kvstore.push((k, grads[d][k]) for k in sync_keys) # evaluate. cannot put into the above for loop since it is blocked # until all forwards are finished From 3f5334dd5d0b9aab0d3e931bf0b9c3a8dd1704a1 Mon Sep 17 00:00:00 2001 From: muli Date: Tue, 8 Sep 2015 13:23:06 -0400 Subject: [PATCH 11/18] kvstore_local --- include/mxnet/context.h | 18 ++++- include/mxnet/kvstore.h | 23 ++++--- src/kvstore/kvstore.cc | 44 +++++++++++++ src/kvstore/kvstore_base.h | 130 +++++++++++++++++++++++++++++++++++++ 4 files changed, 206 insertions(+), 9 deletions(-) create mode 100644 src/kvstore/kvstore.cc create mode 100644 src/kvstore/kvstore_base.h diff --git a/include/mxnet/context.h b/include/mxnet/context.h index 02f10231df1d..3aa57f5b55f5 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -5,7 +5,7 @@ */ #ifndef MXNET_CONTEXT_H_ #define MXNET_CONTEXT_H_ - +#include #include #include #include "./base.h" @@ -61,6 +61,22 @@ struct Context { if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; return true; } + + /** + * \brief returns an unique ID + */ + inline uint64_t UID() const { + return static_cast(dev_mask) << 32 | dev_id; + } + + /** + * \brief returns an unique string name + */ + inline std::string Name() const { + std::stringstream ss; + ss << (dev_mask == cpu::kDevMask ? "cpu" : "gpu") << ":" << dev_id; + return ss.str(); + } }; /*! diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index c10dc1382c08..730727024981 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -1,7 +1,7 @@ /*! * Copyright (c) 2015 by Contributors - * \file ps.h - * \brief parameter server interface for mxnet + * \file kvstore.h + * \brief key-value store interface for mxnet */ #ifndef MXNET_PS_H_ #define MXNET_PS_H_ @@ -14,6 +14,9 @@ namespace mxnet { +/*! \brief forward declaration */ +class KVStoreBase; + /** * \brief distributed key-value store * @@ -48,23 +51,22 @@ namespace mxnet { */ class KVStore { public: - /** * \brief get singleton instance */ - static KVStore* Get(); + static KVStore* Get() { static KVStore store; return &store; } /** * \brief Init with the local devices */ - void Init(const std::vector& devices); + void InitDevices(const std::vector& devices); /** * \brief data * - * insert a key-value pair. One must insert before push and pull + * init a key-value pair. One must insert before push and pull */ - void Insert(int key, const NArray& value); + void Init(int key, const NArray& value); /*! * \brief push data to the store @@ -132,7 +134,7 @@ class KVStore { * \param batch true for batch, false for online * \param updt user-defined updater, default is assign */ - void Register(bool batch = true, const Updater& updt = Updater()); + void Register(bool batch = true, const Updater& updt = Updater()) { } #endif // DMLC_USE_CXX11 /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ @@ -141,6 +143,11 @@ class KVStore { /*! \brief Get the number of nodes in this group. */ int GetGroupSize(); + private: + DISALLOW_COPY_AND_ASSIGN(KVStore); + KVStore() : store_(NULL) { } + ~KVStore(); + KVStoreBase* store_; }; } // namespace mxnet diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc new file mode 100644 index 000000000000..146c0d14c6ac --- /dev/null +++ b/src/kvstore/kvstore.cc @@ -0,0 +1,44 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file kvstore.cc + * \brief implement kv_store + */ +#include "mxnet/kvstore.h" +#include "kvstore_base.h" +#include +#include "dmlc/logging.h" + +namespace mxnet { + +void KVStore::InitDevices(const std::vector& devices) { + char* num_worker = getenv("DMLC_NUM_WORKER"); + if (num_worker == NULL || atoi(num_worker) == 1) { + store_ = new KVStoreBase(); + // local model + } else { + LOG(FATAL) << "not implemented yet"; + } + store_->InitDevices(devices); +} + +void KVStore::Init(int key, const NArray& value) { + CHECK(store_ != NULL) << "call InitDevices first"; + store_->Push(key, value, true); +} + +void KVStore::Push(int key, const NArray& value) { + CHECK(store_ != NULL) << "call InitDevices first"; + store_->Push(key, value, false); +} + +void KVStore::Pull(int key, NArray* value) { + CHECK(store_ != NULL) << "call InitDevices first"; + store_->Pull(key, value); +} + +int KVStore::GetRank() { return store_->GetRank(); } +int KVStore::GetGroupSize() { return store_->GetGroupSize(); } + +KVStore::~KVStore() { delete store_; } + +} // namespace mxnet diff --git a/src/kvstore/kvstore_base.h b/src/kvstore/kvstore_base.h new file mode 100644 index 000000000000..ee2c3d3872f5 --- /dev/null +++ b/src/kvstore/kvstore_base.h @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2015 by Contributors + * @file kvstore_base.h + * @brief local implementation + */ +#ifndef MXNET_KVSTORE_BASE_H_ +#define MXNET_KVSTORE_BASE_H_ +#include +#include +#include "mxnet/narray.h" +#include "mxnet/dag_engine.h" + +namespace mxnet { + +/** + * \brief store data in local machine + */ +class KVStoreBase { + public: + typedef int Key; + KVStoreBase() : inited_(false), engine_(DAGEngine::Get()), aggregate_(true) { } + virtual ~KVStoreBase() { } + virtual void InitDevices(const std::vector& devices) { + CHECK(!inited_) << "double intializatino"; + num_devs_ = 0; + for (auto d : devices) devs_[d.UID()] = num_devs_ ++; + inited_ = true; + } + + virtual void Push(Key key, const NArray& value, bool init) { + CHECK(inited_) << "call InitDevices first"; + auto it = local_.find(key); + if (init) { + CHECK(it == local_.end()) << "duplicate init of key = " << key; + Value val(num_devs_, value.Copy(local_ctx_)); + local_.insert({key, val}).first; + return; + } + CHECK(it != local_.end()) << "key " << key << " has not been inited"; + auto& local_val = it->second; + CHECK_EQ(local_val.arr.shape(), value.shape()) + << "shape mismatch: " << local_val.arr.shape() << ", " << value.shape(); + if (aggregate_) { + int dix = GetDevIdx(value.ctx()); + CHECK(!local_val.pending_push[dix]) + << "duplicate push on key " << key << "from " << value.ctx().Name(); + local_val.pending_push[dix] = true; + local_val.pending_push_arr.push_back(value); + if (local_val.pending_push_arr.size() == num_devs_) { + // do copy for the clossure + std::vector read; + std::swap(read, local_val.pending_push_arr); + std::vector read_val; + for (const auto& r : read) read_val.push_back(r.var()); + NArray write = local_val.arr; + + // issue push to engine + engine_->Push([this, read, write](RunContext rctx) mutable { + for (const auto& r : read) write += r; + }, local_ctx_, read_val, {write.var()}); + + // issue pull if necessary + for (auto& w : local_val.pending_pull_arr) { + CopyFromTo(local_val.arr, &w); + } + + // clean + local_val.pending_push.flip(); + local_val.pending_pull_arr.clear(); + } + } else { + LOG(FATAL) << "TODO"; + } + } + + virtual void Pull(Key key, NArray* value) { + CHECK(inited_) << "call InitDevices first"; + + auto it = local_.find(key); + CHECK(it != local_.end()) << "key " << key << " has not been inited"; + auto& local_val = it->second; + CHECK_EQ(local_val.arr.shape(), value->shape()) + << "shape mismatch: " << local_val.arr.shape() << ", " << value->shape(); + + if (aggregate_) { + int dix = GetDevIdx(value->ctx()); + if (local_val.pending_push[dix]) { + local_val.pending_pull_arr.push_back(*value); + return; + } + CopyFromTo(local_val.arr, value); + } + } + + virtual int GetRank() { return 0; } + virtual int GetGroupSize() { return 1; } + + protected: + /// get the continous device index starting from 0 + inline int GetDevIdx(const Context& ctx) { + auto it = devs_.find(ctx.UID()); + CHECK(it != devs_.end()) + << "unknow device " << ctx.Name(); + return it->second; + } + bool inited_; + DAGEngine* engine_; + bool aggregate_; + + /// map a device into an index + size_t num_devs_; + std::unordered_map devs_; + + /// internal storage of a value + struct Value { + Value() {} + Value(int num_devs, NArray data) + : pending_push(num_devs, false), pending_pull(num_devs, false) { + arr = data; + } + std::vector pending_push, pending_pull; + std::vector pending_push_arr, pending_pull_arr; + NArray arr; + }; + Context local_ctx_; + std::unordered_map local_; +}; + +} // namespace mxnet +#endif // MXNET_KVSTORE_BASE_H_ From 430cca7a0fc8501238acb5be2de02f542e811ac7 Mon Sep 17 00:00:00 2001 From: muli Date: Tue, 8 Sep 2015 16:49:03 -0400 Subject: [PATCH 12/18] passed test_kvstore.py --- include/mxnet/c_api.h | 37 ++++++++++++++++++++++++++++++++++++ python/mxnet/__init__.py | 3 +-- python/mxnet/kvstore.py | 36 +++++++++++++++++++++++++++++++++-- src/c_api.cc | 30 +++++++++++++++++++++++++++++ src/kvstore/kvstore.cc | 2 +- src/narray/narray.cc | 10 ++++++++-- tests/python/test_kvstore.py | 28 +++++++++++++++++++++++++++ 7 files changed, 139 insertions(+), 7 deletions(-) create mode 100644 tests/python/test_kvstore.py diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d43e0576fab3..3d3461b99c01 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -707,5 +707,42 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, */ MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, NArrayHandle *out); +/*! + * \brief initialize the kvstore + * \param num_devs number of devices + * \param dev_masks the list of device masks + * \param dev_ids the list of device IDs + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInitDevices(mx_uint num_devs, + int *dev_masks, + int *dev_ids); +/*! + * \brief Init (key,value) in kvstore + * \param key the int key + * \param value the NArray value + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInit(mx_uint key, + NArrayHandle value); + +/*! + * \brief Push (key,value) to kvstore + * \param key the int key + * \param value the NArray value + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePush(mx_uint key, + NArrayHandle value); + + +/*! + * \brief pull value from kvstore on the given key + * \param key the int key + * \param value the NArray value + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePull(mx_uint key, + NArrayHandle value); #endif // MXNET_C_API_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index a8632bfa2ff8..78b986e9c6e0 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -12,8 +12,7 @@ from .base import MXNetError from . import narray from . import symbol +from . import kvstore from . import io __version__ = "0.1.0" - - diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 6db2a7a4e07f..a92c0f194f2b 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -1,7 +1,11 @@ # coding: utf-8 """ KVStore in mxnet """ - from __future__ import absolute_import +import ctypes +from .narray import NArray +from .context import Context +from .base import _LIB +from .base import check_call, c_array def init_devices(contexts): """ Init key-value store with a list of device contexts @@ -11,6 +15,9 @@ def init_devices(contexts): contexts : list of Context The list of local devices used by this process """ + masks = c_array(ctypes.c_int, [c.device_mask for c in contexts]) + ids = c_array(ctypes.c_int, [c.device_id for c in contexts]) + check_call(_LIB.MXKVStoreInitDevices(len(contexts), masks, ids)) def init(kv_list): """ Initialize a list of key-value pairs @@ -18,8 +25,17 @@ def init(kv_list): Parameters ---------- kv_list : tuple or list/generator of tuples - a key-value tuple or a list of key-value tuples + a key-value tuple or a list of key-value tuples, where key is int and + key is """ + if isinstance(kv_list, tuple): + init([kv_list]) + else: + for kv in kv_list: + assert len(kv) == 2 + assert isinstance(kv[0], int) + assert isinstance(kv[1], NArray) + check_call(_LIB.MXKVStoreInit(kv[0], kv[1].handle)) def push(kv_list): """ Push a value into the store @@ -27,6 +43,14 @@ def push(kv_list): Parameters ---------- """ + if isinstance(kv_list, tuple): + push([kv_list]) + else: + for kv in kv_list: + assert len(kv) == 2 + assert isinstance(kv[0], int) + assert isinstance(kv[1], NArray) + check_call(_LIB.MXKVStorePush(kv[0], kv[1].handle)) def pull(kv_list): """ Pull the value from the store @@ -38,6 +62,14 @@ def pull(kv_list): value : NArray The value """ + if isinstance(kv_list, tuple): + pull([kv_list]) + else: + for kv in kv_list: + assert len(kv) == 2 + assert isinstance(kv[0], int) + assert isinstance(kv[1], NArray) + check_call(_LIB.MXKVStorePull(kv[0], kv[1].handle)) def register(updater): """ Register a updater into the store diff --git a/src/c_api.cc b/src/c_api.cc index 2e59613829bc..97fc83c2da68 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -844,3 +845,32 @@ int MXDataIterGetData(DataIterHandle handle, NArrayHandle *out) { *out = new NArray(db.data[0], 0); API_END(); } + +int MXKVStorePush(mx_uint key, NArrayHandle value) { + API_BEGIN(); + KVStore::Get()->Push(key, *static_cast(value)); + API_END(); +} + +int MXKVStoreInit(mx_uint key, NArrayHandle value) { + API_BEGIN(); + KVStore::Get()->Init(key, *static_cast(value)); + API_END(); +} + + +int MXKVStorePull(mx_uint key, NArrayHandle value) { + API_BEGIN(); + KVStore::Get()->Pull(key, static_cast(value)); + API_END(); +} + +int MXKVStoreInitDevices(mx_uint num_devs, int *dev_masks, int *dev_ids) { + API_BEGIN(); + std::vector devs; + for (mx_uint i = 0; i < num_devs; ++i) { + devs.push_back(Context(dev_masks[i], dev_ids[i])); + } + KVStore::Get()->InitDevices(devs); + API_END(); +} diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index 146c0d14c6ac..ed5298a99c8b 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -13,8 +13,8 @@ namespace mxnet { void KVStore::InitDevices(const std::vector& devices) { char* num_worker = getenv("DMLC_NUM_WORKER"); if (num_worker == NULL || atoi(num_worker) == 1) { - store_ = new KVStoreBase(); // local model + store_ = new KVStoreBase(); } else { LOG(FATAL) << "not implemented yet"; } diff --git a/src/narray/narray.cc b/src/narray/narray.cc index c9dda3f5a654..e65f4127e65d 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -25,12 +25,18 @@ template inline void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out) { - CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; + // no check if both of them are on cpu + if (lhs.ctx().dev_mask != cpu::kDevMask || rhs.ctx().dev_mask != cpu::kDevMask) + CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; // if out is none, allocate space if (out->is_none()) { *out = NArray(OP::GetShape(lhs.shape(), rhs.shape()), lhs.ctx(), true); } else { - CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; + // no check if both of them are on cpu + if (lhs.ctx().dev_mask != cpu::kDevMask || + out->ctx().dev_mask != cpu::kDevMask) { + CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; + } CHECK(out->shape() == OP::GetShape(lhs.shape(), rhs.shape())) << "target shape mismatch"; } diff --git a/tests/python/test_kvstore.py b/tests/python/test_kvstore.py new file mode 100644 index 000000000000..e0b2221b290a --- /dev/null +++ b/tests/python/test_kvstore.py @@ -0,0 +1,28 @@ +# pylint: skip-file +import sys +sys.path.append('../../python/') + +import mxnet as mx + +num_devs = 3 +devs = [mx.Context('cpu', i) for i in range(num_devs)] +mx.kvstore.init_devices(devs) + +s = (4,4) + +# init +a = mx.narray.empty(s,devs[0]) +a[:] = 1.0 +mx.kvstore.init((3, a)) + +# push +B = [mx.narray.empty(s,d) for d in devs] +for b in B: + b[:] = 2.0 + mx.kvstore.push((3, b)) + +# pull +C = [mx.narray.empty(s,d) for d in devs] +for c in C: + mx.kvstore.pull((3, c)) + print c.asnumpy() From 2d5ae84757129345da2e37e2178f283b953308c6 Mon Sep 17 00:00:00 2001 From: muli Date: Tue, 8 Sep 2015 21:54:29 -0400 Subject: [PATCH 13/18] update kvstore --- include/mxnet/c_api.h | 14 ++++ include/mxnet/kvstore.h | 30 +++++++- python/mxnet/kvstore.py | 16 ++++- src/c_api.cc | 17 +++++ src/kvstore/kvstore.cc | 14 ++++ src/kvstore/kvstore_base.h | 134 +++++++++++++++++++++-------------- tests/python/test_kvstore.py | 3 - 7 files changed, 169 insertions(+), 59 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 3d3461b99c01..9587f273ef84 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -717,6 +717,12 @@ MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, MXNET_DLL int MXKVStoreInitDevices(mx_uint num_devs, int *dev_masks, int *dev_ids); +/*! + * \brief clear the kvstore + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreClear(); + /*! * \brief Init (key,value) in kvstore * \param key the int key @@ -745,4 +751,12 @@ MXNET_DLL int MXKVStorePush(mx_uint key, MXNET_DLL int MXKVStorePull(mx_uint key, NArrayHandle value); +typedef void (MXKVStoreUpdater)(NArrayHandle recv, NArrayHandle local); +/*! + * \brief register an push updater + * \param updater udpater function + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreRegister(MXKVStoreUpdater updater); + #endif // MXNET_C_API_H_ diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 730727024981..c4cd9e124fe1 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -107,6 +107,11 @@ class KVStore { */ void Pull(int key, NArray* value); + /** + * \brief clear all data stored, handles registered, and devices binded + */ + void Clear(); + #if DMLC_USE_CXX11 /** * \brief user-defined updater @@ -114,7 +119,16 @@ class KVStore { using Updater = std::function; /** - * \brief register updater + * \brief returns the default updater, which is ASSIGN + */ + static Updater DefaultUpdater() { + return [](const NArray& a, NArray* b) { + CopyFromTo(a, b); + }; + } + + /** + * \brief set an updater * * The server allows user-defined handle to modify the data. Given a key, * assume \a x is the received value and \a y is the value stored on the server @@ -131,12 +145,24 @@ class KVStore { * first computes \f$\sum_{i=0}^n x = x_i\f$, and then applies \a h. It is often * used for synchronous optimization * + * Must be called before \ref Init * \param batch true for batch, false for online * \param updt user-defined updater, default is assign */ - void Register(bool batch = true, const Updater& updt = Updater()) { } + void SetUpdater(const Updater& updt); #endif // DMLC_USE_CXX11 + /** + * \brief set aggregator + * The aggregator first aggregate all pushed data among all devices before + * applying the updater + * + * The aggregator is enabled in default + * + * \param aggregator false to disable + */ + void SetAggregator(bool aggregator); + /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ int GetRank(); diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index a92c0f194f2b..f7ac400ba7d6 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -5,7 +5,7 @@ from .narray import NArray from .context import Context from .base import _LIB -from .base import check_call, c_array +from .base import check_call, c_array, NArrayHandle def init_devices(contexts): """ Init key-value store with a list of device contexts @@ -71,6 +71,17 @@ def pull(kv_list): assert isinstance(kv[1], NArray) check_call(_LIB.MXKVStorePull(kv[0], kv[1].handle)) +def updater_wrapper(updater): + def updater_handle(lhs_handle, rhs_handle): + updater(NArray(lhs_handle), NArray(rhs_handle)) + return updater_handle + +def void_updater(lhs, rhs): + pass + +updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) +updater_func = updater_proto(updater_wrapper(void_updater)) + def register(updater): """ Register a updater into the store @@ -82,3 +93,6 @@ def Update(grad, weight): ---------- """ + global updater_func + updater_func = updater_proto(updater) + check_call(_LIB.MXKVStoreRegister(updater_func)) diff --git a/src/c_api.cc b/src/c_api.cc index 97fc83c2da68..4fa26d3b4ff2 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -19,6 +19,7 @@ #include #include #include +#include // macro hanlding for threadlocal variables #ifdef __GNUC__ @@ -874,3 +875,19 @@ int MXKVStoreInitDevices(mx_uint num_devs, int *dev_masks, int *dev_ids) { KVStore::Get()->InitDevices(devs); API_END(); } + +int MXKVStoreClear() { + API_BEGIN(); + KVStore::Get()->Clear(); + API_END(); +} + +int MXKVStoreRegister(MXKVStoreUpdater updater) { + API_BEGIN(); + auto updt = [updater](const NArray& recv, NArray* local) { + NArray recv_copy = recv; + updater(&recv_copy, local); + }; + // KVStore::Get()->Register(updt); + API_END(); +} diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index ed5298a99c8b..7d619b4291e5 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -36,9 +36,23 @@ void KVStore::Pull(int key, NArray* value) { store_->Pull(key, value); } +void KVStore::Clear() { + if (store_) store_->Clear(); +} + int KVStore::GetRank() { return store_->GetRank(); } int KVStore::GetGroupSize() { return store_->GetGroupSize(); } +void KVStore::SetUpdater(const Updater& updt) { + CHECK(store_ != NULL) << "call InitDevices first"; + store_->SetUpdater(updt); +} + +void KVStore::SetAggregator(bool aggregator) { + CHECK(store_ != NULL) << "call InitDevices first"; + store_->SetAggregator(aggregator); +} + KVStore::~KVStore() { delete store_; } } // namespace mxnet diff --git a/src/kvstore/kvstore_base.h b/src/kvstore/kvstore_base.h index ee2c3d3872f5..862bfd6cb041 100644 --- a/src/kvstore/kvstore_base.h +++ b/src/kvstore/kvstore_base.h @@ -18,8 +18,9 @@ namespace mxnet { class KVStoreBase { public: typedef int Key; - KVStoreBase() : inited_(false), engine_(DAGEngine::Get()), aggregate_(true) { } + KVStoreBase() : engine_(DAGEngine::Get()) { Clear(); } virtual ~KVStoreBase() { } + virtual void InitDevices(const std::vector& devices) { CHECK(!inited_) << "double intializatino"; num_devs_ = 0; @@ -27,71 +28,97 @@ class KVStoreBase { inited_ = true; } - virtual void Push(Key key, const NArray& value, bool init) { + virtual void Push(Key key, const NArray& val, bool init) { CHECK(inited_) << "call InitDevices first"; auto it = local_.find(key); if (init) { CHECK(it == local_.end()) << "duplicate init of key = " << key; - Value val(num_devs_, value.Copy(local_ctx_)); - local_.insert({key, val}).first; + Value lc_v(num_devs_, val.Copy(local_ctx_)); + local_.insert({key, lc_v}).first; return; } + CHECK(it != local_.end()) << "key " << key << " has not been inited"; - auto& local_val = it->second; - CHECK_EQ(local_val.arr.shape(), value.shape()) - << "shape mismatch: " << local_val.arr.shape() << ", " << value.shape(); - if (aggregate_) { - int dix = GetDevIdx(value.ctx()); - CHECK(!local_val.pending_push[dix]) - << "duplicate push on key " << key << "from " << value.ctx().Name(); - local_val.pending_push[dix] = true; - local_val.pending_push_arr.push_back(value); - if (local_val.pending_push_arr.size() == num_devs_) { - // do copy for the clossure - std::vector read; - std::swap(read, local_val.pending_push_arr); - std::vector read_val; - for (const auto& r : read) read_val.push_back(r.var()); - NArray write = local_val.arr; - - // issue push to engine - engine_->Push([this, read, write](RunContext rctx) mutable { - for (const auto& r : read) write += r; - }, local_ctx_, read_val, {write.var()}); - - // issue pull if necessary - for (auto& w : local_val.pending_pull_arr) { - CopyFromTo(local_val.arr, &w); - } + auto& lc_v = it->second; + CHECK_EQ(lc_v.val.shape(), val.shape()) + << "shape mismatch: " << lc_v.val.shape() << ", " << val.shape(); + + if (aggregator_) { + int dix = GetDevIdx(val.ctx()); + CHECK(!lc_v.pending_push[dix]) + << "duplicate push on key " << key << "from " << val.ctx().Name(); + lc_v.pending_push[dix] = true; + lc_v.num_pending_push ++; + + if (lc_v.agg_buf.is_none()) { + lc_v.agg_buf = NArray(lc_v.val.shape(), local_ctx_); + } + if (val.ctx().dev_mask == cpu::kDevMask) { + lc_v.agg_buf += val; + } else { + // copy to pinned memory + LOG(FATAL) << "TODO"; + } + + if (lc_v.num_pending_push == num_devs_) { + // apply updater + if (updater_) updater_(lc_v.agg_buf, &lc_v.val); // clean - local_val.pending_push.flip(); - local_val.pending_pull_arr.clear(); + lc_v.agg_buf = 0.0; + lc_v.pending_push.flip(); + lc_v.num_pending_push = 0; + + + // issue blocked pull + for (auto& w : lc_v.pending_pull_val) { + CopyFromTo(lc_v.val, &w); + } + lc_v.pending_pull_val.clear(); } } else { LOG(FATAL) << "TODO"; } } - virtual void Pull(Key key, NArray* value) { + virtual void Pull(Key key, NArray* val) { CHECK(inited_) << "call InitDevices first"; auto it = local_.find(key); CHECK(it != local_.end()) << "key " << key << " has not been inited"; - auto& local_val = it->second; - CHECK_EQ(local_val.arr.shape(), value->shape()) - << "shape mismatch: " << local_val.arr.shape() << ", " << value->shape(); - - if (aggregate_) { - int dix = GetDevIdx(value->ctx()); - if (local_val.pending_push[dix]) { - local_val.pending_pull_arr.push_back(*value); + auto& lc_v = it->second; + CHECK_EQ(lc_v.val.shape(), val->shape()) + << "shape mismatch: " << lc_v.val.shape() << ", " << val->shape(); + + if (aggregator_) { + int dix = GetDevIdx(val->ctx()); + if (lc_v.pending_push[dix]) { + lc_v.pending_pull_val.push_back(*val); return; } - CopyFromTo(local_val.arr, value); + CopyFromTo(lc_v.val, val); + } else { + LOG(FATAL) << "TODO"; } } + virtual void Clear() { + inited_ = false; + aggregator_ = true; + num_devs_ = 0; + updater_ = KVStore::DefaultUpdater(); + devs_.clear(); + local_.clear(); + } + + virtual void SetUpdater(const KVStore::Updater updt) { + updater_ = updt; + } + + virtual void SetAggregator(bool aggregator) { + aggregator_ = aggregator; + } + virtual int GetRank() { return 0; } virtual int GetGroupSize() { return 1; } @@ -99,31 +126,32 @@ class KVStoreBase { /// get the continous device index starting from 0 inline int GetDevIdx(const Context& ctx) { auto it = devs_.find(ctx.UID()); - CHECK(it != devs_.end()) - << "unknow device " << ctx.Name(); + CHECK(it != devs_.end()) << "unknow device " << ctx.Name(); return it->second; } - bool inited_; DAGEngine* engine_; - bool aggregate_; + bool inited_; + bool aggregator_; + KVStore::Updater updater_; - /// map a device into an index size_t num_devs_; + /// map a device into an index std::unordered_map devs_; /// internal storage of a value struct Value { Value() {} Value(int num_devs, NArray data) - : pending_push(num_devs, false), pending_pull(num_devs, false) { - arr = data; + : pending_push(num_devs, false), num_pending_push(0) { + val = data; } - std::vector pending_push, pending_pull; - std::vector pending_push_arr, pending_pull_arr; - NArray arr; + std::vector pending_push; + std::vector pending_push_val, pending_pull_val; + size_t num_pending_push; + NArray val, agg_buf; }; - Context local_ctx_; std::unordered_map local_; + Context local_ctx_; }; } // namespace mxnet diff --git a/tests/python/test_kvstore.py b/tests/python/test_kvstore.py index e0b2221b290a..3010c4514dcb 100644 --- a/tests/python/test_kvstore.py +++ b/tests/python/test_kvstore.py @@ -1,7 +1,4 @@ # pylint: skip-file -import sys -sys.path.append('../../python/') - import mxnet as mx num_devs = 3 From 5c5d5dbe1e5c480d266fa53c0a788a0cdc2c0361 Mon Sep 17 00:00:00 2001 From: muli Date: Wed, 9 Sep 2015 14:53:06 -0400 Subject: [PATCH 14/18] refactor kvstore --- include/mxnet/c_api.h | 4 +- include/mxnet/kvstore.h | 72 ++++++++++++------- python/mxnet/kvstore.py | 43 ++++++----- src/c_api.cc | 4 +- src/kvstore/kvstore.cc | 42 ++--------- .../{kvstore_base.h => kvstore_local.h} | 69 ++++++------------ tests/python/test_kvstore.py | 9 +-- 7 files changed, 108 insertions(+), 135 deletions(-) rename src/kvstore/{kvstore_base.h => kvstore_local.h} (69%) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 9587f273ef84..615b108ba4f6 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -718,10 +718,10 @@ MXNET_DLL int MXKVStoreInitDevices(mx_uint num_devs, int *dev_masks, int *dev_ids); /*! - * \brief clear the kvstore + * \brief stop the kvstore * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXKVStoreClear(); +MXNET_DLL int MXKVStoreStop(); /*! * \brief Init (key,value) in kvstore diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index c4cd9e124fe1..04533f35254f 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -7,6 +7,7 @@ #define MXNET_PS_H_ #include "dmlc/io.h" #include "narray.h" +#include "dag_engine.h" #if DMLC_USE_CXX11 #include @@ -14,9 +15,6 @@ namespace mxnet { -/*! \brief forward declaration */ -class KVStoreBase; - /** * \brief distributed key-value store * @@ -59,14 +57,17 @@ class KVStore { /** * \brief Init with the local devices */ - void InitDevices(const std::vector& devices); + virtual void InitDevices(const std::vector& devices); /** * \brief data * * init a key-value pair. One must insert before push and pull */ - void Init(int key, const NArray& value); + virtual void Init(int key, const NArray& value) { + CHECK(impl_) << "call InitDevices first"; + impl_->Init(key, value); + } /*! * \brief push data to the store @@ -86,7 +87,10 @@ class KVStore { * \param key the key for pushing * \param value the value for pushing */ - void Push(int key, const NArray& value); + virtual void Push(int key, const NArray& value) { + CHECK(impl_) << "call InitDevices first"; + impl_->Push(key, value); + } /*! * \brief pull data from the server nodes @@ -105,12 +109,19 @@ class KVStore { * \param key the key for pulling * \param value data for pulling, should be pre-allocated */ - void Pull(int key, NArray* value); + virtual void Pull(int key, NArray* value) { + CHECK(impl_) << "call InitDevices first"; + impl_->Pull(key, value); + } /** * \brief clear all data stored, handles registered, and devices binded */ - void Clear(); + virtual void Stop() { + CHECK(impl_) << "call InitDevices first"; + impl_->Stop(); + Clear(); + } #if DMLC_USE_CXX11 /** @@ -118,15 +129,6 @@ class KVStore { */ using Updater = std::function; - /** - * \brief returns the default updater, which is ASSIGN - */ - static Updater DefaultUpdater() { - return [](const NArray& a, NArray* b) { - CopyFromTo(a, b); - }; - } - /** * \brief set an updater * @@ -149,7 +151,7 @@ class KVStore { * \param batch true for batch, false for online * \param updt user-defined updater, default is assign */ - void SetUpdater(const Updater& updt); + void set_updater(const Updater& updater) { updater_ = updater; } #endif // DMLC_USE_CXX11 /** @@ -161,19 +163,41 @@ class KVStore { * * \param aggregator false to disable */ - void SetAggregator(bool aggregator); + void set_aggregator(bool aggregator) { aggregator_ = aggregator; } /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ - int GetRank(); + int get_rank() { return rank_; } /*! \brief Get the number of nodes in this group. */ - int GetGroupSize(); + int get_group_size() { return group_size_; } + + protected: + virtual ~KVStore(); + KVStore() : engine_(DAGEngine::Get()), impl_(NULL) { Clear(); } + DAGEngine* engine_; + int rank_; + int group_size_; + bool aggregator_; + +#if DMLC_USE_CXX11 + /*! \brief returns the default updater, which is ASSIGN */ + Updater DefaultUpdater() { + return [](const NArray& a, NArray* b) { CopyFromTo(a, b); }; + } + Updater updater_; +#endif // DMLC_USE_CXX11 private: DISALLOW_COPY_AND_ASSIGN(KVStore); - KVStore() : store_(NULL) { } - ~KVStore(); - KVStoreBase* store_; + void Clear() { + delete impl_; + impl_ = NULL; + updater_ = DefaultUpdater(); + aggregator_ = true; + rank_ = 0; + group_size_ = 1; + } + KVStore* impl_; }; } // namespace mxnet diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index f7ac400ba7d6..b9600564fc17 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -19,6 +19,10 @@ def init_devices(contexts): ids = c_array(ctypes.c_int, [c.device_id for c in contexts]) check_call(_LIB.MXKVStoreInitDevices(len(contexts), masks, ids)) +def stop(): + """ stop kvstore """ + check_call(_LIB.MXKVStoreStop()) + def init(kv_list): """ Initialize a list of key-value pairs @@ -71,28 +75,29 @@ def pull(kv_list): assert isinstance(kv[1], NArray) check_call(_LIB.MXKVStorePull(kv[0], kv[1].handle)) -def updater_wrapper(updater): - def updater_handle(lhs_handle, rhs_handle): - updater(NArray(lhs_handle), NArray(rhs_handle)) - return updater_handle -def void_updater(lhs, rhs): - pass +# def updater_wrapper(updater): +# def updater_handle(lhs_handle, rhs_handle): +# updater(NArray(lhs_handle), NArray(rhs_handle)) +# return updater_handle -updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) -updater_func = updater_proto(updater_wrapper(void_updater)) +# def void_updater(lhs, rhs): +# pass -def register(updater): - """ Register a updater into the store +# updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) +# updater_func = updater_proto(updater_wrapper(void_updater)) - Example: - def Update(grad, weight): - weight[:] -= lr * grad / batch_size +# def register(updater): +# """ Register a updater into the store - Parameters - ---------- +# Example: +# def Update(grad, weight): +# weight[:] -= lr * grad / batch_size - """ - global updater_func - updater_func = updater_proto(updater) - check_call(_LIB.MXKVStoreRegister(updater_func)) +# Parameters +# ---------- + +# """ +# global updater_func +# updater_func = updater_proto(updater) +# check_call(_LIB.MXKVStoreRegister(updater_func)) diff --git a/src/c_api.cc b/src/c_api.cc index 4fa26d3b4ff2..af3bb820e419 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -876,9 +876,9 @@ int MXKVStoreInitDevices(mx_uint num_devs, int *dev_masks, int *dev_ids) { API_END(); } -int MXKVStoreClear() { +int MXKVStoreStop() { API_BEGIN(); - KVStore::Get()->Clear(); + KVStore::Get()->Stop(); API_END(); } diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc index 7d619b4291e5..b2e043fda95e 100644 --- a/src/kvstore/kvstore.cc +++ b/src/kvstore/kvstore.cc @@ -4,55 +4,23 @@ * \brief implement kv_store */ #include "mxnet/kvstore.h" -#include "kvstore_base.h" #include #include "dmlc/logging.h" +#include "kvstore_local.h" namespace mxnet { void KVStore::InitDevices(const std::vector& devices) { + CHECK(impl_ == NULL) << "double initialization, call Stop() first"; char* num_worker = getenv("DMLC_NUM_WORKER"); if (num_worker == NULL || atoi(num_worker) == 1) { - // local model - store_ = new KVStoreBase(); + impl_ = new KVStoreLocal(); } else { LOG(FATAL) << "not implemented yet"; } - store_->InitDevices(devices); + impl_->InitDevices(devices); } -void KVStore::Init(int key, const NArray& value) { - CHECK(store_ != NULL) << "call InitDevices first"; - store_->Push(key, value, true); -} - -void KVStore::Push(int key, const NArray& value) { - CHECK(store_ != NULL) << "call InitDevices first"; - store_->Push(key, value, false); -} - -void KVStore::Pull(int key, NArray* value) { - CHECK(store_ != NULL) << "call InitDevices first"; - store_->Pull(key, value); -} - -void KVStore::Clear() { - if (store_) store_->Clear(); -} - -int KVStore::GetRank() { return store_->GetRank(); } -int KVStore::GetGroupSize() { return store_->GetGroupSize(); } - -void KVStore::SetUpdater(const Updater& updt) { - CHECK(store_ != NULL) << "call InitDevices first"; - store_->SetUpdater(updt); -} - -void KVStore::SetAggregator(bool aggregator) { - CHECK(store_ != NULL) << "call InitDevices first"; - store_->SetAggregator(aggregator); -} - -KVStore::~KVStore() { delete store_; } +KVStore::~KVStore() { Clear(); } } // namespace mxnet diff --git a/src/kvstore/kvstore_base.h b/src/kvstore/kvstore_local.h similarity index 69% rename from src/kvstore/kvstore_base.h rename to src/kvstore/kvstore_local.h index 862bfd6cb041..9c33a805fbb9 100644 --- a/src/kvstore/kvstore_base.h +++ b/src/kvstore/kvstore_local.h @@ -1,43 +1,37 @@ /** * Copyright (c) 2015 by Contributors - * @file kvstore_base.h + * @file kvstore_local.h * @brief local implementation */ -#ifndef MXNET_KVSTORE_BASE_H_ -#define MXNET_KVSTORE_BASE_H_ +#ifndef MXNET_KVSTORE_LOCAL_H_ +#define MXNET_KVSTORE_LOCAL_H_ #include #include -#include "mxnet/narray.h" -#include "mxnet/dag_engine.h" +#include "mxnet/kvstore.h" namespace mxnet { /** * \brief store data in local machine */ -class KVStoreBase { +class KVStoreLocal : public KVStore { public: - typedef int Key; - KVStoreBase() : engine_(DAGEngine::Get()) { Clear(); } - virtual ~KVStoreBase() { } + KVStoreLocal() { Clear(); } + virtual ~KVStoreLocal() { Clear(); } virtual void InitDevices(const std::vector& devices) { - CHECK(!inited_) << "double intializatino"; num_devs_ = 0; for (auto d : devices) devs_[d.UID()] = num_devs_ ++; - inited_ = true; } - virtual void Push(Key key, const NArray& val, bool init) { - CHECK(inited_) << "call InitDevices first"; - auto it = local_.find(key); - if (init) { - CHECK(it == local_.end()) << "duplicate init of key = " << key; - Value lc_v(num_devs_, val.Copy(local_ctx_)); - local_.insert({key, lc_v}).first; - return; - } + virtual void Init(int key, const NArray& val) { + CHECK(local_.find(key) == local_.end()) << "duplicate init of key " << key; + Value lc_v(num_devs_, val.Copy(local_ctx_)); + local_.insert({key, lc_v}); + } + virtual void Push(int key, const NArray& val) { + auto it = local_.find(key); CHECK(it != local_.end()) << "key " << key << " has not been inited"; auto& lc_v = it->second; CHECK_EQ(lc_v.val.shape(), val.shape()) @@ -81,9 +75,7 @@ class KVStoreBase { } } - virtual void Pull(Key key, NArray* val) { - CHECK(inited_) << "call InitDevices first"; - + virtual void Pull(int key, NArray* val) { auto it = local_.find(key); CHECK(it != local_.end()) << "key " << key << " has not been inited"; auto& lc_v = it->second; @@ -102,38 +94,21 @@ class KVStoreBase { } } - virtual void Clear() { - inited_ = false; - aggregator_ = true; - num_devs_ = 0; - updater_ = KVStore::DefaultUpdater(); + virtual void Stop() { Clear(); } + + protected: + void Clear() { + num_devs_ = 0; devs_.clear(); local_.clear(); } - virtual void SetUpdater(const KVStore::Updater updt) { - updater_ = updt; - } - - virtual void SetAggregator(bool aggregator) { - aggregator_ = aggregator; - } - - virtual int GetRank() { return 0; } - virtual int GetGroupSize() { return 1; } - - protected: /// get the continous device index starting from 0 inline int GetDevIdx(const Context& ctx) { auto it = devs_.find(ctx.UID()); CHECK(it != devs_.end()) << "unknow device " << ctx.Name(); return it->second; } - DAGEngine* engine_; - bool inited_; - bool aggregator_; - KVStore::Updater updater_; - size_t num_devs_; /// map a device into an index std::unordered_map devs_; @@ -150,9 +125,9 @@ class KVStoreBase { size_t num_pending_push; NArray val, agg_buf; }; - std::unordered_map local_; + std::unordered_map local_; Context local_ctx_; }; } // namespace mxnet -#endif // MXNET_KVSTORE_BASE_H_ +#endif // MXNET_KVSTORE_LOCAL_H_ diff --git a/tests/python/test_kvstore.py b/tests/python/test_kvstore.py index 3010c4514dcb..a2c6a2de6761 100644 --- a/tests/python/test_kvstore.py +++ b/tests/python/test_kvstore.py @@ -13,13 +13,14 @@ mx.kvstore.init((3, a)) # push -B = [mx.narray.empty(s,d) for d in devs] -for b in B: - b[:] = 2.0 - mx.kvstore.push((3, b)) +# B = [mx.narray.empty(s,d) for d in devs] +# for b in B: +# b[:] = 2.0 +# mx.kvstore.push((3, b)) # pull C = [mx.narray.empty(s,d) for d in devs] for c in C: mx.kvstore.pull((3, c)) print c.asnumpy() +mx.kvstore.stop() From cdabe631348d4767d6e26adeea25e9f6d5c44df8 Mon Sep 17 00:00:00 2001 From: muli Date: Wed, 9 Sep 2015 15:08:12 -0400 Subject: [PATCH 15/18] fix bug in init agg_buf --- src/kvstore/kvstore_local.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 9c33a805fbb9..254696020ad7 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -46,6 +46,7 @@ class KVStoreLocal : public KVStore { if (lc_v.agg_buf.is_none()) { lc_v.agg_buf = NArray(lc_v.val.shape(), local_ctx_); + lc_v.agg_buf = 0.0; } if (val.ctx().dev_mask == cpu::kDevMask) { lc_v.agg_buf += val; @@ -63,7 +64,6 @@ class KVStoreLocal : public KVStore { lc_v.pending_push.flip(); lc_v.num_pending_push = 0; - // issue blocked pull for (auto& w : lc_v.pending_pull_val) { CopyFromTo(lc_v.val, &w); From dfb5b70990f1120b7cde022bab0883c333b8655d Mon Sep 17 00:00:00 2001 From: muli Date: Wed, 9 Sep 2015 16:48:31 -0400 Subject: [PATCH 16/18] key-value list for kvstore in c api, better test_kvstore --- include/mxnet/c_api.h | 30 ++++++----- python/mxnet/kvstore.py | 101 +++++++++++++++++++---------------- src/c_api.cc | 19 ++++--- tests/python/test_kvstore.py | 63 ++++++++++++++-------- 4 files changed, 124 insertions(+), 89 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 615b108ba4f6..0fcb90acbb6a 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -725,31 +725,37 @@ MXNET_DLL int MXKVStoreStop(); /*! * \brief Init (key,value) in kvstore - * \param key the int key - * \param value the NArray value + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXKVStoreInit(mx_uint key, - NArrayHandle value); +MXNET_DLL int MXKVStoreInit(int num, + int* keys, + NArrayHandle* vals); /*! * \brief Push (key,value) to kvstore - * \param key the int key - * \param value the NArray value + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXKVStorePush(mx_uint key, - NArrayHandle value); +MXNET_DLL int MXKVStorePush(int num, + int* keys, + NArrayHandle* vals); /*! * \brief pull value from kvstore on the given key - * \param key the int key - * \param value the NArray value + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values * \return 0 when success, -1 when failure happens */ -MXNET_DLL int MXKVStorePull(mx_uint key, - NArrayHandle value); +MXNET_DLL int MXKVStorePull(int num, + int* keys, + NArrayHandle* vals); typedef void (MXKVStoreUpdater)(NArrayHandle recv, NArrayHandle local); /*! diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index b9600564fc17..5c6555b4d029 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -7,6 +7,31 @@ from .base import _LIB from .base import check_call, c_array, NArrayHandle +def _ctype_key_value(keys, vals): + """ parse key-value args into ctype""" + if isinstance(keys, int): + if isinstance(vals, NArray): + return (1, + c_array(ctypes.c_int, [keys]), + c_array(NArrayHandle, [vals.handle])) + else: + for v in vals: + assert(isinstance(v, NArray)) + return (len(vals), + c_array(ctypes.c_int, [keys] * len(vals)), + c_array(NArrayHandle, [v.handle for v in vals])) + else: + for k in keys: + assert(isinstance(k, int)) + if len(keys) == 1: + return parse_key_value(keys[0], vals) + assert(len(keys) == len(vals)) + for v in vals: + assert(isinstance(v, NArray)) + return (len(keys), + c_array(ctypes.c_int, keys), + c_array(NArrayHandle, [v.handle for v in vals])) + def init_devices(contexts): """ Init key-value store with a list of device contexts @@ -23,7 +48,8 @@ def stop(): """ stop kvstore """ check_call(_LIB.MXKVStoreStop()) -def init(kv_list): + +def init(keys, values): """ Initialize a list of key-value pairs Parameters @@ -32,31 +58,19 @@ def init(kv_list): a key-value tuple or a list of key-value tuples, where key is int and key is """ - if isinstance(kv_list, tuple): - init([kv_list]) - else: - for kv in kv_list: - assert len(kv) == 2 - assert isinstance(kv[0], int) - assert isinstance(kv[1], NArray) - check_call(_LIB.MXKVStoreInit(kv[0], kv[1].handle)) + num, ckeys, cvals = _ctype_key_value(keys, values) + check_call(_LIB.MXKVStoreInit(num, ckeys, cvals)) -def push(kv_list): +def push(keys, values): """ Push a value into the store Parameters ---------- """ - if isinstance(kv_list, tuple): - push([kv_list]) - else: - for kv in kv_list: - assert len(kv) == 2 - assert isinstance(kv[0], int) - assert isinstance(kv[1], NArray) - check_call(_LIB.MXKVStorePush(kv[0], kv[1].handle)) + num, ckeys, cvals = _ctype_key_value(keys, values) + check_call(_LIB.MXKVStorePush(num, ckeys, cvals)) -def pull(kv_list): +def pull(keys, values): """ Pull the value from the store Parameters @@ -66,38 +80,31 @@ def pull(kv_list): value : NArray The value """ - if isinstance(kv_list, tuple): - pull([kv_list]) - else: - for kv in kv_list: - assert len(kv) == 2 - assert isinstance(kv[0], int) - assert isinstance(kv[1], NArray) - check_call(_LIB.MXKVStorePull(kv[0], kv[1].handle)) - + num, ckeys, cvals = _ctype_key_value(keys, values) + check_call(_LIB.MXKVStorePull(num, ckeys, cvals)) -# def updater_wrapper(updater): -# def updater_handle(lhs_handle, rhs_handle): -# updater(NArray(lhs_handle), NArray(rhs_handle)) -# return updater_handle +def _updater_wrapper(updater): + def updater_handle(lhs_handle, rhs_handle): + updater(NArray(lhs_handle), NArray(rhs_handle)) + return updater_handle -# def void_updater(lhs, rhs): -# pass +def _void_updater(lhs, rhs): + pass -# updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) -# updater_func = updater_proto(updater_wrapper(void_updater)) +_updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) +_updater_func = _updater_proto(_updater_wrapper(_void_updater)) -# def register(updater): -# """ Register a updater into the store +def register(updater): + """ Register a updater into the store -# Example: -# def Update(grad, weight): -# weight[:] -= lr * grad / batch_size + Example: + def Update(grad, weight): + weight[:] -= lr * grad / batch_size -# Parameters -# ---------- + Parameters + ---------- -# """ -# global updater_func -# updater_func = updater_proto(updater) -# check_call(_LIB.MXKVStoreRegister(updater_func)) + """ + global _updater_func + updater_func = _updater_proto(updater) + check_call(_LIB.MXKVStoreRegister(updater_func)) diff --git a/src/c_api.cc b/src/c_api.cc index 3869c74d5cf5..9069582b51ac 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -845,22 +845,27 @@ int MXDataIterGetData(DataIterHandle handle, NArrayHandle *out) { API_END(); } -int MXKVStorePush(mx_uint key, NArrayHandle value) { +int MXKVStoreInit(int num, int* keys, NArrayHandle* vals) { API_BEGIN(); - KVStore::Get()->Push(key, *static_cast(value)); + for (int i = 0; i < num; ++i) { + KVStore::Get()->Init(keys[i], *static_cast(vals[i])); + } API_END(); } -int MXKVStoreInit(mx_uint key, NArrayHandle value) { +int MXKVStorePush(int num, int* keys, NArrayHandle* vals) { API_BEGIN(); - KVStore::Get()->Init(key, *static_cast(value)); + for (int i = 0; i < num; ++i) { + KVStore::Get()->Push(keys[i], *static_cast(vals[i])); + } API_END(); } - -int MXKVStorePull(mx_uint key, NArrayHandle value) { +int MXKVStorePull(int num, int* keys, NArrayHandle* vals) { API_BEGIN(); - KVStore::Get()->Pull(key, static_cast(value)); + for (int i = 0; i < num; ++i) { + KVStore::Get()->Pull(keys[i], static_cast(vals[i])); + } API_END(); } diff --git a/tests/python/test_kvstore.py b/tests/python/test_kvstore.py index a2c6a2de6761..020e6de01453 100644 --- a/tests/python/test_kvstore.py +++ b/tests/python/test_kvstore.py @@ -1,26 +1,43 @@ # pylint: skip-file import mxnet as mx +import numpy as np -num_devs = 3 -devs = [mx.Context('cpu', i) for i in range(num_devs)] -mx.kvstore.init_devices(devs) - -s = (4,4) - -# init -a = mx.narray.empty(s,devs[0]) -a[:] = 1.0 -mx.kvstore.init((3, a)) - -# push -# B = [mx.narray.empty(s,d) for d in devs] -# for b in B: -# b[:] = 2.0 -# mx.kvstore.push((3, b)) - -# pull -C = [mx.narray.empty(s,d) for d in devs] -for c in C: - mx.kvstore.pull((3, c)) - print c.asnumpy() -mx.kvstore.stop() +def check_diff_to_scalar(A, x): + """ assert A == x""" + assert(np.sum(np.abs((A - x).asnumpy())) == 0) + +def test_aggregator(): + + num_devs = 2 + devs = [mx.Context('cpu', i) for i in range(num_devs)] + mx.kvstore.init_devices(devs) + + shape = (4, 4) + keys = (5, 9) + + # init all key-value pairs + mx.kvstore.init(keys, [mx.narray.zeros(shape) for k in keys]) + + # first push and then pull on one key + vals = [mx.narray.ones(shape, d) for d in devs] + mx.kvstore.push(keys[0], vals) + out = mx.narray.empty(shape, devs[1]) + mx.kvstore.pull(keys[0], out) + + check_diff_to_scalar(out, num_devs) + + # interleave push and pull for each device + vals = [] + for d in devs: + vals.append([mx.narray.ones(shape, d) for k in keys]) + mx.kvstore.push(keys, vals[-1]) + mx.kvstore.pull(keys, vals[-1]) + + for v in vals: + for d in v: + check_diff_to_scalar(d, num_devs) + + mx.kvstore.stop() + +if __name__ == '__main__': + test_aggregator() From 3c12463527cb5897cbb24107786bf87d362ad2ea Mon Sep 17 00:00:00 2001 From: muli Date: Wed, 9 Sep 2015 16:51:33 -0400 Subject: [PATCH 17/18] rename mlp_multi_dev --- .../{test_mlp_multi_devices.py => test_mlp_multi_devices.py.bak} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/{test_mlp_multi_devices.py => test_mlp_multi_devices.py.bak} (100%) diff --git a/tests/python/test_mlp_multi_devices.py b/tests/python/test_mlp_multi_devices.py.bak similarity index 100% rename from tests/python/test_mlp_multi_devices.py rename to tests/python/test_mlp_multi_devices.py.bak From 2c57ccb96183d5f100365d52edf121931b8aed44 Mon Sep 17 00:00:00 2001 From: muli Date: Wed, 9 Sep 2015 17:04:33 -0400 Subject: [PATCH 18/18] pass lint --- Makefile | 2 +- include/mxnet/context.h | 3 ++- include/mxnet/kvstore.h | 16 +++++++------- python/mxnet/kvstore.py | 42 ++++++++++++++++++------------------- src/kvstore/kvstore_local.h | 11 +++++----- 5 files changed, 38 insertions(+), 36 deletions(-) diff --git a/Makefile b/Makefile index 1bbfc12655a5..5334e52ab52e 100644 --- a/Makefile +++ b/Makefile @@ -124,7 +124,7 @@ doxygen: doxygen doc/Doxyfile clean: - $(RM) -r build lib/* *~ */*~ */*/*~ */*/*/*~ + $(RM) -r build lib/lib* *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - -include build/*.d diff --git a/include/mxnet/context.h b/include/mxnet/context.h index 3aa57f5b55f5..33a0865bcd0b 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -5,9 +5,10 @@ */ #ifndef MXNET_CONTEXT_H_ #define MXNET_CONTEXT_H_ -#include #include #include +#include +#include #include "./base.h" namespace mxnet { diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h index 04533f35254f..db7cf4e2954a 100644 --- a/include/mxnet/kvstore.h +++ b/include/mxnet/kvstore.h @@ -3,15 +3,15 @@ * \file kvstore.h * \brief key-value store interface for mxnet */ -#ifndef MXNET_PS_H_ -#define MXNET_PS_H_ -#include "dmlc/io.h" -#include "narray.h" -#include "dag_engine.h" - +#ifndef MXNET_KVSTORE_H_ +#define MXNET_KVSTORE_H_ +#include +#include #if DMLC_USE_CXX11 #include #endif // DMLC_USE_CXX11 +#include "narray.h" +#include "dag_engine.h" namespace mxnet { @@ -188,7 +188,6 @@ class KVStore { #endif // DMLC_USE_CXX11 private: - DISALLOW_COPY_AND_ASSIGN(KVStore); void Clear() { delete impl_; impl_ = NULL; @@ -198,7 +197,8 @@ class KVStore { group_size_ = 1; } KVStore* impl_; + DISALLOW_COPY_AND_ASSIGN(KVStore); }; } // namespace mxnet -#endif // MXNET_PS_H_ +#endif // MXNET_KVSTORE_H_ diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py index 5c6555b4d029..bbdaee02f97e 100644 --- a/python/mxnet/kvstore.py +++ b/python/mxnet/kvstore.py @@ -1,9 +1,9 @@ # coding: utf-8 +# pylint: disable=invalid-name, """ KVStore in mxnet """ from __future__ import absolute_import import ctypes from .narray import NArray -from .context import Context from .base import _LIB from .base import check_call, c_array, NArrayHandle @@ -24,7 +24,7 @@ def _ctype_key_value(keys, vals): for k in keys: assert(isinstance(k, int)) if len(keys) == 1: - return parse_key_value(keys[0], vals) + return _ctype_key_value(keys[0], vals) assert(len(keys) == len(vals)) for v in vals: assert(isinstance(v, NArray)) @@ -83,28 +83,28 @@ def pull(keys, values): num, ckeys, cvals = _ctype_key_value(keys, values) check_call(_LIB.MXKVStorePull(num, ckeys, cvals)) -def _updater_wrapper(updater): - def updater_handle(lhs_handle, rhs_handle): - updater(NArray(lhs_handle), NArray(rhs_handle)) - return updater_handle +# def _updater_wrapper(updater): +# def updater_handle(lhs_handle, rhs_handle): +# updater(NArray(lhs_handle), NArray(rhs_handle)) +# return updater_handle -def _void_updater(lhs, rhs): - pass +# def _void_updater(lhs, rhs): +# pass -_updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) -_updater_func = _updater_proto(_updater_wrapper(_void_updater)) +# _updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) +# _updater_func = _updater_proto(_updater_wrapper(_void_updater)) -def register(updater): - """ Register a updater into the store +# def register(updater): +# """ Register a updater into the store - Example: - def Update(grad, weight): - weight[:] -= lr * grad / batch_size +# Example: +# def Update(grad, weight): +# weight[:] -= lr * grad / batch_size - Parameters - ---------- +# Parameters +# ---------- - """ - global _updater_func - updater_func = _updater_proto(updater) - check_call(_LIB.MXKVStoreRegister(updater_func)) +# """ +# global _updater_func +# updater_func = _updater_proto(updater) +# check_call(_LIB.MXKVStoreRegister(updater_func)) diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h index 254696020ad7..4115f868f547 100644 --- a/src/kvstore/kvstore_local.h +++ b/src/kvstore/kvstore_local.h @@ -3,10 +3,11 @@ * @file kvstore_local.h * @brief local implementation */ -#ifndef MXNET_KVSTORE_LOCAL_H_ -#define MXNET_KVSTORE_LOCAL_H_ +#ifndef MXNET_KVSTORE_KVSTORE_LOCAL_H_ +#define MXNET_KVSTORE_KVSTORE_LOCAL_H_ #include #include +#include #include "mxnet/kvstore.h" namespace mxnet { @@ -21,7 +22,7 @@ class KVStoreLocal : public KVStore { virtual void InitDevices(const std::vector& devices) { num_devs_ = 0; - for (auto d : devices) devs_[d.UID()] = num_devs_ ++; + for (auto d : devices) devs_[d.UID()] = num_devs_++; } virtual void Init(int key, const NArray& val) { @@ -42,7 +43,7 @@ class KVStoreLocal : public KVStore { CHECK(!lc_v.pending_push[dix]) << "duplicate push on key " << key << "from " << val.ctx().Name(); lc_v.pending_push[dix] = true; - lc_v.num_pending_push ++; + lc_v.num_pending_push++; if (lc_v.agg_buf.is_none()) { lc_v.agg_buf = NArray(lc_v.val.shape(), local_ctx_); @@ -130,4 +131,4 @@ class KVStoreLocal : public KVStore { }; } // namespace mxnet -#endif // MXNET_KVSTORE_LOCAL_H_ +#endif // MXNET_KVSTORE_KVSTORE_LOCAL_H_