Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,11 @@ public void testRewriteDataFilesOnPartitionTable() {
List<Object[]> expectedRecords = currentData();

List<Object[]> output = sql(
"CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent);
"CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent);

assertEquals("Action should rewrite 10 data files and add 2 data files (one per partition) ",
ImmutableList.of(row(10, 2)),
output);
ImmutableList.of(row(10, 2)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data after compaction should not change", expectedRecords, actualRecords);
Expand All @@ -82,11 +82,11 @@ public void testRewriteDataFilesOnNonPartitionTable() {
List<Object[]> expectedRecords = currentData();

List<Object[]> output = sql(
"CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent);
"CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent);

assertEquals("Action should rewrite 10 data files and add 1 data files",
ImmutableList.of(row(10, 1)),
output);
ImmutableList.of(row(10, 1)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data after compaction should not change", expectedRecords, actualRecords);
Expand All @@ -101,12 +101,12 @@ public void testRewriteDataFilesWithOptions() {

// set the min-input-files = 12, instead of default 5 to skip compacting the files.
List<Object[]> output = sql(
"CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))",
catalogName, tableIdent);
"CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))",
catalogName, tableIdent);

assertEquals("Action should rewrite 0 data files and add 0 data files",
ImmutableList.of(row(0, 0)),
output);
ImmutableList.of(row(0, 0)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data should not change", expectedRecords, actualRecords);
Expand All @@ -121,13 +121,13 @@ public void testRewriteDataFilesWithSortStrategy() {

// set sort_order = c1 DESC LAST
List<Object[]> output = sql(
"CALL %s.system.rewrite_data_files(table => '%s', " +
"strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')",
catalogName, tableIdent);
"CALL %s.system.rewrite_data_files(table => '%s', " +
"strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')",
catalogName, tableIdent);

assertEquals("Action should rewrite 10 data files and add 1 data files",
ImmutableList.of(row(10, 1)),
output);
ImmutableList.of(row(10, 1)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data after compaction should not change", expectedRecords, actualRecords);
Expand All @@ -142,12 +142,12 @@ public void testRewriteDataFilesWithFilter() {

// select only 5 files for compaction (files that may have c1 = 1)
List<Object[]> output = sql(
"CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 = 1 and c2 is not null')", catalogName, tableIdent);
"CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 = 1 and c2 is not null')", catalogName, tableIdent);

assertEquals("Action should rewrite 5 data files (containing c1 = 1) and add 1 data files",
ImmutableList.of(row(5, 1)),
output);
ImmutableList.of(row(5, 1)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data after compaction should not change", expectedRecords, actualRecords);
Expand All @@ -166,14 +166,90 @@ public void testRewriteDataFilesWithFilterOnPartitionTable() {
" where => 'c2 = \"bar\"')", catalogName, tableIdent);

assertEquals("Action should rewrite 5 data files from single matching partition" +
"(containing c2 = bar) and add 1 data files",
"(containing c2 = bar) and add 1 data files",
ImmutableList.of(row(5, 1)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data after compaction should not change", expectedRecords, actualRecords);
}

@Test
public void testRewriteDataFilesWithInFilterOnPartitionTable() {
createPartitionTable();
// create 5 files for each partition (c2 = 'foo' and c2 = 'bar')
insertData(10);
List<Object[]> expectedRecords = currentData();

// select only 5 files for compaction (files in the partition c2 in ('bar'))
List<Object[]> output = sql(
"CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c2 in (\"bar\")')", catalogName, tableIdent);

assertEquals("Action should rewrite 5 data files from single matching partition" +
"(containing c2 = bar) and add 1 data files",
ImmutableList.of(row(5, 1)),
output);

List<Object[]> actualRecords = currentData();
assertEquals("Data after compaction should not change", expectedRecords, actualRecords);
}

@Test
public void testRewriteDataFilesWithAllPossibleFilters() {
createPartitionTable();
// create 5 files for each partition (c2 = 'foo' and c2 = 'bar')
insertData(10);

// Pass the literal value which is not present in the data files.
// So that parsing can be tested on a same dataset without actually compacting the files.

// EqualTo
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 = 3')", catalogName, tableIdent);
// GreaterThan
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 > 3')", catalogName, tableIdent);
// GreaterThanOrEqual
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 >= 3')", catalogName, tableIdent);
// LessThan
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 < 0')", catalogName, tableIdent);
// LessThanOrEqual
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 <= 0')", catalogName, tableIdent);
// In
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 in (3,4,5)')", catalogName, tableIdent);
// IsNull
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 is null')", catalogName, tableIdent);
// IsNotNull
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c3 is not null')", catalogName, tableIdent);
// And
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 = 3 and c2 = \"bar\"')", catalogName, tableIdent);
// Or
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 = 3 or c1 = 5')", catalogName, tableIdent);
// Not
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c1 not in (1,2)')", catalogName, tableIdent);
// StringStartsWith
sql("CALL %s.system.rewrite_data_files(table => '%s'," +
" where => 'c2 like \"%s\"')", catalogName, tableIdent, "car%");

// TODO: Enable when org.apache.iceberg.spark.SparkFilters have implementations for StringEndsWith & StringContains
// StringEndsWith
// sql("CALL %s.system.rewrite_data_files(table => '%s'," +
// " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car");
// StringContains
// sql("CALL %s.system.rewrite_data_files(table => '%s'," +
// " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car%");
}

@Test
public void testRewriteDataFilesWithInvalidInputs() {
createTable();
Expand All @@ -182,41 +258,40 @@ public void testRewriteDataFilesWithInvalidInputs() {

// Test for invalid strategy
AssertHelpers.assertThrows("Should reject calls with unsupported strategy error message",
IllegalArgumentException.class, "unsupported strategy: temp. Only binpack,sort is supported",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " +
"strategy => 'temp')", catalogName, tableIdent));
IllegalArgumentException.class, "unsupported strategy: temp. Only binpack,sort is supported",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " +
"strategy => 'temp')", catalogName, tableIdent));

// Test for sort_order with binpack strategy
AssertHelpers.assertThrows("Should reject calls with error message",
IllegalArgumentException.class, "Cannot set strategy to sort, it has already been set",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " +
"sort_order => 'c1 ASC NULLS FIRST')", catalogName, tableIdent));
IllegalArgumentException.class, "Cannot set strategy to sort, it has already been set",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " +
"sort_order => 'c1 ASC NULLS FIRST')", catalogName, tableIdent));

// Test for sort_order with invalid null order
AssertHelpers.assertThrows("Should reject calls with error message",
IllegalArgumentException.class, "Unable to parse sortOrder:",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " +
"sort_order => 'c1 ASC none')", catalogName, tableIdent));
IllegalArgumentException.class, "Unable to parse sortOrder:",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " +
"sort_order => 'c1 ASC none')", catalogName, tableIdent));

// Test for sort_order with invalid sort direction
AssertHelpers.assertThrows("Should reject calls with error message",
IllegalArgumentException.class, "Unable to parse sortOrder:",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " +
"sort_order => 'c1 none NULLS FIRST')", catalogName, tableIdent));
IllegalArgumentException.class, "Unable to parse sortOrder:",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " +
"sort_order => 'c1 none NULLS FIRST')", catalogName, tableIdent));

