diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java index 72ed83f86df6d..fadecfd085ba3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java @@ -20,6 +20,7 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; /** @@ -46,5 +47,5 @@ public int hashCode() { public String toString() { return "FALSE"; } @Override - public NamedReference[] references() { return EMPTY_REFERENCE; } + public Expression[] references() { return EMPTY_EXPRESSION; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java index b6d39c3f64a77..38e8660a9bc3b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java @@ -20,6 +20,7 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; /** @@ -46,5 +47,5 @@ public int hashCode() { public String toString() { return "TRUE"; } @Override - public NamedReference[] references() { return EMPTY_REFERENCE; } + public Expression[] references() { return EMPTY_EXPRESSION; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java index 0ae6e5af3ca1a..23af11b2cc29f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/BinaryComparison.java @@ -20,8 +20,7 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * Base class for {@link EqualNullSafe}, {@link EqualTo}, {@link GreaterThan}, @@ -31,16 +30,16 @@ */ @Evolving abstract class BinaryComparison extends Filter { - protected final NamedReference column; - protected final Literal value; + protected final Expression column; + protected final Expression value; - protected BinaryComparison(NamedReference column, Literal value) { + protected BinaryComparison(Expression column, Expression value) { this.column = column; this.value = value; } - public NamedReference column() { return column; } - public Literal value() { return value; } + public Expression column() { return column; } + public Expression value() { return value; } @Override public boolean equals(Object o) { @@ -56,5 +55,5 @@ public int hashCode() { } @Override - public NamedReference[] references() { return new NamedReference[] { column }; } + public Expression[] references() { return new Expression[] { column }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java index 34b529194e075..3f700f04efdaf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * Performs equality comparison, similar to {@link EqualTo}. However, this differs from @@ -31,7 +30,7 @@ @Evolving public final class EqualNullSafe extends BinaryComparison { - public EqualNullSafe(NamedReference column, Literal value) { + public EqualNullSafe(Expression column, Expression value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java index b9c4fe053b83c..39ec9d32735fe 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value @@ -30,7 +29,7 @@ @Evolving public final class EqualTo extends BinaryComparison { - public EqualTo(NamedReference column, Literal value) { + public EqualTo(Expression column, Expression value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java index af87e76d2ff7d..c512e42d1cb09 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java @@ -21,7 +21,6 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; -import org.apache.spark.sql.connector.expressions.NamedReference; /** * Filter base class @@ -31,10 +30,10 @@ @Evolving public abstract class Filter implements Expression, Serializable { - protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0]; + protected static final Expression[] EMPTY_EXPRESSION = new Expression[0]; /** * Returns list of columns that are referenced by this filter. */ - public abstract NamedReference[] references(); + public abstract Expression[] references(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java index a3374f359ea29..0aebc61038097 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value @@ -30,7 +29,7 @@ @Evolving public final class GreaterThan extends BinaryComparison { - public GreaterThan(NamedReference column, Literal value) { + public GreaterThan(Expression column, Expression value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java index 4ee921014da41..5d938c7da9a48 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value @@ -30,7 +29,7 @@ @Evolving public final class GreaterThanOrEqual extends BinaryComparison { - public GreaterThanOrEqual(NamedReference column, Literal value) { + public GreaterThanOrEqual(Expression column, Expression value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java index 8d6490b8984fd..6c4f2a5ffb2aa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java @@ -22,8 +22,8 @@ import java.util.stream.Collectors; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to one of the @@ -34,15 +34,15 @@ @Evolving public final class In extends Filter { static final int MAX_LEN_TO_PRINT = 50; - private final NamedReference column; + private final Expression column; private final Literal[] values; - public In(NamedReference column, Literal[] values) { + public In(Expression column, Literal[] values) { this.column = column; this.values = values; } - public NamedReference column() { return column; } + public Expression column() { return column; } public Literal[] values() { return values; } @Override @@ -72,5 +72,5 @@ public String toString() { } @Override - public NamedReference[] references() { return new NamedReference[] { column }; } + public Expression[] references() { return new Expression[] { column }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java index 2cf000e99878e..f1384e59d6082 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java @@ -20,7 +20,7 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to a non-null value. @@ -29,13 +29,13 @@ */ @Evolving public final class IsNotNull extends Filter { - private final NamedReference column; + private final Expression column; - public IsNotNull(NamedReference column) { + public IsNotNull(Expression column) { this.column = column; } - public NamedReference column() { return column; } + public Expression column() { return column; } @Override public String toString() { return column.describe() + " IS NOT NULL"; } @@ -54,5 +54,5 @@ public int hashCode() { } @Override - public NamedReference[] references() { return new NamedReference[] { column }; } + public Expression[] references() { return new Expression[] { column }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java index 1cd497c02242e..f23308cd50866 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java @@ -20,7 +20,7 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to null. @@ -29,13 +29,13 @@ */ @Evolving public final class IsNull extends Filter { - private final NamedReference column; + private final Expression column; - public IsNull(NamedReference column) { + public IsNull(Expression column) { this.column = column; } - public NamedReference column() { return column; } + public Expression column() { return column; } @Override public String toString() { return column.describe() + " IS NULL"; } @@ -54,5 +54,5 @@ public int hashCode() { } @Override - public NamedReference[] references() { return new NamedReference[] { column }; } + public Expression[] references() { return new Expression[] { column }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java index 9fa5cfb87f527..d6a3a01996789 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value @@ -30,7 +29,7 @@ @Evolving public final class LessThan extends BinaryComparison { - public LessThan(NamedReference column, Literal value) { + public LessThan(Expression column, Expression value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java index a41b3c8045d5a..1ce3780c0be93 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java @@ -18,8 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.Literal; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * A filter that evaluates to {@code true} iff the {@code column} evaluates to a value @@ -30,7 +29,7 @@ @Evolving public final class LessThanOrEqual extends BinaryComparison { - public LessThanOrEqual(NamedReference column, Literal value) { + public LessThanOrEqual(Expression column, Expression value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java index 9a01e4d574888..c05e9ad261822 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.unsafe.types.UTF8String; /** @@ -30,7 +30,7 @@ @Evolving public final class StringContains extends StringPredicate { - public StringContains(NamedReference column, UTF8String value) { + public StringContains(Expression column, UTF8String value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java index 11b8317ba4895..c6ff5baba2cf1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.unsafe.types.UTF8String; /** @@ -30,7 +30,7 @@ @Evolving public final class StringEndsWith extends StringPredicate { - public StringEndsWith(NamedReference column, UTF8String value) { + public StringEndsWith(Expression column, UTF8String value) { super(column, value); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java index ffe5d5dba45b3..74d4c5c99eee3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringPredicate.java @@ -20,7 +20,7 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.unsafe.types.UTF8String; /** @@ -31,15 +31,15 @@ */ @Evolving abstract class StringPredicate extends Filter { - protected final NamedReference column; + protected final Expression column; protected final UTF8String value; - protected StringPredicate(NamedReference column, UTF8String value) { + protected StringPredicate(Expression column, UTF8String value) { this.column = column; this.value = value; } - public NamedReference column() { return column; } + public Expression column() { return column; } public UTF8String value() { return value; } @Override @@ -56,5 +56,5 @@ public int hashCode() { } @Override - public NamedReference[] references() { return new NamedReference[] { column }; } + public Expression[] references() { return new Expression[] { column }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java index 38a5de1921cdc..4717ffcd3f0c8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.unsafe.types.UTF8String; /** @@ -30,7 +30,7 @@ @Evolving public final class StringStartsWith extends StringPredicate { - public StringStartsWith(NamedReference column, UTF8String value) { + public StringStartsWith(Expression column, UTF8String value) { super(column, value); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index e544a7c8767e7..64b9dce26094f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -27,17 +27,17 @@ import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, EmptyRow, Expression, Literal, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog} import org.apache.spark.sql.connector.catalog.index.SupportsIndex -import org.apache.spark.sql.connector.expressions.{FieldReference, Literal => V2Literal, LiteralValue} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, Literal => V2Literal, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn, PushableColumnBase} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn, PushableColumnAndNestedColumn, PushableColumnBase, PushableColumnWithoutNestedColumn} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.internal.StaticSQLConf.WAREHOUSE_PATH import org.apache.spark.sql.sources.{BaseRelation, TableScan} @@ -472,64 +472,46 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat private[sql] object DataSourceV2Strategy { private def translateLeafNodeFilterV2( - predicate: Expression, - pushableColumn: PushableColumnBase): Option[V2Filter] = predicate match { - case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => - Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) - case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => - Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) - - case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => - Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) - case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => - Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) - - case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => - Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) - case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => - Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) - - case expressions.LessThan(pushableColumn(name), Literal(v, t)) => - Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) - case expressions.LessThan(Literal(v, t), pushableColumn(name)) => - Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) - - case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) - - case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => - Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => - Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) - - case in @ expressions.InSet(pushableColumn(name), set) => + predicate: Expression): Option[V2Filter] = predicate match { + case expressions.EqualTo(PushableExpression(expr1), PushableExpression(expr2)) => + Some(new V2EqualTo(expr1, expr2)) + case expressions.EqualNullSafe(PushableExpression(expr1), PushableExpression(expr2)) => + Some(new V2EqualNullSafe(expr1, expr2)) + case expressions.GreaterThan(PushableExpression(expr1), PushableExpression(expr2)) => + Some(new V2GreaterThan(expr1, expr2)) + case expressions.LessThan(PushableExpression(expr1), PushableExpression(expr2)) => + Some(new V2LessThan(expr1, expr2)) + case expressions.GreaterThanOrEqual(PushableExpression(expr1), PushableExpression(expr2)) => + Some(new V2GreaterThanOrEqual(expr1, expr2)) + case expressions.LessThanOrEqual(PushableExpression(expr1), PushableExpression(expr2)) => + Some(new V2LessThanOrEqual(expr1, expr2)) + + case in @ expressions.InSet(PushableExpression(expr), set) => val values: Array[V2Literal[_]] = set.toSeq.map(elem => LiteralValue(elem, in.dataType)).toArray - Some(new V2In(FieldReference(name), values)) + Some(new V2In(expr, values)) // Because we only convert In to InSet in Optimizer when there are more than certain // items. So it is possible we still get an In expression here that needs to be pushed // down. - case in @ expressions.In(pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => + case in @ expressions.In(PushableExpression(expr), list) + if list.forall(_.isInstanceOf[Literal]) => val hSet = list.map(_.eval(EmptyRow)) - Some(new V2In(FieldReference(name), - hSet.toArray.map(LiteralValue(_, in.value.dataType)))) + Some(new V2In(expr, hSet.toArray.map(LiteralValue(_, in.value.dataType)))) - case expressions.IsNull(pushableColumn(name)) => - Some(new V2IsNull(FieldReference(name))) - case expressions.IsNotNull(pushableColumn(name)) => - Some(new V2IsNotNull(FieldReference(name))) + case expressions.IsNull(PushableExpression(expr)) => + Some(new V2IsNull(expr)) + case expressions.IsNotNull(PushableExpression(expr)) => + Some(new V2IsNotNull(expr)) - case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(new V2StringStartsWith(FieldReference(name), v)) + case expressions.StartsWith(PushableExpression(expr), Literal(v: UTF8String, StringType)) => + Some(new V2StringStartsWith(expr, v)) - case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(new V2StringEndsWith(FieldReference(name), v)) + case expressions.EndsWith(PushableExpression(expr), Literal(v: UTF8String, StringType)) => + Some(new V2StringEndsWith(expr, v)) - case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => - Some(new V2StringContains(FieldReference(name), v)) + case expressions.Contains(PushableExpression(expr), Literal(v: UTF8String, StringType)) => + Some(new V2StringContains(expr, v)) case expressions.Literal(true, BooleanType) => Some(new V2AlwaysTrue) @@ -537,8 +519,8 @@ private[sql] object DataSourceV2Strategy { case expressions.Literal(false, BooleanType) => Some(new V2AlwaysFalse) - case e @ pushableColumn(name) if e.dataType.isInstanceOf[BooleanType] => - Some(new V2EqualTo(FieldReference(name), LiteralValue(true, BooleanType))) + case e @ PushableExpression(expr) if e.dataType.isInstanceOf[BooleanType] => + Some(new V2EqualTo(expr, LiteralValue(true, BooleanType))) case _ => None } @@ -599,8 +581,7 @@ private[sql] object DataSourceV2Strategy { .map(new V2Not(_)) case other => - val filter = translateLeafNodeFilterV2( - other, PushableColumn(nestedPredicatePushdownEnabled)) + val filter = translateLeafNodeFilterV2(other) if (filter.isDefined && translatedFilterToExpr.isDefined) { translatedFilterToExpr.get(filter.get) = predicate } @@ -626,3 +607,13 @@ private[sql] object DataSourceV2Strategy { } } } + +/** + * Get the expression of DS V2 to represent catalyst expression that can be pushed down. + */ +object PushableExpression { + def unapply(e: Expression): Option[V2Expression] = e match { + case PushableColumnAndNestedColumn(name) => Some(FieldReference(name)) + case _ => new V2ExpressionBuilder(e).build() + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index 61bf729bc8fbf..6c34ed32de5c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -22,12 +22,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN} +import org.apache.spark.sql.connector.expressions.filter.Filter +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.JdbcDialects -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType case class JDBCScanBuilder( @@ -35,7 +35,7 @@ case class JDBCScanBuilder( schema: StructType, jdbcOptions: JDBCOptions) extends ScanBuilder - with SupportsPushDownFilters + with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns with SupportsPushDownAggregates with SupportsPushDownLimit @@ -55,6 +55,15 @@ case class JDBCScanBuilder( private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + if (jdbcOptions.pushDownPredicate) { + val dialect = JdbcDialects.get(jdbcOptions.url) + val (pushed, unSupported) = filters.partition(dialect.compileFilter(_).isDefined) + } else { + filters + } + } + override def pushFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a7e0ec8b72a7c..76e581417b885 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.index.TableIndex import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, NamedReference} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.filter.{EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter} import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} @@ -220,6 +221,34 @@ abstract class JdbcDialect extends Serializable with Logging{ } } + @Since("3.3.0") + def compileFilter(filter: Filter): Option[String] = { + filter match { + case eq: V2EqualTo => + val l = compileExpression(eq.column()) + val r = compileExpression(eq.value()) + if (l.isDefined && r.isDefined) { + Some(s"${l.get} = ${r.get}") + } else { + None + } + case eq: V2EqualNullSafe => + val l = compileExpression(eq.column()) + val r = compileExpression(eq.value()) + if (l.isDefined && r.isDefined) { + Some( + s""" + |(NOT (${l.get} != ${r.get} OR ${l.get} IS NULL OR + |${r.get} IS NULL) OR + |(${l.get} IS NULL AND ${r.get} IS NULL)) + |""".stripMargin) + } else { + None + } + case _ => None + } + } + /** * Converts aggregate function to String representing a SQL expression. * @param aggFunction The aggregate function to be converted.