Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.{AvroDeserializer, AvroOptions, AvroUtils}
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, OrderedFilters}
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.PartitionReader
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile}
import org.apache.spark.sql.execution.datasources.v2.{EmptyPartitionReader, FilePartitionReaderFactory, PartitionReaderWithPartitionValues}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.SerializableConfiguration

Expand All @@ -55,7 +55,7 @@ case class AvroPartitionReaderFactory(
readDataSchema: StructType,
partitionSchema: StructType,
parsedOptions: AvroOptions,
filters: Seq[Filter]) extends FilePartitionReaderFactory with Logging {
filters: Seq[V2Filter]) extends FilePartitionReaderFactory with Logging {
private val datetimeRebaseModeInRead = parsedOptions.datetimeRebaseModeInRead

override def buildReader(partitionedFile: PartitionedFile): PartitionReader[InternalRow] = {
Expand Down Expand Up @@ -94,7 +94,7 @@ case class AvroPartitionReaderFactory(
datetimeRebaseModeInRead)

val avroFilters = if (SQLConf.get.avroFilterPushDown) {
new OrderedFilters(filters, readDataSchema)
new OrderedFilters(filters.map(_.toV1), readDataSchema)
} else {
new NoopFilters
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.avro.AvroOptions
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
Expand All @@ -38,7 +38,7 @@ case class AvroScan(
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
pushedFilters: Array[Filter],
pushedFilters: Array[V2Filter],
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ package org.apache.spark.sql.v2.avro

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.StructFilters
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

Expand All @@ -46,11 +46,11 @@ class AvroScanBuilder (
dataFilters)
}

override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
override def pushDataFilters(dataFilters: Array[V2Filter]): Array[V2Filter] = {
if (sparkSession.sessionState.conf.avroFilterPushDown) {
StructFilters.pushedFilters(dataFilters, dataSchema)
StructFilters.pushedFiltersV2(dataFilters, dataSchema)
} else {
Array.empty[Filter]
Array.empty[V2Filter]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import org.apache.spark.sql.v2.avro.AvroScan
class AvroScanSuite extends FileScanSuiteBase {
val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
("AvroScan",
(s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o, f, pf, df),
Seq.empty))
(s, fi, ds, rds, rps, f, o, pf, df) => AvroScan(s, fi, ds, rds, rps, o,
f.map(_.toV2), pf, df), Seq.empty))

run(scanBuilders)
}
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
|Format: avro
|Location: InMemoryFileIndex\\([0-9]+ paths\\)\\[.*\\]
|PartitionFilters: \\[isnotnull\\(id#x\\), \\(id#x > 1\\)\\]
|PushedFilters: \\[IsNotNull\\(value\\), GreaterThan\\(value,2\\)\\]
|PushedFilters: \\[value IS NOT NULL, value > 2\\]
|ReadSchema: struct\\<value:bigint\\>
|""".stripMargin.trim
spark.range(10)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ public int hashCode() {

@Override
public NamedReference[] references() { return EMPTY_REFERENCE; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.AlwaysFalse();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,9 @@ public int hashCode() {

@Override
public NamedReference[] references() { return EMPTY_REFERENCE; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.AlwaysTrue();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ public And(Filter left, Filter right) {
public String toString() {
return String.format("(%s) AND (%s)", left.describe(), right.describe());
}

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.And(left.toV1(), right.toV1());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.filter;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand All @@ -37,4 +38,10 @@ public EqualNullSafe(NamedReference column, Literal<?> value) {

@Override
public String toString() { return this.column.describe() + " <=> " + value.describe(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.EqualNullSafe(
column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.filter;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand All @@ -36,4 +37,10 @@ public EqualTo(NamedReference column, Literal<?> value) {

@Override
public String toString() { return column.describe() + " = " + value.describe(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.EqualTo(
(column).describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connector.expressions.filter;

import java.io.Serializable;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.NamedReference;
Expand All @@ -27,7 +29,7 @@
* @since 3.3.0
*/
@Evolving
public abstract class Filter implements Expression {
public abstract class Filter implements Expression, Serializable {

protected static final NamedReference[] EMPTY_REFERENCE = new NamedReference[0];

Expand All @@ -38,4 +40,9 @@ public abstract class Filter implements Expression {

@Override
public String describe() { return this.toString(); }

/**
* Returns a V1 Filter.
*/
public abstract org.apache.spark.sql.sources.Filter toV1();
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.filter;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand All @@ -36,4 +37,11 @@ public GreaterThan(NamedReference column, Literal<?> value) {

@Override
public String toString() { return column.describe() + " > " + value.describe(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.GreaterThan(
column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.filter;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand All @@ -36,4 +37,10 @@ public GreaterThanOrEqual(NamedReference column, Literal<?> value) {

@Override
public String toString() { return column.describe() + " >= " + value.describe(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.GreaterThanOrEqual(
column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.stream.Collectors;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand Down Expand Up @@ -73,4 +74,15 @@ public String toString() {

@Override
public NamedReference[] references() { return new NamedReference[] { column }; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
Object[] array = new Object[values.length];
int index = 0;
for (Literal value: values) {
array[index] = CatalystTypeConverters.convertToScala(value.value(), value.dataType());
index++;
}
return new org.apache.spark.sql.sources.In(column.describe(), array);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,9 @@ public int hashCode() {

@Override
public NamedReference[] references() { return new NamedReference[] { column }; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.IsNotNull(column.describe());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,9 @@ public int hashCode() {

@Override
public NamedReference[] references() { return new NamedReference[] { column }; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.IsNull(column.describe());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.filter;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand All @@ -36,4 +37,10 @@ public LessThan(NamedReference column, Literal<?> value) {

@Override
public String toString() { return column.describe() + " < " + value.describe(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.LessThan(
column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.filter;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.CatalystTypeConverters;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;

Expand All @@ -36,4 +37,10 @@ public LessThanOrEqual(NamedReference column, Literal<?> value) {

@Override
public String toString() { return column.describe() + " <= " + value.describe(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.LessThanOrEqual(
column.describe(), CatalystTypeConverters.convertToScala(value.value(), value.dataType()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,9 @@ public int hashCode() {

@Override
public NamedReference[] references() { return child.references(); }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.Not(child.toV1());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ public Or(Filter left, Filter right) {
public String toString() {
return String.format("(%s) OR (%s)", left.describe(), right.describe());
}

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.Or(left.toV1(), right.toV1());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ public StringContains(NamedReference column, UTF8String value) {

@Override
public String toString() { return "STRING_CONTAINS(" + column.describe() + ", " + value + ")"; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.StringContains(column.describe(), value.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,9 @@ public StringEndsWith(NamedReference column, UTF8String value) {

@Override
public String toString() { return "STRING_ENDS_WITH(" + column.describe() + ", " + value + ")"; }

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.StringEndsWith(column.describe(), value.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ public StringStartsWith(NamedReference column, UTF8String value) {
public String toString() {
return "STRING_STARTS_WITH(" + column.describe() + ", " + value + ")";
}

@Override
public org.apache.spark.sql.sources.Filter toV1() {
return new org.apache.spark.sql.sources.StringStartsWith(column.describe(), value.toString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.util.Try

import org.apache.spark.sql.catalyst.StructFilters._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.connector.expressions.filter.{Filter => V2Filter}
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.{BooleanType, StructType}

Expand Down Expand Up @@ -93,6 +94,16 @@ object StructFilters {
filters.filter(checkFilterRefs(_, fieldNames))
}

private def checkFilterRefsV2(filter: V2Filter, fieldNames: Set[String]): Boolean = {
// The names have been normalized and case sensitivity is not a concern here.
filter.references.map(_.fieldNames().mkString(".")).forall(fieldNames.contains)
}

def pushedFiltersV2(filters: Array[V2Filter], schema: StructType): Array[V2Filter] = {
val fieldNames = schema.fieldNames.toSet
filters.filter(checkFilterRefsV2(_, fieldNames))
}

private def zip[A, B](a: Option[A], b: Option[B]): Option[(A, B)] = {
a.zip(b).headOption
}
Expand Down
Loading