diff --git a/cpp-package/example/test_kvstore.cpp b/cpp-package/example/test_kvstore.cpp new file mode 100644 index 000000000000..f0fb8beb6719 --- /dev/null +++ b/cpp-package/example/test_kvstore.cpp @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "mxnet-cpp/MxNetCpp.h" + +using namespace mxnet::cpp; + +static bool test_single_key() { + std::string key = "singlekeytest"; + + NDArray result(Shape(4), Context::cpu()); + + // initialize data + NDArray data({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu()); + KVStore::Init(key, data); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(key, &result); + NDArray::WaitAll(); + + // compare + for (size_t j=0; j < result.Size(); j++) { + if (result.GetData()[j] != data.GetData()[j]) { + LG << "Error: wrong initialized data in singlekeytest, expect " + << data.GetData()[j] << " got " << result.GetData()[j]; + return false; + } + } + + // push gradient + NDArray grad({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu()); + KVStore::Push(key, grad); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(key, &result); + NDArray::WaitAll(); + + // compare + for (size_t j=0; j < result.Size(); j++) { + if (result.GetData()[j] == grad.GetData()[j]) { + LG << "Error: wrong gradient data in singlekeytest, expect " + << grad.GetData()[j] << " got " << result.GetData()[j]; + return false; + } + } + + return true; +} + +static bool test_multiple_key() { + std::vector keys(2); + keys[0] = "multikeytest-0"; + keys[1] = "multikeytest-1"; + + std::vector results(2); + results[0] = NDArray(Shape(4), Context::cpu()); + results[1] = NDArray(Shape(4), Context::cpu()); + + // initialize data + std::vector data(2); + data[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu()); + data[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu()); + KVStore::Init(keys, data); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(keys, &results); + NDArray::WaitAll(); + + // compare + for (size_t i=0; i < results.size(); i++) { + for (size_t j=0; j < results[i].Size(); j++) { + if (results[i].GetData()[j] == data[i].GetData()[j]) { + LG << "Error: wrong initialized data in multikeytest, expect " + << data[i].GetData()[j] << " got " << results[i].GetData()[j]; + return false; + } + } + } + + // push gradient, reduce for the second + std::vector push_keys(3); + push_keys[0] = "multikeytest-0"; + push_keys[1] = "multikeytest-1"; + push_keys[2] = "multikeytest-1"; + + std::vector grads(3); + grads[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu()); + grads[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu()); + grads[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu()); + KVStore::Push(push_keys, grads); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(keys, &results); + NDArray::WaitAll(); + + // compare the first + for (size_t j=0; j < results[0].Size(); j++) { + if (results[0].GetData()[j] == grads[0].GetData()[j]) { + LG << "Error: wrong gradient data, expect " << grads[0].GetData()[j] + << " got " << result[0].GetData()[j]; + return false; + } + } + + // compare the second + for (size_t j=0; j < results[1].Size(); j++) { + if (results[1].GetData()[j] == (grads[1].GetData()[j] + grads[2].GetData()[j])) { + LG << "Error: wrong reduced gradient data, expect " + << (grads[1].GetData()[j] + grads[2].GetData()[j]) + << " got " << result[1].GetData()[j]; + return false; + } + } + + return true; +} + +int main(int argc, char** argv) { + KVStore::SetType("local"); + + bool ret1 = test_single_key(); + bool ret2 = test_multiple_key(); + + MXNotifyShutdown(); + return ret1 + ret2; +} diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index d5aa1509a8f0..67f984fce0ee 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -39,12 +39,21 @@ class KVStore { static void SetType(const std::string& type); static void RunServer(); static void Init(int key, const NDArray& val); + static void Init(const std::string& key, const NDArray& val); static void Init(const std::vector& keys, const std::vector& vals); + static void Init(const std::vector& keys, const std::vector& vals); static void Push(int key, const NDArray& val, int priority = 0); + static void Push(const std::string& key, const NDArray& val, int priority = 0); static void Push(const std::vector& keys, - const std::vector& vals, int priority = 0); + const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, + const std::vector& vals, int priority = 0); static void Pull(int key, NDArray* out, int priority = 0); - static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(const std::string& key, NDArray* out, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? static void SetOptimizer(std::unique_ptr optimizer, bool local = false); static std::string GetType(); diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index f2b5e74990ce..6cd405b91dd4 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) { CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0); } +inline void KVStore::Init(const std::string& key, const NDArray& val) { + const char* key_handle = key.c_str(); + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0); +} + inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); @@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector& keys, const std::vector& keys, const std::vector& vals) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data()), 0); +} + inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0); } +inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { + const char* key_handle = key.c_str(); + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0); +} + inline void KVStore::Push(const std::vector& keys, - const std::vector& vals, - int priority) { + const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector& keys, val_handles.data(), priority), 0); } +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, int priority) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data(), priority), 0); +} + inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0); } -inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { + const char* key_handle = key.c_str(); + NDArrayHandle out_handle = out->GetHandle(); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0); +} + +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector out_handles(keys.size()); @@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou out_handles.data(), priority), 0); } +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { + CHECK_EQ(keys.size(), outs->size()); + + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector out_handles(keys.size()); + std::transform(outs->cbegin(), outs->cend(), out_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + out_handles.data(), priority), 0); +} + inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_) { Optimizer *opt = static_cast(handle_); diff --git a/cpp-package/tests/ci_test.sh b/cpp-package/tests/ci_test.sh index 18fabea7a7f9..2d1f8e4f68e6 100755 --- a/cpp-package/tests/ci_test.sh +++ b/cpp-package/tests/ci_test.sh @@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu . cp ../../build/cpp-package/example/mlp_gpu . ./mlp_gpu - cp ../../build/cpp-package/example/test_optimizer . - ./test_optimizer +cp ../../build/cpp-package/example/test_optimizer . +./test_optimizer + +cp ../../build/cpp-package/example/test_kvstore . +./test_kvstore cp ../../build/cpp-package/example/test_score . ./test_score 0.93