Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ import org.apache.iceberg.spark.functions.SparkFunctions
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
import org.apache.spark.sql.catalyst.expressions.BinaryComparison
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.InSet
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
import org.apache.spark.sql.catalyst.trees.TreePattern.IN
import org.apache.spark.sql.catalyst.trees.TreePattern.INSET
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
Expand All @@ -40,14 +44,23 @@ import org.apache.spark.sql.types.StructType
object ReplaceStaticInvoke extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) {
plan.transformWithPruning (_.containsPattern(FILTER)) {
case filter @ Filter(condition, _) =>
val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
val newCondition = condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) {
case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))

case in @ In(systemFunction: StaticInvoke, values)
if canReplace(systemFunction) && values.forall(_.foldable) =>
in.copy(value = replaceStaticInvoke(systemFunction))

case in @ InSet(systemFunction: StaticInvoke, _) if canReplace(systemFunction) =>
// InSet does not need the check on the values to be foldable
// because it contains only literals by definition
in.copy(child = replaceStaticInvoke(systemFunction))
}

if (newCondition fastEquals condition) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.apache.iceberg.expressions.Expressions.greaterThan;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.hour;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.month;
Expand All @@ -40,6 +41,8 @@

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.iceberg.expressions.ExpressionUtil;
import org.apache.iceberg.spark.SparkCatalogConfig;
import org.apache.iceberg.spark.source.PlanUtils;
Expand Down Expand Up @@ -213,17 +216,74 @@ private void testBucketLongFunction(boolean partitioned) {
String query =
String.format(
"SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", tableName, target);
checkQueryExecution(query, partitioned, lessThanOrEqual(bucket("id", 5), target));
}

@Test
public void testBucketLongFunctionInClauseOnUnpartitionedTable() {
createUnpartitionedTable(spark, tableName);
testBucketLongFunctionInClause(false);
}

@Test
public void testBucketLongFunctionInClauseOnPartitionedTable() {
createPartitionedTable(spark, tableName, "bucket(5, id)");
testBucketLongFunctionInClause(true);
}

private void testBucketLongFunctionInClause(boolean partitioned) {
List<Integer> inValues = IntStream.range(0, 3).boxed().collect(Collectors.toList());
String inValuesAsSql =
inValues.stream().map(x -> Integer.toString(x)).collect(Collectors.joining(", "));
String query =
String.format(
"SELECT * FROM %s WHERE system.bucket(5, id) IN (%s) ORDER BY id",
tableName, inValuesAsSql);

checkQueryExecution(query, partitioned, in(bucket("id", 5), inValues.toArray()));
}

private void checkQueryExecution(
String query, boolean partitioned, org.apache.iceberg.expressions.Expression expression) {
Dataset<Row> df = spark.sql(query);
LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();

checkExpressions(optimizedPlan, partitioned, "bucket");
checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), target));
checkPushedFilters(optimizedPlan, expression);

List<Object[]> actual = rowsToJava(df.collectAsList());
Assertions.assertThat(actual.size()).isEqualTo(5);
}

@Test
public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnPartitionedTable() {
createPartitionedTable(spark, tableName, "bucket(5, id)");
testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals();
}

@Test
public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnUnpartitionedTable() {
createUnpartitionedTable(spark, tableName);
testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals();
}

private void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals() {
String query =
String.format(
"SELECT * FROM %s WHERE system.bucket(5, id) IN (system.bucket(5, id), 1) ORDER BY id",
tableName);

Dataset<Row> df = spark.sql(query);
LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();

checkExpressionsNotReplaced(
optimizedPlan, "org.apache.iceberg.spark.functions.BucketFunction$BucketLong", 2);
checkNotPushedFilters(optimizedPlan);

List<Object[]> actual = rowsToJava(df.collectAsList());
Assertions.assertThat(actual.size()).isEqualTo(10);
}

