Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package org.apache.spark.sql.connect.client.jdbc
import java.sql.{Array => _, _}

import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
import org.apache.spark.sql.connect
import org.apache.spark.sql.connect.client.jdbc.SparkConnectDatabaseMetaData._
import org.apache.spark.sql.functions._
import org.apache.spark.util.VersionUtils

class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends DatabaseMetaData {
Expand Down Expand Up @@ -277,6 +279,9 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas

override def dataDefinitionIgnoredInTransactions: Boolean = false

private def isNullOrWildcard(pattern: String): Boolean =
pattern == null || pattern == "%"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is used to test whether fooPattern matches ALL

https://docs.oracle.com/en/java/javase/17/docs/api/java.sql/java/sql/DatabaseMetaData.html

Some DatabaseMetaData methods take arguments that are String patterns. These arguments all have names such as fooPattern. Within a pattern String, "%" means match any substring of 0 or more characters, and "_" means match any one character. Only metadata entries matching the search pattern are returned. If a search pattern argument is set to null, that argument's criterion will be dropped from the search.


override def getProcedures(
catalog: String,
schemaPattern: String,
Expand All @@ -299,11 +304,59 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas
new SparkConnectResultSet(df.collectResult())
}

override def getSchemas: ResultSet =
throw new SQLFeatureNotSupportedException
override def getSchemas: ResultSet = {
conn.checkOpen()

override def getSchemas(catalog: String, schemaPattern: String): ResultSet =
throw new SQLFeatureNotSupportedException
getSchemas(null, null)
}

// Schema of the returned DataFrame is:
// |-- TABLE_SCHEM: string (nullable = false)
// |-- TABLE_CATALOG: string (nullable = false)
private def getSchemasDataFrame(
catalog: String, schemaPattern: String): connect.DataFrame = {

val schemaFilterClause =
if (isNullOrWildcard(schemaPattern)) "1=1" else s"TABLE_SCHEM LIKE '$schemaPattern'"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if schemaPattern contains a single quote?


def internalGetSchemas(
catalogOpt: Option[String],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
catalogOpt: Option[String],
catalog: String,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that all 3 call points involve deterministic values, so there's no need to wrap them in an Option, right?

schemaFilterClause: String): connect.DataFrame = {
val catalog = catalogOpt.getOrElse(conn.getCatalog)
// Spark SQL supports LIKE clause in SHOW SCHEMAS command, but we can't use that
// because the LIKE pattern does not follow SQL standard.
conn.spark.sql(s"SHOW SCHEMAS IN `$catalog`")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen if the catalog name contains backticks?

.select($"namespace".as("TABLE_SCHEM"))
.filter(schemaFilterClause)
.withColumn("TABLE_CATALOG", lit(catalog))
}

if (catalog == null) {
// search in all catalogs
val emptyDf = conn.spark.emptyDataFrame
.withColumn("TABLE_SCHEM", lit(""))
.withColumn("TABLE_CATALOG", lit(""))
conn.spark.catalog.listCatalogs().collect().map(_.name).map { catalog =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conn.spark.catalog.listCatalogs().collect().map(_.name).map { catalog =>
conn.spark.catalog.listCatalogs().collect().map(_.name).map { catalogName =>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, it would have the same naming as the outer catalog, resulting in poorer readability.

internalGetSchemas(Some(catalog), schemaFilterClause)
}.fold(emptyDf) { (l, r) => l.unionAll(r) }
} else if (catalog == "") {
// search only in current catalog
internalGetSchemas(None, schemaFilterClause)
.withColumn("TABLE_CATALOG", lit(conn.getCatalog))
} else {
// search in the specific catalog
internalGetSchemas(Some(catalog), schemaFilterClause)
.withColumn("TABLE_CATALOG", lit(catalog))
}
}

override def getSchemas(catalog: String, schemaPattern: String): ResultSet = {
conn.checkOpen()

val df = getSchemasDataFrame(catalog, schemaPattern)
.orderBy("TABLE_CATALOG", "TABLE_SCHEM")
new SparkConnectResultSet(df.collectResult())
}

override def getTableTypes: ResultSet =
throw new SQLFeatureNotSupportedException
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,106 @@ class SparkConnectDatabaseMetaDataSuite extends ConnectFunSuite with RemoteSpark
}
}
}

