Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -408,16 +408,6 @@ case class DataSource(
val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive)

// SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does
// not need to have the query as child, to avoid to analyze an optimized query,
// because InsertIntoHadoopFsRelationCommand will be optimized first.
val partitionAttributes = partitionColumns.map { name =>
val plan = data.logicalPlan
plan.resolve(name :: Nil, data.sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${plan.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}
val fileIndex = catalogTable.map(_.identifier).map { tableIdent =>
sparkSession.table(tableIdent).queryExecution.analyzed.collect {
case LogicalRelation(t: HadoopFsRelation, _, _) => t.location
Expand All @@ -431,7 +421,7 @@ case class DataSource(
outputPath = outputPath,
staticPartitions = Map.empty,
ifPartitionNotExists = false,
partitionColumns = partitionAttributes,
partitionColumns = partitionColumns,
bucketSpec = bucketSpec,
fileFormat = format,
options = options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,15 +188,13 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast
"Cannot overwrite a path that is also being read from.")
}

val partitionSchema = actualQuery.resolve(
t.partitionSchema, t.sparkSession.sessionState.analyzer.resolver)
val staticPartitions = parts.filter(_._2.nonEmpty).map { case (k, v) => k -> v.get }

InsertIntoHadoopFsRelationCommand(
outputPath,
staticPartitions,
i.ifPartitionNotExists,
partitionSchema,
partitionColumns = t.partitionSchema.map(_.name),
t.bucketSpec,
t.fileFormat,
t.options,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging {
committer: FileCommitProtocol,
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
partitionColumnNames: Seq[String],
bucketSpec: Option[BucketSpec],
refreshFunction: (Seq[TablePartitionSpec]) => Unit,
options: Map[String, String]): Unit = {
Expand All @@ -111,9 +111,18 @@ object FileFormatWriter extends Logging {
job.setOutputValueClass(classOf[InternalRow])
FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath))

val allColumns = queryExecution.logical.output
val allColumns = queryExecution.executedPlan.output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is problematic. The physical plan may have different schema from logical plan(schema name may be different), and the writer should respect the logical schema as that what users expects.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We should always use analyzed.output

// Get the actual partition columns as attributes after matching them by name with
// the given columns names.
val partitionColumns = partitionColumnNames.map { col =>
val nameEquality = sparkSession.sessionState.conf.resolver
allColumns.find(f => nameEquality(f.name, col)).getOrElse {
throw new RuntimeException(
s"Partition column $col not found in schema ${queryExecution.executedPlan.schema}")
}
}
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = queryExecution.logical.output.filterNot(partitionSet.contains)
val dataColumns = allColumns.filterNot(partitionSet.contains)

val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ case class InsertIntoHadoopFsRelationCommand(
outputPath: Path,
staticPartitions: TablePartitionSpec,
ifPartitionNotExists: Boolean,
partitionColumns: Seq[Attribute],
partitionColumns: Seq[String],
bucketSpec: Option[BucketSpec],
fileFormat: FileFormat,
options: Map[String, String],
Expand Down Expand Up @@ -150,7 +150,7 @@ case class InsertIntoHadoopFsRelationCommand(
outputSpec = FileFormatWriter.OutputSpec(
qualifiedOutputPath.toString, customPartitionLocations),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
partitionColumnNames = partitionColumns,
bucketSpec = bucketSpec,
refreshFunction = refreshPartitionsCallback,
options = options)
Expand All @@ -176,10 +176,10 @@ case class InsertIntoHadoopFsRelationCommand(
customPartitionLocations: Map[TablePartitionSpec, String],
committer: FileCommitProtocol): Unit = {
val staticPartitionPrefix = if (staticPartitions.nonEmpty) {
"/" + partitionColumns.flatMap { p =>
staticPartitions.get(p.name) match {
"/" + partitionColumns.flatMap { col =>
staticPartitions.get(col) match {
case Some(value) =>
Some(escapePathName(p.name) + "=" + escapePathName(value))
Some(escapePathName(col) + "=" + escapePathName(value))
case None =>
None
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,11 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
val resolver = sparkSession.sessionState.conf.resolver
val tableCols = existingTable.schema.map(_.name)

// As we are inserting into an existing table, we should respect the existing schema and
// adjust the column order of the given dataframe according to it, or throw exception
// if the column names do not match.
// As we are inserting into an existing table, we should respect the existing schema, preserve
// the case and adjust the column order of the given DataFrame according to it, or throw
// an exception if the column names do not match.
val adjustedColumns = tableCols.map { col =>
query.resolve(Seq(col), resolver).getOrElse {
query.resolve(Seq(col), resolver).map(Alias(_, col)()).getOrElse {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add an alias for enforcing the query to preserve the original name of table schema, whose case could be different from the underlying query schema

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch!

val inputColumns = query.schema.map(_.name).mkString(", ")
throw new AnalysisException(
s"cannot resolve '$col' given input columns: [$inputColumns]")
Expand Down Expand Up @@ -168,15 +168,9 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi
""".stripMargin)
}

val newQuery = if (adjustedColumns != query.output) {
Project(adjustedColumns, query)
} else {
query
}

c.copy(
tableDesc = existingTable,
query = Some(newQuery))
query = Some(Project(adjustedColumns, query)))

// Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity
// config, and do various checks:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,23 +111,14 @@ class FileStreamSink(
case _ => // Do nothing
}

// Get the actual partition columns as attributes after matching them by name with
// the given columns names.
val partitionColumns: Seq[Attribute] = partitionColumnNames.map { col =>
val nameEquality = data.sparkSession.sessionState.conf.resolver
data.logicalPlan.output.find(f => nameEquality(f.name, col)).getOrElse {
throw new RuntimeException(s"Partition column $col not found in schema ${data.schema}")
}
}

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = data.queryExecution,
fileFormat = fileFormat,
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(path, Map.empty),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
partitionColumnNames = partitionColumnNames,
bucketSpec = None,
refreshFunction = _ => (),
options = options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,14 @@ case class InsertIntoHiveTable(
outputPath = tmpLocation.toString,
isAppend = false)

val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name =>
query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse {
throw new AnalysisException(
s"Unable to resolve $name given [${query.output.map(_.name).mkString(", ")}]")
}.asInstanceOf[Attribute]
}

FileFormatWriter.write(
sparkSession = sparkSession,
queryExecution = Dataset.ofRows(sparkSession, query).queryExecution,
fileFormat = new HiveFileFormat(fileSinkConf),
committer = committer,
outputSpec = FileFormatWriter.OutputSpec(tmpLocation.toString, Map.empty),
hadoopConf = hadoopConf,
partitionColumns = partitionAttributes,
partitionColumnNames = partitionColumnNames.takeRight(numDynamicPartitions),
bucketSpec = None,
refreshFunction = _ => (),
options = Map.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,28 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef
}
}

test("SPARK-21165: the query schema of INSERT is changed after optimization") {
withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) {
withTable("tab1", "tab2") {
Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1")

spark.sql(
"""
|CREATE TABLE tab2 (word string, length int)
|PARTITIONED BY (first string)
""".stripMargin)

spark.sql(
"""
|INSERT INTO TABLE tab2 PARTITION(first)
|SELECT word, length, cast(first as string) as first FROM tab1
""".stripMargin)

checkAnswer(spark.table("tab2"), Row("a", 3, "b"))
}
}
}

testPartitionedTable("insertInto() should reject extra columns") {
tableName =>
sql("CREATE TABLE t (a INT, b INT, c INT, d INT, e INT)")
Expand Down