diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala index dbafc468c6c4..5b93a60a80ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilters.scala @@ -67,6 +67,16 @@ private[sql] object OrcFilters { } } + // Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters + // in order to distinguish predicate pushdown for nested columns. + private def quoteAttributeNameIfNeeded(name: String) : String = { + if (!name.contains("`") && name.contains(".")) { + s"`$name`" + } else { + name + } + } + /** * Create ORC filter as a SearchArgument instance. */ @@ -178,38 +188,47 @@ private[sql] object OrcFilters { // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end()) + Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end()) case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end()) + Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end()) case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end()) + Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end()) case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end()) + Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end()) case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end()) + Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end()) case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValue = castLiteralValue(value, dataTypeMap(attribute)) - Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end()) + Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end()) case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startAnd().isNull(attribute, getType(attribute)).end()) + val quotedName = quoteAttributeNameIfNeeded(attribute) + Some(builder.startAnd().isNull(quotedName, getType(attribute)).end()) case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) => - Some(builder.startNot().isNull(attribute, getType(attribute)).end()) + val quotedName = quoteAttributeNameIfNeeded(attribute) + Some(builder.startNot().isNull(quotedName, getType(attribute)).end()) case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) => + val quotedName = quoteAttributeNameIfNeeded(attribute) val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute))) - Some(builder.startAnd().in(attribute, getType(attribute), + Some(builder.startAnd().in(quotedName, getType(attribute), castedValues.map(_.asInstanceOf[AnyRef]): _*).end()) case _ => None diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index e9dccbf2e261..998b7b31dcd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -445,16 +445,7 @@ abstract class OrcQueryTest extends OrcTest { test("Support for pushing down filters for decimal types") { withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { val data = (0 until 10).map(i => Tuple1(BigDecimal.valueOf(i))) - withTempPath { file => - // It needs to repartition data so that we can have several ORC files - // in order to skip stripes in ORC. - spark.createDataFrame(data).toDF("a").repartition(10) - .write.orc(file.getCanonicalPath) - val df = spark.read.orc(file.getCanonicalPath).where("a == 2") - val actual = stripSparkFilter(df).count() - - assert(actual < 10) - } + checkPredicatePushDown(spark.createDataFrame(data).toDF("a"), 10, "a == 2") } } @@ -465,16 +456,7 @@ abstract class OrcQueryTest extends OrcTest { val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600 Tuple1(new Timestamp(milliseconds)) } - withTempPath { file => - // It needs to repartition data so that we can have several ORC files - // in order to skip stripes in ORC. - spark.createDataFrame(data).toDF("a").repartition(10) - .write.orc(file.getCanonicalPath) - val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'") - val actual = stripSparkFilter(df).count() - - assert(actual < 10) - } + checkPredicatePushDown(spark.createDataFrame(data).toDF("a"), 10, s"a == '$timeString'") } } @@ -674,6 +656,12 @@ class OrcQuerySuite extends OrcQueryTest with SharedSQLContext { } } + test("SPARK-25579 ORC PPD should support column names with dot") { + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + checkPredicatePushDown(spark.range(10).toDF("col.dot"), 10, "`col.dot` == 2") + } + } + test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") { withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") { val e = intercept[AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 38b34a03e3e4..a35c536038c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -106,4 +106,14 @@ abstract class OrcTest extends QueryTest with SQLTestUtils with BeforeAndAfterAl df: DataFrame, path: File): Unit = { df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } + + protected def checkPredicatePushDown(df: DataFrame, numRows: Int, predicate: String): Unit = { + withTempPath { file => + // It needs to repartition data so that we can have several ORC files + // in order to skip stripes in ORC. + df.repartition(numRows).write.orc(file.getCanonicalPath) + val actual = stripSparkFilter(spark.read.orc(file.getCanonicalPath).where(predicate)).count() + assert(actual < numRows) + } + } }