diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0ba150ec1efb..4264627e0d9b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ +import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -207,6 +208,11 @@ class Analyzer( lazy val batches: Seq[Batch] = Seq( Batch("Substitution", fixedPoint, + // This rule optimizes `UpdateFields` expression chains so looks more like optimization rule. + // However, when manipulating deeply nested schema, `UpdateFields` expression tree could be + // very complex and make analysis impossible. Thus we need to optimize `UpdateFields` early + // at the beginning of analysis. + OptimizeUpdateFields, CTESubstitution, WindowsSubstitution, EliminateUnions, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala index c1a9c9d3d9ba..b08e116642ec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveUnion.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.CombineUnions +import org.apache.spark.sql.catalyst.optimizer.{CombineUnions, OptimizeUpdateFields} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -88,13 +88,6 @@ object ResolveUnion extends Rule[LogicalPlan] { } } - def simplifyWithFields(expr: Expression): Expression = { - expr.transformUp { - case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => - UpdateFields(struct, fieldOps1 ++ fieldOps2) - } - } - /** * Adds missing fields recursively into given `col` expression, based on the target `StructType`. * This is called by `compareAndAddFields` when we find two struct columns with same name but @@ -119,7 +112,7 @@ object ResolveUnion extends Rule[LogicalPlan] { missingFieldsOpt.map { s => val struct = addFieldsInto(col, s.fields) // Combines `WithFields`s to reduce expression tree. - val reducedStruct = simplifyWithFields(struct) + val reducedStruct = struct.transformUp(OptimizeUpdateFields.optimizeUpdateFields) val sorted = sortStructFieldsInWithFields(reducedStruct) sorted }.get diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala index 860219e55b05..2ac8f62b67b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ComplexTypes.scala @@ -46,7 +46,12 @@ object SimplifyExtractValueOps extends Rule[LogicalPlan] { // if the struct itself is null, then any value extracted from it (expr) will be null // so we don't need to wrap expr in If(IsNull(struct), Literal(null, expr.dataType), expr) case expr: GetStructField if expr.child.semanticEquals(structExpr) => expr - case expr => If(IsNull(structExpr), Literal(null, expr.dataType), expr) + case expr => + if (structExpr.nullable) { + If(IsNull(structExpr), Literal(null, expr.dataType), expr) + } else { + expr + } } // Remove redundant array indexing. case GetArrayStructFields(CreateArray(elems, useStringTypeWhenEmpty), field, ordinal, _, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 7586bdf4392f..3e9a97419682 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -109,7 +109,7 @@ abstract class Optimizer(catalogManager: CatalogManager) RemoveRedundantAliases, UnwrapCastInBinaryComparison, RemoveNoopOperators, - CombineUpdateFields, + OptimizeUpdateFields, SimplifyExtractValueOps, OptimizeJsonExprs, CombineConcats) ++ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala index c7154210e0c6..465d2efe2775 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UpdateFields.scala @@ -17,19 +17,68 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.UpdateFields +import java.util.Locale + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Expression, UpdateFields, WithField} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf /** - * Combines all adjacent [[UpdateFields]] expression into a single [[UpdateFields]] expression. + * Optimizes [[UpdateFields]] expression chains. */ -object CombineUpdateFields extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { +object OptimizeUpdateFields extends Rule[LogicalPlan] { + private def canOptimize(names: Seq[String]): Boolean = { + if (SQLConf.get.caseSensitiveAnalysis) { + names.distinct.length != names.length + } else { + names.map(_.toLowerCase(Locale.ROOT)).distinct.length != names.length + } + } + + val optimizeUpdateFields: PartialFunction[Expression, Expression] = { + case UpdateFields(structExpr, fieldOps) + if fieldOps.forall(_.isInstanceOf[WithField]) && + canOptimize(fieldOps.map(_.asInstanceOf[WithField].name)) => + val caseSensitive = SQLConf.get.caseSensitiveAnalysis + + val withFields = fieldOps.map(_.asInstanceOf[WithField]) + val names = withFields.map(_.name) + val values = withFields.map(_.valExpr) + + val newNames = mutable.ArrayBuffer.empty[String] + val newValues = mutable.ArrayBuffer.empty[Expression] + + if (caseSensitive) { + names.zip(values).reverse.foreach { case (name, value) => + if (!newNames.contains(name)) { + newNames += name + newValues += value + } + } + } else { + val nameSet = mutable.HashSet.empty[String] + names.zip(values).reverse.foreach { case (name, value) => + val lowercaseName = name.toLowerCase(Locale.ROOT) + if (!nameSet.contains(lowercaseName)) { + newNames += name + newValues += value + nameSet += lowercaseName + } + } + } + + val newWithFields = newNames.reverse.zip(newValues.reverse).map(p => WithField(p._1, p._2)) + UpdateFields(structExpr, newWithFields.toSeq) + case UpdateFields(UpdateFields(struct, fieldOps1), fieldOps2) => UpdateFields(struct, fieldOps1 ++ fieldOps2) } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions(optimizeUpdateFields) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala similarity index 51% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala index ff9c60a2fa5b..b093b39cc4b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineUpdateFieldsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeWithFieldsSuite.scala @@ -19,19 +19,21 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, UpdateFields, WithField} +import org.apache.spark.sql.catalyst.expressions.{Alias, GetStructField, Literal, UpdateFields, WithField} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf - -class CombineUpdateFieldsSuite extends PlanTest { +class OptimizeWithFieldsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { - val batches = Batch("CombineUpdateFields", FixedPoint(10), CombineUpdateFields) :: Nil + val batches = Batch("OptimizeUpdateFields", FixedPoint(10), + OptimizeUpdateFields, SimplifyExtractValueOps) :: Nil } private val testRelation = LocalRelation('a.struct('a1.int)) + private val testRelation2 = LocalRelation('a.struct('a1.int).notNull) test("combines two adjacent UpdateFields Expressions") { val originalQuery = testRelation @@ -70,4 +72,58 @@ class CombineUpdateFieldsSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("SPARK-32941: optimize WithFields followed by GetStructField") { + val originalQuery = testRelation2 + .select(Alias( + GetStructField(UpdateFields('a, + WithField("b1", Literal(4)) :: Nil), 1), "out")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation2 + .select(Alias(Literal(4), "out")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-32941: optimize WithFields chain - case insensitive") { + val originalQuery = testRelation + .select( + Alias(UpdateFields('a, + WithField("b1", Literal(4)) :: WithField("b1", Literal(5)) :: Nil), "out1")(), + Alias(UpdateFields('a, + WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select( + Alias(UpdateFields('a, WithField("b1", Literal(5)) :: Nil), "out1")(), + Alias(UpdateFields('a, WithField("B1", Literal(5)) :: Nil), "out2")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("SPARK-32941: optimize WithFields chain - case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val originalQuery = testRelation + .select( + Alias(UpdateFields('a, + WithField("b1", Literal(4)) :: WithField("b1", Literal(5)) :: Nil), "out1")(), + Alias(UpdateFields('a, + WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")()) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select( + Alias(UpdateFields('a, WithField("b1", Literal(5)) :: Nil), "out1")(), + Alias( + UpdateFields('a, + WithField("b1", Literal(4)) :: WithField("B1", Literal(5)) :: Nil), "out2")()) + .analyze + + comparePlans(optimized, correctAnswer) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala index d9cefdaf3fe7..9878969959bf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/complexTypesSuite.scala @@ -44,7 +44,7 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { BooleanSimplification, SimplifyConditionals, SimplifyBinaryComparison, - CombineUpdateFields, + OptimizeUpdateFields, SimplifyExtractValueOps) :: Nil } @@ -698,7 +698,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val expected = structLevel2.select( UpdateFields('a1, Seq( // scalastyle:off line.size.limit - WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: Nil)), WithField("a2", UpdateFields(GetStructField('a1, 0), WithField("b3", 2) :: WithField("c3", 3) :: Nil)) // scalastyle:on line.size.limit )).as("a1")) @@ -732,7 +731,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { structLevel2.select( UpdateFields('a1, Seq( - WithField("a2", repeatedExpr), WithField("a2", UpdateFields( If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), WithField("c3", Literal(3)) :: Nil)) @@ -763,7 +761,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { val expected = structLevel2.select( UpdateFields('a1, Seq( - WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3")))), WithField("a2", UpdateFields(GetStructField('a1, 0), Seq(DropField("b3"), DropField("c3")))) )).as("a1")) @@ -797,7 +794,6 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper { structLevel2.select( UpdateFields('a1, Seq( - WithField("a2", repeatedExpr), WithField("a2", UpdateFields( If(IsNull('a1), Literal(null, repeatedExprDataType), repeatedExpr), DropField("c3") :: Nil))