// Test for sort_order with invalid column name
AssertHelpers.assertThrows("Should reject calls with error message",
ValidationException.class, "Cannot find field 'col1' in struct:" +
" struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " +
"sort_order => 'col1 DESC NULLS FIRST')", catalogName, tableIdent));
ValidationException.class, "Cannot find field 'col1' in struct:" +
" struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " +
"sort_order => 'col1 DESC NULLS FIRST')", catalogName, tableIdent));

// Test for sort_order with invalid filter column col1
AssertHelpers.assertThrows("Should reject calls with error message",
ValidationException.class, "Cannot find field 'col1' in struct:" +
" struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', " +
"where => 'col1 = 3')", catalogName, tableIdent));
IllegalArgumentException.class, "Cannot parse predicates in where option: col1 = 3",
() -> sql("CALL %s.system.rewrite_data_files(table => '%s', " +
"where => 'col1 = 3')", catalogName, tableIdent));
}

@Test
Expand Down Expand Up @@ -251,8 +326,8 @@ private void createPartitionTable() {
}

private void insertData(int filesCount) {
ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", "detail1");
ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", "detail2");
ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", null);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed why was this changed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while adding a new test cases for all possible filters, I wanted some null data so that my NOT NULL filter will not execute compaction (doesn't select data). C3 was never used in the testcases. So reused it with null data.

ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", null);

List<ThreeColumnRecord> records = Lists.newArrayList();
IntStream.range(0, filesCount / 2).forEach(i -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
import org.apache.spark.sql.AnalysisException;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.parser.ParseException;
import org.apache.spark.sql.catalyst.parser.ParserInterface;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering;
import org.apache.spark.sql.catalyst.plans.logical.SortOrderParserUtil;
Expand Down Expand Up @@ -109,21 +107,20 @@ public InternalRow[] call(InternalRow args) {
}

String where = args.isNullAt(4) ? null : args.getString(4);
action = checkAndApplyFilter(action, where);
action = checkAndApplyFilter(action, where, table.name());

RewriteDataFiles.Result result = action.execute();

return toOutputRows(result);
});
}

private RewriteDataFiles checkAndApplyFilter(RewriteDataFiles action, String where) {
private RewriteDataFiles checkAndApplyFilter(RewriteDataFiles action, String where, String tableName) {
if (where != null) {
ParserInterface sqlParser = spark().sessionState().sqlParser();
try {
Expression expression = sqlParser.parseExpression(where);
Expression expression = SparkExpressionConverter.collectResolvedSparkExpression(spark(), tableName, where);
return action.filter(SparkExpressionConverter.convertToIcebergExpression(expression));
} catch (ParseException e) {
} catch (AnalysisException e) {
throw new IllegalArgumentException("Cannot parse predicates in where option: " + where);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
package org.apache.spark.sql.execution.datasources

import org.apache.iceberg.spark.SparkFilters
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.execution.CommandExecutionMode

object SparkExpressionConverter {

Expand All @@ -30,4 +34,18 @@ object SparkExpressionConverter {
// But these two conversions already exist and well tested. So, we are going with this approach.
SparkFilters.convert(DataSourceStrategy.translateFilter(sparkExpression, supportNestedPredicatePushdown = true).get)
}

@throws[AnalysisException]
def collectResolvedSparkExpression(session: SparkSession, tableName: String, where: String): Expression = {
var expression: Expression = null
// Add a dummy prefix linking to the table to collect the resolved spark expression from optimized plan.
val prefix = String.format("SELECT 42 from %s where ", tableName)
val logicalPlan = session.sessionState.sqlParser.parsePlan(prefix + where)
val optimizedLogicalPlan = session.sessionState.executePlan(logicalPlan, CommandExecutionMode.ALL).optimizedPlan
optimizedLogicalPlan.collectFirst {
case filter: Filter =>
expression = filter.expressions.head
}
expression
}
}