diff --git a/core/src/main/java/org/apache/iceberg/MergingSnapshotProducer.java b/core/src/main/java/org/apache/iceberg/MergingSnapshotProducer.java index 57df521b512a..8fc45c0213c8 100644 --- a/core/src/main/java/org/apache/iceberg/MergingSnapshotProducer.java +++ b/core/src/main/java/org/apache/iceberg/MergingSnapshotProducer.java @@ -976,7 +976,8 @@ public List apply(TableMetadata base, Snapshot snapshot) { // filter any existing manifests List filtered = filterManager.filterManifests( - base.schema(), snapshot != null ? snapshot.dataManifests(ops.io()) : null); + SnapshotUtil.schemaFor(base, targetBranch()), + snapshot != null ? snapshot.dataManifests(ops.io()) : null); long minDataSequenceNumber = filtered.stream() .map(ManifestFile::minSequenceNumber) @@ -989,7 +990,8 @@ public List apply(TableMetadata base, Snapshot snapshot) { deleteFilterManager.dropDeleteFilesOlderThan(minDataSequenceNumber); List filteredDeletes = deleteFilterManager.filterManifests( - base.schema(), snapshot != null ? snapshot.deleteManifests(ops.io()) : null); + SnapshotUtil.schemaFor(base, targetBranch()), + snapshot != null ? snapshot.deleteManifests(ops.io()) : null); // only keep manifests that have live data files or that were written by this commit Predicate shouldKeep = diff --git a/core/src/main/java/org/apache/iceberg/SnapshotProducer.java b/core/src/main/java/org/apache/iceberg/SnapshotProducer.java index 4a7aa746315f..49d4921796f9 100644 --- a/core/src/main/java/org/apache/iceberg/SnapshotProducer.java +++ b/core/src/main/java/org/apache/iceberg/SnapshotProducer.java @@ -57,6 +57,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.util.Exceptions; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.iceberg.util.Tasks; import org.apache.iceberg.util.ThreadPools; import org.slf4j.Logger; @@ -149,6 +150,10 @@ protected void targetBranch(String branch) { this.targetBranch = branch; } + protected String targetBranch() { + return targetBranch; + } + protected ExecutorService workerPool() { return this.workerPool; } @@ -202,15 +207,7 @@ protected void validate(TableMetadata currentMetadata, Snapshot snapshot) {} @Override public Snapshot apply() { refresh(); - Snapshot parentSnapshot = base.currentSnapshot(); - if (targetBranch != null) { - SnapshotRef branch = base.ref(targetBranch); - if (branch != null) { - parentSnapshot = base.snapshot(branch.snapshotId()); - } else if (base.currentSnapshot() != null) { - parentSnapshot = base.currentSnapshot(); - } - } + Snapshot parentSnapshot = SnapshotUtil.latestSnapshot(base, targetBranch); long sequenceNumber = base.nextSequenceNumber(); Long parentSnapshotId = parentSnapshot == null ? null : parentSnapshot.snapshotId(); diff --git a/core/src/main/java/org/apache/iceberg/SnapshotScan.java b/core/src/main/java/org/apache/iceberg/SnapshotScan.java index 07858adb962c..b6520c2ff4d3 100644 --- a/core/src/main/java/org/apache/iceberg/SnapshotScan.java +++ b/core/src/main/java/org/apache/iceberg/SnapshotScan.java @@ -85,11 +85,16 @@ public ThisT useSnapshot(long scanSnapshotId) { } public ThisT useRef(String name) { + if (SnapshotRef.MAIN_BRANCH.equals(name)) { + return newRefinedScan(table(), tableSchema(), context()); + } + Preconditions.checkArgument( snapshotId() == null, "Cannot override ref, already set snapshot id=%s", snapshotId()); Snapshot snapshot = table().snapshot(name); Preconditions.checkArgument(snapshot != null, "Cannot find ref %s", name); - return newRefinedScan(table(), tableSchema(), context().useSnapshotId(snapshot.snapshotId())); + TableScanContext newContext = context().useSnapshotId(snapshot.snapshotId()); + return newRefinedScan(table(), SnapshotUtil.schemaFor(table(), name), newContext); } public ThisT asOfTime(long timestampMillis) { diff --git a/core/src/main/java/org/apache/iceberg/util/SnapshotUtil.java b/core/src/main/java/org/apache/iceberg/util/SnapshotUtil.java index 679a66c587fa..e5c4351f1e4c 100644 --- a/core/src/main/java/org/apache/iceberg/util/SnapshotUtil.java +++ b/core/src/main/java/org/apache/iceberg/util/SnapshotUtil.java @@ -397,6 +397,53 @@ public static Schema schemaFor(Table table, Long snapshotId, Long timestampMilli return table.schema(); } + /** + * Return the schema of the snapshot at a given branch. + * + *

If branch does not exist, the table schema is returned because it will be the schema when + * the new branch is created. + * + * @param table a {@link Table} + * @param branch branch name of the table (nullable) + * @return schema of the specific snapshot at the given branch + */ + public static Schema schemaFor(Table table, String branch) { + if (branch == null || branch.equals(SnapshotRef.MAIN_BRANCH)) { + return table.schema(); + } + + Snapshot ref = table.snapshot(branch); + if (ref == null) { + return table.schema(); + } + + return schemaFor(table, ref.snapshotId()); + } + + /** + * Return the schema of the snapshot at a given branch. + * + *

If branch does not exist, the table schema is returned because it will be the schema when + * the new branch is created. + * + * @param metadata a {@link TableMetadata} + * @param branch branch name of the table (nullable) + * @return schema of the specific snapshot at the given branch + */ + public static Schema schemaFor(TableMetadata metadata, String branch) { + if (branch == null || branch.equals(SnapshotRef.MAIN_BRANCH)) { + return metadata.schema(); + } + + SnapshotRef ref = metadata.ref(branch); + if (ref == null) { + return metadata.schema(); + } + + Snapshot snapshot = metadata.snapshot(ref.snapshotId()); + return metadata.schemas().get(snapshot.schemaId()); + } + /** * Fetch the snapshot at the head of the given branch in the given table. * @@ -405,11 +452,11 @@ public static Schema schemaFor(Table table, Long snapshotId, Long timestampMilli * code path to ensure backwards compatibility. * * @param table a {@link Table} - * @param branch branch name of the table + * @param branch branch name of the table (nullable) * @return the latest snapshot for the given branch */ public static Snapshot latestSnapshot(Table table, String branch) { - if (branch.equals(SnapshotRef.MAIN_BRANCH)) { + if (branch == null || branch.equals(SnapshotRef.MAIN_BRANCH)) { return table.currentSnapshot(); } @@ -423,15 +470,23 @@ public static Snapshot latestSnapshot(Table table, String branch) { * TableMetadata#ref(String)}} for the main branch so that existing code still goes through the * old code path to ensure backwards compatibility. * + *

If branch does not exist, the table's latest snapshot is returned it will be the schema when + * the new branch is created. + * * @param metadata a {@link TableMetadata} - * @param branch branch name of the table metadata + * @param branch branch name of the table metadata (nullable) * @return the latest snapshot for the given branch */ public static Snapshot latestSnapshot(TableMetadata metadata, String branch) { - if (branch.equals(SnapshotRef.MAIN_BRANCH)) { + if (branch == null || branch.equals(SnapshotRef.MAIN_BRANCH)) { + return metadata.currentSnapshot(); + } + + SnapshotRef ref = metadata.ref(branch); + if (ref == null) { return metadata.currentSnapshot(); } - return metadata.snapshot(metadata.ref(branch).snapshotId()); + return metadata.snapshot(ref.snapshotId()); } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java index ff09cf754a42..0ca63ae2bfa2 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java @@ -43,6 +43,7 @@ import org.apache.iceberg.DataFile; import org.apache.iceberg.Files; import org.apache.iceberg.Snapshot; +import org.apache.iceberg.SnapshotRef; import org.apache.iceberg.Table; import org.apache.iceberg.data.GenericRecord; import org.apache.iceberg.data.parquet.GenericParquetWriter; @@ -70,6 +71,7 @@ public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTes protected final String fileFormat; protected final boolean vectorized; protected final String distributionMode; + protected final String branch; public SparkRowLevelOperationsTestBase( String catalogName, @@ -77,17 +79,19 @@ public SparkRowLevelOperationsTestBase( Map config, String fileFormat, boolean vectorized, - String distributionMode) { + String distributionMode, + String branch) { super(catalogName, implementation, config); this.fileFormat = fileFormat; this.vectorized = vectorized; this.distributionMode = distributionMode; + this.branch = branch; } @Parameters( name = "catalogName = {0}, implementation = {1}, config = {2}," - + " format = {3}, vectorized = {4}, distributionMode = {5}") + + " format = {3}, vectorized = {4}, distributionMode = {5}, branch = {6}") public static Object[][] parameters() { return new Object[][] { { @@ -98,7 +102,8 @@ public static Object[][] parameters() { "default-namespace", "default"), "orc", true, - WRITE_DISTRIBUTION_MODE_NONE + WRITE_DISTRIBUTION_MODE_NONE, + SnapshotRef.MAIN_BRANCH }, { "testhive", @@ -108,7 +113,8 @@ public static Object[][] parameters() { "default-namespace", "default"), "parquet", true, - WRITE_DISTRIBUTION_MODE_NONE + WRITE_DISTRIBUTION_MODE_NONE, + null, }, { "testhadoop", @@ -116,7 +122,8 @@ public static Object[][] parameters() { ImmutableMap.of("type", "hadoop"), "parquet", RANDOM.nextBoolean(), - WRITE_DISTRIBUTION_MODE_HASH + WRITE_DISTRIBUTION_MODE_HASH, + null }, { "spark_catalog", @@ -131,7 +138,8 @@ public static Object[][] parameters() { ), "avro", false, - WRITE_DISTRIBUTION_MODE_RANGE + WRITE_DISTRIBUTION_MODE_RANGE, + "test" } }; } @@ -181,6 +189,7 @@ protected void createAndInitTable(String schema, String partitioning, String jso try { Dataset ds = toDS(schema, jsonData); ds.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); } catch (NoSuchTableException e) { throw new RuntimeException("Failed to write data", e); } @@ -315,4 +324,20 @@ protected DataFile writeDataFile(Table table, List records) { throw new UncheckedIOException(e); } } + + @Override + protected String commitTarget() { + return branch == null ? tableName : String.format("%s.branch_%s", tableName, branch); + } + + @Override + protected String selectTarget() { + return branch == null ? tableName : String.format("%s VERSION AS OF '%s'", tableName, branch); + } + + protected void createBranchIfNeeded() { + if (branch != null && !branch.equals(SnapshotRef.MAIN_BRANCH)) { + sql("ALTER TABLE %s CREATE BRANCH %s", tableName, branch); + } + } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java index 5c9d547a6aba..53177340dadd 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteDelete.java @@ -30,6 +30,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; import org.apache.iceberg.DataFile; import org.apache.iceberg.RowLevelOperationMode; import org.apache.iceberg.Snapshot; @@ -42,6 +43,7 @@ import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.internal.SQLConf; @@ -58,8 +60,9 @@ public TestCopyOnWriteDelete( Map config, String fileFormat, Boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @Override @@ -82,6 +85,7 @@ public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception tableName, DELETE_ISOLATION_LEVEL, "snapshot"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); Table table = Spark3Util.loadIcebergTable(spark, tableName); @@ -101,7 +105,7 @@ public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception sleep(10); } - sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", tableName); + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); barrier.incrementAndGet(); } @@ -111,7 +115,7 @@ public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception Future appendFuture = executorService.submit( () -> { - GenericRecord record = GenericRecord.create(table.schema()); + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); record.set(0, 1); // id record.set(1, "hr"); // dep @@ -126,7 +130,12 @@ public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -153,7 +162,8 @@ public synchronized void testDeleteWithConcurrentTableRefresh() throws Exception public void testRuntimeFilteringWithPreservedDataGrouping() throws NoSuchTableException { createAndInitPartitionedTable(); - append(new Employee(1, "hr"), new Employee(3, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); append(new Employee(1, "hardware"), new Employee(2, "hardware")); Map sqlConf = @@ -163,17 +173,17 @@ public void testRuntimeFilteringWithPreservedDataGrouping() throws NoSuchTableEx SparkSQLProperties.PRESERVE_DATA_GROUPING, "true"); - withSQLConf(sqlConf, () -> sql("DELETE FROM %s WHERE id = 2", tableName)); + withSQLConf(sqlConf, () -> sql("DELETE FROM %s WHERE id = 2", commitTarget())); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); validateCopyOnWrite(currentSnapshot, "1", "1", "1"); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java index 6a50937c8845..ed1e05f822cf 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java @@ -42,6 +42,7 @@ import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.internal.SQLConf; import org.assertj.core.api.Assertions; @@ -57,8 +58,9 @@ public TestCopyOnWriteMerge( Map config, String fileFormat, boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @Override @@ -81,6 +83,7 @@ public synchronized void testMergeWithConcurrentTableRefresh() throws Exception tableName, MERGE_ISOLATION_LEVEL, "snapshot"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); Table table = Spark3Util.loadIcebergTable(spark, tableName); @@ -159,8 +162,9 @@ public void testRuntimeFilteringWithReportedPartitioning() { sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); createOrReplaceView("source", Collections.singletonList(2), Encoders.INT()); @@ -180,17 +184,17 @@ public void testRuntimeFilteringWithReportedPartitioning() { + "ON t.id == s.value " + "WHEN MATCHED THEN " + " UPDATE SET id = -1", - tableName)); + commitTarget())); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); validateCopyOnWrite(currentSnapshot, "1", "1", "1"); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java index cc17d4aa0546..f9f48e8f41c7 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteUpdate.java @@ -29,6 +29,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; import org.apache.iceberg.DataFile; import org.apache.iceberg.RowLevelOperationMode; import org.apache.iceberg.Snapshot; @@ -41,6 +42,7 @@ import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.internal.SQLConf; import org.assertj.core.api.Assertions; import org.junit.Assert; @@ -55,8 +57,9 @@ public TestCopyOnWriteUpdate( Map config, String fileFormat, boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @Override @@ -78,6 +81,7 @@ public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); Table table = Spark3Util.loadIcebergTable(spark, tableName); @@ -97,7 +101,7 @@ public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception sleep(10); } - sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); barrier.incrementAndGet(); } @@ -107,7 +111,7 @@ public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception Future appendFuture = executorService.submit( () -> { - GenericRecord record = GenericRecord.create(table.schema()); + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); record.set(0, 1); // id record.set(1, "hr"); // dep @@ -122,7 +126,12 @@ public synchronized void testUpdateWithConcurrentTableRefresh() throws Exception for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -151,8 +160,9 @@ public void testRuntimeFilteringWithReportedPartitioning() { sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); Map sqlConf = @@ -162,17 +172,17 @@ public void testRuntimeFilteringWithReportedPartitioning() { SparkSQLProperties.PRESERVE_DATA_GROUPING, "true"); - withSQLConf(sqlConf, () -> sql("UPDATE %s SET id = -1 WHERE id = 2", tableName)); + withSQLConf(sqlConf, () -> sql("UPDATE %s SET id = -1 WHERE id = 2", commitTarget())); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); validateCopyOnWrite(currentSnapshot, "1", "1", "1"); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } } diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java index 0b73821c617d..4e2851972c28 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java @@ -40,6 +40,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import org.apache.iceberg.AppendFiles; import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.DataFile; import org.apache.iceberg.ManifestFile; @@ -55,6 +56,7 @@ import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.SparkException; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Dataset; @@ -81,8 +83,9 @@ public TestDelete( Map config, String fileFormat, Boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @BeforeClass @@ -101,20 +104,21 @@ public void removeTables() { public void testDeleteWithoutScanningTable() throws Exception { createAndInitPartitionedTable(); - append(new Employee(1, "hr"), new Employee(3, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); append(new Employee(1, "hardware"), new Employee(2, "hardware")); Table table = validationCatalog.loadTable(tableIdent); List manifestLocations = - table.currentSnapshot().allManifests(table.io()).stream() + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io()).stream() .map(ManifestFile::path) .collect(Collectors.toList()); withUnavailableLocations( manifestLocations, () -> { - LogicalPlan parsed = parsePlan("DELETE FROM %s WHERE dep = 'hr'", tableName); + LogicalPlan parsed = parsePlan("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); DeleteFromIcebergTable analyzed = (DeleteFromIcebergTable) spark.sessionState().analyzer().execute(parsed); @@ -125,12 +129,12 @@ public void testDeleteWithoutScanningTable() throws Exception { Assert.assertTrue("Should discard rewrite plan", optimized.rewritePlan().isEmpty()); }); - sql("DELETE FROM %s WHERE dep = 'hr'", tableName); + sql("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hardware"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -139,24 +143,25 @@ public void testDeleteFileThenMetadataDelete() throws Exception { createAndInitUnpartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); // MOR mode: writes a delete file as null cannot be deleted by metadata - sql("DELETE FROM %s AS t WHERE t.id IS NULL", tableName); + sql("DELETE FROM %s AS t WHERE t.id IS NULL", commitTarget()); // Metadata Delete Table table = Spark3Util.loadIcebergTable(spark, tableName); - Set dataFilesBefore = TestHelpers.dataFiles(table); + Set dataFilesBefore = TestHelpers.dataFiles(table, branch); - sql("DELETE FROM %s AS t WHERE t.id = 1", tableName); + sql("DELETE FROM %s AS t WHERE t.id = 1", commitTarget()); - Set dataFilesAfter = TestHelpers.dataFiles(table); + Set dataFilesAfter = TestHelpers.dataFiles(table, branch); Assert.assertTrue( "Data file should have been removed", dataFilesBefore.size() > dataFilesAfter.size()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -164,8 +169,9 @@ public void testDeleteWithFalseCondition() { createAndInitUnpartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); - sql("DELETE FROM %s WHERE id = 1 AND id > 20", tableName); + sql("DELETE FROM %s WHERE id = 1 AND id > 20", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); @@ -173,15 +179,16 @@ public void testDeleteWithFalseCondition() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test public void testDeleteFromEmptyTable() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); createAndInitUnpartitionedTable(); - sql("DELETE FROM %s WHERE id IN (1)", tableName); - sql("DELETE FROM %s WHERE dep = 'hr'", tableName); + sql("DELETE FROM %s WHERE id IN (1)", commitTarget()); + sql("DELETE FROM %s WHERE dep = 'hr'", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); @@ -189,7 +196,17 @@ public void testDeleteFromEmptyTable() { assertEquals( "Should have expected rows", ImmutableList.of(), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDeleteFromNonExistingCustomBranch() { + Assume.assumeTrue("Test only applicable to custom branch", "test".equals(branch)); + createAndInitUnpartitionedTable(); + + Assertions.assertThatThrownBy(() -> sql("DELETE FROM %s WHERE id IN (1)", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); } @Test @@ -197,10 +214,11 @@ public void testExplain() { createAndInitUnpartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); - sql("EXPLAIN DELETE FROM %s WHERE id <=> 1", tableName); + sql("EXPLAIN DELETE FROM %s WHERE id <=> 1", commitTarget()); - sql("EXPLAIN DELETE FROM %s WHERE true", tableName); + sql("EXPLAIN DELETE FROM %s WHERE true", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); @@ -208,7 +226,7 @@ public void testExplain() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", commitTarget())); } @Test @@ -216,28 +234,30 @@ public void testDeleteWithAlias() { createAndInitUnpartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); - sql("DELETE FROM %s AS t WHERE t.id IS NULL", tableName); + sql("DELETE FROM %s AS t WHERE t.id IS NULL", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test public void testDeleteWithDynamicFileFiltering() throws NoSuchTableException { createAndInitPartitionedTable(); - append(new Employee(1, "hr"), new Employee(3, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); append(new Employee(1, "hardware"), new Employee(2, "hardware")); - sql("DELETE FROM %s WHERE id = 2", tableName); + sql("DELETE FROM %s WHERE id = 2", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "1", "1", "1"); } else { @@ -247,7 +267,7 @@ public void testDeleteWithDynamicFileFiltering() throws NoSuchTableException { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hardware"), row(1, "hr"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } @Test @@ -255,13 +275,14 @@ public void testDeleteNonExistingRecords() { createAndInitPartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); - sql("DELETE FROM %s AS t WHERE t.id > 10", tableName); + sql("DELETE FROM %s AS t WHERE t.id > 10", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (fileFormat.equals("orc") || fileFormat.equals("parquet")) { validateDelete(currentSnapshot, "0", null); @@ -276,7 +297,7 @@ public void testDeleteNonExistingRecords() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -284,43 +305,47 @@ public void testDeleteWithoutCondition() { createAndInitPartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); - sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName); - sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); - sql("DELETE FROM %s", tableName); + sql("DELETE FROM %s", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); // should be a delete instead of an overwrite as it is done through a metadata operation - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); validateDelete(currentSnapshot, "2", "3"); assertEquals( - "Should have expected rows", ImmutableList.of(), sql("SELECT * FROM %s", tableName)); + "Should have expected rows", ImmutableList.of(), sql("SELECT * FROM %s", commitTarget())); } @Test public void testDeleteUsingMetadataWithComplexCondition() { createAndInitPartitionedTable(); - sql("INSERT INTO TABLE %s VALUES (1, 'dep1')", tableName); - sql("INSERT INTO TABLE %s VALUES (2, 'dep2')", tableName); - sql("INSERT INTO TABLE %s VALUES (null, 'dep3')", tableName); + sql("INSERT INTO %s VALUES (1, 'dep1')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO %s VALUES (2, 'dep2')", commitTarget()); + sql("INSERT INTO %s VALUES (null, 'dep3')", commitTarget()); - sql("DELETE FROM %s WHERE dep > 'dep2' OR dep = CAST(4 AS STRING) OR dep = 'dep2'", tableName); + sql( + "DELETE FROM %s WHERE dep > 'dep2' OR dep = CAST(4 AS STRING) OR dep = 'dep2'", + commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); // should be a delete instead of an overwrite as it is done through a metadata operation - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); validateDelete(currentSnapshot, "2", "2"); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "dep1")), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test @@ -328,17 +353,18 @@ public void testDeleteWithArbitraryPartitionPredicates() { createAndInitPartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); - sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName); - sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); // %% is an escaped version of % - sql("DELETE FROM %s WHERE id = 10 OR dep LIKE '%%ware'", tableName); + sql("DELETE FROM %s WHERE id = 10 OR dep LIKE '%%ware'", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); // should be an overwrite since cannot be executed using a metadata operation - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "1", "1", null); } else { @@ -348,7 +374,7 @@ public void testDeleteWithArbitraryPartitionPredicates() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -356,12 +382,13 @@ public void testDeleteWithNonDeterministicCondition() { createAndInitPartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); AssertHelpers.assertThrows( "Should complain about non-deterministic expressions", AnalysisException.class, "nondeterministic expressions are only allowed", - () -> sql("DELETE FROM %s WHERE id = 1 AND rand() > 0.5", tableName)); + () -> sql("DELETE FROM %s WHERE id = 1 AND rand() > 0.5", commitTarget())); } @Test @@ -369,34 +396,35 @@ public void testDeleteWithFoldableConditions() { createAndInitPartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware')", tableName); + createBranchIfNeeded(); // should keep all rows and don't trigger execution - sql("DELETE FROM %s WHERE false", tableName); + sql("DELETE FROM %s WHERE false", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // should keep all rows and don't trigger execution - sql("DELETE FROM %s WHERE 50 <> 50", tableName); + sql("DELETE FROM %s WHERE 50 <> 50", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // should keep all rows and don't trigger execution - sql("DELETE FROM %s WHERE 1 > null", tableName); + sql("DELETE FROM %s WHERE 1 > null", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // should remove all rows - sql("DELETE FROM %s WHERE 21 = 21", tableName); + sql("DELETE FROM %s WHERE 21 = 21", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); @@ -409,33 +437,34 @@ public void testDeleteWithNullConditions() { sql( "INSERT INTO TABLE %s VALUES (0, null), (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); // should keep all rows as null is never equal to null - sql("DELETE FROM %s WHERE dep = null", tableName); + sql("DELETE FROM %s WHERE dep = null", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); // null = 'software' -> null // should delete using metadata operation only - sql("DELETE FROM %s WHERE dep = 'software'", tableName); + sql("DELETE FROM %s WHERE dep = 'software'", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); // should delete using metadata operation only - sql("DELETE FROM %s WHERE dep <=> NULL", tableName); + sql("DELETE FROM %s WHERE dep <=> NULL", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); validateDelete(currentSnapshot, "1", "1"); } @@ -444,24 +473,25 @@ public void testDeleteWithInAndNotInConditions() { createAndInitUnpartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); - sql("DELETE FROM %s WHERE id IN (1, null)", tableName); + sql("DELETE FROM %s WHERE id IN (1, null)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("DELETE FROM %s WHERE id NOT IN (null, 1)", tableName); + sql("DELETE FROM %s WHERE id NOT IN (null, 1)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("DELETE FROM %s WHERE id NOT IN (1, 10)", tableName); + sql("DELETE FROM %s WHERE id NOT IN (1, 10)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -485,13 +515,14 @@ public void testDeleteWithMultipleRowGroupsParquet() throws NoSuchTableException .withColumnRenamed("value", "id") .withColumn("dep", lit("hr")); df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); - Assert.assertEquals(200, spark.table(tableName).count()); + Assert.assertEquals(200, spark.table(commitTarget()).count()); // delete a record from one of two row groups and copy over the second one - sql("DELETE FROM %s WHERE id IN (200, 201)", tableName); + sql("DELETE FROM %s WHERE id IN (200, 201)", commitTarget()); - Assert.assertEquals(199, spark.table(tableName).count()); + Assert.assertEquals(199, spark.table(commitTarget()).count()); } @Test @@ -499,15 +530,18 @@ public void testDeleteWithConditionOnNestedColumn() { createAndInitNestedColumnsTable(); sql("INSERT INTO TABLE %s VALUES (1, named_struct(\"c1\", 3, \"c2\", \"v1\"))", tableName); - sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, named_struct(\"c1\", 2, \"c2\", \"v2\"))", commitTarget()); - sql("DELETE FROM %s WHERE complex.c1 = id + 2", tableName); + sql("DELETE FROM %s WHERE complex.c1 = id + 2", commitTarget()); assertEquals( - "Should have expected rows", ImmutableList.of(row(2)), sql("SELECT id FROM %s", tableName)); + "Should have expected rows", + ImmutableList.of(row(2)), + sql("SELECT id FROM %s", selectTarget())); - sql("DELETE FROM %s t WHERE t.complex.c1 = id", tableName); + sql("DELETE FROM %s t WHERE t.complex.c1 = id", commitTarget()); assertEquals( - "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", tableName)); + "Should have expected rows", ImmutableList.of(), sql("SELECT id FROM %s", selectTarget())); } @Test @@ -515,117 +549,124 @@ public void testDeleteWithInSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable(); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); createOrReplaceView("deleted_id", Arrays.asList(0, 1, null), Encoders.INT()); createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); sql( "DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id) AND dep IN (SELECT * from deleted_dep)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); append(new Employee(1, "hr"), new Employee(-1, "hr")); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("DELETE FROM %s WHERE id IS NULL OR id IN (SELECT value + 2 FROM deleted_id)", tableName); + sql( + "DELETE FROM %s WHERE id IS NULL OR id IN (SELECT value + 2 FROM deleted_id)", + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(1, "hr")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); append(new Employee(null, "hr"), new Employee(2, "hr")); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(2, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("DELETE FROM %s WHERE id IN (SELECT value + 2 FROM deleted_id) AND dep = 'hr'", tableName); + sql( + "DELETE FROM %s WHERE id IN (SELECT value + 2 FROM deleted_id) AND dep = 'hr'", + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(1, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test public void testDeleteWithMultiColumnInSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable(); - append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); List deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class)); - sql("DELETE FROM %s WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", tableName); + sql("DELETE FROM %s WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test public void testDeleteWithNotInSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable(); - append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); // the file filter subquery (nested loop lef-anti join) returns 0 records - sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", tableName); + sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id) OR dep IN ('software', 'hr')", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s t WHERE " + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) AND " + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s t WHERE " + "id NOT IN (SELECT * FROM deleted_id WHERE value IS NOT NULL) OR " + "EXISTS (SELECT 1 FROM FROM deleted_dep WHERE t.dep = deleted_dep.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -643,51 +684,53 @@ public void testDeleteOnNonIcebergTableNotSupported() { public void testDeleteWithExistSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable(); - append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); sql( "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s t WHERE EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value) OR t.id IS NULL", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware")), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); sql( "DELETE FROM %s t WHERE " + "EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value) AND " + "EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware")), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test public void testDeleteWithNotExistsSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable(); - append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT()); createOrReplaceView("deleted_dep", Arrays.asList("software", "hr"), Encoders.STRING()); @@ -696,33 +739,34 @@ public void testDeleteWithNotExistsSubquery() throws NoSuchTableException { "DELETE FROM %s t WHERE " + "NOT EXISTS (SELECT 1 FROM deleted_id di WHERE t.id = di.value + 2) AND " + "NOT EXISTS (SELECT 1 FROM deleted_dep dd WHERE t.dep = dd.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "DELETE FROM %s t WHERE NOT EXISTS (SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); String subquery = "SELECT 1 FROM deleted_id d WHERE t.id = d.value + 2"; - sql("DELETE FROM %s t WHERE NOT EXISTS (%s) OR t.id = 1", tableName, subquery); + sql("DELETE FROM %s t WHERE NOT EXISTS (%s) OR t.id = 1", commitTarget(), subquery); assertEquals( "Should have expected rows", ImmutableList.of(), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test public void testDeleteWithScalarSubquery() throws NoSuchTableException { createAndInitUnpartitionedTable(); - append(new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(2, "hardware"), new Employee(null, "hr")); + createBranchIfNeeded(); createOrReplaceView("deleted_id", Arrays.asList(1, 100, null), Encoders.INT()); @@ -730,11 +774,11 @@ public void testDeleteWithScalarSubquery() throws NoSuchTableException { withSQLConf( ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), () -> { - sql("DELETE FROM %s t WHERE id <= (SELECT min(value) FROM deleted_id)", tableName); + sql("DELETE FROM %s t WHERE id <= (SELECT min(value) FROM deleted_id)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); }); } @@ -742,7 +786,8 @@ public void testDeleteWithScalarSubquery() throws NoSuchTableException { public void testDeleteThatRequiresGroupingBeforeWrite() throws NoSuchTableException { createAndInitPartitionedTable(); - append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + append(tableName, new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); + createBranchIfNeeded(); append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); append(new Employee(0, "hr"), new Employee(1, "hr"), new Employee(2, "hr")); append(new Employee(0, "ops"), new Employee(1, "ops"), new Employee(2, "ops")); @@ -754,8 +799,9 @@ public void testDeleteThatRequiresGroupingBeforeWrite() throws NoSuchTableExcept // set the num of shuffle partitions to 1 to ensure we have only 1 writing task spark.conf().set("spark.sql.shuffle.partitions", "1"); - sql("DELETE FROM %s t WHERE id IN (SELECT * FROM deleted_id)", tableName); - Assert.assertEquals("Should have expected num of rows", 8L, spark.table(tableName).count()); + sql("DELETE FROM %s t WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); + Assert.assertEquals( + "Should have expected num of rows", 8L, spark.table(commitTarget()).count()); } finally { spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); } @@ -774,6 +820,7 @@ public synchronized void testDeleteWithSerializableIsolation() throws Interrupte tableName, DELETE_ISOLATION_LEVEL, "serializable"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); ExecutorService executorService = MoreExecutors.getExitingExecutorService( @@ -791,7 +838,7 @@ public synchronized void testDeleteWithSerializableIsolation() throws Interrupte sleep(10); } - sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", tableName); + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); barrier.incrementAndGet(); } @@ -804,7 +851,7 @@ public synchronized void testDeleteWithSerializableIsolation() throws Interrupte // load the table via the validation catalog to use another table instance Table table = validationCatalog.loadTable(tableIdent); - GenericRecord record = GenericRecord.create(table.schema()); + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); record.set(0, 1); // id record.set(1, "hr"); // dep @@ -819,7 +866,12 @@ public synchronized void testDeleteWithSerializableIsolation() throws Interrupte for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -858,6 +910,7 @@ public synchronized void testDeleteWithSnapshotIsolation() tableName, DELETE_ISOLATION_LEVEL, "snapshot"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); ExecutorService executorService = MoreExecutors.getExitingExecutorService( @@ -875,7 +928,7 @@ public synchronized void testDeleteWithSnapshotIsolation() sleep(10); } - sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", tableName); + sql("DELETE FROM %s WHERE id IN (SELECT * FROM deleted_id)", commitTarget()); barrier.incrementAndGet(); } @@ -888,7 +941,7 @@ public synchronized void testDeleteWithSnapshotIsolation() // load the table via the validation catalog to use another table instance for inserts Table table = validationCatalog.loadTable(tableIdent); - GenericRecord record = GenericRecord.create(table.schema()); + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); record.set(0, 1); // id record.set(1, "hr"); // dep @@ -903,7 +956,12 @@ public synchronized void testDeleteWithSnapshotIsolation() for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -926,10 +984,11 @@ public synchronized void testDeleteWithSnapshotIsolation() public void testDeleteRefreshesRelationCache() throws NoSuchTableException { createAndInitPartitionedTable(); - append(new Employee(1, "hr"), new Employee(3, "hr")); + append(tableName, new Employee(1, "hr"), new Employee(3, "hr")); + createBranchIfNeeded(); append(new Employee(1, "hardware"), new Employee(2, "hardware")); - Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); query.createOrReplaceTempView("tmp"); spark.sql("CACHE TABLE tmp"); @@ -939,12 +998,12 @@ public void testDeleteRefreshesRelationCache() throws NoSuchTableException { ImmutableList.of(row(1, "hardware"), row(1, "hr")), sql("SELECT * FROM tmp ORDER BY id, dep")); - sql("DELETE FROM %s WHERE id = 1", tableName); + sql("DELETE FROM %s WHERE id = 1", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "2", "2", "2"); } else { @@ -953,7 +1012,7 @@ public void testDeleteRefreshesRelationCache() throws NoSuchTableException { assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hardware"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); assertEquals( "Should refresh the relation cache", @@ -969,28 +1028,29 @@ public void testDeleteWithMultipleSpecs() { // write an unpartitioned file append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + createBranchIfNeeded(); // write a file partitioned by dep sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); append( - tableName, + commitTarget(), "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); // write a file partitioned by dep and category sql("ALTER TABLE %s ADD PARTITION FIELD category", tableName); - append(tableName, "{ \"id\": 5, \"dep\": \"hr\", \"category\": \"c1\"}"); + append(commitTarget(), "{ \"id\": 5, \"dep\": \"hr\", \"category\": \"c1\"}"); // write another file partitioned by dep sql("ALTER TABLE %s DROP PARTITION FIELD category", tableName); - append(tableName, "{ \"id\": 7, \"dep\": \"hr\", \"category\": \"c1\"}"); + append(commitTarget(), "{ \"id\": 7, \"dep\": \"hr\", \"category\": \"c1\"}"); - sql("DELETE FROM %s WHERE id IN (1, 3, 5, 7)", tableName); + sql("DELETE FROM %s WHERE id IN (1, 3, 5, 7)", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 5 snapshots", 5, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { // copy-on-write is tested against v1 and such tables have different partition evolution // behavior @@ -1003,7 +1063,7 @@ public void testDeleteWithMultipleSpecs() { assertEquals( "Should have expected rows", ImmutableList.of(row(2, "hr", "c1")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } // TODO: multiple stripes for ORC @@ -1024,9 +1084,13 @@ protected void createAndInitNestedColumnsTable() { } protected void append(Employee... employees) throws NoSuchTableException { + append(commitTarget(), employees); + } + + protected void append(String target, Employee... employees) throws NoSuchTableException { List input = Arrays.asList(employees); Dataset inputDF = spark.createDataFrame(input, Employee.class); - inputDF.coalesce(1).writeTo(tableName).append(); + inputDF.coalesce(1).writeTo(target).append(); } private RowLevelOperationMode mode(Table table) { diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java index 9581748e324e..35f12f6ac83a 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java @@ -39,6 +39,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.DataFile; import org.apache.iceberg.DistributionMode; @@ -52,6 +53,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.SparkException; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Dataset; @@ -75,8 +77,9 @@ public TestMerge( Map config, String fileFormat, boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @BeforeClass @@ -113,7 +116,7 @@ public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() { + " UPDATE SET salary = s.salary " + "WHEN NOT MATCHED THEN " + " INSERT *", - tableName); + commitTarget()); Table table = validationCatalog.loadTable(tableIdent); @@ -136,7 +139,7 @@ public void testMergeConditionSplitIntoTargetPredicateAndJoinCondition() { row(2, 200, "d2", "sd2"), // new row(3, 300, "d3", "sd3"), // new row(6, 600, "d6", "sd6")), // existing - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -147,13 +150,14 @@ public void testMergeWithStaticPredicatePushDown() { // add a data file to the 'software' partition append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + createBranchIfNeeded(); // add a data file to the 'hr' partition - append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }"); + append(commitTarget(), "{ \"id\": 1, \"dep\": \"hr\" }"); Table table = validationCatalog.loadTable(tableIdent); - Snapshot snapshot = table.currentSnapshot(); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); Assert.assertEquals("Must have 2 files before MERGE", "2", dataFilesCount); @@ -175,7 +179,7 @@ public void testMergeWithStaticPredicatePushDown() { + " UPDATE SET dep = source.dep " + "WHEN NOT MATCHED THEN " + " INSERT (dep, id) VALUES (source.dep, source.id)", - tableName); + commitTarget()); }); }); @@ -186,11 +190,14 @@ public void testMergeWithStaticPredicatePushDown() { row(2L, "hardware") // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + "Output should match", + expectedRows, + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } @Test public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); createAndInitTable("id INT, dep STRING"); createOrReplaceView( @@ -214,11 +221,14 @@ public void testMergeIntoEmptyTargetInsertAllNonMatchingRows() { row(3, "emp-id-3") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test public void testMergeIntoEmptyTargetInsertOnlyMatchingRows() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); createAndInitTable("id INT, dep STRING"); createOrReplaceView( @@ -241,7 +251,9 @@ public void testMergeIntoEmptyTargetInsertOnlyMatchingRows() { row(3, "emp-id-3") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -262,7 +274,7 @@ public void testMergeWithOnlyUpdateClause() { + "ON t.id == s.id " + "WHEN MATCHED AND t.id = 1 THEN " + " UPDATE SET *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -270,7 +282,9 @@ public void testMergeWithOnlyUpdateClause() { row(6, "emp-id-six") // kept ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -293,7 +307,7 @@ public void testMergeWithOnlyUpdateClauseAndNullValues() { + "ON t.id == s.id AND t.id < 3 " + "WHEN MATCHED THEN " + " UPDATE SET *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -301,7 +315,9 @@ public void testMergeWithOnlyUpdateClauseAndNullValues() { row(1, "emp-id-1"), // updated row(6, "emp-id-six")); // kept assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -322,14 +338,16 @@ public void testMergeWithOnlyDeleteClause() { + "ON t.id == s.id " + "WHEN MATCHED AND t.id = 6 THEN " + " DELETE", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( row(1, "emp-id-one") // kept ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -354,7 +372,7 @@ public void testMergeWithAllCauses() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -362,7 +380,9 @@ public void testMergeWithAllCauses() { row(2, "emp-id-2") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -387,7 +407,7 @@ public void testMergeWithAllCausesWithExplicitColumnSpecification() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT (t.id, t.dep) VALUES (s.id, s.dep)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -395,7 +415,9 @@ public void testMergeWithAllCausesWithExplicitColumnSpecification() { row(2, "emp-id-2") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -421,7 +443,7 @@ public void testMergeWithSourceCTE() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 3 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -429,7 +451,9 @@ public void testMergeWithSourceCTE() { row(3, "emp-id-3") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -459,7 +483,7 @@ public void testMergeWithSourceFromSetOps() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName, derivedSource); + commitTarget(), derivedSource); ImmutableList expectedRows = ImmutableList.of( @@ -467,7 +491,9 @@ public void testMergeWithSourceFromSetOps() { row(2, "emp-id-2") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -498,13 +524,13 @@ public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSource() { + " DELETE " + "WHEN NOT MATCHED AND s.value = 2 THEN " + " INSERT (id, dep) VALUES (s.value, null)", - tableName); + commitTarget()); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -540,14 +566,14 @@ public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSource() { + " DELETE " + "WHEN NOT MATCHED AND s.value = 2 THEN " + " INSERT (id, dep) VALUES (s.value, null)", - tableName); + commitTarget()); }); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -580,14 +606,14 @@ public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoEqua + " DELETE " + "WHEN NOT MATCHED AND s.value = 2 THEN " + " INSERT (id, dep) VALUES (s.value, null)", - tableName); + commitTarget()); }); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -616,13 +642,13 @@ public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotM + " UPDATE SET id = 10 " + "WHEN MATCHED AND t.id = 6 THEN " + " DELETE", - tableName); + commitTarget()); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -650,13 +676,13 @@ public void testMergeWithMultipleUpdatesForTargetRowSmallTargetLargeSourceNoNotM + " UPDATE SET id = 10 " + "WHEN MATCHED AND t.id = 6 THEN " + " DELETE", - tableName); + commitTarget()); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -688,13 +714,13 @@ public void testMergeWithMultipleUpdatesForTargetRow() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -718,14 +744,16 @@ public void testMergeWithUnconditionalDelete() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( row(2, "emp-id-2") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -755,13 +783,13 @@ public void testMergeWithSingleConditionalDelete() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); }); assertEquals( "Target should be unchanged", ImmutableList.of(row(1, "emp-id-one"), row(6, "emp-id-6")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -776,6 +804,7 @@ public void testMergeWithIdentityTransform() { append( tableName, "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); createOrReplaceView( "source", @@ -793,7 +822,7 @@ public void testMergeWithIdentityTransform() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -803,7 +832,7 @@ public void testMergeWithIdentityTransform() { assertEquals( "Should have expected rows", expectedRows, - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); removeTables(); } @@ -823,6 +852,7 @@ public void testMergeWithDaysTransform() { "id INT, ts TIMESTAMP", "{ \"id\": 1, \"ts\": \"2000-01-01 00:00:00\" }\n" + "{ \"id\": 6, \"ts\": \"2000-01-06 00:00:00\" }"); + createBranchIfNeeded(); createOrReplaceView( "source", @@ -840,7 +870,7 @@ public void testMergeWithDaysTransform() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -850,7 +880,7 @@ public void testMergeWithDaysTransform() { assertEquals( "Should have expected rows", expectedRows, - sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id", tableName)); + sql("SELECT id, CAST(ts AS STRING) FROM %s ORDER BY id", selectTarget())); removeTables(); } @@ -868,6 +898,7 @@ public void testMergeWithBucketTransform() { append( tableName, "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); createOrReplaceView( "source", @@ -885,7 +916,7 @@ public void testMergeWithBucketTransform() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -895,7 +926,7 @@ public void testMergeWithBucketTransform() { assertEquals( "Should have expected rows", expectedRows, - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); removeTables(); } @@ -913,6 +944,7 @@ public void testMergeWithTruncateTransform() { append( tableName, "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); createOrReplaceView( "source", @@ -930,7 +962,7 @@ public void testMergeWithTruncateTransform() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -940,7 +972,7 @@ public void testMergeWithTruncateTransform() { assertEquals( "Should have expected rows", expectedRows, - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); removeTables(); } @@ -959,6 +991,7 @@ public void testMergeIntoPartitionedAndOrderedTable() { append( tableName, "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }"); + createBranchIfNeeded(); createOrReplaceView( "source", @@ -976,7 +1009,7 @@ public void testMergeIntoPartitionedAndOrderedTable() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -986,7 +1019,7 @@ public void testMergeIntoPartitionedAndOrderedTable() { assertEquals( "Should have expected rows", expectedRows, - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); removeTables(); } @@ -1004,7 +1037,7 @@ public void testSelfMerge() { + " UPDATE SET v = 'x' " + "WHEN NOT MATCHED THEN " + " INSERT *", - tableName, tableName); + commitTarget(), commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1012,7 +1045,7 @@ public void testSelfMerge() { row(2, "v2") // kept ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1029,7 +1062,7 @@ public void testSelfMergeWithCaching() { + " UPDATE SET v = 'x' " + "WHEN NOT MATCHED THEN " + " INSERT *", - tableName, tableName); + commitTarget(), commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1037,7 +1070,7 @@ public void testSelfMergeWithCaching() { row(2, "v2") // kept ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", commitTarget())); } @Test @@ -1054,7 +1087,7 @@ public void testMergeWithSourceAsSelfSubquery() { + " UPDATE SET v = 'x' " + "WHEN NOT MATCHED THEN " + " INSERT (v, id) VALUES ('invalid', -1) ", - tableName, tableName); + commitTarget(), commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1062,7 +1095,7 @@ public void testMergeWithSourceAsSelfSubquery() { row(2, "v2") // kept ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1078,6 +1111,7 @@ public synchronized void testMergeWithSerializableIsolation() throws Interrupted tableName, MERGE_ISOLATION_LEVEL, "serializable"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); ExecutorService executorService = MoreExecutors.getExitingExecutorService( @@ -1100,7 +1134,7 @@ public synchronized void testMergeWithSerializableIsolation() throws Interrupted + "ON t.id == s.value " + "WHEN MATCHED THEN " + " UPDATE SET dep = 'x'", - tableName); + commitTarget()); barrier.incrementAndGet(); } @@ -1128,7 +1162,11 @@ public synchronized void testMergeWithSerializableIsolation() throws Interrupted for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + appendFiles.commit(); sleep(10); } @@ -1167,6 +1205,7 @@ public synchronized void testMergeWithSnapshotIsolation() tableName, MERGE_ISOLATION_LEVEL, "snapshot"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); ExecutorService executorService = MoreExecutors.getExitingExecutorService( @@ -1189,7 +1228,7 @@ public synchronized void testMergeWithSnapshotIsolation() + "ON t.id == s.value " + "WHEN MATCHED THEN " + " UPDATE SET dep = 'x'", - tableName); + commitTarget()); barrier.incrementAndGet(); } @@ -1217,7 +1256,12 @@ public synchronized void testMergeWithSnapshotIsolation() for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -1253,7 +1297,7 @@ public void testMergeWithExtraColumnsInSource() { + " UPDATE SET v = source.v " + "WHEN NOT MATCHED THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1263,7 +1307,7 @@ public void testMergeWithExtraColumnsInSource() { row(4, "v4") // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1281,7 +1325,7 @@ public void testMergeWithNullsInTargetAndSource() { + " UPDATE SET v = source.v " + "WHEN NOT MATCHED THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1291,7 +1335,7 @@ public void testMergeWithNullsInTargetAndSource() { row(4, "v4") // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); } @Test @@ -1309,7 +1353,7 @@ public void testMergeWithNullSafeEquals() { + " UPDATE SET v = source.v " + "WHEN NOT MATCHED THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1318,7 +1362,7 @@ public void testMergeWithNullSafeEquals() { row(4, "v4") // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); } @Test @@ -1336,7 +1380,7 @@ public void testMergeWithNullCondition() { + " UPDATE SET v = source.v " + "WHEN NOT MATCHED THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1346,7 +1390,7 @@ public void testMergeWithNullCondition() { row(2, "v2_2") // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); } @Test @@ -1370,7 +1414,7 @@ public void testMergeWithNullActionConditions() { + " DELETE " + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows1 = ImmutableList.of( @@ -1378,7 +1422,7 @@ public void testMergeWithNullActionConditions() { row(2, "v2") // kept ); assertEquals( - "Output should match", expectedRows1, sql("SELECT * FROM %s ORDER BY v", tableName)); + "Output should match", expectedRows1, sql("SELECT * FROM %s ORDER BY v", selectTarget())); // only the update and insert conditions are NULL sql( @@ -1390,14 +1434,14 @@ public void testMergeWithNullActionConditions() { + " DELETE " + "WHEN NOT MATCHED AND source.id = 3 AND NULL THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows2 = ImmutableList.of( row(2, "v2") // kept ); assertEquals( - "Output should match", expectedRows2, sql("SELECT * FROM %s ORDER BY v", tableName)); + "Output should match", expectedRows2, sql("SELECT * FROM %s ORDER BY v", selectTarget())); } @Test @@ -1418,7 +1462,7 @@ public void testMergeWithMultipleMatchingActions() { + " DELETE " + "WHEN NOT MATCHED THEN " + " INSERT (v, id) VALUES (source.v, source.id)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1426,7 +1470,7 @@ public void testMergeWithMultipleMatchingActions() { row(2, "v2") // kept (matches neither the update nor the delete cond) ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY v", selectTarget())); } @Test @@ -1453,8 +1497,9 @@ public void testMergeWithMultipleRowGroupsParquet() throws NoSuchTableException .withColumnRenamed("value", "id") .withColumn("dep", lit("hr")); df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); - Assert.assertEquals(200, spark.table(tableName).count()); + Assert.assertEquals(200, spark.table(commitTarget()).count()); // update a record from one of two row groups and copy over the second one sql( @@ -1462,9 +1507,9 @@ public void testMergeWithMultipleRowGroupsParquet() throws NoSuchTableException + "ON t.id == source.value " + "WHEN MATCHED THEN " + " UPDATE SET dep = 'x'", - tableName); + commitTarget()); - Assert.assertEquals(200, spark.table(tableName).count()); + Assert.assertEquals(200, spark.table(commitTarget()).count()); } @Test @@ -1485,7 +1530,7 @@ public void testMergeInsertOnly() { + "ON t.id == source.id " + "WHEN NOT MATCHED THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1496,7 +1541,7 @@ public void testMergeInsertOnly() { row("d", "v4_2") // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1514,7 +1559,7 @@ public void testMergeInsertOnlyWithCondition() { + "ON t.id == s.id " + "WHEN NOT MATCHED AND is_new = TRUE THEN " + " INSERT (v, id) VALUES (s.v + 100, s.id)", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1522,7 +1567,7 @@ public void testMergeInsertOnlyWithCondition() { row(2, 121) // new ); assertEquals( - "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Output should match", expectedRows, sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1540,12 +1585,12 @@ public void testMergeAlignsUpdateAndInsertActions() { + " UPDATE SET b = c2, a = c1, t.id = source.id " + "WHEN NOT MATCHED THEN " + " INSERT (b, a, id) VALUES (c2, c1, id)", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1563,21 +1608,21 @@ public void testMergeMixedCaseAlignsUpdateAndInsertActions() { + " UPDATE SET B = c2, A = c1, t.Id = source.ID " + "WHEN NOT MATCHED THEN " + " INSERT (b, A, iD) VALUES (c2, c1, id)", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, -2, "new_str_1"), row(2, -20, "new_str_2")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); assertEquals( "Output should match", ImmutableList.of(row(1, -2, "new_str_1")), - sql("SELECT * FROM %s WHERE id = 1 ORDER BY id", tableName)); + sql("SELECT * FROM %s WHERE id = 1 ORDER BY id", selectTarget())); assertEquals( "Output should match", ImmutableList.of(row(2, -20, "new_str_2")), - sql("SELECT * FROM %s WHERE b = 'new_str_2'ORDER BY id", tableName)); + sql("SELECT * FROM %s WHERE b = 'new_str_2'ORDER BY id", selectTarget())); } @Test @@ -1593,12 +1638,12 @@ public void testMergeUpdatesNestedStructFields() { + "ON t.id == source.id " + "WHEN MATCHED THEN " + " UPDATE SET t.s.c1 = source.c1, t.s.c2.a = array(-1, -2), t.s.c2.m = map('k', 'v')", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(-2, row(ImmutableList.of(-1, -2), ImmutableMap.of("k", "v"))))), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // set primitive, array, map columns to NULL (proper casts should be in place) sql( @@ -1606,12 +1651,12 @@ public void testMergeUpdatesNestedStructFields() { + "ON t.id == source.id " + "WHEN MATCHED THEN " + " UPDATE SET t.s.c1 = NULL, t.s.c2 = NULL", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(null, null))), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // update all fields in a struct sql( @@ -1619,12 +1664,12 @@ public void testMergeUpdatesNestedStructFields() { + "ON t.id == source.id " + "WHEN MATCHED THEN " + " UPDATE SET t.s = named_struct('c1', 100, 'c2', named_struct('a', array(1), 'm', map('x', 'y')))", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(100, row(ImmutableList.of(1), ImmutableMap.of("x", "y"))))), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1638,12 +1683,12 @@ public void testMergeWithInferredCasts() { + "ON t.id == source.id " + "WHEN MATCHED THEN " + " UPDATE SET t.s = source.c1", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, "-2")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1656,12 +1701,12 @@ public void testMergeModifiesNullStruct() { + "ON t.id == s.id " + "WHEN MATCHED THEN " + " UPDATE SET t.s.n1 = s.n1", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(-10, null))), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test @@ -1669,7 +1714,7 @@ public void testMergeRefreshesRelationCache() { createAndInitTable("id INT, name STRING", "{ \"id\": 1, \"name\": \"n1\" }"); createOrReplaceView("source", "{ \"id\": 1, \"name\": \"n2\" }"); - Dataset query = spark.sql("SELECT name FROM " + tableName); + Dataset query = spark.sql("SELECT name FROM " + commitTarget()); query.createOrReplaceTempView("tmp"); spark.sql("CACHE TABLE tmp"); @@ -1682,7 +1727,7 @@ public void testMergeRefreshesRelationCache() { + "ON t.id == s.id " + "WHEN MATCHED THEN " + " UPDATE SET t.name = s.name", - tableName); + commitTarget()); assertEquals( "View should have correct data", ImmutableList.of(row("n2")), sql("SELECT * FROM tmp")); @@ -1708,7 +1753,7 @@ public void testMergeWithMultipleNotMatchedActions() { + " INSERT (dep, id) VALUES (s.dep, -1)" + "WHEN NOT MATCHED THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1718,7 +1763,9 @@ public void testMergeWithMultipleNotMatchedActions() { row(3, "emp-id-3") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1739,7 +1786,7 @@ public void testMergeWithMultipleConditionalNotMatchedActions() { + " INSERT (dep, id) VALUES (s.dep, -1)" + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1748,7 +1795,9 @@ public void testMergeWithMultipleConditionalNotMatchedActions() { row(2, "emp-id-2") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1772,7 +1821,7 @@ public void testMergeResolvesColumnsByName() { + " UPDATE SET * " + "WHEN NOT MATCHED THEN " + " INSERT * ", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( @@ -1783,7 +1832,7 @@ public void testMergeResolvesColumnsByName() { assertEquals( "Should have expected rows", expectedRows, - sql("SELECT id, badge, dep FROM %s ORDER BY id", tableName)); + sql("SELECT id, badge, dep FROM %s ORDER BY id", selectTarget())); } @Test @@ -1807,6 +1856,7 @@ public void testMergeShouldResolveWhenThereAreNoUnresolvedExpressionsOrColumns() + "WHEN NOT MATCHED THEN " + " INSERT *", tableName); + createBranchIfNeeded(); ImmutableList expectedRows = ImmutableList.of( @@ -1815,7 +1865,9 @@ public void testMergeShouldResolveWhenThereAreNoUnresolvedExpressionsOrColumns() row(3, "emp-id-3") // new ); assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -1840,19 +1892,23 @@ public void testMergeWithTableWithNonNullableColumn() { + " DELETE " + "WHEN NOT MATCHED AND s.id = 2 THEN " + " INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of( row(1, "emp-id-1"), // updated row(2, "emp-id-2")); // new assertEquals( - "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName)); + "Should have expected rows", + expectedRows, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test public void testMergeWithNonExistingColumns() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -1865,7 +1921,7 @@ public void testMergeWithNonExistingColumns() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.invalid_col = s.c2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -1878,7 +1934,7 @@ public void testMergeWithNonExistingColumns() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.c.n2.invalid_col = s.c2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -1893,13 +1949,15 @@ public void testMergeWithNonExistingColumns() { + " UPDATE SET t.c.n2.dn1 = s.c2 " + "WHEN NOT MATCHED THEN " + " INSERT (id, invalid_col) VALUES (s.c1, null)", - tableName); + commitTarget()); }); } @Test public void testMergeWithInvalidColumnsInInsert() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -1914,7 +1972,7 @@ public void testMergeWithInvalidColumnsInInsert() { + " UPDATE SET t.c.n2.dn1 = s.c2 " + "WHEN NOT MATCHED THEN " + " INSERT (id, c.n2) VALUES (s.c1, null)", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -1929,7 +1987,7 @@ public void testMergeWithInvalidColumnsInInsert() { + " UPDATE SET t.c.n2.dn1 = s.c2 " + "WHEN NOT MATCHED THEN " + " INSERT (id, id) VALUES (s.c1, null)", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -1942,13 +2000,15 @@ public void testMergeWithInvalidColumnsInInsert() { + "ON t.id == s.c1 " + "WHEN NOT MATCHED THEN " + " INSERT (id) VALUES (s.c1)", - tableName); + commitTarget()); }); } @Test public void testMergeWithInvalidUpdates() { - createAndInitTable("id INT, a ARRAY>, m MAP"); + createAndInitTable( + "id INT, a ARRAY>, m MAP", + "{ \"id\": 1, \"a\": [ { \"c1\": 2, \"c2\": 3 } ], \"m\": { \"k\": \"v\"} }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -1961,7 +2021,7 @@ public void testMergeWithInvalidUpdates() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.a.c1 = s.c2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -1974,13 +2034,15 @@ public void testMergeWithInvalidUpdates() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.m.key = 'new_key'", - tableName); + commitTarget()); }); } @Test public void testMergeWithConflictingUpdates() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -1993,7 +2055,7 @@ public void testMergeWithConflictingUpdates() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.id = 1, t.c.n1 = 2, t.id = 2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2006,7 +2068,7 @@ public void testMergeWithConflictingUpdates() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2019,14 +2081,15 @@ public void testMergeWithConflictingUpdates() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", - tableName); + commitTarget()); }); } @Test public void testMergeWithInvalidAssignments() { createAndInitTable( - "id INT NOT NULL, s STRUCT> NOT NULL"); + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 1, \"s\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView( "source", "c1 INT, c2 STRUCT NOT NULL, c3 STRING NOT NULL, c4 STRUCT", @@ -2046,7 +2109,7 @@ public void testMergeWithInvalidAssignments() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.id = NULL", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2059,7 +2122,7 @@ public void testMergeWithInvalidAssignments() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.s.n1 = NULL", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2072,7 +2135,7 @@ public void testMergeWithInvalidAssignments() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.s = s.c2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2085,7 +2148,7 @@ public void testMergeWithInvalidAssignments() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.s.n1 = s.c3", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2098,7 +2161,7 @@ public void testMergeWithInvalidAssignments() { + "ON t.id == s.c1 " + "WHEN MATCHED THEN " + " UPDATE SET t.s.n2 = s.c4", - tableName); + commitTarget()); }); }); } @@ -2106,7 +2169,9 @@ public void testMergeWithInvalidAssignments() { @Test public void testMergeWithNonDeterministicConditions() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -2119,7 +2184,7 @@ public void testMergeWithNonDeterministicConditions() { + "ON t.id == s.c1 AND rand() > t.id " + "WHEN MATCHED THEN " + " UPDATE SET t.c.n1 = -1", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2132,7 +2197,7 @@ public void testMergeWithNonDeterministicConditions() { + "ON t.id == s.c1 " + "WHEN MATCHED AND rand() > t.id THEN " + " UPDATE SET t.c.n1 = -1", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2145,7 +2210,7 @@ public void testMergeWithNonDeterministicConditions() { + "ON t.id == s.c1 " + "WHEN MATCHED AND rand() > t.id THEN " + " DELETE", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2158,13 +2223,15 @@ public void testMergeWithNonDeterministicConditions() { + "ON t.id == s.c1 " + "WHEN NOT MATCHED AND rand() > c1 THEN " + " INSERT (id, c) VALUES (1, null)", - tableName); + commitTarget()); }); } @Test public void testMergeWithAggregateExpressions() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -2177,7 +2244,7 @@ public void testMergeWithAggregateExpressions() { + "ON t.id == s.c1 AND max(t.id) == 1 " + "WHEN MATCHED THEN " + " UPDATE SET t.c.n1 = -1", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2190,7 +2257,7 @@ public void testMergeWithAggregateExpressions() { + "ON t.id == s.c1 " + "WHEN MATCHED AND sum(t.id) < 1 THEN " + " UPDATE SET t.c.n1 = -1", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2203,7 +2270,7 @@ public void testMergeWithAggregateExpressions() { + "ON t.id == s.c1 " + "WHEN MATCHED AND sum(t.id) THEN " + " DELETE", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2216,13 +2283,15 @@ public void testMergeWithAggregateExpressions() { + "ON t.id == s.c1 " + "WHEN NOT MATCHED AND sum(c1) < 1 THEN " + " INSERT (id, c) VALUES (1, null)", - tableName); + commitTarget()); }); } @Test public void testMergeWithSubqueriesInConditions() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", + "{ \"id\": 1, \"c\": { \"n1\": 2, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }"); AssertHelpers.assertThrows( @@ -2235,7 +2304,7 @@ public void testMergeWithSubqueriesInConditions() { + "ON t.id == s.c1 AND t.id < (SELECT max(c2) FROM source) " + "WHEN MATCHED THEN " + " UPDATE SET t.c.n1 = s.c2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2248,7 +2317,7 @@ public void testMergeWithSubqueriesInConditions() { + "ON t.id == s.c1 " + "WHEN MATCHED AND t.id < (SELECT max(c2) FROM source) THEN " + " UPDATE SET t.c.n1 = s.c2", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2261,7 +2330,7 @@ public void testMergeWithSubqueriesInConditions() { + "ON t.id == s.c1 " + "WHEN MATCHED AND t.id NOT IN (SELECT c2 FROM source) THEN " + " DELETE", - tableName); + commitTarget()); }); AssertHelpers.assertThrows( @@ -2274,13 +2343,13 @@ public void testMergeWithSubqueriesInConditions() { + "ON t.id == s.c1 " + "WHEN NOT MATCHED AND s.c1 IN (SELECT c2 FROM source) THEN " + " INSERT (id, c) VALUES (1, null)", - tableName); + commitTarget()); }); } @Test public void testMergeWithTargetColumnsInInsertConditions() { - createAndInitTable("id INT, c2 INT"); + createAndInitTable("id INT, c2 INT", "{ \"id\": 1, \"c2\": 2 }"); createOrReplaceView("source", "{ \"id\": 1, \"value\": 11 }"); AssertHelpers.assertThrows( @@ -2293,7 +2362,7 @@ public void testMergeWithTargetColumnsInInsertConditions() { + "ON t.id == s.id " + "WHEN NOT MATCHED AND c2 = 1 THEN " + " INSERT (id, c2) VALUES (s.id, null)", - tableName); + commitTarget()); }); } @@ -2331,17 +2400,18 @@ public void testMergeSinglePartitionPartitioning() { "MERGE INTO %s t USING source s ON t.id = s.id " + "WHEN MATCHED THEN UPDATE SET *" + "WHEN NOT MATCHED THEN INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of(row(-1), row(0), row(1), row(2), row(3), row(4)); - List result = sql("SELECT * FROM %s ORDER BY id", tableName); + List result = sql("SELECT * FROM %s ORDER BY id", selectTarget()); assertEquals("Should correctly add the non-matching rows", expectedRows, result); } @Test public void testMergeEmptyTable() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); // This table will only have a single file and a single partition createAndInitTable("id INT", null); @@ -2352,14 +2422,32 @@ public void testMergeEmptyTable() { "MERGE INTO %s t USING source s ON t.id = s.id " + "WHEN MATCHED THEN UPDATE SET *" + "WHEN NOT MATCHED THEN INSERT *", - tableName); + commitTarget()); ImmutableList expectedRows = ImmutableList.of(row(0), row(1), row(2), row(3), row(4)); - List result = sql("SELECT * FROM %s ORDER BY id", tableName); + List result = sql("SELECT * FROM %s ORDER BY id", selectTarget()); assertEquals("Should correctly add the non-matching rows", expectedRows, result); } + @Test + public void testMergeNonExistingBranch() { + Assume.assumeTrue("Test only applicable to custom branch", "test".equals(branch)); + createAndInitTable("id INT", null); + + // Coalesce forces our source into a SinglePartition distribution + spark.range(0, 5).coalesce(1).createOrReplaceTempView("source"); + Assertions.assertThatThrownBy( + () -> + sql( + "MERGE INTO %s t USING source s ON t.id = s.id " + + "WHEN MATCHED THEN UPDATE SET *" + + "WHEN NOT MATCHED THEN INSERT *", + commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); + } + private void checkJoinAndFilterConditions(String query, String join, String icebergFilters) { // disable runtime filtering for easier validation withSQLConf( diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java index d7b6c0cda465..6d8120127586 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadDelete.java @@ -50,8 +50,9 @@ public TestMergeOnReadDelete( Map config, String fileFormat, Boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @Override @@ -74,14 +75,19 @@ public void testCommitUnknownException() { // write unpartitioned files append(tableName, "{ \"id\": 1, \"dep\": \"hr\", \"category\": \"c1\"}"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 2, \"dep\": \"hr\", \"category\": \"c1\" }\n" + "{ \"id\": 3, \"dep\": \"hr\", \"category\": \"c1\" }"); Table table = validationCatalog.loadTable(tableIdent); RowDelta newRowDelta = table.newRowDelta(); + if (branch != null) { + newRowDelta.toBranch(branch); + } + RowDelta spyNewRowDelta = spy(newRowDelta); doAnswer( invocation -> { @@ -93,7 +99,8 @@ public void testCommitUnknownException() { Table spyTable = spy(table); when(spyTable.newRowDelta()).thenReturn(spyNewRowDelta); - SparkTable sparkTable = new SparkTable(spyTable, false); + SparkTable sparkTable = + branch == null ? new SparkTable(spyTable, false) : new SparkTable(spyTable, branch, false); ImmutableMap config = ImmutableMap.of( @@ -129,11 +136,12 @@ public void testAggregatePushDownInMergeOnReadDelete() { sql( "INSERT INTO TABLE %s VALUES (1, 1111), (1, 2222), (2, 3333), (2, 4444), (3, 5555), (3, 6666) ", tableName); + createBranchIfNeeded(); - sql("DELETE FROM %s WHERE data = 1111", tableName); + sql("DELETE FROM %s WHERE data = 1111", commitTarget()); String select = "SELECT max(data), min(data), count(data) FROM %s"; - List explain = sql("EXPLAIN " + select, tableName); + List explain = sql("EXPLAIN " + select, selectTarget()); String explainString = explain.get(0)[0].toString().toLowerCase(Locale.ROOT); boolean explainContainsPushDownAggregates = false; if (explainString.contains("max(data)") @@ -145,7 +153,7 @@ public void testAggregatePushDownInMergeOnReadDelete() { Assert.assertFalse( "min/max/count not pushed down for deleted", explainContainsPushDownAggregates); - List actual = sql(select, tableName); + List actual = sql(select, selectTarget()); List expected = Lists.newArrayList(); expected.add(new Object[] {6666, 2222, 5L}); assertEquals("min/max/count push down", expected, actual); diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java index 0453a8787a33..86629a127687 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadMerge.java @@ -31,8 +31,9 @@ public TestMergeOnReadMerge( Map config, String fileFormat, boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @Override diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java index 2f359c7b4021..416ee8773af6 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMergeOnReadUpdate.java @@ -31,8 +31,9 @@ public TestMergeOnReadUpdate( Map config, String fileFormat, boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @Override diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java index eb6ffbe71043..f9230915d9e1 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java @@ -41,6 +41,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; +import org.apache.iceberg.AppendFiles; import org.apache.iceberg.AssertHelpers; import org.apache.iceberg.DataFile; import org.apache.iceberg.RowLevelOperationMode; @@ -56,6 +57,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecutors; import org.apache.iceberg.spark.SparkSQLProperties; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.SparkException; import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.Dataset; @@ -78,8 +80,9 @@ public TestUpdate( Map config, String fileFormat, boolean vectorized, - String distributionMode) { - super(catalogName, implementation, config, fileFormat, vectorized, distributionMode); + String distributionMode, + String branch) { + super(catalogName, implementation, config, fileFormat, vectorized, distributionMode, branch); } @BeforeClass @@ -100,10 +103,11 @@ public void testExplain() { createAndInitTable("id INT, dep STRING"); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); - sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE id <=> 1", tableName); + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE id <=> 1", commitTarget()); - sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE true", tableName); + sql("EXPLAIN UPDATE %s SET dep = 'invalid' WHERE true", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 1 snapshot", 1, Iterables.size(table.snapshots())); @@ -111,15 +115,16 @@ public void testExplain() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test public void testUpdateEmptyTable() { + Assume.assumeFalse("Custom branch does not exist for empty table", "test".equals(branch)); createAndInitTable("id INT, dep STRING"); - sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", tableName); - sql("UPDATE %s SET id = -1 WHERE dep = 'hr'", tableName); + sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", commitTarget()); + sql("UPDATE %s SET id = -1 WHERE dep = 'hr'", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); @@ -127,7 +132,18 @@ public void testUpdateEmptyTable() { assertEquals( "Should have expected rows", ImmutableList.of(), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testUpdateNonExistingCustomBranch() { + Assume.assumeTrue("Test only applicable to custom branch", "test".equals(branch)); + createAndInitTable("id INT, dep STRING"); + + Assertions.assertThatThrownBy( + () -> sql("UPDATE %s SET dep = 'invalid' WHERE id IN (1)", commitTarget())) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): test"); } @Test @@ -135,7 +151,7 @@ public void testUpdateWithAlias() { createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"a\" }"); sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); - sql("UPDATE %s AS t SET t.dep = 'invalid'", tableName); + sql("UPDATE %s AS t SET t.dep = 'invalid'", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); @@ -143,7 +159,7 @@ public void testUpdateWithAlias() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "invalid")), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test @@ -151,13 +167,14 @@ public void testUpdateAlignsAssignments() { createAndInitTable("id INT, c1 INT, c2 INT"); sql("INSERT INTO TABLE %s VALUES (1, 11, 111), (2, 22, 222)", tableName); + createBranchIfNeeded(); - sql("UPDATE %s SET `c2` = c2 - 2, c1 = `c1` - 1 WHERE id <=> 1", tableName); + sql("UPDATE %s SET `c2` = c2 - 2, c1 = `c1` - 1 WHERE id <=> 1", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, 10, 109), row(2, 22, 222)), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -166,13 +183,14 @@ public void testUpdateWithUnsupportedPartitionPredicate() { sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); sql("INSERT INTO TABLE %s VALUES (1, 'software'), (2, 'hr')", tableName); + createBranchIfNeeded(); - sql("UPDATE %s t SET `t`.`id` = -1 WHERE t.dep LIKE '%%r' ", tableName); + sql("UPDATE %s t SET `t`.`id` = -1 WHERE t.dep LIKE '%%r' ", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(1, "software")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -181,16 +199,17 @@ public void testUpdateWithDynamicFileFiltering() { sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); - sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", tableName); + sql("UPDATE %s SET id = cast('-1' AS INT) WHERE id = 2", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "1", "1", "1"); } else { @@ -200,7 +219,7 @@ public void testUpdateWithDynamicFileFiltering() { assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); } @Test @@ -209,13 +228,14 @@ public void testUpdateNonExistingRecords() { sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); sql("INSERT INTO TABLE %s VALUES (1, 'hr'), (2, 'hardware'), (null, 'hr')", tableName); + createBranchIfNeeded(); - sql("UPDATE %s SET id = -1 WHERE id > 10", tableName); + sql("UPDATE %s SET id = -1 WHERE id > 10", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 2 snapshots", 2, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "0", null, null); } else { @@ -225,7 +245,7 @@ public void testUpdateNonExistingRecords() { assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -236,8 +256,9 @@ public void testUpdateWithoutCondition() { sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); - sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", tableName); - sql("INSERT INTO TABLE %s VALUES (null, 'hr')", tableName); + createBranchIfNeeded(); + sql("INSERT INTO TABLE %s VALUES (2, 'hardware')", commitTarget()); + sql("INSERT INTO TABLE %s VALUES (null, 'hr')", commitTarget()); // set the num of shuffle partitions to 200 instead of default 4 to reduce the chance of hashing // records for multiple source files to one writing task (needed for a predictable num of output @@ -245,13 +266,13 @@ public void testUpdateWithoutCondition() { withSQLConf( ImmutableMap.of(SQLConf.SHUFFLE_PARTITIONS().key(), "200"), () -> { - sql("UPDATE %s SET id = -1", tableName); + sql("UPDATE %s SET id = -1", commitTarget()); }); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 4 snapshots", 4, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); Assert.assertEquals("Operation must match", OVERWRITE, currentSnapshot.operation()); if (mode(table) == COPY_ON_WRITE) { @@ -266,7 +287,7 @@ public void testUpdateWithoutCondition() { assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(-1, "hr")), - sql("SELECT * FROM %s ORDER BY dep ASC", tableName)); + sql("SELECT * FROM %s ORDER BY dep ASC", selectTarget())); } @Test @@ -278,27 +299,28 @@ public void testUpdateWithNullConditions() { "{ \"id\": 0, \"dep\": null }\n" + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); + createBranchIfNeeded(); // should not update any rows as null is never equal to null - sql("UPDATE %s SET id = -1 WHERE dep = NULL", tableName); + sql("UPDATE %s SET id = -1 WHERE dep = NULL", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // should not update any rows the condition does not match any records - sql("UPDATE %s SET id = -1 WHERE dep = 'software'", tableName); + sql("UPDATE %s SET id = -1 WHERE dep = 'software'", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(0, null), row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // should update one matching row with a null-safe condition - sql("UPDATE %s SET dep = 'invalid', id = -1 WHERE dep <=> NULL", tableName); + sql("UPDATE %s SET dep = 'invalid', id = -1 WHERE dep <=> NULL", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "invalid"), row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); } @Test @@ -310,24 +332,25 @@ public void testUpdateWithInAndNotInConditions() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); - sql("UPDATE %s SET id = -1 WHERE id IN (1, null)", tableName); + sql("UPDATE %s SET id = -1 WHERE id IN (1, null)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("UPDATE %s SET id = 100 WHERE id NOT IN (null, 1)", tableName); + sql("UPDATE %s SET id = 100 WHERE id NOT IN (null, 1)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); - sql("UPDATE %s SET id = 100 WHERE id NOT IN (1, 10)", tableName); + sql("UPDATE %s SET id = 100 WHERE id NOT IN (1, 10)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(100, "hardware"), row(100, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); } @Test @@ -352,13 +375,14 @@ public void testUpdateWithMultipleRowGroupsParquet() throws NoSuchTableException .withColumnRenamed("value", "id") .withColumn("dep", lit("hr")); df.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); - Assert.assertEquals(200, spark.table(tableName).count()); + Assert.assertEquals(200, spark.table(commitTarget()).count()); // update a record from one of two row groups and copy over the second one - sql("UPDATE %s SET id = -1 WHERE id IN (200, 201)", tableName); + sql("UPDATE %s SET id = -1 WHERE id IN (200, 201)", commitTarget()); - Assert.assertEquals(200, spark.table(tableName).count()); + Assert.assertEquals(200, spark.table(commitTarget()).count()); } @Test @@ -368,30 +392,30 @@ public void testUpdateNestedStructFields() { "{ \"id\": 1, \"s\": { \"c1\": 2, \"c2\": { \"a\": [1,2], \"m\": { \"a\": \"b\"} } } } }"); // update primitive, array, map columns inside a struct - sql("UPDATE %s SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1)", tableName); + sql("UPDATE %s SET s.c1 = -1, s.c2.m = map('k', 'v'), s.c2.a = array(-1)", commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(-1, row(ImmutableList.of(-1), ImmutableMap.of("k", "v"))))), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); // set primitive, array, map columns to NULL (proper casts should be in place) - sql("UPDATE %s SET s.c1 = NULL, s.c2 = NULL WHERE id IN (1)", tableName); + sql("UPDATE %s SET s.c1 = NULL, s.c2 = NULL WHERE id IN (1)", commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(null, null))), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); // update all fields in a struct sql( "UPDATE %s SET s = named_struct('c1', 1, 'c2', named_struct('a', array(1), 'm', null))", - tableName); + commitTarget()); assertEquals( "Output should match", ImmutableList.of(row(1, row(1, row(ImmutableList.of(1), null)))), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test @@ -404,30 +428,31 @@ public void testUpdateWithUserDefinedDistribution() { "{ \"id\": 1, \"c2\": 11, \"c3\": 1 }\n" + "{ \"id\": 2, \"c2\": 22, \"c3\": 1 }\n" + "{ \"id\": 3, \"c2\": 33, \"c3\": 1 }"); + createBranchIfNeeded(); // request a global sort sql("ALTER TABLE %s WRITE ORDERED BY c2", tableName); - sql("UPDATE %s SET c2 = -22 WHERE id NOT IN (1, 3)", tableName); + sql("UPDATE %s SET c2 = -22 WHERE id NOT IN (1, 3)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, 33, 1)), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); // request a local sort sql("ALTER TABLE %s WRITE LOCALLY ORDERED BY id", tableName); - sql("UPDATE %s SET c2 = -33 WHERE id = 3", tableName); + sql("UPDATE %s SET c2 = -33 WHERE id = 3", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, 11, 1), row(2, -22, 1), row(3, -33, 1)), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); // request a hash distribution + local sort sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION ORDERED BY id", tableName); - sql("UPDATE %s SET c2 = -11 WHERE id = 1", tableName); + sql("UPDATE %s SET c2 = -11 WHERE id = 1", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, -11, 1), row(2, -22, 1), row(3, -33, 1)), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -442,6 +467,7 @@ public synchronized void testUpdateWithSerializableIsolation() throws Interrupte tableName, UPDATE_ISOLATION_LEVEL, "serializable"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); ExecutorService executorService = MoreExecutors.getExitingExecutorService( @@ -459,7 +485,7 @@ public synchronized void testUpdateWithSerializableIsolation() throws Interrupte sleep(10); } - sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); barrier.incrementAndGet(); } @@ -472,7 +498,7 @@ public synchronized void testUpdateWithSerializableIsolation() throws Interrupte // load the table via the validation catalog to use another table instance Table table = validationCatalog.loadTable(tableIdent); - GenericRecord record = GenericRecord.create(table.schema()); + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); record.set(0, 1); // id record.set(1, "hr"); // dep @@ -487,7 +513,12 @@ public synchronized void testUpdateWithSerializableIsolation() throws Interrupte for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -525,6 +556,7 @@ public synchronized void testUpdateWithSnapshotIsolation() tableName, UPDATE_ISOLATION_LEVEL, "snapshot"); sql("INSERT INTO TABLE %s VALUES (1, 'hr')", tableName); + createBranchIfNeeded(); ExecutorService executorService = MoreExecutors.getExitingExecutorService( @@ -555,7 +587,7 @@ public synchronized void testUpdateWithSnapshotIsolation() // load the table via the validation catalog to use another table instance for inserts Table table = validationCatalog.loadTable(tableIdent); - GenericRecord record = GenericRecord.create(table.schema()); + GenericRecord record = GenericRecord.create(SnapshotUtil.schemaFor(table, branch)); record.set(0, 1); // id record.set(1, "hr"); // dep @@ -570,7 +602,12 @@ public synchronized void testUpdateWithSnapshotIsolation() for (int numAppends = 0; numAppends < 5; numAppends++) { DataFile dataFile = writeDataFile(table, ImmutableList.of(record)); - table.newFastAppend().appendFile(dataFile).commit(); + AppendFiles appendFiles = table.newFastAppend().appendFile(dataFile); + if (branch != null) { + appendFiles.toBranch(branch); + } + + appendFiles.commit(); sleep(10); } @@ -593,24 +630,24 @@ public synchronized void testUpdateWithSnapshotIsolation() public void testUpdateWithInferredCasts() { createAndInitTable("id INT, s STRING", "{ \"id\": 1, \"s\": \"value\" }"); - sql("UPDATE %s SET s = -1 WHERE id = 1", tableName); + sql("UPDATE %s SET s = -1 WHERE id = 1", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "-1")), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test public void testUpdateModifiesNullStruct() { createAndInitTable("id INT, s STRUCT", "{ \"id\": 1, \"s\": null }"); - sql("UPDATE %s SET s.n1 = -1 WHERE id = 1", tableName); + sql("UPDATE %s SET s.n1 = -1 WHERE id = 1", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, row(-1, null))), - sql("SELECT * FROM %s", tableName)); + sql("SELECT * FROM %s", selectTarget())); } @Test @@ -619,12 +656,13 @@ public void testUpdateRefreshesRelationCache() { sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName); append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); - Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); query.createOrReplaceTempView("tmp"); spark.sql("CACHE TABLE tmp"); @@ -634,12 +672,12 @@ public void testUpdateRefreshesRelationCache() { ImmutableList.of(row(1, "hardware"), row(1, "hr")), sql("SELECT * FROM tmp ORDER BY id, dep")); - sql("UPDATE %s SET id = -1 WHERE id = 1", tableName); + sql("UPDATE %s SET id = -1 WHERE id = 1", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "2", "2", "2"); } else { @@ -649,7 +687,7 @@ public void testUpdateRefreshesRelationCache() { assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(2, "hardware"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); assertEquals( "Should refresh the relation cache", @@ -668,6 +706,7 @@ public void testUpdateWithInSubquery() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); createOrReplaceView("updated_id", Arrays.asList(0, 1, null), Encoders.INT()); createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); @@ -676,35 +715,36 @@ public void testUpdateWithInSubquery() { "UPDATE %s SET id = -1 WHERE " + "id IN (SELECT * FROM updated_id) AND " + "dep IN (SELECT * from updated_dep)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "UPDATE %s SET id = 5 WHERE id IS NULL OR id IN (SELECT value + 1 FROM updated_id)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); - append(tableName, "{ \"id\": null, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + append( + commitTarget(), "{ \"id\": null, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); assertEquals( "Should have expected rows", ImmutableList.of( row(-1, "hr"), row(2, "hr"), row(5, "hardware"), row(5, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); sql( "UPDATE %s SET id = 10 WHERE id IN (SELECT value + 2 FROM updated_id) AND dep = 'hr'", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of( row(-1, "hr"), row(5, "hardware"), row(5, "hr"), row(10, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); } @Test @@ -715,18 +755,19 @@ public void testUpdateWithInSubqueryAndDynamicFileFiltering() { sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName); append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 3, \"dep\": \"hr\" }"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 1, \"dep\": \"hardware\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }"); createOrReplaceView("updated_id", Arrays.asList(-1, 2), Encoders.INT()); - sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName); + sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", commitTarget()); Table table = validationCatalog.loadTable(tableIdent); Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots())); - Snapshot currentSnapshot = table.currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table, branch); if (mode(table) == COPY_ON_WRITE) { validateCopyOnWrite(currentSnapshot, "1", "1", "1"); } else { @@ -736,7 +777,7 @@ public void testUpdateWithInSubqueryAndDynamicFileFiltering() { assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", commitTarget())); } @Test @@ -744,12 +785,15 @@ public void testUpdateWithSelfSubquery() { createAndInitTable("id INT, dep STRING"); append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); - sql("UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", tableName, tableName); + sql( + "UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", + commitTarget(), commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "x")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); // TODO: Spark does not support AQE and DPP with aggregates at the moment withSQLConf( @@ -758,18 +802,18 @@ public void testUpdateWithSelfSubquery() { sql( "UPDATE %s SET dep = 'y' WHERE " + "id = (SELECT count(*) FROM (SELECT DISTINCT id FROM %s) AS t)", - tableName, tableName); + commitTarget(), commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "y")), - sql("SELECT * FROM %s ORDER BY id", tableName)); + sql("SELECT * FROM %s ORDER BY id", selectTarget())); }); - sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", tableName, tableName); + sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", commitTarget(), commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(-1, "y")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } @Test @@ -781,6 +825,7 @@ public void testUpdateWithMultiColumnInSubquery() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); List deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr")); @@ -788,11 +833,11 @@ public void testUpdateWithMultiColumnInSubquery() { sql( "UPDATE %s SET dep = 'x', id = -1 WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); } @Test @@ -804,32 +849,33 @@ public void testUpdateWithNotInSubquery() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING()); // the file filter subquery (nested loop lef-anti join) returns 0 records - sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", tableName); + sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id WHERE value IS NOT NULL)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); sql( "UPDATE %s SET id = 5 WHERE id NOT IN (SELECT * FROM updated_id) OR dep IN ('software', 'hr')", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(5, "hr"), row(5, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", selectTarget())); } @Test @@ -841,46 +887,47 @@ public void testUpdateWithExistSubquery() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); createOrReplaceView("updated_dep", Arrays.asList("hr", null), Encoders.STRING()); sql( "UPDATE %s t SET id = -1 WHERE EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "UPDATE %s t SET dep = 'x', id = -1 WHERE " + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); sql( "UPDATE %s t SET id = -2 WHERE " + "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + "t.id IS NULL", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-2, "hr"), row(-2, "x"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); sql( "UPDATE %s t SET id = 1 WHERE " + "EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-2, "x"), row(1, "hr"), row(2, "hardware")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } @Test @@ -892,37 +939,38 @@ public void testUpdateWithNotExistsSubquery() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT()); createOrReplaceView("updated_dep", Arrays.asList("hr", "software"), Encoders.STRING()); sql( "UPDATE %s t SET id = -1 WHERE NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(1, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); sql( "UPDATE %s t SET id = 5 WHERE " + "NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " + "t.id = 1", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(5, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); sql( "UPDATE %s t SET id = 10 WHERE " + "NOT EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " + "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", - tableName); + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(10, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); } @Test @@ -934,6 +982,7 @@ public void testUpdateWithScalarSubquery() { "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hardware\" }\n" + "{ \"id\": null, \"dep\": \"hr\" }"); + createBranchIfNeeded(); createOrReplaceView("updated_id", Arrays.asList(1, 100, null), Encoders.INT()); @@ -941,11 +990,13 @@ public void testUpdateWithScalarSubquery() { withSQLConf( ImmutableMap.of(SQLConf.ADAPTIVE_EXECUTION_ENABLED().key(), "false"), () -> { - sql("UPDATE %s SET id = -1 WHERE id <= (SELECT min(value) FROM updated_id)", tableName); + sql( + "UPDATE %s SET id = -1 WHERE id <= (SELECT min(value) FROM updated_id)", + commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")), - sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName)); + sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", selectTarget())); }); } @@ -959,21 +1010,22 @@ public void testUpdateThatRequiresGroupingBeforeWrite() { "{ \"id\": 0, \"dep\": \"hr\" }\n" + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); append( - tableName, + commitTarget(), "{ \"id\": 0, \"dep\": \"ops\" }\n" + "{ \"id\": 1, \"dep\": \"ops\" }\n" + "{ \"id\": 2, \"dep\": \"ops\" }"); append( - tableName, + commitTarget(), "{ \"id\": 0, \"dep\": \"hr\" }\n" + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); append( - tableName, + commitTarget(), "{ \"id\": 0, \"dep\": \"ops\" }\n" + "{ \"id\": 1, \"dep\": \"ops\" }\n" + "{ \"id\": 2, \"dep\": \"ops\" }"); @@ -985,8 +1037,9 @@ public void testUpdateThatRequiresGroupingBeforeWrite() { // set the num of shuffle partitions to 1 to ensure we have only 1 writing task spark.conf().set("spark.sql.shuffle.partitions", "1"); - sql("UPDATE %s t SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName); - Assert.assertEquals("Should have expected num of rows", 12L, spark.table(tableName).count()); + sql("UPDATE %s t SET id = -1 WHERE id IN (SELECT * FROM updated_id)", commitTarget()); + Assert.assertEquals( + "Should have expected num of rows", 12L, spark.table(commitTarget()).count()); } finally { spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions); } @@ -1001,16 +1054,17 @@ public void testUpdateWithVectorization() { "{ \"id\": 0, \"dep\": \"hr\" }\n" + "{ \"id\": 1, \"dep\": \"hr\" }\n" + "{ \"id\": 2, \"dep\": \"hr\" }"); + createBranchIfNeeded(); withSQLConf( ImmutableMap.of(SparkSQLProperties.VECTORIZATION_ENABLED, "true"), () -> { - sql("UPDATE %s t SET id = -1", tableName); + sql("UPDATE %s t SET id = -1", commitTarget()); assertEquals( "Should have expected rows", ImmutableList.of(row(-1, "hr"), row(-1, "hr"), row(-1, "hr")), - sql("SELECT * FROM %s ORDER BY id, dep", tableName)); + sql("SELECT * FROM %s ORDER BY id, dep", selectTarget())); }); } @@ -1033,6 +1087,7 @@ public void testUpdateModifyPartitionSourceField() throws NoSuchTableException { .withColumn("dep", lit("hr")) .withColumn("country", lit("usa")); df1.coalesce(1).writeTo(tableName).append(); + createBranchIfNeeded(); Dataset df2 = spark @@ -1040,7 +1095,7 @@ public void testUpdateModifyPartitionSourceField() throws NoSuchTableException { .withColumnRenamed("value", "id") .withColumn("dep", lit("software")) .withColumn("country", lit("usa")); - df2.coalesce(1).writeTo(tableName).append(); + df2.coalesce(1).writeTo(commitTarget()).append(); Dataset df3 = spark @@ -1048,10 +1103,12 @@ public void testUpdateModifyPartitionSourceField() throws NoSuchTableException { .withColumnRenamed("value", "id") .withColumn("dep", lit("hardware")) .withColumn("country", lit("usa")); - df3.coalesce(1).writeTo(tableName).append(); + df3.coalesce(1).writeTo(commitTarget()).append(); - sql("UPDATE %s SET id = -1 WHERE id IN (10, 11, 12, 13, 14, 15, 16, 17, 18, 19)", tableName); - Assert.assertEquals(30L, scalarSql("SELECT count(*) FROM %s WHERE id = -1", tableName)); + sql( + "UPDATE %s SET id = -1 WHERE id IN (10, 11, 12, 13, 14, 15, 16, 17, 18, 19)", + commitTarget()); + Assert.assertEquals(30L, scalarSql("SELECT count(*) FROM %s WHERE id = -1", selectTarget())); } @Test @@ -1062,13 +1119,14 @@ public void testUpdateWithStaticPredicatePushdown() { // add a data file to the 'software' partition append(tableName, "{ \"id\": 1, \"dep\": \"software\" }"); + createBranchIfNeeded(); // add a data file to the 'hr' partition - append(tableName, "{ \"id\": 1, \"dep\": \"hr\" }"); + append(commitTarget(), "{ \"id\": 1, \"dep\": \"hr\" }"); Table table = validationCatalog.loadTable(tableIdent); - Snapshot snapshot = table.currentSnapshot(); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, branch); String dataFilesCount = snapshot.summary().get(SnapshotSummary.TOTAL_DATA_FILES_PROP); Assert.assertEquals("Must have 2 files before UPDATE", "2", dataFilesCount); @@ -1080,42 +1138,45 @@ public void testUpdateWithStaticPredicatePushdown() { withSQLConf( ImmutableMap.of(SQLConf.DYNAMIC_PARTITION_PRUNING_ENABLED().key(), "false"), () -> { - sql("UPDATE %s SET id = -1 WHERE dep IN ('software') AND id == 1", tableName); + sql("UPDATE %s SET id = -1 WHERE dep IN ('software') AND id == 1", commitTarget()); }); } @Test public void testUpdateWithInvalidUpdates() { - createAndInitTable("id INT, a ARRAY>, m MAP"); + createAndInitTable( + "id INT, a ARRAY>, m MAP", + "{ \"id\": 0, \"a\": null, \"m\": null }"); AssertHelpers.assertThrows( "Should complain about updating an array column", AnalysisException.class, "Updating nested fields is only supported for structs", - () -> sql("UPDATE %s SET a.c1 = 1", tableName)); + () -> sql("UPDATE %s SET a.c1 = 1", commitTarget())); AssertHelpers.assertThrows( "Should complain about updating a map column", AnalysisException.class, "Updating nested fields is only supported for structs", - () -> sql("UPDATE %s SET m.key = 'new_key'", tableName)); + () -> sql("UPDATE %s SET m.key = 'new_key'", commitTarget())); } @Test public void testUpdateWithConflictingAssignments() { - createAndInitTable("id INT, c STRUCT>"); + createAndInitTable( + "id INT, c STRUCT>", "{ \"id\": 0, \"s\": null }"); AssertHelpers.assertThrows( "Should complain about conflicting updates to a top-level column", AnalysisException.class, "Updates are in conflict", - () -> sql("UPDATE %s t SET t.id = 1, t.c.n1 = 2, t.id = 2", tableName)); + () -> sql("UPDATE %s t SET t.id = 1, t.c.n1 = 2, t.id = 2", commitTarget())); AssertHelpers.assertThrows( "Should complain about conflicting updates to a nested column", AnalysisException.class, "Updates are in conflict for these columns", - () -> sql("UPDATE %s t SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", tableName)); + () -> sql("UPDATE %s t SET t.c.n1 = 1, t.id = 2, t.c.n1 = 2", commitTarget())); AssertHelpers.assertThrows( "Should complain about conflicting updates to a nested column", @@ -1124,14 +1185,15 @@ public void testUpdateWithConflictingAssignments() { () -> { sql( "UPDATE %s SET c.n1 = 1, c = named_struct('n1', 1, 'n2', named_struct('dn1', 1, 'dn2', 2))", - tableName); + commitTarget()); }); } @Test public void testUpdateWithInvalidAssignments() { createAndInitTable( - "id INT NOT NULL, s STRUCT> NOT NULL"); + "id INT NOT NULL, s STRUCT> NOT NULL", + "{ \"id\": 0, \"s\": { \"n1\": 1, \"n2\": { \"dn1\": 3, \"dn2\": 4 } } }"); for (String policy : new String[] {"ansi", "strict"}) { withSQLConf( @@ -1141,44 +1203,47 @@ public void testUpdateWithInvalidAssignments() { "Should complain about writing nulls to a top-level column", AnalysisException.class, "Cannot write nullable values to non-null column", - () -> sql("UPDATE %s t SET t.id = NULL", tableName)); + () -> sql("UPDATE %s t SET t.id = NULL", commitTarget())); AssertHelpers.assertThrows( "Should complain about writing nulls to a nested column", AnalysisException.class, "Cannot write nullable values to non-null column", - () -> sql("UPDATE %s t SET t.s.n1 = NULL", tableName)); + () -> sql("UPDATE %s t SET t.s.n1 = NULL", commitTarget())); AssertHelpers.assertThrows( "Should complain about writing missing fields in structs", AnalysisException.class, "missing fields", - () -> sql("UPDATE %s t SET t.s = named_struct('n1', 1)", tableName)); + () -> sql("UPDATE %s t SET t.s = named_struct('n1', 1)", commitTarget())); AssertHelpers.assertThrows( "Should complain about writing invalid data types", AnalysisException.class, "Cannot safely cast", - () -> sql("UPDATE %s t SET t.s.n1 = 'str'", tableName)); + () -> sql("UPDATE %s t SET t.s.n1 = 'str'", commitTarget())); AssertHelpers.assertThrows( "Should complain about writing incompatible structs", AnalysisException.class, "field name does not match", - () -> sql("UPDATE %s t SET t.s.n2 = named_struct('dn2', 1, 'dn1', 2)", tableName)); + () -> + sql( + "UPDATE %s t SET t.s.n2 = named_struct('dn2', 1, 'dn1', 2)", + commitTarget())); }); } } @Test public void testUpdateWithNonDeterministicCondition() { - createAndInitTable("id INT, dep STRING"); + createAndInitTable("id INT, dep STRING", "{ \"id\": 1, \"dep\": \"hr\" }"); AssertHelpers.assertThrows( "Should complain about non-deterministic expressions", AnalysisException.class, "nondeterministic expressions are only allowed", - () -> sql("UPDATE %s SET id = -1 WHERE id = 1 AND rand() > 0.5", tableName)); + () -> sql("UPDATE %s SET id = -1 WHERE id = 1 AND rand() > 0.5", commitTarget())); } @Test diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java index f9a4787fc16d..3ad3f5d0ee2a 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkCatalog.java @@ -172,7 +172,12 @@ public Table loadTable(Identifier ident, String version) throws NoSuchTableExcep ValidationException.check( ref != null, "Cannot find matching snapshot ID or reference name for version " + version); - return sparkTable.copyWithSnapshotId(ref.snapshotId()); + + if (ref.isBranch()) { + return sparkTable.copyWithBranch(version); + } else { + return sparkTable.copyWithSnapshotId(ref.snapshotId()); + } } } else if (table instanceof SparkChangelogTable) { @@ -659,10 +664,7 @@ private Table load(Identifier ident) { Matcher branch = BRANCH.matcher(ident.name()); if (branch.matches()) { - Snapshot branchSnapshot = table.snapshot(branch.group(1)); - if (branchSnapshot != null) { - return new SparkTable(table, branchSnapshot.snapshotId(), !cacheEnabled); - } + return new SparkTable(table, branch.group(1), !cacheEnabled); } Matcher tag = TAG.matcher(ident.name()); @@ -759,10 +761,7 @@ private Table loadFromPathIdentifier(PathIdentifier ident) { return new SparkTable(table, snapshotIdAsOfTime, !cacheEnabled); } else if (branch != null) { - Snapshot branchSnapshot = table.snapshot(branch); - Preconditions.checkArgument( - branchSnapshot != null, "Cannot find snapshot associated with branch name: %s", branch); - return new SparkTable(table, branchSnapshot.snapshotId(), !cacheEnabled); + return new SparkTable(table, branch, !cacheEnabled); } else if (tag != null) { Snapshot tagSnapshot = table.snapshot(tag); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java index a44929aa30ab..1d2576180c24 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkReadConf.java @@ -21,6 +21,7 @@ import java.util.Map; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; +import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.hadoop.Util; import org.apache.iceberg.util.PropertyUtil; import org.apache.spark.sql.SparkSession; @@ -47,12 +48,19 @@ public class SparkReadConf { private final SparkSession spark; private final Table table; + private final String branch; private final Map readOptions; private final SparkConfParser confParser; public SparkReadConf(SparkSession spark, Table table, Map readOptions) { + this(spark, table, null, readOptions); + } + + public SparkReadConf( + SparkSession spark, Table table, String branch, Map readOptions) { this.spark = spark; this.table = table; + this.branch = branch; this.readOptions = readOptions; this.confParser = new SparkConfParser(spark, table, readOptions); } @@ -83,7 +91,14 @@ public Long endSnapshotId() { } public String branch() { - return confParser.stringConf().option(SparkReadOptions.BRANCH).parseOptional(); + String optionBranch = confParser.stringConf().option(SparkReadOptions.BRANCH).parseOptional(); + ValidationException.check( + branch == null || optionBranch == null || optionBranch.equals(branch), + "Must not specify different branches in both table identifier and read option, " + + "got [%s] in identifier and [%s] in options", + branch, + optionBranch); + return branch != null ? branch : optionBranch; } public String tag() { diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java index 8620c46adaf6..8e88a9b9bdf0 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SparkWriteConf.java @@ -56,12 +56,19 @@ public class SparkWriteConf { private final Table table; + private final String branch; private final RuntimeConfig sessionConf; private final Map writeOptions; private final SparkConfParser confParser; public SparkWriteConf(SparkSession spark, Table table, Map writeOptions) { + this(spark, table, null, writeOptions); + } + + public SparkWriteConf( + SparkSession spark, Table table, String branch, Map writeOptions) { this.table = table; + this.branch = branch; this.sessionConf = spark.conf(); this.writeOptions = writeOptions; this.confParser = new SparkConfParser(spark, table, writeOptions); @@ -324,4 +331,8 @@ public boolean caseSensitive() { .defaultValue(SQLConf.CASE_SENSITIVE().defaultValueString()) .parse(); } + + public String branch() { + return branch; + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java index d206bb8e2b5b..c05b694a60dc 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseBatchReader.java @@ -43,10 +43,11 @@ abstract class BaseBatchReader extends BaseReader taskGroup, + Schema tableSchema, Schema expectedSchema, boolean caseSensitive, int batchSize) { - super(table, taskGroup, expectedSchema, caseSensitive); + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); this.batchSize = batchSize; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java index 2333cd734bbe..4fb838202c88 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseReader.java @@ -73,6 +73,7 @@ abstract class BaseReader implements Closeable { private static final Logger LOG = LoggerFactory.getLogger(BaseReader.class); private final Table table; + private final Schema tableSchema; private final Schema expectedSchema; private final boolean caseSensitive; private final NameMapping nameMapping; @@ -86,11 +87,16 @@ abstract class BaseReader implements Closeable { private TaskT currentTask = null; BaseReader( - Table table, ScanTaskGroup taskGroup, Schema expectedSchema, boolean caseSensitive) { + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { this.table = table; this.taskGroup = taskGroup; this.tasks = taskGroup.tasks().iterator(); this.currentIterator = CloseableIterator.empty(); + this.tableSchema = tableSchema; this.expectedSchema = expectedSchema; this.caseSensitive = caseSensitive; String nameMappingString = table.properties().get(TableProperties.DEFAULT_NAME_MAPPING); @@ -252,7 +258,7 @@ protected class SparkDeleteFilter extends DeleteFilter { private final InternalRowWrapper asStructLike; SparkDeleteFilter(String filePath, List deletes, DeleteCounter counter) { - super(filePath, deletes, table.schema(), expectedSchema, counter); + super(filePath, deletes, tableSchema, expectedSchema, counter); this.asStructLike = new InternalRowWrapper(SparkSchemaUtil.convert(requiredSchema())); } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java index 608f0df0075d..927084caea1c 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BaseRowReader.java @@ -40,8 +40,12 @@ abstract class BaseRowReader extends BaseReader { BaseRowReader( - Table table, ScanTaskGroup taskGroup, Schema expectedSchema, boolean caseSensitive) { - super(table, taskGroup, expectedSchema, caseSensitive); + Table table, + ScanTaskGroup taskGroup, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); } protected CloseableIterable newIterable( diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java index e087fde3f8db..389ad1d5a2d9 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -30,6 +30,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.connector.metric.CustomTaskMetric; import org.apache.spark.sql.connector.read.PartitionReader; @@ -48,6 +49,7 @@ class BatchDataReader extends BaseBatchReader this( partition.table(), partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), partition.expectedSchema(), partition.isCaseSensitive(), batchSize); @@ -56,10 +58,11 @@ class BatchDataReader extends BaseBatchReader BatchDataReader( Table table, ScanTaskGroup taskGroup, + Schema tableSchema, Schema expectedSchema, boolean caseSensitive, int size) { - super(table, taskGroup, expectedSchema, caseSensitive, size); + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive, size); numSplits = taskGroup.tasks().size(); LOG.debug("Reading {} file split(s) for table {}", numSplits, table.name()); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java index ecda3ca37c8f..572f955884a3 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/ChangelogRowReader.java @@ -37,6 +37,7 @@ import org.apache.iceberg.io.CloseableIterator; import org.apache.iceberg.io.InputFile; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -51,6 +52,7 @@ class ChangelogRowReader extends BaseRowReader this( partition.table(), partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), partition.expectedSchema(), partition.isCaseSensitive()); } @@ -58,9 +60,15 @@ class ChangelogRowReader extends BaseRowReader ChangelogRowReader( Table table, ScanTaskGroup taskGroup, + Schema tableSchema, Schema expectedSchema, boolean caseSensitive) { - super(table, taskGroup, ChangelogUtil.dropChangelogMetadata(expectedSchema), caseSensitive); + super( + table, + taskGroup, + tableSchema, + ChangelogUtil.dropChangelogMetadata(expectedSchema), + caseSensitive); } @Override diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java index b930fc1091de..f5b98a5a43bd 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/EqualityDeleteRowReader.java @@ -30,8 +30,12 @@ public class EqualityDeleteRowReader extends RowDataReader { public EqualityDeleteRowReader( - CombinedScanTask task, Table table, Schema expectedSchema, boolean caseSensitive) { - super(table, task, expectedSchema, caseSensitive); + CombinedScanTask task, + Table table, + Schema tableSchema, + Schema expectedSchema, + boolean caseSensitive) { + super(table, task, tableSchema, expectedSchema, caseSensitive); } @Override diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java index 04eecc80bb49..4b847474153c 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/PositionDeletesRowReader.java @@ -33,6 +33,7 @@ import org.apache.iceberg.io.InputFile; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.primitives.Ints; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.read.PartitionReader; @@ -48,6 +49,7 @@ class PositionDeletesRowReader extends BaseRowReader this( partition.table(), partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), partition.expectedSchema(), partition.isCaseSensitive()); } @@ -55,10 +57,11 @@ class PositionDeletesRowReader extends BaseRowReader PositionDeletesRowReader( Table table, ScanTaskGroup taskGroup, + Schema tableSchema, Schema expectedSchema, boolean caseSensitive) { - super(table, taskGroup, expectedSchema, caseSensitive); + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); int numSplits = taskGroup.tasks().size(); LOG.debug("Reading {} position delete file split(s) for table {}", numSplits, table.name()); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java index 3729df930cfe..9356f62f3593 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java @@ -32,6 +32,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.spark.source.metrics.TaskNumDeletes; import org.apache.iceberg.spark.source.metrics.TaskNumSplits; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.rdd.InputFileBlockHolder; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.metric.CustomTaskMetric; @@ -48,6 +49,7 @@ class RowDataReader extends BaseRowReader implements PartitionRead this( partition.table(), partition.taskGroup(), + SnapshotUtil.schemaFor(partition.table(), partition.branch()), partition.expectedSchema(), partition.isCaseSensitive()); } @@ -55,10 +57,11 @@ class RowDataReader extends BaseRowReader implements PartitionRead RowDataReader( Table table, ScanTaskGroup taskGroup, + Schema tableSchema, Schema expectedSchema, boolean caseSensitive) { - super(table, taskGroup, expectedSchema, caseSensitive); + super(table, taskGroup, tableSchema, expectedSchema, caseSensitive); numSplits = taskGroup.tasks().size(); LOG.debug("Reading {} file split(s) for table {}", numSplits, table.name()); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java index 394a920a2dfa..63aef25ba9b1 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatch.java @@ -41,6 +41,7 @@ class SparkBatch implements Batch { private final JavaSparkContext sparkContext; private final Table table; + private final String branch; private final SparkReadConf readConf; private final Types.StructType groupingKeyType; private final List> taskGroups; @@ -59,6 +60,7 @@ class SparkBatch implements Batch { int scanHashCode) { this.sparkContext = sparkContext; this.table = table; + this.branch = readConf.branch(); this.readConf = readConf; this.groupingKeyType = groupingKeyType; this.taskGroups = taskGroups; @@ -87,6 +89,7 @@ public InputPartition[] planInputPartitions() { groupingKeyType, taskGroups.get(index), tableBroadcast, + branch, expectedSchemaString, caseSensitive, localityEnabled)); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java index ed589b2ce819..dd493fbc5097 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -64,7 +64,6 @@ class SparkBatchQueryScan extends SparkPartitioningAwareScan private final Long startSnapshotId; private final Long endSnapshotId; private final Long asOfTimestamp; - private final String branch; private final String tag; private final List runtimeFilterExpressions; @@ -82,7 +81,6 @@ class SparkBatchQueryScan extends SparkPartitioningAwareScan this.startSnapshotId = readConf.startSnapshotId(); this.endSnapshotId = readConf.endSnapshotId(); this.asOfTimestamp = readConf.asOfTimestamp(); - this.branch = readConf.branch(); this.tag = readConf.tag(); this.runtimeFilterExpressions = Lists.newArrayList(); } @@ -194,8 +192,8 @@ public Statistics estimateStatistics() { Snapshot snapshot = table().snapshot(snapshotIdAsOfTime); return estimateStatistics(snapshot); - } else if (branch != null) { - Snapshot snapshot = table().snapshot(branch); + } else if (branch() != null) { + Snapshot snapshot = table().snapshot(branch()); return estimateStatistics(snapshot); } else if (tag != null) { @@ -221,6 +219,7 @@ public boolean equals(Object o) { SparkBatchQueryScan that = (SparkBatchQueryScan) o; return table().name().equals(that.table().name()) + && Objects.equals(branch(), that.branch()) && readSchema().equals(that.readSchema()) // compare Spark schemas to ignore field ids && filterExpressions().toString().equals(that.filterExpressions().toString()) && runtimeFilterExpressions.toString().equals(that.runtimeFilterExpressions.toString()) @@ -228,7 +227,6 @@ && filterExpressions().toString().equals(that.filterExpressions().toString()) && Objects.equals(startSnapshotId, that.startSnapshotId) && Objects.equals(endSnapshotId, that.endSnapshotId) && Objects.equals(asOfTimestamp, that.asOfTimestamp) - && Objects.equals(branch, that.branch) && Objects.equals(tag, that.tag); } @@ -236,6 +234,7 @@ && filterExpressions().toString().equals(that.filterExpressions().toString()) public int hashCode() { return Objects.hash( table().name(), + branch(), readSchema(), filterExpressions().toString(), runtimeFilterExpressions.toString(), @@ -243,15 +242,15 @@ public int hashCode() { startSnapshotId, endSnapshotId, asOfTimestamp, - branch, tag); } @Override public String toString() { return String.format( - "IcebergScan(table=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", + "IcebergScan(table=%s, branch=%s, type=%s, filters=%s, runtimeFilters=%s, caseSensitive=%s)", table(), + branch(), expectedSchema().asStruct(), filterExpressions(), runtimeFilterExpressions, diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java index 68c99440441d..4fca05345a2e 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteOperation.java @@ -39,6 +39,7 @@ class SparkCopyOnWriteOperation implements RowLevelOperation { private final SparkSession spark; private final Table table; + private final String branch; private final Command command; private final IsolationLevel isolationLevel; @@ -48,9 +49,14 @@ class SparkCopyOnWriteOperation implements RowLevelOperation { private WriteBuilder lazyWriteBuilder; SparkCopyOnWriteOperation( - SparkSession spark, Table table, RowLevelOperationInfo info, IsolationLevel isolationLevel) { + SparkSession spark, + Table table, + String branch, + RowLevelOperationInfo info, + IsolationLevel isolationLevel) { this.spark = spark; this.table = table; + this.branch = branch; this.command = info.command(); this.isolationLevel = isolationLevel; } @@ -64,7 +70,7 @@ public Command command() { public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { if (lazyScanBuilder == null) { lazyScanBuilder = - new SparkScanBuilder(spark, table, options) { + new SparkScanBuilder(spark, table, branch, options) { @Override public Scan build() { Scan scan = super.buildCopyOnWriteScan(); @@ -80,7 +86,7 @@ public Scan build() { @Override public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { if (lazyWriteBuilder == null) { - SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, info); + SparkWriteBuilder writeBuilder = new SparkWriteBuilder(spark, table, branch, info); lazyWriteBuilder = writeBuilder.overwriteFiles(configuredScan, command, isolationLevel); } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java index 9a411c213484..d978b81e67bd 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkCopyOnWriteScan.java @@ -33,6 +33,7 @@ import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.SparkReadConf; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.expressions.Expressions; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -179,7 +180,7 @@ public String toString() { } private Long currentSnapshotId() { - Snapshot currentSnapshot = table().currentSnapshot(); + Snapshot currentSnapshot = SnapshotUtil.latestSnapshot(table(), branch()); return currentSnapshot != null ? currentSnapshot.snapshotId() : null; } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java index e4b5f63bd96d..0394b691e152 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkInputPartition.java @@ -36,6 +36,7 @@ class SparkInputPartition implements InputPartition, HasPartitionKey, Serializab private final Types.StructType groupingKeyType; private final ScanTaskGroup taskGroup; private final Broadcast tableBroadcast; + private final String branch; private final String expectedSchemaString; private final boolean caseSensitive; @@ -46,12 +47,14 @@ class SparkInputPartition implements InputPartition, HasPartitionKey, Serializab Types.StructType groupingKeyType, ScanTaskGroup taskGroup, Broadcast
tableBroadcast, + String branch, String expectedSchemaString, boolean caseSensitive, boolean localityPreferred) { this.groupingKeyType = groupingKeyType; this.taskGroup = taskGroup; this.tableBroadcast = tableBroadcast; + this.branch = branch; this.expectedSchemaString = expectedSchemaString; this.caseSensitive = caseSensitive; if (localityPreferred) { @@ -85,6 +88,10 @@ public Table table() { return tableBroadcast.value(); } + public String branch() { + return branch; + } + public boolean isCaseSensitive() { return caseSensitive; } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java index aacacdd752d8..317a7863eff7 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkMicroBatchStream.java @@ -68,6 +68,7 @@ public class SparkMicroBatchStream implements MicroBatchStream { private static final Types.StructType EMPTY_GROUPING_KEY_TYPE = Types.StructType.of(); private final Table table; + private final String branch; private final boolean caseSensitive; private final String expectedSchema; private final Broadcast
tableBroadcast; @@ -87,6 +88,7 @@ public class SparkMicroBatchStream implements MicroBatchStream { Schema expectedSchema, String checkpointLocation) { this.table = table; + this.branch = readConf.branch(); this.caseSensitive = readConf.caseSensitive(); this.expectedSchema = SchemaParser.toJson(expectedSchema); this.localityPreferred = readConf.localityEnabled(); @@ -164,6 +166,7 @@ public InputPartition[] planInputPartitions(Offset start, Offset end) { EMPTY_GROUPING_KEY_TYPE, combinedScanTasks.get(index), tableBroadcast, + branch, expectedSchema, caseSensitive, localityPreferred)); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java index a164c0e13c55..4c7a02543abe 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPartitioningAwareScan.java @@ -44,6 +44,7 @@ import org.apache.iceberg.spark.Spark3Util; import org.apache.iceberg.spark.SparkReadConf; import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.iceberg.util.StructLikeSet; import org.apache.iceberg.util.TableScanUtil; import org.apache.spark.sql.SparkSession; @@ -140,7 +141,8 @@ private Transform[] groupingKeyTransforms() { .map(field -> fieldsById.get(field.fieldId())) .collect(Collectors.toList()); - this.groupingKeyTransforms = Spark3Util.toTransforms(table().schema(), groupingKeyFields); + Schema schema = SnapshotUtil.schemaFor(table(), branch()); + this.groupingKeyTransforms = Spark3Util.toTransforms(schema, groupingKeyFields); } return groupingKeyTransforms; diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java index 72948dedb2bf..9f2647df9d8d 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaOperation.java @@ -39,6 +39,7 @@ class SparkPositionDeltaOperation implements RowLevelOperation, SupportsDelta { private final SparkSession spark; private final Table table; + private final String branch; private final Command command; private final IsolationLevel isolationLevel; @@ -48,9 +49,14 @@ class SparkPositionDeltaOperation implements RowLevelOperation, SupportsDelta { private DeltaWriteBuilder lazyWriteBuilder; SparkPositionDeltaOperation( - SparkSession spark, Table table, RowLevelOperationInfo info, IsolationLevel isolationLevel) { + SparkSession spark, + Table table, + String branch, + RowLevelOperationInfo info, + IsolationLevel isolationLevel) { this.spark = spark; this.table = table; + this.branch = branch; this.command = info.command(); this.isolationLevel = isolationLevel; } @@ -64,7 +70,7 @@ public Command command() { public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { if (lazyScanBuilder == null) { this.lazyScanBuilder = - new SparkScanBuilder(spark, table, options) { + new SparkScanBuilder(spark, table, branch, options) { @Override public Scan build() { Scan scan = super.buildMergeOnReadScan(); @@ -88,6 +94,7 @@ public DeltaWriteBuilder newWriteBuilder(LogicalWriteInfo info) { new SparkPositionDeltaWriteBuilder( spark, table, + branch, command, configuredScan, isolationLevel, diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java index 5eba7166c98b..ce4b248e0f54 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWrite.java @@ -97,6 +97,7 @@ class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrde private final String applicationId; private final boolean wapEnabled; private final String wapId; + private final String branch; private final Map extraSnapshotMetadata; private final Distribution requiredDistribution; private final SortOrder[] requiredOrdering; @@ -123,6 +124,7 @@ class SparkPositionDeltaWrite implements DeltaWrite, RequiresDistributionAndOrde this.applicationId = spark.sparkContext().applicationId(); this.wapEnabled = writeConf.wapEnabled(); this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); this.extraSnapshotMetadata = writeConf.extraSnapshotMetadata(); this.requiredDistribution = requiredDistribution; this.requiredOrdering = requiredOrdering; @@ -273,6 +275,10 @@ private void commitOperation(SnapshotUpdate operation, String description) { operation.stageOnly(); } + if (branch != null) { + operation.toBranch(branch); + } + try { long start = System.currentTimeMillis(); operation.commit(); // abort is automatically called if this fails diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java index ebac7e2515cc..b7ec6734dca1 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkPositionDeltaWriteBuilder.java @@ -59,6 +59,7 @@ class SparkPositionDeltaWriteBuilder implements DeltaWriteBuilder { SparkPositionDeltaWriteBuilder( SparkSession spark, Table table, + String branch, Command command, Scan scan, IsolationLevel isolationLevel, @@ -68,7 +69,7 @@ class SparkPositionDeltaWriteBuilder implements DeltaWriteBuilder { this.command = command; this.scan = (SparkBatchQueryScan) scan; this.isolationLevel = isolationLevel; - this.writeConf = new SparkWriteConf(spark, table, info.options()); + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); this.info = info; this.handleTimestampWithoutZone = writeConf.handleTimestampWithoutZone(); this.checkNullability = writeConf.checkNullability(); diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java index 0673d647703c..b113bd9b25af 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkRowLevelOperationBuilder.java @@ -45,13 +45,16 @@ class SparkRowLevelOperationBuilder implements RowLevelOperationBuilder { private final SparkSession spark; private final Table table; + private final String branch; private final RowLevelOperationInfo info; private final RowLevelOperationMode mode; private final IsolationLevel isolationLevel; - SparkRowLevelOperationBuilder(SparkSession spark, Table table, RowLevelOperationInfo info) { + SparkRowLevelOperationBuilder( + SparkSession spark, Table table, String branch, RowLevelOperationInfo info) { this.spark = spark; this.table = table; + this.branch = branch; this.info = info; this.mode = mode(table.properties(), info.command()); this.isolationLevel = isolationLevel(table.properties(), info.command()); @@ -61,9 +64,9 @@ class SparkRowLevelOperationBuilder implements RowLevelOperationBuilder { public RowLevelOperation build() { switch (mode) { case COPY_ON_WRITE: - return new SparkCopyOnWriteOperation(spark, table, info, isolationLevel); + return new SparkCopyOnWriteOperation(spark, table, branch, info, isolationLevel); case MERGE_ON_READ: - return new SparkPositionDeltaOperation(spark, table, info, isolationLevel); + return new SparkPositionDeltaOperation(spark, table, branch, info, isolationLevel); default: throw new IllegalArgumentException("Unsupported operation mode: " + mode); } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java index 239b1642abd0..06fc4a07a0eb 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScan.java @@ -37,6 +37,7 @@ import org.apache.iceberg.spark.source.metrics.NumSplits; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.PropertyUtil; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.connector.metric.CustomMetric; @@ -59,6 +60,7 @@ abstract class SparkScan implements Scan, SupportsReportStatistics { private final Schema expectedSchema; private final List filterExpressions; private final boolean readTimestampWithoutZone; + private final String branch; // lazy variables private StructType readSchema; @@ -69,8 +71,8 @@ abstract class SparkScan implements Scan, SupportsReportStatistics { SparkReadConf readConf, Schema expectedSchema, List filters) { - - SparkSchemaUtil.validateMetadataColumnReferences(table.schema(), expectedSchema); + Schema snapshotSchema = SnapshotUtil.schemaFor(table, readConf.branch()); + SparkSchemaUtil.validateMetadataColumnReferences(snapshotSchema, expectedSchema); this.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext()); this.table = table; @@ -79,12 +81,17 @@ abstract class SparkScan implements Scan, SupportsReportStatistics { this.expectedSchema = expectedSchema; this.filterExpressions = filters != null ? filters : Collections.emptyList(); this.readTimestampWithoutZone = readConf.handleTimestampWithoutZone(); + this.branch = readConf.branch(); } protected Table table() { return table; } + protected String branch() { + return branch; + } + protected boolean caseSensitive() { return caseSensitive; } @@ -128,7 +135,7 @@ public StructType readSchema() { @Override public Statistics estimateStatistics() { - return estimateStatistics(table.currentSnapshot()); + return estimateStatistics(SnapshotUtil.latestSnapshot(table, branch)); } protected Statistics estimateStatistics(Snapshot snapshot) { @@ -166,8 +173,8 @@ public String description() { .collect(Collectors.joining(", ")); return String.format( - "%s [filters=%s, groupedBy=%s]", - table(), Spark3Util.describe(filterExpressions), groupingKeyFieldNamesAsString); + "%s (branch=%s) [filters=%s, groupedBy=%s]", + table(), branch(), Spark3Util.describe(filterExpressions), groupingKeyFieldNamesAsString); } @Override diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index ee1d86531f00..23cd8524b3c8 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -98,12 +98,16 @@ public class SparkScanBuilder private Filter[] pushedFilters = NO_FILTERS; SparkScanBuilder( - SparkSession spark, Table table, Schema schema, CaseInsensitiveStringMap options) { + SparkSession spark, + Table table, + String branch, + Schema schema, + CaseInsensitiveStringMap options) { this.spark = spark; this.table = table; this.schema = schema; this.options = options; - this.readConf = new SparkReadConf(spark, table, options); + this.readConf = new SparkReadConf(spark, table, branch, options); this.caseSensitive = readConf.caseSensitive(); } @@ -111,6 +115,16 @@ public class SparkScanBuilder this(spark, table, table.schema(), options); } + SparkScanBuilder( + SparkSession spark, Table table, String branch, CaseInsensitiveStringMap options) { + this(spark, table, branch, SnapshotUtil.schemaFor(table, branch), options); + } + + SparkScanBuilder( + SparkSession spark, Table table, Schema schema, CaseInsensitiveStringMap options) { + this(spark, table, null, schema, options); + } + private Expression filterExpression() { if (filterExpressions != null) { return filterExpressions.stream().reduce(Expressions.alwaysTrue(), Expressions::and); @@ -273,7 +287,7 @@ private Snapshot readSnapshot() { if (readConf.snapshotId() != null) { snapshot = table.snapshot(readConf.snapshotId()); } else { - snapshot = table.currentSnapshot(); + snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); } return snapshot; @@ -538,14 +552,10 @@ private Long getStartSnapshotId(Long startTimestamp) { public Scan buildMergeOnReadScan() { Preconditions.checkArgument( - readConf.snapshotId() == null - && readConf.asOfTimestamp() == null - && readConf.branch() == null - && readConf.tag() == null, - "Cannot set time travel options %s, %s, %s and %s for row-level command scans", + readConf.snapshotId() == null && readConf.asOfTimestamp() == null && readConf.tag() == null, + "Cannot set time travel options %s, %s, %s for row-level command scans", SparkReadOptions.SNAPSHOT_ID, SparkReadOptions.AS_OF_TIMESTAMP, - SparkReadOptions.BRANCH, SparkReadOptions.TAG); Preconditions.checkArgument( @@ -554,7 +564,7 @@ public Scan buildMergeOnReadScan() { SparkReadOptions.START_SNAPSHOT_ID, SparkReadOptions.END_SNAPSHOT_ID); - Snapshot snapshot = table.currentSnapshot(); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); if (snapshot == null) { return new SparkBatchQueryScan( @@ -566,7 +576,8 @@ public Scan buildMergeOnReadScan() { CaseInsensitiveStringMap adjustedOptions = Spark3Util.setOption(SparkReadOptions.SNAPSHOT_ID, Long.toString(snapshotId), options); - SparkReadConf adjustedReadConf = new SparkReadConf(spark, table, adjustedOptions); + SparkReadConf adjustedReadConf = + new SparkReadConf(spark, table, readConf.branch(), adjustedOptions); Schema expectedSchema = schemaWithMetadataColumns(); @@ -585,7 +596,7 @@ public Scan buildMergeOnReadScan() { } public Scan buildCopyOnWriteScan() { - Snapshot snapshot = table.currentSnapshot(); + Snapshot snapshot = SnapshotUtil.latestSnapshot(table, readConf.branch()); if (snapshot == null) { return new SparkCopyOnWriteScan( diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java index c5e367de03bb..8c418d51b9a0 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkTable.java @@ -27,15 +27,18 @@ import org.apache.iceberg.BaseMetadataTable; import org.apache.iceberg.BaseTable; import org.apache.iceberg.DataFile; +import org.apache.iceberg.DeleteFiles; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.MetadataColumns; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Partitioning; import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotRef; import org.apache.iceberg.Table; import org.apache.iceberg.TableOperations; import org.apache.iceberg.TableProperties; import org.apache.iceberg.TableScan; +import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.expressions.Evaluator; import org.apache.iceberg.expressions.Expression; import org.apache.iceberg.expressions.ExpressionUtil; @@ -114,11 +117,23 @@ public class SparkTable private final Long snapshotId; private final boolean refreshEagerly; private final Set capabilities; + private String branch; private StructType lazyTableSchema = null; private SparkSession lazySpark = null; public SparkTable(Table icebergTable, boolean refreshEagerly) { - this(icebergTable, null, refreshEagerly); + this(icebergTable, (Long) null, refreshEagerly); + } + + public SparkTable(Table icebergTable, String branch, boolean refreshEagerly) { + this(icebergTable, refreshEagerly); + this.branch = branch; + ValidationException.check( + branch == null + || SnapshotRef.MAIN_BRANCH.equals(branch) + || icebergTable.snapshot(branch) != null, + "Cannot use branch (does not exist): %s", + branch); } public SparkTable(Table icebergTable, Long snapshotId, boolean refreshEagerly) { @@ -159,9 +174,15 @@ public SparkTable copyWithSnapshotId(long newSnapshotId) { return new SparkTable(icebergTable, newSnapshotId, refreshEagerly); } + public SparkTable copyWithBranch(String targetBranch) { + return new SparkTable(icebergTable, targetBranch, refreshEagerly); + } + private Schema snapshotSchema() { if (icebergTable instanceof BaseMetadataTable) { return icebergTable.schema(); + } else if (branch != null) { + return SnapshotUtil.schemaFor(icebergTable, branch); } else { return SnapshotUtil.schemaFor(icebergTable, snapshotId, null); } @@ -248,8 +269,10 @@ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { icebergTable.refresh(); } - CaseInsensitiveStringMap scanOptions = addSnapshotId(options, snapshotId); - return new SparkScanBuilder(sparkSession(), icebergTable, snapshotSchema(), scanOptions); + CaseInsensitiveStringMap scanOptions = + branch != null ? options : addSnapshotId(options, snapshotId); + return new SparkScanBuilder( + sparkSession(), icebergTable, branch, snapshotSchema(), scanOptions); } @Override @@ -257,12 +280,12 @@ public WriteBuilder newWriteBuilder(LogicalWriteInfo info) { Preconditions.checkArgument( snapshotId == null, "Cannot write to table at a specific snapshot: %s", snapshotId); - return new SparkWriteBuilder(sparkSession(), icebergTable, info); + return new SparkWriteBuilder(sparkSession(), icebergTable, branch, info); } @Override public RowLevelOperationBuilder newRowLevelOperationBuilder(RowLevelOperationInfo info) { - return new SparkRowLevelOperationBuilder(sparkSession(), icebergTable, info); + return new SparkRowLevelOperationBuilder(sparkSession(), icebergTable, branch, info); } @Override @@ -300,10 +323,14 @@ private boolean canDeleteUsingMetadata(Expression deleteExpr) { .includeColumnStats() .ignoreResiduals(); + if (branch != null) { + scan.useRef(branch); + } + try (CloseableIterable tasks = scan.planFiles()) { Map evaluators = Maps.newHashMap(); StrictMetricsEvaluator metricsEvaluator = - new StrictMetricsEvaluator(table().schema(), deleteExpr); + new StrictMetricsEvaluator(SnapshotUtil.schemaFor(table(), branch), deleteExpr); return Iterables.all( tasks, @@ -334,11 +361,17 @@ public void deleteWhere(Filter[] filters) { return; } - icebergTable - .newDelete() - .set("spark.app.id", sparkSession().sparkContext().applicationId()) - .deleteFromRowFilter(deleteExpr) - .commit(); + DeleteFiles deleteFiles = + icebergTable + .newDelete() + .set("spark.app.id", sparkSession().sparkContext().applicationId()) + .deleteFromRowFilter(deleteExpr); + + if (branch != null) { + deleteFiles.toBranch(branch); + } + + deleteFiles.commit(); } @Override diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java index f68898e27b3d..9bcbbde8b703 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWrite.java @@ -91,6 +91,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { private final String applicationId; private final boolean wapEnabled; private final String wapId; + private final String branch; private final long targetFileSize; private final Schema writeSchema; private final StructType dsSchema; @@ -119,6 +120,7 @@ abstract class SparkWrite implements Write, RequiresDistributionAndOrdering { this.applicationId = applicationId; this.wapEnabled = writeConf.wapEnabled(); this.wapId = writeConf.wapId(); + this.branch = writeConf.branch(); this.targetFileSize = writeConf.targetDataFileSize(); this.writeSchema = writeSchema; this.dsSchema = dsSchema; @@ -202,6 +204,10 @@ private void commitOperation(SnapshotUpdate operation, String description) { operation.stageOnly(); } + if (branch != null) { + operation.toBranch(branch); + } + try { long start = System.currentTimeMillis(); operation.commit(); // abort is automatically called if this fails diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java index 55cf7961e92f..133ca45b4603 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkWriteBuilder.java @@ -73,10 +73,10 @@ class SparkWriteBuilder implements WriteBuilder, SupportsDynamicOverwrite, Suppo private Command copyOnWriteCommand = null; private IsolationLevel copyOnWriteIsolationLevel = null; - SparkWriteBuilder(SparkSession spark, Table table, LogicalWriteInfo info) { + SparkWriteBuilder(SparkSession spark, Table table, String branch, LogicalWriteInfo info) { this.spark = spark; this.table = table; - this.writeConf = new SparkWriteConf(spark, table, info.options()); + this.writeConf = new SparkWriteConf(spark, table, branch, info.options()); this.writeInfo = info; this.dsSchema = info.schema(); this.overwriteMode = writeConf.overwriteMode(); diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java index e32aeea64d4d..00a9339cb743 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/SparkTestBaseWithCatalog.java @@ -94,4 +94,12 @@ public SparkTestBaseWithCatalog( protected String tableName(String name) { return (catalogName.equals("spark_catalog") ? "" : catalogName + ".") + "default." + name; } + + protected String commitTarget() { + return tableName; + } + + protected String selectTarget() { + return tableName; + } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java index 59074bbd923b..35d16d6f8588 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -50,6 +50,7 @@ import org.apache.iceberg.ManifestFile; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; +import org.apache.iceberg.TableScan; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.relocated.com.google.common.collect.Streams; @@ -797,9 +798,17 @@ public static List deleteManifests(Table table) { } public static Set dataFiles(Table table) { + return dataFiles(table, null); + } + + public static Set dataFiles(Table table, String branch) { Set dataFiles = Sets.newHashSet(); + TableScan scan = table.newScan(); + if (branch != null) { + scan.useRef(branch); + } - for (FileScanTask task : table.newScan().planFiles()) { + for (FileScanTask task : scan.planFiles()) { dataFiles.add(task.file()); } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java index cbcee867803f..3d94966eb76c 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestBaseReader.java @@ -90,7 +90,7 @@ private static class ClosureTrackingReader extends BaseReader tracker = Maps.newHashMap(); ClosureTrackingReader(Table table, List tasks) { - super(table, new BaseCombinedScanTask(tasks), null, false); + super(table, new BaseCombinedScanTask(tasks), null, null, false); } @Override diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java index 3fd8718aedab..fc17547fad41 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestChangelogReader.java @@ -103,7 +103,8 @@ public void testInsert() throws IOException { List rows = Lists.newArrayList(); for (ScanTaskGroup taskGroup : taskGroups) { - ChangelogRowReader reader = new ChangelogRowReader(table, taskGroup, table.schema(), false); + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); while (reader.next()) { rows.add(reader.get().copy()); } @@ -133,7 +134,8 @@ public void testDelete() throws IOException { List rows = Lists.newArrayList(); for (ScanTaskGroup taskGroup : taskGroups) { - ChangelogRowReader reader = new ChangelogRowReader(table, taskGroup, table.schema(), false); + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); while (reader.next()) { rows.add(reader.get().copy()); } @@ -166,7 +168,8 @@ public void testDataFileRewrite() throws IOException { List rows = Lists.newArrayList(); for (ScanTaskGroup taskGroup : taskGroups) { - ChangelogRowReader reader = new ChangelogRowReader(table, taskGroup, table.schema(), false); + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); while (reader.next()) { rows.add(reader.get().copy()); } @@ -192,7 +195,8 @@ public void testMixDeleteAndInsert() throws IOException { List rows = Lists.newArrayList(); for (ScanTaskGroup taskGroup : taskGroups) { - ChangelogRowReader reader = new ChangelogRowReader(table, taskGroup, table.schema(), false); + ChangelogRowReader reader = + new ChangelogRowReader(table, taskGroup, table.schema(), table.schema(), false); while (reader.next()) { rows.add(reader.get().copy()); } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java index dac1c150cdb6..551dc961e309 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkDataWrite.java @@ -37,6 +37,7 @@ import org.apache.iceberg.ManifestFiles; import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; +import org.apache.iceberg.SnapshotRef; import org.apache.iceberg.Table; import org.apache.iceberg.TableProperties; import org.apache.iceberg.exceptions.CommitStateUnknownException; @@ -46,12 +47,14 @@ import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.SparkWriteOptions; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.SparkException; import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Encoders; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; import org.apache.spark.sql.SparkSession; +import org.assertj.core.api.Assertions; import org.junit.AfterClass; import org.junit.Assert; import org.junit.Assume; @@ -66,6 +69,7 @@ public class TestSparkDataWrite { private static final Configuration CONF = new Configuration(); private final FileFormat format; + private final String branch; private static SparkSession spark = null; private static final Schema SCHEMA = new Schema( @@ -73,9 +77,15 @@ public class TestSparkDataWrite { @Rule public TemporaryFolder temp = new TemporaryFolder(); - @Parameterized.Parameters(name = "format = {0}") + @Parameterized.Parameters(name = "format = {0}, branch = {1}") public static Object[] parameters() { - return new Object[] {"parquet", "avro", "orc"}; + return new Object[] { + new Object[] {"parquet", null}, + new Object[] {"parquet", "main"}, + new Object[] {"parquet", "testBranch"}, + new Object[] {"avro", null}, + new Object[] {"orc", "testBranch"} + }; } @BeforeClass @@ -95,14 +105,16 @@ public static void stopSpark() { currentSpark.stop(); } - public TestSparkDataWrite(String format) { + public TestSparkDataWrite(String format, String branch) { this.format = FileFormat.fromString(format); + this.branch = branch; } @Test public void testBasicWrite() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); @@ -121,15 +133,17 @@ public void testBasicWrite() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); Assert.assertEquals("Number of rows should match", expected.size(), actual.size()); Assert.assertEquals("Result rows should match", expected, actual); - for (ManifestFile manifest : table.currentSnapshot().allManifests(table.io())) { + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { for (DataFile file : ManifestFiles.read(manifest, table.io())) { // TODO: avro not support split if (!format.equals(FileFormat.AVRO)) { @@ -152,6 +166,7 @@ public void testBasicWrite() throws IOException { public void testAppend() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); @@ -179,17 +194,19 @@ public void testAppend() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); + df.withColumn("id", df.col("id").plus(3)) .select("id", "data") .write() .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .save(location.toString()); + .save(targetLocation); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -201,6 +218,7 @@ public void testAppend() throws IOException { public void testEmptyOverwrite() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); @@ -220,6 +238,8 @@ public void testEmptyOverwrite() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); + Dataset empty = spark.createDataFrame(ImmutableList.of(), SimpleRecord.class); empty .select("id", "data") @@ -228,11 +248,11 @@ public void testEmptyOverwrite() throws IOException { .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Overwrite) .option("overwrite-mode", "dynamic") - .save(location.toString()); + .save(targetLocation); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -244,6 +264,7 @@ public void testEmptyOverwrite() throws IOException { public void testOverwrite() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("id").build(); @@ -270,6 +291,8 @@ public void testOverwrite() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); + // overwrite with 2*id to replace record 2, append 4 and 6 df.withColumn("id", df.col("id").multiply(2)) .select("id", "data") @@ -278,11 +301,11 @@ public void testOverwrite() throws IOException { .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Overwrite) .option("overwrite-mode", "dynamic") - .save(location.toString()); + .save(targetLocation); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -294,6 +317,7 @@ public void testOverwrite() throws IOException { public void testUnpartitionedOverwrite() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.unpartitioned(); @@ -312,17 +336,19 @@ public void testUnpartitionedOverwrite() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); + // overwrite with the same data; should not produce two copies df.select("id", "data") .write() .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Overwrite) - .save(location.toString()); + .save(targetLocation); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -334,6 +360,7 @@ public void testUnpartitionedOverwrite() throws IOException { public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.unpartitioned(); @@ -358,9 +385,10 @@ public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -368,7 +396,8 @@ public void testUnpartitionedCreateWithTargetFileSizeViaTableProperties() throws Assert.assertEquals("Result rows should match", expected, actual); List files = Lists.newArrayList(); - for (ManifestFile manifest : table.currentSnapshot().allManifests(table.io())) { + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { for (DataFile file : ManifestFiles.read(manifest, table.io())) { files.add(file); } @@ -402,6 +431,7 @@ public void testWriteProjection() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.unpartitioned(); @@ -420,9 +450,10 @@ public void testWriteProjection() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -438,6 +469,7 @@ public void testWriteProjectionWithMiddle() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.unpartitioned(); @@ -463,9 +495,10 @@ public void testWriteProjectionWithMiddle() throws IOException { .mode(SaveMode.Append) .save(location.toString()); + createBranch(table); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("c1").as(Encoders.bean(ThreeColumnRecord.class)).collectAsList(); @@ -477,6 +510,7 @@ public void testWriteProjectionWithMiddle() throws IOException { public void testViewsReturnRecentResults() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); @@ -495,7 +529,10 @@ public void testViewsReturnRecentResults() throws IOException { .mode(SaveMode.Append) .save(location.toString()); - Dataset query = spark.read().format("iceberg").load(location.toString()).where("id = 1"); + Table table = tables.load(location.toString()); + createBranch(table); + + Dataset query = spark.read().format("iceberg").load(targetLocation).where("id = 1"); query.createOrReplaceTempView("tmp"); List actual1 = @@ -509,7 +546,7 @@ public void testViewsReturnRecentResults() throws IOException { .format("iceberg") .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) .mode(SaveMode.Append) - .save(location.toString()); + .save(targetLocation); List actual2 = spark.table("tmp").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -523,6 +560,7 @@ public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType opti throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "test"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); @@ -576,9 +614,10 @@ public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType opti break; } + createBranch(table); table.refresh(); - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); @@ -586,7 +625,8 @@ public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType opti Assert.assertEquals("Result rows should match", expected, actual); List files = Lists.newArrayList(); - for (ManifestFile manifest : table.currentSnapshot().allManifests(table.io())) { + for (ManifestFile manifest : + SnapshotUtil.latestSnapshot(table, branch).allManifests(table.io())) { for (DataFile file : ManifestFiles.read(manifest, table.io())) { files.add(file); } @@ -601,6 +641,7 @@ public void partitionedCreateWithTargetFileSizeViaOption(IcebergOptionsType opti public void testCommitUnknownException() throws IOException { File parent = temp.newFolder(format.toString()); File location = new File(parent, "commitunknown"); + String targetLocation = locationWithBranch(location); HadoopTables tables = new HadoopTables(CONF); PartitionSpec spec = PartitionSpec.builderFor(SCHEMA).identity("data").build(); @@ -612,7 +653,27 @@ public void testCommitUnknownException() throws IOException { Dataset df = spark.createDataFrame(records, SimpleRecord.class); + df.select("id", "data") + .write() + .format("iceberg") + .option(SparkWriteOptions.WRITE_FORMAT, format.toString()) + .mode(SaveMode.Append) + .save(location.toString()); + + createBranch(table); + table.refresh(); + + List records2 = + Lists.newArrayList( + new SimpleRecord(4, "d"), new SimpleRecord(5, "e"), new SimpleRecord(6, "f")); + + Dataset df2 = spark.createDataFrame(records2, SimpleRecord.class); + AppendFiles append = table.newFastAppend(); + if (branch != null) { + append.toBranch(branch); + } + AppendFiles spyAppend = spy(append); doAnswer( invocation -> { @@ -637,20 +698,28 @@ public void testCommitUnknownException() throws IOException { CommitStateUnknownException.class, "Datacenter on Fire", () -> - df.select("id", "data") + df2.select("id", "data") .sort("data") .write() .format("org.apache.iceberg.spark.source.ManualSource") .option(ManualSource.TABLE_NAME, manualTableName) .mode(SaveMode.Append) - .save(location.toString())); + .save(targetLocation)); // Since write and commit succeeded, the rows should be readable - Dataset result = spark.read().format("iceberg").load(location.toString()); + Dataset result = spark.read().format("iceberg").load(targetLocation); List actual = result.orderBy("id").as(Encoders.bean(SimpleRecord.class)).collectAsList(); - Assert.assertEquals("Number of rows should match", records.size(), actual.size()); - Assert.assertEquals("Result rows should match", records, actual); + Assert.assertEquals( + "Number of rows should match", records.size() + records2.size(), actual.size()); + Assertions.assertThat(actual) + .describedAs("Result rows should match") + .containsExactlyInAnyOrder( + ImmutableList.builder() + .addAll(records) + .addAll(records2) + .build() + .toArray(new SimpleRecord[0])); } public enum IcebergOptionsType { @@ -658,4 +727,18 @@ public enum IcebergOptionsType { TABLE, JOB } + + private String locationWithBranch(File location) { + if (branch == null) { + return location.toString(); + } + + return location + "#branch_" + branch; + } + + private void createBranch(Table table) { + if (branch != null && !branch.equals(SnapshotRef.MAIN_BRANCH)) { + table.manageSnapshots().createBranch(branch, table.currentSnapshot().snapshotId()).commit(); + } + } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java index d1d85790868e..cadcbad6aa76 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReaderDeletes.java @@ -311,7 +311,7 @@ public void testReadEqualityDeleteRows() throws IOException { for (CombinedScanTask task : tasks) { try (EqualityDeleteRowReader reader = - new EqualityDeleteRowReader(task, table, table.schema(), false)) { + new EqualityDeleteRowReader(task, table, null, table.schema(), false)) { while (reader.next()) { actualRowSet.add( new InternalRowWrapper(SparkSchemaUtil.convert(table.schema())) diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java new file mode 100644 index 000000000000..c6775c2a0799 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/PartitionedWritesTestBase.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public abstract class PartitionedWritesTestBase extends SparkCatalogTestBase { + public PartitionedWritesTestBase( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql( + "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3))", + tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testInsertAppend() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testInsertOverwrite() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + // 4 and 5 replace 3 in the partition (id - (id % 3)) = 3 + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 4 rows after overwrite", + 4L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2Append() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).append(); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwritePartitions(); + + Assert.assertEquals( + "Should have 4 rows after overwrite", + 4L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2Overwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwrite(functions.col("id").$less(3)); + + Assert.assertEquals( + "Should have 3 rows after overwrite", + 3L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testViewsReturnRecentResults() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + Dataset query = spark.sql("SELECT * FROM " + commitTarget() + " WHERE id = 1"); + query.createOrReplaceTempView("tmp"); + + assertEquals( + "View should have expected rows", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); + + sql("INSERT INTO TABLE %s VALUES (1, 'a')", commitTarget()); + + assertEquals( + "View should have expected rows", + ImmutableList.of(row(1L, "a"), row(1L, "a")), + sql("SELECT * FROM tmp")); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java index 51c56ac79d4d..a18bd997250b 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWrites.java @@ -18,143 +18,12 @@ */ package org.apache.iceberg.spark.sql; -import java.util.List; import java.util.Map; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; -import org.apache.iceberg.spark.SparkCatalogTestBase; -import org.apache.iceberg.spark.source.SimpleRecord; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; -import org.apache.spark.sql.functions; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -public class TestPartitionedWrites extends SparkCatalogTestBase { +public class TestPartitionedWrites extends PartitionedWritesTestBase { + public TestPartitionedWrites( String catalogName, String implementation, Map config) { super(catalogName, implementation, config); } - - @Before - public void createTables() { - sql( - "CREATE TABLE %s (id bigint, data string) USING iceberg PARTITIONED BY (truncate(id, 3))", - tableName); - sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); - } - - @After - public void removeTables() { - sql("DROP TABLE IF EXISTS %s", tableName); - } - - @Test - public void testInsertAppend() { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", tableName); - - Assert.assertEquals( - "Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = - ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testInsertOverwrite() { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - // 4 and 5 replace 3 in the partition (id - (id % 3)) = 3 - sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", tableName); - - Assert.assertEquals( - "Should have 4 rows after overwrite", 4L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = - ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testDataFrameV2Append() throws NoSuchTableException { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); - Dataset ds = spark.createDataFrame(data, SimpleRecord.class); - - ds.writeTo(tableName).append(); - - Assert.assertEquals( - "Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = - ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); - Dataset ds = spark.createDataFrame(data, SimpleRecord.class); - - ds.writeTo(tableName).overwritePartitions(); - - Assert.assertEquals( - "Should have 4 rows after overwrite", 4L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = - ImmutableList.of(row(1L, "a"), row(2L, "b"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testDataFrameV2Overwrite() throws NoSuchTableException { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); - Dataset ds = spark.createDataFrame(data, SimpleRecord.class); - - ds.writeTo(tableName).overwrite(functions.col("id").$less(3)); - - Assert.assertEquals( - "Should have 3 rows after overwrite", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = ImmutableList.of(row(3L, "c"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testViewsReturnRecentResults() { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - Dataset query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1"); - query.createOrReplaceTempView("tmp"); - - assertEquals( - "View should have expected rows", ImmutableList.of(row(1L, "a")), sql("SELECT * FROM tmp")); - - sql("INSERT INTO TABLE %s VALUES (1, 'a')", tableName); - - assertEquals( - "View should have expected rows", - ImmutableList.of(row(1L, "a"), row(1L, "a")), - sql("SELECT * FROM tmp")); - } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java new file mode 100644 index 000000000000..c6cde7a5524e --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestPartitionedWritesToBranch.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.junit.Before; + +public class TestPartitionedWritesToBranch extends PartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + public TestPartitionedWritesToBranch( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + @Override + public void createTables() { + super.createTables(); + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch(BRANCH, table.currentSnapshot().snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + } + + @Override + protected String commitTarget() { + return String.format("%s.branch_%s", tableName, BRANCH); + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java index 0849602c3b92..d01ccab00f55 100644 --- a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWrites.java @@ -18,146 +18,12 @@ */ package org.apache.iceberg.spark.sql; -import java.util.List; import java.util.Map; -import org.apache.iceberg.AssertHelpers; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; -import org.apache.iceberg.spark.SparkCatalogTestBase; -import org.apache.iceberg.spark.source.SimpleRecord; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.Row; -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; -import org.apache.spark.sql.functions; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -public class TestUnpartitionedWrites extends SparkCatalogTestBase { +public class TestUnpartitionedWrites extends UnpartitionedWritesTestBase { + public TestUnpartitionedWrites( String catalogName, String implementation, Map config) { super(catalogName, implementation, config); } - - @Before - public void createTables() { - sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); - sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); - } - - @After - public void removeTables() { - sql("DROP TABLE IF EXISTS %s", tableName); - } - - @Test - public void testInsertAppend() { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", tableName); - - Assert.assertEquals( - "Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = - ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testInsertOverwrite() { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", tableName); - - Assert.assertEquals( - "Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testInsertAppendAtSnapshot() { - long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); - String prefix = "snapshot_id_"; - AssertHelpers.assertThrows( - "Should not be able to insert into a table at a specific snapshot", - IllegalArgumentException.class, - "Cannot write to table at a specific snapshot", - () -> sql("INSERT INTO %s.%s VALUES (4, 'd'), (5, 'e')", tableName, prefix + snapshotId)); - } - - @Test - public void testInsertOverwriteAtSnapshot() { - long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); - String prefix = "snapshot_id_"; - AssertHelpers.assertThrows( - "Should not be able to insert into a table at a specific snapshot", - IllegalArgumentException.class, - "Cannot write to table at a specific snapshot", - () -> - sql( - "INSERT OVERWRITE %s.%s VALUES (4, 'd'), (5, 'e')", - tableName, prefix + snapshotId)); - } - - @Test - public void testDataFrameV2Append() throws NoSuchTableException { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); - Dataset ds = spark.createDataFrame(data, SimpleRecord.class); - - ds.writeTo(tableName).append(); - - Assert.assertEquals( - "Should have 5 rows after insert", 5L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = - ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); - Dataset ds = spark.createDataFrame(data, SimpleRecord.class); - - ds.writeTo(tableName).overwritePartitions(); - - Assert.assertEquals( - "Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } - - @Test - public void testDataFrameV2Overwrite() throws NoSuchTableException { - Assert.assertEquals("Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); - Dataset ds = spark.createDataFrame(data, SimpleRecord.class); - - ds.writeTo(tableName).overwrite(functions.col("id").$less$eq(3)); - - Assert.assertEquals( - "Should have 2 rows after overwrite", 2L, scalarSql("SELECT count(*) FROM %s", tableName)); - - List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); - - assertEquals( - "Row data should match expected", expected, sql("SELECT * FROM %s ORDER BY id", tableName)); - } } diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java new file mode 100644 index 000000000000..1f5bee42af05 --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/TestUnpartitionedWritesToBranch.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.Map; +import org.apache.iceberg.Table; +import org.apache.iceberg.exceptions.ValidationException; +import org.assertj.core.api.Assertions; +import org.junit.Test; + +public class TestUnpartitionedWritesToBranch extends UnpartitionedWritesTestBase { + + private static final String BRANCH = "test"; + + public TestUnpartitionedWritesToBranch( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Override + public void createTables() { + super.createTables(); + Table table = validationCatalog.loadTable(tableIdent); + table.manageSnapshots().createBranch(BRANCH, table.currentSnapshot().snapshotId()).commit(); + sql("REFRESH TABLE " + tableName); + } + + @Override + protected String commitTarget() { + return String.format("%s.branch_%s", tableName, BRANCH); + } + + @Override + protected String selectTarget() { + return String.format("%s VERSION AS OF '%s'", tableName, BRANCH); + } + + @Test + public void testInsertIntoNonExistingBranchFails() { + Assertions.assertThatThrownBy( + () -> sql("INSERT INTO %s.branch_not_exist VALUES (4, 'd'), (5, 'e')", tableName)) + .isInstanceOf(ValidationException.class) + .hasMessage("Cannot use branch (does not exist): not_exist"); + } +} diff --git a/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java new file mode 100644 index 000000000000..71089ebfd79e --- /dev/null +++ b/spark/v3.3/spark/src/test/java/org/apache/iceberg/spark/sql/UnpartitionedWritesTestBase.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iceberg.spark.sql; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.spark.SparkCatalogTestBase; +import org.apache.iceberg.spark.source.SimpleRecord; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.functions; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; + +public abstract class UnpartitionedWritesTestBase extends SparkCatalogTestBase { + public UnpartitionedWritesTestBase( + String catalogName, String implementation, Map config) { + super(catalogName, implementation, config); + } + + @Before + public void createTables() { + sql("CREATE TABLE %s (id bigint, data string) USING iceberg", tableName); + sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName); + } + + @After + public void removeTables() { + sql("DROP TABLE IF EXISTS %s", tableName); + } + + @Test + public void testInsertAppend() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + sql("INSERT INTO %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testInsertOverwrite() { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + sql("INSERT OVERWRITE %s VALUES (4, 'd'), (5, 'e')", commitTarget()); + + Assert.assertEquals( + "Should have 2 rows after overwrite", + 2L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testInsertAppendAtSnapshot() { + Assume.assumeTrue(tableName.equals(commitTarget())); + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows( + "Should not be able to insert into a table at a specific snapshot", + IllegalArgumentException.class, + "Cannot write to table at a specific snapshot", + () -> sql("INSERT INTO %s.%s VALUES (4, 'd'), (5, 'e')", tableName, prefix + snapshotId)); + } + + @Test + public void testInsertOverwriteAtSnapshot() { + Assume.assumeTrue(tableName.equals(commitTarget())); + long snapshotId = validationCatalog.loadTable(tableIdent).currentSnapshot().snapshotId(); + String prefix = "snapshot_id_"; + AssertHelpers.assertThrows( + "Should not be able to insert into a table at a specific snapshot", + IllegalArgumentException.class, + "Cannot write to table at a specific snapshot", + () -> + sql( + "INSERT OVERWRITE %s.%s VALUES (4, 'd'), (5, 'e')", + tableName, prefix + snapshotId)); + } + + @Test + public void testDataFrameV2Append() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).append(); + + Assert.assertEquals( + "Should have 5 rows after insert", + 5L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = + ImmutableList.of(row(1L, "a"), row(2L, "b"), row(3L, "c"), row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2DynamicOverwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwritePartitions(); + + Assert.assertEquals( + "Should have 2 rows after overwrite", + 2L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } + + @Test + public void testDataFrameV2Overwrite() throws NoSuchTableException { + Assert.assertEquals( + "Should have 3 rows", 3L, scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List data = ImmutableList.of(new SimpleRecord(4, "d"), new SimpleRecord(5, "e")); + Dataset ds = spark.createDataFrame(data, SimpleRecord.class); + + ds.writeTo(commitTarget()).overwrite(functions.col("id").$less$eq(3)); + + Assert.assertEquals( + "Should have 2 rows after overwrite", + 2L, + scalarSql("SELECT count(*) FROM %s", selectTarget())); + + List expected = ImmutableList.of(row(4L, "d"), row(5L, "e")); + + assertEquals( + "Row data should match expected", + expected, + sql("SELECT * FROM %s ORDER BY id", selectTarget())); + } +}