Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 3333372

Browse files
committed
Support nested dictionaries in StringDictionary::getOrAddBulk.
Signed-off-by: ienkovich <[email protected]>
1 parent ba0b4a7 commit 3333372

File tree

3 files changed

+107
-23
lines changed

3 files changed

+107
-23
lines changed

omniscidb/StringDictionary/StringDictionary.cpp

+57-15
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,14 @@ template size_t StringDictionary::getBulk(const std::vector<std::string>& string
384384
template <class T, class String>
385385
void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
386386
T* output_string_ids) {
387-
CHECK(!base_dict_) << "Not implemented";
387+
if (base_dict_) {
388+
auto missing_count =
389+
base_dict_->getBulk(input_strings, output_string_ids, base_generation_);
390+
if (!missing_count) {
391+
return;
392+
}
393+
}
394+
388395
if (g_enable_stringdict_parallel) {
389396
getOrAddBulkParallel(input_strings, output_string_ids);
390397
return;
@@ -395,6 +402,12 @@ void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
395402
const size_t initial_str_count = str_count_;
396403
size_t idx = 0;
397404
for (const auto& input_string : input_strings) {
405+
// Skip strings found in the base dictionary.
406+
if (base_dict_ && output_string_ids[idx] != INVALID_STR_ID) {
407+
++idx;
408+
continue;
409+
}
410+
398411
if (input_string.empty()) {
399412
output_string_ids[idx++] = inline_int_null_value<T>();
400413
continue;
@@ -427,7 +440,7 @@ void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
427440
if (materialize_hashes_) {
428441
hash_cache_[str_count_] = input_string_hash;
429442
}
430-
const int32_t string_id = static_cast<int32_t>(str_count_);
443+
const int32_t string_id = indexToId(str_count_);
431444
string_id_uint32_table_[hash_bucket] = string_id;
432445
output_string_ids[idx++] = string_id;
433446
++str_count_;
@@ -441,7 +454,6 @@ void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
441454
template <class T, class String>
442455
void StringDictionary::getOrAddBulkParallel(const std::vector<String>& input_strings,
443456
T* output_string_ids) {
444-
CHECK(!base_dict_) << "Not implemented";
445457
// Compute hashes of the input strings up front, and in parallel,
446458
// as the string hashing does not need to be behind the subsequent write_lock
447459
std::vector<uint32_t> input_strings_hashes(input_strings.size());
@@ -456,6 +468,12 @@ void StringDictionary::getOrAddBulkParallel(const std::vector<String>& input_str
456468
string_memory_ids.reserve(input_strings.size());
457469
size_t input_string_idx{0};
458470
for (const auto& input_string : input_strings) {
471+
// Skip strings found in the base dictionary.
472+
if (base_dict_ && output_string_ids[input_string_idx] != INVALID_STR_ID) {
473+
++input_string_idx;
474+
continue;
475+
}
476+
459477
// Currently we make empty strings null
460478
if (input_string.empty()) {
461479
output_string_ids[input_string_idx++] = inline_int_null_value<T>();
@@ -500,11 +518,11 @@ void StringDictionary::getOrAddBulkParallel(const std::vector<String>& input_str
500518
<< ") of Dictionary encoded Strings reached for this column";
501519
string_memory_ids.push_back(input_string_idx);
502520
sum_new_string_lengths += input_string.size();
503-
string_id_uint32_table_[hash_bucket] = static_cast<int32_t>(shadow_str_count);
521+
string_id_uint32_table_[hash_bucket] = indexToId(shadow_str_count);
504522
if (materialize_hashes_) {
505523
hash_cache_[shadow_str_count] = input_string_hash;
506524
}
507-
output_string_ids[input_string_idx++] = shadow_str_count++;
525+
output_string_ids[input_string_idx++] = indexToId(shadow_str_count++);
508526
}
509527
appendToStorageBulk(input_strings, string_memory_ids, sum_new_string_lengths);
510528
const size_t num_strings_added = shadow_str_count - str_count_;
@@ -530,6 +548,31 @@ template void StringDictionary::getOrAddBulk(
530548
const std::vector<std::string_view>& string_vec,
531549
int32_t* encoded_vec);
532550

551+
template <class String>
552+
std::vector<int32_t> StringDictionary::getOrAddBulk(
553+
const std::vector<String>& string_vec) {
554+
std::vector<int32_t> res(string_vec.size());
555+
getOrAddBulk(string_vec, res.data());
556+
return res;
557+
}
558+
559+
template std::vector<int32_t> StringDictionary::getOrAddBulk(
560+
const std::vector<std::string>& string_vec);
561+
template std::vector<int32_t> StringDictionary::getOrAddBulk(
562+
const std::vector<std::string_view>& string_vec);
563+
564+
template <class String>
565+
std::vector<int32_t> StringDictionary::getBulk(const std::vector<String>& string_vec) {
566+
std::vector<int32_t> res(string_vec.size());
567+
getBulk(string_vec, res.data());
568+
return res;
569+
}
570+
571+
template std::vector<int32_t> StringDictionary::getBulk(
572+
const std::vector<std::string>& string_vec);
573+
template std::vector<int32_t> StringDictionary::getBulk(
574+
const std::vector<std::string_view>& string_vec);
575+
533576
template <class String>
534577
int32_t StringDictionary::getIdOfString(const String& str) const {
535578
return getIdOfString(str, hash_string(str));
@@ -1035,27 +1078,26 @@ void StringDictionary::increaseHashTableCapacityFromStorageAndMemory(
10351078
const std::vector<String>& input_strings,
10361079
const std::vector<size_t>& string_memory_ids,
10371080
const std::vector<uint32_t>& input_strings_hashes) noexcept {
1038-
CHECK(!base_dict_) << "Not implemented";
10391081
std::vector<int32_t> new_str_ids(string_id_uint32_table_.size() * 2, INVALID_STR_ID);
10401082
if (materialize_hashes_) {
10411083
for (size_t i = 0; i != str_count; ++i) {
10421084
const uint32_t hash = hash_cache_[i];
10431085
const uint32_t bucket = computeUniqueBucketWithHash(hash, new_str_ids);
1044-
new_str_ids[bucket] = i;
1086+
new_str_ids[bucket] = indexToId(i);
10451087
}
10461088
hash_cache_.resize(hash_cache_.size() * 2);
10471089
} else {
10481090
for (size_t storage_idx = 0; storage_idx != storage_high_water_mark; ++storage_idx) {
1049-
const auto storage_string = getOwnedStringChecked(storage_idx);
1091+
const auto storage_string = getOwnedStringChecked(indexToId(storage_idx));
10501092
const uint32_t hash = hash_string(storage_string);
10511093
const uint32_t bucket = computeUniqueBucketWithHash(hash, new_str_ids);
1052-
new_str_ids[bucket] = storage_idx;
1094+
new_str_ids[bucket] = indexToId(storage_idx);
10531095
}
10541096
for (size_t memory_idx = 0; memory_idx != string_memory_ids.size(); ++memory_idx) {
10551097
const size_t string_memory_id = string_memory_ids[memory_idx];
10561098
const uint32_t bucket = computeUniqueBucketWithHash(
10571099
input_strings_hashes[string_memory_id], new_str_ids);
1058-
new_str_ids[bucket] = storage_high_water_mark + memory_idx;
1100+
new_str_ids[bucket] = indexToId(storage_high_water_mark + memory_idx);
10591101
}
10601102
}
10611103
string_id_uint32_table_.swap(new_str_ids);
@@ -1112,21 +1154,21 @@ uint32_t StringDictionary::computeBucketFromStorageAndMemory(
11121154
const size_t storage_high_water_mark,
11131155
const std::vector<String>& input_strings,
11141156
const std::vector<size_t>& string_memory_ids) const noexcept {
1115-
CHECK(!base_dict_) << "Not implemented";
11161157
uint32_t bucket = input_string_hash & (string_id_uint32_table.size() - 1);
11171158
while (true) {
11181159
const int32_t candidate_string_id = string_id_uint32_table[bucket];
11191160
if (candidate_string_id ==
11201161
INVALID_STR_ID) { // In this case it means the slot is available for use
11211162
break;
11221163
}
1123-
if (!materialize_hashes_ || (input_string_hash == hash_cache_[candidate_string_id])) {
1124-
if (candidate_string_id > 0 &&
1125-
static_cast<size_t>(candidate_string_id) >= storage_high_water_mark) {
1164+
auto candidate_string_index = idToIndex(candidate_string_id);
1165+
if (!materialize_hashes_ || (input_string_hash == hashById(candidate_string_id))) {
1166+
if (candidate_string_index > 0 &&
1167+
static_cast<size_t>(candidate_string_index) >= storage_high_water_mark) {
11261168
// The candidate string is not in storage yet but in our string_memory_ids temp
11271169
// buffer
11281170
size_t memory_offset =
1129-
static_cast<size_t>(candidate_string_id - storage_high_water_mark);
1171+
static_cast<size_t>(candidate_string_index - storage_high_water_mark);
11301172
const String candidate_string = input_strings[string_memory_ids[memory_offset]];
11311173
if (input_string.size() == candidate_string.size() &&
11321174
!memcmp(input_string.data(), candidate_string.data(), input_string.size())) {

omniscidb/StringDictionary/StringDictionary.h

+6-4
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@
3131

3232
#define USE_LEGACY_STR_DICT
3333

34-
extern bool g_enable_stringdict_parallel;
35-
3634
namespace legacy {
3735
class StringDictionary;
3836
}
@@ -112,8 +110,10 @@ class StringDictionary {
112110
const int64_t generation) const;
113111
template <class T, class String>
114112
void getOrAddBulk(const std::vector<String>& string_vec, T* encoded_vec);
115-
template <class T, class String>
116-
void getOrAddBulkParallel(const std::vector<String>& string_vec, T* encoded_vec);
113+
template <class String>
114+
std::vector<int32_t> getOrAddBulk(const std::vector<String>& string_vec);
115+
template <class String>
116+
std::vector<int32_t> getBulk(const std::vector<String>& string_vec);
117117
template <class String>
118118
int32_t getIdOfString(const String&) const;
119119
std::string getString(int32_t string_id) const;
@@ -184,6 +184,8 @@ class StringDictionary {
184184
std::string getStringUnlocked(int32_t string_id) const noexcept;
185185
std::string getOwnedStringChecked(const int string_id) const noexcept;
186186
std::pair<char*, size_t> getOwnedStringBytesChecked(const int string_id) const noexcept;
187+
template <class T, class String>
188+
void getOrAddBulkParallel(const std::vector<String>& string_vec, T* encoded_vec);
187189
template <class String>
188190
uint32_t computeBucket(
189191
const uint32_t hash,

omniscidb/Tests/StringDictionaryTest.cpp

+44-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "TestHelpers.h"
1818

19+
#include "Shared/scope.h"
1920
#include "StringDictionary/StringDictionaryProxy.h"
2021

2122
#include <cstdio>
@@ -30,6 +31,7 @@
3031
using namespace std::string_literals;
3132

3233
EXTERN extern bool g_cache_string_hash;
34+
EXTERN extern bool g_enable_stringdict_parallel;
3335

3436
TEST(StringDictionary, AddAndGet) {
3537
const DictRef dict_ref(-1, 1);
@@ -460,6 +462,44 @@ TEST(NestedStringDictionary, GetBulk) {
460462
}
461463
}
462464

465+
TEST(NestedStringDictionary, GetOrAddBulk) {
466+
std::shared_ptr<StringDictionary> dict1 =
467+
std::make_shared<StringDictionary>(DictRef{-1, 1}, g_cache_string_hash);
468+
ASSERT_EQ(dict1->getOrAdd("str1"), 0);
469+
ASSERT_EQ(dict1->getOrAdd("str3"), 1);
470+
ASSERT_EQ(dict1->getOrAdd("str5"), 2);
471+
ASSERT_EQ(dict1->getOrAdd("str7"), 3);
472+
ASSERT_EQ(dict1->getOrAdd("str9"), 4);
473+
474+
bool old_enable_stringdict_parallel = g_enable_stringdict_parallel;
475+
bool old_cache_string_hash = g_cache_string_hash;
476+
ScopeGuard g([old_enable_stringdict_parallel, old_cache_string_hash]() {
477+
g_enable_stringdict_parallel = old_enable_stringdict_parallel;
478+
g_cache_string_hash = old_cache_string_hash;
479+
});
480+
for (bool enable_stringdict_parallel : {true, false}) {
481+
for (bool cache_string_hash : {true, false}) {
482+
g_enable_stringdict_parallel = enable_stringdict_parallel;
483+
g_cache_string_hash = cache_string_hash;
484+
StringDictionary dict2(dict1, 3, g_cache_string_hash);
485+
std::vector<int> ids(10, -10);
486+
// Only str1, str2, and str3 should be used from the base dict.
487+
dict2.getOrAddBulk(std::vector<std::string>({"str1"s,
488+
"str2"s,
489+
"str3"s,
490+
"str4"s,
491+
"str5"s,
492+
"str6"s,
493+
"str7"s,
494+
"str8"s,
495+
"str9"s,
496+
"str10"s}),
497+
ids.data());
498+
ASSERT_EQ(ids, std::vector<int>({0, 3, 1, 4, 2, 5, 6, 7, 8, 9}));
499+
}
500+
}
501+
}
502+
463503
static std::shared_ptr<StringDictionary> create_and_fill_dictionary() {
464504
const DictRef dict_ref(-1, 1);
465505
std::shared_ptr<StringDictionary> string_dict =
@@ -481,10 +521,10 @@ static std::shared_ptr<StringDictionary> create_and_fill_dictionary() {
481521
return string_dict;
482522
}
483523

484-
TEST(StringDictionaryProxy, GetOrAddTransientBulk) {
524+
TEST(NestedStringDictionary, GetOrAddTransientBulk) {
485525
auto string_dict = create_and_fill_dictionary();
486526

487-
StringDictionaryProxy string_dict_proxy(string_dict, string_dict->storageEntryCount());
527+
StringDictionary string_dict_proxy(string_dict, string_dict->entryCount());
488528
{
489529
// First iteration is identical to first of the StringDictionary GetOrAddBulk test,
490530
// and results should be the same
@@ -533,10 +573,10 @@ TEST(StringDictionaryProxy, GetOrAddTransientBulk) {
533573
}
534574
}
535575

536-
TEST(StringDictionaryProxy, GetTransientBulk) {
576+
TEST(NestedStringDictionary, GetTransientBulk) {
537577
auto string_dict = create_and_fill_dictionary();
538578

539-
StringDictionaryProxy string_dict_proxy(string_dict, string_dict->storageEntryCount());
579+
StringDictionary string_dict_proxy(string_dict, string_dict->entryCount());
540580
{
541581
// First iteration is identical to first of the StryingDictionaryProxy
542582
// GetOrAddTransientBulk test, and results should be the same

0 commit comments

Comments
 (0)