From 0791cd145b6708a3dfff6bdd976916dc1bfb2241 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 3 May 2018 12:41:53 +0800 Subject: [PATCH 01/10] support string type for kvstore key in cpp-package --- cpp-package/include/mxnet-cpp/kvstore.h | 9 ++- cpp-package/include/mxnet-cpp/kvstore.hpp | 71 ++++++++++++++++++++++- 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index d5aa1509a8f0..31bbc17ea169 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -39,12 +39,17 @@ class KVStore { static void SetType(const std::string& type); static void RunServer(); static void Init(int key, const NDArray& val); + static void Init(const std::string& key, const NDArray& val); static void Init(const std::vector& keys, const std::vector& vals); + static void Init(const std::vector& keys, const std::vector& vals); static void Push(int key, const NDArray& val, int priority = 0); - static void Push(const std::vector& keys, - const std::vector& vals, int priority = 0); + static void Push(const std::string& key, const NDArray& val, int priority = 0); + static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); static void Pull(int key, NDArray* out, int priority = 0); + static void Pull(const std::string& key, NDArray* out, int priority = 0); static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? static void SetOptimizer(std::unique_ptr optimizer, bool local = false); static std::string GetType(); diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index f2b5e74990ce..413b52bd5fc1 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -87,6 +87,11 @@ inline void KVStore::Init(int key, const NDArray& val) { CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0); } +inline void KVStore::Init(const std::string& key, const NDArray& val) { + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle), 0); +} + inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); @@ -99,14 +104,34 @@ inline void KVStore::Init(const std::vector& keys, const std::vector& keys, const std::vector& vals) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data()), 0); +} + inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0); } -inline void KVStore::Push(const std::vector& keys, - const std::vector& vals, - int priority) { +inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle, priority), 0); +} + +inline void KVStore::Push(const std::vector& keys, const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -118,11 +143,33 @@ inline void KVStore::Push(const std::vector& keys, val_handles.data(), priority), 0); } +inline void KVStore::Push(const std::vector& keys, const std::vector& vals, int priority) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data(), priority), 0); +} + inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0); } +inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { + NDArrayHandle out_handle = out->GetHandle(); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), &out_handle, priority), 0); +} + inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); @@ -136,6 +183,24 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou out_handles.data(), priority), 0); } +inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { + CHECK_EQ(keys.size(), outs->size()); + + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector out_handles(keys.size()); + std::transform(outs->cbegin(), outs->cend(), out_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + out_handles.data(), priority), 0); +} + inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_) { Optimizer *opt = static_cast(handle_); From f0064c2099d13e2d313a3f72348f3e33b8907d41 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 3 May 2018 13:24:04 +0800 Subject: [PATCH 02/10] make lines short --- cpp-package/include/mxnet-cpp/kvstore.h | 12 ++++++++---- cpp-package/include/mxnet-cpp/kvstore.hpp | 18 ++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index 31bbc17ea169..67f984fce0ee 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -44,12 +44,16 @@ class KVStore { static void Init(const std::vector& keys, const std::vector& vals); static void Push(int key, const NDArray& val, int priority = 0); static void Push(const std::string& key, const NDArray& val, int priority = 0); - static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); - static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, + const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, + const std::vector& vals, int priority = 0); static void Pull(int key, NDArray* out, int priority = 0); static void Pull(const std::string& key, NDArray* out, int priority = 0); - static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); - static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? static void SetOptimizer(std::unique_ptr optimizer, bool local = false); static std::string GetType(); diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index 413b52bd5fc1..d5d8e5f13a22 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -128,10 +128,12 @@ inline void KVStore::Push(int key, const NDArray& val, int priority) { inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle, priority), 0); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), + &val_handle, priority), 0); } -inline void KVStore::Push(const std::vector& keys, const std::vector& vals, int priority) { +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -143,7 +145,8 @@ inline void KVStore::Push(const std::vector& keys, const std::vector& keys, const std::vector& vals, int priority) { +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector key_handles(keys.size()); std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), @@ -167,10 +170,12 @@ inline void KVStore::Pull(int key, NDArray* out, int priority) { inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); - CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), &out_handle, priority), 0); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), + &out_handle, priority), 0); } -inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector out_handles(keys.size()); @@ -183,7 +188,8 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou out_handles.data(), priority), 0); } -inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector key_handles(keys.size()); From 5d4609a576b52e7ecf3d00181072ac5ce7766176 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 3 May 2018 14:02:21 +0800 Subject: [PATCH 03/10] fix build --- cpp-package/include/mxnet-cpp/kvstore.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index d5d8e5f13a22..6cd405b91dd4 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -88,8 +88,9 @@ inline void KVStore::Init(int key, const NDArray& val) { } inline void KVStore::Init(const std::string& key, const NDArray& val) { + const char* key_handle = key.c_str(); NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle), 0); + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0); } inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { @@ -127,9 +128,9 @@ inline void KVStore::Push(int key, const NDArray& val, int priority) { } inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { + const char* key_handle = key.c_str(); NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), - &val_handle, priority), 0); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0); } inline void KVStore::Push(const std::vector& keys, @@ -169,9 +170,9 @@ inline void KVStore::Pull(int key, NDArray* out, int priority) { } inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { + const char* key_handle = key.c_str(); NDArrayHandle out_handle = out->GetHandle(); - CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), - &out_handle, priority), 0); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0); } inline void KVStore::Pull(const std::vector& keys, From 08c0f8f50903d4acd298e9c946c5286be8596560 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 15 Nov 2018 11:03:20 +0800 Subject: [PATCH 04/10] add kvstore testcase --- Jenkinsfile | 2 +- cpp-package/example/test_kvstore.cpp | 149 +++++++++++++++++++++++++++ cpp-package/tests/ci_test.sh | 7 +- 3 files changed, 155 insertions(+), 3 deletions(-) create mode 100644 cpp-package/example/test_kvstore.cpp diff --git a/Jenkinsfile b/Jenkinsfile index 3f72843596e7..131c983cf51b 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -35,7 +35,7 @@ mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-c mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, build/3rdparty/mkldnn/src/libmkldnn.so.0' mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libiomp5.so, lib/libmkldnn.so.0, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' mx_tensorrt_lib = 'lib/libmxnet.so, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so' -mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/lenet, build/cpp-package/example/alexnet, build/cpp-package/example/googlenet, build/cpp-package/example/lenet_with_mxdataiter, build/cpp-package/example/resnet, build/cpp-package/example/mlp, build/cpp-package/example/mlp_cpu, build/cpp-package/example/mlp_gpu, build/cpp-package/example/test_score, build/cpp-package/example/test_optimizer' +mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/lenet, build/cpp-package/example/alexnet, build/cpp-package/example/googlenet, build/cpp-package/example/lenet_with_mxdataiter, build/cpp-package/example/resnet, build/cpp-package/example/mlp, build/cpp-package/example/mlp_cpu, build/cpp-package/example/mlp_gpu, build/cpp-package/example/test_score, build/cpp-package/example/test_optimizer, build/cpp-package/example/test_kvstore' mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/cpp-package/example/mlp_cpu' // timeout in minutes diff --git a/cpp-package/example/test_kvstore.cpp b/cpp-package/example/test_kvstore.cpp new file mode 100644 index 000000000000..34c76aecce77 --- /dev/null +++ b/cpp-package/example/test_kvstore.cpp @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "mxnet-cpp/MxNetCpp.h" +#include + +using namespace mxnet::cpp; + +static int test_single_key() { + int ret = 0; + + std::string key = "singlekeytest"; + + NDArray result(Shape(4), Context::cpu()); + + // initialize data + NDArray data(Shape(4), Context::cpu()); + for (size_t j=0; j keys(2); + keys[0] = "multikeytest-0"; + keys[1] = "multikeytest-1"; + + std::vector results(2); + results[0] = NDArray(Shape(10), Context::cpu()); + results[1] = NDArray(Shape(10), Context::cpu()); + + // initialize data + std::vector datas(2); + datas[0] = NDArray(Shape(10), Context::cpu()); + datas[1] = NDArray(Shape(10), Context::cpu()); + for (size_t i=0; i push_keys(3); + push_keys[0] = "multikeytest-0"; + push_keys[1] = "multikeytest-1"; + push_keys[2] = "multikeytest-1"; + + std::vector grads(3); + grads[0] = NDArray(Shape(10), Context::cpu()); + grads[1] = NDArray(Shape(10), Context::cpu()); + grads[2] = NDArray(Shape(10), Context::cpu()); + for (size_t i=0; i Date: Thu, 15 Nov 2018 11:24:34 +0800 Subject: [PATCH 05/10] no rand() use --- cpp-package/example/test_kvstore.cpp | 35 +++++++--------------------- 1 file changed, 9 insertions(+), 26 deletions(-) diff --git a/cpp-package/example/test_kvstore.cpp b/cpp-package/example/test_kvstore.cpp index 34c76aecce77..ed28f2b9cea3 100644 --- a/cpp-package/example/test_kvstore.cpp +++ b/cpp-package/example/test_kvstore.cpp @@ -17,7 +17,6 @@ * under the License. */ #include "mxnet-cpp/MxNetCpp.h" -#include using namespace mxnet::cpp; @@ -29,10 +28,7 @@ static int test_single_key() { NDArray result(Shape(4), Context::cpu()); // initialize data - NDArray data(Shape(4), Context::cpu()); - for (size_t j=0; j results(2); - results[0] = NDArray(Shape(10), Context::cpu()); - results[1] = NDArray(Shape(10), Context::cpu()); + results[0] = NDArray(Shape(4), Context::cpu()); + results[1] = NDArray(Shape(4), Context::cpu()); // initialize data std::vector datas(2); - datas[0] = NDArray(Shape(10), Context::cpu()); - datas[1] = NDArray(Shape(10), Context::cpu()); - for (size_t i=0; i grads(3); - grads[0] = NDArray(Shape(10), Context::cpu()); - grads[1] = NDArray(Shape(10), Context::cpu()); - grads[2] = NDArray(Shape(10), Context::cpu()); - for (size_t i=0; i Date: Thu, 15 Nov 2018 11:27:50 +0800 Subject: [PATCH 06/10] fix cpplint sanity check --- cpp-package/example/test_kvstore.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp-package/example/test_kvstore.cpp b/cpp-package/example/test_kvstore.cpp index ed28f2b9cea3..bb81e3a3031a 100644 --- a/cpp-package/example/test_kvstore.cpp +++ b/cpp-package/example/test_kvstore.cpp @@ -37,7 +37,7 @@ static int test_single_key() { NDArray::WaitAll(); // compare - for (size_t j=0; j Date: Thu, 3 May 2018 12:41:53 +0800 Subject: [PATCH 07/10] support string type for kvstore key in cpp-package --- cpp-package/include/mxnet-cpp/kvstore.h | 9 ++- cpp-package/include/mxnet-cpp/kvstore.hpp | 71 ++++++++++++++++++++++- 2 files changed, 75 insertions(+), 5 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index d5aa1509a8f0..31bbc17ea169 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -39,12 +39,17 @@ class KVStore { static void SetType(const std::string& type); static void RunServer(); static void Init(int key, const NDArray& val); + static void Init(const std::string& key, const NDArray& val); static void Init(const std::vector& keys, const std::vector& vals); + static void Init(const std::vector& keys, const std::vector& vals); static void Push(int key, const NDArray& val, int priority = 0); - static void Push(const std::vector& keys, - const std::vector& vals, int priority = 0); + static void Push(const std::string& key, const NDArray& val, int priority = 0); + static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); static void Pull(int key, NDArray* out, int priority = 0); + static void Pull(const std::string& key, NDArray* out, int priority = 0); static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? static void SetOptimizer(std::unique_ptr optimizer, bool local = false); static std::string GetType(); diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index f2b5e74990ce..413b52bd5fc1 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -87,6 +87,11 @@ inline void KVStore::Init(int key, const NDArray& val) { CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0); } +inline void KVStore::Init(const std::string& key, const NDArray& val) { + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle), 0); +} + inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); @@ -99,14 +104,34 @@ inline void KVStore::Init(const std::vector& keys, const std::vector& keys, const std::vector& vals) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data()), 0); +} + inline void KVStore::Push(int key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0); } -inline void KVStore::Push(const std::vector& keys, - const std::vector& vals, - int priority) { +inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { + NDArrayHandle val_handle = val.GetHandle(); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle, priority), 0); +} + +inline void KVStore::Push(const std::vector& keys, const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -118,11 +143,33 @@ inline void KVStore::Push(const std::vector& keys, val_handles.data(), priority), 0); } +inline void KVStore::Push(const std::vector& keys, const std::vector& vals, int priority) { + CHECK_EQ(keys.size(), vals.size()); + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector val_handles(vals.size()); + std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + val_handles.data(), priority), 0); +} + inline void KVStore::Pull(int key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0); } +inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { + NDArrayHandle out_handle = out->GetHandle(); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), &out_handle, priority), 0); +} + inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); @@ -136,6 +183,24 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou out_handles.data(), priority), 0); } +inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { + CHECK_EQ(keys.size(), outs->size()); + + std::vector key_handles(keys.size()); + std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), + [](const std::string& key) { + return key.c_str(); + }); + std::vector out_handles(keys.size()); + std::transform(outs->cbegin(), outs->cend(), out_handles.begin(), + [](const NDArray& val) { + return val.GetHandle(); + }); + + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(), + out_handles.data(), priority), 0); +} + inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local, void* handle_) { Optimizer *opt = static_cast(handle_); From 15b9d7ffd4388a86afada4778a0565f6e8ebb29f Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 3 May 2018 13:24:04 +0800 Subject: [PATCH 08/10] make lines short --- cpp-package/include/mxnet-cpp/kvstore.h | 12 ++++++++---- cpp-package/include/mxnet-cpp/kvstore.hpp | 18 ++++++++++++------ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.h b/cpp-package/include/mxnet-cpp/kvstore.h index 31bbc17ea169..67f984fce0ee 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.h +++ b/cpp-package/include/mxnet-cpp/kvstore.h @@ -44,12 +44,16 @@ class KVStore { static void Init(const std::vector& keys, const std::vector& vals); static void Push(int key, const NDArray& val, int priority = 0); static void Push(const std::string& key, const NDArray& val, int priority = 0); - static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); - static void Push(const std::vector& keys, const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, + const std::vector& vals, int priority = 0); + static void Push(const std::vector& keys, + const std::vector& vals, int priority = 0); static void Pull(int key, NDArray* out, int priority = 0); static void Pull(const std::string& key, NDArray* out, int priority = 0); - static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); - static void Pull(const std::vector& keys, std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); + static void Pull(const std::vector& keys, + std::vector* outs, int priority = 0); // TODO(lx): put lr in optimizer or not? static void SetOptimizer(std::unique_ptr optimizer, bool local = false); static std::string GetType(); diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index 413b52bd5fc1..d5d8e5f13a22 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -128,10 +128,12 @@ inline void KVStore::Push(int key, const NDArray& val, int priority) { inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle, priority), 0); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), + &val_handle, priority), 0); } -inline void KVStore::Push(const std::vector& keys, const std::vector& vals, int priority) { +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector val_handles(vals.size()); std::transform(vals.cbegin(), vals.cend(), val_handles.begin(), @@ -143,7 +145,8 @@ inline void KVStore::Push(const std::vector& keys, const std::vector& keys, const std::vector& vals, int priority) { +inline void KVStore::Push(const std::vector& keys, + const std::vector& vals, int priority) { CHECK_EQ(keys.size(), vals.size()); std::vector key_handles(keys.size()); std::transform(keys.cbegin(), keys.cend(), key_handles.begin(), @@ -167,10 +170,12 @@ inline void KVStore::Pull(int key, NDArray* out, int priority) { inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { NDArrayHandle out_handle = out->GetHandle(); - CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), &out_handle, priority), 0); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), + &out_handle, priority), 0); } -inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector out_handles(keys.size()); @@ -183,7 +188,8 @@ inline void KVStore::Pull(const std::vector& keys, std::vector* ou out_handles.data(), priority), 0); } -inline void KVStore::Pull(const std::vector& keys, std::vector* outs, int priority) { +inline void KVStore::Pull(const std::vector& keys, + std::vector* outs, int priority) { CHECK_EQ(keys.size(), outs->size()); std::vector key_handles(keys.size()); From dd1979932aebd44ba12ce8a2710292ed31fd0045 Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 3 May 2018 14:02:21 +0800 Subject: [PATCH 09/10] fix build --- cpp-package/include/mxnet-cpp/kvstore.hpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp-package/include/mxnet-cpp/kvstore.hpp b/cpp-package/include/mxnet-cpp/kvstore.hpp index d5d8e5f13a22..6cd405b91dd4 100644 --- a/cpp-package/include/mxnet-cpp/kvstore.hpp +++ b/cpp-package/include/mxnet-cpp/kvstore.hpp @@ -88,8 +88,9 @@ inline void KVStore::Init(int key, const NDArray& val) { } inline void KVStore::Init(const std::string& key, const NDArray& val) { + const char* key_handle = key.c_str(); NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key.c_str(), &val_handle), 0); + CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0); } inline void KVStore::Init(const std::vector& keys, const std::vector& vals) { @@ -127,9 +128,9 @@ inline void KVStore::Push(int key, const NDArray& val, int priority) { } inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) { + const char* key_handle = key.c_str(); NDArrayHandle val_handle = val.GetHandle(); - CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key.c_str(), - &val_handle, priority), 0); + CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0); } inline void KVStore::Push(const std::vector& keys, @@ -169,9 +170,9 @@ inline void KVStore::Pull(int key, NDArray* out, int priority) { } inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) { + const char* key_handle = key.c_str(); NDArrayHandle out_handle = out->GetHandle(); - CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key.c_str(), - &out_handle, priority), 0); + CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0); } inline void KVStore::Pull(const std::vector& keys, From 6a5b2d3ce5aff3b56c12a085908d363a3180faf9 Mon Sep 17 00:00:00 2001 From: nihui Date: Fri, 29 Mar 2019 18:32:12 +0800 Subject: [PATCH 10/10] print error log --- cpp-package/example/test_kvstore.cpp | 59 +++++++++++++++++----------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/cpp-package/example/test_kvstore.cpp b/cpp-package/example/test_kvstore.cpp index bb81e3a3031a..f0fb8beb6719 100644 --- a/cpp-package/example/test_kvstore.cpp +++ b/cpp-package/example/test_kvstore.cpp @@ -20,9 +20,7 @@ using namespace mxnet::cpp; -static int test_single_key() { - int ret = 0; - +static bool test_single_key() { std::string key = "singlekeytest"; NDArray result(Shape(4), Context::cpu()); @@ -38,10 +36,12 @@ static int test_single_key() { // compare for (size_t j=0; j < result.Size(); j++) { - ret += (result.GetData()[j] == data.GetData()[j]) ? 0 : 1; + if (result.GetData()[j] != data.GetData()[j]) { + LG << "Error: wrong initialized data in singlekeytest, expect " + << data.GetData()[j] << " got " << result.GetData()[j]; + return false; + } } - if (ret != 0) - return ret; // push gradient NDArray grad({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu()); @@ -54,15 +54,17 @@ static int test_single_key() { // compare for (size_t j=0; j < result.Size(); j++) { - ret += (result.GetData()[j] == grad.GetData()[j]) ? 0 : 1; + if (result.GetData()[j] == grad.GetData()[j]) { + LG << "Error: wrong gradient data in singlekeytest, expect " + << grad.GetData()[j] << " got " << result.GetData()[j]; + return false; + } } - return ret; + return true; } -static int test_multiple_key() { - int ret = 0; - +static bool test_multiple_key() { std::vector keys(2); keys[0] = "multikeytest-0"; keys[1] = "multikeytest-1"; @@ -72,10 +74,10 @@ static int test_multiple_key() { results[1] = NDArray(Shape(4), Context::cpu()); // initialize data - std::vector datas(2); - datas[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu()); - datas[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu()); - KVStore::Init(keys, datas); + std::vector data(2); + data[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu()); + data[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu()); + KVStore::Init(keys, data); NDArray::WaitAll(); // retrieve result @@ -85,11 +87,13 @@ static int test_multiple_key() { // compare for (size_t i=0; i < results.size(); i++) { for (size_t j=0; j < results[i].Size(); j++) { - ret += (results[i].GetData()[j] == datas[i].GetData()[j]) ? 0 : 1; + if (results[i].GetData()[j] == data[i].GetData()[j]) { + LG << "Error: wrong initialized data in multikeytest, expect " + << data[i].GetData()[j] << " got " << results[i].GetData()[j]; + return false; + } } } - if (ret != 0) - return ret; // push gradient, reduce for the second std::vector push_keys(3); @@ -110,22 +114,31 @@ static int test_multiple_key() { // compare the first for (size_t j=0; j < results[0].Size(); j++) { - ret += (results[0].GetData()[j] == grads[0].GetData()[j]) ? 0 : 1; + if (results[0].GetData()[j] == grads[0].GetData()[j]) { + LG << "Error: wrong gradient data, expect " << grads[0].GetData()[j] + << " got " << result[0].GetData()[j]; + return false; + } } // compare the second for (size_t j=0; j < results[1].Size(); j++) { - ret += (results[1].GetData()[j] == (grads[1].GetData()[j] + grads[2].GetData()[j])) ? 0 : 1; + if (results[1].GetData()[j] == (grads[1].GetData()[j] + grads[2].GetData()[j])) { + LG << "Error: wrong reduced gradient data, expect " + << (grads[1].GetData()[j] + grads[2].GetData()[j]) + << " got " << result[1].GetData()[j]; + return false; + } } - return ret; + return true; } int main(int argc, char** argv) { KVStore::SetType("local"); - int ret1 = test_single_key(); - int ret2 = test_multiple_key(); + bool ret1 = test_single_key(); + bool ret2 = test_multiple_key(); MXNotifyShutdown(); return ret1 + ret2;