diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 3b60d1d88b3c..0f63717f9daf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -124,34 +124,43 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { (null, false, null) :: (null, null, null) :: Nil) - test("basic IN predicate test") { - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), + private def checkInAndInSet(in: In, expected: Any): Unit = { + // expecting all in.list are Literal or NonFoldableLiteral. + checkEvaluation(in, expected) + checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected) + } + + test("basic IN/INSET predicate test") { + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1), Literal(2))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) - checkEvaluation(In(Literal(1), Seq.empty), false) - checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null) + checkInAndInSet(In(Literal(1), Seq.empty), false) + checkInAndInSet(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null) + checkInAndInSet(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), + checkInAndInSet(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))), null) - checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true) - checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkInAndInSet(In(Literal(1), Seq(Literal(1), Literal(2))), true) + checkInAndInSet(In(Literal(2), Seq(Literal(1), Literal(2))), true) + checkInAndInSet(In(Literal(3), Seq(Literal(1), Literal(2))), false) + checkEvaluation( And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))), true) + checkEvaluation( + And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))), + true) val ns = NonFoldableLiteral.create(null, StringType) - checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null) - checkEvaluation(In(ns, Seq(ns)), null) - checkEvaluation(In(Literal("a"), Seq(ns)), null) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) - checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) - + checkInAndInSet(In(ns, Seq(Literal("1"), Literal("2"))), null) + checkInAndInSet(In(ns, Seq(ns)), null) + checkInAndInSet(In(Literal("a"), Seq(ns)), null) + checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true) + checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true) + checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false) } test("IN with different types") { @@ -187,11 +196,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { } else { false } - checkEvaluation(In(input(0), input.slice(1, 10)), expected) + checkInAndInSet(In(input(0), input.slice(1, 10)), expected) } val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t => - RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType] + RandomDataGenerator.forType(t).isDefined && + !t.isInstanceOf[DecimalType] && !t.isInstanceOf[BinaryType] } ++ Seq(DecimalType.USER_DEFAULT) val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true)) @@ -252,93 +262,55 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { assert(ctx.inlinedMutableStates.isEmpty) } - test("INSET") { - val hS = HashSet[Any]() + 1 + 2 - val nS = HashSet[Any]() + 1 + 2 + null - val one = Literal(1) - val two = Literal(2) - val three = Literal(3) - val nl = Literal(null) - checkEvaluation(InSet(one, hS), true) - checkEvaluation(InSet(two, hS), true) - checkEvaluation(InSet(two, nS), true) - checkEvaluation(InSet(three, hS), false) - checkEvaluation(InSet(three, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) - - val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType, - LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType) - primitiveTypes.foreach { t => - val dataGen = RandomDataGenerator.forType(t, nullable = true).get - val inputData = Seq.fill(10) { - val value = dataGen.apply() - value match { - case d: Double if d.isNaN => 0.0d - case f: Float if f.isNaN => 0.0f - case _ => value - } - } - val input = inputData.map(Literal(_)) - val expected = if (inputData(0) == null) { - null - } else if (inputData.slice(1, 10).contains(inputData(0))) { - true - } else if (inputData.slice(1, 10).contains(null)) { - null - } else { - false - } - checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected) - } - } - - test("INSET: binary") { - val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) - val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null + test("IN/INSET: binary") { val onetwo = Literal(Array(1.toByte, 2.toByte)) val three = Literal(Array(3.toByte)) val threefour = Literal(Array(3.toByte, 4.toByte)) - val nl = Literal(null, onetwo.dataType) - checkEvaluation(InSet(onetwo, hS), true) - checkEvaluation(InSet(three, hS), true) - checkEvaluation(InSet(three, nS), true) - checkEvaluation(InSet(threefour, hS), false) - checkEvaluation(InSet(threefour, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + val nl = NonFoldableLiteral.create(null, onetwo.dataType) + val hS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte))) + val nS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte)), + NonFoldableLiteral.create(null, onetwo.dataType)) + checkInAndInSet(In(onetwo, hS), true) + checkInAndInSet(In(three, hS), true) + checkInAndInSet(In(three, nS), true) + checkInAndInSet(In(threefour, hS), false) + checkInAndInSet(In(threefour, nS), null) + checkInAndInSet(In(nl, hS), null) + checkInAndInSet(In(nl, nS), null) } - test("INSET: struct") { - val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value - val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null + test("IN/INSET: struct") { val oneA = Literal.create((1, "a")) val twoB = Literal.create((2, "b")) val twoC = Literal.create((2, "c")) - val nl = Literal(null, oneA.dataType) - checkEvaluation(InSet(oneA, hS), true) - checkEvaluation(InSet(twoB, hS), true) - checkEvaluation(InSet(twoB, nS), true) - checkEvaluation(InSet(twoC, hS), false) - checkEvaluation(InSet(twoC, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + val nl = NonFoldableLiteral.create(null, oneA.dataType) + val hS = Seq(Literal.create((1, "a")), Literal.create((2, "b"))) + val nS = Seq(Literal.create((1, "a")), Literal.create((2, "b")), + NonFoldableLiteral.create(null, oneA.dataType)) + checkInAndInSet(In(oneA, hS), true) + checkInAndInSet(In(twoB, hS), true) + checkInAndInSet(In(twoB, nS), true) + checkInAndInSet(In(twoC, hS), false) + checkInAndInSet(In(twoC, nS), null) + checkInAndInSet(In(nl, hS), null) + checkInAndInSet(In(nl, nS), null) } - test("INSET: array") { - val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value - val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null + test("IN/INSET: array") { val onetwo = Literal.create(Seq(1, 2)) val three = Literal.create(Seq(3)) val threefour = Literal.create(Seq(3, 4)) - val nl = Literal(null, onetwo.dataType) - checkEvaluation(InSet(onetwo, hS), true) - checkEvaluation(InSet(three, hS), true) - checkEvaluation(InSet(three, nS), true) - checkEvaluation(InSet(threefour, hS), false) - checkEvaluation(InSet(threefour, nS), null) - checkEvaluation(InSet(nl, hS), null) - checkEvaluation(InSet(nl, nS), null) + val nl = NonFoldableLiteral.create(null, onetwo.dataType) + val hS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3))) + val nS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3)), + NonFoldableLiteral.create(null, onetwo.dataType)) + checkInAndInSet(In(onetwo, hS), true) + checkInAndInSet(In(three, hS), true) + checkInAndInSet(In(three, nS), true) + checkInAndInSet(In(threefour, hS), false) + checkInAndInSet(In(threefour, nS), null) + checkInAndInSet(In(nl, hS), null) + checkInAndInSet(In(nl, nS), null) } private case class MyStruct(a: Long, b: String)