diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 827df04443e5..5d670ecdf1d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1466,6 +1466,12 @@ object EliminateSorts extends Rule[LogicalPlan] { plan match { case Sort(_, global, child) if canRemoveGlobalSort || !global => recursiveRemoveSort(child, canRemoveGlobalSort) + case Sort(sortOrder, true, child) => + // For this case, the upper sort is local so the ordering of present sort is unnecessary, + // so here we only preserve its output partitioning using `RepartitionByExpression`. + // We should use `None` as the optNumPartitions so AQE can coalesce shuffle partitions. + // This behavior is same with original global sort. + RepartitionByExpression(sortOrder, recursiveRemoveSort(child, true), None) case other if canEliminateSort(other) => other.withNewChildren(other.children.map(c => recursiveRemoveSort(c, canRemoveGlobalSort))) case other if canEliminateGlobalSort(other) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 053bc1c21373..7ceac3b3000c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -426,16 +426,20 @@ class EliminateSortsSuite extends AnalysisTest { test("SPARK-39835: Fix EliminateSorts remove global sort below the local sort") { // global -> local val plan = testRelation.orderBy($"a".asc).sortBy($"c".asc).analyze - comparePlans(Optimize.execute(plan), plan) + val expect = RepartitionByExpression($"a".asc :: Nil, testRelation, None) + .sortBy($"c".asc).analyze + comparePlans(Optimize.execute(plan), expect) // global -> global -> local val plan2 = testRelation.orderBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze - val expected2 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze + val expected2 = RepartitionByExpression($"b".asc :: Nil, testRelation, None) + .sortBy($"c".asc).analyze comparePlans(Optimize.execute(plan2), expected2) // local -> global -> local val plan3 = testRelation.sortBy($"a".asc).orderBy($"b".asc).sortBy($"c".asc).analyze - val expected3 = testRelation.orderBy($"b".asc).sortBy($"c".asc).analyze + val expected3 = RepartitionByExpression($"b".asc :: Nil, testRelation, None) + .sortBy($"c".asc).analyze comparePlans(Optimize.execute(plan3), expected3) } }