diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f0143fdb2347..5834f9bad4a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -860,6 +860,17 @@ class Analyzer( lookupTempView(ident) .map(view => i.copy(table = view)) .getOrElse(i) + // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: V2WriteCommand => + write.table match { + case UnresolvedRelation(ident, _, false) => + lookupTempView(ident).map(EliminateSubqueryAliases(_)).map { + case r: DataSourceV2Relation => write.withNewTable(r) + case _ => throw new AnalysisException("Cannot write into temp view " + + s"${ident.quoted} as it's not a data source v2 relation.") + }.getOrElse(write) + case _ => write + } case u @ UnresolvedTable(ident) => lookupTempView(ident).foreach { _ => u.failAnalysis(s"${ident.quoted} is a temp view not table.") @@ -942,6 +953,18 @@ class Analyzer( .map(v2Relation => i.copy(table = v2Relation)) .getOrElse(i) + // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: V2WriteCommand => + write.table match { + case u: UnresolvedRelation if !u.isStreaming => + lookupV2Relation(u.multipartIdentifier, u.options, false).map { + case r: DataSourceV2Relation => write.withNewTable(r) + case other => throw new IllegalStateException( + "[BUG] unexpected plan returned by `lookupV2Relation`: " + other) + }.getOrElse(write) + case _ => write + } + case alter @ AlterTable(_, _, u: UnresolvedV2Relation, _) => CatalogV2Util.loadRelation(u.catalog, u.tableName) .map(rel => alter.copy(table = rel)) @@ -1019,6 +1042,24 @@ class Analyzer( case other => i.copy(table = other) } + // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: V2WriteCommand => + write.table match { + case u: UnresolvedRelation if !u.isStreaming => + lookupRelation(u.multipartIdentifier, u.options, false) + .map(EliminateSubqueryAliases(_)) + .map { + case v: View => write.failAnalysis( + s"Writing into a view is not allowed. View: ${v.desc.identifier}.") + case u: UnresolvedCatalogRelation => write.failAnalysis( + "Cannot write into v1 table: " + u.tableMeta.identifier) + case r: DataSourceV2Relation => write.withNewTable(r) + case other => throw new IllegalStateException( + "[BUG] unexpected plan returned by `lookupRelation`: " + other) + }.getOrElse(write) + case _ => write + } + case u: UnresolvedRelation => lookupRelation(u.multipartIdentifier, u.options, u.isStreaming) .map(resolveViews).getOrElse(u) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index ac91fa0b5811..33a5224ed293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -108,6 +108,11 @@ trait CheckAnalysis extends PredicateHelper { case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _) => failAnalysis(s"Table not found: ${u.multipartIdentifier.quoted}") + // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] => + val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier + write.table.failAnalysis(s"Table or view not found: ${tblName.quoted}") + case u: UnresolvedV2Relation if isView(u.originalNameParts) => u.failAnalysis( s"Invalid command: '${u.originalNameParts.quoted}' is a view not a table.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index fb8a9be80385..94d4e7ecfac2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -53,6 +53,7 @@ trait V2WriteCommand extends Command { } def withNewQuery(newQuery: LogicalPlan): V2WriteCommand + def withNewTable(newTable: NamedRelation): V2WriteCommand } /** @@ -64,6 +65,7 @@ case class AppendData( writeOptions: Map[String, String], isByName: Boolean) extends V2WriteCommand { override def withNewQuery(newQuery: LogicalPlan): AppendData = copy(query = newQuery) + override def withNewTable(newTable: NamedRelation): AppendData = copy(table = newTable) } object AppendData { @@ -97,6 +99,9 @@ case class OverwriteByExpression( override def withNewQuery(newQuery: LogicalPlan): OverwriteByExpression = { copy(query = newQuery) } + override def withNewTable(newTable: NamedRelation): OverwriteByExpression = { + copy(table = newTable) + } } object OverwriteByExpression { @@ -128,6 +133,9 @@ case class OverwritePartitionsDynamic( override def withNewQuery(newQuery: LogicalPlan): OverwritePartitionsDynamic = { copy(query = newQuery) } + override def withNewTable(newTable: NamedRelation): OverwritePartitionsDynamic = { + copy(table = newTable) + } } object OverwritePartitionsDynamic { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala index 87f35410172d..d55b5c310353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala @@ -21,12 +21,11 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NoSuchTableException, TableAlreadyExistsException, UnresolvedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, Bucket, Days, Hours, Literal, Months, Years} import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelectStatement, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelectStatement} import org.apache.spark.sql.connector.expressions.{LogicalExpressions, NamedReference, Transform} import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.types.IntegerType /** @@ -38,21 +37,12 @@ import org.apache.spark.sql.types.IntegerType final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) extends CreateTableWriter[T] { - import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ - import org.apache.spark.sql.connector.catalog.CatalogV2Util._ - import df.sparkSession.sessionState.analyzer.CatalogAndIdentifier - private val df: DataFrame = ds.toDF() private val sparkSession = ds.sparkSession private val tableName = sparkSession.sessionState.sqlParser.parseMultipartIdentifier(table) - private val (catalog, identifier) = { - val CatalogAndIdentifier(catalog, identifier) = tableName - (catalog.asTableCatalog, identifier) - } - private val logicalPlan = df.queryExecution.logical private var provider: Option[String] = None @@ -153,15 +143,7 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) */ @throws(classOf[NoSuchTableException]) def append(): Unit = { - val append = loadTable(catalog, identifier) match { - case Some(t) => - AppendData.byName( - DataSourceV2Relation.create(t, Some(catalog), Some(identifier)), - logicalPlan, options.toMap) - case _ => - throw new NoSuchTableException(identifier) - } - + val append = AppendData.byName(UnresolvedRelation(tableName), logicalPlan, options.toMap) runCommand("append")(append) } @@ -177,15 +159,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) */ @throws(classOf[NoSuchTableException]) def overwrite(condition: Column): Unit = { - val overwrite = loadTable(catalog, identifier) match { - case Some(t) => - OverwriteByExpression.byName( - DataSourceV2Relation.create(t, Some(catalog), Some(identifier)), - logicalPlan, condition.expr, options.toMap) - case _ => - throw new NoSuchTableException(identifier) - } - + val overwrite = OverwriteByExpression.byName( + UnresolvedRelation(tableName), logicalPlan, condition.expr, options.toMap) runCommand("overwrite")(overwrite) } @@ -204,15 +179,8 @@ final class DataFrameWriterV2[T] private[sql](table: String, ds: Dataset[T]) */ @throws(classOf[NoSuchTableException]) def overwritePartitions(): Unit = { - val dynamicOverwrite = loadTable(catalog, identifier) match { - case Some(t) => - OverwritePartitionsDynamic.byName( - DataSourceV2Relation.create(t, Some(catalog), Some(identifier)), - logicalPlan, options.toMap) - case _ => - throw new NoSuchTableException(identifier) - } - + val dynamicOverwrite = OverwritePartitionsDynamic.byName( + UnresolvedRelation(tableName), logicalPlan, options.toMap) runCommand("overwritePartitions")(dynamicOverwrite) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 8720c1f62056..de791383326f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.sources.FakeSourceOne import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener @@ -57,6 +58,7 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo } after { + spark.sessionState.catalog.reset() spark.sessionState.catalogManager.reset() spark.sessionState.conf.clear() } @@ -118,6 +120,18 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) } + test("Append: write to a temp view of v2 relation") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + spark.table("testcat.table_name").createOrReplaceTempView("temp_view") + spark.table("source").writeTo("temp_view").append() + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + checkAnswer( + spark.table("temp_view"), + Seq(Row(1L, "a"), Row(2L, "b"), Row(3L, "c"))) + } + test("Append: by name not position") { spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") @@ -136,11 +150,36 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo } test("Append: fail if table does not exist") { - val exc = intercept[NoSuchTableException] { + val exc = intercept[AnalysisException] { spark.table("source").writeTo("testcat.table_name").append() } - assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("Table or view not found: testcat.table_name")) + } + + test("Append: fail if it writes to a temp view that is not v2 relation") { + spark.range(10).createOrReplaceTempView("temp_view") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("temp_view").append() + } + assert(exc.getMessage.contains("Cannot write into temp view temp_view as it's not a " + + "data source v2 relation")) + } + + test("Append: fail if it writes to a view") { + spark.sql("CREATE VIEW v AS SELECT 1") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("v").append() + } + assert(exc.getMessage.contains("Writing into a view is not allowed")) + } + + test("Append: fail if it writes to a v1 table") { + sql(s"CREATE TABLE table_name USING ${classOf[FakeSourceOne].getName}") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("table_name").append() + } + assert(exc.getMessage.contains("Cannot write into v1 table: `default`.`table_name`")) } test("Overwrite: overwrite by expression: true") { @@ -181,6 +220,20 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo Seq(Row(1L, "a"), Row(2L, "b"), Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) } + test("Overwrite: write to a temp view of v2 relation") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + spark.table("source").writeTo("testcat.table_name").append() + spark.table("testcat.table_name").createOrReplaceTempView("temp_view") + + spark.table("source2").writeTo("testcat.table_name").overwrite(lit(true)) + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + checkAnswer( + spark.table("temp_view"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + test("Overwrite: by name not position") { spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") @@ -200,11 +253,36 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo } test("Overwrite: fail if table does not exist") { - val exc = intercept[NoSuchTableException] { + val exc = intercept[AnalysisException] { spark.table("source").writeTo("testcat.table_name").overwrite(lit(true)) } - assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("Table or view not found: testcat.table_name")) + } + + test("Overwrite: fail if it writes to a temp view that is not v2 relation") { + spark.range(10).createOrReplaceTempView("temp_view") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("temp_view").overwrite(lit(true)) + } + assert(exc.getMessage.contains("Cannot write into temp view temp_view as it's not a " + + "data source v2 relation")) + } + + test("Overwrite: fail if it writes to a view") { + spark.sql("CREATE VIEW v AS SELECT 1") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("v").overwrite(lit(true)) + } + assert(exc.getMessage.contains("Writing into a view is not allowed")) + } + + test("Overwrite: fail if it writes to a v1 table") { + sql(s"CREATE TABLE table_name USING ${classOf[FakeSourceOne].getName}") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("table_name").overwrite(lit(true)) + } + assert(exc.getMessage.contains("Cannot write into v1 table: `default`.`table_name`")) } test("OverwritePartitions: overwrite conflicting partitions") { @@ -245,6 +323,20 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) } + test("OverwritePartitions: write to a temp view of v2 relation") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + spark.table("source").writeTo("testcat.table_name").append() + spark.table("testcat.table_name").createOrReplaceTempView("temp_view") + + spark.table("source2").writeTo("testcat.table_name").overwritePartitions() + checkAnswer( + spark.table("testcat.table_name"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + checkAnswer( + spark.table("temp_view"), + Seq(Row(4L, "d"), Row(5L, "e"), Row(6L, "f"))) + } + test("OverwritePartitions: by name not position") { spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") @@ -264,11 +356,36 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo } test("OverwritePartitions: fail if table does not exist") { - val exc = intercept[NoSuchTableException] { + val exc = intercept[AnalysisException] { spark.table("source").writeTo("testcat.table_name").overwritePartitions() } - assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("Table or view not found: testcat.table_name")) + } + + test("OverwritePartitions: fail if it writes to a temp view that is not v2 relation") { + spark.range(10).createOrReplaceTempView("temp_view") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("temp_view").overwritePartitions() + } + assert(exc.getMessage.contains("Cannot write into temp view temp_view as it's not a " + + "data source v2 relation")) + } + + test("OverwritePartitions: fail if it writes to a view") { + spark.sql("CREATE VIEW v AS SELECT 1") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("v").overwritePartitions() + } + assert(exc.getMessage.contains("Writing into a view is not allowed")) + } + + test("OverwritePartitions: fail if it writes to a v1 table") { + sql(s"CREATE TABLE table_name USING ${classOf[FakeSourceOne].getName}") + val exc = intercept[AnalysisException] { + spark.table("source").writeTo("table_name").overwritePartitions() + } + assert(exc.getMessage.contains("Cannot write into v1 table: `default`.`table_name`")) } test("Create: basic behavior") {