Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,34 +124,43 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, false, null) ::
(null, null, null) :: Nil)

test("basic IN predicate test") {
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
private def checkInAndInSet(in: In, expected: Any): Unit = {
// expecting all in.list are Literal or NonFoldableLiteral.
checkEvaluation(in, expected)
checkEvaluation(InSet(in.value, HashSet() ++ in.list.map(_.eval())), expected)
}

test("basic IN/INSET predicate test") {
checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType),
Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkEvaluation(In(Literal(1), Seq.empty), false)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
checkInAndInSet(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkInAndInSet(In(Literal(1), Seq.empty), false)
checkInAndInSet(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkInAndInSet(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
checkInAndInSet(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkInAndInSet(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkInAndInSet(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkInAndInSet(In(Literal(3), Seq(Literal(1), Literal(2))), false)

checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
Literal(2)))),
true)
checkEvaluation(
And(InSet(Literal(1), HashSet(1, 2)), InSet(Literal(2), Set(1, 2))),
true)

val ns = NonFoldableLiteral.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)

checkInAndInSet(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkInAndInSet(In(ns, Seq(ns)), null)
checkInAndInSet(In(Literal("a"), Seq(ns)), null)
checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"), ns)), true)
checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkInAndInSet(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
}

test("IN with different types") {
Expand Down Expand Up @@ -187,11 +196,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
} else {
false
}
checkEvaluation(In(input(0), input.slice(1, 10)), expected)
checkInAndInSet(In(input(0), input.slice(1, 10)), expected)
}

val atomicTypes = DataTypeTestUtils.atomicTypes.filter { t =>
RandomDataGenerator.forType(t).isDefined && !t.isInstanceOf[DecimalType]
RandomDataGenerator.forType(t).isDefined &&
!t.isInstanceOf[DecimalType] && !t.isInstanceOf[BinaryType]
} ++ Seq(DecimalType.USER_DEFAULT)

val atomicArrayTypes = atomicTypes.map(ArrayType(_, containsNull = true))
Expand Down Expand Up @@ -252,93 +262,55 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ctx.inlinedMutableStates.isEmpty)
}

test("INSET") {
val hS = HashSet[Any]() + 1 + 2
val nS = HashSet[Any]() + 1 + 2 + null
val one = Literal(1)
val two = Literal(2)
val three = Literal(3)
val nl = Literal(null)
checkEvaluation(InSet(one, hS), true)
checkEvaluation(InSet(two, hS), true)
checkEvaluation(InSet(two, nS), true)
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.foreach { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
case d: Double if d.isNaN => 0.0d
case f: Float if f.isNaN => 0.0f
case _ => value
}
}
val input = inputData.map(Literal(_))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
true
} else if (inputData.slice(1, 10).contains(null)) {
null
} else {
false
}
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet), expected)
}
}

test("INSET: binary") {
val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte)
val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null
test("IN/INSET: binary") {
val onetwo = Literal(Array(1.toByte, 2.toByte))
val three = Literal(Array(3.toByte))
val threefour = Literal(Array(3.toByte, 4.toByte))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
val nl = NonFoldableLiteral.create(null, onetwo.dataType)
val hS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte)))
val nS = Seq(Literal(Array(1.toByte, 2.toByte)), Literal(Array(3.toByte)),
NonFoldableLiteral.create(null, onetwo.dataType))
checkInAndInSet(In(onetwo, hS), true)
checkInAndInSet(In(three, hS), true)
checkInAndInSet(In(three, nS), true)
checkInAndInSet(In(threefour, hS), false)
checkInAndInSet(In(threefour, nS), null)
checkInAndInSet(In(nl, hS), null)
checkInAndInSet(In(nl, nS), null)
}

test("INSET: struct") {
val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value
val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null
test("IN/INSET: struct") {
val oneA = Literal.create((1, "a"))
val twoB = Literal.create((2, "b"))
val twoC = Literal.create((2, "c"))
val nl = Literal(null, oneA.dataType)
checkEvaluation(InSet(oneA, hS), true)
checkEvaluation(InSet(twoB, hS), true)
checkEvaluation(InSet(twoB, nS), true)
checkEvaluation(InSet(twoC, hS), false)
checkEvaluation(InSet(twoC, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
val nl = NonFoldableLiteral.create(null, oneA.dataType)
val hS = Seq(Literal.create((1, "a")), Literal.create((2, "b")))
val nS = Seq(Literal.create((1, "a")), Literal.create((2, "b")),
NonFoldableLiteral.create(null, oneA.dataType))
checkInAndInSet(In(oneA, hS), true)
checkInAndInSet(In(twoB, hS), true)
checkInAndInSet(In(twoB, nS), true)
checkInAndInSet(In(twoC, hS), false)
checkInAndInSet(In(twoC, nS), null)
checkInAndInSet(In(nl, hS), null)
checkInAndInSet(In(nl, nS), null)
}

test("INSET: array") {
val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value
val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null
test("IN/INSET: array") {
val onetwo = Literal.create(Seq(1, 2))
val three = Literal.create(Seq(3))
val threefour = Literal.create(Seq(3, 4))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
val nl = NonFoldableLiteral.create(null, onetwo.dataType)
val hS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3)))
val nS = Seq(Literal.create(Seq(1, 2)), Literal.create(Seq(3)),
NonFoldableLiteral.create(null, onetwo.dataType))
checkInAndInSet(In(onetwo, hS), true)
checkInAndInSet(In(three, hS), true)
checkInAndInSet(In(three, nS), true)
checkInAndInSet(In(threefour, hS), false)
checkInAndInSet(In(threefour, nS), null)
checkInAndInSet(In(nl, hS), null)
checkInAndInSet(In(nl, nS), null)
}

private case class MyStruct(a: Long, b: String)
Expand Down