test("SparkConnectDatabaseMetaData getSchemas") {

def verifyGetSchemas(
getSchemas: () => ResultSet)(verify: Seq[(String, String)] => Unit): Unit = {
Using.resource(getSchemas()) { rs =>
val catalogDatabases = new Iterator[(String, String)] {
def hasNext: Boolean = rs.next()
def next(): (String, String) =
(rs.getString("TABLE_CATALOG"), rs.getString("TABLE_SCHEM"))
}.toSeq
verify(catalogDatabases)
}
}

withConnection { conn =>
implicit val spark: SparkSession = conn.asInstanceOf[SparkConnectConnection].spark

registerCatalog("testcat", TEST_IN_MEMORY_CATALOG)

spark.sql("USE testcat")
spark.sql("CREATE DATABASE IF NOT EXISTS testcat.t_db1")
spark.sql("CREATE DATABASE IF NOT EXISTS testcat.t_db2")
spark.sql("CREATE DATABASE IF NOT EXISTS testcat.test_db3")

spark.sql("USE spark_catalog")
spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db1")
spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db2")
spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db_")

val metadata = conn.getMetaData
withDatabase("testcat.t_db1", "testcat.t_db2", "testcat.test_db3",
"spark_catalog.db1", "spark_catalog.db2", "spark_catalog.db_") {

// list schemas in all catalogs
val getSchemasInAllCatalogs = (() => metadata.getSchemas) ::
List(null, "%").map { database => () => metadata.getSchemas(null, database) } ::: Nil

getSchemasInAllCatalogs.foreach { getSchemas =>
verifyGetSchemas(getSchemas) { catalogDatabases =>
// results are ordered by TABLE_CATALOG, TABLE_SCHEM
assert {
catalogDatabases === Seq(
("spark_catalog", "db1"),
("spark_catalog", "db2"),
("spark_catalog", "db_"),
("spark_catalog", "default"),
("testcat", "t_db1"),
("testcat", "t_db2"),
("testcat", "test_db3"))
}
}
}

// list schemas in current catalog
assert(conn.getCatalog === "spark_catalog")
val getSchemasInCurrentCatalog =
List(null, "%").map { database => () => metadata.getSchemas("", database) }
getSchemasInCurrentCatalog.foreach { getSchemas =>
verifyGetSchemas(getSchemas) { catalogDatabases =>
// results are ordered by TABLE_CATALOG, TABLE_SCHEM
assert {
catalogDatabases === Seq(
("spark_catalog", "db1"),
("spark_catalog", "db2"),
("spark_catalog", "db_"),
("spark_catalog", "default"))
}
}
}

// list schemas with SQL pattern
verifyGetSchemas { () => metadata.getSchemas(null, "db%") } { catalogDatabases =>
// results are ordered by TABLE_CATALOG, TABLE_SCHEM
assert {
catalogDatabases === Seq(
("spark_catalog", "db1"),
("spark_catalog", "db2"),
("spark_catalog", "db_"))
}
}

verifyGetSchemas { () => metadata.getSchemas(null, "db_") } { catalogDatabases =>
// results are ordered by TABLE_CATALOG, TABLE_SCHEM
assert {
catalogDatabases === Seq(
("spark_catalog", "db1"),
("spark_catalog", "db2"),
("spark_catalog", "db_"))
}
}

verifyGetSchemas { () => metadata.getSchemas(null, "db\\_") } { catalogDatabases =>
// results are ordered by TABLE_CATALOG, TABLE_SCHEM
assert {
catalogDatabases === Seq(
("spark_catalog", "db_"))
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,16 @@ trait SQLHelper {
spark.sql(s"DROP VIEW IF EXISTS $name")
})
}

/**
* Drops database `dbName` after calling `f`.
*/
protected def withDatabase(dbNames: String*)(f: => Unit): Unit = {
SparkErrorUtils.tryWithSafeFinally(f) {
dbNames.foreach { name =>
spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE")
}
spark.sql(s"USE default")
}
}
}