diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 118ed4e99190c..0296483adc20b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -273,9 +273,8 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { outputs: Seq[Seq[Expression]], colOrdinals: Seq[Int], attrs: Seq[Attribute]): ProjectingInternalRow = { - val schema = StructType(attrs.zipWithIndex.map { case (attr, index) => - val nullable = outputs.exists(output => output(colOrdinals(index)).nullable) - StructField(attr.name, attr.dataType, nullable, attr.metadata) + val schema = StructType(attrs.zipWithIndex.map { case (attr, _) => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) }) ProjectingInternalRow(schema, colOrdinals) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala index e93d4165be332..af4fc5c80ed4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala @@ -82,4 +82,55 @@ class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite insertWriteLogEntry(data = Row(6, 0, "new"))) } } + + test("SPARK-51479: Test Column Nullable") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = 1000 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 1000, "hr"), // update + Row(4, 1000, "hr"), // update + Row(5, 1000, "hr"), // update + Row(6, 0, "new"))) // insert + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + deleteWriteLogEntry(id = 1, metadata = Row("hr", null)), + deleteWriteLogEntry(id = 3, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(3, 1000, "hr")), + deleteWriteLogEntry(id = 4, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(4, 1000, "hr")), + deleteWriteLogEntry(id = 5, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(5, 1000, "hr")), + insertWriteLogEntry(data = Row(6, 0, "new"))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala index 612a26e756abd..a768be4f872f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala @@ -45,10 +45,7 @@ class DeltaBasedUpdateAsDeleteAndInsertTableSuite extends DeltaBasedUpdateTableS Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) @@ -82,10 +79,7 @@ class DeltaBasedUpdateAsDeleteAndInsertTableSuite extends DeltaBasedUpdateTableS Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala index c9fd5d6e3ff0d..3443f6fdc433b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala @@ -44,10 +44,7 @@ class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase { Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) @@ -80,10 +77,7 @@ class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase { Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE))))