diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ce17231265e49..7567fb9bd1c90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3550,6 +3550,17 @@ object ArrayBinaryLike { def throwUnionLengthOverflowException(length: Int): Unit = { throw QueryExecutionErrors.unionArrayWithElementsExceedLimitError(length) } + + def isNanElement(value: Any): Boolean = { + Double.NaN.equals(value) || Float.NaN.equals(value) + } + + def containsNanElement(dataType: DataType): Boolean = { + dataType match { + case DoubleType | FloatType => true + case _ => false + } + } } @@ -3577,6 +3588,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] val hs = new OpenHashSet[Any] var foundNullElement = false + var foundNaNElement = false Seq(array1, array2).foreach { array => var i = 0 while (i < array.numElements()) { @@ -3587,12 +3599,19 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi } } else { val elem = array.get(i, elementType) - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + if (ArrayBinaryLike.isNanElement(elem)) { + if (!foundNaNElement) { + arrayBuffer += elem + foundNaNElement = true + } + } else { + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArrayBinaryLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) } - arrayBuffer += elem - hs.add(elem) } } i += 1 @@ -3651,6 +3670,8 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") val nullElementIndex = ctx.freshName("nullElementIndex") + val foundNanElement = ctx.freshName("foundNanElement") + val nanElementIndex = ctx.freshName("nanElementIndex") val builder = ctx.freshName("builder") val array = ctx.freshName("array") val arrays = ctx.freshName("arrays") @@ -3660,6 +3681,7 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi val hashSet = ctx.freshName("hashSet") val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBinaryLike = ArrayBinaryLike.getClass.getName.stripSuffix("$") def withArrayNullAssignment(body: String) = if (dataType.asInstanceOf[ArrayType].containsNull) { @@ -3679,17 +3701,37 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi body } + def withArrayNanAssignment(body: String) = + if (ArrayBinaryLike.containsNanElement(dataType.asInstanceOf[ArrayType].elementType)) { + s""" + |if ($arrayBinaryLike.isNanElement($value)) { + | if (!$foundNanElement) { + | $nanElementIndex = $size; + | $foundNanElement = true; + | $size++; + | $builder.$$plus$$eq($value); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + val processArray = withArrayNullAssignment( - s""" - |$jt $value = ${genGetValue(array, i)}; - |if (!$hashSet.contains($hsValueCast$value)) { - | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | break; - | } - | $hashSet.add$hsPostFix($hsValueCast$value); - | $builder.$$plus$$eq($value); - |} + s"""$jt $value = ${genGetValue(array, i)};""" ++ + withArrayNanAssignment( + s""" + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} """.stripMargin) + ) // Only need to track null element index when result array's element is nullable. val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { @@ -3701,9 +3743,22 @@ case class ArrayUnion(left: Expression, right: Expression) extends ArrayBinaryLi "" } + // Only need to track NaN element index when result array's element is nullable. + val declareNanTrackVariables = if (ArrayBinaryLike.containsNanElement( + dataType.asInstanceOf[ArrayType].elementType)) { + s""" + |boolean $foundNanElement = false; + |int $nanElementIndex = -1; + """.stripMargin + } else { + "" + } + + s""" |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); |$declareNullTrackVariables + |$declareNanTrackVariables |int $size = 0; |$arrayBuilderClass $builder = new $arrayBuilderClass(); |ArrayData[] $arrays = new ArrayData[]{$array1, $array2}; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 688ee61b88180..f4221a836bea0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -2309,4 +2309,21 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper } } } + + test("SPARK-36702: ArrayUnion should handle duplicated Double.NaN and Float.Nan") { + checkEvaluation(ArrayUnion( + Literal.apply(Array(Double.NaN, Double.NaN)), Literal.apply(Array(1d))), + Seq(Double.NaN, 1d)) + checkEvaluation(ArrayUnion( + Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType)), + Literal.create(Seq(Double.NaN, null, 1d), ArrayType(DoubleType))), + Seq(Double.NaN, null, 1d)) + checkEvaluation(ArrayUnion( + Literal.apply(Array(Float.NaN, Float.NaN)), Literal.apply(Array(1f))), + Seq(Float.NaN, 1f)) + checkEvaluation(ArrayUnion( + Literal.create(Seq(Float.NaN, null), ArrayType(FloatType)), + Literal.create(Seq(Float.NaN, null, 1f), ArrayType(FloatType))), + Seq(Float.NaN, null, 1f)) + } }