Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
import org.apache.spark.util.collection.OpenHashSet

/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
Expand Down Expand Up @@ -4109,32 +4108,38 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
@transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
if (TypeUtils.typeWithProperEquals(elementType)) {
(array1, array2) =>
val hs = new OpenHashSet[Any]
var notFoundNullElement = true
val hs = new SQLOpenHashSet[Any]
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
val withArray2NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) => hs.add(value),
(valueNaN: Any) => {})
val withArray1NaNCheckFunc = SQLOpenHashSet.withNaNCheckFunc(elementType, hs,
(value: Any) =>
if (!hs.contains(value)) {
arrayBuffer += value
hs.add(value)
},
(valueNaN: Any) => arrayBuffer += valueNaN)
var i = 0
while (i < array2.numElements()) {
if (array2.isNullAt(i)) {
notFoundNullElement = false
hs.addNull()
} else {
val elem = array2.get(i, elementType)
hs.add(elem)
withArray2NaNCheckFunc(elem)
}
i += 1
}
val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
i = 0
while (i < array1.numElements()) {
if (array1.isNullAt(i)) {
if (notFoundNullElement) {
if (!hs.containsNull()) {
arrayBuffer += null
notFoundNullElement = false
hs.addNull()
}
} else {
val elem = array1.get(i, elementType)
if (!hs.contains(elem)) {
arrayBuffer += elem
hs.add(elem)
}
withArray1NaNCheckFunc(elem)
}
i += 1
}
Expand Down Expand Up @@ -4203,10 +4208,9 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
val ptName = CodeGenerator.primitiveTypeName(jt)

nullSafeCodeGen(ctx, ev, (array1, array2) => {
val notFoundNullElement = ctx.freshName("notFoundNullElement")
val nullElementIndex = ctx.freshName("nullElementIndex")
val builder = ctx.freshName("builder")
val openHashSet = classOf[OpenHashSet[_]].getName
val openHashSet = classOf[SQLOpenHashSet[_]].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val hashSet = ctx.freshName("hashSet")
val arrayBuilder = classOf[mutable.ArrayBuilder[_]].getName
Expand All @@ -4217,7 +4221,7 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array2.isNullAt($i)) {
| $notFoundNullElement = false;
| $hashSet.addNull();
|} else {
| $body
|}
Expand All @@ -4235,18 +4239,18 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
}

val writeArray2ToHashSet = withArray2NullCheck(
s"""
|$jt $value = ${genGetValue(array2, i)};
|$hashSet.add$hsPostFix($hsValueCast$value);
""".stripMargin)
s"$jt $value = ${genGetValue(array2, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet,
s"$hashSet.add$hsPostFix($hsValueCast$value);",
(valueNaN: Any) => ""))

def withArray1NullAssignment(body: String) =
if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|if ($array1.isNullAt($i)) {
| if ($notFoundNullElement) {
| if (!$hashSet.containsNull()) {
| $hashSet.addNull();
| $nullElementIndex = $size;
| $notFoundNullElement = false;
| $size++;
| $builder.$$plus$$eq($nullValueHolder);
| }
Expand All @@ -4258,22 +4262,29 @@ case class ArrayExcept(left: Expression, right: Expression) extends ArrayBinaryL
body
}

val processArray1 = withArray1NullAssignment(
val body =
s"""
|$jt $value = ${genGetValue(array1, i)};
|if (!$hashSet.contains($hsValueCast$value)) {
| if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| break;
| }
| $hashSet.add$hsPostFix($hsValueCast$value);
| $builder.$$plus$$eq($value);
|}
""".stripMargin)
""".stripMargin

val processArray1 = withArray1NullAssignment(
s"$jt $value = ${genGetValue(array1, i)};" +
SQLOpenHashSet.withNaNCheckCode(elementType, value, hashSet, body,
(valueNaN: String) =>
s"""
|$size++;
|$builder.$$plus$$eq($valueNaN);
""".stripMargin))

// Only need to track null element index when array1's element is nullable.
val declareNullTrackVariables = if (left.dataType.asInstanceOf[ArrayType].containsNull) {
s"""
|boolean $notFoundNullElement = true;
|int $nullElementIndex = -1;
""".stripMargin
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2327,6 +2327,23 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
Seq(Float.NaN, null, 1f))
}

test("SPARK-36753: ArrayExcept should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayExcept(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN))),
Seq(1d))
checkEvaluation(ArrayExcept(
Literal.create(Seq(null, Double.NaN, null, 1d), ArrayType(DoubleType)),
Literal.create(Seq(Double.NaN, null), ArrayType(DoubleType))),
Seq(1d))
checkEvaluation(ArrayExcept(
Literal.apply(Array(Float.NaN, 1f)), Literal.apply(Array(Float.NaN))),
Seq(1f))
checkEvaluation(ArrayExcept(
Literal.create(Seq(null, Float.NaN, null, 1f), ArrayType(FloatType)),
Literal.create(Seq(Float.NaN, null), ArrayType(FloatType))),
Seq(1f))
}

test("SPARK-36754: ArrayIntersect should handle duplicated Double.NaN and Float.Nan") {
checkEvaluation(ArrayIntersect(
Literal.apply(Array(Double.NaN, 1d)), Literal.apply(Array(Double.NaN, 1d, 2d))),
Expand Down