@Test
public void testBucketStringFunctionOnUnpartitionedTable() {
createUnpartitionedTable(spark, tableName);
Expand Down Expand Up @@ -311,4 +371,25 @@ private void checkPushedFilters(
.as("Pushed filter should match")
.isTrue();
}

private void checkExpressionsNotReplaced(
LogicalPlan optimizedPlan, String expectedFunctionName, int expectedFunctionCount) {
List<StaticInvoke> staticInvokes =
PlanUtils.collectSparkExpressions(
optimizedPlan, expression -> expression instanceof StaticInvoke)
.stream()
.map(x -> (StaticInvoke) x)
.collect(Collectors.toList());

Assertions.assertThat(staticInvokes.size()).isEqualTo(expectedFunctionCount);
Assertions.assertThat(staticInvokes)
.allSatisfy(
e -> Assertions.assertThat(e.staticObject().getName()).isEqualTo(expectedFunctionName));
}

private void checkNotPushedFilters(LogicalPlan optimizedPlan) {
List<org.apache.iceberg.expressions.Expression> pushedFilters =
PlanUtils.collectPushDownFilters(optimizedPlan);
Assertions.assertThat(pushedFilters.size()).isEqualTo(0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.apache.iceberg.expressions.Expressions.greaterThan;
import static org.apache.iceberg.expressions.Expressions.greaterThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.hour;
import static org.apache.iceberg.expressions.Expressions.in;
import static org.apache.iceberg.expressions.Expressions.lessThan;
import static org.apache.iceberg.expressions.Expressions.lessThanOrEqual;
import static org.apache.iceberg.expressions.Expressions.month;
Expand All @@ -40,6 +41,8 @@
import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.iceberg.ParameterizedTestExtension;
import org.apache.iceberg.Parameters;
import org.apache.iceberg.expressions.ExpressionUtil;
Expand Down Expand Up @@ -212,17 +215,74 @@ private void testBucketLongFunction(boolean partitioned) {
String query =
String.format(
"SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", tableName, target);
checkQueryExecution(query, partitioned, lessThanOrEqual(bucket("id", 5), target));
}

@TestTemplate
public void testBucketLongFunctionInClauseOnUnpartitionedTable() {
createUnpartitionedTable(spark, tableName);
testBucketLongFunctionInClause(false);
}

@TestTemplate
public void testBucketLongFunctionInClauseOnPartitionedTable() {
createPartitionedTable(spark, tableName, "bucket(5, id)");
testBucketLongFunctionInClause(true);
}

private void testBucketLongFunctionInClause(boolean partitioned) {
List<Integer> inValues = IntStream.range(0, 3).boxed().collect(Collectors.toList());
String inValuesAsSql =
inValues.stream().map(x -> Integer.toString(x)).collect(Collectors.joining(", "));
String query =
String.format(
"SELECT * FROM %s WHERE system.bucket(5, id) IN (%s) ORDER BY id",
tableName, inValuesAsSql);

checkQueryExecution(query, partitioned, in(bucket("id", 5), inValues.toArray()));
}

private void checkQueryExecution(
String query, boolean partitioned, org.apache.iceberg.expressions.Expression expression) {
Dataset<Row> df = spark.sql(query);
LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();

checkExpressions(optimizedPlan, partitioned, "bucket");
checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), target));
checkPushedFilters(optimizedPlan, expression);

List<Object[]> actual = rowsToJava(df.collectAsList());
assertThat(actual).hasSize(5);
}

@TestTemplate
public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnPartitionedTable() {
createPartitionedTable(spark, tableName, "bucket(5, id)");
testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals();
}

@TestTemplate
public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnUnpartitionedTable() {
createUnpartitionedTable(spark, tableName);
testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals();
}

private void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals() {
String query =
String.format(
"SELECT * FROM %s WHERE system.bucket(5, id) IN (system.bucket(5, id), 1) ORDER BY id",
tableName);

Dataset<Row> df = spark.sql(query);
LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan();

checkExpressionsNotReplaced(
optimizedPlan, "org.apache.iceberg.spark.functions.BucketFunction$BucketLong", 2);
checkNotPushedFilters(optimizedPlan);

List<Object[]> actual = rowsToJava(df.collectAsList());
assertThat(actual.size()).isEqualTo(10);
}

@TestTemplate
public void testBucketStringFunctionOnUnpartitionedTable() {
createUnpartitionedTable(spark, tableName);
Expand Down Expand Up @@ -310,4 +370,24 @@ private void checkPushedFilters(
.as("Pushed filter should match")
.isTrue();
}

private void checkExpressionsNotReplaced(
LogicalPlan optimizedPlan, String expectedFunctionName, int expectedFunctionCount) {
List<StaticInvoke> staticInvokes =
PlanUtils.collectSparkExpressions(
optimizedPlan, expression -> expression instanceof StaticInvoke)
.stream()
.map(x -> (StaticInvoke) x)
.collect(Collectors.toList());

assertThat(staticInvokes.size()).isEqualTo(expectedFunctionCount);
assertThat(staticInvokes)
.allSatisfy(e -> assertThat(e.staticObject().getName()).isEqualTo(expectedFunctionName));
}

private void checkNotPushedFilters(LogicalPlan optimizedPlan) {
List<org.apache.iceberg.expressions.Expression> pushedFilters =
PlanUtils.collectPushDownFilters(optimizedPlan);
assertThat(pushedFilters.size()).isEqualTo(0);
}
}