Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.datasources
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD}
import org.apache.spark.sql.catalyst.{InternalRow, expressions}
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
Expand Down Expand Up @@ -344,45 +345,47 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
*/
protected[sql] def selectFilters(filters: Seq[Expression]) = {
def translate(predicate: Expression): Option[Filter] = predicate match {
case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
Some(sources.EqualTo(a.name, v))
case expressions.EqualTo(Literal(v, _), a: Attribute) =>
Some(sources.EqualTo(a.name, v))

case expressions.EqualNullSafe(a: Attribute, Literal(v, _)) =>
Some(sources.EqualNullSafe(a.name, v))
case expressions.EqualNullSafe(Literal(v, _), a: Attribute) =>
Some(sources.EqualNullSafe(a.name, v))

case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThan(a.name, v))
case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
Some(sources.LessThan(a.name, v))

case expressions.LessThan(a: Attribute, Literal(v, _)) =>
Some(sources.LessThan(a.name, v))
case expressions.LessThan(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThan(a.name, v))

case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.GreaterThanOrEqual(a.name, v))
case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, v))

case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
Some(sources.LessThanOrEqual(a.name, v))
case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, v))
case expressions.EqualTo(a: Attribute, Literal(v, t)) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))
case expressions.EqualTo(Literal(v, t), a: Attribute) =>
Some(sources.EqualTo(a.name, convertToScala(v, t)))

case expressions.EqualNullSafe(a: Attribute, Literal(v, t)) =>
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))
case expressions.EqualNullSafe(Literal(v, t), a: Attribute) =>
Some(sources.EqualNullSafe(a.name, convertToScala(v, t)))

case expressions.GreaterThan(a: Attribute, Literal(v, t)) =>
Some(sources.GreaterThan(a.name, convertToScala(v, t)))
case expressions.GreaterThan(Literal(v, t), a: Attribute) =>
Some(sources.LessThan(a.name, convertToScala(v, t)))

case expressions.LessThan(a: Attribute, Literal(v, t)) =>
Some(sources.LessThan(a.name, convertToScala(v, t)))
case expressions.LessThan(Literal(v, t), a: Attribute) =>
Some(sources.GreaterThan(a.name, convertToScala(v, t)))

case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, t)) =>
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))
case expressions.GreaterThanOrEqual(Literal(v, t), a: Attribute) =>
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))

case expressions.LessThanOrEqual(a: Attribute, Literal(v, t)) =>
Some(sources.LessThanOrEqual(a.name, convertToScala(v, t)))
case expressions.LessThanOrEqual(Literal(v, t), a: Attribute) =>
Some(sources.GreaterThanOrEqual(a.name, convertToScala(v, t)))

case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, set.toArray.map(toScala)))

// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
Some(sources.In(a.name, hSet.toArray))
val toScala = CatalystTypeConverters.createToScalaConverter(a.dataType)
Some(sources.In(a.name, hSet.toArray.map(toScala)))

case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ private[sql] class JDBCRDD(
* Converts value to SQL expression.
*/
private def compileValue(value: Any): Any = value match {
case stringValue: UTF8String => s"'${escapeSql(stringValue.toString)}'"
case stringValue: String => s"'${escapeSql(stringValue)}'"
case _ => value
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.SparkEnv
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.sources
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

private[sql] object ParquetFilters {
val PARQUET_FILTER_DATA = "org.apache.spark.sql.parquet.row.filter"
Expand Down Expand Up @@ -65,7 +64,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.eq(
binaryColumn(n),
Expand All @@ -86,7 +85,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[UTF8String].getBytes)).orNull)
Option(v).map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8"))).orNull)
case BinaryType =>
(n: String, v: Any) => FilterApi.notEq(
binaryColumn(n),
Expand All @@ -104,7 +103,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.lt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.lt(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.lt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -121,7 +121,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.ltEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.ltEq(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.ltEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -138,7 +139,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gt(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.gt(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gt(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -155,7 +157,8 @@ private[sql] object ParquetFilters {
(n: String, v: Any) => FilterApi.gtEq(doubleColumn(n), v.asInstanceOf[java.lang.Double])
case StringType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[UTF8String].getBytes))
FilterApi.gtEq(binaryColumn(n),
Binary.fromByteArray(v.asInstanceOf[String].getBytes("utf-8")))
case BinaryType =>
(n: String, v: Any) =>
FilterApi.gtEq(binaryColumn(n), Binary.fromByteArray(v.asInstanceOf[Array[Byte]]))
Expand All @@ -177,7 +180,7 @@ private[sql] object ParquetFilters {
case StringType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
SetInFilter(v.map(e => Binary.fromByteArray(e.asInstanceOf[UTF8String].getBytes))))
SetInFilter(v.map(s => Binary.fromByteArray(s.asInstanceOf[String].getBytes("utf-8")))))
case BinaryType =>
(n: String, v: Set[Any]) =>
FilterApi.userDefined(binaryColumn(n),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import scala.language.existentials

import org.apache.spark.rdd.RDD
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -78,6 +79,9 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
case StringStartsWith("c", v) => _.startsWith(v)
case StringEndsWith("c", v) => _.endsWith(v)
case StringContains("c", v) => _.contains(v)
case EqualTo("c", v: String) => _.equals(v)
case EqualTo("c", v: UTF8String) => sys.error("UTF8String should not appear in filters")
case In("c", values) => (s: String) => values.map(_.asInstanceOf[String]).toSet.contains(s)
case _ => (c: String) => true
}

Expand Down Expand Up @@ -237,6 +241,9 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%eE%'", 1)
testPushDown("SELECT a, b, c FROM oneToTenFiltered WHERE c like '%Ee%'", 0)

testPushDown("SELECT c FROM oneToTenFiltered WHERE c = 'aaaaaAAAAA'", 1)
testPushDown("SELECT c FROM oneToTenFiltered WHERE c IN ('aaaaaAAAAA', 'foo')", 1)

def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution
Expand Down