Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,7 @@ object FunctionRegistry {
expression[MapFilter]("map_filter"),
expression[ArrayFilter]("filter"),
expression[ArrayAggregate]("aggregate"),
expression[ArraysZipWith]("zip_with"),
CreateStruct.registryEntry,

// misc functions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,3 +442,91 @@ case class ArrayAggregate(

override def prettyName: String = "aggregate"
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(left, right, func) - Merges the two given arrays, element-wise, into a single array using function. If one array is shorter, nulls are appended at the end to match the length of the longer array, before applying function.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x));
array(('a', 1), ('b', 3), ('c', 5))
> SELECT _FUNC_(array(1, 2), array(3, 4), (x, y) -> x + y));
array(4, 6)
> SELECT _FUNC_(array('a', 'b', 'c'), array('d', 'e', 'f'), (x, y) -> concat(x, y));
array('ad', 'be', 'cf')
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraysZipWith(
left: Expression,
right: Expression,
function: Expression)
extends HigherOrderFunction with CodegenFallback with ExpectsInputTypes {

override def inputs: Seq[Expression] = List(left, right)

override def functions: Seq[Expression] = List(function)

def expectingFunctionType: AbstractDataType = AnyDataType
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to define this?

@transient lazy val functionForEval: Expression = functionsForEval.head

override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, ArrayType, expectingFunctionType)

override def nullable: Boolean = inputs.exists(_.nullable)

override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)

override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArraysZipWith = {
val (leftElementType, leftContainsNull) = left.dataType match {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can utilize HigherOrderFunction.arrayArgumentType.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is not valid anymore. The method has been removed by #22075.

case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
}
val (rightElementType, rightContainsNull) = right.dataType match {
case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we can do:

val ArrayType(leftElementType, leftContainsNull) = left.dataType
val ArrayType(rightElementType, rightContainsNull) = right.dataType

copy(function = f(function,
(leftElementType, leftContainsNull) :: (rightElementType, rightContainsNull) :: Nil))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to support different size of input arrays (The jira ticket says: "Both arrays must be the same length."), what about the scenario when one array is empty and the second has elements? Shouldn't we use true instead of leftContainsNull and rightContainsNull?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mn-mikke @ueshin "both arrays must be the same length" was how zip_with in presto used to work, they've moved to appending nulls and process regardless.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we append nulls to the shorter array, both of the arguments might be null, so we should use true for nullabilities of the arguments as @mn-mikke suggested.

}

@transient lazy val (arr1Var, arr2Var) = {
val LambdaFunction(_,
(arr1Var: NamedLambdaVariable):: (arr2Var: NamedLambdaVariable) :: Nil, _) = function
(arr1Var, arr2Var)
}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the following should work:

@transient lazy val LambdaFunction(_,
  Seq(leftElemVar: NamedLambdaVariable, rightElemVar: NamedLambdaVariable), _) = function


override def eval(input: InternalRow): Any = {
val leftArr = left.eval(input).asInstanceOf[ArrayData]
val rightArr = right.eval(input).asInstanceOf[ArrayData]

if (leftArr == null || rightArr == null) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If leftArr is null, right doesn't have to be evaluated.

null
} else {
val resultLength = math.max(leftArr.numElements(), rightArr.numElements())
val f = functionForEval
val result = new GenericArrayData(new Array[Any](resultLength))
var i = 0
while (i < resultLength) {
if (i < leftArr.numElements()) {
arr1Var.value.set(leftArr.get(i, arr1Var.dataType))
} else {
arr1Var.value.set(null)
}
if(i < rightArr.numElements()) {
arr2Var.value.set(rightArr.get(i, arr2Var.dataType))
} else {
arr2Var.value.set(null)
}
result.update(i, f.eval(input))
i += 1
}
result
}
}

override def prettyName: String = "zip_with"
}
Original file line number Diff line number Diff line change
Expand Up @@ -2117,6 +2117,65 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
}

test("zip_with function - arrays for primitive type not containing null") {
val df1 = Seq(
(Seq(9001, 9002, 9003), Seq(4, 5, 6)),
(Seq(1, 2), Seq(3, 4)),
(Seq.empty[Int], Seq.empty[Int]),
(null, null)
).toDF("val1", "val2")
val df2 = Seq(
(Seq(1, 2, 3), Seq("a", "b", "c")),
(Seq(1, 2, 3), Seq("a", "b"))
).toDF("val1", "val2")

def testArrayOfPrimitiveTypeNotContainsNull(): Unit = {
val expectedValue1 = Seq(
Row(Seq(9005, 9007, 9009)),
Row(Seq(4, 6)),
Row(Seq.empty),
Row(null))
checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1)

val expectedValue2 = Seq(
Row(Seq(Row("a", 1), Row("b", 2), Row("c", 3))),
Row(Seq(Row("a", 1), Row("b", 2), Row(null, 3))))
checkAnswer(df2.selectExpr("zip_with(val1, val2, (x, y) -> (y, x))"), expectedValue2)
}

// Test with local relation, the Project will be evaluated without codegen
testArrayOfPrimitiveTypeNotContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
df1.cache()
df2.cache()
testArrayOfPrimitiveTypeNotContainsNull()
}

test("zip_with function - arrays for primitive type containing null") {
val df1 = Seq[(Seq[Integer], Seq[Integer])](
(Seq(9001, null, 9003), Seq(4, 5, 6)),
(Seq(1, null, 2, 4), Seq(3, 4)),
(Seq.empty, Seq.empty),
(null, null)
).toDF("val1", "val2")

def testArrayOfPrimitiveTypeContainsNull(): Unit = {
val expectedValue1 = Seq(
Row(Seq(9005, null, 9009)),
Row(Seq(4, null, null, null)),
Row(Seq.empty),
Row(null))
checkAnswer(df1.selectExpr("zip_with(val1, val2, (x, y) -> x + y)"), expectedValue1)
}

// Test with local relation, the Project will be evaluated without codegen
testArrayOfPrimitiveTypeContainsNull()
// Test with cached relation, the Project will be evaluated with codegen
df1.cache()
testArrayOfPrimitiveTypeContainsNull()
}


Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test for invalid cases?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also can you add tests to HigherOrderFunctionsSuite to check more explicit patterns?

private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
Expand Down