Skip to content

Commit ca348e5

Browse files
allisonwang-dbcloud-fan
authored andcommitted
[SPARK-36028][SQL] Allow Project to host outer references in scalar subqueries
### What changes were proposed in this pull request? This PR allows the `Project` node to host outer references in scalar subqueries when `decorrelateInnerQuery` is enabled. It is already supported by the new decorrelation framework and the `RewriteCorrelatedScalarSubquery` rule. Note currently by default all correlated subqueries will be decorrelated, which is not necessarily the most optimal approach. Consider `SELECT (SELECT c1) FROM t`. This should be optimized as `SELECT c1 FROM t` instead of rewriting it as a left outer join. This will be done in a separate PR to optimize correlated scalar/lateral subqueries with OneRowRelation. ### Why are the changes needed? To allow more types of correlated scalar subqueries. ### Does this PR introduce _any_ user-facing change? Yes. This PR allows outer query column references in the SELECT cluase of a correlated scalar subquery. For example: ```sql SELECT (SELECT c1) FROM t; ``` Before this change: ``` org.apache.spark.sql.AnalysisException: Expressions referencing the outer query are not supported outside of WHERE/HAVING clauses ``` After this change: ``` +------------------+ |scalarsubquery(c1)| +------------------+ |0 | |1 | +------------------+ ``` ### How was this patch tested? Added unit tests and SQL tests. Closes #33235 from allisonwang-db/spark-36028-outer-in-project. Authored-by: allisonwang-db <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent bad6f89 commit ca348e5

File tree

5 files changed

+144
-18
lines changed

5 files changed

+144
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -725,9 +725,15 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
725725
s"Filter/Aggregate/Project and a few commands: $plan")
726726
}
727727
}
728+
// Validate to make sure the correlations appearing in the query are valid and
729+
// allowed by spark.
730+
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)
728731

729732
case _: LateralSubquery =>
730733
assert(plan.isInstanceOf[LateralJoin])
734+
// Validate to make sure the correlations appearing in the query are valid and
735+
// allowed by spark.
736+
checkCorrelationsInSubquery(expr.plan, isScalarOrLateral = true)
731737

732738
case inSubqueryOrExistsSubquery =>
733739
plan match {
@@ -736,11 +742,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
736742
failAnalysis(s"IN/EXISTS predicate sub-queries can only be used in" +
737743
s" Filter/Join and a few commands: $plan")
738744
}
745+
// Validate to make sure the correlations appearing in the query are valid and
746+
// allowed by spark.
747+
checkCorrelationsInSubquery(expr.plan)
739748
}
740-
741-
// Validate to make sure the correlations appearing in the query are valid and
742-
// allowed by spark.
743-
checkCorrelationsInSubquery(expr.plan, isLateral = plan.isInstanceOf[LateralJoin])
744749
}
745750

