Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,20 @@ import org.apache.iceberg.spark.functions.SparkFunctions
import org.apache.spark.sql.catalyst.expressions.ApplyFunctionExpression
import org.apache.spark.sql.catalyst.expressions.BinaryComparison
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.In
import org.apache.spark.sql.catalyst.expressions.InSet
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.plans.logical.Filter
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.BINARY_COMPARISON
import org.apache.spark.sql.catalyst.trees.TreePattern.COMMAND
import org.apache.spark.sql.catalyst.trees.TreePattern.FILTER
import org.apache.spark.sql.catalyst.trees.TreePattern.IN
import org.apache.spark.sql.catalyst.trees.TreePattern.INSET
import org.apache.spark.sql.catalyst.trees.TreePattern.JOIN
import org.apache.spark.sql.connector.catalog.functions.ScalarFunction
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
Expand All @@ -40,21 +48,39 @@ import org.apache.spark.sql.types.StructType
object ReplaceStaticInvoke extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) {
case filter @ Filter(condition, _) =>
val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
}

if (newCondition fastEquals condition) {
filter
} else {
filter.copy(condition = newCondition)
}
plan.transformWithPruning (_.containsAnyPattern(COMMAND, FILTER, JOIN)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note containsAllPatterns became containsAnyPattern. I don't anticipate this being a performance problem, however.

case replace @ ReplaceData(_, cond, _, _, _, _) =>
replaceStaticInvoke(replace, cond, newCond => replace.copy(condition = newCond))

case join @ Join(_, _, _, Some(cond), _) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to cover Join because GroupBasedRowLevelOperationScanPlanning in Spark must be able to simplify the join condition by discarding filters evaluated on the Iceberg side.

replaceStaticInvoke(join, cond, newCond => join.copy(condition = Some(newCond)))

case filter @ Filter(cond, _) =>
replaceStaticInvoke(filter, cond, newCond => filter.copy(condition = newCond))
}

private def replaceStaticInvoke[T <: LogicalPlan](
node: T,
condition: Expression,
copy: Expression => T): T = {
val newCondition = replaceStaticInvoke(condition)
if (newCondition fastEquals condition) node else copy(newCondition)
}

private def replaceStaticInvoke(condition: Expression): Expression = {
condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cover not only BINARY_COMPARISON but also IN and INSET.
Otherwise, IN expressions are not pushed down.

case in @ In(value: StaticInvoke, _) if canReplace(value) =>
in.copy(value = replaceStaticInvoke(value))

case in @ InSet(value: StaticInvoke, _) if canReplace(value) =>
in.copy(child = replaceStaticInvoke(value))

case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable =>
c.withNewChildren(Seq(replaceStaticInvoke(left), right))

case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable =>
c.withNewChildren(Seq(left, replaceStaticInvoke(right)))
}
}

private def replaceStaticInvoke(invoke: StaticInvoke): Expression = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@

import static scala.collection.JavaConverters.seqAsJavaListConverter;

import java.util.Collection;
import java.util.List;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.execution.CommandResultExec;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper;
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec;
import scala.PartialFunction;
import scala.collection.Seq;

public class SparkPlanUtil {
Expand Down Expand Up @@ -53,6 +58,49 @@ private static SparkPlan actualPlan(SparkPlan plan) {
}
}

public static List<Expression> collectExprs(
SparkPlan sparkPlan, Predicate<Expression> predicate) {
Seq<List<Expression>> seq =
SPARK_HELPER.collect(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to logic in PlanUtils but it accounts for AQE plans by relying on SPARK_HELPER.
This class also existed before PlanUtils, so we might want to converge in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these are only used in tests, should we add them to test utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already a test util, just the one that existed before PlanUtils. I would probably merge PlanUtils with this class in the future because this one handles AQE plans.

sparkPlan,
new PartialFunction<SparkPlan, List<Expression>>() {
@Override
public List<Expression> apply(SparkPlan plan) {
List<Expression> exprs = Lists.newArrayList();

for (Expression expr : toJavaList(plan.expressions())) {
exprs.addAll(collectExprs(expr, predicate));
}

return exprs;
}

@Override
public boolean isDefinedAt(SparkPlan plan) {
return true;
}
});
return toJavaList(seq).stream().flatMap(Collection::stream).collect(Collectors.toList());
}

private static List<Expression> collectExprs(
Expression expression, Predicate<Expression> predicate) {
Seq<Expression> seq =
expression.collect(
new PartialFunction<Expression, Expression>() {
@Override
public Expression apply(Expression expr) {
return expr;
}

@Override
public boolean isDefinedAt(Expression expr) {
return predicate.test(expr);
}
});
return toJavaList(seq);
}

private static <T> List<T> toJavaList(Seq<T> seq) {
return seqAsJavaListConverter(seq).asJava();
}
Expand Down
Loading