@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525import org .apache .spark .sql .catalyst .analysis .{TypeCheckResult , UnresolvedAttribute }
2626import org .apache .spark .sql .catalyst .expressions .codegen ._
2727import org .apache .spark .sql .catalyst .expressions .codegen .Block ._
28- import org .apache .spark .sql .catalyst .util .{ArrayData , GenericArrayData }
28+ import org .apache .spark .sql .catalyst .util .{ArrayBasedMapData , ArrayData , GenericArrayData , MapData }
2929import org .apache .spark .sql .types ._
3030
3131/**
@@ -133,7 +133,29 @@ trait HigherOrderFunction extends Expression {
133133 }
134134}
135135
136- trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
136+ object HigherOrderFunction {
137+
138+ def arrayArgumentType (dt : DataType ): (DataType , Boolean ) = {
139+ dt match {
140+ case ArrayType (elementType, containsNull) => (elementType, containsNull)
141+ case _ =>
142+ val ArrayType (elementType, containsNull) = ArrayType .defaultConcreteType
143+ (elementType, containsNull)
144+ }
145+ }
146+
147+ def mapKeyValueArgumentType (dt : DataType ): (DataType , DataType , Boolean ) = dt match {
148+ case MapType (kType, vType, vContainsNull) => (kType, vType, vContainsNull)
149+ case _ =>
150+ val MapType (kType, vType, vContainsNull) = MapType .defaultConcreteType
151+ (kType, vType, vContainsNull)
152+ }
153+ }
154+
155+ /**
156+ * Trait for functions having as input one argument and one function.
157+ */
158+ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
137159
138160 def input : Expression
139161
@@ -145,23 +167,33 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu
145167
146168 def expectingFunctionType : AbstractDataType = AnyDataType
147169
148- override def inputTypes : Seq [AbstractDataType ] = Seq (ArrayType , expectingFunctionType)
149-
150170 @ transient lazy val functionForEval : Expression = functionsForEval.head
151- }
152171
153- object ArrayBasedHigherOrderFunction {
172+ /**
173+ * Called by [[eval ]]. If a subclass keeps the default nullability, it can override this method
174+ * in order to save null-check code.
175+ */
176+ protected def nullSafeEval (inputRow : InternalRow , input : Any ): Any =
177+ sys.error(s " UnaryHigherOrderFunction must override either eval or nullSafeEval " )
154178
155- def elementArgumentType ( dt : DataType ): ( DataType , Boolean ) = {
156- dt match {
157- case ArrayType (elementType, containsNull) => (elementType, containsNull)
158- case _ =>
159- val ArrayType (elementType, containsNull) = ArrayType .defaultConcreteType
160- (elementType, containsNull )
179+ override def eval ( inputRow : InternalRow ): Any = {
180+ val value = input.eval(inputRow)
181+ if (value == null ) {
182+ null
183+ } else {
184+ nullSafeEval(inputRow, value )
161185 }
162186 }
163187}
164188
189+ trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
190+ override def inputTypes : Seq [AbstractDataType ] = Seq (ArrayType , expectingFunctionType)
191+ }
192+
193+ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
194+ override def inputTypes : Seq [AbstractDataType ] = Seq (MapType , expectingFunctionType)
195+ }
196+
165197/**
166198 * Transform elements in an array using the transform function. This is similar to
167199 * a `map` in functional programming.
@@ -179,14 +211,14 @@ object ArrayBasedHigherOrderFunction {
179211case class ArrayTransform (
180212 input : Expression ,
181213 function : Expression )
182- extends ArrayBasedHigherOrderFunction with CodegenFallback {
214+ extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
183215
184216 override def nullable : Boolean = input.nullable
185217
186218 override def dataType : ArrayType = ArrayType (function.dataType, function.nullable)
187219
188220 override def bind (f : (Expression , Seq [(DataType , Boolean )]) => LambdaFunction ): ArrayTransform = {
189- val elem = ArrayBasedHigherOrderFunction .elementArgumentType (input.dataType)
221+ val elem = HigherOrderFunction .arrayArgumentType (input.dataType)
190222 function match {
191223 case LambdaFunction (_, arguments, _) if arguments.size == 2 =>
192224 copy(function = f(function, elem :: (IntegerType , false ) :: Nil ))
@@ -205,29 +237,78 @@ case class ArrayTransform(
205237 (elementVar, indexVar)
206238 }
207239
208- override def eval (input : InternalRow ): Any = {
209- val arr = this .input.eval(input).asInstanceOf [ArrayData ]
210- if (arr == null ) {
211- null
212- } else {
213- val f = functionForEval
214- val result = new GenericArrayData (new Array [Any ](arr.numElements))
215- var i = 0
216- while (i < arr.numElements) {
217- elementVar.value.set(arr.get(i, elementVar.dataType))
218- if (indexVar.isDefined) {
219- indexVar.get.value.set(i)
220- }
221- result.update(i, f.eval(input))
222- i += 1
240+ override def nullSafeEval (inputRow : InternalRow , inputValue : Any ): Any = {
241+ val arr = inputValue.asInstanceOf [ArrayData ]
242+ val f = functionForEval
243+ val result = new GenericArrayData (new Array [Any ](arr.numElements))
244+ var i = 0
245+ while (i < arr.numElements) {
246+ elementVar.value.set(arr.get(i, elementVar.dataType))
247+ if (indexVar.isDefined) {
248+ indexVar.get.value.set(i)
223249 }
224- result
250+ result.update(i, f.eval(inputRow))
251+ i += 1
225252 }
253+ result
226254 }
227255
228256 override def prettyName : String = " transform"
229257}
230258
259+ /**
260+ * Filters entries in a map using the provided function.
261+ */
262+ @ ExpressionDescription (
263+ usage = " _FUNC_(expr, func) - Filters entries in a map using the function." ,
264+ examples = """
265+ Examples:
266+ > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
267+ [1 -> 0, 3 -> -1]
268+ """ ,
269+ since = " 2.4.0" )
270+ case class MapFilter (
271+ input : Expression ,
272+ function : Expression )
273+ extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
274+
275+ @ transient lazy val (keyVar, valueVar) = {
276+ val args = function.asInstanceOf [LambdaFunction ].arguments
277+ (args.head.asInstanceOf [NamedLambdaVariable ], args.tail.head.asInstanceOf [NamedLambdaVariable ])
278+ }
279+
280+ @ transient val (keyType, valueType, valueContainsNull) =
281+ HigherOrderFunction .mapKeyValueArgumentType(input.dataType)
282+
283+ override def bind (f : (Expression , Seq [(DataType , Boolean )]) => LambdaFunction ): MapFilter = {
284+ copy(function = f(function, (keyType, false ) :: (valueType, valueContainsNull) :: Nil ))
285+ }
286+
287+ override def nullable : Boolean = input.nullable
288+
289+ override def nullSafeEval (inputRow : InternalRow , value : Any ): Any = {
290+ val m = value.asInstanceOf [MapData ]
291+ val f = functionForEval
292+ val retKeys = new mutable.ListBuffer [Any ]
293+ val retValues = new mutable.ListBuffer [Any ]
294+ m.foreach(keyType, valueType, (k, v) => {
295+ keyVar.value.set(k)
296+ valueVar.value.set(v)
297+ if (f.eval(inputRow).asInstanceOf [Boolean ]) {
298+ retKeys += k
299+ retValues += v
300+ }
301+ })
302+ ArrayBasedMapData (retKeys.toArray, retValues.toArray)
303+ }
304+
305+ override def dataType : DataType = input.dataType
306+
307+ override def expectingFunctionType : AbstractDataType = BooleanType
308+
309+ override def prettyName : String = " map_filter"
310+ }
311+
231312/**
232313 * Filters the input array using the given lambda function.
233314 */
@@ -242,7 +323,7 @@ case class ArrayTransform(
242323case class ArrayFilter (
243324 input : Expression ,
244325 function : Expression )
245- extends ArrayBasedHigherOrderFunction with CodegenFallback {
326+ extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
246327
247328 override def nullable : Boolean = input.nullable
248329
@@ -251,29 +332,25 @@ case class ArrayFilter(
251332 override def expectingFunctionType : AbstractDataType = BooleanType
252333
253334 override def bind (f : (Expression , Seq [(DataType , Boolean )]) => LambdaFunction ): ArrayFilter = {
254- val elem = ArrayBasedHigherOrderFunction .elementArgumentType (input.dataType)
335+ val elem = HigherOrderFunction .arrayArgumentType (input.dataType)
255336 copy(function = f(function, elem :: Nil ))
256337 }
257338
258339 @ transient lazy val LambdaFunction (_, Seq (elementVar : NamedLambdaVariable ), _) = function
259340
260- override def eval (input : InternalRow ): Any = {
261- val arr = this .input.eval(input).asInstanceOf [ArrayData ]
262- if (arr == null ) {
263- null
264- } else {
265- val f = functionForEval
266- val buffer = new mutable.ArrayBuffer [Any ](arr.numElements)
267- var i = 0
268- while (i < arr.numElements) {
269- elementVar.value.set(arr.get(i, elementVar.dataType))
270- if (f.eval(input).asInstanceOf [Boolean ]) {
271- buffer += elementVar.value.get
272- }
273- i += 1
341+ override def nullSafeEval (inputRow : InternalRow , value : Any ): Any = {
342+ val arr = value.asInstanceOf [ArrayData ]
343+ val f = functionForEval
344+ val buffer = new mutable.ArrayBuffer [Any ](arr.numElements)
345+ var i = 0
346+ while (i < arr.numElements) {
347+ elementVar.value.set(arr.get(i, elementVar.dataType))
348+ if (f.eval(inputRow).asInstanceOf [Boolean ]) {
349+ buffer += elementVar.value.get
274350 }
275- new GenericArrayData (buffer)
351+ i += 1
276352 }
353+ new GenericArrayData (buffer)
277354 }
278355
279356 override def prettyName : String = " filter"
@@ -334,7 +411,7 @@ case class ArrayAggregate(
334411 override def bind (f : (Expression , Seq [(DataType , Boolean )]) => LambdaFunction ): ArrayAggregate = {
335412 // Be very conservative with nullable. We cannot be sure that the accumulator does not
336413 // evaluate to null. So we always set nullable to true here.
337- val elem = ArrayBasedHigherOrderFunction .elementArgumentType (input.dataType)
414+ val elem = HigherOrderFunction .arrayArgumentType (input.dataType)
338415 val acc = zero.dataType -> true
339416 val newMerge = f(merge, acc :: elem :: Nil )
340417 val newFinish = f(finish, acc :: Nil )
0 commit comments