Skip to content

Commit 22a7a00

Browse files
committed
address feedback
1 parent 385738d commit 22a7a00

File tree

8 files changed

+55
-123
lines changed

8 files changed

+55
-123
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,6 +2049,17 @@ object SQLConf {
20492049
.booleanConf
20502050
.createWithDefault(true)
20512051

2052+
val NESTED_PREDICATE_PUSHDOWN_ENABLED =
2053+
buildConf("spark.sql.optimizer.nestedPredicatePushdown.enabled")
2054+
.internal()
2055+
.doc("When true, Spark tries to push down predicates for nested columns and or names " +
2056+
"containing `dots` to data sources. Currently, Parquet implements both optimizations " +
2057+
"while ORC only supports predicates for names containing `dots`. The other data sources" +
2058+
"don't support this feature yet.")
2059+
.version("3.0.0")
2060+
.booleanConf
2061+
.createWithDefault(true)
2062+
20522063
val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
20532064
buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
20542065
.internal()
@@ -3035,6 +3046,8 @@ class SQLConf extends Serializable with Logging {
30353046

30363047
def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)
30373048

3049+
def nestedPredicatePushdownEnabled: Boolean = getConf(NESTED_PREDICATE_PUSHDOWN_ENABLED)
3050+
30383051
def serializerNestedSchemaPruningEnabled: Boolean =
30393052
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)
30403053

sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,6 @@ sealed abstract class Filter {
7676
@Stable
7777
case class EqualTo(attribute: String, value: Any) extends Filter {
7878
override def references: Array[String] = Array(attribute) ++ findReferences(value)
79-
80-
/**
81-
* A column name as an array of string multi-identifier
82-
*/
83-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
8479
}
8580

8681
/**
@@ -96,11 +91,6 @@ case class EqualTo(attribute: String, value: Any) extends Filter {
9691
@Stable
9792
case class EqualNullSafe(attribute: String, value: Any) extends Filter {
9893
override def references: Array[String] = Array(attribute) ++ findReferences(value)
99-
100-
/**
101-
* A column name as an array of string multi-identifier
102-
*/
103-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
10494
}
10595

10696
/**
@@ -115,11 +105,6 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter {
115105
@Stable
116106
case class GreaterThan(attribute: String, value: Any) extends Filter {
117107
override def references: Array[String] = Array(attribute) ++ findReferences(value)
118-
119-
/**
120-
* A column name as an array of string multi-identifier
121-
*/
122-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
123108
}
124109

125110
/**
@@ -134,11 +119,6 @@ case class GreaterThan(attribute: String, value: Any) extends Filter {
134119
@Stable
135120
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
136121
override def references: Array[String] = Array(attribute) ++ findReferences(value)
137-
138-
/**
139-
* A column name as an array of string multi-identifier
140-
*/
141-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
142122
}
143123

144124
/**
@@ -153,11 +133,6 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
153133
@Stable
154134
case class LessThan(attribute: String, value: Any) extends Filter {
155135
override def references: Array[String] = Array(attribute) ++ findReferences(value)
156-
157-
/**
158-
* A column name as an array of string multi-identifier
159-
*/
160-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
161136
}
162137

163138
/**
@@ -172,11 +147,6 @@ case class LessThan(attribute: String, value: Any) extends Filter {
172147
@Stable
173148
case class LessThanOrEqual(attribute: String, value: Any) extends Filter {
174149
override def references: Array[String] = Array(attribute) ++ findReferences(value)
175-
176-
/**
177-
* A column name as an array of string multi-identifier
178-
*/
179-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
180150
}
181151

182152
/**
@@ -207,11 +177,6 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
207177
}
208178

209179
override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences)
210-
211-
/**
212-
* A column name as an array of string multi-identifier
213-
*/
214-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
215180
}
216181

217182
/**
@@ -225,11 +190,6 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
225190
@Stable
226191
case class IsNull(attribute: String) extends Filter {
227192
override def references: Array[String] = Array(attribute)
228-
229-
/**
230-
* A column name as an array of string multi-identifier
231-
*/
232-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
233193
}
234194

235195
/**
@@ -243,11 +203,6 @@ case class IsNull(attribute: String) extends Filter {
243203
@Stable
244204
case class IsNotNull(attribute: String) extends Filter {
245205
override def references: Array[String] = Array(attribute)
246-
247-
/**
248-
* A column name as an array of string multi-identifier
249-
*/
250-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
251206
}
252207

