Skip to content

Commit 0a19cc4

Browse files
committed
Added Support for transform_keys function
1 parent 1a5e460 commit 0a19cc4

File tree

6 files changed

+320
-2
lines changed

6 files changed

+320
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,7 @@ object FunctionRegistry {
444444
expression[ArrayTransform]("transform"),
445445
expression[ArrayFilter]("filter"),
446446
expression[ArrayAggregate]("aggregate"),
447+
expression[TransformKeys]("transform_keys"),
447448
CreateStruct.registryEntry,
448449

449450
// misc functions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
2525
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
2626
import org.apache.spark.sql.catalyst.expressions.codegen._
2727
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
28-
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
28+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData}
2929
import org.apache.spark.sql.types._
3030

3131
/**
@@ -365,3 +365,69 @@ case class ArrayAggregate(
365365

366366
override def prettyName: String = "aggregate"
367367
}
368+
369+
/**
370+
* Transform Keys in a map using the transform_keys function.
371+
*/
372+
@ExpressionDescription(
373+
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
374+
examples = """
375+
Examples:
376+
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1);
377+
map(array(2, 3, 4), array(1, 2, 3))
378+
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
379+
map(array(2, 4, 6), array(1, 2, 3))
380+
""",
381+
since = "2.4.0")
382+
case class TransformKeys(
383+
input: Expression,
384+
function: Expression)
385+
extends ArrayBasedHigherOrderFunction with CodegenFallback {
386+
387+
override def nullable: Boolean = input.nullable
388+
389+
override def dataType: DataType = {
390+
val valueType = input.dataType.asInstanceOf[MapType].valueType
391+
MapType(function.dataType, valueType, input.nullable)
392+
}
393+
394+
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
395+
396+
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction):
397+
TransformKeys = {
398+
val (keyElementType, valueElementType, containsNull) = input.dataType match {
399+
case MapType(keyType, valueType, containsNullValue) =>
400+
(keyType, valueType, containsNullValue)
401+
case _ =>
402+
val MapType(keyType, valueType, containsNullValue) = MapType.defaultConcreteType
403+
(keyType, valueType, containsNullValue)
404+
}
405+
copy(function = f(function, (keyElementType, false) :: (valueElementType, containsNull) :: Nil))
406+
}
407+
408+
@transient lazy val (keyVar, valueVar) = {
409+
val LambdaFunction(
410+
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
411+
(keyVar, valueVar)
412+
}
413+
414+
override def eval(input: InternalRow): Any = {
415+
val arr = this.input.eval(input).asInstanceOf[MapData]
416+
if (arr == null) {
417+
null
418+
} else {
419+
val f = functionForEval
420+
val resultKeys = new GenericArrayData(new Array[Any](arr.numElements))
421+
var i = 0
422+
while (i < arr.numElements) {
423+
keyVar.value.set(arr.keyArray().get(i, keyVar.dataType))
424+
valueVar.value.set(arr.valueArray().get(i, valueVar.dataType))
425+
resultKeys.update(i, f.eval(input))
426+
i += 1
427+
}
428+
new ArrayBasedMapData(resultKeys, arr.valueArray())
429+
}
430+
}
431+
432+
override def prettyName: String = "transform_keys"
433+
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
5959
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
6060
}
6161

62+
def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
63+
val valueType = expr.dataType.asInstanceOf[MapType].valueType
64+
val keyType = expr.dataType.asInstanceOf[MapType].keyType
65+
TransformKeys(expr, createLambda(keyType, false, valueType, true, f))
66+
}
67+
6268
def aggregate(
6369
expr: Expression,
6470
zero: Expression,
@@ -181,4 +187,46 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
181187
(acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)),
182188
15)
183189
}
190+
191+
test("TransformKeys") {
192+
val ai0 = Literal.create(
193+
Map(1 -> 1, 2 -> 2, 3 -> 3),
194+
MapType(IntegerType, IntegerType))
195+
val ai1 = Literal.create(
196+
Map.empty[Int, Int],
197+
MapType(IntegerType, IntegerType))
198+
199+
val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
200+
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
201+
202+
checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3))
203+
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3))
204+
checkEvaluation(
205+
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3))
206+
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
207+
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
208+
checkEvaluation(
209+
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
210+
211+
val as0 = Literal.create(
212+
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
213+
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))
214+
215+
val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
216+
val convertKeyAndConcatValue: (Expression, Expression) => Expression =
217+
(k, v) => Length(k) + 1
218+
219+
checkEvaluation(
220+
transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
221+
checkEvaluation(
222+
transformKeys(transformKeys(as0, concatValue), concatValue),
223+
Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
224+
checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String])
225+
checkEvaluation(
226+
transformKeys(transformKeys(asn, concatValue), convertKeyAndConcatValue),
227+
Map.empty[Int, String])
228+
checkEvaluation(transformKeys(as0, convertKeyAndConcatValue),
229+
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
230+
checkEvaluation(transformKeys(asn, convertKeyAndConcatValue), Map.empty[Int, String])
231+
}
184232
}

sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,17 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as
4545

4646
-- Aggregate a null array
4747
select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) as v;
48+
49+
create or replace temporary view nested as values
50+
(1, map(1,1,2,2,3,3)),
51+
(2, map(4,4,5,5,6,6))
52+
as t(x, ys);
53+
54+
-- Identity Transform Keys in a map
55+
select transform_keys(ys, (k, v) -> k) as v from nested;
56+
57+
-- Transform Keys in a map by adding constant
58+
select transform_keys(ys, (k, v) -> k + 1) as v from nested;
59+
60+
-- Transform Keys in a map using values
61+
select transform_keys(ys, (k, v) -> k + v) as v from nested;

sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 15
2+
-- Number of queries: 19
33

44

55
-- !query 0
@@ -145,3 +145,40 @@ select aggregate(cast(null as array<int>), 0, (a, y) -> a + y + 1, a -> a + 2) a
145145
struct<v:int>
146146
-- !query 14 output
147147
NULL
148+
149+
150+
-- !query 15
151+
create or replace temporary view nested as values
152+
(1, map(1,1,2,2,3,3)),
153+
(2, map(4,4,5,5,6,6))
154+
as t(x, ys)
155+
-- !query 15 schema
156+
struct<>
157+
-- !query 15 output
158+
159+
160+
-- !query 16
161+
select transform_keys(ys, (k, v) -> k) as v from nested
162+
-- !query 16 schema
163+
struct<v:map<int,int>>
164+
-- !query 16 output
165+
{1:1,2:2,3:3}
166+
{4:4,5:5,6:6}
167+
168+
169+
-- !query 17
170+
select transform_keys(ys, (k, v) -> k + 1) as v from nested
171+
-- !query 17 schema
172+
struct<v:map<int,int>>
173+
-- !query 17 output
174+
{2:1,3:2,4:3}
175+
{5:4,6:5,7:6}
176+
177+
178+
-- !query 18
179+
select transform_keys(ys, (k, v) -> k + v) as v from nested
180+
-- !query 18 schema
181+
struct<v:map<int,int>>
182+
-- !query 18 output
183+
{10:5,12:6,8:4}
184+
{2:1,4:2,6:3}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,6 +2071,158 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
20712071
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
20722072
}
20732073

