@@ -141,6 +141,22 @@ trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputType
141141 def expectingFunctionType : AbstractDataType = AnyDataType
142142
143143 @ transient lazy val functionForEval : Expression = functionsForEval.head
144+
145+ /**
146+ * Called by [[eval ]]. If a subclass keeps the default nullability, it can override this method
147+ * in order to save null-check code.
148+ */
149+ protected def nullSafeEval (inputRow : InternalRow , input : Any ): Any =
150+ sys.error(s " UnaryHigherOrderFunction must override either eval or nullSafeEval " )
151+
152+ override def eval (inputRow : InternalRow ): Any = {
153+ val value = input.eval(inputRow)
154+ if (value == null ) {
155+ null
156+ } else {
157+ nullSafeEval(inputRow, value)
158+ }
159+ }
144160}
145161
146162trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction {
@@ -199,24 +215,20 @@ case class ArrayTransform(
199215 (elementVar, indexVar)
200216 }
201217
202- override def eval (input : InternalRow ): Any = {
203- val arr = this .input.eval(input).asInstanceOf [ArrayData ]
204- if (arr == null ) {
205- null
206- } else {
207- val f = functionForEval
208- val result = new GenericArrayData (new Array [Any ](arr.numElements))
209- var i = 0
210- while (i < arr.numElements) {
211- elementVar.value.set(arr.get(i, elementVar.dataType))
212- if (indexVar.isDefined) {
213- indexVar.get.value.set(i)
214- }
215- result.update(i, f.eval(input))
216- i += 1
218+ override def nullSafeEval (inputRow : InternalRow , inputValue : Any ): Any = {
219+ val arr = inputValue.asInstanceOf [ArrayData ]
220+ val f = functionForEval
221+ val result = new GenericArrayData (new Array [Any ](arr.numElements))
222+ var i = 0
223+ while (i < arr.numElements) {
224+ elementVar.value.set(arr.get(i, elementVar.dataType))
225+ if (indexVar.isDefined) {
226+ indexVar.get.value.set(i)
217227 }
218- result
228+ result.update(i, f.eval(inputRow))
229+ i += 1
219230 }
231+ result
220232 }
221233
222234 override def prettyName : String = " transform"
@@ -259,23 +271,20 @@ case class MapFilter(
259271
260272 override def nullable : Boolean = input.nullable
261273
262- override def eval (input : InternalRow ): Any = {
263- val m = this .input.eval(input).asInstanceOf [MapData ]
264- if (m == null ) {
265- null
266- } else {
267- val retKeys = new mutable.ListBuffer [Any ]
268- val retValues = new mutable.ListBuffer [Any ]
269- m.foreach(keyType, valueType, (k, v) => {
270- keyVar.value.set(k)
271- valueVar.value.set(v)
272- if (functionForEval.eval(input).asInstanceOf [Boolean ]) {
273- retKeys += k
274- retValues += v
275- }
276- })
277- ArrayBasedMapData (retKeys.toArray, retValues.toArray)
278- }
274+ override def nullSafeEval (inputRow : InternalRow , value : Any ): Any = {
275+ val m = value.asInstanceOf [MapData ]
276+ val f = functionForEval
277+ val retKeys = new mutable.ListBuffer [Any ]
278+ val retValues = new mutable.ListBuffer [Any ]
279+ m.foreach(keyType, valueType, (k, v) => {
280+ keyVar.value.set(k)
281+ valueVar.value.set(v)
282+ if (f.eval(inputRow).asInstanceOf [Boolean ]) {
283+ retKeys += k
284+ retValues += v
285+ }
286+ })
287+ ArrayBasedMapData (retKeys.toArray, retValues.toArray)
279288 }
280289
281290 override def dataType : DataType = input.dataType
0 commit comments