Skip to content

Commit 6d9fa2f

Browse files
author
Andrew Or
committed
Keep track of current database in SessionCatalog
This allows us to not pass it into every single method like we used to before this commit.
1 parent ff1c2c4 commit 6d9fa2f

File tree

2 files changed

+173
-172
lines changed

2 files changed

+173
-172
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 44 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
3535
import ExternalCatalog._
3636

3737
private[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan]
38-
3938
private[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction]
4039

40+
// Note: we track current database here because certain operations do not explicitly
41+
// specify the database (e.g. DROP TABLE my_table). In these cases we must first
42+
// check whether the temporary table or function exists, then, if not, operate on
43+
// the corresponding item in the current database.
44+
private[this] var currentDb = "default"
45+
4146
// ----------------------------------------------------------------------------
4247
// Databases
4348
// ----------------------------------------------------------------------------
@@ -72,6 +77,12 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
7277
externalCatalog.listDatabases(pattern)
7378
}
7479

80+
def getCurrentDatabase: String = currentDb
81+
82+
def setCurrentDatabase(db: String): Unit = {
83+
currentDb = db
84+
}
85+
7586
// ----------------------------------------------------------------------------
7687
// Tables
7788
// ----------------------------------------------------------------------------
@@ -89,10 +100,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
89100
* Create a metastore table in the database specified in `tableDefinition`.
90101
* If no such database is specified, create it in the current database.
91102
*/
92-
def createTable(
93-
currentDb: String,
94-
tableDefinition: CatalogTable,
95-
ignoreIfExists: Boolean): Unit = {
103+
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
96104
val db = tableDefinition.name.database.getOrElse(currentDb)
97105
val newTableDefinition = tableDefinition.copy(
98106
name = TableIdentifier(tableDefinition.name.table, Some(db)))
@@ -108,7 +116,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
108116
* Note: If the underlying implementation does not support altering a certain field,
109117
* this becomes a no-op.
110118
*/
111-
def alterTable(currentDb: String, tableDefinition: CatalogTable): Unit = {
119+
def alterTable(tableDefinition: CatalogTable): Unit = {
112120
val db = tableDefinition.name.database.getOrElse(currentDb)
113121
val newTableDefinition = tableDefinition.copy(
114122
name = TableIdentifier(tableDefinition.name.table, Some(db)))
@@ -119,7 +127,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
119127
* Retrieve the metadata of an existing metastore table.
120128
* If no database is specified, assume the table is in the current database.
121129
*/
122-
def getTable(currentDb: String, name: TableIdentifier): CatalogTable = {
130+
def getTable(name: TableIdentifier): CatalogTable = {
123131
val db = name.database.getOrElse(currentDb)
124132
externalCatalog.getTable(db, name.table)
125133
}
@@ -150,10 +158,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
150158
*
151159
* This assumes the database specified in `oldName` matches the one specified in `newName`.
152160
*/
153-
def renameTable(
154-
currentDb: String,
155-
oldName: TableIdentifier,
156-
newName: TableIdentifier): Unit = {
161+
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = {
157162
if (oldName.database != newName.database) {
158163
throw new AnalysisException("rename does not support moving tables across databases")
159164
}
@@ -173,10 +178,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
173178
* If no database is specified, this will first attempt to drop a temporary table with
174179
* the same name, then, if that does not exist, drop the table from the current database.
175180
*/
176-
def dropTable(
177-
currentDb: String,
178-
name: TableIdentifier,
179-
ignoreIfNotExists: Boolean): Unit = {
181+
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = {
180182
val db = name.database.getOrElse(currentDb)
181183
if (name.database.isDefined || !tempTables.containsKey(name.table)) {
182184
externalCatalog.dropTable(db, name.table, ignoreIfNotExists)
@@ -192,10 +194,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
192194
* If no database is specified, this will first attempt to return a temporary table with
193195
* the same name, then, if that does not exist, return the table from the current database.
194196
*/
195-
def lookupRelation(
196-
currentDb: String,
197-
name: TableIdentifier,
198-
alias: Option[String] = None): LogicalPlan = {
197+
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
199198
val db = name.database.getOrElse(currentDb)
200199
val relation =
201200
if (name.database.isDefined || !tempTables.containsKey(name.table)) {
@@ -211,28 +210,25 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
211210
}
212211

213212
/**
214-
* List all tables in the current database, including temporary tables.
213+
* List all tables in the specified database, including temporary tables.
215214
*/
216-
def listTables(currentDb: String): Seq[TableIdentifier] = {
217-
val tablesInCurrentDb = externalCatalog.listTables(currentDb).map { t =>
218-
TableIdentifier(t, Some(currentDb))
219-
}
215+
def listTables(db: String): Seq[TableIdentifier] = {
216+
val dbTables = externalCatalog.listTables(db).map { t => TableIdentifier(t, Some(db)) }
220217
val _tempTables = tempTables.keys().asScala.map { t => TableIdentifier(t) }
221-
tablesInCurrentDb ++ _tempTables
218+
dbTables ++ _tempTables
222219
}
223220

224221
/**
225-
* List all matching tables in the current database, including temporary tables.
222+
* List all matching tables in the specified database, including temporary tables.
226223
*/
227-
def listTables(currentDb: String, pattern: String): Seq[TableIdentifier] = {
228-
val tablesInCurrentDb = externalCatalog.listTables(currentDb, pattern).map { t =>
229-
TableIdentifier(t, Some(currentDb))
230-
}
224+
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
225+
val dbTables =
226+
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(currentDb)) }
231227
val regex = pattern.replaceAll("\\*", ".*").r
232228
val _tempTables = tempTables.keys().asScala
233229
.filter { t => regex.pattern.matcher(t).matches() }
234230
.map { t => TableIdentifier(t) }
235-
tablesInCurrentDb ++ _tempTables
231+
dbTables ++ _tempTables
236232
}
237233

238234
/**
@@ -260,7 +256,6 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
260256
* If no database is specified, assume the table is in the current database.
261257
*/
262258
def createPartitions(
263-
currentDb: String,
264259
tableName: TableIdentifier,
265260
parts: Seq[CatalogTablePartition],
266261
ignoreIfExists: Boolean): Unit = {
@@ -273,7 +268,6 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
273268
* If no database is specified, assume the table is in the current database.
274269
*/
275270
def dropPartitions(
276-
currentDb: String,
277271
tableName: TableIdentifier,
278272
parts: Seq[TablePartitionSpec],
279273
ignoreIfNotExists: Boolean): Unit = {
@@ -288,7 +282,6 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
288282
* If no database is specified, assume the table is in the current database.
289283
*/
290284
def renamePartitions(
291-
currentDb: String,
292285
tableName: TableIdentifier,
293286
specs: Seq[TablePartitionSpec],
294287
newSpecs: Seq[TablePartitionSpec]): Unit = {
@@ -305,10 +298,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
305298
* Note: If the underlying implementation does not support altering a certain field,
306299
* this becomes a no-op.
307300
*/
308-
def alterPartitions(
309-
currentDb: String,
310-
tableName: TableIdentifier,
311-
parts: Seq[CatalogTablePartition]): Unit = {
301+
def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
312302
val db = tableName.database.getOrElse(currentDb)
313303
externalCatalog.alterPartitions(db, tableName.table, parts)
314304
}
@@ -317,10 +307,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
317307
* Retrieve the metadata of a table partition, assuming it exists.
318308
* If no database is specified, assume the table is in the current database.
319309
*/
320-
def getPartition(
321-
currentDb: String,
322-
tableName: TableIdentifier,
323-
spec: TablePartitionSpec): CatalogTablePartition = {
310+
def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
324311
val db = tableName.database.getOrElse(currentDb)
325312
externalCatalog.getPartition(db, tableName.table, spec)
326313
}
@@ -329,9 +316,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
329316
* List all partitions in a table, assuming it exists.
330317
* If no database is specified, assume the table is in the current database.
331318
*/
332-
def listPartitions(
333-
currentDb: String,
334-
tableName: TableIdentifier): Seq[CatalogTablePartition] = {
319+
def listPartitions(tableName: TableIdentifier): Seq[CatalogTablePartition] = {
335320
val db = tableName.database.getOrElse(currentDb)
336321
externalCatalog.listPartitions(db, tableName.table)
337322
}
@@ -353,7 +338,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
353338
* Create a metastore function in the database specified in `funcDefinition`.
354339
* If no such database is specified, create it in the current database.
355340
*/
356-
def createFunction(currentDb: String, funcDefinition: CatalogFunction): Unit = {
341+
def createFunction(funcDefinition: CatalogFunction): Unit = {
357342
val db = funcDefinition.name.database.getOrElse(currentDb)
358343
val newFuncDefinition = funcDefinition.copy(
359344
name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
@@ -364,7 +349,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
364349
* Drop a metastore function.
365350
* If no database is specified, assume the function is in the current database.
366351
*/
367-
def dropFunction(currentDb: String, name: FunctionIdentifier): Unit = {
352+
def dropFunction(name: FunctionIdentifier): Unit = {
368353
val db = name.database.getOrElse(currentDb)
369354
externalCatalog.dropFunction(db, name.funcName)
370355
}
@@ -378,7 +363,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
378363
* Note: If the underlying implementation does not support altering a certain field,
379364
* this becomes a no-op.
380365
*/
381-
def alterFunction(currentDb: String, funcDefinition: CatalogFunction): Unit = {
366+
def alterFunction(funcDefinition: CatalogFunction): Unit = {
382367
val db = funcDefinition.name.database.getOrElse(currentDb)
383368
val newFuncDefinition = funcDefinition.copy(
384369
name = FunctionIdentifier(funcDefinition.name.funcName, Some(db)))
@@ -393,9 +378,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
393378
* Create a temporary function.
394379
* This assumes no database is specified in `funcDefinition`.
395380
*/
396-
def createTempFunction(
397-
funcDefinition: CatalogFunction,
398-
ignoreIfExists: Boolean): Unit = {
381+
def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
399382
require(funcDefinition.name.database.isEmpty,
400383
"attempted to create a temporary function while specifying a database")
401384
val name = funcDefinition.name.funcName
@@ -428,10 +411,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
428411
*
429412
* This assumes the database specified in `oldName` matches the one specified in `newName`.
430413
*/
431-
def renameFunction(
432-
currentDb: String,
433-
oldName: FunctionIdentifier,
434-
newName: FunctionIdentifier): Unit = {
414+
def renameFunction(oldName: FunctionIdentifier, newName: FunctionIdentifier): Unit = {
415+
if (oldName.database != newName.database) {
416+
throw new AnalysisException("rename does not support moving functions across databases")
417+
}
435418
val db = oldName.database.getOrElse(currentDb)
436419
if (oldName.database.isDefined || !tempFunctions.containsKey(oldName.funcName)) {
437420
externalCatalog.renameFunction(db, oldName.funcName, newName.funcName)
@@ -449,7 +432,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
449432
* If no database is specified, this will first attempt to return a temporary function with
450433
* the same name, then, if that does not exist, return the function in the current database.
451434
*/
452-
def getFunction(currentDb: String, name: FunctionIdentifier): CatalogFunction = {
435+
def getFunction(name: FunctionIdentifier): CatalogFunction = {
453436
val db = name.database.getOrElse(currentDb)
454437
if (name.database.isDefined || !tempFunctions.containsKey(name.funcName)) {
455438
externalCatalog.getFunction(db, name.funcName)
@@ -461,17 +444,16 @@ class SessionCatalog(externalCatalog: ExternalCatalog) {
461444
// TODO: implement lookupFunction that returns something from the registry itself
462445

463446
/**
464-
* List all matching functions in the current database, including temporary functions.
447+
* List all matching functions in the specified database, including temporary functions.
465448
*/
466-
def listFunctions(currentDb: String, pattern: String): Seq[FunctionIdentifier] = {
467-
val functionsInCurrentDb = externalCatalog.listFunctions(currentDb, pattern).map { f =>
468-
FunctionIdentifier(f, Some(currentDb))
469-
}
449+
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
450+
val dbFunctions =
451+
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
470452
val regex = pattern.replaceAll("\\*", ".*").r
471453
val _tempFunctions = tempFunctions.keys().asScala
472454
.filter { f => regex.pattern.matcher(f).matches() }
473455
.map { f => FunctionIdentifier(f) }
474-
functionsInCurrentDb ++ _tempFunctions
456+
dbFunctions ++ _tempFunctions
475457
}
476458

477459
/**

0 commit comments

Comments
 (0)