Skip to content

Commit af55afd

Browse files
committed
Fix SPARK-12139: REGEX Column Specification for Hive Queries
1 parent e3d2022 commit af55afd

File tree

3 files changed

+152
-7
lines changed

3 files changed

+152
-7
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,33 @@ case class UnresolvedTableValuedFunction(
8383
override lazy val resolved = false
8484
}
8585

86+
/**
87+
* Represents all of the input attributes to a given relational operator, for example in
88+
* "SELECT `(id)?+.+` FROM ...".
89+
*
90+
* @param table an optional table that should be the target of the expansion. If omitted all
91+
* tables' columns are produced.
92+
*/
93+
case class UnresolvedRegex(expr: String, table: Option[String]) extends Star with Unevaluable {
94+
override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = {
95+
val expandedAttributes: Seq[Attribute] = table match {
96+
// If there is no table specified, use all input attributes that match expr
97+
case None => input.output.filter(_.name.matches(expr))
98+
// If there is a table, pick out attributes that are part of this table that match expr
99+
case Some(t) => input.output.filter(_.qualifier.filter(resolver(_, t)).nonEmpty)
100+
.filter(_.name.matches(expr))
101+
}
102+
103+
expandedAttributes.zip(input.output).map {
104+
case (n: NamedExpression, _) => n
105+
case (e, originalAttribute) =>
106+
Alias(e, originalAttribute.name)()
107+
}
108+
}
109+
110+
override def toString: String = table.map(_ + ".").getOrElse("") + expr
111+
}
112+
86113
/**
87114
* Holds the name of an attribute that has yet to be resolved.
88115
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ import javax.xml.bind.DatatypeConverter
2323

2424
import scala.collection.JavaConverters._
2525
import scala.collection.mutable.ArrayBuffer
26-
2726
import org.antlr.v4.runtime.{ParserRuleContext, Token}
2827
import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode}
29-
28+
import org.apache.spark.SparkEnv
3029
import org.apache.spark.internal.Logging
3130
import org.apache.spark.sql.AnalysisException
3231
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
@@ -1229,25 +1228,56 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
12291228
CaseWhen(branches, Option(ctx.elseExpression).map(expression))
12301229
}
12311230

1231+
def enableHiveSupportQuotedIdentifiers() : Boolean = {
1232+
SparkEnv.get != null &&
1233+
SparkEnv.get.conf != null &&
1234+
SparkEnv.get.conf.getBoolean("hive.support.quoted.identifiers", false)
1235+
}
1236+
12321237
/**
1233-
* Create a dereference expression. The return type depends on the type of the parent, this can
1234-
* either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an
1235-
* [[UnresolvedExtractValue]] if the parent is some expression.
1238+
* Create a dereference expression. The return type depends on the type of the parent.
1239+
* If the parent is an [[UnresolvedAttribute]], it can be a [[UnresolvedAttribute]] or
1240+
* a [[UnresolvedRegex]] for regex quoted in ``; if the parent is some other expression,
1241+
* it can be [[UnresolvedExtractValue]].
12361242
*/
12371243
override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) {
12381244
val attr = ctx.fieldName.getText
12391245
expression(ctx.base) match {
1240-
case UnresolvedAttribute(nameParts) =>
1246+
case unresolved_attr @ UnresolvedAttribute(nameParts) =>
1247+
if (enableHiveSupportQuotedIdentifiers) {
1248+
val escapedIdentifier = "`(.+)`".r
1249+
val ret = Option(ctx.fieldName.getStart).map(_.getText match {
1250+
case r@escapedIdentifier(i) =>
1251+
UnresolvedRegex(i, Some(unresolved_attr.name))
1252+
case _ =>
1253+
UnresolvedAttribute(nameParts :+ attr)
1254+
})
1255+
return ret.get
1256+
}
1257+
12411258
UnresolvedAttribute(nameParts :+ attr)
12421259
case e =>
12431260
UnresolvedExtractValue(e, Literal(attr))
12441261
}
12451262
}
12461263

