diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala index b43c7faddcd93..6ff5ebc93f65a 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroPartitionReaderFactory.scala @@ -30,11 +30,11 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroUtils} import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, OrderedFilters} +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile} import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory, PartitionReaderWithPartitionValues} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -55,7 +55,7 @@ case class AvroPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, parsedOptions: AvroOptions, - filters: Seq[Filter]) extends FilePartitionReaderFactory with Logging { + filters: Seq[V2Filter]) extends FilePartitionReaderFactory with Logging { private val datetimeRebaseModeInRead = parsedOptions.datetimeRebaseModeInRead override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { @@ -94,7 +94,7 @@ case class AvroPartitionReaderFactory( datetimeRebaseModeInRead) val avroFilters = if (SQLConf.get.avroFilterPushDown) { - new OrderedFilters(filters, readDataSchema) + new OrderedFilters(filters.map(_.toV1), readDataSchema) } else { new NoopFilters } diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index d0f38c12427c3..ea1d16f1aa613 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -23,10 +23,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.avro.AvroOptions import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -38,7 +38,7 @@ case class AvroScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter], + pushedFilters: Array[V2Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index 8fae89a945826..cab7848fa86d3 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -46,11 +46,11 @@ class AvroScanBuilder ( dataFilters) } - override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = { if (sparkSession.sessionState.conf.avroFilterPushDown) { - StructFilters.pushedFilters(dataFilters, dataSchema) + StructFilters.pushedFiltersV2(dataFilters, dataSchema) } else { - Array.empty[Filter] + Array.empty[V2Filter] } } } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala index 98a7190ba984e..4ea55ced65a27 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroScanSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.v2.avro.AvroScan class AvroScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("AvroScan", - (s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o, f, pf, df), - Seq.empty)) + (s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o, + f.map(_.toV2), pf, df), Seq.empty)) run(scanBuilders) } diff --git a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 510ddfcabc5a3..f64411cab157e 100644 --- a/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/external/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2329,7 +2329,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { |Format: avro |Location: InMemoryFileIndex\\([0-9]+ paths\\)\\[.*\\] |PartitionFilters: \\[isnotnull\\(id#x\\), \\(id#x > 1\\)\\] - |PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\] + |PushedFilters: \\[value IS NOT NULL, value > 2\\] |ReadSchema: struct\\ |""".stripMargin.trim spark.range(10) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java index 72ed83f86df6d..bb05479fcb88a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java @@ -47,4 +47,9 @@ public int hashCode() { @Override public NamedReference[] references() { return EMPTY_REFERENCE; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.AlwaysFalse(); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java index b6d39c3f64a77..64aaf18fde109 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java @@ -47,4 +47,9 @@ public int hashCode() { @Override public NamedReference[] references() { return EMPTY_REFERENCE; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.AlwaysTrue(); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java index e0b8b13acb158..88867b4bfde5a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java @@ -36,4 +36,9 @@ public And(Filter left, Filter right) { public String toString() { return String.format("(%s) AND (%s)", left.describe(), right.describe()); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.And(left.toV1(), right.toV1()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java index 34b529194e075..cfb9f8b56948f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualNullSafe.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -37,4 +38,10 @@ public EqualNullSafe(NamedReference column, Literal value) { @Override public String toString() { return this.column.describe() + " <=> " + value.describe(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.EqualNullSafe( + column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType())); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java index b9c4fe053b83c..9566506a3a4d6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/EqualTo.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -36,4 +37,10 @@ public EqualTo(NamedReference column, Literal value) { @Override public String toString() { return column.describe() + " = " + value.describe(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.EqualTo( + (column).describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType())); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java index 852837496a103..a550cce5b9b49 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Filter.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions.filter; +import java.io.Serializable; + import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Expression; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -27,7 +29,7 @@ * @since 3.3.0 */ @Evolving -public abstract class Filter implements Expression { +public abstract class Filter implements Expression, Serializable { protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0]; @@ -38,4 +40,9 @@ public abstract class Filter implements Expression { @Override public String describe() { return this.toString(); } + + /** + * Returns a V1 Filter. + */ + public abstract org.apache.spark.sql.sources.Filter toV1(); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java index a3374f359ea29..c6184f66e0278 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThan.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -36,4 +37,11 @@ public GreaterThan(NamedReference column, Literal value) { @Override public String toString() { return column.describe() + " > " + value.describe(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.GreaterThan( + column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType())); + } + } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java index 4ee921014da41..8be4dcd31aefe 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/GreaterThanOrEqual.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -36,4 +37,10 @@ public GreaterThanOrEqual(NamedReference column, Literal value) { @Override public String toString() { return column.describe() + " >= " + value.describe(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.GreaterThanOrEqual( + column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType())); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java index 8d6490b8984fd..f714c77510351 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/In.java @@ -22,6 +22,7 @@ import java.util.stream.Collectors; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -73,4 +74,15 @@ public String toString() { @Override public NamedReference[] references() { return new NamedReference[] { column }; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + Object[] array = new Object[values.length]; + int index = 0; + for (Literal value: values) { + array[index] = CatalystTypeConverters.convertToScala(value.value(), value.dataType()); + index++; + } + return new org.apache.spark.sql.sources.In(column.describe(), array); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java index 2cf000e99878e..da08d5012b810 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNotNull.java @@ -55,4 +55,9 @@ public int hashCode() { @Override public NamedReference[] references() { return new NamedReference[] { column }; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.IsNotNull(column.describe()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java index 1cd497c02242e..03a581044d9a9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/IsNull.java @@ -55,4 +55,9 @@ public int hashCode() { @Override public NamedReference[] references() { return new NamedReference[] { column }; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.IsNull(column.describe()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java index 9fa5cfb87f527..b024f1847279b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThan.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -36,4 +37,10 @@ public LessThan(NamedReference column, Literal value) { @Override public String toString() { return column.describe() + " < " + value.describe(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.LessThan( + column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType())); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java index a41b3c8045d5a..2699b020f00d7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/LessThanOrEqual.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.filter; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.CatalystTypeConverters; import org.apache.spark.sql.connector.expressions.Literal; import org.apache.spark.sql.connector.expressions.NamedReference; @@ -36,4 +37,10 @@ public LessThanOrEqual(NamedReference column, Literal value) { @Override public String toString() { return column.describe() + " <= " + value.describe(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.LessThanOrEqual( + column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType())); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java index 69746f59ee933..76572222923b4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java @@ -53,4 +53,9 @@ public int hashCode() { @Override public NamedReference[] references() { return child.references(); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.Not(child.toV1()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java index baa33d849feef..c81fa2b3dc0d6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java @@ -36,4 +36,9 @@ public Or(Filter left, Filter right) { public String toString() { return String.format("(%s) OR (%s)", left.describe(), right.describe()); } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.Or(left.toV1(), right.toV1()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java index 9a01e4d574888..042e83063d621 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringContains.java @@ -36,4 +36,9 @@ public StringContains(NamedReference column, UTF8String value) { @Override public String toString() { return "STRING_CONTAINS(" + column.describe() + ", " + value + ")"; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.StringContains(column.describe(), value.toString()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java index 11b8317ba4895..8abdbfd5c623c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringEndsWith.java @@ -36,4 +36,9 @@ public StringEndsWith(NamedReference column, UTF8String value) { @Override public String toString() { return "STRING_ENDS_WITH(" + column.describe() + ", " + value + ")"; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.StringEndsWith(column.describe(), value.toString()); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java index 38a5de1921cdc..59d64f57deede 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/StringStartsWith.java @@ -38,4 +38,9 @@ public StringStartsWith(NamedReference column, UTF8String value) { public String toString() { return "STRING_STARTS_WITH(" + column.describe() + ", " + value + ")"; } + + @Override + public org.apache.spark.sql.sources.Filter toV1() { + return new org.apache.spark.sql.sources.StringStartsWith(column.describe(), value.toString()); + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala index ff67b6fccfae9..bdaa4f7842c70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/StructFilters.scala @@ -21,6 +21,7 @@ import scala.util.Try import org.apache.spark.sql.catalyst.StructFilters._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.sources import org.apache.spark.sql.types.{BooleanType, StructType} @@ -93,6 +94,16 @@ object StructFilters { filters.filter(checkFilterRefs(_, fieldNames)) } + private def checkFilterRefsV2(filter: V2Filter, fieldNames: Set[String]): Boolean = { + // The names have been normalized and case sensitivity is not a concern here. + filter.references.map(_.fieldNames().mkString(".")).forall(fieldNames.contains) + } + + def pushedFiltersV2(filters: Array[V2Filter], schema: StructType): Array[V2Filter] = { + val fieldNames = schema.fieldNames.toSet + filters.filter(checkFilterRefsV2(_, fieldNames)) + } + private def zip[A, B](a: Option[A], b: Option[B]): Option[(A, B)] = { a.zip(b).headOption } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 385e6b783f1e2..5fef5077bfa61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction} import org.apache.spark.sql.connector.expressions.{NamedReference, Transform} +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.{LEGACY_ALLOW_NEGATIVE_SCALE_OF_DECIMAL_ENABLED, LEGACY_CTE_PRECEDENCE_POLICY} import org.apache.spark.sql.sources.Filter @@ -1136,6 +1137,11 @@ object QueryCompilationErrors { s"Fail to rebuild expression: missing key $filter in `translatedFilterToExpr`") } + def failedToRebuildExpressionError(filter: V2Filter): Throwable = { + new AnalysisException( + s"Fail to rebuild expression: missing key $filter in `translatedFilterToExpr`") + } + def dataTypeUnsupportedByDataSourceError(format: String, field: StructField): Throwable = { new AnalysisException( s"$format data source does not support ${field.dataType.catalogString} data type.") @@ -2392,4 +2398,8 @@ object QueryCompilationErrors { errorClass = "INVALID_JSON_SCHEMA_MAPTYPE", messageParameters = Array(schema.toString)) } + + def invalidDataTypeForFilterValue(value: Any): Throwable = { + new AnalysisException(s"Filter value $value has invalid data type") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala index 9c2a4ac78a24a..b1bc75c892a2d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.sources.Filter +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} /** * A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to @@ -37,5 +37,5 @@ trait SupportsPushDownCatalystFilters { * Returns the data filters that are pushed to the data source via * {@link #pushFilters(Expression[])}. */ - def pushedFilters: Array[Filter] + def pushedFilters: Array[V2Filter] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 174dd088d4c66..5a06b93a9b162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,8 +17,18 @@ package org.apache.spark.sql.sources +import java.math.{BigDecimal => JavaBigDecimal, BigInteger => JavaBigInteger} +import java.sql.{Date, Timestamp} +import java.time.{Instant, LocalDate} + import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -43,6 +53,40 @@ sealed abstract class Filter { */ def references: Array[String] + private[sql] def toV2: V2Filter + + private[sql] def getLiteralValue(value: Any): LiteralValue[_] = value match { + case _: JavaBigDecimal => + LiteralValue(Decimal(value.asInstanceOf[JavaBigDecimal]), DecimalType.SYSTEM_DEFAULT) + case _: JavaBigInteger => + LiteralValue(Decimal(value.asInstanceOf[JavaBigInteger]), DecimalType.SYSTEM_DEFAULT) + case _: BigDecimal => + LiteralValue(Decimal(value.asInstanceOf[BigDecimal]), DecimalType.SYSTEM_DEFAULT) + case _: Boolean => LiteralValue(value, BooleanType) + case _: Byte => LiteralValue(value, ByteType) + case _: Array[Byte] => LiteralValue(value, BinaryType) + case _: Date => + val date = DateTimeUtils.fromJavaDate(value.asInstanceOf[Date]) + LiteralValue(date, DateType) + case _: LocalDate => + val date = DateTimeUtils.localDateToDays(value.asInstanceOf[LocalDate]) + LiteralValue(date, DateType) + case _: Double => LiteralValue(value, DoubleType) + case _: Float => LiteralValue(value, FloatType) + case _: Integer => LiteralValue(value, IntegerType) + case _: Long => LiteralValue(value, LongType) + case _: Short => LiteralValue(value, ShortType) + case _: String => LiteralValue(UTF8String.fromString(value.toString), StringType) + case _: Timestamp => + val ts = DateTimeUtils.fromJavaTimestamp(value.asInstanceOf[Timestamp]) + LiteralValue(ts, TimestampType) + case _: Instant => + val ts = DateTimeUtils.instantToMicros(value.asInstanceOf[Instant]) + LiteralValue(ts, TimestampType) + case _ => + throw QueryCompilationErrors.invalidDataTypeForFilterValue(value) + } + protected def findReferences(value: Any): Array[String] = value match { case f: Filter => f.references case _ => Array.empty @@ -78,6 +122,9 @@ sealed abstract class Filter { @Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override private[sql] def toV2 = { + new V2EqualTo(FieldReference(attribute), getLiteralValue(value)) + } } /** @@ -93,6 +140,8 @@ case class EqualTo(attribute: String, value: Any) extends Filter { @Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override private[sql] def toV2 = + new V2EqualNullSafe(FieldReference(attribute), getLiteralValue(value)) } /** @@ -107,6 +156,8 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { @Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override private[sql] def toV2 = + new V2GreaterThan(FieldReference(attribute), getLiteralValue(value)) } /** @@ -121,6 +172,8 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { @Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override private[sql] def toV2 = + new V2GreaterThanOrEqual(FieldReference(attribute), getLiteralValue(value)) } /** @@ -135,6 +188,8 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { @Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override private[sql] def toV2 = + new V2LessThan(FieldReference(attribute), getLiteralValue(value)) } /** @@ -149,6 +204,8 @@ case class LessThan(attribute: String, value: Any) extends Filter { @Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override private[sql] def toV2 = + new V2LessThanOrEqual(FieldReference(attribute), getLiteralValue(value)) } /** @@ -185,6 +242,9 @@ case class In(attribute: String, values: Array[Any]) extends Filter { } override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences) + override private[sql] def toV2 = + new V2In(FieldReference(attribute), + values.map(value => getLiteralValue(value))) } /** @@ -198,6 +258,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { @Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + override private[sql] def toV2 = new V2IsNull(FieldReference(attribute)) } /** @@ -211,6 +272,7 @@ case class IsNull(attribute: String) extends Filter { @Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + override private[sql] def toV2 = new V2IsNotNull(FieldReference(attribute)) } /** @@ -221,6 +283,7 @@ case class IsNotNull(attribute: String) extends Filter { @Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + override private[sql] def toV2 = new V2And(left.toV2, right.toV2) } /** @@ -231,6 +294,7 @@ case class And(left: Filter, right: Filter) extends Filter { @Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + override private[sql] def toV2 = new V2Or(left.toV2, right.toV2) } /** @@ -241,6 +305,7 @@ case class Or(left: Filter, right: Filter) extends Filter { @Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references + override private[sql] def toV2 = new V2Not(child.toV2) } /** @@ -255,6 +320,8 @@ case class Not(child: Filter) extends Filter { @Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override private[sql] def toV2 = new V2StringStartsWith(FieldReference(attribute), + UTF8String.fromString(value)) } /** @@ -269,6 +336,8 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { @Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override private[sql] def toV2 = new V2StringEndsWith(FieldReference(attribute), + UTF8String.fromString(value)) } /** @@ -283,6 +352,8 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { @Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override private[sql] def toV2 = new V2StringContains(FieldReference(attribute), + UTF8String.fromString(value)) } /** @@ -293,6 +364,7 @@ case class StringContains(attribute: String, value: String) extends Filter { @Evolving case class AlwaysTrue() extends Filter { override def references: Array[String] = Array.empty + override private[sql] def toV2 = new V2AlwaysTrue() } @Evolving @@ -307,6 +379,7 @@ object AlwaysTrue extends AlwaysTrue { @Evolving case class AlwaysFalse() extends Filter { override def references: Array[String] = Array.empty + override private[sql] def toV2 = new V2AlwaysFalse() } @Evolving diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 5abfa4cc9ef0d..0545798615529 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -26,10 +26,13 @@ import org.apache.hadoop.hive.ql.io.sarg.SearchArgumentFactory.newBuilder import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} +import org.apache.spark.sql.connector.expressions.Literal +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. @@ -65,9 +68,17 @@ import org.apache.spark.sql.types._ private[sql] object OrcFilters extends OrcFiltersBase { /** - * Create ORC filter as a SearchArgument instance. + * Create ORC filter as a SearchArgument instance from V1 Filters. */ - def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { + def createFilter(schema: StructType, filters: Seq[Filter])(implicit d: DummyImplicit) + : Option[SearchArgument] = { + createFilter(schema, filters.map(_.toV2)) + } + + /** + * Create ORC filter as a SearchArgument instance from V2 Filters. + */ + def createFilter(schema: StructType, filters: Seq[V2Filter]): Option[SearchArgument] = { val dataTypeMap = OrcFilters.getSearchableTypeMap(schema, SQLConf.get.caseSensitiveAnalysis) // Combines all convertible filters using `And` to produce a single conjunction val conjunctionOptional = buildTree(convertibleFilters(dataTypeMap, filters)) @@ -81,12 +92,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { def convertibleFilters( dataTypeMap: Map[String, OrcPrimitiveField], - filters: Seq[Filter]): Seq[Filter] = { - import org.apache.spark.sql.sources._ + filters: Seq[V2Filter]): Seq[V2Filter] = { def convertibleFiltersHelper( - filter: Filter, - canPartialPushDown: Boolean): Option[Filter] = filter match { + filter: V2Filter, + canPartialPushDown: Boolean): Option[V2Filter] = filter match { // At here, it is not safe to just convert one side and remove the other side // if we do not understand what the parent filters are. // @@ -98,11 +108,11 @@ private[sql] object OrcFilters extends OrcFiltersBase { // Pushing one side of AND down is only safe to do at the top level or in the child // AND before hitting NOT or OR conditions, and in this case, the unsupported predicate // can be safely removed. - case And(left, right) => - val leftResultOptional = convertibleFiltersHelper(left, canPartialPushDown) - val rightResultOptional = convertibleFiltersHelper(right, canPartialPushDown) + case f: V2And => + val leftResultOptional = convertibleFiltersHelper(f.left, canPartialPushDown) + val rightResultOptional = convertibleFiltersHelper(f.right, canPartialPushDown) (leftResultOptional, rightResultOptional) match { - case (Some(leftResult), Some(rightResult)) => Some(And(leftResult, rightResult)) + case (Some(leftResult), Some(rightResult)) => Some(new V2And(leftResult, rightResult)) case (Some(leftResult), None) if canPartialPushDown => Some(leftResult) case (None, Some(rightResult)) if canPartialPushDown => Some(rightResult) case _ => None @@ -119,14 +129,14 @@ private[sql] object OrcFilters extends OrcFiltersBase { // The predicate can be converted as // (a1 OR b1) AND (a1 OR b2) AND (a2 OR b1) AND (a2 OR b2) // As per the logical in And predicate, we can push down (a1 OR b1). - case Or(left, right) => + case f: V2Or => for { - lhs <- convertibleFiltersHelper(left, canPartialPushDown) - rhs <- convertibleFiltersHelper(right, canPartialPushDown) - } yield Or(lhs, rhs) - case Not(pred) => - val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) - childResultOptional.map(Not) + lhs <- convertibleFiltersHelper(f.left, canPartialPushDown) + rhs <- convertibleFiltersHelper(f.right, canPartialPushDown) + } yield new V2Or(lhs, rhs) + case f: V2Not => + val childResultOptional = convertibleFiltersHelper(f.child, canPartialPushDown = false) + childResultOptional.map(new V2Not(_)) case other => for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other } @@ -157,16 +167,32 @@ private[sql] object OrcFilters extends OrcFiltersBase { */ private def castLiteralValue(value: Any, dataType: DataType): Any = dataType match { case ByteType | ShortType | IntegerType | LongType => - value.asInstanceOf[Number].longValue + value.asInstanceOf[Literal[_]].value.asInstanceOf[Number].longValue case FloatType | DoubleType => - value.asInstanceOf[Number].doubleValue() + value.asInstanceOf[Literal[_]].value.asInstanceOf[Number].doubleValue() + case _: DecimalType + if value.asInstanceOf[Literal[_]].value.isInstanceOf[java.math.BigDecimal] => + new HiveDecimalWritable(HiveDecimal.create + (value.asInstanceOf[Literal[_]].value.asInstanceOf[java.math.BigDecimal])) case _: DecimalType => - new HiveDecimalWritable(HiveDecimal.create(value.asInstanceOf[java.math.BigDecimal])) - case _: DateType if value.isInstanceOf[LocalDate] => - toJavaDate(localDateToDays(value.asInstanceOf[LocalDate])) - case _: TimestampType if value.isInstanceOf[Instant] => - toJavaTimestamp(instantToMicros(value.asInstanceOf[Instant])) - case _ => value + new HiveDecimalWritable(HiveDecimal.create + (value.asInstanceOf[Literal[_]].value.asInstanceOf[Decimal].toJavaBigDecimal)) + case _: DateType if value.asInstanceOf[Literal[_]].value.isInstanceOf[LocalDate] => + toJavaDate(localDateToDays(value.asInstanceOf[Literal[_]].value.asInstanceOf[LocalDate])) + case _: DateType if value.asInstanceOf[Literal[_]].value.isInstanceOf[Integer] => + toJavaDate(value.asInstanceOf[Literal[_]].value.asInstanceOf[Integer]) + case _: TimestampType if value.asInstanceOf[Literal[_]].value.isInstanceOf[Instant] => + toJavaTimestamp(instantToMicros(value.asInstanceOf[Literal[_]].value.asInstanceOf[Instant])) + case _: TimestampType if value.asInstanceOf[Literal[_]].value.isInstanceOf[Long] => + toJavaTimestamp(value.asInstanceOf[Literal[_]].value.asInstanceOf[Long]) + case StringType => + val str = value.asInstanceOf[Literal[_]].value + if(str.isInstanceOf[UTF8String]) { + str.asInstanceOf[UTF8String].toString + } else { + str + } + case _ => value.asInstanceOf[Literal[_]].value } /** @@ -179,23 +205,22 @@ private[sql] object OrcFilters extends OrcFiltersBase { */ private def buildSearchArgument( dataTypeMap: Map[String, OrcPrimitiveField], - expression: Filter, + expression: V2Filter, builder: Builder): Builder = { - import org.apache.spark.sql.sources._ expression match { - case And(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) + case f: V2And => + val lhs = buildSearchArgument(dataTypeMap, f.left, builder.startAnd()) + val rhs = buildSearchArgument(dataTypeMap, f.right, lhs) rhs.end() - case Or(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) + case f: V2Or => + val lhs = buildSearchArgument(dataTypeMap, f.left, builder.startOr()) + val rhs = buildSearchArgument(dataTypeMap, f.right, lhs) rhs.end() - case Not(child) => - buildSearchArgument(dataTypeMap, child, builder.startNot()).end() + case f: V2Not => + buildSearchArgument(dataTypeMap, f.child, builder.startNot()).end() case other => buildLeafSearchArgument(dataTypeMap, other, builder).getOrElse { @@ -215,58 +240,67 @@ private[sql] object OrcFilters extends OrcFiltersBase { */ private def buildLeafSearchArgument( dataTypeMap: Map[String, OrcPrimitiveField], - expression: Filter, + expression: V2Filter, builder: Builder): Option[Builder] = { def getType(attribute: String): PredicateLeaf.Type = getPredicateLeafType(dataTypeMap(attribute).fieldType) - import org.apache.spark.sql.sources._ - // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { - case EqualTo(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .equals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + case f: V2EqualTo if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValue = castLiteralValue(f.value, dataTypeMap(colName).fieldType) + Some(builder.startAnd().equals( + dataTypeMap(colName).fieldName, getType(colName), castedValue).end()) - case EqualNullSafe(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .nullSafeEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + case f: V2EqualNullSafe + if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValue = castLiteralValue(f.value, dataTypeMap(colName).fieldType) + Some(builder.startAnd().nullSafeEquals( + dataTypeMap(colName).fieldName, getType(colName), castedValue).end()) - case LessThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + case f: V2LessThan if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValue = castLiteralValue(f.value, dataTypeMap(colName).fieldType) + Some(builder.startAnd().lessThan( + dataTypeMap(colName).fieldName, getType(colName), castedValue).end()) - case LessThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startAnd() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + case f: V2LessThanOrEqual + if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValue = castLiteralValue(f.value, dataTypeMap(colName).fieldType) + Some(builder.startAnd().lessThanEquals( + dataTypeMap(colName).fieldName, getType(colName), castedValue).end()) - case GreaterThan(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThanEquals(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + case f: V2GreaterThan + if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValue = castLiteralValue(f.value, dataTypeMap(colName).fieldType) + Some(builder.startNot().lessThanEquals( + dataTypeMap(colName).fieldName, getType(colName), castedValue).end()) - case GreaterThanOrEqual(name, value) if dataTypeMap.contains(name) => - val castedValue = castLiteralValue(value, dataTypeMap(name).fieldType) - Some(builder.startNot() - .lessThan(dataTypeMap(name).fieldName, getType(name), castedValue).end()) + case f: V2GreaterThanOrEqual + if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValue = castLiteralValue(f.value, dataTypeMap(colName).fieldType) + Some(builder.startNot().lessThan( + dataTypeMap(colName).fieldName, getType(colName), castedValue).end()) - case IsNull(name) if dataTypeMap.contains(name) => - Some(builder.startAnd() - .isNull(dataTypeMap(name).fieldName, getType(name)).end()) + case f: V2IsNull if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + Some(builder.startAnd().isNull(dataTypeMap(colName).fieldName, getType(colName)).end()) - case IsNotNull(name) if dataTypeMap.contains(name) => - Some(builder.startNot() - .isNull(dataTypeMap(name).fieldName, getType(name)).end()) + case f: V2IsNotNull if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + Some(builder.startNot().isNull(dataTypeMap(colName).fieldName, getType(colName)).end()) - case In(name, values) if dataTypeMap.contains(name) => - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(name).fieldType)) - Some(builder.startAnd().in(dataTypeMap(name).fieldName, getType(name), + case f: V2In if dataTypeMap.contains(f.column.describe) => + val colName = f.column.describe + val castedValues = f.values.map(v => castLiteralValue(v, dataTypeMap(colName).fieldType)) + Some(builder.startAnd().in(dataTypeMap(colName).fieldName, getType(colName), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala index b7de20ae29349..31ae93775a717 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFiltersBase.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc import java.util.Locale import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.sources.{And, Filter} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Filter => V2Filter} import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField, StructType} /** @@ -28,14 +28,14 @@ import org.apache.spark.sql.types.{AtomicType, BinaryType, DataType, StructField */ trait OrcFiltersBase { - private[sql] def buildTree(filters: Seq[Filter]): Option[Filter] = { + private[sql] def buildTree(filters: Seq[V2Filter]): Option[V2Filter] = { filters match { case Seq() => None case Seq(filter) => Some(filter) - case Seq(filter1, filter2) => Some(And(filter1, filter2)) + case Seq(filter1, filter2) => Some(new V2And(filter1, filter2)) case _ => // length > 2 val (left, right) = filters.splitAt(filters.length / 2) - Some(And(buildTree(left).get, buildTree(right).get)) + Some(new V2And(buildTree(left).get, buildTree(right).get)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index ae8a092592c06..0e66ff32ac497 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,24 +18,30 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedDBObjectName, ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, EmptyRow, Expression, Literal, NamedExpression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.expressions.{FieldReference, Literal => V2Literal, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, Filter => V2Filter, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Not => V2Not, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.{FilterExec, LeafExecNode, LocalTableScanExec, ProjectExec, RowDataSourceScanExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PushableColumn, PushableColumnBase} import org.apache.spark.sql.execution.streaming.continuous.{WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources.{BaseRelation, TableScan} +import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.UTF8String class DataSourceV2Strategy(session: SparkSession) extends Strategy with PredicateHelper { @@ -427,3 +433,157 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => Nil } } + +private[sql] object DataSourceV2Strategy { + + private def translateLeafNodeFilterV2( + predicate: Expression, + pushableColumn: PushableColumnBase): Option[V2Filter] = predicate match { + case expressions.EqualTo(pushableColumn(name), Literal(v, t)) => + Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) + case expressions.EqualTo(Literal(v, t), pushableColumn(name)) => + Some(new V2EqualTo(FieldReference(name), LiteralValue(v, t))) + + case expressions.EqualNullSafe(pushableColumn(name), Literal(v, t)) => + Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) + case expressions.EqualNullSafe(Literal(v, t), pushableColumn(name)) => + Some(new V2EqualNullSafe(FieldReference(name), LiteralValue(v, t))) + + case expressions.GreaterThan(pushableColumn(name), Literal(v, t)) => + Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) + case expressions.GreaterThan(Literal(v, t), pushableColumn(name)) => + Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) + + case expressions.LessThan(pushableColumn(name), Literal(v, t)) => + Some(new V2LessThan(FieldReference(name), LiteralValue(v, t))) + case expressions.LessThan(Literal(v, t), pushableColumn(name)) => + Some(new V2GreaterThan(FieldReference(name), LiteralValue(v, t))) + + case expressions.GreaterThanOrEqual(pushableColumn(name), Literal(v, t)) => + Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), pushableColumn(name)) => + Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) + + case expressions.LessThanOrEqual(pushableColumn(name), Literal(v, t)) => + Some(new V2LessThanOrEqual(FieldReference(name), LiteralValue(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), pushableColumn(name)) => + Some(new V2GreaterThanOrEqual(FieldReference(name), LiteralValue(v, t))) + + case in @ expressions.InSet(pushableColumn(name), set) => + val values: Array[V2Literal[_]] = + set.toSeq.map(elem => LiteralValue(elem, in.dataType)).toArray + Some(new V2In(FieldReference(name), values)) + + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case in @ expressions.In(pushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) => + val hSet = list.map(_.eval(EmptyRow)) + Some(new V2In(FieldReference(name), + hSet.toArray.map(LiteralValue(_, in.value.dataType)))) + + case expressions.IsNull(pushableColumn(name)) => + Some(new V2IsNull(FieldReference(name))) + case expressions.IsNotNull(pushableColumn(name)) => + Some(new V2IsNotNull(FieldReference(name))) + case expressions.StartsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(new V2StringStartsWith(FieldReference(name), v)) + + case expressions.EndsWith(pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(new V2StringEndsWith(FieldReference(name), v)) + + case expressions.Contains(pushableColumn(name), Literal(v: UTF8String, StringType)) => + Some(new V2StringContains(FieldReference(name), v)) + + case expressions.Literal(true, BooleanType) => + Some(new V2AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(new V2AlwaysFalse) + + case _ => None + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2( + predicate: Expression, + supportNestedPredicatePushdown: Boolean): Option[V2Filter] = { + translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown) + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @param predicate The input [[Expression]] to be translated as [[Filter]] + * @param translatedFilterToExpr An optional map from leaf node filter expressions to its + * translated [[Filter]]. The map is used for rebuilding + * [[Expression]] from [[Filter]]. + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2WithMapping( + predicate: Expression, + translatedFilterToExpr: Option[mutable.HashMap[V2Filter, Expression]], + nestedPredicatePushdownEnabled: Boolean) + : Option[V2Filter] = { + predicate match { + case expressions.And(left, right) => + // See SPARK-12218 for detailed discussion + // It is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) + // and we do not understand how to convert trim(b) = 'blah'. + // If we only convert a = 2, we will end up with + // (a = 2) OR (c > 0), which will generate wrong results. + // Pushing one leg of AND down is only safe to do at the top level. + // You can see ParquetFilters' createFilter for more details. + for { + leftFilter <- translateFilterV2WithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterV2WithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + } yield new V2And(leftFilter, rightFilter) + + case expressions.Or(left, right) => + for { + leftFilter <- translateFilterV2WithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterV2WithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + } yield new V2Or(leftFilter, rightFilter) + + case expressions.Not(child) => + translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) + .map(new V2Not(_)) + + case other => + val filter = translateLeafNodeFilterV2( + other, PushableColumn(nestedPredicatePushdownEnabled)) + if (filter.isDefined && translatedFilterToExpr.isDefined) { + translatedFilterToExpr.get(filter.get) = predicate + } + filter + } + } + + protected[sql] def rebuildExpressionFromFilter( + filter: V2Filter, + translatedFilterToExpr: mutable.HashMap[V2Filter, Expression]): Expression = { + filter match { + case and: V2And => + expressions.And(rebuildExpressionFromFilter(and.left, translatedFilterToExpr), + rebuildExpressionFromFilter(and.right, translatedFilterToExpr)) + case or: V2Or => + expressions.Or(rebuildExpressionFromFilter(or.left, translatedFilterToExpr), + rebuildExpressionFromFilter(or.right, translatedFilterToExpr)) + case not: V2Not => + expressions.Not(rebuildExpressionFromFilter(not.child, translatedFilterToExpr)) + case other => + translatedFilterToExpr.getOrElse(other, + throw QueryCompilationErrors.failedToRebuildExpressionError(filter)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 8b0328cabc5a8..b43800239b686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -27,12 +27,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.connector.SupportsMetadata -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -200,7 +200,7 @@ trait FileScan extends Scan StructType(readDataSchema.fields ++ readPartitionSchema.fields) // Returns whether the two given arrays of [[Filter]]s are equivalent. - protected def equivalentFilters(a: Array[Filter], b: Array[Filter]): Boolean = { + protected def equivalentFilters(a: Array[V2Filter], b: Array[V2Filter]): Boolean = { a.sortBy(_.hashCode()).sameElements(b.sortBy(_.hashCode())) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 309f045201140..6c92c1d7f8321 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.{sources, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType abstract class FileScanBuilder( @@ -39,7 +39,7 @@ abstract class FileScanBuilder( protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) protected var partitionFilters = Seq.empty[Expression] protected var dataFilters = Seq.empty[Expression] - protected var pushedDataFilters = Array.empty[Filter] + protected var pushedDataFilters = Array.empty[V2Filter] override def pruneColumns(requiredSchema: StructType): Unit = { // [SPARK-30107] While `requiredSchema` might have pruned nested columns, @@ -73,9 +73,9 @@ abstract class FileScanBuilder( DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters) this.partitionFilters = partitionFilters this.dataFilters = dataFilters - val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] + val translatedFilters = mutable.ArrayBuffer.empty[V2Filter] for (filterExpr <- dataFilters) { - val translated = DataSourceStrategy.translateFilter(filterExpr, true) + val translated = DataSourceV2Strategy.translateFilterV2(filterExpr, true) if (translated.nonEmpty) { translatedFilters += translated.get } @@ -84,14 +84,15 @@ abstract class FileScanBuilder( dataFilters } - override def pushedFilters: Array[Filter] = pushedDataFilters + override def pushedFilters: Array[V2Filter] = pushedDataFilters /* * Push down data filters to the file source, so the data filters can be evaluated there to * reduce the size of the data to be read. By default, data filters are not pushed down. * File source needs to implement this method to push down data filters. */ - protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] + protected def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = + Array.empty[V2Filter] private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index 7229488026bc5..65cc1380d52bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -71,7 +71,7 @@ object PushDownUtils extends PredicateHelper { case f: FileScanBuilder => val postScanFilters = f.pushFilters(filters) - (f.pushedFilters, postScanFilters) + (f.pushedFilters.map(_.toV1).toSeq, postScanFilters) case _ => (Nil, filters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala index 31d31bd43f453..23badc05be2cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVPartitionReaderFactory.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -45,7 +45,7 @@ case class CSVPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, parsedOptions: CSVOptions, - filters: Seq[Filter]) extends FilePartitionReaderFactory { + filters: Seq[V2Filter]) extends FilePartitionReaderFactory { private val columnPruning = sqlConf.csvColumnPruning override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { @@ -58,7 +58,7 @@ case class CSVPartitionReaderFactory( actualDataSchema, actualReadDataSchema, parsedOptions, - filters) + filters.map(_.toV1)) val schema = if (columnPruning) actualReadDataSchema else actualDataSchema val isStartOfFile = file.start == 0 val headerChecker = new CSVHeaderChecker( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index cc3c146106670..46990316e0ec6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -23,12 +23,12 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.csv.CSVOptions import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -40,7 +40,7 @@ case class CSVScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter], + pushedFilters: Array[V2Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index 2b6edd4f357ca..fbe81fc0e0fb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -47,11 +47,11 @@ case class CSVScanBuilder( dataFilters) } - override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = { if (sparkSession.sessionState.conf.csvFilterPushDown) { - StructFilters.pushedFilters(dataFilters, dataSchema) + StructFilters.pushedFiltersV2(dataFilters, dataSchema) } else { - Array.empty[Filter] + Array.empty[V2Filter] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonPartitionReaderFactory.scala index 9737803b597a5..eee428ea15795 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonPartitionReaderFactory.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution.datasources.v2.json import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.{JacksonParser, JSONOptionsInRead} +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.json.JsonDataSource import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -46,7 +46,7 @@ case class JsonPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, parsedOptions: JSONOptionsInRead, - filters: Seq[Filter]) extends FilePartitionReaderFactory { + filters: Seq[V2Filter]) extends FilePartitionReaderFactory { override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = { val actualSchema = @@ -55,7 +55,7 @@ case class JsonPartitionReaderFactory( actualSchema, parsedOptions, allowArrayAsStructs = true, - filters) + filters.map((_.toV1))) val iter = JsonDataSource(parsedOptions).readFile( broadcastedConf.value.value, partitionedFile, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 9ab367136fc97..39b41c1a8aadd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -24,12 +24,12 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Expression, ExprUtils} import org.apache.spark.sql.catalyst.json.JSONOptionsInRead import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -41,7 +41,7 @@ case class JsonScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter], + pushedFilters: Array[V2Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends TextBasedFileScan(sparkSession, options) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala index c581617a4b7e4..57a09a9c6043e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2.json import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -45,11 +45,11 @@ class JsonScanBuilder ( dataFilters) } - override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = { if (sparkSession.sessionState.conf.jsonFilterPushDown) { - StructFilters.pushedFilters(dataFilters, dataSchema) + StructFilters.pushedFiltersV2(dataFilters, dataSchema) } else { - Array.empty[Filter] + Array.empty[V2Filter] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index c5020cb79524c..45386b75fe818 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -29,13 +29,13 @@ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -55,7 +55,7 @@ case class OrcPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory { + filters: Array[V2Filter]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 7619e3c503139..b8d147928dc5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -21,10 +21,10 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -37,7 +37,7 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, - pushedFilters: Array[Filter], + pushedFilters: Array[V2Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index cfa396f5482f4..af35dff286f4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -49,13 +49,13 @@ case class OrcScanBuilder( readPartitionSchema(), options, pushedDataFilters, partitionFilters, dataFilters) } - override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap( readDataSchema(), SQLConf.get.caseSensitiveAnalysis) OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray } else { - Array.empty[Filter] + Array.empty[V2Filter] } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index e277e334845c9..e6ced640be54f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,12 +24,12 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -41,7 +41,7 @@ case class ParquetScan( dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, - pushedFilters: Array[Filter], + pushedFilters: Array[V2Filter], options: CaseInsensitiveStringMap, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { @@ -85,7 +85,7 @@ case class ParquetScan( dataSchema, readDataSchema, readPartitionSchema, - pushedFilters, + pushedFilters.map(_.toV1), new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index ff5137e928db3..c2cef816ed295 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter} import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -63,20 +63,21 @@ case class ParquetScanBuilder( // The rebase mode doesn't matter here because the filters are used to determine // whether they is convertible. LegacyBehaviorPolicy.CORRECTED) - parquetFilters.convertibleFilters(pushedDataFilters).toArray + parquetFilters.convertibleFilters(pushedDataFilters.map(_.toV1)).toArray } override protected val supportsNestedSchemaPruning: Boolean = true - override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters + override def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = dataFilters // Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]]. // It requires the Parquet physical schema to determine whether a filter is convertible. // All filters that can be converted to Parquet are pushed down. - override def pushedFilters(): Array[Filter] = pushedParquetFilters + override def pushedFilters(): Array[V2Filter] = pushedParquetFilters.map(_.toV2) override def build(): Scan = { ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options, partitionFilters, dataFilters) + readPartitionSchema(), pushedParquetFilters.map(_.toV2), options, partitionFilters, + dataFilters) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index 688ded0c3eb39..1e2af1522756a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -458,11 +458,11 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite val basePath = dir.getCanonicalPath + "/" + fmt val pushFilterMaps = Map ( "parquet" -> - "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", + "|PushedFilters: \\[value IS NOT NULL, value > 2\\]", "orc" -> - "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", + "|PushedFilters: \\[value IS NOT NULL, value > 2\\]", "csv" -> - "|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]", + "|PushedFilters: \\[value IS NOT NULL, value > 2\\]", "json" -> "|remove_marker" ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index ab28b0594c404..85a9b40a993f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -887,10 +887,10 @@ class FileBasedDataSourceSuite extends QueryTest format match { case "orc" => assert(scan.isInstanceOf[OrcScan]) - assert(scan.asInstanceOf[OrcScan].pushedFilters === filters) + assert(scan.asInstanceOf[OrcScan].pushedFilters.map(_.toV1) === filters) case "parquet" => assert(scan.isInstanceOf[ParquetScan]) - assert(scan.asInstanceOf[ParquetScan].pushedFilters === filters) + assert(scan.asInstanceOf[ParquetScan].pushedFilters.map(_.toV1) === filters) case _ => fail(s"unknown format $format") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index d0877dbf316c7..94c1308a9632d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -354,18 +354,18 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f.map(_.toV2), o, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => - OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), + OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f.map(_.toV2), pf, df), Seq.empty), ("CSVScan", - (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), + (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f.map(_.toV2), pf, df), Seq.empty), ("JsonScan", - (s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df), - Seq.empty), + (s, fi, ds, rds, rps, f, o, pf, df) => + JsonScan(s, fi, ds, rds, rps, o, f.map(_.toV2), pf, df), Seq.empty), ("TextScan", (s, fi, ds, rds, rps, _, o, pf, df) => TextScan(s, fi, ds, rds, rps, o, pf, df), Seq("dataSchema", "pushedFilters"))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 45591565a9522..a322965c7c567 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3032,10 +3032,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark format match { case "orc" => assert(scan.isInstanceOf[OrcScan]) - assert(scan.asInstanceOf[OrcScan].pushedFilters === filters) + assert(scan.asInstanceOf[OrcScan].pushedFilters.map(_.toV1) === filters) case "parquet" => assert(scan.isInstanceOf[ParquetScan]) - assert(scan.asInstanceOf[ParquetScan].pushedFilters === filters) + assert(scan.asInstanceOf[ParquetScan].pushedFilters.map(_.toV1) === filters) case _ => fail(s"unknown format $format") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index f7f1d0b847cc1..7bfe368e58e03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -34,6 +34,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} +import org.apache.spark.sql.connector.expressions.FieldReference +import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter, IsNotNull} import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, DataSource, InMemoryFileIndex, NoopCache} import org.apache.spark.sql.execution.datasources.v2.json.JsonScanBuilder @@ -3019,7 +3021,7 @@ class JsonV2Suite extends JsonSuite { val options = CaseInsensitiveStringMap.empty() new JsonScanBuilder(spark, fileIndex, schema, schema, options) } - val filters: Array[sources.Filter] = Array(sources.IsNotNull(attr)) + val filters: Array[V2Filter] = Array(new IsNotNull(new FieldReference(Array(attr)))) withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 4243318ac1dd8..da26f04d6474c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -125,7 +125,7 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor assert(o.pushedFilters.isEmpty, "Unsupported filters should not show in pushed filters") } else { assert(o.pushedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters) + val maybeFilter = OrcFilters.createFilter(query.schema, o.pushedFilters.map(_.toV1)) assert(maybeFilter.isEmpty, s"Couldn't generate filter predicate for ${o.pushedFilters}") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala index d0032df488f47..e7673ebeb08ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -55,7 +55,7 @@ class OrcV1FilterSuite extends OrcFilterSuite { DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) assert(selectedFilters.nonEmpty, "No filter is pushed down") - val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters) + val maybeFilter = OrcFilters.createFilter(query.schema, selectedFilters.map(_.toV2)) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") checker(maybeFilter.get) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 8354158533ee5..2d51c416900ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -1930,9 +1930,10 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema) val parquetFilters = createParquetFilters(schema) // In this test suite, all the simple predicates are convertible here. - assert(parquetFilters.convertibleFilters(sourceFilters) === pushedFilters) + assert(parquetFilters.convertibleFilters(sourceFilters).toArray + === pushedFilters.map(_.toV1)) val pushedParquetFilters = pushedFilters.map { pred => - val maybeFilter = parquetFilters.createFilter(pred) + val maybeFilter = parquetFilters.createFilter(pred.toV1) assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $pred") maybeFilter.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala index 325f4923bd6c6..4f5590a9129e1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFsSuite.scala @@ -161,7 +161,7 @@ class ExtractPythonUDFsSuite extends SparkPlanTest with SharedSparkSession { // 'a is not null and 'a > 1 val filters = scanNodes.head.scan.asInstanceOf[ParquetScan].pushedFilters assert(filters.length == 2) - assert(filters.flatMap(_.references).distinct === Array("a")) + assert(filters.flatMap(_.references.map(_.describe())).distinct === Array("a")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala index 33b2db57d9f0f..cb350f1ec339a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FiltersSuite.scala @@ -18,6 +18,11 @@ package org.apache.spark.sql.sources import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, EqualNullSafe => V2EqualNullSafe, EqualTo => V2EqualTo, GreaterThan => V2GreaterThan, GreaterThanOrEqual => V2GreaterThanOrEqual, In => V2In, IsNotNull => V2IsNotNull, IsNull => V2IsNull, LessThan => V2LessThan, LessThanOrEqual => V2LessThanOrEqual, Or => V2Or, StringContains => V2StringContains, StringEndsWith => V2StringEndsWith, StringStartsWith => V2StringStartsWith} +import org.apache.spark.sql.sources.FiltersSuite.ref +import org.apache.spark.sql.types.{StringType} +import org.apache.spark.unsafe.types.UTF8String /** * Unit test suites for data source filters. @@ -47,6 +52,15 @@ class FiltersSuite extends SparkFunSuite { == Seq(Seq("b"), fieldNames.toSeq)) }} + test("EqualTo V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = EqualTo(name, "1") + val v2Filter = new V2EqualTo(ref(name), LiteralValue(UTF8String.fromString("1"), StringType)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("EqualNullSafe references") { withFieldNames { (name, fieldNames) => assert(EqualNullSafe(name, "1").references.toSeq == Seq(name)) assert(EqualNullSafe(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) @@ -60,6 +74,16 @@ class FiltersSuite extends SparkFunSuite { == Seq(Seq("b"), fieldNames.toSeq)) }} + test("EqualNullSafe V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = EqualNullSafe(name, "1") + val v2Filter = new V2EqualNullSafe(ref(name), + LiteralValue(UTF8String.fromString("1"), StringType)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("GreaterThan references") { withFieldNames { (name, fieldNames) => assert(GreaterThan(name, "1").references.toSeq == Seq(name)) assert(GreaterThan(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) @@ -73,6 +97,16 @@ class FiltersSuite extends SparkFunSuite { == Seq(Seq("b"), fieldNames.toSeq)) }} + test("GreaterThan V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = GreaterThan(name, "1") + val v2Filter = new V2GreaterThan(ref(name), + LiteralValue(UTF8String.fromString("1"), StringType)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("GreaterThanOrEqual references") { withFieldNames { (name, fieldNames) => assert(GreaterThanOrEqual(name, "1").references.toSeq == Seq(name)) assert(GreaterThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) @@ -86,6 +120,16 @@ class FiltersSuite extends SparkFunSuite { == Seq(Seq("b"), fieldNames.toSeq)) }} + test("GreaterThanOrEqual V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = GreaterThanOrEqual(name, "1") + val v2Filter = new V2GreaterThanOrEqual(ref(name), + LiteralValue(UTF8String.fromString("1"), StringType)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("LessThan references") { withFieldNames { (name, fieldNames) => assert(LessThan(name, "1").references.toSeq == Seq(name)) assert(LessThan(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) @@ -93,6 +137,15 @@ class FiltersSuite extends SparkFunSuite { assert(LessThan("a", EqualTo("b", "2")).references.toSeq == Seq("a", "b")) }} + test("LessThan V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = LessThan(name, "1") + val v2Filter = new V2LessThan(ref(name), LiteralValue(UTF8String.fromString("1"), StringType)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("LessThanOrEqual references") { withFieldNames { (name, fieldNames) => assert(LessThanOrEqual(name, "1").references.toSeq == Seq(name)) assert(LessThanOrEqual(name, "1").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) @@ -106,6 +159,16 @@ class FiltersSuite extends SparkFunSuite { == Seq(Seq("b"), fieldNames.toSeq)) }} + test("LessThanOrEqual V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = LessThanOrEqual(name, "1") + val v2Filter = new V2LessThanOrEqual(ref(name), + LiteralValue(UTF8String.fromString("1"), StringType)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("In references") { withFieldNames { (name, fieldNames) => assert(In(name, Array("1")).references.toSeq == Seq(name)) assert(In(name, Array("1")).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) @@ -119,16 +182,46 @@ class FiltersSuite extends SparkFunSuite { == Seq(Seq("b"), fieldNames.toSeq)) }} + test("In V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = In(name, Array("1", "2", "3", "4")) + val v2Filter = new V2In(ref(name), Array(LiteralValue(UTF8String.fromString("1"), StringType), + LiteralValue(UTF8String.fromString("2"), StringType), + LiteralValue(UTF8String.fromString("3"), StringType), + LiteralValue(UTF8String.fromString("4"), StringType))) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("IsNull references") { withFieldNames { (name, fieldNames) => assert(IsNull(name).references.toSeq == Seq(name)) assert(IsNull(name).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} + test("IsNull V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = IsNull(name) + val v2Filter = new V2IsNull(ref(name)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("IsNotNull references") { withFieldNames { (name, fieldNames) => assert(IsNotNull(name).references.toSeq == Seq(name)) assert(IsNull(name).v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} + test("IsNotNull V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = IsNotNull(name) + val v2Filter = new V2IsNotNull(ref(name)) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("And references") { withFieldNames { (name, fieldNames) => assert(And(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) assert(And(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) @@ -139,6 +232,17 @@ class FiltersSuite extends SparkFunSuite { Seq(Seq("b"), fieldNames.toSeq)) }} + test("And V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = And(EqualTo(name, "1"), EqualTo("b", "1")) + val v2Filter = new V2And(new V2EqualTo(ref(name), + LiteralValue(UTF8String.fromString("1"), StringType)), + new V2EqualTo(ref("b"), LiteralValue(UTF8String.fromString("1"), StringType))) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("Or references") { withFieldNames { (name, fieldNames) => assert(Or(EqualTo(name, "1"), EqualTo("b", "1")).references.toSeq == Seq(name, "b")) assert(Or(EqualTo("b", "1"), EqualTo(name, "1")).references.toSeq == Seq("b", name)) @@ -149,18 +253,62 @@ class FiltersSuite extends SparkFunSuite { Seq(Seq("b"), fieldNames.toSeq)) }} + test("Or V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = Or(EqualTo(name, "1"), EqualTo("b", "1")) + val v2Filter = new V2Or(new V2EqualTo(ref(name), + LiteralValue(UTF8String.fromString("1"), StringType)), + new V2EqualTo(ref("b"), LiteralValue(UTF8String.fromString("1"), StringType))) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("StringStartsWith references") { withFieldNames { (name, fieldNames) => assert(StringStartsWith(name, "str").references.toSeq == Seq(name)) assert(StringStartsWith(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} + test("StringStartsWith V1 V2 conversion") { withFieldNames { (name, fieldNames) => + val v1Filter = StringStartsWith(name, "str") + val v2Filter = new V2StringStartsWith(ref(name), UTF8String.fromString("str")) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("StringEndsWith references") { withFieldNames { (name, fieldNames) => assert(StringEndsWith(name, "str").references.toSeq == Seq(name)) assert(StringEndsWith(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} + test("StringEndsWith V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = StringEndsWith(name, "str") + val v2Filter = new V2StringEndsWith(ref(name), UTF8String.fromString("str")) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} + test("StringContains references") { withFieldNames { (name, fieldNames) => assert(StringContains(name, "str").references.toSeq == Seq(name)) assert(StringContains(name, "str").v2references.toSeq.map(_.toSeq) == Seq(fieldNames.toSeq)) }} + + test("StringContains V1 V2 conversion") { withFieldNames { (name, _) => + val v1Filter = StringContains(name, "str") + val v2Filter = new V2StringContains(ref(name), UTF8String.fromString("str")) + assert(v1Filter.toV2 == v2Filter) + assert(v2Filter.toV1 == v1Filter) + assert(v1Filter.toV2.toV1 == v1Filter) + assert(v2Filter.toV1.toV2 == v2Filter) + }} +} + +object FiltersSuite { + def ref(parts: String): NamedReference = { + FieldReference(parts) + } }