Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ object FunctionRegistry {
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
CreateStruct.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,62 @@ case class ArrayAggregate(
override def prettyName: String = "aggregate"
}

/**
* Transform Keys for every entry of the map by applying the transform_keys function.
* Returns map with transformed key entries
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1);
map(array(2, 3, 4), array(1, 2, 3))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(2, 4, 6), array(1, 2, 3))
""",
since = "2.4.0")
case class TransformKeys(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = argument.nullable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be moved to SimpleHigherOrderFunction

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.
Let's have wrap-up prs for higher-order functions after the remaining 2 prs are merged.


@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType

override def dataType: DataType = {
MapType(function.dataType, valueType, valueContainsNull)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just in one line?

}

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function


override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val f = functionForEval
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use functionForEval directly?

val resultKeys = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
val result = f.eval(inputRow)
if (result == null) {
throw new RuntimeException("Cannot use null as map key!")
}
resultKeys.update(i, result)
i += 1
}
new ArrayBasedMapData(resultKeys, map.valueArray())
}

override def prettyName: String = "transform_keys"
}

/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._

class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -74,6 +75,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val map = expr.dataType.asInstanceOf[MapType]
TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

def aggregate(
expr: Expression,
zero: Expression,
Expand Down Expand Up @@ -283,6 +289,75 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
15)
}

test("TransformKeys") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
val modKey: (Expression, Expression) => Expression = (k, v) => k % 3

checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4))
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4))
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
checkEvaluation(transformKeys(ai0, modKey),
ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4)))
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3))
checkEvaluation(
transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3))
checkEvaluation(transformKeys(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(null,
MapType(StringType, StringType, valueContainsNull = false))
val as3 = Literal.create(Map.empty[StringType, StringType],
MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val convertKeyToKeyLength: (Expression, Expression) => Expression =
(k, v) => Length(k) + 1

checkEvaluation(
transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
checkEvaluation(
transformKeys(transformKeys(as0, concatValue), concatValue),
Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
checkEvaluation(transformKeys(as3, concatValue), Map.empty[String, String])
checkEvaluation(
transformKeys(transformKeys(as3, concatValue), convertKeyToKeyLength),
Map.empty[Int, String])
checkEvaluation(transformKeys(as0, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> null))
checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null)
checkEvaluation(transformKeys(as3, convertKeyToKeyLength), Map.empty[Int, String])

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ select exists(ys, y -> y > 30) as v from nested;

-- Check for element existence in a null array
select exists(cast(null as array<int>), y -> y > 30) as v;

create or replace temporary view nested as values
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys);

-- Identity Transform Keys in a map
select transform_keys(ys, (k, v) -> k) as v from nested;

-- Transform Keys in a map by adding constant
select transform_keys(ys, (k, v) -> k + 1) as v from nested;

-- Transform Keys in a map using values
select transform_keys(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 20


-- !query 0
Expand Down Expand Up @@ -163,3 +163,40 @@ select exists(cast(null as array<int>), y -> y > 30) as v
struct<v:boolean>
-- !query 16 output
NULL


-- !query 17
create or replace temporary view nested as values
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys)
-- !query 17 schema
struct<>
-- !query 17 output


-- !query 18
select transform_keys(ys, (k, v) -> k) as v from nested
-- !query 18 schema
struct<v:map<int,int>>
-- !query 18 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 19
select transform_keys(ys, (k, v) -> k + 1) as v from nested
-- !query 19 schema
struct<v:map<int,int>>
-- !query 19 output
{2:1,3:2,4:3}
{5:4,6:5,7:6}


-- !query 20
select transform_keys(ys, (k, v) -> k + v) as v from nested
-- !query 20 schema
struct<v:map<int,int>>
-- !query 20 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,100 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map"))
}

test("transform keys function - primitive data types") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
).toDF("j")

val dfExample3 = Seq(
Map[Int, Boolean](25 -> true, 26 -> false)
).toDF("x")

val dfExample4 = Seq(
Map[Array[Int], Boolean](Array(1, 2) -> false)
).toDF("y")


def testMapOfPrimitiveTypesCombination(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))

checkAnswer(dfExample2.selectExpr("transform_keys(j, " +
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"),
Seq(Row(Map(false -> false))))
}
// Test with local relation, the Project will be evaluated without codegen
testMapOfPrimitiveTypesCombination()
dfExample1.cache()
dfExample2.cache()
dfExample3.cache()
dfExample4.cache()
// Test with cached relation, the Project will be evaluated with codegen
testMapOfPrimitiveTypesCombination()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have do that if the expression implements CodegenFallback?

}

test("transform keys function - Invalid lambda functions and exceptions") {

val dfExample1 = Seq(
Map[String, String]("a" -> null)
).toDF("i")

val dfExample2 = Seq(
Seq(1, 2, 3, 4)
).toDF("j")

def testInvalidLambdaFunctions(): Unit = {
val ex1 = intercept[AnalysisException] {
dfExample1.selectExpr("transform_keys(i, k -> k)")
}
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex2 = intercept[AnalysisException] {
dfExample1.selectExpr("transform_keys(i, (k, v, x) -> k + 1)")
}
assert(ex2.getMessage.contains(
"The number of lambda function arguments '3' does not match"))

val ex3 = intercept[RuntimeException] {
dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show()
}
assert(ex3.getMessage.contains("Cannot use null as map key!"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we can do those tests only with dfExample3?


val ex4 = intercept[AnalysisException] {
dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)")
}
assert(ex4.getMessage.contains(
"data type mismatch: argument 1 requires map type"))
}

testInvalidLambdaFunctions()
dfExample1.cache()
dfExample2.cache()
testInvalidLambdaFunctions()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need dfExample3.cache() as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin I would like to ask you a generic question regarding higher-order functions. Is it necessary to perform checks with codegen paths if all the newly added functions extends from CodegenFallback? Eventually, is there a plan to add coden for these functions in future?

}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down