Skip to content

Commit 1cbaf0c

Browse files
committed
addressed review comments
1 parent f7fd231 commit 1cbaf0c

File tree

3 files changed

+44
-128
lines changed

3 files changed

+44
-128
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -451,9 +451,9 @@ case class ArrayAggregate(
451451
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
452452
examples = """
453453
Examples:
454-
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + 1);
454+
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1);
455455
map(array(2, 3, 4), array(1, 2, 3))
456-
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v);
456+
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
457457
map(array(2, 4, 6), array(1, 2, 3))
458458
""",
459459
since = "2.4.0")

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
2122
import org.apache.spark.sql.types._
2223

2324
class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -60,9 +61,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
6061
}
6162

6263
def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
63-
val valueType = expr.dataType.asInstanceOf[MapType].valueType
64-
val keyType = expr.dataType.asInstanceOf[MapType].keyType
65-
TransformKeys(expr, createLambda(keyType, false, valueType, true, f))
64+
val map = expr.dataType.asInstanceOf[MapType]
65+
TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
6666
}
6767

6868
def aggregate(
@@ -239,35 +239,45 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
239239

240240
test("TransformKeys") {
241241
val ai0 = Literal.create(
242-
Map(1 -> 1, 2 -> 2, 3 -> 3),
243-
MapType(IntegerType, IntegerType))
242+
Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
243+
MapType(IntegerType, IntegerType, valueContainsNull = false))
244244
val ai1 = Literal.create(
245245
Map.empty[Int, Int],
246-
MapType(IntegerType, IntegerType))
246+
MapType(IntegerType, IntegerType, valueContainsNull = true))
247247
val ai2 = Literal.create(
248248
Map(1 -> 1, 2 -> null, 3 -> 3),
249-
MapType(IntegerType, IntegerType))
249+
MapType(IntegerType, IntegerType, valueContainsNull = true))
250+
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))
250251

251252
val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
252253
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
254+
val modKey: (Expression, Expression) => Expression = (k, v) => k % 3
253255

254-
checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3))
255-
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3))
256+
checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4))
257+
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4))
256258
checkEvaluation(
257-
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3))
259+
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
260+
checkEvaluation(transformKeys(ai0, modKey),
261+
ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4)))
258262
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
259263
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
260264
checkEvaluation(
261265
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
262266
checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3))
263267
checkEvaluation(
264268
transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3))
269+
checkEvaluation(transformKeys(ai3, plusOne), null)
265270

266271
val as0 = Literal.create(
267-
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType))
272+
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
273+
MapType(StringType, StringType, valueContainsNull = false))
268274
val as1 = Literal.create(
269-
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null), MapType(StringType, StringType))
270-
val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType))
275+
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
276+
MapType(StringType, StringType, valueContainsNull = true))
277+
val as2 = Literal.create(null,
278+
MapType(StringType, StringType, valueContainsNull = false))
279+
val asn = Literal.create(Map.empty[StringType, StringType],
280+
MapType(StringType, StringType, valueContainsNull = true))
271281

272282
val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
273283
val convertKeyToKeyLength: (Expression, Expression) => Expression =
@@ -286,6 +296,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
286296
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
287297
checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
288298
Map(2 -> "xy", 3 -> "yz", 4 -> null))
299+
checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null)
289300
checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String])
301+
302+
val ax0 = Literal.create(
303+
Map(1 -> "x", 2 -> "y", 3 -> "z"),
304+
MapType(IntegerType, StringType, valueContainsNull = false))
305+
306+
checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
290307
}
291308
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 12 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2123,71 +2123,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
21232123
).toDF("i")
21242124

21252125
val dfExample2 = Seq(
2126-
Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c")
2127-
).toDF("x")
2128-
2129-
val dfExample3 = Seq(
2130-
Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)
2131-
).toDF("y")
2132-
2133-
val dfExample4 = Seq(
21342126
Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0)
2135-
).toDF("z")
2127+
).toDF("j")
21362128

2137-
val dfExample5 = Seq(
2129+
val dfExample3 = Seq(
21382130
Map[Int, Boolean](25 -> true, 26 -> false)
2139-
).toDF("a")
2140-
2141-
val dfExample6 = Seq(
2142-
Map[Int, String](25 -> "ab", 26 -> "cd")
2143-
).toDF("b")
2131+
).toDF("x")
21442132

2145-
val dfExample7 = Seq(
2133+
val dfExample4 = Seq(
21462134
Map[Array[Int], Boolean](Array(1, 2) -> false)
2147-
).toDF("c")
2135+
).toDF("y")
21482136

21492137

21502138
def testMapOfPrimitiveTypesCombination(): Unit = {
21512139
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
21522140
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))
21532141

