Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #54 from mli/master
Browse files Browse the repository at this point in the history
python updater for kvstore
  • Loading branch information
mli committed Sep 10, 2015
2 parents 44bd80a + 7e1cd36 commit 44fa3b9
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 128 deletions.
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,6 @@ typedef void (MXKVStoreUpdater)(NArrayHandle recv, NArrayHandle local);
* \param updater udpater function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreRegister(MXKVStoreUpdater updater);
MXNET_DLL int MXKVStoreSetUpdater(MXKVStoreUpdater updater);

#endif // MXNET_C_API_H_
145 changes: 56 additions & 89 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,14 @@
#include <functional>
#endif // DMLC_USE_CXX11
#include "narray.h"
#include "dag_engine.h"

namespace mxnet {

/**
* \brief distributed key-value store
*
* A distributed key-value store for data synchronization over multiple
* devices/machines. It supports user-defined updater
*
* Example to implement allreduce
* \code
* NArray data;
* // init data...
* KVStore store;
* store.Push(0, data);
* store.Pull(0, &data);
* data.Wait();
* \endcode
*
* Example to implement asynchronous SGD
* \code
* Worker store;
* auto updater = [](const NArray& recv, NArray* weight) {
* *weight += 0.1 * recv; // recv is grad
* }
* store.Register(false, updater);
*
* NArray weight, grad;
* if (store.GetRank() == 0) {
* store.Init(0, weight);
* }
* store.Pull(0, &weight);
* // compute grad
* store.Push(0, grad);
* devices/machines. It supports aggregator and user-defined updater.
*/
class KVStore {
public:
Expand All @@ -62,11 +35,11 @@ class KVStore {
/**
* \brief data
*
* init a key-value pair. One must insert before push and pull
* Initialize a key-value pair to the store. For any \a key, this function
* should only be called once.
*/
virtual void Init(int key, const NArray& value) {
CHECK(impl_) << "call InitDevices first";
impl_->Init(key, value);
get_impl()->Init(key, value);
}

/*!
Expand All @@ -77,124 +50,118 @@ class KVStore {
* operator requiring writing \a value will be blocked until the actual push is
* finished.
*
* One can wait the push is finished via `data.Wait()`
* One can wait the push is finished via `data.WaitToWrite()`
*
* For each push, a user-defined updater is called to merge
* the value sent to the one maintained by itself.
* For each push, an updater is called to merge the value to the one
* stored. The default updater is Assign.
*
* For a given \a key, the \a value should be always has the same size over
* One must call Init() on \a key before. And the \a value should be always
* has the same size as being inited.
*
* \param key the key for pushing
* \param value the value for pushing
*/
virtual void Push(int key, const NArray& value) {
CHECK(impl_) << "call InitDevices first";
impl_->Push(key, value);
get_impl()->Push(key, value);
}

/*!
* \brief pull data from the server nodes
* \brief pull data from the store
*
* Pull the \a value associated with the \a key from the store. This
* function returns after adding a pull operator to the engine. Any following
* operator requiring reading \a data will be blocked until the actual pull is
* finished.
*
* One can wait the pull is finished via `data.Wait()`
* One can wait the pull is finished via `data.WaitToRead()`
*
* Before sending back the value, the store will wait all pushed issued by
* this worker on \a key have been applied (updater has been triggered)
* and \a value is initialized
* this worker on \a key have been applied (updater has been applied). One
* must call Init() on \a key before.
*
* \param key the key for pulling
* \param value data for pulling, should be pre-allocated
*/
virtual void Pull(int key, NArray* value) {
CHECK(impl_) << "call InitDevices first";
impl_->Pull(key, value);
get_impl()->Pull(key, value);
}

/**
* \brief clear all data stored, handles registered, and devices binded
* \brief clear all key-value pairs stored, updater, and devices binded
*/
virtual void Stop() {
CHECK(impl_) << "call InitDevices first";
impl_->Stop();
Clear();
}
virtual void Stop() { get_impl()->Stop(); delete impl_; impl_ = NULL; }

#if DMLC_USE_CXX11
/**
* \brief user-defined updater
* \brief the prototype of user-defined updater
*/
using Updater = std::function<void(const NArray&, NArray*)>;

/*! \brief returns the default updater, which is ASSIGN */
Updater DefaultUpdater() {
return [](const NArray& a, NArray* b) { CopyFromTo(a, b); };
}

/**
* \brief set an updater
*
* The server allows user-defined handle to modify the data. Given a key,
* assume \a x is the received value and \a y is the value stored on the server
* node. The server updates \a y by `h(x, &y)`. The default \a h is ASSIGN,
* namely `*y = x`.
*
* The handle is triggered in two ways:
* Given a key, assume \a x is the received (pushed) value and \a y is the
* value stored on the store node. The store updates \a y by `h(x, &y)`. The
* default \a h is ASSIGN, namely `*y = x`.
*
* - online: \a h is called every time when \a x is received from a worker. It
* is often used for asynchronous optimization.
* The updater is applied in two ways depends on whether there is an aggregator
*
* - batch: \a h is called after data have been aggregated over all
* - yes: \a h is called after data have been aggregated over all
* workers. Assume \f$ x_i \f$ is received from worker i. Then the server
* first computes \f$\sum_{i=0}^n x = x_i\f$, and then applies \a h. It is often
* used for synchronous optimization
*
* Must be called before \ref Init
* - no: \a h is called every time when \a x is received from a worker. It
* is often used for asynchronous optimization.
*
* \param batch true for batch, false for online
* \param updt user-defined updater, default is assign
*/
void set_updater(const Updater& updater) { updater_ = updater; }
virtual void set_updater(const Updater& updater) {
get_impl()->set_updater(updater);
}

#endif // DMLC_USE_CXX11

/**
* \brief set aggregator
* The aggregator first aggregate all pushed data among all devices before
* applying the updater
*
* The aggregator is enabled in default
*
* \param aggregator false to disable
*/
void set_aggregator(bool aggregator) { aggregator_ = aggregator; }

/*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */
int get_rank() { return rank_; }
virtual void set_aggregator(bool aggregator) {
get_impl()->set_aggregator(aggregator);
}

/*! \brief Get the number of nodes in this group. */
int get_group_size() { return group_size_; }
/*!
* \brief Gets rank of this node in its group, which is in [0, GroupSize).
*/
virtual int get_rank() const {
return get_impl()->get_rank();
}

protected:
virtual ~KVStore();
KVStore() : engine_(DAGEngine::Get()), impl_(NULL) { Clear(); }
DAGEngine* engine_;
int rank_;
int group_size_;
bool aggregator_;

#if DMLC_USE_CXX11
/*! \brief returns the default updater, which is ASSIGN */
Updater DefaultUpdater() {
return [](const NArray& a, NArray* b) { CopyFromTo(a, b); };
/*!
* \brief Get the number of nodes in this group.
*/
virtual int get_group_size() const {
return get_impl()->get_group_size();
}
Updater updater_;
#endif // DMLC_USE_CXX11

protected:
KVStore() : impl_(NULL) { }
virtual ~KVStore() { delete impl_; impl_ = NULL; }

private:
void Clear() {
delete impl_;
impl_ = NULL;
updater_ = DefaultUpdater();
aggregator_ = true;
rank_ = 0;
group_size_ = 1;
inline KVStore* get_impl() const {
CHECK(impl_) << "call InitDevices() first";
return impl_;
}
KVStore* impl_;
DISALLOW_COPY_AND_ASSIGN(KVStore);
Expand Down
66 changes: 36 additions & 30 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name,
# pylint: disable=invalid-name, global-variable-undefined,
""" KVStore in mxnet """
from __future__ import absolute_import
import ctypes
Expand Down Expand Up @@ -45,7 +45,7 @@ def init_devices(contexts):
check_call(_LIB.MXKVStoreInitDevices(len(contexts), masks, ids))

def stop():
""" stop kvstore """
""" Stop kvstore """
check_call(_LIB.MXKVStoreStop())


Expand All @@ -54,9 +54,10 @@ def init(keys, values):
Parameters
----------
kv_list : tuple or list/generator of tuples
a key-value tuple or a list of key-value tuples, where key is int and
key is
keys: int or list of int
A single key or a list of keys
values: NArray or list of NArray
A single value of a list of values
"""
num, ckeys, cvals = _ctype_key_value(keys, values)
check_call(_LIB.MXKVStoreInit(num, ckeys, cvals))
Expand All @@ -66,6 +67,10 @@ def push(keys, values):
Parameters
----------
keys: int or list of int
A single key or a list of keys
values: NArray or list of NArray
A single value of a list of values
"""
num, ckeys, cvals = _ctype_key_value(keys, values)
check_call(_LIB.MXKVStorePush(num, ckeys, cvals))
Expand All @@ -75,36 +80,37 @@ def pull(keys, values):
Parameters
----------
key : int
The key
value : NArray
The value
keys: int or list of int
A single key or a list of keys
values: NArray or list of NArray
A single value of a list of values
"""
num, ckeys, cvals = _ctype_key_value(keys, values)
check_call(_LIB.MXKVStorePull(num, ckeys, cvals))

# def _updater_wrapper(updater):
# def updater_handle(lhs_handle, rhs_handle):
# updater(NArray(lhs_handle), NArray(rhs_handle))
# return updater_handle
def _updater_wrapper(updater):
""" a wrapper for the user-defined handle """
def updater_handle(lhs_handle, rhs_handle):
""" ctypes function """
lhs = NArray(NArrayHandle(lhs_handle))
rhs = NArray(NArrayHandle(rhs_handle))
updater(lhs, rhs)
return updater_handle

# def _void_updater(lhs, rhs):
# pass
def set_updater(updater):
""" set a updater into the store
# _updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle)
# _updater_func = _updater_proto(_updater_wrapper(_void_updater))
Example:
# def register(updater):
# """ Register a updater into the store
def updater(recv, local):
local += recv
kvstore.set_updater(updater)
# Example:
# def Update(grad, weight):
# weight[:] -= lr * grad / batch_size

# Parameters
# ----------

# """
# global _updater_func
# updater_func = _updater_proto(updater)
# check_call(_LIB.MXKVStoreRegister(updater_func))
Parameters
----------
updater: functon
"""
_updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle)
global _updater_func
_updater_func = _updater_proto(_updater_wrapper(updater))
check_call(_LIB.MXKVStoreSetUpdater(_updater_func))
12 changes: 8 additions & 4 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -885,12 +885,16 @@ int MXKVStoreStop() {
API_END();
}

int MXKVStoreRegister(MXKVStoreUpdater updater) {
int MXKVStoreSetUpdater(MXKVStoreUpdater updater) {
API_BEGIN();
auto updt = [updater](const NArray& recv, NArray* local) {
NArray recv_copy = recv;
updater(&recv_copy, local);
NArray* recv_copy = new NArray();
*recv_copy = recv;
NArray* local_copy = new NArray();
*local_copy = *local;
updater(recv_copy, local_copy);
};
// KVStore::Get()->Register(updt);

KVStore::Get()->set_updater(updt);
API_END();
}
2 changes: 0 additions & 2 deletions src/kvstore/kvstore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,4 @@ void KVStore::InitDevices(const std::vector<Context>& devices) {
impl_->InitDevices(devices);
}

KVStore::~KVStore() { Clear(); }

} // namespace mxnet
Loading

0 comments on commit 44fa3b9

Please sign in to comment.