253208
/**
@@ -292,11 +247,6 @@ case class Not(child: Filter) extends Filter {
292247
@Stable
293248
case class StringStartsWith(attribute: String, value: String) extends Filter {
294249
override def references: Array[String] = Array(attribute)
295-
296-
/**
297-
* A column name as an array of string multi-identifier
298-
*/
299-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
300250
}
301251

302252
/**
@@ -311,11 +261,6 @@ case class StringStartsWith(attribute: String, value: String) extends Filter {
311261
@Stable
312262
case class StringEndsWith(attribute: String, value: String) extends Filter {
313263
override def references: Array[String] = Array(attribute)
314-
315-
/**
316-
* A column name as an array of string multi-identifier
317-
*/
318-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
319264
}
320265

321266
/**
@@ -330,11 +275,6 @@ case class StringEndsWith(attribute: String, value: String) extends Filter {
330275
@Stable
331276
case class StringContains(attribute: String, value: String) extends Filter {
332277
override def references: Array[String] = Array(attribute)
333-
334-
/**
335-
* A column name as an array of string multi-identifier
336-
*/
337-
val fieldNames: Array[String] = parseColumnPath(attribute).toArray
338278
}
339279

340280
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,10 +652,17 @@ object DataSourceStrategy {
652652
*/
653653
object PushableColumn {
654654
def unapply(e: Expression): Option[String] = {
655+
val nestedPredicatePushdownEnabled = SQLConf.get.nestedPredicatePushdownEnabled
655656
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
656657
def helper(e: Expression): Option[Seq[String]] = e match {
657-
case a: Attribute => Some(Seq(a.name))
658-
case s: GetStructField => helper(s.child).map(_ :+ s.childSchema(s.ordinal).name)
658+
case a: Attribute =>
659+
if (nestedPredicatePushdownEnabled || !a.name.contains(".")) {
660+
Some(Seq(a.name))
661+
} else {
662+
None
663+
}
664+
case s: GetStructField if nestedPredicatePushdownEnabled =>
665+
helper(s.child).map(_ :+ s.childSchema(s.ordinal).name)
659666
case _ => None
660667
}
661668
helper(e).map(_.quoted)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ class ParquetFilters(
4949
pushDownInFilterThreshold: Int,
5050
caseSensitive: Boolean) {
5151
// A map which contains parquet field name and data type, if predicate push down applies.
52-
// The keys are the column names. For nested column, `dot` will be used as a separator.
53-
// For column name that contains `dot`, backquote will be used.
52+
//
53+
// Each key in `nameToParquetField` represents a column; `dots` are used as separators for
54+
// nested columns. If any part of the names contains `dots`, it is quoted to avoid confusion.
5455
// See `org.apache.spark.sql.connector.catalog.quote` for implementation details.
5556
private val nameToParquetField : Map[String, ParquetPrimitiveField] = {
5657
// Recursively traverse the parquet schema to get primitive fields that can be pushed-down.

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
147147
spark.createDataFrame(data.map(x => ColA(Some(ColB(Some(ColC(Some(x)))))))),
148148
"a.b.c", // two level nesting
149149
(x: Any) => Row(Row(x)))
150-
).foreach { case (i, pushDownColName, resultFun) => withParquetDFfromDF(i) { implicit df =>
150+
).foreach { case (i, pushDownColName, resultFun) => withParquetDataFrame(i) { implicit df =>
151151
val tsAttr = df(pushDownColName).expr
152152
checkFilterPredicate(tsAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
153153
checkFilterPredicate(tsAttr.isNotNull, classOf[NotEq[_]],
@@ -218,7 +218,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
218218
data.map(x => ColA(Option(ColB(Option(ColC(Option(x)))))))),
219219
"a.b.c", // two level nesting
220220
(x: Any) => Row(Row(x)))
221-
).foreach { case (i, pushDownColName, resultFun) => withParquetDFfromDF(i) { implicit df =>
221+
).foreach { case (i, pushDownColName, resultFun) => withParquetDataFrame(i) { implicit df =>
222222
val booleanAttr = df(pushDownColName).expr
223223
checkFilterPredicate(booleanAttr.isNull, classOf[Eq[_]], Seq.empty[Row])
224224
checkFilterPredicate(booleanAttr.isNotNull, classOf[NotEq[_]],
@@ -231,7 +231,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
231231
}}
232232

233233
test("filter pushdown - tinyint") {
234-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toByte)))) { implicit df =>
234+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toByte))))) { implicit df =>
235235
assert(df.schema.head.dataType === ByteType)
236236
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
237237
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
@@ -259,7 +259,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
259259
}
260260

