diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index eabbc7fc74f5..bf7491625fa0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -439,9 +439,7 @@ case class InMemoryRelation( override def innerChildren: Seq[SparkPlan] = Seq(cachedPlan) override def doCanonicalize(): logical.LogicalPlan = - copy(output = output.map(QueryPlan.normalizeExpressions(_, output)), - cacheBuilder, - outputOrdering) + withOutput(output.map(QueryPlan.normalizeExpressions(_, output))) @transient val partitionStatistics = new PartitionStatistics(output) @@ -469,8 +467,13 @@ case class InMemoryRelation( } } - def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = - InMemoryRelation(newOutput, cacheBuilder, outputOrdering, statsOfPlanToCache) + def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { + val map = AttributeMap(output.zip(newOutput)) + val newOutputOrdering = outputOrdering + .map(_.transform { case a: Attribute => map(a) }) + .asInstanceOf[Seq[SortOrder]] + InMemoryRelation(newOutput, cacheBuilder, newOutputOrdering, statsOfPlanToCache) + } override def newInstance(): this.type = { InMemoryRelation( @@ -487,6 +490,12 @@ case class InMemoryRelation( cloned } + override def makeCopy(newArgs: Array[AnyRef]): LogicalPlan = { + val copied = super.makeCopy(newArgs).asInstanceOf[InMemoryRelation] + copied.statsOfPlanToCache = this.statsOfPlanToCache + copied + } + override def simpleString(maxFields: Int): String = s"InMemoryRelation [${truncatedString(output, ", ", maxFields)}], ${cacheBuilder.storageLevel}" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala index 280fe1068d81..4493d1a6e689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Attribute, AttributeMap, AttributeSet, BitwiseAnd, Empty2Null, Expression, HiveHash, Literal, NamedExpression, Pmod, SortOrder} +import org.apache.spark.sql.catalyst.optimizer.{EliminateSorts, FoldablePropagation} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Sort} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.rules.Rule @@ -97,13 +98,15 @@ object V1Writes extends Rule[LogicalPlan] { assert(empty2NullPlan.output.length == query.output.length) val attrMap = AttributeMap(query.output.zip(empty2NullPlan.output)) - // Rewrite the attribute references in the required ordering to use the new output. - val requiredOrdering = write.requiredOrdering.map(_.transform { - case a: Attribute => attrMap.getOrElse(a, a) - }.asInstanceOf[SortOrder]) - val outputOrdering = empty2NullPlan.outputOrdering - val orderingMatched = isOrderingMatched(requiredOrdering.map(_.child), outputOrdering) - if (orderingMatched) { + // Rewrite the attribute references in the required ordering to use the new output, + // then eliminate foldable ordering. + val requiredOrdering = { + val ordering = write.requiredOrdering.map(_.transform { + case a: Attribute => attrMap.getOrElse(a, a) + }.asInstanceOf[SortOrder]) + eliminateFoldableOrdering(ordering, empty2NullPlan).outputOrdering + } + if (isOrderingMatched(requiredOrdering.map(_.child), empty2NullPlan.outputOrdering)) { empty2NullPlan } else { Sort(requiredOrdering, global = false, empty2NullPlan) @@ -199,6 +202,15 @@ object V1WritesUtils { expressions.exists(_.exists(_.isInstanceOf[Empty2Null])) } + // SPARK-53738: the required ordering inferred from table spec (partition, bucketing, etc.) + // may contain foldable sort ordering expressions, which causes the optimized query's output + // ordering mismatch, here we calculate the required ordering more accurately, by creating a + // fake Sort node with the input query, then remove the foldable sort ordering expressions. + def eliminateFoldableOrdering(ordering: Seq[SortOrder], query: LogicalPlan): LogicalPlan = + EliminateSorts(FoldablePropagation(Sort(ordering, global = false, query))) + + // The comparison ignores SortDirection and NullOrdering since it doesn't matter + // for writing cases. def isOrderingMatched( requiredOrdering: Seq[Expression], outputOrdering: Seq[SortOrder]): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala index 80d771428d90..a46afcef3cdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/V1WriteCommandSuite.scala @@ -63,10 +63,23 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper hasLogicalSort: Boolean, orderingMatched: Boolean, hasEmpty2Null: Boolean = false)(query: => Unit): Unit = { - var optimizedPlan: LogicalPlan = null + executeAndCheckOrderingAndCustomValidate( + hasLogicalSort, Some(orderingMatched), hasEmpty2Null)(query)(_ => ()) + } + + /** + * Execute a write query and check ordering of the plan, then do custom validation + */ + protected def executeAndCheckOrderingAndCustomValidate( + hasLogicalSort: Boolean, + orderingMatched: Option[Boolean], + hasEmpty2Null: Boolean = false)(query: => Unit)( + customValidate: LogicalPlan => Unit): Unit = { + @volatile var optimizedPlan: LogicalPlan = null val listener = new QueryExecutionListener { override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + val conf = qe.sparkSession.sessionState.conf qe.optimizedPlan match { case w: V1WriteCommand => if (hasLogicalSort && conf.getConf(SQLConf.PLANNED_WRITE_ENABLED)) { @@ -85,9 +98,12 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper query - // Check whether the output ordering is matched before FileFormatWriter executes rdd. - assert(FileFormatWriter.outputOrderingMatched == orderingMatched, - s"Expect: $orderingMatched, Actual: ${FileFormatWriter.outputOrderingMatched}") + orderingMatched.foreach { matched => + // Check whether the output ordering is matched before FileFormatWriter executes rdd. + assert(FileFormatWriter.outputOrderingMatched == matched, + s"Expect orderingMatched: $matched, " + + s"Actual: ${FileFormatWriter.outputOrderingMatched}") + } sparkContext.listenerBus.waitUntilEmpty() @@ -103,6 +119,8 @@ trait V1WriteCommandSuiteBase extends SQLTestUtils with AdaptiveSparkPlanHelper assert(empty2nullExpr == hasEmpty2Null, s"Expect hasEmpty2Null: $hasEmpty2Null, Actual: $empty2nullExpr. Plan:\n$optimizedPlan") + customValidate(optimizedPlan) + spark.listenerManager.unregister(listener) } } @@ -391,4 +409,33 @@ class V1WriteCommandSuite extends QueryTest with SharedSparkSession with V1Write } } } + + test("v1 write with sort by literal column preserve custom order") { + withPlannedWrite { enabled => + withTable("t") { + sql( + """ + |CREATE TABLE t(i INT, j INT, k STRING) USING PARQUET + |PARTITIONED BY (k) + |""".stripMargin) + // Skip checking orderingMatched temporarily to avoid touching `FileFormatWriter`, + // see details at https://github.com/apache/spark/pull/52584#issuecomment-3407716019 + executeAndCheckOrderingAndCustomValidate( + hasLogicalSort = true, orderingMatched = None) { + sql( + """ + |INSERT OVERWRITE t + |SELECT i, j, '0' as k FROM t0 SORT BY k, i + |""".stripMargin) + } { optimizedPlan => + assert { + optimizedPlan.outputOrdering.exists { + case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i" + case _ => false + } + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala index e0e056be5987..a3e864ee55c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/command/V1WriteHiveCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.execution.command import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SortOrder} import org.apache.spark.sql.execution.datasources.V1WriteCommandSuiteBase import org.apache.spark.sql.hive.HiveUtils._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -126,4 +127,37 @@ class V1WriteHiveCommandSuite } } } + + test("v1 write to hive table with sort by literal column preserve custom order") { + withCovnertMetastore { _ => + withPlannedWrite { enabled => + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTable("t") { + sql( + """ + |CREATE TABLE t(i INT, j INT, k STRING) STORED AS PARQUET + |PARTITIONED BY (k) + |""".stripMargin) + // Skip checking orderingMatched temporarily to avoid touching `FileFormatWriter`, + // see details at https://github.com/apache/spark/pull/52584#issuecomment-3407716019 + executeAndCheckOrderingAndCustomValidate( + hasLogicalSort = true, orderingMatched = None) { + sql( + """ + |INSERT OVERWRITE t + |SELECT i, j, '0' as k FROM t0 SORT BY k, i + |""".stripMargin) + } { optimizedPlan => + assert { + optimizedPlan.outputOrdering.exists { + case SortOrder(attr: AttributeReference, _, _, _) => attr.name == "i" + case _ => false + } + } + } + } + } + } + } + } }