diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index d76f3013f0c4..d20c04373b94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -199,7 +199,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI """.stripMargin } - val splittedGetValuesAndCardinalities = ctx.splitExpressions( + val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs( expressions = getValuesAndCardinalities, funcName = "getValuesAndCardinalities", returnType = "int", @@ -209,7 +209,7 @@ case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsI |return $biggestCardinality; """.stripMargin, foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"), - arguments = + extraArguments = ("ArrayData[]", arrVals) :: ("int", biggestCardinality) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 4e5c1c56e267..d3c55b15a44f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -556,6 +556,17 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(df8.selectExpr("arrays_zip(v1, v2)"), expectedValue8) } + test("SPARK-24633: arrays_zip splits input processing correctly") { + Seq("true", "false").foreach { wholestageCodegenEnabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholestageCodegenEnabled) { + val df = spark.range(1) + val exprs = (0 to 5).map(x => array($"id" + lit(x))) + checkAnswer(df.select(arrays_zip(exprs: _*)), + Row(Seq(Row(0, 1, 2, 3, 4, 5)))) + } + } + } + test("map size function") { val df = Seq( (Map[Int, Int](1 -> 1, 2 -> 2), "x"),