From 8467ba0fdde323760e11a235da48f15412290535 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Tue, 11 Mar 2025 23:42:13 -0700 Subject: [PATCH 1/4] [SPARK-51479][SQL] Nullable in Row Level Operation Column is not correct --- .../analysis/RewriteRowLevelCommand.scala | 24 +++++---- ...ntoTableUpdateAsDeleteAndInsertSuite.scala | 51 +++++++++++++++++++ 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index 118ed4e99190c..de96e229f4a5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -190,11 +190,11 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val outputs = extractOutputs(plan) val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION)) - val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) + val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs, true) val metadataProjection = if (metadataAttrs.nonEmpty) { val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION)) - Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs, false)) } else { None } @@ -211,7 +211,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val rowProjection = if (rowAttrs.nonEmpty) { val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW) - Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) + Some(newLazyProjection(plan, outputsWithRow, rowAttrs, true)) } else { None } @@ -221,7 +221,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val metadataProjection = if (metadataAttrs.nonEmpty) { val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA) - Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs, false)) } else { None } @@ -251,9 +251,10 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { private def newLazyProjection( plan: LogicalPlan, outputs: Seq[Seq[Expression]], - attrs: Seq[Attribute]): ProjectingInternalRow = { + attrs: Seq[Attribute], + isRowAttr: Boolean): ProjectingInternalRow = { val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name)) - createProjectingInternalRow(outputs, colOrdinals, attrs) + createProjectingInternalRow(outputs, colOrdinals, attrs, isRowAttr) } // if there are assignment to row ID attributes, original values are projected as special columns @@ -266,15 +267,20 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name) if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name) } - createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs) + createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs, true) } private def createProjectingInternalRow( outputs: Seq[Seq[Expression]], colOrdinals: Seq[Int], - attrs: Seq[Attribute]): ProjectingInternalRow = { + attrs: Seq[Attribute], + isRowAttr: Boolean): ProjectingInternalRow = { val schema = StructType(attrs.zipWithIndex.map { case (attr, index) => - val nullable = outputs.exists(output => output(colOrdinals(index)).nullable) + val nullable = if (!isRowAttr) { + outputs.exists(output => output(colOrdinals(index)).nullable) + } else { + attr.nullable + } StructField(attr.name, attr.dataType, nullable, attr.metadata) }) ProjectingInternalRow(schema, colOrdinals) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala index e93d4165be332..af4fc5c80ed4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite.scala @@ -82,4 +82,55 @@ class DeltaBasedMergeIntoTableUpdateAsDeleteAndInsertSuite insertWriteLogEntry(data = Row(6, 0, "new"))) } } + + test("SPARK-51479: Test Column Nullable") { + withTempView("source") { + createAndInitTable("pk INT NOT NULL, salary INT, dep STRING", + """{ "pk": 1, "salary": 100, "dep": "hr" } + |{ "pk": 2, "salary": 200, "dep": "software" } + |{ "pk": 3, "salary": 300, "dep": "hr" } + |{ "pk": 4, "salary": 400, "dep": "hr" } + |{ "pk": 5, "salary": 500, "dep": "hr" } + |""".stripMargin) + + val sourceDF = Seq(3, 4, 5, 6).toDF("pk") + sourceDF.createOrReplaceTempView("source") + + sql( + s"""MERGE INTO $tableNameAsString t + |USING source s + |ON t.pk = s.pk + |WHEN MATCHED THEN + | UPDATE SET t.salary = 1000 + |WHEN NOT MATCHED THEN + | INSERT (pk, salary, dep) VALUES (s.pk, 0, 'new') + |WHEN NOT MATCHED BY SOURCE AND pk = 1 THEN + | DELETE + |""".stripMargin) + + checkAnswer( + sql(s"SELECT * FROM $tableNameAsString"), + Seq( + Row(2, 200, "software"), // unchanged + Row(3, 1000, "hr"), // update + Row(4, 1000, "hr"), // update + Row(5, 1000, "hr"), // update + Row(6, 0, "new"))) // insert + + checkLastWriteInfo( + expectedRowSchema = table.schema, + expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), + expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) + + checkLastWriteLog( + deleteWriteLogEntry(id = 1, metadata = Row("hr", null)), + deleteWriteLogEntry(id = 3, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(3, 1000, "hr")), + deleteWriteLogEntry(id = 4, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(4, 1000, "hr")), + deleteWriteLogEntry(id = 5, metadata = Row("hr", null)), + reinsertWriteLogEntry(metadata = Row("hr", null), data = Row(5, 1000, "hr")), + insertWriteLogEntry(data = Row(6, 0, "new"))) + } + } } From 6de21b90e63a347d29f4a04dee66079494d5f9dc Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 12 Mar 2025 12:03:32 -0700 Subject: [PATCH 2/4] fix test failures --- .../DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala index 612a26e756abd..a768be4f872f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateAsDeleteAndInsertTableSuite.scala @@ -45,10 +45,7 @@ class DeltaBasedUpdateAsDeleteAndInsertTableSuite extends DeltaBasedUpdateTableS Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) @@ -82,10 +79,7 @@ class DeltaBasedUpdateAsDeleteAndInsertTableSuite extends DeltaBasedUpdateTableS Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) From 3e9af4426d072a0116eebc0bd302ab820ce95356 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 12 Mar 2025 14:43:54 -0700 Subject: [PATCH 3/4] fix test failure --- .../sql/connector/DeltaBasedUpdateTableSuite.scala | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala index c9fd5d6e3ff0d..3443f6fdc433b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeltaBasedUpdateTableSuite.scala @@ -44,10 +44,7 @@ class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase { Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) @@ -80,10 +77,7 @@ class DeltaBasedUpdateTableSuite extends DeltaBasedUpdateTableSuiteBase { Row(1, -1, "hr") :: Row(2, 2, "software") :: Row(3, 3, "hr") :: Nil) checkLastWriteInfo( - expectedRowSchema = StructType(table.schema.map { - case attr if attr.name == "id" => attr.copy(nullable = false) // input is a constant - case attr => attr - }), + expectedRowSchema = table.schema, expectedRowIdSchema = Some(StructType(Array(PK_FIELD))), expectedMetadataSchema = Some(StructType(Array(PARTITION_FIELD, INDEX_FIELD_NULLABLE)))) From c7869ff1354cd30706d3025c974659e763286784 Mon Sep 17 00:00:00 2001 From: huaxingao Date: Wed, 12 Mar 2025 19:05:24 -0700 Subject: [PATCH 4/4] use attri.nullable for all columns --- .../analysis/RewriteRowLevelCommand.scala | 27 +++++++------------ 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala index de96e229f4a5b..0296483adc20b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelCommand.scala @@ -190,11 +190,11 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val outputs = extractOutputs(plan) val outputsWithRow = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION, WRITE_OPERATION)) - val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs, true) + val rowProjection = newLazyProjection(plan, outputsWithRow, rowAttrs) val metadataProjection = if (metadataAttrs.nonEmpty) { val outputsWithMetadata = filterOutputs(outputs, Set(WRITE_WITH_METADATA_OPERATION)) - Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs, false)) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None } @@ -211,7 +211,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val rowProjection = if (rowAttrs.nonEmpty) { val outputsWithRow = filterOutputs(outputs, DELTA_OPERATIONS_WITH_ROW) - Some(newLazyProjection(plan, outputsWithRow, rowAttrs, true)) + Some(newLazyProjection(plan, outputsWithRow, rowAttrs)) } else { None } @@ -221,7 +221,7 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val metadataProjection = if (metadataAttrs.nonEmpty) { val outputsWithMetadata = filterOutputs(outputs, DELTA_OPERATIONS_WITH_METADATA) - Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs, false)) + Some(newLazyProjection(plan, outputsWithMetadata, metadataAttrs)) } else { None } @@ -251,10 +251,9 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { private def newLazyProjection( plan: LogicalPlan, outputs: Seq[Seq[Expression]], - attrs: Seq[Attribute], - isRowAttr: Boolean): ProjectingInternalRow = { + attrs: Seq[Attribute]): ProjectingInternalRow = { val colOrdinals = attrs.map(attr => findColOrdinal(plan, attr.name)) - createProjectingInternalRow(outputs, colOrdinals, attrs, isRowAttr) + createProjectingInternalRow(outputs, colOrdinals, attrs) } // if there are assignment to row ID attributes, original values are projected as special columns @@ -267,21 +266,15 @@ trait RewriteRowLevelCommand extends Rule[LogicalPlan] { val originalValueIndex = findColOrdinal(plan, ORIGINAL_ROW_ID_VALUE_PREFIX + attr.name) if (originalValueIndex != -1) originalValueIndex else findColOrdinal(plan, attr.name) } - createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs, true) + createProjectingInternalRow(outputs, colOrdinals, rowIdAttrs) } private def createProjectingInternalRow( outputs: Seq[Seq[Expression]], colOrdinals: Seq[Int], - attrs: Seq[Attribute], - isRowAttr: Boolean): ProjectingInternalRow = { - val schema = StructType(attrs.zipWithIndex.map { case (attr, index) => - val nullable = if (!isRowAttr) { - outputs.exists(output => output(colOrdinals(index)).nullable) - } else { - attr.nullable - } - StructField(attr.name, attr.dataType, nullable, attr.metadata) + attrs: Seq[Attribute]): ProjectingInternalRow = { + val schema = StructType(attrs.zipWithIndex.map { case (attr, _) => + StructField(attr.name, attr.dataType, attr.nullable, attr.metadata) }) ProjectingInternalRow(schema, colOrdinals) }