Skip to content

Commit 5fdcbdc

Browse files
chenghao-intelmarmbrus
authored andcommitted
[SPARK-4625] [SQL] Add sort by for DSL & SimpleSqlParser
Add `sort by` support for both DSL & SqlParser. This PR is relevant with #3386, either one merged, will cause the other rebased. Author: Cheng Hao <[email protected]> Closes #3481 from chenghao-intel/sortby and squashes the following commits: 041004f [Cheng Hao] Add sort by for DSL & SimpleSqlParser
1 parent cf50631 commit 5fdcbdc

File tree

5 files changed

+48
-2
lines changed

5 files changed

+48
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ class SqlParser extends AbstractSparkSQLParser {
8585
protected val ON = Keyword("ON")
8686
protected val OR = Keyword("OR")
8787
protected val ORDER = Keyword("ORDER")
88+
protected val SORT = Keyword("SORT")
8889
protected val OUTER = Keyword("OUTER")
8990
protected val OVERWRITE = Keyword("OVERWRITE")
9091
protected val REGEXP = Keyword("REGEXP")
@@ -140,7 +141,7 @@ class SqlParser extends AbstractSparkSQLParser {
140141
(WHERE ~> expression).? ~
141142
(GROUP ~ BY ~> rep1sep(expression, ",")).? ~
142143
(HAVING ~> expression).? ~
143-
(ORDER ~ BY ~> ordering).? ~
144+
sortType.? ~
144145
(LIMIT ~> expression).? ^^ {
145146
case d ~ p ~ r ~ f ~ g ~ h ~ o ~ l =>
146147
val base = r.getOrElse(NoRelation)
@@ -150,7 +151,7 @@ class SqlParser extends AbstractSparkSQLParser {
150151
.getOrElse(Project(assignAliases(p), withFilter))
151152
val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection)
152153
val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct)
153-
val withOrder = o.map(Sort(_, withHaving)).getOrElse(withHaving)
154+
val withOrder = o.map(_(withHaving)).getOrElse(withHaving)
154155
val withLimit = l.map(Limit(_, withOrder)).getOrElse(withOrder)
155156
withLimit
156157
}
@@ -202,6 +203,11 @@ class SqlParser extends AbstractSparkSQLParser {
202203
| FULL ~ OUTER.? ^^^ FullOuter
203204
)
204205

206+
protected lazy val sortType: Parser[LogicalPlan => LogicalPlan] =
207+
( ORDER ~ BY ~> ordering ^^ { case o => l: LogicalPlan => Sort(o, l) }
208+
| SORT ~ BY ~> ordering ^^ { case o => l: LogicalPlan => SortPartitions(o, l) }
209+
)
210+
205211
protected lazy val ordering: Parser[Seq[SortOrder]] =
206212
( rep1sep(singleOrder, ",")
207213
| rep1sep(expression, ",") ~ direction.? ^^ {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,8 @@ package object dsl {
246246

247247
def orderBy(sortExprs: SortOrder*) = Sort(sortExprs, logicalPlan)
248248

249+
def sortBy(sortExprs: SortOrder*) = SortPartitions(sortExprs, logicalPlan)
250+
249251
def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*) = {
250252
val aliasedExprs = aggregateExprs.map {
251253
case ne: NamedExpression => ne

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,19 @@ class SchemaRDD(
216216
def orderBy(sortExprs: SortOrder*): SchemaRDD =
217217
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
218218

219+
/**
220+
* Sorts the results by the given expressions within partition.
221+
* {{{
222+
* schemaRDD.sortBy('a)
223+
* schemaRDD.sortBy('a, 'b)
224+
* schemaRDD.sortBy('a.asc, 'b.desc)
225+
* }}}
226+
*
227+
* @group Query
228+
*/
229+
def sortBy(sortExprs: SortOrder*): SchemaRDD =
230+
new SchemaRDD(sqlContext, SortPartitions(sortExprs, logicalPlan))
231+
219232
@deprecated("use limit with integer argument", "1.1.0")
220233
def limit(limitExpr: Expression): SchemaRDD =
221234
new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,24 @@ class DslQuerySuite extends QueryTest {
120120
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
121121
}
122122

123+
test("sorting #2") {
124+
checkAnswer(
125+
testData2.sortBy('a.asc, 'b.asc),
126+
Seq((1,1), (1,2), (2,1), (2,2), (3,1), (3,2)))
127+
128+
checkAnswer(
129+
testData2.sortBy('a.asc, 'b.desc),
130+
Seq((1,2), (1,1), (2,2), (2,1), (3,2), (3,1)))
131+
132+
checkAnswer(
133+
testData2.sortBy('a.desc, 'b.desc),
134+
Seq((3,2), (3,1), (2,2), (2,1), (1,2), (1,1)))
135+
136+
checkAnswer(
137+
testData2.sortBy('a.desc, 'b.asc),
138+
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
139+
}
140+
123141
test("limit") {
124142
checkAnswer(
125143
testData.limit(10),

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
4242
TimeZone.setDefault(origZone)
4343
}
4444

45+
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
46+
checkAnswer(
47+
sql("SELECT a FROM testData2 SORT BY a"),
48+
Seq(1, 1, 2 ,2 ,3 ,3).map(Seq(_))
49+
)
50+
}
51+
4552
test("grouping on nested fields") {
4653
jsonRDD(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil))
4754
.registerTempTable("rows")

0 commit comments

Comments
 (0)