diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 8a01b85e8fd3..e964aae1ac30 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -20,10 +20,12 @@ import java.util.Locale; import java.util.Map; +import java.util.Set; import org.apache.iceberg.FileFormat; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; @@ -51,11 +53,15 @@ public class SparkWriteConf { private final RuntimeConfig sessionConf; private final Map writeOptions; private final SparkConfParser confParser; + private final int currentSpecId; + private final Set partitionSpecIds; public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { this.sessionConf = spark.conf(); this.writeOptions = writeOptions; this.confParser = new SparkConfParser(spark, table, writeOptions); + this.currentSpecId = table.spec().specId(); + this.partitionSpecIds = table.specs().keySet(); } public boolean checkNullability() { @@ -115,6 +121,20 @@ public String wapId() { return sessionConf.get("spark.wap.id", null); } + public int outputSpecId() { + int outputSpecId = + confParser + .intConf() + .option(SparkWriteOptions.OUTPUT_SPEC_ID) + .defaultValue(currentSpecId) + .parse(); + Preconditions.checkArgument( + partitionSpecIds.contains(outputSpecId), + "Output spec id %s is not a valid spec id for table", + outputSpecId); + return outputSpecId; + } + public FileFormat dataFileFormat() { String valueAsString = confParser diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index ef25f871aed7..52cb1537dfaa 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -51,5 +51,7 @@ private SparkWriteOptions() {} public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; + public static final String OUTPUT_SPEC_ID = "output-spec-id"; + public static final String OVERWRITE_MODE = "overwrite-mode"; } diff --git a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java index 8d955bdd21e8..e055c87f4586 100644 --- a/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java +++ b/spark/v3.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -99,6 +99,7 @@ class SparkWrite { private final String applicationId; private final boolean wapEnabled; private final String wapId; + private final int outputSpecId; private final long targetFileSize; private final Schema writeSchema; private final StructType dsSchema; @@ -127,6 +128,7 @@ class SparkWrite { this.dsSchema = dsSchema; this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled(); + this.outputSpecId = writeConf.outputSpecId(); } BatchWrite asBatchAppend() { @@ -163,7 +165,13 @@ private WriterFactory createWriterFactory() { Broadcast tableBroadcast = sparkContext.broadcast(SerializableTableWithSize.copyOf(table)); return new WriterFactory( - tableBroadcast, format, targetFileSize, writeSchema, dsSchema, partitionedFanoutEnabled); + tableBroadcast, + format, + outputSpecId, + targetFileSize, + writeSchema, + dsSchema, + partitionedFanoutEnabled); } private void commitOperation(SnapshotUpdate operation, String description) { @@ -558,6 +566,7 @@ DataFile[] files() { private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory { private final Broadcast
tableBroadcast; private final FileFormat format; + private final int outputSpecId; private final long targetFileSize; private final Schema writeSchema; private final StructType dsSchema; @@ -566,12 +575,14 @@ private static class WriterFactory implements DataWriterFactory, StreamingDataWr protected WriterFactory( Broadcast
tableBroadcast, FileFormat format, + int outputSpecId, long targetFileSize, Schema writeSchema, StructType dsSchema, boolean partitionedFanoutEnabled) { this.tableBroadcast = tableBroadcast; this.format = format; + this.outputSpecId = outputSpecId; this.targetFileSize = targetFileSize; this.writeSchema = writeSchema; this.dsSchema = dsSchema; @@ -586,7 +597,7 @@ public DataWriter createWriter(int partitionId, long taskId) { @Override public DataWriter createWriter(int partitionId, long taskId, long epochId) { Table table = tableBroadcast.value(); - PartitionSpec spec = table.spec(); + PartitionSpec spec = table.specs().get(outputSpecId); FileIO io = table.io(); OutputFileFactory fileFactory = diff --git a/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java b/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java index 51c56ac79d4d..6f65d469255c 100644 --- a/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java +++ b/spark/v3.1/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java @@ -20,6 +20,8 @@ import java.util.List; import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.spark.SparkCatalogTestBase; import org.apache.iceberg.spark.source.SimpleRecord; @@ -157,4 +159,87 @@ public void testViewsReturnRecentResults() { ImmutableList.of(row(1L, "a"), row(1L, "a")), sql("SELECT * FROM tmp")); } + + @Test + public void testWriteWithOutputSpec() throws NoSuchTableException { + Table table = validationCatalog.loadTable(tableIdent); + + // Drop all records in table to have a fresh start. + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + final int originalSpecId = table.spec().specId(); + table.updateSpec().addField("data").commit(); + + // Refresh this when using SparkCatalog since otherwise the new spec would not be caught. + sql("REFRESH TABLE %s", tableName); + + // By default, we write to the current spec. + List data = ImmutableList.of(new SimpleRecord(10, "a")); + spark.createDataFrame(data, SimpleRecord.class).toDF().writeTo(tableName).append(); + + List expected = ImmutableList.of(row(10L, "a", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Output spec ID should be respected when present. + data = ImmutableList.of(new SimpleRecord(11, "b"), new SimpleRecord(12, "c")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(originalSpecId)) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId)); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Verify that the actual partitions are written with the correct spec ID. + // Two of the partitions should have the original spec ID and one should have the new one. + Dataset actualPartitionRows = + spark + .read() + .format("iceberg") + .load(tableName + ".partitions") + .select("spec_id", "partition.id_trunc", "partition.data") + .orderBy("spec_id", "partition.id_trunc"); + + expected = + ImmutableList.of( + row(originalSpecId, 9L, null), + row(originalSpecId, 12L, null), + row(table.spec().specId(), 9L, "a")); + assertEquals( + "There are 3 partitions, one with the original spec ID and two with the new one", + expected, + rowsToJava(actualPartitionRows.collectAsList())); + + // Even the default spec ID should be followed when present. + data = ImmutableList.of(new SimpleRecord(13, "d")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(table.spec().specId())) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId), + row(13L, "d", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + } } diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index ca26405543c0..5d5c31b722fa 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -24,12 +24,14 @@ import java.util.Locale; import java.util.Map; +import java.util.Set; import org.apache.iceberg.DistributionMode; import org.apache.iceberg.FileFormat; import org.apache.iceberg.IsolationLevel; import org.apache.iceberg.SnapshotSummary; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; @@ -58,12 +60,16 @@ public class SparkWriteConf { private final RuntimeConfig sessionConf; private final Map writeOptions; private final SparkConfParser confParser; + private final int currentSpecId; + private final Set partitionSpecIds; public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { this.table = table; this.sessionConf = spark.conf(); this.writeOptions = writeOptions; this.confParser = new SparkConfParser(spark, table, writeOptions); + this.currentSpecId = table.spec().specId(); + this.partitionSpecIds = table.specs().keySet(); } public boolean checkNullability() { @@ -123,6 +129,20 @@ public String wapId() { return sessionConf.get("spark.wap.id", null); } + public int outputSpecId() { + int outputSpecId = + confParser + .intConf() + .option(SparkWriteOptions.OUTPUT_SPEC_ID) + .defaultValue(currentSpecId) + .parse(); + Preconditions.checkArgument( + partitionSpecIds.contains(outputSpecId), + "Output spec id %s is not a valid spec id for table", + outputSpecId); + return outputSpecId; + } + public boolean mergeSchema() { return confParser .booleanConf() diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 6f4649642c57..c4eacb7b98a4 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -57,6 +57,8 @@ private SparkWriteOptions() {} public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; + public static final String OUTPUT_SPEC_ID = "output-spec-id"; + public static final String OVERWRITE_MODE = "overwrite-mode"; // Overrides the default distribution mode for a write operation diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java index f63db416cc2a..428888057bec 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -90,6 +90,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { private final String applicationId; private final boolean wapEnabled; private final String wapId; + private final int outputSpecId; private final long targetFileSize; private final Schema writeSchema; private final StructType dsSchema; @@ -125,6 +126,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled(); this.requiredDistribution = requiredDistribution; this.requiredOrdering = requiredOrdering; + this.outputSpecId = writeConf.outputSpecId(); } @Override @@ -174,6 +176,7 @@ private WriterFactory createWriterFactory() { tableBroadcast, queryId, format, + outputSpecId, targetFileSize, writeSchema, dsSchema, @@ -584,6 +587,7 @@ DataFile[] files() { private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory { private final Broadcast
tableBroadcast; private final FileFormat format; + private final int outputSpecId; private final long targetFileSize; private final Schema writeSchema; private final StructType dsSchema; @@ -594,12 +598,14 @@ protected WriterFactory( Broadcast
tableBroadcast, String queryId, FileFormat format, + int outputSpecId, long targetFileSize, Schema writeSchema, StructType dsSchema, boolean partitionedFanoutEnabled) { this.tableBroadcast = tableBroadcast; this.format = format; + this.outputSpecId = outputSpecId; this.targetFileSize = targetFileSize; this.writeSchema = writeSchema; this.dsSchema = dsSchema; @@ -615,7 +621,7 @@ public DataWriter createWriter(int partitionId, long taskId) { @Override public DataWriter createWriter(int partitionId, long taskId, long epochId) { Table table = tableBroadcast.value(); - PartitionSpec spec = table.spec(); + PartitionSpec spec = table.specs().get(outputSpecId); FileIO io = table.io(); OutputFileFactory fileFactory = diff --git a/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java b/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java index 51c56ac79d4d..6f65d469255c 100644 --- a/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java +++ b/spark/v3.2/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java @@ -20,6 +20,8 @@ import java.util.List; import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.spark.SparkCatalogTestBase; import org.apache.iceberg.spark.source.SimpleRecord; @@ -157,4 +159,87 @@ public void testViewsReturnRecentResults() { ImmutableList.of(row(1L, "a"), row(1L, "a")), sql("SELECT * FROM tmp")); } + + @Test + public void testWriteWithOutputSpec() throws NoSuchTableException { + Table table = validationCatalog.loadTable(tableIdent); + + // Drop all records in table to have a fresh start. + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + final int originalSpecId = table.spec().specId(); + table.updateSpec().addField("data").commit(); + + // Refresh this when using SparkCatalog since otherwise the new spec would not be caught. + sql("REFRESH TABLE %s", tableName); + + // By default, we write to the current spec. + List data = ImmutableList.of(new SimpleRecord(10, "a")); + spark.createDataFrame(data, SimpleRecord.class).toDF().writeTo(tableName).append(); + + List expected = ImmutableList.of(row(10L, "a", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Output spec ID should be respected when present. + data = ImmutableList.of(new SimpleRecord(11, "b"), new SimpleRecord(12, "c")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(originalSpecId)) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId)); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Verify that the actual partitions are written with the correct spec ID. + // Two of the partitions should have the original spec ID and one should have the new one. + Dataset actualPartitionRows = + spark + .read() + .format("iceberg") + .load(tableName + ".partitions") + .select("spec_id", "partition.id_trunc", "partition.data") + .orderBy("spec_id", "partition.id_trunc"); + + expected = + ImmutableList.of( + row(originalSpecId, 9L, null), + row(originalSpecId, 12L, null), + row(table.spec().specId(), 9L, "a")); + assertEquals( + "There are 3 partitions, one with the original spec ID and two with the new one", + expected, + rowsToJava(actualPartitionRows.collectAsList())); + + // Even the default spec ID should be followed when present. + data = ImmutableList.of(new SimpleRecord(13, "d")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(table.spec().specId())) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId), + row(13L, "d", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 87b2f0b25879..40d67717d116 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -24,6 +24,7 @@ import java.util.Locale; import java.util.Map; +import java.util.Set; import org.apache.iceberg.DistributionMode; import org.apache.iceberg.FileFormat; import org.apache.iceberg.IsolationLevel; @@ -31,6 +32,7 @@ import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.spark.sql.RuntimeConfig; import org.apache.spark.sql.SparkSession; @@ -61,6 +63,8 @@ public class SparkWriteConf { private final RuntimeConfig sessionConf; private final Map writeOptions; private final SparkConfParser confParser; + private final int currentSpecId; + private final Set partitionSpecIds; public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { this(spark, table, null, writeOptions); @@ -73,6 +77,8 @@ public SparkWriteConf( this.sessionConf = spark.conf(); this.writeOptions = writeOptions; this.confParser = new SparkConfParser(spark, table, writeOptions); + this.currentSpecId = table.spec().specId(); + this.partitionSpecIds = table.specs().keySet(); } public boolean checkNullability() { @@ -141,6 +147,20 @@ public boolean mergeSchema() { .parse(); } + public int outputSpecId() { + int outputSpecId = + confParser + .intConf() + .option(SparkWriteOptions.OUTPUT_SPEC_ID) + .defaultValue(currentSpecId) + .parse(); + Preconditions.checkArgument( + partitionSpecIds.contains(outputSpecId), + "Output spec id %s is not a valid spec id for table", + outputSpecId); + return outputSpecId; + } + public FileFormat dataFileFormat() { String valueAsString = confParser diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java index 6f4649642c57..c4eacb7b98a4 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteOptions.java @@ -57,6 +57,8 @@ private SparkWriteOptions() {} public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE = "handle-timestamp-without-timezone"; + public static final String OUTPUT_SPEC_ID = "output-spec-id"; + public static final String OVERWRITE_MODE = "overwrite-mode"; // Overrides the default distribution mode for a write operation diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java index 9bcbbde8b703..a080fcead13b 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -91,6 +91,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { private final String applicationId; private final boolean wapEnabled; private final String wapId; + private final int outputSpecId; private final String branch; private final long targetFileSize; private final Schema writeSchema; @@ -128,6 +129,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled(); this.requiredDistribution = requiredDistribution; this.requiredOrdering = requiredOrdering; + this.outputSpecId = writeConf.outputSpecId(); } @Override @@ -177,6 +179,7 @@ private WriterFactory createWriterFactory() { tableBroadcast, queryId, format, + outputSpecId, targetFileSize, writeSchema, dsSchema, @@ -604,6 +607,7 @@ DataFile[] files() { private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory { private final Broadcast
tableBroadcast; private final FileFormat format; + private final int outputSpecId; private final long targetFileSize; private final Schema writeSchema; private final StructType dsSchema; @@ -614,12 +618,14 @@ protected WriterFactory( Broadcast
tableBroadcast, String queryId, FileFormat format, + int outputSpecId, long targetFileSize, Schema writeSchema, StructType dsSchema, boolean partitionedFanoutEnabled) { this.tableBroadcast = tableBroadcast; this.format = format; + this.outputSpecId = outputSpecId; this.targetFileSize = targetFileSize; this.writeSchema = writeSchema; this.dsSchema = dsSchema; @@ -635,7 +641,7 @@ public DataWriter createWriter(int partitionId, long taskId) { @Override public DataWriter createWriter(int partitionId, long taskId, long epochId) { Table table = tableBroadcast.value(); - PartitionSpec spec = table.spec(); + PartitionSpec spec = table.specs().get(outputSpecId); FileIO io = table.io(); OutputFileFactory fileFactory = diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java index c6775c2a0799..77dccbf1e064 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java @@ -18,8 +18,11 @@ */ package org.apache.iceberg.spark.sql; +import java.util.Arrays; import java.util.List; import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.expressions.Expressions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.spark.SparkCatalogTestBase; import org.apache.iceberg.spark.source.SimpleRecord; @@ -183,4 +186,97 @@ public void testViewsReturnRecentResults() { ImmutableList.of(row(1L, "a"), row(1L, "a")), sql("SELECT * FROM tmp")); } + + // Asserts whether the given table .partitions table has the expected rows. Note that the output + // row should have spec_id and it is sorted by spec_id and selectPartitionColumns. + protected void assertPartitionMetadata( + String tableName, List expected, String... selectPartitionColumns) { + String[] fullyQualifiedCols = + Arrays.stream(selectPartitionColumns).map(s -> "partition." + s).toArray(String[]::new); + Dataset actualPartitionRows = + spark + .read() + .format("iceberg") + .load(tableName + ".partitions") + .select("spec_id", fullyQualifiedCols) + .orderBy("spec_id", fullyQualifiedCols); + + assertEquals( + "There are 3 partitions, one with the original spec ID and two with the new one", + expected, + rowsToJava(actualPartitionRows.collectAsList())); + } + + @Test + public void testWriteWithOutputSpec() throws NoSuchTableException { + Table table = validationCatalog.loadTable(tableIdent); + + // Drop all records in table to have a fresh start. + table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit(); + + final int originalSpecId = table.spec().specId(); + table.updateSpec().addField("data").commit(); + + // Refresh this when using SparkCatalog since otherwise the new spec would not be caught. + sql("REFRESH TABLE %s", tableName); + + // By default, we write to the current spec. + List data = ImmutableList.of(new SimpleRecord(10, "a")); + spark.createDataFrame(data, SimpleRecord.class).toDF().writeTo(tableName).append(); + + List expected = ImmutableList.of(row(10L, "a", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Output spec ID should be respected when present. + data = ImmutableList.of(new SimpleRecord(11, "b"), new SimpleRecord(12, "c")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(originalSpecId)) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId)); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + + // Verify that the actual partitions are written with the correct spec ID. + // Two of the partitions should have the original spec ID and one should have the new one. + // TODO: WAP branch does not support reading partitions table, skip this check for now. + expected = + ImmutableList.of( + row(originalSpecId, 9L, null), + row(originalSpecId, 12L, null), + row(table.spec().specId(), 9L, "a")); + assertPartitionMetadata(tableName, expected, "id_trunc", "data"); + + // Even the default spec ID should be followed when present. + data = ImmutableList.of(new SimpleRecord(13, "d")); + spark + .createDataFrame(data, SimpleRecord.class) + .toDF() + .writeTo(tableName) + .option("output-spec-id", Integer.toString(table.spec().specId())) + .append(); + + expected = + ImmutableList.of( + row(10L, "a", table.spec().specId()), + row(11L, "b", originalSpecId), + row(12L, "c", originalSpecId), + row(13L, "d", table.spec().specId())); + assertEquals( + "Rows must match", + expected, + sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName)); + } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java index a65e94ee6e62..5dde5f33d965 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToWapBranch.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.spark.sql; +import java.util.List; import java.util.Map; import java.util.UUID; import org.apache.iceberg.Table; @@ -88,4 +89,11 @@ public void testWapIdAndWapBranchCannotBothBeSetForWrite() { .hasMessage( "Cannot set both WAP ID and branch, but got ID [%s] and branch [%s]", wapId, BRANCH); } + + @Override + protected void assertPartitionMetadata( + String tableName, List expected, String... selectPartitionColumns) { + // Cannot read from the .partitions table newly written data into the WAP branch. See + // https://github.com/apache/iceberg/issues/7297 for more details. + } }