diff --git a/cpp-package/example/test_kvstore.cpp b/cpp-package/example/test_kvstore.cpp new file mode 100644 index 000000000000..d9e0400a5ac8 --- /dev/null +++ b/cpp-package/example/test_kvstore.cpp @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "mxnet/c_api.h" // MXGetGPUCount() +#include "mxnet-cpp/MxNetCpp.h" + +using namespace mxnet::cpp; + +static bool test_single_key(const Context &context, const std::string &context_str) { + std::string key = "singlekeytest-" + context_str; + + NDArray result(Shape(4), context); + NDArray result_cpu; + + // initialize data + NDArray data_cpu({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu()); + NDArray data = data_cpu.Copy(context); + NDArray::WaitAll(); + + KVStore::Init(key, data); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(key, &result); + NDArray::WaitAll(); + + result_cpu = result.Copy(Context::cpu()); + NDArray::WaitAll(); + + // compare + for (size_t j=0; j < result_cpu.Size(); j++) { + if (result_cpu.GetData()[j] != data_cpu.GetData()[j]) { + LG << "Error: wrong initialized data in singlekeytest-" << context_str + << ", expect " << data_cpu.GetData()[j] + << " got " << result_cpu.GetData()[j]; + return false; + } + } + + // push gradient + NDArray grad_cpu({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu()); + NDArray grad = grad_cpu.Copy(context); + NDArray::WaitAll(); + + KVStore::Push(key, grad); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(key, &result); + NDArray::WaitAll(); + + result_cpu = result.Copy(Context::cpu()); + NDArray::WaitAll(); + + // compare + for (size_t j=0; j < result_cpu.Size(); j++) { + if (result_cpu.GetData()[j] != grad_cpu.GetData()[j]) { + LG << "Error: wrong gradient data in singlekeytest-" << context_str + << ", expect " << grad_cpu.GetData()[j] + << " got " << result_cpu.GetData()[j]; + return false; + } + } + + return true; +} + +static bool test_multiple_key(const Context &context, const std::string &context_str) { + std::vector keys(2); + keys[0] = "multikeytest-0-" + context_str; + keys[1] = "multikeytest-1-" + context_str; + + std::vector results(2); + results[0] = NDArray(Shape(4), context); + results[1] = NDArray(Shape(4), context); + std::vector results_cpu(2); + + // initialize data + std::vector data_cpu(2); + data_cpu[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu()); + data_cpu[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu()); + std::vector data(2); + data[0] = data_cpu[0].Copy(context); + data[1] = data_cpu[1].Copy(context); + NDArray::WaitAll(); + + KVStore::Init(keys, data); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(keys, &results); + NDArray::WaitAll(); + + results_cpu[0] = results[0].Copy(Context::cpu()); + results_cpu[1] = results[1].Copy(Context::cpu()); + NDArray::WaitAll(); + + // compare + for (size_t i=0; i < results_cpu.size(); i++) { + for (size_t j=0; j < results_cpu[i].Size(); j++) { + if (results_cpu[i].GetData()[j] != data_cpu[i].GetData()[j]) { + LG << "Error: wrong initialized data in multikeytest-" << context_str + << ", expect " << data_cpu[i].GetData()[j] + << " got " << results_cpu[i].GetData()[j]; + return false; + } + } + } + + // push gradient, reduce for the second + std::vector push_keys(3); + push_keys[0] = "multikeytest-0-" + context_str; + push_keys[1] = "multikeytest-1-" + context_str; + push_keys[2] = "multikeytest-1-" + context_str; + + std::vector grads_cpu(3); + grads_cpu[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu()); + grads_cpu[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu()); + grads_cpu[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu()); + std::vector grads(3); + grads[0] = grads_cpu[0].Copy(context); + grads[1] = grads_cpu[1].Copy(context); + grads[2] = grads_cpu[2].Copy(context); + NDArray::WaitAll(); + + KVStore::Push(push_keys, grads); + NDArray::WaitAll(); + + // retrieve result + KVStore::Pull(keys, &results); + NDArray::WaitAll(); + + results_cpu[0] = results[0].Copy(Context::cpu()); + results_cpu[1] = results[1].Copy(Context::cpu()); + NDArray::WaitAll(); + + // compare the first + for (size_t j=0; j < results_cpu[0].Size(); j++) { + if (results_cpu[0].GetData()[j] != grads_cpu[0].GetData()[j]) { + LG << "Error: wrong gradient data in multikeytest-" << context_str + << ", expect " << grads_cpu[0].GetData()[j] + << " got " << results_cpu[0].GetData()[j]; + return false; + } + } + + // compare the second + for (size_t j=0; j < results_cpu[1].Size(); j++) { + if (results_cpu[1].GetData()[j] != (grads_cpu[1].GetData()[j] + grads_cpu[2].GetData()[j])) { + LG << "Error: wrong reduced gradient data in multikeytest-" << context_str + << ", expect " << (grads_cpu[1].GetData()[j] + grads_cpu[2].GetData()[j]) + << " got " << results_cpu[1].GetData()[j]; + return false; + } + } + + return true; +} + +int main(int argc, char** argv) { + KVStore::SetType("local"); + + bool success1 = test_single_key(Context::cpu(), "cpu"); + bool success2 = test_multiple_key(Context::cpu(), "cpu"); + + bool success3 = true; + bool success4 = true; + + int gpu_count = 0; + if (MXGetGPUCount(&gpu_count) != 0) { + LG << "Error: MXGetGPUCount"; + + MXNotifyShutdown(); + return 1; + } + + if (gpu_count > 0) { + success3 = test_single_key(Context::gpu(), "gpu"); + success4 = test_multiple_key(Context::gpu(), "gpu"); + } + + int ret = (success1 && success2 && success3 && success4) ? 0 : 1; + + MXNotifyShutdown(); + return ret; +} diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index d5aa1509a8f0..67f984fce0ee 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -39,12 +39,21 @@ class KVStore { static void SetType(const std::string& type); static void RunServer(); static void Init(int key, const NDArray& val); + static void Init(const std::string& key, const NDArray& val); static void Init(const std::vector& keys, const std::vector& vals); + static void Init(const std::vector& keys, const std::vector& vals); static void Push(int key, const NDArray& val, int priority = 0); + static void Push(const std::string& key, const NDArray& val, int priority = 0); static void Push(const std::vector& keys, - const std::vector& vals, int priority = 0); + const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, + const std::vector& vals, int priority = 0); static void Pull(int key, NDArray* out, int priority = 0); - static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(const std::string& key, NDArray* out, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? static void SetOptimizer(std::unique_ptr optimizer, bool local = false); static std::string GetType(); diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index f2b5e74990ce..6cd405b91dd4 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) { CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0); } +inline void KVStore::Init(const std::string& key, const NDArray& val) { + const char* key_handle = key.c_str(); + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0); +} + inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); @@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector& keys, const std::vector& keys, const std::vector& vals) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data()), 0); +} + inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0); } +inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { + const char* key_handle = key.c_str(); + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0); +} + inline void KVStore::Push(const std::vector& keys, - const std::vector& vals, - int priority) { + const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector& keys, val_handles.data(), priority), 0); } +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, int priority) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data(), priority), 0); +} + inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0); } -inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { + const char* key_handle = key.c_str(); + NDArrayHandle out_handle = out->GetHandle(); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0); +} + +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector out_handles(keys.size()); @@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou out_handles.data(), priority), 0); } +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { + CHECK_EQ(keys.size(), outs->size()); + + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector out_handles(keys.size()); + std::transform(outs->cbegin(), outs->cend(), out_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + out_handles.data(), priority), 0); +} + inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_) { Optimizer *opt = static_cast(handle_); diff --git a/cpp-package/tests/ci_test.sh b/cpp-package/tests/ci_test.sh index 18fabea7a7f9..2d1f8e4f68e6 100755 --- a/cpp-package/tests/ci_test.sh +++ b/cpp-package/tests/ci_test.sh @@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu . cp ../../build/cpp-package/example/mlp_gpu . ./mlp_gpu - cp ../../build/cpp-package/example/test_optimizer . - ./test_optimizer +cp ../../build/cpp-package/example/test_optimizer . +./test_optimizer + +cp ../../build/cpp-package/example/test_kvstore . +./test_kvstore cp ../../build/cpp-package/example/test_score . ./test_score 0.93