-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23937][SQL] Add map_filter SQL function #21986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
3f88e2a
9bbaa3b
37e221c
b58a1de
9c25ae6
1823fb2
16d8b64
af79644
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow | |
| import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
|
|
@@ -133,7 +133,10 @@ trait HigherOrderFunction extends Expression { | |
| } | ||
| } | ||
|
|
||
| trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { | ||
| /** | ||
| * Trait for functions having as input one argument and one function. | ||
| */ | ||
| trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { | ||
|
|
||
| def input: Expression | ||
|
|
||
|
|
@@ -145,9 +148,31 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu | |
|
|
||
| def expectingFunctionType: AbstractDataType = AnyDataType | ||
|
|
||
| @transient lazy val functionForEval: Expression = functionsForEval.head | ||
|
|
||
| /** | ||
| * Called by [[eval]]. If a subclass keeps the default nullability, it can override this method | ||
| * in order to save null-check code. | ||
| */ | ||
| protected def nullSafeEval(inputRow: InternalRow, input: Any): Any = | ||
| sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval") | ||
|
|
||
| override def eval(inputRow: InternalRow): Any = { | ||
| val value = input.eval(inputRow) | ||
| if (value == null) { | ||
| null | ||
| } else { | ||
| nullSafeEval(inputRow, value) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) | ||
| } | ||
|
|
||
| @transient lazy val functionForEval: Expression = functionsForEval.head | ||
| trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) | ||
| } | ||
|
|
||
| object ArrayBasedHigherOrderFunction { | ||
|
|
@@ -179,7 +204,7 @@ object ArrayBasedHigherOrderFunction { | |
| case class ArrayTransform( | ||
| input: Expression, | ||
| function: Expression) | ||
| extends ArrayBasedHigherOrderFunction with CodegenFallback { | ||
| extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback { | ||
|
|
||
| override def nullable: Boolean = input.nullable | ||
|
|
||
|
|
@@ -205,29 +230,82 @@ case class ArrayTransform( | |
| (elementVar, indexVar) | ||
| } | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val arr = this.input.eval(input).asInstanceOf[ArrayData] | ||
| if (arr == null) { | ||
| null | ||
| } else { | ||
| val f = functionForEval | ||
| val result = new GenericArrayData(new Array[Any](arr.numElements)) | ||
| var i = 0 | ||
| while (i < arr.numElements) { | ||
| elementVar.value.set(arr.get(i, elementVar.dataType)) | ||
| if (indexVar.isDefined) { | ||
| indexVar.get.value.set(i) | ||
| } | ||
| result.update(i, f.eval(input)) | ||
| i += 1 | ||
| override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = { | ||
| val arr = inputValue.asInstanceOf[ArrayData] | ||
| val f = functionForEval | ||
| val result = new GenericArrayData(new Array[Any](arr.numElements)) | ||
| var i = 0 | ||
| while (i < arr.numElements) { | ||
| elementVar.value.set(arr.get(i, elementVar.dataType)) | ||
| if (indexVar.isDefined) { | ||
| indexVar.get.value.set(i) | ||
| } | ||
| result | ||
| result.update(i, f.eval(inputRow)) | ||
| i += 1 | ||
| } | ||
| result | ||
| } | ||
|
|
||
| override def prettyName: String = "transform" | ||
| } | ||
|
|
||
| /** | ||
| * Filters entries in a map using the provided function. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); | ||
| [1 -> 0, 3 -> -1] | ||
| """, | ||
| since = "2.4.0") | ||
| case class MapFilter( | ||
| input: Expression, | ||
| function: Expression) | ||
| extends MapBasedUnaryHigherOrderFunction with CodegenFallback { | ||
|
|
||
| @transient val (keyType, valueType, valueContainsNull) = input.dataType match { | ||
|
||
| case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) | ||
| case _ => | ||
| val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType | ||
| (kType, vType, vContainsNull) | ||
| } | ||
|
||
|
|
||
| @transient lazy val (keyVar, valueVar) = { | ||
| val args = function.asInstanceOf[LambdaFunction].arguments | ||
| (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) | ||
| } | ||
|
|
||
| override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { | ||
| copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) | ||
| } | ||
|
|
||
| override def nullable: Boolean = input.nullable | ||
|
|
||
| override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { | ||
| val m = value.asInstanceOf[MapData] | ||
| val f = functionForEval | ||
| val retKeys = new mutable.ListBuffer[Any] | ||
| val retValues = new mutable.ListBuffer[Any] | ||
| m.foreach(keyType, valueType, (k, v) => { | ||
| keyVar.value.set(k) | ||
| valueVar.value.set(v) | ||
| if (f.eval(inputRow).asInstanceOf[Boolean]) { | ||
| retKeys += k | ||
| retValues += v | ||
| } | ||
| }) | ||
| ArrayBasedMapData(retKeys.toArray, retValues.toArray) | ||
| } | ||
|
|
||
| override def dataType: DataType = input.dataType | ||
|
|
||
| override def expectingFunctionType: AbstractDataType = BooleanType | ||
|
|
||
| override def prettyName: String = "map_filter" | ||
| } | ||
|
|
||
| /** | ||
| * Filters the input array using the given lambda function. | ||
| */ | ||
|
|
@@ -242,7 +320,7 @@ case class ArrayTransform( | |
| case class ArrayFilter( | ||
| input: Expression, | ||
| function: Expression) | ||
| extends ArrayBasedHigherOrderFunction with CodegenFallback { | ||
| extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback { | ||
|
|
||
| override def nullable: Boolean = input.nullable | ||
|
|
||
|
|
@@ -257,23 +335,19 @@ case class ArrayFilter( | |
|
|
||
| @transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val arr = this.input.eval(input).asInstanceOf[ArrayData] | ||
| if (arr == null) { | ||
| null | ||
| } else { | ||
| val f = functionForEval | ||
| val buffer = new mutable.ArrayBuffer[Any](arr.numElements) | ||
| var i = 0 | ||
| while (i < arr.numElements) { | ||
| elementVar.value.set(arr.get(i, elementVar.dataType)) | ||
| if (f.eval(input).asInstanceOf[Boolean]) { | ||
| buffer += elementVar.value.get | ||
| } | ||
| i += 1 | ||
| override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { | ||
| val arr = value.asInstanceOf[ArrayData] | ||
| val f = functionForEval | ||
| val buffer = new mutable.ArrayBuffer[Any](arr.numElements) | ||
| var i = 0 | ||
| while (i < arr.numElements) { | ||
| elementVar.value.set(arr.get(i, elementVar.dataType)) | ||
| if (f.eval(inputRow).asInstanceOf[Boolean]) { | ||
| buffer += elementVar.value.get | ||
| } | ||
| new GenericArrayData(buffer) | ||
| i += 1 | ||
| } | ||
| new GenericArrayData(buffer) | ||
| } | ||
|
|
||
| override def prettyName: String = "filter" | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this trait but I'm not sure whether we can say
"Unary"HigherOrderFunctionfor this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, how about defining
nullSafeEvalforinputin this trait likeUnaryExpression? (nullInputSafeEval?)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I called it
Unaryas it gets one input and one function. Honestly I can't think of a better name without becoming very verbose. if you have a better suggestion I am happy to follow it. I will add thenullSafeEval, thanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @hvanhovell for the naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the term
Unarya lot and this is different from the other uses. The name should convey a HigherOrderFunction that only uses a single (lambda) function right? The only thing I can come up with isSingleHigherOrderFunction.Simplewould probably also be fine.