diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index f48e8fc064c7..62d157815eef 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -310,6 +310,7 @@ public String branch() { return confParser .stringConf() .option(SparkWriteOptions.BRANCH) + .sessionConf(SparkWriteOptions.BRANCH) .defaultValue(SnapshotRef.MAIN_BRANCH) .parse(); } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 4252d0afd76f..b6c0b4e8d5fa 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -77,6 +77,7 @@ private SparkWriteOptions() {} // Isolation Level for DataFrame calls. Currently supported by overwritePartitions public static final String ISOLATION_LEVEL = "isolation-level"; + // Branch to write to public static final String BRANCH = "branch"; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java index 6c278e131d74..ea0f3b8133a2 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java @@ -100,6 +100,7 @@ class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrde private final Map extraSnapshotMetadata; private final Distribution requiredDistribution; private final SortOrder[] requiredOrdering; + private final String branch; private boolean cleanupOnAbort = true; @@ -123,6 +124,7 @@ class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrde this.applicationId = spark.sparkContext().applicationId(); this.wapEnabled = writeConf.wapEnabled(); this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); this.requiredDistribution = requiredDistribution; this.requiredOrdering = requiredOrdering; @@ -277,6 +279,7 @@ private void commitOperation(SnapshotUpdate operation, String description) { try { long start = System.currentTimeMillis(); + operation.toBranch(branch); operation.commit(); // abort is automatically called if this fails long duration = System.currentTimeMillis() - start; LOG.info("Committed in {} ms", duration); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java index ee528a15f4a8..63a7123cad69 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java @@ -50,7 +50,6 @@ import org.apache.iceberg.spark.SparkFilters; import org.apache.iceberg.spark.SparkReadOptions; import org.apache.iceberg.spark.SparkSchemaUtil; -import org.apache.iceberg.spark.SparkWriteOptions; import org.apache.iceberg.util.PropertyUtil; import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; @@ -248,11 +247,8 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { @Override public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { - boolean branchOptionPresent = info.options().containsKey(SparkWriteOptions.BRANCH); - if (!branchOptionPresent) { - Preconditions.checkArgument( - snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId); - } + Preconditions.checkArgument( + snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId); return new SparkWriteBuilder(sparkSession(), icebergTable, info); } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java index 5538d35f1bf1..93e76c666768 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -220,6 +220,7 @@ private void commitOperation(SnapshotUpdate operation, String description) { try { long start = System.currentTimeMillis(); + operation.toBranch(branch); operation.commit(); // abort is automatically called if this fails long duration = System.currentTimeMillis() - start; LOG.info("Committed in {} ms", duration); @@ -292,7 +293,7 @@ public String toString() { private class BatchAppend extends BaseBatchWrite { @Override public void commit(WriterCommitMessage[] messages) { - AppendFiles append = table.newAppend().toBranch(branch); + AppendFiles append = table.newAppend(); int numFiles = 0; for (DataFile file : files(messages)) { @@ -314,7 +315,7 @@ public void commit(WriterCommitMessage[] messages) { return; } - ReplacePartitions dynamicOverwrite = table.newReplacePartitions().toBranch(branch); + ReplacePartitions dynamicOverwrite = table.newReplacePartitions(); IsolationLevel isolationLevel = writeConf.isolationLevel(); Long validateFromSnapshotId = writeConf.validateFromSnapshotId(); @@ -352,8 +353,7 @@ private OverwriteByFilter(Expression overwriteExpr) { @Override public void commit(WriterCommitMessage[] messages) { - OverwriteFiles overwriteFiles = - table.newOverwrite().toBranch(branch).overwriteByRowFilter(overwriteExpr); + OverwriteFiles overwriteFiles = table.newOverwrite().overwriteByRowFilter(overwriteExpr); int numFiles = 0; for (DataFile file : files(messages)) { @@ -414,7 +414,7 @@ private Expression conflictDetectionFilter() { @Override public void commit(WriterCommitMessage[] messages) { - OverwriteFiles overwriteFiles = table.newOverwrite().toBranch(branch); + OverwriteFiles overwriteFiles = table.newOverwrite(); List overwrittenFiles = overwrittenFiles(); int numOverwrittenFiles = overwrittenFiles.size(); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java index ea7b4a21ef17..3646377e9909 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java @@ -36,6 +36,7 @@ import org.apache.iceberg.ManifestFiles; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; +import org.apache.iceberg.Snapshot; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; import org.apache.iceberg.exceptions.CommitStateUnknownException; @@ -71,7 +72,7 @@ public class TestSparkDataWrite { @Rule public TemporaryFolder temp = new TemporaryFolder(); - private String branch; + private String targetBranch; @Parameterized.Parameters(name = "format = {0}, branch = {1}") public static Object[] parameters() { @@ -102,9 +103,9 @@ public static void stopSpark() { currentSpark.stop(); } - public TestSparkDataWrite(String format, String branch) { + public TestSparkDataWrite(String format, String targetBranch) { this.format = FileFormat.fromString(format); - this.branch = branch; + this.targetBranch = targetBranch; } @Test @@ -127,19 +128,19 @@ public void testBasicWrite() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); table.refresh(); Dataset result = - spark.read().format("iceberg").option("branch", branch).load(location.toString()); + spark.read().format("iceberg").option("branch", targetBranch).load(location.toString()); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); Assert.assertEquals("Result rows should match", expected, actual); - for (ManifestFile manifest : table.snapshot(branch).allManifests(table.io())) { + for (ManifestFile manifest : latestSnapshot(table, targetBranch).allManifests(table.io())) { for (DataFile file : ManifestFiles.read(manifest, table.io())) { // TODO: avro not support split if (!format.equals(FileFormat.AVRO)) { @@ -187,7 +188,7 @@ public void testAppend() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); df.withColumn("id", df.col("id").plus(3)) @@ -196,13 +197,13 @@ public void testAppend() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); table.refresh(); Dataset result = - spark.read().format("iceberg").option("branch", branch).load(location.toString()); + spark.read().format("iceberg").option("branch", targetBranch).load(location.toString()); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -231,7 +232,7 @@ public void testEmptyOverwrite() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); Dataset empty = spark.createDataFrame(ImmutableList.of(), SimpleRecord.class); @@ -242,13 +243,13 @@ public void testEmptyOverwrite() throws IOException { .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Overwrite) .option("overwrite-mode", "dynamic") - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); table.refresh(); Dataset result = - spark.read().format("iceberg").option("branch", branch).load(location.toString()); + spark.read().format("iceberg").option("branch", targetBranch).load(location.toString()); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -284,7 +285,7 @@ public void testOverwrite() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); // overwrite with 2*id to replace record 2, append 4 and 6 @@ -295,13 +296,13 @@ public void testOverwrite() throws IOException { .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Overwrite) .option("overwrite-mode", "dynamic") - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); table.refresh(); Dataset result = - spark.read().format("iceberg").option("branch", branch).load(location.toString()); + spark.read().format("iceberg").option("branch", targetBranch).load(location.toString()); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -329,7 +330,7 @@ public void testUnpartitionedOverwrite() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); // overwrite with the same data; should not produce two copies @@ -338,13 +339,13 @@ public void testUnpartitionedOverwrite() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Overwrite) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); table.refresh(); Dataset result = - spark.read().format("iceberg").option("branch", branch).load(location.toString()); + spark.read().format("iceberg").option("branch", targetBranch).load(location.toString()); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -378,13 +379,13 @@ public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .option("branch", branch) + .option("branch", targetBranch) .save(location.toString()); table.refresh(); Dataset result = - spark.read().format("iceberg").option("branch", branch).load(location.toString()); + spark.read().format("iceberg").option("branch", targetBranch).load(location.toString()); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -392,7 +393,7 @@ public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws Assert.assertEquals("Result rows should match", expected, actual); List files = Lists.newArrayList(); - for (ManifestFile manifest : table.snapshot(branch).allManifests(table.io())) { + for (ManifestFile manifest : latestSnapshot(table, targetBranch).allManifests(table.io())) { for (DataFile file : ManifestFiles.read(manifest, table.io())) { files.add(file); } @@ -674,6 +675,14 @@ public void testCommitUnknownException() throws IOException { Assert.assertEquals("Result rows should match", records, actual); } + private Snapshot latestSnapshot(Table table, String branch) { + if ("main".equals(branch)) { + return table.currentSnapshot(); + } else { + return table.snapshot(branch); + } + } + public enum IcebergOptionsType { NONE, TABLE,