2074+
test("transform keys function - test various primitive data types combinations") {
2075+
val dfExample1 = Seq(
2076+
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
2077+
).toDF("i")
2078+
2079+
val dfExample2 = Seq(
2080+
Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c")
2081+
).toDF("x")
2082+
2083+
val dfExample3 = Seq(
2084+
Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)
2085+
).toDF("y")
2086+
2087+
val dfExample4 = Seq(
2088+
Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0)
2089+
).toDF("z")
2090+
2091+
val dfExample5 = Seq(
2092+
Map[Int, Boolean](25 -> true, 26 -> false)
2093+
).toDF("a")
2094+
2095+
val dfExample6 = Seq(
2096+
Map[Int, String](25 -> "ab", 26 -> "cd")
2097+
).toDF("b")
2098+
2099+
val dfExample7 = Seq(
2100+
Map[Array[Int], Boolean](Array(1, 2) -> false)
2101+
).toDF("c")
2102+
2103+
2104+
def testMapOfPrimitiveTypesCombination(): Unit = {
2105+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
2106+
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
2107+
2108+
checkAnswer(dfExample2.selectExpr("transform_keys(x, (k, v) -> k + 1)"),
2109+
Seq(Row(Map(2 -> "a", 3 -> "b", 4 -> "c"))))
2110+
2111+
checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> v * v)"),
2112+
Seq(Row(Map(1 -> 1, 4 -> 2, 9 -> 3))))
2113+
2114+
checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> length(k) + v)"),
2115+
Seq(Row(Map(2 -> 1, 3 -> 2, 4 -> 3))))
2116+
2117+
checkAnswer(
2118+
dfExample3.selectExpr("transform_keys(y, (k, v) -> concat(k, cast(v as String)))"),
2119+
Seq(Row(Map("a1" -> 1, "b2" -> 2, "c3" -> 3))))
2120+
2121+
checkAnswer(dfExample4.selectExpr("transform_keys(z, " +
2122+
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
2123+
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))
2124+
2125+
checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
2126+
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
2127+
2128+
checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> k + v)"),
2129+
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
2130+
2131+
checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> k % 2 = 0 OR v)"),
2132+
Seq(Row(Map(true -> true, true -> false))))
2133+
2134+
checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
2135+
Seq(Row(Map(50 -> true, 78 -> false))))
2136+
2137+
checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
2138+
Seq(Row(Map(50 -> true, 78 -> false))))
2139+
2140+
checkAnswer(dfExample6.selectExpr(
2141+
"transform_keys(b, (k, v) -> concat(conv(k, 10, 16) , substr(v, 1, 1)))"),
2142+
Seq(Row(Map("19a" -> "ab", "1Ac" -> "cd"))))
2143+
2144+
checkAnswer(dfExample7.selectExpr("transform_keys(c, (k, v) -> array_contains(k, 3) AND v)"),
2145+
Seq(Row(Map(false -> false))))
2146+
}
2147+
// Test with local relation, the Project will be evaluated without codegen
2148+
testMapOfPrimitiveTypesCombination()
2149+
dfExample1.cache()
2150+
dfExample2.cache()
2151+
dfExample3.cache()
2152+
dfExample4.cache()
2153+
dfExample5.cache()
2154+
dfExample6.cache()
2155+
// Test with cached relation, the Project will be evaluated with codegen
2156+
testMapOfPrimitiveTypesCombination()
2157+
}
2158+
2159+
test("transform keys function - test empty") {
2160+
val dfExample1 = Seq(
2161+
Map.empty[Int, Int]
2162+
).toDF("i")
2163+
2164+
val dfExample2 = Seq(
2165+
Map.empty[BigInt, String]
2166+
).toDF("j")
2167+
2168+
def testEmpty(): Unit = {
2169+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> NULL)"),
2170+
Seq(Row(Map.empty[Null, Null])))
2171+
2172+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k)"),
2173+
Seq(Row(Map.empty[Null, Null])))
2174+
2175+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> v)"),
2176+
Seq(Row(Map.empty[Null, Null])))
2177+
2178+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 0)"),
2179+
Seq(Row(Map.empty[Int, Null])))
2180+
2181+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 'key')"),
2182+
Seq(Row(Map.empty[String, Null])))
2183+
2184+
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> true)"),
2185+
Seq(Row(Map.empty[Boolean, Null])))
2186+
2187+
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + cast(v as BIGINT))"),
2188+
Seq(Row(Map.empty[BigInt, Null])))
2189+
2190+
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> v)"),
2191+
Seq(Row(Map())))
2192+
}
2193+
testEmpty()
2194+
dfExample1.cache()
2195+
dfExample2.cache()
2196+
testEmpty()
2197+
}
2198+
2199+
test("transform keys function - Invalid lambda functions") {
2200+
val dfExample1 = Seq(
2201+
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
2202+
).toDF("i")
2203+
2204+
val dfExample2 = Seq(
2205+
Map[String, String]("a" -> "b")
2206+
).toDF("j")
2207+
2208+
def testInvalidLambdaFunctions(): Unit = {
2209+
val ex1 = intercept[AnalysisException] {
2210+
dfExample1.selectExpr("transform_keys(i, k -> k )")
2211+
}
2212+
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))
2213+
2214+
val ex2 = intercept[AnalysisException] {
2215+
dfExample2.selectExpr("transform_keys(j, (k, v, x) -> k + 1)")
2216+
}
2217+
assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match"))
2218+
}
2219+
2220+
testInvalidLambdaFunctions()
2221+
dfExample1.cache()
2222+
dfExample2.cache()
2223+
testInvalidLambdaFunctions()
2224+
}
2225+
20742226
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
20752227
import DataFrameFunctionsSuite.CodegenFallbackExpr
20762228
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)