Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -271,32 +271,33 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
/**
* Updates an inline table to generate missing default column values.
*/
private def addMissingDefaultValuesForInsertFromInlineTable(
def addMissingDefaultValuesForInsertFromInlineTable(
node: LogicalPlan,
insertTableSchemaWithoutPartitionColumns: StructType,
numUserSpecifiedColumns: Int): LogicalPlan = {
val schema = insertTableSchemaWithoutPartitionColumns
val newDefaultExpressions: Seq[Expression] =
getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
val newNames: Seq[String] = if (numUserSpecifiedColumns > 0) {
schema.fields.drop(numUserSpecifiedColumns).map(_.name)
} else {
schema.fields.map(_.name)
}
val newNames: Seq[String] = schema.fields.map(_.name)
node match {
case _ if newDefaultExpressions.isEmpty => node
case table: UnresolvedInlineTable =>
table.copy(
names = table.names ++ newNames,
names = newNames,
rows = table.rows.map { row => row ++ newDefaultExpressions })
case local: LocalRelation =>
// Note that we have consumed a LocalRelation but return an UnresolvedInlineTable, because
// addMissingDefaultValuesForInsertFromProject must replace unresolved DEFAULT references.
UnresolvedInlineTable(
local.output.map(_.name) ++ newNames,
newNames,
local.data.map { row =>
val colTypes = StructType(local.output.map(col => StructField(col.name, col.dataType)))
row.toSeq(colTypes).map(Literal(_)) ++ newDefaultExpressions
val values: Seq[Any] = row.toSeq(colTypes)
val dataTypes: Seq[DataType] = colTypes.map(_.dataType)
val literals: Seq[Literal] = values.zip(dataTypes).map {
case (value, dataType) => Literal(value, dataType)
}
literals ++ newDefaultExpressions
})
case _ => node
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @dtenedor and @gengliangwang . I have a question.

Although I understand this test suite provide a test coverage for org.apache.spark.sql.catalyst.analysis.ResolveDefaultColumns, it doesn't mean this test suite is belong to org.apache.spark.sql.catalyst.analysis package. This test suite exists in sql module and alone in this directory

$ tree sql/core/src/test/scala/org/apache/spark/sql/catalyst
sql/core/src/test/scala/org/apache/spark/sql/catalyst
└── analysis
    └── ResolveDefaultColumnsSuite.scala

Is this intentional?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @dongjoon-hyun I don't think this is intentional, we could move the ResolveDefaultColumnsSuite to org.apache.spark.sql package. What do you think? If you want me to do this, I can prepare a PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{StructField, StructType, TimestampType}

class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession {
val rule = ResolveDefaultColumns(catalog = null)

test("Assign correct types when adding DEFAULTs for inserting from a VALUES list") {
// This is the internal storage for the timestamp 2020-12-31 00:00:00.0.
val literal = Literal(1609401600000000L, TimestampType)
val table = UnresolvedInlineTable(
names = Seq("attr1"),
rows = Seq(Seq(literal)))
val node = ResolveInlineTables(table).asInstanceOf[LocalRelation]

assert(node.output.map(_.dataType) == Seq(TimestampType))
assert(node.data.size == 1)

val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
StructField("c1", TimestampType),
StructField("c2", TimestampType)))
val result = rule.addMissingDefaultValuesForInsertFromInlineTable(
node, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1)
val inlineTable: UnresolvedInlineTable = result match {
case u: UnresolvedInlineTable => u
case _ => fail(s"invalid result operator type: $result")
}
assert(inlineTable.names == Seq("c1", "c2"))
assert(inlineTable.rows == Seq(
Seq(literal, UnresolvedAttribute("DEFAULT"))))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1100,9 +1100,15 @@ class DataSourceV2SQLSuiteV1Filter
exception = intercept[AnalysisException] {
sql(s"INSERT INTO $t1(data, data) VALUES(5)")
},
errorClass = "COLUMN_ALREADY_EXISTS",
parameters = Map("columnName" -> "`data`")
)
errorClass = "_LEGACY_ERROR_TEMP_2305",
parameters = Map(
"numCols" -> "3",
"rowSize" -> "2",
"ri" -> "0"),
context = ExpectedContext(
fragment = s"INSERT INTO $t1(data, data)",
start = 0,
stop = 26))
}
}

Expand All @@ -1123,14 +1129,20 @@ class DataSourceV2SQLSuiteV1Filter
assert(intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1 VALUES(4)")
}.getMessage.contains("not enough data columns"))
// Duplicate columns
// Duplicate columns
checkError(
exception = intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
},
errorClass = "COLUMN_ALREADY_EXISTS",
parameters = Map("columnName" -> "`data`")
)
errorClass = "_LEGACY_ERROR_TEMP_2305",
parameters = Map(
"numCols" -> "3",
"rowSize" -> "2",
"ri" -> "0"),
context = ExpectedContext(
fragment = s"INSERT OVERWRITE $t1(data, data)",
start = 0,
stop = 31))
}
}

Expand All @@ -1152,14 +1164,20 @@ class DataSourceV2SQLSuiteV1Filter
assert(intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1 VALUES('a', 4)")
}.getMessage.contains("not enough data columns"))
// Duplicate columns
// Duplicate columns
checkError(
exception = intercept[AnalysisException] {
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
},
errorClass = "COLUMN_ALREADY_EXISTS",
parameters = Map("columnName" -> "`data`")
)
errorClass = "_LEGACY_ERROR_TEMP_2305",
parameters = Map(
"numCols" -> "4",
"rowSize" -> "3",
"ri" -> "0"),
context = ExpectedContext(
fragment = s"INSERT OVERWRITE $t1(data, data)",
start = 0,
stop = 31))
}
}

Expand Down