Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-400] support string type for kvstore key in cpp-package #10792

Merged
merged 6 commits into from
Apr 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions cpp-package/example/test_kvstore.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include "mxnet/c_api.h" // MXGetGPUCount()
#include "mxnet-cpp/MxNetCpp.h"

using namespace mxnet::cpp;

static bool test_single_key(const Context &context, const std::string &context_str) {
std::string key = "singlekeytest-" + context_str;

NDArray result(Shape(4), context);
NDArray result_cpu;

// initialize data
NDArray data_cpu({0.f, 233.f, -0.12f, 9.f}, Shape(4), Context::cpu());
NDArray data = data_cpu.Copy(context);
NDArray::WaitAll();

KVStore::Init(key, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

result_cpu = result.Copy(Context::cpu());
NDArray::WaitAll();

// compare
for (size_t j=0; j < result_cpu.Size(); j++) {
if (result_cpu.GetData()[j] != data_cpu.GetData()[j]) {
LG << "Error: wrong initialized data in singlekeytest-" << context_str
<< ", expect " << data_cpu.GetData()[j]
<< " got " << result_cpu.GetData()[j];
return false;
}
}

// push gradient
NDArray grad_cpu({0.1f, -2.f, -4.4f, 0.f}, Shape(4), Context::cpu());
NDArray grad = grad_cpu.Copy(context);
NDArray::WaitAll();

KVStore::Push(key, grad);
wkcn marked this conversation as resolved.
Show resolved Hide resolved
NDArray::WaitAll();

// retrieve result
KVStore::Pull(key, &result);
NDArray::WaitAll();

result_cpu = result.Copy(Context::cpu());
NDArray::WaitAll();

// compare
for (size_t j=0; j < result_cpu.Size(); j++) {
if (result_cpu.GetData()[j] != grad_cpu.GetData()[j]) {
LG << "Error: wrong gradient data in singlekeytest-" << context_str
<< ", expect " << grad_cpu.GetData()[j]
<< " got " << result_cpu.GetData()[j];
return false;
}
}

return true;
}

static bool test_multiple_key(const Context &context, const std::string &context_str) {
std::vector<std::string> keys(2);
keys[0] = "multikeytest-0-" + context_str;
keys[1] = "multikeytest-1-" + context_str;

std::vector<NDArray> results(2);
results[0] = NDArray(Shape(4), context);
results[1] = NDArray(Shape(4), context);
std::vector<NDArray> results_cpu(2);

// initialize data
std::vector<NDArray> data_cpu(2);
data_cpu[0] = NDArray({0.f, 2.f, -3.12f, 4.f}, Shape(4), Context::cpu());
data_cpu[1] = NDArray({0.8f, -2.f, 6.6f, 77.f}, Shape(4), Context::cpu());
std::vector<NDArray> data(2);
data[0] = data_cpu[0].Copy(context);
data[1] = data_cpu[1].Copy(context);
NDArray::WaitAll();

KVStore::Init(keys, data);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

results_cpu[0] = results[0].Copy(Context::cpu());
results_cpu[1] = results[1].Copy(Context::cpu());
NDArray::WaitAll();

// compare
for (size_t i=0; i < results_cpu.size(); i++) {
for (size_t j=0; j < results_cpu[i].Size(); j++) {
if (results_cpu[i].GetData()[j] != data_cpu[i].GetData()[j]) {
LG << "Error: wrong initialized data in multikeytest-" << context_str
<< ", expect " << data_cpu[i].GetData()[j]
<< " got " << results_cpu[i].GetData()[j];
return false;
}
}
}

// push gradient, reduce for the second
std::vector<std::string> push_keys(3);
push_keys[0] = "multikeytest-0-" + context_str;
push_keys[1] = "multikeytest-1-" + context_str;
push_keys[2] = "multikeytest-1-" + context_str;
wkcn marked this conversation as resolved.
Show resolved Hide resolved

std::vector<NDArray> grads_cpu(3);
grads_cpu[0] = NDArray({0.2f, -0.3f, -1.1f, 0.0f}, Shape(4), Context::cpu());
grads_cpu[1] = NDArray({2.f, 4.f, -4.f, -5.f}, Shape(4), Context::cpu());
grads_cpu[2] = NDArray({-3.f, -0.2f, 12.f, -9.f}, Shape(4), Context::cpu());
std::vector<NDArray> grads(3);
grads[0] = grads_cpu[0].Copy(context);
grads[1] = grads_cpu[1].Copy(context);
grads[2] = grads_cpu[2].Copy(context);
NDArray::WaitAll();

KVStore::Push(push_keys, grads);
NDArray::WaitAll();

// retrieve result
KVStore::Pull(keys, &results);
NDArray::WaitAll();

results_cpu[0] = results[0].Copy(Context::cpu());
results_cpu[1] = results[1].Copy(Context::cpu());
NDArray::WaitAll();

// compare the first
for (size_t j=0; j < results_cpu[0].Size(); j++) {
if (results_cpu[0].GetData()[j] != grads_cpu[0].GetData()[j]) {
LG << "Error: wrong gradient data in multikeytest-" << context_str
<< ", expect " << grads_cpu[0].GetData()[j]
<< " got " << results_cpu[0].GetData()[j];
return false;
}
}

// compare the second
for (size_t j=0; j < results_cpu[1].Size(); j++) {
if (results_cpu[1].GetData()[j] != (grads_cpu[1].GetData()[j] + grads_cpu[2].GetData()[j])) {
LG << "Error: wrong reduced gradient data in multikeytest-" << context_str
<< ", expect " << (grads_cpu[1].GetData()[j] + grads_cpu[2].GetData()[j])
<< " got " << results_cpu[1].GetData()[j];
return false;
}
}

return true;
}

int main(int argc, char** argv) {
KVStore::SetType("local");

bool success1 = test_single_key(Context::cpu(), "cpu");
bool success2 = test_multiple_key(Context::cpu(), "cpu");

bool success3 = true;
bool success4 = true;

int gpu_count = 0;
if (MXGetGPUCount(&gpu_count) != 0) {
LG << "Error: MXGetGPUCount";

MXNotifyShutdown();
return 1;
}

if (gpu_count > 0) {
success3 = test_single_key(Context::gpu(), "gpu");
success4 = test_multiple_key(Context::gpu(), "gpu");
}

int ret = (success1 && success2 && success3 && success4) ? 0 : 1;

MXNotifyShutdown();
return ret;
}
13 changes: 11 additions & 2 deletions cpp-package/include/mxnet-cpp/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,21 @@ class KVStore {
static void SetType(const std::string& type);
static void RunServer();
static void Init(int key, const NDArray& val);
static void Init(const std::string& key, const NDArray& val);
static void Init(const std::vector<int>& keys, const std::vector<NDArray>& vals);
static void Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals);
static void Push(int key, const NDArray& val, int priority = 0);
static void Push(const std::string& key, const NDArray& val, int priority = 0);
static void Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals, int priority = 0);
const std::vector<NDArray>& vals, int priority = 0);
static void Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority = 0);
static void Pull(int key, NDArray* out, int priority = 0);
static void Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::string& key, NDArray* out, int priority = 0);
wkcn marked this conversation as resolved.
Show resolved Hide resolved
static void Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority = 0);
static void Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority = 0);
// TODO(lx): put lr in optimizer or not?
static void SetOptimizer(std::unique_ptr<Optimizer> optimizer, bool local = false);
static std::string GetType();
Expand Down
78 changes: 75 additions & 3 deletions cpp-package/include/mxnet-cpp/kvstore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ inline void KVStore::Init(int key, const NDArray& val) {
CHECK_EQ(MXKVStoreInit(get_kvstore()->get_handle(), 1, &key, &val_handle), 0);
}

