Skip to content

Commit 343ae27

Browse files
committed
[SPARK-4943][SQL] refactoring according to review
1 parent 29e5e55 commit 343ae27

File tree

15 files changed

+73
-69
lines changed

15 files changed

+73
-69
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,9 @@ class SqlParser extends AbstractSparkSQLParser {
178178
joinedRelation | relationFactor
179179

180180
protected lazy val relationFactor: Parser[LogicalPlan] =
181-
(
182-
ident ~ ("." ~> ident) ~ ("." ~> ident) ~ ("." ~> ident) ~ (opt(AS) ~> opt(ident)) ^^ {
183-
case reserveName1 ~ reserveName2 ~ dbName ~ tableName ~ alias =>
184-
UnresolvedRelation(IndexedSeq(tableName, dbName, reserveName2, reserveName1), alias)
181+
( rep1sep(ident, ".") ~ (opt(AS) ~> opt(ident)) ^^ {
182+
case tableIdent ~ alias => UnresolvedRelation(tableIdent, alias)
185183
}
186-
| ident ~ ("." ~> ident) ~ ("." ~> ident) ~ (opt(AS) ~> opt(ident)) ^^ {
187-
case reserveName1 ~ dbName ~ tableName ~ alias =>
188-
UnresolvedRelation(IndexedSeq(tableName, dbName, reserveName1), alias)
189-
}
190-
| ident ~ ("." ~> ident) ~ (opt(AS) ~> opt(ident)) ^^ {
191-
case dbName ~ tableName ~ alias =>
192-
UnresolvedRelation(IndexedSeq(tableName, dbName), alias)
193-
}
194-
| ident ~ (opt(AS) ~> opt(ident)) ^^ {
195-
case tableName ~ alias => UnresolvedRelation(IndexedSeq(tableName), alias)
196-
}
197184
| ("(" ~> start <~ ")") ~ (AS.? ~> ident) ^^ { case s ~ a => Subquery(a, s) }
198185
)
199186

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,35 @@ trait Catalog {
3131
def tableExists(tableIdentifier: Seq[String]): Boolean
3232

3333
def lookupRelation(
34-
tableIdentifier: Seq[String],
35-
alias: Option[String] = None): LogicalPlan
34+
tableIdentifier: Seq[String],
35+
alias: Option[String] = None): LogicalPlan
3636

3737
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
3838

3939
def unregisterTable(tableIdentifier: Seq[String]): Unit
4040

4141
def unregisterAllTables(): Unit
4242

43-
protected def processTableIdentifier(tableIdentifier: Seq[String]):
44-
Seq[String] = {
43+
protected def processTableIdentifier(tableIdentifier: Seq[String]): Seq[String] = {
4544
if (!caseSensitive) {
4645
tableIdentifier.map(_.toLowerCase)
4746
} else {
4847
tableIdentifier
4948
}
5049
}
5150

51+
protected def getDbTableName(tableIdent: Seq[String]): String = {
52+
val size = tableIdent.size
53+
if (size <= 2) {
54+
tableIdent.mkString(".")
55+
} else {
56+
tableIdent.slice(size - 2, size).mkString(".")
57+
}
58+
}
59+
60+
protected def getDBTable(tableIdent: Seq[String]) : (Option[String], String) = {
61+
(tableIdent.lift(tableIdent.size - 2), tableIdent.last)
62+
}
5263
}
5364

5465
class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
@@ -58,12 +69,12 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
5869
tableIdentifier: Seq[String],
5970
plan: LogicalPlan): Unit = {
6071
val tableIdent = processTableIdentifier(tableIdentifier)
61-
tables += ((tableIdent.mkString("."), plan))
72+
tables += ((getDbTableName(tableIdent), plan))
6273
}
6374

6475
override def unregisterTable(tableIdentifier: Seq[String]) = {
6576
val tableIdent = processTableIdentifier(tableIdentifier)
66-
tables -= tableIdent.mkString(".")
77+
tables -= getDbTableName(tableIdent)
6778
}
6879

6980
override def unregisterAllTables() = {
@@ -72,7 +83,7 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
7283

7384
override def tableExists(tableIdentifier: Seq[String]): Boolean = {
7485
val tableIdent = processTableIdentifier(tableIdentifier)
75-
tables.get(tableIdent.mkString(".")) match {
86+
tables.get(getDbTableName(tableIdent)) match {
7687
case Some(_) => true
7788
case None => false
7889
}
@@ -82,9 +93,9 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
8293
tableIdentifier: Seq[String],
8394
alias: Option[String] = None): LogicalPlan = {
8495
val tableIdent = processTableIdentifier(tableIdentifier)
85-
val tableFullName = tableIdent.mkString(".")
96+
val tableFullName = getDbTableName(tableIdent)
8697
val table = tables.getOrElse(tableFullName, sys.error(s"Table Not Found: $tableFullName"))
87-
val tableWithQualifiers = Subquery(tableIdent.head, table)
98+
val tableWithQualifiers = Subquery(tableIdent.last, table)
8899

89100
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
90101
// properly qualified with this alias.
@@ -101,11 +112,11 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
101112
trait OverrideCatalog extends Catalog {
102113

103114
// TODO: This doesn't work when the database changes...
104-
val overrides = new mutable.HashMap[String, LogicalPlan]()
115+
val overrides = new mutable.HashMap[(Option[String],String), LogicalPlan]()
105116

106117
abstract override def tableExists(tableIdentifier: Seq[String]): Boolean = {
107-
val tableIdent = processTableIdentifier(tableIdentifier).mkString(".")
108-
overrides.get(tableIdent) match {
118+
val tableIdent = processTableIdentifier(tableIdentifier)
119+
overrides.get(getDBTable(tableIdent)) match {
109120
case Some(_) => true
110121
case None => super.tableExists(tableIdentifier)
111122
}
@@ -115,8 +126,8 @@ trait OverrideCatalog extends Catalog {
115126
tableIdentifier: Seq[String],
116127
alias: Option[String] = None): LogicalPlan = {
117128
val tableIdent = processTableIdentifier(tableIdentifier)
118-
val overriddenTable = overrides.get(tableIdent.mkString("."))
119-
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.head, r))
129+
val overriddenTable = overrides.get(getDBTable(tableIdent))
130+
val tableWithQualifers = overriddenTable.map(r => Subquery(tableIdent.last, r))
120131

121132
// If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are
122133
// properly qualified with this alias.
@@ -129,13 +140,13 @@ trait OverrideCatalog extends Catalog {
129140
override def registerTable(
130141
tableIdentifier: Seq[String],
131142
plan: LogicalPlan): Unit = {
132-
val tableIdent = processTableIdentifier(tableIdentifier).mkString(".")
133-
overrides.put(tableIdent, plan)
143+
val tableIdent = processTableIdentifier(tableIdentifier)
144+
overrides.put(getDBTable(tableIdent), plan)
134145
}
135146

136147
override def unregisterTable(tableIdentifier: Seq[String]): Unit = {
137-
val tableIdent = processTableIdentifier(tableIdentifier).mkString(".")
138-
overrides.remove(tableIdent)
148+
val tableIdent = processTableIdentifier(tableIdentifier)
149+
overrides.remove(getDBTable(tableIdent))
139150
}
140151

141152
override def unregisterAllTables(): Unit = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ package object dsl {
290290

291291
def insertInto(tableName: String, overwrite: Boolean = false) =
292292
InsertIntoTable(
293-
analysis.UnresolvedRelation(IndexedSeq(tableName)), Map.empty, logicalPlan, overwrite)
293+
analysis.UnresolvedRelation(Seq(tableName)), Map.empty, logicalPlan, overwrite)
294294

295295
def analyze = analysis.SimpleAnalyzer(logicalPlan)
296296
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
4444
AttributeReference("e", ShortType)())
4545

4646
before {
47-
caseSensitiveCatalog.registerTable(IndexedSeq("TaBlE"), testRelation)
48-
caseInsensitiveCatalog.registerTable(IndexedSeq("TaBlE"), testRelation)
47+
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
48+
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
4949
}
5050

5151
test("union project *") {
@@ -64,45 +64,45 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
6464
assert(
6565
caseSensitiveAnalyze(
6666
Project(Seq(UnresolvedAttribute("TbL.a")),
67-
UnresolvedRelation(IndexedSeq("TaBlE"), Some("TbL")))) ===
67+
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
6868
Project(testRelation.output, testRelation))
6969

7070
val e = intercept[TreeNodeException[_]] {
7171
caseSensitiveAnalyze(
7272
Project(Seq(UnresolvedAttribute("tBl.a")),
73-
UnresolvedRelation(IndexedSeq("TaBlE"), Some("TbL"))))
73+
UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
7474
}
7575
assert(e.getMessage().toLowerCase.contains("unresolved"))
7676

7777
assert(
7878
caseInsensitiveAnalyze(
7979
Project(Seq(UnresolvedAttribute("TbL.a")),
80-
UnresolvedRelation(IndexedSeq("TaBlE"), Some("TbL")))) ===
80+
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
8181
Project(testRelation.output, testRelation))
8282

8383
assert(
8484
caseInsensitiveAnalyze(
8585
Project(Seq(UnresolvedAttribute("tBl.a")),
86-
UnresolvedRelation(IndexedSeq("TaBlE"), Some("TbL")))) ===
86+
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
8787
Project(testRelation.output, testRelation))
8888
}
8989

9090
test("resolve relations") {
9191
val e = intercept[RuntimeException] {
92-
caseSensitiveAnalyze(UnresolvedRelation(IndexedSeq("tAbLe"), None))
92+
caseSensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None))
9393
}
9494
assert(e.getMessage == "Table Not Found: tAbLe")
9595

9696
assert(
97-
caseSensitiveAnalyze(UnresolvedRelation(IndexedSeq("TaBlE"), None)) ===
97+
caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
9898
testRelation)
9999

100100
assert(
101-
caseInsensitiveAnalyze(UnresolvedRelation(IndexedSeq("tAbLe"), None)) ===
101+
caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
102102
testRelation)
103103

104104
assert(
105-
caseInsensitiveAnalyze(UnresolvedRelation(IndexedSeq("TaBlE"), None)) ===
105+
caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
106106
testRelation)
107107
}
108108

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class DecimalPrecisionSuite extends FunSuite with BeforeAndAfter {
4141
val f: Expression = UnresolvedAttribute("f")
4242

4343
before {
44-
catalog.registerTable(IndexedSeq("table"), relation)
44+
catalog.registerTable(Seq("table"), relation)
4545
}
4646

4747
private def checkType(expression: Expression, expectedType: DataType): Unit = {

sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
276276
* @group userf
277277
*/
278278
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
279-
catalog.registerTable(IndexedSeq(tableName), rdd.queryExecution.logical)
279+
catalog.registerTable(Seq(tableName), rdd.queryExecution.logical)
280280
}
281281

282282
/**
@@ -289,7 +289,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
289289
*/
290290
def dropTempTable(tableName: String): Unit = {
291291
tryUncacheQuery(table(tableName))
292-
catalog.unregisterTable(IndexedSeq(tableName))
292+
catalog.unregisterTable(Seq(tableName))
293293
}
294294

295295
/**
@@ -308,7 +308,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
308308

309309
/** Returns the specified table as a SchemaRDD */
310310
def table(tableName: String): SchemaRDD =
311-
new SchemaRDD(this, catalog.lookupRelation(IndexedSeq(tableName)))
311+
new SchemaRDD(this, catalog.lookupRelation(Seq(tableName)))
312312

313313
/**
314314
* :: DeveloperApi ::

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ private[sql] trait SchemaRDDLike {
9797
*/
9898
@Experimental
9999
def insertInto(tableName: String, overwrite: Boolean): Unit =
100-
sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(IndexedSeq(tableName)),
100+
sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
101101
Map.empty, logicalPlan, overwrite)).toRdd
102102

103103
/**

sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
302302
upperCaseData.where('N <= 4).registerTempTable("left")
303303
upperCaseData.where('N >= 3).registerTempTable("right")
304304

305-
val left = UnresolvedRelation(IndexedSeq("left"), None)
306-
val right = UnresolvedRelation(IndexedSeq("right"), None)
305+
val left = UnresolvedRelation(Seq("left"), None)
306+
val right = UnresolvedRelation(Seq("right"), None)
307307

308308
checkAnswer(
309309
left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
124124
* in the Hive metastore.
125125
*/
126126
def analyze(tableName: String) {
127-
val relation = EliminateAnalysisOperators(catalog.lookupRelation(IndexedSeq(tableName)))
127+
val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName)))
128128

129129
relation match {
130130
case relation: MetastoreRelation =>

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
6060

6161
def tableExists(tableIdentifier: Seq[String]): Boolean = {
6262
val tableIdent = processTableIdentifier(tableIdentifier)
63-
val (databaseName, tblName) =
64-
(tableIdent.lift(1).getOrElse(hive.sessionState.getCurrentDatabase), tableIdent.head)
63+
val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
64+
hive.sessionState.getCurrentDatabase)
65+
val tblName = tableIdent.last
6566
try {
6667
client.getTable(databaseName, tblName) != null
6768
} catch {
@@ -73,8 +74,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
7374
tableIdentifier: Seq[String],
7475
alias: Option[String]): LogicalPlan = synchronized {
7576
val tableIdent = processTableIdentifier(tableIdentifier)
76-
val (databaseName, tblName) =
77-
(tableIdent.lift(1).getOrElse(hive.sessionState.getCurrentDatabase), tableIdent.head)
77+
val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
78+
hive.sessionState.getCurrentDatabase)
79+
val tblName = tableIdent.last
7880
val table = client.getTable(databaseName, tblName)
7981
if (table.isView) {
8082
// if the unresolved relation is from hive view
@@ -296,7 +298,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
296298
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
297299

298300
// Get the CreateTableDesc from Hive SemanticAnalyzer
299-
val desc: Option[CreateTableDesc] = if (tableExists(IndexedSeq(tblName, databaseName))) {
301+
val desc: Option[CreateTableDesc] = if (tableExists(Seq(databaseName, tblName))) {
300302
None
301303
} else {
302304
val sa = new SemanticAnalyzer(hive.hiveconf) {

0 commit comments

Comments
 (0)