diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 918db8e7d083..d69e5bcffb60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -62,7 +62,20 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Scanning partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) if t.partitionSpec.partitionColumns.nonEmpty => - val selectedPartitions = prunePartitions(filters, t.partitionSpec).toArray + // We divide the filter expressions into 3 parts + val partitionColumns = AttributeSet( + t.partitionColumns.map(c => l.output.find(_.name == c.name).get)) + + // Only pruning the partition keys + val partitionFilters = filters.filter(_.references.subsetOf(partitionColumns)) + + // Only pushes down predicates that do not reference partition keys. + val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) + + // Predicates with both partition keys and attributes + val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet + + val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray logInfo { val total = t.partitionSpec.partitions.length @@ -71,21 +84,16 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } - // Only pushes down predicates that do not reference partition columns. - val pushedFilters = { - val partitionColumnNames = t.partitionSpec.partitionColumns.map(_.name).toSet - filters.filter { f => - val referencedColumnNames = f.references.map(_.name).toSet - referencedColumnNames.intersect(partitionColumnNames).isEmpty - } - } - - buildPartitionedTableScan( + val scan = buildPartitionedTableScan( l, projects, pushedFilters, t.partitionSpec.partitionColumns, - selectedPartitions) :: Nil + selectedPartitions) + + combineFilters + .reduceLeftOption(expressions.And) + .map(execution.Filter(_, scan)).getOrElse(scan) :: Nil // Scanning non-partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 36063d8fa4a6..20787217e892 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -937,4 +937,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { expected(except) ) } + + test("SPARK-11301: fix case sensitivity for filter on partitioned columns") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + withTempPath { path => + Seq(2012 -> "a").toDF("year", "val").write.partitionBy("year").parquet(path.getAbsolutePath) + val df = sqlContext.read.parquet(path.getAbsolutePath) + checkAnswer(df.filter($"yEAr" > 2000).select($"val"), Row("a")) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 7f4d36768e59..13fdd555a4c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -299,4 +299,21 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-10829: Filter combine partition key and attribute doesn't work in DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/part=1" + (1 to 3).map(i => (i, i.toString)).toDF("a", "b").write.parquet(path) + + // If the "part = 1" filter gets pushed down, this query will throw an exception since + // "part" is not a valid column in the actual Parquet file + checkAnswer( + sqlContext.read.parquet(path).filter("a > 0 and (part = 0 or a > 1)"), + (2 to 3).map(i => Row(i, i.toString, 1))) + } + } + } }