diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala index cd9bcc0f44f7..d5079202c8fe 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/utils/CatalogV2Util.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, NamespaceChange, TableChange} import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, DeleteColumn, RemoveProperty, RenameColumn, SetProperty, UpdateColumnComment, UpdateColumnType} -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException} import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types.{ArrayType, MapType, StructField, StructType} @@ -219,5 +219,7 @@ object CatalogV2Util { Option(catalog.asTableCatalog.loadTable(ident)) } catch { case _: NoSuchTableException => None + case _: NoSuchDatabaseException => None + case _: NoSuchNamespaceException => None } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/internal/UnresolvedTable.scala b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/internal/UnresolvedTable.scala new file mode 100644 index 000000000000..8813d0ab840d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/internal/UnresolvedTable.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.internal + +import java.util + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.sql.catalog.v2.expressions.{LogicalExpressions, Transform} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.types.StructType + +/** + * An implementation of catalog v2 `Table` to expose v1 table metadata. + */ +case class UnresolvedTable(v1Table: CatalogTable) extends Table { + implicit class IdentifierHelper(identifier: TableIdentifier) { + def quoted: String = { + identifier.database match { + case Some(db) => + Seq(db, identifier.table).map(quote).mkString(".") + case _ => + quote(identifier.table) + + } + } + + private def quote(part: String): String = { + if (part.contains(".") || part.contains("`")) { + s"`${part.replace("`", "``")}`" + } else { + part + } + } + } + + def catalogTable: CatalogTable = v1Table + + lazy val options: Map[String, String] = { + v1Table.storage.locationUri match { + case Some(uri) => + v1Table.storage.properties + ("path" -> uri.toString) + case _ => + v1Table.storage.properties + } + } + + override lazy val properties: util.Map[String, String] = v1Table.properties.asJava + + override lazy val schema: StructType = v1Table.schema + + override lazy val partitioning: Array[Transform] = { + val partitions = new mutable.ArrayBuffer[Transform]() + + v1Table.partitionColumnNames.foreach { col => + partitions += LogicalExpressions.identity(col) + } + + v1Table.bucketSpec.foreach { spec => + partitions += LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) + } + + partitions.toArray + } + + override def name: String = v1Table.identifier.quoted + + override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() + + override def toString: String = s"UnresolvedTable($name)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7267ad8d37c8..3a72988f8345 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.sources.v2.internal.UnresolvedTable import org.apache.spark.sql.types._ /** @@ -650,8 +651,14 @@ class Analyzer( if catalog.isTemporaryTable(ident) => u // temporary views take precedence over catalog table names - case u @ UnresolvedRelation(CatalogObjectIdentifier(Some(catalogPlugin), ident)) => - loadTable(catalogPlugin, ident).map(DataSourceV2Relation.create).getOrElse(u) + case u @ UnresolvedRelation(CatalogObjectIdentifier(maybeCatalog, ident)) => + maybeCatalog.orElse(sessionCatalog) + .flatMap(loadTable(_, ident)) + .map { + case unresolved: UnresolvedTable => u + case resolved => DataSourceV2Relation.create(resolved) + } + .getOrElse(u) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index af7ddd756ae8..0b49cf24e6c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.sources.v2.internal.UnresolvedTable import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -251,19 +252,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") val session = df.sparkSession - val useV1Sources = - session.sessionState.conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") val cls = DataSource.lookupDataSource(source, session.sessionState.conf) - val shouldUseV1Source = cls.newInstance() match { - case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => true - case _ => useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) - } + val canUseV2 = canUseV2Source(session, cls) && partitioningColumns.isEmpty // In Data Source V2 project, partitioning is still under development. // Here we fallback to V1 if partitioning columns are specified. // TODO(SPARK-26778): use V2 implementations when partitioning feature is supported. - if (!shouldUseV1Source && classOf[TableProvider].isAssignableFrom(cls) && - partitioningColumns.isEmpty) { + if (canUseV2) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( provider, session.sessionState.conf) @@ -493,13 +488,20 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ - import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ val session = df.sparkSession + val provider = DataSource.lookupDataSource(source, session.sessionState.conf) + val canUseV2 = canUseV2Source(session, provider) + val sessionCatalogOpt = session.sessionState.analyzer.sessionCatalog session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case CatalogObjectIdentifier(Some(catalog), ident) => saveAsTable(catalog.asTableCatalog, ident, modeForDSV2) - // TODO(SPARK-28666): This should go through V2SessionCatalog + + case CatalogObjectIdentifier(None, ident) + if canUseV2 && sessionCatalogOpt.isDefined && ident.namespace().length <= 1 => + // We pass in the modeForDSV1, as using the V2 session catalog should maintain compatibility + // for now. + saveAsTable(sessionCatalogOpt.get.asTableCatalog, ident, modeForDSV1) case AsTableIdentifier(tableIdentifier) => saveAsTable(tableIdentifier) @@ -525,6 +527,9 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } val command = (mode, tableOpt) match { + case (_, Some(table: UnresolvedTable)) => + return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) + case (SaveMode.Append, Some(table)) => AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan) @@ -830,6 +835,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { private def modeForDSV2 = mode.getOrElse(SaveMode.Append) + private def canUseV2Source(session: SparkSession, providerClass: Class[_]): Boolean = { + val useV1Sources = + session.sessionState.conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") + val shouldUseV1Source = providerClass.newInstance() match { + case d: DataSourceRegister if useV1Sources.contains(d.shortName()) => true + case _ => useV1Sources.contains(providerClass.getCanonicalName.toLowerCase(Locale.ROOT)) + } + !shouldUseV1Source && classOf[TableProvider].isAssignableFrom(providerClass) + } + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 4791fe5fb525..48b504a6545f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -31,11 +31,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DeleteFromStatement, DescribeColumnStatement, DescribeTableStatement, DropTableStatement, DropViewStatement, QualifiedColType, ReplaceTableAsSelectStatement, ReplaceTableStatement} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DescribeColumnCommand, DescribeTableCommand, DropTableCommand} -import org.apache.spark.sql.execution.datasources.v2.{CatalogTableAsV2, DataSourceV2Relation} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} -import org.apache.spark.sql.util.SchemaUtils case class DataSourceResolution( conf: SQLConf, @@ -183,8 +181,6 @@ case class DataSourceResolution( val aliased = delete.tableAlias.map(SubqueryAlias(_, relation)).getOrElse(relation) DeleteFromTable(aliased, delete.condition) - case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => - UnresolvedCatalogRelation(catalogTable) } object V1WriteProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 79ea8756721e..6dcebe29537d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchT import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils, SessionCatalog} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.sources.v2.internal.UnresolvedTable import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -70,7 +71,7 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { throw new NoSuchTableException(ident) } - CatalogTableAsV2(catalogTable) + UnresolvedTable(catalogTable) } override def invalidateTable(ident: Identifier): Unit = { @@ -179,66 +180,6 @@ class V2SessionCatalog(sessionState: SessionState) extends TableCatalog { override def toString: String = s"V2SessionCatalog($name)" } -/** - * An implementation of catalog v2 [[Table]] to expose v1 table metadata. - */ -case class CatalogTableAsV2(v1Table: CatalogTable) extends Table { - implicit class IdentifierHelper(identifier: TableIdentifier) { - def quoted: String = { - identifier.database match { - case Some(db) => - Seq(db, identifier.table).map(quote).mkString(".") - case _ => - quote(identifier.table) - - } - } - - private def quote(part: String): String = { - if (part.contains(".") || part.contains("`")) { - s"`${part.replace("`", "``")}`" - } else { - part - } - } - } - - def catalogTable: CatalogTable = v1Table - - lazy val options: Map[String, String] = { - v1Table.storage.locationUri match { - case Some(uri) => - v1Table.storage.properties + ("path" -> uri.toString) - case _ => - v1Table.storage.properties - } - } - - override lazy val properties: util.Map[String, String] = v1Table.properties.asJava - - override lazy val schema: StructType = v1Table.schema - - override lazy val partitioning: Array[Transform] = { - val partitions = new mutable.ArrayBuffer[Transform]() - - v1Table.partitionColumnNames.foreach { col => - partitions += LogicalExpressions.identity(col) - } - - v1Table.bucketSpec.foreach { spec => - partitions += LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) - } - - partitions.toArray - } - - override def name: String = v1Table.identifier.quoted - - override def capabilities: util.Set[TableCapability] = new util.HashSet[TableCapability]() - - override def toString: String = s"CatalogTableAsV2($name)" -} - private[sql] object V2SessionCatalog { /** * Convert v2 Transforms to v1 partition columns and an optional bucket spec. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala new file mode 100644 index 000000000000..2ef2df3345e8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} +import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class DataSourceV2DataFrameSessionCatalogSuite + extends QueryTest + with SharedSQLContext + with BeforeAndAfter { + import testImplicits._ + + private val v2Format = classOf[InMemoryTableProvider].getName + + before { + spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[TestV2SessionCatalog].getName) + } + + override def afterEach(): Unit = { + super.afterEach() + spark.catalog("session").asInstanceOf[TestV2SessionCatalog].clearTables() + spark.conf.set(SQLConf.V2_SESSION_CATALOG.key, classOf[V2SessionCatalog].getName) + } + + private def verifyTable(tableName: String, expected: DataFrame): Unit = { + checkAnswer(spark.table(tableName), expected) + checkAnswer(sql(s"SELECT * FROM $tableName"), expected) + checkAnswer(sql(s"SELECT * FROM default.$tableName"), expected) + checkAnswer(sql(s"TABLE $tableName"), expected) + } + + test("saveAsTable: v2 table - table doesn't exist and default mode (ErrorIfExists)") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.format(v2Format).saveAsTable(t1) + verifyTable(t1, df) + } + + test("saveAsTable: v2 table - table doesn't exist and append mode") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.format(v2Format).mode("append").saveAsTable(t1) + verifyTable(t1, df) + } + + test("saveAsTable: Append mode should not fail if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable: Append mode should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + val format = spark.sessionState.conf.defaultDataSourceName + sql(s"CREATE TABLE same_name(id LONG) USING $format") + spark.range(10).createTempView("same_name") + spark.range(20).write.format(v2Format).mode(SaveMode.Append).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable: v2 table - table exists") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + spark.sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format") + intercept[TableAlreadyExistsException] { + df.select("id", "data").write.format(v2Format).saveAsTable(t1) + } + df.write.format(v2Format).mode("append").saveAsTable(t1) + verifyTable(t1, df) + + // Check that appends are by name + df.select('data, 'id).write.format(v2Format).mode("append").saveAsTable(t1) + verifyTable(t1, df.union(df)) + } + + test("saveAsTable: v2 table - table overwrite and table doesn't exist") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.format(v2Format).mode("overwrite").saveAsTable(t1) + verifyTable(t1, df) + } + + test("saveAsTable: v2 table - table overwrite and table exists") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + spark.sql(s"CREATE TABLE $t1 USING $v2Format AS SELECT 'c', 'd'") + df.write.format(v2Format).mode("overwrite").saveAsTable(t1) + verifyTable(t1, df) + } + + test("saveAsTable: Overwrite mode should not drop the temp view if the table not exists " + + "but a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + spark.range(10).createTempView("same_name") + spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name") + assert(spark.sessionState.catalog.getTempView("same_name").isDefined) + assert( + spark.sessionState.catalog.tableExists(TableIdentifier("same_name", Some("default")))) + } + } + } + + test("saveAsTable with mode Overwrite should not fail if the table already exists " + + "and a same-name temp view exist") { + withTable("same_name") { + withTempView("same_name") { + sql(s"CREATE TABLE same_name(id LONG) USING $v2Format") + spark.range(10).createTempView("same_name") + spark.range(20).write.format(v2Format).mode(SaveMode.Overwrite).saveAsTable("same_name") + checkAnswer(spark.table("same_name"), spark.range(10).toDF()) + checkAnswer(spark.table("default.same_name"), spark.range(20).toDF()) + } + } + } + + test("saveAsTable: v2 table - ignore mode and table doesn't exist") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.format(v2Format).mode("ignore").saveAsTable(t1) + verifyTable(t1, df) + } + + test("saveAsTable: v2 table - ignore mode and table exists") { + val t1 = "tbl" + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + spark.sql(s"CREATE TABLE $t1 USING $v2Format AS SELECT 'c', 'd'") + df.write.format(v2Format).mode("ignore").saveAsTable(t1) + verifyTable(t1, Seq(("c", "d")).toDF("id", "data")) + } +} + +class InMemoryTableProvider extends TableProvider { + override def getTable(options: CaseInsensitiveStringMap): Table = { + throw new UnsupportedOperationException("D'oh!") + } +} + +/** A SessionCatalog that always loads an in memory Table, so we can test write code paths. */ +class TestV2SessionCatalog extends V2SessionCatalog { + + protected val tables: util.Map[Identifier, InMemoryTable] = + new ConcurrentHashMap[Identifier, InMemoryTable]() + + private def fullIdentifier(ident: Identifier): Identifier = { + if (ident.namespace().isEmpty) { + Identifier.of(Array("default"), ident.name()) + } else { + ident + } + } + + override def loadTable(ident: Identifier): Table = { + val fullIdent = fullIdentifier(ident) + if (tables.containsKey(fullIdent)) { + tables.get(fullIdent) + } else { + // Table was created through the built-in catalog + val t = super.loadTable(fullIdent) + val table = new InMemoryTable(t.name(), t.schema(), t.partitioning(), t.properties()) + tables.put(fullIdent, table) + table + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + val created = super.createTable(ident, schema, partitions, properties) + val t = new InMemoryTable(created.name(), schema, partitions, properties) + val fullIdent = fullIdentifier(ident) + tables.put(fullIdent, t) + t + } + + def clearTables(): Unit = { + assert(!tables.isEmpty, "Tables were empty, maybe didn't use the session catalog code path?") + tables.keySet().asScala.foreach(super.dropTable) + tables.clear() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 9b1a23a1f2bb..cfa6506a95e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.datasources.v2.V2SessionCatalog import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG} +import org.apache.spark.sql.sources.v2.internal.UnresolvedTable import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType, LongType, MapType, Metadata, StringType, StructField, StructType, TimestampType} import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -493,8 +494,12 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn sparkSession.sql(s"CREATE TABLE table_name USING parquet AS SELECT id, data FROM source") - // use the catalog name to force loading with the v2 catalog - checkAnswer(sparkSession.sql(s"TABLE session.table_name"), sparkSession.table("source")) + checkAnswer(sparkSession.sql(s"TABLE default.table_name"), sparkSession.table("source")) + // The fact that the following line doesn't throw an exception means, the session catalog + // can load the table. + val t = sparkSession.catalog("session").asTableCatalog + .loadTable(Identifier.of(Array.empty, "table_name")) + assert(t.isInstanceOf[UnresolvedTable], "V1 table wasn't returned as an unresolved table") } test("DropTable: basic") {