diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 563ce7133a3d..5058bdd5264c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion} +import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCheckResult, TypeCoercion} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.{FUNC_ALIAS, FunctionBuilder} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ @@ -541,57 +543,114 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } /** - * Adds/replaces field in struct by name. + * Represents an operation to be applied to the fields of a struct. */ -case class WithFields( - structExpr: Expression, - names: Seq[String], - valExprs: Seq[Expression]) extends Unevaluable { +trait StructFieldsOperation { + + val resolver: Resolver = SQLConf.get.resolver - assert(names.length == valExprs.length) + /** + * Returns an updated list of StructFields and Expressions that will ultimately be used + * as the fields argument for [[StructType]] and as the children argument for + * [[CreateNamedStruct]] respectively inside of [[UpdateFields]]. + */ + def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] +} + +/** + * Add or replace a field by name. + * + * We extend [[Unevaluable]] here to ensure that [[UpdateFields]] can include it as part of its + * children, and thereby enable the analyzer to resolve and transform valExpr as necessary. + */ +case class WithField(name: String, valExpr: Expression) + extends Unevaluable with StructFieldsOperation { + + override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = { + val newFieldExpr = (StructField(name, valExpr.dataType, valExpr.nullable), valExpr) + val result = ArrayBuffer.empty[(StructField, Expression)] + var hasMatch = false + for (existingFieldExpr @ (existingField, _) <- values) { + if (resolver(existingField.name, name)) { + hasMatch = true + result += newFieldExpr + } else { + result += existingFieldExpr + } + } + if (!hasMatch) result += newFieldExpr + result + } + + override def children: Seq[Expression] = valExpr :: Nil + + override def dataType: DataType = throw new IllegalStateException( + "WithField.dataType should not be called.") + + override def nullable: Boolean = throw new IllegalStateException( + "WithField.nullable should not be called.") + + override def prettyName: String = "WithField" +} + +/** + * Drop a field by name. + */ +case class DropField(name: String) extends StructFieldsOperation { + override def apply(values: Seq[(StructField, Expression)]): Seq[(StructField, Expression)] = + values.filterNot { case (field, _) => resolver(field.name, name) } +} + +/** + * Updates fields in a struct. + */ +case class UpdateFields(structExpr: Expression, fieldOps: Seq[StructFieldsOperation]) + extends Unevaluable { override def checkInputDataTypes(): TypeCheckResult = { - if (!structExpr.dataType.isInstanceOf[StructType]) { - TypeCheckResult.TypeCheckFailure( - "struct argument should be struct type, got: " + structExpr.dataType.catalogString) + val dataType = structExpr.dataType + if (!dataType.isInstanceOf[StructType]) { + TypeCheckResult.TypeCheckFailure("struct argument should be struct type, got: " + + dataType.catalogString) + } else if (newExprs.isEmpty) { + TypeCheckResult.TypeCheckFailure("cannot drop all fields in struct") } else { TypeCheckResult.TypeCheckSuccess } } - override def children: Seq[Expression] = structExpr +: valExprs + override def children: Seq[Expression] = structExpr +: fieldOps.collect { + case e: Expression => e + } - override def dataType: StructType = evalExpr.dataType.asInstanceOf[StructType] + override def dataType: StructType = StructType(newFields) override def nullable: Boolean = structExpr.nullable - override def prettyName: String = "with_fields" + override def prettyName: String = "update_fields" - lazy val evalExpr: Expression = { - val existingExprs = structExpr.dataType.asInstanceOf[StructType].fieldNames.zipWithIndex.map { - case (name, i) => (name, GetStructField(KnownNotNull(structExpr), i).asInstanceOf[Expression]) - } + private lazy val newFieldExprs: Seq[(StructField, Expression)] = { + val existingFieldExprs: Seq[(StructField, Expression)] = + structExpr.dataType.asInstanceOf[StructType].fields.zipWithIndex.map { + case (field, i) => (field, GetStructField(structExpr, i)) + } - val addOrReplaceExprs = names.zip(valExprs) - - val resolver = SQLConf.get.resolver - val newExprs = addOrReplaceExprs.foldLeft(existingExprs) { - case (resultExprs, newExpr @ (newExprName, _)) => - if (resultExprs.exists(x => resolver(x._1, newExprName))) { - resultExprs.map { - case (name, _) if resolver(name, newExprName) => newExpr - case x => x - } - } else { - resultExprs :+ newExpr - } - }.flatMap { case (name, expr) => Seq(Literal(name), expr) } + fieldOps.foldLeft(existingFieldExprs)((exprs, op) => op(exprs)) + } + + private lazy val newFields: Seq[StructField] = newFieldExprs.map(_._1) + + lazy val newExprs: Seq[Expression] = newFieldExprs.map(_._2) + + lazy val evalExpr: Expression = { + val createNamedStructExpr = CreateNamedStruct(newFieldExprs.flatMap { + case (field, expr) => Seq(Literal(field.name), expr) + }) - val expr = CreateNamedStruct(newExprs) if (structExpr.nullable) { - If(IsNull(structExpr), Literal(null, expr.dataType), expr) + If(IsNull(structExpr), Literal(null, dataType), createNamedStructExpr) } else { - expr + createNamedStructExpr } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 2aba4bae397c..860219e55b05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.StructType /** * Simplify redundant [[CreateNamedStruct]], [[CreateArray]] and [[CreateMap]] expressions. @@ -39,18 +40,13 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // Remove redundant field extraction. case GetStructField(createNamedStruct: CreateNamedStruct, ordinal, _) => createNamedStruct.valExprs(ordinal) - case GetStructField(w @ WithFields(struct, names, valExprs), ordinal, maybeName) => - val name = w.dataType(ordinal).name - val matches = names.zip(valExprs).filter(_._1 == name) - if (matches.nonEmpty) { - // return last matching element as that is the final value for the field being extracted. - // For example, if a user submits a query like this: - // `$"struct_col".withField("b", lit(1)).withField("b", lit(2)).getField("b")` - // we want to return `lit(2)` (and not `lit(1)`). - val expr = matches.last._2 - If(IsNull(struct), Literal(null, expr.dataType), expr) - } else { - GetStructField(struct, ordinal, maybeName) + case GetStructField(u: UpdateFields, ordinal, _)if !u.structExpr.isInstanceOf[UpdateFields] => + val structExpr = u.structExpr + u.newExprs(ordinal) match { + // if the struct itself is null, then any value extracted from it (expr) will be null + // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) + case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr + case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr) } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6033c01a60f4..ac7a14716fbf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -109,7 +109,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveRedundantAliases, UnwrapCastInBinaryComparison, RemoveNoopOperators, - CombineWithFields, + CombineUpdateFields, SimplifyExtractValueOps, CombineConcats) ++ extendedOperatorOptimizationRules @@ -221,7 +221,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveNoopOperators) :+ // This batch must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ - Batch("ReplaceWithFieldsExpression", Once, ReplaceWithFieldsExpression) + Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) @@ -255,7 +255,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceWithFieldsExpression.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala similarity index 68% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index 05c90864e4bb..c7154210e0c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/WithFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -17,26 +17,26 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.WithFields +import org.apache.spark.sql.catalyst.expressions.UpdateFields import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** - * Combines all adjacent [[WithFields]] expression into a single [[WithFields]] expression. + * Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression. */ -object CombineWithFields extends Rule[LogicalPlan] { +object CombineUpdateFields extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case WithFields(WithFields(struct, names1, valExprs1), names2, valExprs2) => - WithFields(struct, names1 ++ names2, valExprs1 ++ valExprs2) + case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => + UpdateFields(struct, fieldOps1 ++ fieldOps2) } } /** - * Replaces [[WithFields]] expression with an evaluable expression. + * Replaces [[UpdateFields]] expression with an evaluable expression. */ -object ReplaceWithFieldsExpression extends Rule[LogicalPlan] { +object ReplaceUpdateFieldsExpression extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case w: WithFields => w.evalExpr + case u: UpdateFields => u.evalExpr } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala similarity index 65% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala index a3e0bbc57e63..ff9c60a2fa5b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineWithFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala @@ -19,56 +19,53 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, WithFields} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -class CombineWithFieldsSuite extends PlanTest { +class CombineUpdateFieldsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("CombineWithFields", FixedPoint(10), CombineWithFields) :: Nil + val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil } private val testRelation = LocalRelation('a.struct('a1.int)) - test("combines two WithFields") { + test("combines two adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( - WithFields( - WithFields( + UpdateFields( + UpdateFields( 'a, - Seq("b1"), - Seq(Literal(4))), - Seq("c1"), - Seq(Literal(5))), "out")()) + WithField("b1", Literal(4)) :: Nil), + WithField("c1", Literal(5)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1", "c1"), Seq(Literal(4), Literal(5))), "out")()) + .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + Nil), "out")()) .analyze comparePlans(optimized, correctAnswer) } - test("combines three WithFields") { + test("combines three adjacent UpdateFields Expressions") { val originalQuery = testRelation .select(Alias( - WithFields( - WithFields( - WithFields( + UpdateFields( + UpdateFields( + UpdateFields( 'a, - Seq("b1"), - Seq(Literal(4))), - Seq("c1"), - Seq(Literal(5))), - Seq("d1"), - Seq(Literal(6))), "out")()) + WithField("b1", Literal(4)) :: Nil), + WithField("c1", Literal(5)) :: Nil), + WithField("d1", Literal(6)) :: Nil), "out")()) val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select(Alias(WithFields('a, Seq("b1", "c1", "d1"), Seq(4, 5, 6).map(Literal(_))), "out")()) + .select(Alias(UpdateFields('a, WithField("b1", Literal(4)) :: WithField("c1", Literal(5)) :: + WithField("d1", Literal(6)) :: Nil), "out")()) .analyze comparePlans(optimized, correctAnswer) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index 00aed6a10cd6..d9cefdaf3fe7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, OneRowRelation, Project, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -37,14 +37,15 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { object Optimizer extends RuleExecutor[LogicalPlan] { val batches = Batch("collapse projections", FixedPoint(10), - CollapseProject) :: + CollapseProject) :: Batch("Constant Folding", FixedPoint(10), - NullPropagation, - ConstantFolding, - BooleanSimplification, - SimplifyConditionals, - SimplifyBinaryComparison, - SimplifyExtractValueOps) :: Nil + NullPropagation, + ConstantFolding, + BooleanSimplification, + SimplifyConditionals, + SimplifyBinaryComparison, + CombineUpdateFields, + SimplifyExtractValueOps) :: Nil } private val idAtt = ('id).long.notNull @@ -453,58 +454,182 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null) } - private val structAttr = 'struct1.struct('a.int).withNullability(false) + private val structAttr = 'struct1.struct('a.int, 'b.int).withNullability(false) private val testStructRelation = LocalRelation(structAttr) - private val nullableStructAttr = 'struct1.struct('a.int) + private val nullableStructAttr = 'struct1.struct('a.int, 'b.int) private val testNullableStructRelation = LocalRelation(nullableStructAttr) - test("simplify GetStructField on WithFields that is not changing the attribute being extracted") { - def query(relation: LocalRelation): LogicalPlan = relation.select( - GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 0, Some("a")) as "outerAttr") + test("simplify GetStructField on basic UpdateFields") { + def check(fieldOps: Seq[StructFieldsOperation], ordinal: Int, expected: Expression): Unit = { + def query(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField(UpdateFields('struct1, fieldOps), ordinal).as("res")) + + checkRule( + query(testStructRelation), + testStructRelation.select(expected.as("res"))) + + checkRule( + query(testNullableStructRelation), + testNullableStructRelation.select((expected match { + case expr: GetStructField => expr + case expr => If(IsNull('struct1), Literal(null, expr.dataType), expr) + }).as("res"))) + } + + // scalastyle:off line.size.limit + + // add attribute, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: Nil, 1, GetStructField('struct1, 1)) + // add attribute, extract added attribute + check(WithField("c", Literal(3)) :: Nil, 2, Literal(3)) + + // replace attribute, extract an attribute from the original struct + check(WithField("a", Literal(1)) :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("b", Literal(2)) :: Nil, 0, GetStructField('struct1, 0)) + // replace attribute, extract replaced attribute + check(WithField("a", Literal(1)) :: Nil, 0, Literal(1)) + check(WithField("b", Literal(2)) :: Nil, 1, Literal(2)) + + // add multiple attributes, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 1, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 1, GetStructField('struct1, 1)) + // add multiple attributes, extract newly added attribute + check(WithField("c", Literal(3)) :: WithField("c", Literal(4)) :: Nil, 2, Literal(4)) + check(WithField("c", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 2, Literal(3)) + check(WithField("c", Literal(3)) :: WithField("d", Literal(4)) :: Nil, 3, Literal(4)) + check(WithField("d", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 2, Literal(4)) + check(WithField("d", Literal(4)) :: WithField("c", Literal(3)) :: Nil, 3, Literal(3)) + + // drop attribute, extract an attribute from the original struct + check(DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + + // drop attribute, add attribute, extract an attribute from the original struct + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 1)) + // drop attribute, add attribute, extract added attribute + check(DropField("b") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("a") :: WithField("c", Literal(3)) :: Nil, 1, Literal(3)) + + // add attribute, drop attribute, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + // add attribute, drop attribute, extract added attribute + check(WithField("c", Literal(3)) :: DropField("a") :: Nil, 1, Literal(3)) + check(WithField("c", Literal(3)) :: DropField("b") :: Nil, 1, Literal(3)) + + // replace attribute, drop same attribute, extract an attribute from the original struct + check(WithField("b", Literal(3)) :: DropField("b") :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("a", Literal(3)) :: DropField("a") :: Nil, 0, GetStructField('struct1, 1)) + + // add attribute, drop same attribute, extract an attribute from the original struct + check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 0, GetStructField('struct1, 0)) + check(WithField("c", Literal(3)) :: DropField("c") :: Nil, 1, GetStructField('struct1, 1)) + + // replace attribute, drop another attribute, extract added attribute + check(WithField("b", Literal(3)) :: DropField("a") :: Nil, 0, Literal(3)) + check(WithField("a", Literal(3)) :: DropField("b") :: Nil, 0, Literal(3)) + + // drop attribute, add same attribute, extract attribute from the original struct + check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 0, GetStructField('struct1, 1)) + // drop attribute, add same attribute, extract added attribute + check(DropField("b") :: WithField("b", Literal(3)) :: Nil, 1, Literal(3)) + check(DropField("a") :: WithField("a", Literal(3)) :: Nil, 1, Literal(3)) + + // drop non-existent attribute, add same attribute, extract attribute from the original struct + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 0, GetStructField('struct1, 0)) + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 1, GetStructField('struct1, 1)) + // drop non-existent attribute, add same attribute, extract added attribute + check(DropField("c") :: WithField("c", Literal(3)) :: Nil, 2, Literal(3)) + + // scalastyle:on line.size.limit + } + + test("simplify GetStructField that is extracting a field nested inside a struct") { + val struct2 = 'struct2.struct('b.int) + val testStructRelation = LocalRelation(structAttr, struct2) + val testNullableStructRelation = LocalRelation(nullableStructAttr, struct2) + + // if the field being extracted is from the same struct that UpdateFields is modifying, + // we can just return GetStructField in both the non-nullable and nullable struct scenario + + def addFieldFromSameStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField( + UpdateFields('struct1, WithField("b", GetStructField('struct1, 0)) :: Nil), 1).as("res")) checkRule( - query(testStructRelation), - testStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAttr")) + addFieldFromSameStructAndThenExtractIt(testStructRelation), + testStructRelation.select(GetStructField('struct1, 0).as("res"))) checkRule( - query(testNullableStructRelation), - testNullableStructRelation.select(GetStructField('struct1, 0, Some("a")) as "outerAttr")) - } + addFieldFromSameStructAndThenExtractIt(testNullableStructRelation), + testNullableStructRelation.select(GetStructField('struct1, 0).as("res"))) - test("simplify GetStructField on WithFields that is changing the attribute being extracted") { - def query(relation: LocalRelation): LogicalPlan = relation.select( - GetStructField(WithFields('struct1, Seq("b"), Seq(Literal(1))), 1, Some("b")) as "res") + // if the field being extracted is from a different struct than the one UpdateFields is + // modifying, we must return GetStructField wrapped in If(IsNull(struct), null, GetStructField) + // in the nullable struct scenario + + def addFieldFromAnotherStructAndThenExtractIt(relation: LocalRelation): LogicalPlan = + relation.select(GetStructField( + UpdateFields('struct1, WithField("b", GetStructField('struct2, 0)) :: Nil), 1).as("res")) checkRule( - query(testStructRelation), - testStructRelation.select(Literal(1) as "res")) + addFieldFromAnotherStructAndThenExtractIt(testStructRelation), + testStructRelation.select(GetStructField('struct2, 0).as("res"))) checkRule( - query(testNullableStructRelation), + addFieldFromAnotherStructAndThenExtractIt(testNullableStructRelation), testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(1)) as "res")) + If(IsNull('struct1), Literal(null, IntegerType), GetStructField('struct2, 0)).as("res"))) } - test( - "simplify GetStructField on WithFields that is changing the attribute being extracted twice") { - def query(relation: LocalRelation): LogicalPlan = relation.select( - GetStructField(WithFields('struct1, Seq("b", "b"), Seq(Literal(1), Literal(2))), 1, Some("b")) - as "outerAtt") + test("simplify GetStructField on nested UpdateFields") { + def query(relation: LocalRelation, ordinal: Int): LogicalPlan = { + val nestedUpdateFields = + UpdateFields( + UpdateFields( + UpdateFields( + UpdateFields( + 'struct1, + WithField("c", Literal(1)) :: Nil), + WithField("d", Literal(2)) :: Nil), + WithField("e", Literal(3)) :: Nil), + WithField("f", Literal(4)) :: Nil) + + relation.select(GetStructField(nestedUpdateFields, ordinal) as "res") + } + + // extract newly added field checkRule( - query(testStructRelation), - testStructRelation.select(Literal(2) as "outerAtt")) + query(testStructRelation, 5), + testStructRelation.select(Literal(4) as "res")) checkRule( - query(testNullableStructRelation), + query(testNullableStructRelation, 5), testNullableStructRelation.select( - If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "outerAtt")) + If(IsNull('struct1), Literal(null, IntegerType), Literal(4)) as "res")) + + // extract field from original struct + + checkRule( + query(testStructRelation, 0), + testStructRelation.select(GetStructField('struct1, 0) as "res")) + + checkRule( + query(testNullableStructRelation, 0), + testNullableStructRelation.select(GetStructField('struct1, 0) as "res")) } - test("collapse multiple GetStructField on the same WithFields") { + test("simplify multiple GetStructField on the same UpdateFields") { def query(relation: LocalRelation): LogicalPlan = relation - .select(WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2") + .select(UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2") .select( GetStructField('struct2, 0, Some("a")) as "struct1A", GetStructField('struct2, 1, Some("b")) as "struct1B") @@ -512,21 +637,21 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { checkRule( query(testStructRelation), testStructRelation.select( - GetStructField('struct1, 0, Some("a")) as "struct1A", + GetStructField('struct1, 0) as "struct1A", Literal(2) as "struct1B")) checkRule( query(testNullableStructRelation), testNullableStructRelation.select( - GetStructField('struct1, 0, Some("a")) as "struct1A", + GetStructField('struct1, 0) as "struct1A", If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct1B")) } - test("collapse multiple GetStructField on different WithFields") { + test("simplify multiple GetStructField on different UpdateFields") { def query(relation: LocalRelation): LogicalPlan = relation .select( - WithFields('struct1, Seq("b"), Seq(Literal(2))) as "struct2", - WithFields('struct1, Seq("b"), Seq(Literal(3))) as "struct3") + UpdateFields('struct1, WithField("b", Literal(2)) :: Nil) as "struct2", + UpdateFields('struct1, WithField("b", Literal(3)) :: Nil) as "struct3") .select( GetStructField('struct2, 0, Some("a")) as "struct2A", GetStructField('struct2, 1, Some("b")) as "struct2B", @@ -537,18 +662,148 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { query(testStructRelation), testStructRelation .select( - GetStructField('struct1, 0, Some("a")) as "struct2A", + GetStructField('struct1, 0) as "struct2A", Literal(2) as "struct2B", - GetStructField('struct1, 0, Some("a")) as "struct3A", + GetStructField('struct1, 0) as "struct3A", Literal(3) as "struct3B")) checkRule( query(testNullableStructRelation), testNullableStructRelation .select( - GetStructField('struct1, 0, Some("a")) as "struct2A", + GetStructField('struct1, 0) as "struct2A", If(IsNull('struct1), Literal(null, IntegerType), Literal(2)) as "struct2B", - GetStructField('struct1, 0, Some("a")) as "struct3A", + GetStructField('struct1, 0) as "struct3A", If(IsNull('struct1), Literal(null, IntegerType), Literal(3)) as "struct3B")) } + + test("simplify add multiple nested fields to non-nullable struct") { + // this scenario is possible if users add multiple nested columns to a non-nullable struct + // using the Column.withField API in a non-performant way + val structLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int.notNull)).notNull) + + val query = { + val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", + UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + + structLevel2.select( + UpdateFields( + addB3toA1A2, + Seq(WithField("a2", UpdateFields( + GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) + } + + val expected = structLevel2.select( + UpdateFields('a1, Seq( + // scalastyle:off line.size.limit + WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: Nil)), + WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: WithField("c3", 3) :: Nil)) + // scalastyle:on line.size.limit + )).as("a1")) + + checkRule(query, expected) + } + + test("simplify add multiple nested fields to nullable struct") { + // this scenario is possible if users add multiple nested columns to a nullable struct + // using the Column.withField API in a non-performant way + val structLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int.notNull))) + + val query = { + val addB3toA1A2 = UpdateFields('a1, Seq(WithField("a2", + UpdateFields(GetStructField('a1, 0), Seq(WithField("b3", Literal(2))))))) + + structLevel2.select( + UpdateFields( + addB3toA1A2, + Seq(WithField("a2", UpdateFields( + GetStructField(addB3toA1A2, 0), Seq(WithField("c3", Literal(3))))))).as("a1")) + } + + val expected = { + val repeatedExpr = UpdateFields(GetStructField('a1, 0), WithField("b3", Literal(2)) :: Nil) + val repeatedExprDataType = StructType(Seq( + StructField("a3", IntegerType, nullable = false), + StructField("b3", IntegerType, nullable = false))) + + structLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", repeatedExpr), + WithField("a2", UpdateFields( + If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), + WithField("c3", Literal(3)) :: Nil)) + )).as("a1")) + } + + checkRule(query, expected) + } + + test("simplify drop multiple nested fields in non-nullable struct") { + // this scenario is possible if users drop multiple nested columns in a non-nullable struct + // using the Column.dropFields API in a non-performant way + val structLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int.notNull, 'b3.int.notNull, 'c3.int.notNull).notNull + ).notNull) + + val query = { + val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( + GetStructField('a1, 0), Seq(DropField("b3")))))) + + structLevel2.select( + UpdateFields( + dropA1A2B, + Seq(WithField("a2", UpdateFields( + GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) + } + + val expected = structLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3")))), + WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) + )).as("a1")) + + checkRule(query, expected) + } + + test("simplify drop multiple nested fields in nullable struct") { + // this scenario is possible if users drop multiple nested columns in a nullable struct + // using the Column.dropFields API in a non-performant way + val structLevel2 = LocalRelation( + 'a1.struct( + 'a2.struct('a3.int.notNull, 'b3.int.notNull, 'c3.int.notNull) + )) + + val query = { + val dropA1A2B = UpdateFields('a1, Seq(WithField("a2", UpdateFields( + GetStructField('a1, 0), Seq(DropField("b3")))))) + + structLevel2.select( + UpdateFields( + dropA1A2B, + Seq(WithField("a2", UpdateFields( + GetStructField(dropA1A2B, 0), Seq(DropField("c3")))))).as("a1")) + } + + val expected = { + val repeatedExpr = UpdateFields(GetStructField('a1, 0), DropField("b3") :: Nil) + val repeatedExprDataType = StructType(Seq( + StructField("a3", IntegerType, nullable = false), + StructField("c3", IntegerType, nullable = false))) + + structLevel2.select( + UpdateFields('a1, Seq( + WithField("a2", repeatedExpr), + WithField("a2", UpdateFields( + If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), + DropField("c3") :: Nil)) + )).as("a1")) + } + + checkRule(query, expected) + } } diff --git a/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt new file mode 100644 index 000000000000..5feca0e100bb --- /dev/null +++ b/sql/core/benchmarks/UpdateFieldsBenchmark-results.txt @@ -0,0 +1,26 @@ +================================================================================================ +Add 2 columns and drop 2 columns at 3 different depths of nesting +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Add 2 columns and drop 2 columns at 3 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------------------------------- +To non-nullable StructTypes using performant method 10 11 2 0.0 Infinity 1.0X +To nullable StructTypes using performant method 9 10 1 0.0 Infinity 1.0X +To non-nullable StructTypes using non-performant method 2457 2464 10 0.0 Infinity 0.0X +To nullable StructTypes using non-performant method 42641 43804 1644 0.0 Infinity 0.0X + + +================================================================================================ +Add 50 columns and drop 50 columns at 100 different depths of nesting +================================================================================================ + +OpenJDK 64-Bit Server VM 1.8.0_212-b03 on Mac OS X 10.14.6 +Intel(R) Core(TM) i7-7820HQ CPU @ 2.90GHz +Add 50 columns and drop 50 columns at 100 different depths of nesting: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +----------------------------------------------------------------------------------------------------------------------------------------------------- +To non-nullable StructTypes using performant method 4595 4927 470 0.0 Infinity 1.0X +To nullable StructTypes using performant method 5185 5516 468 0.0 Infinity 0.9X + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index da542c67d9c5..a46d6c0bb228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -901,6 +901,23 @@ class Column(val expr: Expression) extends Logging { * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields * }}} * + * This method supports adding/replacing nested fields directly e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a.c", lit(3)).withField("a.d", lit(4))) + * // result: {"a":{"a":1,"b":2,"c":3,"d":4}} + * }}} + * + * However, if you are going to add/replace multiple nested fields, it is more optimal to extract + * out the nested struct before adding/replacing multiple fields e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a", $"struct_col.a".withField("c", lit(3)).withField("d", lit(4)))) + * // result: {"a":{"a":1,"b":2,"c":3,"d":4}} + * }}} + * * @group expr_ops * @since 3.1.0 */ @@ -908,32 +925,102 @@ class Column(val expr: Expression) extends Logging { def withField(fieldName: String, col: Column): Column = withExpr { require(fieldName != null, "fieldName cannot be null") require(col != null, "col cannot be null") + updateFieldsHelper(expr, nameParts(fieldName), name => WithField(name, col.expr)) + } - val nameParts = if (fieldName.isEmpty) { + // scalastyle:off line.size.limit + /** + * An expression that drops fields in `StructType` by name. + * This is a no-op if schema doesn't contain field name(s). + * + * {{{ + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("c")) + * // result: {"a":1,"b":2} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col") + * df.select($"struct_col".dropFields("b", "c")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + * df.select($"struct_col".dropFields("a", "b")) + * // result: org.apache.spark.sql.AnalysisException: cannot resolve 'update_fields(update_fields(`struct_col`))' due to data type mismatch: cannot drop all fields in struct + * + * val df = sql("SELECT CAST(NULL AS struct) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: null of type struct + * + * val df = sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + * df.select($"struct_col".dropFields("b")) + * // result: {"a":1} + * + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".dropFields("a.b")) + * // result: {"a":{"a":1}} + * + * val df = sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + * df.select($"struct_col".dropFields("a.c")) + * // result: org.apache.spark.sql.AnalysisException: Ambiguous reference to fields + * }}} + * + * This method supports dropping multiple nested fields directly e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".dropFields("a.b", "a.c")) + * // result: {"a":{"a":1}} + * }}} + * + * However, if you are going to drop multiple nested fields, it is more optimal to extract + * out the nested struct before dropping multiple fields from it e.g. + * + * {{{ + * val df = sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + * df.select($"struct_col".withField("a", $"struct_col.a".dropFields("b", "c"))) + * // result: {"a":{"a":1}} + * }}} + * + * @group expr_ops + * @since 3.1.0 + */ + // scalastyle:on line.size.limit + def dropFields(fieldNames: String*): Column = withExpr { + def dropField(structExpr: Expression, fieldName: String): UpdateFields = + updateFieldsHelper(structExpr, nameParts(fieldName), name => DropField(name)) + + fieldNames.tail.foldLeft(dropField(expr, fieldNames.head)) { + (resExpr, fieldName) => dropField(resExpr, fieldName) + } + } + + private def nameParts(fieldName: String): Seq[String] = { + require(fieldName != null, "fieldName cannot be null") + + if (fieldName.isEmpty) { fieldName :: Nil } else { CatalystSqlParser.parseMultipartIdentifier(fieldName) } - withFieldHelper(expr, nameParts, Nil, col.expr) } - private def withFieldHelper( - struct: Expression, + private def updateFieldsHelper( + structExpr: Expression, namePartsRemaining: Seq[String], - namePartsDone: Seq[String], - value: Expression) : WithFields = { - val name = namePartsRemaining.head + valueFunc: String => StructFieldsOperation): UpdateFields = { + + val fieldName = namePartsRemaining.head if (namePartsRemaining.length == 1) { - WithFields(struct, name :: Nil, value :: Nil) + UpdateFields(structExpr, valueFunc(fieldName) :: Nil) } else { - val newNamesRemaining = namePartsRemaining.tail - val newNamesDone = namePartsDone :+ name - val newValue = withFieldHelper( - struct = UnresolvedExtractValue(struct, Literal(name)), - namePartsRemaining = newNamesRemaining, - namePartsDone = newNamesDone, - value = value) - WithFields(struct, name :: Nil, newValue :: Nil) + val newValue = updateFieldsHelper( + structExpr = UnresolvedExtractValue(structExpr, Literal(fieldName)), + namePartsRemaining = namePartsRemaining.tail, + valueFunc = valueFunc) + UpdateFields(structExpr, WithField(fieldName, newValue) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 24419968c047..b11f4c603dfd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -24,6 +24,7 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} import org.scalatest.matchers.should.Matchers._ +import org.apache.spark.sql.UpdateFieldsBenchmark._ import org.apache.spark.sql.catalyst.expressions.{InSet, Literal, NamedExpression} import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.functions._ @@ -922,11 +923,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { assert(inSet.sql === "('a' IN ('a', 'b'))") } - def checkAnswerAndSchema( + def checkAnswer( df: => DataFrame, expectedAnswer: Seq[Row], expectedSchema: StructType): Unit = { - checkAnswer(df, expectedAnswer) assert(df.schema == expectedSchema) } @@ -940,8 +940,8 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sparkContext.parallelize(Row(Row(1, null, 3)) :: Nil), StructType(Seq(StructField("a", structType, nullable = false)))) - private lazy val nullStructLevel1: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(null) :: Nil), + private lazy val nullableStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Row(Row(1, null, 3)) :: Nil), StructType(Seq(StructField("a", structType, nullable = true)))) private lazy val structLevel2: DataFrame = spark.createDataFrame( @@ -951,12 +951,12 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("a", structType, nullable = false))), nullable = false)))) - private lazy val nullStructLevel2: DataFrame = spark.createDataFrame( - sparkContext.parallelize(Row(Row(null)) :: Nil), + private lazy val nullableStructLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Row(Row(null)) :: Row(Row(Row(1, null, 3))) :: Nil), StructType(Seq( StructField("a", StructType(Seq( StructField("a", structType, nullable = true))), - nullable = false)))) + nullable = true)))) private lazy val structLevel3: DataFrame = spark.createDataFrame( sparkContext.parallelize(Row(Row(Row(Row(1, null, 3)))) :: Nil), @@ -1018,7 +1018,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field with no name") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("", lit(4))), Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( @@ -1031,7 +1031,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field to struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(4))), Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( @@ -1043,10 +1043,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } - test("withField should add field to null struct") { - checkAnswerAndSchema( - nullStructLevel1.withColumn("a", $"a".withField("d", lit(4))), - Row(null) :: Nil, + test("withField should add field to nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", $"a".withField("d", lit(4))), + Row(null) :: Row(Row(1, null, 3, 4)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), @@ -1056,10 +1056,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = true)))) } - test("withField should add field to nested null struct") { - checkAnswerAndSchema( - nullStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), - Row(Row(null)) :: Nil, + test("withField should add field to nested nullable struct") { + checkAnswer( + nullableStructLevel2.withColumn("a", $"a".withField("a.d", lit(4))), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, null, 3, 4))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( StructField("a", StructType(Seq( @@ -1068,11 +1068,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("c", IntegerType, nullable = false), StructField("d", IntegerType, nullable = false))), nullable = true))), - nullable = false)))) + nullable = true)))) } test("withField should add null field to struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(null).cast(IntegerType))), Row(Row(1, null, 3, null)) :: Nil, StructType(Seq( @@ -1085,7 +1085,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add multiple fields to struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), Row(Row(1, null, 3, 4, 5)) :: Nil, StructType(Seq( @@ -1098,12 +1098,26 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } + test("withField should add multiple fields to nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("e", lit(5))), + Row(null) :: Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = true)))) + } + test("withField should add field to nested struct") { Seq( structLevel2.withColumn("a", 'a.withField("a.d", lit(4))), structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("d", lit(4)))) ).foreach { df => - checkAnswerAndSchema( + checkAnswer( df, Row(Row(Row(1, null, 3, 4))) :: Nil, StructType( @@ -1118,8 +1132,50 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } + test("withField should add multiple fields to nested struct") { + Seq( + col("a").withField("a", $"a.a".withField("d", lit(4)).withField("e", lit(5))), + col("a").withField("a.d", lit(4)).withField("a.e", lit(5)) + ).foreach { column => + checkAnswer( + structLevel2.select(column.as("a")), + Row(Row(Row(1, null, 3, 4, 5))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should add multiple fields to nested nullable struct") { + Seq( + col("a").withField("a", $"a.a".withField("d", lit(4)).withField("e", lit(5))), + col("a").withField("a.d", lit(4)).withField("a.e", lit(5)) + ).foreach { column => + checkAnswer( + nullableStructLevel2.select(column.as("a")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, null, 3, 4, 5))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = true))), + nullable = true)))) + } + } + test("withField should add field to deeply nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel3.withColumn("a", 'a.withField("a.a.d", lit(4))), Row(Row(Row(Row(1, null, 3, 4)))) :: Nil, StructType(Seq( @@ -1136,7 +1192,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace field in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("b", lit(2))), Row(Row(1, 2, 3)) :: Nil, StructType(Seq( @@ -1147,10 +1203,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } - test("withField should replace field in null struct") { - checkAnswerAndSchema( - nullStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), - Row(null) :: Nil, + test("withField should replace field in nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", 'a.withField("b", lit("foo"))), + Row(null) :: Row(Row(1, "foo", 3)) :: Nil, StructType(Seq( StructField("a", StructType(Seq( StructField("a", IntegerType, nullable = false), @@ -1159,10 +1215,10 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = true)))) } - test("withField should replace field in nested null struct") { - checkAnswerAndSchema( - nullStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), - Row(Row(null)) :: Nil, + test("withField should replace field in nested nullable struct") { + checkAnswer( + nullableStructLevel2.withColumn("a", $"a".withField("a.b", lit("foo"))), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, "foo", 3))) :: Nil, StructType( Seq(StructField("a", StructType(Seq( StructField("a", StructType(Seq( @@ -1170,11 +1226,11 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("b", StringType, nullable = false), StructField("c", IntegerType, nullable = false))), nullable = true))), - nullable = false)))) + nullable = true)))) } test("withField should replace field with null value in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("c", lit(null).cast(IntegerType))), Row(Row(1, null, null)) :: Nil, StructType(Seq( @@ -1186,7 +1242,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace multiple fields in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), Row(Row(10, 20, 3)) :: Nil, StructType(Seq( @@ -1197,12 +1253,24 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false)))) } + test("withField should replace multiple fields in nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", 'a.withField("a", lit(10)).withField("b", lit(20))), + Row(null) :: Row(Row(10, 20, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + test("withField should replace field in nested struct") { Seq( structLevel2.withColumn("a", $"a".withField("a.b", lit(2))), structLevel2.withColumn("a", 'a.withField("a", $"a.a".withField("b", lit(2)))) ).foreach { df => - checkAnswerAndSchema( + checkAnswer( df, Row(Row(Row(1, 2, 3))) :: Nil, StructType(Seq( @@ -1216,8 +1284,46 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } } + test("withField should replace multiple fields in nested struct") { + Seq( + col("a").withField("a", $"a.a".withField("a", lit(10)).withField("b", lit(20))), + col("a").withField("a.a", lit(10)).withField("a.b", lit(20)) + ).foreach { column => + checkAnswer( + structLevel2.select(column.as("a")), + Row(Row(Row(10, 20, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("withField should replace multiple fields in nested nullable struct") { + Seq( + col("a").withField("a", $"a.a".withField("a", lit(10)).withField("b", lit(20))), + col("a").withField("a.a", lit(10)).withField("a.b", lit(20)) + ).foreach { column => + checkAnswer( + nullableStructLevel2.select(column.as("a")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(10, 20, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = true)))) + } + } + test("withField should replace field in deeply nested struct") { - checkAnswerAndSchema( + checkAnswer( structLevel3.withColumn("a", $"a".withField("a.a.b", lit(2))), Row(Row(Row(Row(1, 2, 3)))) :: Nil, StructType(Seq( @@ -1242,7 +1348,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("b", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("b", lit(100))), Row(Row(1, 100, 100)) :: Nil, StructType(Seq( @@ -1254,7 +1360,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should replace fields in struct in given order") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("b", lit(2)).withField("b", lit(20))), Row(Row(1, 20, 3)) :: Nil, StructType(Seq( @@ -1266,7 +1372,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { } test("withField should add field and then replace same field in struct") { - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", 'a.withField("d", lit(4)).withField("d", lit(5))), Row(Row(1, null, 3, 5)) :: Nil, StructType(Seq( @@ -1290,7 +1396,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", 'a.withField("`a.b`.`e.f`", lit(2))), Row(Row(Row(1, 2, 3))) :: Nil, StructType(Seq( @@ -1317,7 +1423,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), Row(Row(2, 1)) :: Nil, StructType(Seq( @@ -1326,7 +1432,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("B", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), Row(Row(1, 2)) :: Nil, StructType(Seq( @@ -1339,7 +1445,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should add field to struct because casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("A", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( @@ -1349,7 +1455,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructField("A", IntegerType, nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel1.withColumn("a", 'a.withField("b", lit(2))), Row(Row(1, 1, 2)) :: Nil, StructType(Seq( @@ -1377,7 +1483,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("withField should replace nested field in struct even if casing is different") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel2.withColumn("a", 'a.withField("A.a", lit(2))), Row(Row(Row(2, 1), Row(1, 1))) :: Nil, StructType(Seq( @@ -1392,7 +1498,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { nullable = false))), nullable = false)))) - checkAnswerAndSchema( + checkAnswer( mixedCaseStructLevel2.withColumn("a", 'a.withField("b.a", lit(2))), Row(Row(Row(1, 1), Row(2, 1))) :: Nil, StructType(Seq( @@ -1451,30 +1557,41 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") .select($"struct_col".withField("a.c", lit(3))) }.getMessage should include("Ambiguous reference to fields") + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".withField("a.c", lit(3)).withField("a.d", lit(4))), + Row(Row(Row(1, 2, 3, 4)))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".withField("a", + $"struct_col.a".withField("c", lit(3)).withField("d", lit(4)))), + Row(Row(Row(1, 2, 3, 4)))) } test("SPARK-32641: extracting field from non-null struct column after withField should return " + "field value") { // extract newly added field - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("d")), Row(4) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) // extract newly replaced field - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("a")), Row(4) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) // add new field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("c")), Row(3):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) // replace field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( structLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("c")), Row(3):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = false)))) @@ -1482,26 +1599,30 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { test("SPARK-32641: extracting field from null struct column after withField should return " + "null if the original struct was null") { + val nullStructLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + // extract newly added field - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("d")), Row(null) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // extract newly replaced field - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("a")), Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // add new field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("d", lit(4)).getField("c")), Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // replace field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( nullStructLevel1.withColumn("a", $"a".withField("a", lit(4)).getField("c")), Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) @@ -1514,27 +1635,671 @@ class ColumnExpressionSuite extends QueryTest with SharedSparkSession { StructType(Seq(StructField("a", structType, nullable = true)))) // extract newly added field - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("d", lit(4)).getField("d")), Row(4) :: Row(null) :: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // extract newly replaced field - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("a", lit(4)).getField("a")), Row(4) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // add new field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("d", lit(4)).getField("c")), Row(3) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) // replace field, extract another field from original struct - checkAnswerAndSchema( + checkAnswer( df.withColumn("a", $"a".withField("a", lit(4)).getField("c")), Row(3) :: Row(null):: Nil, StructType(Seq(StructField("a", IntegerType, nullable = true)))) } + + + test("dropFields should throw an exception if called on a non-StructType column") { + intercept[AnalysisException] { + testData.withColumn("key", $"key".dropFields("a")) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("dropFields should throw an exception if fieldName argument is null") { + intercept[IllegalArgumentException] { + structLevel1.withColumn("a", $"a".dropFields(null)) + }.getMessage should include("fieldName cannot be null") + } + + test("dropFields should throw an exception if any intermediate structs don't exist") { + intercept[AnalysisException] { + structLevel2.withColumn("a", 'a.dropFields("x.b")) + }.getMessage should include("No such struct field x in a") + + intercept[AnalysisException] { + structLevel3.withColumn("a", 'a.dropFields("a.x.b")) + }.getMessage should include("No such struct field x in a") + } + + test("dropFields should throw an exception if intermediate field is not a struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.dropFields("b.a")) + }.getMessage should include("struct argument should be struct type, got: int") + } + + test("dropFields should throw an exception if intermediate field reference is ambiguous") { + intercept[AnalysisException] { + val structLevel2: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("a", structType, nullable = false))), + nullable = false)))) + + structLevel2.withColumn("a", 'a.dropFields("a.b")) + }.getMessage should include("Ambiguous reference to fields") + } + + test("dropFields should drop field in struct") { + checkAnswer( + structLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in nullable struct") { + checkAnswer( + nullableStructLevel1.withColumn("a", $"a".dropFields("b")), + Row(null) :: Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true)))) + } + + test("dropFields should drop multiple fields in struct") { + Seq( + structLevel1.withColumn("a", $"a".dropFields("b", "c")), + structLevel1.withColumn("a", 'a.dropFields("b").dropFields("c")) + ).foreach { df => + checkAnswer( + df, + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should throw an exception if no fields will be left in struct") { + intercept[AnalysisException] { + structLevel1.withColumn("a", 'a.dropFields("a", "b", "c")) + }.getMessage should include("cannot drop all fields in struct") + } + + test("dropFields should drop field with no name in struct") { + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("", IntegerType, nullable = false))) + + val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + checkAnswer( + structLevel1.withColumn("a", $"a".dropFields("")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in nested struct") { + checkAnswer( + structLevel2.withColumn("a", 'a.dropFields("a.b")), + Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop multiple fields in nested struct") { + checkAnswer( + structLevel2.withColumn("a", 'a.dropFields("a.b", "a.c")), + Row(Row(Row(1))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in nested nullable struct") { + checkAnswer( + nullableStructLevel2.withColumn("a", $"a".dropFields("a.b")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = true))), + nullable = true)))) + } + + test("dropFields should drop multiple fields in nested nullable struct") { + checkAnswer( + nullableStructLevel2.withColumn("a", $"a".dropFields("a.b", "a.c")), + Row(null) :: Row(Row(null)) :: Row(Row(Row(1))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = true))), + nullable = true)))) + } + + test("dropFields should drop field in deeply nested struct") { + checkAnswer( + structLevel3.withColumn("a", 'a.dropFields("a.a.b")), + Row(Row(Row(Row(1, 3)))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop all fields with given name in struct") { + val structLevel1 = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2, 3)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + structLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + + test("dropFields should drop field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should not drop field in struct because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("A")), + Row(Row(1, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + mixedCaseStructLevel1.withColumn("a", 'a.dropFields("b")), + Row(Row(1, 1)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("B", IntegerType, nullable = false))), + nullable = false)))) + } + } + + test("dropFields should drop nested field in struct even if casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkAnswer( + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")), + Row(Row(Row(1), Row(1, 1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("A", StructType(Seq( + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("B", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswer( + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")), + Row(Row(Row(1, 1), Row(1))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false), + StructField("b", StructType(Seq( + StructField("b", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + } + + test("dropFields should throw an exception because casing is different") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("A.a")) + }.getMessage should include("No such struct field A in a, B") + + intercept[AnalysisException] { + mixedCaseStructLevel2.withColumn("a", 'a.dropFields("b.a")) + }.getMessage should include("No such struct field b in a, B") + } + } + + test("dropFields should drop only fields that exist") { + checkAnswer( + structLevel1.withColumn("a", 'a.dropFields("d")), + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + structLevel1.withColumn("a", 'a.dropFields("b", "d")), + Row(Row(1, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + structLevel2.withColumn("a", $"a".dropFields("a.b", "a.d")), + Row(Row(Row(1, 3))) :: Nil, + StructType( + Seq(StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + } + + test("dropFields should drop multiple fields at arbitrary levels of nesting in a single call") { + val df: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(Row(1, null, 3), 4)) :: Nil), + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", structType, nullable = false), + StructField("b", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + df.withColumn("a", $"a".dropFields("a.b", "b")), + Row(Row(Row(1, 3))) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("c", IntegerType, nullable = false))), nullable = false))), + nullable = false)))) + } + + test("dropFields user-facing examples") { + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("b")), + Row(Row(1))) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("c")), + Row(Row(1, 2))) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2, 'c', 3) struct_col") + .select($"struct_col".dropFields("b", "c")), + Row(Row(1))) + + intercept[AnalysisException] { + sql("SELECT named_struct('a', 1, 'b', 2) struct_col") + .select($"struct_col".dropFields("a", "b")) + }.getMessage should include("cannot drop all fields in struct") + + checkAnswer( + sql("SELECT CAST(NULL AS struct) struct_col") + .select($"struct_col".dropFields("b")), + Row(null)) + + checkAnswer( + sql("SELECT named_struct('a', 1, 'b', 2, 'b', 3) struct_col") + .select($"struct_col".dropFields("b")), + Row(Row(1))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2)) struct_col") + .select($"struct_col".dropFields("a.b")), + Row(Row(Row(1)))) + + intercept[AnalysisException] { + sql("SELECT named_struct('a', named_struct('b', 1), 'a', named_struct('c', 2)) struct_col") + .select($"struct_col".dropFields("a.c")) + }.getMessage should include("Ambiguous reference to fields") + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3)) struct_col") + .select($"struct_col".dropFields("a.b", "a.c")), + Row(Row(Row(1)))) + + checkAnswer( + sql("SELECT named_struct('a', named_struct('a', 1, 'b', 2, 'c', 3)) struct_col") + .select($"struct_col".withField("a", $"struct_col.a".dropFields("b", "c"))), + Row(Row(Row(1)))) + } + + test("should correctly handle different dropField + withField + getField combinations") { + val structType = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = false))) + + val structLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2)) :: Nil), + StructType(Seq(StructField("a", structType, nullable = false)))) + + val nullStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + val nullableStructLevel1: DataFrame = spark.createDataFrame( + sparkContext.parallelize(Row(Row(1, 2)) :: Row(null) :: Nil), + StructType(Seq(StructField("a", structType, nullable = true)))) + + def check( + fieldOps: Column => Column, + getFieldName: String, + expectedValue: Option[Int]): Unit = { + + def query(df: DataFrame): DataFrame = + df.select(fieldOps(col("a")).getField(getFieldName).as("res")) + + checkAnswer( + query(structLevel1), + Row(expectedValue.orNull) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = expectedValue.isEmpty)))) + + checkAnswer( + query(nullStructLevel1), + Row(null) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + + checkAnswer( + query(nullableStructLevel1), + Row(expectedValue.orNull) :: Row(null) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + // add attribute, extract an attribute from the original struct + check(_.withField("c", lit(3)), "a", Some(1)) + check(_.withField("c", lit(3)), "b", Some(2)) + + // add attribute, extract added attribute + check(_.withField("c", lit(3)), "c", Some(3)) + check(_.withField("c", col("a.a")), "c", Some(1)) + check(_.withField("c", col("a.b")), "c", Some(2)) + check(_.withField("c", lit(null).cast(IntegerType)), "c", None) + + // replace attribute, extract an attribute from the original struct + check(_.withField("b", lit(3)), "a", Some(1)) + check(_.withField("a", lit(3)), "b", Some(2)) + + // replace attribute, extract replaced attribute + check(_.withField("b", lit(3)), "b", Some(3)) + check(_.withField("b", lit(null).cast(IntegerType)), "b", None) + check(_.withField("a", lit(3)), "a", Some(3)) + check(_.withField("a", lit(null).cast(IntegerType)), "a", None) + + // drop attribute, extract an attribute from the original struct + check(_.dropFields("b"), "a", Some(1)) + check(_.dropFields("a"), "b", Some(2)) + + // drop attribute, add attribute, extract an attribute from the original struct + check(_.dropFields("b").withField("c", lit(3)), "a", Some(1)) + check(_.dropFields("a").withField("c", lit(3)), "b", Some(2)) + + // drop attribute, add another attribute, extract added attribute + check(_.dropFields("a").withField("c", lit(3)), "c", Some(3)) + check(_.dropFields("b").withField("c", lit(3)), "c", Some(3)) + + // add attribute, drop attribute, extract an attribute from the original struct + check(_.withField("c", lit(3)).dropFields("a"), "b", Some(2)) + check(_.withField("c", lit(3)).dropFields("b"), "a", Some(1)) + + // add attribute, drop another attribute, extract added attribute + check(_.withField("c", lit(3)).dropFields("a"), "c", Some(3)) + check(_.withField("c", lit(3)).dropFields("b"), "c", Some(3)) + + // replace attribute, drop same attribute, extract an attribute from the original struct + check(_.withField("b", lit(3)).dropFields("b"), "a", Some(1)) + check(_.withField("a", lit(3)).dropFields("a"), "b", Some(2)) + + // add attribute, drop same attribute, extract an attribute from the original struct + check(_.withField("c", lit(3)).dropFields("c"), "a", Some(1)) + check(_.withField("c", lit(3)).dropFields("c"), "b", Some(2)) + + // add attribute, drop another attribute, extract added attribute + check(_.withField("b", lit(3)).dropFields("a"), "b", Some(3)) + check(_.withField("a", lit(3)).dropFields("b"), "a", Some(3)) + check(_.withField("b", lit(null).cast(IntegerType)).dropFields("a"), "b", None) + check(_.withField("a", lit(null).cast(IntegerType)).dropFields("b"), "a", None) + + // drop attribute, add same attribute, extract added attribute + check(_.dropFields("b").withField("b", lit(3)), "b", Some(3)) + check(_.dropFields("a").withField("a", lit(3)), "a", Some(3)) + check(_.dropFields("b").withField("b", lit(null).cast(IntegerType)), "b", None) + check(_.dropFields("a").withField("a", lit(null).cast(IntegerType)), "a", None) + check(_.dropFields("c").withField("c", lit(3)), "c", Some(3)) + + // add attribute, drop same attribute, add same attribute again, extract added attribute + check(_.withField("c", lit(3)).dropFields("c").withField("c", lit(4)), "c", Some(4)) + } + + test("should move field up one level of nesting") { + // move a field up one level + checkAnswer( + nullableStructLevel2.select( + col("a").withField("c", col("a.a.c")).dropFields("a.c").as("res")), + Row(null) :: Row(Row(null, null)) :: Row(Row(Row(1, null), 3)) :: Nil, + StructType(Seq( + StructField("res", StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true))), + nullable = true), + StructField("c", IntegerType, nullable = true))), + nullable = true)))) + + // move a field up one level and then extract it + checkAnswer( + nullableStructLevel2.select( + col("a").withField("c", col("a.a.c")).dropFields("a.c").getField("c").as("res")), + Row(null) :: Row(null) :: Row(3) :: Nil, + StructType(Seq(StructField("res", IntegerType, nullable = true)))) + } + + test("should be able to refer to newly added nested column") { + intercept[AnalysisException] { + structLevel1.select($"a".withField("d", lit(4)).withField("e", $"a.d" + 1).as("a")) + }.getMessage should include("No such struct field d in a, b, c") + + checkAnswer( + structLevel1 + .select($"a".withField("d", lit(4)).as("a")) + .select($"a".withField("e", $"a.d" + 1).as("a")), + Row(Row(1, null, 3, 4, 5)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("c", IntegerType, nullable = false), + StructField("d", IntegerType, nullable = false), + StructField("e", IntegerType, nullable = false))), + nullable = false)))) + } + + test("should be able to drop newly added nested column") { + Seq( + structLevel1.select($"a".withField("d", lit(4)).dropFields("d").as("a")), + structLevel1 + .select($"a".withField("d", lit(4)).as("a")) + .select($"a".dropFields("d").as("a")) + ).foreach { query => + checkAnswer( + query, + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", structType, nullable = false)))) + } + } + + test("should still be able to refer to dropped column within the same select statement") { + // we can still access the nested column even after dropping it within the same select statement + checkAnswer( + structLevel1.select($"a".dropFields("c").withField("z", $"a.c").as("a")), + Row(Row(1, null, 3)) :: Nil, + StructType(Seq( + StructField("a", StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", IntegerType, nullable = true), + StructField("z", IntegerType, nullable = false))), + nullable = false)))) + + // we can't access the nested column in subsequent select statement after dropping it in a + // previous select statement + intercept[AnalysisException]{ + structLevel1 + .select($"a".dropFields("c").as("a")) + .select($"a".withField("z", $"a.c")).as("a") + }.getMessage should include("No such struct field c in a, b;") + } + + test("nestedDf should generate nested DataFrames") { + checkAnswer( + emptyNestedDf(1, 1, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(1, 2, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", IntegerType, nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(2, 1, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false))), + nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(2, 2, nullable = false), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = false), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = false)))) + + checkAnswer( + emptyNestedDf(2, 2, nullable = true), + Seq.empty[Row], + StructType(Seq(StructField("nested0Col0", StructType(Seq( + StructField("nested1Col0", StructType(Seq( + StructField("nested2Col0", IntegerType, nullable = false), + StructField("nested2Col1", IntegerType, nullable = false))), + nullable = true), + StructField("nested1Col1", IntegerType, nullable = false))), + nullable = true)))) + } + + Seq(Performant, NonPerformant).foreach { method => + Seq(false, true).foreach { nullable => + test(s"should add and drop 1 column at each depth of nesting using ${method.name} method, " + + s"nullable = $nullable") { + val maxDepth = 3 + + // dataframe with nested*Col0 to nested*Col2 at each depth + val inputDf = emptyNestedDf(maxDepth, 3, nullable) + + // add nested*Col3 and drop nested*Col2 + val modifiedColumn = method( + column = col(nestedColName(0, 0)), + numsToAdd = Seq(3), + numsToDrop = Seq(2), + maxDepth = maxDepth + ).as(nestedColName(0, 0)) + val resultDf = inputDf.select(modifiedColumn) + + // dataframe with nested*Col0, nested*Col1, nested*Col3 at each depth + val expectedDf = { + val colNums = Seq(0, 1, 3) + val nestedColumnDataType = nestedStructType(colNums, nullable, maxDepth) + + spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) + } + + checkAnswer(resultDf, expectedDf.collect(), expectedDf.schema) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala new file mode 100644 index 000000000000..28af552fe586 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UpdateFieldsBenchmark.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.execution.benchmark.SqlBasedBenchmark +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * Benchmark to measure Spark's performance analyzing and optimizing long UpdateFields chains. + * + * {{{ + * To run this benchmark: + * 1. without sbt: + * bin/spark-submit --class + * 2. with sbt: + * build/sbt "sql/test:runMain " + * 3. generate result: + * SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain " + * Results will be written to "benchmarks/UpdateFieldsBenchmark-results.txt". + * }}} + */ +object UpdateFieldsBenchmark extends SqlBasedBenchmark { + + def nestedColName(d: Int, colNum: Int): String = s"nested${d}Col$colNum" + + def nestedStructType( + colNums: Seq[Int], + nullable: Boolean, + maxDepth: Int, + currDepth: Int = 1): StructType = { + + if (currDepth == maxDepth) { + val fields = colNums.map { colNum => + val name = nestedColName(currDepth, colNum) + StructField(name, IntegerType, nullable = false) + } + StructType(fields) + } else { + val fields = colNums.foldLeft(Seq.empty[StructField]) { + case (structFields, colNum) if colNum == 0 => + val nested = nestedStructType(colNums, nullable, maxDepth, currDepth + 1) + structFields :+ StructField(nestedColName(currDepth, colNum), nested, nullable) + case (structFields, colNum) => + val name = nestedColName(currDepth, colNum) + structFields :+ StructField(name, IntegerType, nullable = false) + } + StructType(fields) + } + } + + /** + * Utility function for generating an empty DataFrame with nested columns. + * + * @param maxDepth: The depth to which to create nested columns. + * @param numColsAtEachDepth: The number of columns to create at each depth. + * @param nullable: This value is used to set the nullability of any StructType columns. + */ + def emptyNestedDf(maxDepth: Int, numColsAtEachDepth: Int, nullable: Boolean): DataFrame = { + require(maxDepth > 0) + require(numColsAtEachDepth > 0) + + val nestedColumnDataType = nestedStructType(0 until numColsAtEachDepth, nullable, maxDepth) + spark.createDataFrame( + spark.sparkContext.emptyRDD[Row], + StructType(Seq(StructField(nestedColName(0, 0), nestedColumnDataType, nullable)))) + } + + trait ModifyNestedColumns { + val name: String + def apply(column: Column, numsToAdd: Seq[Int], numsToDrop: Seq[Int], maxDepth: Int): Column + } + + object Performant extends ModifyNestedColumns { + override val name: String = "performant" + + override def apply( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int): Column = helper(column, numsToAdd, numsToDrop, maxDepth, 1) + + private def helper( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int, + currDepth: Int): Column = { + + // drop columns at the current depth + val dropped = if (numsToDrop.nonEmpty) { + column.dropFields(numsToDrop.map(num => nestedColName(currDepth, num)): _*) + } else column + + // add columns at the current depth + val added = numsToAdd.foldLeft(dropped) { + (res, num) => res.withField(nestedColName(currDepth, num), lit(num)) + } + + if (currDepth == maxDepth) { + added + } else { + // add/drop columns at the next depth + val newValue = helper( + column = col((0 to currDepth).map(d => nestedColName(d, 0)).mkString(".")), + numsToAdd = numsToAdd, + numsToDrop = numsToDrop, + currDepth = currDepth + 1, + maxDepth = maxDepth) + added.withField(nestedColName(currDepth, 0), newValue) + } + } + } + + object NonPerformant extends ModifyNestedColumns { + override val name: String = "non-performant" + + override def apply( + column: Column, + numsToAdd: Seq[Int], + numsToDrop: Seq[Int], + maxDepth: Int): Column = { + + val dropped = if (numsToDrop.nonEmpty) { + val colsToDrop = (1 to maxDepth).flatMap { depth => + numsToDrop.map(num => s"${prefix(depth)}${nestedColName(depth, num)}") + } + column.dropFields(colsToDrop: _*) + } else column + + val added = { + val colsToAdd = (1 to maxDepth).flatMap { depth => + numsToAdd.map(num => (s"${prefix(depth)}${nestedColName(depth, num)}", lit(num))) + } + colsToAdd.foldLeft(dropped)((col, add) => col.withField(add._1, add._2)) + } + + added + } + + private def prefix(depth: Int): String = + if (depth == 1) "" + else (1 until depth).map(d => nestedColName(d, 0)).mkString("", ".", ".") + } + + private def updateFieldsBenchmark( + methods: Seq[ModifyNestedColumns], + maxDepth: Int, + initialNumberOfColumns: Int, + numsToAdd: Seq[Int] = Seq.empty, + numsToDrop: Seq[Int] = Seq.empty): Unit = { + + val name = s"Add ${numsToAdd.length} columns and drop ${numsToDrop.length} columns " + + s"at $maxDepth different depths of nesting" + + runBenchmark(name) { + val benchmark = new Benchmark( + name = name, + // The purpose of this benchmark is to ensure Spark is able to analyze and optimize long + // UpdateFields chains quickly so it runs over 0 rows of data. + valuesPerIteration = 0, + output = output) + + val nonNullableStructsDf = emptyNestedDf(maxDepth, initialNumberOfColumns, nullable = false) + val nullableStructsDf = emptyNestedDf(maxDepth, initialNumberOfColumns, nullable = true) + + methods.foreach { method => + val modifiedColumn = method( + column = col(nestedColName(0, 0)), + numsToAdd = numsToAdd, + numsToDrop = numsToDrop, + maxDepth = maxDepth + ).as(nestedColName(0, 0)) + + benchmark.addCase(s"To non-nullable StructTypes using ${method.name} method") { _ => + nonNullableStructsDf.select(modifiedColumn).queryExecution.optimizedPlan + } + + benchmark.addCase(s"To nullable StructTypes using ${method.name} method") { _ => + nullableStructsDf.select(modifiedColumn).queryExecution.optimizedPlan + } + } + + benchmark.run() + } + } + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + // This benchmark compares the performant and non-performant methods of writing the same query. + // We use small values for maxDepth, numsToAdd, and numsToDrop because the NonPerformant method + // scales extremely poorly with the number of nested columns being added/dropped. + updateFieldsBenchmark( + methods = Seq(Performant, NonPerformant), + maxDepth = 3, + initialNumberOfColumns = 5, + numsToAdd = 5 to 6, + numsToDrop = 3 to 4) + + // This benchmark is to show that the performant method of writing a query when we want to add + // and drop a large number of nested columns scales nicely. + updateFieldsBenchmark( + methods = Seq(Performant), + maxDepth = 100, + initialNumberOfColumns = 51, + numsToAdd = 51 to 100, + numsToDrop = 1 to 50) + } +}