Skip to content
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,10 @@ trait Row extends Serializable {
case (r: Row, _) => r.jsonValue
case (v: Any, udt: UserDefinedType[Any @unchecked]) =>
val dataType = udt.sqlType
toJson(CatalystTypeConverters.convertToScala(udt.serialize(v), dataType), dataType)
toJson(CatalystTypeConverters.convertToScala(
udt.serialize(v),
dataType,
SQLConf.get.datetimeJava8ApiEnabled), dataType)
case _ =>
throw new IllegalArgumentException(s"Failed to convert value $value " +
s"(class of ${value.getClass}}) with the type of $dataType to JSON.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,17 @@ object CatalystTypeConverters {
}
}

private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = {
private def getConverterForType(
dataType: DataType,
useJava8DateTimeApi: Boolean): CatalystTypeConverter[Any, Any, Any] = {
val converter = dataType match {
case udt: UserDefinedType[_] => UDTConverter(udt)
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
case structType: StructType => StructConverter(structType)
case arrayType: ArrayType => ArrayConverter(arrayType.elementType, useJava8DateTimeApi)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType, useJava8DateTimeApi)
case structType: StructType => StructConverter(structType, useJava8DateTimeApi)
case StringType => StringConverter
case DateType if SQLConf.get.datetimeJava8ApiEnabled => LocalDateConverter
case DateType => DateConverter
case TimestampType if SQLConf.get.datetimeJava8ApiEnabled => InstantConverter
case TimestampType => TimestampConverter
case DateType => if (useJava8DateTimeApi) LocalDateConverter else DateConverter
case TimestampType => if (useJava8DateTimeApi) InstantConverter else TimestampConverter
case dt: DecimalType => new DecimalConverter(dt)
case BooleanType => BooleanConverter
case ByteType => ByteConverter
Expand Down Expand Up @@ -156,9 +156,10 @@ object CatalystTypeConverters {

/** Converter for arrays, sequences, and Java iterables. */
private case class ArrayConverter(
elementType: DataType) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {
elementType: DataType,
useJava8DateTimeApi: Boolean) extends CatalystTypeConverter[Any, Seq[Any], ArrayData] {

private[this] val elementConverter = getConverterForType(elementType)
private[this] val elementConverter = getConverterForType(elementType, useJava8DateTimeApi)

override def toCatalystImpl(scalaValue: Any): ArrayData = {
scalaValue match {
Expand Down Expand Up @@ -200,11 +201,12 @@ object CatalystTypeConverters {

private case class MapConverter(
keyType: DataType,
valueType: DataType)
valueType: DataType,
useJava8DateTimeApi: Boolean)
extends CatalystTypeConverter[Any, Map[Any, Any], MapData] {

private[this] val keyConverter = getConverterForType(keyType)
private[this] val valueConverter = getConverterForType(valueType)
private[this] val keyConverter = getConverterForType(keyType, useJava8DateTimeApi)
private[this] val valueConverter = getConverterForType(valueType, useJava8DateTimeApi)

override def toCatalystImpl(scalaValue: Any): MapData = {
val keyFunction = (k: Any) => keyConverter.toCatalyst(k)
Expand Down Expand Up @@ -240,9 +242,11 @@ object CatalystTypeConverters {
}

private case class StructConverter(
structType: StructType) extends CatalystTypeConverter[Any, Row, InternalRow] {
structType: StructType,
useJava8DateTimeApi: Boolean) extends CatalystTypeConverter[Any, Row, InternalRow] {

private[this] val converters = structType.fields.map { f => getConverterForType(f.dataType) }
private[this] val converters = structType.fields
.map { f => getConverterForType(f.dataType, useJava8DateTimeApi) }

override def toCatalystImpl(scalaValue: Any): InternalRow = scalaValue match {
case row: Row =>
Expand Down Expand Up @@ -404,7 +408,9 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
def createToCatalystConverter(dataType: DataType): Any => Any = {
def createToCatalystConverter(
dataType: DataType,
useJava8DateTimeApi: Boolean = SQLConf.get.datetimeJava8ApiEnabled): Any => Any = {
if (isPrimitive(dataType)) {
// Although the `else` branch here is capable of handling inbound conversion of primitives,
// we add some special-case handling for those types here. The motivation for this relates to
Expand All @@ -422,7 +428,7 @@ object CatalystTypeConverters {
}
convert
} else {
getConverterForType(dataType).toCatalyst
getConverterForType(dataType, useJava8DateTimeApi).toCatalyst
}
}

Expand All @@ -431,11 +437,13 @@ object CatalystTypeConverters {
* Typical use case would be converting a collection of rows that have the same schema. You will
* call this function once to get a converter, and apply it to every row.
*/
def createToScalaConverter(dataType: DataType): Any => Any = {
def createToScalaConverter(
dataType: DataType,
useJava8DateTimeApi: Boolean = SQLConf.get.datetimeJava8ApiEnabled): Any => Any = {
if (isPrimitive(dataType)) {
identity
} else {
getConverterForType(dataType).toScala
getConverterForType(dataType, useJava8DateTimeApi).toScala
}
}

Expand Down Expand Up @@ -470,7 +478,7 @@ object CatalystTypeConverters {
* This method is slow, and for batch conversion you should be using converter
* produced by createToScalaConverter.
*/
def convertToScala(catalystValue: Any, dataType: DataType): Any = {
createToScalaConverter(dataType)(catalystValue)
def convertToScala(catalystValue: Any, dataType: DataType, useJava8DateTimeApi: Boolean): Any = {
createToScalaConverter(dataType, useJava8DateTimeApi)(catalystValue)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}
import java.time.LocalDate

import scala.language.implicitConversions

Expand Down Expand Up @@ -146,6 +147,7 @@ package object dsl {
implicit def doubleToLiteral(d: Double): Literal = Literal(d)
implicit def stringToLiteral(s: String): Literal = Literal.create(s, StringType)
implicit def dateToLiteral(d: Date): Literal = Literal(d)
implicit def localDateToLiteral(d: LocalDate): Literal = Literal(d)
implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying())
implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d)
implicit def decimalToLiteral(d: Decimal): Literal = Literal(d)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,45 +450,45 @@ object DataSourceStrategy {

private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = predicate match {
case expressions.EqualTo(PushableColumn(name), Literal(v, t)) =>
Some(sources.EqualTo(name, convertToScala(v, t)))
Some(sources.EqualTo(name, convertToScala(v, t, false)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we're treating this as a temp fix for Spark 3.0?
Looks like ideally we should support Java 8 datetime instances for this interface as well when spark.sql.datetime.java8API.enabled is enabled. It could cause more confusion. In addition, seems like spark.sql.datetime.java8API.enabled is disabled by default, too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's problematic to let the java8 config also control the value type inside Filter, as it can break existing DS v1 implementations. It's a bit unfortunate that we don't document clearly what the value type can be for Filter, but if we do, it's not user-friendly to say "the value type depends on xxx config". This just makes it harder to implement data source filter pushdown.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon Taking into account #23811 (comment), the flag won't be enabled by default in the near future.

case expressions.EqualTo(Literal(v, t), PushableColumn(name)) =>
Some(sources.EqualTo(name, convertToScala(v, t)))
Some(sources.EqualTo(name, convertToScala(v, t, false)))

case expressions.EqualNullSafe(PushableColumn(name), Literal(v, t)) =>
Some(sources.EqualNullSafe(name, convertToScala(v, t)))
Some(sources.EqualNullSafe(name, convertToScala(v, t, false)))
case expressions.EqualNullSafe(Literal(v, t), PushableColumn(name)) =>
Some(sources.EqualNullSafe(name, convertToScala(v, t)))
Some(sources.EqualNullSafe(name, convertToScala(v, t, false)))

case expressions.GreaterThan(PushableColumn(name), Literal(v, t)) =>
Some(sources.GreaterThan(name, convertToScala(v, t)))
Some(sources.GreaterThan(name, convertToScala(v, t, false)))
case expressions.GreaterThan(Literal(v, t), PushableColumn(name)) =>
Some(sources.LessThan(name, convertToScala(v, t)))
Some(sources.LessThan(name, convertToScala(v, t, false)))

case expressions.LessThan(PushableColumn(name), Literal(v, t)) =>
Some(sources.LessThan(name, convertToScala(v, t)))
Some(sources.LessThan(name, convertToScala(v, t, false)))
case expressions.LessThan(Literal(v, t), PushableColumn(name)) =>
Some(sources.GreaterThan(name, convertToScala(v, t)))
Some(sources.GreaterThan(name, convertToScala(v, t, false)))

case expressions.GreaterThanOrEqual(PushableColumn(name), Literal(v, t)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t, false)))
case expressions.GreaterThanOrEqual(Literal(v, t), PushableColumn(name)) =>
Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
Some(sources.LessThanOrEqual(name, convertToScala(v, t, false)))

case expressions.LessThanOrEqual(PushableColumn(name), Literal(v, t)) =>
Some(sources.LessThanOrEqual(name, convertToScala(v, t)))
Some(sources.LessThanOrEqual(name, convertToScala(v, t, false)))
case expressions.LessThanOrEqual(Literal(v, t), PushableColumn(name)) =>
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t)))
Some(sources.GreaterThanOrEqual(name, convertToScala(v, t, false)))

case expressions.InSet(e @ PushableColumn(name), set) =>
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType, false)
Some(sources.In(name, set.toArray.map(toScala)))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(e @ PushableColumn(name), list) if list.forall(_.isInstanceOf[Literal]) =>
val hSet = list.map(_.eval(EmptyRow))
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType)
val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType, false)
Some(sources.In(name, hSet.toArray.map(toScala)))

case expressions.IsNull(PushableColumn(name)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet
import java.math.{BigDecimal => JBigDecimal}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.LocalDate

import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate, Operators}
import org.apache.parquet.filter2.predicate.FilterApi._
Expand Down Expand Up @@ -1561,6 +1562,63 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
}
}
}

test("filter pushdown - local date") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just update test("filter pushdown - date") to test with DATETIME_JAVA8API_ENABLED on and off, so that we have less duplicated code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

implicit class StringToDate(s: String) {
def date: LocalDate = LocalDate.parse(s)
}

val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21").map(_.date)
import testImplicits._
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
withNestedDataFrame(data.map(i => Tuple1(i)).toDF()) { case (inputDF, colName, resultFun) =>
withParquetDataFrame(inputDF) { implicit df =>
val dateAttr: Expression = df(colName).expr
assert(df(colName).expr.dataType === DateType)

checkFilterPredicate(dateAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
checkFilterPredicate(dateAttr.isNotNull, classOf[NotEq[_]],
data.map(i => Row.apply(resultFun(i))))

checkFilterPredicate(dateAttr === "2018-03-18".date, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr <=> "2018-03-18".date, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr =!= "2018-03-18".date, classOf[NotEq[_]],
Seq("2018-03-19", "2018-03-20", "2018-03-21").map(i => Row.apply(resultFun(i.date))))

checkFilterPredicate(dateAttr < "2018-03-19".date, classOf[Lt[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr > "2018-03-20".date, classOf[Gt[_]],
resultFun("2018-03-21".date))
checkFilterPredicate(dateAttr <= "2018-03-18".date, classOf[LtEq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(dateAttr >= "2018-03-21".date, classOf[GtEq[_]],
resultFun("2018-03-21".date))

checkFilterPredicate(Literal("2018-03-18".date) === dateAttr, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-18".date) <=> dateAttr, classOf[Eq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-19".date) > dateAttr, classOf[Lt[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-20".date) < dateAttr, classOf[Gt[_]],
resultFun("2018-03-21".date))
checkFilterPredicate(Literal("2018-03-18".date) >= dateAttr, classOf[LtEq[_]],
resultFun("2018-03-18".date))
checkFilterPredicate(Literal("2018-03-21".date) <= dateAttr, classOf[GtEq[_]],
resultFun("2018-03-21".date))

checkFilterPredicate(!(dateAttr < "2018-03-21".date), classOf[GtEq[_]],
resultFun("2018-03-21".date))
checkFilterPredicate(
dateAttr < "2018-03-19".date || dateAttr > "2018-03-20".date,
classOf[Operators.Or],
Seq(Row(resultFun("2018-03-18".date)), Row(resultFun("2018-03-21".date))))
}
}
}
}
}

class ParquetV1FilterSuite extends ParquetFilterSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc
import java.math.MathContext
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.LocalDate

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -450,5 +451,31 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
).get.toString
}
}

test("filter pushdown - local date") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day =>
LocalDate.parse(day)
}
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
withOrcDataFrame(dates.map(Tuple1(_))) { implicit df =>
checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL)

checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS)
checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS)

checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN)

checkFilterPredicate(Literal(dates(0)) === $"_1", PredicateLeaf.Operator.EQUALS)
checkFilterPredicate(Literal(dates(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS)
checkFilterPredicate(Literal(dates(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate(Literal(dates(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(Literal(dates(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(Literal(dates(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN)
}
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.orc
import java.math.MathContext
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.time.LocalDate

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -451,5 +452,31 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession {
).get.toString
}
}

test("filter pushdown - local date") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

val dates = Seq("2017-08-18", "2017-08-19", "2017-08-20", "2017-08-21").map { day =>
LocalDate.parse(day)
}
withSQLConf(SQLConf.DATETIME_JAVA8API_ENABLED.key -> "true") {
withOrcDataFrame(dates.map(Tuple1(_))) { implicit df =>
checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL)

checkFilterPredicate($"_1" === dates(0), PredicateLeaf.Operator.EQUALS)
checkFilterPredicate($"_1" <=> dates(0), PredicateLeaf.Operator.NULL_SAFE_EQUALS)

checkFilterPredicate($"_1" < dates(1), PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate($"_1" > dates(2), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" <= dates(0), PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate($"_1" >= dates(3), PredicateLeaf.Operator.LESS_THAN)

checkFilterPredicate(Literal(dates(0)) === $"_1", PredicateLeaf.Operator.EQUALS)
checkFilterPredicate(Literal(dates(0)) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS)
checkFilterPredicate(Literal(dates(1)) > $"_1", PredicateLeaf.Operator.LESS_THAN)
checkFilterPredicate(Literal(dates(2)) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(Literal(dates(0)) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS)
checkFilterPredicate(Literal(dates(3)) <= $"_1", PredicateLeaf.Operator.LESS_THAN)
}
}
}
}