Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
insertTableSchemaWithoutPartitionColumns.map { schema: StructType =>
val regenerated: InsertIntoStatement =
regenerateUserSpecifiedCols(i, schema)
val expanded: LogicalPlan =
val (expanded: LogicalPlan, addedDefaults: Boolean) =
addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.size)
val replaced: Option[LogicalPlan] =
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults)
replaced.map { r: LogicalPlan =>
node = r
for (child <- children.reverse) {
Expand All @@ -131,10 +131,10 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
insertTableSchemaWithoutPartitionColumns.map { schema =>
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema)
val project: Project = i.query.asInstanceOf[Project]
val expanded: Project =
val (expanded: Project, addedDefaults: Boolean) =
addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.size)
val replaced: Option[LogicalPlan] =
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults)
replaced.map { r =>
regenerated.copy(query = r)
}.getOrElse(i)
Expand Down Expand Up @@ -270,67 +270,83 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl

/**
* Updates an inline table to generate missing default column values.
* Returns the resulting plan plus a boolean indicating whether such values were added.
*/
private def addMissingDefaultValuesForInsertFromInlineTable(
def addMissingDefaultValuesForInsertFromInlineTable(
node: LogicalPlan,
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedColumns: Int): LogicalPlan = {
numUserSpecifiedColumns: Int): (LogicalPlan, Boolean) = {
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
val newNames: Seq[String] = if (numUserSpecifiedColumns > 0) {
schema.fields.drop(numUserSpecifiedColumns).map(_.name)
} else {
schema.fields.map(_.name)
}
node match {
case _ if newDefaultExpressions.isEmpty => node
val newDefaultExpressions: Seq[UnresolvedAttribute] =
getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, node.output.size)
val newNames: Seq[String] = schema.fields.map(_.name)
val resultPlan: LogicalPlan = node match {
case _ if newDefaultExpressions.isEmpty =>
node
case table: UnresolvedInlineTable =>
table.copy(
names = table.names ++ newNames,
names = newNames,
rows = table.rows.map { row => row ++ newDefaultExpressions })
case local: LocalRelation =>
// Note that we have consumed a LocalRelation but return an UnresolvedInlineTable, because
// addMissingDefaultValuesForInsertFromProject must replace unresolved DEFAULT references.
UnresolvedInlineTable(
local.output.map(_.name) ++ newNames,
local.data.map { row =>
val colTypes = StructType(local.output.map(col => StructField(col.name, col.dataType)))
row.toSeq(colTypes).map(Literal(_)) ++ newDefaultExpressions
val newDefaultExpressionsRow = new GenericInternalRow(
// Note that this code path only runs when there is a user-specified column list of fewer
// column than the target table; otherwise, the above 'newDefaultExpressions' is empty and
// we match the first case in this list instead.
schema.fields.drop(local.output.size).map {
case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
analyze(f, "INSERT") match {
case lit: Literal => lit.value
case _ => null
}
case _ => null
})
case _ => node
LocalRelation(
output = schema.toAttributes,
data = local.data.map { row =>
new JoinedRow(row, newDefaultExpressionsRow)
})
case _ =>
node
}
(resultPlan, newDefaultExpressions.nonEmpty)
}

/**
* Adds a new expressions to a projection to generate missing default column values.
* Returns the logical plan plus a boolean indicating if such defaults were added.
*/
private def addMissingDefaultValuesForInsertFromProject(
project: Project,
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedColumns: Int): Project = {
numUserSpecifiedColumns: Int): (Project, Boolean) = {
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, project.projectList.size)
val newAliases: Seq[NamedExpression] =
newDefaultExpressions.zip(schema.fields).map {
case (expr, field) => Alias(expr, field.name)()
}
project.copy(projectList = project.projectList ++ newAliases)
(project.copy(projectList = project.projectList ++ newAliases),
newDefaultExpressions.nonEmpty)
}

/**
* This is a helper for the addMissingDefaultValuesForInsertFromInlineTable methods above.
*/
private def getDefaultExpressionsForInsert(
schema: StructType,
numUserSpecifiedColumns: Int): Seq[Expression] = {
private def getNewDefaultExpressionsForInsert(
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedColumns: Int,
numProvidedValues: Int): Seq[UnresolvedAttribute] = {
val remainingFields: Seq[StructField] = if (numUserSpecifiedColumns > 0) {
schema.fields.drop(numUserSpecifiedColumns)
insertTableSchemaWithoutPartitionColumns.fields.drop(numUserSpecifiedColumns)
} else {
Seq.empty
}
val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size
// Limit the number of new DEFAULT expressions to the difference of the number of columns in
// the target table and the number of provided values in the source relation. This clamps the
// total final number of provided values to the number of columns in the target table.
.min(insertTableSchemaWithoutPartitionColumns.size - numProvidedValues)
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
}

Expand All @@ -351,7 +367,8 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
*/
private def replaceExplicitDefaultValuesForInputOfInsertInto(
insertTableSchemaWithoutPartitionColumns: StructType,
input: LogicalPlan): Option[LogicalPlan] = {
input: LogicalPlan,
addedDefaults: Boolean): Option[LogicalPlan] = {
val schema = insertTableSchemaWithoutPartitionColumns
val defaultExpressions: Seq[Expression] = schema.fields.map {
case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT")
Expand All @@ -371,7 +388,11 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
case project: Project =>
replaceExplicitDefaultValuesForProject(defaultExpressions, project)
case local: LocalRelation =>
Some(local)
if (addedDefaults) {
Some(local)
} else {
None
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @dtenedor and @gengliangwang . I have a question.

Although I understand this test suite provide a test coverage for org.apache.spark.sql.catalyst.analysis.ResolveDefaultColumns, it doesn't mean this test suite is belong to org.apache.spark.sql.catalyst.analysis package. This test suite exists in sql module and alone in this directory

$ tree sql/core/src/test/scala/org/apache/spark/sql/catalyst
sql/core/src/test/scala/org/apache/spark/sql/catalyst
└── analysis
    └── ResolveDefaultColumnsSuite.scala

Is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dongjoon-hyun I don't think this is intentional, we could move the ResolveDefaultColumnsSuite to org.apache.spark.sql package. What do you think? If you want me to do this, I can prepare a PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{StructField, StructType, TimestampType}

class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession {
val rule = ResolveDefaultColumns(catalog = null)
// This is the internal storage for the timestamp 2020-12-31 00:00:00.0.
val literal = Literal(1609401600000000L, TimestampType)
val table = UnresolvedInlineTable(
names = Seq("attr1"),
rows = Seq(Seq(literal)))
val localRelation = ResolveInlineTables(table).asInstanceOf[LocalRelation]

def asLocalRelation(result: LogicalPlan): LocalRelation = result match {
case r: LocalRelation => r
case _ => fail(s"invalid result operator type: $result")
}

test("SPARK-43018: Add DEFAULTs for INSERT from VALUES list with user-defined columns") {
// Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with one user-specified
// column. We add a default value of NULL to the row as a result.
val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
StructField("c1", TimestampType),
StructField("c2", TimestampType)))
val (result: LogicalPlan, _: Boolean) =
rule.addMissingDefaultValuesForInsertFromInlineTable(
localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1)
val relation = asLocalRelation(result)
assert(relation.output.map(_.name) == Seq("c1", "c2"))
val data: Seq[Seq[Any]] = relation.data.map { row =>
row.toSeq(StructType(relation.output.map(col => StructField(col.name, col.dataType))))
}
assert(data == Seq(Seq(literal.value, null)))
}

test("SPARK-43018: Add no DEFAULTs for INSERT from VALUES list with no user-defined columns") {
// Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with zero user-specified
// columns. The table is unchanged because there are no default columns to add in this case.
val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
StructField("c1", TimestampType),
StructField("c2", TimestampType)))
val (result: LogicalPlan, _: Boolean) =
rule.addMissingDefaultValuesForInsertFromInlineTable(
localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 0)
assert(asLocalRelation(result) == localRelation)
}

test("SPARK-43018: INSERT timestamp values into a table with column DEFAULTs") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Especially, this test case tests more than catalyst/analysis.

withTable("t") {
sql("create table t(id int, ts timestamp) using parquet")
sql("insert into t (ts) values (timestamp'2020-12-31')")
checkAnswer(spark.table("t"),
sql("select null, timestamp'2020-12-31'").collect().head)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,15 @@ class DataSourceV2SQLSuiteV1Filter
exception = intercept[AnalysisException] {
sql(s"INSERT INTO $t1(data, data) VALUES(5)")
},
errorClass = "COLUMN_ALREADY_EXISTS",
parameters = Map("columnName" -> "`data`")
)
errorClass = "_LEGACY_ERROR_TEMP_2305",
parameters = Map(
"numCols" -> "3",
"rowSize" -> "2",
"ri" -> "0"),
context = ExpectedContext(
fragment = s"INSERT INTO $t1(data, data)",
start = 0,
stop = 26))
}
}

