Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #53 from dmlc/ps
Browse files Browse the repository at this point in the history
data aggregation over multiple devices
  • Loading branch information
mli committed Sep 9, 2015
2 parents ba28bd8 + 2c57ccb commit 44bd80a
Show file tree
Hide file tree
Showing 12 changed files with 774 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ doxygen:
doxygen doc/Doxyfile

clean:
$(RM) -r build lib/* *~ */*~ */*/*~ */*/*/*~
$(RM) -r build lib/lib* *~ */*~ */*/*~ */*/*/*~
cd $(DMLC_CORE); make clean; cd -

-include build/*.d
Expand Down
57 changes: 57 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -707,5 +707,62 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
*/
MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle,
NArrayHandle *out);
/*!
* \brief initialize the kvstore
* \param num_devs number of devices
* \param dev_masks the list of device masks
* \param dev_ids the list of device IDs
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreInitDevices(mx_uint num_devs,
int *dev_masks,
int *dev_ids);
/*!
* \brief stop the kvstore
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreStop();

/*!
* \brief Init (key,value) in kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreInit(int num,
int* keys,
NArrayHandle* vals);

/*!
* \brief Push (key,value) to kvstore
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePush(int num,
int* keys,
NArrayHandle* vals);


/*!
* \brief pull value from kvstore on the given key
* \param num the number of key-value pairs
* \param keys the list of keys
* \param vals the list of values
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStorePull(int num,
int* keys,
NArrayHandle* vals);

typedef void (MXKVStoreUpdater)(NArrayHandle recv, NArrayHandle local);
/*!
* \brief register an push updater
* \param updater udpater function
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXKVStoreRegister(MXKVStoreUpdater updater);

#endif // MXNET_C_API_H_
19 changes: 18 additions & 1 deletion include/mxnet/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
*/
#ifndef MXNET_CONTEXT_H_
#define MXNET_CONTEXT_H_

#include <dmlc/io.h>
#include <dmlc/type_traits.h>
#include <sstream>
#include <string>
#include "./base.h"

namespace mxnet {
Expand Down Expand Up @@ -61,6 +62,22 @@ struct Context {
if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false;
return true;
}

/**
* \brief returns an unique ID
*/
inline uint64_t UID() const {
return static_cast<uint64_t>(dev_mask) << 32 | dev_id;
}

/**
* \brief returns an unique string name
*/
inline std::string Name() const {
std::stringstream ss;
ss << (dev_mask == cpu::kDevMask ? "cpu" : "gpu") << ":" << dev_id;
return ss.str();
}
};

/*!
Expand Down
204 changes: 204 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*!
* Copyright (c) 2015 by Contributors
* \file kvstore.h
* \brief key-value store interface for mxnet
*/
#ifndef MXNET_KVSTORE_H_
#define MXNET_KVSTORE_H_
#include <dmlc/io.h>
#include <vector>
#if DMLC_USE_CXX11
#include <functional>
#endif // DMLC_USE_CXX11
#include "narray.h"
#include "dag_engine.h"

