Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.lang.{Boolean => JBoolean, Double => JDouble, Float => JFloat, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.sql.{Date, Timestamp}
import java.time.LocalDate
import java.time.{Instant, LocalDate}
import java.util.Locale

import scala.collection.JavaConverters.asScalaBufferConverter
Expand Down Expand Up @@ -129,6 +129,11 @@ class ParquetFilters(
case ld: LocalDate => DateTimeUtils.localDateToDays(ld)
}

private def timestampToMicros(v: Any): JLong = v match {
case i: Instant => DateTimeUtils.instantToMicros(i)
case t: Timestamp => DateTimeUtils.fromJavaTimestamp(t)
}

private def decimalToInt32(decimal: JBigDecimal): Integer = decimal.unscaledValue().intValue()

private def decimalToInt64(decimal: JBigDecimal): JLong = decimal.unscaledValue().longValue()
Expand All @@ -149,8 +154,7 @@ class ParquetFilters(
}

private def timestampToMillis(v: Any): JLong = {
val timestamp = v.asInstanceOf[Timestamp]
val micros = DateTimeUtils.fromJavaTimestamp(timestamp)
val micros = timestampToMicros(v)
val millis = DateTimeUtils.microsToMillis(micros)
millis.asInstanceOf[JLong]
}
Expand Down Expand Up @@ -186,8 +190,7 @@ class ParquetFilters(
case ParquetTimestampMicrosType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.eq(
longColumn(n),
Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp])
.asInstanceOf[JLong]).orNull)
Option(v).map(timestampToMicros).orNull)
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.eq(
longColumn(n),
Expand Down Expand Up @@ -237,8 +240,7 @@ class ParquetFilters(
case ParquetTimestampMicrosType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.notEq(
longColumn(n),
Option(v).map(t => DateTimeUtils.fromJavaTimestamp(t.asInstanceOf[Timestamp])
.asInstanceOf[JLong]).orNull)
Option(v).map(timestampToMicros).orNull)
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.notEq(
longColumn(n),
Expand Down Expand Up @@ -280,9 +282,7 @@ class ParquetFilters(
(n: Array[String], v: Any) =>
FilterApi.lt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
case ParquetTimestampMicrosType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.lt(
longColumn(n),
DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong])
(n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.lt(longColumn(n), timestampToMillis(v))

Expand Down Expand Up @@ -319,9 +319,7 @@ class ParquetFilters(
(n: Array[String], v: Any) =>
FilterApi.ltEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
case ParquetTimestampMicrosType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.ltEq(
longColumn(n),
DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong])
(n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.ltEq(longColumn(n), timestampToMillis(v))

Expand Down Expand Up @@ -358,9 +356,7 @@ class ParquetFilters(
(n: Array[String], v: Any) =>
FilterApi.gt(intColumn(n), dateToDays(v).asInstanceOf[Integer])
case ParquetTimestampMicrosType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.gt(
longColumn(n),
DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong])
(n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.gt(longColumn(n), timestampToMillis(v))

Expand Down Expand Up @@ -397,9 +393,7 @@ class ParquetFilters(
(n: Array[String], v: Any) =>
FilterApi.gtEq(intColumn(n), dateToDays(v).asInstanceOf[Integer])
case ParquetTimestampMicrosType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.gtEq(
longColumn(n),
DateTimeUtils.fromJavaTimestamp(v.asInstanceOf[Timestamp]).asInstanceOf[JLong])
(n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMicros(v))
case ParquetTimestampMillisType if pushDownTimestamp =>
(n: Array[String], v: Any) => FilterApi.gtEq(longColumn(n), timestampToMillis(v))

Expand Down Expand Up @@ -475,7 +469,7 @@ class ParquetFilters(
case ParquetDateType =>
value.isInstanceOf[Date] || value.isInstanceOf[LocalDate]
case ParquetTimestampMicrosType | ParquetTimestampMillisType =>
value.isInstanceOf[Timestamp]
value.isInstanceOf[Timestamp] || value.isInstanceOf[Instant]
case ParquetSchemaType(DECIMAL, INT32, _, decimalMeta) =>
isDecimalMatched(value, decimalMeta)
case ParquetSchemaType(DECIMAL, INT64, _, decimalMeta) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.math.{BigDecimal => JBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.LocalDate
import java.time.{LocalDate, LocalDateTime, ZoneId}

import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
Expand Down Expand Up @@ -143,15 +143,29 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}

private def testTimestampPushdown(data: Seq[Timestamp]): Unit = {
private def testTimestampPushdown(data: Seq[String], java8Api: Boolean): Unit = {
implicit class StringToTs(s: String) {
def ts: Timestamp = Timestamp.valueOf(s)
}
assert(data.size === 4)
val ts1 = data.head
val ts2 = data(1)
val ts3 = data(2)
val ts4 = data(3)

import testImplicits._
withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) =>
val df = data.map(i => Tuple1(Timestamp.valueOf(i))).toDF()
withNestedDataFrame(df) { case (inputDF, colName, fun) =>
def resultFun(tsStr: String): Any = {
val parsed = if (java8Api) {
LocalDateTime.parse(tsStr.replace(" ", "T"))
.atZone(ZoneId.systemDefault())
.toInstant
} else {
Timestamp.valueOf(tsStr)
}
fun(parsed)
}
withParquetDataFrame(inputDF) { implicit df =>
val tsAttr = df(colName).expr
assert(df(colName).expr.dataType === TimestampType)
Expand All @@ -160,26 +174,26 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]],
data.map(i => Row.apply(resultFun(i))))

checkFilterPredicate(tsAttr === ts1, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(tsAttr <=> ts1, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(tsAttr =!= ts1, classOf[NotEq[_]],
checkFilterPredicate(tsAttr === ts1.ts, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(tsAttr <=> ts1.ts, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(tsAttr =!= ts1.ts, classOf[NotEq[_]],
Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))

checkFilterPredicate(tsAttr < ts2, classOf[Lt[_]], resultFun(ts1))
checkFilterPredicate(tsAttr > ts1, classOf[Gt[_]],
checkFilterPredicate(tsAttr < ts2.ts, classOf[Lt[_]], resultFun(ts1))
checkFilterPredicate(tsAttr > ts1.ts, classOf[Gt[_]],
Seq(ts2, ts3, ts4).map(i => Row.apply(resultFun(i))))
checkFilterPredicate(tsAttr <= ts1, classOf[LtEq[_]], resultFun(ts1))
checkFilterPredicate(tsAttr >= ts4, classOf[GtEq[_]], resultFun(ts4))

checkFilterPredicate(Literal(ts1) === tsAttr, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts1) <=> tsAttr, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts2) > tsAttr, classOf[Lt[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts3) < tsAttr, classOf[Gt[_]], resultFun(ts4))
checkFilterPredicate(Literal(ts1) >= tsAttr, classOf[LtEq[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts4) <= tsAttr, classOf[GtEq[_]], resultFun(ts4))

checkFilterPredicate(!(tsAttr < ts4), classOf[GtEq[_]], resultFun(ts4))
checkFilterPredicate(tsAttr < ts2 || tsAttr > ts3, classOf[Operators.Or],
checkFilterPredicate(tsAttr <= ts1.ts, classOf[LtEq[_]], resultFun(ts1))
checkFilterPredicate(tsAttr >= ts4.ts, classOf[GtEq[_]], resultFun(ts4))

checkFilterPredicate(Literal(ts1.ts) === tsAttr, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts1.ts) <=> tsAttr, classOf[Eq[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts2.ts) > tsAttr, classOf[Lt[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts3.ts) < tsAttr, classOf[Gt[_]], resultFun(ts4))
checkFilterPredicate(Literal(ts1.ts) >= tsAttr, classOf[LtEq[_]], resultFun(ts1))
checkFilterPredicate(Literal(ts4.ts) <= tsAttr, classOf[GtEq[_]], resultFun(ts4))

checkFilterPredicate(!(tsAttr < ts4.ts), classOf[GtEq[_]], resultFun(ts4))
checkFilterPredicate(tsAttr < ts2.ts || tsAttr > ts3.ts, classOf[Operators.Or],
Seq(Row(resultFun(ts1)), Row(resultFun(ts4))))
}
}
Expand Down Expand Up @@ -588,36 +602,41 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}

test("filter pushdown - timestamp") {
// spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS
val millisData = Seq(
Timestamp.valueOf("1000-06-14 08:28:53.123"),
Timestamp.valueOf("1582-06-15 08:28:53.001"),
Timestamp.valueOf("1900-06-16 08:28:53.0"),
Timestamp.valueOf("2018-06-17 08:28:53.999"))
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) {
testTimestampPushdown(millisData)
}

// spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS
val microsData = Seq(
Timestamp.valueOf("1000-06-14 08:28:53.123456"),
Timestamp.valueOf("1582-06-15 08:28:53.123456"),
Timestamp.valueOf("1900-06-16 08:28:53.123456"),
Timestamp.valueOf("2018-06-17 08:28:53.123456"))
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) {
testTimestampPushdown(microsData)
}

// spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.INT96.toString) {
import testImplicits._
withParquetDataFrame(millisData.map(i => Tuple1(i)).toDF()) { implicit df =>
val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
assertResult(None) {
createParquetFilters(schema).createFilter(sources.IsNull("_1"))
Seq(true, false).foreach { java8Api =>
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> java8Api.toString) {
// spark.sql.parquet.outputTimestampType = TIMESTAMP_MILLIS
val millisData = Seq(
"1000-06-14 08:28:53.123",
"1582-06-15 08:28:53.001",
"1900-06-16 08:28:53.0",
"2018-06-17 08:28:53.999")
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.TIMESTAMP_MILLIS.toString) {
testTimestampPushdown(millisData, java8Api)
}

// spark.sql.parquet.outputTimestampType = TIMESTAMP_MICROS
val microsData = Seq(
"1000-06-14 08:28:53.123456",
"1582-06-15 08:28:53.123456",
"1900-06-16 08:28:53.123456",
"2018-06-17 08:28:53.123456")
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.TIMESTAMP_MICROS.toString) {
testTimestampPushdown(microsData, java8Api)
}

// spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
ParquetOutputTimestampType.INT96.toString) {
import testImplicits._
withParquetDataFrame(
millisData.map(i => Tuple1(Timestamp.valueOf(i))).toDF()) { implicit df =>
val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
assertResult(None) {
createParquetFilters(schema).createFilter(sources.IsNull("_1"))
}
}
}
}
}
Expand Down