Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dd01ba1
Support contains() on lists of primitives
mythrocks Dec 17, 2020
4558261
Support contains() on lists of primitives
mythrocks Dec 18, 2020
a87ffb0
Support contains() on lists of primitives
mythrocks Dec 18, 2020
21aae8a
Support contains() on lists of primitives
mythrocks Dec 18, 2020
4e5b819
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
752379a
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
551c014
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
36ffde8
Support contains() on lists of primitives:
mythrocks Dec 19, 2020
2e7d725
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Dec 19, 2020
80510cf
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 4, 2021
f9bb045
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 5, 2021
f880b46
Support contains() on lists of primitives:
mythrocks Jan 5, 2021
ce92cc5
Support contains() on lists of primitives
mythrocks Jan 5, 2021
9c92bf8
Support contains() on lists of primitives
mythrocks Jan 6, 2021
542e39a
Support contains() on lists of primitives:
mythrocks Jan 6, 2021
0a60a12
Support contains() on lists of primitives
mythrocks Jan 7, 2021
63d1a4a
Chrono support in lists::contains()
mythrocks Jan 7, 2021
ed9269f
Added validity iterator for scalars.
mythrocks Jan 7, 2021
2d469b7
Collapsed construct_null_mask() single function
mythrocks Jan 7, 2021
a398bb6
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 7, 2021
6ebed08
Collapsed contains() implementation to one impl function
mythrocks Jan 8, 2021
e01e860
Renamed search_key_has_all_nulls
mythrocks Jan 13, 2021
21d06e9
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 14, 2021
792bd86
Fixed behaviour for lists containing nulls:
mythrocks Jan 15, 2021
bb800a7
Cleaned up SFINAE is_supported().
mythrocks Jan 17, 2021
1210f83
Remove names for unused function/template parameters.
mythrocks Jan 19, 2021
c77d1a1
Code formatting.
mythrocks Jan 19, 2021
21bd39d
Fixed documentation for make_validity_iterator(scalar)
mythrocks Jan 19, 2021
793863a
Added Doxygen directives.
mythrocks Jan 21, 2021
b0ad781
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 21, 2021
4225a62
Move namespace directives to file level.
mythrocks Jan 25, 2021
f7d4bad
Merge remote-tracking branch 'origin/branch-0.18' into list-contains
mythrocks Jan 25, 2021
d6ddee0
Fix doxygen_groups.h lists_contains group
mythrocks Jan 25, 2021
b3b5b6c
Move tests to namespace.
mythrocks Jan 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- PR #6907 Add `replace_null` API with `replace_policy` parameter, `fixed_width` column support

- PR #6775 Implement cudf.DateOffset for months
- PR #7039 Support contains() on lists of primitives

## Improvements

Expand Down
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ test:
- test -f $PREFIX/include/cudf/lists/detail/concatenate.hpp
- test -f $PREFIX/include/cudf/lists/detail/copying.hpp
- test -f $PREFIX/include/cudf/lists/extract.hpp
- test -f $PREFIX/include/cudf/lists/contains.hpp
- test -f $PREFIX/include/cudf/lists/lists_column_view.hpp
- test -f $PREFIX/include/cudf/merge.hpp
- test -f $PREFIX/include/cudf/null_mask.hpp
Expand Down
53 changes: 53 additions & 0 deletions cpp/include/cudf/lists/contains.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
Comment thread
mythrocks marked this conversation as resolved.
Outdated
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/lists/lists_column_view.hpp>