namespace mxnet {

/**
* \brief distributed key-value store
*
* A distributed key-value store for data synchronization over multiple
* devices/machines. It supports user-defined updater
*
* Example to implement allreduce
* \code
* NArray data;
* // init data...
* KVStore store;
* store.Push(0, data);
* store.Pull(0, &data);
* data.Wait();
* \endcode
*
* Example to implement asynchronous SGD
* \code
* Worker store;
* auto updater = [](const NArray& recv, NArray* weight) {
* *weight += 0.1 * recv; // recv is grad
* }
* store.Register(false, updater);
*
* NArray weight, grad;
* if (store.GetRank() == 0) {
* store.Init(0, weight);
* }
* store.Pull(0, &weight);
* // compute grad
* store.Push(0, grad);
*/
class KVStore {
public:
/**
* \brief get singleton instance
*/
static KVStore* Get() { static KVStore store; return &store; }

/**
* \brief Init with the local devices
*/
virtual void InitDevices(const std::vector<Context>& devices);

/**
* \brief data
*
* init a key-value pair. One must insert before push and pull
*/
virtual void Init(int key, const NArray& value) {
CHECK(impl_) << "call InitDevices first";
impl_->Init(key, value);
}

/*!
* \brief push data to the store
*
* Push the key-value pair (\a key, \a value) to the store. This
* function returns after adding a push operator to the engine. Any following
* operator requiring writing \a value will be blocked until the actual push is
* finished.
*
* One can wait the push is finished via `data.Wait()`
*
* For each push, a user-defined updater is called to merge
* the value sent to the one maintained by itself.
*
* For a given \a key, the \a value should be always has the same size over
*
* \param key the key for pushing
* \param value the value for pushing
*/
virtual void Push(int key, const NArray& value) {
CHECK(impl_) << "call InitDevices first";
impl_->Push(key, value);
}

/*!
* \brief pull data from the server nodes
*
* Pull the \a value associated with the \a key from the store. This
* function returns after adding a pull operator to the engine. Any following
* operator requiring reading \a data will be blocked until the actual pull is
* finished.
*
* One can wait the pull is finished via `data.Wait()`
*
* Before sending back the value, the store will wait all pushed issued by
* this worker on \a key have been applied (updater has been triggered)
* and \a value is initialized
*
* \param key the key for pulling
* \param value data for pulling, should be pre-allocated
*/
virtual void Pull(int key, NArray* value) {
CHECK(impl_) << "call InitDevices first";
impl_->Pull(key, value);
}

/**
* \brief clear all data stored, handles registered, and devices binded
*/
virtual void Stop() {
CHECK(impl_) << "call InitDevices first";
impl_->Stop();
Clear();
}

#if DMLC_USE_CXX11
/**
* \brief user-defined updater
*/
using Updater = std::function<void(const NArray&, NArray*)>;

/**
* \brief set an updater
*
* The server allows user-defined handle to modify the data. Given a key,
* assume \a x is the received value and \a y is the value stored on the server
* node. The server updates \a y by `h(x, &y)`. The default \a h is ASSIGN,
* namely `*y = x`.
*
* The handle is triggered in two ways:
*
* - online: \a h is called every time when \a x is received from a worker. It
* is often used for asynchronous optimization.
*
* - batch: \a h is called after data have been aggregated over all
* workers. Assume \f$ x_i \f$ is received from worker i. Then the server
* first computes \f$\sum_{i=0}^n x = x_i\f$, and then applies \a h. It is often
* used for synchronous optimization
*
* Must be called before \ref Init
* \param batch true for batch, false for online
* \param updt user-defined updater, default is assign
*/
void set_updater(const Updater& updater) { updater_ = updater; }
#endif // DMLC_USE_CXX11

/**
* \brief set aggregator
* The aggregator first aggregate all pushed data among all devices before
* applying the updater
*
* The aggregator is enabled in default
*
* \param aggregator false to disable
*/
void set_aggregator(bool aggregator) { aggregator_ = aggregator; }

/*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */
int get_rank() { return rank_; }

/*! \brief Get the number of nodes in this group. */
int get_group_size() { return group_size_; }

protected:
virtual ~KVStore();
KVStore() : engine_(DAGEngine::Get()), impl_(NULL) { Clear(); }
DAGEngine* engine_;
int rank_;
int group_size_;
bool aggregator_;

#if DMLC_USE_CXX11
/*! \brief returns the default updater, which is ASSIGN */
Updater DefaultUpdater() {
return [](const NArray& a, NArray* b) { CopyFromTo(a, b); };
}
Updater updater_;
#endif // DMLC_USE_CXX11

private:
void Clear() {
delete impl_;
impl_ = NULL;
updater_ = DefaultUpdater();
aggregator_ = true;
rank_ = 0;
group_size_ = 1;
}
KVStore* impl_;
DISALLOW_COPY_AND_ASSIGN(KVStore);
};

} // namespace mxnet
#endif // MXNET_KVSTORE_H_
3 changes: 1 addition & 2 deletions python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from .base import MXNetError
from . import narray
from . import symbol
from . import kvstore
from . import io

__version__ = "0.1.0"


Loading

0 comments on commit 44bd80a

Please sign in to comment.