-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Support contains() on lists of primitives #7039
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
dd01ba1
Support contains() on lists of primitives
mythrocks 4558261
Support contains() on lists of primitives
mythrocks a87ffb0
Support contains() on lists of primitives
mythrocks 21aae8a
Support contains() on lists of primitives
mythrocks 4e5b819
Support contains() on lists of primitives:
mythrocks 752379a
Support contains() on lists of primitives:
mythrocks 551c014
Support contains() on lists of primitives:
mythrocks 36ffde8
Support contains() on lists of primitives:
mythrocks 2e7d725
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks 80510cf
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks f9bb045
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks f880b46
Support contains() on lists of primitives:
mythrocks ce92cc5
Support contains() on lists of primitives
mythrocks 9c92bf8
Support contains() on lists of primitives
mythrocks 542e39a
Support contains() on lists of primitives:
mythrocks 0a60a12
Support contains() on lists of primitives
mythrocks 63d1a4a
Chrono support in lists::contains()
mythrocks ed9269f
Added validity iterator for scalars.
mythrocks 2d469b7
Collapsed construct_null_mask() single function
mythrocks a398bb6
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks 6ebed08
Collapsed contains() implementation to one impl function
mythrocks e01e860
Renamed search_key_has_all_nulls
mythrocks 21d06e9
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks 792bd86
Fixed behaviour for lists containing nulls:
mythrocks bb800a7
Cleaned up SFINAE is_supported().
mythrocks 1210f83
Remove names for unused function/template parameters.
mythrocks c77d1a1
Code formatting.
mythrocks 21bd39d
Fixed documentation for make_validity_iterator(scalar)
mythrocks 793863a
Added Doxygen directives.
mythrocks b0ad781
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks 4225a62
Move namespace directives to file level.
mythrocks f7d4bad
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks d6ddee0
Fix doxygen_groups.h lists_contains group
mythrocks b3b5b6c
Move tests to namespace.
mythrocks File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| /* | ||
| * Copyright (c) 2020, NVIDIA CORPORATION. | ||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| #pragma once | ||
|
|
||
| #include <cudf/column/column.hpp> | ||
| #include <cudf/lists/lists_column_view.hpp> | ||
|
|
||
| namespace cudf { | ||
| namespace lists { | ||
|
|
||
|
mythrocks marked this conversation as resolved.
|
||
| /** | ||
| * @brief Create a column of bool values indicating whether the specified scalar | ||
| * is an element of each row of a list column. | ||
| * | ||
| * The output column has as many elements as the input `lists` column. | ||
| * Output `column[i]` is set to true if the lists row `lists[i]` contains the value | ||
| * specified in skey. Otherwise, it is set to false. | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| * | ||
| * Output `column[i]` is set to null if even one of the following holds true: | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| * 1. The search key `skey` is null | ||
| * 2. The list row `lists[i]` is null | ||
| * 3. The list row `lists[i]` contains even *one* null | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| * | ||
| * @param lists Lists column whose `n` rows are to be searched | ||
| * @param skey The scalar key to be looked up in each list row | ||
| * @param mr Device memory resource used to allocate the returned column's device memory. | ||
| * @return std::unique_ptr<column> BOOL8 column of `n` rows with the result of the lookup | ||
| */ | ||
| std::unique_ptr<column> contains( | ||
| cudf::lists_column_view const& lists, | ||
| cudf::scalar const& skey, | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); | ||
|
|
||
| std::unique_ptr<column> contains( | ||
| cudf::lists_column_view const& lists, | ||
| cudf::column_view const& skeys, | ||
| rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()); | ||
|
|
||
| } // namespace lists | ||
| } // namespace cudf | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,160 @@ | ||
| /* | ||
| * Copyright (c) 2020, NVIDIA CORPORATION. | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| * | ||
| * Licensed under the Apache License, Version 2.0 (the "License"); | ||
| * you may not use this file except in compliance with the License. | ||
| * You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| #include <thrust/logical.h> | ||
| #include <cudf/column/column_factories.hpp> | ||
| #include <cudf/detail/valid_if.cuh> | ||
| #include <cudf/lists/contains.hpp> | ||
| #include <cudf/lists/list_device_view.cuh> | ||
| #include <cudf/lists/lists_column_device_view.cuh> | ||
| #include <cudf/lists/lists_column_view.hpp> | ||
| #include <cudf/scalar/scalar.hpp> | ||
| #include <cudf/scalar/scalar_device_view.cuh> | ||
| #include <cudf/utilities/type_dispatcher.hpp> | ||
| #include <rmm/exec_policy.hpp> | ||
| #include <type_traits> | ||
|
|
||
| namespace cudf { | ||
| namespace lists { | ||
|
|
||
| namespace { | ||
|
|
||
| auto CUDA_HOST_DEVICE_CALLABLE counting_iter(size_type n) | ||
| { | ||
| return thrust::make_counting_iterator(n); | ||
| } | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
|
|
||
| std::pair<rmm::device_buffer, size_type> construct_null_mask( | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| cudf::detail::lists_column_device_view const& d_lists, | ||
| cudf::scalar const& skey, | ||
| bool input_column_has_nulls, | ||
| rmm::cuda_stream_view stream, | ||
| rmm::mr::device_memory_resource* mr) | ||
| { | ||
| using namespace cudf; | ||
| using namespace cudf::detail; | ||
|
|
||
| auto const skey_is_valid = skey.is_valid(stream); | ||
| auto const output_has_nulls = !skey_is_valid || input_column_has_nulls; | ||
|
|
||
| if (!output_has_nulls) { return std::make_pair(rmm::device_buffer{0, stream, mr}, size_type{0}); } | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
|
|
||
| if (!skey_is_valid) { | ||
| return std::make_pair(cudf::create_null_mask(d_lists.size(), mask_state::ALL_NULL, mr), | ||
| d_lists.size()); | ||
| } | ||
|
|
||
| return cudf::detail::valid_if( | ||
| counting_iter(0), counting_iter(d_lists.size()), [d_lists] __device__(auto const& row_index) { | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| auto list = cudf::list_device_view(d_lists, row_index); | ||
| if (list.is_null()) { return false; } | ||
| return thrust::none_of(thrust::seq, | ||
| counting_iter(0), | ||
| counting_iter(list.size()), | ||
| [&list] __device__(auto const& i) { return list.is_null(i); }); | ||
| }); | ||
| } | ||
|
|
||
| struct lookup_functor { | ||
| template <typename T, typename... Args> | ||
| std::enable_if_t<!cudf::is_numeric<T>() && !std::is_same<T, cudf::string_view>::value, void> | ||
| operator()(Args&&...) const | ||
| { | ||
| CUDF_FAIL("lists::contains() is only supported on numeric types and strings."); | ||
| } | ||
|
|
||
| template <typename T> | ||
| std::enable_if_t<cudf::is_numeric<T>() || std::is_same<T, cudf::string_view>::value, void> | ||
| operator()(cudf::detail::lists_column_device_view const& d_lists, | ||
| cudf::scalar const& skey, | ||
| cudf::mutable_column_device_view output_bools, | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| rmm::cuda_stream_view stream, | ||
| rmm::mr::device_memory_resource* mr) const | ||
| { | ||
| assert(skey.is_valid() && "skey should have been checked for nulls by this point."); | ||
|
|
||
| auto h_scalar = static_cast<cudf::scalar_type_t<T> const&>(skey); | ||
| auto d_scalar = cudf::get_scalar_device_view(h_scalar); | ||
|
|
||
| thrust::transform(rmm::exec_policy(stream), | ||
| counting_iter(0), | ||
| counting_iter(d_lists.size()), | ||
| output_bools.data<bool>(), | ||
| [d_lists, d_scalar] __device__(auto row_index) { | ||
| auto list = cudf::list_device_view(d_lists, row_index); | ||
| if (list.is_null()) { return false; } | ||
| for (size_type i{0}; i < list.size(); ++i) { | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| if (list.is_null(i)) { return false; } | ||
| auto list_element = list.template element<T>(i); | ||
| if (list_element == d_scalar.value()) { return true; } | ||
| } | ||
| return false; | ||
| }); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace | ||
|
|
||
| namespace detail { | ||
|
|
||
| std::unique_ptr<column> contains(cudf::lists_column_view const& lists, | ||
|
mythrocks marked this conversation as resolved.
|
||
| cudf::scalar const& skey, | ||
| rmm::cuda_stream_view stream, | ||
| rmm::mr::device_memory_resource* mr) | ||
| { | ||
| using namespace cudf; | ||
| using namespace cudf::detail; | ||
|
|
||
| CUDF_EXPECTS(!cudf::is_nested(lists.child().type()), | ||
| "Nested types not supported in lists::contains()"); | ||
| CUDF_EXPECTS(lists.child().type().id() == skey.type().id(), | ||
| "Type of search key does not match list column element type."); | ||
| CUDF_EXPECTS(skey.type().id() != type_id::EMPTY, "Type cannot be empty."); | ||
|
|
||
| auto const device_view = column_device_view::create(lists.parent(), stream); | ||
| auto const d_lists = lists_column_device_view(*device_view); | ||
|
|
||
| rmm::device_buffer null_mask; | ||
| size_type num_nulls; | ||
|
|
||
| std::tie(null_mask, num_nulls) = | ||
| construct_null_mask(d_lists, skey, lists.has_nulls(), stream, mr); | ||
|
|
||
| auto ret_bools = make_fixed_width_column( | ||
| data_type{type_id::BOOL8}, lists.size(), std::move(null_mask), num_nulls, stream, mr); | ||
|
|
||
| auto ret_bools_mutable_device_view = | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| mutable_column_device_view::create(ret_bools->mutable_view(), stream); | ||
|
|
||
| if (skey.is_valid()) { | ||
| cudf::type_dispatcher( | ||
| skey.type(), lookup_functor{}, d_lists, skey, *ret_bools_mutable_device_view, stream, mr); | ||
| } | ||
|
|
||
| return ret_bools; | ||
| } | ||
|
|
||
| } // namespace detail | ||
|
|
||
| std::unique_ptr<column> contains(cudf::lists_column_view const& lists, | ||
| cudf::scalar const& skey, | ||
| rmm::mr::device_memory_resource* mr) | ||
| { | ||
| return detail::contains(lists, skey, rmm::cuda_stream_default, mr); | ||
|
mythrocks marked this conversation as resolved.
Outdated
|
||
| } | ||
|
|
||
| } // namespace lists | ||
| } // namespace cudf | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.