Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ import scala.collection.JavaConverters._

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation}
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource}
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.StructType

/**
Expand Down Expand Up @@ -364,7 +365,11 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
throw new AnalysisException("Cannot create hive serde table with saveAsTable API")
}

val tableExists = df.sparkSession.sessionState.catalog.tableExists(tableIdent)
val catalog = df.sparkSession.sessionState.catalog
val tableExists = catalog.tableExists(tableIdent)
val db = tableIdent.database.getOrElse(catalog.getCurrentDatabase)
val tableIdentWithDB = tableIdent.copy(database = Some(db))
val tableName = tableIdentWithDB.unquotedString

(tableExists, mode) match {
case (true, SaveMode.Ignore) =>
Expand All @@ -373,39 +378,48 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
case (true, SaveMode.ErrorIfExists) =>
throw new AnalysisException(s"Table $tableIdent already exists.")

case _ =>
val existingTable = if (tableExists) {
Some(df.sparkSession.sessionState.catalog.getTableMetadata(tableIdent))
} else {
None
case (true, SaveMode.Overwrite) =>
// Get all input data source relations of the query.
val srcRelations = df.logicalPlan.collect {
case LogicalRelation(src: BaseRelation, _, _) => src
}
val storage = if (tableExists) {
existingTable.get.storage
} else {
DataSource.buildStorageFormatFromOptions(extraOptions.toMap)
}
val tableType = if (tableExists) {
existingTable.get.tableType
} else if (storage.locationUri.isDefined) {
CatalogTableType.EXTERNAL
} else {
CatalogTableType.MANAGED
EliminateSubqueryAliases(catalog.lookupRelation(tableIdentWithDB)) match {
// Only do the check if the table is a data source table (the relation is a BaseRelation).
case LogicalRelation(dest: BaseRelation, _, _) if srcRelations.contains(dest) =>
throw new AnalysisException(
s"Cannot overwrite table $tableName that is also being read from")
case _ => // OK
}

val tableDesc = CatalogTable(
identifier = tableIdent,
tableType = tableType,
storage = storage,
schema = new StructType,
provider = Some(source),
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec
)
df.sparkSession.sessionState.executePlan(
CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
// Drop the existing table
catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false)
createTable(tableIdent)

case _ => createTable(tableIdent)
}
}

private def createTable(tableIdent: TableIdentifier): Unit = {
val storage = DataSource.buildStorageFormatFromOptions(extraOptions.toMap)
val tableType = if (storage.locationUri.isDefined) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted #15983 here because it's not needed anymore after this refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part of that pr is reverted?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of it, except the test

CatalogTableType.EXTERNAL
} else {
CatalogTableType.MANAGED
}

val tableDesc = CatalogTable(
identifier = tableIdent,
tableType = tableType,
storage = storage,
schema = new StructType,
provider = Some(source),
partitionColumnNames = partitioningColumns.getOrElse(Nil),
bucketSpec = getBucketSpec
)
df.sparkSession.sessionState.executePlan(
CreateTable(tableDesc, mode, Some(df.logicalPlan))).toRdd
}

/**
* Saves the content of the `DataFrame` to an external database table via JDBC. In the case the
* table already exists in the external database, behavior of this function depends on the
Expand Down Expand Up @@ -441,7 +455,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
assertNotPartitioned("jdbc")
assertNotBucketed("jdbc")
// connectionProperties should override settings in extraOptions.
this.extraOptions = this.extraOptions ++ connectionProperties.asScala
this.extraOptions ++= connectionProperties.asScala
// explicit url and dbtable should override all
this.extraOptions += ("url" -> url, "dbtable" -> table)
format("jdbc").save()
Expand Down Expand Up @@ -588,7 +602,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {

private var mode: SaveMode = SaveMode.ErrorIfExists

private var extraOptions = new scala.collection.mutable.HashMap[String, String]
private val extraOptions = new scala.collection.mutable.HashMap[String, String]

private var partitioningColumns: Option[Seq[String]] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.command

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources.BaseRelation
Expand Down Expand Up @@ -134,146 +133,39 @@ case class CreateDataSourceTableAsSelectCommand(
assert(table.provider.isDefined)
assert(table.schema.isEmpty)

val provider = table.provider.get
val sessionState = sparkSession.sessionState
val db = table.identifier.database.getOrElse(sessionState.catalog.getCurrentDatabase)
val tableIdentWithDB = table.identifier.copy(database = Some(db))
val tableName = tableIdentWithDB.unquotedString

var createMetastoreTable = false
// We may need to reorder the columns of the query to match the existing table.
var reorderedColumns = Option.empty[Seq[NamedExpression]]
if (sessionState.catalog.tableExists(tableIdentWithDB)) {
// Check if we need to throw an exception or just return.
mode match {
case SaveMode.ErrorIfExists =>
throw new AnalysisException(s"Table $tableName already exists. " +
s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " +
s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" +
s"the existing data. " +
s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.")
case SaveMode.Ignore =>
// Since the table already exists and the save mode is Ignore, we will just return.
return Seq.empty[Row]
case SaveMode.Append =>
val existingTable = sessionState.catalog.getTableMetadata(tableIdentWithDB)
val result = if (sessionState.catalog.tableExists(tableIdentWithDB)) {
assert(mode != SaveMode.Overwrite,
s"Expect the table $tableName has been dropped when the save mode is Overwrite")

if (existingTable.provider.get == DDLUtils.HIVE_PROVIDER) {
throw new AnalysisException(s"Saving data in the Hive serde table $tableName is " +
"not supported yet. Please use the insertInto() API as an alternative.")
}

// Check if the specified data source match the data source of the existing table.
val existingProvider = DataSource.lookupDataSource(existingTable.provider.get)
val specifiedProvider = DataSource.lookupDataSource(table.provider.get)
// TODO: Check that options from the resolved relation match the relation that we are
// inserting into (i.e. using the same compression).
if (existingProvider != specifiedProvider) {
throw new AnalysisException(s"The format of the existing table $tableName is " +
s"`${existingProvider.getSimpleName}`. It doesn't match the specified format " +
s"`${specifiedProvider.getSimpleName}`.")
}

if (query.schema.length != existingTable.schema.length) {
throw new AnalysisException(
s"The column number of the existing table $tableName" +
s"(${existingTable.schema.catalogString}) doesn't match the data schema" +
s"(${query.schema.catalogString})")
}

val resolver = sessionState.conf.resolver
val tableCols = existingTable.schema.map(_.name)

reorderedColumns = Some(existingTable.schema.map { f =>
query.resolve(Seq(f.name), resolver).getOrElse {
val inputColumns = query.schema.map(_.name).mkString(", ")
throw new AnalysisException(
s"cannot resolve '${f.name}' given input columns: [$inputColumns]")
}
})

// In `AnalyzeCreateTable`, we verified the consistency between the user-specified table
// definition(partition columns, bucketing) and the SELECT query, here we also need to
// verify the the consistency between the user-specified table definition and the existing
// table definition.

// Check if the specified partition columns match the existing table.
val specifiedPartCols = CatalogUtils.normalizePartCols(
tableName, tableCols, table.partitionColumnNames, resolver)
if (specifiedPartCols != existingTable.partitionColumnNames) {
throw new AnalysisException(
s"""
|Specified partitioning does not match that of the existing table $tableName.
|Specified partition columns: [${specifiedPartCols.mkString(", ")}]
|Existing partition columns: [${existingTable.partitionColumnNames.mkString(", ")}]
""".stripMargin)
}

// Check if the specified bucketing match the existing table.
val specifiedBucketSpec = table.bucketSpec.map { bucketSpec =>
CatalogUtils.normalizeBucketSpec(tableName, tableCols, bucketSpec, resolver)
}
if (specifiedBucketSpec != existingTable.bucketSpec) {
val specifiedBucketString =
specifiedBucketSpec.map(_.toString).getOrElse("not bucketed")
val existingBucketString =
existingTable.bucketSpec.map(_.toString).getOrElse("not bucketed")
throw new AnalysisException(
s"""
|Specified bucketing does not match that of the existing table $tableName.
|Specified bucketing: $specifiedBucketString
|Existing bucketing: $existingBucketString
""".stripMargin)
}

case SaveMode.Overwrite =>
sessionState.catalog.dropTable(tableIdentWithDB, ignoreIfNotExists = true, purge = false)
// Need to create the table again.
createMetastoreTable = true
if (mode == SaveMode.ErrorIfExists) {
throw new AnalysisException(s"Table $tableName already exists. You need to drop it first.")
}
if (mode == SaveMode.Ignore) {
// Since the table already exists and the save mode is Ignore, we will just return.
return Seq.empty
}
} else {
// The table does not exist. We need to create it in metastore.
createMetastoreTable = true
}

val data = Dataset.ofRows(sparkSession, query)
val df = reorderedColumns match {
// Reorder the columns of the query to match the existing table.
case Some(cols) => data.select(cols.map(Column(_)): _*)
case None => data
}

val tableLocation = if (table.tableType == CatalogTableType.MANAGED) {
Some(sessionState.catalog.defaultTablePath(table.identifier))
saveDataIntoTable(sparkSession, table, table.storage.locationUri, query, mode)
} else {
table.storage.locationUri
}

// Create the relation based on the data of df.
val pathOption = tableLocation.map("path" -> _)
val dataSource = DataSource(
sparkSession,
className = provider,
partitionColumns = table.partitionColumnNames,
bucketSpec = table.bucketSpec,
options = table.storage.properties ++ pathOption,
catalogTable = Some(table))

val result = try {
dataSource.write(mode, df)
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table $tableName in $mode mode", ex)
throw ex
}
if (createMetastoreTable) {
val tableLocation = if (table.tableType == CatalogTableType.MANAGED) {
Some(sessionState.catalog.defaultTablePath(table.identifier))
} else {
table.storage.locationUri
}
val result = saveDataIntoTable(sparkSession, table, tableLocation, query, mode)
val newTable = table.copy(
storage = table.storage.copy(locationUri = tableLocation),
// We will use the schema of resolved.relation as the schema of the table (instead of
// the schema of df). It is important since the nullability may be changed by the relation
// provider (for example, see org.apache.spark.sql.parquet.DefaultSource).
schema = result.schema)
sessionState.catalog.createTable(newTable, ignoreIfExists = false)
result
}

result match {
Expand All @@ -289,4 +181,29 @@ case class CreateDataSourceTableAsSelectCommand(
sessionState.catalog.refreshTable(tableIdentWithDB)
Seq.empty[Row]
}

private def saveDataIntoTable(
session: SparkSession,
table: CatalogTable,
tableLocation: Option[String],
data: LogicalPlan,
mode: SaveMode): BaseRelation = {
// Create the relation based on the input logical plan: `data`.
val pathOption = tableLocation.map("path" -> _)
val dataSource = DataSource(
session,
className = table.provider.get,
partitionColumns = table.partitionColumnNames,
bucketSpec = table.bucketSpec,
options = table.storage.properties ++ pathOption,
catalogTable = Some(table))

try {
dataSource.write(mode, Dataset.ofRows(session, query))
} catch {
case ex: AnalysisException =>
logError(s"Failed to write to table ${table.identifier.unquotedString}", ex)
throw ex
}
}
}
Loading