Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -51,11 +53,15 @@ public class SparkWriteConf {
private final RuntimeConfig sessionConf;
private final Map<String, String> writeOptions;
private final SparkConfParser confParser;
private final int currentSpecId;
private final Set<Integer> partitionSpecIds;

public SparkWriteConf(SparkSession spark, Table table, Map<String, String> writeOptions) {
this.sessionConf = spark.conf();
this.writeOptions = writeOptions;
this.confParser = new SparkConfParser(spark, table, writeOptions);
this.currentSpecId = table.spec().specId();
this.partitionSpecIds = table.specs().keySet();
}

public boolean checkNullability() {
Expand Down Expand Up @@ -115,6 +121,20 @@ public String wapId() {
return sessionConf.get("spark.wap.id", null);
}

public int outputSpecId() {
int outputSpecId =
confParser
.intConf()
.option(SparkWriteOptions.OUTPUT_SPEC_ID)
.defaultValue(currentSpecId)
.parse();
Preconditions.checkArgument(
partitionSpecIds.contains(outputSpecId),
"Output spec id %s is not a valid spec id for table",
outputSpecId);
return outputSpecId;
}

public FileFormat dataFileFormat() {
String valueAsString =
confParser
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,7 @@ private SparkWriteOptions() {}
public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE =
"handle-timestamp-without-timezone";

public static final String OUTPUT_SPEC_ID = "output-spec-id";

public static final String OVERWRITE_MODE = "overwrite-mode";
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class SparkWrite {
private final String applicationId;
private final boolean wapEnabled;
private final String wapId;
private final int outputSpecId;
private final long targetFileSize;
private final Schema writeSchema;
private final StructType dsSchema;
Expand Down Expand Up @@ -127,6 +128,7 @@ class SparkWrite {
this.dsSchema = dsSchema;
this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata();
this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled();
this.outputSpecId = writeConf.outputSpecId();
}

BatchWrite asBatchAppend() {
Expand Down Expand Up @@ -163,7 +165,13 @@ private WriterFactory createWriterFactory() {
Broadcast<Table> tableBroadcast =
sparkContext.broadcast(SerializableTableWithSize.copyOf(table));
return new WriterFactory(
tableBroadcast, format, targetFileSize, writeSchema, dsSchema, partitionedFanoutEnabled);
tableBroadcast,
format,
outputSpecId,
targetFileSize,
writeSchema,
dsSchema,
partitionedFanoutEnabled);
}

private void commitOperation(SnapshotUpdate<?> operation, String description) {
Expand Down Expand Up @@ -558,6 +566,7 @@ DataFile[] files() {
private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory {
private final Broadcast<Table> tableBroadcast;
private final FileFormat format;
private final int outputSpecId;
private final long targetFileSize;
private final Schema writeSchema;
private final StructType dsSchema;
Expand All @@ -566,12 +575,14 @@ private static class WriterFactory implements DataWriterFactory, StreamingDataWr
protected WriterFactory(
Broadcast<Table> tableBroadcast,
FileFormat format,
int outputSpecId,
long targetFileSize,
Schema writeSchema,
StructType dsSchema,
boolean partitionedFanoutEnabled) {
this.tableBroadcast = tableBroadcast;
this.format = format;
this.outputSpecId = outputSpecId;
this.targetFileSize = targetFileSize;
this.writeSchema = writeSchema;
this.dsSchema = dsSchema;
Expand All @@ -586,7 +597,7 @@ public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
@Override
public DataWriter<InternalRow> createWriter(int partitionId, long taskId, long epochId) {
Table table = tableBroadcast.value();
PartitionSpec spec = table.spec();
PartitionSpec spec = table.specs().get(outputSpecId);
FileIO io = table.io();

OutputFileFactory fileFactory =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import java.util.List;
import java.util.Map;
import org.apache.iceberg.Table;
import org.apache.iceberg.expressions.Expressions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.spark.SparkCatalogTestBase;
import org.apache.iceberg.spark.source.SimpleRecord;
Expand Down Expand Up @@ -157,4 +159,87 @@ public void testViewsReturnRecentResults() {
ImmutableList.of(row(1L, "a"), row(1L, "a")),
sql("SELECT * FROM tmp"));
}

@Test
public void testWriteWithOutputSpec() throws NoSuchTableException {
Table table = validationCatalog.loadTable(tableIdent);

// Drop all records in table to have a fresh start.
table.newDelete().deleteFromRowFilter(Expressions.alwaysTrue()).commit();

final int originalSpecId = table.spec().specId();
table.updateSpec().addField("data").commit();

// Refresh this when using SparkCatalog since otherwise the new spec would not be caught.
sql("REFRESH TABLE %s", tableName);

// By default, we write to the current spec.
List<SimpleRecord> data = ImmutableList.of(new SimpleRecord(10, "a"));
spark.createDataFrame(data, SimpleRecord.class).toDF().writeTo(tableName).append();

List<Object[]> expected = ImmutableList.of(row(10L, "a", table.spec().specId()));
assertEquals(
"Rows must match",
expected,
sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName));

// Output spec ID should be respected when present.
data = ImmutableList.of(new SimpleRecord(11, "b"), new SimpleRecord(12, "c"));
spark
.createDataFrame(data, SimpleRecord.class)
.toDF()
.writeTo(tableName)
.option("output-spec-id", Integer.toString(originalSpecId))
.append();

expected =
ImmutableList.of(
row(10L, "a", table.spec().specId()),
row(11L, "b", originalSpecId),
row(12L, "c", originalSpecId));
assertEquals(
"Rows must match",
expected,
sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName));

// Verify that the actual partitions are written with the correct spec ID.
// Two of the partitions should have the original spec ID and one should have the new one.
Dataset<Row> actualPartitionRows =
spark
.read()
.format("iceberg")
.load(tableName + ".partitions")
.select("spec_id", "partition.id_trunc", "partition.data")
.orderBy("spec_id", "partition.id_trunc");

expected =
ImmutableList.of(
row(originalSpecId, 9L, null),
row(originalSpecId, 12L, null),
row(table.spec().specId(), 9L, "a"));
assertEquals(
"There are 3 partitions, one with the original spec ID and two with the new one",
expected,
rowsToJava(actualPartitionRows.collectAsList()));

// Even the default spec ID should be followed when present.
data = ImmutableList.of(new SimpleRecord(13, "d"));
spark
.createDataFrame(data, SimpleRecord.class)
.toDF()
.writeTo(tableName)
.option("output-spec-id", Integer.toString(table.spec().specId()))
.append();

expected =
ImmutableList.of(
row(10L, "a", table.spec().specId()),
row(11L, "b", originalSpecId),
row(12L, "c", originalSpecId),
row(13L, "d", table.spec().specId()));
assertEquals(
"Rows must match",
expected,
sql("SELECT id, data, _spec_id FROM %s WHERE id >= 10 ORDER BY id", tableName));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@

import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.iceberg.DistributionMode;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.IsolationLevel;
import org.apache.iceberg.SnapshotSummary;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.spark.sql.RuntimeConfig;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -58,12 +60,16 @@ public class SparkWriteConf {
private final RuntimeConfig sessionConf;
private final Map<String, String> writeOptions;
private final SparkConfParser confParser;
private final int currentSpecId;
private final Set<Integer> partitionSpecIds;

public SparkWriteConf(SparkSession spark, Table table, Map<String, String> writeOptions) {
this.table = table;
this.sessionConf = spark.conf();
this.writeOptions = writeOptions;
this.confParser = new SparkConfParser(spark, table, writeOptions);
this.currentSpecId = table.spec().specId();
this.partitionSpecIds = table.specs().keySet();
}

public boolean checkNullability() {
Expand Down Expand Up @@ -123,6 +129,20 @@ public String wapId() {
return sessionConf.get("spark.wap.id", null);
}

public int outputSpecId() {
int outputSpecId =
confParser
.intConf()
.option(SparkWriteOptions.OUTPUT_SPEC_ID)
.defaultValue(currentSpecId)
.parse();
Preconditions.checkArgument(
partitionSpecIds.contains(outputSpecId),
"Output spec id %s is not a valid spec id for table",
outputSpecId);
return outputSpecId;
}

public boolean mergeSchema() {
return confParser
.booleanConf()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ private SparkWriteOptions() {}
public static final String HANDLE_TIMESTAMP_WITHOUT_TIMEZONE =
"handle-timestamp-without-timezone";

public static final String OUTPUT_SPEC_ID = "output-spec-id";

public static final String OVERWRITE_MODE = "overwrite-mode";

// Overrides the default distribution mode for a write operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering {
private final String applicationId;
private final boolean wapEnabled;
private final String wapId;
private final int outputSpecId;
private final long targetFileSize;
private final Schema writeSchema;
private final StructType dsSchema;
Expand Down Expand Up @@ -125,6 +126,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering {
this.partitionedFanoutEnabled = writeConf.fanoutWriterEnabled();
this.requiredDistribution = requiredDistribution;
this.requiredOrdering = requiredOrdering;
this.outputSpecId = writeConf.outputSpecId();
}

@Override
Expand Down Expand Up @@ -174,6 +176,7 @@ private WriterFactory createWriterFactory() {
tableBroadcast,
queryId,
format,
outputSpecId,
targetFileSize,
writeSchema,
dsSchema,
Expand Down Expand Up @@ -584,6 +587,7 @@ DataFile[] files() {
private static class WriterFactory implements DataWriterFactory, StreamingDataWriterFactory {
private final Broadcast<Table> tableBroadcast;
private final FileFormat format;
private final int outputSpecId;
private final long targetFileSize;
private final Schema writeSchema;
private final StructType dsSchema;
Expand All @@ -594,12 +598,14 @@ protected WriterFactory(
Broadcast<Table> tableBroadcast,
String queryId,
FileFormat format,
int outputSpecId,
long targetFileSize,
Schema writeSchema,
StructType dsSchema,
boolean partitionedFanoutEnabled) {
this.tableBroadcast = tableBroadcast;
this.format = format;
this.outputSpecId = outputSpecId;
this.targetFileSize = targetFileSize;
this.writeSchema = writeSchema;
this.dsSchema = dsSchema;
Expand All @@ -615,7 +621,7 @@ public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
@Override
public DataWriter<InternalRow> createWriter(int partitionId, long taskId, long epochId) {
Table table = tableBroadcast.value();
PartitionSpec spec = table.spec();
PartitionSpec spec = table.specs().get(outputSpecId);
FileIO io = table.io();

OutputFileFactory fileFactory =
Expand Down
Loading