Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ package object dsl {
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
analysis.UnresolvedRelation(TableIdentifier(tableName)),
Map.empty, logicalPlan, OverwriteOptions(overwrite), false)
Map.empty, logicalPlan, overwrite, false)

def as(alias: String): LogicalPlan = logicalPlan match {
case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,12 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
throw new ParseException(s"Dynamic partitions do not support IF NOT EXISTS. Specified " +
"partitions with value: " + dynamicPartitionKeys.keys.mkString("[", ",", "]"), ctx)
}
val overwrite = ctx.OVERWRITE != null
val staticPartitionKeys: Map[String, String] =
partitionKeys.filter(_._2.nonEmpty).map(t => (t._1, t._2.get))

InsertIntoTable(
UnresolvedRelation(tableIdent, None),
partitionKeys,
query,
OverwriteOptions(overwrite, if (overwrite) staticPartitionKeys else Map.empty),
ctx.OVERWRITE != null,
ctx.EXISTS != null)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,8 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.catalog.CatalogTypes
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -346,22 +344,6 @@ case class BroadcastHint(child: LogicalPlan) extends UnaryNode {
override lazy val statistics: Statistics = super.statistics.copy(isBroadcastable = true)
}

/**
* Options for writing new data into a table.
*
* @param enabled whether to overwrite existing data in the table.
* @param staticPartitionKeys if non-empty, specifies that we only want to overwrite partitions
* that match this partial partition spec. If empty, all partitions
* will be overwritten.
*/
case class OverwriteOptions(
enabled: Boolean,
staticPartitionKeys: CatalogTypes.TablePartitionSpec = Map.empty) {
if (staticPartitionKeys.nonEmpty) {
assert(enabled, "Overwrite must be enabled when specifying specific partitions.")
}
}

