diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala index c640e4b2c789..455129f2c9d5 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala @@ -20,6 +20,7 @@ package org.apache.iceberg.spark.extensions import org.apache.spark.sql.SparkSessionExtensions +import org.apache.spark.sql.catalyst.analysis.AlignedRowLevelIcebergCommandCheck import org.apache.spark.sql.catalyst.analysis.AlignRowLevelCommandAssignments import org.apache.spark.sql.catalyst.analysis.CheckMergeIntoTableConditions import org.apache.spark.sql.catalyst.analysis.MergeIntoIcebergTableResolutionCheck @@ -55,6 +56,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) { extensions.injectResolutionRule { _ => RewriteUpdateTable } extensions.injectResolutionRule { _ => RewriteMergeIntoTable } extensions.injectCheckRule { _ => MergeIntoIcebergTableResolutionCheck } + extensions.injectCheckRule { _ => AlignedRowLevelIcebergCommandCheck } // optimizer extensions extensions.injectOptimizerRule { _ => ExtendedSimplifyConditionalsInPredicate } diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala new file mode 100644 index 000000000000..d915e4f10949 --- /dev/null +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlignedRowLevelIcebergCommandCheck.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.MergeIntoIcebergTable +import org.apache.spark.sql.catalyst.plans.logical.UpdateIcebergTable + +object AlignedRowLevelIcebergCommandCheck extends (LogicalPlan => Unit) { + + override def apply(plan: LogicalPlan): Unit = { + plan foreach { + case m: MergeIntoIcebergTable if !m.aligned => + throw new AnalysisException(s"Could not align Iceberg MERGE INTO: $m") + case u: UpdateIcebergTable if !u.aligned => + throw new AnalysisException(s"Could not align Iceberg UPDATE: $u") + case _ => // OK + } + } +} diff --git a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala index c58c033f1d39..ce3818922c78 100644 --- a/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala +++ b/spark/v3.2/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/expressions/AssignmentUtils.scala @@ -40,7 +40,10 @@ object AssignmentUtils extends SQLConfHelper { sameSize && table.output.zip(assignments).forall { case (attr, assignment) => val key = assignment.key val value = assignment.value - toAssignmentRef(attr) == toAssignmentRef(key) && + val refsEqual = toAssignmentRef(attr).zip(toAssignmentRef(key)) + .forall{ case (attrRef, keyRef) => conf.resolver(attrRef, keyRef)} + + refsEqual && DataType.equalsIgnoreCompatibleNullability(value.dataType, attr.dataType) && (attr.nullable || !value.nullable) } diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java index 6537e31c6757..131c323b0734 100644 --- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java +++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -1189,6 +1189,36 @@ public void testMergeAlignsUpdateAndInsertActions() { sql("SELECT * FROM %s ORDER BY id", tableName)); } + @Test + public void testMergeMixedCaseAlignsUpdateAndInsertActions() { + createAndInitTable("id INT, a INT, b STRING", "{ \"id\": 1, \"a\": 2, \"b\": \"str\" }"); + createOrReplaceView( + "source", + "{ \"id\": 1, \"c1\": -2, \"c2\": \"new_str_1\" }\n" + + "{ \"id\": 2, \"c1\": -20, \"c2\": \"new_str_2\" }"); + + sql("MERGE INTO %s t USING source " + + "ON t.iD == source.Id " + + "WHEN MATCHED THEN " + + " UPDATE SET B = c2, A = c1, t.Id = source.ID " + + "WHEN NOT MATCHED THEN " + + " INSERT (b, A, iD) VALUES (c2, c1, id)", tableName); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), + sql("SELECT * FROM %s ORDER BY id", tableName)); + + assertEquals( + "Output should match", + ImmutableList.of(row(1, -2, "new_str_1")), + sql("SELECT * FROM %s WHERE id = 1 ORDER BY id", tableName)); + assertEquals( + "Output should match", + ImmutableList.of(row(2, -20, "new_str_2")), + sql("SELECT * FROM %s WHERE b = 'new_str_2'ORDER BY id", tableName)); + } + @Test public void testMergeUpdatesNestedStructFields() { createAndInitTable("id INT, s STRUCT,m:MAP>>",