Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit a45c880

Browse files
committed
Support nested dictionaries in StringDictionary::getCompare.
Signed-off-by: ienkovich <[email protected]>
1 parent d962b78 commit a45c880

File tree

3 files changed

+203
-96
lines changed

3 files changed

+203
-96
lines changed

omniscidb/StringDictionary/StringDictionary.cpp

+83-88
Original file line numberDiff line numberDiff line change
@@ -750,83 +750,101 @@ std::vector<int32_t> StringDictionary::getLike(const std::string& pattern,
750750
return result;
751751
}
752752

753-
std::vector<int32_t> StringDictionary::getEquals(std::string pattern,
754-
std::string comp_operator,
755-
size_t generation) {
756-
CHECK(!base_dict_) << "Not implemented";
753+
std::vector<int32_t> StringDictionary::getEquals(const std::string& pattern,
754+
const std::string& comp_operator,
755+
int64_t generation) const {
756+
mapd_lock_guard<mapd_shared_mutex> write_lock(rw_mutex_);
757757
std::vector<int32_t> result;
758+
if (base_dict_) {
759+
result = base_dict_->getEquals(
760+
pattern, comp_operator, std::min(generation, base_generation_));
761+
if ((comp_operator == "=" && !result.empty()) || generation < base_generation_) {
762+
return result;
763+
}
764+
}
765+
758766
auto eq_id_itr = equal_cache_.find(pattern);
759-
int32_t eq_id = MAX_STRLEN + 1;
760-
int32_t cur_size = str_count_;
767+
int32_t eq_id = -1;
761768
if (eq_id_itr != equal_cache_.end()) {
762769
auto eq_id = eq_id_itr->second;
763770
if (comp_operator == "=") {
764-
result.push_back(eq_id);
771+
if (eq_id < generation) {
772+
result.push_back(eq_id);
773+
}
765774
} else {
766-
for (int32_t idx = 0; idx <= cur_size; idx++) {
767-
if (idx == eq_id) {
768-
continue;
775+
for (int32_t id = base_generation_; id < generation; id++) {
776+
if (id != eq_id) {
777+
result.push_back(id);
769778
}
770-
result.push_back(idx);
771779
}
772780
}
773781
} else {
774782
std::vector<std::thread> workers;
775783
int worker_count = cpu_threads();
776784
CHECK_GT(worker_count, 0);
777-
std::vector<std::vector<int32_t>> worker_results(worker_count);
778-
CHECK_LE(generation, str_count_);
779785
for (int worker_idx = 0; worker_idx < worker_count; ++worker_idx) {
780786
workers.emplace_back(
781-
[&worker_results, &pattern, generation, worker_idx, worker_count, this]() {
782-
for (size_t string_id = worker_idx; string_id < generation;
787+
[&eq_id, &pattern, generation, worker_idx, worker_count, this]() {
788+
for (int string_id = indexToId(worker_idx); string_id < generation;
783789
string_id += worker_count) {
784790
const auto str = getStringUnlocked(string_id);
785791
if (str == pattern) {
786-
worker_results[worker_idx].push_back(string_id);
792+
// Only one thread can find matching string, so no additional sync.
793+
eq_id = string_id;
794+
break;
787795
}
788796
}
789797
});
790798
}
791799
for (auto& worker : workers) {
792800
worker.join();
793801
}
794-
for (const auto& worker_result : worker_results) {
795-
result.insert(result.end(), worker_result.begin(), worker_result.end());
796-
}
797-
if (result.size() > 0) {
798-
const auto it_ok = equal_cache_.insert(std::make_pair(pattern, result[0]));
802+
if (eq_id >= 0) {
803+
const auto it_ok = equal_cache_.insert(std::make_pair(pattern, eq_id));
799804
CHECK(it_ok.second);
800-
eq_id = result[0];
801805
}
802806
if (comp_operator == "<>") {
803-
for (int32_t idx = 0; idx <= cur_size; idx++) {
804-
if (idx == eq_id) {
805-
continue;
807+
for (int32_t id = base_generation_; id < generation; id++) {
808+
if (id != eq_id) {
809+
result.push_back(id);
806810
}
807-
result.push_back(idx);
808811
}
812+
} else if (eq_id >= 0 && eq_id < generation) {
813+
result.push_back(eq_id);
809814
}
810815
}
811816
return result;
812817
}
813818

814819
std::vector<int32_t> StringDictionary::getCompare(const std::string& pattern,
815820
const std::string& comp_operator,
816-
const size_t generation) {
817-
CHECK(!base_dict_) << "Not implemented";
821+
int64_t generation) const {
822+
generation = generation >= 0 ? std::min(generation, static_cast<int64_t>(entryCount()))
823+
: static_cast<int64_t>(entryCount());
824+
{
825+
// The lock is used only to check cache.
826+
mapd_shared_lock<mapd_shared_mutex> read_lock(rw_mutex_);
827+
if ((sorted_cache.size() < str_count_) &&
828+
(comp_operator == "=" || comp_operator == "<>")) {
829+
read_lock.unlock();
830+
return getEquals(pattern, comp_operator, generation);
831+
}
832+
}
833+
818834
mapd_lock_guard<mapd_shared_mutex> write_lock(rw_mutex_);
819835
std::vector<int32_t> ret;
820-
if (str_count_ == 0) {
821-
return ret;
822-
}
823-
if (sorted_cache.size() < str_count_) {
824-
if (comp_operator == "=" || comp_operator == "<>") {
825-
return getEquals(pattern, comp_operator, generation);
836+
if (base_dict_) {
837+
ret = base_dict_->getCompare(
838+
pattern, comp_operator, std::min(generation, base_generation_));
839+
if ((comp_operator == "=" && !ret.empty()) || generation < base_generation_) {
840+
return ret;
826841
}
842+
}
827843

844+
if (sorted_cache.size() < str_count_) {
828845
buildSortedCache();
829846
}
847+
830848
auto cache_index = compare_cache_.get(pattern);
831849

832850
if (!cache_index) {
@@ -868,92 +886,72 @@ std::vector<int32_t> StringDictionary::getCompare(const std::string& pattern,
868886
// For < operator if the index that we have points to the element which is equal to
869887
// the pattern that we are searching for we simply get all the elements less than the
870888
// index. If the element pointed by the index is not equal to the pattern we are
871-
// comparing with we also need to include that index in result vector, except when the
872-
// index points to 0 and the pattern is lesser than the smallest value in the string
873-
// dictionary.
889+
// comparing with we also need to include that index in result vector.
874890

875891
if (comp_operator == "<") {
876892
size_t idx = cache_index->index;
877893
if (cache_index->diff) {
878894
idx = cache_index->index + 1;
879-
if (cache_index->index == 0 && cache_index->diff > 0) {
880-
idx = cache_index->index;
881-
}
882895
}
883896
for (size_t i = 0; i < idx; i++) {
884-
ret.push_back(sorted_cache[i]);
897+
if (sorted_cache[i] < generation) {
898+
ret.push_back(sorted_cache[i]);
899+
}
885900
}
886901

887-
// For <= operator if the index that we have points to the element which is equal to
888-
// the pattern that we are searching for we want to include the element pointed by
889-
// the index in the result set. If the element pointed by the index is not equal to
890-
// the pattern we are comparing with we just want to include all the ids with index
891-
// less than the index that is cached, except when pattern that we are searching for
892-
// is smaller than the smallest string in the dictionary.
893-
902+
// For <= operator we want to include the all elements less than the index and
903+
// the index itself since it cannot be greater than the pattern.
894904
} else if (comp_operator == "<=") {
895905
size_t idx = cache_index->index + 1;
896-
if (cache_index == 0 && cache_index->diff > 0) {
897-
idx = cache_index->index;
898-
}
899906
for (size_t i = 0; i < idx; i++) {
900-
ret.push_back(sorted_cache[i]);
907+
if (sorted_cache[i] < generation) {
908+
ret.push_back(sorted_cache[i]);
909+
}
901910
}
902911

903912
// For > operator we want to get all the elements with index greater than the index
904-
// that we have except, when the pattern we are searching for is lesser than the
905-
// smallest string in the dictionary we also want to include the id of the index
906-
// that we have.
907-
908913
} else if (comp_operator == ">") {
909914
size_t idx = cache_index->index + 1;
910-
if (cache_index->index == 0 && cache_index->diff > 0) {
911-
idx = cache_index->index;
912-
}
913915
for (size_t i = idx; i < sorted_cache.size(); i++) {
914-
ret.push_back(sorted_cache[i]);
916+
if (sorted_cache[i] < generation) {
917+
ret.push_back(sorted_cache[i]);
918+
}
915919
}
916920

917-
// For >= operator when the indexed element that we have points to element which is
918-
// equal to the pattern we are searching for we want to include that in the result
919-
// vector. If the index that we have does not point to the string which is equal to
920-
// the pattern we are searching we don't want to include that id into the result
921-
// vector except when the index is 0.
922-
921+
// For >= operator we want to get all the elements with index greater than the index.
922+
// We also include the index if it matches the pattern
923923
} else if (comp_operator == ">=") {
924924
size_t idx = cache_index->index;
925925
if (cache_index->diff) {
926926
idx = cache_index->index + 1;
927-
if (cache_index->index == 0 && cache_index->diff > 0) {
928-
idx = cache_index->index;
929-
}
930927
}
931928
for (size_t i = idx; i < sorted_cache.size(); i++) {
932-
ret.push_back(sorted_cache[i]);
929+
if (sorted_cache[i] < generation) {
930+
ret.push_back(sorted_cache[i]);
931+
}
933932
}
934933
} else if (comp_operator == "=") {
935934
if (!cache_index->diff) {
936-
ret.push_back(sorted_cache[cache_index->index]);
935+
if (sorted_cache[cache_index->index] < generation) {
936+
ret.push_back(sorted_cache[cache_index->index]);
937+
}
937938
}
938939

939940
// For <> operator it is simple matter of not including id of string which is equal
940941
// to pattern we are searching for.
941942
} else if (comp_operator == "<>") {
942943
if (!cache_index->diff) {
943-
size_t idx = cache_index->index;
944-
for (size_t i = 0; i < idx; i++) {
945-
ret.push_back(sorted_cache[i]);
946-
}
947-
++idx;
948-
for (size_t i = idx; i < sorted_cache.size(); i++) {
949-
ret.push_back(sorted_cache[i]);
944+
int eq_id = sorted_cache[cache_index->index];
945+
for (int id = base_generation_; id < generation; ++id) {
946+
if (id != eq_id) {
947+
ret.push_back(id);
948+
}
950949
}
951950
} else {
952-
for (size_t i = 0; i < sorted_cache.size(); i++) {
953-
ret.insert(ret.begin(), sorted_cache.begin(), sorted_cache.end());
951+
for (int id = base_generation_; id < generation; ++id) {
952+
ret.push_back(id);
954953
}
955954
}
956-
957955
} else {
958956
std::runtime_error("Unsupported string comparison operator");
959957
}
@@ -1375,20 +1373,18 @@ void StringDictionary::invalidateInvertedIndex() noexcept {
13751373
compare_cache_.invalidateInvertedIndex();
13761374
}
13771375

1378-
void StringDictionary::buildSortedCache() {
1379-
CHECK(!base_dict_) << "Not implemented";
1376+
void StringDictionary::buildSortedCache() const {
13801377
// This method is not thread-safe.
13811378
const auto cur_cache_size = sorted_cache.size();
13821379
std::vector<int32_t> temp_sorted_cache;
13831380
for (size_t i = cur_cache_size; i < str_count_; i++) {
1384-
temp_sorted_cache.push_back(i);
1381+
temp_sorted_cache.push_back(indexToId(i));
13851382
}
13861383
sortCache(temp_sorted_cache);
13871384
mergeSortedCache(temp_sorted_cache);
13881385
}
13891386

1390-
void StringDictionary::sortCache(std::vector<int32_t>& cache) {
1391-
CHECK(!base_dict_) << "Not implemented";
1387+
void StringDictionary::sortCache(std::vector<int32_t>& cache) const {
13921388
// This method is not thread-safe.
13931389

13941390
// this boost sort is creating some problems when we use UTF-8 encoded strings.
@@ -1401,8 +1397,7 @@ void StringDictionary::sortCache(std::vector<int32_t>& cache) {
14011397
});
14021398
}
14031399

1404-
void StringDictionary::mergeSortedCache(std::vector<int32_t>& temp_sorted_cache) {
1405-
CHECK(!base_dict_) << "Not implemented";
1400+
void StringDictionary::mergeSortedCache(std::vector<int32_t>& temp_sorted_cache) const {
14061401
// this method is not thread safe
14071402
std::vector<int32_t> updated_cache(temp_sorted_cache.size() + sorted_cache.size());
14081403
size_t t_idx = 0, s_idx = 0, idx = 0;

omniscidb/StringDictionary/StringDictionary.h

+8-8
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class StringDictionary {
130130

131131
std::vector<int32_t> getCompare(const std::string& pattern,
132132
const std::string& comp_operator,
133-
const size_t generation);
133+
int64_t generation = -1) const;
134134

135135
std::vector<int32_t> getRegexpLike(const std::string& pattern,
136136
const char escape,
@@ -223,12 +223,12 @@ class StringDictionary {
223223
size_t& mem_size,
224224
const size_t min_capacity_requested = 0) noexcept;
225225
void invalidateInvertedIndex() noexcept;
226-
std::vector<int32_t> getEquals(std::string pattern,
227-
std::string comp_operator,
228-
size_t generation);
229-
void buildSortedCache();
230-
void sortCache(std::vector<int32_t>& cache);
231-
void mergeSortedCache(std::vector<int32_t>& temp_sorted_cache);
226+
std::vector<int32_t> getEquals(const std::string& pattern,
227+
const std::string& comp_operator,
228+
int64_t generation) const;
229+
void buildSortedCache() const;
230+
void sortCache(std::vector<int32_t>& cache) const;
231+
void mergeSortedCache(std::vector<int32_t>& temp_sorted_cache) const;
232232

233233
int indexToId(int string_idx) const { return string_idx + base_generation_; }
234234
int idToIndex(int string_id) const { return string_id - base_generation_; }
@@ -243,7 +243,7 @@ class StringDictionary {
243243
size_t collisions_;
244244
std::vector<int32_t> string_id_uint32_table_;
245245
std::vector<uint32_t> hash_cache_;
246-
std::vector<int32_t> sorted_cache;
246+
mutable std::vector<int32_t> sorted_cache;
247247
bool materialize_hashes_;
248248
StringIdxEntry* offset_map_;
249249
char* payload_map_;

0 commit comments

Comments
 (0)