Skip to content

Commit 709684f

Browse files
committed
Changed parser to support case when function.
1 parent c9ae79f commit 709684f

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
128128
protected val UNION = Keyword("UNION")
129129
protected val UPPER = Keyword("UPPER")
130130
protected val WHERE = Keyword("WHERE")
131+
protected val CASE = Keyword("CASE")
132+
protected val WHEN = Keyword("WHEN")
133+
protected val THEN = Keyword("THEN")
134+
protected val ELSE = Keyword("ELSE")
135+
protected val END = Keyword("END")
131136

132137
// Use reflection to find the reserved words defined in this class.
133138
protected val reservedWords =
@@ -333,6 +338,24 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
333338
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
334339
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
335340
} |
341+
CASE ~> opt(expression) ~ (WHEN ~ expression ~ THEN ~ expression).* ~
342+
opt(ELSE ~> expression) <~ END ^^ {
343+
case c ~ l ~ el =>
344+
var caseWhenExpr = l.map{x =>
345+
x match {
346+
case w ~ we ~ t ~ te =>
347+
c match {
348+
case Some(e) => Seq(EqualTo(e, we), te)
349+
case None => Seq(we, te)
350+
}
351+
}
352+
}.toSeq.reduce(_ ++ _)
353+
caseWhenExpr = el match {
354+
case Some(e) => caseWhenExpr ++ Seq(e)
355+
case None => caseWhenExpr
356+
}
357+
CaseWhen(caseWhenExpr)
358+
} |
336359
(SUBSTR | SUBSTRING) ~> "(" ~> expression ~ "," ~ expression <~ ")" ^^ {
337360
case s ~ "," ~ p => Substring(s,p,Literal(Integer.MAX_VALUE))
338361
} |

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -680,9 +680,20 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
680680
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
681681
("true", "false") :: Nil)
682682
}
683-
683+
684684
test("SPARK-3371 Renaming a function expression with group by gives error") {
685685
registerFunction("len", (s: String) => s.length)
686686
checkAnswer(
687-
sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)}
687+
sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
688+
}
689+
690+
test("SPARK-3813 CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END") {
691+
checkAnswer(
692+
sql("SELECT CASE key WHEN 1 THEN 1 ELSE 0 END FROM testData WHERE key = 1 group by key"), 1)
693+
}
694+
695+
test("SPARK-3813 CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END") {
696+
checkAnswer(
697+
sql("SELECT CASE WHEN key=1 THEN 1 ELSE 2 END FROM testData WHERE key = 1 group by key"), 1)
698+
}
688699
}

0 commit comments

Comments
 (0)