diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index d44cb11e28762..4338360137b82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -444,69 +444,81 @@ object DataSourceStrategy { } } - private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = predicate match { - case expressions.EqualTo(a: Attribute, Literal(v, t)) => - Some(sources.EqualTo(a.name, convertToScala(v, t))) - case expressions.EqualTo(Literal(v, t), a: Attribute) => - Some(sources.EqualTo(a.name, convertToScala(v, t))) - - case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) => - Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - case expressions.EqualNullSafe(Literal(v, t), a: Attribute) => - Some(sources.EqualNullSafe(a.name, convertToScala(v, t))) - - case expressions.GreaterThan(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThan(a.name, convertToScala(v, t))) - case expressions.GreaterThan(Literal(v, t), a: Attribute) => - Some(sources.LessThan(a.name, convertToScala(v, t))) - - case expressions.LessThan(a: Attribute, Literal(v, t)) => - Some(sources.LessThan(a.name, convertToScala(v, t))) - case expressions.LessThan(Literal(v, t), a: Attribute) => - Some(sources.GreaterThan(a.name, convertToScala(v, t))) - - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) - - case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) => - Some(sources.LessThanOrEqual(a.name, convertToScala(v, t))) - case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) => - Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t))) - - case expressions.InSet(a: Attribute, set) => - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(a.name, set.toArray.map(toScala))) - - // Because we only convert In to InSet in Optimizer when there are more than certain - // items. So it is possible we still get an In expression here that needs to be pushed - // down. - case expressions.In(a: Attribute, list) if list.forall(_.isInstanceOf[Literal]) => - val hSet = list.map(_.eval(EmptyRow)) - val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType) - Some(sources.In(a.name, hSet.toArray.map(toScala))) - - case expressions.IsNull(a: Attribute) => - Some(sources.IsNull(a.name)) - case expressions.IsNotNull(a: Attribute) => - Some(sources.IsNotNull(a.name)) - case expressions.StartsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringStartsWith(a.name, v.toString)) - - case expressions.EndsWith(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringEndsWith(a.name, v.toString)) - - case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => - Some(sources.StringContains(a.name, v.toString)) - - case expressions.Literal(true, BooleanType) => - Some(sources.AlwaysTrue) - - case expressions.Literal(false, BooleanType) => - Some(sources.AlwaysFalse) - - case _ => None + private def translateLeafNodeFilter(predicate: Expression): Option[Filter] = { + // Recursively try to find an attribute name from the top level that can be pushed down. + def attrName(e: Expression): Option[String] = e match { + case a: Attribute if a.dataType != StructType => + Some(a.name) + case s: GetStructField if s.childSchema(s.ordinal).dataType != StructType => + attrName(s.child).map(_ + s".${s.childSchema(s.ordinal).name}") + case _ => + None + } + + predicate match { + case expressions.EqualTo(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.EqualTo(name, convertToScala(v, t))) + case expressions.EqualTo(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.EqualTo(name, convertToScala(v, t))) + + case expressions.EqualNullSafe(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.EqualNullSafe(name, convertToScala(v, t))) + case expressions.EqualNullSafe(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.EqualNullSafe(name, convertToScala(v, t))) + + case expressions.GreaterThan(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.GreaterThan(name, convertToScala(v, t))) + case expressions.GreaterThan(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.LessThan(name, convertToScala(v, t))) + + case expressions.LessThan(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.LessThan(name, convertToScala(v, t))) + case expressions.LessThan(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.GreaterThan(name, convertToScala(v, t))) + + case expressions.GreaterThanOrEqual(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.GreaterThanOrEqual(name, convertToScala(v, t))) + case expressions.GreaterThanOrEqual(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.LessThanOrEqual(name, convertToScala(v, t))) + + case expressions.LessThanOrEqual(e: Expression, Literal(v, t)) => + attrName(e).map(name => sources.LessThanOrEqual(name, convertToScala(v, t))) + case expressions.LessThanOrEqual(Literal(v, t), e: Expression) => + attrName(e).map(name => sources.GreaterThanOrEqual(name, convertToScala(v, t))) + + case expressions.InSet(e: Expression, set) => + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + attrName(e).map(name => sources.In(name, set.toArray.map(toScala))) + + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case expressions.In(e: Expression, list) if list.forall(_.isInstanceOf[Literal]) => + val hSet = list.map(_.eval(EmptyRow)) + val toScala = CatalystTypeConverters.createToScalaConverter(e.dataType) + attrName(e).map(name => sources.In(name, hSet.toArray.map(toScala))) + + case expressions.IsNull(e: Expression) => + attrName(e).map(name => sources.IsNull(name)) + case expressions.IsNotNull(e: Expression) => + attrName(e).map(name => sources.IsNotNull(name)) + case expressions.StartsWith(e: Expression, Literal(v: UTF8String, StringType)) => + attrName(e).map(name => sources.StringStartsWith(name, v.toString)) + + case expressions.EndsWith(e: Expression, Literal(v: UTF8String, StringType)) => + attrName(e).map(name => sources.StringEndsWith(name, v.toString)) + + case expressions.Contains(e: Expression, Literal(v: UTF8String, StringType)) => + attrName(e).map(name => sources.StringContains(name, v.toString)) + + case expressions.Literal(true, BooleanType) => + Some(sources.AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(sources.AlwaysFalse) + + case _ => None + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index b9b86adb438e6..3cd0d840c08bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -27,7 +27,7 @@ import scala.collection.JavaConverters.asScalaBufferConverter import org.apache.parquet.filter2.predicate._ import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary -import org.apache.parquet.schema.{DecimalMetadata, MessageType, OriginalType, PrimitiveComparator} +import org.apache.parquet.schema._ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -49,16 +49,47 @@ class ParquetFilters( pushDownInFilterThreshold: Int, caseSensitive: Boolean) { // A map which contains parquet field name and data type, if predicate push down applies. - private val nameToParquetField : Map[String, ParquetField] = { - // Here we don't flatten the fields in the nested schema but just look up through - // root fields. Currently, accessing to nested fields does not push down filters - // and it does not support to create filters for them. - val primitiveFields = - schema.getFields.asScala.filter(_.isPrimitive).map(_.asPrimitiveType()).map { f => - f.getName -> ParquetField(f.getName, - ParquetSchemaType(f.getOriginalType, - f.getPrimitiveTypeName, f.getTypeLength, f.getDecimalMetadata)) + private val nameToParquetField : Map[String, ParquetPrimitiveField] = { + def canPushDownField(field: Type): Boolean = { + if (field.getName.contains(".")) { + // Parquet does not allow dots in the column name because dots are used as a column path + // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates + // with missing columns. The incorrect results could be got from Parquet when we push down + // filters for the column having dots in the names. Thus, we do not push down such filters. + // See SPARK-20364. + false + } else { + field match { + case _: PrimitiveType => true + // Parquet only supports push-down for primitive types; as a result, Map and List types + // are filtered out. FYI, when g is a `Struct`, `g.getOriginalType` is `null`. + // When g is a `Map`, `g.getOriginalType` is `MAP`. + // When g is a `List`, `g.getOriginalType` is `LIST`. + case g: GroupType if g.getOriginalType == null => true + case _ => false + } + } + } + + def getFieldMapHelper( + fields: Seq[Type], + baseName: Option[String] = None): Seq[(String, ParquetPrimitiveField)] = { + fields.filter(canPushDownField).flatMap { field => + val name = baseName.map(_ + "." + field.getName).getOrElse(field.getName) + field match { + case p: PrimitiveType => + val primitiveField = ParquetPrimitiveField(fieldName = name, + fieldType = ParquetSchemaType(p.getOriginalType, + p.getPrimitiveTypeName, p.getTypeLength, p.getDecimalMetadata)) + Some((name, primitiveField)) + case g: GroupType => + getFieldMapHelper(g.getFields.asScala, Some(name)) + } + } } + + val primitiveFields = getFieldMapHelper(schema.getFields.asScala) + if (caseSensitive) { primitiveFields.toMap } else { @@ -74,13 +105,14 @@ class ParquetFilters( } } + /** * Holds a single field information stored in the underlying parquet file. * * @param fieldName field name in parquet file * @param fieldType field type related info in parquet file */ - private case class ParquetField( + private case class ParquetPrimitiveField( fieldName: String, fieldType: ParquetSchemaType) @@ -466,13 +498,8 @@ class ParquetFilters( case _ => false } - // Parquet does not allow dots in the column name because dots are used as a column path - // delimiter. Since Parquet 1.8.2 (PARQUET-389), Parquet accepts the filter predicates - // with missing columns. The incorrect results could be got from Parquet when we push down - // filters for the column having dots in the names. Thus, we do not push down such filters. - // See SPARK-20364. private def canMakeFilterOn(name: String, value: Any): Boolean = { - nameToParquetField.contains(name) && !name.contains(".") && valueCanMakeFilterOn(name, value) + nameToParquetField.contains(name) && valueCanMakeFilterOn(name, value) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 1421ffd8b6de4..d55ebdf73b2a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -59,8 +59,7 @@ case class OrcScanBuilder( // changed `hadoopConf` in executors. OrcInputFormat.setSearchArgument(hadoopConf, f, schema.fieldNames) } - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap - _pushedFilters = OrcFilters.convertibleFilters(schema, dataTypeMap, filters).toArray + _pushedFilters = OrcFilters.convertibleFilters(schema, filters).toArray } filters } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index b76db70494cf8..d8e0fa34144f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -22,76 +22,153 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.sources import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DataSourceStrategySuite extends PlanTest with SharedSparkSession { test("translate simple expression") { + val fields = StructField("cint", IntegerType, nullable = true) :: + StructField("cstr", StringType, nullable = true) :: Nil val attrInt = 'cint.int val attrStr = 'cstr.string + val attrNested1 = 'a.struct(StructType(fields)) + val attrNested2 = 'b.struct(StructType( + StructField("c", StructType(fields), nullable = true) :: Nil)) + + val attrIntNested1 = GetStructField(attrNested1, 0, None) + val attrStrNested1 = GetStructField(attrNested1, 1, None) + + val attrIntNested2 = GetStructField(GetStructField(attrNested2, 0, None), 0, None) + val attrStrNested2 = GetStructField(GetStructField(attrNested2, 0, None), 1, None) + + Seq(('cint.int, 'cstr.string, "cint", "cstr"), // no nesting + (attrIntNested1, attrStrNested1, "a.cint", "a.cstr"), // one level nesting + (attrIntNested2, attrStrNested2, "b.c.cint", "b.c.cstr") // two level nesting + ).foreach { case (attrInt, attrStr, attrIntString, attrStrString) => + testTranslateSimpleExpression( + attrInt, attrStr, attrIntString, attrStrString, isResultNone = false) + } + } + + test("translate complex expression") { + val fields = StructField("cint", IntegerType, nullable = true) :: Nil - testTranslateFilter(EqualTo(attrInt, 1), Some(sources.EqualTo("cint", 1))) - testTranslateFilter(EqualTo(1, attrInt), Some(sources.EqualTo("cint", 1))) + val attrNested1 = 'a.struct(StructType(fields)) + val attrNested2 = 'b.struct(StructType( + StructField("c", StructType(fields), nullable = true) :: Nil)) + + val attrIntNested1 = GetStructField(attrNested1, 0, None) + val attrIntNested2 = GetStructField(GetStructField(attrNested2, 0, None), 0, None) + + StructField("cint", IntegerType, nullable = true) + + Seq(('cint.int, "cint"), // no nesting + (attrIntNested1, "a.cint"), // one level nesting + (attrIntNested2, "b.c.cint") // two level nesting + ).foreach { case (attrInt, attrIntString) => + testTranslateComplexExpression(attrInt, attrIntString, isResultNone = false) + } + } + + // `isResultNone` is used when testing invalid input expression + // containing dots which translates into None + private def testTranslateSimpleExpression( + attrInt: Expression, + attrStr: Expression, + attrIntString: String, + attrStrString: String, + isResultNone: Boolean): Unit = { + + def result(result: sources.Filter): Option[sources.Filter] = { + if (isResultNone) { + None + } else { + Some(result) + } + } + + testTranslateFilter(EqualTo(attrInt, 1), result(sources.EqualTo(attrIntString, 1))) + testTranslateFilter(EqualTo(1, attrInt), result(sources.EqualTo(attrIntString, 1))) testTranslateFilter(EqualNullSafe(attrStr, Literal(null)), - Some(sources.EqualNullSafe("cstr", null))) + result(sources.EqualNullSafe(attrStrString, null))) testTranslateFilter(EqualNullSafe(Literal(null), attrStr), - Some(sources.EqualNullSafe("cstr", null))) + result(sources.EqualNullSafe(attrStrString, null))) - testTranslateFilter(GreaterThan(attrInt, 1), Some(sources.GreaterThan("cint", 1))) - testTranslateFilter(GreaterThan(1, attrInt), Some(sources.LessThan("cint", 1))) + testTranslateFilter(GreaterThan(attrInt, 1), result(sources.GreaterThan(attrIntString, 1))) + testTranslateFilter(GreaterThan(1, attrInt), result(sources.LessThan(attrIntString, 1))) - testTranslateFilter(LessThan(attrInt, 1), Some(sources.LessThan("cint", 1))) - testTranslateFilter(LessThan(1, attrInt), Some(sources.GreaterThan("cint", 1))) + testTranslateFilter(LessThan(attrInt, 1), result(sources.LessThan(attrIntString, 1))) + testTranslateFilter(LessThan(1, attrInt), result(sources.GreaterThan(attrIntString, 1))) - testTranslateFilter(GreaterThanOrEqual(attrInt, 1), Some(sources.GreaterThanOrEqual("cint", 1))) - testTranslateFilter(GreaterThanOrEqual(1, attrInt), Some(sources.LessThanOrEqual("cint", 1))) + testTranslateFilter(GreaterThanOrEqual(attrInt, 1), + result(sources.GreaterThanOrEqual(attrIntString, 1))) + testTranslateFilter(GreaterThanOrEqual(1, attrInt), + result(sources.LessThanOrEqual(attrIntString, 1))) - testTranslateFilter(LessThanOrEqual(attrInt, 1), Some(sources.LessThanOrEqual("cint", 1))) - testTranslateFilter(LessThanOrEqual(1, attrInt), Some(sources.GreaterThanOrEqual("cint", 1))) + testTranslateFilter(LessThanOrEqual(attrInt, 1), + result(sources.LessThanOrEqual(attrIntString, 1))) + testTranslateFilter(LessThanOrEqual(1, attrInt), + result(sources.GreaterThanOrEqual(attrIntString, 1))) - testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(InSet(attrInt, Set(1, 2, 3)), + result(sources.In(attrIntString, Array(1, 2, 3)))) - testTranslateFilter(In(attrInt, Seq(1, 2, 3)), Some(sources.In("cint", Array(1, 2, 3)))) + testTranslateFilter(In(attrInt, Seq(1, 2, 3)), + result(sources.In(attrIntString, Array(1, 2, 3)))) - testTranslateFilter(IsNull(attrInt), Some(sources.IsNull("cint"))) - testTranslateFilter(IsNotNull(attrInt), Some(sources.IsNotNull("cint"))) + testTranslateFilter(IsNull(attrInt), result(sources.IsNull(attrIntString))) + testTranslateFilter(IsNotNull(attrInt), result(sources.IsNotNull(attrIntString))) // cint > 1 AND cint < 10 testTranslateFilter(And( GreaterThan(attrInt, 1), LessThan(attrInt, 10)), - Some(sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)))) + result(sources.And( + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)))) // cint >= 8 OR cint <= 2 testTranslateFilter(Or( GreaterThanOrEqual(attrInt, 8), LessThanOrEqual(attrInt, 2)), - Some(sources.Or( - sources.GreaterThanOrEqual("cint", 8), - sources.LessThanOrEqual("cint", 2)))) + result(sources.Or( + sources.GreaterThanOrEqual(attrIntString, 8), + sources.LessThanOrEqual(attrIntString, 2)))) testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)), - Some(sources.Not(sources.GreaterThanOrEqual("cint", 8)))) + result(sources.Not(sources.GreaterThanOrEqual(attrIntString, 8)))) - testTranslateFilter(StartsWith(attrStr, "a"), Some(sources.StringStartsWith("cstr", "a"))) + testTranslateFilter(StartsWith(attrStr, "a"), + result(sources.StringStartsWith(attrStrString, "a"))) - testTranslateFilter(EndsWith(attrStr, "a"), Some(sources.StringEndsWith("cstr", "a"))) + testTranslateFilter(EndsWith(attrStr, "a"), result(sources.StringEndsWith(attrStrString, "a"))) - testTranslateFilter(Contains(attrStr, "a"), Some(sources.StringContains("cstr", "a"))) + testTranslateFilter(Contains(attrStr, "a"), result(sources.StringContains(attrStrString, "a"))) } - test("translate complex expression") { - val attrInt = 'cint.int + // `isResultNone` is used when testing invalid input expression + // containing dots which translates into None + private def testTranslateComplexExpression( + attrInt: Expression, + attrIntString: String, + isResultNone: Boolean): Unit = { + + def result(result: sources.Filter): Option[sources.Filter] = { + if (isResultNone) { + None + } else { + Some(result) + } + } - // ABS(cint) - 2 <= 1 + // ABS(attrInt) - 2 <= 1 testTranslateFilter(LessThanOrEqual( // Expressions are not supported // Functions such as 'Abs' are not supported Subtract(Abs(attrInt), 2), 1), None) - // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100) + // (attrInt > 1 AND attrInt < 10) OR (attrInt > 50 AND attrInt > 100) testTranslateFilter(Or( And( GreaterThan(attrInt, 1), @@ -100,16 +177,16 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { And( GreaterThan(attrInt, 50), LessThan(attrInt, 100))), - Some(sources.Or( + result(sources.Or( sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)), sources.And( - sources.GreaterThan("cint", 50), - sources.LessThan("cint", 100))))) + sources.GreaterThan(attrIntString, 50), + sources.LessThan(attrIntString, 100))))) // SPARK-22548 Incorrect nested AND expression pushed down to JDBC data source - // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100) + // (attrInt > 1 AND ABS(attrInt) < 10) OR (attrInt < 50 AND attrInt > 100) testTranslateFilter(Or( And( GreaterThan(attrInt, 1), @@ -120,7 +197,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { GreaterThan(attrInt, 50), LessThan(attrInt, 100))), None) - // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100)) + // NOT ((attrInt <= 1 OR ABS(attrInt) >= 10) AND (attrInt <= 50 OR attrInt >= 100)) testTranslateFilter(Not(And( Or( LessThanOrEqual(attrInt, 1), @@ -131,7 +208,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { LessThanOrEqual(attrInt, 50), GreaterThanOrEqual(attrInt, 100)))), None) - // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10) + // (attrInt = 1 OR attrInt = 10) OR (attrInt > 0 OR attrInt < -10) testTranslateFilter(Or( Or( EqualTo(attrInt, 1), @@ -140,15 +217,15 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { Or( GreaterThan(attrInt, 0), LessThan(attrInt, -10))), - Some(sources.Or( + result(sources.Or( sources.Or( - sources.EqualTo("cint", 1), - sources.EqualTo("cint", 10)), + sources.EqualTo(attrIntString, 1), + sources.EqualTo(attrIntString, 10)), sources.Or( - sources.GreaterThan("cint", 0), - sources.LessThan("cint", -10))))) + sources.GreaterThan(attrIntString, 0), + sources.LessThan(attrIntString, -10))))) - // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10) + // (attrInt = 1 OR ABS(attrInt) = 10) OR (attrInt > 0 OR attrInt < -10) testTranslateFilter(Or( Or( EqualTo(attrInt, 1), @@ -162,7 +239,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { // In end-to-end testing, conjunctive predicate should has been split // before reaching DataSourceStrategy.translateFilter. // This is for UT purpose to test each [[case]]. - // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL) + // (attrInt > 1 AND attrInt < 10) AND (attrInt = 6 AND attrInt IS NOT NULL) testTranslateFilter(And( And( GreaterThan(attrInt, 1), @@ -171,15 +248,15 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { And( EqualTo(attrInt, 6), IsNotNull(attrInt))), - Some(sources.And( + result(sources.And( sources.And( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)), sources.And( - sources.EqualTo("cint", 6), - sources.IsNotNull("cint"))))) + sources.EqualTo(attrIntString, 6), + sources.IsNotNull(attrIntString))))) - // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL) + // (attrInt > 1 AND attrInt < 10) AND (ABS(attrInt) = 6 AND attrInt IS NOT NULL) testTranslateFilter(And( And( GreaterThan(attrInt, 1), @@ -190,7 +267,7 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { EqualTo(Abs(attrInt), 6), IsNotNull(attrInt))), None) - // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + // (attrInt > 1 OR attrInt < 10) AND (attrInt = 6 OR attrInt IS NOT NULL) testTranslateFilter(And( Or( GreaterThan(attrInt, 1), @@ -199,15 +276,16 @@ class DataSourceStrategySuite extends PlanTest with SharedSparkSession { Or( EqualTo(attrInt, 6), IsNotNull(attrInt))), - Some(sources.And( + result(sources.And( sources.Or( - sources.GreaterThan("cint", 1), - sources.LessThan("cint", 10)), + sources.GreaterThan(attrIntString, 1), + sources.LessThan(attrIntString, 10)), sources.Or( - sources.EqualTo("cint", 6), - sources.IsNotNull("cint"))))) + sources.EqualTo(attrIntString, 6), + sources.IsNotNull(attrIntString))))) - // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL) + // (attrInt > 1 OR attrInt < 10) AND + // (attrInt = 6 OR attrInt IS NOT NULL) testTranslateFilter(And( Or( GreaterThan(attrInt, 1), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala index d0032df488f47..d5ec565b33c55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1FilterSuite.scala @@ -40,7 +40,10 @@ class OrcV1FilterSuite extends OrcFilterSuite { checker: (SearchArgument) => Unit): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-25557 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) var maybeRelation: Option[HadoopFsRelation] = None @@ -84,7 +87,10 @@ class OrcV1FilterSuite extends OrcFilterSuite { (implicit df: DataFrame): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-25557 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) var maybeRelation: Option[HadoopFsRelation] = None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 286bb1e920266..5e53dea149a03 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -27,7 +27,7 @@ import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} import org.apache.parquet.schema.MessageType import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql._ +import org.apache.spark.sql.{Column, _} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.InferFiltersFromConstraints @@ -186,14 +186,43 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } } - test("filter pushdown - boolean") { - withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) - checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + case class N1[T](a: Option[T]) + case class N2[T](b: Option[T]) - checkFilterPredicate('_1 === true, classOf[Eq[_]], true) - checkFilterPredicate('_1 <=> true, classOf[Eq[_]], true) - checkFilterPredicate('_1 =!= true, classOf[NotEq[_]], false) + test("filter pushdown - boolean") { + val data0 = (true :: false :: Nil).map(b => Tuple1.apply(Option(b))) + val data1 = data0.map(x => N1(Some(x))) + val data2 = data1.map(x => N2(Some(x))) + + // zero nesting + withParquetDataFrame(data0) { implicit df => + val col = Symbol("_1") + checkFilterPredicate(col.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(col.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false))) + checkFilterPredicate(col === true, classOf[Eq[_]], true) + checkFilterPredicate(col <=> true, classOf[Eq[_]], true) + checkFilterPredicate(col =!= true, classOf[NotEq[_]], false) + } + + // one level nesting + withParquetDataFrame(data1) { implicit df => + val col = Symbol("a._1") + checkFilterPredicate(col.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(col.isNotNull, classOf[NotEq[_]], Seq(Row(Row(true)), Row(Row(false)))) + checkFilterPredicate(col === true, classOf[Eq[_]], Seq(Row(Row(true)))) + checkFilterPredicate(col <=> true, classOf[Eq[_]], Seq(Row(Row(true)))) + checkFilterPredicate(col =!= true, classOf[NotEq[_]], Seq(Row(Row(false)))) + } + + // two level nesting + withParquetDataFrame(data2) { implicit df => + val col = Symbol("b.a._1") + checkFilterPredicate(col.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate(col.isNotNull, classOf[NotEq[_]], + Seq(Row(Row(Row(true))), Row(Row(Row(false))))) + checkFilterPredicate(col === true, classOf[Eq[_]], Seq(Row(Row(Row(true))))) + checkFilterPredicate(col <=> true, classOf[Eq[_]], Seq(Row(Row(Row(true))))) + checkFilterPredicate(col =!= true, classOf[NotEq[_]], Seq(Row(Row(Row(false))))) } } @@ -1418,7 +1447,10 @@ class ParquetV1FilterSuite extends ParquetFilterSuite { SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName, SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-17636 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) var maybeRelation: Option[HadoopFsRelation] = None @@ -1478,7 +1510,10 @@ class ParquetV2FilterSuite extends ParquetFilterSuite { SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> InferFiltersFromConstraints.ruleName, SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-17636 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) query.queryExecution.optimizedPlan.collectFirst { diff --git a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 995c5ed317de1..741648083a534 100644 --- a/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v1.2/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -64,20 +64,18 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + val conjunctionOptional = buildTree(convertibleFilters(schema, filters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the // following recursive method call `buildSearchArgument`. - buildSearchArgument(dataTypeMap, conjunction, newBuilder).build() + buildSearchArgument(schema, conjunction, newBuilder).build() } } def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -125,7 +123,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) childResultOptional.map(Not) case other => - for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other + for (_ <- buildLeafSearchArgument(schema, other, newBuilder())) yield other } filters.flatMap { filter => convertibleFiltersHelper(filter, true) @@ -165,33 +163,33 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Build a SearchArgument and return the builder so far. * - * @param dataTypeMap a map from the attribute name to its data type. + * @param schema the schema of the data * @param expression the input predicates, which should be fully convertible to SearchArgument. * @param builder the input SearchArgument.Builder. * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], + schema: StructType, expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ expression match { case And(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) + val lhs = buildSearchArgument(schema, left, builder.startAnd()) + val rhs = buildSearchArgument(schema, right, lhs) rhs.end() case Or(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) + val lhs = buildSearchArgument(schema, left, builder.startOr()) + val rhs = buildSearchArgument(schema, right, lhs) rhs.end() case Not(child) => - buildSearchArgument(dataTypeMap, child, builder.startNot()).end() + buildSearchArgument(schema, child, builder.startNot()).end() case other => - buildLeafSearchArgument(dataTypeMap, other, builder).getOrElse { + buildLeafSearchArgument(schema, other, builder).getOrElse { throw new SparkException( "The input filter of OrcFilters.buildSearchArgument should be fully convertible.") } @@ -201,17 +199,35 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Build a SearchArgument for a leaf predicate and return the builder so far. * - * @param dataTypeMap a map from the attribute name to its data type. + * @param schema the schema of the data * @param expression the input filter predicates. * @param builder the input SearchArgument.Builder. * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], + schema: StructType, expression: Filter, builder: Builder): Option[Builder] = { + + def getDataType(schema: StructType, attribute: String): DataType = { + val typeMap = schema.fields.map(f => f.name -> f.dataType).toMap + typeMap.get(attribute) match { + case Some(t) => t + case _ => + val levels = attribute.split("\\.", 2) + if (levels.length <= 1) { + NullType + } else { + typeMap.get(levels.head) match { + case Some(s: StructType) => getDataType(s, levels.last) + case _ => NullType + } + } + } + } + def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(getDataType(schema, attribute)) import org.apache.spark.sql.sources._ @@ -219,47 +235,47 @@ private[sql] object OrcFilters extends OrcFiltersBase { // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case EqualTo(attribute, value) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case EqualNullSafe(attribute, value) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case LessThan(attribute, value) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case LessThanOrEqual(attribute, value) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case GreaterThan(attribute, value) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case GreaterThanOrEqual(attribute, value) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case IsNull(attribute) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case IsNotNull(attribute) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + case In(attribute, values) if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + val castedValues = values.map(v => castLiteralValue(v, getDataType(schema, attribute))) Some(builder.startAnd().in(quotedName, getType(attribute), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) diff --git a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index d09236a934337..514c808a96442 100644 --- a/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v1.2/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -22,9 +22,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ - import org.apache.orc.storage.ql.io.sarg.{PredicateLeaf, SearchArgument} - import org.apache.spark.sql.{AnalysisException, Column, DataFrame} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ @@ -49,7 +47,10 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checker: (SearchArgument) => Unit): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-25557 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) query.queryExecution.optimizedPlan match { @@ -190,24 +191,65 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } + case class N1[T](a: Option[T]) + case class N2[T](b: Option[T]) test("filter pushdown - boolean") { - withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) - - checkFilterPredicate($"_1" === true, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + val data0 = (true :: false :: Nil).map(b => Tuple1.apply(Option(b))) + val data1 = data0.map(x => N1(Some(x))) + val data2 = data1.map(x => N2(Some(x))) + + // zero nesting + withOrcDataFrame(data0) { implicit df => + val col = Symbol("_1") + checkFilterPredicate(col.isNull, PredicateLeaf.Operator.IS_NULL) + checkFilterPredicate(col === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(col <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(col < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(col > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(false) === col, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> col, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > col, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= col, PredicateLeaf.Operator.LESS_THAN) + } - checkFilterPredicate($"_1" < true, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= false, PredicateLeaf.Operator.LESS_THAN) + // one level nesting + withOrcDataFrame(data1) { implicit df => + val col = Symbol("a._1") + checkFilterPredicate(col.isNull, PredicateLeaf.Operator.IS_NULL) + checkFilterPredicate(col === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(col <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(col < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(col > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(false) === col, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> col, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > col, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= col, PredicateLeaf.Operator.LESS_THAN) + } - checkFilterPredicate(Literal(false) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(false) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(false) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(true) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + // two level nesting + withOrcDataFrame(data2) { implicit df => + val col = Symbol("b.a._1") + checkFilterPredicate(col.isNull, PredicateLeaf.Operator.IS_NULL) + checkFilterPredicate(col === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(col <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(col < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(col > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(false) === col, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> col, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > col, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= col, PredicateLeaf.Operator.LESS_THAN) } } diff --git a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index 948ab44a8c19c..94d6909dea3c6 100644 --- a/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/v2.3/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -64,20 +64,18 @@ private[sql] object OrcFilters extends OrcFiltersBase { * Create ORC filter as a SearchArgument instance. */ def createFilter(schema: StructType, filters: Seq[Filter]): Option[SearchArgument] = { - val dataTypeMap = schema.map(f => f.name -> f.dataType).toMap // Combines all convertible filters using `And` to produce a single conjunction - val conjunctionOptional = buildTree(convertibleFilters(schema, dataTypeMap, filters)) + val conjunctionOptional = buildTree(convertibleFilters(schema, filters)) conjunctionOptional.map { conjunction => // Then tries to build a single ORC `SearchArgument` for the conjunction predicate. // The input predicate is fully convertible. There should not be any empty result in the // following recursive method call `buildSearchArgument`. - buildSearchArgument(dataTypeMap, conjunction, newBuilder).build() + buildSearchArgument(schema, conjunction, newBuilder).build() } } def convertibleFilters( schema: StructType, - dataTypeMap: Map[String, DataType], filters: Seq[Filter]): Seq[Filter] = { import org.apache.spark.sql.sources._ @@ -125,7 +123,7 @@ private[sql] object OrcFilters extends OrcFiltersBase { val childResultOptional = convertibleFiltersHelper(pred, canPartialPushDown = false) childResultOptional.map(Not) case other => - for (_ <- buildLeafSearchArgument(dataTypeMap, other, newBuilder())) yield other + for (_ <- buildLeafSearchArgument(schema, other, newBuilder())) yield other } filters.flatMap { filter => convertibleFiltersHelper(filter, true) @@ -165,33 +163,33 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Build a SearchArgument and return the builder so far. * - * @param dataTypeMap a map from the attribute name to its data type. + * @param schema the schema of the data * @param expression the input predicates, which should be fully convertible to SearchArgument. * @param builder the input SearchArgument.Builder. * @return the builder so far. */ private def buildSearchArgument( - dataTypeMap: Map[String, DataType], + schema: StructType, expression: Filter, builder: Builder): Builder = { import org.apache.spark.sql.sources._ expression match { case And(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startAnd()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) + val lhs = buildSearchArgument(schema, left, builder.startAnd()) + val rhs = buildSearchArgument(schema, right, lhs) rhs.end() case Or(left, right) => - val lhs = buildSearchArgument(dataTypeMap, left, builder.startOr()) - val rhs = buildSearchArgument(dataTypeMap, right, lhs) + val lhs = buildSearchArgument(schema, left, builder.startOr()) + val rhs = buildSearchArgument(schema, right, lhs) rhs.end() case Not(child) => - buildSearchArgument(dataTypeMap, child, builder.startNot()).end() + buildSearchArgument(schema, child, builder.startNot()).end() case other => - buildLeafSearchArgument(dataTypeMap, other, builder).getOrElse { + buildLeafSearchArgument(schema, other, builder).getOrElse { throw new SparkException( "The input filter of OrcFilters.buildSearchArgument should be fully convertible.") } @@ -201,17 +199,35 @@ private[sql] object OrcFilters extends OrcFiltersBase { /** * Build a SearchArgument for a leaf predicate and return the builder so far. * - * @param dataTypeMap a map from the attribute name to its data type. + * @param schema the schema of the data * @param expression the input filter predicates. * @param builder the input SearchArgument.Builder. * @return the builder so far. */ private def buildLeafSearchArgument( - dataTypeMap: Map[String, DataType], + schema: StructType, expression: Filter, builder: Builder): Option[Builder] = { + + def getDataType(schema: StructType, attribute: String): DataType = { + val typeMap = schema.fields.map(f => f.name -> f.dataType).toMap + typeMap.get(attribute) match { + case Some(t) => t + case _ => + val levels = attribute.split("\\.", 2) + if (levels.length <= 1) { + NullType + } else { + typeMap.get(levels.head) match { + case Some(s: StructType) => getDataType(s, levels.last) + case _ => NullType + } + } + } + } + def getType(attribute: String): PredicateLeaf.Type = - getPredicateLeafType(dataTypeMap(attribute)) + getPredicateLeafType(getDataType(schema, attribute)) import org.apache.spark.sql.sources._ @@ -219,47 +235,56 @@ private[sql] object OrcFilters extends OrcFiltersBase { // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). expression match { - case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case EqualTo(attribute, value) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) - case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case EqualNullSafe(attribute, value) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) - case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case LessThan(attribute, value) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) - case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case LessThanOrEqual(attribute, value) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) - case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case GreaterThan(attribute, value) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) - case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + case GreaterThanOrEqual(attribute, value) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValue = castLiteralValue(value, dataTypeMap(attribute)) + val castedValue = castLiteralValue(value, getDataType(schema, attribute)) Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) - case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case IsNull(attribute) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) - case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => + case IsNotNull(attribute) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) - case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + case In(attribute, values) + if isSearchableType(getDataType(schema, attribute)) => val quotedName = quoteAttributeNameIfNeeded(attribute) - val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) + val castedValues = values.map(v => castLiteralValue(v, getDataType(schema, attribute))) Some(builder.startAnd().in(quotedName, getType(attribute), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) diff --git a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index b95a32ef85ddf..38ea4750c79ad 100644 --- a/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/v2.3/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -50,7 +50,10 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { checker: (SearchArgument) => Unit): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-25557 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) query.queryExecution.optimizedPlan match { @@ -191,24 +194,66 @@ class OrcFilterSuite extends OrcTest with SharedSparkSession { } } - test("filter pushdown - boolean") { - withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => - checkFilterPredicate($"_1".isNull, PredicateLeaf.Operator.IS_NULL) + case class N1[T](a: Option[T]) + case class N2[T](b: Option[T]) - checkFilterPredicate($"_1" === true, PredicateLeaf.Operator.EQUALS) - checkFilterPredicate($"_1" <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + test("filter pushdown - boolean") { + val data0 = (true :: false :: Nil).map(b => Tuple1.apply(Option(b))) + val data1 = data0.map(x => N1(Some(x))) + val data2 = data1.map(x => N2(Some(x))) + + // zero nesting + withOrcDataFrame(data0) { implicit df => + val col = Symbol("_1") + checkFilterPredicate(col.isNull, PredicateLeaf.Operator.IS_NULL) + checkFilterPredicate(col === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(col <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(col < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(col > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(false) === col, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> col, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > col, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= col, PredicateLeaf.Operator.LESS_THAN) + } - checkFilterPredicate($"_1" < true, PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate($"_1" > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate($"_1" >= false, PredicateLeaf.Operator.LESS_THAN) + // one level nesting + withOrcDataFrame(data1) { implicit df => + val col = Symbol("a._1") + checkFilterPredicate(col.isNull, PredicateLeaf.Operator.IS_NULL) + checkFilterPredicate(col === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(col <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(col < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(col > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(false) === col, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> col, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > col, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= col, PredicateLeaf.Operator.LESS_THAN) + } - checkFilterPredicate(Literal(false) === $"_1", PredicateLeaf.Operator.EQUALS) - checkFilterPredicate(Literal(false) <=> $"_1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) - checkFilterPredicate(Literal(false) > $"_1", PredicateLeaf.Operator.LESS_THAN) - checkFilterPredicate(Literal(true) < $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) >= $"_1", PredicateLeaf.Operator.LESS_THAN_EQUALS) - checkFilterPredicate(Literal(true) <= $"_1", PredicateLeaf.Operator.LESS_THAN) + // two level nesting + withOrcDataFrame(data2) { implicit df => + val col = Symbol("b.a._1") + checkFilterPredicate(col.isNull, PredicateLeaf.Operator.IS_NULL) + checkFilterPredicate(col === true, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(col <=> true, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(col < true, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(col > false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col <= false, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(col >= false, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(false) === col, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(false) <=> col, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(false) > col, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(true) < col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) >= col, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(true) <= col, PredicateLeaf.Operator.LESS_THAN) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala index 5fc41067f661d..ac0b3b5c92095 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/HiveOrcFilterSuite.scala @@ -47,7 +47,10 @@ class HiveOrcFilterSuite extends OrcTest with TestHiveSingleton { checker: (SearchArgument) => Unit): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct val query = df - .select(output.map(e => Column(e)): _*) + // SPARK-25557 + // The following select will flatten the nested data structure, + // so comment it out for now until we find a better approach. + // .select(output.map(e => Column(e)): _*) .where(Column(predicate)) var maybeRelation: Option[HadoopFsRelation] = None