Skip to content

Commit 2118212

Browse files
author
Andrew Or
committed
Refactor CatalogFunction to use FunctionIdentifier
1 parent dd1fbae commit 2118212

File tree

5 files changed

+24
-21
lines changed

5 files changed

+24
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/TableIdentifier.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@ private[sql] object TableIdentifier {
4040
* If `database` is not defined, the current database is used.
4141
*/
4242
// TODO: reuse some code with TableIdentifier.
43-
private[sql] case class FunctionIdentifier(name: String, database: Option[String]) {
43+
private[sql] case class FunctionIdentifier(funcName: String, database: Option[String]) {
4444
def this(name: String) = this(name, None)
4545

4646
override def toString: String = quotedString
4747

48-
def quotedString: String = database.map(db => s"`$db`.`$name`").getOrElse(s"`$name`")
48+
def quotedString: String = database.map(db => s"`$db`.`$funcName`").getOrElse(s"`$funcName`")
4949

50-
def unquotedString: String = database.map(db => s"$db.$name").getOrElse(name)
50+
def unquotedString: String = database.map(db => s"$db.$funcName").getOrElse(funcName)
5151
}
5252

5353
private[sql] object FunctionIdentifier {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.catalog
2020
import scala.collection.mutable
2121

2222
import org.apache.spark.sql.AnalysisException
23-
import org.apache.spark.sql.catalyst.TableIdentifier
23+
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
2424

2525

2626
/**
@@ -294,10 +294,10 @@ class InMemoryCatalog extends ExternalCatalog {
294294

295295
override def createFunction(db: String, func: CatalogFunction): Unit = synchronized {
296296
requireDbExists(db)
297-
if (existsFunction(db, func.name)) {
297+
if (existsFunction(db, func.name.funcName)) {
298298
throw new AnalysisException(s"Function $func already exists in $db database")
299299
} else {
300-
catalog(db).functions.put(func.name, func)
300+
catalog(db).functions.put(func.name.funcName, func)
301301
}
302302
}
303303

@@ -308,14 +308,14 @@ class InMemoryCatalog extends ExternalCatalog {
308308

309309
override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized {
310310
requireFunctionExists(db, oldName)
311-
val newFunc = getFunction(db, oldName).copy(name = newName)
311+
val newFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
312312
catalog(db).functions.remove(oldName)
313313
catalog(db).functions.put(newName, newFunc)
314314
}
315315

316316
override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized {
317-
requireFunctionExists(db, funcDefinition.name)
318-
catalog(db).functions.put(funcDefinition.name, funcDefinition)
317+
requireFunctionExists(db, funcDefinition.name.funcName)
318+
catalog(db).functions.put(funcDefinition.name.funcName, funcDefinition)
319319
}
320320

321321
override def getFunction(db: String, funcName: String): CatalogFunction = synchronized {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,7 @@ abstract class ExternalCatalog {
168168
* @param name name of the function
169169
* @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc"
170170
*/
171-
// TODO: use FunctionIdentifier here.
172-
case class CatalogFunction(name: String, className: String)
171+
case class CatalogFunction(name: FunctionIdentifier, className: String)
173172

174173

175174
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterEach
2121

2222
import org.apache.spark.SparkFunSuite
2323
import org.apache.spark.sql.AnalysisException
24-
import org.apache.spark.sql.catalyst.TableIdentifier
24+
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
2525

2626

2727
/**
@@ -82,7 +82,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
8282
catalog
8383
}
8484

85-
private def newFunc(): CatalogFunction = CatalogFunction("funcname", funcClass)
85+
private def newFunc(): CatalogFunction = newFunc("funcName")
8686

8787
private def newDb(name: String): CatalogDatabase = {
8888
CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty)
@@ -97,7 +97,9 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
9797
partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string")))
9898
}
9999

100-
private def newFunc(name: String): CatalogFunction = CatalogFunction(name, funcClass)
100+
private def newFunc(name: String): CatalogFunction = {
101+
CatalogFunction(FunctionIdentifier(name, database = None), funcClass)
102+
}
101103

102104
/**
103105
* Whether the catalog's table partitions equal the ones given.
@@ -498,7 +500,8 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
498500

499501
test("get function") {
500502
val catalog = newBasicCatalog()
501-
assert(catalog.getFunction("db2", "func1") == newFunc("func1"))
503+
assert(catalog.getFunction("db2", "func1") ==
504+
CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass))
502505
intercept[AnalysisException] {
503506
catalog.getFunction("db2", "does_not_exist")
504507
}
@@ -517,7 +520,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach {
517520
assert(catalog.getFunction("db2", "func1").className == funcClass)
518521
catalog.renameFunction("db2", "func1", newName)
519522
intercept[AnalysisException] { catalog.getFunction("db2", "func1") }
520-
assert(catalog.getFunction("db2", newName).name == newName)
523+
assert(catalog.getFunction("db2", newName).name.funcName == newName)
521524
assert(catalog.getFunction("db2", newName).className == funcClass)
522525
intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") }
523526
}

sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.hadoop.hive.ql.session.SessionState
3535
import org.apache.hadoop.security.UserGroupInformation
3636

3737
import org.apache.spark.{Logging, SparkConf, SparkException}
38-
import org.apache.spark.sql.catalyst.TableIdentifier
38+
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
3939
import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException}
4040
import org.apache.spark.sql.catalyst.catalog._
4141
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -545,13 +545,13 @@ private[hive] class HiveClientImpl(
545545
}
546546

547547
override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState {
548-
val catalogFunc = getFunction(db, oldName).copy(name = newName)
548+
val catalogFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db)))
549549
val hiveFunc = toHiveFunction(catalogFunc, db)
550550
client.alterFunction(db, oldName, hiveFunc)
551551
}
552552

553553
override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState {
554-
client.alterFunction(db, func.name, toHiveFunction(func, db))
554+
client.alterFunction(db, func.name.funcName, toHiveFunction(func, db))
555555
}
556556

557557
override def getFunctionOption(
@@ -612,7 +612,7 @@ private[hive] class HiveClientImpl(
612612

613613
private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = {
614614
new HiveFunction(
615-
f.name,
615+
f.name.funcName,
616616
db,
617617
f.className,
618618
null,
@@ -623,7 +623,8 @@ private[hive] class HiveClientImpl(
623623
}
624624

625625
private def fromHiveFunction(hf: HiveFunction): CatalogFunction = {
626-
new CatalogFunction(hf.getFunctionName, hf.getClassName)
626+
val name = FunctionIdentifier(hf.getFunctionName, Option(hf.getDbName))
627+
new CatalogFunction(name, hf.getClassName)
627628
}
628629

629630
private def toHiveColumn(c: CatalogColumn): FieldSchema = {

0 commit comments

Comments
 (0)