746751
/**
@@ -779,7 +784,9 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
779784
* Validates to make sure the outer references appearing inside the subquery
780785
* are allowed.
781786
*/
782-
private def checkCorrelationsInSubquery(sub: LogicalPlan, isLateral: Boolean = false): Unit = {
787+
private def checkCorrelationsInSubquery(
788+
sub: LogicalPlan,
789+
isScalarOrLateral: Boolean = false): Unit = {
783790
// Validate that correlated aggregate expression do not contain a mixture
784791
// of outer and local references.
785792
def checkMixedReferencesInsideAggregateExpr(expr: Expression): Unit = {
@@ -800,11 +807,11 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
800807
}
801808

802809
// Check whether the logical plan node can host outer references.
803-
// A `Project` can host outer references if it is inside a lateral subquery.
804-
// Otherwise, only Filter can only outer references.
810+
// A `Project` can host outer references if it is inside a scalar or a lateral subquery and
811+
// DecorrelateInnerQuery is enabled. Otherwise, only Filter can only outer references.
805812
def canHostOuter(plan: LogicalPlan): Boolean = plan match {
806813
case _: Filter => true
807-
case _: Project => isLateral
814+
case _: Project => isScalarOrLateral && SQLConf.get.decorrelateInnerQueryEnabled
808815
case _ => false
809816
}
810817

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -824,13 +824,6 @@ class AnalysisErrorSuite extends AnalysisTest {
824824
Project(ScalarSubquery(t0.select(star("t1"))).as("sub") :: Nil, t1),
825825
"Scalar subquery must return only one column, but got 2" :: Nil)
826826

827-
// array(t1.*) in the subquery should be resolved into array(outer(t1.a), outer(t1.b))
828-
val array = CreateArray(Seq(star("t1")))
829-
assertAnalysisError(
830-
Project(ScalarSubquery(t0.select(array)).as("sub") :: Nil, t1),
831-
"Expressions referencing the outer query are not supported outside" +
832-
" of WHERE/HAVING clauses" :: Nil)
833-
834827
// t2.* cannot be resolved and the error should be the initial analysis exception.
835828
assertAnalysisError(
836829
Project(ScalarSubquery(t0.select(star("t2"))).as("sub") :: Nil, t1),

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveSubquerySuite.scala

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.analysis
2020
import org.apache.spark.sql.AnalysisException
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference}
23+
import org.apache.spark.sql.catalyst.expressions.{CreateArray, Expression, GetStructField, InSubquery, LateralSubquery, ListQuery, OuterReference, ScalarSubquery}
2424
import org.apache.spark.sql.catalyst.expressions.aggregate.Count
2525
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
2626
import org.apache.spark.sql.catalyst.plans.logical._
@@ -240,4 +240,28 @@ class ResolveSubquerySuite extends AnalysisTest {
240240
Inner, None)
241241
)
242242
}
243+
244+
test("SPARK-36028: resolve scalar subqueries with outer references in Project") {
245+
// SELECT (SELECT a) FROM t1
246+
checkAnalysis(
247+
Project(ScalarSubquery(t0.select('a)).as("sub") :: Nil, t1),
248+
Project(ScalarSubquery(Project(OuterReference(a) :: Nil, t0), Seq(a)).as("sub") :: Nil, t1)
249+
)
250+
// SELECT (SELECT a + b + c AS r FROM t2) FROM t1
251+
checkAnalysis(
252+
Project(ScalarSubquery(
253+
t2.select(('a + 'b + 'c).as("r"))).as("sub") :: Nil, t1),
254+
Project(ScalarSubquery(
255+
Project((OuterReference(a) + b + c).as("r") :: Nil, t2), Seq(a)).as("sub") :: Nil, t1)
256+
)
257+
// SELECT (SELECT array(t1.*) AS arr) FROM t1
258+
checkAnalysis(
259+
Project(ScalarSubquery(t0.select(
260+
CreateArray(Seq(star("t1"))).as("arr"))
261+
).as("sub") :: Nil, t1.as("t1")),
262+
Project(ScalarSubquery(Project(
263+
CreateArray(Seq(OuterReference(a), OuterReference(b))).as("arr") :: Nil, t0
264+
), Seq(a, b)).as("sub") :: Nil, t1)
265+
)
266+
}
243267
}

sql/core/src/test/resources/sql-tests/inputs/subquery/scalar-subquery/scalar-subquery-select.sql

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,4 +137,11 @@ SELECT t1a,
137137
(SELECT collect_list(t2d) FROM t2 WHERE t2a = t1a) collect_list_t2,
138138
(SELECT sort_array(collect_set(t2d)) FROM t2 WHERE t2a = t1a) collect_set_t2,
139139
(SELECT hex(count_min_sketch(t2d, 0.5d, 0.5d, 1)) FROM t2 WHERE t2a = t1a) collect_set_t2
140-
FROM t1;
140+
FROM t1;
141+
142+
-- SPARK-36028: Allow Project to host outer references in scalar subqueries
143+
SELECT t1c, (SELECT t1c) FROM t1;
144+
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1;
145+
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1;
146+
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1;
147+
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1;

sql/core/src/test/resources/sql-tests/results/subquery/scalar-subquery/scalar-subquery-select.sql.out

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 12
2+
-- Number of queries: 17
33

44

55
-- !query
@@ -222,3 +222,98 @@ val1d 0 0 0 [] [] 0000000100000000000000000000000100000004000000005D8D6AB9000000
222222
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
223223
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
224224
val1e 1 1 1 [19] [19] 0000000100000000000000010000000100000004000000005D8D6AB90000000000000000000000000000000100000000000000000000000000000000
225+
226+
227+
-- !query
228+
SELECT t1c, (SELECT t1c) FROM t1
229+
-- !query schema
230+
struct<t1c:int,scalarsubquery(t1c):int>
231+
-- !query output
232+
12 12
233+
12 12
234+
16 16
235+
16 16
236+
16 16
237+
16 16
238+
8 8
239+
8 8
240+
NULL NULL
241+
NULL NULL
242+
NULL NULL
243+
NULL NULL
244+
245+
246+
-- !query
247+
SELECT t1c, (SELECT t1c WHERE t1c = 8) FROM t1
248+
-- !query schema
249+
struct<t1c:int,scalarsubquery(t1c, t1c):int>
250+
-- !query output
251+
12 NULL
252+
12 NULL
253+
16 NULL
254+
16 NULL
255+
16 NULL
256+
16 NULL
257+
8 8
258+
8 8
259+
NULL NULL
260+
NULL NULL
261+
NULL NULL
262+
NULL NULL
263+
264+
265+
-- !query
266+
SELECT t1c, t1d, (SELECT c + d FROM (SELECT t1c AS c, t1d AS d)) FROM t1
267+
-- !query schema
268+
struct<t1c:int,t1d:bigint,scalarsubquery(t1c, t1d):bigint>
269+
-- !query output
270+
12 10 22
271+
12 21 33
272+
16 19 35
273+
16 19 35
274+
16 19 35
275+
16 22 38
276+
8 10 18
277+
8 10 18
278+
NULL 12 NULL
279+
NULL 19 NULL
280+
NULL 19 NULL
281+
NULL 25 NULL
282+
283+
284+
-- !query
285+
SELECT t1c, (SELECT SUM(c) FROM (SELECT t1c AS c)) FROM t1
286+
-- !query schema
287+
struct<t1c:int,scalarsubquery(t1c):bigint>
288+
-- !query output
289+
12 12
290+
12 12
291+
16 16
292+
16 16
293+
16 16
294+
16 16
295+
8 8
296+
8 8
297+
NULL NULL
298+
NULL NULL
299+
NULL NULL
300+
NULL NULL
301+
302+
303+
-- !query
304+
SELECT t1a, (SELECT SUM(t2b) FROM t2 JOIN (SELECT t1a AS a) ON t2a = a) FROM t1
305+
-- !query schema
306+
struct<t1a:string,scalarsubquery(t1a):bigint>
307+
-- !query output
308+
val1a NULL
309+
val1a NULL
310+
val1a NULL
311+
val1a NULL
312+
val1b 36
313+
val1c 24
314+
val1d NULL
315+
val1d NULL
316+
val1d NULL
317+
val1e 8
318+
val1e 8
319+
val1e 8

0 commit comments

Comments
 (0)