diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b726231db5..3b027220032 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ - PR #6907 Add `replace_null` API with `replace_policy` parameter, `fixed_width` column support - PR #6885 Share `factorize` implementation with Index and cudf module - PR #6775 Implement cudf.DateOffset for months +- PR #7039 Support contains() on lists of primitives ## Improvements diff --git a/conda/recipes/libcudf/meta.yaml b/conda/recipes/libcudf/meta.yaml index ad9c72a10b2..894c8750f1a 100644 --- a/conda/recipes/libcudf/meta.yaml +++ b/conda/recipes/libcudf/meta.yaml @@ -125,6 +125,7 @@ test: - test -f $PREFIX/include/cudf/lists/detail/copying.hpp - test -f $PREFIX/include/cudf/lists/count_elements.hpp - test -f $PREFIX/include/cudf/lists/extract.hpp + - test -f $PREFIX/include/cudf/lists/contains.hpp - test -f $PREFIX/include/cudf/lists/gather.hpp - test -f $PREFIX/include/cudf/lists/lists_column_view.hpp - test -f $PREFIX/include/cudf/merge.hpp diff --git a/cpp/include/cudf/detail/iterator.cuh b/cpp/include/cudf/detail/iterator.cuh index 75a710d1d5c..e95d932920e 100644 --- a/cpp/include/cudf/detail/iterator.cuh +++ b/cpp/include/cudf/detail/iterator.cuh @@ -174,6 +174,21 @@ auto inline make_validity_iterator(column_device_view const& column) validity_accessor{column}); } +/** + * @brief Constructs a constant device iterator over a scalar's validity. + * + * Dereferencing the returned iterator returns a `bool`. + * + * For `p = *(iter + i)`, `p` is the validity of the scalar. + * + * @param scalar_value The scalar to iterate + * @return auto Iterator that returns scalar validity + */ +auto inline make_validity_iterator(scalar const& scalar_value) +{ + return thrust::make_constant_iterator(scalar_value.is_valid()); +} + /** * @brief value accessor for scalar with valid data. * The unary functor returns data of Element type of the scalar. diff --git a/cpp/include/cudf/lists/contains.hpp b/cpp/include/cudf/lists/contains.hpp new file mode 100644 index 00000000000..7cd40bb2f86 --- /dev/null +++ b/cpp/include/cudf/lists/contains.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include + +namespace cudf { +namespace lists { +/** + * @addtogroup lists_contains + * @{ + * @file + */ + +/** + * @brief Create a column of bool values indicating whether the specified scalar + * is an element of each row of a list column. + * + * The output column has as many elements as the input `lists` column. + * Output `column[i]` is set to true if the lists row `lists[i]` contains the value + * specified in `search_key`. Otherwise, it is set to false. + * + * Output `column[i]` is set to null if one or more of the following are true: + * 1. The search key `search_key` is null + * 2. The list row `lists[i]` is null + * 3. The list row `lists[i]` does not contain the search key, and contains at least + * one null. + * + * @param lists Lists column whose `n` rows are to be searched + * @param search_key The scalar key to be looked up in each list row + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr BOOL8 column of `n` rows with the result of the lookup + */ +std::unique_ptr contains( + cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** + * @brief Create a column of bool values indicating whether the list rows of the first + * column contain the corresponding values in the second column + * + * The output column has as many elements as the input `lists` column. + * Output `column[i]` is set to true if the lists row `lists[i]` contains the value + * in `search_keys[i]`. Otherwise, it is set to false. + * + * Output `column[i]` is set to null if one or more of the following are true: + * 1. The row `search_keys[i]` is null + * 2. The list row `lists[i]` is null + * 3. The list row `lists[i]` does not contain the `search_keys[i]`, and contains at least + * one null. + * + * @param lists Lists column whose `n` rows are to be searched + * @param search_keys Column of elements to be looked up in each list row + * @param mr Device memory resource used to allocate the returned column's device memory. + * @return std::unique_ptr BOOL8 column of `n` rows with the result of the lookup + */ +std::unique_ptr contains( + cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); + +/** @} */ // end of group +} // namespace lists +} // namespace cudf diff --git a/cpp/include/cudf/lists/list_device_view.cuh b/cpp/include/cudf/lists/list_device_view.cuh index 38708d4878e..824b10ced83 100644 --- a/cpp/include/cudf/lists/list_device_view.cuh +++ b/cpp/include/cudf/lists/list_device_view.cuh @@ -112,12 +112,82 @@ class list_device_view { */ CUDA_DEVICE_CALLABLE lists_column_device_view const& get_column() const { return lists_column; } + template + struct pair_accessor; + + template + using const_pair_iterator = + thrust::transform_iterator, thrust::counting_iterator>; + + /** + * @brief Fetcher for a pair iterator to the first element in the list_device_view. + * + * Dereferencing the returned iterator yields a `thrust::pair`. + * + * If the element at index `i` is valid, then for `p = iter[i]`, + * 1. `p.first` is the value of the element at `i` + * 2. `p.second == true` + * + * If the element at index `i` is null, + * 1. `p.first` is undefined + * 2. `p.second == false` + */ + template + CUDA_DEVICE_CALLABLE const_pair_iterator pair_begin() const + { + return const_pair_iterator{thrust::counting_iterator(0), pair_accessor{*this}}; + } + + /** + * @brief Fetcher for a pair iterator to one position past the last element in the + * list_device_view. + */ + template + CUDA_DEVICE_CALLABLE const_pair_iterator pair_end() const + { + return const_pair_iterator{thrust::counting_iterator(size()), + pair_accessor{*this}}; + } + private: lists_column_device_view const& lists_column; size_type _row_index{}; // Row index in the Lists column vector. size_type _size{}; // Number of elements in *this* list row. size_type begin_offset; // Offset in list_column_device_view where this list begins. + + /** + * @brief pair accessor for elements in a `list_device_view` + * + * This unary functor returns a pair of: + * 1. data element at a specified index + * 2. boolean validity flag for that element + * + * @tparam T The element-type of the list row + */ + template + struct pair_accessor { + list_device_view const& list; + + /** + * @brief constructor + * + * @param _list The `list_device_view` whose rows are being accessed. + */ + explicit CUDA_HOST_DEVICE_CALLABLE pair_accessor(list_device_view const& _list) : list{_list} {} + + /** + * @brief Accessor for the {data, validity} pair at the specified index + * + * @param i Index into the list_device_view + * @return A pair of data element and its validity flag. + */ + CUDA_DEVICE_CALLABLE + thrust::pair operator()(cudf::size_type i) const + { + return {list.element(i), !list.is_null(i)}; + } + }; }; } // namespace cudf diff --git a/cpp/include/doxygen_groups.h b/cpp/include/doxygen_groups.h index 1d796aca4b7..e732a13e67c 100644 --- a/cpp/include/doxygen_groups.h +++ b/cpp/include/doxygen_groups.h @@ -143,6 +143,8 @@ * @defgroup lists_apis Lists * @{ * @defgroup lists_extract Extracting + * @defgroup lists_contains Searching + * @defgroup lists_gather Gathering * @defgroup lists_elements Counting * @} * @defgroup nvtext_apis NVText diff --git a/cpp/src/lists/contains.cu b/cpp/src/lists/contains.cu new file mode 100644 index 00000000000..49f06d5acfd --- /dev/null +++ b/cpp/src/lists/contains.cu @@ -0,0 +1,251 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cudf { +namespace lists { + +namespace { + +auto get_search_keys_device_iterable_view(cudf::column_view const& search_keys, + rmm::cuda_stream_view stream) +{ + return column_device_view::create(search_keys, stream); +} + +auto get_search_keys_device_iterable_view(cudf::scalar const& search_key, rmm::cuda_stream_view) +{ + return &search_key; +} + +template +auto get_pair_iterator(cudf::column_device_view const& d_search_keys) +{ + return d_search_keys.pair_begin(); +} + +template +auto get_pair_iterator(cudf::scalar const& search_key) +{ + return cudf::detail::make_pair_iterator(search_key); +} + +/** + * @brief Functor to search each list row for the specified search keys. + */ +template +struct lookup_functor { + template + struct is_supported { + static constexpr bool value = cudf::is_numeric() || + cudf::is_chrono() || + std::is_same::value; + }; + + template + std::enable_if_t::value, std::unique_ptr> operator()( + Args&&...) const + { + CUDF_FAIL("lists::contains() is only supported on numeric types, chrono types, and strings."); + } + + std::pair construct_null_mask(lists_column_view const& input_lists, + column_view const& result_validity, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + if (!search_keys_have_nulls && !input_lists.has_nulls() && !input_lists.child().has_nulls()) { + return {rmm::device_buffer{0, stream, mr}, size_type{0}}; + } else { + return cudf::detail::valid_if(result_validity.begin(), + result_validity.end(), + thrust::identity{}, + stream, + mr); + } + } + + template + void search_each_list_row(cudf::detail::lists_column_device_view const& d_lists, + SearchKeyPairIter search_key_pair_iter, + cudf::mutable_column_device_view mutable_ret_bools, + cudf::mutable_column_device_view mutable_ret_validity, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(d_lists.size()), + [d_lists, + search_key_pair_iter, + d_bools = mutable_ret_bools.data(), + d_validity = mutable_ret_validity.data()] __device__(auto row_index) { + auto search_key_and_validity = search_key_pair_iter[row_index]; + auto const& search_key_is_valid = search_key_and_validity.second; + + if (search_keys_have_nulls && !search_key_is_valid) { + d_bools[row_index] = false; + d_validity[row_index] = false; + return; + } + + auto list = cudf::list_device_view(d_lists, row_index); + if (list.is_null()) { + d_bools[row_index] = false; + d_validity[row_index] = false; + return; + } + + auto search_key = search_key_and_validity.first; + d_bools[row_index] = thrust::find_if(thrust::seq, + list.pair_begin(), + list.pair_end(), + [search_key] __device__(auto element_and_validity) { + return element_and_validity.second && + (element_and_validity.first == search_key); + }) != list.pair_end(); + d_validity[row_index] = + d_bools[row_index] || + thrust::none_of(thrust::seq, + thrust::make_counting_iterator(size_type{0}), + thrust::make_counting_iterator(list.size()), + [&list] __device__(auto const& i) { return list.is_null(i); }); + }); + } + + template + std::enable_if_t::value, std::unique_ptr> operator()( + cudf::lists_column_view const& lists, + SearchKeyType const& search_key, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) + { + using namespace cudf; + using namespace cudf::detail; + + CUDF_EXPECTS(!cudf::is_nested(lists.child().type()), + "Nested types not supported in lists::contains()"); + CUDF_EXPECTS(lists.child().type().id() == search_key.type().id(), + "Type of search key does not match list column element type."); + CUDF_EXPECTS(search_key.type().id() != type_id::EMPTY, "Type cannot be empty."); + + auto constexpr search_key_is_scalar = std::is_same::value; + + if (search_keys_have_nulls && search_key_is_scalar) { + return make_fixed_width_column(data_type(type_id::BOOL8), + lists.size(), + cudf::create_null_mask(lists.size(), mask_state::ALL_NULL, mr), + lists.size(), + stream, + mr); + } + + auto const device_view = column_device_view::create(lists.parent(), stream); + auto const d_lists = lists_column_device_view(*device_view); + auto const d_skeys = get_search_keys_device_iterable_view(search_key, stream); + + auto const lists_column_has_nulls = lists.has_nulls() || lists.child().has_nulls(); + + auto result_validity = make_fixed_width_column( + data_type{type_id::BOOL8}, lists.size(), cudf::mask_state::UNALLOCATED, stream, mr); + auto result_bools = make_fixed_width_column( + data_type{type_id::BOOL8}, lists.size(), cudf::mask_state::UNALLOCATED, stream, mr); + auto mutable_result_bools = + mutable_column_device_view::create(result_bools->mutable_view(), stream); + auto mutable_result_validity = + mutable_column_device_view::create(result_validity->mutable_view(), stream); + auto search_key_iter = get_pair_iterator(*d_skeys); + + search_each_list_row( + d_lists, search_key_iter, *mutable_result_bools, *mutable_result_validity, stream, mr); + + rmm::device_buffer null_mask; + size_type num_nulls; + + std::tie(null_mask, num_nulls) = + construct_null_mask(lists, result_validity->view(), stream, mr); + result_bools->set_null_mask(std::move(null_mask), num_nulls); + + return result_bools; + } +}; + +} // namespace + +namespace detail { + +std::unique_ptr contains(cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + return search_key.is_valid(stream) + ? cudf::type_dispatcher( + search_key.type(), lookup_functor{}, lists, search_key, stream, mr) + : cudf::type_dispatcher( + search_key.type(), lookup_functor{}, lists, search_key, stream, mr); +} + +std::unique_ptr contains(cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + CUDF_EXPECTS(search_keys.size() == lists.size(), + "Number of search keys must match list column size."); + + return search_keys.has_nulls() + ? cudf::type_dispatcher( + search_keys.type(), lookup_functor{}, lists, search_keys, stream, mr) + : cudf::type_dispatcher( + search_keys.type(), lookup_functor{}, lists, search_keys, stream, mr); +} + +} // namespace detail + +std::unique_ptr contains(cudf::lists_column_view const& lists, + cudf::scalar const& search_key, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::contains(lists, search_key, rmm::cuda_stream_default, mr); +} + +std::unique_ptr contains(cudf::lists_column_view const& lists, + cudf::column_view const& search_keys, + rmm::mr::device_memory_resource* mr) +{ + CUDF_FUNC_RANGE(); + return detail::contains(lists, search_keys, rmm::cuda_stream_default, mr); +} + +} // namespace lists +} // namespace cudf diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 587a833c86e..a654b3e5cf6 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -654,6 +654,7 @@ ConfigureTest(AST_TEST "${AST_TEST_SRC}") # - lists tests ---------------------------------------------------------------------------------- set(LISTS_TEST_SRC + "${CMAKE_CURRENT_SOURCE_DIR}/lists/contains_tests.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/lists/count_elements_tests.cpp" "${CMAKE_CURRENT_SOURCE_DIR}/lists/extract_tests.cpp") diff --git a/cpp/tests/lists/contains_tests.cpp b/cpp/tests/lists/contains_tests.cpp new file mode 100644 index 00000000000..1885f626490 --- /dev/null +++ b/cpp/tests/lists/contains_tests.cpp @@ -0,0 +1,568 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace cudf { +namespace test { + +struct ContainsTest : public BaseFixture { +}; + +using ContainsTestTypes = Concat; + +template +struct TypedContainsTest : public ContainsTest { +}; + +TYPED_TEST_CASE(TypedContainsTest, ContainsTestTypes); + +namespace { +template (), void>* = nullptr> +auto create_scalar_search_key(T const& value) +{ + auto search_key = make_numeric_scalar(data_type{type_to_id()}); + search_key->set_valid(true); + static_cast*>(search_key.get())->set_value(value); + return search_key; +} + +template ::value, void>* = nullptr> +auto create_scalar_search_key(std::string const& value) +{ + return make_string_scalar(value); +} + +template (), void>* = nullptr> +auto create_scalar_search_key(typename T::rep const& value) +{ + auto search_key = make_timestamp_scalar(data_type{type_to_id()}); + search_key->set_valid(true); + static_cast*>(search_key.get())->set_value(value); + return search_key; +} + +template (), void>* = nullptr> +auto create_scalar_search_key(typename T::rep const& value) +{ + auto search_key = make_duration_scalar(data_type{type_to_id()}); + search_key->set_valid(true); + static_cast*>(search_key.get())->set_value(value); + return search_key; +} + +template (), void>* = nullptr> +auto create_null_search_key() +{ + auto search_key = make_numeric_scalar(data_type{type_to_id()}); + search_key->set_valid(false); + return search_key; +} + +template (), void>* = nullptr> +auto create_null_search_key() +{ + auto search_key = make_timestamp_scalar(data_type{type_to_id()}); + search_key->set_valid(false); + return search_key; +} + +template (), void>* = nullptr> +auto create_null_search_key() +{ + auto search_key = make_duration_scalar(data_type{type_to_id()}); + search_key->set_valid(false); + return search_key; +} + +} // namespace + +TYPED_TEST(TypedContainsTest, ListContainsScalarWithNoNulls) +{ + using T = TypeParam; + + auto search_space = lists_column_wrapper{ + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 3}, + {}}.release(); + + auto search_key_one = create_scalar_search_key(1); + + auto actual_result = lists::contains(search_space->view(), *search_key_one); + + auto expected_result = fixed_width_column_wrapper{1, 0, 0, 1, 0, 0, 0, 0, 1, 0}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedContainsTest, ListContainsScalarWithNullLists) +{ + // Test List columns that have NULL list rows. + + using T = TypeParam; + + auto search_space = lists_column_wrapper{ + {{0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {}, + {9, 0, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 3}, + {}}, + make_counting_transform_iterator(0, [](auto i) { + return (i != 3) && (i != 10); + })}.release(); + + auto search_key_one = create_scalar_search_key(1); + + auto actual_result = lists::contains(search_space->view(), *search_key_one); + + auto expected_result = fixed_width_column_wrapper{ + {1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}, + make_counting_transform_iterator(0, [](auto i) { return (i != 3) && (i != 10); })}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedContainsTest, ListContainsScalarNonNullListsWithNullValues) +{ + // Test List columns that have no NULL list rows, but NULL elements in some list rows. + using T = TypeParam; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto search_space = + make_lists_column(8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 11, 15}.release(), + numerals.release(), + 0, + {}); + + auto search_key_one = create_scalar_search_key(1); + + auto actual_result = lists::contains(search_space->view(), *search_key_one); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 0, 1}, {0, 1, 0, 1, 1, 0, 1, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedContainsTest, ListContainsScalarWithNullsInLists) +{ + using T = TypeParam; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 11, 15}.release(), + numerals.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_key_one = create_scalar_search_key(1); + + auto actual_result = lists::contains(search_space->view(), *search_key_one); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 0, 1}, {0, 1, 0, 1, 0, 0, 1, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TEST_F(ContainsTest, BoolListContainsScalarWithNullsInLists) +{ + using T = bool; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 11, 15}.release(), + numerals.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_key_one = create_scalar_search_key(1); + + auto actual_result = lists::contains(search_space->view(), *search_key_one); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 1, 0, 0, 1, 0, 1}, {0, 1, 1, 1, 0, 1, 1, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TEST_F(ContainsTest, StringListContainsScalarWithNullsInLists) +{ + using T = std::string; + + auto strings = strings_column_wrapper{ + {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4"}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 11, 15}.release(), + strings.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_key_one = create_scalar_search_key("1"); + + auto actual_result = lists::contains(search_space->view(), *search_key_one); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 0, 1}, {0, 1, 0, 1, 0, 0, 1, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedContainsTest, ContainsScalarNullSearchKey) +{ + using T = TypeParam; + + auto search_space = lists_column_wrapper{ + {{0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {}, + {9, 0, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 3}, + {}}, + make_counting_transform_iterator(0, [](auto i) { + return (i != 3) && (i != 10); + })}.release(); + + auto search_key_null = create_null_search_key(); + + auto actual_result = lists::contains(search_space->view(), *search_key_null); + + auto expected_result = fixed_width_column_wrapper{ + {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + make_counting_transform_iterator(0, [](auto i) { return false; })}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TEST_F(ContainsTest, ScalarTypeRelatedExceptions) +{ + { + // Nested types unsupported. + auto list_of_lists = lists_column_wrapper{ + {{1, 2, 3}, {4, 5, 6}}, + {{1, 2, 3}, {4, 5, 6}}, + {{1, 2, 3}, + {4, 5, 6}}}.release(); + auto skey = create_scalar_search_key(10); + CUDF_EXPECT_THROW_MESSAGE(lists::contains(list_of_lists->view(), *skey), + "Nested types not supported in lists::contains()"); + } + + { + // Search key must match list elements in type. + auto list_of_ints = + lists_column_wrapper{ + {0, 1, 2}, + {3, 4, 5}, + } + .release(); + auto skey = create_scalar_search_key("Hello, World!"); + CUDF_EXPECT_THROW_MESSAGE(lists::contains(list_of_ints->view(), *skey), + "Type of search key does not match list column element type."); + } +} + +template +struct TypedVectorContainsTest : public ContainsTest { +}; + +using VectorContainsTestTypes = + cudf::test::Concat; + +TYPED_TEST_CASE(TypedVectorContainsTest, VectorContainsTestTypes); + +TYPED_TEST(TypedVectorContainsTest, ListContainsVectorWithNoNulls) +{ + using T = TypeParam; + + auto search_space = lists_column_wrapper{ + {0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {9, 0, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 3}, + {}}.release(); + + auto search_key = fixed_width_column_wrapper{1, 2, 3, 1, 2, 3, 1, 2, 3, 1}; + + auto actual_result = lists::contains(search_space->view(), search_key); + + auto expected_result = fixed_width_column_wrapper{1, 0, 0, 1, 1, 0, 0, 0, 1, 0}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedVectorContainsTest, ListContainsVectorWithNullLists) +{ + // Test List columns that have NULL list rows. + + using T = TypeParam; + + auto search_space = lists_column_wrapper{ + {{0, 1, 2}, + {3, 4, 5}, + {6, 7, 8}, + {}, + {9, 0, 1}, + {2, 3, 4}, + {5, 6, 7}, + {8, 9, 0}, + {}, + {1, 2, 3}, + {}}, + make_counting_transform_iterator(0, [](auto i) { + return (i != 3) && (i != 10); + })}.release(); + + auto search_keys = fixed_width_column_wrapper{1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2}; + + auto actual_result = lists::contains(search_space->view(), search_keys); + + auto expected_result = fixed_width_column_wrapper{ + {1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0}, + make_counting_transform_iterator(0, [](auto i) { return (i != 3) && (i != 10); })}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedVectorContainsTest, ListContainsVectorNonNullListsWithNullValues) +{ + // Test List columns that have no NULL list rows, but NULL elements in some list rows. + using T = TypeParam; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto search_space = + make_lists_column(8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 12, 15}.release(), + numerals.release(), + 0, + {}); + + auto search_keys = fixed_width_column_wrapper{1, 2, 3, 1, 2, 3, 1, 3}; + + auto actual_result = lists::contains(search_space->view(), search_keys); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 1, 1}, {0, 1, 0, 1, 1, 0, 1, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedVectorContainsTest, ListContainsVectorWithNullsInLists) +{ + using T = TypeParam; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 12, 15}.release(), + numerals.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_keys = fixed_width_column_wrapper{1, 2, 3, 1, 2, 3, 1, 3}; + + auto actual_result = lists::contains(search_space->view(), search_keys); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 1, 1}, {0, 1, 0, 1, 0, 0, 1, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TYPED_TEST(TypedVectorContainsTest, ListContainsVectorWithNullsInListsAndInSearchKeys) +{ + using T = TypeParam; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 12, 15}.release(), + numerals.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_keys = fixed_width_column_wrapper{ + {1, 2, 3, 1, 2, 3, 1, 3}, make_counting_transform_iterator(0, [](auto i) { return i != 6; })}; + + auto actual_result = lists::contains(search_space->view(), search_keys); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 0, 1}, {0, 1, 0, 1, 0, 0, 0, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TEST_F(ContainsTest, BoolListContainsVectorWithNullsInListsAndInSearchKeys) +{ + using T = bool; + + auto numerals = fixed_width_column_wrapper{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 12, 15}.release(), + numerals.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_keys = fixed_width_column_wrapper{ + {0, 1, 0, 1, 0, 0, 1, 1}, make_counting_transform_iterator(0, [](auto i) { return i != 6; })}; + + auto actual_result = lists::contains(search_space->view(), search_keys); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 0, 1}, {0, 1, 0, 1, 0, 0, 0, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TEST_F(ContainsTest, StringListContainsVectorWithNullsInListsAndInSearchKeys) +{ + auto numerals = strings_column_wrapper{ + {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "0", "1", "2", "3", "4"}, + make_counting_transform_iterator(0, [](auto i) -> bool { return i % 3; })}; + + auto input_null_mask_iter = make_counting_transform_iterator(0, [](auto i) { return i != 4; }); + + auto search_space = make_lists_column( + 8, + fixed_width_column_wrapper{0, 1, 3, 7, 7, 7, 10, 12, 15}.release(), + numerals.release(), + 1, + cudf::test::detail::make_null_mask(input_null_mask_iter, input_null_mask_iter + 8)); + + auto search_keys = + strings_column_wrapper{{"1", "2", "3", "1", "2", "3", "1", "3"}, + make_counting_transform_iterator(0, [](auto i) { return i != 6; })}; + + auto actual_result = lists::contains(search_space->view(), search_keys); + + auto expected_result = + fixed_width_column_wrapper{{0, 1, 0, 0, 0, 0, 0, 1}, {0, 1, 0, 1, 0, 0, 0, 1}}; + + CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_result, *actual_result); +} + +TEST_F(ContainsTest, VectorTypeRelatedExceptions) +{ + { + // Nested types unsupported. + auto list_of_lists = lists_column_wrapper{ + {{1, 2, 3}, {4, 5, 6}}, + {{1, 2, 3}, {4, 5, 6}}, + {{1, 2, 3}, + {4, 5, 6}}}.release(); + auto skey = fixed_width_column_wrapper{0, 1, 2}; + CUDF_EXPECT_THROW_MESSAGE(lists::contains(list_of_lists->view(), skey), + "Nested types not supported in lists::contains()"); + } + + { + // Search key must match list elements in type. + auto list_of_ints = + lists_column_wrapper{ + {0, 1, 2}, + {3, 4, 5}, + } + .release(); + auto skey = strings_column_wrapper{"Hello", "World"}; + CUDF_EXPECT_THROW_MESSAGE(lists::contains(list_of_ints->view(), skey), + "Type of search key does not match list column element type."); + } + + { + // Search key column size must match lists column size. + auto list_of_ints = lists_column_wrapper{{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}.release(); + + auto skey = fixed_width_column_wrapper{0, 1, 2, 3}; + CUDF_EXPECT_THROW_MESSAGE(lists::contains(list_of_ints->view(), skey), + "Number of search keys must match list column size."); + } +} + +} // namespace test + +} // namespace cudf