-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23931][SQL] Adds arrays_zip function to sparksql #21045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 32 commits
7bf45dd
99848fe
27b0bc2
93826b6
a7e29f6
7130fec
d552216
1fecef4
f71151a
6b4bc94
1549928
9f7bba1
3ba2b4f
3a59201
6462fa8
8b1eb7c
2bfba80
c3b062c
d9b95c4
26bbf66
d9ad04d
f29ee1c
c58d09c
38fa996
5b3066b
759a4d4
68e69db
12b3835
643cb9b
5876082
0223960
2b88387
bbc20ee
8d3a838
d8f3dea
3d68ea9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2394,6 +2394,23 @@ def array_repeat(col, count): | |
| return Column(sc._jvm.functions.array_repeat(_to_java_column(col), count)) | ||
|
|
||
|
|
||
| @since(2.4) | ||
| def zip(*cols): | ||
| """ | ||
| Collection function: Merge two columns into one, such that the M-th element of the N-th | ||
| argument will be the N-th field of the M-th output element. | ||
|
|
||
| :param cols: columns in input | ||
|
|
||
| >>> from pyspark.sql.functions import zip as spark_zip | ||
|
||
| >>> df = spark.createDataFrame([(([1, 2, 3], [2, 3, 4]))], ['vals1', 'vals2']) | ||
| >>> df.select(spark_zip(df.vals1, df.vals2).alias('zipped')).collect() | ||
| [Row(zipped=[Row(vals1=1, vals2=2), Row(vals1=2, vals2=3), Row(vals1=3, vals2=4)])] | ||
| """ | ||
| sc = SparkContext._active_spark_context | ||
| return Column(sc._jvm.functions.zip(_to_seq(sc, cols, _to_java_column))) | ||
|
|
||
|
|
||
| # ---------------------------- User Defined Function ---------------------------------- | ||
|
|
||
| class PandasUDFType(object): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -128,6 +128,173 @@ case class MapKeys(child: Expression) | |
| override def prettyName: String = "map_keys" | ||
| } | ||
|
|
||
| @ExpressionDescription( | ||
| usage = """_FUNC_(a1, a2, ...) - Returns a merged array containing in the N-th position the | ||
| N-th value of each array given.""", | ||
|
||
| examples = """ | ||
| Examples: | ||
| > SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4)); | ||
| [[1, 2], [2, 3], [3, 4]] | ||
| > SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4)); | ||
| [[1, 2, 3], [2, 3, 4]] | ||
| """, | ||
| since = "2.4.0") | ||
| case class Zip(children: Seq[Expression]) extends Expression with ExpectsInputTypes { | ||
|
|
||
| override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType) | ||
|
|
||
| override def dataType: DataType = ArrayType(mountSchema) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just a quick follow up question... Under what circumstances can the output array contain
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the first test case I zipped I tried to define the nullability (this word exist? haha) of the output in runtime, but I thought that it was not possible since I can't eval every result before defining the schema. What you think?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I understand that fields of the nested
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm you are correct then, I don't think that such scenario could happen (correctly, at least). That means that the dataType should always reject null values?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, seems like the struct which is the element of the array is not null, so the data type would be
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, I will fix it as a part of #21352.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the struct can be null if any of the input element is null IIUC. So probably
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In that case, the array itself will be null and
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, you're right, sorry! |
||
|
|
||
| override def nullable: Boolean = children.exists(_.nullable) | ||
|
|
||
| private lazy val arrayTypes = children.map(_.dataType.asInstanceOf[ArrayType]) | ||
|
|
||
| private lazy val arrayElementTypes = arrayTypes.map(_.elementType) | ||
|
||
|
|
||
| def mountSchema: StructType = { | ||
|
||
| val fields = children.zip(arrayElementTypes).zipWithIndex.map { | ||
| case ((expr: NamedExpression, elementType), _) => | ||
| StructField(expr.name, elementType, nullable = true) | ||
| case ((_, elementType), idx) => | ||
| StructField(idx.toString, elementType, nullable = true) | ||
| } | ||
|
||
| StructType(fields) | ||
| } | ||
|
|
||
| @transient lazy val numberOfArrays: Int = children.length | ||
|
|
||
| @transient lazy val genericArrayData = classOf[GenericArrayData].getName | ||
|
|
||
| def emptyInputGenCode(ev: ExprCode): ExprCode = { | ||
| ev.copy(code""" | ||
| |${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]); | ||
| |boolean ${ev.isNull} = false; | ||
| """.stripMargin) | ||
| } | ||
|
|
||
| def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| val genericInternalRow = classOf[GenericInternalRow].getName | ||
| val arrVals = ctx.freshName("arrVals") | ||
| val arrCardinality = ctx.freshName("arrCardinality") | ||
| val biggestCardinality = ctx.freshName("biggestCardinality") | ||
|
|
||
| val currentRow = ctx.freshName("currentRow") | ||
| val j = ctx.freshName("j") | ||
| val i = ctx.freshName("i") | ||
| val args = ctx.freshName("args") | ||
|
|
||
| val evals = children.map(_.genCode(ctx)) | ||
| val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) => | ||
| s""" | ||
| |if ($biggestCardinality != -1) { | ||
| | ${eval.code} | ||
| | if (!${eval.isNull}) { | ||
| | $arrVals[$index] = ${eval.value}; | ||
| | $arrCardinality[$index] = ${eval.value}.numElements(); | ||
| | $biggestCardinality = Math.max($biggestCardinality, $arrCardinality[$index]); | ||
| | } else { | ||
| | $biggestCardinality = -1; | ||
| | } | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
| val splittedGetValuesAndCardinalities = ctx.splitExpressions( | ||
| expressions = getValuesAndCardinalities, | ||
| funcName = "getValuesAndCardinalities", | ||
| returnType = "int", | ||
| makeSplitFunction = body => | ||
| s""" | ||
| |$body | ||
| |return $biggestCardinality; | ||
| """.stripMargin, | ||
| foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), | ||
| arguments = | ||
| ("ArrayData[]", arrVals) :: | ||
| ("int[]", arrCardinality) :: | ||
| ("int", biggestCardinality) :: Nil) | ||
|
|
||
| val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) => | ||
| val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i) | ||
| s""" | ||
| |if ($i < $arrCardinality[$idx] && !$arrVals[$idx].isNullAt($i)) { | ||
|
||
| | $currentRow[$idx] = $g; | ||
| |} else { | ||
| | $currentRow[$idx] = null; | ||
| |} | ||
| """.stripMargin | ||
| } | ||
|
|
||
| val getValueForTypeSplitted = ctx.splitExpressions( | ||
| expressions = getValueForType, | ||
| funcName = "extractValue", | ||
| arguments = | ||
| ("int", i) :: | ||
| ("Object[]", currentRow) :: | ||
| ("int[]", arrCardinality) :: | ||
| ("ArrayData[]", arrVals) :: Nil) | ||
|
|
||
| val initVariables = s""" | ||
| |ArrayData[] $arrVals = new ArrayData[$numberOfArrays]; | ||
| |int[] $arrCardinality = new int[$numberOfArrays]; | ||
| |int $biggestCardinality = 0; | ||
| |${CodeGenerator.javaType(dataType)} ${ev.value} = null; | ||
| """.stripMargin | ||
|
|
||
| ev.copy(code""" | ||
| |$initVariables | ||
| |$splittedGetValuesAndCardinalities | ||
| |boolean ${ev.isNull} = $biggestCardinality == -1; | ||
| |if (!${ev.isNull}) { | ||
| | Object[] $args = new Object[$biggestCardinality]; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We usually don't set a value if the result is null. |
||
| | for (int $i = 0; $i < $biggestCardinality; $i ++) { | ||
| | Object[] $currentRow = new Object[$numberOfArrays]; | ||
| | $getValueForTypeSplitted | ||
| | $args[$i] = new $genericInternalRow($currentRow); | ||
| | } | ||
| | ${ev.value} = new $genericArrayData($args); | ||
| |} | ||
| """.stripMargin) | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
| if (numberOfArrays == 0) { | ||
| emptyInputGenCode(ev) | ||
| } else { | ||
| nonEmptyInputGenCode(ctx, ev) | ||
| } | ||
| } | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
| val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData]) | ||
| if (inputArrays.contains(null)) { | ||
| null | ||
| } else { | ||
| val biggestCardinality = if (inputArrays.isEmpty) { | ||
| 0 | ||
| } else { | ||
| inputArrays.map(_.numElements()).max | ||
| } | ||
|
||
|
|
||
| val result = new Array[InternalRow](biggestCardinality) | ||
| val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex | ||
|
|
||
| for (i <- 0 until biggestCardinality) { | ||
| val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) => | ||
| if (i < arr.numElements() && !arr.isNullAt(i)) { | ||
| arr.get(i, arrayElementTypes(index)) | ||
| } else { | ||
| null | ||
| } | ||
| } | ||
|
|
||
| result(i) = InternalRow.apply(currentLayer: _*) | ||
| } | ||
| new GenericArrayData(result) | ||
| } | ||
| } | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
| } | ||
|
|
||
| /** | ||
| * Returns an unordered array containing the values of the map. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ | |
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.Row | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.types._ | ||
|
|
||
|
|
@@ -315,6 +316,91 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper | |
| Some(Literal.create(null, StringType))), null) | ||
| } | ||
|
|
||
| test("Zip") { | ||
|
||
| val literals = Seq( | ||
| Literal.create(Seq(9001, 9002, 9003, null), ArrayType(IntegerType)), | ||
| Literal.create(Seq(null, 1L, null, 4L, 11L), ArrayType(LongType)), | ||
| Literal.create(Seq(-1, -3, 900, null), ArrayType(IntegerType)), | ||
| Literal.create(Seq("a", null, "c"), ArrayType(StringType)), | ||
| Literal.create(Seq(null, false, true), ArrayType(BooleanType)), | ||
| Literal.create(Seq(1.1, null, 1.3, null), ArrayType(DoubleType)), | ||
| Literal.create(Seq(), ArrayType(NullType)), | ||
| Literal.create(Seq(null), ArrayType(NullType)), | ||
| Literal.create(Seq(192.toByte), ArrayType(ByteType)), | ||
| Literal.create( | ||
| Seq(Seq(1, 2, 3), null, Seq(4, 5), Seq(1, null, 3)), ArrayType(ArrayType(IntegerType))), | ||
| Literal.create(Seq(Array[Byte](1.toByte, 5.toByte)), ArrayType(BinaryType)) | ||
| ) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(1))), | ||
| List(Row(9001, null), Row(9002, 1L), Row(9003, null), Row(null, 4L), Row(null, 11L))) | ||
|
||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(2))), | ||
| List(Row(9001, -1), Row(9002, -3), Row(9003, 900), Row(null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(3))), | ||
| List(Row(9001, "a"), Row(9002, null), Row(9003, "c"), Row(null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(4))), | ||
| List(Row(9001, null), Row(9002, false), Row(9003, true), Row(null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(5))), | ||
| List(Row(9001, 1.1), Row(9002, null), Row(9003, 1.3), Row(null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(6))), | ||
| List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(7))), | ||
| List(Row(9001, null), Row(9002, null), Row(9003, null), Row(null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), literals(1), literals(2), literals(3))), | ||
| List( | ||
| Row(9001, null, -1, "a"), | ||
| Row(9002, 1L, -3, null), | ||
| Row(9003, null, 900, "c"), | ||
| Row(null, 4L, null, null), | ||
| Row(null, 11L, null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(4), literals(5), literals(6), literals(7), literals(8))), | ||
| List( | ||
| Row(null, 1.1, null, null, 192.toByte), | ||
| Row(false, null, null, null, null), | ||
| Row(true, 1.3, null, null, null), | ||
| Row(null, null, null, null, null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(9), literals(0))), | ||
| List( | ||
| Row(List(1, 2, 3), 9001), | ||
| Row(null, 9002), | ||
| Row(List(4, 5), 9003), | ||
| Row(List(1, null, 3), null))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(7), literals(10))), | ||
| List(Row(null, Array[Byte](1.toByte, 5.toByte)))) | ||
|
||
|
|
||
| val longLiteral = | ||
| Literal.create((0 to 1000).toSeq, ArrayType(IntegerType)) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), longLiteral)), | ||
| List(Row(9001, 0), Row(9002, 1), Row(9003, 2)) ++ | ||
| (3 to 1000).map { Row(null, _) }.toList) | ||
|
|
||
| val manyLiterals = (0 to 1000).map { _ => | ||
| Literal.create(Seq(1), ArrayType(IntegerType)) | ||
| }.toSeq | ||
|
|
||
| val numbers = List( | ||
| Row(Seq(9001) ++ (0 to 1000).map { _ => 1 }.toSeq: _*), | ||
| Row(Seq(9002) ++ (0 to 1000).map { _ => null }.toSeq: _*), | ||
| Row(Seq(9003) ++ (0 to 1000).map { _ => null }.toSeq: _*), | ||
| Row(Seq(null) ++ (0 to 1000).map { _ => null }.toSeq: _*)) | ||
| checkEvaluation(Zip(Seq(literals(0)) ++ manyLiterals), | ||
| List(numbers(0), numbers(1), numbers(2), numbers(3))) | ||
|
|
||
| checkEvaluation(Zip(Seq(literals(0), Literal.create(null, ArrayType(IntegerType)))), null) | ||
| checkEvaluation(Zip(Seq()), List()) | ||
| } | ||
|
||
|
|
||
| test("Array Min") { | ||
| checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), ArrayType(IntegerType))), -11) | ||
| checkEvaluation( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3508,6 +3508,14 @@ object functions { | |
| */ | ||
| def map_entries(e: Column): Column = withExpr { MapEntries(e.expr) } | ||
|
|
||
| /** | ||
| * Merge multiple columns into a resulting one. | ||
|
||
| * | ||
| * @group collection_funcs | ||
| * @since 2.4.0 | ||
| */ | ||
| def zip(e: Column*): Column = withExpr { Zip(e.map(_.expr)) } | ||
|
|
||
| ////////////////////////////////////////////////////////////////////////////////////////////// | ||
| // Mask functions | ||
| ////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -479,6 +479,53 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { | |
| ) | ||
| } | ||
|
|
||
| test("dataframe zip function") { | ||
|
||
| val df1 = Seq((Seq(9001, 9002, 9003), Seq(4, 5, 6))).toDF("val1", "val2") | ||
| val df2 = Seq((Seq("a", "b"), Seq(true, false), Seq(10, 11))).toDF("val1", "val2", "val3") | ||
| val df3 = Seq((Seq("a", "b"), Seq(4, 5, 6))).toDF("val1", "val2") | ||
| val df4 = Seq((Seq("a", "b", null), Seq(4L))).toDF("val1", "val2") | ||
| val df5 = Seq((Seq(-1), Seq(null), Seq(), Seq(null, null))).toDF("val1", "val2", "val3", "val4") | ||
| val df6 = Seq((Seq(192.toByte, 256.toByte), Seq(1.1), Seq(), Seq(null, null))) | ||
| .toDF("v1", "v2", "v3", "v4") | ||
| val df7 = Seq((Seq(Seq(1, 2, 3), Seq(4, 5)), Seq(1.1, 2.2))).toDF("v1", "v2") | ||
| val df8 = Seq((Seq(Array[Byte](1.toByte, 5.toByte)), Seq(null))).toDF("v1", "v2") | ||
|
|
||
| val expectedValue1 = Row(Seq(Row(9001, 4), Row(9002, 5), Row(9003, 6))) | ||
| checkAnswer(df1.select(zip($"val1", $"val2")), expectedValue1) | ||
| checkAnswer(df1.selectExpr("zip(val1, val2)"), expectedValue1) | ||
|
|
||
| val expectedValue2 = Row(Seq(Row("a", true, 10), Row("b", false, 11))) | ||
| checkAnswer(df2.select(zip($"val1", $"val2", $"val3")), expectedValue2) | ||
| checkAnswer(df2.selectExpr("zip(val1, val2, val3)"), expectedValue2) | ||
|
|
||
| val expectedValue3 = Row(Seq(Row("a", 4), Row("b", 5), Row(null, 6))) | ||
| checkAnswer(df3.select(zip($"val1", $"val2")), expectedValue3) | ||
| checkAnswer(df3.selectExpr("zip(val1, val2)"), expectedValue3) | ||
|
|
||
| val expectedValue4 = Row(Seq(Row("a", 4L), Row("b", null), Row(null, null))) | ||
| checkAnswer(df4.select(zip($"val1", $"val2")), expectedValue4) | ||
| checkAnswer(df4.selectExpr("zip(val1, val2)"), expectedValue4) | ||
|
|
||
| val expectedValue5 = Row(Seq(Row(-1, null, null, null), Row(null, null, null, null))) | ||
| checkAnswer(df5.select(zip($"val1", $"val2", $"val3", $"val4")), expectedValue5) | ||
| checkAnswer(df5.selectExpr("zip(val1, val2, val3, val4)"), expectedValue5) | ||
|
|
||
| val expectedValue6 = Row(Seq( | ||
| Row(192.toByte, 1.1, null, null), Row(256.toByte, null, null, null))) | ||
| checkAnswer(df6.select(zip($"v1", $"v2", $"v3", $"v4")), expectedValue6) | ||
| checkAnswer(df6.selectExpr("zip(v1, v2, v3, v4)"), expectedValue6) | ||
|
|
||
| val expectedValue7 = Row(Seq( | ||
| Row(Seq(1, 2, 3), 1.1), Row(Seq(4, 5), 2.2))) | ||
| checkAnswer(df7.select(zip($"v1", $"v2")), expectedValue7) | ||
| checkAnswer(df7.selectExpr("zip(v1, v2)"), expectedValue7) | ||
|
|
||
| val expectedValue8 = Row(Seq( | ||
| Row(Array[Byte](1.toByte, 5.toByte), null))) | ||
| checkAnswer(df8.select(zip($"v1", $"v2")), expectedValue8) | ||
| checkAnswer(df8.selectExpr("zip(v1, v2)"), expectedValue8) | ||
| } | ||
|
|
||
| test("map size function") { | ||
| val df = Seq( | ||
| (Map[Int, Int](1 -> 1, 2 -> 2), "x"), | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: columns of arrays to be merged.