261261
test("filter pushdown - smallint") {
262-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df =>
262+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toShort))))) { implicit df =>
263263
assert(df.schema.head.dataType === ShortType)
264264
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
265265
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
@@ -287,7 +287,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
287287
}
288288

289289
test("filter pushdown - integer") {
290-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
290+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i))))) { implicit df =>
291291
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
292292
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
293293

@@ -313,7 +313,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
313313
}
314314

315315
test("filter pushdown - long") {
316-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df =>
316+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toLong))))) { implicit df =>
317317
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
318318
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
319319

@@ -339,7 +339,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
339339
}
340340

341341
test("filter pushdown - float") {
342-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df =>
342+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toFloat))))) { implicit df =>
343343
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
344344
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
345345

@@ -365,7 +365,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
365365
}
366366

367367
test("filter pushdown - double") {
368-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df =>
368+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Option(i.toDouble))))) { implicit df =>
369369
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
370370
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
371371

@@ -391,7 +391,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
391391
}
392392

393393
test("filter pushdown - string") {
394-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(i.toString))) { implicit df =>
394+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i.toString)))) { implicit df =>
395395
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
396396
checkFilterPredicate(
397397
'_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString)))
@@ -423,7 +423,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
423423
def b: Array[Byte] = int.toString.getBytes(StandardCharsets.UTF_8)
424424
}
425425

426-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(i.b))) { implicit df =>
426+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i.b)))) { implicit df =>
427427
checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b)
428428
checkBinaryFilterPredicate('_1 <=> 1.b, classOf[Eq[_]], 1.b)
429429

@@ -459,7 +459,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
459459

460460
val data = Seq("2018-03-18", "2018-03-19", "2018-03-20", "2018-03-21")
461461

462-
withParquetDFfromObjs(data.map(i => Tuple1(i.date))) { implicit df =>
462+
withParquetDataFrame(toDF(data.map(i => Tuple1(i.date)))) { implicit df =>
463463
checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
464464
checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], data.map(i => Row.apply(i.date)))
465465

@@ -518,7 +518,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
518518
// spark.sql.parquet.outputTimestampType = INT96 doesn't support pushdown
519519
withSQLConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key ->
520520
ParquetOutputTimestampType.INT96.toString) {
521-
withParquetDFfromObjs(millisData.map(i => Tuple1(i))) { implicit df =>
521+
withParquetDataFrame(toDF(millisData.map(i => Tuple1(i)))) { implicit df =>
522522
val schema = new SparkToParquetSchemaConverter(conf).convert(df.schema)
523523
assertResult(None) {
524524
createParquetFilters(schema).createFilter(sources.IsNull("_1"))
@@ -539,7 +539,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
539539
val rdd =
540540
spark.sparkContext.parallelize((1 to 4).map(i => Row(new java.math.BigDecimal(i))))
541541
val dataFrame = spark.createDataFrame(rdd, schema)
542-
withParquetDFfromDF(dataFrame) { implicit df =>
542+
withParquetDataFrame(dataFrame) { implicit df =>
543543
assert(df.schema === schema)
544544
checkFilterPredicate('a.isNull, classOf[Eq[_]], Seq.empty[Row])
545545
checkFilterPredicate('a.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
@@ -1075,7 +1075,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
10751075
}
10761076

10771077
test("SPARK-16371 Do not push down filters when inner name and outer name are the same") {
1078-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df =>
1078+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(Tuple1(i))))) { implicit df =>
10791079
// Here the schema becomes as below:
10801080
//
10811081
// root
@@ -1217,7 +1217,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
12171217
}
12181218

12191219
test("filter pushdown - StringStartsWith") {
1220-
withParquetDFfromObjs((1 to 4).map(i => Tuple1(i + "str" + i))) { implicit df =>
1220+
withParquetDataFrame(toDF((1 to 4).map(i => Tuple1(i + "str" + i)))) { implicit df =>
12211221
checkFilterPredicate(
12221222
'_1.startsWith("").asInstanceOf[Predicate],
12231223
classOf[UserDefinedByInstance[_, _]],
@@ -1263,7 +1263,7 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared
12631263
}
12641264

12651265
// SPARK-28371: make sure filter is null-safe.
1266-
withParquetDFfromObjs(Seq(Tuple1[String](null))) { implicit df =>
1266+
withParquetDataFrame(toDF(Seq(Tuple1[String](null)))) { implicit df =>
12671267
checkFilterPredicate(
12681268
'_1.startsWith("blah").asInstanceOf[Predicate],
12691269
classOf[UserDefinedByInstance[_, _]],

0 commit comments

Comments
 (0)