inline void KVStore::Init(const std::string& key, const NDArray& val) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle), 0);
}

inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
Expand All @@ -99,14 +105,36 @@ inline void KVStore::Init(const std::vector<int>& keys, const std::vector<NDArra
val_handles.data()), 0);
}

inline void KVStore::Init(const std::vector<std::string>& keys, const std::vector<NDArray>& vals) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStoreInitEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data()), 0);
}

inline void KVStore::Push(int key, const NDArray& val, int priority) {
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePush(get_kvstore()->get_handle(), 1, &key, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::string& key, const NDArray& val, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle val_handle = val.GetHandle();
CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), 1, &key_handle, &val_handle, priority), 0);
}

inline void KVStore::Push(const std::vector<int>& keys,
const std::vector<NDArray>& vals,
int priority) {
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
Expand All @@ -118,12 +146,37 @@ inline void KVStore::Push(const std::vector<int>& keys,
val_handles.data(), priority), 0);
}

inline void KVStore::Push(const std::vector<std::string>& keys,
const std::vector<NDArray>& vals, int priority) {
CHECK_EQ(keys.size(), vals.size());
std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> val_handles(vals.size());
std::transform(vals.cbegin(), vals.cend(), val_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePushEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
val_handles.data(), priority), 0);
}

inline void KVStore::Pull(int key, NDArray* out, int priority) {
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePull(get_kvstore()->get_handle(), 1, &key, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* outs, int priority) {
inline void KVStore::Pull(const std::string& key, NDArray* out, int priority) {
const char* key_handle = key.c_str();
NDArrayHandle out_handle = out->GetHandle();
CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), 1, &key_handle, &out_handle, priority), 0);
}

inline void KVStore::Pull(const std::vector<int>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<NDArrayHandle> out_handles(keys.size());
Expand All @@ -136,6 +189,25 @@ inline void KVStore::Pull(const std::vector<int>& keys, std::vector<NDArray>* ou
out_handles.data(), priority), 0);
}

inline void KVStore::Pull(const std::vector<std::string>& keys,
std::vector<NDArray>* outs, int priority) {
CHECK_EQ(keys.size(), outs->size());

std::vector<const char*> key_handles(keys.size());
std::transform(keys.cbegin(), keys.cend(), key_handles.begin(),
[](const std::string& key) {
return key.c_str();
});
std::vector<NDArrayHandle> out_handles(keys.size());
std::transform(outs->cbegin(), outs->cend(), out_handles.begin(),
[](const NDArray& val) {
return val.GetHandle();
});

CHECK_EQ(MXKVStorePullEx(get_kvstore()->get_handle(), key_handles.size(), key_handles.data(),
out_handles.data(), priority), 0);
}

inline void KVStore::Updater(int key, NDArrayHandle recv, NDArrayHandle local,
void* handle_) {
Optimizer *opt = static_cast<Optimizer*>(handle_);
Expand Down
7 changes: 5 additions & 2 deletions cpp-package/tests/ci_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,11 @@ cp ../../build/cpp-package/example/mlp_cpu .
cp ../../build/cpp-package/example/mlp_gpu .
./mlp_gpu

cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer
cp ../../build/cpp-package/example/test_optimizer .
./test_optimizer

cp ../../build/cpp-package/example/test_kvstore .
./test_kvstore

cp ../../build/cpp-package/example/test_score .
./test_score 0.93
Expand Down