Skip to content

Commit f453791

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-15187][SQL] Disallow Dropping Default Database
#### What changes were proposed in this pull request? In Hive Metastore, dropping default database is not allowed. However, in `InMemoryCatalog`, this is allowed. This PR is to disallow users to drop default database. #### How was this patch tested? Previously, we already have a test case in HiveDDLSuite. Now, we also add the same one in DDLSuite Author: gatorsmile <[email protected]> Closes #12962 from gatorsmile/dropDefaultDB.
1 parent 4b4344a commit f453791

File tree

4 files changed

+106
-52
lines changed

4 files changed

+106
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala

Lines changed: 58 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ class SessionCatalog(
8282
CatalogDatabase(defaultName, "default database", conf.warehousePath, Map())
8383
// Initialize default database if it doesn't already exist
8484
createDatabase(defaultDbDefinition, ignoreIfExists = true)
85-
defaultName
85+
formatDatabaseName(defaultName)
8686
}
8787

8888
/**
@@ -92,6 +92,13 @@ class SessionCatalog(
9292
if (conf.caseSensitiveAnalysis) name else name.toLowerCase
9393
}
9494

95+
/**
96+
* Format database name, taking into account case sensitivity.
97+
*/
98+
protected[this] def formatDatabaseName(name: String): String = {
99+
if (conf.caseSensitiveAnalysis) name else name.toLowerCase
100+
}
101+
95102
/**
96103
* This method is used to make the given path qualified before we
97104
* store this path in the underlying external catalog. So, when a path
@@ -112,25 +119,33 @@ class SessionCatalog(
112119

113120
def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = {
114121
val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString
122+
val dbName = formatDatabaseName(dbDefinition.name)
115123
externalCatalog.createDatabase(
116-
dbDefinition.copy(locationUri = qualifiedPath),
124+
dbDefinition.copy(name = dbName, locationUri = qualifiedPath),
117125
ignoreIfExists)
118126
}
119127

120128
def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = {
121-
externalCatalog.dropDatabase(db, ignoreIfNotExists, cascade)
129+
val dbName = formatDatabaseName(db)
130+
if (dbName == "default") {
131+
throw new AnalysisException(s"Can not drop default database")
132+
}
133+
externalCatalog.dropDatabase(dbName, ignoreIfNotExists, cascade)
122134
}
123135

124136
def alterDatabase(dbDefinition: CatalogDatabase): Unit = {
125-
externalCatalog.alterDatabase(dbDefinition)
137+
val dbName = formatDatabaseName(dbDefinition.name)
138+
externalCatalog.alterDatabase(dbDefinition.copy(name = dbName))
126139
}
127140

128141
def getDatabaseMetadata(db: String): CatalogDatabase = {
129-
externalCatalog.getDatabase(db)
142+
val dbName = formatDatabaseName(db)
143+
externalCatalog.getDatabase(dbName)
130144
}
131145

132146
def databaseExists(db: String): Boolean = {
133-
externalCatalog.databaseExists(db)
147+
val dbName = formatDatabaseName(db)
148+
externalCatalog.databaseExists(dbName)
134149
}
135150

136151
def listDatabases(): Seq[String] = {
@@ -144,18 +159,19 @@ class SessionCatalog(
144159
def getCurrentDatabase: String = synchronized { currentDb }
145160

146161
def setCurrentDatabase(db: String): Unit = {
147-
if (!databaseExists(db)) {
148-
throw new AnalysisException(s"Database '$db' does not exist.")
162+
val dbName = formatDatabaseName(db)
163+
if (!databaseExists(dbName)) {
164+
throw new AnalysisException(s"Database '$dbName' does not exist.")
149165
}
150-
synchronized { currentDb = db }
166+
synchronized { currentDb = dbName }
151167
}
152168

153169
/**
154170
* Get the path for creating a non-default database when database location is not provided
155171
* by users.
156172
*/
157173
def getDefaultDBPath(db: String): String = {
158-
val database = if (conf.caseSensitiveAnalysis) db else db.toLowerCase
174+
val database = formatDatabaseName(db)
159175
new Path(new Path(conf.warehousePath), database + ".db").toString
160176
}
161177

