From 58eb6bc2395e9925c00b5c85f90b0ad85568f5fa Mon Sep 17 00:00:00 2001 From: Martin Leitner-Ankerl Date: Fri, 22 Dec 2023 18:20:52 +0100 Subject: [PATCH] Add extract(key) and extract(it) API. Similar to erase(), but the return value is the erased element, moved out. This has practically the same performance behavior as erase(), except for 2 additional moves. (once into a temporary variable), and then out. --- include/ankerl/unordered_dense.h | 58 ++++++++++++++++++++++++++++---- test/unit/extract.cpp | 36 ++++++++++++++++++++ 2 files changed, 87 insertions(+), 7 deletions(-) diff --git a/include/ankerl/unordered_dense.h b/include/ankerl/unordered_dense.h index e0168d00..a91635dd 100644 --- a/include/ankerl/unordered_dense.h +++ b/include/ankerl/unordered_dense.h @@ -86,6 +86,7 @@ # include // for pair, distance # include // for numeric_limits # include // for allocator, allocator_traits, shared_ptr +# include // for optional # include // for out_of_range # include // for basic_string # include // for basic_string_view, hash @@ -1008,7 +1009,8 @@ class table : public std::conditional_t, base_table_type_map, bas clear_and_fill_buckets_from_values(); } - void do_erase(value_idx_type bucket_idx) { + template + void do_erase(value_idx_type bucket_idx, Op handle_erased_value) { auto const value_idx_to_remove = at(m_buckets, bucket_idx).m_value_idx; // shift down until either empty or an element with correct spot is found @@ -1019,6 +1021,7 @@ class table : public std::conditional_t, base_table_type_map, bas bucket_idx = std::exchange(next_bucket_idx, next(next_bucket_idx)); } at(m_buckets, bucket_idx) = {}; + handle_erased_value(std::move(m_values[value_idx_to_remove])); // update m_values if (value_idx_to_remove != m_values.size() - 1) { @@ -1039,8 +1042,8 @@ class table : public std::conditional_t, base_table_type_map, bas m_values.pop_back(); } - template - auto do_erase_key(K&& key) -> size_t { + template + auto do_erase_key(K&& key, Op handle_erased_value) -> size_t { if (empty()) { return 0; } @@ -1056,7 +1059,7 @@ class table : public std::conditional_t, base_table_type_map, bas if (dist_and_fingerprint != at(m_buckets, bucket_idx).m_dist_and_fingerprint) { return 0; } - do_erase(bucket_idx); + do_erase(bucket_idx, handle_erased_value); return 1; } @@ -1619,15 +1622,37 @@ class table : public std::conditional_t, base_table_type_map, bas bucket_idx = next(bucket_idx); } - do_erase(bucket_idx); + do_erase(bucket_idx, [](value_type&& /*unused*/) { + }); return begin() + static_cast(value_idx_to_remove); } + auto extract(iterator it) -> value_type { + auto hash = mixed_hash(get_key(*it)); + auto bucket_idx = bucket_idx_from_hash(hash); + + auto const value_idx_to_remove = static_cast(it - cbegin()); + while (at(m_buckets, bucket_idx).m_value_idx != value_idx_to_remove) { + bucket_idx = next(bucket_idx); + } + + auto tmp = std::optional{}; + do_erase(bucket_idx, [&tmp](value_type&& val) { + tmp = std::move(val); + }); + return std::move(tmp).value(); + } + template , bool> = true> auto erase(const_iterator it) -> iterator { return erase(begin() + (it - cbegin())); } + template , bool> = true> + auto extract(const_iterator it) -> value_type { + return extract(begin() + (it - cbegin())); + } + auto erase(const_iterator first, const_iterator last) -> iterator { auto const idx_first = first - cbegin(); auto const idx_last = last - cbegin(); @@ -1653,12 +1678,31 @@ class table : public std::conditional_t, base_table_type_map, bas } auto erase(Key const& key) -> size_t { - return do_erase_key(key); + return do_erase_key(key, [](value_type&& /*unused*/) { + }); + } + + auto extract(Key const& key) -> std::optional { + auto tmp = std::optional{}; + do_erase_key(key, [&tmp](value_type&& val) { + tmp = std::move(val); + }); + return tmp; } template , bool> = true> auto erase(K&& key) -> size_t { - return do_erase_key(std::forward(key)); + return do_erase_key(std::forward(key), [](value_type&& /*unused*/) { + }); + } + + template , bool> = true> + auto extract(K&& key) -> std::optional { + auto tmp = std::optional{}; + do_erase_key(std::forward(key), [&tmp](value_type&& val) { + tmp = std::move(val); + }); + return tmp; } void swap(table& other) noexcept(noexcept(std::is_nothrow_swappable_v && diff --git a/test/unit/extract.cpp b/test/unit/extract.cpp index 75b7ac04..a5db77e0 100644 --- a/test/unit/extract.cpp +++ b/test/unit/extract.cpp @@ -3,6 +3,8 @@ #include #include +#include + TEST_CASE_MAP("extract", counter::obj, counter::obj) { auto counts = counter(); INFO(counts); @@ -24,3 +26,37 @@ TEST_CASE_MAP("extract", counter::obj, counter::obj) { REQUIRE(container[i].second.get() == i); } } + +TEST_CASE_MAP("extract_element", counter::obj, counter::obj) { + auto counts = counter(); + INFO(counts); + + counts("init"); + auto map = map_t(); + for (size_t i = 0; i < 100; ++i) { + map.try_emplace(counter::obj{i, counts}, i, counts); + } + + // extract(key) + for (size_t i = 0; i < 20; ++i) { + auto query = counter::obj{i, counts}; + counts("before remove 1"); + auto opt = map.extract(query); + counts("after remove 1"); + REQUIRE(opt); + REQUIRE(opt->first.get() == i); + REQUIRE(opt->second.get() == i); + } + REQUIRE(map.size() == 80); + + // extract iterator + for (size_t i = 20; i < 100; ++i) { + auto query = counter::obj{i, counts}; + auto it = map.find(query); + REQUIRE(it != map.end()); + auto opt = map.extract(it); + REQUIRE(opt.first.get() == i); + REQUIRE(opt.second.get() == i); + } + REQUIRE(map.empty()); +}