Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ private[sql] object OrcFilters {
}
}

// Since ORC 1.5.0 (ORC-323), we need to quote for column names with `.` characters
// in order to distinguish predicate pushdown for nested columns.
private def quoteAttributeNameIfNeeded(name: String) : String = {
if (!name.contains("`") && name.contains(".")) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this condition take the backtick in column name into account? For instance,

>>> spark.range(1).toDF("abc`.abc").show()
+--------+
|abc`.abc|
+--------+
|       0|
+--------+

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for review. I'll consider that, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon . Actually, Spark 2.3.2 ORC (native/hive) doesn't support a backtick character in column names. It fails on writing operation. And, although Spark 2.4.0 broadens the supported special characters like . and " in column names, the backtick character is not handled yet.

So, for that one, I'll proceed in another PR since it's an improvement instead of a regression.

Also, cc @gatorsmile and @dbtsai .

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For ORC and AVRO improvement, SPARK-25722 is created.

s"`$name`"
} else {
name
}
}

/**
* Create ORC filter as a SearchArgument instance.
*/
Expand Down Expand Up @@ -178,38 +188,47 @@ private[sql] object OrcFilters {
// wrapped by a "parent" predicate (`And`, `Or`, or `Not`).

case EqualTo(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValue = castLiteralValue(value, dataTypeMap(attribute))
Some(builder.startAnd().equals(attribute, getType(attribute), castedValue).end())
Some(builder.startAnd().equals(quotedName, getType(attribute), castedValue).end())

case EqualNullSafe(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValue = castLiteralValue(value, dataTypeMap(attribute))
Some(builder.startAnd().nullSafeEquals(attribute, getType(attribute), castedValue).end())
Some(builder.startAnd().nullSafeEquals(quotedName, getType(attribute), castedValue).end())

case LessThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValue = castLiteralValue(value, dataTypeMap(attribute))
Some(builder.startAnd().lessThan(attribute, getType(attribute), castedValue).end())
Some(builder.startAnd().lessThan(quotedName, getType(attribute), castedValue).end())

case LessThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValue = castLiteralValue(value, dataTypeMap(attribute))
Some(builder.startAnd().lessThanEquals(attribute, getType(attribute), castedValue).end())
Some(builder.startAnd().lessThanEquals(quotedName, getType(attribute), castedValue).end())

case GreaterThan(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValue = castLiteralValue(value, dataTypeMap(attribute))
Some(builder.startNot().lessThanEquals(attribute, getType(attribute), castedValue).end())
Some(builder.startNot().lessThanEquals(quotedName, getType(attribute), castedValue).end())

case GreaterThanOrEqual(attribute, value) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValue = castLiteralValue(value, dataTypeMap(attribute))
Some(builder.startNot().lessThan(attribute, getType(attribute), castedValue).end())
Some(builder.startNot().lessThan(quotedName, getType(attribute), castedValue).end())

case IsNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
Some(builder.startAnd().isNull(attribute, getType(attribute)).end())
val quotedName = quoteAttributeNameIfNeeded(attribute)
Some(builder.startAnd().isNull(quotedName, getType(attribute)).end())

case IsNotNull(attribute) if isSearchableType(dataTypeMap(attribute)) =>
Some(builder.startNot().isNull(attribute, getType(attribute)).end())
val quotedName = quoteAttributeNameIfNeeded(attribute)
Some(builder.startNot().isNull(quotedName, getType(attribute)).end())

case In(attribute, values) if isSearchableType(dataTypeMap(attribute)) =>
val quotedName = quoteAttributeNameIfNeeded(attribute)
val castedValues = values.map(v => castLiteralValue(v, dataTypeMap(attribute)))
Some(builder.startAnd().in(attribute, getType(attribute),
Some(builder.startAnd().in(quotedName, getType(attribute),
castedValues.map(_.asInstanceOf[AnyRef]): _*).end())

case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -445,16 +445,7 @@ abstract class OrcQueryTest extends OrcTest {
test("Support for pushing down filters for decimal types") {
withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
val data = (0 until 10).map(i => Tuple1(BigDecimal.valueOf(i)))
withTempPath { file =>
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
spark.createDataFrame(data).toDF("a").repartition(10)
.write.orc(file.getCanonicalPath)
val df = spark.read.orc(file.getCanonicalPath).where("a == 2")
val actual = stripSparkFilter(df).count()

assert(actual < 10)
}
checkPredicatePushDown(spark.createDataFrame(data).toDF("a"), 10, "a == 2")
}
}

Expand All @@ -465,16 +456,7 @@ abstract class OrcQueryTest extends OrcTest {
val milliseconds = Timestamp.valueOf(timeString).getTime + i * 3600
Tuple1(new Timestamp(milliseconds))
}
withTempPath { file =>
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
spark.createDataFrame(data).toDF("a").repartition(10)
.write.orc(file.getCanonicalPath)
val df = spark.read.orc(file.getCanonicalPath).where(s"a == '$timeString'")
val actual = stripSparkFilter(df).count()

assert(actual < 10)
}
checkPredicatePushDown(spark.createDataFrame(data).toDF("a"), 10, s"a == '$timeString'")
}
}

Expand Down Expand Up @@ -674,6 +656,12 @@ class OrcQuerySuite extends OrcQueryTest with SharedSQLContext {
}
}

test("SPARK-25579 ORC PPD should support column names with dot") {
withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") {
checkPredicatePushDown(spark.range(10).toDF("col.dot"), 10, "`col.dot` == 2")
}
}

test("SPARK-20728 Make ORCFileFormat configurable between sql/hive and sql/core") {
withSQLConf(SQLConf.ORC_IMPLEMENTATION.key -> "hive") {
val e = intercept[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,14 @@ abstract class OrcTest extends QueryTest with SQLTestUtils with BeforeAndAfterAl
df: DataFrame, path: File): Unit = {
df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath)
}

protected def checkPredicatePushDown(df: DataFrame, numRows: Int, predicate: String): Unit = {
Copy link
Member Author

@dongjoon-hyun dongjoon-hyun Oct 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HyukjinKwon . I refactor this since it's repeated three times now.
And, this function should be here because the existing two instances are in OrcQueryTest and new instance is in OrcQuerySuite. There is another similar instance, but I skipped it because it's not the same pattern.

withTempPath { file =>
// It needs to repartition data so that we can have several ORC files
// in order to skip stripes in ORC.
df.repartition(numRows).write.orc(file.getCanonicalPath)
val actual = stripSparkFilter(spark.read.orc(file.getCanonicalPath).where(predicate)).count()
assert(actual < numRows)
}
}
}