diff --git a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala index 215c8256acbc3..34bdbcf133985 100644 --- a/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala +++ b/sql/connect/client/jdbc/src/main/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaData.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.connect.client.jdbc import java.sql.{Array => _, _} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} +import org.apache.spark.sql.connect import org.apache.spark.sql.connect.client.jdbc.SparkConnectDatabaseMetaData._ +import org.apache.spark.sql.functions._ import org.apache.spark.util.VersionUtils class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends DatabaseMetaData { @@ -277,6 +279,9 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas override def dataDefinitionIgnoredInTransactions: Boolean = false + private def isNullOrWildcard(pattern: String): Boolean = + pattern == null || pattern == "%" + override def getProcedures( catalog: String, schemaPattern: String, @@ -299,11 +304,59 @@ class SparkConnectDatabaseMetaData(conn: SparkConnectConnection) extends Databas new SparkConnectResultSet(df.collectResult()) } - override def getSchemas: ResultSet = - throw new SQLFeatureNotSupportedException + override def getSchemas: ResultSet = { + conn.checkOpen() - override def getSchemas(catalog: String, schemaPattern: String): ResultSet = - throw new SQLFeatureNotSupportedException + getSchemas(null, null) + } + + // Schema of the returned DataFrame is: + // |-- TABLE_SCHEM: string (nullable = false) + // |-- TABLE_CATALOG: string (nullable = false) + private def getSchemasDataFrame( + catalog: String, schemaPattern: String): connect.DataFrame = { + + val schemaFilterClause = + if (isNullOrWildcard(schemaPattern)) "1=1" else s"TABLE_SCHEM LIKE '$schemaPattern'" + + def internalGetSchemas( + catalogOpt: Option[String], + schemaFilterClause: String): connect.DataFrame = { + val catalog = catalogOpt.getOrElse(conn.getCatalog) + // Spark SQL supports LIKE clause in SHOW SCHEMAS command, but we can't use that + // because the LIKE pattern does not follow SQL standard. + conn.spark.sql(s"SHOW SCHEMAS IN `$catalog`") + .select($"namespace".as("TABLE_SCHEM")) + .filter(schemaFilterClause) + .withColumn("TABLE_CATALOG", lit(catalog)) + } + + if (catalog == null) { + // search in all catalogs + val emptyDf = conn.spark.emptyDataFrame + .withColumn("TABLE_SCHEM", lit("")) + .withColumn("TABLE_CATALOG", lit("")) + conn.spark.catalog.listCatalogs().collect().map(_.name).map { catalog => + internalGetSchemas(Some(catalog), schemaFilterClause) + }.fold(emptyDf) { (l, r) => l.unionAll(r) } + } else if (catalog == "") { + // search only in current catalog + internalGetSchemas(None, schemaFilterClause) + .withColumn("TABLE_CATALOG", lit(conn.getCatalog)) + } else { + // search in the specific catalog + internalGetSchemas(Some(catalog), schemaFilterClause) + .withColumn("TABLE_CATALOG", lit(catalog)) + } + } + + override def getSchemas(catalog: String, schemaPattern: String): ResultSet = { + conn.checkOpen() + + val df = getSchemasDataFrame(catalog, schemaPattern) + .orderBy("TABLE_CATALOG", "TABLE_SCHEM") + new SparkConnectResultSet(df.collectResult()) + } override def getTableTypes: ResultSet = throw new SQLFeatureNotSupportedException diff --git a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala index 42596b56f4c56..d462486ed38bd 100644 --- a/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala +++ b/sql/connect/client/jdbc/src/test/scala/org/apache/spark/sql/connect/client/jdbc/SparkConnectDatabaseMetaDataSuite.scala @@ -235,4 +235,106 @@ class SparkConnectDatabaseMetaDataSuite extends ConnectFunSuite with RemoteSpark } } } + + test("SparkConnectDatabaseMetaData getSchemas") { + + def verifyGetSchemas( + getSchemas: () => ResultSet)(verify: Seq[(String, String)] => Unit): Unit = { + Using.resource(getSchemas()) { rs => + val catalogDatabases = new Iterator[(String, String)] { + def hasNext: Boolean = rs.next() + def next(): (String, String) = + (rs.getString("TABLE_CATALOG"), rs.getString("TABLE_SCHEM")) + }.toSeq + verify(catalogDatabases) + } + } + + withConnection { conn => + implicit val spark: SparkSession = conn.asInstanceOf[SparkConnectConnection].spark + + registerCatalog("testcat", TEST_IN_MEMORY_CATALOG) + + spark.sql("USE testcat") + spark.sql("CREATE DATABASE IF NOT EXISTS testcat.t_db1") + spark.sql("CREATE DATABASE IF NOT EXISTS testcat.t_db2") + spark.sql("CREATE DATABASE IF NOT EXISTS testcat.test_db3") + + spark.sql("USE spark_catalog") + spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db1") + spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db2") + spark.sql("CREATE DATABASE IF NOT EXISTS spark_catalog.db_") + + val metadata = conn.getMetaData + withDatabase("testcat.t_db1", "testcat.t_db2", "testcat.test_db3", + "spark_catalog.db1", "spark_catalog.db2", "spark_catalog.db_") { + + // list schemas in all catalogs + val getSchemasInAllCatalogs = (() => metadata.getSchemas) :: + List(null, "%").map { database => () => metadata.getSchemas(null, database) } ::: Nil + + getSchemasInAllCatalogs.foreach { getSchemas => + verifyGetSchemas(getSchemas) { catalogDatabases => + // results are ordered by TABLE_CATALOG, TABLE_SCHEM + assert { + catalogDatabases === Seq( + ("spark_catalog", "db1"), + ("spark_catalog", "db2"), + ("spark_catalog", "db_"), + ("spark_catalog", "default"), + ("testcat", "t_db1"), + ("testcat", "t_db2"), + ("testcat", "test_db3")) + } + } + } + + // list schemas in current catalog + assert(conn.getCatalog === "spark_catalog") + val getSchemasInCurrentCatalog = + List(null, "%").map { database => () => metadata.getSchemas("", database) } + getSchemasInCurrentCatalog.foreach { getSchemas => + verifyGetSchemas(getSchemas) { catalogDatabases => + // results are ordered by TABLE_CATALOG, TABLE_SCHEM + assert { + catalogDatabases === Seq( + ("spark_catalog", "db1"), + ("spark_catalog", "db2"), + ("spark_catalog", "db_"), + ("spark_catalog", "default")) + } + } + } + + // list schemas with SQL pattern + verifyGetSchemas { () => metadata.getSchemas(null, "db%") } { catalogDatabases => + // results are ordered by TABLE_CATALOG, TABLE_SCHEM + assert { + catalogDatabases === Seq( + ("spark_catalog", "db1"), + ("spark_catalog", "db2"), + ("spark_catalog", "db_")) + } + } + + verifyGetSchemas { () => metadata.getSchemas(null, "db_") } { catalogDatabases => + // results are ordered by TABLE_CATALOG, TABLE_SCHEM + assert { + catalogDatabases === Seq( + ("spark_catalog", "db1"), + ("spark_catalog", "db2"), + ("spark_catalog", "db_")) + } + } + + verifyGetSchemas { () => metadata.getSchemas(null, "db\\_") } { catalogDatabases => + // results are ordered by TABLE_CATALOG, TABLE_SCHEM + assert { + catalogDatabases === Seq( + ("spark_catalog", "db_")) + } + } + } + } + } } diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala index b8d1062c3b3b0..731550363fc0a 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/test/SQLHelper.scala @@ -142,4 +142,16 @@ trait SQLHelper { spark.sql(s"DROP VIEW IF EXISTS $name") }) } + + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + SparkErrorUtils.tryWithSafeFinally(f) { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name CASCADE") + } + spark.sql(s"USE default") + } + } }