diff --git a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java index 17e489746f94..a6fa5e45eb7e 100644 --- a/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java +++ b/spark/v3.2/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRewriteDataFilesProcedure.java @@ -64,11 +64,11 @@ public void testRewriteDataFilesOnPartitionTable() { List expectedRecords = currentData(); List output = sql( - "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); assertEquals("Action should rewrite 10 data files and add 2 data files (one per partition) ", - ImmutableList.of(row(10, 2)), - output); + ImmutableList.of(row(10, 2)), + output); List actualRecords = currentData(); assertEquals("Data after compaction should not change", expectedRecords, actualRecords); @@ -82,11 +82,11 @@ public void testRewriteDataFilesOnNonPartitionTable() { List expectedRecords = currentData(); List output = sql( - "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); + "CALL %s.system.rewrite_data_files(table => '%s')", catalogName, tableIdent); assertEquals("Action should rewrite 10 data files and add 1 data files", - ImmutableList.of(row(10, 1)), - output); + ImmutableList.of(row(10, 1)), + output); List actualRecords = currentData(); assertEquals("Data after compaction should not change", expectedRecords, actualRecords); @@ -101,12 +101,12 @@ public void testRewriteDataFilesWithOptions() { // set the min-input-files = 12, instead of default 5 to skip compacting the files. List output = sql( - "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))", - catalogName, tableIdent); + "CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','12'))", + catalogName, tableIdent); assertEquals("Action should rewrite 0 data files and add 0 data files", - ImmutableList.of(row(0, 0)), - output); + ImmutableList.of(row(0, 0)), + output); List actualRecords = currentData(); assertEquals("Data should not change", expectedRecords, actualRecords); @@ -121,13 +121,13 @@ public void testRewriteDataFilesWithSortStrategy() { // set sort_order = c1 DESC LAST List output = sql( - "CALL %s.system.rewrite_data_files(table => '%s', " + - "strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')", - catalogName, tableIdent); + "CALL %s.system.rewrite_data_files(table => '%s', " + + "strategy => 'sort', sort_order => 'c1 DESC NULLS LAST')", + catalogName, tableIdent); assertEquals("Action should rewrite 10 data files and add 1 data files", - ImmutableList.of(row(10, 1)), - output); + ImmutableList.of(row(10, 1)), + output); List actualRecords = currentData(); assertEquals("Data after compaction should not change", expectedRecords, actualRecords); @@ -142,12 +142,12 @@ public void testRewriteDataFilesWithFilter() { // select only 5 files for compaction (files that may have c1 = 1) List output = sql( - "CALL %s.system.rewrite_data_files(table => '%s'," + - " where => 'c1 = 1 and c2 is not null')", catalogName, tableIdent); + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 1 and c2 is not null')", catalogName, tableIdent); assertEquals("Action should rewrite 5 data files (containing c1 = 1) and add 1 data files", - ImmutableList.of(row(5, 1)), - output); + ImmutableList.of(row(5, 1)), + output); List actualRecords = currentData(); assertEquals("Data after compaction should not change", expectedRecords, actualRecords); @@ -166,7 +166,28 @@ public void testRewriteDataFilesWithFilterOnPartitionTable() { " where => 'c2 = \"bar\"')", catalogName, tableIdent); assertEquals("Action should rewrite 5 data files from single matching partition" + - "(containing c2 = bar) and add 1 data files", + "(containing c2 = bar) and add 1 data files", + ImmutableList.of(row(5, 1)), + output); + + List actualRecords = currentData(); + assertEquals("Data after compaction should not change", expectedRecords, actualRecords); + } + + @Test + public void testRewriteDataFilesWithInFilterOnPartitionTable() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + List expectedRecords = currentData(); + + // select only 5 files for compaction (files in the partition c2 in ('bar')) + List output = sql( + "CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c2 in (\"bar\")')", catalogName, tableIdent); + + assertEquals("Action should rewrite 5 data files from single matching partition" + + "(containing c2 = bar) and add 1 data files", ImmutableList.of(row(5, 1)), output); @@ -174,6 +195,61 @@ public void testRewriteDataFilesWithFilterOnPartitionTable() { assertEquals("Data after compaction should not change", expectedRecords, actualRecords); } + @Test + public void testRewriteDataFilesWithAllPossibleFilters() { + createPartitionTable(); + // create 5 files for each partition (c2 = 'foo' and c2 = 'bar') + insertData(10); + + // Pass the literal value which is not present in the data files. + // So that parsing can be tested on a same dataset without actually compacting the files. + + // EqualTo + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 3')", catalogName, tableIdent); + // GreaterThan + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 > 3')", catalogName, tableIdent); + // GreaterThanOrEqual + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 >= 3')", catalogName, tableIdent); + // LessThan + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 < 0')", catalogName, tableIdent); + // LessThanOrEqual + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 <= 0')", catalogName, tableIdent); + // In + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 in (3,4,5)')", catalogName, tableIdent); + // IsNull + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 is null')", catalogName, tableIdent); + // IsNotNull + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c3 is not null')", catalogName, tableIdent); + // And + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 3 and c2 = \"bar\"')", catalogName, tableIdent); + // Or + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 = 3 or c1 = 5')", catalogName, tableIdent); + // Not + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c1 not in (1,2)')", catalogName, tableIdent); + // StringStartsWith + sql("CALL %s.system.rewrite_data_files(table => '%s'," + + " where => 'c2 like \"%s\"')", catalogName, tableIdent, "car%"); + + // TODO: Enable when org.apache.iceberg.spark.SparkFilters have implementations for StringEndsWith & StringContains + // StringEndsWith + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car"); + // StringContains + // sql("CALL %s.system.rewrite_data_files(table => '%s'," + + // " where => 'c2 like \"%s\"')", catalogName, tableIdent, "%car%"); + } + @Test public void testRewriteDataFilesWithInvalidInputs() { createTable(); @@ -182,41 +258,40 @@ public void testRewriteDataFilesWithInvalidInputs() { // Test for invalid strategy AssertHelpers.assertThrows("Should reject calls with unsupported strategy error message", - IllegalArgumentException.class, "unsupported strategy: temp. Only binpack,sort is supported", - () -> sql("CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " + - "strategy => 'temp')", catalogName, tableIdent)); + IllegalArgumentException.class, "unsupported strategy: temp. Only binpack,sort is supported", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', options => map('min-input-files','2'), " + + "strategy => 'temp')", catalogName, tableIdent)); // Test for sort_order with binpack strategy AssertHelpers.assertThrows("Should reject calls with error message", - IllegalArgumentException.class, "Cannot set strategy to sort, it has already been set", - () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " + - "sort_order => 'c1 ASC NULLS FIRST')", catalogName, tableIdent)); + IllegalArgumentException.class, "Cannot set strategy to sort, it has already been set", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'binpack', " + + "sort_order => 'c1 ASC NULLS FIRST')", catalogName, tableIdent)); // Test for sort_order with invalid null order AssertHelpers.assertThrows("Should reject calls with error message", - IllegalArgumentException.class, "Unable to parse sortOrder:", - () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + - "sort_order => 'c1 ASC none')", catalogName, tableIdent)); + IllegalArgumentException.class, "Unable to parse sortOrder:", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 ASC none')", catalogName, tableIdent)); // Test for sort_order with invalid sort direction AssertHelpers.assertThrows("Should reject calls with error message", - IllegalArgumentException.class, "Unable to parse sortOrder:", - () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + - "sort_order => 'c1 none NULLS FIRST')", catalogName, tableIdent)); + IllegalArgumentException.class, "Unable to parse sortOrder:", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'c1 none NULLS FIRST')", catalogName, tableIdent)); // Test for sort_order with invalid column name AssertHelpers.assertThrows("Should reject calls with error message", - ValidationException.class, "Cannot find field 'col1' in struct:" + - " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", - () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + - "sort_order => 'col1 DESC NULLS FIRST')", catalogName, tableIdent)); + ValidationException.class, "Cannot find field 'col1' in struct:" + + " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', " + + "sort_order => 'col1 DESC NULLS FIRST')", catalogName, tableIdent)); // Test for sort_order with invalid filter column col1 AssertHelpers.assertThrows("Should reject calls with error message", - ValidationException.class, "Cannot find field 'col1' in struct:" + - " struct<1: c1: optional int, 2: c2: optional string, 3: c3: optional string>", - () -> sql("CALL %s.system.rewrite_data_files(table => '%s', " + - "where => 'col1 = 3')", catalogName, tableIdent)); + IllegalArgumentException.class, "Cannot parse predicates in where option: col1 = 3", + () -> sql("CALL %s.system.rewrite_data_files(table => '%s', " + + "where => 'col1 = 3')", catalogName, tableIdent)); } @Test @@ -251,8 +326,8 @@ private void createPartitionTable() { } private void insertData(int filesCount) { - ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", "detail1"); - ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", "detail2"); + ThreeColumnRecord record1 = new ThreeColumnRecord(1, "foo", null); + ThreeColumnRecord record2 = new ThreeColumnRecord(2, "bar", null); List records = Lists.newArrayList(); IntStream.range(0, filesCount / 2).forEach(i -> { diff --git a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java index 76550f93a33c..b33eab6b5b3c 100644 --- a/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java +++ b/spark/v3.2/spark/src/main/java/org/apache/iceberg/spark/procedures/RewriteDataFilesProcedure.java @@ -28,8 +28,6 @@ import org.apache.spark.sql.AnalysisException; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.parser.ParseException; -import org.apache.spark.sql.catalyst.parser.ParserInterface; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering; import org.apache.spark.sql.catalyst.plans.logical.SortOrderParserUtil; @@ -109,7 +107,7 @@ public InternalRow[] call(InternalRow args) { } String where = args.isNullAt(4) ? null : args.getString(4); - action = checkAndApplyFilter(action, where); + action = checkAndApplyFilter(action, where, table.name()); RewriteDataFiles.Result result = action.execute(); @@ -117,13 +115,12 @@ public InternalRow[] call(InternalRow args) { }); } - private RewriteDataFiles checkAndApplyFilter(RewriteDataFiles action, String where) { + private RewriteDataFiles checkAndApplyFilter(RewriteDataFiles action, String where, String tableName) { if (where != null) { - ParserInterface sqlParser = spark().sessionState().sqlParser(); try { - Expression expression = sqlParser.parseExpression(where); + Expression expression = SparkExpressionConverter.collectResolvedSparkExpression(spark(), tableName, where); return action.filter(SparkExpressionConverter.convertToIcebergExpression(expression)); - } catch (ParseException e) { + } catch (AnalysisException e) { throw new IllegalArgumentException("Cannot parse predicates in where option: " + where); } } diff --git a/spark/v3.2/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala b/spark/v3.2/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala index c41852713d1a..c4b5a7c0ce14 100644 --- a/spark/v3.2/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala +++ b/spark/v3.2/spark/src/main/scala/org/apache/spark/sql/execution/datasources/SparkExpressionConverter.scala @@ -20,7 +20,11 @@ package org.apache.spark.sql.execution.datasources import org.apache.iceberg.spark.SparkFilters +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.execution.CommandExecutionMode object SparkExpressionConverter { @@ -30,4 +34,18 @@ object SparkExpressionConverter { // But these two conversions already exist and well tested. So, we are going with this approach. SparkFilters.convert(DataSourceStrategy.translateFilter(sparkExpression, supportNestedPredicatePushdown = true).get) } + + @throws[AnalysisException] + def collectResolvedSparkExpression(session: SparkSession, tableName: String, where: String): Expression = { + var expression: Expression = null + // Add a dummy prefix linking to the table to collect the resolved spark expression from optimized plan. + val prefix = String.format("SELECT 42 from %s where ", tableName) + val logicalPlan = session.sessionState.sqlParser.parsePlan(prefix + where) + val optimizedLogicalPlan = session.sessionState.executePlan(logicalPlan, CommandExecutionMode.ALL).optimizedPlan + optimizedLogicalPlan.collectFirst { + case filter: Filter => + expression = filter.expressions.head + } + expression + } }