Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ public String branch() {
return confParser
.stringConf()
.option(SparkWriteOptions.BRANCH)
.sessionConf(SparkWriteOptions.BRANCH)
.defaultValue(SnapshotRef.MAIN_BRANCH)
.parse();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ private SparkWriteOptions() {}

// Isolation Level for DataFrame calls. Currently supported by overwritePartitions
public static final String ISOLATION_LEVEL = "isolation-level";

// Branch to write to
public static final String BRANCH = "branch";
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrde
private final Map<String, String> extraSnapshotMetadata;
private final Distribution requiredDistribution;
private final SortOrder[] requiredOrdering;
private final String branch;

private boolean cleanupOnAbort = true;

Expand All @@ -123,6 +124,7 @@ class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrde
this.applicationId = spark.sparkContext().applicationId();
this.wapEnabled = writeConf.wapEnabled();
this.wapId = writeConf.wapId();
this.branch = writeConf.branch();
this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata();
this.requiredDistribution = requiredDistribution;
this.requiredOrdering = requiredOrdering;
Expand Down Expand Up @@ -277,6 +279,7 @@ private void commitOperation(SnapshotUpdate<?> operation, String description) {

try {
long start = System.currentTimeMillis();
operation.toBranch(branch);
operation.commit(); // abort is automatically called if this fails
long duration = System.currentTimeMillis() - start;
LOG.info("Committed in {} ms", duration);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
import org.apache.iceberg.spark.SparkFilters;
import org.apache.iceberg.spark.SparkReadOptions;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.SparkWriteOptions;
import org.apache.iceberg.util.PropertyUtil;
import org.apache.iceberg.util.SnapshotUtil;
import org.apache.spark.sql.SparkSession;
Expand Down Expand Up @@ -248,11 +247,8 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {

@Override
public WriteBuilder newWriteBuilder(LogicalWriteInfo info) {
boolean branchOptionPresent = info.options().containsKey(SparkWriteOptions.BRANCH);
if (!branchOptionPresent) {
Preconditions.checkArgument(
snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId);
}
Preconditions.checkArgument(
snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId);
return new SparkWriteBuilder(sparkSession(), icebergTable, info);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ private void commitOperation(SnapshotUpdate<?> operation, String description) {

try {
long start = System.currentTimeMillis();
operation.toBranch(branch);
operation.commit(); // abort is automatically called if this fails
long duration = System.currentTimeMillis() - start;
LOG.info("Committed in {} ms", duration);
Expand Down Expand Up @@ -292,7 +293,7 @@ public String toString() {
private class BatchAppend extends BaseBatchWrite {
@Override
public void commit(WriterCommitMessage[] messages) {
AppendFiles append = table.newAppend().toBranch(branch);
AppendFiles append = table.newAppend();

int numFiles = 0;
for (DataFile file : files(messages)) {
Expand All @@ -314,7 +315,7 @@ public void commit(WriterCommitMessage[] messages) {
return;
}

ReplacePartitions dynamicOverwrite = table.newReplacePartitions().toBranch(branch);
ReplacePartitions dynamicOverwrite = table.newReplacePartitions();

IsolationLevel isolationLevel = writeConf.isolationLevel();
Long validateFromSnapshotId = writeConf.validateFromSnapshotId();
Expand Down Expand Up @@ -352,8 +353,7 @@ private OverwriteByFilter(Expression overwriteExpr) {

@Override
public void commit(WriterCommitMessage[] messages) {
OverwriteFiles overwriteFiles =
table.newOverwrite().toBranch(branch).overwriteByRowFilter(overwriteExpr);
OverwriteFiles overwriteFiles = table.newOverwrite().overwriteByRowFilter(overwriteExpr);

int numFiles = 0;
for (DataFile file : files(messages)) {
Expand Down Expand Up @@ -414,7 +414,7 @@ private Expression conflictDetectionFilter() {

@Override
public void commit(WriterCommitMessage[] messages) {
OverwriteFiles overwriteFiles = table.newOverwrite().toBranch(branch);
OverwriteFiles overwriteFiles = table.newOverwrite();

List<DataFile> overwrittenFiles = overwrittenFiles();
int numOverwrittenFiles = overwrittenFiles.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.iceberg.ManifestFiles;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Snapshot;
import org.apache.iceberg.Table;
import org.apache.iceberg.TableProperties;
import org.apache.iceberg.exceptions.CommitStateUnknownException;
Expand Down Expand Up @@ -71,7 +72,7 @@ public class TestSparkDataWrite {

@Rule public TemporaryFolder temp = new TemporaryFolder();

private String branch;
private String targetBranch;

@Parameterized.Parameters(name = "format = {0}, branch = {1}")
public static Object[] parameters() {
Expand Down Expand Up @@ -102,9 +103,9 @@ public static void stopSpark() {
currentSpark.stop();
}

public TestSparkDataWrite(String format, String branch) {
public TestSparkDataWrite(String format, String targetBranch) {
this.format = FileFormat.fromString(format);
this.branch = branch;
this.targetBranch = targetBranch;
}

@Test
Expand All @@ -127,19 +128,19 @@ public void testBasicWrite() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

table.refresh();

Dataset<Row> result =
spark.read().format("iceberg").option("branch", branch).load(location.toString());
spark.read().format("iceberg").option("branch", targetBranch).load(location.toString());

List<SimpleRecord> actual =
result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
Assert.assertEquals("Number of rows should match", expected.size(), actual.size());
Assert.assertEquals("Result rows should match", expected, actual);
for (ManifestFile manifest : table.snapshot(branch).allManifests(table.io())) {
for (ManifestFile manifest : latestSnapshot(table, targetBranch).allManifests(table.io())) {
for (DataFile file : ManifestFiles.read(manifest, table.io())) {
// TODO: avro not support split
if (!format.equals(FileFormat.AVRO)) {
Expand Down Expand Up @@ -187,7 +188,7 @@ public void testAppend() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

df.withColumn("id", df.col("id").plus(3))
Expand All @@ -196,13 +197,13 @@ public void testAppend() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

table.refresh();

Dataset<Row> result =
spark.read().format("iceberg").option("branch", branch).load(location.toString());
spark.read().format("iceberg").option("branch", targetBranch).load(location.toString());

List<SimpleRecord> actual =
result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
Expand Down Expand Up @@ -231,7 +232,7 @@ public void testEmptyOverwrite() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

Dataset<Row> empty = spark.createDataFrame(ImmutableList.of(), SimpleRecord.class);
Expand All @@ -242,13 +243,13 @@ public void testEmptyOverwrite() throws IOException {
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Overwrite)
.option("overwrite-mode", "dynamic")
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

table.refresh();

Dataset<Row> result =
spark.read().format("iceberg").option("branch", branch).load(location.toString());
spark.read().format("iceberg").option("branch", targetBranch).load(location.toString());

List<SimpleRecord> actual =
result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
Expand Down Expand Up @@ -284,7 +285,7 @@ public void testOverwrite() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

// overwrite with 2*id to replace record 2, append 4 and 6
Expand All @@ -295,13 +296,13 @@ public void testOverwrite() throws IOException {
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Overwrite)
.option("overwrite-mode", "dynamic")
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

table.refresh();

Dataset<Row> result =
spark.read().format("iceberg").option("branch", branch).load(location.toString());
spark.read().format("iceberg").option("branch", targetBranch).load(location.toString());

List<SimpleRecord> actual =
result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
Expand Down Expand Up @@ -329,7 +330,7 @@ public void testUnpartitionedOverwrite() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

// overwrite with the same data; should not produce two copies
Expand All @@ -338,13 +339,13 @@ public void testUnpartitionedOverwrite() throws IOException {
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Overwrite)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

table.refresh();

Dataset<Row> result =
spark.read().format("iceberg").option("branch", branch).load(location.toString());
spark.read().format("iceberg").option("branch", targetBranch).load(location.toString());

List<SimpleRecord> actual =
result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
Expand Down Expand Up @@ -378,21 +379,21 @@ public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws
.format("iceberg")
.option(SparkWriteOptions.WRITE_FORMAT, format.toString())
.mode(SaveMode.Append)
.option("branch", branch)
.option("branch", targetBranch)
.save(location.toString());

table.refresh();

Dataset<Row> result =
spark.read().format("iceberg").option("branch", branch).load(location.toString());
spark.read().format("iceberg").option("branch", targetBranch).load(location.toString());

List<SimpleRecord> actual =
result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList();
Assert.assertEquals("Number of rows should match", expected.size(), actual.size());
Assert.assertEquals("Result rows should match", expected, actual);

List<DataFile> files = Lists.newArrayList();
for (ManifestFile manifest : table.snapshot(branch).allManifests(table.io())) {
for (ManifestFile manifest : latestSnapshot(table, targetBranch).allManifests(table.io())) {
for (DataFile file : ManifestFiles.read(manifest, table.io())) {
files.add(file);
}
Expand Down Expand Up @@ -674,6 +675,14 @@ public void testCommitUnknownException() throws IOException {
Assert.assertEquals("Result rows should match", records, actual);
}

private Snapshot latestSnapshot(Table table, String branch) {
if ("main".equals(branch)) {
return table.currentSnapshot();
} else {
return table.snapshot(branch);
}
}

public enum IcebergOptionsType {
NONE,
TABLE,
Expand Down