diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala index 063e2e2bc8b77..2c022c11950bc 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaRelationSuite.scala @@ -624,7 +624,7 @@ class KafkaRelationSuiteV2 extends KafkaRelationSuiteBase { val topic = newTopic() val df = createDF(topic) assert(df.logicalPlan.collect { - case DataSourceV2Relation(_, _, _) => true + case _: DataSourceV2Relation => true }.nonEmpty) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 36e558b0dc571..4cb73cb6977e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -817,8 +817,8 @@ class Analyzer( case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) => CatalogV2Util.loadRelation(u.catalog, u.tableName) - .map(rel => alter.copy(table = rel)) - .getOrElse(alter) + .map(rel => alter.copy(table = rel)) + .getOrElse(alter) case u: UnresolvedV2Relation => CatalogV2Util.loadRelation(u.catalog, u.tableName).getOrElse(u) @@ -831,7 +831,8 @@ class Analyzer( expandRelationName(identifier) match { case NonSessionCatalogAndIdentifier(catalog, ident) => CatalogV2Util.loadTable(catalog, ident) match { - case Some(table) => Some(DataSourceV2Relation.create(table)) + case Some(table) => + Some(DataSourceV2Relation.create(table, Some(catalog), Some(ident))) case None => None } case _ => None @@ -923,7 +924,7 @@ class Analyzer( AnalysisContext.get.relationCache.getOrElseUpdate( key, v1SessionCatalog.getRelation(v1Table.v1Table)) case table => - DataSourceV2Relation.create(table) + DataSourceV2Relation.create(table, Some(catalog), Some(ident)) } case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 67726c7343524..7b2466bfe13e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -257,7 +257,7 @@ private[sql] object CatalogV2Util { } def loadRelation(catalog: CatalogPlugin, ident: Identifier): Option[NamedRelation] = { - loadTable(catalog, ident).map(DataSourceV2Relation.create) + loadTable(catalog, ident).map(DataSourceV2Relation.create(_, Some(catalog), Some(ident))) } def isSessionCatalog(catalog: CatalogPlugin): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 87d3419e8115f..45d89498f5ae9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelat import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.catalog.{Table, TableCapability} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, Statistics => V2Statistics, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.connector.write.WriteBuilder @@ -32,12 +32,17 @@ import org.apache.spark.util.Utils * A logical plan representing a data source v2 table. * * @param table The table that this relation represents. + * @param output the output attributes of this relation. + * @param catalog catalogPlugin for the table. None if no catalog is specified. + * @param identifier the identifier for the table. None if no identifier is defined. * @param options The options for this table operation. It's used to create fresh [[ScanBuilder]] * and [[WriteBuilder]]. */ case class DataSourceV2Relation( table: Table, output: Seq[AttributeReference], + catalog: Option[CatalogPlugin], + identifier: Option[Identifier], options: CaseInsensitiveStringMap) extends LeafNode with MultiInstanceRelation with NamedRelation { @@ -137,12 +142,20 @@ case class StreamingDataSourceV2Relation( } object DataSourceV2Relation { - def create(table: Table, options: CaseInsensitiveStringMap): DataSourceV2Relation = { + def create( + table: Table, + catalog: Option[CatalogPlugin], + identifier: Option[Identifier], + options: CaseInsensitiveStringMap): DataSourceV2Relation = { val output = table.schema().toAttributes - DataSourceV2Relation(table, output, options) + DataSourceV2Relation(table, output, catalog, identifier, options) } - def create(table: Table): DataSourceV2Relation = create(table, CaseInsensitiveStringMap.empty) + def create( + table: Table, + catalog: Option[CatalogPlugin], + identifier: Option[Identifier]): DataSourceV2Relation = + create(table, catalog, identifier, CaseInsensitiveStringMap.empty) /** * This is used to transform data source v2 statistics to logical.Statistics. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala new file mode 100644 index 0000000000000..7a9a7f52ff8fd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.mockito.Mockito.{mock, when} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.StructType + +class CatalogV2UtilSuite extends SparkFunSuite { + test("Load relation should encode the identifiers for V2Relations") { + val testCatalog = mock(classOf[TableCatalog]) + val ident = mock(classOf[Identifier]) + val table = mock(classOf[Table]) + when(table.schema()).thenReturn(mock(classOf[StructType])) + when(testCatalog.loadTable(ident)).thenReturn(table) + val r = CatalogV2Util.loadRelation(testCatalog, ident) + assert(r.isDefined) + assert(r.get.isInstanceOf[DataSourceV2Relation]) + val v2Relation = r.get.asInstanceOf[DataSourceV2Relation] + assert(v2Relation.catalog.exists(_ == testCatalog)) + assert(v2Relation.identifier.exists(_ == ident)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 1af4931c553ee..b5d7bbca9064d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -195,6 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf).map { provider => + val catalogManager = sparkSession.sessionState.catalogManager val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = sparkSession.sessionState.conf) val pathsOption = if (paths.isEmpty) { @@ -206,7 +207,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = provider match { + val (table, catalog, ident) = provider match { case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => throw new IllegalArgumentException( s"$source does not support user specified schema. Please don't specify the schema.") @@ -214,19 +215,22 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val ident = hasCatalog.extractIdentifier(dsOptions) val catalog = CatalogV2Util.getTableProviderCatalog( hasCatalog, - sparkSession.sessionState.catalogManager, + catalogManager, dsOptions) - catalog.loadTable(ident) + (catalog.loadTable(ident), Some(catalog), Some(ident)) case _ => + // TODO: Non-catalog paths for DSV2 are currently not well defined. userSpecifiedSchema match { - case Some(schema) => provider.getTable(dsOptions, schema) - case _ => provider.getTable(dsOptions) + case Some(schema) => (provider.getTable(dsOptions, schema), None, None) + case _ => (provider.getTable(dsOptions), None, None) } } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supports(BATCH_READ) => - Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions)) + Dataset.ofRows( + sparkSession, + DataSourceV2Relation.create(table, catalog, ident, dsOptions)) case _ => loadV1Source(paths: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 998ec9ebdff85..c041d14c8b8df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -258,20 +258,20 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new CaseInsensitiveStringMap(options.asJava) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + val catalogManager = df.sparkSession.sessionState.catalogManager mode match { case SaveMode.Append | SaveMode.Overwrite => - val table = provider match { + val (table, catalog, ident) = provider match { case supportsExtract: SupportsCatalogOptions => val ident = supportsExtract.extractIdentifier(dsOptions) - val sessionState = df.sparkSession.sessionState val catalog = CatalogV2Util.getTableProviderCatalog( - supportsExtract, sessionState.catalogManager, dsOptions) + supportsExtract, catalogManager, dsOptions) - catalog.loadTable(ident) + (catalog.loadTable(ident), Some(catalog), Some(ident)) case tableProvider: TableProvider => val t = tableProvider.getTable(dsOptions) if (t.supports(BATCH_WRITE)) { - t + (t, None, None) } else { // Streaming also uses the data source V2 API. So it may be that the data source // implements v2, but has no v2 implementation for batch writes. In that case, we @@ -280,7 +280,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } } - val relation = DataSourceV2Relation.create(table, dsOptions) + val relation = DataSourceV2Relation.create(table, catalog, ident, dsOptions) checkPartitioningMatchesV2Table(table) if (mode == SaveMode.Append) { runCommand(df.sparkSession, "save") { @@ -299,9 +299,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { provider match { case supportsExtract: SupportsCatalogOptions => val ident = supportsExtract.extractIdentifier(dsOptions) - val sessionState = df.sparkSession.sessionState val catalog = CatalogV2Util.getTableProviderCatalog( - supportsExtract, sessionState.catalogManager, dsOptions) + supportsExtract, catalogManager, dsOptions) val location = Option(dsOptions.get("path")).map(TableCatalog.PROP_LOCATION -> _) @@ -419,7 +418,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { case _: V1Table => return insertInto(TableIdentifier(ident.name(), ident.namespace().headOption)) case t => - DataSourceV2Relation.create(t) + DataSourceV2Relation.create(t, Some(catalog), Some(ident)) } val command = mode match { @@ -554,12 +553,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { } val command = (mode, tableOpt) match { - case (_, Some(table: V1Table)) => + case (_, Some(_: V1Table)) => return saveAsTable(TableIdentifier(ident.name(), ident.namespace().headOption)) case (SaveMode.Append, Some(table)) => checkPartitioningMatchesV2Table(table) - AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan, extraOptions.toMap) + val v2Relation = DataSourceV2Relation.create(table, Some(catalog), Some(ident)) + AppendData.byName(v2Relation, df.logicalPlan, extraOptions.toMap) case (SaveMode.Overwrite, _) => ReplaceTableAsSelect( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index f0758809bd749..f5dd7613d4103 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -158,7 +158,9 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) def append(): Unit = { val append = loadTable(catalog, identifier) match { case Some(t) => - AppendData.byName(DataSourceV2Relation.create(t), logicalPlan, options.toMap) + AppendData.byName( + DataSourceV2Relation.create(t, Some(catalog), Some(identifier)), + logicalPlan, options.toMap) case _ => throw new NoSuchTableException(identifier) } @@ -181,7 +183,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) val overwrite = loadTable(catalog, identifier) match { case Some(t) => OverwriteByExpression.byName( - DataSourceV2Relation.create(t), logicalPlan, condition.expr, options.toMap) + DataSourceV2Relation.create(t, Some(catalog), Some(identifier)), + logicalPlan, condition.expr, options.toMap) case _ => throw new NoSuchTableException(identifier) } @@ -207,7 +210,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) val dynamicOverwrite = loadTable(catalog, identifier) match { case Some(t) => OverwritePartitionsDynamic.byName( - DataSourceV2Relation.create(t), logicalPlan, options.toMap) + DataSourceV2Relation.create(t, Some(catalog), Some(identifier)), + logicalPlan, options.toMap) case _ => throw new NoSuchTableException(identifier) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 75e11abaa3161..413bd7b29cf45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -270,7 +270,7 @@ class CacheManager extends Logging { case _ => false } - case DataSourceV2Relation(fileTable: FileTable, _, _) => + case DataSourceV2Relation(fileTable: FileTable, _, _, _, _) => refreshFileIndexIfNecessary(fileTable.fileIndex, fs, qualifiedPath) case _ => false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala index 5f6c3e8f7eeed..28a63c26604ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallBackFileSourceV2.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, File */ class FallBackFileSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoStatement(d @ DataSourceV2Relation(table: FileTable, _, _), _, _, _, _) => + case i @ + InsertIntoStatement(d @ DataSourceV2Relation(table: FileTable, _, _, _, _), _, _, _, _) => val v1FileFormat = table.fallbackFileFormat.newInstance() val relation = HadoopFsRelation( table.fileIndex, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 448a4354ddd66..8b4b6fb64658a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -232,11 +232,11 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case desc @ DescribeNamespace(ResolvedNamespace(catalog, ns), extended) => DescribeNamespaceExec(desc.output, catalog, ns, extended) :: Nil - case desc @ DescribeRelation(ResolvedTable(_, _, table), partitionSpec, isExtended) => + case desc @ DescribeRelation(r: ResolvedTable, partitionSpec, isExtended) => if (partitionSpec.nonEmpty) { throw new AnalysisException("DESCRIBE does not support partition for v2 tables.") } - DescribeTableExec(desc.output, table, isExtended) :: Nil + DescribeTableExec(desc.output, r.table, isExtended) :: Nil case DropTable(catalog, ident, ifExists) => DropTableExec(catalog, ident, ifExists) :: Nil @@ -284,8 +284,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case r: ShowCurrentNamespace => ShowCurrentNamespaceExec(r.output, r.catalogManager) :: Nil - case r @ ShowTableProperties(ResolvedTable(_, _, table), propertyKey) => - ShowTablePropertiesExec(r.output, table, propertyKey) :: Nil + case r @ ShowTableProperties(rt: ResolvedTable, propertyKey) => + ShowTablePropertiesExec(r.output, rt.table, propertyKey) :: Nil case _ => Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 4e6381aea3c31..d49dc58e93ddb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -22,11 +22,15 @@ import scala.collection.JavaConverters._ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} import org.apache.spark.sql.connector.InMemoryTableCatalog import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.util.Utils class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -54,6 +58,45 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo spark.sessionState.conf.clear() } + test("DataFrameWriteV2 encode identifiers correctly") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + var plan: LogicalPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + + } + override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {} + } + spark.listenerManager.register(listener) + + spark.table("source").writeTo("testcat.table_name").append() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[AppendData]) + checkV2Identifiers(plan.asInstanceOf[AppendData].table) + + spark.table("source").writeTo("testcat.table_name").overwrite(lit(true)) + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[OverwriteByExpression]) + checkV2Identifiers(plan.asInstanceOf[OverwriteByExpression].table) + + spark.table("source").writeTo("testcat.table_name").overwritePartitions() + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[OverwritePartitionsDynamic]) + checkV2Identifiers(plan.asInstanceOf[OverwritePartitionsDynamic].table) + } + + private def checkV2Identifiers( + plan: LogicalPlan, + identifier: String = "table_name", + catalogPlugin: TableCatalog = catalog("testcat")): Unit = { + assert(plan.isInstanceOf[DataSourceV2Relation]) + val v2 = plan.asInstanceOf[DataSourceV2Relation] + assert(v2.identifier.exists(_.name() == identifier)) + assert(v2.catalog.exists(_ == catalogPlugin)) + } + test("Append: basic append") { spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala index 0148bb07ee967..cec48bb368aef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SupportsCatalogOptionsSuite.scala @@ -26,13 +26,16 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.connector.catalog.{Identifier, SupportsCatalogOptions, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{LongType, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, QueryExecutionListener} class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with BeforeAndAfter { @@ -196,11 +199,79 @@ class SupportsCatalogOptionsSuite extends QueryTest with SharedSparkSession with assert(e.getMessage.contains("not support user specified schema")) } + test("DataFrameReader creates v2Relation with identifiers") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + val df = load("t1", Some(catalogName)) + checkV2Identifiers(df.logicalPlan) + } + + test("DataFrameWriter creates v2Relation with identifiers") { + sql(s"create table $catalogName.t1 (id bigint) using $format") + + var plan: LogicalPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + } + override def onFailure(funcName: String, qe: QueryExecution, error: Throwable): Unit = {} + } + + spark.listenerManager.register(listener) + + try { + // Test append + save("t1", SaveMode.Append, Some(catalogName)) + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[AppendData]) + val appendRelation = plan.asInstanceOf[AppendData].table + checkV2Identifiers(appendRelation) + + // Test overwrite + save("t1", SaveMode.Overwrite, Some(catalogName)) + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[OverwriteByExpression]) + val overwriteRelation = plan.asInstanceOf[OverwriteByExpression].table + checkV2Identifiers(overwriteRelation) + + // Test insert + spark.range(10).write.format(format).insertInto(s"$catalogName.t1") + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[AppendData]) + val insertRelation = plan.asInstanceOf[AppendData].table + checkV2Identifiers(insertRelation) + + // Test saveAsTable append + spark.range(10).write.format(format).mode(SaveMode.Append).saveAsTable(s"$catalogName.t1") + sparkContext.listenerBus.waitUntilEmpty() + assert(plan.isInstanceOf[AppendData]) + val saveAsTableRelation = plan.asInstanceOf[AppendData].table + checkV2Identifiers(saveAsTableRelation) + } finally { + spark.listenerManager.unregister(listener) + } + } + + private def checkV2Identifiers( + plan: LogicalPlan, + identifier: String = "t1", + catalogPlugin: TableCatalog = catalog(catalogName)): Unit = { + assert(plan.isInstanceOf[DataSourceV2Relation]) + val v2 = plan.asInstanceOf[DataSourceV2Relation] + assert(v2.identifier.exists(_.name() == identifier)) + assert(v2.catalog.exists(_ == catalogPlugin)) + } + private def load(name: String, catalogOpt: Option[String]): DataFrame = { - val dfr = spark.read.format(format).option("name", "t1") + val dfr = spark.read.format(format).option("name", name) catalogOpt.foreach(cName => dfr.option("catalog", cName)) dfr.load() } + + private def save(name: String, mode: SaveMode, catalogOpt: Option[String]): Unit = { + val df = spark.range(10).write.format(format).option("name", name) + catalogOpt.foreach(cName => df.option("catalog", cName)) + df.mode(mode).save() + } } class CatalogSupportingInMemoryTableProvider diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce6d56cf84df1..5196ca65276e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{AnalysisSuite, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.connector.catalog.{Table, TableCapability, TableProvider} +import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, TableCapabilityCheck} @@ -37,6 +37,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { + private val emptyMap = CaseInsensitiveStringMap.empty private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = { StreamingRelationV2( TestTableProvider, @@ -53,9 +54,9 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("batch scan: check missing capabilities") { val e = intercept[AnalysisException] { - TableCapabilityCheck.apply(DataSourceV2Relation.create( - CapabilityTable(), - CaseInsensitiveStringMap.empty)) + TableCapabilityCheck.apply( + DataSourceV2Relation.create(CapabilityTable(), None, None, emptyMap) + ) } assert(e.message.contains("does not support batch scan")) } @@ -88,7 +89,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("AppendData: check missing capabilities") { val plan = AppendData.byName( - DataSourceV2Relation.create(CapabilityTable(), CaseInsensitiveStringMap.empty), TestRelation) + DataSourceV2Relation.create(CapabilityTable(), None, None, emptyMap), + TestRelation) val exc = intercept[AnalysisException]{ TableCapabilityCheck.apply(plan) @@ -100,7 +102,7 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("AppendData: check correct capabilities") { Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write => val plan = AppendData.byName( - DataSourceV2Relation.create(CapabilityTable(write), CaseInsensitiveStringMap.empty), + DataSourceV2Relation.create(CapabilityTable(write), None, None, emptyMap), TestRelation) TableCapabilityCheck.apply(plan) @@ -115,7 +117,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => val plan = OverwriteByExpression.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + DataSourceV2Relation.create(table, None, None, emptyMap), + TestRelation, Literal(true)) val exc = intercept[AnalysisException]{ @@ -133,7 +136,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { CapabilityTable(V1_BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => val plan = OverwriteByExpression.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + DataSourceV2Relation.create(table, None, None, emptyMap), + TestRelation, Literal(true)) TableCapabilityCheck.apply(plan) @@ -147,7 +151,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => val plan = OverwriteByExpression.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + DataSourceV2Relation.create(table, None, None, emptyMap), + TestRelation, EqualTo(AttributeReference("x", LongType)(), Literal(5))) val exc = intercept[AnalysisException]{ @@ -162,7 +167,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { Seq(BATCH_WRITE, V1_BATCH_WRITE).foreach { write => val table = CapabilityTable(write, OVERWRITE_BY_FILTER) val plan = OverwriteByExpression.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + DataSourceV2Relation.create(table, None, None, emptyMap), + TestRelation, EqualTo(AttributeReference("x", LongType)(), Literal(5))) TableCapabilityCheck.apply(plan) @@ -175,7 +181,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { CapabilityTable(OVERWRITE_DYNAMIC)).foreach { table => val plan = OverwritePartitionsDynamic.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + DataSourceV2Relation.create(table, None, None, emptyMap), + TestRelation) val exc = intercept[AnalysisException] { TableCapabilityCheck.apply(plan) @@ -188,7 +195,8 @@ class TableCapabilityCheckSuite extends AnalysisSuite with SharedSparkSession { test("OverwritePartitionsDynamic: check correct capabilities") { val table = CapabilityTable(BATCH_WRITE, OVERWRITE_DYNAMIC) val plan = OverwritePartitionsDynamic.byName( - DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + DataSourceV2Relation.create(table, None, None, emptyMap), + TestRelation) TableCapabilityCheck.apply(plan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 8f17ce7f32c82..70b9b7ec12ea2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -26,11 +26,11 @@ import org.mockito.invocation.InvocationOnMock import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, CTESubstitution, EmptyFunctionRegistry, NoSuchTableException, ResolveCatalogs, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedV2Relation} +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer, CTESubstitution, EmptyFunctionRegistry, NoSuchTableException, ResolveCatalogs, ResolvedTable, ResolveInlineTables, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, UnresolvedV2Relation} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, StringLiteral} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, SubqueryAlias, UpdateAction, UpdateTable} +import org.apache.spark.sql.catalyst.plans.logical.{AlterTable, Assignment, CreateTableAsSelect, CreateV2Table, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, Project, ShowTableProperties, SubqueryAlias, UpdateAction, UpdateTable} import org.apache.spark.sql.connector.InMemoryTableProvider import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogNotFoundException, Identifier, Table, TableCapability, TableCatalog, TableChange, V1Table} import org.apache.spark.sql.execution.datasources.CreateTable @@ -104,6 +104,8 @@ class PlanResolutionSuite extends AnalysisTest { invocation.getArgument[String](0) match { case "testcat" => testCat + case CatalogManager.SESSION_CATALOG_NAME => + v2SessionCatalog case name => throw new CatalogNotFoundException(s"No such catalog: $name") } @@ -139,6 +141,7 @@ class PlanResolutionSuite extends AnalysisTest { // TODO: run the analyzer directly. val rules = Seq( CTESubstitution, + ResolveInlineTables(conf), analyzer.ResolveRelations, new ResolveCatalogs(catalogManager), new ResolveSessionCatalog(catalogManager, conf, _ == Seq("v")), @@ -1072,6 +1075,54 @@ class PlanResolutionSuite extends AnalysisTest { } } + val DSV2ResolutionTests = { + val v2SessionCatalogTable = s"${CatalogManager.SESSION_CATALOG_NAME}.v2Table" + Seq( + ("ALTER TABLE testcat.tab ALTER COLUMN i TYPE bigint", false), + ("ALTER TABLE tab ALTER COLUMN i TYPE bigint", false), + (s"ALTER TABLE $v2SessionCatalogTable ALTER COLUMN i TYPE bigint", true), + ("INSERT INTO TABLE tab VALUES (1)", false), + ("INSERT INTO TABLE testcat.tab VALUES (1)", false), + (s"INSERT INTO TABLE $v2SessionCatalogTable VALUES (1)", true), + ("DESC TABLE tab", false), + ("DESC TABLE testcat.tab", false), + (s"DESC TABLE $v2SessionCatalogTable", true), + ("SHOW TBLPROPERTIES tab", false), + ("SHOW TBLPROPERTIES testcat.tab", false), + (s"SHOW TBLPROPERTIES $v2SessionCatalogTable", true), + ("SELECT * from tab", false), + ("SELECT * from testcat.tab", false), + (s"SELECT * from ${CatalogManager.SESSION_CATALOG_NAME}.v2Table", true) + ) + } + + DSV2ResolutionTests.foreach { case (sql, isSessionCatlog) => + test(s"Data source V2 relation resolution '$sql'") { + val parsed = parseAndResolve(sql, withDefault = true) + val catlogIdent = if (isSessionCatlog) v2SessionCatalog else testCat + val tableIdent = if (isSessionCatlog) "v2Table" else "tab" + parsed match { + case AlterTable(_, _, r: DataSourceV2Relation, _) => + assert(r.catalog.exists(_ == catlogIdent)) + assert(r.identifier.exists(_.name() == tableIdent)) + case Project(_, r: DataSourceV2Relation) => + assert(r.catalog.exists(_ == catlogIdent)) + assert(r.identifier.exists(_.name() == tableIdent)) + case InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _) => + assert(r.catalog.exists(_ == catlogIdent)) + assert(r.identifier.exists(_.name() == tableIdent)) + case DescribeRelation(r: ResolvedTable, _, _) => + assert(r.catalog == catlogIdent) + assert(r.identifier.name() == tableIdent) + case ShowTableProperties(r: ResolvedTable, _) => + assert(r.catalog == catlogIdent) + assert(r.identifier.name() == tableIdent) + case ShowTablePropertiesCommand(t: TableIdentifier, _) => + assert(t.identifier == tableIdent) + } + } + } + test("MERGE INTO TABLE") { def checkResolution( target: LogicalPlan, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 6497a1ceb5c0e..e63929470ce5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -1281,7 +1281,7 @@ class ParquetV2PartitionDiscoverySuite extends ParquetPartitionDiscoverySuite { (1 to 10).map(i => (i, i.toString)).toDF("a", "b").write.parquet(dir.getCanonicalPath) val queryExecution = spark.read.parquet(dir.getCanonicalPath).queryExecution queryExecution.analyzed.collectFirst { - case DataSourceV2Relation(fileTable: FileTable, _, _) => + case DataSourceV2Relation(fileTable: FileTable, _, _, _, _) => assert(fileTable.fileIndex.partitionSpec() === PartitionSpec.emptySpec) }.getOrElse { fail(s"Expecting a matching DataSourceV2Relation, but got:\n$queryExecution") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 9bce7f3568e81..877965100f018 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -657,7 +657,7 @@ class FileStreamSinkV2Suite extends FileStreamSinkSuite { // Verify that MetadataLogFileIndex is being used and the correct partitioning schema has // been inferred val table = df.queryExecution.analyzed.collect { - case DataSourceV2Relation(table: FileTable, _, _) => table + case DataSourceV2Relation(table: FileTable, _, _, _, _) => table } assert(table.size === 1) assert(table.head.fileIndex.isInstanceOf[MetadataLogFileIndex])