Skip to content

Commit cb6cb31

Browse files
mgaido91ueshin
authored andcommitted
[SPARK-23937][SQL] Add map_filter SQL function
## What changes were proposed in this pull request? The PR adds the high order function `map_filter`, which filters the entries of a map and returns a new map which contains only the entries which satisfied the filter function. ## How was this patch tested? added UTs Closes #21986 from mgaido91/SPARK-23937. Authored-by: Marco Gaido <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 298e80f commit cb6cb31

File tree

4 files changed

+221
-48
lines changed

4 files changed

+221
-48
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ object FunctionRegistry {
442442
expression[ArrayRemove]("array_remove"),
443443
expression[ArrayDistinct]("array_distinct"),
444444
expression[ArrayTransform]("transform"),
445+
expression[MapFilter]("map_filter"),
445446
expression[ArrayFilter]("filter"),
446447
expression[ArrayAggregate]("aggregate"),
447448
CreateStruct.registryEntry,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 125 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
28-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
28+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
2929
import org.apache.spark.sql.types._
3030

3131
/**
@@ -133,7 +133,29 @@ trait HigherOrderFunction extends Expression {
133133
}
134134
}
135135

136-
trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
136+
object HigherOrderFunction {
137+
138+
def arrayArgumentType(dt: DataType): (DataType, Boolean) = {
139+
dt match {
140+
case ArrayType(elementType, containsNull) => (elementType, containsNull)
141+
case _ =>
142+
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
143+
(elementType, containsNull)
144+
}
145+
}
146+
147+
def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match {
148+
case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
149+
case _ =>
150+
val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
151+
(kType, vType, vContainsNull)
152+
}
153+
}
154+
155+
/**
156+
* Trait for functions having as input one argument and one function.
157+
*/
158+
trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
137159

138160
def input: Expression
139161

@@ -145,23 +167,33 @@ trait ArrayBasedHigherOrderFunction extends HigherOrderFunction with ExpectsInpu
145167

146168
def expectingFunctionType: AbstractDataType = AnyDataType
147169

148-
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
149-
150170
@transient lazy val functionForEval: Expression = functionsForEval.head
151-
}
152171

