Skip to content

Commit

Permalink
Kvstore strkey (#2)
Browse files Browse the repository at this point in the history
* support string type for kvstore key in cpp-package

* make lines short

* fix build

* add kvstore testcase

* no rand() use

* fix cpplint sanity check

* support string type for kvstore key in cpp-package

* make lines short

* fix build

* print error log
  • Loading branch information
nihui committed Mar 29, 2019
1 parent 84c2ae1 commit 24ca621
Show file tree
Hide file tree
Showing 4 changed files with 236 additions and 7 deletions.
145 changes: 145 additions & 0 deletions cpp-package/example/test_kvstore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "mxnet-cpp/MxNetCpp.h"

using namespace mxnet::cpp;

static bool test_single_key() {
std::string key = "singlekeytest";

NDArray result(Shape(4), Context::cpu());

// initialize data
NDArray data({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu());
KVStore::Init(key, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

// compare
for (size_t j=0; j < result.Size(); j++) {
if (result.GetData()[j] != data.GetData()[j]) {
LG << "Error: wrong initialized data in singlekeytest, expect "
<< data.GetData()[j] << " got " << result.GetData()[j];
return false;
}
}

// push gradient
NDArray grad({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu());
KVStore::Push(key, grad);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

// compare
for (size_t j=0; j < result.Size(); j++) {
if (result.GetData()[j] == grad.GetData()[j]) {
LG << "Error: wrong gradient data in singlekeytest, expect "
<< grad.GetData()[j] << " got " << result.GetData()[j];
return false;
}
}

return true;
}

static bool test_multiple_key() {
std::vector<std::string> keys(2);
keys[0] = "multikeytest-0";
keys[1] = "multikeytest-1";

std::vector<NDArray> results(2);
results[0] = NDArray(Shape(4), Context::cpu());
results[1] = NDArray(Shape(4), Context::cpu());

// initialize data
std::vector<NDArray> data(2);
data[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu());
data[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu());
KVStore::Init(keys, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

// compare
for (size_t i=0; i < results.size(); i++) {
for (size_t j=0; j < results[i].Size(); j++) {
if (results[i].GetData()[j] == data[i].GetData()[j]) {
LG << "Error: wrong initialized data in multikeytest, expect "
<< data[i].GetData()[j] << " got " << results[i].GetData()[j];
return false;
}
}
}

// push gradient, reduce for the second
std::vector<std::string> push_keys(3);
push_keys[0] = "multikeytest-0";
push_keys[1] = "multikeytest-1";
push_keys[2] = "multikeytest-1";

std::vector<NDArray> grads(3);
grads[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu());
grads[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu());
grads[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu());
KVStore::Push(push_keys, grads);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

// compare the first
for (size_t j=0; j < results[0].Size(); j++) {
if (results[0].GetData()[j] == grads[0].GetData()[j]) {
LG << "Error: wrong gradient data, expect " << grads[0].GetData()[j]
<< " got " << result[0].GetData()[j];
return false;
}
}

// compare the second
for (size_t j=0; j < results[1].Size(); j++) {
if (results[1].GetData()[j] == (grads[1].GetData()[j] + grads[2].GetData()[j])) {
LG << "Error: wrong reduced gradient data, expect "
<< (grads[1].GetData()[j] + grads[2].GetData()[j])
<< " got " << result[1].GetData()[j];
return false;
}
}

return true;
}

int main(int argc, char** argv) {
KVStore::SetType("local");

bool ret1 = test_single_key();
bool ret2 = test_multiple_key();

MXNotifyShutdown();
return ret1 + ret2;
}
13 changes: 11 additions & 2 deletions cpp-package/include/mxnet-cpp/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ class KVStore {
static void SetType(const std::string& type);
static void RunServer();
static void Init(int key, const NDArray& val);
static void Init(const std::string& key, const NDArray& val);
static void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
static void Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals);
static void Push(int key, const NDArray& val, int priority = 0);
static void Push(const std::string& key, const NDArray& val, int priority = 0);
static void Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals, int priority = 0);
const std::vector<NDArray>& vals, int priority = 0);
static void Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority = 0);
static void Pull(int key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::string& key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority = 0);
// TODO(lx): put lr in optimizer or not?
static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = false);
static std::string GetType();
Expand Down
78 changes: 75 additions & 3 deletions cpp-package/include/mxnet-cpp/kvstore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) {
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0);
}

inline void KVStore::Init(const std::string& key, const NDArray& val) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0);
}

inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
Expand All @@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArra
val_handles.data()), 0);
}

inline void KVStore::Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data()), 0);
}

inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
Expand All @@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector<int>& keys,
val_handles.data(), priority), 0);
}

inline void KVStore::Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data(), priority), 0);
}

inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<NDArrayHandle> out_handles(keys.size());
Expand All @@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* ou
out_handles.data(), priority), 0);
}

inline void KVStore::Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> out_handles(keys.size());
std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
out_handles.data(), priority), 0);
}

inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
Expand Down
7 changes: 5 additions & 2 deletions cpp-package/tests/ci_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu .
cp ../../build/cpp-package/example/mlp_gpu .
./mlp_gpu

cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer
cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer

cp ../../build/cpp-package/example/test_kvstore .
./test_kvstore

cp ../../build/cpp-package/example/test_score .
./test_score 0.93
Expand Down

0 comments on commit 24ca621

Please sign in to comment.