Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ object NestedColumnAliasing {
: Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = plan match {
case Project(projectList, child)
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
getAliasSubMap(projectList)
val exprCandidatesToPrune = projectList ++ child.expressions
getAliasSubMap(exprCandidatesToPrune, child.producedAttributes.toSeq)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya, just to clarify, you added producedAttributes here just to be safe but not related to the current changes (?). Seems Join and RepartitionByExpression have an empty producedAttributes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, if it's going to output, it shouldn't be pruned anyway.


case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) =>
Copy link
Member

@HyukjinKwon HyukjinKwon Jun 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No big deal but I would rename plan to p to avoid shadowing the plan argument. At least my IDE complains on that.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change it in other PR. Thanks.

val exprCandidatesToPrune = plan.expressions
Expand All @@ -53,11 +54,11 @@ object NestedColumnAliasing {
case Project(projectList, child) =>
Project(
getNewProjectList(projectList, nestedFieldToAlias),
replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases))
replaceWithAliases(child, nestedFieldToAlias, attrToAliases))

// The operators reaching here was already guarded by `canPruneOn`.
case other =>
replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases)
replaceWithAliases(other, nestedFieldToAlias, attrToAliases)
}

/**
Expand All @@ -73,9 +74,10 @@ object NestedColumnAliasing {
}

/**
* Return a plan with new children replaced with aliases.
* Return a plan with new children replaced with aliases, and expressions replaced with
* aliased attributes.
*/
def replaceChildrenWithAliases(
def replaceWithAliases(
plan: LogicalPlan,
nestedFieldToAlias: Map[ExtractValue, Alias],
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
Expand Down Expand Up @@ -104,6 +106,8 @@ object NestedColumnAliasing {
case _: LocalLimit => true
case _: Repartition => true
case _: Sample => true
case _: RepartitionByExpression => true
case _: Join => true
case _ => false
}

Expand Down Expand Up @@ -202,7 +206,9 @@ object GeneratorNestedColumnAliasing {
val exprsToPrune = projectList ++ g.generator.children
NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map {
case (nestedFieldToAlias, attrToAliases) =>
val newChild = pruneGenerate(g, nestedFieldToAlias, attrToAliases)
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
val newChild =
NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
}

Expand All @@ -215,21 +221,14 @@ object GeneratorNestedColumnAliasing {
NestedColumnAliasing.getAliasSubMap(
g.generator.children, g.requiredChildOutput).map {
case (nestedFieldToAlias, attrToAliases) =>
pruneGenerate(g, nestedFieldToAlias, attrToAliases)
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases)
}

case _ =>
None
}

private def pruneGenerate(
g: Generate,
nestedFieldToAlias: Map[ExtractValue, Alias],
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases)
}

/**
* This is a while-list for pruning nested fields at `Generator`.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
test("Pushing a single nested field projection - negative") {
val ops = Seq(
(input: LogicalPlan) => input.distribute('name)(1),
(input: LogicalPlan) => input.distribute($"name.middle")(1),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, looks nice. This PR could support this case.

(input: LogicalPlan) => input.orderBy('name.asc),
(input: LogicalPlan) => input.orderBy($"name.middle".asc),
(input: LogicalPlan) => input.sortBy('name.asc),
Expand Down Expand Up @@ -342,6 +341,89 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
comparePlans(optimized, expected)
}

test("Nested field pruning through RepartitionByExpression") {
val query1 = contact
.distribute($"id")(1)
.select($"name.middle")
.analyze
val optimized1 = Optimize.execute(query1)

val aliases1 = collectGeneratedAliases(optimized1)

val expected1 = contact
.select('id, 'name.getField("middle").as(aliases1(0)))
.distribute($"id")(1)
.select($"${aliases1(0)}".as("middle"))
.analyze
comparePlans(optimized1, expected1)

val query2 = contact
.distribute($"name.middle")(1)
.select($"name.middle")
.analyze
val optimized2 = Optimize.execute(query2)

val aliases2 = collectGeneratedAliases(optimized2)

val expected2 = contact
.select('name.getField("middle").as(aliases2(0)))
.distribute($"${aliases2(0)}")(1)
.select($"${aliases2(0)}".as("middle"))
.analyze
comparePlans(optimized2, expected2)

val query3 = contact
.select($"name")
.distribute($"name")(1)
.select($"name.middle")
.analyze
val optimized3 = Optimize.execute(query3)

comparePlans(optimized3, query3)
}

test("Nested field pruning through Join") {
val department = LocalRelation(
'depID.int,
'personID.string)

val query1 = contact.join(department, condition = Some($"id" === $"depID"))
.select($"name.middle")
.analyze
val optimized1 = Optimize.execute(query1)

val aliases1 = collectGeneratedAliases(optimized1)

val expected1 = contact.select('id, 'name.getField("middle").as(aliases1(0)))
.join(department.select('depID), condition = Some($"id" === $"depID"))
.select($"${aliases1(0)}".as("middle"))
.analyze
comparePlans(optimized1, expected1)

val query2 = contact.join(department, condition = Some($"name.middle" === $"personID"))
.select($"name.first")
.analyze
val optimized2 = Optimize.execute(query2)

val aliases2 = collectGeneratedAliases(optimized2)

val expected2 = contact.select(
'name.getField("first").as(aliases2(0)),
'name.getField("middle").as(aliases2(1)))
.join(department.select('personID), condition = Some($"${aliases2(1)}" === $"personID"))
.select($"${aliases2(0)}".as("first"))
.analyze
comparePlans(optimized2, expected2)

val contact2 = LocalRelation('name2.struct(name))
val query3 = contact.select('name)
.join(contact2, condition = Some($"name" === $"name2"))
.select($"name.first")
.analyze
val optimized3 = Optimize.execute(query3)
comparePlans(optimized3, query3)
}

test("Nested field pruning for Aggregate") {
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,19 @@ abstract class SchemaPruningSuite
relatives: Map[String, FullName] = Map.empty,
employer: Employer = null,
relations: Map[FullName, String] = Map.empty)
case class Department(
depId: Int,
depName: String,
contactId: Int,
employer: Employer)

val janeDoe = FullName("Jane", "X.", "Doe")
val johnDoe = FullName("John", "Y.", "Doe")
val susanSmith = FullName("Susan", "Z.", "Smith")

val employer = Employer(0, Company("abc", "123 Business Street"))
val employerWithNullCompany = Employer(1, null)
val employerWithNullCompany2 = Employer(2, null)

val contacts =
Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith),
Expand All @@ -66,6 +72,11 @@ abstract class SchemaPruningSuite
Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe),
employer = employerWithNullCompany, relations = Map(janeDoe -> "sister")) :: Nil

val departments =
Department(0, "Engineering", 0, employer) ::
Department(1, "Marketing", 1, employerWithNullCompany) ::
Department(2, "Operation", 4, employerWithNullCompany2) :: Nil

case class Name(first: String, last: String)
case class BriefContact(id: Int, name: Name, address: String)

Expand Down Expand Up @@ -350,6 +361,83 @@ abstract class SchemaPruningSuite
checkAnswer(query, Row(0) :: Nil)
}

testSchemaPruning("select one deep nested complex field after repartition by expression") {
val query1 = sql("select * from contacts")
.repartition(100, col("id"))
.where("employer.company.address is not null")
.selectExpr("employer.id as employer_id")
checkScan(query1,
"struct<id:int,employer:struct<id:int,company:struct<address:string>>>")
checkAnswer(query1, Row(0) :: Nil)

val query2 = sql("select * from contacts")
.repartition(100, col("employer"))
.where("employer.company.address is not null")
.selectExpr("employer.id as employer_id")
checkScan(query2,
"struct<employer:struct<id:int,company:struct<name:string,address:string>>>")
checkAnswer(query2, Row(0) :: Nil)

val query3 = sql("select * from contacts")
.repartition(100, col("employer.company"))
.where("employer.company.address is not null")
.selectExpr("employer.company as employer_company")
checkScan(query3,
"struct<employer:struct<company:struct<name:string,address:string>>>")
checkAnswer(query3, Row(Row("abc", "123 Business Street")) :: Nil)

val query4 = sql("select * from contacts")
.repartition(100, col("employer.company.address"))
.where("employer.company.address is not null")
.selectExpr("employer.company.address as employer_company_addr")
checkScan(query4,
"struct<employer:struct<company:struct<address:string>>>")
checkAnswer(query4, Row("123 Business Street") :: Nil)
}

testSchemaPruning("select one deep nested complex field after join") {
val query1 = sql("select contacts.name.middle from contacts, departments where " +
"contacts.id = departments.contactId")
checkScan(query1,
"struct<id:int,name:struct<middle:string>>",
"struct<contactId:int>")
checkAnswer(query1, Row("X.") :: Row("Y.") :: Nil)

val query2 = sql("select contacts.name.middle from contacts, departments where " +
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think its better to use uppercases for SQL keywords where possible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems all tests in this test suite are using lowercases. Changing all tests seems too bothering... :)

"contacts.employer = departments.employer")
checkScan(query2,
"struct<name:struct<middle:string>," +
"employer:struct<id:int,company:struct<name:string,address:string>>>",
"struct<employer:struct<id:int,company:struct<name:string,address:string>>>")
checkAnswer(query2, Row("X.") :: Row("Y.") :: Nil)

val query3 = sql("select contacts.employer.company.name from contacts, departments where " +
"contacts.employer = departments.employer")
checkScan(query3,
"struct<employer:struct<id:int,company:struct<name:string,address:string>>>",
"struct<employer:struct<id:int,company:struct<name:string,address:string>>>")
checkAnswer(query3, Row("abc") :: Row(null) :: Nil)
}

testSchemaPruning("select one deep nested complex field after outer join") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the tests.

val query1 = sql("select departments.contactId, contacts.name.middle from departments " +
"left outer join contacts on departments.contactId = contacts.id")
checkScan(query1,
"struct<contactId:int>",
"struct<id:int,name:struct<middle:string>>")
checkAnswer(query1, Row(0, "X.") :: Row(1, "Y.") :: Row(4, null) :: Nil)

val query2 = sql("select contacts.name.first, departments.employer.company.name " +
"from contacts right outer join departments on contacts.id = departments.contactId")
checkScan(query2,
"struct<id:int,name:struct<first:string>>",
"struct<contactId:int,employer:struct<company:struct<name:string>>>")
checkAnswer(query2,
Row("Jane", "abc") ::
Row("John", null) ::
Row(null, null) :: Nil)
}

testSchemaPruning("select nested field in aggregation function of Aggregate") {
val query1 = sql("select count(name.first) from contacts group by name.last")
checkScan(query1, "struct<name:struct<first:string,last:string>>")
Expand Down Expand Up @@ -439,6 +527,7 @@ abstract class SchemaPruningSuite

makeDataSourceFile(contacts, new File(path + "/contacts/p=1"))
makeDataSourceFile(briefContacts, new File(path + "/contacts/p=2"))
makeDataSourceFile(departments, new File(path + "/departments"))

// Providing user specified schema. Inferred schema from different data sources might
// be different.
Expand All @@ -451,6 +540,11 @@ abstract class SchemaPruningSuite
spark.read.format(dataSourceName).schema(schema).load(path + "/contacts")
.createOrReplaceTempView("contacts")

val departmentScahem = "`depId` INT,`depName` STRING,`contactId` INT, " +
"`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, `address`: STRING>>"
spark.read.format(dataSourceName).schema(departmentScahem).load(path + "/departments")
.createOrReplaceTempView("departments")

testThunk
}
}
Expand All @@ -461,6 +555,7 @@ abstract class SchemaPruningSuite

makeDataSourceFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1"))
makeDataSourceFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2"))
makeDataSourceFile(departments, new File(path + "/departments"))

// Providing user specified schema. Inferred schema from different data sources might
// be different.
Expand All @@ -473,6 +568,11 @@ abstract class SchemaPruningSuite
spark.read.format(dataSourceName).schema(schema).load(path + "/contacts")
.createOrReplaceTempView("contacts")

val departmentScahem = "`depId` INT,`depName` STRING,`contactId` INT, " +
"`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, `address`: STRING>>"
spark.read.format(dataSourceName).schema(departmentScahem).load(path + "/departments")
.createOrReplaceTempView("departments")

testThunk
}
}
Expand Down