Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,55 @@ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName

import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] object ParquetFilters {
case class SetInFilter[T <: Comparable[T]](
valueSet: Set[T]) extends UserDefinedPredicate[T] with Serializable {
case class InSetFilter[T <: Comparable[T]](valueSet: Set[T])
extends UserDefinedPredicate[T] {

private val min = valueSet.min
private val max = valueSet.max

override def keep(value: T): Boolean = {
value != null && valueSet.contains(value)
}

override def canDrop(statistics: Statistics[T]): Boolean = false
override def canDrop(statistics: Statistics[T]): Boolean = {
statistics.getMax.compareTo(min) < 0 || statistics.getMin.compareTo(max) > 0
}

override def inverseCanDrop(statistics: Statistics[T]): Boolean = false
}

abstract class StringFilter extends UserDefinedPredicate[Binary] {
override def canDrop(statistics: Statistics[Binary]): Boolean = false
override def inverseCanDrop(statistics: Statistics[Binary]): Boolean = false

def binaryToUTF8String(value: Binary): UTF8String = {
// This is a trick used in CatalystStringConverter to steal the underlying
// byte array of the binary without copying it.
val buffer = value.toByteBuffer
val offset = buffer.position()
val numBytes = buffer.limit() - buffer.position()
UTF8String.fromBytes(buffer.array(), offset, numBytes)
}
}

case class StringStartsWithFilter(prefix: String) extends StringFilter {
private val strToCompare: UTF8String = UTF8String.fromString(prefix)
override def keep(value: Binary): Boolean = binaryToUTF8String(value).startsWith(strToCompare)
}

case class StringEndsWithFilter(suffix: String) extends StringFilter {
private val strToCompare: UTF8String = UTF8String.fromString(suffix)
override def keep(value: Binary): Boolean = binaryToUTF8String(value).endsWith(strToCompare)
}

case class StringContainsFilter(str: String) extends StringFilter {
private val strToCompare: UTF8String = UTF8String.fromString(str)
override def keep(value: Binary): Boolean = binaryToUTF8String(value).contains(strToCompare)
}

private val makeEq: PartialFunction[DataType, (String, Any) => FilterPredicate] = {
case BooleanType =>
(n: String, v: Any) => FilterApi.eq(booleanColumn(n), v.asInstanceOf[java.lang.Boolean])
Expand Down Expand Up @@ -157,27 +192,54 @@ private[sql] object ParquetFilters {
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
}

private val makeStringStartsFilter: PartialFunction[DataType,
(String, String) => FilterPredicate] = {
case StringType =>
(n: String, v: String) =>
FilterApi.userDefined(binaryColumn(n),
StringStartsWithFilter(v.asInstanceOf[java.lang.String]))
}

private val makeStringEndsFilter: PartialFunction[DataType,
(String, String) => FilterPredicate] = {
case StringType =>
(n: String, v: String) =>
FilterApi.userDefined(binaryColumn(n),
StringEndsWithFilter(v.asInstanceOf[java.lang.String]))
}

private val makeStringContainsFilter: PartialFunction[DataType,
(String, String) => FilterPredicate] = {
case StringType =>
(n: String, v: String) =>
FilterApi.userDefined(binaryColumn(n),
StringContainsFilter(v.asInstanceOf[java.lang.String]))
}

private val makeInSet: PartialFunction[DataType, (String, Set[Any]) => FilterPredicate] = {
case BooleanType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(booleanColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Boolean]]))
case IntegerType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(intColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Integer]]))
FilterApi.userDefined(intColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Integer]]))
case LongType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(longColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Long]]))
FilterApi.userDefined(longColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Long]]))
case FloatType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(floatColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Float]]))
FilterApi.userDefined(floatColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Float]]))
case DoubleType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(doubleColumn(n), SetInFilter(v.asInstanceOf[Set[java.lang.Double]]))
FilterApi.userDefined(doubleColumn(n), InSetFilter(v.asInstanceOf[Set[java.lang.Double]]))
case StringType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
InSetFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
case BinaryType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]]))))
InSetFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[Array[Byte]]))))
}