153-
object ArrayBasedHigherOrderFunction {
172+
/**
173+
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
174+
* in order to save null-check code.
175+
*/
176+
protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
177+
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")
154178

155-
def elementArgumentType(dt: DataType): (DataType, Boolean) = {
156-
dt match {
157-
case ArrayType(elementType, containsNull) => (elementType, containsNull)
158-
case _ =>
159-
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
160-
(elementType, containsNull)
179+
override def eval(inputRow: InternalRow): Any = {
180+
val value = input.eval(inputRow)
181+
if (value == null) {
182+
null
183+
} else {
184+
nullSafeEval(inputRow, value)
161185
}
162186
}
163187
}
164188

189+
trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
190+
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
191+
}
192+
193+
trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
194+
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
195+
}
196+
165197
/**
166198
* Transform elements in an array using the transform function. This is similar to
167199
* a `map` in functional programming.
@@ -179,14 +211,14 @@ object ArrayBasedHigherOrderFunction {
179211
case class ArrayTransform(
180212
input: Expression,
181213
function: Expression)
182-
extends ArrayBasedHigherOrderFunction with CodegenFallback {
214+
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
183215

184216
override def nullable: Boolean = input.nullable
185217

186218
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
187219

188220
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
189-
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
221+
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
190222
function match {
191223
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
192224
copy(function = f(function, elem :: (IntegerType, false) :: Nil))
@@ -205,29 +237,78 @@ case class ArrayTransform(
205237
(elementVar, indexVar)
206238
}
207239

208-
override def eval(input: InternalRow): Any = {
209-
val arr = this.input.eval(input).asInstanceOf[ArrayData]
210-
if (arr == null) {
211-
null
212-
} else {
213-
val f = functionForEval
214-
val result = new GenericArrayData(new Array[Any](arr.numElements))
215-
var i = 0
216-
while (i < arr.numElements) {
217-
elementVar.value.set(arr.get(i, elementVar.dataType))
218-
if (indexVar.isDefined) {
219-
indexVar.get.value.set(i)
220-
}
221-
result.update(i, f.eval(input))
222-
i += 1
240+
override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
241+
val arr = inputValue.asInstanceOf[ArrayData]
242+
val f = functionForEval
243+
val result = new GenericArrayData(new Array[Any](arr.numElements))
244+
var i = 0
245+
while (i < arr.numElements) {
246+
elementVar.value.set(arr.get(i, elementVar.dataType))
247+
if (indexVar.isDefined) {
248+
indexVar.get.value.set(i)
223249
}
224-
result
250+
result.update(i, f.eval(inputRow))
251+
i += 1
225252
}
253+
result
226254
}
227255

228256
override def prettyName: String = "transform"
229257
}
230258

259+
/**
260+
* Filters entries in a map using the provided function.
261+
*/
262+
@ExpressionDescription(
263+
usage = "_FUNC_(expr, func) - Filters entries in a map using the function.",
264+
examples = """
265+
Examples:
266+
> SELECT _FUNC_(map(1, 0, 2, 2, 3, -1), (k, v) -> k > v);
267+
[1 -> 0, 3 -> -1]
268+
""",
269+
since = "2.4.0")
270+
case class MapFilter(
271+
input: Expression,
272+
function: Expression)
273+
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
274+
275+
@transient lazy val (keyVar, valueVar) = {
276+
val args = function.asInstanceOf[LambdaFunction].arguments
277+
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
278+
}
279+
280+
@transient val (keyType, valueType, valueContainsNull) =
281+
HigherOrderFunction.mapKeyValueArgumentType(input.dataType)
282+
283+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
284+
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
285+
}
286+
287+
override def nullable: Boolean = input.nullable
288+
289+
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
290+
val m = value.asInstanceOf[MapData]
291+
val f = functionForEval
292+
val retKeys = new mutable.ListBuffer[Any]
293+
val retValues = new mutable.ListBuffer[Any]
294+
m.foreach(keyType, valueType, (k, v) => {
295+
keyVar.value.set(k)
296+
valueVar.value.set(v)
297+
if (f.eval(inputRow).asInstanceOf[Boolean]) {
298+
retKeys += k
299+
retValues += v
300+
}
301+
})
302+
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
303+
}
304+
305+
override def dataType: DataType = input.dataType
306+
307+
override def expectingFunctionType: AbstractDataType = BooleanType
308+
309+
override def prettyName: String = "map_filter"
310+
}
311+
231312
/**
232313
* Filters the input array using the given lambda function.
233314
*/
@@ -242,7 +323,7 @@ case class ArrayTransform(
242323
case class ArrayFilter(
243324
input: Expression,
244325
function: Expression)
245-
extends ArrayBasedHigherOrderFunction with CodegenFallback {
326+
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
246327

247328
override def nullable: Boolean = input.nullable
248329

@@ -251,29 +332,25 @@ case class ArrayFilter(
251332
override def expectingFunctionType: AbstractDataType = BooleanType
252333

253334
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
254-
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
335+
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
255336
copy(function = f(function, elem :: Nil))
256337
}
257338

258339
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
259340

260-
override def eval(input: InternalRow): Any = {
261-
val arr = this.input.eval(input).asInstanceOf[ArrayData]
262-
if (arr == null) {
263-
null
264-
} else {
265-
val f = functionForEval
266-
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
267-
var i = 0
268-
while (i < arr.numElements) {
269-
elementVar.value.set(arr.get(i, elementVar.dataType))
270-
if (f.eval(input).asInstanceOf[Boolean]) {
271-
buffer += elementVar.value.get
272-
}
273-
i += 1
341+
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
342+
val arr = value.asInstanceOf[ArrayData]
343+
val f = functionForEval
344+
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
345+
var i = 0
346+
while (i < arr.numElements) {
347+
elementVar.value.set(arr.get(i, elementVar.dataType))
348+
if (f.eval(inputRow).asInstanceOf[Boolean]) {
349+
buffer += elementVar.value.get
274350
}
275-
new GenericArrayData(buffer)
351+
i += 1
276352
}
353+
new GenericArrayData(buffer)
277354
}
278355

279356
override def prettyName: String = "filter"
@@ -334,7 +411,7 @@ case class ArrayAggregate(
334411
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = {
335412
// Be very conservative with nullable. We cannot be sure that the accumulator does not
336413
// evaluate to null. So we always set nullable to true here.
337-
val elem = ArrayBasedHigherOrderFunction.elementArgumentType(input.dataType)
414+
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
338415
val acc = zero.dataType -> true
339416
val newMerge = f(merge, acc :: elem :: Nil)
340417
val newFinish = f(finish, acc :: Nil)

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,55 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
121121
Seq("[1, 3, 5]", null, "[4, 6]"))
122122
}
123123

124+
test("MapFilter") {
125+
def mapFilter(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
126+
val mt = expr.dataType.asInstanceOf[MapType]
127+
MapFilter(expr, createLambda(mt.keyType, false, mt.valueType, mt.valueContainsNull, f))
128+
}
129+
val mii0 = Literal.create(Map(1 -> 0, 2 -> 10, 3 -> -1),
130+
MapType(IntegerType, IntegerType, valueContainsNull = false))
131+
val mii1 = Literal.create(Map(1 -> null, 2 -> 10, 3 -> null),
132+
MapType(IntegerType, IntegerType, valueContainsNull = true))
133+
val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))
134+
135+
val kGreaterThanV: (Expression, Expression) => Expression = (k, v) => k > v
136+
137+
checkEvaluation(mapFilter(mii0, kGreaterThanV), Map(1 -> 0, 3 -> -1))
138+
checkEvaluation(mapFilter(mii1, kGreaterThanV), Map())
139+
checkEvaluation(mapFilter(miin, kGreaterThanV), null)
140+
141+
val valueIsNull: (Expression, Expression) => Expression = (_, v) => v.isNull
142+
143+
checkEvaluation(mapFilter(mii0, valueIsNull), Map())
144+
checkEvaluation(mapFilter(mii1, valueIsNull), Map(1 -> null, 3 -> null))
145+
checkEvaluation(mapFilter(miin, valueIsNull), null)
146+
147+
val msi0 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> 0),
148+
MapType(StringType, IntegerType, valueContainsNull = false))
149+
val msi1 = Literal.create(Map("abcdf" -> 5, "abc" -> 10, "" -> null),
150+
MapType(StringType, IntegerType, valueContainsNull = true))
151+
val msin = Literal.create(null, MapType(StringType, IntegerType, valueContainsNull = false))
152+
153+
val isLengthOfKey: (Expression, Expression) => Expression = (k, v) => Length(k) === v
154+
155+
checkEvaluation(mapFilter(msi0, isLengthOfKey), Map("abcdf" -> 5, "" -> 0))
156+
checkEvaluation(mapFilter(msi1, isLengthOfKey), Map("abcdf" -> 5))
157+
checkEvaluation(mapFilter(msin, isLengthOfKey), null)
158+
159+
val mia0 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> Seq(10), -3 -> Seq(-1, 0, -2, 3)),
160+
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))
161+
val mia1 = Literal.create(Map(1 -> Seq(0, 1, 2), 2 -> null, -3 -> Seq(-1, 0, -2, 3)),
162+
MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = true))
163+
val mian = Literal.create(
164+
null, MapType(IntegerType, ArrayType(IntegerType), valueContainsNull = false))
165+
166+
val customFunc: (Expression, Expression) => Expression = (k, v) => Size(v) + k > 3
167+
168+
checkEvaluation(mapFilter(mia0, customFunc), Map(1 -> Seq(0, 1, 2)))
169+
checkEvaluation(mapFilter(mia1, customFunc), Map(1 -> Seq(0, 1, 2)))
170+
checkEvaluation(mapFilter(mian, customFunc), null)
171+
}
172+
124173
test("ArrayFilter") {
125174
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
126175
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1854,6 +1854,52 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
18541854
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
18551855
}
18561856

