diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 1398552399cd..5b559becbb11 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -83,13 +83,9 @@ object TableOutputResolver extends SQLConfHelper with Logging { // TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well. supportColDefaultValue: Boolean = false): LogicalPlan = { - val actualExpectedCols = expected.map { attr => - attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) - } - - if (actualExpectedCols.size < query.output.size) { + if (expected.size < query.output.size) { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError( - tableName, actualExpectedCols.map(_.name), query.output) + tableName, expected.map(_.name), query.output) } val errors = new mutable.ArrayBuffer[String]() @@ -100,21 +96,21 @@ object TableOutputResolver extends SQLConfHelper with Logging { reorderColumnsByName( tableName, query.output, - actualExpectedCols, + expected, conf, errors += _, fillDefaultValue = supportColDefaultValue) } else { - if (actualExpectedCols.size > query.output.size) { + if (expected.size > query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( - tableName, actualExpectedCols.map(_.name), query.output) + tableName, expected.map(_.name), query.output) } - resolveColumnsByPosition(tableName, query.output, actualExpectedCols, conf, errors += _) + resolveColumnsByPosition(tableName, query.output, expected, conf, errors += _) } if (errors.nonEmpty) { throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError( - tableName, actualExpectedCols.map(_.name).map(toSQLId).mkString(", ")) + tableName, expected.map(_.name).map(toSQLId).mkString(", ")) } if (resolved == query.output) { @@ -246,22 +242,25 @@ object TableOutputResolver extends SQLConfHelper with Logging { case a: Alias => a.withName(expectedName) case other => other } - (matchedCol.dataType, expectedCol.dataType) match { + val actualExpectedCol = expectedCol.withDataType { + CharVarcharUtils.getRawType(expectedCol.metadata).getOrElse(expectedCol.dataType) + } + (matchedCol.dataType, actualExpectedCol.dataType) match { case (matchedType: StructType, expectedType: StructType) => resolveStructType( - tableName, matchedCol, matchedType, expectedCol, expectedType, + tableName, matchedCol, matchedType, actualExpectedCol, expectedType, byName = true, conf, addError, newColPath) case (matchedType: ArrayType, expectedType: ArrayType) => resolveArrayType( - tableName, matchedCol, matchedType, expectedCol, expectedType, + tableName, matchedCol, matchedType, actualExpectedCol, expectedType, byName = true, conf, addError, newColPath) case (matchedType: MapType, expectedType: MapType) => resolveMapType( - tableName, matchedCol, matchedType, expectedCol, expectedType, + tableName, matchedCol, matchedType, actualExpectedCol, expectedType, byName = true, conf, addError, newColPath) case _ => checkField( - tableName, expectedCol, matchedCol, byName = true, conf, addError, newColPath) + tableName, actualExpectedCol, matchedCol, byName = true, conf, addError, newColPath) } } } @@ -292,26 +291,28 @@ object TableOutputResolver extends SQLConfHelper with Logging { conf: SQLConf, addError: String => Unit, colPath: Seq[String] = Nil): Seq[NamedExpression] = { - - if (inputCols.size > expectedCols.size) { - val extraColsStr = inputCols.takeRight(inputCols.size - expectedCols.size) + val actualExpectedCols = expectedCols.map { attr => + attr.withDataType { CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) } + } + if (inputCols.size > actualExpectedCols.size) { + val extraColsStr = inputCols.takeRight(inputCols.size - actualExpectedCols.size) .map(col => toSQLId(col.name)) .mkString(", ") if (colPath.isEmpty) { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(tableName, - expectedCols.map(_.name), inputCols.map(_.toAttribute)) + actualExpectedCols.map(_.name), inputCols.map(_.toAttribute)) } else { throw QueryCompilationErrors.incompatibleDataToTableExtraStructFieldsError( tableName, colPath.quoted, extraColsStr ) } - } else if (inputCols.size < expectedCols.size) { - val missingColsStr = expectedCols.takeRight(expectedCols.size - inputCols.size) + } else if (inputCols.size < actualExpectedCols.size) { + val missingColsStr = actualExpectedCols.takeRight(actualExpectedCols.size - inputCols.size) .map(col => toSQLId(col.name)) .mkString(", ") if (colPath.isEmpty) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(tableName, - expectedCols.map(_.name), inputCols.map(_.toAttribute)) + actualExpectedCols.map(_.name), inputCols.map(_.toAttribute)) } else { throw QueryCompilationErrors.incompatibleDataToTableStructMissingFieldsError( tableName, colPath.quoted, missingColsStr @@ -319,7 +320,7 @@ object TableOutputResolver extends SQLConfHelper with Logging { } } - inputCols.zip(expectedCols).flatMap { case (inputCol, expectedCol) => + inputCols.zip(actualExpectedCols).flatMap { case (inputCol, expectedCol) => val newColPath = colPath :+ expectedCol.name (inputCol.dataType, expectedCol.dataType) match { case (inputType: StructType, expectedType: StructType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index a93dee3bf2a6..5df46ea101c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -661,6 +661,19 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-48792: Fix INSERT with partial column list to a table with char/varchar") { + assume(format != "foo", + "TODO: TableOutputResolver.resolveOutputColumns supportColDefaultValue is false") + Seq("char", "varchar").foreach { typ => + withTable("students") { + sql(s"CREATE TABLE students (name $typ(64), address $typ(64)) USING $format") + sql("INSERT INTO students VALUES ('Kent Yao', 'Hangzhou')") + sql("INSERT INTO students (address) VALUES ('')") + checkAnswer(sql("SELECT count(*) FROM students"), Row(2)) + } + } + } } // Some basic char/varchar tests which doesn't rely on table implementation.