diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 96f2e38946f1c..50c0612035af5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -451,8 +451,24 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _, _) if child.resolved => i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) + case i @ InsertIntoTable(table, parts, child, _, _, None) if child.resolved => + val staticPartCols = parts.filter(_._2.isDefined).keySet + + val resolvedPartCols = staticPartCols.map { p => + table.resolve(Seq(p), resolver).getOrElse { + throw new AnalysisException( + s"Can't resolve static partition column $p on table $table") + } + } + + val expectedColumns = if (table.output.isEmpty) { + None + } else { + Some(table.output.filterNot(a => resolvedPartCols.exists(a.semanticEquals(_)))) + } + i.copy(expectedColumns = expectedColumns) case u: UnresolvedRelation => val table = u.tableIdentifier if (table.database.isDefined && conf.runSQLonFile && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b451baaa02b9..edce9c4e22bfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -313,7 +313,7 @@ trait CheckAnalysis extends PredicateHelper { |${s.catalogTable.identifier} """.stripMargin) - case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) => + case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) => failAnalysis( s""" |Hive support is required to insert into the following tables: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ff3dcbc957ac1..380a067fa9554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -359,32 +359,23 @@ case class InsertIntoTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + expectedColumns: Option[Seq[Attribute]] = None) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty - private[spark] lazy val expectedColumns = { - if (table.output.isEmpty) { - None - } else { - // Note: The parser (visitPartitionSpec in AstBuilder) already turns - // keys in partition to their lowercase forms. - val staticPartCols = partition.filter(_._2.isDefined).keySet - Some(table.output.filterNot(a => staticPartCols.contains(a.name))) - } - } - assert(overwrite || !ifNotExists) assert(partition.values.forall(_.nonEmpty) || !ifNotExists) override lazy val resolved: Boolean = - childrenResolved && table.resolved && expectedColumns.forall { expected => - child.output.size == expected.size && child.output.zip(expected).forall { - case (childAttr, tableAttr) => + childrenResolved && table.resolved && (expectedColumns.isDefined || table.output.isEmpty) && + expectedColumns.forall { expected => + child.output.size == expected.size && child.output.zip(expected).forall { + case (childAttr, tableAttr) => DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + } } - } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 102c78bd72111..d56f1b23e184b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -377,4 +378,31 @@ class AnalysisSuite extends AnalysisTest { assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) } + + test("InsertIntoTable's expectedColumns support case-insensitive resolution properly") { + val data = LocalRelation( + AttributeReference("a", StringType)(), + AttributeReference("b", StringType)(), + AttributeReference("c", DoubleType)(), + AttributeReference("d", DecimalType(10, 2))()) + + val insertIntoTable = InsertIntoTable(testRelation2, + Map("E" -> Some("1")), data, overwrite = false, ifNotExists = false) + + val caseSensitiveAnalyzer = getAnalyzer(true) + val caseInSensitiveAnalyzer = getAnalyzer(false) + intercept[AnalysisException] { + caseSensitiveAnalyzer.execute(insertIntoTable) + } + val caseInSensitiveAnalysisAttempt = caseInSensitiveAnalyzer.execute(insertIntoTable) + caseInSensitiveAnalyzer.checkAnalysis(caseInSensitiveAnalysisAttempt) + + val expectedColumns = + caseInSensitiveAnalysisAttempt.asInstanceOf[InsertIntoTable].expectedColumns + + assert(expectedColumns.isDefined && expectedColumns.get.length == data.output.length) + expectedColumns.get.zip(data.output).map { case (expected, output) => + DataType.equalsIgnoreCompatibleNullability(expected.dataType, output.dataType) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2b4786542c72f..436e7d10bc4d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -46,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) + l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false, _) if query.resolved && t.schema.asNullable == query.schema.asNullable => // Sanity checks @@ -110,7 +110,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[ } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) + case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) if DDLUtils.isDatasourceTable(s.metadata) => i.copy(table = readDataSourceTable(sparkSession, s.metadata)) @@ -152,7 +152,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), - part, query, overwrite, false) if part.isEmpty => + part, query, overwrite, false, _) if part.isEmpty => ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 10425af3e1f18..00953e26450a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -128,7 +128,8 @@ private[sql] case class PreprocessTableInsertion(conf: SQLConf) extends Rule[Log } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ InsertIntoTable(table, partition, child, _, _) if table.resolved && child.resolved => + case i @ InsertIntoTable(table, partition, child, _, _, _) + if table.resolved && child.resolved => table match { case relation: CatalogRelation => val metadata = relation.catalogTable @@ -156,7 +157,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) plan.foreach { case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation, _, _), - partition, query, overwrite, ifNotExists) => + partition, query, overwrite, ifNotExists, _) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") @@ -174,7 +175,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) } case logical.InsertIntoTable( - LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) => + LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. val existingPartitionColumns = r.partitionSchema.fieldNames.toSet @@ -202,11 +203,11 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // OK } - case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") - case logical.InsertIntoTable(t, _, _, _, _) => + case logical.InsertIntoTable(t, _, _, _, _, _) => if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) { failAnalysis(s"Inserting into an RDD-based table is not allowed.") } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 2e0b5d59b5783..ab6aed832c44c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -367,10 +367,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + expectedCols) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + expectedCols) // Write path case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) @@ -411,10 +413,12 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + expectedCols) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + expectedCols) // Write path case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 71b180e55b58c..30a37e81be85c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -43,7 +43,7 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case hive.InsertIntoHiveTable(