From ac149a52b88e7eb9ab5aa50d1cf341caf9f59faa Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Sat, 15 Aug 2020 07:00:19 -0400 Subject: [PATCH 01/18] implement dropFields --- .../expressions/complexTypeCreator.scala | 120 ++- .../sql/catalyst/optimizer/ComplexTypes.scala | 41 +- .../sql/catalyst/optimizer/Optimizer.scala | 6 +- .../{WithFields.scala => UpdateFields.scala} | 16 +- ...e.scala => CombineUpdateFieldsSuite.scala} | 41 +- .../optimizer/complexTypesSuite.scala | 279 ++++++- .../scala/org/apache/spark/sql/Column.scala | 120 ++- .../spark/sql/ColumnExpressionSuite.scala | 745 ++++++++++++++++++ 8 files changed, 1236 insertions(+), 132 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/{WithFields.scala => UpdateFields.scala} (68%) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{CombineWithFieldsSuite.scala => CombineUpdateFieldsSuite.scala} (65%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 563ce7133a3d..c7f02986ddaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -541,57 +541,105 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } /** - * Adds/replaces field in struct by name. + * Represents an operation to be applied to the fields of a struct. */ -case class WithFields( - structExpr: Expression, - names: Seq[String], - valExprs: Seq[Expression]) extends Unevaluable { +trait StructFieldsOperation { - assert(names.length == valExprs.length) + val resolver: Resolver = SQLConf.get.resolver + + /** + * Returns an updated list of StructFields and Expressions that will ultimately be used + * as the fields argument for [[StructType]] and as the children argument for + * [[CreateNamedStruct]] respectively inside of [[UpdateFields]]. + */ + def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] +} + +/** + * Add or replace a field by name. + * + * We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its + * children, and thereby enable the analyzer to resolve and transform valExpr as necessary. + */ +case class WithField(name: String, valExpr: Expression) + extends Unevaluable with StructFieldsOperation { + + override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = { + val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable), valExpr) + if (values.exists { case (field, _) => resolver(field.name, name) }) { + values.map { + case (field, _) if resolver(field.name, name) => newFieldExpr + case x => x + } + } else { + values :+ newFieldExpr + } + } + + override def children: Seq[Expression] = valExpr :: Nil + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + + override def prettyName: String = "WithField" +} + +/** + * Drop a field by name. + */ +case class DropField(name: String) extends StructFieldsOperation { + override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = + values.filterNot { case (field, _) => resolver(field.name, name) } +} + +/** + * Updates fields in struct by name. + */ +case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation]) + extends Unevaluable { override def checkInputDataTypes(): TypeCheckResult = { - if (!structExpr.dataType.isInstanceOf[StructType]) { - TypeCheckResult.TypeCheckFailure( - "struct argument should be struct type, got: " + structExpr.dataType.catalogString) + val dataType = structExpr.dataType + if (!dataType.isInstanceOf[StructType]) { + TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " + + dataType.catalogString) + } else if (newExprs.isEmpty) { + TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct") } else { TypeCheckResult.TypeCheckSuccess } } - override def children: Seq[Expression] = structExpr +: valExprs + override def children: Seq[Expression] = structExpr +: fieldOps.collect { + case e: Expression => e + } - override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType] + override def dataType: StructType = StructType(newFields) override def nullable: Boolean = structExpr.nullable - override def prettyName: String = "with_fields" + override def prettyName: String = "update_fields" - lazy val evalExpr: Expression = { - val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression]) + private lazy val existingFieldExprs: Seq[(StructField, Expression)] = + structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex.map { + case (field, i) => (field, GetStructField(structExpr, i)) } - val addOrReplaceExprs = names.zip(valExprs) - - val resolver = SQLConf.get.resolver - val newExprs = addOrReplaceExprs.foldLeft(existingExprs) { - case (resultExprs, newExpr @ (newExprName, _)) => - if (resultExprs.exists(x => resolver(x._1, newExprName))) { - resultExprs.map { - case (name, _) if resolver(name, newExprName) => newExpr - case x => x - } - } else { - resultExprs :+ newExpr - } - }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + private lazy val newFieldExprs: Seq[(StructField, Expression)] = + fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs)) - val expr = CreateNamedStruct(newExprs) - if (structExpr.nullable) { - If(IsNull(structExpr), Literal(null, expr.dataType), expr) - } else { - expr - } + private lazy val newFields: Seq[StructField] = newFieldExprs.map(_._1) + + lazy val newExprs: Seq[Expression] = newFieldExprs.map(_._2) + + private lazy val createNamedStructExpr = CreateNamedStruct(newFieldExprs.flatMap { + case (field, expr) => Seq(Literal(field.name), expr) + }) + + lazy val evalExpr: Expression = if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, dataType), createNamedStructExpr) + } else { + createNamedStructExpr } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 2aba4bae397c..aa992dc8d236 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.StructType /** * Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions. @@ -39,19 +40,16 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => - val name = w.dataType(ordinal).name - val matches = names.zip(valExprs).filter(_._1 == name) - if (matches.nonEmpty) { - // return last matching element as that is the final value for the field being extracted. - // For example, if a user submits a query like this: - // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` - // we want to return `lit(2)` (and not `lit(1)`). - val expr = matches.last._2 - If(IsNull(struct), Literal(null, expr.dataType), expr) - } else { - GetStructField(struct, ordinal, maybeName) - } + case GetStructField(updateFields: UpdateFields, ordinal, _) => + val expr = updateFields.newExprs(ordinal) + val structExpr = updateFields.structExpr + if (isExprNestedInsideStruct(expr, structExpr)) { + // if the struct itself is null, then any value extracted from it (expr) will be null + // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) + expr + } else { + If(IsNull(ultimateStruct(structExpr)), Literal(null, expr.dataType), expr) + } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => // Instead of selecting the field on the entire array, select it from each member @@ -73,4 +71,21 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems) } } + + @scala.annotation.tailrec + private def ultimateStruct(expr: Expression): Expression = expr match { + case e: UpdateFields => ultimateStruct(e.structExpr) + case e => e + } + + @scala.annotation.tailrec + private def isExprNestedInsideStruct(expr: Expression, struct: Expression): Boolean = { + require(struct.dataType.isInstanceOf[StructType]) + + expr match { + case e: GetStructField => + e.child.semanticEquals(struct) || isExprNestedInsideStruct(e.child, struct) + case _ => false + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6033c01a60f4..e56375d32f11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -109,8 +109,8 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveRedundantAliases, UnwrapCastInBinaryComparison, RemoveNoopOperators, - CombineWithFields, SimplifyExtractValueOps, + CombineUpdateFields, CombineConcats) ++ extendedOperatorOptimizationRules @@ -221,7 +221,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ - Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression) + Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -255,7 +255,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceWithFieldsExpression.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala similarity index 68% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index 05c90864e4bb..c7154210e0c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -17,26 +17,26 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.expressions.UpdateFields import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** - * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + * Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression. */ -object CombineWithFields extends Rule[LogicalPlan] { +object CombineUpdateFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => + UpdateFields(struct, fieldOps1 ++ fieldOps2) } } /** - * Replaces [[WithFields]] expression with an evaluable expression. + * Replaces [[UpdateFields]] expression with an evaluable expression. */ -object ReplaceWithFieldsExpression extends Rule[LogicalPlan] { +object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case w: WithFields => w.evalExpr + case u: UpdateFields => u.evalExpr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala similarity index 65% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala index a3e0bbc57e63..ff9c60a2fa5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class CombineWithFieldsSuite extends PlanTest { +class CombineUpdateFieldsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil } private val testRelation = LocalRelation('a.struct('a1.int)) - test("combines two WithFields") { + test("combines two adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( - WithFields( - WithFields( + UpdateFields( + UpdateFields( 'a, - Seq("b1"), - Seq(Literal(4))), - Seq("c1"), - Seq(Literal(5))), "out")()) + WithField("b1", Literal(4)) :: Nil), + WithField("c1", Literal(5)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")()) + .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + Nil), "out")()) .analyze comparePlans(optimized, correctAnswer) } - test("combines three WithFields") { + test("combines three adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( - WithFields( - WithFields( - WithFields( + UpdateFields( + UpdateFields( + UpdateFields( 'a, - Seq("b1"), - Seq(Literal(4))), - Seq("c1"), - Seq(Literal(5))), - Seq("d1"), - Seq(Literal(6))), "out")()) + WithField("b1", Literal(4)) :: Nil), + WithField("c1", Literal(5)) :: Nil), + WithField("d1", Literal(6)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")()) + .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + WithField("d1", Literal(6)) :: Nil), "out")()) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 00aed6a10cd6..c54bf8ae5e36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -37,14 +37,15 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = Batch("collapse projections", FixedPoint(10), - CollapseProject) :: + CollapseProject) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation, - ConstantFolding, - BooleanSimplification, - SimplifyConditionals, - SimplifyBinaryComparison, - SimplifyExtractValueOps) :: Nil + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + SimplifyExtractValueOps, + CombineUpdateFields) :: Nil } private val idAtt = ('id).long.notNull @@ -453,28 +454,29 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } - private val structAttr = 'struct1.struct('a.int).withNullability(false) + private val structAttr = 'struct1.struct('a.int, 'b.int).withNullability(false) private val testStructRelation = LocalRelation(structAttr) - private val nullableStructAttr = 'struct1.struct('a.int) + private val nullableStructAttr = 'struct1.struct('a.int, 'b.int) private val testNullableStructRelation = LocalRelation(nullableStructAttr) - test("simplify GetStructField on WithFields that is not changing the attribute being extracted") { - def query(relation: LocalRelation): LogicalPlan = relation.select( - GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAttr") + test("simplify GetStructField on UpdateFields that is not changing the attribute being " + + "extracted") { + def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( + UpdateFields('struct1, Seq(WithField("b", Literal(1)))), 0, Some("a")) as "outerAttr") checkRule( query(testStructRelation), - testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAttr")) + testStructRelation.select(GetStructField('struct1, 0) as "outerAttr")) checkRule( query(testNullableStructRelation), - testNullableStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAttr")) + testNullableStructRelation.select(GetStructField('struct1, 0) as "outerAttr")) } - test("simplify GetStructField on WithFields that is changing the attribute being extracted") { - def query(relation: LocalRelation): LogicalPlan = relation.select( - GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "res") + test("simplify GetStructField on UpdateFields that is changing the attribute being extracted") { + def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( + UpdateFields('struct1, Seq(WithField("b", Literal(1)))), 1, Some("b")) as "res") checkRule( query(testStructRelation), @@ -486,11 +488,11 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { If(IsNull('struct1), Literal(null, IntegerType), Literal(1)) as "res")) } - test( - "simplify GetStructField on WithFields that is changing the attribute being extracted twice") { - def query(relation: LocalRelation): LogicalPlan = relation.select( - GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1, Some("b")) - as "outerAtt") + test("simplify GetStructField on UpdateFields that is changing the attribute being extracted " + + "twice") { + def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( + UpdateFields('struct1, Seq(WithField("b", Literal(1)), WithField("b", Literal(2)))), + 1, Some("b")) as "outerAtt") checkRule( query(testStructRelation), @@ -502,9 +504,9 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "outerAtt")) } - test("collapse multiple GetStructField on the same WithFields") { + test("collapse multiple GetStructField on the same UpdateFields") { def query(relation: LocalRelation): LogicalPlan = relation - .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2") + .select(UpdateFields('struct1, Seq(WithField("b", Literal(2)))) as "struct2") .select( GetStructField('struct2, 0, Some("a")) as "struct1A", GetStructField('struct2, 1, Some("b")) as "struct1B") @@ -512,21 +514,21 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule( query(testStructRelation), testStructRelation.select( - GetStructField('struct1, 0, Some("a")) as "struct1A", + GetStructField('struct1, 0) as "struct1A", Literal(2) as "struct1B")) checkRule( query(testNullableStructRelation), testNullableStructRelation.select( - GetStructField('struct1, 0, Some("a")) as "struct1A", + GetStructField('struct1, 0) as "struct1A", If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct1B")) } - test("collapse multiple GetStructField on different WithFields") { + test("collapse multiple GetStructField on different UpdateFields") { def query(relation: LocalRelation): LogicalPlan = relation .select( - WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2", - WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3") + UpdateFields('struct1, Seq(WithField("b", Literal(2)))) as "struct2", + UpdateFields('struct1, Seq(WithField("b", Literal(3)))) as "struct3") .select( GetStructField('struct2, 0, Some("a")) as "struct2A", GetStructField('struct2, 1, Some("b")) as "struct2B", @@ -537,18 +539,229 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { query(testStructRelation), testStructRelation .select( - GetStructField('struct1, 0, Some("a")) as "struct2A", + GetStructField('struct1, 0) as "struct2A", Literal(2) as "struct2B", - GetStructField('struct1, 0, Some("a")) as "struct3A", + GetStructField('struct1, 0) as "struct3A", Literal(3) as "struct3B")) checkRule( query(testNullableStructRelation), testNullableStructRelation .select( - GetStructField('struct1, 0, Some("a")) as "struct2A", + GetStructField('struct1, 0) as "struct2A", If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct2B", - GetStructField('struct1, 0, Some("a")) as "struct3A", + GetStructField('struct1, 0) as "struct3A", If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as "struct3B")) } + + test("simplify GetStructField on UpdateFields with multiple WithField and extract new column") { + def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( + UpdateFields('struct1, ('a' to 'z').zipWithIndex.map { case (char, i) => + WithField(char.toString, Literal(i)) + }), 2) as "res") + + checkRule( + query(testStructRelation), + testStructRelation.select(Literal(2) as "res")) + + checkRule( + query(testNullableStructRelation), + testNullableStructRelation.select( + If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "res")) + } + + test("Combine multiple UpdateFields and extract newly added field") { + def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( + UpdateFields( + UpdateFields( + UpdateFields( + UpdateFields( + 'struct1, + Seq(WithField("col1", Literal(1)))), + Seq(WithField("col2", Literal(2)))), + Seq(WithField("col3", Literal(3)))), + Seq(WithField("col4", Literal(4)))), + 5, Some("col4")) as "res") + + checkRule( + query(testStructRelation), + testStructRelation.select(Literal(4) as "res")) + + checkRule( + query(testNullableStructRelation), + testNullableStructRelation.select( + If(IsNull('struct1), Literal(null, IntegerType), Literal(4)) as "res")) + } + + test("should correctly handle different WithField + DropField + GetStructField combinations") { + def check(fieldOps: Seq[StructFieldsOperation], ordinal: Int, expected: Expression): Unit = { + def query(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField(UpdateFields('struct1, fieldOps), ordinal).as("res")) + + checkRule( + query(testStructRelation), + testStructRelation.select(expected.as("res"))) + + checkRule( + query(testNullableStructRelation), + testNullableStructRelation.select((expected match { + case expr: GetStructField if expr.child.fastEquals('struct1.expr) => expr + case expr => If(IsNull('struct1), Literal(null, expr.dataType), expr) + }).as("res"))) + } + + // add attribute, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + + // drop attribute, extract an attribute from the original struct + check(DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + + // drop attribute, add attribute, extract an attribute from the original struct + check(DropField("b") :: WithField("c", Literal(2)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: WithField("c", Literal(2)) :: Nil, 0, GetStructField('struct1, 1)) + + // add attribute, drop attribute, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("b", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("a", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + + // add attribute, drop same attribute, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 0, GetStructField('struct1, 0)) + + // add attribute, drop another attribute, extract added attribute + check(WithField("b", Literal(3)) :: DropField("a") :: Nil, 0, Literal(3)) + check(WithField("a", Literal(3)) :: DropField("b") :: Nil, 0, Literal(3)) + + // drop attribute, add same attribute, extract added attribute + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) + check(WithField("c", Literal(3)) :: DropField("c") :: WithField("c", Literal(4)) :: Nil, 2, + Literal(4)) + + // drop earlier attribute, add attribute, extract added attribute + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + + // drop later attribute, add attribute, extract added attribute + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + } + + test("should only transform GetStructField calls if not GetStructField on same struct") { + val struct2 = 'struct2.struct('b.int) + val testStructRelation = LocalRelation(structAttr, struct2) + val testNullableStructRelation = LocalRelation(nullableStructAttr, struct2) + + def addFieldFromAnotherStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField( + UpdateFields('struct1, Seq(WithField("b", GetStructField('struct2, 0)))), 1).as("res")) + + checkRule( + addFieldFromAnotherStructAndThenExtractIt(testStructRelation), + testStructRelation.select(GetStructField('struct2, 0).as("res"))) + + checkRule( + addFieldFromAnotherStructAndThenExtractIt(testNullableStructRelation), + testNullableStructRelation.select( + If(IsNull('struct1), Literal(null, IntegerType), GetStructField('struct2, 0)).as("res"))) + + def addFieldFromSameStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField( + UpdateFields('struct1, Seq(WithField("b", GetStructField('struct1, 0)))), 1).as("res")) + + checkRule( + addFieldFromSameStructAndThenExtractIt(testStructRelation), + testStructRelation.select(GetStructField('struct1, 0).as("res"))) + + checkRule( + addFieldFromSameStructAndThenExtractIt(testNullableStructRelation), + testNullableStructRelation.select(GetStructField('struct1, 0).as("res"))) + } + + test("simplify if extract value is from inside struct") { + val nullableStructAttr = 'struct1a.struct( + 'struct2a.struct( + 'struct3a.struct('a.int, 'b.int))) + val structAttr = nullableStructAttr.withNullability(false) + val testStructRelation = LocalRelation(structAttr) + val testNullableStructRelation = LocalRelation(nullableStructAttr) + + Seq(testStructRelation, testNullableStructRelation).foreach { relation => + checkRule( + relation.select(GetStructField( + UpdateFields('struct1a, Seq( + WithField("struct2b", GetStructField(GetStructField('struct1a, 0), 0)))), 1).as("res")), + relation.select(GetStructField(GetStructField('struct1a, 0), 0).as("res"))) + } + } + + test("simplify add multiple nested fields") { + // this scenario is possible if users add multiple nested columns via the Column.withField API + // ideally, users should not be doing this. + val nullableStructLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int)).withNullability(false)) + + val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", + UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + + val query = nullableStructLevel2.select( + UpdateFields( + addB3toA1A2, + Seq(WithField("a2", UpdateFields( + GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) + + val expected = nullableStructLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", 2)))), + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq( + WithField("b3", 2), + WithField("c3", 3)))) + )).as("a1")) + + // TODO: how to make this a reality + val idealExpected = nullableStructLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq( + WithField("b3", 2), + WithField("c3", 3)))) + )).as("a1")) + + checkRule(query, expected) + } + + test("simplify drop multiple nested fields") { + // this scenario is possible if users drop multiple nested columns via the Column.dropFields API + // ideally, users should not be doing this. + val df = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int, 'b3.int, 'c3.int).withNullability(false) + ).withNullability(false)) + + val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( + GetStructField('a1, 0), Seq(DropField("b3")))))) + + val query = df.select( + UpdateFields( + dropA1A2B, + Seq(WithField("a2", UpdateFields( + GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) + + val expected = df.select( + UpdateFields('a1, Seq( + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3")))), + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) + )).as("a1")) + + // TODO: how to make this a reality + val idealExpected = df.select( + UpdateFields('a1, Seq( + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) + )).as("a1")) + + checkRule(query, expected) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index da542c67d9c5..f8daaf9c58c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -901,6 +901,23 @@ class Column(val expr: Expression) extends Logging { * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields * }}} * + * This method supports adding/replacing nested fields directly e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d", lit(4))) + * // result: {"a":{"a":1,"b":2,"c":3,"d":4}} + * }}} + * + * However, if you are going to add/replace multiple nested fields, it is more optimal to extract + * out the nested struct before adding/replacing multiple fields e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a", $"struct_col.a".withField("c", lit(3)).withField("d", lit(4)))) + * // result: {"a":{"a":1,"b":2,"c":3,"d":4}} + * }}} + * * @group expr_ops * @since 3.1.0 */ @@ -908,32 +925,101 @@ class Column(val expr: Expression) extends Logging { def withField(fieldName: String, col: Column): Column = withExpr { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") + updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, col.expr)) + } - val nameParts = if (fieldName.isEmpty) { + // scalastyle:off line.size.limit + /** + * An expression that drops fields in `StructType` by name. + * + * {{{ + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("c")) + * // result: {"a":1,"b":2} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col") + * df.select($"struct_col".dropFields("b", "c")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("a", "b")) + * // result: org.apache.spark.sql.AnalysisException: cannot resolve 'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot drop all fields in struct + * + * val df = sql("SELECT CAST(NULL AS struct) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: null of type struct + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".dropFields("a.b")) + * // result: {"a":{"a":1}} + * + * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + * df.select($"struct_col".dropFields("a.c")) + * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields + * }}} + * + * This method supports dropping multiple nested fields directly e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".dropFields("a.b", "a.c")) + * // result: {"a":{"a":1}} + * }}} + * + * However, if you are going to drop multiple nested fields, it is more optimal to extract + * out the nested struct before dropping multiple fields from it e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b", "c"))) + * // result: {"a":{"a":1}} + * }}} + * + * @group expr_ops + * @since 3.1.0 + */ + // scalastyle:on line.size.limit + def dropFields(fieldNames: String*): Column = withExpr { + def dropField(structExpr: Expression, fieldName: String): UpdateFields = + updateFieldsHelper(structExpr, nameParts(fieldName), name => DropField(name)) + + fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) { + (resExpr, fieldName) => dropField(resExpr, fieldName) + } + } + + private def nameParts(fieldName: String): Seq[String] = { + require(fieldName != null, "fieldName cannot be null") + + if (fieldName.isEmpty) { fieldName :: Nil } else { CatalystSqlParser.parseMultipartIdentifier(fieldName) } - withFieldHelper(expr, nameParts, Nil, col.expr) } - private def withFieldHelper( - struct: Expression, - namePartsRemaining: Seq[String], - namePartsDone: Seq[String], - value: Expression) : WithFields = { - val name = namePartsRemaining.head + private def updateFieldsHelper( + structExpr: Expression, + namePartsRemaining: Seq[String], + valueFunc: String => StructFieldsOperation): UpdateFields = { + + val fieldName = namePartsRemaining.head if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil) + UpdateFields(structExpr, valueFunc(fieldName) :: Nil) } else { - val newNamesRemaining = namePartsRemaining.tail - val newNamesDone = namePartsDone :+ name - val newValue = withFieldHelper( - struct = UnresolvedExtractValue(struct, Literal(name)), - namePartsRemaining = newNamesRemaining, - namePartsDone = newNamesDone, - value = value) - WithFields(struct, name :: Nil, newValue :: Nil) + val newValue = updateFieldsHelper( + structExpr = UnresolvedExtractValue(structExpr, Literal(fieldName)), + namePartsRemaining = namePartsRemaining.tail, + valueFunc = valueFunc) + UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 24419968c047..48872de9ad43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -927,6 +927,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { expectedAnswer: Seq[Row], expectedSchema: StructType): Unit = { + df.explain(true) + df.printSchema() + df.show(false) checkAnswer(df, expectedAnswer) assert(df.schema == expectedSchema) } @@ -967,6 +970,15 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) + private lazy val nullStructLevel3: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = true))), + nullable = true))), + nullable = false)))) + test("withField should throw an exception if called on a non-StructType column") { intercept[AnalysisException] { testData.withColumn("key", $"key".withField("a", lit(2))) @@ -1451,6 +1463,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") .select($"struct_col".withField("a.c", lit(3))) }.getMessage should include("Ambiguous reference to fields") + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".withField("a.c", lit(3)).withField("a.d", lit(4))), + Row(Row(Row(1, 2, 3, 4)))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".withField("a", + $"struct_col.a".withField("c", lit(3)).withField("d", lit(4)))), + Row(Row(Row(1, 2, 3, 4)))) } test("SPARK-32641: extracting field from non-null struct column after withField should return " + @@ -1537,4 +1560,726 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { Row(3) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) } + + + test("dropFields should throw an exception if called on a non-StructType column") { + intercept[AnalysisException] { + testData.withColumn("key", $"key".dropFields("a")) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("dropFields should throw an exception if fieldName argument is null") { + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".dropFields(null)) + }.getMessage should include("fieldName cannot be null") + } + + test("dropFields should throw an exception if any intermediate structs don't exist") { + intercept[AnalysisException] { + structLevel2.withColumn("a", 'a.dropFields("x.b")) + }.getMessage should include("No such struct field x in a") + + intercept[AnalysisException] { + structLevel3.withColumn("a", 'a.dropFields("a.x.b")) + }.getMessage should include("No such struct field x in a") + } + + test("dropFields should throw an exception if intermediate field is not a struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.dropFields("b.a")) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("dropFields should throw an exception if intermediate field reference is ambiguous") { + intercept[AnalysisException] { + val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("a", structType, nullable = false))), + nullable = false)))) + + structLevel2.withColumn("a", 'a.dropFields("a.b")) + }.getMessage should include("Ambiguous reference to fields") + } + + test("dropFields should drop field in struct") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in null struct") { + checkAnswerAndSchema( + nullStructLevel1.withColumn("a", $"a".dropFields("b")), + Row(null) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + + test("dropFields should drop multiple fields in struct") { + Seq( + structLevel1.withColumn("a", $"a".dropFields("b", "c")), + structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) + ).foreach { df => + checkAnswerAndSchema( + df, + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should throw an exception if no fields will be left in struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.dropFields("a", "b", "c")) + }.getMessage should include("cannot drop all fields in struct") + } + + test("dropFields should drop field in nested struct") { + checkAnswerAndSchema( + structLevel2.withColumn("a", 'a.dropFields("a.b")), + Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop multiple fields in nested struct") { + checkAnswerAndSchema( + structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), + Row(Row(Row(1))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".dropFields("a.b")), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("dropFields should drop multiple fields in nested null struct") { + checkAnswerAndSchema( + nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")), + Row(Row(null)) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + } + + test("dropFields should drop field in deeply nested struct") { + checkAnswerAndSchema( + structLevel3.withColumn("a", 'a.dropFields("a.a.b")), + Row(Row(Row(Row(1, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop all fields with given name in struct") { + val structLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should not drop field in struct because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + Row(Row(1, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should drop nested field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), + Row(Row(Row(1), Row(1, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", StructType(Seq( + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), + Row(Row(Row(1, 1), Row(1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("b", StructType(Seq( + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("dropFields should throw an exception because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")) + }.getMessage should include("No such struct field A in a, B") + + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")) + }.getMessage should include("No such struct field b in a, B") + } + } + + test("dropFields should drop only fields that exist") { + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("d")), + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel1.withColumn("a", 'a.dropFields("b", "d")), + Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")), + Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop multiple fields at arbitrary levels of nesting in a single call") { + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswerAndSchema( + df.withColumn("a", $"a".dropFields("a.b", "b")), + Row(Row(Row(1, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), nullable = false))), + nullable = false)))) + } + + test("dropFields user-facing examples") { + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("b")), + Row(Row(1))) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("c")), + Row(Row(1, 2))) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col") + .select($"struct_col".dropFields("b", "c")), + Row(Row(1))) + + intercept[AnalysisException] { + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("a", "b")) + }.getMessage should include("cannot drop all fields in struct") + + checkAnswer( + sql("SELECT CAST(NULL AS struct) struct_col") + .select($"struct_col".dropFields("b")), + Row(null)) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + .select($"struct_col".dropFields("b")), + Row(Row(1))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".dropFields("a.b")), + Row(Row(Row(1)))) + + intercept[AnalysisException] { + sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + .select($"struct_col".dropFields("a.c")) + }.getMessage should include("Ambiguous reference to fields") + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3)) struct_col") + .select($"struct_col".dropFields("a.b", "a.c")), + Row(Row(Row(1)))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3)) struct_col") + .select($"struct_col".withField("a", $"struct_col.a".dropFields("b", "c"))), + Row(Row(Row(1)))) + } + + test("should correctly handle different dropField + withField + getField combinations") { + // TODO: maybe make another suite, where we run through these tests with and without optimizer + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))) + + val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + val nullStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + val nullableStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2)) :: Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + def check( + fieldOps: Column => Column, + getFieldName: String, + expectedValue: Option[Int]): Unit = { + + def query(df: DataFrame): DataFrame = + df.select(fieldOps(col("a")).getField(getFieldName).as("res")) + + checkAnswerAndSchema( + query(structLevel1), + Row(expectedValue.orNull) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = expectedValue.isEmpty)))) + + checkAnswerAndSchema( + query(nullStructLevel1), + Row(null) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + checkAnswerAndSchema( + query(nullableStructLevel1), + Row(expectedValue.orNull) :: Row(null) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + // add attribute, extract an attribute from the original struct + check(_.withField("c", lit(3)), "a", Some(1)) + check(_.withField("c", lit(3)), "b", Some(2)) + + // add attribute, extract added attribute + check(_.withField("c", lit(3)), "c", Some(3)) + check(_.withField("c", col("a.a")), "c", Some(1)) + check(_.withField("c", col("a.b")), "c", Some(2)) + check(_.withField("c", lit(null).cast(IntegerType)), "c", None) + + // replace attribute, extract an attribute from the original struct + check(_.withField("b", lit(3)), "a", Some(1)) + + // replace attribute, extract replaced attribute + check(_.withField("b", lit(3)), "b", Some(3)) + check(_.withField("b", lit(null).cast(IntegerType)), "b", None) + + // drop attribute, extract an attribute from the original struct + check(_.dropFields("b"), "a", Some(1)) + check(_.dropFields("a"), "b", Some(2)) + + // drop attribute, add attribute, extract an attribute from the original struct + check(_.dropFields("b").withField("c", lit(3)), "a", Some(1)) + check(_.dropFields("a").withField("c", lit(3)), "b", Some(2)) + + // add attribute, drop attribute, extract an attribute from the original struct + check(_.withField("c", lit(3)).dropFields("a"), "b", Some(2)) + check(_.withField("c", lit(3)).dropFields("b"), "a", Some(1)) + check(_.withField("b", lit(3)).dropFields("b"), "a", Some(1)) + check(_.withField("a", lit(3)).dropFields("a"), "b", Some(2)) + + // add attribute, drop same attribute, extract an attribute from the original struct + check(_.withField("c", lit(3)).dropFields("c"), "a", Some(1)) + + // add attribute, drop another attribute, extract added attribute + check(_.withField("b", lit(3)).dropFields("a"), "b", Some(3)) + check(_.withField("a", lit(3)).dropFields("b"), "a", Some(3)) + check(_.withField("b", lit(null).cast(IntegerType)).dropFields("a"), "b", None) + check(_.withField("a", lit(null).cast(IntegerType)).dropFields("b"), "a", None) + + // drop attribute, add same attribute, extract added attribute + check(_.dropFields("a").withField("c", lit(3)), "c", Some(3)) + check(_.dropFields("b").withField("c", lit(3)), "c", Some(3)) + check(_.dropFields("b").withField("b", lit(3)), "b", Some(3)) + check(_.dropFields("a").withField("a", lit(3)), "a", Some(3)) + check(_.dropFields("b").withField("b", lit(null).cast(IntegerType)), "b", None) + check(_.dropFields("a").withField("a", lit(null).cast(IntegerType)), "a", None) + check(_.dropFields("c").withField("c", lit(3)), "c", Some(3)) + check(_.withField("c", lit(3)).dropFields("c").withField("c", lit(4)), "c", Some(4)) + } + + test("non-nullable field from a struct being added to a non-nullable struct and then extracted") { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq(Row(Row(1, 2, 3), Row(4, 5, 6)))), + StructType(Seq( + StructField("a1", structType, nullable = false), + StructField("a2", structType, nullable = false)))) + + // add field from the same struct (a1) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), + Seq(Row(2)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + // add field from another struct (a2) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), + Seq(Row(5)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + test("nullable field from a struct being added to a nullable struct and then extracted") { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(Row(1, 2, 3), Row(4, 5, 6)), + Row(null, Row(4, 5, 6)), + Row(Row(1, 2, 3), null), + Row(null, null))), + StructType(Seq( + StructField("a1", structType, nullable = true), + StructField("a2", structType, nullable = true)))) + + // TODO: This doesn't fit the test title? + // add field from the same struct (a1) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), + Seq( + Row(2), + Row(null), + Row(2), + Row(null)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + // add field from another struct (a2) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), + Seq( + Row(5), + Row(null), + Row(null), + Row(null)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + test("non-nullable field from a struct being added to a nullable struct and then extracted") { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(Row(1, 2, 3), Row(4, 5, 6)), + Row(null, Row(4, 5, 6)), + Row(Row(1, 2, 3), Row(4, 5, 6)), + Row(null, Row(4, 5, 6)))), + StructType(Seq( + StructField("a1", structType, nullable = true), + StructField("a2", structType, nullable = false)))) + + // add field from the same struct (a1) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), + Seq( + Row(2), + Row(null), + Row(2), + Row(null)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + // add field from another struct (a2) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), + Seq( + Row(5), + Row(5), + Row(null), + Row(null)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + test("nullable field from a struct being added to a non-nullable struct and then extracted") { + val df = spark.createDataFrame( + sparkContext.parallelize(Seq( + Row(Row(1, 2, 3), Row(4, 5, 6)), + Row(Row(1, 2, 3), Row(4, 5, 6)), + Row(Row(1, 2, 3), null), + Row(Row(1, 2, 3), null))), + StructType(Seq( + StructField("a1", structType, nullable = false), + StructField("a2", structType, nullable = true)))) + + // add field from the same struct (a1) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), + Seq( + Row(2), + Row(2), + Row(2), + Row(2)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + // add field from another struct (a2) and then extract it + checkAnswerAndSchema( + df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), + Seq( + Row(5), + Row(5), + Row(null), + Row(null)), + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + test("should add and drop many fields") { + // stress test withField and dropFields APIs + + val nullableStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(null)) :: Row(Row(Row(0))) :: Nil), + StructType(Seq( + StructField("a1", StructType(Seq( + StructField("a2", StructType(Seq( + StructField("col0", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + + val maxNum = 100 + + // add many nested fields + checkAnswerAndSchema( + nullableStructLevel2.select(col("a1").withField("a2", (1 to maxNum).foldLeft(col("a1.a2")) { + (column, num) => column.withField(s"col$num", lit(num)) + }).as("a1")), + Row(Row(null)) :: Row(Row(Row(0 to maxNum: _*))) :: Nil, + StructType(Seq( + StructField("a1", StructType(Seq( + StructField("a2", StructType((0 to maxNum).map(num => + StructField(s"col$num", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + + /** + * You might be tempted to write the above query like so: + * + * {{{ + * nullableStructLevel2.select((1 to maxNum).foldLeft(col("a1")) { + * (column, num) => column.withField(s"a2.col$num", lit(num)) + * }.as("a1")), + * }}} + * + * This query leverages Column.withField API's ability to add/replace nested columns directly. + * However this will likely stall at the analyzer phase with as little as `maxNum = 3`, + * depending on how deeply nested the Column is that you're adding/replacing. + * If you're going to add multiple nested columns, you are better off adding them at the level + * of nesting that you want to add them. + */ + + // add and drop many nested fields + checkAnswerAndSchema( + nullableStructLevel2.select(col("a1").withField("a2", (1 to maxNum).foldLeft(col("a1.a2")) { + (column, num) => column.withField(s"col$num", lit(num)).dropFields(s"col${num - 1}") + }).as("a1")), + Row(Row(null)) :: Row(Row(Row(100))) :: Nil, + StructType(Seq( + StructField("a1", StructType(Seq( + StructField("a2", StructType(Seq( + StructField("col100", IntegerType, nullable = false))), + nullable = true))), + nullable = false)))) + + // add and drop many fields and then extract one of the newly added fields + checkAnswerAndSchema( + nullableStructLevel2.select(col("a1").withField("a2", (1 to maxNum).foldLeft(col("a1.a2")) { + (column, num) => column.withField(s"col$num", lit(num)).dropFields(s"col${num - 1}") + }).getField("a2").getField(s"col$maxNum").as("res")), + Row(null) :: Row(100) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + // TODO: add many fields at many different depths of nesting + } + + test("should move field up one level of nesting") { + val nullableStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(null)) :: Row(Row(Row(1, 2, 3))) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = true))), + nullable = true)))) + + // move a field up one level + checkAnswerAndSchema( + nullableStructLevel2.select( + col("a").withField("b", col("a.a.b")).dropFields("a.b").as("res")), + Row(Row(null, null)) :: Row(Row(Row(1, 3), 2)) :: Nil, + StructType(Seq( + StructField("res", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true), + StructField("b", IntegerType, nullable = true))), + nullable = true)))) + + // move a field up one level and then extract it + checkAnswerAndSchema( + nullableStructLevel2.select(col("a").withField("b", col("a.a.b")).getField("b").as("res")), + Row(null) :: Row(2) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + test("should be able to refer to newly added nested column") { + intercept[AnalysisException] { + structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) + }.getMessage should include("No such struct field d in a, b, c") + + checkAnswerAndSchema( + structLevel1 + .select($"a".withField("d", lit(4)).as("a")) + .select($"a".withField("e", $"a.d" + 1).as("a")), + Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false)))) + } + + test("should be able to drop newly added nested column") { + Seq( + structLevel1.select($"a".withField("d", lit(4)).dropFields("d").as("a")), + structLevel1 + .select($"a".withField("d", lit(4)).as("a")) + .select($"a".dropFields("d").as("a")) + ).foreach { query => + checkAnswerAndSchema( + query, + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", structType, nullable = false)))) + } + } + + test("should still be able to refer to dropped column within the same select statement") { + // we can still access the nested column even after dropping it within the select statement + checkAnswerAndSchema( + structLevel1.select($"a".dropFields("c").withField("z", $"a.c").as("a")), + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("z", IntegerType, nullable = false))), + nullable = false)))) + + // we can't access the nested column in another select statement after dropping it in a previous + // select statement + intercept[AnalysisException]{ + structLevel1 + .select($"a".dropFields("c").as("a")) + .select($"a".withField("z", $"a.c")).as("a") + }.getMessage should include("No such struct field c in a, b;") + } } From 2f16213ea0a658f481f8f745d759a813fc844dfc Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Sat, 19 Sep 2020 17:01:07 -0400 Subject: [PATCH 02/18] clean up --- .../expressions/complexTypeCreator.scala | 2 +- .../sql/catalyst/optimizer/ComplexTypes.scala | 19 +- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../optimizer/complexTypesSuite.scala | 383 ++++++++---------- .../spark/sql/ColumnExpressionSuite.scala | 356 ++++------------ .../org/apache/spark/sql/QueryTest.scala | 9 + .../sql/UpdateFieldsPerformanceSuite.scala | 229 +++++++++++ 7 files changed, 501 insertions(+), 499 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c7f02986ddaa..7ca625129bbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -594,7 +594,7 @@ case class DropField(name: String) extends StructFieldsOperation { } /** - * Updates fields in struct by name. + * Updates fields in a struct. */ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation]) extends Unevaluable { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index aa992dc8d236..7548861e716b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -41,14 +41,12 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) case GetStructField(updateFields: UpdateFields, ordinal, _) => - val expr = updateFields.newExprs(ordinal) val structExpr = updateFields.structExpr - if (isExprNestedInsideStruct(expr, structExpr)) { + updateFields.newExprs(ordinal) match { // if the struct itself is null, then any value extracted from it (expr) will be null // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) - expr - } else { - If(IsNull(ultimateStruct(structExpr)), Literal(null, expr.dataType), expr) + case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr + case expr => If(IsNull(ultimateStruct(structExpr)), Literal(null, expr.dataType), expr) } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => @@ -77,15 +75,4 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case e: UpdateFields => ultimateStruct(e.structExpr) case e => e } - - @scala.annotation.tailrec - private def isExprNestedInsideStruct(expr: Expression, struct: Expression): Boolean = { - require(struct.dataType.isInstanceOf[StructType]) - - expr match { - case e: GetStructField => - e.child.semanticEquals(struct) || isExprNestedInsideStruct(e.child, struct) - case _ => false - } - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index e56375d32f11..ac7a14716fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -109,8 +109,8 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveRedundantAliases, UnwrapCastInBinaryComparison, RemoveNoopOperators, - SimplifyExtractValueOps, CombineUpdateFields, + SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index c54bf8ae5e36..7ba24b13cda4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -44,8 +44,8 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - SimplifyExtractValueOps, - CombineUpdateFields) :: Nil + CombineUpdateFields, + SimplifyExtractValueOps) :: Nil } private val idAtt = ('id).long.notNull @@ -460,140 +460,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { private val nullableStructAttr = 'struct1.struct('a.int, 'b.int) private val testNullableStructRelation = LocalRelation(nullableStructAttr) - test("simplify GetStructField on UpdateFields that is not changing the attribute being " + - "extracted") { - def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, Seq(WithField("b", Literal(1)))), 0, Some("a")) as "outerAttr") - - checkRule( - query(testStructRelation), - testStructRelation.select(GetStructField('struct1, 0) as "outerAttr")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select(GetStructField('struct1, 0) as "outerAttr")) - } - - test("simplify GetStructField on UpdateFields that is changing the attribute being extracted") { - def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, Seq(WithField("b", Literal(1)))), 1, Some("b")) as "res") - - checkRule( - query(testStructRelation), - testStructRelation.select(Literal(1) as "res")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(1)) as "res")) - } - - test("simplify GetStructField on UpdateFields that is changing the attribute being extracted " + - "twice") { - def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, Seq(WithField("b", Literal(1)), WithField("b", Literal(2)))), - 1, Some("b")) as "outerAtt") - - checkRule( - query(testStructRelation), - testStructRelation.select(Literal(2) as "outerAtt")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "outerAtt")) - } - - test("collapse multiple GetStructField on the same UpdateFields") { - def query(relation: LocalRelation): LogicalPlan = relation - .select(UpdateFields('struct1, Seq(WithField("b", Literal(2)))) as "struct2") - .select( - GetStructField('struct2, 0, Some("a")) as "struct1A", - GetStructField('struct2, 1, Some("b")) as "struct1B") - - checkRule( - query(testStructRelation), - testStructRelation.select( - GetStructField('struct1, 0) as "struct1A", - Literal(2) as "struct1B")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select( - GetStructField('struct1, 0) as "struct1A", - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct1B")) - } - - test("collapse multiple GetStructField on different UpdateFields") { - def query(relation: LocalRelation): LogicalPlan = relation - .select( - UpdateFields('struct1, Seq(WithField("b", Literal(2)))) as "struct2", - UpdateFields('struct1, Seq(WithField("b", Literal(3)))) as "struct3") - .select( - GetStructField('struct2, 0, Some("a")) as "struct2A", - GetStructField('struct2, 1, Some("b")) as "struct2B", - GetStructField('struct3, 0, Some("a")) as "struct3A", - GetStructField('struct3, 1, Some("b")) as "struct3B") - - checkRule( - query(testStructRelation), - testStructRelation - .select( - GetStructField('struct1, 0) as "struct2A", - Literal(2) as "struct2B", - GetStructField('struct1, 0) as "struct3A", - Literal(3) as "struct3B")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation - .select( - GetStructField('struct1, 0) as "struct2A", - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct2B", - GetStructField('struct1, 0) as "struct3A", - If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as "struct3B")) - } - - test("simplify GetStructField on UpdateFields with multiple WithField and extract new column") { - def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, ('a' to 'z').zipWithIndex.map { case (char, i) => - WithField(char.toString, Literal(i)) - }), 2) as "res") - - checkRule( - query(testStructRelation), - testStructRelation.select(Literal(2) as "res")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "res")) - } - - test("Combine multiple UpdateFields and extract newly added field") { - def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields( - UpdateFields( - UpdateFields( - UpdateFields( - 'struct1, - Seq(WithField("col1", Literal(1)))), - Seq(WithField("col2", Literal(2)))), - Seq(WithField("col3", Literal(3)))), - Seq(WithField("col4", Literal(4)))), - 5, Some("col4")) as "res") - - checkRule( - query(testStructRelation), - testStructRelation.select(Literal(4) as "res")) - - checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(4)) as "res")) - } - - test("should correctly handle different WithField + DropField + GetStructField combinations") { + test("simplify GetStructField on basic UpdateFields") { def check(fieldOps: Seq[StructFieldsOperation], ordinal: Int, expected: Expression): Unit = { def query(relation: LocalRelation): LogicalPlan = relation.select(GetStructField(UpdateFields('struct1, fieldOps), ordinal).as("res")) @@ -605,59 +472,112 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule( query(testNullableStructRelation), testNullableStructRelation.select((expected match { - case expr: GetStructField if expr.child.fastEquals('struct1.expr) => expr + case expr: GetStructField => expr case expr => If(IsNull('struct1), Literal(null, expr.dataType), expr) }).as("res"))) } + // scalastyle:off line.size.limit + // add attribute, extract an attribute from the original struct check(WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: Nil, 1, GetStructField('struct1, 1)) + // add attribute, extract added attribute + check(WithField("c", Literal(3)) :: Nil, 2, Literal(3)) + + // replace attribute, extract an attribute from the original struct + check(WithField("a", Literal(1)) :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("b", Literal(2)) :: Nil, 0, GetStructField('struct1, 0)) + // replace attribute, extract replaced attribute + check(WithField("a", Literal(1)) :: Nil, 0, Literal(1)) + check(WithField("b", Literal(2)) :: Nil, 1, Literal(2)) + + // add multiple attributes, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 1, GetStructField('struct1, 1)) + // add multiple attributes, extract newly added attribute + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 2, Literal(4)) + check(WithField("c", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 2, Literal(3)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 3, Literal(4)) + check(WithField("d", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 2, Literal(4)) + check(WithField("d", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 3, Literal(3)) // drop attribute, extract an attribute from the original struct check(DropField("b") :: Nil, 0, GetStructField('struct1, 0)) check(DropField("a") :: Nil, 0, GetStructField('struct1, 1)) // drop attribute, add attribute, extract an attribute from the original struct - check(DropField("b") :: WithField("c", Literal(2)) :: Nil, 0, GetStructField('struct1, 0)) - check(DropField("a") :: WithField("c", Literal(2)) :: Nil, 0, GetStructField('struct1, 1)) + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 1)) + // drop attribute, add attribute, extract added attribute + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) // add attribute, drop attribute, extract an attribute from the original struct check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + // add attribute, drop attribute, extract added attribute + check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 1, Literal(3)) + check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 1, Literal(3)) + + // replace attribute, drop same attribute, extract an attribute from the original struct check(WithField("b", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) check(WithField("a", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) // add attribute, drop same attribute, extract an attribute from the original struct check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 1, GetStructField('struct1, 1)) - // add attribute, drop another attribute, extract added attribute + // replace attribute, drop another attribute, extract added attribute check(WithField("b", Literal(3)) :: DropField("a") :: Nil, 0, Literal(3)) check(WithField("a", Literal(3)) :: DropField("b") :: Nil, 0, Literal(3)) + // drop attribute, add same attribute, extract attribute from the original struct + check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 0, GetStructField('struct1, 1)) // drop attribute, add same attribute, extract added attribute - check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) - check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 1, Literal(3)) check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 1, Literal(3)) - check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) - check(WithField("c", Literal(3)) :: DropField("c") :: WithField("c", Literal(4)) :: Nil, 2, - Literal(4)) - // drop earlier attribute, add attribute, extract added attribute - check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + // drop non-existent attribute, add same attribute, extract attribute from the original struct + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 1, GetStructField('struct1, 1)) + // drop non-existent attribute, add same attribute, extract added attribute + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) - // drop later attribute, add attribute, extract added attribute - check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + // scalastyle:on line.size.limit } - test("should only transform GetStructField calls if not GetStructField on same struct") { + test("simplify GetStructField that is extracting a field nested inside a struct") { val struct2 = 'struct2.struct('b.int) val testStructRelation = LocalRelation(structAttr, struct2) val testNullableStructRelation = LocalRelation(nullableStructAttr, struct2) + // if the field being extracted is from the same struct that UpdateFields is modifying, + // we can just return GetStructField in both the non-nullable and nullable struct scenario + + def addFieldFromSameStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField( + UpdateFields('struct1, WithField("b", GetStructField('struct1, 0)) :: Nil), 1).as("res")) + + checkRule( + addFieldFromSameStructAndThenExtractIt(testStructRelation), + testStructRelation.select(GetStructField('struct1, 0).as("res"))) + + checkRule( + addFieldFromSameStructAndThenExtractIt(testNullableStructRelation), + testNullableStructRelation.select(GetStructField('struct1, 0).as("res"))) + + // if the field being extracted is from a different struct than the one UpdateFields is + // modifying, we must return GetStructField wrapped in If(IsNull(struct), null, GetStructField) + // in the nullable struct scenario + def addFieldFromAnotherStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = relation.select(GetStructField( - UpdateFields('struct1, Seq(WithField("b", GetStructField('struct2, 0)))), 1).as("res")) + UpdateFields('struct1, WithField("b", GetStructField('struct2, 0)) :: Nil), 1).as("res")) checkRule( addFieldFromAnotherStructAndThenExtractIt(testStructRelation), @@ -667,73 +587,126 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { addFieldFromAnotherStructAndThenExtractIt(testNullableStructRelation), testNullableStructRelation.select( If(IsNull('struct1), Literal(null, IntegerType), GetStructField('struct2, 0)).as("res"))) + } - def addFieldFromSameStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = - relation.select(GetStructField( - UpdateFields('struct1, Seq(WithField("b", GetStructField('struct1, 0)))), 1).as("res")) + test("simplify GetStructField on nested UpdateFields") { + def query(relation: LocalRelation, ordinal: Int): LogicalPlan = { + val nestedUpdateFields = + UpdateFields( + UpdateFields( + UpdateFields( + UpdateFields( + 'struct1, + WithField("c", Literal(1)) :: Nil), + WithField("d", Literal(2)) :: Nil), + WithField("e", Literal(3)) :: Nil), + WithField("f", Literal(4)) :: Nil) + + relation.select(GetStructField(nestedUpdateFields, ordinal) as "res") + } + + // extract newly added field checkRule( - addFieldFromSameStructAndThenExtractIt(testStructRelation), - testStructRelation.select(GetStructField('struct1, 0).as("res"))) + query(testStructRelation, 5), + testStructRelation.select(Literal(4) as "res")) checkRule( - addFieldFromSameStructAndThenExtractIt(testNullableStructRelation), - testNullableStructRelation.select(GetStructField('struct1, 0).as("res"))) + query(testNullableStructRelation, 5), + testNullableStructRelation.select( + If(IsNull('struct1), Literal(null, IntegerType), Literal(4)) as "res")) + + // extract field from original struct + + checkRule( + query(testStructRelation, 0), + testStructRelation.select(GetStructField('struct1, 0) as "res")) + + checkRule( + query(testNullableStructRelation, 0), + testNullableStructRelation.select(GetStructField('struct1, 0) as "res")) } - test("simplify if extract value is from inside struct") { - val nullableStructAttr = 'struct1a.struct( - 'struct2a.struct( - 'struct3a.struct('a.int, 'b.int))) - val structAttr = nullableStructAttr.withNullability(false) - val testStructRelation = LocalRelation(structAttr) - val testNullableStructRelation = LocalRelation(nullableStructAttr) + test("simplify multiple GetStructField on the same UpdateFields") { + def query(relation: LocalRelation): LogicalPlan = relation + .select(UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2") + .select( + GetStructField('struct2, 0, Some("a")) as "struct1A", + GetStructField('struct2, 1, Some("b")) as "struct1B") - Seq(testStructRelation, testNullableStructRelation).foreach { relation => - checkRule( - relation.select(GetStructField( - UpdateFields('struct1a, Seq( - WithField("struct2b", GetStructField(GetStructField('struct1a, 0), 0)))), 1).as("res")), - relation.select(GetStructField(GetStructField('struct1a, 0), 0).as("res"))) - } + checkRule( + query(testStructRelation), + testStructRelation.select( + GetStructField('struct1, 0) as "struct1A", + Literal(2) as "struct1B")) + + checkRule( + query(testNullableStructRelation), + testNullableStructRelation.select( + GetStructField('struct1, 0) as "struct1A", + If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct1B")) + } + + test("simplify multiple GetStructField on different UpdateFields") { + def query(relation: LocalRelation): LogicalPlan = relation + .select( + UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2", + UpdateFields('struct1, WithField("b", Literal(3)) :: Nil) as "struct3") + .select( + GetStructField('struct2, 0, Some("a")) as "struct2A", + GetStructField('struct2, 1, Some("b")) as "struct2B", + GetStructField('struct3, 0, Some("a")) as "struct3A", + GetStructField('struct3, 1, Some("b")) as "struct3B") + + checkRule( + query(testStructRelation), + testStructRelation + .select( + GetStructField('struct1, 0) as "struct2A", + Literal(2) as "struct2B", + GetStructField('struct1, 0) as "struct3A", + Literal(3) as "struct3B")) + + checkRule( + query(testNullableStructRelation), + testNullableStructRelation + .select( + GetStructField('struct1, 0) as "struct2A", + If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct2B", + GetStructField('struct1, 0) as "struct3A", + If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as "struct3B")) } - test("simplify add multiple nested fields") { + test("simplify add multiple nested fields to struct") { // this scenario is possible if users add multiple nested columns via the Column.withField API // ideally, users should not be doing this. val nullableStructLevel2 = LocalRelation( 'a1.struct( 'a2.struct('a3.int)).withNullability(false)) - val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", - UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + val query = { + val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", + UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) - val query = nullableStructLevel2.select( - UpdateFields( - addB3toA1A2, - Seq(WithField("a2", UpdateFields( - GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) + nullableStructLevel2.select( + UpdateFields( + addB3toA1A2, + Seq(WithField("a2", UpdateFields( + GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) + } val expected = nullableStructLevel2.select( UpdateFields('a1, Seq( - WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", 2)))), - WithField("a2", UpdateFields(GetStructField('a1, 0), Seq( - WithField("b3", 2), - WithField("c3", 3)))) - )).as("a1")) - - // TODO: how to make this a reality - val idealExpected = nullableStructLevel2.select( - UpdateFields('a1, Seq( - WithField("a2", UpdateFields(GetStructField('a1, 0), Seq( - WithField("b3", 2), - WithField("c3", 3)))) + // scalastyle:off line.size.limit + WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: Nil)), + WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: WithField("c3", 3) :: Nil)) + // scalastyle:on line.size.limit )).as("a1")) checkRule(query, expected) } - test("simplify drop multiple nested fields") { + test("simplify drop multiple nested fields in struct") { // this scenario is possible if users drop multiple nested columns via the Column.dropFields API // ideally, users should not be doing this. val df = LocalRelation( @@ -741,14 +714,16 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { 'a2.struct('a3.int, 'b3.int, 'c3.int).withNullability(false) ).withNullability(false)) - val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( - GetStructField('a1, 0), Seq(DropField("b3")))))) + val query = { + val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( + GetStructField('a1, 0), Seq(DropField("b3")))))) - val query = df.select( - UpdateFields( - dropA1A2B, - Seq(WithField("a2", UpdateFields( - GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) + df.select( + UpdateFields( + dropA1A2B, + Seq(WithField("a2", UpdateFields( + GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) + } val expected = df.select( UpdateFields('a1, Seq( @@ -756,12 +731,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) )).as("a1")) - // TODO: how to make this a reality - val idealExpected = df.select( - UpdateFields('a1, Seq( - WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) - )).as("a1")) - checkRule(query, expected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 48872de9ad43..6423ede59350 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -922,18 +922,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(inSet.sql === "('a' IN ('a', 'b'))") } - def checkAnswerAndSchema( - df: => DataFrame, - expectedAnswer: Seq[Row], - expectedSchema: StructType): Unit = { - - df.explain(true) - df.printSchema() - df.show(false) - checkAnswer(df, expectedAnswer) - assert(df.schema == expectedSchema) - } - private lazy val structType = StructType(Seq( StructField("a", IntegerType, nullable = false), StructField("b", IntegerType, nullable = true), @@ -970,15 +958,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) - private lazy val nullStructLevel3: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), - StructType(Seq( - StructField("a", StructType(Seq( - StructField("a", StructType(Seq( - StructField("a", structType, nullable = true))), - nullable = true))), - nullable = false)))) - test("withField should throw an exception if called on a non-StructType column") { intercept[AnalysisException] { testData.withColumn("key", $"key".withField("a", lit(2))) @@ -1030,7 +1009,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field with no name") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("", lit(4))), Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( @@ -1043,7 +1022,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field to struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(4))), Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( @@ -1056,7 +1035,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field to null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))), Row(null) :: Nil, StructType(Seq( @@ -1069,7 +1048,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field to nested null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), Row(Row(null)) :: Nil, StructType( @@ -1084,7 +1063,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add null field to struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), Row(Row(1, null, 3, null)) :: Nil, StructType(Seq( @@ -1097,7 +1076,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add multiple fields to struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), Row(Row(1, null, 3, 4, 5)) :: Nil, StructType(Seq( @@ -1115,7 +1094,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) ).foreach { df => - checkAnswerAndSchema( + checkAnswer( df, Row(Row(Row(1, null, 3, 4))) :: Nil, StructType( @@ -1131,7 +1110,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field to deeply nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, StructType(Seq( @@ -1148,7 +1127,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace field in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("b", lit(2))), Row(Row(1, 2, 3)) :: Nil, StructType(Seq( @@ -1160,7 +1139,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace field in null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), Row(null) :: Nil, StructType(Seq( @@ -1172,7 +1151,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace field in nested null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), Row(Row(null)) :: Nil, StructType( @@ -1186,7 +1165,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace field with null value in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), Row(Row(1, null, null)) :: Nil, StructType(Seq( @@ -1198,7 +1177,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace multiple fields in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), Row(Row(10, 20, 3)) :: Nil, StructType(Seq( @@ -1214,7 +1193,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) ).foreach { df => - checkAnswerAndSchema( + checkAnswer( df, Row(Row(Row(1, 2, 3))) :: Nil, StructType(Seq( @@ -1229,7 +1208,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace field in deeply nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), Row(Row(Row(Row(1, 2, 3)))) :: Nil, StructType(Seq( @@ -1254,7 +1233,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("b", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("b", lit(100))), Row(Row(1, 100, 100)) :: Nil, StructType(Seq( @@ -1266,7 +1245,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace fields in struct in given order") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), Row(Row(1, 20, 3)) :: Nil, StructType(Seq( @@ -1278,7 +1257,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field and then replace same field in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), Row(Row(1, null, 3, 5)) :: Nil, StructType(Seq( @@ -1302,7 +1281,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), Row(Row(Row(1, 2, 3))) :: Nil, StructType(Seq( @@ -1329,7 +1308,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), Row(Row(2, 1)) :: Nil, StructType(Seq( @@ -1338,7 +1317,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("B", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), Row(Row(1, 2)) :: Nil, StructType(Seq( @@ -1351,7 +1330,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( @@ -1361,7 +1340,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("A", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( @@ -1389,7 +1368,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), Row(Row(Row(2, 1), Row(1, 1))) :: Nil, StructType(Seq( @@ -1404,7 +1383,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), Row(Row(Row(1, 1), Row(2, 1))) :: Nil, StructType(Seq( @@ -1479,25 +1458,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-32641: extracting field from non-null struct column after withField should return " + "field value") { // extract newly added field - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("d")), Row(4) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) // extract newly replaced field - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("a")), Row(4) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) // add new field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("c")), Row(3):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) // replace field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("c")), Row(3):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) @@ -1506,25 +1485,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-32641: extracting field from null struct column after withField should return " + "null if the original struct was null") { // extract newly added field - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("d")), Row(null) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // extract newly replaced field - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("a")), Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // add new field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("c")), Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // replace field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("c")), Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) @@ -1537,25 +1516,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("a", structType, nullable = true)))) // extract newly added field - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("d", lit(4)).getField("d")), Row(4) :: Row(null) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // extract newly replaced field - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("a", lit(4)).getField("a")), Row(4) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // add new field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("d", lit(4)).getField("c")), Row(3) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // replace field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("a", lit(4)).getField("c")), Row(3) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) @@ -1605,7 +1584,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop field in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.dropFields("b")), Row(Row(1, 3)) :: Nil, StructType(Seq( @@ -1616,7 +1595,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop field in null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".dropFields("b")), Row(null) :: Nil, StructType(Seq( @@ -1631,7 +1610,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.withColumn("a", $"a".dropFields("b", "c")), structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) ).foreach { df => - checkAnswerAndSchema( + checkAnswer( df, Row(Row(1)) :: Nil, StructType(Seq( @@ -1648,7 +1627,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop field in nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel2.withColumn("a", 'a.dropFields("a.b")), Row(Row(Row(1, 3))) :: Nil, StructType( @@ -1661,7 +1640,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop multiple fields in nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), Row(Row(Row(1))) :: Nil, StructType( @@ -1673,7 +1652,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop field in nested null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel2.withColumn("a", $"a".dropFields("a.b")), Row(Row(null)) :: Nil, StructType( @@ -1686,7 +1665,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop multiple fields in nested null struct") { - checkAnswerAndSchema( + checkAnswer( nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")), Row(Row(null)) :: Nil, StructType( @@ -1698,7 +1677,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop field in deeply nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel3.withColumn("a", 'a.dropFields("a.a.b")), Row(Row(Row(Row(1, 3)))) :: Nil, StructType(Seq( @@ -1722,7 +1701,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("b", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.dropFields("b")), Row(Row(1)) :: Nil, StructType(Seq( @@ -1733,7 +1712,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), Row(Row(1)) :: Nil, StructType(Seq( @@ -1741,7 +1720,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("B", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), Row(Row(1)) :: Nil, StructType(Seq( @@ -1753,7 +1732,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should not drop field in struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), Row(Row(1, 1)) :: Nil, StructType(Seq( @@ -1762,7 +1741,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("B", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), Row(Row(1, 1)) :: Nil, StructType(Seq( @@ -1775,7 +1754,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("dropFields should drop nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), Row(Row(Row(1), Row(1, 1))) :: Nil, StructType(Seq( @@ -1789,7 +1768,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), Row(Row(Row(1, 1), Row(1))) :: Nil, StructType(Seq( @@ -1818,7 +1797,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("dropFields should drop only fields that exist") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.dropFields("d")), Row(Row(1, null, 3)) :: Nil, StructType(Seq( @@ -1828,7 +1807,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("c", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.dropFields("b", "d")), Row(Row(1, 3)) :: Nil, StructType(Seq( @@ -1837,7 +1816,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("c", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")), Row(Row(Row(1, 3))) :: Nil, StructType( @@ -1858,7 +1837,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("b", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".dropFields("a.b", "b")), Row(Row(Row(1, 3))) :: Nil, StructType(Seq( @@ -1922,7 +1901,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("should correctly handle different dropField + withField + getField combinations") { - // TODO: maybe make another suite, where we run through these tests with and without optimizer val structType = StructType(Seq( StructField("a", IntegerType, nullable = false), StructField("b", IntegerType, nullable = false))) @@ -1947,17 +1925,17 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { def query(df: DataFrame): DataFrame = df.select(fieldOps(col("a")).getField(getFieldName).as("res")) - checkAnswerAndSchema( + checkAnswer( query(structLevel1), Row(expectedValue.orNull) :: Nil, StructType(Seq(StructField("res", IntegerType, nullable = expectedValue.isEmpty)))) - checkAnswerAndSchema( + checkAnswer( query(nullStructLevel1), Row(null) :: Nil, StructType(Seq(StructField("res", IntegerType, nullable = true)))) - checkAnswerAndSchema( + checkAnswer( query(nullableStructLevel1), Row(expectedValue.orNull) :: Row(null) :: Nil, StructType(Seq(StructField("res", IntegerType, nullable = true)))) @@ -1975,10 +1953,13 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { // replace attribute, extract an attribute from the original struct check(_.withField("b", lit(3)), "a", Some(1)) + check(_.withField("a", lit(3)), "b", Some(2)) // replace attribute, extract replaced attribute check(_.withField("b", lit(3)), "b", Some(3)) check(_.withField("b", lit(null).cast(IntegerType)), "b", None) + check(_.withField("a", lit(3)), "a", Some(3)) + check(_.withField("a", lit(null).cast(IntegerType)), "a", None) // drop attribute, extract an attribute from the original struct check(_.dropFields("b"), "a", Some(1)) @@ -1988,14 +1969,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { check(_.dropFields("b").withField("c", lit(3)), "a", Some(1)) check(_.dropFields("a").withField("c", lit(3)), "b", Some(2)) + // drop attribute, add another attribute, extract added attribute + check(_.dropFields("a").withField("c", lit(3)), "c", Some(3)) + check(_.dropFields("b").withField("c", lit(3)), "c", Some(3)) + // add attribute, drop attribute, extract an attribute from the original struct check(_.withField("c", lit(3)).dropFields("a"), "b", Some(2)) check(_.withField("c", lit(3)).dropFields("b"), "a", Some(1)) + + // add attribute, drop another attribute, extract added attribute + check(_.withField("c", lit(3)).dropFields("a"), "c", Some(3)) + check(_.withField("c", lit(3)).dropFields("b"), "c", Some(3)) + + // replace attribute, drop same attribute, extract an attribute from the original struct check(_.withField("b", lit(3)).dropFields("b"), "a", Some(1)) check(_.withField("a", lit(3)).dropFields("a"), "b", Some(2)) // add attribute, drop same attribute, extract an attribute from the original struct check(_.withField("c", lit(3)).dropFields("c"), "a", Some(1)) + check(_.withField("c", lit(3)).dropFields("c"), "b", Some(2)) // add attribute, drop another attribute, extract added attribute check(_.withField("b", lit(3)).dropFields("a"), "b", Some(3)) @@ -2004,198 +1996,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { check(_.withField("a", lit(null).cast(IntegerType)).dropFields("b"), "a", None) // drop attribute, add same attribute, extract added attribute - check(_.dropFields("a").withField("c", lit(3)), "c", Some(3)) - check(_.dropFields("b").withField("c", lit(3)), "c", Some(3)) check(_.dropFields("b").withField("b", lit(3)), "b", Some(3)) check(_.dropFields("a").withField("a", lit(3)), "a", Some(3)) check(_.dropFields("b").withField("b", lit(null).cast(IntegerType)), "b", None) check(_.dropFields("a").withField("a", lit(null).cast(IntegerType)), "a", None) check(_.dropFields("c").withField("c", lit(3)), "c", Some(3)) - check(_.withField("c", lit(3)).dropFields("c").withField("c", lit(4)), "c", Some(4)) - } - test("non-nullable field from a struct being added to a non-nullable struct and then extracted") { - val df = spark.createDataFrame( - sparkContext.parallelize(Seq(Row(Row(1, 2, 3), Row(4, 5, 6)))), - StructType(Seq( - StructField("a1", structType, nullable = false), - StructField("a2", structType, nullable = false)))) - - // add field from the same struct (a1) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), - Seq(Row(2)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - - // add field from another struct (a2) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), - Seq(Row(5)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - } - - test("nullable field from a struct being added to a nullable struct and then extracted") { - val df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(Row(1, 2, 3), Row(4, 5, 6)), - Row(null, Row(4, 5, 6)), - Row(Row(1, 2, 3), null), - Row(null, null))), - StructType(Seq( - StructField("a1", structType, nullable = true), - StructField("a2", structType, nullable = true)))) - - // TODO: This doesn't fit the test title? - // add field from the same struct (a1) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), - Seq( - Row(2), - Row(null), - Row(2), - Row(null)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - - // add field from another struct (a2) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), - Seq( - Row(5), - Row(null), - Row(null), - Row(null)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - } - - test("non-nullable field from a struct being added to a nullable struct and then extracted") { - val df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(Row(1, 2, 3), Row(4, 5, 6)), - Row(null, Row(4, 5, 6)), - Row(Row(1, 2, 3), Row(4, 5, 6)), - Row(null, Row(4, 5, 6)))), - StructType(Seq( - StructField("a1", structType, nullable = true), - StructField("a2", structType, nullable = false)))) - - // add field from the same struct (a1) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), - Seq( - Row(2), - Row(null), - Row(2), - Row(null)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - - // add field from another struct (a2) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), - Seq( - Row(5), - Row(5), - Row(null), - Row(null)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - } - - test("nullable field from a struct being added to a non-nullable struct and then extracted") { - val df = spark.createDataFrame( - sparkContext.parallelize(Seq( - Row(Row(1, 2, 3), Row(4, 5, 6)), - Row(Row(1, 2, 3), Row(4, 5, 6)), - Row(Row(1, 2, 3), null), - Row(Row(1, 2, 3), null))), - StructType(Seq( - StructField("a1", structType, nullable = false), - StructField("a2", structType, nullable = true)))) - - // add field from the same struct (a1) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a1.b").getField("d").as("res")), - Seq( - Row(2), - Row(2), - Row(2), - Row(2)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - - // add field from another struct (a2) and then extract it - checkAnswerAndSchema( - df.select($"a1".withField("d", $"a2.b").getField("d").as("res")), - Seq( - Row(5), - Row(5), - Row(null), - Row(null)), - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - } - - test("should add and drop many fields") { - // stress test withField and dropFields APIs - - val nullableStructLevel2: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(Row(null)) :: Row(Row(Row(0))) :: Nil), - StructType(Seq( - StructField("a1", StructType(Seq( - StructField("a2", StructType(Seq( - StructField("col0", IntegerType, nullable = false))), - nullable = true))), - nullable = false)))) - - val maxNum = 100 - - // add many nested fields - checkAnswerAndSchema( - nullableStructLevel2.select(col("a1").withField("a2", (1 to maxNum).foldLeft(col("a1.a2")) { - (column, num) => column.withField(s"col$num", lit(num)) - }).as("a1")), - Row(Row(null)) :: Row(Row(Row(0 to maxNum: _*))) :: Nil, - StructType(Seq( - StructField("a1", StructType(Seq( - StructField("a2", StructType((0 to maxNum).map(num => - StructField(s"col$num", IntegerType, nullable = false))), - nullable = true))), - nullable = false)))) - - /** - * You might be tempted to write the above query like so: - * - * {{{ - * nullableStructLevel2.select((1 to maxNum).foldLeft(col("a1")) { - * (column, num) => column.withField(s"a2.col$num", lit(num)) - * }.as("a1")), - * }}} - * - * This query leverages Column.withField API's ability to add/replace nested columns directly. - * However this will likely stall at the analyzer phase with as little as `maxNum = 3`, - * depending on how deeply nested the Column is that you're adding/replacing. - * If you're going to add multiple nested columns, you are better off adding them at the level - * of nesting that you want to add them. - */ - - // add and drop many nested fields - checkAnswerAndSchema( - nullableStructLevel2.select(col("a1").withField("a2", (1 to maxNum).foldLeft(col("a1.a2")) { - (column, num) => column.withField(s"col$num", lit(num)).dropFields(s"col${num - 1}") - }).as("a1")), - Row(Row(null)) :: Row(Row(Row(100))) :: Nil, - StructType(Seq( - StructField("a1", StructType(Seq( - StructField("a2", StructType(Seq( - StructField("col100", IntegerType, nullable = false))), - nullable = true))), - nullable = false)))) - - // add and drop many fields and then extract one of the newly added fields - checkAnswerAndSchema( - nullableStructLevel2.select(col("a1").withField("a2", (1 to maxNum).foldLeft(col("a1.a2")) { - (column, num) => column.withField(s"col$num", lit(num)).dropFields(s"col${num - 1}") - }).getField("a2").getField(s"col$maxNum").as("res")), - Row(null) :: Row(100) :: Nil, - StructType(Seq(StructField("res", IntegerType, nullable = true)))) - - // TODO: add many fields at many different depths of nesting + // add attribute, drop same attribute, add same attribute again, extract added attribute + check(_.withField("c", lit(3)).dropFields("c").withField("c", lit(4)), "c", Some(4)) } test("should move field up one level of nesting") { @@ -2207,7 +2015,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = true)))) // move a field up one level - checkAnswerAndSchema( + checkAnswer( nullableStructLevel2.select( col("a").withField("b", col("a.a.b")).dropFields("a.b").as("res")), Row(Row(null, null)) :: Row(Row(Row(1, 3), 2)) :: Nil, @@ -2221,7 +2029,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = true)))) // move a field up one level and then extract it - checkAnswerAndSchema( + checkAnswer( nullableStructLevel2.select(col("a").withField("b", col("a.a.b")).getField("b").as("res")), Row(null) :: Row(2) :: Nil, StructType(Seq(StructField("res", IntegerType, nullable = true)))) @@ -2232,7 +2040,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) }.getMessage should include("No such struct field d in a, b, c") - checkAnswerAndSchema( + checkAnswer( structLevel1 .select($"a".withField("d", lit(4)).as("a")) .select($"a".withField("e", $"a.d" + 1).as("a")), @@ -2254,7 +2062,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".withField("d", lit(4)).as("a")) .select($"a".dropFields("d").as("a")) ).foreach { query => - checkAnswerAndSchema( + checkAnswer( query, Row(Row(1, null, 3)) :: Nil, StructType(Seq( @@ -2264,7 +2072,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("should still be able to refer to dropped column within the same select statement") { // we can still access the nested column even after dropping it within the select statement - checkAnswerAndSchema( + checkAnswer( structLevel1.select($"a".dropFields("c").withField("z", $"a.c").as("a")), Row(Row(1, null, 3)) :: Nil, StructType(Seq( @@ -2274,8 +2082,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("z", IntegerType, nullable = false))), nullable = false)))) - // we can't access the nested column in another select statement after dropping it in a previous - // select statement + // we can't access the nested column in subsequent select statement after dropping it in a + // previous select statement intercept[AnalysisException]{ structLevel1 .select($"a".dropFields("c").as("a")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8469216901b0..12989b342dc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.columnar.InMemoryRelation +import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -159,6 +160,14 @@ abstract class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } + protected def checkAnswer( + df: => DataFrame, + expectedAnswer: Seq[Row], + expectedSchema: StructType): Unit = { + checkAnswer(df, expectedAnswer) + assert(df.schema == expectedSchema) + } + /** * Runs the plan and makes sure the answer is within absTol of the expected result. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala new file mode 100644 index 000000000000..a839d079f1e0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { + + private def colName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" + + private def nestedStructType( + depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = { + if (depths.length == 1) { + StructType(colNums.map { colNum => + StructField(colName(depths.head, colNum), IntegerType, nullable = false) + }) + } else { + val depth = depths.head + val fields = colNums.foldLeft(Seq.empty[StructField]) { + case (structFields, colNum) if colNum == 0 => + val nested = nestedStructType(depths.tail, colNums, nullable = nullable) + structFields :+ StructField(colName(depth, colNum), nested, nullable = nullable) + case (structFields, colNum) => + structFields :+ StructField(colName(depth, colNum), IntegerType, nullable = false) + } + StructType(fields) + } + } + + private def nestedRow(depths: Seq[Int], colNums: Seq[Int]): Row = { + if (depths.length == 1) { + Row.fromSeq(colNums) + } else { + val values = colNums.foldLeft(Seq.empty[Any]) { + case (values, colNum) if colNum == 0 => values :+ nestedRow(depths.tail, colNums) + case (values, colNum) => values :+ colNum + } + Row.fromSeq(values) + } + } + + /** + * Utility function for generating a DataFrame with nested columns. + * + * @param depth: The depth to which to create nested columns. + * @param numColsAtEachDepth: The number of columns to create at each depth. The columns names + * are in the format of nested${depth}Col${index}. The value of each + * column will be its index at that depth, or if the index of the column + * is 0, then the value could also be a struct. + * @param nullable: This value is used to set the nullability of StructType columns. + */ + private def nestedDf( + depth: Int, numColsAtEachDepth: Int, nullable: Boolean = false): DataFrame = { + require(depth > 0) + require(numColsAtEachDepth > 0) + + val depths = 1 to depth + val colNums = 0 until numColsAtEachDepth + val nestedColumn = nestedRow(depths, colNums) + val nestedColumnDataType = nestedStructType(depths, colNums, nullable) + + spark.createDataFrame( + sparkContext.parallelize(Row(nestedColumn) :: Nil), + StructType(Seq(StructField(colName(0, 0), nestedColumnDataType, nullable = nullable)))) + } + + test("nestedDf should generate nested DataFrames") { + checkAnswer( + nestedDf(1, 1), + Row(Row(0)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(1, 2), + Row(Row(0, 1)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(2, 1), + Row(Row(Row(0))) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(2, 2), + Row(Row(Row(0, 1), 1)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(2, 2, nullable = true), + Row(Row(Row(0, 1), 1)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = true), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = true)))) + } + + // simulates how a user would add/drop nested fields in a performant manner + private def addDropNestedColumns( + column: Column, + depths: Seq[Int], + colNumsToAdd: Seq[Int] = Seq.empty, + colNumsToDrop: Seq[Int] = Seq.empty): Column = { + val depth = depths.head + + // drop columns at the current depth + val dropped = if (colNumsToDrop.nonEmpty) { + column.dropFields(colNumsToDrop.map(num => colName(depth, num)): _*) + } else column + + // add columns at the current depth + val added = colNumsToAdd.foldLeft(dropped) { + (res, num) => res.withField(colName(depth, num), lit(num)) + } + + if (depths.length == 1) { + added + } else { + // add/drop columns at the next depth + val nestedColumn = col((0 to depth).map(d => s"`${colName(d, 0)}`").mkString(".")) + added.withField( + colName(depth, 0), + addDropNestedColumns(nestedColumn, depths.tail, colNumsToAdd, colNumsToDrop)) + } + } + + // check both nullable and non-nullable struct code paths are performant + Seq(true, false).foreach { nullable => + test("should add 5 columns at 20 different depths of nesting for a total of 100 columns " + + s"added, nullable = $nullable") { + val maxDepth = 20 + + // dataframe with nested*Col0 to nested*Col4 at each of 20 depths + val inputDf = nestedDf(maxDepth, 5, nullable = nullable) + + // add nested*Col5 through nested*Col9 at each depth + val resultDf = inputDf.select(addDropNestedColumns( + column = col(colName(0, 0)), + depths = 1 to maxDepth, + colNumsToAdd = 5 to 9).as("nested0Col0")) + + // dataframe with nested*Col0 to nested*Col9 at each of 20 depths + val expectedDf = nestedDf(maxDepth, 10, nullable = nullable) + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + + test("should drop 5 columns at 20 different depths of nesting for a total of 100 columns " + + s"dropped, nullable = $nullable") { + val maxDepth = 20 + + // dataframe with nested*Col0 to nested*Col9 at each of 20 depths + val inputDf = nestedDf(maxDepth, 10, nullable = nullable) + + // drop nested*Col5 to nested*Col9 at each of 20 depths + val resultDf = inputDf.select(addDropNestedColumns( + column = col(colName(0, 0)), + depths = 1 to maxDepth, + colNumsToDrop = 5 to 9).as("nested0Col0")) + + // dataframe with nested*Col0 to nested*Col4 at each of 20 depths + val expectedDf = nestedDf(maxDepth, 5, nullable = nullable) + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + + test("should add 5 columns and drop 5 columns at 20 different depths of nesting for a total " + + s"of 200 columns added/dropped, nullable = $nullable") { + val maxDepth = 20 + + // dataframe with nested*Col0 to nested*Col9 at each of 20 depths + val inputDf = nestedDf(maxDepth, 10, nullable = nullable) + + // add nested*Col10 through nested*Col14 at each depth + // drop nested*Col5 through nested*Col9 at each depth + val resultDf = inputDf.select(addDropNestedColumns( + column = col(colName(0, 0)), + depths = 1 to maxDepth, + colNumsToAdd = 10 to 14, + colNumsToDrop = 5 to 9).as("nested0Col0")) + + // dataframe with nested*Col0 to nested*Col4 and nested*Col10 to nested*Col14 + // at each of 20 depths + val expectedDf = { + val depths = 1 to maxDepth + val numCols = (0 to 4) ++ (10 to 14) + val nestedColumn = nestedRow(depths, numCols) + val nestedColumnDataType = nestedStructType(depths, numCols, nullable = nullable) + + spark.createDataFrame( + sparkContext.parallelize(Row(nestedColumn) :: Nil), + StructType(Seq(StructField(colName(0, 0), nestedColumnDataType, nullable = nullable)))) + } + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + } +} From 331af594a7b12c061a8aa5f01d8b6df83a707149 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Tue, 22 Sep 2020 17:42:33 -0400 Subject: [PATCH 03/18] add guard and remove ultimateStruct method --- .../sql/catalyst/optimizer/ComplexTypes.scala | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 7548861e716b..0e63cecdcecb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -40,13 +40,13 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(updateFields: UpdateFields, ordinal, _) => - val structExpr = updateFields.structExpr - updateFields.newExprs(ordinal) match { + case GetStructField(u: UpdateFields, ordinal, _) if !u.structExpr.isInstanceOf[UpdateFields] => + val structExpr = u.structExpr + u.newExprs(ordinal) match { // if the struct itself is null, then any value extracted from it (expr) will be null // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr - case expr => If(IsNull(ultimateStruct(structExpr)), Literal(null, expr.dataType), expr) + case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr) } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => @@ -69,10 +69,4 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { case GetMapValue(CreateMap(elems, _), key) => CaseKeyWhen(key, elems) } } - - @scala.annotation.tailrec - private def ultimateStruct(expr: Expression): Expression = expr match { - case e: UpdateFields => ultimateStruct(e.structExpr) - case e => e - } } From 650d366b71982ff496b6f57af66fcf82d77603bf Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Tue, 22 Sep 2020 18:03:39 -0400 Subject: [PATCH 04/18] minor cleanup --- .../sql/UpdateFieldsPerformanceSuite.scala | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala index a839d079f1e0..7f0ec0a77ffc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala @@ -23,22 +23,22 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { - private def colName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" + private def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" private def nestedStructType( depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = { if (depths.length == 1) { StructType(colNums.map { colNum => - StructField(colName(depths.head, colNum), IntegerType, nullable = false) + StructField(nestedColName(depths.head, colNum), IntegerType, nullable = false) }) } else { val depth = depths.head val fields = colNums.foldLeft(Seq.empty[StructField]) { case (structFields, colNum) if colNum == 0 => - val nested = nestedStructType(depths.tail, colNums, nullable = nullable) - structFields :+ StructField(colName(depth, colNum), nested, nullable = nullable) + val nested = nestedStructType(depths.tail, colNums, nullable) + structFields :+ StructField(nestedColName(depth, colNum), nested, nullable) case (structFields, colNum) => - structFields :+ StructField(colName(depth, colNum), IntegerType, nullable = false) + structFields :+ StructField(nestedColName(depth, colNum), IntegerType, nullable = false) } StructType(fields) } @@ -60,10 +60,10 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { * Utility function for generating a DataFrame with nested columns. * * @param depth: The depth to which to create nested columns. - * @param numColsAtEachDepth: The number of columns to create at each depth. The columns names - * are in the format of nested${depth}Col${index}. The value of each - * column will be its index at that depth, or if the index of the column - * is 0, then the value could also be a struct. + * @param numColsAtEachDepth: The number of columns to create at each depth. The value of each + * column will be the same as its index (IntegerType) at that depth + * unless the index = 0, in which case it may be a StructType which + * represents the next depth. * @param nullable: This value is used to set the nullability of StructType columns. */ private def nestedDf( @@ -78,7 +78,7 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { spark.createDataFrame( sparkContext.parallelize(Row(nestedColumn) :: Nil), - StructType(Seq(StructField(colName(0, 0), nestedColumnDataType, nullable = nullable)))) + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) } test("nestedDf should generate nested DataFrames") { @@ -139,21 +139,21 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { // drop columns at the current depth val dropped = if (colNumsToDrop.nonEmpty) { - column.dropFields(colNumsToDrop.map(num => colName(depth, num)): _*) + column.dropFields(colNumsToDrop.map(num => nestedColName(depth, num)): _*) } else column // add columns at the current depth val added = colNumsToAdd.foldLeft(dropped) { - (res, num) => res.withField(colName(depth, num), lit(num)) + (res, num) => res.withField(nestedColName(depth, num), lit(num)) } if (depths.length == 1) { added } else { // add/drop columns at the next depth - val nestedColumn = col((0 to depth).map(d => s"`${colName(d, 0)}`").mkString(".")) + val nestedColumn = col((0 to depth).map(d => s"`${nestedColName(d, 0)}`").mkString(".")) added.withField( - colName(depth, 0), + nestedColName(depth, 0), addDropNestedColumns(nestedColumn, depths.tail, colNumsToAdd, colNumsToDrop)) } } @@ -165,16 +165,16 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { val maxDepth = 20 // dataframe with nested*Col0 to nested*Col4 at each of 20 depths - val inputDf = nestedDf(maxDepth, 5, nullable = nullable) + val inputDf = nestedDf(maxDepth, 5, nullable) // add nested*Col5 through nested*Col9 at each depth val resultDf = inputDf.select(addDropNestedColumns( - column = col(colName(0, 0)), + column = col(nestedColName(0, 0)), depths = 1 to maxDepth, colNumsToAdd = 5 to 9).as("nested0Col0")) // dataframe with nested*Col0 to nested*Col9 at each of 20 depths - val expectedDf = nestedDf(maxDepth, 10, nullable = nullable) + val expectedDf = nestedDf(maxDepth, 10, nullable) checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) } @@ -183,16 +183,16 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { val maxDepth = 20 // dataframe with nested*Col0 to nested*Col9 at each of 20 depths - val inputDf = nestedDf(maxDepth, 10, nullable = nullable) + val inputDf = nestedDf(maxDepth, 10, nullable) // drop nested*Col5 to nested*Col9 at each of 20 depths val resultDf = inputDf.select(addDropNestedColumns( - column = col(colName(0, 0)), + column = col(nestedColName(0, 0)), depths = 1 to maxDepth, colNumsToDrop = 5 to 9).as("nested0Col0")) // dataframe with nested*Col0 to nested*Col4 at each of 20 depths - val expectedDf = nestedDf(maxDepth, 5, nullable = nullable) + val expectedDf = nestedDf(maxDepth, 5, nullable) checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) } @@ -201,12 +201,12 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { val maxDepth = 20 // dataframe with nested*Col0 to nested*Col9 at each of 20 depths - val inputDf = nestedDf(maxDepth, 10, nullable = nullable) + val inputDf = nestedDf(maxDepth, 10, nullable) // add nested*Col10 through nested*Col14 at each depth // drop nested*Col5 through nested*Col9 at each depth val resultDf = inputDf.select(addDropNestedColumns( - column = col(colName(0, 0)), + column = col(nestedColName(0, 0)), depths = 1 to maxDepth, colNumsToAdd = 10 to 14, colNumsToDrop = 5 to 9).as("nested0Col0")) @@ -217,11 +217,11 @@ class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { val depths = 1 to maxDepth val numCols = (0 to 4) ++ (10 to 14) val nestedColumn = nestedRow(depths, numCols) - val nestedColumnDataType = nestedStructType(depths, numCols, nullable = nullable) + val nestedColumnDataType = nestedStructType(depths, numCols, nullable) spark.createDataFrame( sparkContext.parallelize(Row(nestedColumn) :: Nil), - StructType(Seq(StructField(colName(0, 0), nestedColumnDataType, nullable = nullable)))) + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) } checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) } From 6b301bbe06cbec2e260cf939a12e7ee02501d561 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Wed, 23 Sep 2020 20:52:49 -0400 Subject: [PATCH 05/18] implement benchmark --- .../UpdateFieldsBenchmark-results.txt | 36 ++ .../spark/sql/UpdateFieldsBenchmark.scala | 310 ++++++++++++++++++ .../sql/UpdateFieldsPerformanceSuite.scala | 229 ------------- 3 files changed, 346 insertions(+), 229 deletions(-) create mode 100644 sql/core/benchmarks/UpdateFieldsBenchmark-results.txt create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala diff --git a/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt new file mode 100644 index 000000000000..4e12ba9fd727 --- /dev/null +++ b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt @@ -0,0 +1,36 @@ +================================================================================================ +Add 5 columns and drop 0 columns at 20 different depths of nesting +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Add 5 columns and drop 0 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------- +Non-Nullable StructTypes 239 299 57 0.0 239041387.0 1.0X +Nullable StructTypes 249 275 27 0.0 249397898.0 1.0X + + +================================================================================================ +Add 0 columns and drop 5 columns at 20 different depths of nesting +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Add 0 columns and drop 5 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------- +Non-Nullable StructTypes 288 299 10 0.0 287999562.0 1.0X +Nullable StructTypes 342 348 5 0.0 341891672.0 0.8X + + +================================================================================================ +Add 5 columns and drop 5 columns at 20 different depths of nesting +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Add 5 columns and drop 5 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +-------------------------------------------------------------------------------------------------------------------------------------------------- +Non-Nullable StructTypes 325 328 4 0.0 324817445.0 1.0X +Nullable StructTypes 374 395 25 0.0 373766295.0 0.9X + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala new file mode 100644 index 000000000000..eafc70dbc907 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala @@ -0,0 +1,310 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * Benchmark to measure Spark's performance analyzing and optimizing long UpdateFields chains. + * + * {{{ + * To run this benchmark: + * 1. without sbt: + * bin/spark-submit --class + * 2. with sbt: + * build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/UpdateFieldsBenchmark-results.txt". + * }}} + */ +object UpdateFieldsBenchmark extends SqlBasedBenchmark { + + private def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" + + private def nestedStructType( + colNums: Seq[Int], + nullable: Boolean, + maxDepth: Int, + currDepth: Int = 1): StructType = { + + if (currDepth == maxDepth) { + val fields = colNums.map { colNum => + val name = nestedColName(currDepth, colNum) + StructField(name, IntegerType, nullable = false) + } + StructType(fields) + } else { + val fields = colNums.foldLeft(Seq.empty[StructField]) { + case (structFields, colNum) if colNum == 0 => + val nested = nestedStructType(colNums, nullable, maxDepth, currDepth + 1) + structFields :+ StructField(nestedColName(currDepth, colNum), nested, nullable) + case (structFields, colNum) => + val name = nestedColName(currDepth, colNum) + structFields :+ StructField(name, IntegerType, nullable = false) + } + StructType(fields) + } + } + + private def nestedRow(colNums: Seq[Int], maxDepth: Int, currDepth: Int = 1): Row = { + if (currDepth == maxDepth) { + Row.fromSeq(colNums) + } else { + val values = colNums.foldLeft(Seq.empty[Any]) { + case (values, colNum) if colNum == 0 => + values :+ nestedRow(colNums, maxDepth, currDepth + 1) + case (values, colNum) => + values :+ colNum + } + Row.fromSeq(values) + } + } + + /** + * Utility function for generating a DataFrame with nested columns. + * + * @param maxDepth: The depth to which to create nested columns. + * @param numColsAtEachDepth: The number of columns to create at each depth. The value of each + * column will be the same as its index (IntegerType) at that depth + * unless the index = 0, in which case it may be a StructType which + * represents the next depth. + * @param nullable: This value is used to set the nullability of StructType columns. + */ + def nestedDf(maxDepth: Int, numColsAtEachDepth: Int, nullable: Boolean): DataFrame = { + require(maxDepth > 0) + require(numColsAtEachDepth > 0) + + val colNums = 0 until numColsAtEachDepth + val nestedColumn = nestedRow(colNums, maxDepth) + val nestedColumnDataType = nestedStructType(colNums, nullable, maxDepth) + + spark.createDataFrame( + spark.sparkContext.parallelize(Row(nestedColumn) :: Nil), + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) + } + + // simulates how a user would add/drop nested fields in a performant manner + def modifyNestedColumns( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int, + currDepth: Int = 1): Column = { + + // drop columns at the current depth + val dropped = if (numsToDrop.nonEmpty) { + column.dropFields(numsToDrop.map(num => nestedColName(currDepth, num)): _*) + } else column + + // add columns at the current depth + val added = numsToAdd.foldLeft(dropped) { + (res, num) => res.withField(nestedColName(currDepth, num), lit(num)) + } + + if (currDepth == maxDepth) { + added + } else { + // add/drop columns at the next depth + val newValue = modifyNestedColumns( + column = col((0 to currDepth).map(d => nestedColName(d, 0)).mkString(".")), + numsToAdd = numsToAdd, + numsToDrop = numsToDrop, + currDepth = currDepth + 1, + maxDepth = maxDepth) + added.withField(nestedColName(currDepth, 0), newValue) + } + } + + def updateFieldsBenchmark( + maxDepth: Int, + initialNumberOfColumns: Int, + numsToAdd: Seq[Int] = Seq.empty, + numsToDrop: Seq[Int] = Seq.empty): Unit = { + + val name = s"Add ${numsToAdd.length} columns and drop ${numsToDrop.length} columns " + + s"at $maxDepth different depths of nesting" + + runBenchmark(name) { + val benchmark = new Benchmark( + name = name, + // Because the point of this benchmark is only to ensure Spark is able to analyze and + // optimize long UpdateFields chains quickly, this benchmark operates only over 1 row of + // data. + valuesPerIteration = 1, + output = output) + + val columnFunc = modifyNestedColumns( + col(nestedColName(0, 0)), + numsToAdd, + numsToDrop, + maxDepth + ).as(nestedColName(0, 0)) + + val nonNullableInputDf = nestedDf(maxDepth, initialNumberOfColumns, nullable = false) + val nullableInputDf = nestedDf(maxDepth, initialNumberOfColumns, nullable = true) + + benchmark.addCase("Non-Nullable StructTypes") { _ => + nonNullableInputDf.select(columnFunc).noop() + } + + benchmark.addCase("Nullable StructTypes") { _ => + nullableInputDf.select(columnFunc).noop() + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + val maxDepth = 20 + + updateFieldsBenchmark( + maxDepth = maxDepth, + initialNumberOfColumns = 5, + numsToAdd = 5 to 9) + + updateFieldsBenchmark( + maxDepth = maxDepth, + initialNumberOfColumns = 10, + numsToDrop = 5 to 9) + + updateFieldsBenchmark( + maxDepth = maxDepth, + initialNumberOfColumns = 10, + numsToAdd = 10 to 14, + numsToDrop = 5 to 9) + } +} + +class UpdateFieldsBenchmark extends QueryTest with SharedSparkSession { + import UpdateFieldsBenchmark._ + + test("nestedDf should generate nested DataFrames") { + checkAnswer( + nestedDf(1, 1, nullable = false), + Row(Row(0)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(1, 2, nullable = false), + Row(Row(0, 1)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(2, 1, nullable = false), + Row(Row(Row(0))) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(2, 2, nullable = false), + Row(Row(Row(0, 1), 1)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + nestedDf(2, 2, nullable = true), + Row(Row(Row(0, 1), 1)) :: Nil, + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = true), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = true)))) + } + + private val maxDepth = 3 + + test("modifyNestedColumns should add 5 columns at each depth of nesting") { + // dataframe with nested*Col0 to nested*Col4 at each depth + val inputDf = nestedDf(maxDepth, 5, nullable = false) + + // add nested*Col5 through nested*Col9 at each depth + val resultDf = inputDf.select(modifyNestedColumns( + column = col(nestedColName(0, 0)), + numsToAdd = 5 to 9, + numsToDrop = Seq.empty, + maxDepth = maxDepth + ).as("nested0Col0")) + + // dataframe with nested*Col0 to nested*Col9 at each depth + val expectedDf = nestedDf(maxDepth, 10, nullable = false) + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + + test("modifyNestedColumns should drop 5 columns at each depth of nesting") { + // dataframe with nested*Col0 to nested*Col9 at each depth + val inputDf = nestedDf(maxDepth, 10, nullable = false) + + // drop nested*Col5 to nested*Col9 at each of 20 depths + val resultDf = inputDf.select(modifyNestedColumns( + column = col(nestedColName(0, 0)), + numsToAdd = Seq.empty, + numsToDrop = 5 to 9, + maxDepth = maxDepth + ).as("nested0Col0")) + + // dataframe with nested*Col0 to nested*Col4 at each depth + val expectedDf = nestedDf(maxDepth, 5, nullable = false) + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + + test("modifyNestedColumns should add and drop 5 columns at each depth of nesting") { + // dataframe with nested*Col0 to nested*Col9 at each depth + val inputDf = nestedDf(maxDepth, 10, nullable = false) + + // drop nested*Col5 to nested*Col9 at each depth + val resultDf = inputDf.select(modifyNestedColumns( + column = col(nestedColName(0, 0)), + numsToAdd = 10 to 14, + numsToDrop = 5 to 9, + maxDepth = maxDepth + ).as("nested0Col0")) + + // dataframe with nested*Col0 to nested*Col4 and nested*Col10 to nested*Col14 at each depth + val expectedDf = { + val numCols = (0 to 4) ++ (10 to 14) + val nestedColumn = nestedRow(numCols, maxDepth) + val nestedColumnDataType = nestedStructType(numCols, nullable = false, maxDepth) + + spark.createDataFrame( + sparkContext.parallelize(Row(nestedColumn) :: Nil), + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable = false)))) + } + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala deleted file mode 100644 index 7f0ec0a77ffc..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsPerformanceSuite.scala +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{IntegerType, StructField, StructType} - -class UpdateFieldsPerformanceSuite extends QueryTest with SharedSparkSession { - - private def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" - - private def nestedStructType( - depths: Seq[Int], colNums: Seq[Int], nullable: Boolean): StructType = { - if (depths.length == 1) { - StructType(colNums.map { colNum => - StructField(nestedColName(depths.head, colNum), IntegerType, nullable = false) - }) - } else { - val depth = depths.head - val fields = colNums.foldLeft(Seq.empty[StructField]) { - case (structFields, colNum) if colNum == 0 => - val nested = nestedStructType(depths.tail, colNums, nullable) - structFields :+ StructField(nestedColName(depth, colNum), nested, nullable) - case (structFields, colNum) => - structFields :+ StructField(nestedColName(depth, colNum), IntegerType, nullable = false) - } - StructType(fields) - } - } - - private def nestedRow(depths: Seq[Int], colNums: Seq[Int]): Row = { - if (depths.length == 1) { - Row.fromSeq(colNums) - } else { - val values = colNums.foldLeft(Seq.empty[Any]) { - case (values, colNum) if colNum == 0 => values :+ nestedRow(depths.tail, colNums) - case (values, colNum) => values :+ colNum - } - Row.fromSeq(values) - } - } - - /** - * Utility function for generating a DataFrame with nested columns. - * - * @param depth: The depth to which to create nested columns. - * @param numColsAtEachDepth: The number of columns to create at each depth. The value of each - * column will be the same as its index (IntegerType) at that depth - * unless the index = 0, in which case it may be a StructType which - * represents the next depth. - * @param nullable: This value is used to set the nullability of StructType columns. - */ - private def nestedDf( - depth: Int, numColsAtEachDepth: Int, nullable: Boolean = false): DataFrame = { - require(depth > 0) - require(numColsAtEachDepth > 0) - - val depths = 1 to depth - val colNums = 0 until numColsAtEachDepth - val nestedColumn = nestedRow(depths, colNums) - val nestedColumnDataType = nestedStructType(depths, colNums, nullable) - - spark.createDataFrame( - sparkContext.parallelize(Row(nestedColumn) :: Nil), - StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) - } - - test("nestedDf should generate nested DataFrames") { - checkAnswer( - nestedDf(1, 1), - Row(Row(0)) :: Nil, - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", IntegerType, nullable = false))), - nullable = false)))) - - checkAnswer( - nestedDf(1, 2), - Row(Row(0, 1)) :: Nil, - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", IntegerType, nullable = false), - StructField("nested1Col1", IntegerType, nullable = false))), - nullable = false)))) - - checkAnswer( - nestedDf(2, 1), - Row(Row(Row(0))) :: Nil, - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", StructType(Seq( - StructField("nested2Col0", IntegerType, nullable = false))), - nullable = false))), - nullable = false)))) - - checkAnswer( - nestedDf(2, 2), - Row(Row(Row(0, 1), 1)) :: Nil, - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", StructType(Seq( - StructField("nested2Col0", IntegerType, nullable = false), - StructField("nested2Col1", IntegerType, nullable = false))), - nullable = false), - StructField("nested1Col1", IntegerType, nullable = false))), - nullable = false)))) - - checkAnswer( - nestedDf(2, 2, nullable = true), - Row(Row(Row(0, 1), 1)) :: Nil, - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", StructType(Seq( - StructField("nested2Col0", IntegerType, nullable = false), - StructField("nested2Col1", IntegerType, nullable = false))), - nullable = true), - StructField("nested1Col1", IntegerType, nullable = false))), - nullable = true)))) - } - - // simulates how a user would add/drop nested fields in a performant manner - private def addDropNestedColumns( - column: Column, - depths: Seq[Int], - colNumsToAdd: Seq[Int] = Seq.empty, - colNumsToDrop: Seq[Int] = Seq.empty): Column = { - val depth = depths.head - - // drop columns at the current depth - val dropped = if (colNumsToDrop.nonEmpty) { - column.dropFields(colNumsToDrop.map(num => nestedColName(depth, num)): _*) - } else column - - // add columns at the current depth - val added = colNumsToAdd.foldLeft(dropped) { - (res, num) => res.withField(nestedColName(depth, num), lit(num)) - } - - if (depths.length == 1) { - added - } else { - // add/drop columns at the next depth - val nestedColumn = col((0 to depth).map(d => s"`${nestedColName(d, 0)}`").mkString(".")) - added.withField( - nestedColName(depth, 0), - addDropNestedColumns(nestedColumn, depths.tail, colNumsToAdd, colNumsToDrop)) - } - } - - // check both nullable and non-nullable struct code paths are performant - Seq(true, false).foreach { nullable => - test("should add 5 columns at 20 different depths of nesting for a total of 100 columns " + - s"added, nullable = $nullable") { - val maxDepth = 20 - - // dataframe with nested*Col0 to nested*Col4 at each of 20 depths - val inputDf = nestedDf(maxDepth, 5, nullable) - - // add nested*Col5 through nested*Col9 at each depth - val resultDf = inputDf.select(addDropNestedColumns( - column = col(nestedColName(0, 0)), - depths = 1 to maxDepth, - colNumsToAdd = 5 to 9).as("nested0Col0")) - - // dataframe with nested*Col0 to nested*Col9 at each of 20 depths - val expectedDf = nestedDf(maxDepth, 10, nullable) - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) - } - - test("should drop 5 columns at 20 different depths of nesting for a total of 100 columns " + - s"dropped, nullable = $nullable") { - val maxDepth = 20 - - // dataframe with nested*Col0 to nested*Col9 at each of 20 depths - val inputDf = nestedDf(maxDepth, 10, nullable) - - // drop nested*Col5 to nested*Col9 at each of 20 depths - val resultDf = inputDf.select(addDropNestedColumns( - column = col(nestedColName(0, 0)), - depths = 1 to maxDepth, - colNumsToDrop = 5 to 9).as("nested0Col0")) - - // dataframe with nested*Col0 to nested*Col4 at each of 20 depths - val expectedDf = nestedDf(maxDepth, 5, nullable) - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) - } - - test("should add 5 columns and drop 5 columns at 20 different depths of nesting for a total " + - s"of 200 columns added/dropped, nullable = $nullable") { - val maxDepth = 20 - - // dataframe with nested*Col0 to nested*Col9 at each of 20 depths - val inputDf = nestedDf(maxDepth, 10, nullable) - - // add nested*Col10 through nested*Col14 at each depth - // drop nested*Col5 through nested*Col9 at each depth - val resultDf = inputDf.select(addDropNestedColumns( - column = col(nestedColName(0, 0)), - depths = 1 to maxDepth, - colNumsToAdd = 10 to 14, - colNumsToDrop = 5 to 9).as("nested0Col0")) - - // dataframe with nested*Col0 to nested*Col4 and nested*Col10 to nested*Col14 - // at each of 20 depths - val expectedDf = { - val depths = 1 to maxDepth - val numCols = (0 to 4) ++ (10 to 14) - val nestedColumn = nestedRow(depths, numCols) - val nestedColumnDataType = nestedStructType(depths, numCols, nullable) - - spark.createDataFrame( - sparkContext.parallelize(Row(nestedColumn) :: Nil), - StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) - } - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) - } - } -} From 25ff6fa4f93de4f144860b5449985e8d67452baa Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 20:12:45 -0400 Subject: [PATCH 06/18] replace unresolved with illegal state exception --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 7ca625129bbf..ecdfb303ed5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -578,9 +578,11 @@ case class WithField(name: String, valExpr: Expression) override def children: Seq[Expression] = valExpr :: Nil - override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def dataType: DataType = throw new IllegalStateException( + "WithField.dataType should not be called.") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def nullable: Boolean = throw new IllegalStateException( + "WithField.nullable should not be called.") override def prettyName: String = "WithField" } From f7dd0f4edca6a38cfcbef9405a18433157d7e94c Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 20:15:44 -0400 Subject: [PATCH 07/18] use local variable --- .../expressions/complexTypeCreator.scala | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index ecdfb303ed5f..791d2d5b08bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -623,25 +623,28 @@ case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperat override def prettyName: String = "update_fields" - private lazy val existingFieldExprs: Seq[(StructField, Expression)] = - structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex.map { - case (field, i) => (field, GetStructField(structExpr, i)) - } + private lazy val newFieldExprs: Seq[(StructField, Expression)] = { + val existingFieldExprs: Seq[(StructField, Expression)] = + structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex.map { + case (field, i) => (field, GetStructField(structExpr, i)) + } - private lazy val newFieldExprs: Seq[(StructField, Expression)] = fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs)) + } private lazy val newFields: Seq[StructField] = newFieldExprs.map(_._1) lazy val newExprs: Seq[Expression] = newFieldExprs.map(_._2) - private lazy val createNamedStructExpr = CreateNamedStruct(newFieldExprs.flatMap { - case (field, expr) => Seq(Literal(field.name), expr) - }) + lazy val evalExpr: Expression = { + val createNamedStructExpr = CreateNamedStruct(newFieldExprs.flatMap { + case (field, expr) => Seq(Literal(field.name), expr) + }) - lazy val evalExpr: Expression = if (structExpr.nullable) { - If(IsNull(structExpr), Literal(null, dataType), createNamedStructExpr) - } else { - createNamedStructExpr + if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, dataType), createNamedStructExpr) + } else { + createNamedStructExpr + } } } From f727fac2e31867b0b4668b15df9c5f2f7d0dad61 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 20:56:37 -0400 Subject: [PATCH 08/18] avoid iterating the values twice --- .../expressions/complexTypeCreator.scala | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 791d2d5b08bf..b457a25921b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion, UnresolvedException} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -566,14 +568,18 @@ case class WithField(name: String, valExpr: Expression) override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = { val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable), valExpr) - if (values.exists { case (field, _) => resolver(field.name, name) }) { - values.map { - case (field, _) if resolver(field.name, name) => newFieldExpr - case x => x + val result = ArrayBuffer.empty[(StructField, Expression)] + var hasMatch = false + for (existingFieldExpr @ (existingField, _) <- values) { + if (resolver(existingField.name, name)) { + hasMatch = true + result.append(newFieldExpr) + } else { + result.append(existingFieldExpr) } - } else { - values :+ newFieldExpr } + if (!hasMatch) result += newFieldExpr + result } override def children: Seq[Expression] = valExpr :: Nil From f77921d9bc0e2f4dfb356c11ba7229e5dd042104 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 20:58:55 -0400 Subject: [PATCH 09/18] explicitly mention noop --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index f8daaf9c58c1..dfc47a963f81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -931,6 +931,7 @@ class Column(val expr: Expression) extends Logging { // scalastyle:off line.size.limit /** * An expression that drops fields in `StructType` by name. + * This is a no-op if schema doesn't contain field name(s). * * {{{ * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") From 9f3d7bccdcaad8525ca9c6d6823c779c614dda7f Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 21:18:07 -0400 Subject: [PATCH 10/18] add test to show dropFields should drop field with no name in struct --- .../spark/sql/ColumnExpressionSuite.scala | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6423ede59350..ecffd21e8279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -1626,6 +1626,24 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { }.getMessage should include("cannot drop all fields in struct") } + test("dropFields should drop field with no name in struct") { + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("", IntegerType, nullable = false))) + + val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + checkAnswer( + structLevel1.withColumn("a", $"a".dropFields("")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + test("dropFields should drop field in nested struct") { checkAnswer( structLevel2.withColumn("a", 'a.dropFields("a.b")), From 430b996c5a49ffabf7581342684dfa2e771d01e8 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 21:21:19 -0400 Subject: [PATCH 11/18] 4 space indentation --- .../scala/org/apache/spark/sql/Column.scala | 6 ++--- .../org/apache/spark/sql/QueryTest.scala | 6 ++--- .../spark/sql/UpdateFieldsBenchmark.scala | 26 +++++++++---------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index dfc47a963f81..a46d6c0bb228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1008,9 +1008,9 @@ class Column(val expr: Expression) extends Logging { } private def updateFieldsHelper( - structExpr: Expression, - namePartsRemaining: Seq[String], - valueFunc: String => StructFieldsOperation): UpdateFields = { + structExpr: Expression, + namePartsRemaining: Seq[String], + valueFunc: String => StructFieldsOperation): UpdateFields = { val fieldName = namePartsRemaining.head if (namePartsRemaining.length == 1) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 12989b342dc1..8ac6af2ac25d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -161,9 +161,9 @@ abstract class QueryTest extends PlanTest { } protected def checkAnswer( - df: => DataFrame, - expectedAnswer: Seq[Row], - expectedSchema: StructType): Unit = { + df: => DataFrame, + expectedAnswer: Seq[Row], + expectedSchema: StructType): Unit = { checkAnswer(df, expectedAnswer) assert(df.schema == expectedSchema) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala index eafc70dbc907..ac3888ad99be 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala @@ -42,10 +42,10 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { private def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" private def nestedStructType( - colNums: Seq[Int], - nullable: Boolean, - maxDepth: Int, - currDepth: Int = 1): StructType = { + colNums: Seq[Int], + nullable: Boolean, + maxDepth: Int, + currDepth: Int = 1): StructType = { if (currDepth == maxDepth) { val fields = colNums.map { colNum => @@ -105,11 +105,11 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { // simulates how a user would add/drop nested fields in a performant manner def modifyNestedColumns( - column: Column, - numsToAdd: Seq[Int], - numsToDrop: Seq[Int], - maxDepth: Int, - currDepth: Int = 1): Column = { + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int, + currDepth: Int = 1): Column = { // drop columns at the current depth val dropped = if (numsToDrop.nonEmpty) { @@ -136,10 +136,10 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { } def updateFieldsBenchmark( - maxDepth: Int, - initialNumberOfColumns: Int, - numsToAdd: Seq[Int] = Seq.empty, - numsToDrop: Seq[Int] = Seq.empty): Unit = { + maxDepth: Int, + initialNumberOfColumns: Int, + numsToAdd: Seq[Int] = Seq.empty, + numsToDrop: Seq[Int] = Seq.empty): Unit = { val name = s"Add ${numsToAdd.length} columns and drop ${numsToDrop.length} columns " + s"at $maxDepth different depths of nesting" From 53d83b635b6e840eb1a66f1c89b453729f589500 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Thu, 24 Sep 2020 21:36:22 -0400 Subject: [PATCH 12/18] use += instead of append method for consistency --- .../spark/sql/catalyst/expressions/complexTypeCreator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index b457a25921b2..5058bdd5264c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -573,9 +573,9 @@ case class WithField(name: String, valExpr: Expression) for (existingFieldExpr @ (existingField, _) <- values) { if (resolver(existingField.name, name)) { hasMatch = true - result.append(newFieldExpr) + result += newFieldExpr } else { - result.append(existingFieldExpr) + result += existingFieldExpr } } if (!hasMatch) result += newFieldExpr From dffc0e84bff935e3b4bdf3680ba62fc085441603 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Fri, 25 Sep 2020 14:07:06 -0400 Subject: [PATCH 13/18] use .queryExecution.optimizedPlan instead of .noop() --- .../benchmarks/UpdateFieldsBenchmark-results.txt | 12 ++++++------ .../org/apache/spark/sql/UpdateFieldsBenchmark.scala | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt index 4e12ba9fd727..9e7cd88eac33 100644 --- a/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt +++ b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt @@ -6,8 +6,8 @@ OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Add 5 columns and drop 0 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------- -Non-Nullable StructTypes 239 299 57 0.0 239041387.0 1.0X -Nullable StructTypes 249 275 27 0.0 249397898.0 1.0X +Non-Nullable StructTypes 64 70 7 0.0 64023392.0 1.0X +Nullable StructTypes 65 69 4 0.0 65179698.0 1.0X ================================================================================================ @@ -18,8 +18,8 @@ OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Add 0 columns and drop 5 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------- -Non-Nullable StructTypes 288 299 10 0.0 287999562.0 1.0X -Nullable StructTypes 342 348 5 0.0 341891672.0 0.8X +Non-Nullable StructTypes 42 47 6 0.0 42129670.0 1.0X +Nullable StructTypes 44 48 5 0.0 44021916.0 1.0X ================================================================================================ @@ -30,7 +30,7 @@ OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz Add 5 columns and drop 5 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative -------------------------------------------------------------------------------------------------------------------------------------------------- -Non-Nullable StructTypes 325 328 4 0.0 324817445.0 1.0X -Nullable StructTypes 374 395 25 0.0 373766295.0 0.9X +Non-Nullable StructTypes 74 77 2 0.0 73518396.0 1.0X +Nullable StructTypes 77 79 2 0.0 76520609.0 1.0X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala index ac3888ad99be..94c172482b69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala @@ -153,7 +153,7 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { valuesPerIteration = 1, output = output) - val columnFunc = modifyNestedColumns( + val modifiedColumn = modifyNestedColumns( col(nestedColName(0, 0)), numsToAdd, numsToDrop, @@ -164,11 +164,11 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { val nullableInputDf = nestedDf(maxDepth, initialNumberOfColumns, nullable = true) benchmark.addCase("Non-Nullable StructTypes") { _ => - nonNullableInputDf.select(columnFunc).noop() + nonNullableInputDf.select(modifiedColumn).queryExecution.optimizedPlan } benchmark.addCase("Nullable StructTypes") { _ => - nullableInputDf.select(columnFunc).noop() + nullableInputDf.select(modifiedColumn).queryExecution.optimizedPlan } benchmark.run() From 2ef7bc877445fcb2e514b92056da76732235f1a4 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Fri, 25 Sep 2020 16:55:17 -0400 Subject: [PATCH 14/18] benchmark both performant and non-performant way, remove unnecessary tests and methods --- .../UpdateFieldsBenchmark-results.txt | 34 +- .../spark/sql/UpdateFieldsBenchmark.scala | 298 +++++++++--------- 2 files changed, 161 insertions(+), 171 deletions(-) diff --git a/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt index 9e7cd88eac33..5feca0e100bb 100644 --- a/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt +++ b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt @@ -1,36 +1,26 @@ ================================================================================================ -Add 5 columns and drop 0 columns at 20 different depths of nesting +Add 2 columns and drop 2 columns at 3 different depths of nesting ================================================================================================ OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz -Add 5 columns and drop 0 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------------------- -Non-Nullable StructTypes 64 70 7 0.0 64023392.0 1.0X -Nullable StructTypes 65 69 4 0.0 65179698.0 1.0X +Add 2 columns and drop 2 columns at 3 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------- +To non-nullable StructTypes using performant method 10 11 2 0.0 Infinity 1.0X +To nullable StructTypes using performant method 9 10 1 0.0 Infinity 1.0X +To non-nullable StructTypes using non-performant method 2457 2464 10 0.0 Infinity 0.0X +To nullable StructTypes using non-performant method 42641 43804 1644 0.0 Infinity 0.0X ================================================================================================ -Add 0 columns and drop 5 columns at 20 different depths of nesting +Add 50 columns and drop 50 columns at 100 different depths of nesting ================================================================================================ OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz -Add 0 columns and drop 5 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------------------- -Non-Nullable StructTypes 42 47 6 0.0 42129670.0 1.0X -Nullable StructTypes 44 48 5 0.0 44021916.0 1.0X - - -================================================================================================ -Add 5 columns and drop 5 columns at 20 different depths of nesting -================================================================================================ - -OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 -Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz -Add 5 columns and drop 5 columns at 20 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative --------------------------------------------------------------------------------------------------------------------------------------------------- -Non-Nullable StructTypes 74 77 2 0.0 73518396.0 1.0X -Nullable StructTypes 77 79 2 0.0 76520609.0 1.0X +Add 50 columns and drop 50 columns at 100 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +To non-nullable StructTypes using performant method 4595 4927 470 0.0 Infinity 1.0X +To nullable StructTypes using performant method 5185 5516 468 0.0 Infinity 0.9X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala index 94c172482b69..31c77550ce90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala @@ -66,76 +66,102 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { } } - private def nestedRow(colNums: Seq[Int], maxDepth: Int, currDepth: Int = 1): Row = { - if (currDepth == maxDepth) { - Row.fromSeq(colNums) - } else { - val values = colNums.foldLeft(Seq.empty[Any]) { - case (values, colNum) if colNum == 0 => - values :+ nestedRow(colNums, maxDepth, currDepth + 1) - case (values, colNum) => - values :+ colNum - } - Row.fromSeq(values) - } - } - /** - * Utility function for generating a DataFrame with nested columns. + * Utility function for generating an empty DataFrame with nested columns. * * @param maxDepth: The depth to which to create nested columns. - * @param numColsAtEachDepth: The number of columns to create at each depth. The value of each - * column will be the same as its index (IntegerType) at that depth - * unless the index = 0, in which case it may be a StructType which - * represents the next depth. - * @param nullable: This value is used to set the nullability of StructType columns. + * @param numColsAtEachDepth: The number of columns to create at each depth. + * @param nullable: This value is used to set the nullability of any StructType columns. */ - def nestedDf(maxDepth: Int, numColsAtEachDepth: Int, nullable: Boolean): DataFrame = { + def emptyNestedDf(maxDepth: Int, numColsAtEachDepth: Int, nullable: Boolean): DataFrame = { require(maxDepth > 0) require(numColsAtEachDepth > 0) - val colNums = 0 until numColsAtEachDepth - val nestedColumn = nestedRow(colNums, maxDepth) - val nestedColumnDataType = nestedStructType(colNums, nullable, maxDepth) - + val nestedColumnDataType = nestedStructType(0 until numColsAtEachDepth, nullable, maxDepth) spark.createDataFrame( - spark.sparkContext.parallelize(Row(nestedColumn) :: Nil), + spark.sparkContext.emptyRDD[Row], StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) } - // simulates how a user would add/drop nested fields in a performant manner - def modifyNestedColumns( - column: Column, - numsToAdd: Seq[Int], - numsToDrop: Seq[Int], - maxDepth: Int, - currDepth: Int = 1): Column = { + private trait ModifyNestedColumns { + val name: String + def apply(column: Column, numsToAdd: Seq[Int], numsToDrop: Seq[Int], maxDepth: Int): Column + } - // drop columns at the current depth - val dropped = if (numsToDrop.nonEmpty) { - column.dropFields(numsToDrop.map(num => nestedColName(currDepth, num)): _*) - } else column + private object Performant extends ModifyNestedColumns { + override val name: String = "performant" + + override def apply( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int): Column = helper(column, numsToAdd, numsToDrop, maxDepth, 1) + + private def helper( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int, + currDepth: Int): Column = { + + // drop columns at the current depth + val dropped = if (numsToDrop.nonEmpty) { + column.dropFields(numsToDrop.map(num => nestedColName(currDepth, num)): _*) + } else column + + // add columns at the current depth + val added = numsToAdd.foldLeft(dropped) { + (res, num) => res.withField(nestedColName(currDepth, num), lit(num)) + } - // add columns at the current depth - val added = numsToAdd.foldLeft(dropped) { - (res, num) => res.withField(nestedColName(currDepth, num), lit(num)) + if (currDepth == maxDepth) { + added + } else { + // add/drop columns at the next depth + val newValue = helper( + column = col((0 to currDepth).map(d => nestedColName(d, 0)).mkString(".")), + numsToAdd = numsToAdd, + numsToDrop = numsToDrop, + currDepth = currDepth + 1, + maxDepth = maxDepth) + added.withField(nestedColName(currDepth, 0), newValue) + } } + } + + private object NonPerformant extends ModifyNestedColumns { + override val name: String = "non-performant" + + override def apply( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int): Column = { + + val dropped = if (numsToDrop.nonEmpty) { + val colsToDrop = (1 to maxDepth).flatMap { depth => + numsToDrop.map(num => s"${prefix(depth)}${nestedColName(depth, num)}") + } + column.dropFields(colsToDrop: _*) + } else column + + val added = { + val colsToAdd = (1 to maxDepth).flatMap { depth => + numsToAdd.map(num => (s"${prefix(depth)}${nestedColName(depth, num)}", lit(num))) + } + colsToAdd.foldLeft(dropped)((col, add) => col.withField(add._1, add._2)) + } - if (currDepth == maxDepth) { added - } else { - // add/drop columns at the next depth - val newValue = modifyNestedColumns( - column = col((0 to currDepth).map(d => nestedColName(d, 0)).mkString(".")), - numsToAdd = numsToAdd, - numsToDrop = numsToDrop, - currDepth = currDepth + 1, - maxDepth = maxDepth) - added.withField(nestedColName(currDepth, 0), newValue) } + + private def prefix(depth: Int): String = + if (depth == 1) "" + else (1 until depth).map(d => nestedColName(d, 0)).mkString("", ".", ".") } - def updateFieldsBenchmark( + private def updateFieldsBenchmark( + methods: Seq[ModifyNestedColumns], maxDepth: Int, initialNumberOfColumns: Int, numsToAdd: Seq[Int] = Seq.empty, @@ -147,28 +173,29 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { runBenchmark(name) { val benchmark = new Benchmark( name = name, - // Because the point of this benchmark is only to ensure Spark is able to analyze and - // optimize long UpdateFields chains quickly, this benchmark operates only over 1 row of - // data. - valuesPerIteration = 1, + // The purpose of this benchmark is to ensure Spark is able to analyze and optimize long + // UpdateFields chains quickly so it runs over 0 rows of data. + valuesPerIteration = 0, output = output) - val modifiedColumn = modifyNestedColumns( - col(nestedColName(0, 0)), - numsToAdd, - numsToDrop, - maxDepth - ).as(nestedColName(0, 0)) + val nonNullableStructsDf = emptyNestedDf(maxDepth, initialNumberOfColumns, nullable = false) + val nullableStructsDf = emptyNestedDf(maxDepth, initialNumberOfColumns, nullable = true) - val nonNullableInputDf = nestedDf(maxDepth, initialNumberOfColumns, nullable = false) - val nullableInputDf = nestedDf(maxDepth, initialNumberOfColumns, nullable = true) + methods.foreach { method => + val modifiedColumn = method( + column = col(nestedColName(0, 0)), + numsToAdd = numsToAdd, + numsToDrop = numsToDrop, + maxDepth = maxDepth + ).as(nestedColName(0, 0)) - benchmark.addCase("Non-Nullable StructTypes") { _ => - nonNullableInputDf.select(modifiedColumn).queryExecution.optimizedPlan - } + benchmark.addCase(s"To non-nullable StructTypes using ${method.name} method") { _ => + nonNullableStructsDf.select(modifiedColumn).queryExecution.optimizedPlan + } - benchmark.addCase("Nullable StructTypes") { _ => - nullableInputDf.select(modifiedColumn).queryExecution.optimizedPlan + benchmark.addCase(s"To nullable StructTypes using ${method.name} method") { _ => + nullableStructsDf.select(modifiedColumn).queryExecution.optimizedPlan + } } benchmark.run() @@ -176,23 +203,24 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { } override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { - val maxDepth = 20 - + // This benchmark compares the performant and non-performant methods of writing the same query. + // We use small values for maxDepth, numsToAdd, and numsToDrop because the NonPerformant method + // scales extremely poorly with the number of nested columns being added/dropped. updateFieldsBenchmark( - maxDepth = maxDepth, + methods = Seq(Performant, NonPerformant), + maxDepth = 3, initialNumberOfColumns = 5, - numsToAdd = 5 to 9) + numsToAdd = 5 to 6, + numsToDrop = 3 to 4) + // This benchmark is to show that the performant method of writing a query when we want to add + // and drop a large number of nested columns scales nicely. updateFieldsBenchmark( - maxDepth = maxDepth, - initialNumberOfColumns = 10, - numsToDrop = 5 to 9) - - updateFieldsBenchmark( - maxDepth = maxDepth, - initialNumberOfColumns = 10, - numsToAdd = 10 to 14, - numsToDrop = 5 to 9) + methods = Seq(Performant), + maxDepth = 100, + initialNumberOfColumns = 51, + numsToAdd = 51 to 100, + numsToDrop = 1 to 50) } } @@ -201,23 +229,23 @@ class UpdateFieldsBenchmark extends QueryTest with SharedSparkSession { test("nestedDf should generate nested DataFrames") { checkAnswer( - nestedDf(1, 1, nullable = false), - Row(Row(0)) :: Nil, + emptyNestedDf(1, 1, nullable = false), + Seq.empty[Row], StructType(Seq(StructField("nested0Col0", StructType(Seq( StructField("nested1Col0", IntegerType, nullable = false))), nullable = false)))) checkAnswer( - nestedDf(1, 2, nullable = false), - Row(Row(0, 1)) :: Nil, + emptyNestedDf(1, 2, nullable = false), + Seq.empty[Row], StructType(Seq(StructField("nested0Col0", StructType(Seq( StructField("nested1Col0", IntegerType, nullable = false), StructField("nested1Col1", IntegerType, nullable = false))), nullable = false)))) checkAnswer( - nestedDf(2, 1, nullable = false), - Row(Row(Row(0))) :: Nil, + emptyNestedDf(2, 1, nullable = false), + Seq.empty[Row], StructType(Seq(StructField("nested0Col0", StructType(Seq( StructField("nested1Col0", StructType(Seq( StructField("nested2Col0", IntegerType, nullable = false))), @@ -225,8 +253,8 @@ class UpdateFieldsBenchmark extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - nestedDf(2, 2, nullable = false), - Row(Row(Row(0, 1), 1)) :: Nil, + emptyNestedDf(2, 2, nullable = false), + Seq.empty[Row], StructType(Seq(StructField("nested0Col0", StructType(Seq( StructField("nested1Col0", StructType(Seq( StructField("nested2Col0", IntegerType, nullable = false), @@ -236,8 +264,8 @@ class UpdateFieldsBenchmark extends QueryTest with SharedSparkSession { nullable = false)))) checkAnswer( - nestedDf(2, 2, nullable = true), - Row(Row(Row(0, 1), 1)) :: Nil, + emptyNestedDf(2, 2, nullable = true), + Seq.empty[Row], StructType(Seq(StructField("nested0Col0", StructType(Seq( StructField("nested1Col0", StructType(Seq( StructField("nested2Col0", IntegerType, nullable = false), @@ -247,64 +275,36 @@ class UpdateFieldsBenchmark extends QueryTest with SharedSparkSession { nullable = true)))) } - private val maxDepth = 3 - - test("modifyNestedColumns should add 5 columns at each depth of nesting") { - // dataframe with nested*Col0 to nested*Col4 at each depth - val inputDf = nestedDf(maxDepth, 5, nullable = false) - - // add nested*Col5 through nested*Col9 at each depth - val resultDf = inputDf.select(modifyNestedColumns( - column = col(nestedColName(0, 0)), - numsToAdd = 5 to 9, - numsToDrop = Seq.empty, - maxDepth = maxDepth - ).as("nested0Col0")) - - // dataframe with nested*Col0 to nested*Col9 at each depth - val expectedDf = nestedDf(maxDepth, 10, nullable = false) - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) - } - - test("modifyNestedColumns should drop 5 columns at each depth of nesting") { - // dataframe with nested*Col0 to nested*Col9 at each depth - val inputDf = nestedDf(maxDepth, 10, nullable = false) - - // drop nested*Col5 to nested*Col9 at each of 20 depths - val resultDf = inputDf.select(modifyNestedColumns( - column = col(nestedColName(0, 0)), - numsToAdd = Seq.empty, - numsToDrop = 5 to 9, - maxDepth = maxDepth - ).as("nested0Col0")) - - // dataframe with nested*Col0 to nested*Col4 at each depth - val expectedDf = nestedDf(maxDepth, 5, nullable = false) - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) - } - - test("modifyNestedColumns should add and drop 5 columns at each depth of nesting") { - // dataframe with nested*Col0 to nested*Col9 at each depth - val inputDf = nestedDf(maxDepth, 10, nullable = false) - - // drop nested*Col5 to nested*Col9 at each depth - val resultDf = inputDf.select(modifyNestedColumns( - column = col(nestedColName(0, 0)), - numsToAdd = 10 to 14, - numsToDrop = 5 to 9, - maxDepth = maxDepth - ).as("nested0Col0")) - - // dataframe with nested*Col0 to nested*Col4 and nested*Col10 to nested*Col14 at each depth - val expectedDf = { - val numCols = (0 to 4) ++ (10 to 14) - val nestedColumn = nestedRow(numCols, maxDepth) - val nestedColumnDataType = nestedStructType(numCols, nullable = false, maxDepth) - - spark.createDataFrame( - sparkContext.parallelize(Row(nestedColumn) :: Nil), - StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable = false)))) + Seq(Performant, NonPerformant).foreach { method => + Seq(false, true).foreach { nullable => + test(s"should add and drop 1 column at each depth of nesting using ${method.name} method, " + + s"nullable = $nullable") { + val maxDepth = 3 + + // dataframe with nested*Col0 to nested*Col2 at each depth + val inputDf = emptyNestedDf(maxDepth, 3, nullable) + + // add nested*Col3 and drop nested*Col2 + val modifiedColumn = method( + column = col(nestedColName(0, 0)), + numsToAdd = Seq(3), + numsToDrop = Seq(2), + maxDepth = maxDepth + ).as(nestedColName(0, 0)) + val resultDf = inputDf.select(modifiedColumn) + + // dataframe with nested*Col0, nested*Col1, nested*Col3 at each depth + val expectedDf = { + val colNums = Seq(0, 1, 3) + val nestedColumnDataType = nestedStructType(colNums, nullable, maxDepth) + + spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) + } + + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } } - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) } } From 4fe48b4287c81e73276165453477811211e341d9 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Sat, 26 Sep 2020 11:42:13 -0400 Subject: [PATCH 15/18] add cases we're not able to optimize well currently to specs --- .../optimizer/complexTypesSuite.scala | 103 +++++++++++++++--- 1 file changed, 88 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 7ba24b13cda4..d9cefdaf3fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -677,25 +677,25 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as "struct3B")) } - test("simplify add multiple nested fields to struct") { - // this scenario is possible if users add multiple nested columns via the Column.withField API - // ideally, users should not be doing this. - val nullableStructLevel2 = LocalRelation( + test("simplify add multiple nested fields to non-nullable struct") { + // this scenario is possible if users add multiple nested columns to a non-nullable struct + // using the Column.withField API in a non-performant way + val structLevel2 = LocalRelation( 'a1.struct( - 'a2.struct('a3.int)).withNullability(false)) + 'a2.struct('a3.int.notNull)).notNull) val query = { val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) - nullableStructLevel2.select( + structLevel2.select( UpdateFields( addB3toA1A2, Seq(WithField("a2", UpdateFields( GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) } - val expected = nullableStructLevel2.select( + val expected = structLevel2.select( UpdateFields('a1, Seq( // scalastyle:off line.size.limit WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: Nil)), @@ -706,26 +706,62 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule(query, expected) } - test("simplify drop multiple nested fields in struct") { - // this scenario is possible if users drop multiple nested columns via the Column.dropFields API - // ideally, users should not be doing this. - val df = LocalRelation( + test("simplify add multiple nested fields to nullable struct") { + // this scenario is possible if users add multiple nested columns to a nullable struct + // using the Column.withField API in a non-performant way + val structLevel2 = LocalRelation( 'a1.struct( - 'a2.struct('a3.int, 'b3.int, 'c3.int).withNullability(false) - ).withNullability(false)) + 'a2.struct('a3.int.notNull))) + + val query = { + val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", + UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + + structLevel2.select( + UpdateFields( + addB3toA1A2, + Seq(WithField("a2", UpdateFields( + GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) + } + + val expected = { + val repeatedExpr = UpdateFields(GetStructField('a1, 0), WithField("b3", Literal(2)) :: Nil) + val repeatedExprDataType = StructType(Seq( + StructField("a3", IntegerType, nullable = false), + StructField("b3", IntegerType, nullable = false))) + + structLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", repeatedExpr), + WithField("a2", UpdateFields( + If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), + WithField("c3", Literal(3)) :: Nil)) + )).as("a1")) + } + + checkRule(query, expected) + } + + test("simplify drop multiple nested fields in non-nullable struct") { + // this scenario is possible if users drop multiple nested columns in a non-nullable struct + // using the Column.dropFields API in a non-performant way + val structLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int.notNull, 'b3.int.notNull, 'c3.int.notNull).notNull + ).notNull) val query = { val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( GetStructField('a1, 0), Seq(DropField("b3")))))) - df.select( + structLevel2.select( UpdateFields( dropA1A2B, Seq(WithField("a2", UpdateFields( GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) } - val expected = df.select( + val expected = structLevel2.select( UpdateFields('a1, Seq( WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3")))), WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) @@ -733,4 +769,41 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule(query, expected) } + + test("simplify drop multiple nested fields in nullable struct") { + // this scenario is possible if users drop multiple nested columns in a nullable struct + // using the Column.dropFields API in a non-performant way + val structLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int.notNull, 'b3.int.notNull, 'c3.int.notNull) + )) + + val query = { + val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( + GetStructField('a1, 0), Seq(DropField("b3")))))) + + structLevel2.select( + UpdateFields( + dropA1A2B, + Seq(WithField("a2", UpdateFields( + GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) + } + + val expected = { + val repeatedExpr = UpdateFields(GetStructField('a1, 0), DropField("b3") :: Nil) + val repeatedExprDataType = StructType(Seq( + StructField("a3", IntegerType, nullable = false), + StructField("c3", IntegerType, nullable = false))) + + structLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", repeatedExpr), + WithField("a2", UpdateFields( + If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), + DropField("c3") :: Nil)) + )).as("a1")) + } + + checkRule(query, expected) + } } From 2e9902ebfeb9d6ef8d157530ddbd58cd7ec83987 Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Sat, 26 Sep 2020 19:16:34 -0400 Subject: [PATCH 16/18] better test coverage for nullable path --- .../spark/sql/ColumnExpressionSuite.scala | 192 ++++++++++++++---- 1 file changed, 148 insertions(+), 44 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ecffd21e8279..bf7a140c0c80 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -931,8 +931,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil), StructType(Seq(StructField("a", structType, nullable = false)))) - private lazy val nullStructLevel1: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(null) :: Nil), + private lazy val nullableStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Row(Row(1, null, 3)) :: Nil), StructType(Seq(StructField("a", structType, nullable = true)))) private lazy val structLevel2: DataFrame = spark.createDataFrame( @@ -942,12 +942,12 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("a", structType, nullable = false))), nullable = false)))) - private lazy val nullStructLevel2: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(Row(null)) :: Nil), + private lazy val nullableStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Row(Row(null)) :: Row(Row(Row(1, null, 3))) :: Nil), StructType(Seq( StructField("a", StructType(Seq( StructField("a", structType, nullable = true))), - nullable = false)))) + nullable = true)))) private lazy val structLevel3: DataFrame = spark.createDataFrame( sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), @@ -1034,10 +1034,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } - test("withField should add field to null struct") { + test("withField should add field to nullable struct") { checkAnswer( - nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))), - Row(null) :: Nil, + nullableStructLevel1.withColumn("a", $"a".withField("d", lit(4))), + Row(null) :: Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), @@ -1047,10 +1047,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = true)))) } - test("withField should add field to nested null struct") { + test("withField should add field to nested nullable struct") { checkAnswer( - nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), - Row(Row(null)) :: Nil, + nullableStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, null, 3, 4))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( StructField("a", StructType(Seq( @@ -1059,7 +1059,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("c", IntegerType, nullable = false), StructField("d", IntegerType, nullable = false))), nullable = true))), - nullable = false)))) + nullable = true)))) } test("withField should add null field to struct") { @@ -1089,6 +1089,20 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } + test("withField should add multiple fields to nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + Row(null) :: Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = true)))) + } + test("withField should add field to nested struct") { Seq( structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), @@ -1109,6 +1123,48 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } + test("withField should add multiple fields to nested struct") { + Seq( + col("a").withField("a", $"a.a".withField("d", lit(4)).withField("e", lit(5))), + col("a").withField("a.d", lit(4)).withField("a.e", lit(5)) + ).foreach { column => + checkAnswer( + structLevel2.select(column.as("a")), + Row(Row(Row(1, null, 3, 4, 5))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should add multiple fields to nested nullable struct") { + Seq( + col("a").withField("a", $"a.a".withField("d", lit(4)).withField("e", lit(5))), + col("a").withField("a.d", lit(4)).withField("a.e", lit(5)) + ).foreach { column => + checkAnswer( + nullableStructLevel2.select(column.as("a")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, null, 3, 4, 5))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = true))), + nullable = true)))) + } + } + test("withField should add field to deeply nested struct") { checkAnswer( structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), @@ -1138,10 +1194,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } - test("withField should replace field in null struct") { + test("withField should replace field in nullable struct") { checkAnswer( - nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), - Row(null) :: Nil, + nullableStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + Row(null) :: Row(Row(1, "foo", 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), @@ -1150,10 +1206,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = true)))) } - test("withField should replace field in nested null struct") { + test("withField should replace field in nested nullable struct") { checkAnswer( - nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), - Row(Row(null)) :: Nil, + nullableStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, "foo", 3))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( StructField("a", StructType(Seq( @@ -1161,7 +1217,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("b", StringType, nullable = false), StructField("c", IntegerType, nullable = false))), nullable = true))), - nullable = false)))) + nullable = true)))) } test("withField should replace field with null value in struct") { @@ -1188,6 +1244,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } + test("withField should replace multiple fields in nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + Row(null) :: Row(Row(10, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + test("withField should replace field in nested struct") { Seq( structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), @@ -1207,6 +1275,44 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } + test("withField should replace multiple fields in nested struct") { + Seq( + col("a").withField("a", $"a.a".withField("a", lit(10)).withField("b", lit(20))), + col("a").withField("a.a", lit(10)).withField("a.b", lit(20)) + ).foreach { column => + checkAnswer( + structLevel2.select(column.as("a")), + Row(Row(Row(10, 20, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should replace multiple fields in nested nullable struct") { + Seq( + col("a").withField("a", $"a.a".withField("a", lit(10)).withField("b", lit(20))), + col("a").withField("a.a", lit(10)).withField("a.b", lit(20)) + ).foreach { column => + checkAnswer( + nullableStructLevel2.select(column.as("a")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(10, 20, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = true)))) + } + } + test("withField should replace field in deeply nested struct") { checkAnswer( structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), @@ -1484,6 +1590,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-32641: extracting field from null struct column after withField should return " + "null if the original struct was null") { + val nullStructLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + // extract newly added field checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("d")), @@ -1594,10 +1704,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } - test("dropFields should drop field in null struct") { + test("dropFields should drop field in nullable struct") { checkAnswer( - nullStructLevel1.withColumn("a", $"a".dropFields("b")), - Row(null) :: Nil, + nullableStructLevel1.withColumn("a", $"a".dropFields("b")), + Row(null) :: Row(Row(1, 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), @@ -1669,29 +1779,29 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } - test("dropFields should drop field in nested null struct") { + test("dropFields should drop field in nested nullable struct") { checkAnswer( - nullStructLevel2.withColumn("a", $"a".dropFields("a.b")), - Row(Row(null)) :: Nil, + nullableStructLevel2.withColumn("a", $"a".dropFields("a.b")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, 3))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), StructField("c", IntegerType, nullable = false))), nullable = true))), - nullable = false)))) + nullable = true)))) } - test("dropFields should drop multiple fields in nested null struct") { + test("dropFields should drop multiple fields in nested nullable struct") { checkAnswer( - nullStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")), - Row(Row(null)) :: Nil, + nullableStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false))), nullable = true))), - nullable = false)))) + nullable = true)))) } test("dropFields should drop field in deeply nested struct") { @@ -2025,31 +2135,25 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("should move field up one level of nesting") { - val nullableStructLevel2: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(Row(null)) :: Row(Row(Row(1, 2, 3))) :: Nil), - StructType(Seq( - StructField("a", StructType(Seq( - StructField("a", structType, nullable = true))), - nullable = true)))) - // move a field up one level checkAnswer( nullableStructLevel2.select( - col("a").withField("b", col("a.a.b")).dropFields("a.b").as("res")), - Row(Row(null, null)) :: Row(Row(Row(1, 3), 2)) :: Nil, + col("a").withField("c", col("a.a.c")).dropFields("a.c").as("res")), + Row(null) :: Row(Row(null, null)) :: Row(Row(Row(1, null), 3)) :: Nil, StructType(Seq( StructField("res", StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), - StructField("c", IntegerType, nullable = false))), + StructField("b", IntegerType, nullable = true))), nullable = true), - StructField("b", IntegerType, nullable = true))), + StructField("c", IntegerType, nullable = true))), nullable = true)))) // move a field up one level and then extract it checkAnswer( - nullableStructLevel2.select(col("a").withField("b", col("a.a.b")).getField("b").as("res")), - Row(null) :: Row(2) :: Nil, + nullableStructLevel2.select( + col("a").withField("c", col("a.a.c")).dropFields("a.c").getField("c").as("res")), + Row(null) :: Row(null) :: Row(3) :: Nil, StructType(Seq(StructField("res", IntegerType, nullable = true)))) } @@ -2089,7 +2193,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("should still be able to refer to dropped column within the same select statement") { - // we can still access the nested column even after dropping it within the select statement + // we can still access the nested column even after dropping it within the same select statement checkAnswer( structLevel1.select($"a".dropFields("c").withField("z", $"a.c").as("a")), Row(Row(1, null, 3)) :: Nil, From cca6f37d03df6c41561dc2b4cc127ebf1305a8ed Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Sat, 26 Sep 2020 16:48:57 -0400 Subject: [PATCH 17/18] fix indentation --- .../sql/catalyst/optimizer/ComplexTypes.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 0e63cecdcecb..860219e55b05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -40,14 +40,14 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(u: UpdateFields, ordinal, _) if !u.structExpr.isInstanceOf[UpdateFields] => - val structExpr = u.structExpr - u.newExprs(ordinal) match { - // if the struct itself is null, then any value extracted from it (expr) will be null - // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) - case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr - case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr) - } + case GetStructField(u: UpdateFields, ordinal, _)if !u.structExpr.isInstanceOf[UpdateFields] => + val structExpr = u.structExpr + u.newExprs(ordinal) match { + // if the struct itself is null, then any value extracted from it (expr) will be null + // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) + case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr + case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr) + } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => // Instead of selecting the field on the entire array, select it from each member From 7e51f35580db72fda11153d76caf232b83e617cd Mon Sep 17 00:00:00 2001 From: "fqaiser94@gmail.com" Date: Tue, 29 Sep 2020 16:45:43 -0400 Subject: [PATCH 18/18] move UpdateFieldsBenchmark tests to ColumnExpressionSuite --- .../spark/sql/ColumnExpressionSuite.scala | 90 +++++++++++++++++ .../org/apache/spark/sql/QueryTest.scala | 9 -- .../spark/sql/UpdateFieldsBenchmark.scala | 96 +------------------ 3 files changed, 95 insertions(+), 100 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bf7a140c0c80..b11f4c603dfd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.matchers.should.Matchers._ +import org.apache.spark.sql.UpdateFieldsBenchmark._ import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.functions._ @@ -922,6 +923,14 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(inSet.sql === "('a' IN ('a', 'b'))") } + def checkAnswer( + df: => DataFrame, + expectedAnswer: Seq[Row], + expectedSchema: StructType): Unit = { + checkAnswer(df, expectedAnswer) + assert(df.schema == expectedSchema) + } + private lazy val structType = StructType(Seq( StructField("a", IntegerType, nullable = false), StructField("b", IntegerType, nullable = true), @@ -2212,4 +2221,85 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { .select($"a".withField("z", $"a.c")).as("a") }.getMessage should include("No such struct field c in a, b;") } + + test("nestedDf should generate nested DataFrames") { + checkAnswer( + emptyNestedDf(1, 1, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(1, 2, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(2, 1, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(2, 2, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(2, 2, nullable = true), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = true), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = true)))) + } + + Seq(Performant, NonPerformant).foreach { method => + Seq(false, true).foreach { nullable => + test(s"should add and drop 1 column at each depth of nesting using ${method.name} method, " + + s"nullable = $nullable") { + val maxDepth = 3 + + // dataframe with nested*Col0 to nested*Col2 at each depth + val inputDf = emptyNestedDf(maxDepth, 3, nullable) + + // add nested*Col3 and drop nested*Col2 + val modifiedColumn = method( + column = col(nestedColName(0, 0)), + numsToAdd = Seq(3), + numsToDrop = Seq(2), + maxDepth = maxDepth + ).as(nestedColName(0, 0)) + val resultDf = inputDf.select(modifiedColumn) + + // dataframe with nested*Col0, nested*Col1, nested*Col3 at each depth + val expectedDf = { + val colNums = Seq(0, 1, 3) + val nestedColumnDataType = nestedStructType(colNums, nullable, maxDepth) + + spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) + } + + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8ac6af2ac25d..8469216901b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.types.StructType import org.apache.spark.storage.StorageLevel @@ -160,14 +159,6 @@ abstract class QueryTest extends PlanTest { checkAnswer(df, expectedAnswer.collect()) } - protected def checkAnswer( - df: => DataFrame, - expectedAnswer: Seq[Row], - expectedSchema: StructType): Unit = { - checkAnswer(df, expectedAnswer) - assert(df.schema == expectedSchema) - } - /** * Runs the plan and makes sure the answer is within absTol of the expected result. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala index 31c77550ce90..28af552fe586 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import org.apache.spark.benchmark.Benchmark import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, StructField, StructType} /** @@ -39,9 +38,9 @@ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} */ object UpdateFieldsBenchmark extends SqlBasedBenchmark { - private def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" + def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" - private def nestedStructType( + def nestedStructType( colNums: Seq[Int], nullable: Boolean, maxDepth: Int, @@ -83,12 +82,12 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) } - private trait ModifyNestedColumns { + trait ModifyNestedColumns { val name: String def apply(column: Column, numsToAdd: Seq[Int], numsToDrop: Seq[Int], maxDepth: Int): Column } - private object Performant extends ModifyNestedColumns { + object Performant extends ModifyNestedColumns { override val name: String = "performant" override def apply( @@ -129,7 +128,7 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { } } - private object NonPerformant extends ModifyNestedColumns { + object NonPerformant extends ModifyNestedColumns { override val name: String = "non-performant" override def apply( @@ -223,88 +222,3 @@ object UpdateFieldsBenchmark extends SqlBasedBenchmark { numsToDrop = 1 to 50) } } - -class UpdateFieldsBenchmark extends QueryTest with SharedSparkSession { - import UpdateFieldsBenchmark._ - - test("nestedDf should generate nested DataFrames") { - checkAnswer( - emptyNestedDf(1, 1, nullable = false), - Seq.empty[Row], - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", IntegerType, nullable = false))), - nullable = false)))) - - checkAnswer( - emptyNestedDf(1, 2, nullable = false), - Seq.empty[Row], - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", IntegerType, nullable = false), - StructField("nested1Col1", IntegerType, nullable = false))), - nullable = false)))) - - checkAnswer( - emptyNestedDf(2, 1, nullable = false), - Seq.empty[Row], - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", StructType(Seq( - StructField("nested2Col0", IntegerType, nullable = false))), - nullable = false))), - nullable = false)))) - - checkAnswer( - emptyNestedDf(2, 2, nullable = false), - Seq.empty[Row], - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", StructType(Seq( - StructField("nested2Col0", IntegerType, nullable = false), - StructField("nested2Col1", IntegerType, nullable = false))), - nullable = false), - StructField("nested1Col1", IntegerType, nullable = false))), - nullable = false)))) - - checkAnswer( - emptyNestedDf(2, 2, nullable = true), - Seq.empty[Row], - StructType(Seq(StructField("nested0Col0", StructType(Seq( - StructField("nested1Col0", StructType(Seq( - StructField("nested2Col0", IntegerType, nullable = false), - StructField("nested2Col1", IntegerType, nullable = false))), - nullable = true), - StructField("nested1Col1", IntegerType, nullable = false))), - nullable = true)))) - } - - Seq(Performant, NonPerformant).foreach { method => - Seq(false, true).foreach { nullable => - test(s"should add and drop 1 column at each depth of nesting using ${method.name} method, " + - s"nullable = $nullable") { - val maxDepth = 3 - - // dataframe with nested*Col0 to nested*Col2 at each depth - val inputDf = emptyNestedDf(maxDepth, 3, nullable) - - // add nested*Col3 and drop nested*Col2 - val modifiedColumn = method( - column = col(nestedColName(0, 0)), - numsToAdd = Seq(3), - numsToDrop = Seq(2), - maxDepth = maxDepth - ).as(nestedColName(0, 0)) - val resultDf = inputDf.select(modifiedColumn) - - // dataframe with nested*Col0, nested*Col1, nested*Col3 at each depth - val expectedDf = { - val colNums = Seq(0, 1, 3) - val nestedColumnDataType = nestedStructType(colNums, nullable, maxDepth) - - spark.createDataFrame( - spark.sparkContext.emptyRDD[Row], - StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) - } - - checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) - } - } - } -}