From be16454e4440b0f9f1cf4e2886578096a6481f1a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 May 2023 00:41:51 +0800 Subject: [PATCH 1/6] refactor default column value resolution --- .../main/resources/error/error-classes.json | 4 +- .../write/SupportsCustomSchemaWrite.java | 38 - .../sql/catalyst/analysis/Analyzer.scala | 141 +--- .../catalyst/analysis/AssignmentUtils.scala | 8 +- .../analysis/ColumnResolutionHelper.scala | 8 + .../ResolveColumnDefaultInInsert.scala | 133 ++++ .../analysis/ResolveDefaultColumns.scala | 681 ------------------ .../analysis/ResolveInsertionBase.scala | 70 ++ .../analysis/ResolveReferencesInUpdate.scala | 70 ++ .../analysis/TableOutputResolver.scala | 43 +- .../util/ResolveDefaultColumnsUtil.scala | 48 ++ .../sql/execution/datasources/rules.scala | 16 +- .../apache/spark/sql/SQLInsertTestSuite.scala | 24 - .../analysis/ResolveDefaultColumnsSuite.scala | 224 ++---- .../command/AlignMergeAssignmentsSuite.scala | 37 +- .../apache/spark/sql/hive/InsertSuite.scala | 21 +- 16 files changed, 510 insertions(+), 1056 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 24f972a5006b..7212750c7a18 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -3176,14 +3176,14 @@ "_LEGACY_ERROR_TEMP_1202" : { "message" : [ "Cannot write to '', too many data columns:", - "Table columns: .", + "Table columns (excluding columns with static partition values): .", "Data columns: ." ] }, "_LEGACY_ERROR_TEMP_1203" : { "message" : [ "Cannot write to '', not enough data columns:", - "Table columns: .", + "Table columns (excluding columns with static partition values): .", "Data columns: ." ] }, diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java deleted file mode 100644 index 9435625a1c4a..000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsCustomSchemaWrite.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.write; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.types.StructType; - -/** - * Trait for tables that support custom schemas for write operations including INSERT INTO commands - * whose target table columns have explicit or implicit default values. - * - * @since 3.4.1 - */ -@Evolving -public interface SupportsCustomSchemaWrite { - /** - * Represents a table with a custom schema to use for resolving DEFAULT column references when - * inserting into the table. For example, this can be useful for excluding hidden pseudocolumns. - * - * @return the new schema to use for this process. - */ - StructType customSchemaForInserts(); -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index dbc9da1ea22e..161e6398a5e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils, StringUtils} +import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -55,8 +55,7 @@ import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssig import org.apache.spark.sql.internal.connector.V1Function import org.apache.spark.sql.types._ import org.apache.spark.sql.types.DayTimeIntervalType.DAY -import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} -import org.apache.spark.util.collection.{Utils => CUtils} +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and @@ -280,7 +279,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor KeepLegacyOutputs), Batch("Resolution", fixedPoint, new ResolveCatalogs(catalogManager) :: - ResolveUserSpecifiedColumns :: ResolveInsertInto :: ResolveRelations :: ResolvePartitionSpec :: @@ -313,7 +311,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor TimeWindowing :: SessionWindowing :: ResolveWindowTime :: - ResolveDefaultColumns(ResolveRelations.resolveRelationOrTempView) :: ResolveInlineTables :: ResolveLambdaVariables :: ResolveTimeZone :: @@ -1080,7 +1077,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor def apply(plan: LogicalPlan) : LogicalPlan = plan.resolveOperatorsUpWithPruning(AlwaysProcess.fn, ruleId) { - case i @ InsertIntoStatement(table, _, _, _, _, _) if i.query.resolved => + case i @ InsertIntoStatement(table, _, _, _, _, _) => val relation = table match { case u: UnresolvedRelation if !u.isStreaming => resolveRelation(u).getOrElse(u) @@ -1278,53 +1275,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } /** Handle INSERT INTO for DSv2 */ - object ResolveInsertInto extends Rule[LogicalPlan] { - - /** Add a project to use the table column names for INSERT INTO BY NAME */ - private def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = { - SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) - - if (i.userSpecifiedCols.size != i.query.output.size) { - throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( - i.userSpecifiedCols.size, i.query.output.size, i.query) - } - val projectByName = i.userSpecifiedCols.zip(i.query.output) - .map { case (userSpecifiedCol, queryOutputCol) => - val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver) - .getOrElse( - throw QueryCompilationErrors.unresolvedAttributeError( - "UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin)) - (queryOutputCol.dataType, resolvedCol.dataType) match { - case (input: StructType, expected: StructType) => - // Rename inner fields of the input column to pass the by-name INSERT analysis. - Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)() - case _ => - Alias(queryOutputCol, resolvedCol.name)() - } - } - Project(projectByName, i.query) - } - - private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = { - if (input.length == expected.length) { - val newFields = input.zip(expected).map { case (f1, f2) => - (f1.dataType, f2.dataType) match { - case (s1: StructType, s2: StructType) => - f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2)) - case _ => - f1.copy(name = f2.name) - } - } - StructType(newFields) - } else { - input - } - } - + object ResolveInsertInto extends ResolveInsertionBase { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( AlwaysProcess.fn, ruleId) { - case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) - if i.query.resolved => + case i @ InsertIntoStatement(r: DataSourceV2Relation, _, _, _, _, _) if i.query.resolved => // ifPartitionNotExists is append with validation, but validation is not supported if (i.ifPartitionNotExists) { throw QueryCompilationErrors.unsupportedIfNotExistsError(r.table.name) @@ -1527,6 +1481,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + // Don't wait other rules to resolve the child plans of `InsertIntoStatement` as we need + // to resolve column "DEFAULT" in the child plans so that they must be unresolved. + case i: InsertIntoStatement => ResolveColumnDefaultInInsert(i) + // Wait for other rules to resolve child plans first case p: LogicalPlan if !p.childrenResolved => p @@ -1646,6 +1604,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor // implementation and should be resolved based on the table schema. o.copy(deleteExpr = resolveExpressionByPlanOutput(o.deleteExpr, o.table)) + case u: UpdateTable => ResolveReferencesInUpdate(u) + case m @ MergeIntoTable(targetTable, sourceTable, _, _, _, _) if !m.resolved && targetTable.resolved && sourceTable.resolved => @@ -1796,7 +1756,17 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case MergeResolvePolicy.SOURCE => Project(Nil, mergeInto.sourceTable) case MergeResolvePolicy.TARGET => Project(Nil, mergeInto.targetTable) } - resolveMergeExprOrFail(c, resolvePlan) + resolvedKey match { + case attr: AttributeReference => + val resolvedExpr = resolveExprInAssignment(c, resolvePlan) match { + case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => + getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) + case other => other + } + checkResolvedMergeExpr(resolvedExpr, resolvePlan) + resolvedExpr + case _ => resolveMergeExprOrFail(c, resolvePlan) + } case o => o } Assignment(resolvedKey, resolvedValue) @@ -1804,15 +1774,13 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } private def resolveMergeExprOrFail(e: Expression, p: LogicalPlan): Expression = { - val resolved = resolveExpressionByPlanChildren(e, p) - resolved.references.filter { attribute: Attribute => - !attribute.resolved && - // We exclude attribute references named "DEFAULT" from consideration since they are - // handled exclusively by the ResolveDefaultColumns analysis rule. That rule checks the - // MERGE command for such references and either replaces each one with a corresponding - // value, or returns a custom error message. - normalizeFieldName(attribute.name) != normalizeFieldName(CURRENT_DEFAULT_COLUMN_NAME) - }.foreach { a => + val resolved = resolveExprInAssignment(e, p) + checkResolvedMergeExpr(resolved, p) + resolved + } + + private def checkResolvedMergeExpr(e: Expression, p: LogicalPlan): Unit = { + e.references.filter(!_.resolved).foreach { a => // Note: This will throw error only on unresolved attribute issues, // not other resolution errors like mismatched data types. val cols = p.inputSet.toSeq.map(_.sql).mkString(", ") @@ -1822,10 +1790,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor "sqlExpr" -> a.sql, "cols" -> cols)) } - resolved match { - case Alias(child: ExtractValue, _) => child - case other => other - } } // Expand the star expression using the input plan first. If failed, try resolve @@ -3346,53 +3310,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - /** - * A special rule to reorder columns for DSv1 when users specify a column list in INSERT INTO. - * DSv2 is handled by [[ResolveInsertInto]] separately. - */ - object ResolveUserSpecifiedColumns extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsWithPruning( - AlwaysProcess.fn, ruleId) { - case i: InsertIntoStatement if !i.table.isInstanceOf[DataSourceV2Relation] && - i.table.resolved && i.query.resolved && i.userSpecifiedCols.nonEmpty => - val resolved = resolveUserSpecifiedColumns(i) - val projection = addColumnListOnQuery(i.table.output, resolved, i.query) - i.copy(userSpecifiedCols = Nil, query = projection) - } - - private def resolveUserSpecifiedColumns(i: InsertIntoStatement): Seq[NamedExpression] = { - SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) - - i.userSpecifiedCols.map { col => - i.table.resolve(Seq(col), resolver).getOrElse { - val candidates = i.table.output.map(_.qualifiedName) - val orderedCandidates = StringUtils.orderSuggestedIdentifiersBySimilarity(col, candidates) - throw QueryCompilationErrors - .unresolvedAttributeError("UNRESOLVED_COLUMN", col, orderedCandidates, i.origin) - } - } - } - - private def addColumnListOnQuery( - tableOutput: Seq[Attribute], - cols: Seq[NamedExpression], - query: LogicalPlan): LogicalPlan = { - if (cols.size != query.output.size) { - throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( - cols.size, query.output.size, query) - } - val nameToQueryExpr = CUtils.toMap(cols, query.output) - // Static partition columns in the table output should not appear in the column list - // they will be handled in another rule ResolveInsertInto - val reordered = tableOutput.flatMap { nameToQueryExpr.get(_).orElse(None) } - if (reordered == query.output) { - query - } else { - Project(reordered, query) - } - } - } - private def validateStoreAssignmentPolicy(): Unit = { // SPARK-28730: LEGACY store assignment policy is disallowed in data source v2. if (conf.storeAssignmentPolicy == StoreAssignmentPolicy.LEGACY) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index 6d8118548fb4..f3f2a94c7478 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.plans.logical.Assignment import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLiteral import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} @@ -103,8 +104,11 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { case assignment if assignment.key.semanticEquals(attr) => assignment } val resolvedValue = if (matchingAssignments.isEmpty) { - errors += s"No assignment for '${attr.name}'" - attr + val defaultExpr = getDefaultValueExprOrNullLiteral(attr, conf) + if (defaultExpr.isEmpty) { + errors += s"No assignment for '${attr.name}'" + } + defaultExpr.getOrElse(attr) } else if (matchingAssignments.length > 1) { val conflictingValuesStr = matchingAssignments.map(_.value.sql).mkString(", ") errors += s"Multiple assignments for '${attr.name}': $conflictingValuesStr" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala index 318a23c36afa..98cbdea72d53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala @@ -384,6 +384,14 @@ trait ColumnResolutionHelper extends Logging { allowOuter = allowOuter) } + def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): Expression = { + resolveExpressionByPlanChildren(expr, hostPlan) match { + // Assignment key and value does not need the alias when resolving nested columns. + case Alias(child: ExtractValue, _) => child + case other => other + } + } + private def resolveExpressionByPlanId( e: Expression, q: LogicalPlan): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala new file mode 100644 index 000000000000..e9c37094f5fa --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{getDefaultValueExpr, isExplicitDefaultColumn} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.StructField + +/** + * A virtual rule to resolve column "DEFAULT" in [[Project]] and [[UnresolvedInlineTable]] under + * [[InsertIntoStatement]]. It's only used by the real rule `ResolveReferences`. + * + * This virtual rule is triggered if: + * 1. The column "DEFAULT" can't be resolved normally by `ResolveReferences`. This is guaranteed as + * `ResolveReferences` resolves the query plan bottom up. This means that when we reach here to + * resolve [[InsertIntoStatement]], its child plans have already been resolved by + * `ResolveReferences`. + * 2. The plan nodes between [[Project]]/[[UnresolvedInlineTable]] and [[InsertIntoStatement]] are + * all unary nodes that inherit the output columns from its child. + */ +case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolutionHelper { + // TODO: support v2 write commands as well. + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case i: InsertIntoStatement if conf.enableDefaultColumns && i.table.resolved && + i.query.containsPattern(UNRESOLVED_ATTRIBUTE) => + val staticPartCols = i.partitionSpec.filter(_._2.isDefined).keys.map(normalizeFieldName).toSet + val expectedQuerySchema = i.table.schema.filter { field => + !staticPartCols.contains(normalizeFieldName(field.name)) + } + if (i.userSpecifiedCols.isEmpty) { + i.withNewChildren(Seq(resolveColumnDefault(i.query, expectedQuerySchema))) + } else { + // Reorder the fields in `expectedQuerySchema` according to the user-specified column list + // of the INSERT command. + val colNamesToFields: Map[String, StructField] = expectedQuerySchema.map { field => + normalizeFieldName(field.name) -> field + }.toMap + val reorder = i.userSpecifiedCols.map { col => + colNamesToFields.get(normalizeFieldName(col)) + } + if (reorder.forall(_.isDefined)) { + i.withNewChildren(Seq(resolveColumnDefault(i.query, reorder.flatten))) + } else { + i + } + } + + case _ => plan + } + + private def resolveColumnDefault( + plan: LogicalPlan, + expectedQuerySchema: Seq[StructField]): LogicalPlan = { + plan match { + case _: GlobalLimit | _: LocalLimit | _: Offset | _: SubqueryAlias | _: Sort => + plan.mapChildren(resolveColumnDefault(_, expectedQuerySchema)) + + case p: Project if p.child.resolved && p.containsPattern(UNRESOLVED_ATTRIBUTE) && + p.projectList.length <= expectedQuerySchema.length => + var changed = false + val newProjectList = p.projectList.zipWithIndex.map { + case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => + changed = true + val field = expectedQuerySchema(i) + Alias(getDefaultValueExpr(field).getOrElse(Literal(null, field.dataType)), u.name)() + case (other, _) => other + } + if (changed) { + val newProj = p.copy(projectList = newProjectList) + newProj.copyTagsFrom(p) + newProj + } else { + p + } + + case inlineTable: UnresolvedInlineTable + if inlineTable.containsPattern(UNRESOLVED_ATTRIBUTE) && + inlineTable.rows.forall(exprs => exprs.length <= expectedQuerySchema.length) => + var changed = false + val newRows = inlineTable.rows.map { exprs => + exprs.zipWithIndex.map { + case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => + changed = true + val field = expectedQuerySchema(i) + getDefaultValueExpr(field).getOrElse(Literal(null, field.dataType)) + case (other, _) => other + } + } + if (changed) { + val newInlineTable = inlineTable.copy(rows = newRows) + newInlineTable.copyTagsFrom(inlineTable) + newInlineTable + } else { + inlineTable + } + + case other => other + } + } + + /** + * Normalizes a schema field name suitable for use in looking up into maps keyed by schema field + * names. + * @param str the field name to normalize + * @return the normalized result + */ + def normalizeFieldName(str: String): String = { + if (SQLConf.get.caseSensitiveAnalysis) { + str + } else { + str.toLowerCase() + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala deleted file mode 100644 index 13e9866645aa..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala +++ /dev/null @@ -1,681 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.analysis - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.catalog.UnresolvedCatalogRelation -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.write.SupportsCustomSchemaWrite -import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ - -/** - * This is a rule to process DEFAULT columns in statements such as CREATE/REPLACE TABLE. - * - * Background: CREATE TABLE and ALTER TABLE invocations support setting column default values for - * later operations. Following INSERT, UPDATE, and MERGE commands may then reference the value - * using the DEFAULT keyword as needed. - * - * Example: - * CREATE TABLE T(a INT DEFAULT 4, b INT NOT NULL DEFAULT 5); - * INSERT INTO T VALUES (1, 2); - * INSERT INTO T VALUES (1, DEFAULT); - * INSERT INTO T VALUES (DEFAULT, 6); - * SELECT * FROM T; - * (1, 2) - * (1, 5) - * (4, 6) - * - * @param resolveRelation function to resolve relations from the catalog. This should generally map - * to the 'resolveRelationOrTempView' method of the ResolveRelations rule. - */ -case class ResolveDefaultColumns( - resolveRelation: UnresolvedRelation => LogicalPlan) extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = { - plan.resolveOperatorsWithPruning( - (_ => SQLConf.get.enableDefaultColumns), ruleId) { - case i: InsertIntoStatement if insertsFromInlineTable(i) => - resolveDefaultColumnsForInsertFromInlineTable(i) - case i: InsertIntoStatement if insertsFromProject(i).isDefined => - resolveDefaultColumnsForInsertFromProject(i) - case u: UpdateTable => - resolveDefaultColumnsForUpdate(u) - case m: MergeIntoTable => - resolveDefaultColumnsForMerge(m) - } - } - - /** - * Checks if a logical plan is an INSERT INTO command where the inserted data comes from a VALUES - * list, with possible projection(s), aggregate(s), and/or alias(es) in between. - */ - private def insertsFromInlineTable(i: InsertIntoStatement): Boolean = { - var query = i.query - while (query.children.size == 1) { - query match { - case _: Project | _: Aggregate | _: SubqueryAlias => - query = query.children(0) - case _ => - return false - } - } - query match { - case u: UnresolvedInlineTable - if u.rows.nonEmpty && u.rows.forall(_.size == u.rows(0).size) => - true - case r: LocalRelation - if r.data.nonEmpty && r.data.forall(_.numFields == r.data(0).numFields) => - true - case _ => - false - } - } - - /** - * Checks if a logical plan is an INSERT INTO command where the inserted data comes from a SELECT - * list, with possible other unary operators like sorting and/or alias(es) in between. - */ - private def insertsFromProject(i: InsertIntoStatement): Option[Project] = { - var node = i.query - def matches(node: LogicalPlan): Boolean = node match { - case _: GlobalLimit | _: LocalLimit | _: Offset | _: SubqueryAlias | _: Sort => true - case _ => false - } - while (matches(node)) { - node = node.children.head - } - node match { - case p: Project => Some(p) - case _ => None - } - } - - /** - * Resolves DEFAULT column references for an INSERT INTO command satisfying the - * [[insertsFromInlineTable]] method. - */ - private def resolveDefaultColumnsForInsertFromInlineTable(i: InsertIntoStatement): LogicalPlan = { - val children = mutable.Buffer.empty[LogicalPlan] - var node = i.query - while (node.children.size == 1) { - children.append(node) - node = node.children(0) - } - val insertTableSchemaWithoutPartitionColumns: Option[StructType] = - getInsertTableSchemaWithoutPartitionColumns(i) - insertTableSchemaWithoutPartitionColumns.map { schema: StructType => - val regenerated: InsertIntoStatement = - regenerateUserSpecifiedCols(i, schema) - val (expanded: LogicalPlan, addedDefaults: Boolean) = - addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.size) - val replaced: Option[LogicalPlan] = - replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults) - replaced.map { r: LogicalPlan => - node = r - for (child <- children.reverse) { - node = child.withNewChildren(Seq(node)) - } - regenerated.copy(query = node) - }.getOrElse(i) - }.getOrElse(i) - } - - /** - * Resolves DEFAULT column references for an INSERT INTO command whose query is a general - * projection. - */ - private def resolveDefaultColumnsForInsertFromProject(i: InsertIntoStatement): LogicalPlan = { - val insertTableSchemaWithoutPartitionColumns: Option[StructType] = - getInsertTableSchemaWithoutPartitionColumns(i) - insertTableSchemaWithoutPartitionColumns.map { schema => - val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema) - val project: Project = insertsFromProject(i).get - if (project.projectList.exists(_.isInstanceOf[Star])) { - i - } else { - val (expanded: Project, addedDefaults: Boolean) = - addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.size) - val replaced: Option[LogicalPlan] = - replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults) - replaced.map { r => - // Replace the INSERT INTO source relation, copying unary operators until we reach the - // original projection which we replace with the new projection with new values. - def replace(plan: LogicalPlan): LogicalPlan = plan match { - case _: Project => r - case u: UnaryNode => u.withNewChildren(Seq(replace(u.child))) - } - regenerated.copy(query = replace(regenerated.query)) - }.getOrElse(i) - } - }.getOrElse(i) - } - - /** - * Resolves DEFAULT column references for an UPDATE command. - */ - private def resolveDefaultColumnsForUpdate(u: UpdateTable): LogicalPlan = { - // Return a more descriptive error message if the user tries to use a DEFAULT column reference - // inside an UPDATE command's WHERE clause; this is not allowed. - u.condition.foreach { c: Expression => - if (c.find(isExplicitDefaultColumn).isDefined) { - throw QueryCompilationErrors.defaultReferencesNotAllowedInUpdateWhereClause() - } - } - val schemaForTargetTable: Option[StructType] = getSchemaForTargetTable(u.table) - schemaForTargetTable.map { schema => - val defaultExpressions: Seq[Expression] = schema.fields.map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "UPDATE") - case _ => Literal(null) - } - // Create a map from each column name in the target table to its DEFAULT expression. - val columnNamesToExpressions: Map[String, Expression] = - mapStructFieldNamesToExpressions(schema, defaultExpressions) - // For each assignment in the UPDATE command's SET clause with a DEFAULT column reference on - // the right-hand side, look up the corresponding expression from the above map. - val newAssignments: Option[Seq[Assignment]] = - replaceExplicitDefaultValuesForUpdateAssignments( - u.assignments, CommandType.Update, columnNamesToExpressions) - newAssignments.map { n => - u.copy(assignments = n) - }.getOrElse(u) - }.getOrElse(u) - } - - /** - * Resolves DEFAULT column references for a MERGE INTO command. - */ - private def resolveDefaultColumnsForMerge(m: MergeIntoTable): LogicalPlan = { - val schema: StructType = getSchemaForTargetTable(m.targetTable).getOrElse(return m) - // Return a more descriptive error message if the user tries to use a DEFAULT column reference - // inside an UPDATE command's WHERE clause; this is not allowed. - m.mergeCondition.foreach { c: Expression => - if (c.find(isExplicitDefaultColumn).isDefined) { - throw QueryCompilationErrors.defaultReferencesNotAllowedInMergeCondition() - } - } - val columnsWithDefaults = ArrayBuffer.empty[String] - val defaultExpressions: Seq[Expression] = schema.fields.map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => - columnsWithDefaults.append(normalizeFieldName(f.name)) - analyze(f, "MERGE") - case _ => Literal(null) - } - val columnNamesToExpressions: Map[String, Expression] = - mapStructFieldNamesToExpressions(schema, defaultExpressions) - var replaced = false - val newMatchedActions: Seq[MergeAction] = m.matchedActions.map { action: MergeAction => - replaceExplicitDefaultValuesInMergeAction(action, columnNamesToExpressions).map { r => - replaced = true - r - }.getOrElse(action) - } - val newNotMatchedActions: Seq[MergeAction] = m.notMatchedActions.map { action: MergeAction => - val expanded = addMissingDefaultValuesForMergeAction(action, m, columnsWithDefaults.toSeq) - replaceExplicitDefaultValuesInMergeAction(expanded, columnNamesToExpressions).map { r => - replaced = true - r - }.getOrElse(action) - } - val newNotMatchedBySourceActions: Seq[MergeAction] = - m.notMatchedBySourceActions.map { action: MergeAction => - replaceExplicitDefaultValuesInMergeAction(action, columnNamesToExpressions).map { r => - replaced = true - r - }.getOrElse(action) - } - if (replaced) { - m.copy(matchedActions = newMatchedActions, - notMatchedActions = newNotMatchedActions, - notMatchedBySourceActions = newNotMatchedBySourceActions) - } else { - m - } - } - - /** Adds a new expressions to a merge action to generate missing default column values. */ - def addMissingDefaultValuesForMergeAction( - action: MergeAction, - m: MergeIntoTable, - columnNamesWithDefaults: Seq[String]): MergeAction = { - action match { - case i: InsertAction => - val targetColumns: Set[String] = i.assignments.map(_.key).flatMap { expr => - expr match { - case a: AttributeReference => Seq(normalizeFieldName(a.name)) - case u: UnresolvedAttribute => Seq(u.nameParts.map(normalizeFieldName).mkString(".")) - case _ => Seq() - } - }.toSet - val targetTable: String = m.targetTable match { - case SubqueryAlias(id, _) => id.name - case d: DataSourceV2Relation => d.name - } - val missingColumnNamesWithDefaults = columnNamesWithDefaults.filter { name => - !targetColumns.contains(normalizeFieldName(name)) && - !targetColumns.contains( - s"${normalizeFieldName(targetTable)}.${normalizeFieldName(name)}") - } - val newAssignments: Seq[Assignment] = missingColumnNamesWithDefaults.map { key => - Assignment(UnresolvedAttribute(key), UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME)) - } - i.copy(assignments = i.assignments ++ newAssignments) - case _ => - action - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in one action of a - * MERGE INTO command. - */ - private def replaceExplicitDefaultValuesInMergeAction( - action: MergeAction, - columnNamesToExpressions: Map[String, Expression]): Option[MergeAction] = { - action match { - case u: UpdateAction => - val replaced: Option[Seq[Assignment]] = - replaceExplicitDefaultValuesForUpdateAssignments( - u.assignments, CommandType.Merge, columnNamesToExpressions) - replaced.map { r => - Some(u.copy(assignments = r)) - }.getOrElse(None) - case i: InsertAction => - val replaced: Option[Seq[Assignment]] = - replaceExplicitDefaultValuesForUpdateAssignments( - i.assignments, CommandType.Merge, columnNamesToExpressions) - replaced.map { r => - Some(i.copy(assignments = r)) - }.getOrElse(None) - case _ => Some(action) - } - } - - /** - * Regenerates user-specified columns of an InsertIntoStatement based on the names in the - * insertTableSchemaWithoutPartitionColumns field of this class. - */ - private def regenerateUserSpecifiedCols( - i: InsertIntoStatement, - insertTableSchemaWithoutPartitionColumns: StructType): InsertIntoStatement = { - if (i.userSpecifiedCols.nonEmpty) { - i.copy( - userSpecifiedCols = insertTableSchemaWithoutPartitionColumns.fields.map(_.name)) - } else { - i - } - } - - /** - * Returns true if an expression is an explicit DEFAULT column reference. - */ - private def isExplicitDefaultColumn(expr: Expression): Boolean = expr match { - case u: UnresolvedAttribute if u.name.equalsIgnoreCase(CURRENT_DEFAULT_COLUMN_NAME) => true - case _ => false - } - - /** - * Updates an inline table to generate missing default column values. - * Returns the resulting plan plus a boolean indicating whether such values were added. - */ - def addMissingDefaultValuesForInsertFromInlineTable( - node: LogicalPlan, - insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int): (LogicalPlan, Boolean) = { - val schema = insertTableSchemaWithoutPartitionColumns - val newDefaultExpressions: Seq[UnresolvedAttribute] = - getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, node.output.size) - val newNames: Seq[String] = schema.fields.map(_.name) - val resultPlan: LogicalPlan = node match { - case _ if newDefaultExpressions.isEmpty => - node - case table: UnresolvedInlineTable => - table.copy( - names = newNames, - rows = table.rows.map { row => row ++ newDefaultExpressions }) - case local: LocalRelation => - val newDefaultExpressionsRow = new GenericInternalRow( - // Note that this code path only runs when there is a user-specified column list of fewer - // column than the target table; otherwise, the above 'newDefaultExpressions' is empty and - // we match the first case in this list instead. - schema.fields.drop(local.output.size).map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => - analyze(f, "INSERT") match { - case lit: Literal => lit.value - case _ => null - } - case _ => null - }) - LocalRelation( - output = schema.toAttributes, - data = local.data.map { row => - new JoinedRow(row, newDefaultExpressionsRow) - }) - case _ => - node - } - (resultPlan, newDefaultExpressions.nonEmpty) - } - - /** - * Adds a new expressions to a projection to generate missing default column values. - * Returns the logical plan plus a boolean indicating if such defaults were added. - */ - private def addMissingDefaultValuesForInsertFromProject( - project: Project, - insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int): (Project, Boolean) = { - val schema = insertTableSchemaWithoutPartitionColumns - val newDefaultExpressions: Seq[Expression] = - getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, project.projectList.size) - val newAliases: Seq[NamedExpression] = - newDefaultExpressions.zip(schema.fields).map { - case (expr, field) => Alias(expr, field.name)() - } - (project.copy(projectList = project.projectList ++ newAliases), - newDefaultExpressions.nonEmpty) - } - - /** - * This is a helper for the addMissingDefaultValuesForInsertFromInlineTable methods above. - */ - private def getNewDefaultExpressionsForInsert( - insertTableSchemaWithoutPartitionColumns: StructType, - numUserSpecifiedColumns: Int, - numProvidedValues: Int): Seq[UnresolvedAttribute] = { - val remainingFields: Seq[StructField] = if (numUserSpecifiedColumns > 0) { - insertTableSchemaWithoutPartitionColumns.fields.drop(numUserSpecifiedColumns) - } else { - Seq.empty - } - val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size - // Limit the number of new DEFAULT expressions to the difference of the number of columns in - // the target table and the number of provided values in the source relation. This clamps the - // total final number of provided values to the number of columns in the target table. - .min(insertTableSchemaWithoutPartitionColumns.size - numProvidedValues) - Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME)) - } - - /** - * This is a helper for the getDefaultExpressionsForInsert methods above. - */ - private def getStructFieldsForDefaultExpressions(fields: Seq[StructField]): Seq[StructField] = { - if (SQLConf.get.useNullsForMissingDefaultColumnValues) { - fields - } else { - fields.takeWhile(_.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in an INSERT INTO - * command from a logical plan. - */ - private def replaceExplicitDefaultValuesForInputOfInsertInto( - insertTableSchemaWithoutPartitionColumns: StructType, - input: LogicalPlan, - addedDefaults: Boolean): Option[LogicalPlan] = { - val schema = insertTableSchemaWithoutPartitionColumns - val defaultExpressions: Seq[Expression] = schema.fields.map { - case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT") - case _ => Literal(null) - } - // Check the type of `input` and replace its expressions accordingly. - // If necessary, return a more descriptive error message if the user tries to nest the DEFAULT - // column reference inside some other expression, such as DEFAULT + 1 (this is not allowed). - // - // Note that we don't need to check if "SQLConf.get.useNullsForMissingDefaultColumnValues" after - // this point because this method only takes responsibility to replace *existing* DEFAULT - // references. In contrast, the "getDefaultExpressionsForInsert" method will check that config - // and add new NULLs if needed. - input match { - case table: UnresolvedInlineTable => - replaceExplicitDefaultValuesForInlineTable(defaultExpressions, table) - case project: Project => - replaceExplicitDefaultValuesForProject(defaultExpressions, project) - case local: LocalRelation => - if (addedDefaults) { - Some(local) - } else { - None - } - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in an inline table. - */ - private def replaceExplicitDefaultValuesForInlineTable( - defaultExpressions: Seq[Expression], - table: UnresolvedInlineTable): Option[LogicalPlan] = { - var replaced = false - val updated: Seq[Seq[Expression]] = { - table.rows.map { row: Seq[Expression] => - for { - i <- row.indices - expr = row(i) - defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) else Literal(null) - } yield replaceExplicitDefaultReferenceInExpression( - expr, defaultExpr, CommandType.Insert, addAlias = false).map { e => - replaced = true - e - }.getOrElse(expr) - } - } - if (replaced) { - Some(table.copy(rows = updated)) - } else { - None - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in a projection. - */ - private def replaceExplicitDefaultValuesForProject( - defaultExpressions: Seq[Expression], - project: Project): Option[LogicalPlan] = { - var replaced = false - val updated: Seq[NamedExpression] = { - for { - i <- project.projectList.indices - projectExpr = project.projectList(i) - defaultExpr = if (i < defaultExpressions.size) defaultExpressions(i) else Literal(null) - } yield replaceExplicitDefaultReferenceInExpression( - projectExpr, defaultExpr, CommandType.Insert, addAlias = true).map { e => - replaced = true - e.asInstanceOf[NamedExpression] - }.getOrElse(projectExpr) - } - if (replaced) { - Some(project.copy(projectList = updated)) - } else { - None - } - } - - /** - * Represents a type of command we are currently processing. - */ - private object CommandType extends Enumeration { - val Insert, Update, Merge = Value - } - - /** - * Checks if a given input expression is an unresolved "DEFAULT" attribute reference. - * - * @param input the input expression to examine. - * @param defaultExpr the default to return if [[input]] is an unresolved "DEFAULT" reference. - * @param isInsert the type of command we are currently processing. - * @param addAlias if true, wraps the result with an alias of the original default column name. - * @return [[defaultExpr]] if [[input]] is an unresolved "DEFAULT" attribute reference. - */ - private def replaceExplicitDefaultReferenceInExpression( - input: Expression, - defaultExpr: Expression, - command: CommandType.Value, - addAlias: Boolean): Option[Expression] = { - input match { - case a@Alias(u: UnresolvedAttribute, _) - if isExplicitDefaultColumn(u) => - Some(Alias(defaultExpr, a.name)()) - case u: UnresolvedAttribute - if isExplicitDefaultColumn(u) => - if (addAlias) { - Some(Alias(defaultExpr, u.name)()) - } else { - Some(defaultExpr) - } - case expr@_ - if expr.find(isExplicitDefaultColumn).isDefined => - command match { - case CommandType.Insert => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() - case CommandType.Update => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() - case CommandType.Merge => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates() - } - case _ => - None - } - } - - /** - * Looks up the schema for the table object of an INSERT INTO statement from the catalog. - */ - private def getInsertTableSchemaWithoutPartitionColumns( - enclosingInsert: InsertIntoStatement): Option[StructType] = { - val target: StructType = getSchemaForTargetTable(enclosingInsert.table).getOrElse(return None) - val schema: StructType = StructType(target.fields.dropRight(enclosingInsert.partitionSpec.size)) - // Rearrange the columns in the result schema to match the order of the explicit column list, - // if any. - val userSpecifiedCols: Seq[String] = enclosingInsert.userSpecifiedCols - if (userSpecifiedCols.isEmpty) { - return Some(schema) - } - val colNamesToFields: Map[String, StructField] = mapStructFieldNamesToFields(schema) - val userSpecifiedFields: Seq[StructField] = - userSpecifiedCols.map { - name: String => colNamesToFields.getOrElse(normalizeFieldName(name), return None) - } - val userSpecifiedColNames: Set[String] = userSpecifiedCols.toSet - .map(normalizeFieldName) - val nonUserSpecifiedFields: Seq[StructField] = - schema.fields.filter { - field => !userSpecifiedColNames.contains( - normalizeFieldName( - field.name - ) - ) - } - Some(StructType(userSpecifiedFields ++ - getStructFieldsForDefaultExpressions(nonUserSpecifiedFields))) - } - - /** - * Returns a map of the names of fields in a schema to the fields themselves. - */ - private def mapStructFieldNamesToFields(schema: StructType): Map[String, StructField] = { - schema.fields.map { - field: StructField => normalizeFieldName(field.name) -> field - }.toMap - } - - /** - * Returns a map of the names of fields in a schema to corresponding expressions. - */ - private def mapStructFieldNamesToExpressions( - schema: StructType, - expressions: Seq[Expression]): Map[String, Expression] = { - schema.fields.zip(expressions).map { - case (field: StructField, expression: Expression) => - normalizeFieldName(field.name) -> expression - }.toMap - } - - /** - * Returns the schema for the target table of a DML command, looking into the catalog if needed. - */ - private def getSchemaForTargetTable(table: LogicalPlan): Option[StructType] = { - val resolved = table match { - case r: UnresolvedRelation if !r.skipSchemaResolution && !r.isStreaming => - resolveRelation(r) - case other => - other - } - resolved.collectFirst { - case r: UnresolvedCatalogRelation => - r.tableMeta.schema - case DataSourceV2Relation(table: SupportsCustomSchemaWrite, _, _, _, _) => - table.customSchemaForInserts - case r: NamedRelation if !r.skipSchemaResolution => - r.schema - case v: View if v.isTempViewStoringAnalyzedPlan => - v.schema - } - } - - /** - * Replaces unresolved DEFAULT column references with corresponding values in a series of - * assignments in an UPDATE assignment, either comprising an UPDATE command or as part of a MERGE. - */ - private def replaceExplicitDefaultValuesForUpdateAssignments( - assignments: Seq[Assignment], - command: CommandType.Value, - columnNamesToExpressions: Map[String, Expression]): Option[Seq[Assignment]] = { - var replaced = false - val newAssignments: Seq[Assignment] = - for (assignment <- assignments) yield { - val destColName = assignment.key match { - case a: AttributeReference => a.name - case u: UnresolvedAttribute => u.nameParts.last - case _ => "" - } - val adjusted: String = normalizeFieldName(destColName) - val lookup: Option[Expression] = columnNamesToExpressions.get(adjusted) - val newValue: Expression = lookup.map { defaultExpr => - val updated: Option[Expression] = - replaceExplicitDefaultReferenceInExpression( - assignment.value, - defaultExpr, - command, - addAlias = false) - updated.map { e => - replaced = true - e - }.getOrElse(assignment.value) - }.getOrElse(assignment.value) - assignment.copy(value = newValue) - } - if (replaced) { - Some(newAssignments) - } else { - None - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala new file mode 100644 index 000000000000..71d368679510 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInsertionBase.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Alias, Cast} +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoStatement, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils + +abstract class ResolveInsertionBase extends Rule[LogicalPlan] { + def resolver: Resolver = conf.resolver + + /** Add a project to use the table column names for INSERT INTO BY NAME */ + protected def createProjectForByNameQuery(i: InsertIntoStatement): LogicalPlan = { + SchemaUtils.checkColumnNameDuplication(i.userSpecifiedCols, resolver) + + if (i.userSpecifiedCols.size != i.query.output.size) { + throw QueryCompilationErrors.writeTableWithMismatchedColumnsError( + i.userSpecifiedCols.size, i.query.output.size, i.query) + } + val projectByName = i.userSpecifiedCols.zip(i.query.output) + .map { case (userSpecifiedCol, queryOutputCol) => + val resolvedCol = i.table.resolve(Seq(userSpecifiedCol), resolver) + .getOrElse( + throw QueryCompilationErrors.unresolvedAttributeError( + "UNRESOLVED_COLUMN", userSpecifiedCol, i.table.output.map(_.name), i.origin)) + (queryOutputCol.dataType, resolvedCol.dataType) match { + case (input: StructType, expected: StructType) => + // Rename inner fields of the input column to pass the by-name INSERT analysis. + Alias(Cast(queryOutputCol, renameFieldsInStruct(input, expected)), resolvedCol.name)() + case _ => + Alias(queryOutputCol, resolvedCol.name)() + } + } + Project(projectByName, i.query) + } + + private def renameFieldsInStruct(input: StructType, expected: StructType): StructType = { + if (input.length == expected.length) { + val newFields = input.zip(expected).map { case (f1, f2) => + (f1.dataType, f2.dataType) match { + case (s1: StructType, s2: StructType) => + f1.copy(name = f2.name, dataType = renameFieldsInStruct(s1, s2)) + case _ => + f1.copy(name = f2.name) + } + } + StructType(newFields) + } else { + input + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala new file mode 100644 index 000000000000..fd701992b9c0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{getDefaultValueExpr, isExplicitDefaultColumn} + +/** + * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real + * rule `ResolveReferences`. The column resolution order for [[UpdateTable]] is: + * 1. Resolves the column to [[AttributeReference]] with the output of the child plan. This + * includes metadata columns as well. + * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. + * `SELECT col, current_date FROM t`. + * 3. Resolves the column to the default value expression, if the column is the assignment value + * and the corresponding assignment key is a top-level column. + */ +case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutionHelper { + + def apply(u: UpdateTable): UpdateTable = { + assert(u.table.resolved) + if (u.resolved) return u + + val newAssignments = u.assignments.map { assign => + val resolvedKey = assign.key match { + case c if !c.resolved => + resolveExprInAssignment(c, u) + case o => o + } + val resolvedValue = assign.value match { + case c if !c.resolved => + val resolved = resolveExprInAssignment(c, u) + resolvedKey match { + case attr: AttributeReference if conf.enableDefaultColumns => + resolved match { + case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => + getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) + case other => other + } + case _ => resolved + } + case o => o + } + val resolved = Assignment(resolvedKey, resolvedValue) + resolved.copyTagsFrom(assign) + resolved + } + + val newUpdate = u.copy(assignments = newAssignments) + newUpdate.copyTagsFrom(u) + newUpdate + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index ae159cd349f0..2051487f1a68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLiteral import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -36,7 +37,9 @@ object TableOutputResolver { expected: Seq[Attribute], query: LogicalPlan, byName: Boolean, - conf: SQLConf): LogicalPlan = { + conf: SQLConf, + // TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well. + supportColDefaultValue: Boolean = false): LogicalPlan = { val actualExpectedCols = expected.map { attr => attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) @@ -49,14 +52,32 @@ object TableOutputResolver { val errors = new mutable.ArrayBuffer[String]() val resolved: Seq[NamedExpression] = if (byName) { - reorderColumnsByName(query.output, actualExpectedCols, conf, errors += _) + // If a top-level column does not have a corresponding value in the input query, fill with + // the column's default value. We need to pass `fillDefaultValue` as true here, if the + // `supportColDefaultValue` parameter is also true. + reorderColumnsByName( + query.output, + actualExpectedCols, + conf, + errors += _, + fillDefaultValue = supportColDefaultValue) } else { - if (actualExpectedCols.size > query.output.size) { + // If the target table needs more columns than the input query, fill them with + // the columns' default values, if the `supportColDefaultValue` parameter is true. + val fillDefaultValue = supportColDefaultValue && actualExpectedCols.size > query.output.size + val queryOutputCols = if (fillDefaultValue) { + query.output ++ actualExpectedCols.drop(query.output.size).flatMap { expectedCol => + getDefaultValueExprOrNullLiteral(expectedCol, conf) + } + } else { + query.output + } + if (actualExpectedCols.size > queryOutputCols.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( tableName, actualExpectedCols, query) } - resolveColumnsByPosition(query.output, actualExpectedCols, conf, errors += _) + resolveColumnsByPosition(queryOutputCols, actualExpectedCols, conf, errors += _) } if (errors.nonEmpty) { @@ -156,14 +177,22 @@ object TableOutputResolver { expectedCols: Seq[Attribute], conf: SQLConf, addError: String => Unit, - colPath: Seq[String] = Nil): Seq[NamedExpression] = { + colPath: Seq[String] = Nil, + fillDefaultValue: Boolean = false): Seq[NamedExpression] = { val matchedCols = mutable.HashSet.empty[String] val reordered = expectedCols.flatMap { expectedCol => val matched = inputCols.filter(col => conf.resolver(col.name, expectedCol.name)) val newColPath = colPath :+ expectedCol.name if (matched.isEmpty) { - addError(s"Cannot find data for output column '${newColPath.quoted}'") - None + val defaultExpr = if (fillDefaultValue) { + getDefaultValueExprOrNullLiteral(expectedCol, conf) + } else { + None + } + if (defaultExpr.isEmpty) { + addError(s"Cannot find data for output column '${newColPath.quoted}'") + } + defaultExpr } else if (matched.length > 1) { addError(s"Ambiguous column name in the input data: '${newColPath.quoted}'") None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 8c7e2ad4f1df..c6375d0a947e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -155,6 +155,54 @@ object ResolveDefaultColumns { } } + /** + * Returns true if the unresolved column is an explicit DEFAULT column reference. + */ + def isExplicitDefaultColumn(col: UnresolvedAttribute): Boolean = { + col.name.equalsIgnoreCase(CURRENT_DEFAULT_COLUMN_NAME) + } + + /** + * Generates the expression of the default value for the given field. If there is no + * user-specified default value for this field, returns None. + */ + def getDefaultValueExpr(field: StructField): Option[Expression] = { + if (field.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { + Some(analyze(field, "INSERT")) + } else { + None + } + } + + /** + * Generates the expression of the default value for the given column. If there is no + * user-specified default value for this field, returns None. + */ + def getDefaultValueExpr(attr: Attribute): Option[Expression] = { + if (attr.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { + val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + Some(analyze(field, "INSERT")) + } else { + None + } + } + + /** + * Generates the expression of the default value for the given column. If there is no + * user-specified default value for this column, returns a null literal. + */ + def getDefaultValueExprOrNullLiteral(attr: Attribute, conf: SQLConf): Option[NamedExpression] = { + val defaultExprOpt = if (attr.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { + val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + Some(analyze(field, "INSERT")) + } else if (conf.useNullsForMissingDefaultColumnValues) { + Some(Literal(null, attr.dataType)) + } else { + None + } + defaultExprOpt.map(expr => Alias(expr, attr.name)()) + } + /** * Parses and analyzes the DEFAULT column text in `field`, returning an error upon failure. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 2564d7e50a21..cedb58d8cbfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -362,7 +362,7 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -object PreprocessTableInsertion extends Rule[LogicalPlan] { +object PreprocessTableInsertion extends ResolveInsertionBase { private def preprocess( insert: InsertIntoStatement, tblName: String, @@ -375,11 +375,6 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { val staticPartCols = normalizedPartSpec.filter(_._2.isDefined).keySet val expectedColumns = insert.table.output.filterNot(a => staticPartCols.contains(a.name)) - if (expectedColumns.length != insert.query.schema.length) { - throw QueryCompilationErrors.mismatchedInsertedDataColumnNumberError( - tblName, insert, staticPartCols) - } - val partitionsTrackedByCatalog = catalogTable.isDefined && catalogTable.get.partitionColumnNames.nonEmpty && catalogTable.get.tracksPartitionsInCatalog @@ -392,8 +387,15 @@ object PreprocessTableInsertion extends Rule[LogicalPlan] { } } + // Create a project if this INSERT has a user-specified column list. + val isByName = insert.userSpecifiedCols.nonEmpty + val query = if (isByName) { + createProjectForByNameQuery(insert) + } else { + insert.query + } val newQuery = TableOutputResolver.resolveOutputColumns( - tblName, expectedColumns, insert.query, byName = false, conf) + tblName, expectedColumns, query, byName = isByName, conf, supportColDefaultValue = true) if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw QueryCompilationErrors.requestedPartitionsMismatchTablePartitionsError( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala index 1997fce0f5cf..904980d58d60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLInsertTestSuite.scala @@ -201,30 +201,6 @@ trait SQLInsertTestSuite extends QueryTest with SQLTestUtils { } } - test("insert with column list - missing columns") { - val v2Msg = "Cannot write incompatible data to table 'testcat.t1'" - val cols = Seq("c1", "c2", "c3", "c4") - - withTable("t1") { - createTable("t1", cols, Seq.fill(4)("int")) - val e1 = intercept[AnalysisException](sql(s"INSERT INTO t1 values(1)")) - assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 1") || - e1.getMessage.contains("expected 4 columns but found 1") || - e1.getMessage.contains("not enough data columns") || - e1.getMessage.contains(v2Msg)) - } - - withTable("t1") { - createTable("t1", cols, Seq.fill(4)("int"), cols.takeRight(2)) - val e1 = intercept[AnalysisException] { - sql(s"INSERT INTO t1 partition(c3=3, c4=4) values(1)") - } - assert(e1.getMessage.contains("target table has 4 column(s) but the inserted data has 3") || - e1.getMessage.contains("not enough data columns") || - e1.getMessage.contains(v2Msg)) - } - } - test("SPARK-34223: static partition with null raise NPE") { withTable("t") { sql(s"CREATE TABLE t(i STRING, c string) USING PARQUET PARTITIONED BY (c)") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala index ba52ac995b7f..f184d66cacff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumnsSuite.scala @@ -17,69 +17,84 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog.{Table, TableCapability} -import org.apache.spark.sql.connector.write.SupportsCustomSchemaWrite -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{MetadataBuilder, StringType, StructField, StructType, TimestampType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { - val rule = ResolveDefaultColumns(null) - // This is the internal storage for the timestamp 2020-12-31 00:00:00.0. - val literal = Literal(1609401600000000L, TimestampType) - val table = UnresolvedInlineTable( - names = Seq("attr1"), - rows = Seq(Seq(literal))) - val localRelation = ResolveInlineTables(table).asInstanceOf[LocalRelation] - - def asLocalRelation(result: LogicalPlan): LocalRelation = result match { - case r: LocalRelation => r - case _ => fail(s"invalid result operator type: $result") - } + test("column without default value defined (null as default)") { + withTable("t") { + sql("create table t(c1 timestamp, c2 timestamp) using parquet") - test("SPARK-43018: Add DEFAULTs for INSERT from VALUES list with user-defined columns") { - // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with one user-specified - // column. We add a default value of NULL to the row as a result. - val insertTableSchemaWithoutPartitionColumns = StructType(Seq( - StructField("c1", TimestampType), - StructField("c2", TimestampType))) - val (result: LogicalPlan, _: Boolean) = - rule.addMissingDefaultValuesForInsertFromInlineTable( - localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1) - val relation = asLocalRelation(result) - assert(relation.output.map(_.name) == Seq("c1", "c2")) - val data: Seq[Seq[Any]] = relation.data.map { row => - row.toSeq(StructType(relation.output.map(col => StructField(col.name, col.dataType)))) - } - assert(data == Seq(Seq(literal.value, null))) - } + // INSERT with user-defined columns + sql("insert into t (c2) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select null, timestamp'2020-12-31'").collect().head) + sql("truncate table t") + sql("insert into t (c1) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-12-31', null").collect().head) - test("SPARK-43018: Add no DEFAULTs for INSERT from VALUES list with no user-defined columns") { - // Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with zero user-specified - // columns. The table is unchanged because there are no default columns to add in this case. - val insertTableSchemaWithoutPartitionColumns = StructType(Seq( - StructField("c1", TimestampType), - StructField("c2", TimestampType))) - val (result: LogicalPlan, _: Boolean) = - rule.addMissingDefaultValuesForInsertFromInlineTable( - localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 0) - assert(asLocalRelation(result) == localRelation) + // INSERT without user-defined columns + sql("truncate table t") + sql("insert into t values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-12-31', null").collect().head) + } } - test("SPARK-43018: INSERT timestamp values into a table with column DEFAULTs") { + test("column with default value defined") { withTable("t") { - sql("create table t(id int, ts timestamp) using parquet") - sql("insert into t (ts) values (timestamp'2020-12-31')") + sql("create table t(c1 timestamp DEFAULT timestamp'2020-01-01', " + + "c2 timestamp DEFAULT timestamp'2020-01-01') using parquet") + + // INSERT with user-defined columns + sql("insert into t (c1) values (timestamp'2020-12-31')") checkAnswer(spark.table("t"), - sql("select null, timestamp'2020-12-31'").collect().head) + sql("select timestamp'2020-12-31', timestamp'2020-01-01'").collect().head) + sql("truncate table t") + sql("insert into t (c2) values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-01-01', timestamp'2020-12-31'").collect().head) + + // INSERT without user-defined columns + sql("truncate table t") + sql("insert into t values (timestamp'2020-12-31')") + checkAnswer(spark.table("t"), + sql("select timestamp'2020-12-31', timestamp'2020-01-01'").collect().head) } } + test("INSERT into partitioned tables") { + sql("create table t(c1 int, c2 int, c3 int, c4 int) using parquet partitioned by (c3, c4)") + + // INSERT without static partitions + sql("insert into t values (1, 2, 3)") + checkAnswer(spark.table("t"), Row(1, 2, 3, null)) + + // INSERT without static partitions but with column list + sql("truncate table t") + sql("insert into t (c2, c1, c4) values (1, 2, 3)") + checkAnswer(spark.table("t"), Row(2, 1, null, 3)) + + // INSERT with static partitions + sql("truncate table t") + sql("insert into t partition(c3=3, c4=4) values (1)") + checkAnswer(spark.table("t"), Row(1, null, 3, 4)) + + // INSERT with static partitions and with column list + sql("truncate table t") + sql("insert into t partition(c3=3, c4=4) (c2) values (1)") + checkAnswer(spark.table("t"), Row(null, 1, 3, 4)) + + // INSERT with partial static partitions + sql("truncate table t") + sql("insert into t partition(c3=3, c4) values (1, 2)") + checkAnswer(spark.table("t"), Row(1, 2, 3, null)) + + // INSERT with partial static partitions and with column list is not allowed + intercept[AnalysisException](sql("insert into t partition(c3=3, c4) (c1) values (1, 4)")) + } + test("SPARK-43085: Column DEFAULT assignment for target tables with multi-part names") { withDatabase("demos") { sql("create database demos") @@ -129,111 +144,4 @@ class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession { } } } - - /** - * This is a new relation type that defines the 'customSchemaForInserts' method. - * Its implementation drops the last table column as it represents an internal pseudocolumn. - */ - case class TableWithCustomInsertSchema(output: Seq[Attribute], numMetadataColumns: Int) - extends Table with SupportsCustomSchemaWrite { - override def name: String = "t" - override def schema: StructType = StructType.fromAttributes(output) - override def capabilities(): java.util.Set[TableCapability] = - new java.util.HashSet[TableCapability]() - override def customSchemaForInserts: StructType = - StructType(schema.fields.dropRight(numMetadataColumns)) - } - - /** Helper method to generate a DSV2 relation using the above table type. */ - private def relationWithCustomInsertSchema( - output: Seq[AttributeReference], numMetadataColumns: Int): DataSourceV2Relation = { - DataSourceV2Relation( - TableWithCustomInsertSchema(output, numMetadataColumns), - output, - catalog = None, - identifier = None, - options = CaseInsensitiveStringMap.empty) - } - - test("SPARK-43313: Add missing default values for MERGE INSERT actions") { - val testRelation = SubqueryAlias( - "testRelation", - relationWithCustomInsertSchema(Seq( - AttributeReference( - "a", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'a'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'a'") - .build())(), - AttributeReference( - "b", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'b'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'b'") - .build())(), - AttributeReference( - "c", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'c'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'c'") - .build())(), - AttributeReference( - "pseudocolumn", - StringType, - true, - new MetadataBuilder() - .putString(CURRENT_DEFAULT_COLUMN_METADATA_KEY, "'pseudocolumn'") - .putString(EXISTS_DEFAULT_COLUMN_METADATA_KEY, "'pseudocolumn'") - .build())()), - numMetadataColumns = 1)) - val testRelation2 = - SubqueryAlias( - "testRelation2", - relationWithCustomInsertSchema(Seq( - AttributeReference("d", StringType)(), - AttributeReference("e", StringType)(), - AttributeReference("f", StringType)()), - numMetadataColumns = 0)) - val mergePlan = MergeIntoTable( - targetTable = testRelation, - sourceTable = testRelation2, - mergeCondition = EqualTo(testRelation.output.head, testRelation2.output.head), - matchedActions = Seq(DeleteAction(None)), - notMatchedActions = Seq( - InsertAction( - condition = None, - assignments = Seq( - Assignment( - key = UnresolvedAttribute("a"), - value = UnresolvedAttribute("DEFAULT")), - Assignment( - key = UnresolvedAttribute(Seq("testRelation", "b")), - value = Literal("xyz"))))), - notMatchedBySourceActions = Seq(DeleteAction(None))) - // Run the 'addMissingDefaultValuesForMergeAction' method of the 'ResolveDefaultColumns' rule - // on an MERGE INSERT action with two assignments, one to the target table's column 'a' and - // another to the target table's column 'b'. - val columnNamesWithDefaults = Seq("a", "b", "c") - val actualMergeAction = - rule.apply(mergePlan).asInstanceOf[MergeIntoTable].notMatchedActions.head - val expectedMergeAction = - InsertAction( - condition = None, - assignments = Seq( - Assignment(key = UnresolvedAttribute("a"), value = Literal("a")), - Assignment(key = UnresolvedAttribute(Seq("testRelation", "b")), value = Literal("xyz")), - Assignment(key = UnresolvedAttribute("c"), value = Literal("c")))) - assert(expectedMergeAction == actualMergeAction) - // Run the same method on another MERGE DELETE action. There is no change because this method - // only operates on MERGE INSERT actions. - assert(rule.addMissingDefaultValuesForMergeAction( - mergePlan.matchedActions.head, mergePlan, columnNamesWithDefaults) == - mergePlan.matchedActions.head) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala index e56916663390..a479e810e462 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlignMergeAssignmentsSuite.scala @@ -602,14 +602,6 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { } test("invalid INSERT assignments") { - assertAnalysisException( - """MERGE INTO primitive_table t USING primitive_table src - |ON t.i = src.i - |WHEN NOT MATCHED THEN - | INSERT (i, txt) VALUES (src.i, src.txt) - |""".stripMargin, - "No assignment for 'l'") - assertAnalysisException( """MERGE INTO primitive_table t USING primitive_table src |ON t.i = src.i @@ -624,10 +616,7 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { |WHEN NOT MATCHED THEN | INSERT (s.n_i) VALUES (1) |""".stripMargin, - "INSERT assignment keys cannot be nested fields: t.s.`n_i` = 1", - "No assignment for 'i'", - "No assignment for 's'", - "No assignment for 'txt'") + "INSERT assignment keys cannot be nested fields: t.s.`n_i` = 1") } test("updates to nested structs in arrays") { @@ -866,6 +855,8 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { |ON t.b = s.b |WHEN MATCHED THEN | UPDATE SET t.i = DEFAULT + |WHEN NOT MATCHED AND (s.i = 1) THEN + | INSERT (b) VALUES (false) |WHEN NOT MATCHED THEN | INSERT (i, b) VALUES (DEFAULT, false) |WHEN NOT MATCHED BY SOURCE THEN @@ -889,8 +880,26 @@ class AlignMergeAssignmentsSuite extends AlignAssignmentsSuite { fail(s"Unexpected actions: $other") } - notMatchedActions match { - case Seq(InsertAction(None, assignments)) => + assert(notMatchedActions.length == 2) + notMatchedActions(0) match { + case InsertAction(Some(_), assignments) => + assignments match { + case Seq( + Assignment(b: AttributeReference, BooleanLiteral(false)), + Assignment(i: AttributeReference, IntegerLiteral(42))) => + + assert(b.name == "b") + assert(i.name == "i") + + case other => + fail(s"Unexpected assignments: $other") + } + + case other => + fail(s"Unexpected actions: $other") + } + notMatchedActions(1) match { + case InsertAction(None, assignments) => assignments match { case Seq( Assignment(b: AttributeReference, BooleanLiteral(false)), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index ea1e9a7e0486..9cc26d894baa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -359,8 +359,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter val e = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION(b=1, c=2) SELECT 1, 2, 3") } - assert(e.message.contains( - "target table has 4 column(s) but the inserted data has 5 column(s)")) + assert(e.message.contains("Cannot write to") && e.message.contains("too many data columns")) } testPartitionedTable("SPARK-16037: INSERT statement should match columns by position") { @@ -382,6 +381,9 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter sql(s"INSERT INTO TABLE $tableName PARTITION (c=11, b=10) SELECT 9, 12") + // The data is missing a column. The default value for the missing column is null. + sql(s"INSERT INTO TABLE $tableName PARTITION (c=15, b=16) SELECT 13") + // c is defined twice. Analyzer will complain. intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, c=16) SELECT 13") @@ -397,11 +399,6 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter sql(s"INSERT INTO TABLE $tableName PARTITION (b=14, c=15, d=16) SELECT 13") } - // The data is missing a column. - intercept[AnalysisException] { - sql(s"INSERT INTO TABLE $tableName PARTITION (c=15, b=16) SELECT 13") - } - // d is not a partitioning column. intercept[AnalysisException] { sql(s"INSERT INTO TABLE $tableName PARTITION (b=15, d=15) SELECT 13, 14") @@ -436,6 +433,7 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter Row(5, 6, 7, 8) :: Row(9, 10, 11, 12) :: Row(13, 14, 15, 16) :: + Row(13, 16, 15, null) :: Row(17, 18, 19, 20) :: Row(21, 22, 23, 24) :: Row(25, 26, 27, 28) :: Nil @@ -473,13 +471,14 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter } } - testPartitionedTable("insertInto() should reject missing columns") { + testPartitionedTable("insertInto() should reject missing columns if null default is disabled") { tableName => withTable("t") { sql("CREATE TABLE t (a INT, b INT)") - - intercept[AnalysisException] { - spark.table("t").write.insertInto(tableName) + withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") { + intercept[AnalysisException] { + spark.table("t").write.insertInto(tableName) + } } } } From 5c83feb08f6836289b0cb0cbcb9f07321e509f64 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 23 May 2023 22:07:43 +0800 Subject: [PATCH 2/6] fix --- .../catalyst/analysis/ResolveColumnDefaultInInsert.scala | 2 +- .../sql/catalyst/analysis/ResolveReferencesInUpdate.scala | 4 +++- .../sql/catalyst/analysis/V2WriteAnalysisSuite.scala | 8 ++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala index e9c37094f5fa..9aa791f71f8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types.StructField * all unary nodes that inherit the output columns from its child. */ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolutionHelper { - // TODO: support v2 write commands as well. + // TODO (SPARK-43752): support v2 write commands as well. def apply(plan: LogicalPlan): LogicalPlan = plan match { case i: InsertIntoStatement if conf.enableDefaultColumns && i.table.resolved && i.query.containsPattern(UNRESOLVED_ATTRIBUTE) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index fd701992b9c0..672bc4c68961 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -63,7 +63,9 @@ case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutio resolved } - val newUpdate = u.copy(assignments = newAssignments) + val newUpdate = u.copy( + assignments = newAssignments, + condition = u.condition.map(resolveExpressionByPlanChildren(_, u))) newUpdate.copyTagsFrom(u) newUpdate } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 0b698ae07a28..8a01c1a602b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -445,7 +445,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "too many data columns", - "Table columns: 'x', 'y'", + "Table columns (excluding columns with static partition values): 'x', 'y'", "Data columns: 'x', 'y', 'z'")) } @@ -525,7 +525,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "not enough data columns", - "Table columns: 'x', 'y'", + "Table columns (excluding columns with static partition values): 'x', 'y'", "Data columns: 'y'")) } @@ -539,7 +539,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "not enough data columns", - "Table columns: 'x', 'y'", + "Table columns (excluding columns with static partition values): 'x', 'y'", "Data columns: 'y'")) } @@ -574,7 +574,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "too many data columns", - "Table columns: 'x', 'y'", + "Table columns (excluding columns with static partition values): 'x', 'y'", "Data columns: 'a', 'b', 'c'")) } From 3544605983eb33a1de4464fdbcc14376104b37b2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 May 2023 18:40:25 +0800 Subject: [PATCH 3/6] fix tests --- .../main/resources/error/error-classes.json | 27 +- .../sql/catalyst/analysis/Analyzer.scala | 3 + .../sql/catalyst/analysis/CheckAnalysis.scala | 23 +- .../ResolveColumnDefaultInInsert.scala | 65 ++-- .../analysis/ResolveReferencesInUpdate.scala | 6 +- .../util/ResolveDefaultColumnsUtil.scala | 10 + .../sql/errors/QueryCompilationErrors.scala | 25 +- .../analysis/V2WriteAnalysisSuite.scala | 8 +- .../datasources/DataSourceStrategy.scala | 3 +- .../sql/execution/datasources/rules.scala | 17 +- .../postgreSQL/numeric.sql.out | 8 +- .../results/postgreSQL/numeric.sql.out | 8 +- .../sql/connector/DataSourceV2SQLSuite.scala | 33 +- .../command/PlanResolutionSuite.scala | 153 ++++----- .../spark/sql/sources/InsertSuite.scala | 318 ++++++++---------- .../sql/hive/execution/HiveQuerySuite.scala | 8 +- 16 files changed, 329 insertions(+), 386 deletions(-) diff --git a/core/src/main/resources/error/error-classes.json b/core/src/main/resources/error/error-classes.json index 7212750c7a18..455cdee37060 100644 --- a/core/src/main/resources/error/error-classes.json +++ b/core/src/main/resources/error/error-classes.json @@ -785,7 +785,18 @@ }, "INSERT_COLUMN_ARITY_MISMATCH" : { "message" : [ - " requires that the data to be inserted have the same number of columns as the target table: target table has column(s) but the inserted data has column(s), including partition column(s) having constant value(s)." + "Cannot write to '', :", + "Table columns: .", + "Data columns: ." + ], + "sqlState" : "21S01" + }, + "INSERT_PARTITION_COLUMN_ARITY_MISMATCH" : { + "message" : [ + "Cannot write to '', :", + "Table columns: .", + "Partition columns with static values: .", + "Data columns: ." ], "sqlState" : "21S01" }, @@ -3173,20 +3184,6 @@ "Cannot resolve column name \"\" among ()." ] }, - "_LEGACY_ERROR_TEMP_1202" : { - "message" : [ - "Cannot write to '', too many data columns:", - "Table columns (excluding columns with static partition values): .", - "Data columns: ." - ] - }, - "_LEGACY_ERROR_TEMP_1203" : { - "message" : [ - "Cannot write to '', not enough data columns:", - "Table columns (excluding columns with static partition values): .", - "Data columns: ." - ] - }, "_LEGACY_ERROR_TEMP_1204" : { "message" : [ "Cannot write incompatible data to table '':", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 161e6398a5e7..0f739f1c4946 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1761,6 +1761,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val resolvedExpr = resolveExprInAssignment(c, resolvePlan) match { case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) + case other if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates() case other => other } checkResolvedMergeExpr(resolvedExpr, resolvePlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index b67b4ee9912c..0322db6994d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -159,6 +159,21 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB } def checkAnalysis0(plan: LogicalPlan): Unit = { + // The target table is not a child plan of the insert command. We should report errors for table + // not found first, instead of errors in the input query of the insert command, by doing a + // top-down traversal. + plan.foreach { + case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) => + u.tableNotFound(u.multipartIdentifier) + + // TODO (SPARK-27484): handle streaming write commands when we have them. + case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] => + val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier + write.table.tableNotFound(tblName) + + case _ => + } + // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. plan.foreachUp { @@ -195,14 +210,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB errorClass = "_LEGACY_ERROR_TEMP_2313", messageParameters = Map("name" -> u.name)) - case InsertIntoStatement(u: UnresolvedRelation, _, _, _, _, _) => - u.tableNotFound(u.multipartIdentifier) - - // TODO (SPARK-27484): handle streaming write commands when we have them. - case write: V2WriteCommand if write.table.isInstanceOf[UnresolvedRelation] => - val tblName = write.table.asInstanceOf[UnresolvedRelation].multipartIdentifier - write.table.tableNotFound(tblName) - case command: V2PartitionCommand => command.table match { case r @ ResolvedTable(_, _, table, _) => table match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala index 9aa791f71f8c..1f0f452fbd85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala @@ -21,7 +21,8 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{getDefaultValueExpr, isExplicitDefaultColumn} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExpr, isExplicitDefaultColumn} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructField @@ -34,8 +35,10 @@ import org.apache.spark.sql.types.StructField * `ResolveReferences` resolves the query plan bottom up. This means that when we reach here to * resolve [[InsertIntoStatement]], its child plans have already been resolved by * `ResolveReferences`. - * 2. The plan nodes between [[Project]]/[[UnresolvedInlineTable]] and [[InsertIntoStatement]] are + * 2. The plan nodes between [[Project]] and [[InsertIntoStatement]] are * all unary nodes that inherit the output columns from its child. + * 3. The plan nodes between [[UnresolvedInlineTable]] and [[InsertIntoStatement]] are either + * [[Project]], or [[Aggregate]], or [[SubqueryAlias]]. */ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolutionHelper { // TODO (SPARK-43752): support v2 write commands as well. @@ -69,49 +72,55 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu private def resolveColumnDefault( plan: LogicalPlan, - expectedQuerySchema: Seq[StructField]): LogicalPlan = { + expectedQuerySchema: Seq[StructField], + acceptProject: Boolean = true, + acceptInlineTable: Boolean = true): LogicalPlan = { plan match { - case _: GlobalLimit | _: LocalLimit | _: Offset | _: SubqueryAlias | _: Sort => - plan.mapChildren(resolveColumnDefault(_, expectedQuerySchema)) + case _: SubqueryAlias => + plan.mapChildren( + resolveColumnDefault(_, expectedQuerySchema, acceptProject, acceptInlineTable)) - case p: Project if p.child.resolved && p.containsPattern(UNRESOLVED_ATTRIBUTE) && + case _: GlobalLimit | _: LocalLimit | _: Offset | _: Sort if acceptProject => + plan.mapChildren( + resolveColumnDefault(_, expectedQuerySchema, acceptInlineTable = false)) + + case p: Project if acceptProject && p.child.resolved && + p.containsPattern(UNRESOLVED_ATTRIBUTE) && p.projectList.length <= expectedQuerySchema.length => - var changed = false val newProjectList = p.projectList.zipWithIndex.map { case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => - changed = true val field = expectedQuerySchema(i) Alias(getDefaultValueExpr(field).getOrElse(Literal(null, field.dataType)), u.name)() + case (other, _) if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() case (other, _) => other } - if (changed) { - val newProj = p.copy(projectList = newProjectList) - newProj.copyTagsFrom(p) - newProj - } else { - p - } + val newChild = resolveColumnDefault(p.child, expectedQuerySchema, acceptProject = false) + val newProj = p.copy(projectList = newProjectList, child = newChild) + newProj.copyTagsFrom(p) + newProj - case inlineTable: UnresolvedInlineTable - if inlineTable.containsPattern(UNRESOLVED_ATTRIBUTE) && - inlineTable.rows.forall(exprs => exprs.length <= expectedQuerySchema.length) => - var changed = false + case _: Project | _: Aggregate if acceptInlineTable => + plan.mapChildren(resolveColumnDefault(_, expectedQuerySchema, acceptProject = false)) + + case inlineTable: UnresolvedInlineTable if acceptInlineTable && + inlineTable.containsPattern(UNRESOLVED_ATTRIBUTE) && + inlineTable.rows.forall(exprs => exprs.length <= expectedQuerySchema.length) => val newRows = inlineTable.rows.map { exprs => exprs.zipWithIndex.map { case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => - changed = true val field = expectedQuerySchema(i) getDefaultValueExpr(field).getOrElse(Literal(null, field.dataType)) + case (other, _) if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() case (other, _) => other } } - if (changed) { - val newInlineTable = inlineTable.copy(rows = newRows) - newInlineTable.copyTagsFrom(inlineTable) - newInlineTable - } else { - inlineTable - } + val newInlineTable = inlineTable.copy(rows = newRows) + newInlineTable.copyTagsFrom(inlineTable) + newInlineTable case other => other } @@ -123,7 +132,7 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu * @param str the field name to normalize * @return the normalized result */ - def normalizeFieldName(str: String): String = { + private def normalizeFieldName(str: String): String = { if (SQLConf.get.caseSensitiveAnalysis) { str } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index 672bc4c68961..d486cc285530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{getDefaultValueExpr, isExplicitDefaultColumn} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExpr, isExplicitDefaultColumn} +import org.apache.spark.sql.errors.QueryCompilationErrors /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real @@ -52,6 +53,9 @@ case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutio resolved match { case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) + case other if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() case other => other } case _ => resolved diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index c6375d0a947e..fbf37588173e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -162,6 +162,16 @@ object ResolveDefaultColumns { col.name.equalsIgnoreCase(CURRENT_DEFAULT_COLUMN_NAME) } + /** + * Returns true if the given expression contains an explicit DEFAULT column reference. + */ + def containsExplicitDefaultColumn(expr: Expression): Boolean = { + expr.exists { + case u: UnresolvedAttribute => isExplicitDefaultColumn(u) + case _ => false + } + } + /** * Generates the expression of the default value for the given field. If there is no * user-specified default value for this field, returns None. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index ad0a17ef4f46..87b7e110fad8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, CreateStruct, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InsertIntoStatement, Join, LogicalPlan, SerdeInfo, Window} +import org.apache.spark.sql.catalyst.plans.logical.{Assignment, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} import org.apache.spark.sql.connector.catalog._ @@ -1731,17 +1731,6 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { "normalizedPartCols" -> normalizedPartCols.mkString(", "))) } - def mismatchedInsertedDataColumnNumberError( - tableName: String, insert: InsertIntoStatement, staticPartCols: Set[String]): Throwable = { - new AnalysisException( - errorClass = "INSERT_COLUMN_ARITY_MISMATCH", - messageParameters = Map( - "tableName" -> tableName, - "targetColumns" -> insert.table.output.size.toString, - "insertedColumns" -> (insert.query.output.length + staticPartCols.size).toString, - "staticPartCols" -> staticPartCols.size.toString)) - } - def requestedPartitionsMismatchTablePartitionsError( tableName: String, normalizedPartSpec: Map[String, Option[String]], @@ -1751,7 +1740,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { messageParameters = Map( "tableName" -> tableName, "normalizedPartSpec" -> normalizedPartSpec.keys.mkString(","), - "partColNames" -> partColNames.mkString(","))) + "partColNames" -> partColNames.map(_.name).mkString(","))) } def ddlWithoutHiveSupportEnabledError(detail: String): Throwable = { @@ -2074,11 +2063,14 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { } def cannotWriteTooManyColumnsToTableError( - tableName: String, expected: Seq[Attribute], query: LogicalPlan): Throwable = { + tableName: String, + expected: Seq[Attribute], + query: LogicalPlan): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1202", + errorClass = "INSERT_COLUMN_ARITY_MISMATCH", messageParameters = Map( "tableName" -> tableName, + "reason" -> "too many data columns", "tableColumns" -> expected.map(c => s"'${c.name}'").mkString(", "), "dataColumns" -> query.output.map(c => s"'${c.name}'").mkString(", "))) } @@ -2086,9 +2078,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { def cannotWriteNotEnoughColumnsToTableError( tableName: String, expected: Seq[Attribute], query: LogicalPlan): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1203", + errorClass = "INSERT_COLUMN_ARITY_MISMATCH", messageParameters = Map( "tableName" -> tableName, + "reason" -> "not enough data columns", "tableColumns" -> expected.map(c => s"'${c.name}'").mkString(", "), "dataColumns" -> query.output.map(c => s"'${c.name}'").mkString(", "))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 8a01c1a602b1..0b698ae07a28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -445,7 +445,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "too many data columns", - "Table columns (excluding columns with static partition values): 'x', 'y'", + "Table columns: 'x', 'y'", "Data columns: 'x', 'y', 'z'")) } @@ -525,7 +525,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "not enough data columns", - "Table columns (excluding columns with static partition values): 'x', 'y'", + "Table columns: 'x', 'y'", "Data columns: 'y'")) } @@ -539,7 +539,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "not enough data columns", - "Table columns (excluding columns with static partition values): 'x', 'y'", + "Table columns: 'x', 'y'", "Data columns: 'y'")) } @@ -574,7 +574,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write", "'table-name'", "too many data columns", - "Table columns (excluding columns with static partition values): 'x', 'y'", + "Table columns: 'x', 'y'", "Data columns: 'a', 'b', 'c'")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 69c7624605b4..dd79e9b26d47 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -164,7 +164,8 @@ object DataSourceAnalysis extends Rule[LogicalPlan] { InsertIntoDataSourceDirCommand(storage, provider.get, query, overwrite) case i @ InsertIntoStatement( - l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, _, query, overwrite, _) => + l @ LogicalRelation(t: HadoopFsRelation, _, table, _), parts, _, query, overwrite, _) + if query.resolved => // If the InsertIntoTable command is for a partitioned HadoopFsRelation and // the user has specified static partitions, we add a Project operator on top of the query // to include those constant column values in the query result. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index cedb58d8cbfd..430399aac8dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -394,8 +394,21 @@ object PreprocessTableInsertion extends ResolveInsertionBase { } else { insert.query } - val newQuery = TableOutputResolver.resolveOutputColumns( - tblName, expectedColumns, query, byName = isByName, conf, supportColDefaultValue = true) + val newQuery = try { + TableOutputResolver.resolveOutputColumns( + tblName, expectedColumns, query, byName = isByName, conf, supportColDefaultValue = true) + } catch { + case e: AnalysisException if staticPartCols.nonEmpty && + e.getErrorClass == "INSERT_COLUMN_ARITY_MISMATCH" => + val newException = e.copy( + errorClass = Some("INSERT_PARTITION_COLUMN_ARITY_MISMATCH"), + messageParameters = e.messageParameters ++ Map( + "tableColumns" -> insert.table.output.map(c => s"'${c.name}'").mkString(", "), + "staticPartCols" -> staticPartCols.toSeq.sorted.map(c => s"'$c'").mkString(", ") + )) + newException.setStackTrace(e.getStackTrace) + throw newException + } if (normalizedPartSpec.nonEmpty) { if (normalizedPartSpec.size != partColNames.length) { throw QueryCompilationErrors.requestedPartitionsMismatchTablePartitionsError( diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out index d32e2abe1565..a6408f945799 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/postgreSQL/numeric.sql.out @@ -3844,10 +3844,10 @@ org.apache.spark.sql.AnalysisException "errorClass" : "INSERT_COLUMN_ARITY_MISMATCH", "sqlState" : "21S01", "messageParameters" : { - "insertedColumns" : "5", - "staticPartCols" : "0", - "tableName" : "`spark_catalog`.`default`.`num_result`", - "targetColumns" : "3" + "dataColumns" : "'id', 'id', 'val', 'val', '(val * val)'", + "reason" : "too many data columns", + "tableColumns" : "'id1', 'id2', 'result'", + "tableName" : "`spark_catalog`.`default`.`num_result`" } } diff --git a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out index db81160bf033..5840e1164fa1 100644 --- a/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/postgreSQL/numeric.sql.out @@ -3835,10 +3835,10 @@ org.apache.spark.sql.AnalysisException "errorClass" : "INSERT_COLUMN_ARITY_MISMATCH", "sqlState" : "21S01", "messageParameters" : { - "insertedColumns" : "5", - "staticPartCols" : "0", - "tableName" : "`spark_catalog`.`default`.`num_result`", - "targetColumns" : "3" + "dataColumns" : "'id', 'id', 'val', 'val', '(val * val)'", + "reason" : "too many data columns", + "tableColumns" : "'id1', 'id2', 'result'", + "tableName" : "`spark_catalog`.`default`.`num_result`" } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index 968e91e31bdd..7893675a790c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1117,15 +1117,8 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT INTO $t1(data, data) VALUES(5)") }, - errorClass = "_LEGACY_ERROR_TEMP_2305", - parameters = Map( - "numCols" -> "3", - "rowSize" -> "2", - "ri" -> "0"), - context = ExpectedContext( - fragment = s"INSERT INTO $t1(data, data)", - start = 0, - stop = 26)) + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`")) } } @@ -1151,15 +1144,8 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "_LEGACY_ERROR_TEMP_2305", - parameters = Map( - "numCols" -> "3", - "rowSize" -> "2", - "ri" -> "0"), - context = ExpectedContext( - fragment = s"INSERT OVERWRITE $t1(data, data)", - start = 0, - stop = 31)) + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`")) } } @@ -1186,15 +1172,8 @@ class DataSourceV2SQLSuiteV1Filter exception = intercept[AnalysisException] { sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)") }, - errorClass = "_LEGACY_ERROR_TEMP_2305", - parameters = Map( - "numCols" -> "4", - "rowSize" -> "3", - "ri" -> "0"), - context = ExpectedContext( - fragment = s"INSERT OVERWRITE $t1(data, data)", - start = 0, - stop = 31)) + errorClass = "COLUMN_ALREADY_EXISTS", + parameters = Map("columnName" -> "`data`")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 013e7227aefc..5646c837ce5f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.{AliasIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AnalysisTest, Analyzer, EmptyFunctionRegistry, NoSuchTableException, ResolvedFieldName, ResolvedIdentifier, ResolvedTable, ResolveSessionCatalog, UnresolvedAttribute, UnresolvedInlineTable, UnresolvedRelation, UnresolvedSubqueryColumnAliases, UnresolvedTable} import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType, InMemoryCatalog, SessionCatalog} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, EvalMode, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EqualTo, Expression, InSubquery, IntegerLiteral, ListQuery, Literal, StringLiteral} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} import org.apache.spark.sql.catalyst.plans.logical.{AlterColumn, AnalysisOnlyCommand, AppendData, Assignment, CreateTable, CreateTableAsSelect, DeleteAction, DeleteFromTable, DescribeRelation, DropTable, InsertAction, InsertIntoStatement, LocalRelation, LogicalPlan, MergeIntoTable, OneRowRelation, OverwriteByExpression, OverwritePartitionsDynamic, Project, SetTableLocation, SetTableProperties, ShowTableProperties, SubqueryAlias, UnsetTableProperties, UpdateAction, UpdateTable} @@ -151,6 +151,7 @@ class PlanResolutionSuite extends AnalysisTest { case "defaultvalues" => defaultValues case "defaultvalues2" => defaultValues2 case "tablewithcolumnnameddefault" => tableWithColumnNamedDefault + case "v2TableWithAcceptAnySchemaCapability" => tableWithAcceptAnySchemaCapability case name => throw new NoSuchTableException(Seq(name)) } }) @@ -167,7 +168,6 @@ class PlanResolutionSuite extends AnalysisTest { case "v1HiveTable" => createV1TableMock(ident, provider = "hive") case "v2Table" => table case "v2Table1" => table1 - case "v2TableWithAcceptAnySchemaCapability" => tableWithAcceptAnySchemaCapability case "view" => createV1TableMock(ident, tableType = CatalogTableType.VIEW) case name => throw new NoSuchTableException(Seq(name)) } @@ -1023,12 +1023,12 @@ class PlanResolutionSuite extends AnalysisTest { val sql5 = s"UPDATE $tblName SET name=DEFAULT, age=DEFAULT" // Note: 'i' and 's' are the names of the columns in 'tblName'. val sql6 = s"UPDATE $tblName SET i=DEFAULT, s=DEFAULT" - val sql7 = s"UPDATE defaultvalues SET i=DEFAULT, s=DEFAULT" - val sql8 = s"UPDATE $tblName SET name='Robert', age=32 WHERE p=DEFAULT" - val sql9 = s"UPDATE defaultvalues2 SET i=DEFAULT" - // Note: 'i' is the correct column name, but since the table has ACCEPT_ANY_SCHEMA capability, - // DEFAULT column resolution should skip this table. - val sql10 = s"UPDATE v2TableWithAcceptAnySchemaCapability SET i=DEFAULT" + val sql7 = s"UPDATE testcat.defaultvalues SET i=DEFAULT, s=DEFAULT" + // UPDATE condition won't resolve column "DEFAULT" + val sql8 = s"UPDATE testcat.defaultvalues SET i=DEFAULT, s=DEFAULT WHERE i=DEFAULT" + val sql9 = s"UPDATE testcat.defaultvalues2 SET i=DEFAULT" + // Table with ACCEPT_ANY_SCHEMA can also resolve the column DEFAULT. + val sql10 = s"UPDATE testcat.v2TableWithAcceptAnySchemaCapability SET i=DEFAULT" val parsed1 = parseAndResolve(sql1) val parsed2 = parseAndResolve(sql2) @@ -1036,8 +1036,8 @@ class PlanResolutionSuite extends AnalysisTest { val parsed4 = parseAndResolve(sql4) val parsed5 = parseAndResolve(sql5) val parsed6 = parseAndResolve(sql6) - val parsed7 = parseAndResolve(sql7, true) - val parsed9 = parseAndResolve(sql9, true) + val parsed7 = parseAndResolve(sql7) + val parsed9 = parseAndResolve(sql9) val parsed10 = parseAndResolve(sql10) parsed1 match { @@ -1116,12 +1116,9 @@ class PlanResolutionSuite extends AnalysisTest { // Note that when resolving DEFAULT column references, the analyzer will insert literal // NULL values if the corresponding table does not define an explicit default value for // that column. This is intended. - Assignment(i: AttributeReference, - cast1 @ Cast(Literal(null, _), IntegerType, _, EvalMode.ANSI)), - Assignment(s: AttributeReference, - cast2 @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI))), - None) if cast1.getTagValue(Cast.BY_TABLE_INSERTION).isDefined && - cast2.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Assignment(i: AttributeReference, Literal(null, IntegerType)), + Assignment(s: AttributeReference, Literal(null, StringType))), + None) => assert(i.name == "i") assert(s.name == "s") @@ -1143,38 +1140,34 @@ class PlanResolutionSuite extends AnalysisTest { checkError( exception = intercept[AnalysisException] { - parseAndResolve(sql8) + parseAndResolve(sql8, checkAnalysis = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1341", - parameters = Map.empty) + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`DEFAULT`", "proposal" -> "`i`, `s`"), + context = ExpectedContext( + fragment = "DEFAULT", + start = 62, + stop = 68)) parsed9 match { case UpdateTable( - _, - Seq(Assignment(i: AttributeReference, - cast @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI))), - None) if cast.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + _, + Seq(Assignment(i: AttributeReference, Literal(null, StringType))), + None) => assert(i.name == "i") case _ => fail("Expect UpdateTable, but got:\n" + parsed9.treeString) } parsed10 match { - case u: UpdateTable => - assert(u.assignments.size == 1) - u.assignments(0).key match { - case i: AttributeReference => - assert(i.name == "i") - } - u.assignments(0).value match { - case d: UnresolvedAttribute => - assert(d.name == "DEFAULT") - } + case UpdateTable( + _, + Seq(Assignment(i: AttributeReference, Literal(null, IntegerType))), + None) => + assert(i.name == "i") - case _ => - fail("Expect UpdateTable, but got:\n" + parsed10.treeString) + case _ => fail("Expect UpdateTable, but got:\n" + parsed10.treeString) } - } val sql1 = "UPDATE non_existing SET id=1" @@ -1766,22 +1759,16 @@ class PlanResolutionSuite extends AnalysisTest { second match { case UpdateAction(Some(EqualTo(_: AttributeReference, StringLiteral("update"))), Seq( - Assignment(_: AttributeReference, - cast @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI)), - Assignment(_: AttributeReference, _: AttributeReference))) - if cast.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Assignment(_: AttributeReference, Literal(null, StringType)), + Assignment(_: AttributeReference, _: AttributeReference))) => case other => fail("unexpected second matched action " + other) } assert(m.notMatchedActions.length == 1) val negative = m.notMatchedActions(0) negative match { case InsertAction(Some(EqualTo(_: AttributeReference, StringLiteral("insert"))), - Seq(Assignment(i: AttributeReference, - cast1 @ Cast(Literal(null, _), IntegerType, _, EvalMode.ANSI)), - Assignment(s: AttributeReference, - cast2 @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI)))) - if cast1.getTagValue(Cast.BY_TABLE_INSERTION).isDefined && - cast2.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Seq(Assignment(i: AttributeReference, Literal(null, IntegerType)), + Assignment(s: AttributeReference, Literal(null, StringType)))) => assert(i.name == "i") assert(s.name == "s") case other => fail("unexpected not matched action " + other) @@ -1793,9 +1780,7 @@ class PlanResolutionSuite extends AnalysisTest { } m.notMatchedBySourceActions(1) match { case UpdateAction(Some(EqualTo(_: AttributeReference, StringLiteral("update"))), - Seq(Assignment(_: AttributeReference, - cast @ Cast(Literal(null, _), StringType, _, EvalMode.ANSI)))) - if cast.getTagValue(Cast.BY_TABLE_INSERTION).isDefined => + Seq(Assignment(_: AttributeReference, Literal(null, StringType)))) => case other => fail("unexpected second not matched by source action " + other) } @@ -1805,8 +1790,8 @@ class PlanResolutionSuite extends AnalysisTest { } // DEFAULT column reference in the merge condition: - // This MERGE INTO command includes an ON clause with a DEFAULT column reference. This is - // invalid and returns an error message. + // This MERGE INTO command includes an ON clause with a DEFAULT column reference. This + // DEFAULT column won't be resolved. val mergeWithDefaultReferenceInMergeCondition = s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source @@ -1821,14 +1806,19 @@ class PlanResolutionSuite extends AnalysisTest { | THEN UPDATE SET target.s = DEFAULT""".stripMargin checkError( exception = intercept[AnalysisException] { - parseAndResolve(mergeWithDefaultReferenceInMergeCondition) + parseAndResolve(mergeWithDefaultReferenceInMergeCondition, checkAnalysis = true) }, - errorClass = "_LEGACY_ERROR_TEMP_1342", - parameters = Map.empty) + errorClass = "UNRESOLVED_COLUMN.WITH_SUGGESTION", + parameters = Map("objectName" -> "`DEFAULT`", + "proposal" -> "`target`.`i`, `source`.`i`, `target`.`s`, `source`.`s`"), + context = ExpectedContext( + fragment = "DEFAULT", + start = 76, + stop = 82)) // DEFAULT column reference within a complex expression: // This MERGE INTO command includes a WHEN MATCHED clause with a DEFAULT column reference as - // of a complex expression (DEFAULT + 1). This is invalid and returns an error message. + // of a complex expression (DEFAULT + 1). This is invalid and column won't be resolved. val mergeWithDefaultReferenceAsPartOfComplexExpression = s"""MERGE INTO testcat.tab AS target |USING testcat.tab1 AS source @@ -1890,7 +1880,7 @@ class PlanResolutionSuite extends AnalysisTest { // values. This test case covers that behavior. val mergeDefaultWithExplicitDefaultColumns = s""" - |MERGE INTO defaultvalues AS target + |MERGE INTO testcat.defaultvalues AS target |USING testcat.tab1 AS source |ON target.i = source.i |WHEN MATCHED AND (target.s = 31) THEN DELETE @@ -1902,7 +1892,7 @@ class PlanResolutionSuite extends AnalysisTest { |WHEN NOT MATCHED BY SOURCE AND (target.s = 31) | THEN UPDATE SET target.s = DEFAULT """.stripMargin - parseAndResolve(mergeDefaultWithExplicitDefaultColumns, true) match { + parseAndResolve(mergeDefaultWithExplicitDefaultColumns) match { case m: MergeIntoTable => val cond = m.mergeCondition cond match { @@ -2218,54 +2208,29 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("MERGE INTO TABLE - skip resolution on v2 tables that accept any schema") { + test("MERGE INTO TABLE - skip filling missing cols on v2 tables that accept any schema") { val sql = s""" - |MERGE INTO v2TableWithAcceptAnySchemaCapability AS target + |MERGE INTO testcat.v2TableWithAcceptAnySchemaCapability AS target |USING v2Table AS source |ON target.i = source.i - |WHEN MATCHED AND (target.s='delete')THEN DELETE - |WHEN MATCHED AND (target.s='update') THEN UPDATE SET target.s = source.s - |WHEN NOT MATCHED AND (target.s=DEFAULT) - | THEN INSERT (target.i, target.s) values (source.i, source.s) - |WHEN NOT MATCHED BY SOURCE AND (target.s='delete') THEN DELETE - |WHEN NOT MATCHED BY SOURCE AND (target.s='update') THEN UPDATE SET target.s = target.i + |WHEN MATCHED THEN DELETE + |WHEN NOT MATCHED THEN INSERT (target.i) values (DEFAULT) """.stripMargin parseAndResolve(sql) match { case MergeIntoTable( SubqueryAlias(AliasIdentifier("target", Seq()), AsDataSourceV2Relation(_)), SubqueryAlias(AliasIdentifier("source", Seq()), AsDataSourceV2Relation(_)), - EqualTo(l: UnresolvedAttribute, r: UnresolvedAttribute), - Seq( - DeleteAction(Some(EqualTo(dl: UnresolvedAttribute, StringLiteral("delete")))), - UpdateAction( - Some(EqualTo(ul: UnresolvedAttribute, StringLiteral("update"))), - firstUpdateAssigns)), - Seq( - InsertAction( - Some(EqualTo(il: UnresolvedAttribute, UnresolvedAttribute(Seq("DEFAULT")))), - insertAssigns)), - Seq( - DeleteAction(Some(EqualTo(ndl: UnresolvedAttribute, StringLiteral("delete")))), - UpdateAction( - Some(EqualTo(nul: UnresolvedAttribute, StringLiteral("update"))), - secondUpdateAssigns))) => - assert(l.name == "target.i" && r.name == "source.i") - assert(dl.name == "target.s") - assert(ul.name == "target.s") - assert(il.name == "target.s") - assert(ndl.name == "target.s") - assert(nul.name == "target.s") - assert(firstUpdateAssigns.size == 1) - assert(firstUpdateAssigns.head.key.asInstanceOf[UnresolvedAttribute].name == "target.s") - assert(firstUpdateAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "source.s") - assert(insertAssigns.size == 2) + _, + Seq(DeleteAction(None)), + Seq(InsertAction(None, insertAssigns)), + Nil) => + // There is only one assignment, the missing col is not filled with default value + assert(insertAssigns.size == 1) + // Special case: Spark does not resolve any columns in MERGE if table accepts any schema. assert(insertAssigns.head.key.asInstanceOf[UnresolvedAttribute].name == "target.i") - assert(insertAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "source.i") - assert(secondUpdateAssigns.size == 1) - assert(secondUpdateAssigns.head.key.asInstanceOf[UnresolvedAttribute].name == "target.s") - assert(secondUpdateAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "target.i") + assert(insertAssigns.head.value.asInstanceOf[UnresolvedAttribute].name == "DEFAULT") case l => fail("Expected unresolved MergeIntoTable, but got:\n" + l.treeString) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 3c469b989184..be2399a46ccf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -142,17 +142,6 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { ) } - test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[AnalysisException] { - sql( - s""" - |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt - """.stripMargin) - }.getMessage - assert(message.contains("target table has 2 column(s) but the inserted data has 1 column(s)") - ) - } - test("INSERT OVERWRITE a JSONRelation multiple times") { sql( s""" @@ -642,16 +631,8 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { msg = intercept[AnalysisException] { sql("insert into t select 1, 2.0D, 3") }.getMessage - assert(msg.contains("`t` requires that the data to be inserted have the same number of " + - "columns as the target table: target table has 2 column(s)" + - " but the inserted data has 3 column(s)")) - - msg = intercept[AnalysisException] { - sql("insert into t select 1") - }.getMessage - assert(msg.contains("`t` requires that the data to be inserted have the same number of " + - "columns as the target table: target table has 2 column(s)" + - " but the inserted data has 1 column(s)")) + assert(msg.contains( + "Cannot write to '`spark_catalog`.`default`.`t`', too many data columns")) // Insert into table successfully. sql("insert into t select 1, 2.0D") @@ -863,42 +844,39 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } test("Allow user to insert specified columns into insertable view") { - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - sql("INSERT OVERWRITE TABLE jsonTable SELECT a, DEFAULT FROM jt") - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, null)) - ) + sql("INSERT OVERWRITE TABLE jsonTable SELECT a, DEFAULT FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, null)) + ) - sql("INSERT OVERWRITE TABLE jsonTable(a) SELECT a FROM jt") - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(i, null)) - ) + sql("INSERT OVERWRITE TABLE jsonTable(a) SELECT a FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, null)) + ) - sql("INSERT OVERWRITE TABLE jsonTable(b) SELECT b FROM jt") - checkAnswer( - sql("SELECT a, b FROM jsonTable"), - (1 to 10).map(i => Row(null, s"str$i")) - ) - } + sql("INSERT OVERWRITE TABLE jsonTable(b) SELECT b FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(null, s"str$i")) + ) - val message = intercept[AnalysisException] { - sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") - }.getMessage - assert(message.contains("target table has 2 column(s) but the inserted data has 1 column(s)")) + withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "false") { + val message = intercept[AnalysisException] { + sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") + }.getMessage + assert(message.contains("Cannot write to 'unknown', not enough data columns")) + } } test("SPARK-38336 INSERT INTO statements with tables with default columns: positive tests") { - // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is enabled, and no - // explicit DEFAULT value is available when the INSERT INTO statement provides fewer - // values than expected, NULL values are appended in their place. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t") { - sql("create table t(i boolean, s bigint) using parquet") - sql("insert into t(i) values(true)") - checkAnswer(spark.table("t"), Row(true, null)) - } + // When the INSERT INTO statement provides fewer values than expected, NULL values are appended + // in their place. + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + sql("insert into t(i) values(true)") + checkAnswer(spark.table("t"), Row(true, null)) } // The default value for the DEFAULT keyword is the NULL literal. withTable("t") { @@ -924,6 +902,11 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t(i) values(1)") checkAnswer(sql("select s + x from t where i = 1"), Seq(85L).map(i => Row(i))) } + withTable("t") { + sql("create table t(i int, s bigint default 42, x bigint) using parquet") + sql("insert into t values(1)") + checkAnswer(spark.table("t"), Row(1, 42L, null)) + } // The table has a partitioning column and a default value is injected. withTable("t") { sql("create table t(i boolean, s bigint, q int default 42) using parquet partitioned by (i)") @@ -998,45 +981,43 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // There are three column types exercising various combinations of implicit and explicit // default column value references in the 'insert into' statements. Note these tests depend on // enabling the configuration to use NULLs for missing DEFAULT column values. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - for (useDataFrames <- Seq(false, true)) { - withTable("t1", "t2") { - sql("create table t1(j int, s bigint default 42, x bigint default 43) using parquet") - if (useDataFrames) { - Seq((1, 42, 43)).toDF.write.insertInto("t1") - Seq((2, 42, 43)).toDF.write.insertInto("t1") - Seq((3, 42, 43)).toDF.write.insertInto("t1") - Seq((4, 44, 43)).toDF.write.insertInto("t1") - Seq((5, 44, 43)).toDF.write.insertInto("t1") - } else { - sql("insert into t1(j) values(1)") - sql("insert into t1(j, s) values(2, default)") - sql("insert into t1(j, s, x) values(3, default, default)") - sql("insert into t1(j, s) values(4, 44)") - sql("insert into t1(j, s, x) values(5, 44, 45)") - } - sql("create table t2(j int, s bigint default 42, x bigint default 43) using parquet") - if (useDataFrames) { - spark.table("t1").where("j = 1").write.insertInto("t2") - spark.table("t1").where("j = 2").write.insertInto("t2") - spark.table("t1").where("j = 3").write.insertInto("t2") - spark.table("t1").where("j = 4").write.insertInto("t2") - spark.table("t1").where("j = 5").write.insertInto("t2") - } else { - sql("insert into t2(j) select j from t1 where j = 1") - sql("insert into t2(j, s) select j, default from t1 where j = 2") - sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") - sql("insert into t2(j, s) select j, s from t1 where j = 4") - sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") - } - checkAnswer( - spark.table("t2"), - Row(1, 42L, 43L) :: - Row(2, 42L, 43L) :: - Row(3, 42L, 43L) :: - Row(4, 44L, 43L) :: - Row(5, 44L, 43L) :: Nil) + for (useDataFrames <- Seq(false, true)) { + withTable("t1", "t2") { + sql("create table t1(j int, s bigint default 42, x bigint default 43) using parquet") + if (useDataFrames) { + Seq((1, 42, 43)).toDF.write.insertInto("t1") + Seq((2, 42, 43)).toDF.write.insertInto("t1") + Seq((3, 42, 43)).toDF.write.insertInto("t1") + Seq((4, 44, 43)).toDF.write.insertInto("t1") + Seq((5, 44, 43)).toDF.write.insertInto("t1") + } else { + sql("insert into t1(j) values(1)") + sql("insert into t1(j, s) values(2, default)") + sql("insert into t1(j, s, x) values(3, default, default)") + sql("insert into t1(j, s) values(4, 44)") + sql("insert into t1(j, s, x) values(5, 44, 45)") + } + sql("create table t2(j int, s bigint default 42, x bigint default 43) using parquet") + if (useDataFrames) { + spark.table("t1").where("j = 1").write.insertInto("t2") + spark.table("t1").where("j = 2").write.insertInto("t2") + spark.table("t1").where("j = 3").write.insertInto("t2") + spark.table("t1").where("j = 4").write.insertInto("t2") + spark.table("t1").where("j = 5").write.insertInto("t2") + } else { + sql("insert into t2(j) select j from t1 where j = 1") + sql("insert into t2(j, s) select j, default from t1 where j = 2") + sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") + sql("insert into t2(j, s) select j, s from t1 where j = 4") + sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") } + checkAnswer( + spark.table("t2"), + Row(1, 42L, 43L) :: + Row(2, 42L, 43L) :: + Row(3, 42L, 43L) :: + Row(4, 44L, 43L) :: + Row(5, 44L, 43L) :: Nil) } } } @@ -1113,7 +1094,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t select t1.id, t2.id, t1.val, t2.val, t1.val * t2.val " + "from num_data t1, num_data t2") }.getMessage.contains( - "requires that the data to be inserted have the same number of columns as the target")) + "Cannot write to '`spark_catalog`.`default`.`t`', too many data columns")) } // The default value is disabled per configuration. withTable("t") { @@ -1123,13 +1104,6 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }.getMessage.contains("Support for DEFAULT column values is not allowed")) } } - // There is one trailing default value referenced implicitly by the INSERT INTO statement. - withTable("t") { - sql("create table t(i int, s bigint default 42, x bigint) using parquet") - assert(intercept[AnalysisException] { - sql("insert into t values(1)") - }.getMessage.contains("target table has 3 column(s) but the inserted data has 1 column(s)")) - } // The table has a partitioning column with a default value; this is not allowed. withTable("t") { sql("create table t(i boolean default true, s bigint, q int default 42) " + @@ -1152,7 +1126,8 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t values(true)") - }.getMessage.contains("target table has 2 column(s) but the inserted data has 1 column(s)")) + }.getMessage.contains( + "Cannot write to '`spark_catalog`.`default`.`t`', not enough data columns")) } } } @@ -1185,48 +1160,43 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { Row(4, 43, false), Row(4, 42, false))) } - // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is enabled, and no - // explicit DEFAULT value is available when the INSERT INTO statement provides fewer + // If no explicit DEFAULT value is available when the INSERT INTO statement provides fewer // values than expected, NULL values are appended in their place. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t") { - sql("create table t(i boolean, s bigint) using parquet") - sql("insert into t (i) values (true)") - checkAnswer(spark.table("t"), Row(true, null)) - } - withTable("t") { - sql("create table t(i boolean default true, s bigint) using parquet") - sql("insert into t (i) values (default)") - checkAnswer(spark.table("t"), Row(true, null)) - } - withTable("t") { - sql("create table t(i boolean, s bigint default 42) using parquet") - sql("insert into t (s) values (default)") - checkAnswer(spark.table("t"), Row(null, 42L)) - } - withTable("t") { - sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") - sql("insert into t partition(i='true') (s) values(5)") - sql("insert into t partition(i='false') (q) select 43") - sql("insert into t partition(i='false') (q) select default") - checkAnswer(spark.table("t"), - Seq(Row(5, null, true), - Row(null, 43, false), - Row(null, null, false))) - } + withTable("t") { + sql("create table t(i boolean, s bigint) using parquet") + sql("insert into t (i) values (true)") + checkAnswer(spark.table("t"), Row(true, null)) + } + withTable("t") { + sql("create table t(i boolean default true, s bigint) using parquet") + sql("insert into t (i) values (default)") + checkAnswer(spark.table("t"), Row(true, null)) + } + withTable("t") { + sql("create table t(i boolean, s bigint default 42) using parquet") + sql("insert into t (s) values (default)") + checkAnswer(spark.table("t"), Row(null, 42L)) + } + withTable("t") { + sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") + sql("insert into t partition(i='true') (s) values(5)") + sql("insert into t partition(i='false') (q) select 43") + sql("insert into t partition(i='false') (q) select default") + checkAnswer(spark.table("t"), + Seq(Row(5, null, true), + Row(null, 43, false), + Row(null, null, false))) } } test("SPARK- 38795 INSERT INTO with user specified columns and defaults: negative tests") { - val addOneColButExpectedTwo = "target table has 2 column(s) but the inserted data has 1 col" - val addTwoColButExpectedThree = "target table has 3 column(s) but the inserted data has 2 col" // The missing columns in these INSERT INTO commands do not have explicit default values. withTable("t") { sql("create table t(i boolean, s bigint, q int default 43) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i, q) select true from (select 1)") }.getMessage.contains("Cannot write to table due to mismatched user specified column " + - "size(3) and data column size(2)")) + "size(2) and data column size(1)")) } // When the USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES configuration is disabled, and no // explicit DEFAULT value is available when the INSERT INTO statement provides fewer @@ -1236,37 +1206,37 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i) values (true)") - }.getMessage.contains(addOneColButExpectedTwo)) + }.getMessage.contains("Cannot find data for output column 's'")) } withTable("t") { sql("create table t(i boolean default true, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i) values (default)") - }.getMessage.contains(addOneColButExpectedTwo)) + }.getMessage.contains("Cannot find data for output column 's'")) } withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") assert(intercept[AnalysisException] { sql("insert into t (s) values (default)") - }.getMessage.contains(addOneColButExpectedTwo)) + }.getMessage.contains("Cannot find data for output column 'i'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='true') (s) values(5)") - }.getMessage.contains(addTwoColButExpectedThree)) + }.getMessage.contains("Cannot find data for output column 'q'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select 43") - }.getMessage.contains(addTwoColButExpectedThree)) + }.getMessage.contains("Cannot find data for output column 's'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select default") - }.getMessage.contains(addTwoColButExpectedThree)) + }.getMessage.contains("Cannot find data for output column 's'")) } } // When the CASE_SENSITIVE configuration is enabled, then using different cases for the required @@ -1311,6 +1281,13 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("insert into t(i) values(1)") checkAnswer(spark.table("t"), Row(1, 42, 43)) } + withTable("t") { + sql(createTableIntCol) + sql("alter table t add column s bigint default 42") + sql("alter table t add column x bigint") + sql("insert into t values(1)") + checkAnswer(spark.table("t"), Row(1, 42, null)) + } // The table has a partitioning column and a default value is injected. withTable("t") { sql("create table t(i boolean, s bigint) using parquet partitioned by (i)") @@ -1368,31 +1345,29 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { // There are three column types exercising various combinations of implicit and explicit // default column value references in the 'insert into' statements. Note these tests depend on // enabling the configuration to use NULLs for missing DEFAULT column values. - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t1", "t2") { - sql("create table t1(j int) using parquet") - sql("alter table t1 add column s bigint default 42") - sql("alter table t1 add column x bigint default 43") - sql("insert into t1(j) values(1)") - sql("insert into t1(j, s) values(2, default)") - sql("insert into t1(j, s, x) values(3, default, default)") - sql("insert into t1(j, s) values(4, 44)") - sql("insert into t1(j, s, x) values(5, 44, 45)") - sql("create table t2(j int) using parquet") - sql("alter table t2 add columns s bigint default 42, x bigint default 43") - sql("insert into t2(j) select j from t1 where j = 1") - sql("insert into t2(j, s) select j, default from t1 where j = 2") - sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") - sql("insert into t2(j, s) select j, s from t1 where j = 4") - sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") - checkAnswer( - spark.table("t2"), - Row(1, 42L, 43L) :: - Row(2, 42L, 43L) :: - Row(3, 42L, 43L) :: - Row(4, 44L, 43L) :: - Row(5, 44L, 43L) :: Nil) - } + withTable("t1", "t2") { + sql("create table t1(j int) using parquet") + sql("alter table t1 add column s bigint default 42") + sql("alter table t1 add column x bigint default 43") + sql("insert into t1(j) values(1)") + sql("insert into t1(j, s) values(2, default)") + sql("insert into t1(j, s, x) values(3, default, default)") + sql("insert into t1(j, s) values(4, 44)") + sql("insert into t1(j, s, x) values(5, 44, 45)") + sql("create table t2(j int) using parquet") + sql("alter table t2 add columns s bigint default 42, x bigint default 43") + sql("insert into t2(j) select j from t1 where j = 1") + sql("insert into t2(j, s) select j, default from t1 where j = 2") + sql("insert into t2(j, s, x) select j, default, default from t1 where j = 3") + sql("insert into t2(j, s) select j, s from t1 where j = 4") + sql("insert into t2(j, s, x) select j, s, default from t1 where j = 5") + checkAnswer( + spark.table("t2"), + Row(1, 42L, 43L) :: + Row(2, 42L, 43L) :: + Row(3, 42L, 43L) :: + Row(4, 44L, 43L) :: + Row(5, 44L, 43L) :: Nil) } } @@ -1439,15 +1414,6 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { }.getMessage.contains("Support for DEFAULT column values is not allowed")) } } - // There is one trailing default value referenced implicitly by the INSERT INTO statement. - withTable("t") { - sql("create table t(i int) using parquet") - sql("alter table t add column s bigint default 42") - sql("alter table t add column x bigint") - assert(intercept[AnalysisException] { - sql("insert into t values(1)") - }.getMessage.contains("target table has 3 column(s) but the inserted data has 1 column(s)")) - } } test("SPARK-38838 INSERT INTO with defaults set by ALTER TABLE ALTER COLUMN: positive tests") { @@ -2263,14 +2229,12 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { checkAnswer(spark.table("t1"), Row(1, "str1")) } - withSQLConf(SQLConf.USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES.key -> "true") { - withTable("t1") { - sql("CREATE TABLE t1(c1 int, c2 string, c3 int) using parquet") - sql("INSERT INTO TABLE t1(c1, c2) select * from jt where a=1") - checkAnswer(spark.table("t1"), Row(1, "str1", null)) - sql("INSERT INTO TABLE t1 select *, 2 from jt where a=2") - checkAnswer(spark.table("t1"), Seq(Row(1, "str1", null), Row(2, "str2", 2))) - } + withTable("t1") { + sql("CREATE TABLE t1(c1 int, c2 string, c3 int) using parquet") + sql("INSERT INTO TABLE t1(c1, c2) select * from jt where a=1") + checkAnswer(spark.table("t1"), Row(1, "str1", null)) + sql("INSERT INTO TABLE t1 select *, 2 from jt where a=2") + checkAnswer(spark.table("t1"), Seq(Row(1, "str1", null), Row(2, "str2", 2))) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 10f18a9ef2e8..4eae3933bf51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1258,13 +1258,11 @@ class HiveQuerySuite extends HiveComparisonTest with SQLTestUtils with BeforeAnd """INSERT INTO TABLE dp_test PARTITION(dp) |SELECT key, value, key % 5 FROM src""".stripMargin) }, - errorClass = "INSERT_COLUMN_ARITY_MISMATCH", - sqlState = "21S01", + errorClass = "_LEGACY_ERROR_TEMP_1169", parameters = Map( "tableName" -> "`spark_catalog`.`default`.`dp_test`", - "targetColumns" -> "4", - "insertedColumns" -> "3", - "staticPartCols" -> "0")) + "normalizedPartSpec" -> "dp", + "partColNames" -> "dp,sp")) sql("SET hive.exec.dynamic.partition.mode=nonstrict") From a1a9cc85abc2cd0fb2e0269ff169c27a27f1d5bb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 May 2023 23:12:02 +0800 Subject: [PATCH 4/6] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 20 +++++-------- .../ResolveColumnDefaultInInsert.scala | 30 +++++++++++++++++-- .../analysis/ResolveReferencesInUpdate.scala | 21 ++++--------- .../util/ResolveDefaultColumnsUtil.scala | 26 ++++++++++++++++ 4 files changed, 66 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0f739f1c4946..74c602db09f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1756,20 +1756,14 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case MergeResolvePolicy.SOURCE => Project(Nil, mergeInto.sourceTable) case MergeResolvePolicy.TARGET => Project(Nil, mergeInto.targetTable) } - resolvedKey match { - case attr: AttributeReference => - val resolvedExpr = resolveExprInAssignment(c, resolvePlan) match { - case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => - getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) - case other if containsExplicitDefaultColumn(other) => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates() - case other => other - } - checkResolvedMergeExpr(resolvedExpr, resolvePlan) - resolvedExpr - case _ => resolveMergeExprOrFail(c, resolvePlan) + val resolvedExpr = resolveExprInAssignment(c, resolvePlan) + val withDefaultResolved = if (conf.enableDefaultColumns) { + resolveColumnDefaultInAssignmentValue(resolvedKey, resolvedExpr) + } else { + resolvedExpr } + checkResolvedMergeExpr(withDefaultResolved, resolvePlan) + withDefaultResolved case o => o } Assignment(resolvedKey, resolvedValue) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala index 1f0f452fbd85..b6e77efb8efd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala @@ -45,15 +45,24 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu def apply(plan: LogicalPlan): LogicalPlan = plan match { case i: InsertIntoStatement if conf.enableDefaultColumns && i.table.resolved && i.query.containsPattern(UNRESOLVED_ATTRIBUTE) => - val staticPartCols = i.partitionSpec.filter(_._2.isDefined).keys.map(normalizeFieldName).toSet + val staticPartCols = i.partitionSpec.filter(_._2.isDefined).keySet.map(normalizeFieldName) + // For INSERT with static partitions, such as `INSERT INTO t PARTITION(c=1) SELECT ...`, the + // input query schema should match the table schema excluding columns with static + // partition values. val expectedQuerySchema = i.table.schema.filter { field => !staticPartCols.contains(normalizeFieldName(field.name)) } + // Normally, we should match the query schema with the table schema by position. If the n-th + // column of the query is the DEFAULT column, we should get the default value expression + // defined for the n-th column of the table. However, if the INSERT has a column list, such as + // `INSERT INTO t(b, c, a)`, the matching should be by name. For example, the first column of + // the query should match the column 'b' of the table. + // To simplify the implementation, `resolveColumnDefault` always does by-position match. If + // the INSERT has a column list, we reorder the table schema w.r.t. the column list and pass + // the reordered schema as the expected schema to `resolveColumnDefault`. if (i.userSpecifiedCols.isEmpty) { i.withNewChildren(Seq(resolveColumnDefault(i.query, expectedQuerySchema))) } else { - // Reorder the fields in `expectedQuerySchema` according to the user-specified column list - // of the INSERT command. val colNamesToFields: Map[String, StructField] = expectedQuerySchema.map { field => normalizeFieldName(field.name) -> field }.toMap @@ -70,6 +79,21 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu case _ => plan } + /** + * Resolves the column "DEFAULT" in [[Project]] and [[UnresolvedInlineTable]]. A column is a + * "DEFAULT" column if all the following conditions are met: + * 1. The expression inside project list or inline table expressions is a single + * [[UnresolvedAttribute]] with name "DEFAULT". This means `SELECT DEFAULT, ...` is valid but + * `SELECT DEFAULT + 1, ...` is not. + * 2. The project list or inline table expressions have less elements than the expected schema. + * To find the default value definition, we need to find the matching column for expressions + * inside project list or inline table expressions. This matching is by position and it + * doesn't make sense if we have more expressions than the columns of expected schema. + * 3. The plan nodes between [[Project]] and [[InsertIntoStatement]] are + * all unary nodes that inherit the output columns from its child. + * 4. The plan nodes between [[UnresolvedInlineTable]] and [[InsertIntoStatement]] are either + * [[Project]], or [[Aggregate]], or [[SubqueryAlias]]. + */ private def resolveColumnDefault( plan: LogicalPlan, expectedQuerySchema: Seq[StructField], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index d486cc285530..b6383b5112e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -18,15 +18,13 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExpr, isExplicitDefaultColumn} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.resolveColumnDefaultInAssignmentValue /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real * rule `ResolveReferences`. The column resolution order for [[UpdateTable]] is: - * 1. Resolves the column to [[AttributeReference]] with the output of the child plan. This + * 1. Resolves the column to `AttributeReference`` with the output of the child plan. This * includes metadata columns as well. * 2. Resolves the column to a literal function which is allowed to be invoked without braces, e.g. * `SELECT col, current_date FROM t`. @@ -48,17 +46,10 @@ case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutio val resolvedValue = assign.value match { case c if !c.resolved => val resolved = resolveExprInAssignment(c, u) - resolvedKey match { - case attr: AttributeReference if conf.enableDefaultColumns => - resolved match { - case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => - getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) - case other if containsExplicitDefaultColumn(other) => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() - case other => other - } - case _ => resolved + if (conf.enableDefaultColumns) { + resolveColumnDefaultInAssignmentValue(resolvedKey, resolved) + } else { + resolved } case o => o } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index fbf37588173e..d5a06531429d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -172,6 +172,32 @@ object ResolveDefaultColumns { } } + /** + * Resolves the column "DEFAULT" in UPDATE/MERGE assignment value expression if the following + * conditions are met: + * 1. The assignment value expression is a single `UnresolvedAttribute` with name "DEFAULT". This + * means `key = DEFAULT` is allowed but `key = DEFAULT + 1` is not. + * 2. The assignment key expression is a top-level column. This means `col = DEFAULT` is allowed + * but `col.field = DEFAULT` is not. + * + * The column "DEFAULT" will be resolved to the default value expression defined for the column of + * the assignment key. + */ + def resolveColumnDefaultInAssignmentValue(key: Expression, value: Expression): Expression = { + key match { + case attr: AttributeReference => + value match { + case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => + getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) + case other if containsExplicitDefaultColumn(other) => + throw QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() + case other => other + } + case _ => value + } + } + /** * Generates the expression of the default value for the given field. If there is no * user-specified default value for this field, returns None. From e1d03ab7cf537886e754abae26c9bf22cb719f97 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 May 2023 14:15:28 +0800 Subject: [PATCH 5/6] address comments --- .../catalyst/analysis/AssignmentUtils.scala | 4 +- .../ResolveColumnDefaultInInsert.scala | 10 ++-- .../analysis/TableOutputResolver.scala | 6 +-- .../util/ResolveDefaultColumnsUtil.scala | 52 +++++++++---------- .../spark/sql/sources/InsertSuite.scala | 19 ++++--- 5 files changed, 48 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index f3f2a94c7478..069cef6b3610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.plans.logical.Assignment import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLiteral +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{DataType, StructType} @@ -104,7 +104,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { case assignment if assignment.key.semanticEquals(attr) => assignment } val resolvedValue = if (matchingAssignments.isEmpty) { - val defaultExpr = getDefaultValueExprOrNullLiteral(attr, conf) + val defaultExpr = getDefaultValueExprOrNullLit(attr, conf) if (defaultExpr.isEmpty) { errors += s"No assignment for '${attr.name}'" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala index b6e77efb8efd..f79196649266 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveColumnDefaultInInsert.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} +import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_ATTRIBUTE -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExpr, isExplicitDefaultColumn} +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.{containsExplicitDefaultColumn, getDefaultValueExprOrNullLit, isExplicitDefaultColumn} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructField @@ -113,8 +113,7 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu p.projectList.length <= expectedQuerySchema.length => val newProjectList = p.projectList.zipWithIndex.map { case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => - val field = expectedQuerySchema(i) - Alias(getDefaultValueExpr(field).getOrElse(Literal(null, field.dataType)), u.name)() + Alias(getDefaultValueExprOrNullLit(expectedQuerySchema(i)), u.name)() case (other, _) if containsExplicitDefaultColumn(other) => throw QueryCompilationErrors .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() @@ -134,8 +133,7 @@ case object ResolveColumnDefaultInInsert extends SQLConfHelper with ColumnResolu val newRows = inlineTable.rows.map { exprs => exprs.zipWithIndex.map { case (u: UnresolvedAttribute, i) if isExplicitDefaultColumn(u) => - val field = expectedQuerySchema(i) - getDefaultValueExpr(field).getOrElse(Literal(null, field.dataType)) + getDefaultValueExprOrNullLit(expectedQuerySchema(i)) case (other, _) if containsExplicitDefaultColumn(other) => throw QueryCompilationErrors .defaultReferencesNotAllowedInComplexExpressionsInInsertValuesList() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 2051487f1a68..b9aca30c754d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLiteral +import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -67,7 +67,7 @@ object TableOutputResolver { val fillDefaultValue = supportColDefaultValue && actualExpectedCols.size > query.output.size val queryOutputCols = if (fillDefaultValue) { query.output ++ actualExpectedCols.drop(query.output.size).flatMap { expectedCol => - getDefaultValueExprOrNullLiteral(expectedCol, conf) + getDefaultValueExprOrNullLit(expectedCol, conf) } } else { query.output @@ -185,7 +185,7 @@ object TableOutputResolver { val newColPath = colPath :+ expectedCol.name if (matched.isEmpty) { val defaultExpr = if (fillDefaultValue) { - getDefaultValueExprOrNullLiteral(expectedCol, conf) + getDefaultValueExprOrNullLit(expectedCol, conf) } else { None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index d5a06531429d..645759656831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -188,7 +188,7 @@ object ResolveDefaultColumns { case attr: AttributeReference => value match { case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => - getDefaultValueExpr(attr).getOrElse(Literal(null, attr.dataType)) + getDefaultValueExprOrNullLit(attr) case other if containsExplicitDefaultColumn(other) => throw QueryCompilationErrors .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() @@ -198,11 +198,7 @@ object ResolveDefaultColumns { } } - /** - * Generates the expression of the default value for the given field. If there is no - * user-specified default value for this field, returns None. - */ - def getDefaultValueExpr(field: StructField): Option[Expression] = { + private def getDefaultValueExprOpt(field: StructField): Option[Expression] = { if (field.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { Some(analyze(field, "INSERT")) } else { @@ -211,32 +207,36 @@ object ResolveDefaultColumns { } /** - * Generates the expression of the default value for the given column. If there is no - * user-specified default value for this field, returns None. + * Generates the expression of the default value for the given field. If there is no + * user-specified default value for this field, returns null literal. */ - def getDefaultValueExpr(attr: Attribute): Option[Expression] = { - if (attr.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { - val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) - Some(analyze(field, "INSERT")) - } else { - None - } + def getDefaultValueExprOrNullLit(field: StructField): Expression = { + getDefaultValueExprOpt(field).getOrElse(Literal(null, field.dataType)) } /** * Generates the expression of the default value for the given column. If there is no - * user-specified default value for this column, returns a null literal. + * user-specified default value for this field, returns null literal. */ - def getDefaultValueExprOrNullLiteral(attr: Attribute, conf: SQLConf): Option[NamedExpression] = { - val defaultExprOpt = if (attr.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY)) { - val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) - Some(analyze(field, "INSERT")) - } else if (conf.useNullsForMissingDefaultColumnValues) { - Some(Literal(null, attr.dataType)) - } else { - None - } - defaultExprOpt.map(expr => Alias(expr, attr.name)()) + def getDefaultValueExprOrNullLit(attr: Attribute): Expression = { + val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + getDefaultValueExprOrNullLit(field) + } + + /** + * Generates the aliased expression of the default value for the given column. If there is no + * user-specified default value for this column, returns a null literal or None w.r.t. the config + * `USE_NULLS_FOR_MISSING_DEFAULT_COLUMN_VALUES`. + */ + def getDefaultValueExprOrNullLit(attr: Attribute, conf: SQLConf): Option[NamedExpression] = { + val field = StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) + getDefaultValueExprOpt(field).orElse { + if (conf.useNullsForMissingDefaultColumnValues) { + Some(Literal(null, attr.dataType)) + } else { + None + } + }.map(expr => Alias(expr, attr.name)()) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index be2399a46ccf..312b2723fffd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -850,6 +850,12 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { (1 to 10).map(i => Row(i, null)) ) + sql("INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt") + checkAnswer( + sql("SELECT a, b FROM jsonTable"), + (1 to 10).map(i => Row(i, null)) + ) + sql("INSERT OVERWRITE TABLE jsonTable(a) SELECT a FROM jt") checkAnswer( sql("SELECT a, b FROM jsonTable"), @@ -1190,6 +1196,7 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } test("SPARK- 38795 INSERT INTO with user specified columns and defaults: negative tests") { + val missingColError = "Cannot find data for output column " // The missing columns in these INSERT INTO commands do not have explicit default values. withTable("t") { sql("create table t(i boolean, s bigint, q int default 43) using parquet") @@ -1206,37 +1213,37 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { sql("create table t(i boolean, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i) values (true)") - }.getMessage.contains("Cannot find data for output column 's'")) + }.getMessage.contains(missingColError + "'s'")) } withTable("t") { sql("create table t(i boolean default true, s bigint) using parquet") assert(intercept[AnalysisException] { sql("insert into t (i) values (default)") - }.getMessage.contains("Cannot find data for output column 's'")) + }.getMessage.contains(missingColError + "'s'")) } withTable("t") { sql("create table t(i boolean, s bigint default 42) using parquet") assert(intercept[AnalysisException] { sql("insert into t (s) values (default)") - }.getMessage.contains("Cannot find data for output column 'i'")) + }.getMessage.contains(missingColError + "'i'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='true') (s) values(5)") - }.getMessage.contains("Cannot find data for output column 'q'")) + }.getMessage.contains(missingColError + "'q'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select 43") - }.getMessage.contains("Cannot find data for output column 's'")) + }.getMessage.contains(missingColError + "'s'")) } withTable("t") { sql("create table t(i boolean, s bigint, q int) using parquet partitioned by (i)") assert(intercept[AnalysisException] { sql("insert into t partition(i='false') (q) select default") - }.getMessage.contains("Cannot find data for output column 's'")) + }.getMessage.contains(missingColError + "'s'")) } } // When the CASE_SENSITIVE configuration is enabled, then using different cases for the required From 69c3ae2d78dbd95b88f1df9fd6aebd8082ea7a34 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 May 2023 23:19:10 +0800 Subject: [PATCH 6/6] fix test --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 6 +++++- .../sql/catalyst/analysis/ResolveReferencesInUpdate.scala | 7 ++++++- .../sql/catalyst/util/ResolveDefaultColumnsUtil.scala | 8 +++++--- 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 74c602db09f0..3644a50acf10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1758,7 +1758,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } val resolvedExpr = resolveExprInAssignment(c, resolvePlan) val withDefaultResolved = if (conf.enableDefaultColumns) { - resolveColumnDefaultInAssignmentValue(resolvedKey, resolvedExpr) + resolveColumnDefaultInAssignmentValue( + resolvedKey, + resolvedExpr, + QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInMergeInsertsOrUpdates()) } else { resolvedExpr } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala index b6383b5112e7..cebc1e25f921 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInUpdate.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.plans.logical.{Assignment, UpdateTable} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.resolveColumnDefaultInAssignmentValue +import org.apache.spark.sql.errors.QueryCompilationErrors /** * A virtual rule to resolve [[UnresolvedAttribute]] in [[UpdateTable]]. It's only used by the real @@ -47,7 +48,11 @@ case object ResolveReferencesInUpdate extends SQLConfHelper with ColumnResolutio case c if !c.resolved => val resolved = resolveExprInAssignment(c, u) if (conf.enableDefaultColumns) { - resolveColumnDefaultInAssignmentValue(resolvedKey, resolved) + resolveColumnDefaultInAssignmentValue( + resolvedKey, + resolved, + QueryCompilationErrors + .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause()) } else { resolved } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala index 645759656831..c006dde13922 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ResolveDefaultColumnsUtil.scala @@ -183,15 +183,17 @@ object ResolveDefaultColumns { * The column "DEFAULT" will be resolved to the default value expression defined for the column of * the assignment key. */ - def resolveColumnDefaultInAssignmentValue(key: Expression, value: Expression): Expression = { + def resolveColumnDefaultInAssignmentValue( + key: Expression, + value: Expression, + invalidColumnDefaultException: Throwable): Expression = { key match { case attr: AttributeReference => value match { case u: UnresolvedAttribute if isExplicitDefaultColumn(u) => getDefaultValueExprOrNullLit(attr) case other if containsExplicitDefaultColumn(other) => - throw QueryCompilationErrors - .defaultReferencesNotAllowedInComplexExpressionsInUpdateSetClause() + throw invalidColumnDefaultException case other => other } case _ => value