@@ -384,7 +384,14 @@ template size_t StringDictionary::getBulk(const std::vector<std::string>& string
384
384
template <class T , class String >
385
385
void StringDictionary::getOrAddBulk (const std::vector<String>& input_strings,
386
386
T* output_string_ids) {
387
- CHECK (!base_dict_) << " Not implemented" ;
387
+ if (base_dict_) {
388
+ auto missing_count =
389
+ base_dict_->getBulk (input_strings, output_string_ids, base_generation_);
390
+ if (!missing_count) {
391
+ return ;
392
+ }
393
+ }
394
+
388
395
if (g_enable_stringdict_parallel) {
389
396
getOrAddBulkParallel (input_strings, output_string_ids);
390
397
return ;
@@ -395,6 +402,12 @@ void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
395
402
const size_t initial_str_count = str_count_;
396
403
size_t idx = 0 ;
397
404
for (const auto & input_string : input_strings) {
405
+ // Skip strings found in the base dictionary.
406
+ if (base_dict_ && output_string_ids[idx] != INVALID_STR_ID) {
407
+ ++idx;
408
+ continue ;
409
+ }
410
+
398
411
if (input_string.empty ()) {
399
412
output_string_ids[idx++] = inline_int_null_value<T>();
400
413
continue ;
@@ -427,7 +440,7 @@ void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
427
440
if (materialize_hashes_) {
428
441
hash_cache_[str_count_] = input_string_hash;
429
442
}
430
- const int32_t string_id = static_cast < int32_t > (str_count_);
443
+ const int32_t string_id = indexToId (str_count_);
431
444
string_id_uint32_table_[hash_bucket] = string_id;
432
445
output_string_ids[idx++] = string_id;
433
446
++str_count_;
@@ -441,7 +454,6 @@ void StringDictionary::getOrAddBulk(const std::vector<String>& input_strings,
441
454
template <class T , class String >
442
455
void StringDictionary::getOrAddBulkParallel (const std::vector<String>& input_strings,
443
456
T* output_string_ids) {
444
- CHECK (!base_dict_) << " Not implemented" ;
445
457
// Compute hashes of the input strings up front, and in parallel,
446
458
// as the string hashing does not need to be behind the subsequent write_lock
447
459
std::vector<uint32_t > input_strings_hashes (input_strings.size ());
@@ -456,6 +468,12 @@ void StringDictionary::getOrAddBulkParallel(const std::vector<String>& input_str
456
468
string_memory_ids.reserve (input_strings.size ());
457
469
size_t input_string_idx{0 };
458
470
for (const auto & input_string : input_strings) {
471
+ // Skip strings found in the base dictionary.
472
+ if (base_dict_ && output_string_ids[input_string_idx] != INVALID_STR_ID) {
473
+ ++input_string_idx;
474
+ continue ;
475
+ }
476
+
459
477
// Currently we make empty strings null
460
478
if (input_string.empty ()) {
461
479
output_string_ids[input_string_idx++] = inline_int_null_value<T>();
@@ -500,11 +518,11 @@ void StringDictionary::getOrAddBulkParallel(const std::vector<String>& input_str
500
518
<< " ) of Dictionary encoded Strings reached for this column" ;
501
519
string_memory_ids.push_back (input_string_idx);
502
520
sum_new_string_lengths += input_string.size ();
503
- string_id_uint32_table_[hash_bucket] = static_cast < int32_t > (shadow_str_count);
521
+ string_id_uint32_table_[hash_bucket] = indexToId (shadow_str_count);
504
522
if (materialize_hashes_) {
505
523
hash_cache_[shadow_str_count] = input_string_hash;
506
524
}
507
- output_string_ids[input_string_idx++] = shadow_str_count++;
525
+ output_string_ids[input_string_idx++] = indexToId ( shadow_str_count++) ;
508
526
}
509
527
appendToStorageBulk (input_strings, string_memory_ids, sum_new_string_lengths);
510
528
const size_t num_strings_added = shadow_str_count - str_count_;
@@ -530,6 +548,31 @@ template void StringDictionary::getOrAddBulk(
530
548
const std::vector<std::string_view>& string_vec,
531
549
int32_t * encoded_vec);
532
550
551
+ template <class String >
552
+ std::vector<int32_t > StringDictionary::getOrAddBulk (
553
+ const std::vector<String>& string_vec) {
554
+ std::vector<int32_t > res (string_vec.size ());
555
+ getOrAddBulk (string_vec, res.data ());
556
+ return res;
557
+ }
558
+
559
+ template std::vector<int32_t > StringDictionary::getOrAddBulk (
560
+ const std::vector<std::string>& string_vec);
561
+ template std::vector<int32_t > StringDictionary::getOrAddBulk (
562
+ const std::vector<std::string_view>& string_vec);
563
+
564
+ template <class String >
565
+ std::vector<int32_t > StringDictionary::getBulk (const std::vector<String>& string_vec) {
566
+ std::vector<int32_t > res (string_vec.size ());
567
+ getBulk (string_vec, res.data ());
568
+ return res;
569
+ }
570
+
571
+ template std::vector<int32_t > StringDictionary::getBulk (
572
+ const std::vector<std::string>& string_vec);
573
+ template std::vector<int32_t > StringDictionary::getBulk (
574
+ const std::vector<std::string_view>& string_vec);
575
+
533
576
template <class String >
534
577
int32_t StringDictionary::getIdOfString (const String& str) const {
535
578
return getIdOfString (str, hash_string (str));
@@ -1035,27 +1078,26 @@ void StringDictionary::increaseHashTableCapacityFromStorageAndMemory(
1035
1078
const std::vector<String>& input_strings,
1036
1079
const std::vector<size_t >& string_memory_ids,
1037
1080
const std::vector<uint32_t >& input_strings_hashes) noexcept {
1038
- CHECK (!base_dict_) << " Not implemented" ;
1039
1081
std::vector<int32_t > new_str_ids (string_id_uint32_table_.size () * 2 , INVALID_STR_ID);
1040
1082
if (materialize_hashes_) {
1041
1083
for (size_t i = 0 ; i != str_count; ++i) {
1042
1084
const uint32_t hash = hash_cache_[i];
1043
1085
const uint32_t bucket = computeUniqueBucketWithHash (hash, new_str_ids);
1044
- new_str_ids[bucket] = i ;
1086
+ new_str_ids[bucket] = indexToId (i) ;
1045
1087
}
1046
1088
hash_cache_.resize (hash_cache_.size () * 2 );
1047
1089
} else {
1048
1090
for (size_t storage_idx = 0 ; storage_idx != storage_high_water_mark; ++storage_idx) {
1049
- const auto storage_string = getOwnedStringChecked (storage_idx);
1091
+ const auto storage_string = getOwnedStringChecked (indexToId ( storage_idx) );
1050
1092
const uint32_t hash = hash_string (storage_string);
1051
1093
const uint32_t bucket = computeUniqueBucketWithHash (hash, new_str_ids);
1052
- new_str_ids[bucket] = storage_idx;
1094
+ new_str_ids[bucket] = indexToId ( storage_idx) ;
1053
1095
}
1054
1096
for (size_t memory_idx = 0 ; memory_idx != string_memory_ids.size (); ++memory_idx) {
1055
1097
const size_t string_memory_id = string_memory_ids[memory_idx];
1056
1098
const uint32_t bucket = computeUniqueBucketWithHash (
1057
1099
input_strings_hashes[string_memory_id], new_str_ids);
1058
- new_str_ids[bucket] = storage_high_water_mark + memory_idx;
1100
+ new_str_ids[bucket] = indexToId ( storage_high_water_mark + memory_idx) ;
1059
1101
}
1060
1102
}
1061
1103
string_id_uint32_table_.swap (new_str_ids);
@@ -1112,21 +1154,21 @@ uint32_t StringDictionary::computeBucketFromStorageAndMemory(
1112
1154
const size_t storage_high_water_mark,
1113
1155
const std::vector<String>& input_strings,
1114
1156
const std::vector<size_t >& string_memory_ids) const noexcept {
1115
- CHECK (!base_dict_) << " Not implemented" ;
1116
1157
uint32_t bucket = input_string_hash & (string_id_uint32_table.size () - 1 );
1117
1158
while (true ) {
1118
1159
const int32_t candidate_string_id = string_id_uint32_table[bucket];
1119
1160
if (candidate_string_id ==
1120
1161
INVALID_STR_ID) { // In this case it means the slot is available for use
1121
1162
break ;
1122
1163
}
1123
- if (!materialize_hashes_ || (input_string_hash == hash_cache_[candidate_string_id])) {
1124
- if (candidate_string_id > 0 &&
1125
- static_cast <size_t >(candidate_string_id) >= storage_high_water_mark) {
1164
+ auto candidate_string_index = idToIndex (candidate_string_id);
1165
+ if (!materialize_hashes_ || (input_string_hash == hashById (candidate_string_id))) {
1166
+ if (candidate_string_index > 0 &&
1167
+ static_cast <size_t >(candidate_string_index) >= storage_high_water_mark) {
1126
1168
// The candidate string is not in storage yet but in our string_memory_ids temp
1127
1169
// buffer
1128
1170
size_t memory_offset =
1129
- static_cast <size_t >(candidate_string_id - storage_high_water_mark);
1171
+ static_cast <size_t >(candidate_string_index - storage_high_water_mark);
1130
1172
const String candidate_string = input_strings[string_memory_ids[memory_offset]];
1131
1173
if (input_string.size () == candidate_string.size () &&
1132
1174
!memcmp (input_string.data (), candidate_string.data (), input_string.size ())) {
0 commit comments