2154-
checkAnswer(dfExample2.selectExpr("transform_keys(x, (k, v) -> k + 1)"),
2155-
Seq(Row(Map(2 -> "a", 3 -> "b", 4 -> "c"))))
2156-
2157-
checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> v * v)"),
2158-
Seq(Row(Map(1 -> 1, 4 -> 2, 9 -> 3))))
2159-
2160-
checkAnswer(dfExample3.selectExpr("transform_keys(y, (k, v) -> length(k) + v)"),
2161-
Seq(Row(Map(2 -> 1, 3 -> 2, 4 -> 3))))
2162-
2163-
checkAnswer(
2164-
dfExample3.selectExpr("transform_keys(y, (k, v) -> concat(k, cast(v as String)))"),
2165-
Seq(Row(Map("a1" -> 1, "b2" -> 2, "c3" -> 3))))
2166-
2167-
checkAnswer(dfExample4.selectExpr("transform_keys(z, " +
2142+
checkAnswer(dfExample2.selectExpr("transform_keys(j, " +
21682143
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
21692144
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))
21702145

2171-
checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
2146+
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
21722147
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))
21732148

2174-
checkAnswer(dfExample4.selectExpr("transform_keys(z, (k, v) -> k + v)"),
2149+
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
21752150
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))
21762151

2177-
checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> k % 2 = 0 OR v)"),
2152+
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
21782153
Seq(Row(Map(true -> true, true -> false))))
21792154

2180-
checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
2155+
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
21812156
Seq(Row(Map(50 -> true, 78 -> false))))
21822157

2183-
checkAnswer(dfExample5.selectExpr("transform_keys(a, (k, v) -> if(v, 2 * k, 3 * k))"),
2158+
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
21842159
Seq(Row(Map(50 -> true, 78 -> false))))
21852160

2186-
checkAnswer(dfExample6.selectExpr(
2187-
"transform_keys(b, (k, v) -> concat(conv(k, 10, 16) , substr(v, 1, 1)))"),
2188-
Seq(Row(Map("19a" -> "ab", "1Ac" -> "cd"))))
2189-
2190-
checkAnswer(dfExample7.selectExpr("transform_keys(c, (k, v) -> array_contains(k, 3) AND v)"),
2161+
checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"),
21912162
Seq(Row(Map(false -> false))))
21922163
}
21932164
// Test with local relation, the Project will be evaluated without codegen
@@ -2196,52 +2167,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
21962167
dfExample2.cache()
21972168
dfExample3.cache()
21982169
dfExample4.cache()
2199-
dfExample5.cache()
2200-
dfExample6.cache()
22012170
// Test with cached relation, the Project will be evaluated with codegen
22022171
testMapOfPrimitiveTypesCombination()
22032172
}
22042173

2205-
test("transform keys function - test empty") {
2206-
val dfExample1 = Seq(
2207-
Map.empty[Int, Int]
2208-
).toDF("i")
2209-
2210-
val dfExample2 = Seq(
2211-
Map.empty[BigInt, String]
2212-
).toDF("j")
2213-
2214-
def testEmpty(): Unit = {
2215-
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> NULL)"),
2216-
Seq(Row(Map.empty[Null, Null])))
2217-
2218-
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k)"),
2219-
Seq(Row(Map.empty[Null, Null])))
2220-
2221-
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> v)"),
2222-
Seq(Row(Map.empty[Null, Null])))
2223-
2224-
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 0)"),
2225-
Seq(Row(Map.empty[Int, Null])))
2226-
2227-
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> 'key')"),
2228-
Seq(Row(Map.empty[String, Null])))
2229-
2230-
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> true)"),
2231-
Seq(Row(Map.empty[Boolean, Null])))
2232-
2233-
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + cast(v as BIGINT))"),
2234-
Seq(Row(Map.empty[BigInt, Null])))
2235-
2236-
checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> v)"),
2237-
Seq(Row(Map())))
2238-
}
2239-
testEmpty()
2240-
dfExample1.cache()
2241-
dfExample2.cache()
2242-
testEmpty()
2243-
}
2244-
22452174
test("transform keys function - Invalid lambda functions and exceptions") {
22462175
val dfExample1 = Seq(
22472176
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
@@ -2279,36 +2208,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
22792208
testInvalidLambdaFunctions()
22802209
}
22812210

2282-
test("transform keys function - test null") {
2283-
val dfExample1 = Seq(
2284-
Map[Boolean, Integer](true -> 1, false -> null)
2285-
).toDF("a")
2286-
2287-
def testNullValues(): Unit = {
2288-
checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> if(k, NOT k, v IS NULL))"),
2289-
Seq(Row(Map(false -> 1, true -> null))))
2290-
}
2291-
2292-
testNullValues()
2293-
dfExample1.cache()
2294-
testNullValues()
2295-
}
2296-
2297-
test("transform keys function - test duplicate keys") {
2298-
val dfExample1 = Seq(
2299-
Map[Int, String](1 -> "a", 2 -> "b", 3 -> "c", 4 -> "d")
2300-
).toDF("a")
2301-
2302-
def testNullValues(): Unit = {
2303-
checkAnswer(dfExample1.selectExpr("transform_keys(a, (k, v) -> k%3)"),
2304-
Seq(Row(Map(1 -> "a", 2 -> "b", 0 -> "c", 1 -> "d"))))
2305-
}
2306-
2307-
testNullValues()
2308-
dfExample1.cache()
2309-
testNullValues()
2310-
}
2311-
23122211
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
23132212
import DataFrameFunctionsSuite.CodegenFallbackExpr
23142213
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

0 commit comments

Comments
 (0)