/**
* Insert some data into a table.
*
Expand All @@ -382,14 +364,14 @@ case class InsertIntoTable(
table: LogicalPlan,
partition: Map[String, Option[String]],
child: LogicalPlan,
overwrite: OverwriteOptions,
overwrite: Boolean,
ifNotExists: Boolean)
extends LogicalPlan {

override def children: Seq[LogicalPlan] = child :: Nil
override def output: Seq[Attribute] = Seq.empty

assert(overwrite.enabled || !ifNotExists)
assert(overwrite || !ifNotExists)
assert(partition.values.forall(_.nonEmpty) || !ifNotExists)

override lazy val resolved: Boolean = childrenResolved && table.resolved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,16 +180,7 @@ class PlanParserSuite extends PlanTest {
partition: Map[String, Option[String]],
overwrite: Boolean = false,
ifNotExists: Boolean = false): LogicalPlan =
InsertIntoTable(
table("s"), partition, plan,
OverwriteOptions(
overwrite,
if (overwrite && partition.nonEmpty) {
partition.map(kv => (kv._1, kv._2.get))
} else {
Map.empty
}),
ifNotExists)
InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists)

// Single inserts
assertEqual(s"insert overwrite table s $sql",
Expand All @@ -205,9 +196,9 @@ class PlanParserSuite extends PlanTest {
val plan2 = table("t").where('x > 5).select(star())
assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5",
InsertIntoTable(
table("s"), Map.empty, plan.limit(1), OverwriteOptions(false), ifNotExists = false).union(
table("s"), Map.empty, plan.limit(1), false, ifNotExists = false).union(
InsertIntoTable(
table("u"), Map.empty, plan2, OverwriteOptions(false), ifNotExists = false)))
table("u"), Map.empty, plan2, false, ifNotExists = false)))
}

test ("insert with if not exists") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, OverwriteOptions}
import org.apache.spark.sql.catalyst.plans.logical.InsertIntoTable
import org.apache.spark.sql.execution.command.{AlterTableRecoverPartitionsCommand, DDLUtils}
import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, HadoopFsRelation}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -259,7 +259,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
table = UnresolvedRelation(tableIdent),
partition = Map.empty[String, Option[String]],
child = df.logicalPlan,
overwrite = OverwriteOptions(mode == SaveMode.Overwrite),
overwrite = mode == SaveMode.Overwrite,
ifNotExists = false)).toRdd
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ case class DataSource(
val plan =
InsertIntoHadoopFsRelationCommand(
outputPath = outputPath,
staticPartitionKeys = Map.empty,
staticPartitions = Map.empty,
customPartitionLocations = Map.empty,
partitionColumns = columns,
bucketSpec = bucketSpec,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,15 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTablePartition, SimpleCatalogRelation}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
Expand Down Expand Up @@ -100,7 +99,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
None
} else if (potentialSpecs.size == 1) {
val partValue = potentialSpecs.head._2
Some(Alias(Cast(Literal(partValue), field.dataType), "_staticPart")())
Some(Alias(Cast(Literal(partValue), field.dataType), field.name)())
} else {
throw new AnalysisException(
s"Partition column ${field.name} have multiple values specified, " +
Expand Down Expand Up @@ -128,61 +127,75 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
projectList
}

/**
* Returns true if the [[InsertIntoTable]] plan has already been preprocessed by analyzer rule
* [[PreprocessTableInsertion]]. It is important that this rule([[DataSourceAnalysis]]) has to
* be run after [[PreprocessTableInsertion]], to normalize the column names in partition spec and
* fix the schema mismatch by adding Cast.
*/
private def hasBeenPreprocessed(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a comment about who preprocesses this?

tableOutput: Seq[Attribute],
partSchema: StructType,
partSpec: Map[String, Option[String]],
query: LogicalPlan): Boolean = {
val partColNames = partSchema.map(_.name).toSet
query.resolved && partSpec.keys.forall(partColNames.contains) && {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to check that the keys are all valid columns?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, is the issue to avoid this running before PreprocessTableInsertion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

val staticPartCols = partSpec.filter(_._2.isDefined).keySet
val expectedColumns = tableOutput.filterNot(a => staticPartCols.contains(a.name))
expectedColumns.toStructType.sameType(query.schema)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar question, when is this false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is to follow the previous condition: https://github.com/apache/spark/pull/15995/files#diff-d99813bd5bbc18277e4090475e4944cfL166

This can be caused if users issue an invalid command, e.g. INSERT INTO src SELECT 1,2 while table src has 3 columns.

}
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
//
// Example:
// Let's say that we have a table "t", which is created by
// CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)
// The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3"
// will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3".
//
// Basically, we will put those partition columns having a assigned value back
// to the SELECT clause. The output of the SELECT clause is organized as
// normal_columns static_partitioning_columns dynamic_partitioning_columns.
// static_partitioning_columns are partitioning columns having assigned
// values in the PARTITION clause (e.g. b in the above example).
// dynamic_partitioning_columns are partitioning columns that do not assigned
// values in the PARTITION clause (e.g. c in the above example).
case insert @ logical.InsertIntoTable(
relation @ LogicalRelation(t: HadoopFsRelation, _, _), parts, query, overwrite, false)
if query.resolved && parts.exists(_._2.isDefined) =>

val projectList = convertStaticPartitions(
sourceAttributes = query.output,
providedPartitions = parts,
targetAttributes = relation.output,
targetPartitionSchema = t.partitionSchema)

// We will remove all assigned values to static partitions because they have been
// moved to the projectList.
insert.copy(partition = parts.map(p => (p._1, None)), child = Project(projectList, query))


case logical.InsertIntoTable(
l @ LogicalRelation(t: HadoopFsRelation, _, table), _, query, overwrite, false)
if query.resolved && t.schema.sameType(query.schema) =>

// Sanity checks
case InsertIntoTable(
l @ LogicalRelation(t: HadoopFsRelation, _, table), parts, query, overwrite, false)
if hasBeenPreprocessed(l.output, t.partitionSchema, parts, query) =>

// If the InsertIntoTable command is for a partitioned HadoopFsRelation and
// the user has specified static partitions, we add a Project operator on top of the query
// to include those constant column values in the query result.
//
// Example:
// Let's say that we have a table "t", which is created by
// CREATE TABLE t (a INT, b INT, c INT) USING parquet PARTITIONED BY (b, c)
// The statement of "INSERT INTO TABLE t PARTITION (b=2, c) SELECT 1, 3"
// will be converted to "INSERT INTO TABLE t PARTITION (b, c) SELECT 1, 2, 3".
//
// Basically, we will put those partition columns having a assigned value back
// to the SELECT clause. The output of the SELECT clause is organized as
// normal_columns static_partitioning_columns dynamic_partitioning_columns.
// static_partitioning_columns are partitioning columns having assigned
// values in the PARTITION clause (e.g. b in the above example).
// dynamic_partitioning_columns are partitioning columns that do not assigned
// values in the PARTITION clause (e.g. c in the above example).
val actualQuery = if (parts.exists(_._2.isDefined)) {
val projectList = convertStaticPartitions(
sourceAttributes = query.output,
providedPartitions = parts,
targetAttributes = l.output,
targetPartitionSchema = t.partitionSchema)
Project(projectList, query)
} else {
query
}

// Sanity check
if (t.location.rootPaths.size != 1) {
throw new AnalysisException(
"Can only write data to relations with a single path.")
throw new AnalysisException("Can only write data to relations with a single path.")
}

val outputPath = t.location.rootPaths.head
val inputPaths = query.collect {
val inputPaths = actualQuery.collect {
case LogicalRelation(r: HadoopFsRelation, _, _) => r.location.rootPaths
}.flatten

val mode = if (overwrite.enabled) SaveMode.Overwrite else SaveMode.Append
if (overwrite.enabled && inputPaths.contains(outputPath)) {
val mode = if (overwrite) SaveMode.Overwrite else SaveMode.Append
if (overwrite && inputPaths.contains(outputPath)) {
throw new AnalysisException(
"Cannot overwrite a path that is also being read from.")
}

val partitionSchema = query.resolve(
val partitionSchema = actualQuery.resolve(
t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
val partitionsTrackedByCatalog =
t.sparkSession.sessionState.conf.manageFilesourcePartitions &&
Expand All @@ -192,19 +205,13 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
var initialMatchingPartitions: Seq[TablePartitionSpec] = Nil
var customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty

val staticPartitionKeys: TablePartitionSpec = if (overwrite.enabled) {
overwrite.staticPartitionKeys.map { case (k, v) =>
(partitionSchema.map(_.name).find(_.equalsIgnoreCase(k)).get, v)
}
} else {
Map.empty
}
val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get }
Copy link
Contributor Author

@cloud-fan cloud-fan Dec 11, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The column names in partition spec are already normalized in PreprocessTableInsertion rule, we don't need to consider case sensitivity here. And the if-else is not needed, because:

  1. staticPartitions is used to get matchingPartitions in this line, and the matchingPartitions is used to decided which partitions need to be added to metastore. Previously if overwrite is false, we will get all partitions as matchingPartitions, and issue a lot of unnecessary ADD PARTITION calls. After removing the if-else, it's fixed.
  2. After we pass staticPartitions to InsertIntoHadoopFsRelationCommand, it will be used only with OverWrite mode, so the if-else is unnecessary.


// When partitions are tracked by the catalog, compute all custom partition locations that
// may be relevant to the insertion job.
if (partitionsTrackedByCatalog) {
val matchingPartitions = t.sparkSession.sessionState.catalog.listPartitions(
l.catalogTable.get.identifier, Some(staticPartitionKeys))
l.catalogTable.get.identifier, Some(staticPartitions))
initialMatchingPartitions = matchingPartitions.map(_.spec)
customPartitionLocations = getCustomPartitionLocations(
t.sparkSession, l.catalogTable.get, outputPath, matchingPartitions)
Expand All @@ -220,7 +227,7 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {
l.catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)),
ifNotExists = true).run(t.sparkSession)
}
if (overwrite.enabled) {
if (overwrite) {
val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions
if (deletedPartitions.nonEmpty) {
AlterTableDropPartitionCommand(
Expand All @@ -235,14 +242,14 @@ case class DataSourceAnalysis(conf: CatalystConf) extends Rule[LogicalPlan] {

val insertCmd = InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitionKeys,
staticPartitions,
customPartitionLocations,
partitionSchema,
t.bucketSpec,
t.fileFormat,
refreshPartitionsCallback,
t.options,
query,
actualQuery,
mode,
table)

Expand Down Expand Up @@ -305,7 +312,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
}

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _)
case i @ InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _)
if DDLUtils.isDatasourceTable(s.metadata) =>
i.copy(table = readDataSourceTable(sparkSession, s))

Expand Down Expand Up @@ -351,7 +358,7 @@ object DataSourceStrategy extends Strategy with Logging {
Map.empty,
None) :: Nil

case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
case InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _),
part, query, overwrite, false) if part.isEmpty =>
ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OverwriteOptions}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.sources.InsertableRelation

Expand All @@ -30,7 +30,7 @@ import org.apache.spark.sql.sources.InsertableRelation
case class InsertIntoDataSourceCommand(
logicalRelation: LogicalRelation,
query: LogicalPlan,
overwrite: OverwriteOptions)
overwrite: Boolean)
extends RunnableCommand {

override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query)
Expand All @@ -40,7 +40,7 @@ case class InsertIntoDataSourceCommand(
val data = Dataset.ofRows(sparkSession, query)
// Apply the schema of the existing table to the new data.
val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema)
relation.insert(df, overwrite.enabled)
relation.insert(df, overwrite)

// Invalidate the cache.
sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation)
Expand Down
Loading