diff --git a/integration_tests/src/main/python/array_test.py b/integration_tests/src/main/python/array_test.py index 9362c2181da..2b7485b42eb 100644 --- a/integration_tests/src/main/python/array_test.py +++ b/integration_tests/src/main/python/array_test.py @@ -99,10 +99,15 @@ def test_orderby_array_of_structs(data_gen): def test_array_contains(data_gen): arr_gen = ArrayGen(data_gen) lit = gen_scalar(data_gen, force_no_nulls=True) - assert_gpu_and_cpu_are_equal_collect(lambda spark: two_col_df( - spark, arr_gen, data_gen).select(array_contains(col('a'), lit.cast(data_gen.data_type)), - array_contains(col('a'), col('b')), - array_contains(col('a'), col('a')[5])), no_nans_conf) + + def get_input(spark): + return two_col_df(spark, arr_gen, data_gen) + + assert_gpu_and_cpu_are_equal_collect(lambda spark: get_input(spark).select( + array_contains(col('a'), lit.cast(data_gen.data_type)), + array_contains(col('a'), col('b')), + array_contains(col('a'), col('a')[5]) + ), no_nans_conf) # Test array_contains() with a literal key that is extracted from the input array of doubles @@ -118,6 +123,7 @@ def main_df(spark): return df.select(array_contains(col('a'), chk_val)) assert_gpu_and_cpu_are_equal_collect(main_df) + @pytest.mark.skipif(is_before_spark_311(), reason="Only in Spark 3.1.1 + ANSI mode, array index throws on out of range indexes") @pytest.mark.parametrize('data_gen', array_gens_sample, ids=idfn) def test_get_array_item_ansi_fail(data_gen): diff --git a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala index 1aa32ce5894..47d90b4b72e 100644 --- a/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala +++ b/sql-plugin/src/main/scala/org/apache/spark/sql/rapids/complexTypeExtractors.scala @@ -223,8 +223,37 @@ case class GpuArrayContains(left: Expression, right: Expression) left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull } - override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = - lhs.getBase.listContains(rhs.getBase) + /** + * Helper function to account for `libcudf`'s `listContains()` semantics. + * + * If a list row contains at least one null element, and is found not to contain + * the search key, `libcudf` returns false instead of null. SparkSQL expects to + * return null in those cases. + * + * This method determines the result's validity mask by ORing the output of + * `listContains()` with the NOT of `listContainsNulls()`. + * A result row is thus valid if either the search key is found in the list, + * or if the list does not contain any null elements. + */ + private def orNotContainsNull(containsResult: ColumnVector, + inputListsColumn:ColumnVector): ColumnVector = { + val notContainsNull = withResource(inputListsColumn.listContainsNulls) { + _.not + } + val containsKeyOrNotContainsNull = withResource(notContainsNull) { + containsResult.or(_) + } + withResource(containsKeyOrNotContainsNull) { + containsResult.copyWithBooleanColumnAsValidity(_) + } + } + + override def doColumnar(lhs: GpuColumnVector, rhs: GpuScalar): ColumnVector = { + val inputListsColumn = lhs.getBase + withResource(inputListsColumn.listContains(rhs.getBase)) { + orNotContainsNull(_, inputListsColumn) + } + } override def doColumnar(numRows: Int, lhs: GpuScalar, rhs: GpuScalar): ColumnVector = throw new IllegalStateException("This is not supported yet") @@ -232,8 +261,12 @@ case class GpuArrayContains(left: Expression, right: Expression) override def doColumnar(lhs: GpuScalar, rhs: GpuColumnVector): ColumnVector = throw new IllegalStateException("This is not supported yet") - override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = - lhs.getBase.listContainsColumn(rhs.getBase) + override def doColumnar(lhs: GpuColumnVector, rhs: GpuColumnVector): ColumnVector = { + val inputListsColumn = lhs.getBase + withResource(inputListsColumn.listContainsColumn(rhs.getBase)) { + orNotContainsNull(_, inputListsColumn) + } + } override def prettyName: String = "array_contains" }