Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,7 @@ object FunctionRegistry {
expression[ArrayFilter]("filter"),
expression[ArrayExists]("exists"),
expression[ArrayAggregate]("aggregate"),
expression[TransformValues]("transform_values"),
expression[MapZipWith]("map_zip_with"),
CreateStruct.registryEntry,

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,53 @@ case class ArrayAggregate(
override def prettyName: String = "aggregate"
}

/**
* Returns a map that applies the function to each value of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(expr, func) - Transforms values in the map using the function.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent

examples = """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

Examples:
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> v + 1);
map(array(1, 2, 3), array(2, 3, 4))
> SELECT _FUNC_(map(array(1, 2, 3), array(1, 2, 3)), (k, v) -> k + v);
map(array(1, 2, 3), array(2, 4, 6))
""",
since = "2.4.0")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto.

case class TransformValues(
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {

override def nullable: Boolean = argument.nullable

@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType

override def dataType: DataType = MapType(keyType, function.dataType, valueContainsNull)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the dataType be defined as MapType(keyType, function.dataType, function.nullable)?


override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction)
: TransformValues = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}

@transient lazy val LambdaFunction(
_, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: indent


override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val map = argumentValue.asInstanceOf[MapData]
val resultValues = new GenericArrayData(new Array[Any](map.numElements))
var i = 0
while (i < map.numElements) {
keyVar.value.set(map.keyArray().get(i, keyVar.dataType))
valueVar.value.set(map.valueArray().get(i, valueVar.dataType))
resultValues.update(i, functionForEval.eval(inputRow))
i += 1
}
new ArrayBasedMapData(map.keyArray(), resultValues)
}
override def prettyName: String = "transform_values"
}

/**
* Merges two given maps into a single map by applying function to the pair of values with
* the same key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
aggregate(expr, zero, merge, identity)
}

def transformValues(expr: Expression, f: (Expression, Expression) => Expression): Expression = {
val map = expr.dataType.asInstanceOf[MapType]
TransformValues(expr, createLambda(map.keyType, false, map.valueType, map.valueContainsNull, f))
}

test("ArrayTransform") {
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType, containsNull = false))
val ai1 = Literal.create(Seq[Integer](1, null, 3), ArrayType(IntegerType, containsNull = true))
Expand Down Expand Up @@ -283,6 +288,74 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
15)
}

test("TransformValues") {
val ai0 = Literal.create(
Map(1 -> 1, 2 -> 2, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val ai1 = Literal.create(
Map(1 -> 1, 2 -> null, 3 -> 3),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai2 = Literal.create(
Map.empty[Int, Int],
MapType(IntegerType, IntegerType, valueContainsNull = true))
val ai3 = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false))

val plusOne: (Expression, Expression) => Expression = (k, v) => v + 1
val valueUpdate: (Expression, Expression) => Expression = (k, v) => k * k

checkEvaluation(transformValues(ai0, plusOne), Map(1 -> 2, 2 -> 3, 3 -> 4))
checkEvaluation(transformValues(ai0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai0, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai1, plusOne), Map(1 -> 2, 2 -> null, 3 -> 4))
checkEvaluation(transformValues(ai1, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(
transformValues(transformValues(ai1, plusOne), valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
checkEvaluation(transformValues(ai2, plusOne), Map.empty[Int, Int])
checkEvaluation(transformValues(ai3, plusOne), null)

val as0 = Literal.create(
Map("a" -> "xy", "bb" -> "yz", "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = false))
val as1 = Literal.create(
Map("a" -> "xy", "bb" -> null, "ccc" -> "zx"),
MapType(StringType, StringType, valueContainsNull = true))
val as2 = Literal.create(Map.empty[StringType, StringType],
MapType(StringType, StringType, valueContainsNull = true))
val as3 = Literal.create(null, MapType(StringType, StringType, valueContainsNull = true))

val concatValue: (Expression, Expression) => Expression = (k, v) => Concat(Seq(k, v))
val valueTypeUpdate: (Expression, Expression) => Expression =
(k, v) => Length(v) + 1

checkEvaluation(
transformValues(as0, concatValue), Map("a" -> "axy", "bb" -> "bbyz", "ccc" -> "ccczx"))
checkEvaluation(transformValues(as0, valueTypeUpdate),
Map("a" -> 3, "bb" -> 3, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as0, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> "bbbbyz", "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as1, concatValue),
Map("a" -> "axy", "bb" -> null, "ccc" -> "ccczx"))
checkEvaluation(transformValues(as1, valueTypeUpdate),
Map("a" -> 3, "bb" -> null, "ccc" -> 3))
checkEvaluation(
transformValues(transformValues(as1, concatValue), concatValue),
Map("a" -> "aaxy", "bb" -> null, "ccc" -> "cccccczx"))
checkEvaluation(transformValues(as2, concatValue), Map.empty[String, String])
checkEvaluation(transformValues(as2, valueTypeUpdate), Map.empty[String, Int])
checkEvaluation(
transformValues(transformValues(as2, concatValue), valueTypeUpdate),
Map.empty[String, Int])
checkEvaluation(transformValues(as3, concatValue), null)

val ax0 = Literal.create(
Map(1 -> "x", 2 -> "y", 3 -> "z"),
MapType(IntegerType, StringType, valueContainsNull = false))

checkEvaluation(transformValues(ax0, valueUpdate), Map(1 -> 1, 2 -> 4, 3 -> 9))
}

test("MapZipWith") {
def map_zip_with(
left: Expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ select exists(ys, y -> y > 30) as v from nested;

-- Check for element existence in a null array
select exists(cast(null as array<int>), y -> y > 30) as v;

create or replace temporary view nested as values
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys);

-- Identity Transform Keys in a map
select transform_values(ys, (k, v) -> v) as v from nested;

-- Transform Keys in a map by adding constant
select transform_values(ys, (k, v) -> v + 1) as v from nested;

-- Transform Keys in a map using values
select transform_values(ys, (k, v) -> k + v) as v from nested;
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 20


-- !query 0
Expand Down Expand Up @@ -163,3 +163,40 @@ select exists(cast(null as array<int>), y -> y > 30) as v
struct<v:boolean>
-- !query 16 output
NULL


-- !query 17
create or replace temporary view nested as values
(1, map(1, 1, 2, 2, 3, 3)),
(2, map(4, 4, 5, 5, 6, 6))
as t(x, ys)
-- !query 17 schema
struct<>
-- !query 17 output


-- !query 18
select transform_values(ys, (k, v) -> v) as v from nested
-- !query 18 schema
struct<v:map<int,int>>
-- !query 18 output
{1:1,2:2,3:3}
{4:4,5:5,6:6}


-- !query 19
select transform_values(ys, (k, v) -> v + 1) as v from nested
-- !query 19 schema
struct<v:map<int,int>>
-- !query 19 output
{1:2,2:3,3:4}
{4:5,5:6,6:7}


-- !query 20
select transform_values(ys, (k, v) -> k + v) as v from nested
-- !query 20 schema
struct<v:map<int,int>>
-- !query 20 output
{1:2,2:4,3:6}
{4:8,5:10,6:12}
Original file line number Diff line number Diff line change
Expand Up @@ -2302,6 +2302,177 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map"))
}

test("transform values function - test primitive data types") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[Boolean, String](false -> "abc", true -> "def")
).toDF("x")

val dfExample3 = Seq(
Map[String, Int]("a" -> 1, "b" -> 2, "c" -> 3)
).toDF("y")

val dfExample4 = Seq(
Map[Int, Double](1 -> 1.0, 2 -> 1.40, 3 -> 1.70)
).toDF("z")

val dfExample5 = Seq(
Map[Int, Array[Int]](1 -> Array(1, 2))
).toDF("c")

def testMapOfPrimitiveTypesCombination(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k + v)"),
Seq(Row(Map(1 -> 2, 9 -> 18, 8 -> 16, 7 -> 14))))

checkAnswer(dfExample2.selectExpr(
"transform_values(x, (k, v) -> if(k, v, CAST(k AS String)))"),
Seq(Row(Map(false -> "false", true -> "def"))))

checkAnswer(dfExample2.selectExpr("transform_values(x, (k, v) -> NOT k AND v = 'abc')"),
Seq(Row(Map(false -> true, true -> false))))

checkAnswer(dfExample3.selectExpr("transform_values(y, (k, v) -> v * v)"),
Seq(Row(Map("a" -> 1, "b" -> 4, "c" -> 9))))

checkAnswer(dfExample3.selectExpr(
"transform_values(y, (k, v) -> k || ':' || CAST(v as String))"),
Seq(Row(Map("a" -> "a:1", "b" -> "b:2", "c" -> "c:3"))))

checkAnswer(
dfExample3.selectExpr("transform_values(y, (k, v) -> concat(k, cast(v as String)))"),
Seq(Row(Map("a" -> "a1", "b" -> "b2", "c" -> "c3"))))

checkAnswer(
dfExample4.selectExpr(
"transform_values(" +
"z,(k, v) -> map_from_arrays(ARRAY(1, 2, 3), " +
"ARRAY('one', 'two', 'three'))[k] || '_' || CAST(v AS String))"),
Seq(Row(Map(1 -> "one_1.0", 2 -> "two_1.4", 3 ->"three_1.7"))))

checkAnswer(
dfExample4.selectExpr("transform_values(z, (k, v) -> k-v)"),
Seq(Row(Map(1 -> 0.0, 2 -> 0.6000000000000001, 3 -> 1.3))))

checkAnswer(
dfExample5.selectExpr("transform_values(c, (k, v) -> k + cardinality(v))"),
Seq(Row(Map(1 -> 3))))
}

// Test with local relation, the Project will be evaluated without codegen
testMapOfPrimitiveTypesCombination()
dfExample1.cache()
dfExample2.cache()
dfExample3.cache()
dfExample4.cache()
dfExample5.cache()
// Test with cached relation, the Project will be evaluated with codegen
testMapOfPrimitiveTypesCombination()
}

test("transform values function - test empty") {
val dfExample1 = Seq(
Map.empty[Integer, Integer]
).toDF("i")

val dfExample2 = Seq(
Map.empty[BigInt, String]
).toDF("j")

def testEmpty(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> NULL)"),
Seq(Row(Map.empty[Integer, Integer])))

checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> k)"),
Seq(Row(Map.empty[Integer, Integer])))

checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> v)"),
Seq(Row(Map.empty[Integer, Integer])))

checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 0)"),
Seq(Row(Map.empty[Integer, Integer])))

checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> 'value')"),
Seq(Row(Map.empty[Integer, String])))

checkAnswer(dfExample1.selectExpr("transform_values(i, (k, v) -> true)"),
Seq(Row(Map.empty[Integer, Boolean])))

checkAnswer(dfExample2.selectExpr("transform_values(j, (k, v) -> k + cast(v as BIGINT))"),
Seq(Row(Map.empty[BigInt, BigInt])))
}

testEmpty()
dfExample1.cache()
dfExample2.cache()
testEmpty()
}

test("transform values function - test null values") {
val dfExample1 = Seq(
Map[Int, Integer](1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4)
).toDF("a")

val dfExample2 = Seq(
Map[Int, String](1 -> "a", 2 -> "b", 3 -> null)
).toDF("b")

def testNullValue(): Unit = {
checkAnswer(dfExample1.selectExpr("transform_values(a, (k, v) -> null)"),
Seq(Row(Map[Int, Integer](1 -> null, 2 -> null, 3 -> null, 4 -> null))))

checkAnswer(dfExample2.selectExpr(
"transform_values(b, (k, v) -> IF(v IS NULL, k + 1, k + 2))"),
Seq(Row(Map(1 -> 3, 2 -> 4, 3 -> 4))))
}

testNullValue()
dfExample1.cache()
dfExample2.cache()
testNullValue()
}

test("transform values function - test invalid functions") {
val dfExample1 = Seq(
Map[Int, Int](1 -> 1, 9 -> 9, 8 -> 8, 7 -> 7)
).toDF("i")

val dfExample2 = Seq(
Map[String, String]("a" -> "b")
).toDF("j")

val dfExample3 = Seq(
Seq(1, 2, 3, 4)
).toDF("x")

def testInvalidLambdaFunctions(): Unit = {

val ex1 = intercept[AnalysisException] {
dfExample1.selectExpr("transform_values(i, k -> k )")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: remove an extra space after k -> k.

}
assert(ex1.getMessage.contains("The number of lambda function arguments '1' does not match"))

val ex2 = intercept[AnalysisException] {
dfExample2.selectExpr("transform_values(j, (k, v, x) -> k + 1)")
}
assert(ex2.getMessage.contains("The number of lambda function arguments '3' does not match"))

val ex3 = intercept[AnalysisException] {
dfExample3.selectExpr("transform_values(x, (k, v) -> k + 1)")
}
assert(ex3.getMessage.contains(
"data type mismatch: argument 1 requires map type"))
}

testInvalidLambdaFunctions()
dfExample1.cache()
dfExample2.cache()
dfExample3.cache()
testInvalidLambdaFunctions()
}

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down