@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919
2020import org .apache .spark .sql .catalyst .expressions .Explode
2121import org .apache .spark .sql .catalyst .plans .PlanTest
22- import org .apache .spark .sql .catalyst .plans .logical .{Project , LocalRelation , Generate , LogicalPlan }
22+ import org .apache .spark .sql .catalyst .plans .logical .{LocalRelation , LogicalPlan }
2323import org .apache .spark .sql .catalyst .rules .RuleExecutor
2424import org .apache .spark .sql .catalyst .dsl .expressions ._
2525import org .apache .spark .sql .catalyst .dsl .plans ._
@@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest {
3535 test(" Column pruning for Generate when Generate.join = false" ) {
3636 val input = LocalRelation (' a .int, ' b .array(StringType ))
3737
38- val query = Generate (Explode (' b ), false , false , None , ' s .string :: Nil , input).analyze
38+ val query = input.generate(Explode (' b ), join = false ).analyze
39+
3940 val optimized = Optimize .execute(query)
4041
41- val correctAnswer =
42- Generate (Explode (' b ), false , false , None , ' s .string :: Nil ,
43- Project (' b .attr :: Nil , input)).analyze
42+ val correctAnswer = input.select(' b ).generate(Explode (' b ), join = false ).analyze
4443
4544 comparePlans(optimized, correctAnswer)
4645 }
@@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest {
4948 val input = LocalRelation (' a .int, ' b .int, ' c .array(StringType ))
5049
5150 val query =
52- Project (Seq (' a , ' s ),
53- Generate (Explode (' c ), true , false , None , ' s .string :: Nil ,
54- input)).analyze
51+ input
52+ .generate(Explode (' c ), join = true , outputNames = " explode" :: Nil )
53+ .select(' a , ' explode )
54+ .analyze
55+
5556 val optimized = Optimize .execute(query)
5657
5758 val correctAnswer =
58- Project (Seq (' a , ' s ),
59- Generate (Explode (' c ), true , false , None , ' s .string :: Nil ,
60- Project (Seq (' a , ' c ),
61- input))).analyze
59+ input
60+ .select(' a , ' c )
61+ .generate(Explode (' c ), join = true , outputNames = " explode" :: Nil )
62+ .select(' a , ' explode )
63+ .analyze
6264
6365 comparePlans(optimized, correctAnswer)
6466 }
@@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest {
6769 val input = LocalRelation (' b .array(StringType ))
6870
6971 val query =
70- Project ((' s + 1 ).as(" s+1" ) :: Nil ,
71- Generate (Explode (' b ), true , false , None , ' s .string :: Nil ,
72- input)).analyze
72+ input
73+ .generate(Explode (' b ), join = true , outputNames = " explode" :: Nil )
74+ .select((' explode + 1 ).as(" result" ))
75+ .analyze
76+
7377 val optimized = Optimize .execute(query)
7478
7579 val correctAnswer =
76- Project ((' s + 1 ).as(" s+1" ) :: Nil ,
77- Generate (Explode (' b ), false , false , None , ' s .string :: Nil ,
78- input)).analyze
80+ input
81+ .generate(Explode (' b ), join = false , outputNames = " explode" :: Nil )
82+ .select((' explode + 1 ).as(" result" ))
83+ .analyze
7984
8085 comparePlans(optimized, correctAnswer)
8186 }
0 commit comments