Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -367,11 +367,29 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}

@transient lazy val set: Set[Any] = child.dataType match {
case _: AtomicType => hset
case t: AtomicType if !t.isInstanceOf[BinaryType] => hset
case _: NullType => hset
case _ =>
val ord = TypeUtils.getInterpretedOrdering(child.dataType)
val ordering = if (hasNull) {
new Ordering[Any] {
override def compare(x: Any, y: Any): Int = {
Copy link
Member

@viirya viirya Nov 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InSet overrides nullSafeEval, and for codegen we look into set only if !ev.isNull, so I think we only need to consider the case the value from set is null.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or simply filter out null from the tree set as @cloud-fan's idea.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! yeah, I'm updating as @cloud-fan's idea.
Also we can use nullSafeCodeGen for codegen path, I'll update it as well.

if (x == null && y == null) {
0
} else if (x == null) {
-1
} else if (y == null) {
1
} else {
ord.compare(x, y)
}
}
}
} else {
ord
}
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
TreeSet.empty(ordering) ++ hset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we just filter out nulls when building the tree set?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and udpate eval to

if (value == null) {
  null
} else if (set.contains(value)) {
  true
} else if (hasNull) {
  null
} else {
  false
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we are using nullSafeEval, so we don't need to update it.
Instead, I'm updating to use nullSafeCodeGen for codegen path.

}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(nl, nS), null)

val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
LongType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.foreach { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = true).get
val inputData = Seq.fill(10) {
Expand All @@ -293,6 +293,54 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("INSET: binary") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding the semantics, InSet is equal to In. Could we combine the test cases? Test both?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea! we should test In and InSet together

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll do it later. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Submitted #23187.

val hS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte)
val nS = HashSet[Any]() + Array(1.toByte, 2.toByte) + Array(3.toByte) + null
val onetwo = Literal(Array(1.toByte, 2.toByte))
val three = Literal(Array(3.toByte))
val threefour = Literal(Array(3.toByte, 4.toByte))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

test("INSET: struct") {
val hS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value
val nS = HashSet[Any]() + Literal.create((1, "a")).value + Literal.create((2, "b")).value + null
val oneA = Literal.create((1, "a"))
val twoB = Literal.create((2, "b"))
val twoC = Literal.create((2, "c"))
val nl = Literal(null, oneA.dataType)
checkEvaluation(InSet(oneA, hS), true)
checkEvaluation(InSet(twoB, hS), true)
checkEvaluation(InSet(twoB, nS), true)
checkEvaluation(InSet(twoC, hS), false)
checkEvaluation(InSet(twoC, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

test("INSET: array") {
val hS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value
val nS = HashSet[Any]() + Literal.create(Seq(1, 2)).value + Literal.create(Seq(3)).value + null
val onetwo = Literal.create(Seq(1, 2))
val three = Literal.create(Seq(3))
val threefour = Literal.create(Seq(3, 4))
val nl = Literal(null, onetwo.dataType)
checkEvaluation(InSet(onetwo, hS), true)
checkEvaluation(InSet(three, hS), true)
checkEvaluation(InSet(three, nS), true)
checkEvaluation(InSet(threefour, hS), false)
checkEvaluation(InSet(threefour, nS), null)
checkEvaluation(InSet(nl, hS), null)
checkEvaluation(InSet(nl, nS), null)
}

private case class MyStruct(a: Long, b: String)
private case class MyStruct2(a: MyStruct, b: Array[Int])
private val udt = new ExamplePointUDT
Expand Down