Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions spark/v3.5/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ project(":iceberg-spark:iceberg-spark-extensions-${sparkMajorVersion}_${scalaVer

testImplementation libs.avro.avro
testImplementation libs.parquet.hadoop
testImplementation libs.awaitility

// Required because we remove antlr plugin dependencies from the compile configuration, see note above
runtimeOnly libs.antlr.runtime
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.Files;
import org.apache.iceberg.Parameter;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Parameters;
import org.apache.iceberg.PlanningMode;
import org.apache.iceberg.RowLevelOperationMode;
import org.apache.iceberg.Snapshot;
Expand All @@ -69,41 +73,30 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.execution.SparkPlan;
import org.junit.Assert;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;
import org.junit.jupiter.api.extension.ExtendWith;

@RunWith(Parameterized.class)
public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTestBase {
@ExtendWith(ParameterizedTestExtension.class)
public abstract class SparkRowLevelOperationsTestBase extends ExtensionsTestBase {

private static final Random RANDOM = ThreadLocalRandom.current();

protected final String fileFormat;
protected final boolean vectorized;
protected final String distributionMode;
protected final boolean fanoutEnabled;
protected final String branch;
protected final PlanningMode planningMode;

public SparkRowLevelOperationsTestBase(
String catalogName,
String implementation,
Map<String, String> config,
String fileFormat,
boolean vectorized,
String distributionMode,
boolean fanoutEnabled,
String branch,
PlanningMode planningMode) {
super(catalogName, implementation, config);
this.fileFormat = fileFormat;
this.vectorized = vectorized;
this.distributionMode = distributionMode;
this.fanoutEnabled = fanoutEnabled;
this.branch = branch;
this.planningMode = planningMode;
}
@Parameter(index = 3)
protected FileFormat fileFormat;

@Parameter(index = 4)
protected boolean vectorized;

@Parameter(index = 5)
protected String distributionMode;

@Parameter(index = 6)
protected boolean fanoutEnabled;

@Parameter(index = 7)
protected String branch;

@Parameter(index = 8)
protected PlanningMode planningMode;

@Parameters(
name =
Expand All @@ -118,7 +111,7 @@ public static Object[][] parameters() {
ImmutableMap.of(
"type", "hive",
"default-namespace", "default"),
"orc",
FileFormat.ORC,
true,
WRITE_DISTRIBUTION_MODE_NONE,
true,
Expand All @@ -131,7 +124,7 @@ public static Object[][] parameters() {
ImmutableMap.of(
"type", "hive",
"default-namespace", "default"),
"parquet",
FileFormat.PARQUET,
true,
WRITE_DISTRIBUTION_MODE_NONE,
false,
Expand All @@ -142,7 +135,7 @@ public static Object[][] parameters() {
"testhadoop",
SparkCatalog.class.getName(),
ImmutableMap.of("type", "hadoop"),
"parquet",
FileFormat.PARQUET,
RANDOM.nextBoolean(),
WRITE_DISTRIBUTION_MODE_HASH,
true,
Expand All @@ -160,7 +153,7 @@ public static Object[][] parameters() {
"cache-enabled",
"false" // Spark will delete tables using v1, leaving the cache out of sync
),
"avro",
FileFormat.AVRO,
false,
WRITE_DISTRIBUTION_MODE_RANGE,
false,
Expand Down Expand Up @@ -188,18 +181,18 @@ protected void initTable() {
planningMode.modeName());

switch (fileFormat) {
case "parquet":
case PARQUET:
sql(
"ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')",
tableName, PARQUET_VECTORIZATION_ENABLED, vectorized);
break;
case "orc":
case ORC:
sql(
"ALTER TABLE %s SET TBLPROPERTIES('%s' '%b')",
tableName, ORC_VECTORIZATION_ENABLED, vectorized);
break;
case "avro":
Assert.assertFalse(vectorized);
case AVRO:
assertThat(vectorized).isFalse();
break;
}

Expand Down Expand Up @@ -303,7 +296,7 @@ protected void validateSnapshot(
String deletedDataFiles,
String addedDeleteFiles,
String addedDataFiles) {
Assert.assertEquals("Operation must match", operation, snapshot.operation());
assertThat(snapshot.operation()).as("Operation must match").isEqualTo(operation);
validateProperty(snapshot, CHANGED_PARTITION_COUNT_PROP, changedPartitionCount);
validateProperty(snapshot, DELETED_FILES_PROP, deletedDataFiles);
validateProperty(snapshot, ADDED_DELETE_FILES_PROP, addedDeleteFiles);
Expand All @@ -312,20 +305,22 @@ protected void validateSnapshot(

protected void validateProperty(Snapshot snapshot, String property, Set<String> expectedValues) {
String actual = snapshot.summary().get(property);
Assert.assertTrue(
"Snapshot property "
+ property
+ " has unexpected value, actual = "
+ actual
+ ", expected one of : "
+ String.join(",", expectedValues),
expectedValues.contains(actual));
assertThat(actual)
.as(
"Snapshot property "
+ property
+ " has unexpected value, actual = "
+ actual
+ ", expected one of : "
+ String.join(",", expectedValues))
.isIn(expectedValues);
}

protected void validateProperty(Snapshot snapshot, String property, String expectedValue) {
String actual = snapshot.summary().get(property);
Assert.assertEquals(
"Snapshot property " + property + " has unexpected value.", expectedValue, actual);
assertThat(actual)
.as("Snapshot property " + property + " has unexpected value.")
.isEqualTo(expectedValue);
}

protected void sleep(long millis) {
Expand All @@ -338,7 +333,9 @@ protected void sleep(long millis) {

protected DataFile writeDataFile(Table table, List<GenericRecord> records) {
try {
OutputFile file = Files.localOutput(temp.newFile());
OutputFile file =
Files.localOutput(
temp.resolve(fileFormat.addExtension(UUID.randomUUID().toString())).toFile());

DataWriter<GenericRecord> dataWriter =
Parquet.writeData(file)
Expand Down Expand Up @@ -384,7 +381,7 @@ protected boolean supportsVectorization() {
}

private boolean isParquet() {
return fileFormat.equalsIgnoreCase(FileFormat.PARQUET.name());
return fileFormat.equals(FileFormat.PARQUET);
}

private boolean isCopyOnWrite() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
package org.apache.iceberg.spark.extensions;

import java.util.List;
import java.util.Map;
import org.apache.iceberg.IsolationLevel;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Table;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
Expand All @@ -30,18 +30,15 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import org.assertj.core.api.Assertions;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

public class TestConflictValidation extends SparkExtensionsTestBase {
@ExtendWith(ParameterizedTestExtension.class)
public class TestConflictValidation extends ExtensionsTestBase {

public TestConflictValidation(
String catalogName, String implementation, Map<String, String> config) {
super(catalogName, implementation, config);
}

@Before
@BeforeEach
public void createTables() {
sql(
"CREATE TABLE %s (id int, data string) USING iceberg "
Expand All @@ -53,12 +50,12 @@ public void createTables() {
sql("INSERT INTO %s VALUES (1, 'a'), (2, 'b'), (3, 'c')", tableName);
}

@After
@AfterEach
public void removeTables() {
sql("DROP TABLE IF EXISTS %s", tableName);
}

@Test
@TestTemplate
public void testOverwriteFilterSerializableIsolation() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);
long snapshotId = table.currentSnapshot().snapshotId();
Expand Down Expand Up @@ -90,7 +87,7 @@ public void testOverwriteFilterSerializableIsolation() throws Exception {
.overwrite(functions.col("id").equalTo(1));
}

@Test
@TestTemplate
public void testOverwriteFilterSerializableIsolation2() throws Exception {
List<SimpleRecord> records =
Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b"));
Expand Down Expand Up @@ -127,7 +124,7 @@ public void testOverwriteFilterSerializableIsolation2() throws Exception {
.overwrite(functions.col("id").equalTo(1));
}

@Test
@TestTemplate
public void testOverwriteFilterSerializableIsolation3() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);
long snapshotId = table.currentSnapshot().snapshotId();
Expand Down Expand Up @@ -161,7 +158,7 @@ public void testOverwriteFilterSerializableIsolation3() throws Exception {
.overwrite(functions.col("id").equalTo(1));
}

@Test
@TestTemplate
public void testOverwriteFilterNoSnapshotIdValidation() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);

Expand Down Expand Up @@ -192,7 +189,7 @@ public void testOverwriteFilterNoSnapshotIdValidation() throws Exception {
.overwrite(functions.col("id").equalTo(1));
}

@Test
@TestTemplate
public void testOverwriteFilterSnapshotIsolation() throws Exception {
List<SimpleRecord> records =
Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b"));
Expand Down Expand Up @@ -229,7 +226,7 @@ public void testOverwriteFilterSnapshotIsolation() throws Exception {
.overwrite(functions.col("id").equalTo(1));
}

@Test
@TestTemplate
public void testOverwriteFilterSnapshotIsolation2() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);
long snapshotId = table.currentSnapshot().snapshotId();
Expand All @@ -246,7 +243,7 @@ public void testOverwriteFilterSnapshotIsolation2() throws Exception {
.overwrite(functions.col("id").equalTo(1));
}

@Test
@TestTemplate
public void testOverwritePartitionSerializableIsolation() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);
final long snapshotId = table.currentSnapshot().snapshotId();
Expand Down Expand Up @@ -278,7 +275,7 @@ public void testOverwritePartitionSerializableIsolation() throws Exception {
.overwritePartitions();
}

@Test
@TestTemplate
public void testOverwritePartitionSnapshotIsolation() throws Exception {
List<SimpleRecord> records =
Lists.newArrayList(new SimpleRecord(1, "a"), new SimpleRecord(1, "b"));
Expand Down Expand Up @@ -313,7 +310,7 @@ public void testOverwritePartitionSnapshotIsolation() throws Exception {
.overwritePartitions();
}

@Test
@TestTemplate
public void testOverwritePartitionSnapshotIsolation2() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);
final long snapshotId = table.currentSnapshot().snapshotId();
Expand Down Expand Up @@ -347,7 +344,7 @@ public void testOverwritePartitionSnapshotIsolation2() throws Exception {
.overwritePartitions();
}

@Test
@TestTemplate
public void testOverwritePartitionSnapshotIsolation3() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);
final long snapshotId = table.currentSnapshot().snapshotId();
Expand All @@ -364,7 +361,7 @@ public void testOverwritePartitionSnapshotIsolation3() throws Exception {
.overwritePartitions();
}

@Test
@TestTemplate
public void testOverwritePartitionNoSnapshotIdValidation() throws Exception {
Table table = validationCatalog.loadTable(tableIdent);

Expand Down
Loading