Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
object NestedColumnAliasing {

def unapply(plan: LogicalPlan)
: Option[(Map[GetStructField, Alias], Map[ExprId, Seq[Alias]])] = plan match {
: Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = plan match {
case Project(projectList, child)
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
Comment thread
dongjoon-hyun marked this conversation as resolved.
getAliasSubMap(projectList)
Expand All @@ -43,7 +43,7 @@ object NestedColumnAliasing {
*/
def replaceToAliases(
plan: LogicalPlan,
nestedFieldToAlias: Map[GetStructField, Alias],
nestedFieldToAlias: Map[ExtractValue, Alias],
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = plan match {
case Project(projectList, child) =>
Project(
Expand All @@ -56,9 +56,9 @@ object NestedColumnAliasing {
*/
private def getNewProjectList(
projectList: Seq[NamedExpression],
nestedFieldToAlias: Map[GetStructField, Alias]): Seq[NamedExpression] = {
nestedFieldToAlias: Map[ExtractValue, Alias]): Seq[NamedExpression] = {
projectList.map(_.transform {
case f: GetStructField if nestedFieldToAlias.contains(f) =>
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
nestedFieldToAlias(f).toAttribute
}.asInstanceOf[NamedExpression])
}
Expand Down Expand Up @@ -86,32 +86,39 @@ object NestedColumnAliasing {
}

/**
* Return root references that are individually accessed as a whole, and `GetStructField`s.
* Return root references that are individually accessed as a whole, and `GetStructField`s
* or `GetArrayStructField`s which on top of other `ExtractValue`s or special expressions.
* Check `SelectedField` to see which expressions should be listed here.
*/
private def collectRootReferenceAndGetStructField(e: Expression): Seq[Expression] = e match {
case _: AttributeReference | _: GetStructField => Seq(e)
case es if es.children.nonEmpty => es.children.flatMap(collectRootReferenceAndGetStructField)
private def collectRootReferenceAndExtractValue(e: Expression): Seq[Expression] = e match {
case _: AttributeReference => Seq(e)
case GetStructField(_: ExtractValue | _: AttributeReference, _, _) => Seq(e)
case GetArrayStructFields(_: MapValues |
_: MapKeys |
_: ExtractValue |
_: AttributeReference, _, _, _, _) => Seq(e)
Comment thread
dongjoon-hyun marked this conversation as resolved.
case es if es.children.nonEmpty => es.children.flatMap(collectRootReferenceAndExtractValue)
case _ => Seq.empty
}

/**
* Return two maps in order to replace nested fields to aliases.
*
* 1. GetStructField -> Alias: A new alias is created for each nested field.
* 1. ExtractValue -> Alias: A new alias is created for each nested field.
* 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases pointing it.
*/
private def getAliasSubMap(projectList: Seq[NamedExpression])
: Option[(Map[GetStructField, Alias], Map[ExprId, Seq[Alias]])] = {
: Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = {
val (nestedFieldReferences, otherRootReferences) =
projectList.flatMap(collectRootReferenceAndGetStructField).partition {
case _: GetStructField => true
projectList.flatMap(collectRootReferenceAndExtractValue).partition {
case _: ExtractValue => true
case _ => false
}

val aliasSub = nestedFieldReferences.asInstanceOf[Seq[GetStructField]]
val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]]
.filter(!_.references.subsetOf(AttributeSet(otherRootReferences)))
.groupBy(_.references.head)
.flatMap { case (attr, nestedFields: Seq[GetStructField]) =>
.flatMap { case (attr, nestedFields: Seq[ExtractValue]) =>
// Each expression can contain multiple nested fields.
// Note that we keep the original names to deliver to parquet in a case-sensitive way.
val nestedFieldToAlias = nestedFields.distinct.map { f =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ object SQLConf {
"reading unnecessary nested column data. Currently Parquet and ORC are the " +
"data sources that implement this optimization.")
.booleanConf
.createWithDefault(false)
.createWithDefault(true)
Comment thread
dongjoon-hyun marked this conversation as resolved.
Outdated

val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.sql.types.{StringType, StructField, StructType}

class NestedColumnAliasingSuite extends SchemaPruningTest {

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There still are many usage of GetStructField in this test suite. Maybe make a minor PR to rewrite them.


Expand Down Expand Up @@ -221,6 +221,51 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
comparePlans(optimized, expected)
}

test("nested field pruning for getting struct field in array of struct") {
val field1 = GetArrayStructFields(child = 'friends,
field = StructField("first", StringType),
ordinal = 0,
numFields = 3,
containsNull = true)
val field2 = GetStructField('employer, 0, Some("id"))

val query = contact
.limit(5)
.select(field1, field2)
.analyze

val optimized = Optimize.execute(query)

val expected = contact
.select(field1, field2)
.limit(5)
.analyze
comparePlans(optimized, expected)
}

test("nested field pruning for getting struct field in map") {
val field1 = GetStructField(GetMapValue('relatives, Literal("key")), 0, Some("first"))
val field2 = GetArrayStructFields(child = MapValues('relatives),
field = StructField("middle", StringType),
ordinal = 1,
numFields = 3,
containsNull = true)

val query = contact
.limit(5)
.select(field1, field2)
.analyze

val optimized = Optimize.execute(query)

val expected = contact
.select(field1, field2)
.limit(5)
.analyze
comparePlans(optimized, expected)
}


private def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
val aliases = ArrayBuffer[String]()
query.transformAllExpressions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {
protected val N = 1000000
protected val numIters = 10

// We use `col1 BIGINT, col2 STRUCT<_1: BIGINT, _2: STRING>` as a test schema.
// col1 and col2._1 is used for comparision. col2._2 mimics the burden for the other columns
// We use `col1 BIGINT, col2 STRUCT<_1: BIGINT, _2: STRING>,
// col3 ARRAY<STRUCT<_1: BIGINT, _2: STRING>>` as a test schema.
// col1, col2._1 and col3._1 are used for comparision. col2._2 and col3._2 mimics the burden
// for the other columns
Comment thread
dongjoon-hyun marked this conversation as resolved.
private val df = spark
.range(N * 10)
.sample(false, 0.1)
.map(x => (x, (x, s"$x" * 100)))
.toDF("col1", "col2")
.map { x =>
val col3 = (0 until 10).map(i => (x + i, s"$x" * 10))
(x, (x, s"$x" * 100), col3)
}.toDF("col1", "col2", "col3")
Comment thread
dongjoon-hyun marked this conversation as resolved.

private def addCase(benchmark: Benchmark, name: String, sql: String): Unit = {
benchmark.addCase(name) { _ =>
Expand All @@ -60,6 +64,7 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {

addCase(benchmark, "Top-level column", "SELECT col1 FROM (SELECT col1 FROM t1)")
addCase(benchmark, "Nested column", "SELECT col2._1 FROM (SELECT col2 FROM t2)")
addCase(benchmark, "Nested column in array", "SELECT col3._1 FROM (SELECT col3 FROM t2)")

benchmark.run()
}
Expand All @@ -80,6 +85,8 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {
s"SELECT col1 FROM (SELECT col1 FROM t1 LIMIT ${Int.MaxValue})")
addCase(benchmark, "Nested column",
s"SELECT col2._1 FROM (SELECT col2 FROM t2 LIMIT ${Int.MaxValue})")
addCase(benchmark, "Nested column in array",
s"SELECT col3._1 FROM (SELECT col3 FROM t2 LIMIT ${Int.MaxValue})")

benchmark.run()
}
Expand All @@ -100,6 +107,8 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {
s"SELECT col1 FROM (SELECT /*+ REPARTITION(1) */ col1 FROM t1)")
addCase(benchmark, "Nested column",
s"SELECT col2._1 FROM (SELECT /*+ REPARTITION(1) */ col2 FROM t2)")
addCase(benchmark, "Nested column in array",
s"SELECT col3._1 FROM (SELECT /*+ REPARTITION(1) */ col3 FROM t2)")

benchmark.run()
}
Expand All @@ -120,6 +129,8 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {
s"SELECT col1 FROM (SELECT col1 FROM t1 DISTRIBUTE BY col1)")
addCase(benchmark, "Nested column",
s"SELECT col2._1 FROM (SELECT col2 FROM t2 DISTRIBUTE BY col2._1)")
addCase(benchmark, "Nested column in array",
s"SELECT col3._1 FROM (SELECT col3 FROM t2 DISTRIBUTE BY col3._1)")

benchmark.run()
}
Expand All @@ -140,6 +151,8 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {
s"SELECT col1 FROM (SELECT col1 FROM t1 TABLESAMPLE(100 percent))")
addCase(benchmark, "Nested column",
s"SELECT col2._1 FROM (SELECT col2 FROM t2 TABLESAMPLE(100 percent))")
addCase(benchmark, "Nested column in array",
s"SELECT col3._1 FROM (SELECT col3 FROM t2 TABLESAMPLE(100 percent))")

benchmark.run()
}
Expand All @@ -158,6 +171,7 @@ abstract class NestedSchemaPruningBenchmark extends SqlBasedBenchmark {

addCase(benchmark, "Top-level column", "SELECT col1 FROM t1 ORDER BY col1")
addCase(benchmark, "Nested column", "SELECT col2._1 FROM t2 ORDER BY col2._1")
addCase(benchmark, "Nested column in array", "SELECT col3._1 FROM t2 ORDER BY col3._1")

benchmark.run()
}
Expand Down