From d8d3900ecf99d764d41ca7b5904301e818bf6644 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 23 Oct 2019 05:19:00 +0900 Subject: [PATCH 1/3] [SPARK-29503][SQL] Remove problematic conversion CreateNamedStruct to CreateNamedStructUnsafe --- .../sql/catalyst/expressions/Projection.scala | 8 +--- .../expressions/complexTypeCreator.scala | 48 +++++-------------- .../sql/catalyst/optimizer/ComplexTypes.scala | 6 +-- .../optimizer/NormalizeFloatingNumbers.scala | 5 +- .../sql/catalyst/optimizer/expressions.scala | 4 +- .../expressions/ComplexTypeSuite.scala | 1 - .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/DataFrameComplexTypeSuite.scala | 33 +++++++++++++ 8 files changed, 54 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index eaaf94baac21..300f075d3276 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -127,12 +127,6 @@ object UnsafeProjection InterpretedUnsafeProjection.createProjection(in) } - protected def toUnsafeExprs(exprs: Seq[Expression]): Seq[Expression] = { - exprs.map(_ transform { - case CreateNamedStruct(children) => CreateNamedStructUnsafe(children) - }) - } - /** * Returns an UnsafeProjection for given StructType. * @@ -153,7 +147,7 @@ object UnsafeProjection * Returns an UnsafeProjection for given sequence of bound Expressions. */ def create(exprs: Seq[Expression]): UnsafeProjection = { - createObject(toUnsafeExprs(exprs)) + createObject(exprs) } def create(expr: Expression): UnsafeProjection = create(Seq(expr)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index cae3c0528e13..3f722e8537c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -295,9 +295,20 @@ object CreateStruct extends FunctionBuilder { } /** - * Common base class for both [[CreateNamedStruct]] and [[CreateNamedStructUnsafe]]. + * Creates a struct with the given field names and values + * + * @param children Seq(name1, val1, name2, val2, ...) */ -trait CreateNamedStructLike extends Expression { +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.", + examples = """ + Examples: + > SELECT _FUNC_("a", 1, "b", 2, "c", 3); + {"a":1,"b":2,"c":3} + """) +// scalastyle:on line.size.limit +case class CreateNamedStruct(children: Seq[Expression]) extends Expression { lazy val (nameExprs, valExprs) = children.grouped(2).map { case Seq(name, value) => (name, value) }.toList.unzip @@ -348,23 +359,6 @@ trait CreateNamedStructLike extends Expression { override def eval(input: InternalRow): Any = { InternalRow(valExprs.map(_.eval(input)): _*) } -} - -/** - * Creates a struct with the given field names and values - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.", - examples = """ - Examples: - > SELECT _FUNC_("a", 1, "b", 2, "c", 3); - {"a":1,"b":2,"c":3} - """) -// scalastyle:on line.size.limit -case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStructLike { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName @@ -397,22 +391,6 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def prettyName: String = "named_struct" } -/** - * Creates a struct with the given field names and values. This is a variant that returns - * UnsafeRow directly. The unsafe projection operator replaces [[CreateStruct]] with - * this expression automatically at runtime. - * - * @param children Seq(name1, val1, name2, val2, ...) - */ -case class CreateNamedStructUnsafe(children: Seq[Expression]) extends CreateNamedStructLike { - override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ExprCode(code = eval.code, isNull = FalseLiteral, value = eval.value) - } - - override def prettyName: String = "named_struct_unsafe" -} - /** * Creates a map after splitting the input text into key/value pairs using delimiters */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index db7d6d3254bd..1743565ccb6c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule /** - * Simplify redundant [[CreateNamedStructLike]], [[CreateArray]] and [[CreateMap]] expressions. + * Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions. */ object SimplifyExtractValueOps extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -37,8 +37,8 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case a: Aggregate => a case p => p.transformExpressionsUp { // Remove redundant field extraction. - case GetStructField(createNamedStructLike: CreateNamedStructLike, ordinal, _) => - createNamedStructLike.valExprs(ordinal) + case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => + createNamedStruct.valExprs(ordinal) // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems), field, ordinal, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index b036092cf1fc..ea01d9e63eef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateNamedStructUnsafe, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CreateArray, CreateMap, CreateNamedStruct, CreateStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, KnownFloatingPointNormalized, LambdaFunction, NamedLambdaVariable, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery, Window} @@ -114,9 +114,6 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case CreateNamedStruct(children) => CreateNamedStruct(children.map(normalize)) - case CreateNamedStructUnsafe(children) => - CreateNamedStructUnsafe(children.map(normalize)) - case CreateArray(children) => CreateArray(children.map(normalize)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 0a6737ba4211..36ad796c08a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -227,8 +227,8 @@ object OptimizeIn extends Rule[LogicalPlan] { if (newList.length == 1 // TODO: `EqualTo` for structural types are not working. Until SPARK-24443 is addressed, // TODO: we exclude them in this rule. - && !v.isInstanceOf[CreateNamedStructLike] - && !newList.head.isInstanceOf[CreateNamedStructLike]) { + && !v.isInstanceOf[CreateNamedStruct] + && !newList.head.isInstanceOf[CreateNamedStruct]) { EqualTo(v, newList.head) } else if (newList.length > SQLConf.get.optimizerInSetConversionThreshold) { val hSet = newList.map(e => e.eval(EmptyRow)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 0c4438987cd2..9039cd645159 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -369,7 +369,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val b = AttributeReference("b", IntegerType)() checkMetadata(CreateStruct(Seq(a, b))) checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) - checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) } test("StringToMap") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 7b903a3f7f14..ed10843b0859 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -200,7 +200,7 @@ class Column(val expr: Expression) extends Logging { UnresolvedAlias(a, Some(Column.generateAlias)) // Wait until the struct is resolved. This will generate a nicer looking alias. - case struct: CreateNamedStructLike => UnresolvedAlias(struct) + case struct: CreateNamedStruct => UnresolvedAlias(struct) case expr: Expression => Alias(expr, toPrettySQL(expr))() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index e9179a39d3b6..10d09bc7350b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -17,9 +17,15 @@ package org.apache.spark.sql +import scala.collection.mutable + import org.apache.spark.sql.catalyst.DefinedByConstructorParams +import org.apache.spark.sql.catalyst.expressions.{Expression, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.expressions.objects.MapObjects import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.ArrayType /** * A test suite to test DataFrame/SQL functionalities with complex types (i.e. array, struct, map). @@ -64,6 +70,33 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { val ds100_5 = Seq(S100_5()).toDS() ds100_5.rdd.count } + + test("SPARK-29503 nest unsafe struct inside safe array") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val exampleDS = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items") + + // items: Seq[Int] => items.map { item => Seq(Struct(item)) } + val result = exampleDS.select( + new Column(MapObjects( + (item: Expression) => array(struct(new Column(item))).expr, + $"items".expr, + exampleDS.schema("items").dataType.asInstanceOf[ArrayType].elementType + )) as "items" + ).collect() + + def getValueInsideDepth(result: Row, index: Int): Int = { + // expected output: + // WrappedArray([WrappedArray(WrappedArray([1]), WrappedArray([2]), WrappedArray([3]))]) + result.getSeq[mutable.WrappedArray[_]](0)(index)(0) + .asInstanceOf[GenericRowWithSchema].getInt(0) + } + + assert(result.size === 1) + assert(getValueInsideDepth(result.head, 0) === 1) + assert(getValueInsideDepth(result.head, 1) === 2) + assert(getValueInsideDepth(result.head, 2) === 3) + } + } } class S100( From 44602333fa84cae8d0bc3702b3e5b5596c5d61a7 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 23 Oct 2019 13:29:32 +0900 Subject: [PATCH 2/3] reflect review comments --- .../spark/sql/DataFrameComplexTypeSuite.scala | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 10d09bc7350b..128b0eac0e63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -73,28 +73,19 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { test("SPARK-29503 nest unsafe struct inside safe array") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { - val exampleDS = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items") + val df = spark.sparkContext.parallelize(Seq(Seq(1, 2, 3))).toDF("items") // items: Seq[Int] => items.map { item => Seq(Struct(item)) } - val result = exampleDS.select( + val result = df.select( new Column(MapObjects( (item: Expression) => array(struct(new Column(item))).expr, $"items".expr, - exampleDS.schema("items").dataType.asInstanceOf[ArrayType].elementType + df.schema("items").dataType.asInstanceOf[ArrayType].elementType )) as "items" ).collect() - def getValueInsideDepth(result: Row, index: Int): Int = { - // expected output: - // WrappedArray([WrappedArray(WrappedArray([1]), WrappedArray([2]), WrappedArray([3]))]) - result.getSeq[mutable.WrappedArray[_]](0)(index)(0) - .asInstanceOf[GenericRowWithSchema].getInt(0) - } - assert(result.size === 1) - assert(getValueInsideDepth(result.head, 0) === 1) - assert(getValueInsideDepth(result.head, 1) === 2) - assert(getValueInsideDepth(result.head, 2) === 3) + assert(result === Row(Seq(Seq(Row(1)), Seq(Row(2)), Seq(Row(3)))) :: Nil) } } } From f10042abe8cd0fdbea4604e57b54e5797279fe93 Mon Sep 17 00:00:00 2001 From: "Jungtaek Lim (HeartSaVioR)" Date: Wed, 23 Oct 2019 13:43:22 +0900 Subject: [PATCH 3/3] fix nits --- .../org/apache/spark/sql/DataFrameComplexTypeSuite.scala | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 128b0eac0e63..4f2564290662 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql -import scala.collection.mutable - import org.apache.spark.sql.catalyst.DefinedByConstructorParams -import org.apache.spark.sql.catalyst.expressions.{Expression, GenericRowWithSchema} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.expressions.objects.MapObjects import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf