Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kvstore strkey #2

Merged
merged 14 commits into from
Mar 29, 2019
145 changes: 145 additions & 0 deletions cpp-package/example/test_kvstore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "mxnet-cpp/MxNetCpp.h"

using namespace mxnet::cpp;

static bool test_single_key() {
std::string key = "singlekeytest";

NDArray result(Shape(4), Context::cpu());

// initialize data
NDArray data({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu());
KVStore::Init(key, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

// compare
for (size_t j=0; j < result.Size(); j++) {
if (result.GetData()[j] != data.GetData()[j]) {
LG << "Error: wrong initialized data in singlekeytest, expect "
<< data.GetData()[j] << " got " << result.GetData()[j];
return false;
}
}

// push gradient
NDArray grad({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu());
KVStore::Push(key, grad);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

// compare
for (size_t j=0; j < result.Size(); j++) {
if (result.GetData()[j] == grad.GetData()[j]) {
LG << "Error: wrong gradient data in singlekeytest, expect "
<< grad.GetData()[j] << " got " << result.GetData()[j];
return false;
}
}

return true;
}

static bool test_multiple_key() {
std::vector<std::string> keys(2);
keys[0] = "multikeytest-0";
keys[1] = "multikeytest-1";

std::vector<NDArray> results(2);
results[0] = NDArray(Shape(4), Context::cpu());
results[1] = NDArray(Shape(4), Context::cpu());

// initialize data
std::vector<NDArray> data(2);
data[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu());
data[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu());
KVStore::Init(keys, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

// compare
for (size_t i=0; i < results.size(); i++) {
for (size_t j=0; j < results[i].Size(); j++) {
if (results[i].GetData()[j] == data[i].GetData()[j]) {
LG << "Error: wrong initialized data in multikeytest, expect "
<< data[i].GetData()[j] << " got " << results[i].GetData()[j];
return false;
}
}
}

// push gradient, reduce for the second
std::vector<std::string> push_keys(3);
push_keys[0] = "multikeytest-0";
push_keys[1] = "multikeytest-1";
push_keys[2] = "multikeytest-1";

std::vector<NDArray> grads(3);
grads[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu());
grads[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu());
grads[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu());
KVStore::Push(push_keys, grads);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

// compare the first
for (size_t j=0; j < results[0].Size(); j++) {
if (results[0].GetData()[j] == grads[0].GetData()[j]) {
LG << "Error: wrong gradient data, expect " << grads[0].GetData()[j]
<< " got " << result[0].GetData()[j];
return false;
}
}

// compare the second
for (size_t j=0; j < results[1].Size(); j++) {
if (results[1].GetData()[j] == (grads[1].GetData()[j] + grads[2].GetData()[j])) {
LG << "Error: wrong reduced gradient data, expect "
<< (grads[1].GetData()[j] + grads[2].GetData()[j])
<< " got " << result[1].GetData()[j];
return false;
}
}

return true;
}

int main(int argc, char** argv) {
KVStore::SetType("local");

bool ret1 = test_single_key();
bool ret2 = test_multiple_key();

MXNotifyShutdown();
return ret1 + ret2;
}
13 changes: 11 additions & 2 deletions cpp-package/include/mxnet-cpp/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ class KVStore {
static void SetType(const std::string& type);
static void RunServer();
static void Init(int key, const NDArray& val);
static void Init(const std::string& key, const NDArray& val);
static void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
static void Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals);
static void Push(int key, const NDArray& val, int priority = 0);
static void Push(const std::string& key, const NDArray& val, int priority = 0);
static void Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals, int priority = 0);
const std::vector<NDArray>& vals, int priority = 0);
static void Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority = 0);
static void Pull(int key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::string& key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority = 0);
// TODO(lx): put lr in optimizer or not?
static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = false);
static std::string GetType();
Expand Down
78 changes: 75 additions & 3 deletions cpp-package/include/mxnet-cpp/kvstore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) {
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0);
}

inline void KVStore::Init(const std::string& key, const NDArray& val) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0);
}

inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
Expand All @@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArra
val_handles.data()), 0);
}

inline void KVStore::Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data()), 0);
}

inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
Expand All @@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector<int>& keys,
val_handles.data(), priority), 0);
}

inline void KVStore::Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data(), priority), 0);
}

inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<NDArrayHandle> out_handles(keys.size());
Expand All @@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* ou
out_handles.data(), priority), 0);
}

inline void KVStore::Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> out_handles(keys.size());
std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
out_handles.data(), priority), 0);
}

inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
Expand Down
7 changes: 5 additions & 2 deletions cpp-package/tests/ci_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu .
cp ../../build/cpp-package/example/mlp_gpu .
./mlp_gpu

cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer
cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer

cp ../../build/cpp-package/example/test_kvstore .
./test_kvstore

cp ../../build/cpp-package/example/test_score .
./test_score 0.93
Expand Down