diff --git a/Makefile b/Makefile index 1bbfc12655a5..5334e52ab52e 100644 --- a/Makefile +++ b/Makefile @@ -124,7 +124,7 @@ doxygen: doxygen doc/Doxyfile clean: - $(RM) -r build lib/* *~ */*~ */*/*~ */*/*/*~ + $(RM) -r build lib/lib* *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - -include build/*.d diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index d43e0576fab3..0fcb90acbb6a 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -707,5 +707,62 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, */ MXNET_DLL int MXDataIterGetLabel(DataIterHandle handle, NArrayHandle *out); +/*! + * \brief initialize the kvstore + * \param num_devs number of devices + * \param dev_masks the list of device masks + * \param dev_ids the list of device IDs + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInitDevices(mx_uint num_devs, + int *dev_masks, + int *dev_ids); +/*! + * \brief stop the kvstore + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreStop(); + +/*! + * \brief Init (key,value) in kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreInit(int num, + int* keys, + NArrayHandle* vals); + +/*! + * \brief Push (key,value) to kvstore + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePush(int num, + int* keys, + NArrayHandle* vals); + + +/*! + * \brief pull value from kvstore on the given key + * \param num the number of key-value pairs + * \param keys the list of keys + * \param vals the list of values + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStorePull(int num, + int* keys, + NArrayHandle* vals); + +typedef void (MXKVStoreUpdater)(NArrayHandle recv, NArrayHandle local); +/*! + * \brief register an push updater + * \param updater udpater function + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXKVStoreRegister(MXKVStoreUpdater updater); #endif // MXNET_C_API_H_ diff --git a/include/mxnet/context.h b/include/mxnet/context.h index 02f10231df1d..33a0865bcd0b 100644 --- a/include/mxnet/context.h +++ b/include/mxnet/context.h @@ -5,9 +5,10 @@ */ #ifndef MXNET_CONTEXT_H_ #define MXNET_CONTEXT_H_ - #include #include +#include +#include #include "./base.h" namespace mxnet { @@ -61,6 +62,22 @@ struct Context { if (strm->Read(&dev_id, sizeof(int32_t)) != sizeof(int32_t)) return false; return true; } + + /** + * \brief returns an unique ID + */ + inline uint64_t UID() const { + return static_cast(dev_mask) << 32 | dev_id; + } + + /** + * \brief returns an unique string name + */ + inline std::string Name() const { + std::stringstream ss; + ss << (dev_mask == cpu::kDevMask ? "cpu" : "gpu") << ":" << dev_id; + return ss.str(); + } }; /*! diff --git a/include/mxnet/kvstore.h b/include/mxnet/kvstore.h new file mode 100644 index 000000000000..db7cf4e2954a --- /dev/null +++ b/include/mxnet/kvstore.h @@ -0,0 +1,204 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file kvstore.h + * \brief key-value store interface for mxnet + */ +#ifndef MXNET_KVSTORE_H_ +#define MXNET_KVSTORE_H_ +#include +#include +#if DMLC_USE_CXX11 +#include +#endif // DMLC_USE_CXX11 +#include "narray.h" +#include "dag_engine.h" + +namespace mxnet { + +/** + * \brief distributed key-value store + * + * A distributed key-value store for data synchronization over multiple + * devices/machines. It supports user-defined updater + * + * Example to implement allreduce + * \code + * NArray data; + * // init data... + * KVStore store; + * store.Push(0, data); + * store.Pull(0, &data); + * data.Wait(); + * \endcode + * + * Example to implement asynchronous SGD + * \code + * Worker store; + * auto updater = [](const NArray& recv, NArray* weight) { + * *weight += 0.1 * recv; // recv is grad + * } + * store.Register(false, updater); + * + * NArray weight, grad; + * if (store.GetRank() == 0) { + * store.Init(0, weight); + * } + * store.Pull(0, &weight); + * // compute grad + * store.Push(0, grad); + */ +class KVStore { + public: + /** + * \brief get singleton instance + */ + static KVStore* Get() { static KVStore store; return &store; } + + /** + * \brief Init with the local devices + */ + virtual void InitDevices(const std::vector& devices); + + /** + * \brief data + * + * init a key-value pair. One must insert before push and pull + */ + virtual void Init(int key, const NArray& value) { + CHECK(impl_) << "call InitDevices first"; + impl_->Init(key, value); + } + + /*! + * \brief push data to the store + * + * Push the key-value pair (\a key, \a value) to the store. This + * function returns after adding a push operator to the engine. Any following + * operator requiring writing \a value will be blocked until the actual push is + * finished. + * + * One can wait the push is finished via `data.Wait()` + * + * For each push, a user-defined updater is called to merge + * the value sent to the one maintained by itself. + * + * For a given \a key, the \a value should be always has the same size over + * + * \param key the key for pushing + * \param value the value for pushing + */ + virtual void Push(int key, const NArray& value) { + CHECK(impl_) << "call InitDevices first"; + impl_->Push(key, value); + } + + /*! + * \brief pull data from the server nodes + * + * Pull the \a value associated with the \a key from the store. This + * function returns after adding a pull operator to the engine. Any following + * operator requiring reading \a data will be blocked until the actual pull is + * finished. + * + * One can wait the pull is finished via `data.Wait()` + * + * Before sending back the value, the store will wait all pushed issued by + * this worker on \a key have been applied (updater has been triggered) + * and \a value is initialized + * + * \param key the key for pulling + * \param value data for pulling, should be pre-allocated + */ + virtual void Pull(int key, NArray* value) { + CHECK(impl_) << "call InitDevices first"; + impl_->Pull(key, value); + } + + /** + * \brief clear all data stored, handles registered, and devices binded + */ + virtual void Stop() { + CHECK(impl_) << "call InitDevices first"; + impl_->Stop(); + Clear(); + } + +#if DMLC_USE_CXX11 + /** + * \brief user-defined updater + */ + using Updater = std::function; + + /** + * \brief set an updater + * + * The server allows user-defined handle to modify the data. Given a key, + * assume \a x is the received value and \a y is the value stored on the server + * node. The server updates \a y by `h(x, &y)`. The default \a h is ASSIGN, + * namely `*y = x`. + * + * The handle is triggered in two ways: + * + * - online: \a h is called every time when \a x is received from a worker. It + * is often used for asynchronous optimization. + * + * - batch: \a h is called after data have been aggregated over all + * workers. Assume \f$ x_i \f$ is received from worker i. Then the server + * first computes \f$\sum_{i=0}^n x = x_i\f$, and then applies \a h. It is often + * used for synchronous optimization + * + * Must be called before \ref Init + * \param batch true for batch, false for online + * \param updt user-defined updater, default is assign + */ + void set_updater(const Updater& updater) { updater_ = updater; } +#endif // DMLC_USE_CXX11 + + /** + * \brief set aggregator + * The aggregator first aggregate all pushed data among all devices before + * applying the updater + * + * The aggregator is enabled in default + * + * \param aggregator false to disable + */ + void set_aggregator(bool aggregator) { aggregator_ = aggregator; } + + /*! \brief Gets rank of this node in its group, which is in [0, GroupSize) */ + int get_rank() { return rank_; } + + /*! \brief Get the number of nodes in this group. */ + int get_group_size() { return group_size_; } + + protected: + virtual ~KVStore(); + KVStore() : engine_(DAGEngine::Get()), impl_(NULL) { Clear(); } + DAGEngine* engine_; + int rank_; + int group_size_; + bool aggregator_; + +#if DMLC_USE_CXX11 + /*! \brief returns the default updater, which is ASSIGN */ + Updater DefaultUpdater() { + return [](const NArray& a, NArray* b) { CopyFromTo(a, b); }; + } + Updater updater_; +#endif // DMLC_USE_CXX11 + + private: + void Clear() { + delete impl_; + impl_ = NULL; + updater_ = DefaultUpdater(); + aggregator_ = true; + rank_ = 0; + group_size_ = 1; + } + KVStore* impl_; + DISALLOW_COPY_AND_ASSIGN(KVStore); +}; + +} // namespace mxnet +#endif // MXNET_KVSTORE_H_ diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index a8632bfa2ff8..78b986e9c6e0 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -12,8 +12,7 @@ from .base import MXNetError from . import narray from . import symbol +from . import kvstore from . import io __version__ = "0.1.0" - - diff --git a/python/mxnet/kvstore.py b/python/mxnet/kvstore.py new file mode 100644 index 000000000000..bbdaee02f97e --- /dev/null +++ b/python/mxnet/kvstore.py @@ -0,0 +1,110 @@ +# coding: utf-8 +# pylint: disable=invalid-name, +""" KVStore in mxnet """ +from __future__ import absolute_import +import ctypes +from .narray import NArray +from .base import _LIB +from .base import check_call, c_array, NArrayHandle + +def _ctype_key_value(keys, vals): + """ parse key-value args into ctype""" + if isinstance(keys, int): + if isinstance(vals, NArray): + return (1, + c_array(ctypes.c_int, [keys]), + c_array(NArrayHandle, [vals.handle])) + else: + for v in vals: + assert(isinstance(v, NArray)) + return (len(vals), + c_array(ctypes.c_int, [keys] * len(vals)), + c_array(NArrayHandle, [v.handle for v in vals])) + else: + for k in keys: + assert(isinstance(k, int)) + if len(keys) == 1: + return _ctype_key_value(keys[0], vals) + assert(len(keys) == len(vals)) + for v in vals: + assert(isinstance(v, NArray)) + return (len(keys), + c_array(ctypes.c_int, keys), + c_array(NArrayHandle, [v.handle for v in vals])) + +def init_devices(contexts): + """ Init key-value store with a list of device contexts + + Parameters + ---------- + contexts : list of Context + The list of local devices used by this process + """ + masks = c_array(ctypes.c_int, [c.device_mask for c in contexts]) + ids = c_array(ctypes.c_int, [c.device_id for c in contexts]) + check_call(_LIB.MXKVStoreInitDevices(len(contexts), masks, ids)) + +def stop(): + """ stop kvstore """ + check_call(_LIB.MXKVStoreStop()) + + +def init(keys, values): + """ Initialize a list of key-value pairs + + Parameters + ---------- + kv_list : tuple or list/generator of tuples + a key-value tuple or a list of key-value tuples, where key is int and + key is + """ + num, ckeys, cvals = _ctype_key_value(keys, values) + check_call(_LIB.MXKVStoreInit(num, ckeys, cvals)) + +def push(keys, values): + """ Push a value into the store + + Parameters + ---------- + """ + num, ckeys, cvals = _ctype_key_value(keys, values) + check_call(_LIB.MXKVStorePush(num, ckeys, cvals)) + +def pull(keys, values): + """ Pull the value from the store + + Parameters + ---------- + key : int + The key + value : NArray + The value + """ + num, ckeys, cvals = _ctype_key_value(keys, values) + check_call(_LIB.MXKVStorePull(num, ckeys, cvals)) + +# def _updater_wrapper(updater): +# def updater_handle(lhs_handle, rhs_handle): +# updater(NArray(lhs_handle), NArray(rhs_handle)) +# return updater_handle + +# def _void_updater(lhs, rhs): +# pass + +# _updater_proto = ctypes.CFUNCTYPE(None, NArrayHandle, NArrayHandle) +# _updater_func = _updater_proto(_updater_wrapper(_void_updater)) + +# def register(updater): +# """ Register a updater into the store + +# Example: +# def Update(grad, weight): +# weight[:] -= lr * grad / batch_size + +# Parameters +# ---------- + +# """ +# global _updater_func +# updater_func = _updater_proto(updater) +# check_call(_LIB.MXKVStoreRegister(updater_func)) diff --git a/src/c_api.cc b/src/c_api.cc index a01ba5cd75ee..9069582b51ac 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -13,11 +13,13 @@ #include #include #include +#include #include #include #include #include #include +#include // macro hanlding for threadlocal variables #ifdef __GNUC__ @@ -842,3 +844,53 @@ int MXDataIterGetData(DataIterHandle handle, NArrayHandle *out) { *out = new NArray(db.data[0], 0); API_END(); } + +int MXKVStoreInit(int num, int* keys, NArrayHandle* vals) { + API_BEGIN(); + for (int i = 0; i < num; ++i) { + KVStore::Get()->Init(keys[i], *static_cast(vals[i])); + } + API_END(); +} + +int MXKVStorePush(int num, int* keys, NArrayHandle* vals) { + API_BEGIN(); + for (int i = 0; i < num; ++i) { + KVStore::Get()->Push(keys[i], *static_cast(vals[i])); + } + API_END(); +} + +int MXKVStorePull(int num, int* keys, NArrayHandle* vals) { + API_BEGIN(); + for (int i = 0; i < num; ++i) { + KVStore::Get()->Pull(keys[i], static_cast(vals[i])); + } + API_END(); +} + +int MXKVStoreInitDevices(mx_uint num_devs, int *dev_masks, int *dev_ids) { + API_BEGIN(); + std::vector devs; + for (mx_uint i = 0; i < num_devs; ++i) { + devs.push_back(Context(dev_masks[i], dev_ids[i])); + } + KVStore::Get()->InitDevices(devs); + API_END(); +} + +int MXKVStoreStop() { + API_BEGIN(); + KVStore::Get()->Stop(); + API_END(); +} + +int MXKVStoreRegister(MXKVStoreUpdater updater) { + API_BEGIN(); + auto updt = [updater](const NArray& recv, NArray* local) { + NArray recv_copy = recv; + updater(&recv_copy, local); + }; + // KVStore::Get()->Register(updt); + API_END(); +} diff --git a/src/kvstore/kvstore.cc b/src/kvstore/kvstore.cc new file mode 100644 index 000000000000..b2e043fda95e --- /dev/null +++ b/src/kvstore/kvstore.cc @@ -0,0 +1,26 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file kvstore.cc + * \brief implement kv_store + */ +#include "mxnet/kvstore.h" +#include +#include "dmlc/logging.h" +#include "kvstore_local.h" + +namespace mxnet { + +void KVStore::InitDevices(const std::vector& devices) { + CHECK(impl_ == NULL) << "double initialization, call Stop() first"; + char* num_worker = getenv("DMLC_NUM_WORKER"); + if (num_worker == NULL || atoi(num_worker) == 1) { + impl_ = new KVStoreLocal(); + } else { + LOG(FATAL) << "not implemented yet"; + } + impl_->InitDevices(devices); +} + +KVStore::~KVStore() { Clear(); } + +} // namespace mxnet diff --git a/src/kvstore/kvstore_local.h b/src/kvstore/kvstore_local.h new file mode 100644 index 000000000000..4115f868f547 --- /dev/null +++ b/src/kvstore/kvstore_local.h @@ -0,0 +1,134 @@ +/** + * Copyright (c) 2015 by Contributors + * @file kvstore_local.h + * @brief local implementation + */ +#ifndef MXNET_KVSTORE_KVSTORE_LOCAL_H_ +#define MXNET_KVSTORE_KVSTORE_LOCAL_H_ +#include +#include +#include +#include "mxnet/kvstore.h" + +namespace mxnet { + +/** + * \brief store data in local machine + */ +class KVStoreLocal : public KVStore { + public: + KVStoreLocal() { Clear(); } + virtual ~KVStoreLocal() { Clear(); } + + virtual void InitDevices(const std::vector& devices) { + num_devs_ = 0; + for (auto d : devices) devs_[d.UID()] = num_devs_++; + } + + virtual void Init(int key, const NArray& val) { + CHECK(local_.find(key) == local_.end()) << "duplicate init of key " << key; + Value lc_v(num_devs_, val.Copy(local_ctx_)); + local_.insert({key, lc_v}); + } + + virtual void Push(int key, const NArray& val) { + auto it = local_.find(key); + CHECK(it != local_.end()) << "key " << key << " has not been inited"; + auto& lc_v = it->second; + CHECK_EQ(lc_v.val.shape(), val.shape()) + << "shape mismatch: " << lc_v.val.shape() << ", " << val.shape(); + + if (aggregator_) { + int dix = GetDevIdx(val.ctx()); + CHECK(!lc_v.pending_push[dix]) + << "duplicate push on key " << key << "from " << val.ctx().Name(); + lc_v.pending_push[dix] = true; + lc_v.num_pending_push++; + + if (lc_v.agg_buf.is_none()) { + lc_v.agg_buf = NArray(lc_v.val.shape(), local_ctx_); + lc_v.agg_buf = 0.0; + } + if (val.ctx().dev_mask == cpu::kDevMask) { + lc_v.agg_buf += val; + } else { + // copy to pinned memory + LOG(FATAL) << "TODO"; + } + + if (lc_v.num_pending_push == num_devs_) { + // apply updater + if (updater_) updater_(lc_v.agg_buf, &lc_v.val); + + // clean + lc_v.agg_buf = 0.0; + lc_v.pending_push.flip(); + lc_v.num_pending_push = 0; + + // issue blocked pull + for (auto& w : lc_v.pending_pull_val) { + CopyFromTo(lc_v.val, &w); + } + lc_v.pending_pull_val.clear(); + } + } else { + LOG(FATAL) << "TODO"; + } + } + + virtual void Pull(int key, NArray* val) { + auto it = local_.find(key); + CHECK(it != local_.end()) << "key " << key << " has not been inited"; + auto& lc_v = it->second; + CHECK_EQ(lc_v.val.shape(), val->shape()) + << "shape mismatch: " << lc_v.val.shape() << ", " << val->shape(); + + if (aggregator_) { + int dix = GetDevIdx(val->ctx()); + if (lc_v.pending_push[dix]) { + lc_v.pending_pull_val.push_back(*val); + return; + } + CopyFromTo(lc_v.val, val); + } else { + LOG(FATAL) << "TODO"; + } + } + + virtual void Stop() { Clear(); } + + protected: + void Clear() { + num_devs_ = 0; + devs_.clear(); + local_.clear(); + } + + /// get the continous device index starting from 0 + inline int GetDevIdx(const Context& ctx) { + auto it = devs_.find(ctx.UID()); + CHECK(it != devs_.end()) << "unknow device " << ctx.Name(); + return it->second; + } + size_t num_devs_; + /// map a device into an index + std::unordered_map devs_; + + /// internal storage of a value + struct Value { + Value() {} + Value(int num_devs, NArray data) + : pending_push(num_devs, false), num_pending_push(0) { + val = data; + } + std::vector pending_push; + std::vector pending_push_val, pending_pull_val; + size_t num_pending_push; + NArray val, agg_buf; + }; + std::unordered_map local_; + Context local_ctx_; +}; + +} // namespace mxnet +#endif // MXNET_KVSTORE_KVSTORE_LOCAL_H_ diff --git a/src/narray/narray.cc b/src/narray/narray.cc index c9dda3f5a654..e65f4127e65d 100644 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -25,12 +25,18 @@ template inline void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out) { - CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; + // no check if both of them are on cpu + if (lhs.ctx().dev_mask != cpu::kDevMask || rhs.ctx().dev_mask != cpu::kDevMask) + CHECK(lhs.ctx() == rhs.ctx()) << "operands context mismatch"; // if out is none, allocate space if (out->is_none()) { *out = NArray(OP::GetShape(lhs.shape(), rhs.shape()), lhs.ctx(), true); } else { - CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; + // no check if both of them are on cpu + if (lhs.ctx().dev_mask != cpu::kDevMask || + out->ctx().dev_mask != cpu::kDevMask) { + CHECK(out->ctx() == lhs.ctx()) << "target context mismatch"; + } CHECK(out->shape() == OP::GetShape(lhs.shape(), rhs.shape())) << "target shape mismatch"; } diff --git a/tests/python/test_kvstore.py b/tests/python/test_kvstore.py new file mode 100644 index 000000000000..020e6de01453 --- /dev/null +++ b/tests/python/test_kvstore.py @@ -0,0 +1,43 @@ +# pylint: skip-file +import mxnet as mx +import numpy as np + +def check_diff_to_scalar(A, x): + """ assert A == x""" + assert(np.sum(np.abs((A - x).asnumpy())) == 0) + +def test_aggregator(): + + num_devs = 2 + devs = [mx.Context('cpu', i) for i in range(num_devs)] + mx.kvstore.init_devices(devs) + + shape = (4, 4) + keys = (5, 9) + + # init all key-value pairs + mx.kvstore.init(keys, [mx.narray.zeros(shape) for k in keys]) + + # first push and then pull on one key + vals = [mx.narray.ones(shape, d) for d in devs] + mx.kvstore.push(keys[0], vals) + out = mx.narray.empty(shape, devs[1]) + mx.kvstore.pull(keys[0], out) + + check_diff_to_scalar(out, num_devs) + + # interleave push and pull for each device + vals = [] + for d in devs: + vals.append([mx.narray.ones(shape, d) for k in keys]) + mx.kvstore.push(keys, vals[-1]) + mx.kvstore.pull(keys, vals[-1]) + + for v in vals: + for d in v: + check_diff_to_scalar(d, num_devs) + + mx.kvstore.stop() + +if __name__ == '__main__': + test_aggregator() diff --git a/tests/python/test_mlp_multi_devices.py.bak b/tests/python/test_mlp_multi_devices.py.bak new file mode 100644 index 000000000000..7a2e7ce1938a --- /dev/null +++ b/tests/python/test_mlp_multi_devices.py.bak @@ -0,0 +1,120 @@ +# pylint: skip-file +import sys +sys.path.append('../../python/') + +import mxnet as mx +import numpy as np +import os, gzip +import pickle as pickle +import get_data + +# symbol net +data = mx.symbol.Variable('data') +fc1 = mx.symbol.FullyConnected(data = data, name='fc1', nb_hidden=128) +act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu") +fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', nb_hidden = 64) +act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu") +fc3 = mx.symbol.FullyConnected(data = act2, name='fc3', nb_hidden=10) +mlp = mx.symbol.Softmax(data = fc3, name = 'mlp') + +# use multiple devices +num_devs = 2 +devs = [mx.Context('cpu', i) for i in range(num_devs)] + +# infer shape +batch_size = 100 +input_shape = (batch_size / num_devs, 784) +param_shapes, out_shapes, aux_shapes = mlp.infer_shape(data=input_shape) +param_names = mlp.list_arguments() + +# allocate memory +params = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; +grads = [[mx.narray.create(s, d) for s in param_shapes] for d in devs]; + +# only need to init param on device 0 +mx.kvstore.init_devices(devs) +sync_keys = [i for i,m in enumerate(param_names) if "weight" in m or "bias" in m] +np.random.seed(0) +for k in sync_keys: + if "weight" in param_names[k]: + params[0][k].numpy[:, :] = np.random.uniform(-0.07, 0.07, v.numpy.shape) + else: + params[0][k].numpy[:] = 0 +mx.kvstore.init((k,params[0][k]) for k in sync_keys) + +# register param updater +def make_updater(env): + def updater(grad, weight): + eta = env['lr'] / sqrt(env['iter']) / env['batch_size'] + env['iter'] += 1 + weight[:] -= eta * grad + return updater + +mx.kvstore.register(make_updater( + {'lr' : 0.1, 'batch_size' : batch_size, 'wd' : .00004})) + +# create exector for each device + +req = ['write_to' for i in range(len(param_names))] +executors = [mlp.bind(devs[i], params[i], grads[i], req) for i in range(num_devs)] +forward_out = [mx.narray.create(e.heads()[0].shape) for e in executors] + +# data reader +get_data.GetMNIST_ubyte() +train_dataiter = mx.io.MNISTIter( + image="data/train-images-idx3-ubyte", + label="data/train-labels-idx1-ubyte", + batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10) +val_dataiter = mx.io.MNISTIter( + image="data/t10k-images-idx3-ubyte", + label="data/t10k-labels-idx1-ubyte", + batch_size=batch_size, shuffle=True, flat=True, silent=False) + +def cal_acc(out, label): + pred = np.argmax(out, axis=1) + return np.sum(pred == label) * 1.0 / out.shape[0] + +def test_mlp(): + epoch = 9 + acc_train = 0. + acc_val = 0. + for i in range(epoch): + # train + print("Epoch %d" % i) + train_acc = 0.0 + for data, label in train_dataiter: + data = data.numpy + label = label.numpy.flatten() + k = batch_size / num_devs + + for d in range(num_devs): + # feed input + idx = range(d*k, (d+1)*k) + params[d][param_names.index('data')].numpy[:] = data[idx,:] + params[d][param_names.index('mlp_label')].numpy[:] = label[idx] + + # pull weight + mx.kvstore.pull((k,params[d][k]) for k in sync_keys) + + # forward and backward + executors[d].forward() + executors[d].heads()[0].copyto(forward_out[d]) + executors[d].backward([forward_out[d]]) + + # push gradient + mx.kvstore.push((k, grads[d][k]) for k in sync_keys) + + # evaluate. cannot put into the above for loop since it is blocked + # until all forwards are finished + for d in range(num_devs): + train_acc += cal_acc(forward_out[d].numpy, label[range(d*k, (d+1)*k)]) + + train_acc /= train_nbatch + train_nbatch += 1 + print("Train Acc: ", train_acc) + train_dataiter.reset() + + assert(acc_train > 0.98) + +if __name__ == "__main__": + test_mlp()