-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23937][SQL] Add map_filter SQL function #21986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
3f88e2a
9bbaa3b
37e221c
b58a1de
9c25ae6
1823fb2
16d8b64
af79644
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.expressions | |
|
|
||
| import java.util.concurrent.atomic.AtomicReference | ||
|
|
||
| import scala.collection.mutable | ||
|
|
||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
| /** | ||
|
|
@@ -123,7 +125,10 @@ trait HigherOrderFunction extends Expression { | |
| } | ||
| } | ||
|
|
||
| trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { | ||
| /** | ||
| * Trait for functions having as input one argument and one function. | ||
| */ | ||
| trait UnaryHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes { | ||
|
|
||
| def input: Expression | ||
|
|
||
|
|
@@ -135,9 +140,15 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu | |
|
|
||
| def expectingFunctionType: AbstractDataType = AnyDataType | ||
|
|
||
| @transient lazy val functionForEval: Expression = functionsForEval.head | ||
| } | ||
|
|
||
| trait ArrayBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType) | ||
| } | ||
|
|
||
| @transient lazy val functionForEval: Expression = functionsForEval.head | ||
| trait MapBasedUnaryHigherOrderFunction extends UnaryHigherOrderFunction { | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -157,7 +168,7 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu | |
| case class ArrayTransform( | ||
| input: Expression, | ||
| function: Expression) | ||
| extends ArrayBasedHigherOrderFunction with CodegenFallback { | ||
| extends ArrayBasedUnaryHigherOrderFunction with CodegenFallback { | ||
|
|
||
| override def nullable: Boolean = input.nullable | ||
|
|
||
|
|
@@ -210,3 +221,66 @@ case class ArrayTransform( | |
|
|
||
| override def prettyName: String = "transform" | ||
| } | ||
|
|
||
| /** | ||
| * Filters entries in a map using the provided function. | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(expr, func) - Filters entries in a map using the function.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v); | ||
| [1 -> 0, 3 -> -1] | ||
| """, | ||
| since = "2.4.0") | ||
| case class MapFilter( | ||
| input: Expression, | ||
| function: Expression) | ||
| extends MapBasedUnaryHigherOrderFunction with CodegenFallback { | ||
|
|
||
| @transient val (keyType, valueType, valueContainsNull) = input.dataType match { | ||
|
||
| case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull) | ||
| case _ => | ||
| val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType | ||
| (kType, vType, vContainsNull) | ||
| } | ||
|
||
|
|
||
| @transient lazy val (keyVar, valueVar) = { | ||
| val args = function.asInstanceOf[LambdaFunction].arguments | ||
| (args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable]) | ||
| } | ||
|
|
||
| override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = { | ||
| function match { | ||
| case LambdaFunction(_, _, _) => | ||
|
||
| copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) | ||
| } | ||
| } | ||
|
|
||
| override def nullable: Boolean = input.nullable | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val m = this.input.eval(input).asInstanceOf[MapData] | ||
| if (m == null) { | ||
| null | ||
| } else { | ||
| val retKeys = new mutable.ListBuffer[Any] | ||
| val retValues = new mutable.ListBuffer[Any] | ||
|
||
| m.foreach(keyType, valueType, (k, v) => { | ||
| keyVar.value.set(k) | ||
| valueVar.value.set(v) | ||
| if (functionForEval.eval(input).asInstanceOf[Boolean]) { | ||
| retKeys += k | ||
| retValues += v | ||
| } | ||
| }) | ||
| ArrayBasedMapData(retKeys.toArray, retValues.toArray) | ||
| } | ||
| } | ||
|
|
||
| override def dataType: DataType = input.dataType | ||
|
|
||
| override def expectingFunctionType: AbstractDataType = BooleanType | ||
|
|
||
| override def prettyName: String = "map_filter" | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -94,4 +94,53 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| checkEvaluation(transform(aai, array => Cast(transform(array, plusIndex), StringType)), | ||
| Seq("[1, 3, 5]", null, "[4, 6]")) | ||
| } | ||
|
|
||
| test("MapFilter") { | ||
| def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = { | ||
| val mt = expr.dataType.asInstanceOf[MapType] | ||
| MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f)) | ||
| } | ||
| val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1), | ||
| MapType(IntegerType, IntegerType, valueContainsNull = false)) | ||
| val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null), | ||
| MapType(IntegerType, IntegerType, valueContainsNull = true)) | ||
| val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) | ||
|
|
||
| val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v | ||
|
|
||
| checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1)) | ||
| checkEvaluation(mapFilter(mii1, kGreaterThanV), Map()) | ||
| checkEvaluation(mapFilter(miin, kGreaterThanV), null) | ||
|
|
||
| val valueNull: (Expression, Expression) => Expression = (_, v) => v.isNull | ||
|
||
|
|
||
| checkEvaluation(mapFilter(mii0, valueNull), Map()) | ||
| checkEvaluation(mapFilter(mii1, valueNull), Map(1 -> null, 3 -> null)) | ||
| checkEvaluation(mapFilter(miin, valueNull), null) | ||
|
|
||
| val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0), | ||
| MapType(StringType, IntegerType, valueContainsNull = false)) | ||
| val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null), | ||
| MapType(StringType, IntegerType, valueContainsNull = true)) | ||
| val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false)) | ||
|
|
||
| val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v | ||
|
|
||
| checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0)) | ||
| checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5)) | ||
| checkEvaluation(mapFilter(msin, isLengthOfKey), null) | ||
|
|
||
| val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)), | ||
| MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) | ||
| val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)), | ||
| MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true)) | ||
| val mian = Literal.create( | ||
| null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false)) | ||
|
|
||
| val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3 | ||
|
|
||
| checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2))) | ||
| checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2))) | ||
| checkEvaluation(mapFilter(mian, customFunc), null) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this trait but I'm not sure whether we can say
"Unary"HigherOrderFunctionfor this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, how about defining
nullSafeEvalforinputin this trait likeUnaryExpression? (nullInputSafeEval?)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I called it
Unaryas it gets one input and one function. Honestly I can't think of a better name without becoming very verbose. if you have a better suggestion I am happy to follow it. I will add thenullSafeEval, thanks!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @hvanhovell for the naming?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use the term
Unarya lot and this is different from the other uses. The name should convey a HigherOrderFunction that only uses a single (lambda) function right? The only thing I can come up with isSingleHigherOrderFunction.Simplewould probably also be fine.