diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java index 4e2851972c28..463bf2a27b0e 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java @@ -47,6 +47,7 @@ import org.apache.iceberg.RowLevelOperationMode; import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; @@ -55,6 +56,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.Spark3Util; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.data.TestHelpers; import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.SparkException; @@ -1066,6 +1068,73 @@ public void testDeleteWithMultipleSpecs() { sql("SELECT * FROM %s ORDER BY id", selectTarget())); } + @Test + public void testDeleteToWapBranch() throws NoSuchTableException { + Assume.assumeTrue("WAP branch only works for table identifier without branch", branch == null); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("DELETE FROM %s t WHERE id=0", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table", + 2L, + spark.table(tableName).count()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch", + 2L, + spark.table(tableName + ".branch_wap").count()); + Assert.assertEquals( + "Should not modify main branch", 3L, spark.table(tableName + ".branch_main").count()); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("DELETE FROM %s t WHERE id=1", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table with multiple writes", + 1L, + spark.table(tableName).count()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch with multiple writes", + 1L, + spark.table(tableName + ".branch_wap").count()); + Assert.assertEquals( + "Should not modify main branch with multiple writes", + 3L, + spark.table(tableName + ".branch_main").count()); + }); + } + + @Test + public void testDeleteToWapBranchWithTableBranchIdentifier() throws NoSuchTableException { + Assume.assumeTrue("Test must have branch name part in table identifier", branch != null); + + createAndInitPartitionedTable(); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + Assertions.assertThatThrownBy(() -> sql("DELETE FROM %s t WHERE id=0", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + // TODO: multiple stripes for ORC protected void createAndInitPartitionedTable() { diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java index 35f12f6ac83a..dc1e96be48a1 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -47,12 +47,14 @@ import org.apache.iceberg.Snapshot; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.SparkException; import org.apache.spark.sql.AnalysisException; @@ -2448,6 +2450,96 @@ public void testMergeNonExistingBranch() { .hasMessage("Cannot use branch (does not exist): test"); } + @Test + public void testMergeToWapBranch() { + Assume.assumeTrue("WAP branch only works for table identifier without branch", branch == null); + + createAndInitTable("id INT", "{\"id\": -1}"); + ImmutableList originalRows = ImmutableList.of(row(-1)); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertEquals( + "Should have expected rows when reading table", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", tableName)); + assertEquals( + "Should have expected rows when reading WAP branch", + expectedRows, + sql("SELECT * FROM %s.branch_wap ORDER BY id", tableName)); + assertEquals( + "Should not modify main branch", + originalRows, + sql("SELECT * FROM %s.branch_main ORDER BY id", tableName)); + }); + + spark.range(3, 6).coalesce(1).createOrReplaceTempView("source2"); + ImmutableList expectedRows2 = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(5)); + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql( + "MERGE INTO %s t USING source2 s ON t.id = s.id " + + "WHEN MATCHED THEN DELETE " + + "WHEN NOT MATCHED THEN INSERT *", + tableName); + assertEquals( + "Should have expected rows when reading table with multiple writes", + expectedRows2, + sql("SELECT * FROM %s ORDER BY id", tableName)); + assertEquals( + "Should have expected rows when reading WAP branch with multiple writes", + expectedRows2, + sql("SELECT * FROM %s.branch_wap ORDER BY id", tableName)); + assertEquals( + "Should not modify main branch with multiple writes", + originalRows, + sql("SELECT * FROM %s.branch_main ORDER BY id", tableName)); + }); + } + + @Test + public void testMergeToWapBranchWithTableBranchIdentifier() { + Assume.assumeTrue("Test must have branch name part in table identifier", branch != null); + + createAndInitTable("id INT", "{\"id\": -1}"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + ImmutableList expectedRows = + ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + Assertions.assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) { // disable runtime filtering for easier validation withSQLConf( diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java index f9230915d9e1..8093e6fc0984 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java @@ -48,6 +48,7 @@ import org.apache.iceberg.Snapshot; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; @@ -1257,6 +1258,74 @@ public void testUpdateOnNonIcebergTableNotSupported() { () -> sql("UPDATE %s SET c1 = -1 WHERE c2 = 1", "testtable")); } + @Test + public void testUpdateToWAPBranch() { + Assume.assumeTrue("WAP branch only works for table identifier without branch", branch == null); + + createAndInitTable( + "id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"a\" }"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("UPDATE %s SET dep='hr' WHERE dep='a'", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table", + 2L, + sql("SELECT * FROM %s WHERE dep='hr'", tableName).size()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch", + 2L, + sql("SELECT * FROM %s.branch_wap WHERE dep='hr'", tableName).size()); + Assert.assertEquals( + "Should not modify main branch", + 1L, + sql("SELECT * FROM %s.branch_main WHERE dep='hr'", tableName).size()); + }); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> { + sql("UPDATE %s SET dep='b' WHERE dep='hr'", tableName); + Assert.assertEquals( + "Should have expected num of rows when reading table with multiple writes", + 2L, + sql("SELECT * FROM %s WHERE dep='b'", tableName).size()); + Assert.assertEquals( + "Should have expected num of rows when reading WAP branch with multiple writes", + 2L, + sql("SELECT * FROM %s.branch_wap WHERE dep='b'", tableName).size()); + Assert.assertEquals( + "Should not modify main branch with multiple writes", + 0L, + sql("SELECT * FROM %s.branch_main WHERE dep='b'", tableName).size()); + }); + } + + @Test + public void testUpdateToWapBranchWithTableBranchIdentifier() { + Assume.assumeTrue("Test must have branch name part in table identifier", branch != null); + + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); + sql( + "ALTER TABLE %s SET TBLPROPERTIES ('%s' = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + + withSQLConf( + ImmutableMap.of(SparkSQLProperties.WAP_BRANCH, "wap"), + () -> + Assertions.assertThatThrownBy( + () -> sql("UPDATE %s SET dep='hr' WHERE dep='a'", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage( + String.format( + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [wap]", + branch))); + } + private RowLevelOperationMode mode(Table table) { String modeName = table.properties().getOrDefault(UPDATE_MODE, UPDATE_MODE_DEFAULT); return RowLevelOperationMode.fromName(modeName); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java index 1d2576180c24..f4e27aa09b0f 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -98,7 +98,22 @@ public String branch() { + "got [%s] in identifier and [%s] in options", branch, optionBranch); - return branch != null ? branch : optionBranch; + String inputBranch = branch != null ? branch : optionBranch; + if (inputBranch != null) { + return inputBranch; + } + + boolean wapEnabled = + PropertyUtil.propertyAsBoolean( + table.properties(), TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, false); + if (wapEnabled) { + String wapBranch = spark.conf().get(SparkSQLProperties.WAP_BRANCH, null); + if (wapBranch != null && table.refs().containsKey(wapBranch)) { + return wapBranch; + } + } + + return null; } public String tag() { diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java index d36ce76f6226..d7ff4311c907 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkSQLProperties.java @@ -55,4 +55,12 @@ private SparkSQLProperties() {} // Controls write distribution mode public static final String DISTRIBUTION_MODE = "spark.sql.iceberg.distribution-mode"; + + // Controls the WAP ID used for write-audit-publish workflow. + // When set, new snapshots will be staged with this ID in snapshot summary. + public static final String WAP_ID = "spark.wap.id"; + + // Controls the WAP branch used for write-audit-publish workflow. + // When set, new snapshots will be committed to this branch. + public static final String WAP_BRANCH = "spark.wap.branch"; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 8e88a9b9bdf0..87b2f0b25879 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -30,6 +30,7 @@ import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; @@ -128,7 +129,7 @@ public boolean wapEnabled() { } public String wapId() { - return sessionConf.get("spark.wap.id", null); + return sessionConf.get(SparkSQLProperties.WAP_ID, null); } public boolean mergeSchema() { @@ -333,6 +334,28 @@ public boolean caseSensitive() { } public String branch() { + if (wapEnabled()) { + String wapId = wapId(); + String wapBranch = + confParser.stringConf().sessionConf(SparkSQLProperties.WAP_BRANCH).parseOptional(); + + ValidationException.check( + wapId == null || wapBranch == null, + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", + wapId, + wapBranch); + + if (wapBranch != null) { + ValidationException.check( + branch == null, + "Cannot write to both branch and WAP branch, but got branch [%s] and WAP branch [%s]", + branch, + wapBranch); + + return wapBranch; + } + } + return branch; } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java index 1c38d616970b..536dd5febbaa 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/actions/TestRemoveOrphanFilesAction.java @@ -67,6 +67,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.SparkTestBase; import org.apache.iceberg.spark.actions.DeleteOrphanFilesSparkAction.StringToFileURI; import org.apache.iceberg.spark.source.FilePathLastModifiedRecord; @@ -319,7 +320,7 @@ public void testWapFilesAreKept() throws InterruptedException { // normal write df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); - spark.conf().set("spark.wap.id", "1"); + spark.conf().set(SparkSQLProperties.WAP_ID, "1"); // wap write df.select("c1", "c2", "c3").write().format("iceberg").mode("append").save(tableLocation); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java index 0f6ae3f20d77..178c52b840ca 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestIcebergSourceTablesBase.java @@ -62,6 +62,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Iterables; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.SparkReadOptions; +import org.apache.iceberg.spark.SparkSQLProperties; import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.iceberg.spark.SparkTableUtil; import org.apache.iceberg.spark.SparkTestBase; @@ -644,7 +645,7 @@ public void testAllMetadataTablesWithStagedCommits() throws Exception { Table table = createTable(tableIdentifier, SCHEMA, SPEC); table.updateProperties().set(TableProperties.WRITE_AUDIT_PUBLISH_ENABLED, "true").commit(); - spark.conf().set("spark.wap.id", "1234567"); + spark.conf().set(SparkSQLProperties.WAP_ID, "1234567"); Dataset df1 = spark.createDataFrame(Lists.newArrayList(new SimpleRecord(1, "a")), SimpleRecord.class); Dataset df2 = diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java new file mode 100644 index 000000000000..a65e94ee6e62 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import java.util.UUID; +import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.spark.SparkSQLProperties; +import org.assertj.core.api.Assertions; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestPartitionedWritesToWapBranch extends PartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + public TestPartitionedWritesToWapBranch( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + @Override + public void createTables() { + spark.conf().set(SparkSQLProperties.WAP_BRANCH, BRANCH); + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3)) OPTIONS (%s = 'true')", + tableName, TableProperties.WRITE_AUDIT_PUBLISH_ENABLED); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + @Override + public void removeTables() { + super.removeTables(); + spark.conf().unset(SparkSQLProperties.WAP_BRANCH); + spark.conf().unset(SparkSQLProperties.WAP_ID); + } + + @Override + protected String commitTarget() { + return tableName; + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } + + @Test + public void testBranchAndWapBranchCannotBothBeSetForWrite() { + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch("test2", table.refs().get(BRANCH).snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + Assertions.assertThatThrownBy( + () -> sql("INSERT INTO %s.branch_test2 VALUES (4, 'd')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot write to both branch and WAP branch, but got branch [test2] and WAP branch [%s]", + BRANCH); + } + + @Test + public void testWapIdAndWapBranchCannotBothBeSetForWrite() { + String wapId = UUID.randomUUID().toString(); + spark.conf().set(SparkSQLProperties.WAP_ID, wapId); + Assertions.assertThatThrownBy(() -> sql("INSERT INTO %s VALUES (4, 'd')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage( + "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", wapId, BRANCH); + } +}