1857+
test("map_filter") {
1858+
val dfInts = Seq(
1859+
Map(1 -> 10, 2 -> 20, 3 -> 30),
1860+
Map(1 -> -1, 2 -> -2, 3 -> -3),
1861+
Map(1 -> 10, 2 -> 5, 3 -> -3)).toDF("m")
1862+
1863+
checkAnswer(dfInts.selectExpr(
1864+
"map_filter(m, (k, v) -> k * 10 = v)", "map_filter(m, (k, v) -> k = -v)"),
1865+
Seq(
1866+
Row(Map(1 -> 10, 2 -> 20, 3 -> 30), Map()),
1867+
Row(Map(), Map(1 -> -1, 2 -> -2, 3 -> -3)),
1868+
Row(Map(1 -> 10), Map(3 -> -3))))
1869+
1870+
val dfComplex = Seq(
1871+
Map(1 -> Seq(Some(1)), 2 -> Seq(Some(1), Some(2)), 3 -> Seq(Some(1), Some(2), Some(3))),
1872+
Map(1 -> null, 2 -> Seq(Some(-2), Some(-2)), 3 -> Seq[Option[Int]](None))).toDF("m")
1873+
1874+
checkAnswer(dfComplex.selectExpr(
1875+
"map_filter(m, (k, v) -> k = v[0])", "map_filter(m, (k, v) -> k = size(v))"),
1876+
Seq(
1877+
Row(Map(1 -> Seq(1)), Map(1 -> Seq(1), 2 -> Seq(1, 2), 3 -> Seq(1, 2, 3))),
1878+
Row(Map(), Map(2 -> Seq(-2, -2)))))
1879+
1880+
// Invalid use cases
1881+
val df = Seq(
1882+
(Map(1 -> "a"), 1),
1883+
(Map.empty[Int, String], 2),
1884+
(null, 3)
1885+
).toDF("s", "i")
1886+
1887+
val ex1 = intercept[AnalysisException] {
1888+
df.selectExpr("map_filter(s, (x, y, z) -> x + y + z)")
1889+
}
1890+
assert(ex1.getMessage.contains("The number of lambda function arguments '3' does not match"))
1891+
1892+
val ex2 = intercept[AnalysisException] {
1893+
df.selectExpr("map_filter(s, x -> x)")
1894+
}
1895+
assert(ex2.getMessage.contains("The number of lambda function arguments '1' does not match"))
1896+
1897+
val ex3 = intercept[AnalysisException] {
1898+
df.selectExpr("map_filter(i, (k, v) -> k > v)")
1899+
}
1900+
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type"))
1901+
}
1902+
18571903
test("filter function - array for primitive type not containing null") {
18581904
val df = Seq(
18591905
Seq(1, 9, 8, 7),

0 commit comments

Comments
 (0)