Skip to content

Commit f161409

Browse files
codeatriueshin
authored andcommitted
[SPARK-23940][SQL] Add transform_values SQL function
## What changes were proposed in this pull request? This pr adds `transform_values` function which applies the function to each entry of the map and transforms the values. ```javascript > SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> v + 1); map(1->2, 2->3, 3->4) > SELECT transform_values(map(array(1, 2, 3), array(1, 2, 3)), (k,v) -> k + v); map(1->2, 2->4, 3->6) ``` ## How was this patch tested? New Tests added to `DataFrameFunctionsSuite` `HigherOrderFunctionsSuite` `SQLQueryTestSuite` Closes #22045 from codeatri/SPARK-23940. Authored-by: codeatri <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 9251c61 commit f161409

File tree

6 files changed

+332
-3
lines changed

6 files changed

+332
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,7 @@ object FunctionRegistry {
446446
expression[ArrayFilter]("filter"),
447447
expression[ArrayExists]("exists"),
448448
expression[ArrayAggregate]("aggregate"),
449+
expression[TransformValues]("transform_values"),
449450
expression[TransformKeys]("transform_keys"),
450451
expression[MapZipWith]("map_zip_with"),
451452
expression[ZipWith]("zip_with"),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ case class TransformKeys(
527527
}
528528

529529
@transient lazy val LambdaFunction(
530-
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
530+
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
531531

532532

533533
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
@@ -550,6 +550,54 @@ case class TransformKeys(
550550
override def prettyName: String = "transform_keys"
551551
}
552552

553+
/**
554+
* Returns a map that applies the function to each value of the map.
555+
*/
556+
@ExpressionDescription(
557+
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
558+
examples = """
559+
Examples:
560+
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1);
561+
map(array(1, 2, 3), array(2, 3, 4))
562+
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
563+
map(array(1, 2, 3), array(2, 4, 6))
564+
""",
565+
since = "2.4.0")
566+
case class TransformValues(
567+
argument: Expression,
568+
function: Expression)
569+
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
570+
571+
override def nullable: Boolean = argument.nullable
572+
573+
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
574+
575+
override def dataType: DataType = MapType(keyType, function.dataType, function.nullable)
576+
577+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction)
578+
: TransformValues = {
579+
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
580+
}
581+
582+
@transient lazy val LambdaFunction(
583+
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
584+
585+
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
586+
val map = argumentValue.asInstanceOf[MapData]
587+
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
588+
var i = 0
589+
while (i < map.numElements) {
590+
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
591+
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
592+
resultValues.update(i, functionForEval.eval(inputRow))
593+
i += 1
594+
}
595+
new ArrayBasedMapData(map.keyArray(), resultValues)
596+
}
597+
598+
override def prettyName: String = "transform_values"
599+
}
600+
553601
/**
554602
* Merges two given maps into a single map by applying function to the pair of values with
555603
* the same key.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
101101
aggregate(expr, zero, merge, identity)
102102
}
103103

104+
def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
105+
val map = expr.dataType.asInstanceOf[MapType]
106+
TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
107+
}
108+
104109
test("ArrayTransform") {
105110
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
106111
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
@@ -358,6 +363,74 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
358363
checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
359364
}
360365

366+
test("TransformValues") {
367+
val ai0 = Literal.create(
368+
Map(1 -> 1, 2 -> 2, 3 -> 3),
369+
MapType(IntegerType, IntegerType, valueContainsNull = false))
370+
val ai1 = Literal.create(
371+
Map(1 -> 1, 2 -> null, 3 -> 3),
372+
MapType(IntegerType, IntegerType, valueContainsNull = true))
373+
val ai2 = Literal.create(
374+
Map.empty[Int, Int],
375+
MapType(IntegerType, IntegerType, valueContainsNull = true))
376+
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))
377+
378+
val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1
379+
val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k
380+
381+
checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4))
382+
checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
383+
checkEvaluation(
384+
transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
385+
checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4))
386+
checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
387+
checkEvaluation(
388+
transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
389+
checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int])
390+
checkEvaluation(transformValues(ai3, plusOne), null)
391+
392+
val as0 = Literal.create(
393+
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
394+
MapType(StringType, StringType, valueContainsNull = false))
395+
val as1 = Literal.create(
396+
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"),
397+
MapType(StringType, StringType, valueContainsNull = true))
398+
val as2 = Literal.create(Map.empty[StringType, StringType],
399+
MapType(StringType, StringType, valueContainsNull = true))
400+
val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true))
401+
402+
val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
403+
val valueTypeUpdate: (Expression, Expression) => Expression =
404+
(k, v) => Length(v) + 1
405+
406+
checkEvaluation(
407+
transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx"))
408+
checkEvaluation(transformValues(as0, valueTypeUpdate),
409+
Map("a" -> 3, "bb" -> 3, "ccc" -> 3))
410+
checkEvaluation(
411+
transformValues(transformValues(as0, concatValue), concatValue),
412+
Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx"))
413+
checkEvaluation(transformValues(as1, concatValue),
414+
Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx"))
415+
checkEvaluation(transformValues(as1, valueTypeUpdate),
416+
Map("a" -> 3, "bb" -> null, "ccc" -> 3))
417+
checkEvaluation(
418+
transformValues(transformValues(as1, concatValue), concatValue),
419+
Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx"))
420+
checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String])
421+
checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int])
422+
checkEvaluation(
423+
transformValues(transformValues(as2, concatValue), valueTypeUpdate),
424+
Map.empty[String, Int])
425+
checkEvaluation(transformValues(as3, concatValue), null)
426+
427+
val ax0 = Literal.create(
428+
Map(1 -> "x", 2 -> "y", 3 -> "z"),
429+
MapType(IntegerType, StringType, valueContainsNull = false))
430+
431+
checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
432+
}
433+
361434
test("MapZipWith") {
362435
def map_zip_with(
363436
left: Expression,

sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,12 @@ select transform_keys(ys, (k, v) -> k + 1) as v from nested;
7474

7575
-- Transform Keys in a map using values
7676
select transform_keys(ys, (k, v) -> k + v) as v from nested;
77+
78+
-- Identity Transform values in a map
79+
select transform_values(ys, (k, v) -> v) as v from nested;
80+
81+
-- Transform values in a map by adding constant
82+
select transform_values(ys, (k, v) -> v + 1) as v from nested;
83+
84+
-- Transform values in a map using values
85+
select transform_values(ys, (k, v) -> k + v) as v from nested;

sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 20
2+
-- Number of queries: 27
33

44

55
-- !query 0
@@ -226,3 +226,30 @@ struct<v:map<int,int>>
226226
-- !query 23 output
227227
{10:5,12:6,8:4}
228228
{2:1,4:2,6:3}
229+
230+
231+
-- !query 24
232+
select transform_values(ys, (k, v) -> v) as v from nested
233+
-- !query 24 schema
234+
struct<v:map<int,int>>
235+
-- !query 24 output
236+
{1:1,2:2,3:3}
237+
{4:4,5:5,6:6}
238+
239+
240+
-- !query 25
241+
select transform_values(ys, (k, v) -> v + 1) as v from nested
242+
-- !query 25 schema
243+
struct<v:map<int,int>>
244+
-- !query 25 output
245+
{1:2,2:3,3:4}
246+
{4:5,5:6,6:7}
247+
248+
249+
-- !query 26
250+
select transform_values(ys, (k, v) -> k + v) as v from nested
251+
-- !query 26 schema
252+
struct<v:map<int,int>>
253+
-- !query 26 output
254+
{1:2,2:4,3:6}
255+
{4:8,5:10,6:12}

0 commit comments

Comments
 (0)