Skip to content

Commit a0f22cf

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-10195] [SQL] Data sources Filter should not expose internal types
Spark SQL's data sources API exposes Catalyst's internal types through its Filter interfaces. This is a problem because types like UTF8String are not stable developer APIs and should not be exposed to third-parties. This issue caused incompatibilities when upgrading our `spark-redshift` library to work against Spark 1.5.0. To avoid these issues in the future we should only expose public types through these Filter objects. This patch accomplishes this by using CatalystTypeConverters to add the appropriate conversions. Author: Josh Rosen <[email protected]> Closes #8403 from JoshRosen/datasources-internal-vs-external-types. (cherry picked from commit 7bc9a8c) Signed-off-by: Reynold Xin <[email protected]>
1 parent e5cea56 commit a0f22cf

File tree

4 files changed

+54
-41
lines changed

4 files changed

+54
-41
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala

Lines changed: 35 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
2020
import org.apache.spark.{Logging, TaskContext}
2121
import org.apache.spark.deploy.SparkHadoopUtil
2222
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
23-
import org.apache.spark.sql.catalyst.{InternalRow, expressions}
23+
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
24+
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
2425
import org.apache.spark.sql.catalyst.expressions._
2526
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
2627
import org.apache.spark.sql.catalyst.plans.logical
@@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
344345
*/
345346
protected[sql] def selectFilters(filters: Seq[Expression]) = {
346347
def translate(predicate: Expression): Option[Filter] = predicate match {
347-
case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
348-
Some(sources.EqualTo(a.name, v))
349-
case expressions.EqualTo(Literal(v, _), a: Attribute) =>
350-
Some(sources.EqualTo(a.name, v))
351-
352-
case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
353-
Some(sources.EqualNullSafe(a.name, v))
354-
case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
355-
Some(sources.EqualNullSafe(a.name, v))
356-
357-
case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
358-
Some(sources.GreaterThan(a.name, v))
359-
case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
360-
Some(sources.LessThan(a.name, v))
361-
362-
case expressions.LessThan(a: Attribute, Literal(v, _)) =>
363-
Some(sources.LessThan(a.name, v))
364-
case expressions.LessThan(Literal(v, _), a: Attribute) =>
365-
Some(sources.GreaterThan(a.name, v))
366-
367-
case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
368-
Some(sources.GreaterThanOrEqual(a.name, v))
369-
case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
370-
Some(sources.LessThanOrEqual(a.name, v))
371-
372-
case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
373-
Some(sources.LessThanOrEqual(a.name, v))
374-
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
375-
Some(sources.GreaterThanOrEqual(a.name, v))
348+
case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
349+
Some(sources.EqualTo(a.name, convertToScala(v, t)))
350+
case expressions.EqualTo(Literal(v, t), a: Attribute) =>
351+
Some(sources.EqualTo(a.name, convertToScala(v, t)))
352+
353+
case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
354+
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
355+
case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
356+
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
357+
358+
case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
359+
Some(sources.GreaterThan(a.name, convertToScala(v, t)))
360+
case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
361+
Some(sources.LessThan(a.name, convertToScala(v, t)))
362+
363+
case expressions.LessThan(a: Attribute, Literal(v, t)) =>
364+
Some(sources.LessThan(a.name, convertToScala(v, t)))
365+
case expressions.LessThan(Literal(v, t), a: Attribute) =>
366+
Some(sources.GreaterThan(a.name, convertToScala(v, t)))
367+
368+
case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
369+
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
370+
case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
371+
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
372+
373+
case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
374+
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
375+
case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
376+
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
376377

377378
case expressions.InSet(a: Attribute, set) =>
378-
Some(sources.In(a.name, set.toArray))
379+
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
380+
Some(sources.In(a.name, set.toArray.map(toScala)))
379381

380382
// Because we only convert In to InSet in Optimizer when there are more than certain
381383
// items. So it is possible we still get an In expression here that needs to be pushed
382384
// down.
383385
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
384386
val hSet = list.map(e => e.eval(EmptyRow))
385-
Some(sources.In(a.name, hSet.toArray))
387+
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
388+
Some(sources.In(a.name, hSet.toArray.map(toScala)))
386389

387390
case expressions.IsNull(a: Attribute) =>
388391
Some(sources.IsNull(a.name))

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
262262
* Converts value to SQL expression.
263263
*/
264264
private def compileValue(value: Any): Any = value match {
265-
case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
265+
case stringValue: String => s"'${escapeSql(stringValue)}'"
266266
case _ => value
267267
}
268268

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ import org.apache.spark.SparkEnv
3434
import org.apache.spark.sql.catalyst.expressions._
3535
import org.apache.spark.sql.sources
3636
import org.apache.spark.sql.types._
37-
import org.apache.spark.unsafe.types.UTF8String
3837

3938
private[sql] object ParquetFilters {
4039
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
@@ -73,7 +72,7 @@ private[sql] object ParquetFilters {
7372
case StringType =>
7473
(n: String, v: Any) => FilterApi.eq(
7574
binaryColumn(n),
76-
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
75+
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
7776
case BinaryType =>
7877
(n: String, v: Any) => FilterApi.eq(
7978
binaryColumn(n),
@@ -94,7 +93,7 @@ private[sql] object ParquetFilters {
9493
case StringType =>
9594
(n: String, v: Any) => FilterApi.notEq(
9695
binaryColumn(n),
97-
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
96+
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
9897
case BinaryType =>
9998
(n: String, v: Any) => FilterApi.notEq(
10099
binaryColumn(n),
@@ -112,7 +111,8 @@ private[sql] object ParquetFilters {
112111
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
113112
case StringType =>
114113
(n: String, v: Any) =>
115-
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
114+
FilterApi.lt(binaryColumn(n),
115+
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
116116
case BinaryType =>
117117
(n: String, v: Any) =>
118118
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -129,7 +129,8 @@ private[sql] object ParquetFilters {
129129
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
130130
case StringType =>
131131
(n: String, v: Any) =>
132-
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
132+
FilterApi.ltEq(binaryColumn(n),
133+
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
133134
case BinaryType =>
134135
(n: String, v: Any) =>
135136
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -146,7 +147,8 @@ private[sql] object ParquetFilters {
146147
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
147148
case StringType =>
148149
(n: String, v: Any) =>
149-
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
150+
FilterApi.gt(binaryColumn(n),
151+
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
150152
case BinaryType =>
151153
(n: String, v: Any) =>
152154
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -163,7 +165,8 @@ private[sql] object ParquetFilters {
163165
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
164166
case StringType =>
165167
(n: String, v: Any) =>
166-
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
168+
FilterApi.gtEq(binaryColumn(n),
169+
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
167170
case BinaryType =>
168171
(n: String, v: Any) =>
169172
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
@@ -185,7 +188,7 @@ private[sql] object ParquetFilters {
185188
case StringType =>
186189
(n: String, v: Set[Any]) =>
187190
FilterApi.userDefined(binaryColumn(n),
188-
SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
191+
SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
189192
case BinaryType =>
190193
(n: String, v: Set[Any]) =>
191194
FilterApi.userDefined(binaryColumn(n),

sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
2020
import scala.language.existentials
2121

2222
import org.apache.spark.rdd.RDD
23+
import org.apache.spark.unsafe.types.UTF8String
2324
import org.apache.spark.sql._
2425
import org.apache.spark.sql.test.SharedSQLContext
2526
import org.apache.spark.sql.types._
@@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
7879
case StringStartsWith("c", v) => _.startsWith(v)
7980
case StringEndsWith("c", v) => _.endsWith(v)
8081
case StringContains("c", v) => _.contains(v)
82+
case EqualTo("c", v: String) => _.equals(v)
83+
case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
84+
case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
8185
case _ => (c: String) => true
8286
}
8387

@@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
237241
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
238242
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)
239243

244+
testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
245+
testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)
246+
240247
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
241248
test(s"PushDown Returns $expectedCount: $sqlString") {
242249
val queryExecution = sql(sqlString).queryExecution

0 commit comments

Comments
 (0)