Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

python updater for kvstore #54

Merged
merged 2 commits into from
Sep 10, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,6 @@ typedef void (MXKVStoreUpdater)(NArrayHandle recv, NArrayHandle local);
* \param updater udpater function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreRegister(MXKVStoreUpdater updater);
MXNET_DLL int MXKVStoreSetUpdater(MXKVStoreUpdater updater);

#endif // MXNET_C_API_H_
145 changes: 56 additions & 89 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,41 +11,14 @@
#include <functional>
#endif // DMLC_USE_CXX11
#include "narray.h"
#include "dag_engine.h"

namespace mxnet {

/**
* \brief distributed key-value store
*
* A distributed key-value store for data synchronization over multiple
* devices/machines. It supports user-defined updater
*
* Example to implement allreduce
* \code
* NArray data;
* // init data...
* KVStore store;
* store.Push(0, data);
* store.Pull(0, &data);
* data.Wait();
* \endcode
*
* Example to implement asynchronous SGD
* \code
* Worker store;
* auto updater = [](const NArray& recv, NArray* weight) {
* *weight += 0.1 * recv; // recv is grad
* }
* store.Register(false, updater);
*
* NArray weight, grad;
* if (store.GetRank() == 0) {
* store.Init(0, weight);
* }
* store.Pull(0, &weight);
* // compute grad
* store.Push(0, grad);
* devices/machines. It supports aggregator and user-defined updater.
*/
class KVStore {
public:
Expand All @@ -62,11 +35,11 @@ class KVStore {
/**
* \brief data
*
* init a key-value pair. One must insert before push and pull
* Initialize a key-value pair to the store. For any \a key, this function
* should only be called once.
*/
virtual void Init(int key, const NArray& value) {
CHECK(impl_) << "call InitDevices first";
impl_->Init(key, value);
get_impl()->Init(key, value);
}

/*!
Expand All @@ -77,124 +50,118 @@ class KVStore {
* operator requiring writing \a value will be blocked until the actual push is
* finished.
*
* One can wait the push is finished via `data.Wait()`
* One can wait the push is finished via `data.WaitToWrite()`
*
* For each push, a user-defined updater is called to merge
* the value sent to the one maintained by itself.
* For each push, an updater is called to merge the value to the one
* stored. The default updater is Assign.
*
* For a given \a key, the \a value should be always has the same size over
* One must call Init() on \a key before. And the \a value should be always
* has the same size as being inited.
*
* \param key the key for pushing
* \param value the value for pushing
*/
virtual void Push(int key, const NArray& value) {
CHECK(impl_) << "call InitDevices first";
impl_->Push(key, value);
get_impl()->Push(key, value);
}

/*!
* \brief pull data from the server nodes
* \brief pull data from the store
*
* Pull the \a value associated with the \a key from the store. This
* function returns after adding a pull operator to the engine. Any following
* operator requiring reading \a data will be blocked until the actual pull is
* finished.
*
* One can wait the pull is finished via `data.Wait()`
* One can wait the pull is finished via `data.WaitToRead()`
*
* Before sending back the value, the store will wait all pushed issued by
* this worker on \a key have been applied (updater has been triggered)
* and \a value is initialized
* this worker on \a key have been applied (updater has been applied). One
* must call Init() on \a key before.
*
* \param key the key for pulling
* \param value data for pulling, should be pre-allocated
*/
virtual void Pull(int key, NArray* value) {
CHECK(impl_) << "call InitDevices first";
impl_->Pull(key, value);
get_impl()->Pull(key, value);
}

/**
* \brief clear all data stored, handles registered, and devices binded
* \brief clear all key-value pairs stored, updater, and devices binded
*/
virtual void Stop() {
CHECK(impl_) << "call InitDevices first";
impl_->Stop();
Clear();
}
virtual void Stop() { get_impl()->Stop(); delete impl_; impl_ = NULL; }

#if DMLC_USE_CXX11
/**
* \brief user-defined updater
* \brief the prototype of user-defined updater
*/
using Updater = std::function<void(const NArray&, NArray*)>;

/*! \brief returns the default updater, which is ASSIGN */
Updater DefaultUpdater() {
return [](const NArray& a, NArray* b) { CopyFromTo(a, b); };
}

/**
* \brief set an updater
*
* The server allows user-defined handle to modify the data. Given a key,
* assume \a x is the received value and \a y is the value stored on the server
* node. The server updates \a y by `h(x, &y)`. The default \a h is ASSIGN,
* namely `*y = x`.
*
* The handle is triggered in two ways:
* Given a key, assume \a x is the received (pushed) value and \a y is the
* value stored on the store node. The store updates \a y by `h(x, &y)`. The
* default \a h is ASSIGN, namely `*y = x`.
*
* - online: \a h is called every time when \a x is received from a worker. It
* is often used for asynchronous optimization.
* The updater is applied in two ways depends on whether there is an aggregator
*
* - batch: \a h is called after data have been aggregated over all
* - yes: \a h is called after data have been aggregated over all
* workers. Assume \f$ x_i \f$ is received from worker i. Then the server
* first computes \f$\sum_{i=0}^n x = x_i\f$, and then applies \a h. It is often
* used for synchronous optimization
*
* Must be called before \ref Init
* - no: \a h is called every time when \a x is received from a worker. It
* is often used for asynchronous optimization.
*
* \param batch true for batch, false for online
* \param updt user-defined updater, default is assign
*/
void set_updater(const Updater& updater) { updater_ = updater; }
virtual void set_updater(const Updater& updater) {
get_impl()->set_updater(updater);
}

#endif // DMLC_USE_CXX11

/**
* \brief set aggregator
* The aggregator first aggregate all pushed data among all devices before
* applying the updater
*
* The aggregator is enabled in default
*
* \param aggregator false to disable
*/
void set_aggregator(bool aggregator) { aggregator_ = aggregator; }

/*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */
int get_rank() { return rank_; }
virtual void set_aggregator(bool aggregator) {
get_impl()->set_aggregator(aggregator);
}

/*! \brief Get the number of nodes in this group. */
int get_group_size() { return group_size_; }
/*!
* \brief Gets rank of this node in its group, which is in [0, GroupSize).
*/
virtual int get_rank() const {
return get_impl()->get_rank();
}

protected:
virtual ~KVStore();
KVStore() : engine_(DAGEngine::Get()), impl_(NULL) { Clear(); }
DAGEngine* engine_;
int rank_;
int group_size_;
bool aggregator_;

#if DMLC_USE_CXX11
/*! \brief returns the default updater, which is ASSIGN */
Updater DefaultUpdater() {
return [](const NArray& a, NArray* b) { CopyFromTo(a, b); };
/*!
* \brief Get the number of nodes in this group.
*/
virtual int get_group_size() const {
return get_impl()->get_group_size();
}
Updater updater_;
#endif // DMLC_USE_CXX11

protected:
KVStore() : impl_(NULL) { }
virtual ~KVStore() { delete impl_; impl_ = NULL; }

private:
void Clear() {
delete impl_;
impl_ = NULL;
updater_ = DefaultUpdater();
aggregator_ = true;
rank_ = 0;
group_size_ = 1;
inline KVStore* get_impl() const {
CHECK(impl_) << "call InitDevices() first";
return impl_;
}
KVStore* impl_;
DISALLOW_COPY_AND_ASSIGN(KVStore);
Expand Down
66 changes: 36 additions & 30 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# coding: utf-8
# pylint: disable=invalid-name,
# pylint: disable=invalid-name, global-variable-undefined,
""" KVStore in mxnet """
from __future__ import absolute_import
import ctypes
Expand Down Expand Up @@ -45,7 +45,7 @@ def init_devices(contexts):
check_call(_LIB.MXKVStoreInitDevices(len(contexts), masks, ids))

def stop():
""" stop kvstore """
""" Stop kvstore """
check_call(_LIB.MXKVStoreStop())


Expand All @@ -54,9 +54,10 @@ def init(keys, values):

Parameters
----------
kv_list : tuple or list/generator of tuples
a key-value tuple or a list of key-value tuples, where key is int and
key is
keys: int or list of int
A single key or a list of keys
values: NArray or list of NArray
A single value of a list of values
"""
num, ckeys, cvals = _ctype_key_value(keys, values)
check_call(_LIB.MXKVStoreInit(num, ckeys, cvals))
Expand All @@ -66,6 +67,10 @@ def push(keys, values):

Parameters
----------
keys: int or list of int
A single key or a list of keys
values: NArray or list of NArray
A single value of a list of values
"""
num, ckeys, cvals = _ctype_key_value(keys, values)
check_call(_LIB.MXKVStorePush(num, ckeys, cvals))
Expand All @@ -75,36 +80,37 @@ def pull(keys, values):

Parameters
----------
key : int
The key
value : NArray
The value
keys: int or list of int
A single key or a list of keys
values: NArray or list of NArray
A single value of a list of values
"""
num, ckeys, cvals = _ctype_key_value(keys, values)
check_call(_LIB.MXKVStorePull(num, ckeys, cvals))

# def _updater_wrapper(updater):
# def updater_handle(lhs_handle, rhs_handle):
# updater(NArray(lhs_handle), NArray(rhs_handle))
# return updater_handle
def _updater_wrapper(updater):
""" a wrapper for the user-defined handle """
def updater_handle(lhs_handle, rhs_handle):
""" ctypes function """
lhs = NArray(NArrayHandle(lhs_handle))
rhs = NArray(NArrayHandle(rhs_handle))
updater(lhs, rhs)
return updater_handle

# def _void_updater(lhs, rhs):
# pass
def set_updater(updater):
""" set a updater into the store

# _updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle)
# _updater_func = _updater_proto(_updater_wrapper(_void_updater))
Example:

# def register(updater):
# """ Register a updater into the store
def updater(recv, local):
local += recv
kvstore.set_updater(updater)

# Example:
# def Update(grad, weight):
# weight[:] -= lr * grad / batch_size

# Parameters
# ----------

# """
# global _updater_func
# updater_func = _updater_proto(updater)
# check_call(_LIB.MXKVStoreRegister(updater_func))
Parameters
----------
updater: functon
"""
_updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle)
global _updater_func
_updater_func = _updater_proto(_updater_wrapper(updater))
check_call(_LIB.MXKVStoreSetUpdater(_updater_func))
12 changes: 8 additions & 4 deletions src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -885,12 +885,16 @@ int MXKVStoreStop() {
API_END();
}

int MXKVStoreRegister(MXKVStoreUpdater updater) {
int MXKVStoreSetUpdater(MXKVStoreUpdater updater) {
API_BEGIN();
auto updt = [updater](const NArray& recv, NArray* local) {
NArray recv_copy = recv;
updater(&recv_copy, local);
NArray* recv_copy = new NArray();
*recv_copy = recv;
NArray* local_copy = new NArray();
*local_copy = *local;
updater(recv_copy, local_copy);
};
// KVStore::Get()->Register(updt);

KVStore::Get()->set_updater(updt);
API_END();
}
2 changes: 0 additions & 2 deletions src/kvstore/kvstore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,4 @@ void KVStore::InitDevices(const std::vector<Context>& devices) {
impl_->InitDevices(devices);
}

KVStore::~KVStore() { Clear(); }

} // namespace mxnet
Loading