diff --git a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala index 1f0e164d8467..4435587e925f 100644 --- a/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala +++ b/spark/v3.4/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala @@ -22,12 +22,16 @@ import org.apache.iceberg.spark.functions.SparkFunctions import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression import org.apache.spark.sql.catalyst.expressions.BinaryComparison import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.In +import org.apache.spark.sql.catalyst.expressions.InSet import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.plans.logical.Filter import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER +import org.apache.spark.sql.catalyst.trees.TreePattern.IN +import org.apache.spark.sql.catalyst.trees.TreePattern.INSET import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.types.StructField import org.apache.spark.sql.types.StructType @@ -40,14 +44,23 @@ import org.apache.spark.sql.types.StructType object ReplaceStaticInvoke extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) { + plan.transformWithPruning (_.containsPattern(FILTER)) { case filter @ Filter(condition, _) => - val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { + val newCondition = condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) { case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable => c.withNewChildren(Seq(replaceStaticInvoke(left), right)) case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable => c.withNewChildren(Seq(left, replaceStaticInvoke(right))) + + case in @ In(systemFunction: StaticInvoke, values) + if canReplace(systemFunction) && values.forall(_.foldable) => + in.copy(value = replaceStaticInvoke(systemFunction)) + + case in @ InSet(systemFunction: StaticInvoke, _) if canReplace(systemFunction) => + // InSet does not need the check on the values to be foldable + // because it contains only literals by definition + in.copy(child = replaceStaticInvoke(systemFunction)) } if (newCondition fastEquals condition) { diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java index 7f2857cce0b9..d58917350945 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java @@ -24,6 +24,7 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.hour; +import static org.apache.iceberg.expressions.Expressions.in; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.month; @@ -40,6 +41,8 @@ import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.iceberg.expressions.ExpressionUtil; import org.apache.iceberg.spark.SparkCatalogConfig; import org.apache.iceberg.spark.source.PlanUtils; @@ -213,17 +216,74 @@ private void testBucketLongFunction(boolean partitioned) { String query = String.format( "SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", tableName, target); + checkQueryExecution(query, partitioned, lessThanOrEqual(bucket("id", 5), target)); + } + + @Test + public void testBucketLongFunctionInClauseOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunctionInClause(false); + } + + @Test + public void testBucketLongFunctionInClauseOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunctionInClause(true); + } + + private void testBucketLongFunctionInClause(boolean partitioned) { + List inValues = IntStream.range(0, 3).boxed().collect(Collectors.toList()); + String inValuesAsSql = + inValues.stream().map(x -> Integer.toString(x)).collect(Collectors.joining(", ")); + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) IN (%s) ORDER BY id", + tableName, inValuesAsSql); + checkQueryExecution(query, partitioned, in(bucket("id", 5), inValues.toArray())); + } + + private void checkQueryExecution( + String query, boolean partitioned, org.apache.iceberg.expressions.Expression expression) { Dataset df = spark.sql(query); LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); checkExpressions(optimizedPlan, partitioned, "bucket"); - checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), target)); + checkPushedFilters(optimizedPlan, expression); List actual = rowsToJava(df.collectAsList()); Assertions.assertThat(actual.size()).isEqualTo(5); } + @Test + public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals(); + } + + @Test + public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals(); + } + + private void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals() { + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) IN (system.bucket(5, id), 1) ORDER BY id", + tableName); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressionsNotReplaced( + optimizedPlan, "org.apache.iceberg.spark.functions.BucketFunction$BucketLong", 2); + checkNotPushedFilters(optimizedPlan); + + List actual = rowsToJava(df.collectAsList()); + Assertions.assertThat(actual.size()).isEqualTo(10); + } + @Test public void testBucketStringFunctionOnUnpartitionedTable() { createUnpartitionedTable(spark, tableName); @@ -311,4 +371,25 @@ private void checkPushedFilters( .as("Pushed filter should match") .isTrue(); } + + private void checkExpressionsNotReplaced( + LogicalPlan optimizedPlan, String expectedFunctionName, int expectedFunctionCount) { + List staticInvokes = + PlanUtils.collectSparkExpressions( + optimizedPlan, expression -> expression instanceof StaticInvoke) + .stream() + .map(x -> (StaticInvoke) x) + .collect(Collectors.toList()); + + Assertions.assertThat(staticInvokes.size()).isEqualTo(expectedFunctionCount); + Assertions.assertThat(staticInvokes) + .allSatisfy( + e -> Assertions.assertThat(e.staticObject().getName()).isEqualTo(expectedFunctionName)); + } + + private void checkNotPushedFilters(LogicalPlan optimizedPlan) { + List pushedFilters = + PlanUtils.collectPushDownFilters(optimizedPlan); + Assertions.assertThat(pushedFilters.size()).isEqualTo(0); + } } diff --git a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java index f6102bab69b0..a22109d64674 100644 --- a/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java +++ b/spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java @@ -24,6 +24,7 @@ import static org.apache.iceberg.expressions.Expressions.greaterThan; import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual; import static org.apache.iceberg.expressions.Expressions.hour; +import static org.apache.iceberg.expressions.Expressions.in; import static org.apache.iceberg.expressions.Expressions.lessThan; import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual; import static org.apache.iceberg.expressions.Expressions.month; @@ -40,6 +41,8 @@ import static org.assertj.core.api.Assertions.assertThat; import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.iceberg.ParameterizedTestExtension; import org.apache.iceberg.Parameters; import org.apache.iceberg.expressions.ExpressionUtil; @@ -212,17 +215,74 @@ private void testBucketLongFunction(boolean partitioned) { String query = String.format( "SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", tableName, target); + checkQueryExecution(query, partitioned, lessThanOrEqual(bucket("id", 5), target)); + } + + @TestTemplate + public void testBucketLongFunctionInClauseOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunctionInClause(false); + } + + @TestTemplate + public void testBucketLongFunctionInClauseOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunctionInClause(true); + } + + private void testBucketLongFunctionInClause(boolean partitioned) { + List inValues = IntStream.range(0, 3).boxed().collect(Collectors.toList()); + String inValuesAsSql = + inValues.stream().map(x -> Integer.toString(x)).collect(Collectors.joining(", ")); + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) IN (%s) ORDER BY id", + tableName, inValuesAsSql); + checkQueryExecution(query, partitioned, in(bucket("id", 5), inValues.toArray())); + } + + private void checkQueryExecution( + String query, boolean partitioned, org.apache.iceberg.expressions.Expression expression) { Dataset df = spark.sql(query); LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); checkExpressions(optimizedPlan, partitioned, "bucket"); - checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), target)); + checkPushedFilters(optimizedPlan, expression); List actual = rowsToJava(df.collectAsList()); assertThat(actual).hasSize(5); } + @TestTemplate + public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals(); + } + + @TestTemplate + public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals(); + } + + private void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals() { + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) IN (system.bucket(5, id), 1) ORDER BY id", + tableName); + + Dataset df = spark.sql(query); + LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); + + checkExpressionsNotReplaced( + optimizedPlan, "org.apache.iceberg.spark.functions.BucketFunction$BucketLong", 2); + checkNotPushedFilters(optimizedPlan); + + List actual = rowsToJava(df.collectAsList()); + assertThat(actual.size()).isEqualTo(10); + } + @TestTemplate public void testBucketStringFunctionOnUnpartitionedTable() { createUnpartitionedTable(spark, tableName); @@ -310,4 +370,24 @@ private void checkPushedFilters( .as("Pushed filter should match") .isTrue(); } + + private void checkExpressionsNotReplaced( + LogicalPlan optimizedPlan, String expectedFunctionName, int expectedFunctionCount) { + List staticInvokes = + PlanUtils.collectSparkExpressions( + optimizedPlan, expression -> expression instanceof StaticInvoke) + .stream() + .map(x -> (StaticInvoke) x) + .collect(Collectors.toList()); + + assertThat(staticInvokes.size()).isEqualTo(expectedFunctionCount); + assertThat(staticInvokes) + .allSatisfy(e -> assertThat(e.staticObject().getName()).isEqualTo(expectedFunctionName)); + } + + private void checkNotPushedFilters(LogicalPlan optimizedPlan) { + List pushedFilters = + PlanUtils.collectPushDownFilters(optimizedPlan); + assertThat(pushedFilters.size()).isEqualTo(0); + } }