@@ -177,7 +193,7 @@ class SessionCatalog(
177193
* If no such database is specified, create it in the current database.
178194
*/
179195
def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = {
180-
val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
196+
val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase))
181197
val table = formatTableName(tableDefinition.identifier.table)
182198
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
183199
externalCatalog.createTable(db, newTableDefinition, ignoreIfExists)
@@ -193,7 +209,7 @@ class SessionCatalog(
193209
* this becomes a no-op.
194210
*/
195211
def alterTable(tableDefinition: CatalogTable): Unit = {
196-
val db = tableDefinition.identifier.database.getOrElse(getCurrentDatabase)
212+
val db = formatDatabaseName(tableDefinition.identifier.database.getOrElse(getCurrentDatabase))
197213
val table = formatTableName(tableDefinition.identifier.table)
198214
val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db)))
199215
externalCatalog.alterTable(db, newTableDefinition)
@@ -205,7 +221,7 @@ class SessionCatalog(
205221
* If the specified table is not found in the database then an [[AnalysisException]] is thrown.
206222
*/
207223
def getTableMetadata(name: TableIdentifier): CatalogTable = {
208-
val db = name.database.getOrElse(getCurrentDatabase)
224+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
209225
val table = formatTableName(name.table)
210226
externalCatalog.getTable(db, table)
211227
}
@@ -216,7 +232,7 @@ class SessionCatalog(
216232
* If the specified table is not found in the database then return None if it doesn't exist.
217233
*/
218234
def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = {
219-
val db = name.database.getOrElse(getCurrentDatabase)
235+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
220236
val table = formatTableName(name.table)
221237
externalCatalog.getTableOption(db, table)
222238
}
@@ -231,7 +247,7 @@ class SessionCatalog(
231247
loadPath: String,
232248
isOverwrite: Boolean,
233249
holdDDLTime: Boolean): Unit = {
234-
val db = name.database.getOrElse(getCurrentDatabase)
250+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
235251
val table = formatTableName(name.table)
236252
externalCatalog.loadTable(db, table, loadPath, isOverwrite, holdDDLTime)
237253
}
@@ -249,14 +265,14 @@ class SessionCatalog(
249265
holdDDLTime: Boolean,
250266
inheritTableSpecs: Boolean,
251267
isSkewedStoreAsSubdir: Boolean): Unit = {
252-
val db = name.database.getOrElse(getCurrentDatabase)
268+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
253269
val table = formatTableName(name.table)
254270
externalCatalog.loadPartition(db, table, loadPath, partition, isOverwrite, holdDDLTime,
255271
inheritTableSpecs, isSkewedStoreAsSubdir)
256272
}
257273

258274
def defaultTablePath(tableIdent: TableIdentifier): String = {
259-
val dbName = tableIdent.database.getOrElse(getCurrentDatabase)
275+
val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase))
260276
val dbLocation = getDatabaseMetadata(dbName).locationUri
261277

262278
new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString
@@ -290,8 +306,8 @@ class SessionCatalog(
290306
* This assumes the database specified in `oldName` matches the one specified in `newName`.
291307
*/
292308
def renameTable(oldName: TableIdentifier, newName: TableIdentifier): Unit = synchronized {
293-
val db = oldName.database.getOrElse(currentDb)
294-
val newDb = newName.database.getOrElse(currentDb)
309+
val db = formatDatabaseName(oldName.database.getOrElse(currentDb))
310+
val newDb = formatDatabaseName(newName.database.getOrElse(currentDb))
295311
if (db != newDb) {
296312
throw new AnalysisException(
297313
s"RENAME TABLE source and destination databases do not match: '$db' != '$newDb'")
@@ -324,7 +340,7 @@ class SessionCatalog(
324340
* the same name, then, if that does not exist, drop the table from the current database.
325341
*/
326342
def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = synchronized {
327-
val db = name.database.getOrElse(currentDb)
343+
val db = formatDatabaseName(name.database.getOrElse(currentDb))
328344
val table = formatTableName(name.table)
329345
if (name.database.isDefined || !tempTables.contains(table)) {
330346
// When ignoreIfNotExists is false, no exception is issued when the table does not exist.
@@ -348,7 +364,7 @@ class SessionCatalog(
348364
*/
349365
def lookupRelation(name: TableIdentifier, alias: Option[String] = None): LogicalPlan = {
350366
synchronized {
351-
val db = name.database.getOrElse(currentDb)
367+
val db = formatDatabaseName(name.database.getOrElse(currentDb))
352368
val table = formatTableName(name.table)
353369
val relation =
354370
if (name.database.isDefined || !tempTables.contains(table)) {
@@ -373,7 +389,7 @@ class SessionCatalog(
373389
* contain the table.
374390
*/
375391
def tableExists(name: TableIdentifier): Boolean = synchronized {
376-
val db = name.database.getOrElse(currentDb)
392+
val db = formatDatabaseName(name.database.getOrElse(currentDb))
377393
val table = formatTableName(name.table)
378394
if (name.database.isDefined || !tempTables.contains(table)) {
379395
externalCatalog.tableExists(db, table)
@@ -395,14 +411,15 @@ class SessionCatalog(
395411
/**
396412
* List all tables in the specified database, including temporary tables.
397413
*/
398-
def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*")
414+
def listTables(db: String): Seq[TableIdentifier] = listTables(formatDatabaseName(db), "*")
399415

400416
/**
401417
* List all matching tables in the specified database, including temporary tables.
402418
*/
403419
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
420+
val dbName = formatDatabaseName(db)
404421
val dbTables =
405-
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
422+
externalCatalog.listTables(dbName, pattern).map { t => TableIdentifier(t, Some(dbName)) }
406423
synchronized {
407424
val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
408425
.map { t => TableIdentifier(t) }
@@ -458,7 +475,7 @@ class SessionCatalog(
458475
tableName: TableIdentifier,
459476
parts: Seq[CatalogTablePartition],
460477
ignoreIfExists: Boolean): Unit = {
461-
val db = tableName.database.getOrElse(getCurrentDatabase)
478+
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
462479
val table = formatTableName(tableName.table)
463480
externalCatalog.createPartitions(db, table, parts, ignoreIfExists)
464481
}
@@ -471,7 +488,7 @@ class SessionCatalog(
471488
tableName: TableIdentifier,
472489
parts: Seq[TablePartitionSpec],
473490
ignoreIfNotExists: Boolean): Unit = {
474-
val db = tableName.database.getOrElse(getCurrentDatabase)
491+
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
475492
val table = formatTableName(tableName.table)
476493
externalCatalog.dropPartitions(db, table, parts, ignoreIfNotExists)
477494
}
@@ -486,7 +503,7 @@ class SessionCatalog(
486503
tableName: TableIdentifier,
487504
specs: Seq[TablePartitionSpec],
488505
newSpecs: Seq[TablePartitionSpec]): Unit = {
489-
val db = tableName.database.getOrElse(getCurrentDatabase)
506+
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
490507
val table = formatTableName(tableName.table)
491508
externalCatalog.renamePartitions(db, table, specs, newSpecs)
492509
}
@@ -501,7 +518,7 @@ class SessionCatalog(
501518
* this becomes a no-op.
502519
*/
503520
def alterPartitions(tableName: TableIdentifier, parts: Seq[CatalogTablePartition]): Unit = {
504-
val db = tableName.database.getOrElse(getCurrentDatabase)
521+
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
505522
val table = formatTableName(tableName.table)
506523
externalCatalog.alterPartitions(db, table, parts)
507524
}
@@ -511,7 +528,7 @@ class SessionCatalog(
511528
* If no database is specified, assume the table is in the current database.
512529
*/
513530
def getPartition(tableName: TableIdentifier, spec: TablePartitionSpec): CatalogTablePartition = {
514-
val db = tableName.database.getOrElse(getCurrentDatabase)
531+
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
515532
val table = formatTableName(tableName.table)
516533
externalCatalog.getPartition(db, table, spec)
517534
}
@@ -526,7 +543,7 @@ class SessionCatalog(
526543
def listPartitions(
527544
tableName: TableIdentifier,
528545
partialSpec: Option[TablePartitionSpec] = None): Seq[CatalogTablePartition] = {
529-
val db = tableName.database.getOrElse(getCurrentDatabase)
546+
val db = formatDatabaseName(tableName.database.getOrElse(getCurrentDatabase))
530547
val table = formatTableName(tableName.table)
531548
externalCatalog.listPartitions(db, table, partialSpec)
532549
}
@@ -549,7 +566,7 @@ class SessionCatalog(
549566
* If no such database is specified, create it in the current database.
550567
*/
551568
def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = {
552-
val db = funcDefinition.identifier.database.getOrElse(getCurrentDatabase)
569+
val db = formatDatabaseName(funcDefinition.identifier.database.getOrElse(getCurrentDatabase))
553570
val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))
554571
val newFuncDefinition = funcDefinition.copy(identifier = identifier)
555572
if (!functionExists(identifier)) {
@@ -564,7 +581,7 @@ class SessionCatalog(
564581
* If no database is specified, assume the function is in the current database.
565582
*/
566583
def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = {
567-
val db = name.database.getOrElse(getCurrentDatabase)
584+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
568585
val identifier = name.copy(database = Some(db))
569586
if (functionExists(identifier)) {
570587
// TODO: registry should just take in FunctionIdentifier for type safety
@@ -588,15 +605,15 @@ class SessionCatalog(
588605
* If no database is specified, this will return the function in the current database.
589606
*/
590607
def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = {
591-
val db = name.database.getOrElse(getCurrentDatabase)
608+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
592609
externalCatalog.getFunction(db, name.funcName)
593610
}
594611

595612
/**
596613
* Check if the specified function exists.
597614
*/
598615
def functionExists(name: FunctionIdentifier): Boolean = {
599-
val db = name.database.getOrElse(getCurrentDatabase)
616+
val db = formatDatabaseName(name.database.getOrElse(getCurrentDatabase))
600617
functionRegistry.functionExists(name.unquotedString) ||
601618
externalCatalog.functionExists(db, name.funcName)
602619
}
@@ -661,7 +678,8 @@ class SessionCatalog(
661678
*/
662679
private[spark] def lookupFunctionInfo(name: FunctionIdentifier): ExpressionInfo = synchronized {
663680
// TODO: just make function registry take in FunctionIdentifier instead of duplicating this
664-
val qualifiedName = name.copy(database = name.database.orElse(Some(currentDb)))
681+
val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
682+
val qualifiedName = name.copy(database = database)
665683
functionRegistry.lookupFunction(name.funcName)
666684
.orElse(functionRegistry.lookupFunction(qualifiedName.unquotedString))
667685
.getOrElse {
@@ -700,7 +718,8 @@ class SessionCatalog(
700718
}
701719

702720
// If the name itself is not qualified, add the current database to it.
703-
val qualifiedName = if (name.database.isEmpty) name.copy(database = Some(currentDb)) else name
721+
val database = name.database.orElse(Some(currentDb)).map(formatDatabaseName)
722+
val qualifiedName = name.copy(database = database)
704723

705724
if (functionRegistry.functionExists(qualifiedName.unquotedString)) {
706725
// This function has been already loaded into the function registry.
@@ -740,8 +759,9 @@ class SessionCatalog(
740759
* List all matching functions in the specified database, including temporary functions.
741760
*/
742761
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
743-
val dbFunctions =
744-
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
762+
val dbName = formatDatabaseName(db)
763+
val dbFunctions = externalCatalog.listFunctions(dbName, pattern)
764+
.map { f => FunctionIdentifier(f, Some(dbName)) }
745765
val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
746766
.map { f => FunctionIdentifier(f) }
747767
dbFunctions ++ loadedFunctions

sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -644,16 +644,16 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
644644

645645
checkAnswer(
646646
sql("SHOW DATABASES LIKE '*db1A'"),
647-
Row("showdb1A") :: Nil)
647+
Row("showdb1a") :: Nil)
648648

649649
checkAnswer(
650650
sql("SHOW DATABASES LIKE 'showdb1A'"),
651-
Row("showdb1A") :: Nil)
651+
Row("showdb1a") :: Nil)
652652

653653
checkAnswer(
654654
sql("SHOW DATABASES LIKE '*db1A|*db2B'"),
655-
Row("showdb1A") ::
656-
Row("showdb2B") :: Nil)
655+
Row("showdb1a") ::
656+
Row("showdb2b") :: Nil)
657657

658658
checkAnswer(
659659
sql("SHOW DATABASES LIKE 'non-existentdb'"),
@@ -1000,4 +1000,24 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
10001000
Row("Usage: a ^ b - Bitwise exclusive OR.") :: Nil
10011001
)
10021002
}
1003+
1004+
test("drop default database") {
1005+
Seq("true", "false").foreach { caseSensitive =>
1006+
withSQLConf(SQLConf.CASE_SENSITIVE.key -> caseSensitive) {
1007+
var message = intercept[AnalysisException] {
1008+
sql("DROP DATABASE default")
1009+
}.getMessage
1010+
assert(message.contains("Can not drop default database"))
1011+
1012+
message = intercept[AnalysisException] {
1013+
sql("DROP DATABASE DeFault")
1014+
}.getMessage
1015+
if (caseSensitive == "true") {
1016+
assert(message.contains("Database 'DeFault' does not exist"))
1017+
} else {
1018+
assert(message.contains("Can not drop default database"))
1019+
}
1020+
}
1021+
}
1022+
}
10031023
}

0 commit comments

Comments
 (0)