-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23957][SQL] Sorts in subqueries are redundant and can be removed #21853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -164,10 +164,20 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) | |
| * Optimize all the subqueries inside expression. | ||
| */ | ||
| object OptimizeSubqueries extends Rule[LogicalPlan] { | ||
| private def removeTopLevelSorts(plan: LogicalPlan): LogicalPlan = { | ||
| plan match { | ||
| case Sort(_, _, child) => child | ||
| case Project(fields, child) => Project(fields, removeTopLevelSorts(child)) | ||
| case other => other | ||
| } | ||
| } | ||
| def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { | ||
| case s: SubqueryExpression => | ||
| val Subquery(newPlan) = Optimizer.this.execute(Subquery(s.plan)) | ||
| s.withNewPlan(newPlan) | ||
| // At this point we have an optimized subquery plan that we are going to attach | ||
| // to this subquery expression. Here we can safely remove any top level sorts | ||
|
||
| // in the plan as tuples produced by a subquery are un-ordered. | ||
| s.withNewPlan(removeTopLevelSorts(newPlan)) | ||
| } | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,10 @@ | |
|
|
||
| package org.apache.spark.sql | ||
|
|
||
| import org.apache.spark.sql.catalyst.plans.logical.Join | ||
| import scala.collection.mutable.ArrayBuffer | ||
|
|
||
| import org.apache.spark.sql.catalyst.expressions.SubqueryExpression | ||
| import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort} | ||
| import org.apache.spark.sql.test.SharedSQLContext | ||
|
|
||
| class SubquerySuite extends QueryTest with SharedSQLContext { | ||
|
|
@@ -970,4 +973,300 @@ class SubquerySuite extends QueryTest with SharedSQLContext { | |
| Row("3", "b") :: Row("4", "b") :: Nil) | ||
| } | ||
| } | ||
|
|
||
| private def getNumSortsInQuery(query: String): Int = { | ||
| val plan = sql(query).queryExecution.optimizedPlan | ||
| getNumSorts(plan) + getSubqueryExpressions(plan).map{s => getNumSorts(s.plan)}.sum | ||
| } | ||
|
|
||
| private def getSubqueryExpressions(plan: LogicalPlan): Seq[SubqueryExpression] = { | ||
| val subqueryExpressions = ArrayBuffer.empty[SubqueryExpression] | ||
| plan transformAllExpressions { | ||
| case s: SubqueryExpression => | ||
| subqueryExpressions ++= (getSubqueryExpressions(s.plan) :+ s) | ||
| s | ||
| } | ||
| subqueryExpressions | ||
| } | ||
|
|
||
| private def getNumSorts(plan: LogicalPlan): Int = { | ||
| plan.collect { case s: Sort => s }.size | ||
| } | ||
|
|
||
| test("SPARK-23957 Remove redundant sort from subquery plan(in subquery)") { | ||
| withTempView("t1", "t2", "t3") { | ||
| Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") | ||
| Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") | ||
| Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") | ||
|
|
||
| // Simple order by | ||
| val query1 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |c1 IN (SELECT c1 FROM t2 ORDER BY c1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query1) == 0) | ||
|
|
||
| // Nested order bys | ||
| val query2 = | ||
| """ | ||
| |SELECT c1 | ||
| |FROM t1 | ||
| |WHERE c1 IN (SELECT c1 | ||
| | FROM (SELECT * | ||
| | FROM t2 | ||
| | ORDER BY c2) | ||
| | ORDER BY c1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query2) == 0) | ||
|
|
||
|
|
||
| // nested IN | ||
| val query3 = | ||
| """ | ||
| |SELECT c1 | ||
| |FROM t1 | ||
| |WHERE c1 IN (SELECT c1 | ||
| | FROM t2 | ||
| | WHERE c1 IN (SELECT c1 | ||
| | FROM t3 | ||
| | WHERE c1 = 1 | ||
| | ORDER BY c3) | ||
| | ORDER BY c2) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query3) == 0) | ||
|
|
||
| // Complex subplan and multiple sorts | ||
| val query4 = | ||
| """ | ||
| |SELECT c1 | ||
| |FROM t1 | ||
| |WHERE c1 IN (SELECT c1 | ||
| | FROM (SELECT c1, c2, count(*) | ||
| | FROM t2 | ||
| | GROUP BY c1, c2 | ||
| | HAVING count(*) > 0 | ||
| | ORDER BY c2) | ||
| | ORDER BY c1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query4) == 0) | ||
|
|
||
| // Join in subplan | ||
| val query5 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |c1 IN (SELECT t2.c1 FROM t2, t3 | ||
| | WHERE t2.c1 = t3.c1 | ||
| | ORDER BY t2.c1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query5) == 0) | ||
|
|
||
| val query6 = | ||
| """ | ||
| |SELECT c1 | ||
| |FROM t1 | ||
| |WHERE (c1, c2) IN (SELECT c1, max(c2) | ||
| | FROM (SELECT c1, c2, count(*) | ||
| | FROM t2 | ||
| | GROUP BY c1, c2 | ||
| | HAVING count(*) > 0 | ||
| | ORDER BY c2) | ||
| | GROUP BY c1 | ||
| | HAVING max(c2) > 0 | ||
| | ORDER BY c1) | ||
| """.stripMargin | ||
| // The rule to remove redundant sorts is not able to remove the inner sort under | ||
| // an Aggregate operator. We only remove the top level sort. | ||
| assert(getNumSortsInQuery(query6) == 1) | ||
|
|
||
| // Cases when sort is not removed from the plan | ||
| // Limit on top of sort | ||
| val query7 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |c1 IN (SELECT c1 FROM t2 ORDER BY c1 limit 1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query7) == 1) | ||
|
|
||
| // Sort below a set operations (intersect, union) | ||
| val query8 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |c1 IN (( | ||
| | SELECT c1 FROM t2 | ||
| | ORDER BY c1 | ||
| | ) | ||
| | UNION | ||
| | ( | ||
| | SELECT c1 FROM t2 | ||
| | ORDER BY c1 | ||
| | )) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query8) == 2) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-23957 Remove redundant sort from subquery plan(exists subquery)") { | ||
| withTempView("t1", "t2", "t3") { | ||
| Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") | ||
| Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") | ||
| Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") | ||
|
|
||
| // Simple order by exists correlated | ||
| val query1 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |EXISTS (SELECT t2.c1 FROM t2 WHERE t1.c1 = t2.c1 ORDER BY t2.c1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query1) == 0) | ||
|
|
||
| // Nested order by and correlated. | ||
| val query2 = | ||
| """ | ||
| |SELECT c1 | ||
| |FROM t1 | ||
| |WHERE EXISTS (SELECT c1 | ||
| | FROM (SELECT * | ||
| | FROM t2 | ||
| | WHERE t2.c1 = t1.c1 | ||
| | ORDER BY t2.c2) t2 | ||
| | ORDER BY t2.c1) | ||
|
||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query2) == 0) | ||
|
|
||
| // nested EXISTS | ||
| val query3 = | ||
| """ | ||
| |SELECT c1 | ||
| |FROM t1 | ||
| |WHERE EXISTS (SELECT c1 | ||
| | FROM t2 | ||
| | WHERE EXISTS (SELECT c1 | ||
| | FROM t3 | ||
| | WHERE t3.c1 = t2.c1 | ||
| | ORDER BY c3) | ||
| | AND t2.c1 = t1.c1 | ||
| | ORDER BY c2) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query3) == 0) | ||
|
|
||
| // Cases when sort is not removed from the plan | ||
| // Limit on top of sort | ||
| val query4 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |EXISTS (SELECT t2.c1 FROM t2 WHERE t2.c1 = 1 ORDER BY t2.c1 limit 1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query4) == 1) | ||
|
|
||
| // Sort below a set operations (intersect, union) | ||
| val query5 = | ||
| """ | ||
| |SELECT c1 FROM t1 | ||
| |WHERE | ||
| |EXISTS (( | ||
| | SELECT c1 FROM t2 | ||
| | WHERE t2.c1 = 1 | ||
| | ORDER BY t2.c1 | ||
| | ) | ||
| | UNION | ||
| | ( | ||
| | SELECT c1 FROM t2 | ||
| | WHERE t2.c1 = 2 | ||
| | ORDER BY t2.c1 | ||
| | )) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query5) == 2) | ||
| } | ||
| } | ||
|
|
||
| test("SPARK-23957 Remove redundant sort from subquery plan(scalar subquery)") { | ||
| withTempView("t1", "t2", "t3") { | ||
| Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t1") | ||
| Seq((1, 1), (2, 2)).toDF("c1", "c2").createOrReplaceTempView("t2") | ||
| Seq((1, 1, 1), (2, 2, 2)).toDF("c1", "c2", "c3").createOrReplaceTempView("t3") | ||
|
|
||
| // Two scalar subqueries in OR | ||
| val query1 = | ||
| """ | ||
| |SELECT * FROM t1 | ||
| |WHERE c1 = (SELECT max(t2.c1) | ||
| | FROM t2 | ||
| | ORDER BY max(t2.c1)) | ||
| |OR c2 = (SELECT min(t3.c2) | ||
| | FROM t3 | ||
| | WHERE t3.c1 = 1 | ||
| | ORDER BY min(t3.c2)) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query1) == 0) | ||
|
|
||
| // scalar subquery - groupby and having | ||
| val query2 = | ||
| """ | ||
| |SELECT * | ||
| |FROM t1 | ||
| |WHERE c1 = (SELECT max(t2.c1) | ||
| | FROM t2 | ||
| | GROUP BY t2.c1 | ||
| | HAVING count(*) >= 1 | ||
| | ORDER BY max(t2.c1)) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query2) == 0) | ||
|
|
||
| // nested scalar subquery | ||
| val query3 = | ||
| """ | ||
| |SELECT * | ||
| |FROM t1 | ||
| |WHERE c1 = (SELECT max(t2.c1) | ||
| | FROM t2 | ||
| | WHERE c1 = (SELECT max(t3.c1) | ||
| | FROM t3 | ||
| | WHERE t3.c1 = 1 | ||
| | GROUP BY t3.c1 | ||
| | ORDER BY max(t3.c1) | ||
| | ) | ||
| | GROUP BY t2.c1 | ||
| | HAVING count(*) >= 1 | ||
| | ORDER BY max(t2.c1)) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query3) == 0) | ||
|
|
||
| // Scalar subquery in projection | ||
| val query4 = | ||
| """ | ||
| |SELECT (SELECT min(c1) from t1 group by c1 order by c1) | ||
| |FROM t1 | ||
| |WHERE t1.c1 = 1 | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query4) == 0) | ||
|
|
||
| // Limit on top of sort prevents it from being pruned. | ||
| val query5 = | ||
| """ | ||
| |SELECT * | ||
| |FROM t1 | ||
| |WHERE c1 = (SELECT max(t2.c1) | ||
| | FROM t2 | ||
| | WHERE c1 = (SELECT max(t3.c1) | ||
| | FROM t3 | ||
| | WHERE t3.c1 = 1 | ||
| | GROUP BY t3.c1 | ||
| | ORDER BY max(t3.c1) | ||
| | ) | ||
| | GROUP BY t2.c1 | ||
| | HAVING count(*) >= 1 | ||
| | ORDER BY max(t2.c1) | ||
| | LIMIT 1) | ||
| """.stripMargin | ||
| assert(getNumSortsInQuery(query5) == 1) | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
removeTopLevelSort? (I think this func removes a single sort on the top?)