Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ object FunctionRegistry {
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformKeys]("transform_keys"),
expression[MapZipWith]("map_zip_with"),
CreateStruct.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,65 @@ case class ArrayAggregate(
override def prettyName: String = "aggregate"
}

/**
* Transform Keys for every entry of the map by applying the transform_keys function.
* Returns map with transformed key entries
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms elements in a map using the function.",
examples = """
Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + 1);
map(array(2, 3, 4), array(1, 2, 3))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(2, 4, 6), array(1, 2, 3))
""",
since = "2.4.0")
case class TransformKeys(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = argument.nullable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be moved to SimpleHigherOrderFunction

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense.
Let's have wrap-up prs for higher-order functions after the remaining 2 prs are merged.


override def dataType: DataType = {
val map = argument.dataType.asInstanceOf[MapType]
MapType(function.dataType, map.valueType, map.valueContainsNull)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use valueType and valueContainsNull from the following val?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about this?

}

@transient val MapType(keyType, valueType, valueContainsNull) = argument.dataType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lazy val?
Could you add a test when argument is not a map in invalid cases of DataFrameFunctionsSuite?


override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): TransformKeys = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val (keyVar, valueVar) = {
val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
(keyVar, valueVar)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: how about:

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant we don't need to surround by:

@transient lazy val (keyVar, valueVar) = {
  ...
  (keyVar, valueVar)
}

just

@transient lazy val LambdaFunction(_,
  (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function

should work.


override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val f = functionForEval
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we use functionForEval directly?

val resultKeys = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
val result = f.eval(inputRow)
if (result == null) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space between == and null.

throw new RuntimeException("Cannot use null as map key!")
}
resultKeys.update(i, result)
i += 1
}
new ArrayBasedMapData(resultKeys, map.valueArray())
}

override def prettyName: String = "transform_keys"
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._

class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -74,6 +75,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayFilter(expr, createLambda(at.elementType, at.containsNull, f))
}

def transformKeys(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val map = expr.dataType.asInstanceOf[MapType]
TransformKeys(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

def aggregate(
expr: Expression,
zero: Expression,
Expand Down Expand Up @@ -283,6 +289,75 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
15)
}

test("TransformKeys") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => k + 1
val plusValue: (Expression, Expression) => Expression = (k, v) => k + v
val modKey: (Expression, Expression) => Expression = (k, v) => k % 3

checkEvaluation(transformKeys(ai0, plusOne), Map(2 -> 1, 3 -> 2, 4 -> 3, 5 -> 4))
checkEvaluation(transformKeys(ai0, plusValue), Map(2 -> 1, 4 -> 2, 6 -> 3, 8 -> 4))
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue), Map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
checkEvaluation(transformKeys(ai0, modKey),
ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4)))
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
transformKeys(transformKeys(ai1, plusOne), plusValue), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai2, plusOne), Map(2 -> 1, 3 -> null, 4 -> 3))
checkEvaluation(
transformKeys(transformKeys(ai2, plusOne), plusOne), Map(3 -> 1, 4 -> null, 5 -> 3))
checkEvaluation(transformKeys(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> null),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(null,
MapType(StringType, StringType, valueContainsNull = false))
val asn = Literal.create(Map.empty[StringType, StringType],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as3?

MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val convertKeyToKeyLength: (Expression, Expression) => Expression =
(k, v) => Length(k) + 1

checkEvaluation(
transformKeys(as0, concatValue), Map("axy" -> "xy", "bbyz" -> "yz", "ccczx" -> "zx"))
checkEvaluation(
transformKeys(transformKeys(as0, concatValue), concatValue),
Map("axyxy" -> "xy", "bbyzyz" -> "yz", "ccczxzx" -> "zx"))
checkEvaluation(transformKeys(asn, concatValue), Map.empty[String, String])
checkEvaluation(
transformKeys(transformKeys(asn, concatValue), convertKeyToKeyLength),
Map.empty[Int, String])
checkEvaluation(transformKeys(as0, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> "zx"))
checkEvaluation(transformKeys(as1, convertKeyToKeyLength),
Map(2 -> "xy", 3 -> "yz", 4 -> null))
checkEvaluation(transformKeys(as2, convertKeyToKeyLength), null)
checkEvaluation(transformKeys(asn, convertKeyToKeyLength), Map.empty[Int, String])

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformKeys(ax0, plusOne), Map(2 -> "x", 3 -> "y", 4 -> "z"))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ select exists(ys, y -> y > 30) as v from nested;

-- Check for element existence in a null array
select exists(cast(null as array<int>), y -> y > 30) as v;

create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  (1, map(1, 1, 2, 2, 3, 3)),
  (2, map(4, 4, 5, 5, 6, 6))

as t(x, ys);

-- Identity Transform Keys in a map
select transform_keys(ys, (k, v) -> k) as v from nested;

-- Transform Keys in a map by adding constant
select transform_keys(ys, (k, v) -> k + 1) as v from nested;

-- Transform Keys in a map using values
select transform_keys(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 20


-- !query 0
Expand Down Expand Up @@ -163,3 +163,40 @@ select exists(cast(null as array<int>), y -> y > 30) as v
struct<v:boolean>
-- !query 16 output
NULL


-- !query 17
create or replace temporary view nested as values
(1, map(1,1,2,2,3,3)),
(2, map(4,4,5,5,6,6))
as t(x, ys)
-- !query 17 schema
struct<>
-- !query 17 output


-- !query 18
select transform_keys(ys, (k, v) -> k) as v from nested
-- !query 18 schema
struct<v:map<int,int>>
-- !query 18 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 19
select transform_keys(ys, (k, v) -> k + 1) as v from nested
-- !query 19 schema
struct<v:map<int,int>>
-- !query 19 output
{2:1,3:2,4:3}
{5:4,6:5,7:6}


-- !query 20
select transform_keys(ys, (k, v) -> k + v) as v from nested
-- !query 20 schema
struct<v:map<int,int>>
-- !query 20 output
{10:5,12:6,8:4}
{2:1,4:2,6:3}
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,97 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map"))
}

test("transform keys function - test various primitive data types combinations") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need so many cases here. We only need to verify the api works end to end.
Evaluation checks of the function should be in HigherOrderFunctionsSuite.

val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[Int, Double](1 -> 1.0E0, 2 -> 1.4E0, 3 -> 1.7E0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need E0?

).toDF("j")

val dfExample3 = Seq(
Map[Int, Boolean](25 -> true, 26 -> false)
).toDF("x")

val dfExample4 = Seq(
Map[Array[Int], Boolean](Array(1, 2) -> false)
).toDF("y")


def testMapOfPrimitiveTypesCombination(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_keys(i, (k, v) -> k + v)"),
Seq(Row(Map(2 -> 1, 18 -> 9, 16 -> 8, 14 -> 7))))

checkAnswer(dfExample2.selectExpr("transform_keys(j, " +
"(k, v) -> map_from_arrays(ARRAY(1, 2, 3), ARRAY('one', 'two', 'three'))[k])"),
Seq(Row(Map("one" -> 1.0, "two" -> 1.4, "three" -> 1.7))))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> CAST(v * 2 AS BIGINT) + k)"),
Seq(Row(Map(3 -> 1.0, 4 -> 1.4, 6 -> 1.7))))

checkAnswer(dfExample2.selectExpr("transform_keys(j, (k, v) -> k + v)"),
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))

checkAnswer(dfExample4.selectExpr("transform_keys(y, (k, v) -> array_contains(k, 3) AND v)"),
Seq(Row(Map(false -> false))))
}
// Test with local relation, the Project will be evaluated without codegen
testMapOfPrimitiveTypesCombination()
dfExample1.cache()
dfExample2.cache()
dfExample3.cache()
dfExample4.cache()
// Test with cached relation, the Project will be evaluated with codegen
testMapOfPrimitiveTypesCombination()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have do that if the expression implements CodegenFallback?

}

test("transform keys function - Invalid lambda functions and exceptions") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[String, String]("a" -> "b")
).toDF("j")

val dfExample3 = Seq(
Map[String, String]("a" -> null)
).toDF("x")

def testInvalidLambdaFunctions(): Unit = {
val ex1 = intercept[AnalysisException] {
dfExample1.selectExpr("transform_keys(i, k -> k )")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra space after k -> k.

}
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex2 = intercept[AnalysisException] {
dfExample2.selectExpr("transform_keys(j, (k, v, x) -> k + 1)")
}
assert(ex2.getMessage.contains(
"The number of lambda function arguments '3' does not match"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


val ex3 = intercept[RuntimeException] {
dfExample3.selectExpr("transform_keys(x, (k, v) -> v)").show()
}
assert(ex3.getMessage.contains("Cannot use null as map key!"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we can do those tests only with dfExample3?

}

testInvalidLambdaFunctions()
dfExample1.cache()
dfExample2.cache()
testInvalidLambdaFunctions()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need dfExample3.cache() as well?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ueshin I would like to ask you a generic question regarding higher-order functions. Is it necessary to perform checks with codegen paths if all the newly added functions extends from CodegenFallback? Eventually, is there a plan to add coden for these functions in future?

}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down