Skip to content

Commit

Permalink
Add insert_if tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Aug 7, 2024
1 parent 26ce0b7 commit f933cb5
Showing 1 changed file with 69 additions and 36 deletions.
105 changes: 69 additions & 36 deletions tests/static_multimap/insert_if_test.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,51 +26,84 @@

#include <catch2/catch_template_test_macros.hpp>

template <typename Key, typename Map, typename PairIt, typename KeyIt>
void test_insert_if(Map& map, PairIt pair_begin, KeyIt key_begin, std::size_t size)
template <typename Map>
void test_insert_if(Map& map, std::size_t size)
{
using Key = typename Map::key_type;
using Value = typename Map::mapped_type;

// 50% insertion
auto pred_lambda = [] __device__(Key k) { return k % 2 == 0; };
auto const pred = [] __device__(Key k) { return k % 2 == 0; };
auto const keys_begin = thrust::counting_iterator<Key>{0};

SECTION("Count of n / 2 insertions should be n / 2.")
{
auto const pairs_begin = thrust::make_transform_iterator(
keys_begin, cuda::proclaim_return_type<cuco::pair<Key, Value>>([] __device__(auto i) {
return cuco::pair<Key, Value>{i, i};
}));

auto const num = map.insert_if(pairs_begin, pairs_begin + size, keys_begin, pred);
REQUIRE(num * 2 == size);

map.insert_if(pair_begin, pair_begin + size, key_begin, pred_lambda);
auto const count = map.count(keys_begin, keys_begin + size);
REQUIRE(count * 2 == size);
}

auto res = map.get_size();
REQUIRE(res * 2 == size);
SECTION("Inserting the same element n / 2 times should return n / 2.")
{
auto const pairs_begin = thrust::constant_iterator<cuco::pair<Key, Value>>{{1, 1}};

auto num = map.count(key_begin, key_begin + size);
REQUIRE(num * 2 == size);
auto const num = map.insert_if(pairs_begin, pairs_begin + size, keys_begin, pred);
REQUIRE(num * 2 == size);

auto const count = map.count(keys_begin, keys_begin + size);
REQUIRE(count * 2 == size);
}
}

TEMPLATE_TEST_CASE_SIG(
"Tests of insert_if",
"static_multimap insert_if",
"",
((typename Key, typename Value, cuco::test::probe_sequence Probe), Key, Value, Probe),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing))
((typename Key, typename Value, cuco::test::probe_sequence Probe, int CGSize),
Key,
Value,
Probe,
CGSize),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::double_hashing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::double_hashing, 2),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int32_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int32_t, int64_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 1),
(int64_t, int32_t, cuco::test::probe_sequence::linear_probing, 2),
(int64_t, int64_t, cuco::test::probe_sequence::linear_probing, 2))
{
constexpr std::size_t num_keys{1'000};

thrust::device_vector<Key> d_keys(num_keys);
thrust::device_vector<cuco::pair<Key, Value>> d_pairs(num_keys);

thrust::sequence(thrust::device, d_keys.begin(), d_keys.end());
// multiplicity = 1
thrust::transform(thrust::device,
thrust::counting_iterator<int>(0),
thrust::counting_iterator<int>(num_keys),
d_pairs.begin(),
[] __device__(auto i) { return cuco::pair<Key, Value>{i, i}; });

using probe =
std::conditional_t<Probe == cuco::test::probe_sequence::linear_probing,
cuco::legacy::linear_probing<1, cuco::default_hash_function<Key>>,
cuco::legacy::double_hashing<8, cuco::default_hash_function<Key>>>;

cuco::static_multimap<Key, Value, cuda::thread_scope_device, cuco::cuda_allocator<char>, probe>
map{num_keys * 2, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};
test_insert_if<Key>(map, d_pairs.begin(), d_keys.begin(), num_keys);
using extent_type = cuco::extent<std::size_t>;
using probe = std::conditional_t<
Probe == cuco::test::probe_sequence::linear_probing,
cuco::linear_probing<CGSize, cuco::murmurhash3_32<Key>>,
cuco::double_hashing<CGSize, cuco::murmurhash3_32<Key>, cuco::murmurhash3_32<Key>>>;

auto map = cuco::experimental::static_multimap<Key,
Value,
extent_type,
cuda::thread_scope_device,
thrust::equal_to<Key>,
probe,
cuco::cuda_allocator<std::byte>,
cuco::storage<2>>{
num_keys * 2, cuco::empty_key<Key>{-1}, cuco::empty_value<Value>{-1}};

test_insert_if(map, num_keys);
}

0 comments on commit f933cb5

Please sign in to comment.