diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala index 6afb51ba81e1..630a85e46229 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala @@ -107,10 +107,10 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl insertTableSchemaWithoutPartitionColumns.map { schema: StructType => val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema) - val expanded: LogicalPlan = + val (expanded: LogicalPlan, addedDefaults: Boolean) = addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.size) val replaced: Option[LogicalPlan] = - replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded) + replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults) replaced.map { r: LogicalPlan => node = r for (child <- children.reverse) { @@ -131,10 +131,10 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl insertTableSchemaWithoutPartitionColumns.map { schema => val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema) val project: Project = i.query.asInstanceOf[Project] - val expanded: Project = + val (expanded: Project, addedDefaults: Boolean) = addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.size) val replaced: Option[LogicalPlan] = - replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded) + replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults) replaced.map { r => regenerated.copy(query = r) }.getOrElse(i) @@ -270,67 +270,83 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl /** * Updates an inline table to generate missing default column values. + * Returns the resulting plan plus a boolean indicating whether such values were added. */ - private def addMissingDefaultValuesForInsertFromInlineTable( + def addMissingDefaultValuesForInsertFromInlineTable( node: LogicalPlan, insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int): LogicalPlan = { + numUserSpecifiedColumns: Int): (LogicalPlan, Boolean) = { val schema = insertTableSchemaWithoutPartitionColumns - val newDefaultExpressions: Seq[Expression] = - getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns) - val newNames: Seq[String] = if (numUserSpecifiedColumns > 0) { - schema.fields.drop(numUserSpecifiedColumns).map(_.name) - } else { - schema.fields.map(_.name) - } - node match { - case _ if newDefaultExpressions.isEmpty => node + val newDefaultExpressions: Seq[UnresolvedAttribute] = + getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, node.output.size) + val newNames: Seq[String] = schema.fields.map(_.name) + val resultPlan: LogicalPlan = node match { + case _ if newDefaultExpressions.isEmpty => + node case table: UnresolvedInlineTable => table.copy( - names = table.names ++ newNames, + names = newNames, rows = table.rows.map { row => row ++ newDefaultExpressions }) case local: LocalRelation => - // Note that we have consumed a LocalRelation but return an UnresolvedInlineTable, because - // addMissingDefaultValuesForInsertFromProject must replace unresolved DEFAULT references. - UnresolvedInlineTable( - local.output.map(_.name) ++ newNames, - local.data.map { row => - val colTypes = StructType(local.output.map(col => StructField(col.name, col.dataType))) - row.toSeq(colTypes).map(Literal(_)) ++ newDefaultExpressions + val newDefaultExpressionsRow = new GenericInternalRow( + // Note that this code path only runs when there is a user-specified column list of fewer + // column than the target table; otherwise, the above 'newDefaultExpressions' is empty and + // we match the first case in this list instead. + schema.fields.drop(local.output.size).map { + case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => + analyze(f, "INSERT") match { + case lit: Literal => lit.value + case _ => null + } + case _ => null }) - case _ => node + LocalRelation( + output = schema.toAttributes, + data = local.data.map { row => + new JoinedRow(row, newDefaultExpressionsRow) + }) + case _ => + node } + (resultPlan, newDefaultExpressions.nonEmpty) } /** * Adds a new expressions to a projection to generate missing default column values. + * Returns the logical plan plus a boolean indicating if such defaults were added. */ private def addMissingDefaultValuesForInsertFromProject( project: Project, insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int): Project = { + numUserSpecifiedColumns: Int): (Project, Boolean) = { val schema = insertTableSchemaWithoutPartitionColumns val newDefaultExpressions: Seq[Expression] = - getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns) + getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, project.projectList.size) val newAliases: Seq[NamedExpression] = newDefaultExpressions.zip(schema.fields).map { case (expr, field) => Alias(expr, field.name)() } - project.copy(projectList = project.projectList ++ newAliases) + (project.copy(projectList = project.projectList ++ newAliases), + newDefaultExpressions.nonEmpty) } /** * This is a helper for the addMissingDefaultValuesForInsertFromInlineTable methods above. */ - private def getDefaultExpressionsForInsert( - schema: StructType, - numUserSpecifiedColumns: Int): Seq[Expression] = { + private def getNewDefaultExpressionsForInsert( + insertTableSchemaWithoutPartitionColumns: StructType, + numUserSpecifiedColumns: Int, + numProvidedValues: Int): Seq[UnresolvedAttribute] = { val remainingFields: Seq[StructField] = if (numUserSpecifiedColumns > 0) { - schema.fields.drop(numUserSpecifiedColumns) + insertTableSchemaWithoutPartitionColumns.fields.drop(numUserSpecifiedColumns) } else { Seq.empty } val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size + // Limit the number of new DEFAULT expressions to the difference of the number of columns in + // the target table and the number of provided values in the source relation. This clamps the + // total final number of provided values to the number of columns in the target table. + .min(insertTableSchemaWithoutPartitionColumns.size - numProvidedValues) Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME)) } @@ -351,7 +367,8 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl */ private def replaceExplicitDefaultValuesForInputOfInsertInto( insertTableSchemaWithoutPartitionColumns: StructType, - input: LogicalPlan): Option[LogicalPlan] = { + input: LogicalPlan, + addedDefaults: Boolean): Option[LogicalPlan] = { val schema = insertTableSchemaWithoutPartitionColumns val defaultExpressions: Seq[Expression] = schema.fields.map { case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT") @@ -371,7 +388,11 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl case project: Project => replaceExplicitDefaultValuesForProject(defaultExpressions, project) case local: LocalRelation => - Some(local) + if (addedDefaults) { + Some(local) + } else { + None + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala new file mode 100644 index 000000000000..fc540e65593d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{StructField, StructType, TimestampType} + +class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { + val rule = ResolveDefaultColumns(catalog = null) + // This is the internal storage for the timestamp 2020-12-31 00:00:00.0. + val literal = Literal(1609401600000000L, TimestampType) + val table = UnresolvedInlineTable( + names = Seq("attr1"), + rows = Seq(Seq(literal))) + val localRelation = ResolveInlineTables(table).asInstanceOf[LocalRelation] + + def asLocalRelation(result: LogicalPlan): LocalRelation = result match { + case r: LocalRelation => r + case _ => fail(s"invalid result operator type: $result") + } + + test("SPARK-43018: Add DEFAULTs for INSERT from VALUES list with user-defined columns") { + // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with one user-specified + // column. We add a default value of NULL to the row as a result. + val insertTableSchemaWithoutPartitionColumns = StructType(Seq( + StructField("c1", TimestampType), + StructField("c2", TimestampType))) + val (result: LogicalPlan, _: Boolean) = + rule.addMissingDefaultValuesForInsertFromInlineTable( + localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1) + val relation = asLocalRelation(result) + assert(relation.output.map(_.name) == Seq("c1", "c2")) + val data: Seq[Seq[Any]] = relation.data.map { row => + row.toSeq(StructType(relation.output.map(col => StructField(col.name, col.dataType)))) + } + assert(data == Seq(Seq(literal.value, null))) + } + + test("SPARK-43018: Add no DEFAULTs for INSERT from VALUES list with no user-defined columns") { + // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with zero user-specified + // columns. The table is unchanged because there are no default columns to add in this case. + val insertTableSchemaWithoutPartitionColumns = StructType(Seq( + StructField("c1", TimestampType), + StructField("c2", TimestampType))) + val (result: LogicalPlan, _: Boolean) = + rule.addMissingDefaultValuesForInsertFromInlineTable( + localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 0) + assert(asLocalRelation(result) == localRelation) + } + + test("SPARK-43018: INSERT timestamp values into a table with column DEFAULTs") { + withTable("t") { + sql("create table t(id int, ts timestamp) using parquet") + sql("insert into t (ts) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select null, timestamp'2020-12-31'").collect().head) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index bb3843e3fee8..053afb84c103 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1100,9 +1100,15 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT INTO $t1(data, data) VALUES(5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", - parameters = Map("columnName" -> "`data`") - ) + errorClass = "_LEGACY_ERROR_TEMP_2305", + parameters = Map( + "numCols" -> "3", + "rowSize" -> "2", + "ri" -> "0"), + context = ExpectedContext( + fragment = s"INSERT INTO $t1(data, data)", + start = 0, + stop = 26)) } } @@ -1123,14 +1129,20 @@ class DataSourceV2SQLSuiteV1Filter assert(intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1 VALUES(4)") }.getMessage.contains("not enough data columns")) - // Duplicate columns + // Duplicate columns checkError( exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", - parameters = Map("columnName" -> "`data`") - ) + errorClass = "_LEGACY_ERROR_TEMP_2305", + parameters = Map( + "numCols" -> "3", + "rowSize" -> "2", + "ri" -> "0"), + context = ExpectedContext( + fragment = s"INSERT OVERWRITE $t1(data, data)", + start = 0, + stop = 31)) } } @@ -1152,14 +1164,20 @@ class DataSourceV2SQLSuiteV1Filter assert(intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1 VALUES('a', 4)") }.getMessage.contains("not enough data columns")) - // Duplicate columns + // Duplicate columns checkError( exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "COLUMN_ALREADY_EXISTS", - parameters = Map("columnName" -> "`data`") - ) + errorClass = "_LEGACY_ERROR_TEMP_2305", + parameters = Map( + "numCols" -> "4", + "rowSize" -> "3", + "ri" -> "0"), + context = ExpectedContext( + fragment = s"INSERT OVERWRITE $t1(data, data)", + start = 0, + stop = 31)) } }