From 68392e31d86f26663fbb8e5badac82b356081f47 Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 8 Aug 2018 11:42:36 -0700 Subject: [PATCH 1/7] Added transform_values function --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/higherOrderFunctions.scala | 58 +++++ .../HigherOrderFunctionsSuite.scala | 61 ++++++ .../inputs/higher-order-functions.sql | 14 ++ .../results/higher-order-functions.sql.out | 39 +++- .../spark/sql/DataFrameFunctionsSuite.scala | 204 ++++++++++++++++++ 6 files changed, 376 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 390debd865ee..fb38a51f3aa0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -445,6 +445,7 @@ object FunctionRegistry { expression[MapFilter]("map_filter"), expression[ArrayFilter]("filter"), expression[ArrayAggregate]("aggregate"), + expression[TransformValues]("transform_values"), CreateStruct.registryEntry, // misc functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index d20673359129..c3d9a2e11c9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -442,3 +442,61 @@ case class ArrayAggregate( override def prettyName: String = "aggregate" } + +/** + * Transform Values for every entry of the map by applying transform_values function. + * Returns map wth transformed values + */ +@ExpressionDescription( +usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", +examples = """ + Examples: + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1); + map(array(1, 2, 3), array(2, 3, 4)) + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); + map(array(1, 2, 3), array(2, 4, 6)) + """, +since = "2.4.0") +case class TransformValues( + input: Expression, + function: Expression) + extends MapBasedSimpleHigherOrderFunction with CodegenFallback { + + override def nullable: Boolean = input.nullable + + override def dataType: DataType = { + val map = input.dataType.asInstanceOf[MapType] + MapType(map.keyType, function.dataType, map.valueContainsNull) + } + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) + + @transient val (keyType, valueType, valueContainsNull) = + HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): + TransformValues = { + copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) + } + + @transient lazy val (keyVar, valueVar) = { + val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + (keyVar, valueVar) + } + + override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { + val map = value.asInstanceOf[MapData] + val f = functionForEval + val resultValues = new GenericArrayData(new Array[Any](map.numElements)) + var i = 0 + while (i < map.numElements) { + keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) + valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) + resultValues.update(i, f.eval(inputRow)) + i += 1 + } + new ArrayBasedMapData(map.keyArray(), resultValues) + } + override def prettyName: String = "transform_values" +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index f7e84b875791..8b7cb00f173e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -80,6 +80,12 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper aggregate(expr, zero, merge, identity) } + def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { + val valueType = expr.dataType.asInstanceOf[MapType].valueType + val keyType = expr.dataType.asInstanceOf[MapType].keyType + TransformValues(expr, createLambda(keyType, false, valueType, true, f)) + } + test("ArrayTransform") { val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false)) val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true)) @@ -230,4 +236,59 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper (acc, array) => coalesce(aggregate(array, acc, (acc, elem) => acc + elem), acc)), 15) } + + test("TransformValues") { + val ai0 = Literal.create( + Map(1 -> 1, 2 -> 2, 3 -> 3), + MapType(IntegerType, IntegerType)) + val ai1 = Literal.create( + Map(1 -> 1, 2 -> null, 3 -> 3), + MapType(IntegerType, IntegerType)) + val ain = Literal.create( + Map.empty[Int, Int], + MapType(IntegerType, IntegerType)) + + val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 + val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k + + checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4)) + checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4)) + checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation( + transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + checkEvaluation(transformValues(ain, plusOne), Map.empty[Int, Int]) + + val as0 = Literal.create( + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType)) + val as1 = Literal.create( + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), MapType(StringType, StringType)) + val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) + + val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) + val valueTypeUpdate: (Expression, Expression) => Expression = + (k, v) => Length(v) + 1 + + checkEvaluation( + transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx")) + checkEvaluation(transformValues(as0, valueTypeUpdate), + Map("a" -> 3, "bb" -> 3, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as0, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx")) + checkEvaluation(transformValues(as1, concatValue), + Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx")) + checkEvaluation(transformValues(as1, valueTypeUpdate), + Map("a" -> 3, "bb" -> null, "ccc" -> 3)) + checkEvaluation( + transformValues(transformValues(as1, concatValue), concatValue), + Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) + checkEvaluation(transformValues(asn, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(asn, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation( + transformValues(transformValues(asn, concatValue), valueTypeUpdate), + Map.empty[String, Int]) + } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index 136396d9553d..4e8d9bc2aa08 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -45,3 +45,17 @@ select transform(zs, z -> aggregate(z, 1, (acc, val) -> acc * val * size(z))) as -- Aggregate a null array select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) as v; + +create or replace temporary view nested as values + (1, map(1,1,2,2,3,3)), + (2, map(4,4,5,5,6,6)) + as t(x, ys); + +-- Identity Transform Keys in a map +select transform_values(ys, (k, v) -> v) as v from nested; + +-- Transform Keys in a map by adding constant +select transform_values(ys, (k, v) -> v + 1) as v from nested; + +-- Transform Keys in a map using values +select transform_values(ys, (k, v) -> k + v) as v from nested; \ No newline at end of file diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index e6f62f2e1bb6..a109710fcc94 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 18 -- !query 0 @@ -145,3 +145,40 @@ select aggregate(cast(null as array), 0, (a, y) -> a + y + 1, a -> a + 2) a struct -- !query 14 output NULL + + +-- !query 15 +create or replace temporary view nested as values + (1, map(1,1,2,2,3,3)), + (2, map(4,4,5,5,6,6)) + as t(x, ys) +-- !query 15 schema +struct<> +-- !query 15 output + + +-- !query 16 +select transform_values(ys, (k, v) -> v) as v from nested +-- !query 16 schema +struct> +-- !query 16 output +{1:1,2:2,3:3} +{4:4,5:5,6:6} + + +-- !query 17 +select transform_values(ys, (k, v) -> v + 1) as v from nested +-- !query 17 schema +struct> +-- !query 17 output +{1:2,2:3,3:4} +{4:5,5:6,6:7} + + +-- !query 18 +select transform_values(ys, (k, v) -> k + v) as v from nested +-- !query 18 schema +struct> +-- !query 18 output +{1:2,2:4,3:6} +{4:8,5:10,6:12} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 24091f212804..830d21e36f98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2117,6 +2117,210 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) } + test("transform values function - test various primitive data types combinations") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[Boolean, String](false -> "abc", true -> "def") + ).toDF("x") + + val dfExample3 = Seq( + Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3) + ).toDF("y") + + val dfExample4 = Seq( + Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0) + ).toDF("z") + + val dfExample5 = Seq( + Map[Int, Boolean](25 -> true, 26 -> false) + ).toDF("a") + + val dfExample6 = Seq( + Map[Int, String](25 -> "ab", 26 -> "cd") + ).toDF("b") + + val dfExample7 = Seq( + Map[Int, Array[Int]](1 -> Array(1, 2)) + ).toDF("c") + + val dfExample8 = Seq( + Map[Int, Double](25 -> 26.1E0, 26 -> 31.2E0, 27 -> 37.1E0) + ).toDF("d") + + val dfExample10 = Seq( + Map[String, String]("s0" -> "abc", "s1" -> "def") + ).toDF("f") + + def testMapOfPrimitiveTypesCombination(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), + Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"), + Seq(Row(Map(false -> "false", true -> "def")))) + + checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"), + Seq(Row(Map(false -> true, true -> false)))) + + checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"), + Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9)))) + + checkAnswer(dfExample3.selectExpr( + "transform_values(y, (k, v) -> k || ':' || CAST(v as String))"), + Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3")))) + + checkAnswer( + dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"), + Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3")))) + + checkAnswer( + dfExample4.selectExpr( + "transform_values(" + + "z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " + + "ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"), + Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7")))) + + checkAnswer( + dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"), + Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(a, (k, v) -> if(v, k + 1, k + 2))"), + Seq(Row(Map(25 -> 26, 26 -> 28)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(a, (k, v) -> v AND k = 25)"), + Seq(Row(Map(25 -> true, 26 -> false)))) + + checkAnswer( + dfExample5.selectExpr("transform_values(a, (k, v) -> v OR k = 26)"), + Seq(Row(Map(25 -> true, 26 -> true)))) + + checkAnswer( + dfExample6.selectExpr("transform_values(b, (k, v) -> k + length(v))"), + Seq(Row(Map(25 -> 27, 26 -> 28)))) + + checkAnswer( + dfExample7.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + Seq(Row(Map(1 -> 3)))) + + checkAnswer( + dfExample8.selectExpr("transform_values(d, (k, v) -> CAST(v - k AS BIGINT))"), + Seq(Row(Map(25 -> 1, 26 -> 5, 27 -> 10)))) + + checkAnswer( + dfExample10.selectExpr("transform_values(f, (k, v) -> k || ':' || v)"), + Seq(Row(Map("s0" -> "s0:abc", "s1" -> "s1:def")))) + } + + // Test with local relation, the Project will be evaluated without codegen + testMapOfPrimitiveTypesCombination() + dfExample1.cache() + dfExample2.cache() + dfExample3.cache() + dfExample4.cache() + dfExample5.cache() + dfExample6.cache() + dfExample7.cache() + dfExample8.cache() + dfExample10.cache() + // Test with cached relation, the Project will be evaluated with codegen + testMapOfPrimitiveTypesCombination() + } + + test("transform values function - test empty") { + val dfExample1 = Seq( + Map.empty[Int, Int] + ).toDF("i") + + val dfExample2 = Seq( + Map.empty[BigInt, String] + ).toDF("j") + + def testEmpty(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"), + Seq(Row(Map.empty[Int, Null]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"), + Seq(Row(Map.empty[Int, Int]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"), + Seq(Row(Map.empty[Int, Int]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"), + Seq(Row(Map.empty[Int, Int]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"), + Seq(Row(Map.empty[Int, String]))) + + checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"), + Seq(Row(Map.empty[Int, Boolean]))) + + checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), + Seq(Row(Map.empty[BigInt, BigInt]))) + } + + testEmpty() + dfExample1.cache() + dfExample2.cache() + testEmpty() + } + + test("transform values function - test null values") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) + ).toDF("a") + + val dfExample2 = Seq( + Map[Int, String](1 -> "a", 2 -> "b", 3 -> null) + ).toDF("b") + + def testNullValue(): Unit = { + checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"), + Seq(Row(Map(1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + + checkAnswer(dfExample2.selectExpr( + "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), + Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4)))) + } + + testNullValue() + dfExample1.cache() + dfExample2.cache() + testNullValue() + } + + test("transform values function - test invalid functions") { + val dfExample1 = Seq( + Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) + ).toDF("i") + + val dfExample2 = Seq( + Map[String, String]("a" -> "b") + ).toDF("j") + + def testInvalidLambdaFunctions(): Unit = { + + val ex1 = intercept[AnalysisException] { + dfExample1.selectExpr("transform_values(i, k -> k )") + } + assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) + + val ex2 = intercept[AnalysisException] { + dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") + } + assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + } + + testInvalidLambdaFunctions() + dfExample1.cache() + dfExample2.cache() + testInvalidLambdaFunctions() + } + private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { import DataFrameFunctionsSuite.CodegenFallbackExpr for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { From fd056454f52c84ada529c5432b1d34b0c0a66367 Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 8 Aug 2018 15:05:54 -0700 Subject: [PATCH 2/7] added test improvements --- .../spark/sql/DataFrameFunctionsSuite.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 830d21e36f98..950b5299aa4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2233,7 +2233,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("transform values function - test empty") { val dfExample1 = Seq( - Map.empty[Int, Int] + Map.empty[Integer, Integer] ).toDF("i") val dfExample2 = Seq( @@ -2242,22 +2242,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { def testEmpty(): Unit = { checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"), - Seq(Row(Map.empty[Int, Null]))) + Seq(Row(Map.empty[Integer, Integer]))) checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"), - Seq(Row(Map.empty[Int, Int]))) + Seq(Row(Map.empty[Integer, Integer]))) checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"), - Seq(Row(Map.empty[Int, Int]))) + Seq(Row(Map.empty[Integer, Integer]))) checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"), - Seq(Row(Map.empty[Int, Int]))) + Seq(Row(Map.empty[Integer, Integer]))) checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"), - Seq(Row(Map.empty[Int, String]))) + Seq(Row(Map.empty[Integer, String]))) checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"), - Seq(Row(Map.empty[Int, Boolean]))) + Seq(Row(Map.empty[Integer, Boolean]))) checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"), Seq(Row(Map.empty[BigInt, BigInt]))) @@ -2271,7 +2271,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { test("transform values function - test null values") { val dfExample1 = Seq( - Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) + Map[Int, Integer](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4) ).toDF("a") val dfExample2 = Seq( @@ -2280,7 +2280,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { def testNullValue(): Unit = { checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"), - Seq(Row(Map(1 -> null, 2 -> null, 3 -> null, 4 -> null)))) + Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null)))) checkAnswer(dfExample2.selectExpr( "transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"), From cdecd32e1cc2253fb85697b389b328ddc554f9b7 Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 8 Aug 2018 16:07:16 -0700 Subject: [PATCH 3/7] addressed review comments --- .../catalyst/expressions/higherOrderFunctions.scala | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index c3d9a2e11c9d..6a25baac87d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -444,14 +444,13 @@ case class ArrayAggregate( } /** - * Transform Values for every entry of the map by applying transform_values function. - * Returns map wth transformed values + * Returns a map that applies the function to each value of the map. */ @ExpressionDescription( usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k,v) -> k + 1); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> v + 1); map(array(1, 2, 3), array(2, 3, 4)) > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); map(array(1, 2, 3), array(2, 4, 6)) @@ -466,16 +465,14 @@ case class TransformValues( override def dataType: DataType = { val map = input.dataType.asInstanceOf[MapType] - MapType(map.keyType, function.dataType, map.valueContainsNull) + MapType(map.keyType, function.dataType, function.nullable) } - override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType) - @transient val (keyType, valueType, valueContainsNull) = HigherOrderFunction.mapKeyValueArgumentType(input.dataType) - override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): - TransformValues = { + override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) + : TransformValues = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } From b73106d43000972ab9adae3d3b463a0dada2b9cc Mon Sep 17 00:00:00 2001 From: codeatri Date: Tue, 14 Aug 2018 16:39:20 -0700 Subject: [PATCH 4/7] Merge master Refactoring changes --- .../catalyst/expressions/higherOrderFunctions.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 8646872453ec..63cf5fc36a5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -511,19 +511,18 @@ examples = """ """, since = "2.4.0") case class TransformValues( - input: Expression, + argument: Expression, function: Expression) extends MapBasedSimpleHigherOrderFunction with CodegenFallback { - override def nullable: Boolean = input.nullable + override def nullable: Boolean = argument.nullable override def dataType: DataType = { - val map = input.dataType.asInstanceOf[MapType] + val map = argument.dataType.asInstanceOf[MapType] MapType(map.keyType, function.dataType, function.nullable) } - @transient val (keyType, valueType, valueContainsNull) = - HigherOrderFunction.mapKeyValueArgumentType(input.dataType) + @transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) : TransformValues = { @@ -536,8 +535,8 @@ case class TransformValues( (keyVar, valueVar) } - override def nullSafeEval(inputRow: InternalRow, value: Any): Any = { - val map = value.asInstanceOf[MapData] + override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { + val map = argumentValue.asInstanceOf[MapData] val f = functionForEval val resultValues = new GenericArrayData(new Array[Any](map.numElements)) var i = 0 From daf793599a6da5c11dbc4a6bd6e5dea3e0d47afd Mon Sep 17 00:00:00 2001 From: codeatri Date: Wed, 15 Aug 2018 14:47:46 -0700 Subject: [PATCH 5/7] review comments --- .../expressions/higherOrderFunctions.scala | 21 +++---- .../HigherOrderFunctionsSuite.scala | 44 ++++++++----- .../inputs/higher-order-functions.sql | 4 +- .../results/higher-order-functions.sql.out | 4 +- .../spark/sql/DataFrameFunctionsSuite.scala | 61 +++++-------------- 5 files changed, 53 insertions(+), 81 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 63cf5fc36a5a..5c7fb9a675ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -504,9 +504,9 @@ case class ArrayAggregate( usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> v + 1); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); map(array(1, 2, 3), array(2, 3, 4)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3), (k, v) -> k + v); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); map(array(1, 2, 3), array(2, 4, 6)) """, since = "2.4.0") @@ -517,33 +517,26 @@ case class TransformValues( override def nullable: Boolean = argument.nullable - override def dataType: DataType = { - val map = argument.dataType.asInstanceOf[MapType] - MapType(map.keyType, function.dataType, function.nullable) - } + @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType - @transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType + override def dataType: DataType = MapType(keyType, function.dataType, valueContainsNull) override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction) : TransformValues = { copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil)) } - @transient lazy val (keyVar, valueVar) = { - val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function - (keyVar, valueVar) - } + @transient lazy val LambdaFunction( + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] - val f = functionForEval val resultValues = new GenericArrayData(new Array[Any](map.numElements)) var i = 0 while (i < map.numElements) { keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) - resultValues.update(i, f.eval(inputRow)) + resultValues.update(i, functionForEval.eval(inputRow)) i += 1 } new ArrayBasedMapData(map.keyArray(), resultValues) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 60ed04672cce..8c3ff3ed918a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -96,9 +96,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper } def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = { - val valueType = expr.dataType.asInstanceOf[MapType].valueType - val keyType = expr.dataType.asInstanceOf[MapType].keyType - TransformValues(expr, createLambda(keyType, false, valueType, true, f)) + val map = expr.dataType.asInstanceOf[MapType] + TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f)) } test("ArrayTransform") { @@ -292,13 +291,14 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper test("TransformValues") { val ai0 = Literal.create( Map(1 -> 1, 2 -> 2, 3 -> 3), - MapType(IntegerType, IntegerType)) + MapType(IntegerType, IntegerType, valueContainsNull = false)) val ai1 = Literal.create( Map(1 -> 1, 2 -> null, 3 -> 3), - MapType(IntegerType, IntegerType)) - val ain = Literal.create( + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai2 = Literal.create( Map.empty[Int, Int], - MapType(IntegerType, IntegerType)) + MapType(IntegerType, IntegerType, valueContainsNull = true)) + val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1 val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k @@ -311,13 +311,18 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) checkEvaluation( transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) - checkEvaluation(transformValues(ain, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int]) + checkEvaluation(transformValues(ai3, plusOne), null) val as0 = Literal.create( - Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), MapType(StringType, StringType)) + Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = false)) val as1 = Literal.create( - Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), MapType(StringType, StringType)) - val asn = Literal.create(Map.empty[StringType, StringType], MapType(StringType, StringType)) + Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"), + MapType(StringType, StringType, valueContainsNull = true)) + val as2 = Literal.create(Map.empty[StringType, StringType], + MapType(StringType, StringType, valueContainsNull = true)) + val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true)) val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v)) val valueTypeUpdate: (Expression, Expression) => Expression = @@ -337,13 +342,20 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( transformValues(transformValues(as1, concatValue), concatValue), Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx")) - checkEvaluation(transformValues(asn, concatValue), Map.empty[String, String]) - checkEvaluation(transformValues(asn, valueTypeUpdate), Map.empty[String, Int]) + checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String]) + checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int]) checkEvaluation( - transformValues(transformValues(asn, concatValue), valueTypeUpdate), + transformValues(transformValues(as2, concatValue), valueTypeUpdate), Map.empty[String, Int]) - } - + checkEvaluation(transformValues(as3, concatValue), null) + + val ax0 = Literal.create( + Map(1 -> "x", 2 -> "y", 3 -> "z"), + MapType(IntegerType, StringType, valueContainsNull = false)) + + checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9)) + } + test("MapZipWith") { def map_zip_with( left: Expression, diff --git a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql index aad5dcbeda30..bdb884ae9ab2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/higher-order-functions.sql @@ -53,8 +53,8 @@ select exists(ys, y -> y > 30) as v from nested; select exists(cast(null as array), y -> y > 30) as v; create or replace temporary view nested as values - (1, map(1,1,2,2,3,3)), - (2, map(4,4,5,5,6,6)) + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys); -- Identity Transform Keys in a map diff --git a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out index 770e7875e051..06e0e231ec71 100644 --- a/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/higher-order-functions.sql.out @@ -167,8 +167,8 @@ NULL -- !query 17 create or replace temporary view nested as values - (1, map(1,1,2,2,3,3)), - (2, map(4,4,5,5,6,6)) + (1, map(1, 1, 2, 2, 3, 3)), + (2, map(4, 4, 5, 5, 6, 6)) as t(x, ys) -- !query 17 schema struct<> diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 28bfc8c2cf4d..ae5e9424a59e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2302,7 +2302,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } - test("transform values function - test various primitive data types combinations") { + test("transform values function - test primitive data types") { val dfExample1 = Seq( Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7) ).toDF("i") @@ -2316,29 +2316,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ).toDF("y") val dfExample4 = Seq( - Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0) + Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70) ).toDF("z") val dfExample5 = Seq( - Map[Int, Boolean](25 -> true, 26 -> false) - ).toDF("a") - - val dfExample6 = Seq( - Map[Int, String](25 -> "ab", 26 -> "cd") - ).toDF("b") - - val dfExample7 = Seq( Map[Int, Array[Int]](1 -> Array(1, 2)) ).toDF("c") - val dfExample8 = Seq( - Map[Int, Double](25 -> 26.1E0, 26 -> 31.2E0, 27 -> 37.1E0) - ).toDF("d") - - val dfExample10 = Seq( - Map[String, String]("s0" -> "abc", "s1" -> "def") - ).toDF("f") - def testMapOfPrimitiveTypesCombination(): Unit = { checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"), Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14)))) @@ -2373,32 +2357,8 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3)))) checkAnswer( - dfExample5.selectExpr("transform_values(a, (k, v) -> if(v, k + 1, k + 2))"), - Seq(Row(Map(25 -> 26, 26 -> 28)))) - - checkAnswer( - dfExample5.selectExpr("transform_values(a, (k, v) -> v AND k = 25)"), - Seq(Row(Map(25 -> true, 26 -> false)))) - - checkAnswer( - dfExample5.selectExpr("transform_values(a, (k, v) -> v OR k = 26)"), - Seq(Row(Map(25 -> true, 26 -> true)))) - - checkAnswer( - dfExample6.selectExpr("transform_values(b, (k, v) -> k + length(v))"), - Seq(Row(Map(25 -> 27, 26 -> 28)))) - - checkAnswer( - dfExample7.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), + dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"), Seq(Row(Map(1 -> 3)))) - - checkAnswer( - dfExample8.selectExpr("transform_values(d, (k, v) -> CAST(v - k AS BIGINT))"), - Seq(Row(Map(25 -> 1, 26 -> 5, 27 -> 10)))) - - checkAnswer( - dfExample10.selectExpr("transform_values(f, (k, v) -> k || ':' || v)"), - Seq(Row(Map("s0" -> "s0:abc", "s1" -> "s1:def")))) } // Test with local relation, the Project will be evaluated without codegen @@ -2408,10 +2368,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample3.cache() dfExample4.cache() dfExample5.cache() - dfExample6.cache() - dfExample7.cache() - dfExample8.cache() - dfExample10.cache() // Test with cached relation, the Project will be evaluated with codegen testMapOfPrimitiveTypesCombination() } @@ -2487,6 +2443,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { Map[String, String]("a" -> "b") ).toDF("j") + val dfExample3 = Seq( + Seq(1, 2, 3, 4) + ).toDF("x") + def testInvalidLambdaFunctions(): Unit = { val ex1 = intercept[AnalysisException] { @@ -2498,11 +2458,18 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)") } assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match")) + + val ex3 = intercept[AnalysisException] { + dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)") + } + assert(ex3.getMessage.contains( + "data type mismatch: argument 1 requires map type")) } testInvalidLambdaFunctions() dfExample1.cache() dfExample2.cache() + dfExample3.cache() testInvalidLambdaFunctions() } From 56d08ef37531f8e25ae2c7fe3996cf7657384a80 Mon Sep 17 00:00:00 2001 From: codeatri Date: Thu, 16 Aug 2018 00:15:37 -0700 Subject: [PATCH 6/7] post review --- .../catalyst/expressions/higherOrderFunctions.scala | 12 ++++++------ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 5c7fb9a675ae..b91ffd3c9f64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -501,15 +501,15 @@ case class ArrayAggregate( * Returns a map that applies the function to each value of the map. */ @ExpressionDescription( -usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", -examples = """ + usage = "_FUNC_(expr, func) - Transforms values in the map using the function.", + examples = """ Examples: - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1); map(array(1, 2, 3), array(2, 3, 4)) - > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); + > SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v); map(array(1, 2, 3), array(2, 4, 6)) """, -since = "2.4.0") + since = "2.4.0") case class TransformValues( argument: Expression, function: Expression) @@ -527,7 +527,7 @@ case class TransformValues( } @transient lazy val LambdaFunction( - _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index ae5e9424a59e..7c7ce2ed1a18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -2450,7 +2450,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { def testInvalidLambdaFunctions(): Unit = { val ex1 = intercept[AnalysisException] { - dfExample1.selectExpr("transform_values(i, k -> k )") + dfExample1.selectExpr("transform_values(i, k -> k)") } assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match")) From 3382e1a5396c8e5a94802d92a7106eacf627617c Mon Sep 17 00:00:00 2001 From: codeatri Date: Thu, 16 Aug 2018 13:01:50 -0700 Subject: [PATCH 7/7] adding new line --- .../apache/spark/sql/catalyst/analysis/FunctionRegistry.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 15cdd96a2c2a..77860e1584f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -450,6 +450,7 @@ object FunctionRegistry { expression[TransformKeys]("transform_keys"), expression[MapZipWith]("map_zip_with"), expression[ZipWith]("zip_with"), + CreateStruct.registryEntry, // misc functions