Skip to content

Commit 9bbaa3b

Browse files
committed
address comments
1 parent 3f88e2a commit 9bbaa3b

File tree

2 files changed

+46
-37
lines changed

2 files changed

+46
-37
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,22 @@ trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputType
141141
def expectingFunctionType: AbstractDataType = AnyDataType
142142

143143
@transient lazy val functionForEval: Expression = functionsForEval.head
144+
145+
/**
146+
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
147+
* in order to save null-check code.
148+
*/
149+
protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
150+
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")
151+
152+
override def eval(inputRow: InternalRow): Any = {
153+
val value = input.eval(inputRow)
154+
if (value == null) {
155+
null
156+
} else {
157+
nullSafeEval(inputRow, value)
158+
}
159+
}
144160
}
145161

146162
trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction {
@@ -199,24 +215,20 @@ case class ArrayTransform(
199215
(elementVar, indexVar)
200216
}
201217

202-
override def eval(input: InternalRow): Any = {
203-
val arr = this.input.eval(input).asInstanceOf[ArrayData]
204-
if (arr == null) {
205-
null
206-
} else {
207-
val f = functionForEval
208-
val result = new GenericArrayData(new Array[Any](arr.numElements))
209-
var i = 0
210-
while (i < arr.numElements) {
211-
elementVar.value.set(arr.get(i, elementVar.dataType))
212-
if (indexVar.isDefined) {
213-
indexVar.get.value.set(i)
214-
}
215-
result.update(i, f.eval(input))
216-
i += 1
218+
override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
219+
val arr = inputValue.asInstanceOf[ArrayData]
220+
val f = functionForEval
221+
val result = new GenericArrayData(new Array[Any](arr.numElements))
222+
var i = 0
223+
while (i < arr.numElements) {
224+
elementVar.value.set(arr.get(i, elementVar.dataType))
225+
if (indexVar.isDefined) {
226+
indexVar.get.value.set(i)
217227
}
218-
result
228+
result.update(i, f.eval(inputRow))
229+
i += 1
219230
}
231+
result
220232
}
221233

222234
override def prettyName: String = "transform"
@@ -259,23 +271,20 @@ case class MapFilter(
259271

260272
override def nullable: Boolean = input.nullable
261273

262-
override def eval(input: InternalRow): Any = {
263-
val m = this.input.eval(input).asInstanceOf[MapData]
264-
if (m == null) {
265-
null
266-
} else {
267-
val retKeys = new mutable.ListBuffer[Any]
268-
val retValues = new mutable.ListBuffer[Any]
269-
m.foreach(keyType, valueType, (k, v) => {
270-
keyVar.value.set(k)
271-
valueVar.value.set(v)
272-
if (functionForEval.eval(input).asInstanceOf[Boolean]) {
273-
retKeys += k
274-
retValues += v
275-
}
276-
})
277-
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
278-
}
274+
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
275+
val m = value.asInstanceOf[MapData]
276+
val f = functionForEval
277+
val retKeys = new mutable.ListBuffer[Any]
278+
val retValues = new mutable.ListBuffer[Any]
279+
m.foreach(keyType, valueType, (k, v) => {
280+
keyVar.value.set(k)
281+
valueVar.value.set(v)
282+
if (f.eval(inputRow).asInstanceOf[Boolean]) {
283+
retKeys += k
284+
retValues += v
285+
}
286+
})
287+
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
279288
}
280289

281290
override def dataType: DataType = input.dataType

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
112112
checkEvaluation(mapFilter(mii1, kGreaterThanV), Map())
113113
checkEvaluation(mapFilter(miin, kGreaterThanV), null)
114114

115-
val valueNull: (Expression, Expression) => Expression = (_, v) => v.isNull
115+
val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull
116116

117-
checkEvaluation(mapFilter(mii0, valueNull), Map())
118-
checkEvaluation(mapFilter(mii1, valueNull), Map(1 -> null, 3 -> null))
119-
checkEvaluation(mapFilter(miin, valueNull), null)
117+
checkEvaluation(mapFilter(mii0, valueIsNull), Map())
118+
checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null))
119+
checkEvaluation(mapFilter(miin, valueIsNull), null)
120120

121121
val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0),
122122
MapType(StringType, IntegerType, valueContainsNull = false))

0 commit comments

Comments
 (0)