diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index ae82670a63d5c..af7ddd756ae89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -22,22 +22,22 @@ import java.util.{Locale, Properties, UUID} import scala.collection.JavaConverters._ import org.apache.spark.annotation.Stable -import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier} +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, NoSuchTableException, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, InsertIntoTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, ReplaceTableAsSelect} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister} -import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.TableCapability._ -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{IntegerType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -360,6 +360,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { */ def insertInto(tableName: String): Unit = { import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ assertNotBucketed("insertInto") @@ -374,8 +375,12 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { df.sparkSession.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { case CatalogObjectIdentifier(Some(catalog), ident) => insertInto(catalog, ident) + // TODO(SPARK-28667): Support the V2SessionCatalog case AsTableIdentifier(tableIdentifier) => insertInto(tableIdentifier) + case other => + throw new AnalysisException( + s"Couldn't find a catalog to handle the identifier ${other.quoted}.") } } @@ -485,7 +490,71 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def saveAsTable(tableName: String): Unit = { - saveAsTable(df.sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName)) + import df.sparkSession.sessionState.analyzer.{AsTableIdentifier, CatalogObjectIdentifier} + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + val session = df.sparkSession + + session.sessionState.sqlParser.parseMultipartIdentifier(tableName) match { + case CatalogObjectIdentifier(Some(catalog), ident) => + saveAsTable(catalog.asTableCatalog, ident, modeForDSV2) + // TODO(SPARK-28666): This should go through V2SessionCatalog + + case AsTableIdentifier(tableIdentifier) => + saveAsTable(tableIdentifier) + + case other => + throw new AnalysisException( + s"Couldn't find a catalog to handle the identifier ${other.quoted}.") + } + } + + + private def saveAsTable(catalog: TableCatalog, ident: Identifier, mode: SaveMode): Unit = { + val partitioning = partitioningColumns.map { colNames => + colNames.map(name => IdentityTransform(FieldReference(name))) + }.getOrElse(Seq.empty[Transform]) + val bucketing = bucketColumnNames.map { cols => + Seq(BucketTransform(LiteralValue(numBuckets.get, IntegerType), cols.map(FieldReference(_)))) + }.getOrElse(Seq.empty[Transform]) + val partitionTransforms = partitioning ++ bucketing + + val tableOpt = try Option(catalog.loadTable(ident)) catch { + case _: NoSuchTableException => None + } + + val command = (mode, tableOpt) match { + case (SaveMode.Append, Some(table)) => + AppendData.byName(DataSourceV2Relation.create(table), df.logicalPlan) + + case (SaveMode.Overwrite, _) => + ReplaceTableAsSelect( + catalog, + ident, + partitionTransforms, + df.queryExecution.analyzed, + Map.empty, // properties can't be specified through this API + extraOptions.toMap, + orCreate = true) // Create the table if it doesn't exist + + case (other, _) => + // We have a potential race condition here in AppendMode, if the table suddenly gets + // created between our existence check and physical execution, but this can't be helped + // in any case. + CreateTableAsSelect( + catalog, + ident, + partitionTransforms, + df.queryExecution.analyzed, + Map.empty, + extraOptions.toMap, + ignoreIfExists = other == SaveMode.Ignore) + } + + runCommand(df.sparkSession, "saveAsTable") { + command + } } private def saveAsTable(tableIdent: TableIdentifier): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala index 755cabc620023..8909c41ddaa8f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2DataFrameSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sources.v2 import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} import org.apache.spark.sql.test.SharedSQLContext @@ -141,4 +141,66 @@ class DataSourceV2DataFrameSuite extends QueryTest with SharedSQLContext with Be } } } + + testQuietly("saveAsTable: table doesn't exist => create table") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.saveAsTable(t1) + checkAnswer(spark.table(t1), df) + } + } + + testQuietly("saveAsTable: table exists => append by name") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 (id bigint, data string) USING foo") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + // Default saveMode is append, therefore this doesn't throw a table already exists exception + df.write.saveAsTable(t1) + checkAnswer(spark.table(t1), df) + + // also appends are by name not by position + df.select('data, 'id).write.saveAsTable(t1) + checkAnswer(spark.table(t1), df.union(df)) + } + } + + testQuietly("saveAsTable: table overwrite and table doesn't exist => create table") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.mode("overwrite").saveAsTable(t1) + checkAnswer(spark.table(t1), df) + } + } + + testQuietly("saveAsTable: table overwrite and table exists => replace table") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + sql(s"CREATE TABLE $t1 USING foo AS SELECT 'c', 'd'") + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.mode("overwrite").saveAsTable(t1) + checkAnswer(spark.table(t1), df) + } + } + + testQuietly("saveAsTable: ignore mode and table doesn't exist => create table") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + df.write.mode("ignore").saveAsTable(t1) + checkAnswer(spark.table(t1), df) + } + } + + testQuietly("saveAsTable: ignore mode and table exists => do nothing") { + val t1 = "testcat.ns1.ns2.tbl" + withTable(t1) { + val df = Seq((1L, "a"), (2L, "b"), (3L, "c")).toDF("id", "data") + sql(s"CREATE TABLE $t1 USING foo AS SELECT 'c', 'd'") + df.write.mode("ignore").saveAsTable(t1) + checkAnswer(spark.table(t1), Seq(Row("c", "d"))) + } + } }