@@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
1919
2020import org .apache .spark .sql .catalyst .dsl .expressions ._
2121import org .apache .spark .sql .catalyst .dsl .plans ._
22- import org .apache .spark .sql .catalyst .expressions .{GetStructField , MultiScalarSubquery , ScalarSubquery }
22+ import org .apache .spark .sql .catalyst .expressions .{CreateStruct , GetStructField , ScalarSubquery }
2323import org .apache .spark .sql .catalyst .expressions .aggregate .{CollectList , CollectSet }
2424import org .apache .spark .sql .catalyst .plans ._
2525import org .apache .spark .sql .catalyst .plans .logical ._
2626import org .apache .spark .sql .catalyst .rules ._
2727
2828class MergeScalarSubqueriesSuite extends PlanTest {
29-
3029 private object Optimize extends RuleExecutor [LogicalPlan ] {
3130 val batches = Batch (" MergeScalarSubqueries" , Once , MergeScalarSubqueries ) :: Nil
3231 }
@@ -35,82 +34,81 @@ class MergeScalarSubqueriesSuite extends PlanTest {
3534
3635 test(" Simple non-correlated scalar subquery merge" ) {
3736 val subquery1 = testRelation
38- .groupBy(' b )(max(' a ))
37+ .groupBy(' b )(max(' a ).as( " max_a " ) )
3938 val subquery2 = testRelation
40- .groupBy(' b )(sum(' a ))
39+ .groupBy(' b )(sum(' a ).as( " sum_a " ) )
4140 val originalQuery = testRelation
4241 .select(ScalarSubquery (subquery1), ScalarSubquery (subquery2))
4342
4443 val multiSubquery = testRelation
45- .groupBy(' b )(max(' a ), sum(' a )).analyze
44+ .groupBy(' b )(max(' a ).as(" max_a" ), sum(' a ).as(" sum_a" ))
45+ .select(CreateStruct (Seq (' max_a , ' sum_a )).as(" mergedValue" ))
4646 val correctAnswer = testRelation
47- .select(GetStructField (MultiScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
48- GetStructField (MultiScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
47+ .select(GetStructField (ScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
48+ GetStructField (ScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
4949
50- // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
51- // as only `Optimizer` can insert such a node to the plan
52- comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer, false )
50+ comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer.analyze)
5351 }
5452
5553 test(" Aggregate and group expression merge" ) {
5654 val subquery1 = testRelation
57- .groupBy(' b )(max(' a ))
55+ .groupBy(' b )(max(' a ).as( " max_a " ) )
5856 val subquery2 = testRelation
5957 .groupBy(' b )(' b )
6058 val originalQuery = testRelation
6159 .select(ScalarSubquery (subquery1), ScalarSubquery (subquery2))
6260
6361 val multiSubquery = testRelation
64- .groupBy(' b )(max(' a ), ' b ).analyze
62+ .groupBy(' b )(max(' a ).as(" max_a" ), ' b )
63+ .select(CreateStruct (Seq (' max_a , ' b )).as(" mergedValue" ))
6564 val correctAnswer = testRelation
66- .select(GetStructField (MultiScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
67- GetStructField (MultiScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
65+ .select(GetStructField (ScalarSubquery (multiSubquery), 0 ).as(" scalarsubquery()" ),
66+ GetStructField (ScalarSubquery (multiSubquery), 1 ).as(" scalarsubquery()" ))
6867
69- // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
70- // as only `Optimizer` can insert such a node to the plan
71- comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer, false )
68+ comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer.analyze)
7269 }
7370
7471 test(" Do not merge different aggregate implementations" ) {
7572 // supports HashAggregate
7673 val subquery1 = testRelation
77- .groupBy(' b )(max(' a ))
74+ .groupBy(' b )(max(' a ).as( " max_a " ) )
7875 val subquery2 = testRelation
79- .groupBy(' b )(min(' a ))
76+ .groupBy(' b )(min(' a ).as( " min_a " ) )
8077
8178 // supports ObjectHashAggregate
8279 val subquery3 = testRelation
83- .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ))
80+ .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ).as( " collectlist_a " ) )
8481 val subquery4 = testRelation
85- .groupBy(' b )(CollectSet (' a ).toAggregateExpression(isDistinct = false ))
82+ .groupBy(' b )(CollectSet (' a ).toAggregateExpression(isDistinct = false ).as( " collectset_a " ) )
8683
8784 // supports SortAggregate
8885 val subquery5 = testRelation
89- .groupBy(' b )(max(' c ))
86+ .groupBy(' b )(max(' c ).as( " max_c " ) )
9087 val subquery6 = testRelation
91- .groupBy(' b )(min(' c ))
88+ .groupBy(' b )(min(' c ).as( " min_c " ) )
9289
9390 val originalQuery = testRelation
9491 .select(ScalarSubquery (subquery1), ScalarSubquery (subquery2), ScalarSubquery (subquery3),
9592 ScalarSubquery (subquery4), ScalarSubquery (subquery5), ScalarSubquery (subquery6))
9693
9794 val hashAggregates = testRelation
98- .groupBy(' b )(max(' a ), min(' a )).analyze
95+ .groupBy(' b )(max(' a ).as(" max_a" ), min(' a ).as(" min_a" ))
96+ .select(CreateStruct (Seq (' max_a , ' min_a )).as(" mergedValue" ))
9997 val objectHashAggregates = testRelation
100- .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ),
101- CollectSet (' a ).toAggregateExpression(isDistinct = false )).analyze
98+ .groupBy(' b )(CollectList (' a ).toAggregateExpression(isDistinct = false ).as(" collectlist_a" ),
99+ CollectSet (' a ).toAggregateExpression(isDistinct = false ).as(" collectset_a" ))
100+ .select(CreateStruct (Seq (' collectlist_a , ' collectset_a )).as(" mergedValue" ))
102101 val sortAggregates = testRelation
103- .groupBy(' b )(max(' c ), min(' c )).analyze
102+ .groupBy(' b )(max(' c ).as(" max_c" ), min(' c ).as(" min_c" ))
103+ .select(CreateStruct (Seq (' max_c , ' min_c )).as(" mergedValue" ))
104104 val correctAnswer = testRelation
105- .select(GetStructField (MultiScalarSubquery (hashAggregates), 0 ).as(" scalarsubquery()" ),
106- GetStructField (MultiScalarSubquery (hashAggregates), 1 ).as(" scalarsubquery()" ),
107- GetStructField (MultiScalarSubquery (objectHashAggregates), 0 ).as(" scalarsubquery()" ),
108- GetStructField (MultiScalarSubquery (objectHashAggregates), 1 ).as(" scalarsubquery()" ),
109- GetStructField (MultiScalarSubquery (sortAggregates), 0 ).as(" scalarsubquery()" ),
110- GetStructField (MultiScalarSubquery (sortAggregates), 1 ).as(" scalarsubquery()" ))
111-
112- // checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
113- // as only `Optimizer` can insert such a node to the plan
114- comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer, false )
105+ .select(GetStructField (ScalarSubquery (hashAggregates), 0 ).as(" scalarsubquery()" ),
106+ GetStructField (ScalarSubquery (hashAggregates), 1 ).as(" scalarsubquery()" ),
107+ GetStructField (ScalarSubquery (objectHashAggregates), 0 ).as(" scalarsubquery()" ),
108+ GetStructField (ScalarSubquery (objectHashAggregates), 1 ).as(" scalarsubquery()" ),
109+ GetStructField (ScalarSubquery (sortAggregates), 0 ).as(" scalarsubquery()" ),
110+ GetStructField (ScalarSubquery (sortAggregates), 1 ).as(" scalarsubquery()" ))
111+
112+ comparePlans(Optimize .execute(originalQuery.analyze), correctAnswer.analyze)
115113 }
116114}
0 commit comments