diff --git a/java/src/main/java/ai/rapids/cudf/ColumnView.java b/java/src/main/java/ai/rapids/cudf/ColumnView.java index 7ae4540bafe..8bc764a078e 100644 --- a/java/src/main/java/ai/rapids/cudf/ColumnView.java +++ b/java/src/main/java/ai/rapids/cudf/ColumnView.java @@ -3260,12 +3260,11 @@ public final Table extractRe(String pattern) throws CudfException { } /** - * Extracts all strings that match the given regular expression and corresponds to the + * Extracts all strings that match the given regular expression and corresponds to the * regular expression group index. Any null inputs also result in null output entries. - * + * * For supported regex patterns refer to: * @link https://docs.rapids.ai/api/libcudf/nightly/md_regex.html - * @param pattern The regex pattern * @param idx The regex group index * @return A new column vector of extracted matches @@ -3313,7 +3312,7 @@ public final ColumnVector urlEncode() throws CudfException { } private static void assertIsSupportedMapKeyType(DType keyType) { - boolean isSupportedKeyType = + boolean isSupportedKeyType = !keyType.equals(DType.EMPTY) && !keyType.equals(DType.LIST) && !keyType.equals(DType.STRUCT); assert isSupportedKeyType : "Map lookup by STRUCT and LIST keys is not supported."; } @@ -3331,7 +3330,7 @@ public final ColumnVector getMapValue(ColumnView keys) { return new ColumnVector(mapLookupForKeys(getNativeView(), keys.getNativeView())); } - /** + /** * Given a column of type List> and a key of type X, return a column of type Y, * where each row in the output column is the Y value corresponding to the X key. * If the key is not found, the corresponding output value is null. @@ -3542,6 +3541,88 @@ public final ColumnVector listSortRows(boolean isDescending, boolean isNullSmall return new ColumnVector(listSortRows(getNativeView(), isDescending, isNullSmallest)); } + /** + * For each pair of lists from the input lists columns, check if they have any common non-null + * elements. + * + * A null input row in any of the input columns will result in a null output row. During checking + * for common elements, nulls within each list are considered as different values while + * floating-point NaN values are considered as equal. + * + * The input lists columns must have the same size and same data type. + * + * @param lhs The input lists column for one side + * @param rhs The input lists column for the other side + * @return A column of type BOOL8 containing the check result + */ + public static ColumnVector listsHaveOverlap(ColumnView lhs, ColumnView rhs) { + assert lhs.getType().equals(DType.LIST) && rhs.getType().equals(DType.LIST) : + "Input columns type must be of type LIST"; + assert lhs.getRowCount() == rhs.getRowCount() : "Input columns must have the same size"; + return new ColumnVector(listsHaveOverlap(lhs.getNativeView(), rhs.getNativeView())); + } + + /** + * Find the intersection without duplicate between lists at each row of the given lists columns. + * + * A null input row in any of the input lists columns will result in a null output row. During + * finding list intersection, nulls and floating-point NaN values within each list are + * considered as equal values. + * + * The input lists columns must have the same size and same data type. + * + * @param lhs The input lists column for one side + * @param rhs The input lists column for the other side + * @return A lists column containing the intersection result + */ + public static ColumnVector listsIntersectDistinct(ColumnView lhs, ColumnView rhs) { + assert lhs.getType().equals(DType.LIST) && rhs.getType().equals(DType.LIST) : + "Input columns type must be of type LIST"; + assert lhs.getRowCount() == rhs.getRowCount() : "Input columns must have the same size"; + return new ColumnVector(listsIntersectDistinct(lhs.getNativeView(), rhs.getNativeView())); + } + + /** + * Find the union without duplicate between lists at each row of the given lists columns. + * + * A null input row in any of the input lists columns will result in a null output row. During + * finding list union, nulls and floating-point NaN values within each list are considered as + * equal values. + * + * The input lists columns must have the same size and same data type. + * + * @param lhs The input lists column for one side + * @param rhs The input lists column for the other side + * @return A lists column containing the union result + */ + public static ColumnVector listsUnionDistinct(ColumnView lhs, ColumnView rhs) { + assert lhs.getType().equals(DType.LIST) && rhs.getType().equals(DType.LIST) : + "Input columns type must be of type LIST"; + assert lhs.getRowCount() == rhs.getRowCount() : "Input columns must have the same size"; + return new ColumnVector(listsUnionDistinct(lhs.getNativeView(), rhs.getNativeView())); + } + + /** + * Find the difference of lists of the left column against lists of the right column. + * Specifically, find the elements (without duplicates) from each list of the left column that + * do not exist in the corresponding list of the right column. + * + * A null input row in any of the input lists columns will result in a null output row. During + * finding, nulls and floating-point NaN values within each list are considered as equal values. + * + * The input lists columns must have the same size and same data type. + * + * @param lhs The input lists column for one side + * @param rhs The input lists column for the other side + * @return A lists column containing the difference result + */ + public static ColumnVector listsDifferenceDistinct(ColumnView lhs, ColumnView rhs) { + assert lhs.getType().equals(DType.LIST) && rhs.getType().equals(DType.LIST) : + "Input columns type must be of type LIST"; + assert lhs.getRowCount() == rhs.getRowCount() : "Input columns must have the same size"; + return new ColumnVector(listsDifferenceDistinct(lhs.getNativeView(), rhs.getNativeView())); + } + /** * Generate list offsets from sizes of each list. * NOTICE: This API only works for INT32. Otherwise, the behavior is undefined. And no null and negative value is allowed. @@ -4089,6 +4170,14 @@ private static native long stringReplaceWithBackrefs(long columnView, String pat private static native long listSortRows(long nativeView, boolean isDescending, boolean isNullSmallest); + private static native long listsHaveOverlap(long lhsViewHandle, long rhsViewHandle); + + private static native long listsIntersectDistinct(long lhsViewHandle, long rhsViewHandle); + + private static native long listsUnionDistinct(long lhsViewHandle, long rhsViewHandle); + + private static native long listsDifferenceDistinct(long lhsViewHandle, long rhsViewHandle); + private static native long getElement(long nativeView, int index); private static native long castTo(long nativeHandle, int type, int scale); diff --git a/java/src/main/native/src/ColumnViewJni.cpp b/java/src/main/native/src/ColumnViewJni.cpp index 9cf1e74d84d..f8f7c79ddf0 100644 --- a/java/src/main/native/src/ColumnViewJni.cpp +++ b/java/src/main/native/src/ColumnViewJni.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -595,6 +596,72 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_generateListOffsets(JNIEn CATCH_STD(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listsHaveOverlap(JNIEnv *env, jclass, + jlong lhs_handle, + jlong rhs_handle) { + JNI_NULL_CHECK(env, lhs_handle, "lhs_handle is null", 0) + JNI_NULL_CHECK(env, rhs_handle, "rhs_handle is null", 0) + try { + cudf::jni::auto_set_device(env); + auto const lhs = reinterpret_cast(lhs_handle); + auto const rhs = reinterpret_cast(rhs_handle); + auto overlap_result = + cudf::lists::have_overlap(cudf::lists_column_view{*lhs}, cudf::lists_column_view{*rhs}, + cudf::null_equality::UNEQUAL, cudf::nan_equality::ALL_EQUAL); + cudf::jni::post_process_list_overlap(*lhs, *rhs, overlap_result); + return release_as_jlong(overlap_result); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listsIntersectDistinct(JNIEnv *env, jclass, + jlong lhs_handle, + jlong rhs_handle) { + JNI_NULL_CHECK(env, lhs_handle, "lhs_handle is null", 0) + JNI_NULL_CHECK(env, rhs_handle, "rhs_handle is null", 0) + try { + cudf::jni::auto_set_device(env); + auto const lhs = reinterpret_cast(lhs_handle); + auto const rhs = reinterpret_cast(rhs_handle); + return release_as_jlong(cudf::lists::intersect_distinct( + cudf::lists_column_view{*lhs}, cudf::lists_column_view{*rhs}, cudf::null_equality::EQUAL, + cudf::nan_equality::ALL_EQUAL)); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listsUnionDistinct(JNIEnv *env, jclass, + jlong lhs_handle, + jlong rhs_handle) { + JNI_NULL_CHECK(env, lhs_handle, "lhs_handle is null", 0) + JNI_NULL_CHECK(env, rhs_handle, "rhs_handle is null", 0) + try { + cudf::jni::auto_set_device(env); + auto const lhs = reinterpret_cast(lhs_handle); + auto const rhs = reinterpret_cast(rhs_handle); + return release_as_jlong( + cudf::lists::union_distinct(cudf::lists_column_view{*lhs}, cudf::lists_column_view{*rhs}, + cudf::null_equality::EQUAL, cudf::nan_equality::ALL_EQUAL)); + } + CATCH_STD(env, 0); +} + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_ColumnView_listsDifferenceDistinct(JNIEnv *env, jclass, + jlong lhs_handle, + jlong rhs_handle) { + JNI_NULL_CHECK(env, lhs_handle, "lhs_handle is null", 0) + JNI_NULL_CHECK(env, rhs_handle, "rhs_handle is null", 0) + try { + cudf::jni::auto_set_device(env); + auto const lhs = reinterpret_cast(lhs_handle); + auto const rhs = reinterpret_cast(rhs_handle); + return release_as_jlong(cudf::lists::difference_distinct( + cudf::lists_column_view{*lhs}, cudf::lists_column_view{*rhs}, cudf::null_equality::EQUAL, + cudf::nan_equality::ALL_EQUAL)); + } + CATCH_STD(env, 0); +} + JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_ColumnView_stringSplit(JNIEnv *env, jclass, jlong input_handle, jstring pattern_obj, diff --git a/java/src/main/native/src/ColumnViewJni.cu b/java/src/main/native/src/ColumnViewJni.cu index aa21b508040..a3f9ab5928d 100644 --- a/java/src/main/native/src/ColumnViewJni.cu +++ b/java/src/main/native/src/ColumnViewJni.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#include + #include #include #include @@ -22,12 +24,17 @@ #include #include #include +#include +#include #include #include #include #include #include +#include +#include #include +#include #include "ColumnViewJni.hpp" @@ -81,6 +88,89 @@ std::unique_ptr generate_list_offsets(cudf::column_view const &lis return offsets_column; } +namespace { + +/** + * @brief Check if the input list has any null elements. + * + * @param list The input list. + * @return The boolean result indicating if the input list has null elements. + */ +__device__ bool list_has_nulls(list_device_view list) { + return thrust::any_of(thrust::seq, thrust::make_counting_iterator(0), + thrust::make_counting_iterator(list.size()), + [&list](auto const idx) { return list.is_null(idx); }); +} + +} // namespace + +void post_process_list_overlap(cudf::column_view const &lhs, cudf::column_view const &rhs, + std::unique_ptr const &overlap_result, + rmm::cuda_stream_view stream) { + // If both of the input columns do not have nulls, we don't need to do anything here. + if (!lists_column_view{lhs}.child().has_nulls() && !lists_column_view{rhs}.child().has_nulls()) { + return; + } + + auto const overlap_cv = overlap_result->view(); + auto const lhs_cdv_ptr = column_device_view::create(lhs, stream); + auto const rhs_cdv_ptr = column_device_view::create(rhs, stream); + auto const overlap_cdv_ptr = column_device_view::create(overlap_cv, stream); + + // Create a new bitmask to satisfy Spark's arrays_overlap's special behavior. + auto validity = rmm::device_uvector(overlap_cv.size(), stream); + thrust::tabulate(rmm::exec_policy(stream), validity.begin(), validity.end(), + [lhs = cudf::detail::lists_column_device_view{*lhs_cdv_ptr}, + rhs = cudf::detail::lists_column_device_view{*rhs_cdv_ptr}, + overlap_result = *overlap_cdv_ptr] __device__(auto const idx) { + if (overlap_result.is_null(idx) || + overlap_result.template element(idx)) { + return true; + } + + // `lhs_list` and `rhs_list` should not be null, otherwise + // `overlap_result[idx]` is null and that has been handled above. + auto const lhs_list = list_device_view{lhs, idx}; + auto const rhs_list = list_device_view{rhs, idx}; + + // Only proceed if both lists are non-empty. + if (lhs_list.size() == 0 || rhs_list.size() == 0) { + return true; + } + + // Only proceed if at least one list has nulls. + if (!list_has_nulls(lhs_list) && !list_has_nulls(rhs_list)) { + return true; + } + + // Here, the input lists satisfy all the conditions below so we output a + // null: + // - Both of the the input lists have no non-null common element, and + // - They are both non-empty, and + // - Either of them contains null elements. + return false; + }); + + // Create a new nullmask from the validity data. + auto [new_null_mask, new_null_count] = + cudf::detail::valid_if(validity.begin(), validity.end(), thrust::identity{}); + + if (new_null_count > 0) { + // If the `overlap_result` column is nullable, perform `bitmask_and` of its nullmask and the + // new nullmask. + if (overlap_cv.nullable()) { + auto [null_mask, null_count] = cudf::detail::bitmask_and( + std::vector{ + overlap_cv.null_mask(), static_cast(new_null_mask.data())}, + std::vector{0, 0}, overlap_cv.size(), stream); + overlap_result->set_null_mask(std::move(null_mask), null_count); + } else { + // Just set the output nullmask as the new nullmask. + overlap_result->set_null_mask(std::move(new_null_mask), new_null_count); + } + } +} + std::unique_ptr lists_distinct_by_key(cudf::lists_column_view const &input, rmm::cuda_stream_view stream) { if (input.is_empty()) { diff --git a/java/src/main/native/src/ColumnViewJni.hpp b/java/src/main/native/src/ColumnViewJni.hpp index 1ad8923d5b3..2cbdb65653e 100644 --- a/java/src/main/native/src/ColumnViewJni.hpp +++ b/java/src/main/native/src/ColumnViewJni.hpp @@ -53,6 +53,28 @@ std::unique_ptr generate_list_offsets(cudf::column_view const &list_length, rmm::cuda_stream_view stream = cudf::default_stream_value); +/** + * @brief Perform a special treatment for the results of `cudf::lists::have_overlap` to produce the + * results that match with Spark's `arrays_overlap`. + * + * The function `arrays_overlap` of Apache Spark has a special behavior that needs to be addressed. + * In particular, the result of checking overlap between two lists will be a null element instead of + * a `false` value (as output by `cudf::lists::have_overlap`) if: + * - Both of the the input lists have no non-null common element, and + * - They are both non-empty, and + * - Either of them contains null elements. + * + * This function performs post-processing on the results of `cudf::lists::have_overlap`, adding + * special treatment to produce an output column that matches with the behavior described above. + * + * @param lhs The input lists column for one side. + * @param rhs The input lists column for the other side. + * @param overlap_result The result column generated by checking list overlap in cudf. + */ +void post_process_list_overlap(cudf::column_view const &lhs, cudf::column_view const &rhs, + std::unique_ptr const &overlap_result, + rmm::cuda_stream_view stream = cudf::default_stream_value); + /** * @brief Generates lists column by copying elements that are distinct by key from each input list * row to the corresponding output row. diff --git a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java index 05abe4958e2..5a9671ba311 100644 --- a/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java +++ b/java/src/test/java/ai/rapids/cudf/ColumnVectorTest.java @@ -4870,6 +4870,65 @@ void testListSortRowsWithStringChild() { } } + @Test + void testSetOperations() { + List lhsList1 = Arrays.asList(Double.NaN, 5.0, 0.0, 0.0, 0.0, 0.0, null, 0.0); + List lhsList2 = Arrays.asList(Double.NaN, 5.0, 0.0, 0.0, 0.0, 0.0, null, 1.0); + List lhsList3 = null; + List lhsList4 = Arrays.asList(Double.NaN, 5.0, 0.0, 0.0, 0.0, 0.0, null, 1.0); + + List rhsList1 = Arrays.asList(1.0, 0.5, null, 0.0, 0.0, null, Double.NaN); + List rhsList2 = Arrays.asList(2.0, 1.0, null, 0.0, 0.0, null); + List rhsList3 = Arrays.asList(2.0, 1.0, null, 0.0, 0.0, null); + List rhsList4 = null; + + // Set intersection result: + List expectedIntersectionList1 = Arrays.asList(null, 0.0, Double.NaN); + List expectedIntersectionList2 = Arrays.asList(null, 0.0, 1.0); + + // Set union result: + List expectedUnionList1 = Arrays.asList(null, 0.0, 0.5, 1.0, 5.0, Double.NaN); + List expectedUnionList2 = Arrays.asList(null, 0.0, 1.0, 2.0, 5.0, Double.NaN); + + // Set difference result: + List expectedDifferenceList1 = Arrays.asList(5.0); + List expectedDifferenceList2 = Arrays.asList(5.0, Double.NaN); + + try(ColumnVector lhs = makeListsColumn(DType.FLOAT64, lhsList1, lhsList2, lhsList3, lhsList4); + ColumnVector rhs = makeListsColumn(DType.FLOAT64, rhsList1, rhsList2, rhsList3, rhsList4)) { + + // Test listsHaveOverlap: + try(ColumnVector expected = ColumnVector.fromBoxedBooleans(true, true, null, null); + ColumnVector result = ColumnVector.listsHaveOverlap(lhs, rhs)) { + assertColumnsAreEqual(expected, result); + } + + // Test listsIntersectDistinct: + try(ColumnVector expected = makeListsColumn(DType.FLOAT64, expectedIntersectionList1, + expectedIntersectionList2, null, null); + ColumnVector result = ColumnVector.listsIntersectDistinct(lhs, rhs); + ColumnVector resultSorted = result.listSortRows(false, true)) { + assertColumnsAreEqual(expected, resultSorted); + } + + // Test listsUnionDistinct: + try(ColumnVector expected = makeListsColumn(DType.FLOAT64, expectedUnionList1, + expectedUnionList2, null, null); + ColumnVector result = ColumnVector.listsUnionDistinct(lhs, rhs); + ColumnVector resultSorted = result.listSortRows(false, true)) { + assertColumnsAreEqual(expected, resultSorted); + } + + // Test listsDifferenceDistinct: + try(ColumnVector expected = makeListsColumn(DType.FLOAT64, expectedDifferenceList1, + expectedDifferenceList2, null, null); + ColumnVector result = ColumnVector.listsDifferenceDistinct(lhs, rhs); + ColumnVector resultSorted = result.listSortRows(false, true)) { + assertColumnsAreEqual(expected, resultSorted); + } + } + } + @Test void testStringSplit() { String pattern = " ";