Expand All @@ -1123,14 +1129,20 @@ class DataSourceV2SQLSuiteV1Filter
assert(intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1 VALUES(4)")
}.getMessage.contains("not enough data columns"))
// Duplicate columns
// Duplicate columns
checkError(
exception = intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
},
errorClass = "COLUMN_ALREADY_EXISTS",
parameters = Map("columnName" -> "`data`")
)
errorClass = "_LEGACY_ERROR_TEMP_2305",
parameters = Map(
"numCols" -> "3",
"rowSize" -> "2",
"ri" -> "0"),
context = ExpectedContext(
fragment = s"INSERT OVERWRITE $t1(data, data)",
start = 0,
stop = 31))
}
}

Expand All @@ -1152,14 +1164,20 @@ class DataSourceV2SQLSuiteV1Filter
assert(intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1 VALUES('a', 4)")
}.getMessage.contains("not enough data columns"))
// Duplicate columns
// Duplicate columns
checkError(
exception = intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
},
errorClass = "COLUMN_ALREADY_EXISTS",
parameters = Map("columnName" -> "`data`")
)
errorClass = "_LEGACY_ERROR_TEMP_2305",
parameters = Map(
"numCols" -> "4",
"rowSize" -> "3",
"ri" -> "0"),
context = ExpectedContext(
fragment = s"INSERT OVERWRITE $t1(data, data)",
start = 0,
stop = 31))
}
}

Expand Down