namespace cudf {
namespace lists {

Comment thread
mythrocks marked this conversation as resolved.
/**
* @brief Create a column of bool values indicating whether the specified scalar
* is an element of each row of a list column.
*
* The output column has as many elements as the input `lists` column.
* Output `column[i]` is set to true if the lists row `lists[i]` contains the value
* specified in skey. Otherwise, it is set to false.
Comment thread
mythrocks marked this conversation as resolved.
Outdated
*
* Output `column[i]` is set to null if even one of the following holds true:
Comment thread
mythrocks marked this conversation as resolved.
Outdated
* 1. The search key `skey` is null
* 2. The list row `lists[i]` is null
* 3. The list row `lists[i]` contains even *one* null
Comment thread
mythrocks marked this conversation as resolved.
Outdated
*
* @param lists Lists column whose `n` rows are to be searched
* @param skey The scalar key to be looked up in each list row
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return std::unique_ptr<column> BOOL8 column of `n` rows with the result of the lookup
*/
std::unique_ptr<column> contains(
cudf::lists_column_view const& lists,
cudf::scalar const& skey,
Comment thread
mythrocks marked this conversation as resolved.
Outdated
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

std::unique_ptr<column> contains(
cudf::lists_column_view const& lists,
cudf::column_view const& skeys,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

} // namespace lists
} // namespace cudf
160 changes: 160 additions & 0 deletions cpp/src/lists/contains.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION.
Comment thread
mythrocks marked this conversation as resolved.
Outdated
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <thrust/logical.h>
#include <cudf/column/column_factories.hpp>
#include <cudf/detail/valid_if.cuh>
#include <cudf/lists/contains.hpp>
#include <cudf/lists/list_device_view.cuh>
#include <cudf/lists/lists_column_device_view.cuh>
#include <cudf/lists/lists_column_view.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/scalar/scalar_device_view.cuh>
#include <cudf/utilities/type_dispatcher.hpp>
#include <rmm/exec_policy.hpp>
#include <type_traits>

namespace cudf {
namespace lists {

namespace {

auto CUDA_HOST_DEVICE_CALLABLE counting_iter(size_type n)
{
return thrust::make_counting_iterator(n);
}
Comment thread
mythrocks marked this conversation as resolved.
Outdated

std::pair<rmm::device_buffer, size_type> construct_null_mask(
Comment thread
mythrocks marked this conversation as resolved.
Outdated
cudf::detail::lists_column_device_view const& d_lists,
cudf::scalar const& skey,
bool input_column_has_nulls,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf;
using namespace cudf::detail;

auto const skey_is_valid = skey.is_valid(stream);
auto const output_has_nulls = !skey_is_valid || input_column_has_nulls;

if (!output_has_nulls) { return std::make_pair(rmm::device_buffer{0, stream, mr}, size_type{0}); }
Comment thread
mythrocks marked this conversation as resolved.
Outdated

if (!skey_is_valid) {
return std::make_pair(cudf::create_null_mask(d_lists.size(), mask_state::ALL_NULL, mr),
d_lists.size());
}

return cudf::detail::valid_if(
counting_iter(0), counting_iter(d_lists.size()), [d_lists] __device__(auto const& row_index) {
Comment thread
mythrocks marked this conversation as resolved.
Outdated
auto list = cudf::list_device_view(d_lists, row_index);
if (list.is_null()) { return false; }
return thrust::none_of(thrust::seq,
counting_iter(0),
counting_iter(list.size()),
[&list] __device__(auto const& i) { return list.is_null(i); });
});
}

struct lookup_functor {
template <typename T, typename... Args>
std::enable_if_t<!cudf::is_numeric<T>() && !std::is_same<T, cudf::string_view>::value, void>
operator()(Args&&...) const
{
CUDF_FAIL("lists::contains() is only supported on numeric types and strings.");
}

template <typename T>
std::enable_if_t<cudf::is_numeric<T>() || std::is_same<T, cudf::string_view>::value, void>
operator()(cudf::detail::lists_column_device_view const& d_lists,
cudf::scalar const& skey,
cudf::mutable_column_device_view output_bools,
Comment thread
mythrocks marked this conversation as resolved.
Outdated
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr) const
{
assert(skey.is_valid() && "skey should have been checked for nulls by this point.");

auto h_scalar = static_cast<cudf::scalar_type_t<T> const&>(skey);
auto d_scalar = cudf::get_scalar_device_view(h_scalar);

thrust::transform(rmm::exec_policy(stream),
counting_iter(0),
counting_iter(d_lists.size()),
output_bools.data<bool>(),
[d_lists, d_scalar] __device__(auto row_index) {
auto list = cudf::list_device_view(d_lists, row_index);
if (list.is_null()) { return false; }
for (size_type i{0}; i < list.size(); ++i) {
Comment thread
mythrocks marked this conversation as resolved.
Outdated
if (list.is_null(i)) { return false; }
auto list_element = list.template element<T>(i);
if (list_element == d_scalar.value()) { return true; }
}
return false;
});
}
};

} // namespace

namespace detail {

std::unique_ptr<column> contains(cudf::lists_column_view const& lists,
Comment thread
mythrocks marked this conversation as resolved.
cudf::scalar const& skey,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
{
using namespace cudf;
using namespace cudf::detail;

CUDF_EXPECTS(!cudf::is_nested(lists.child().type()),
"Nested types not supported in lists::contains()");
CUDF_EXPECTS(lists.child().type().id() == skey.type().id(),
"Type of search key does not match list column element type.");
CUDF_EXPECTS(skey.type().id() != type_id::EMPTY, "Type cannot be empty.");

auto const device_view = column_device_view::create(lists.parent(), stream);
auto const d_lists = lists_column_device_view(*device_view);

rmm::device_buffer null_mask;
size_type num_nulls;

std::tie(null_mask, num_nulls) =
construct_null_mask(d_lists, skey, lists.has_nulls(), stream, mr);

auto ret_bools = make_fixed_width_column(
data_type{type_id::BOOL8}, lists.size(), std::move(null_mask), num_nulls, stream, mr);

auto ret_bools_mutable_device_view =
Comment thread
mythrocks marked this conversation as resolved.
Outdated
mutable_column_device_view::create(ret_bools->mutable_view(), stream);

if (skey.is_valid()) {
cudf::type_dispatcher(
skey.type(), lookup_functor{}, d_lists, skey, *ret_bools_mutable_device_view, stream, mr);
}

return ret_bools;
}

} // namespace detail

std::unique_ptr<column> contains(cudf::lists_column_view const& lists,
cudf::scalar const& skey,
rmm::mr::device_memory_resource* mr)
{
return detail::contains(lists, skey, rmm::cuda_stream_default, mr);
Comment thread
mythrocks marked this conversation as resolved.
Outdated
}

} // namespace lists
} // namespace cudf
3 changes: 2 additions & 1 deletion cpp/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,8 @@ ConfigureTest(AST_TEST "${AST_TEST_SRC}")
# - lists tests ----------------------------------------------------------------------------------

set(LISTS_TEST_SRC
"${CMAKE_CURRENT_SOURCE_DIR}/lists/extract_tests.cpp")
"${CMAKE_CURRENT_SOURCE_DIR}/lists/extract_tests.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/lists/contains_tests.cpp")

ConfigureTest(LISTS_TEST "${LISTS_TEST_SRC}")

Expand Down
Loading