12471264
/**
1248-
* Create an [[UnresolvedAttribute]] expression.
1265+
* Create an [[UnresolvedAttribute]] expression or a [[UnresolvedRegex]] if it is a regex
1266+
* quoted in ``
12491267
*/
12501268
override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) {
1269+
if (enableHiveSupportQuotedIdentifiers) {
1270+
val escapedIdentifier = "`(.+)`".r
1271+
val ret = Option(ctx.getStart).map(_.getText match {
1272+
case r @ escapedIdentifier(i) =>
1273+
UnresolvedRegex(i, None)
1274+
case _ =>
1275+
UnresolvedAttribute.quoted(ctx.getText)
1276+
})
1277+
1278+
return ret.get
1279+
}
1280+
12511281
UnresolvedAttribute.quoted(ctx.getText)
12521282
}
12531283

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2624,4 +2624,92 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
26242624
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
26252625
assert(e.message.contains("Invalid number of arguments"))
26262626
}
2627+
2628+
test("SPARK-12139: REGEX Column Specification for Hive Queries") {
2629+
// hive.support.quoted.identifiers is turned off by default
2630+
checkAnswer(
2631+
sql(
2632+
"""
2633+
|SELECT b
2634+
|FROM testData2
2635+
|WHERE a = 1
2636+
""".stripMargin),
2637+
Row(1) :: Row(2) :: Nil)
2638+
2639+
checkAnswer(
2640+
sql(
2641+
"""
2642+
|SELECT t.b
2643+
|FROM testData2 t
2644+
|WHERE a = 1
2645+
""".stripMargin),
2646+
Row(1) :: Row(2) :: Nil)
2647+
2648+
intercept[AnalysisException] {
2649+
sql(
2650+
"""
2651+
|SELECT `(a)?+.+`
2652+
|FROM testData2
2653+
|WHERE a = 1
2654+
""".stripMargin)
2655+
}
2656+
2657+
intercept[AnalysisException] {
2658+
sql(
2659+
"""
2660+
|SELECT t.`(a)?+.+`
2661+
|FROM testData2 t
2662+
|WHERE a = 1
2663+
""".stripMargin)
2664+
}
2665+
2666+
// now, turn on hive.support.quoted.identifiers
2667+
sparkContext.conf.set("hive.support.quoted.identifiers", "true")
2668+
2669+
checkAnswer(
2670+
sql(
2671+
"""
2672+
|SELECT b
2673+
|FROM testData2
2674+
|WHERE a = 1
2675+
""".stripMargin),
2676+
Row(1) :: Row(2) :: Nil)
2677+
2678+
checkAnswer(
2679+
sql(
2680+
"""
2681+
|SELECT t.b
2682+
|FROM testData2 t
2683+
|WHERE a = 1
2684+
""".stripMargin),
2685+
Row(1) :: Row(2) :: Nil)
2686+
2687+
checkAnswer(
2688+
sql(
2689+
"""
2690+
|SELECT `(a)?+.+`
2691+
|FROM testData2
2692+
|WHERE a = 1
2693+
""".stripMargin),
2694+
Row(1) :: Row(2) :: Nil)
2695+
2696+
checkAnswer(
2697+
sql(
2698+
"""
2699+
|SELECT t.`(a)?+.+`
2700+
|FROM testData2 t
2701+
|WHERE a = 1
2702+
""".stripMargin),
2703+
Row(1) :: Row(2) :: Nil)
2704+
2705+
checkAnswer(
2706+
sql(
2707+
"""
2708+
|SELECT p.`(key)?+.+`, b, testdata2.`(b)?+.+`
2709+
|FROM testData p join testData2
2710+
|ON p.key = testData2.a
2711+
|WHERE key < 3
2712+
""".stripMargin),
2713+
Row("1", 1, 1) :: Row("1", 2, 1) :: Row("2", 1, 2) :: Row("2", 2, 2) ::Nil)
2714+
}
26272715
}

0 commit comments

Comments
 (0)