/**
Expand Down Expand Up @@ -209,6 +271,9 @@ private[sql] object ParquetFilters {
case sources.IsNotNull(name) =>
makeNotEq.lift(dataTypeOf(name)).map(_(name, null))

case sources.In(name, values) =>
makeInSet.lift(dataTypeOf(name)).map(_(name, values.toSet))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this back :) I think this is worth a separate JIRA ticket to track. Would you mind to open one and add the ID to the PR title?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I will open another ticket.

case sources.EqualTo(name, value) =>
makeEq.lift(dataTypeOf(name)).map(_(name, value))
case sources.Not(sources.EqualTo(name, value)) =>
Expand All @@ -229,6 +294,13 @@ private[sql] object ParquetFilters {
case sources.GreaterThanOrEqual(name, value) =>
makeGtEq.lift(dataTypeOf(name)).map(_(name, value))

case sources.StringStartsWith(name, value) =>
makeStringStartsFilter.lift(dataTypeOf(name)).map(_(name, value))
case sources.StringEndsWith(name, value) =>
makeStringEndsFilter.lift(dataTypeOf(name)).map(_(name, value))
case sources.StringContains(name, value) =>
makeStringContainsFilter.lift(dataTypeOf(name)).map(_(name, value))

case sources.And(lhs, rhs) =>
(createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
checkFilterPredicate('_1 === true, classOf[Eq[_]], true)
checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true)
checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false)

checkFilterPredicate(
('_1.in(true)).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], true)
checkFilterPredicate(
('_1.in(false)).asInstanceOf[Predicate], classOf[UserDefinedByInstance[_, _]], false)
}
}

Expand All @@ -138,6 +143,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex

checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))

checkFilterPredicate(
('_1.in(1, 2)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(1), Row(2)))
checkFilterPredicate(
('_1.in(3, 4)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(3), Row(4)))
}
}

Expand All @@ -164,6 +178,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex

checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))

checkFilterPredicate(
('_1.in(1L, 2L)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(1L), Row(2L)))
checkFilterPredicate(
('_1.in(3L, 4L)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(3L), Row(4L)))
}
}

Expand All @@ -190,6 +213,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex

checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))

checkFilterPredicate(
('_1.in(1.0f, 2.0f)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(1.0f), Row(2.0f)))
checkFilterPredicate(
('_1.in(3.0f, 4.0f)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(3.0f), Row(4.0f)))
}
}

Expand All @@ -216,6 +248,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex

checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))

checkFilterPredicate(
('_1.in(1.0, 2.0)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(1.0), Row(2.0)))
checkFilterPredicate(
('_1.in(3.0, 4.0)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(3.0), Row(4.0)))
}
}

Expand Down Expand Up @@ -244,6 +285,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex

checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))

checkFilterPredicate(
('_1.in("1", "2")).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row("1"), Row("2")))
checkFilterPredicate(
('_1.in("3", "4")).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row("3"), Row("4")))
}

withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString * 5 + "test"))) { implicit df =>
checkFilterPredicate(
('_1 contains "11").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"11111test")

checkFilterPredicate(
('_1 contains "2test").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"22222test")

checkFilterPredicate(
('_1 contains "3t").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"33333test")

checkFilterPredicate(
('_1 startsWith "22").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"22222test")

checkFilterPredicate(
('_1 endsWith "4test").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"44444test")

checkFilterPredicate(
('_1 endsWith "2test").asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
"22222test")
}
}

Expand Down Expand Up @@ -278,6 +360,15 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
checkBinaryFilterPredicate(
'_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b)))

checkFilterPredicate(
('_1.in(1.b, 2.b)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(1.b), Row(2.b)))
checkFilterPredicate(
('_1.in(3.b, 4.b)).asInstanceOf[Predicate],
classOf[UserDefinedByInstance[_, _]],
Seq(Row(3.b), Row(4.b)))
}
}

Expand Down