-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-23932][SQL] Higher order function zip_with #22031
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
03d19ce
6f91777
cc0752a
f20d646
f8c0320
14ef371
c7e2dee
35d2cbc
d6c44a6
92cb34a
0342ed9
16516ec
248bccf
2388130
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -442,3 +442,91 @@ case class ArrayAggregate( | |
|
|
||
| override def prettyName: String = "aggregate" | ||
| } | ||
|
|
||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.", | ||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)); | ||
| array(('a', 1), ('b', 3), ('c', 5)) | ||
| > SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y)); | ||
| array(4, 6) | ||
| > SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y)); | ||
| array('ad', 'be', 'cf') | ||
| """, | ||
| since = "2.4.0") | ||
| // scalastyle:on line.size.limit | ||
| case class ArraysZipWith( | ||
| left: Expression, | ||
| right: Expression, | ||
| function: Expression) | ||
| extends HigherOrderFunction with CodegenFallback with ExpectsInputTypes { | ||
|
|
||
| override def inputs: Seq[Expression] = List(left, right) | ||
|
|
||
| override def functions: Seq[Expression] = List(function) | ||
|
|
||
| def expectingFunctionType: AbstractDataType = AnyDataType | ||
| @transient lazy val functionForEval: Expression = functionsForEval.head | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType, expectingFunctionType) | ||
|
|
||
| override def nullable: Boolean = inputs.exists(_.nullable) | ||
|
|
||
| override def dataType: ArrayType = ArrayType(function.dataType, function.nullable) | ||
|
|
||
| override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraysZipWith = { | ||
| val (leftElementType, leftContainsNull) = left.dataType match { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can utilize
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comment is not valid anymore. The method has been removed by #22075. |
||
| case ArrayType(elementType, containsNull) => (elementType, containsNull) | ||
| case _ => | ||
| val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType | ||
| (elementType, containsNull) | ||
| } | ||
| val (rightElementType, rightContainsNull) = right.dataType match { | ||
| case ArrayType(elementType, containsNull) => (elementType, containsNull) | ||
| case _ => | ||
| val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType | ||
| (elementType, containsNull) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now we can do: val ArrayType(leftElementType, leftContainsNull) = left.dataType
val ArrayType(rightElementType, rightContainsNull) = right.dataType |
||
| copy(function = f(function, | ||
| (leftElementType, leftContainsNull) :: (rightElementType, rightContainsNull) :: Nil)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to support different size of input arrays (The jira ticket says: "Both arrays must be the same length."), what about the scenario when one array is empty and the second has elements? Shouldn't we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we append |
||
| } | ||
|
|
||
| @transient lazy val (arr1Var, arr2Var) = { | ||
| val LambdaFunction(_, | ||
| (arr1Var: NamedLambdaVariable):: (arr2Var: NamedLambdaVariable) :: Nil, _) = function | ||
| (arr1Var, arr2Var) | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: the following should work: @transient lazy val LambdaFunction(_,
Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function |
||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val leftArr = left.eval(input).asInstanceOf[ArrayData] | ||
| val rightArr = right.eval(input).asInstanceOf[ArrayData] | ||
|
|
||
| if (leftArr == null || rightArr == null) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If |
||
| null | ||
| } else { | ||
| val resultLength = math.max(leftArr.numElements(), rightArr.numElements()) | ||
| val f = functionForEval | ||
| val result = new GenericArrayData(new Array[Any](resultLength)) | ||
| var i = 0 | ||
| while (i < resultLength) { | ||
| if (i < leftArr.numElements()) { | ||
| arr1Var.value.set(leftArr.get(i, arr1Var.dataType)) | ||
| } else { | ||
| arr1Var.value.set(null) | ||
| } | ||
| if(i < rightArr.numElements()) { | ||
| arr2Var.value.set(rightArr.get(i, arr2Var.dataType)) | ||
| } else { | ||
| arr2Var.value.set(null) | ||
| } | ||
| result.update(i, f.eval(input)) | ||
| i += 1 | ||
| } | ||
| result | ||
| } | ||
| } | ||
|
|
||
| override def prettyName: String = "zip_with" | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2117,6 +2117,65 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { | |
| assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type")) | ||
| } | ||
|
|
||
| test("zip_with function - arrays for primitive type not containing null") { | ||
| val df1 = Seq( | ||
| (Seq(9001, 9002, 9003), Seq(4, 5, 6)), | ||
| (Seq(1, 2), Seq(3, 4)), | ||
| (Seq.empty[Int], Seq.empty[Int]), | ||
| (null, null) | ||
| ).toDF("val1", "val2") | ||
| val df2 = Seq( | ||
| (Seq(1, 2, 3), Seq("a", "b", "c")), | ||
| (Seq(1, 2, 3), Seq("a", "b")) | ||
| ).toDF("val1", "val2") | ||
|
|
||
| def testArrayOfPrimitiveTypeNotContainsNull(): Unit = { | ||
| val expectedValue1 = Seq( | ||
| Row(Seq(9005, 9007, 9009)), | ||
| Row(Seq(4, 6)), | ||
| Row(Seq.empty), | ||
| Row(null)) | ||
| checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) | ||
|
|
||
| val expectedValue2 = Seq( | ||
| Row(Seq(Row("a", 1), Row("b", 2), Row("c", 3))), | ||
| Row(Seq(Row("a", 1), Row("b", 2), Row(null, 3)))) | ||
| checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2) | ||
| } | ||
|
|
||
| // Test with local relation, the Project will be evaluated without codegen | ||
| testArrayOfPrimitiveTypeNotContainsNull() | ||
| // Test with cached relation, the Project will be evaluated with codegen | ||
| df1.cache() | ||
| df2.cache() | ||
| testArrayOfPrimitiveTypeNotContainsNull() | ||
| } | ||
|
|
||
| test("zip_with function - arrays for primitive type containing null") { | ||
| val df1 = Seq[(Seq[Integer], Seq[Integer])]( | ||
| (Seq(9001, null, 9003), Seq(4, 5, 6)), | ||
| (Seq(1, null, 2, 4), Seq(3, 4)), | ||
| (Seq.empty, Seq.empty), | ||
| (null, null) | ||
| ).toDF("val1", "val2") | ||
|
|
||
| def testArrayOfPrimitiveTypeContainsNull(): Unit = { | ||
| val expectedValue1 = Seq( | ||
| Row(Seq(9005, null, 9009)), | ||
| Row(Seq(4, null, null, null)), | ||
| Row(Seq.empty), | ||
| Row(null)) | ||
| checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1) | ||
| } | ||
|
|
||
| // Test with local relation, the Project will be evaluated without codegen | ||
| testArrayOfPrimitiveTypeContainsNull() | ||
| // Test with cached relation, the Project will be evaluated with codegen | ||
| df1.cache() | ||
| testArrayOfPrimitiveTypeContainsNull() | ||
| } | ||
|
|
||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a test for invalid cases?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also can you add tests to |
||
| private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { | ||
| import DataFrameFunctionsSuite.CodegenFallbackExpr | ||
| for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need to define this?