-
Notifications
You must be signed in to change notification settings - Fork 1k
Strong index types for equality comparator #10883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50b8891
b9ed4d7
d67f17e
464ed2b
b26b318
1fd199d
18bd9f0
b5b8b39
4060b4f
73c4b27
9cdbe27
c8a38fe
8b5ef34
77f85b4
529e944
fb0e192
56d99ba
c5998b7
b78d978
3aea8d4
bbaf360
4a1d7aa
09c5661
290323f
4c69edd
ea8c223
857f570
d9f63f0
12f7a8b
fbd5b90
3db6484
3e81b53
1834095
9cb656b
c766bf3
a311bcc
b935835
ff26024
f779bff
157abbc
a2ac19d
75249e8
f50faf5
bed1162
6930952
2781af1
8b239d4
17bd96c
ae77f68
d6b5eb9
5f39a28
bf3555c
df98698
b67d070
4429cc4
893db8a
7e69d3b
c51b053
934ee73
cf996f0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -646,6 +646,7 @@ namespace equality { | |
| template <typename Nullate> | ||
| class device_row_comparator { | ||
| friend class self_comparator; | ||
| friend class two_table_comparator; | ||
|
|
||
| public: | ||
| /** | ||
|
|
@@ -855,6 +856,7 @@ struct preprocessed_table { | |
|
|
||
| private: | ||
| friend class self_comparator; | ||
| friend class two_table_comparator; | ||
| friend class hash::row_hasher; | ||
|
|
||
| using table_device_view_owner = | ||
|
|
@@ -923,6 +925,98 @@ class self_comparator { | |
| std::shared_ptr<preprocessed_table> d_t; | ||
| }; | ||
|
|
||
| template <typename Comparator> | ||
| struct strong_index_comparator_adapter { | ||
| __device__ constexpr bool operator()(lhs_index_type const lhs_index, | ||
| rhs_index_type const rhs_index) const noexcept | ||
| { | ||
| return comparator(static_cast<cudf::size_type>(lhs_index), | ||
| static_cast<cudf::size_type>(rhs_index)); | ||
| } | ||
|
|
||
| __device__ constexpr bool operator()(rhs_index_type const rhs_index, | ||
| lhs_index_type const lhs_index) const noexcept | ||
| { | ||
| return this->operator()(lhs_index, rhs_index); | ||
| } | ||
|
|
||
| Comparator const comparator; | ||
| }; | ||
|
|
||
| /** | ||
| * @brief An owning object that can be used to equality compare rows of two different tables. | ||
| * | ||
| * This class takes two table_views and preprocesses certain columns to allow for equality | ||
| * comparison. The preprocessed table and temporary data required for the comparison are created and | ||
| * owned by this class. | ||
| * | ||
| * Alternatively, `two_table_comparator` can be constructed from two existing | ||
| * `shared_ptr<preprocessed_table>`s when sharing the same tables among multiple comparators. | ||
| * | ||
| * This class can then provide a functor object that can used on the device. | ||
| * The object of this class must outlive the usage of the device functor. | ||
| */ | ||
| class two_table_comparator { | ||
| public: | ||
| /** | ||
| * @brief Construct an owning object for performing equality comparisons between two rows from two | ||
| * tables. | ||
| * | ||
| * The left and right table are expected to have the same number of columns and data types for | ||
| * each column. | ||
| * | ||
| * @param left The left table to compare. | ||
| * @param right The right table to compare. | ||
| * @param stream The stream to construct this object on. Not the stream that will be used for | ||
| * comparisons using this object. | ||
| */ | ||
| two_table_comparator(table_view const& left, | ||
| table_view const& right, | ||
| rmm::cuda_stream_view stream); | ||
|
|
||
| /** | ||
| * @brief Construct an owning object for performing equality comparisons between two rows from two | ||
| * tables. | ||
| * | ||
| * This constructor allows independently constructing a `preprocessed_table` and sharing it among | ||
| * multiple comparators. | ||
| * | ||
| * @param left The left table preprocessed for equality comparison. | ||
| * @param right The right table preprocessed for equality comparison. | ||
| */ | ||
| two_table_comparator(std::shared_ptr<preprocessed_table> left, | ||
| std::shared_ptr<preprocessed_table> right) | ||
| : d_left_table{std::move(left)}, d_right_table{std::move(right)} | ||
| { | ||
| } | ||
|
|
||
| /** | ||
| * @brief Return the binary operator for comparing rows in the table. | ||
| * | ||
| * Returns a binary callable, `F`, with signatures `bool F(lhs_index_type, rhs_index_type)` and | ||
| * `bool F(rhs_index_type, lhs_index_type)`. | ||
| * | ||
| * `F(lhs_index_type i, rhs_index_type j)` returns true if and only if row `i` of the left table | ||
| * compares equal to row `j` of the right table. | ||
| * | ||
| * Similarly, `F(rhs_index_type i, lhs_index_type j)` returns true if and only if row `i` of the | ||
| * right table compares equal to row `j` of the left table. | ||
| * | ||
| * @tparam Nullate A cudf::nullate type describing whether to check for nulls. | ||
| */ | ||
| template <typename Nullate> | ||
| auto device_comparator(Nullate nullate = {}, | ||
| null_equality nulls_are_equal = null_equality::EQUAL) const | ||
| { | ||
| return strong_index_comparator_adapter<device_row_comparator<Nullate>>{ | ||
| device_row_comparator<Nullate>(nullate, *d_left_table, *d_right_table, nulls_are_equal)}; | ||
|
Comment on lines
+1011
to
+1012
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fyi, CTAD is your friend. |
||
| } | ||
|
|
||
| private: | ||
| std::shared_ptr<preprocessed_table> d_left_table; | ||
| std::shared_ptr<preprocessed_table> d_right_table; | ||
| }; | ||
|
|
||
| } // namespace equality | ||
|
|
||
| namespace hash { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,9 +14,9 @@ | |
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <cudf/detail/structs/utilities.hpp> | ||
| #include <cudf/scalar/scalar_device_view.cuh> | ||
| #include <cudf/structs/detail/contains.hpp> | ||
| #include <cudf/table/experimental/row_operators.cuh> | ||
| #include <cudf/table/row_operators.cuh> | ||
| #include <cudf/table/table_device_view.cuh> | ||
| #include <cudf/table/table_view.hpp> | ||
|
|
@@ -35,52 +35,32 @@ bool contains(structs_column_view const& haystack, | |
| scalar const& needle, | ||
| rmm::cuda_stream_view stream) | ||
| { | ||
| CUDF_EXPECTS(haystack.type() == needle.type(), "scalar and column types must match"); | ||
| auto const haystack_tv = table_view{{haystack}}; | ||
| // Create a (structs) column_view of one row having children given from the input scalar. | ||
| auto const needle_tv = static_cast<struct_scalar const*>(&needle)->view(); | ||
| auto const needle_as_col = | ||
| column_view(data_type{type_id::STRUCT}, | ||
| 1, | ||
| nullptr, | ||
| nullptr, | ||
| 0, | ||
| 0, | ||
| std::vector<column_view>{needle_tv.begin(), needle_tv.end()}); | ||
|
Comment on lines
+39
to
+48
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably add overloads of the comparators that take a scalar for the lhs or rhs that do this automaticaly.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm glad you mentioned this. I was thinking that would be a helpful next step as well, but I hadn't yet spent the time to identify how many times that pattern occurs.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be an interesting follow up PR.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Filed an issue: #10892
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also filed a related FEA issue: #10893 |
||
|
|
||
| auto const scalar_table = static_cast<struct_scalar const*>(&needle)->view(); | ||
| CUDF_EXPECTS(haystack.num_children() == scalar_table.num_columns(), | ||
| "struct scalar and structs column must have the same number of children"); | ||
| for (size_type i = 0; i < haystack.num_children(); ++i) { | ||
| CUDF_EXPECTS(haystack.child(i).type() == scalar_table.column(i).type(), | ||
| "scalar and column children types must match"); | ||
| } | ||
| // Haystack and needle compatibility is checked by the table comparator constructor. | ||
| auto const comparator = cudf::experimental::row::equality::two_table_comparator( | ||
| haystack_tv, table_view{{needle_as_col}}, stream); | ||
| auto const has_nulls = has_nested_nulls(haystack_tv) || has_nested_nulls(needle_tv); | ||
| auto const d_comp = comparator.device_comparator(nullate::DYNAMIC{has_nulls}); | ||
|
|
||
| // Prepare to flatten the structs column and scalar. | ||
| auto const has_null_elements = has_nested_nulls(table_view{std::vector<column_view>{ | ||
| haystack.child_begin(), haystack.child_end()}}) || | ||
| has_nested_nulls(scalar_table); | ||
| auto const flatten_nullability = has_null_elements | ||
| ? structs::detail::column_nullability::FORCE | ||
| : structs::detail::column_nullability::MATCH_INCOMING; | ||
|
|
||
| // Flatten the input structs column, only materialize the bitmask if there is null in the input. | ||
| auto const haystack_flattened = | ||
| structs::detail::flatten_nested_columns(table_view{{haystack}}, {}, {}, flatten_nullability); | ||
| auto const needle_flattened = | ||
| structs::detail::flatten_nested_columns(scalar_table, {}, {}, flatten_nullability); | ||
|
|
||
| // The struct scalar only contains the struct member columns. | ||
| // Thus, if there is any null in the input, we must exclude the first column in the flattened | ||
| // table of the input column from searching because that column is the materialized bitmask of | ||
| // the input structs column. | ||
| auto const haystack_flattened_content = haystack_flattened.flattened_columns(); | ||
| auto const haystack_flattened_children = table_view{std::vector<column_view>{ | ||
| haystack_flattened_content.begin() + static_cast<size_type>(has_null_elements), | ||
| haystack_flattened_content.end()}}; | ||
|
|
||
| auto const d_haystack_children_ptr = | ||
| table_device_view::create(haystack_flattened_children, stream); | ||
| auto const d_needle_ptr = table_device_view::create(needle_flattened, stream); | ||
|
|
||
| auto const start_iter = thrust::make_counting_iterator<size_type>(0); | ||
| auto const start_iter = cudf::experimental::row::lhs_iterator(0); | ||
| auto const end_iter = start_iter + haystack.size(); | ||
| auto const comp = row_equality_comparator(nullate::DYNAMIC{has_null_elements}, | ||
| *d_haystack_children_ptr, | ||
| *d_needle_ptr, | ||
| null_equality::EQUAL); | ||
| using cudf::experimental::row::rhs_index_type; | ||
|
|
||
| auto const found_iter = thrust::find_if( | ||
| rmm::exec_policy(stream), start_iter, end_iter, [comp] __device__(auto const idx) { | ||
| return comp(idx, 0); // compare haystack[idx] == val[0]. | ||
| rmm::exec_policy(stream), start_iter, end_iter, [d_comp] __device__(auto const idx) { | ||
| // Compare haystack[idx] == needle_as_col[0]. | ||
| return d_comp(idx, static_cast<rhs_index_type>(0)); | ||
| }); | ||
|
|
||
| return found_iter != end_iter; | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.