-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
changes based on code reviews #176
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
#include <vector> | ||
#include <string> | ||
#include <utility> | ||
#include <functional> | ||
#include <algorithm> | ||
#include "./comm.h" | ||
|
||
|
@@ -85,7 +86,7 @@ class KVStoreLocal : public KVStore { | |
int priority) override { | ||
std::vector<int> uniq_keys; | ||
std::vector<std::vector<NDArray> > grouped_vals; | ||
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); | ||
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_vals); | ||
|
||
for (size_t i = 0; i < uniq_keys.size(); ++i) { | ||
int key = uniq_keys[i]; | ||
|
@@ -114,7 +115,7 @@ class KVStoreLocal : public KVStore { | |
int priority) override { | ||
std::vector<int> uniq_keys; | ||
std::vector<std::vector<NDArray*> > grouped_vals; | ||
GroupKVPairs(keys, values, &uniq_keys, &grouped_vals); | ||
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals); | ||
|
||
for (size_t i = 0; i < uniq_keys.size(); ++i) { | ||
int key = uniq_keys[i]; | ||
|
@@ -129,7 +130,7 @@ class KVStoreLocal : public KVStore { | |
int priority = 0) override { | ||
std::vector<int> uniq_keys; | ||
std::vector<std::vector<std::pair<NDArray*, NDArray>>> grouped_val_rowids; | ||
GroupKVPairs(keys, val_rowids, &uniq_keys, &grouped_val_rowids); | ||
GroupKVPairsPullRsp(keys, val_rowids, &uniq_keys, &grouped_val_rowids); | ||
for (size_t i = 0; i < uniq_keys.size(); ++i) { | ||
int key = uniq_keys[i]; | ||
const NDArray& local = local_[key]; | ||
|
@@ -174,13 +175,75 @@ class KVStoreLocal : public KVStore { | |
|
||
protected: | ||
/** | ||
* \brief group values on keys | ||
* \brief group values on keys for push | ||
*/ | ||
template <typename V> | ||
void GroupKVPairsPush(const std::vector<int>& keys, | ||
const std::vector<NDArray>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<NDArray>> *grouped_vals) { | ||
// check if the storage type of a value is valid | ||
auto validator = [this](const int key, const NDArray& nd) -> bool { | ||
auto stype = nd.storage_type(); | ||
// valid NDArray | ||
if (stype == kDefaultStorage || stype == kRowSparseStorage) return true; | ||
// invalid NDArray, abort | ||
LOG(FATAL) << "Unexpected storage type detected during kvstore push: " << stype; | ||
return false; | ||
}; | ||
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); | ||
} | ||
/** | ||
* \brief group values on keys for pull | ||
*/ | ||
void GroupKVPairsPull(const std::vector<int>& keys, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GroupKVPairsForPull? |
||
const std::vector<NDArray*>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<NDArray*>> *grouped_vals) { | ||
// check if the storage type of a value is valid | ||
auto validator = [this](const int key, NDArray* nd) -> bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
// valid | ||
if (nd->storage_type() == kDefaultStorage) return true; | ||
// invalid, print warning messages once | ||
if (this->warnings_printed_.find(key) == this->warnings_printed_.end()) { | ||
LOG(INFO) << "Warning: non-default weights detected during kvstore pull. " | ||
<< "Please make sure to use row_sparse_pull with row_ids instead."; | ||
this->warnings_printed_.insert(key); | ||
} | ||
return false; | ||
}; | ||
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); | ||
} | ||
/** | ||
* \brief group values on keys for row_sparse_pull | ||
*/ | ||
void GroupKVPairsPullRsp(const std::vector<int>& keys, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. GroupKVPairsForPullRsp? |
||
const std::vector<std::pair<NDArray*, NDArray>>& values, | ||
std::vector<int> *uniq_keys, | ||
std::vector<std::vector<std::pair<NDArray*, NDArray>>> *grouped_vals) { | ||
// check if the storage type of a value is valid | ||
auto validator = [this](const int key, std::pair<NDArray*, NDArray> val_rowid) -> bool { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
auto val_stype = val_rowid.first->storage_type(); | ||
auto rowid_stype = val_rowid.second.storage_type(); | ||
// check storage types | ||
CHECK_EQ(val_stype, kRowSparseStorage) << "Expected row_sparse storage type for " | ||
<< "row_sparse_pull values, but detected storage type " << val_stype; | ||
CHECK_EQ(rowid_stype, kDefaultStorage) << "Expected default storage type for " | ||
<< "row_sparse_pull rowids, but detected storage type " << rowid_stype; | ||
return true; | ||
}; | ||
GroupKVPairs(keys, values, uniq_keys, grouped_vals, validator); | ||
} | ||
|
||
/** | ||
* \brief group values on keys with validation. | ||
* A value `v` is not included in the result if is_valid(v) returns false. | ||
*/ | ||
template <typename V, typename FValidate> | ||
void GroupKVPairs(const std::vector<int>& keys, | ||
const std::vector<V>& values, | ||
std::vector<int>* uniq_keys, | ||
std::vector<std::vector<V> >* grouped_vals) { | ||
std::vector<std::vector<V> >* grouped_vals, | ||
const FValidate& is_valid) { | ||
CHECK_EQ(keys.size(), values.size()); | ||
// TODO(mli) check if already sorted as an optimization | ||
using Idx = std::pair<int, int>; | ||
|
@@ -194,12 +257,14 @@ class KVStoreLocal : public KVStore { | |
|
||
int pre_key = idx[0].first - 1; | ||
for (auto i : idx) { | ||
if (i.first != pre_key) { | ||
uniq_keys->push_back(i.first); | ||
grouped_vals->push_back({values[i.second]}); | ||
pre_key = i.first;; | ||
} else { | ||
grouped_vals->back().push_back(values[i.second]); | ||
if (is_valid(i.first, values[i.second])) { | ||
if (i.first != pre_key) { | ||
uniq_keys->push_back(i.first); | ||
grouped_vals->push_back({values[i.second]}); | ||
pre_key = i.first; | ||
} else { | ||
grouped_vals->back().push_back(values[i.second]); | ||
} | ||
} | ||
} | ||
} | ||
|
@@ -246,6 +311,8 @@ class KVStoreLocal : public KVStore { | |
std::unordered_map<std::string, int> str_key_dict_; | ||
/// the next available integer for string->int key mapping | ||
int next_str_key_ = 0; | ||
/// whether printed warning due to mismatch stype in each key | ||
std::unordered_set<int> warnings_printed_; | ||
}; | ||
} // namespace kvstore | ||
} // namespace mxnet | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is
GroupKVPairsForPush
clearer?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A second thought is defining validator lambda function in the caller functions or a common place accessible to various callers and pass it to
GroupKVPairs
. In this way, there is no need to define extra interfaces of GroupKVPairsXXXX where XXXX stands for Push and Pull, respectively. Is this feasible?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also thought about that. What file do you think is the best place to put the free function? I started with
GroupKVPairsForPushRsp
but it is too long (>100 char) causing lint to fail..