-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-31736][SQL] Nested column aliasing for RepartitionByExpression/Join #28556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2c95e81
b77a1ba
db601df
f720bdf
719a2ad
ce5d8dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,8 @@ object NestedColumnAliasing { | |
| : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = plan match { | ||
| case Project(projectList, child) | ||
| if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => | ||
| getAliasSubMap(projectList) | ||
| val exprCandidatesToPrune = projectList ++ child.expressions | ||
| getAliasSubMap(exprCandidatesToPrune, child.producedAttributes.toSeq) | ||
|
|
||
| case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) => | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No big deal but I would rename
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will change it in other PR. Thanks. |
||
| val exprCandidatesToPrune = plan.expressions | ||
|
|
@@ -53,11 +54,11 @@ object NestedColumnAliasing { | |
| case Project(projectList, child) => | ||
| Project( | ||
| getNewProjectList(projectList, nestedFieldToAlias), | ||
| replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases)) | ||
| replaceWithAliases(child, nestedFieldToAlias, attrToAliases)) | ||
|
|
||
| // The operators reaching here was already guarded by `canPruneOn`. | ||
| case other => | ||
| replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases) | ||
| replaceWithAliases(other, nestedFieldToAlias, attrToAliases) | ||
| } | ||
|
|
||
| /** | ||
|
|
@@ -73,9 +74,10 @@ object NestedColumnAliasing { | |
| } | ||
|
|
||
| /** | ||
| * Return a plan with new children replaced with aliases. | ||
| * Return a plan with new children replaced with aliases, and expressions replaced with | ||
| * aliased attributes. | ||
| */ | ||
| def replaceChildrenWithAliases( | ||
| def replaceWithAliases( | ||
| plan: LogicalPlan, | ||
| nestedFieldToAlias: Map[ExtractValue, Alias], | ||
| attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { | ||
|
|
@@ -104,6 +106,8 @@ object NestedColumnAliasing { | |
| case _: LocalLimit => true | ||
| case _: Repartition => true | ||
| case _: Sample => true | ||
| case _: RepartitionByExpression => true | ||
| case _: Join => true | ||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| case _ => false | ||
| } | ||
|
|
||
|
|
@@ -202,7 +206,9 @@ object GeneratorNestedColumnAliasing { | |
| val exprsToPrune = projectList ++ g.generator.children | ||
| NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map { | ||
| case (nestedFieldToAlias, attrToAliases) => | ||
| val newChild = pruneGenerate(g, nestedFieldToAlias, attrToAliases) | ||
| // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. | ||
| val newChild = | ||
| NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) | ||
| Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) | ||
| } | ||
|
|
||
|
|
@@ -215,21 +221,14 @@ object GeneratorNestedColumnAliasing { | |
| NestedColumnAliasing.getAliasSubMap( | ||
| g.generator.children, g.requiredChildOutput).map { | ||
| case (nestedFieldToAlias, attrToAliases) => | ||
| pruneGenerate(g, nestedFieldToAlias, attrToAliases) | ||
| // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. | ||
| NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) | ||
| } | ||
|
|
||
| case _ => | ||
| None | ||
| } | ||
|
|
||
| private def pruneGenerate( | ||
| g: Generate, | ||
| nestedFieldToAlias: Map[ExtractValue, Alias], | ||
| attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { | ||
| // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. | ||
| NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases) | ||
| } | ||
|
|
||
HyukjinKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /** | ||
| * This is a while-list for pruning nested fields at `Generator`. | ||
| */ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -144,7 +144,6 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { | |
| test("Pushing a single nested field projection - negative") { | ||
| val ops = Seq( | ||
| (input: LogicalPlan) => input.distribute('name)(1), | ||
| (input: LogicalPlan) => input.distribute($"name.middle")(1), | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, looks nice. This PR could support this case. |
||
| (input: LogicalPlan) => input.orderBy('name.asc), | ||
| (input: LogicalPlan) => input.orderBy($"name.middle".asc), | ||
| (input: LogicalPlan) => input.sortBy('name.asc), | ||
|
|
@@ -342,6 +341,89 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { | |
| comparePlans(optimized, expected) | ||
| } | ||
|
|
||
| test("Nested field pruning through RepartitionByExpression") { | ||
| val query1 = contact | ||
| .distribute($"id")(1) | ||
| .select($"name.middle") | ||
| .analyze | ||
| val optimized1 = Optimize.execute(query1) | ||
|
|
||
| val aliases1 = collectGeneratedAliases(optimized1) | ||
|
|
||
| val expected1 = contact | ||
| .select('id, 'name.getField("middle").as(aliases1(0))) | ||
| .distribute($"id")(1) | ||
| .select($"${aliases1(0)}".as("middle")) | ||
| .analyze | ||
| comparePlans(optimized1, expected1) | ||
|
|
||
| val query2 = contact | ||
| .distribute($"name.middle")(1) | ||
| .select($"name.middle") | ||
| .analyze | ||
| val optimized2 = Optimize.execute(query2) | ||
|
|
||
| val aliases2 = collectGeneratedAliases(optimized2) | ||
|
|
||
| val expected2 = contact | ||
| .select('name.getField("middle").as(aliases2(0))) | ||
| .distribute($"${aliases2(0)}")(1) | ||
| .select($"${aliases2(0)}".as("middle")) | ||
| .analyze | ||
| comparePlans(optimized2, expected2) | ||
|
|
||
| val query3 = contact | ||
| .select($"name") | ||
| .distribute($"name")(1) | ||
| .select($"name.middle") | ||
| .analyze | ||
| val optimized3 = Optimize.execute(query3) | ||
|
|
||
| comparePlans(optimized3, query3) | ||
| } | ||
|
|
||
| test("Nested field pruning through Join") { | ||
| val department = LocalRelation( | ||
| 'depID.int, | ||
| 'personID.string) | ||
|
|
||
| val query1 = contact.join(department, condition = Some($"id" === $"depID")) | ||
| .select($"name.middle") | ||
| .analyze | ||
| val optimized1 = Optimize.execute(query1) | ||
|
|
||
| val aliases1 = collectGeneratedAliases(optimized1) | ||
|
|
||
| val expected1 = contact.select('id, 'name.getField("middle").as(aliases1(0))) | ||
| .join(department.select('depID), condition = Some($"id" === $"depID")) | ||
| .select($"${aliases1(0)}".as("middle")) | ||
| .analyze | ||
| comparePlans(optimized1, expected1) | ||
|
|
||
| val query2 = contact.join(department, condition = Some($"name.middle" === $"personID")) | ||
| .select($"name.first") | ||
| .analyze | ||
| val optimized2 = Optimize.execute(query2) | ||
|
|
||
| val aliases2 = collectGeneratedAliases(optimized2) | ||
|
|
||
| val expected2 = contact.select( | ||
| 'name.getField("first").as(aliases2(0)), | ||
| 'name.getField("middle").as(aliases2(1))) | ||
| .join(department.select('personID), condition = Some($"${aliases2(1)}" === $"personID")) | ||
| .select($"${aliases2(0)}".as("first")) | ||
| .analyze | ||
| comparePlans(optimized2, expected2) | ||
|
|
||
| val contact2 = LocalRelation('name2.struct(name)) | ||
| val query3 = contact.select('name) | ||
| .join(contact2, condition = Some($"name" === $"name2")) | ||
| .select($"name.first") | ||
| .analyze | ||
| val optimized3 = Optimize.execute(query3) | ||
| comparePlans(optimized3, query3) | ||
| } | ||
|
|
||
| test("Nested field pruning for Aggregate") { | ||
| def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = { | ||
| val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,13 +51,19 @@ abstract class SchemaPruningSuite | |
| relatives: Map[String, FullName] = Map.empty, | ||
| employer: Employer = null, | ||
| relations: Map[FullName, String] = Map.empty) | ||
| case class Department( | ||
| depId: Int, | ||
| depName: String, | ||
| contactId: Int, | ||
| employer: Employer) | ||
|
|
||
| val janeDoe = FullName("Jane", "X.", "Doe") | ||
| val johnDoe = FullName("John", "Y.", "Doe") | ||
| val susanSmith = FullName("Susan", "Z.", "Smith") | ||
|
|
||
| val employer = Employer(0, Company("abc", "123 Business Street")) | ||
| val employerWithNullCompany = Employer(1, null) | ||
| val employerWithNullCompany2 = Employer(2, null) | ||
|
|
||
| val contacts = | ||
| Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), | ||
|
|
@@ -66,6 +72,11 @@ abstract class SchemaPruningSuite | |
| Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe), | ||
| employer = employerWithNullCompany, relations = Map(janeDoe -> "sister")) :: Nil | ||
|
|
||
| val departments = | ||
| Department(0, "Engineering", 0, employer) :: | ||
| Department(1, "Marketing", 1, employerWithNullCompany) :: | ||
| Department(2, "Operation", 4, employerWithNullCompany2) :: Nil | ||
|
|
||
| case class Name(first: String, last: String) | ||
| case class BriefContact(id: Int, name: Name, address: String) | ||
|
|
||
|
|
@@ -350,6 +361,83 @@ abstract class SchemaPruningSuite | |
| checkAnswer(query, Row(0) :: Nil) | ||
| } | ||
|
|
||
| testSchemaPruning("select one deep nested complex field after repartition by expression") { | ||
| val query1 = sql("select * from contacts") | ||
| .repartition(100, col("id")) | ||
| .where("employer.company.address is not null") | ||
| .selectExpr("employer.id as employer_id") | ||
| checkScan(query1, | ||
| "struct<id:int,employer:struct<id:int,company:struct<address:string>>>") | ||
| checkAnswer(query1, Row(0) :: Nil) | ||
|
|
||
| val query2 = sql("select * from contacts") | ||
| .repartition(100, col("employer")) | ||
| .where("employer.company.address is not null") | ||
| .selectExpr("employer.id as employer_id") | ||
| checkScan(query2, | ||
| "struct<employer:struct<id:int,company:struct<name:string,address:string>>>") | ||
| checkAnswer(query2, Row(0) :: Nil) | ||
|
|
||
| val query3 = sql("select * from contacts") | ||
| .repartition(100, col("employer.company")) | ||
| .where("employer.company.address is not null") | ||
| .selectExpr("employer.company as employer_company") | ||
| checkScan(query3, | ||
| "struct<employer:struct<company:struct<name:string,address:string>>>") | ||
| checkAnswer(query3, Row(Row("abc", "123 Business Street")) :: Nil) | ||
|
|
||
| val query4 = sql("select * from contacts") | ||
| .repartition(100, col("employer.company.address")) | ||
| .where("employer.company.address is not null") | ||
| .selectExpr("employer.company.address as employer_company_addr") | ||
| checkScan(query4, | ||
| "struct<employer:struct<company:struct<address:string>>>") | ||
| checkAnswer(query4, Row("123 Business Street") :: Nil) | ||
| } | ||
|
|
||
| testSchemaPruning("select one deep nested complex field after join") { | ||
| val query1 = sql("select contacts.name.middle from contacts, departments where " + | ||
| "contacts.id = departments.contactId") | ||
| checkScan(query1, | ||
| "struct<id:int,name:struct<middle:string>>", | ||
| "struct<contactId:int>") | ||
| checkAnswer(query1, Row("X.") :: Row("Y.") :: Nil) | ||
|
|
||
| val query2 = sql("select contacts.name.middle from contacts, departments where " + | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: I think its better to use uppercases for SQL keywords where possible.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems all tests in this test suite are using lowercases. Changing all tests seems too bothering... :) |
||
| "contacts.employer = departments.employer") | ||
| checkScan(query2, | ||
| "struct<name:struct<middle:string>," + | ||
| "employer:struct<id:int,company:struct<name:string,address:string>>>", | ||
| "struct<employer:struct<id:int,company:struct<name:string,address:string>>>") | ||
| checkAnswer(query2, Row("X.") :: Row("Y.") :: Nil) | ||
|
|
||
| val query3 = sql("select contacts.employer.company.name from contacts, departments where " + | ||
| "contacts.employer = departments.employer") | ||
| checkScan(query3, | ||
| "struct<employer:struct<id:int,company:struct<name:string,address:string>>>", | ||
| "struct<employer:struct<id:int,company:struct<name:string,address:string>>>") | ||
| checkAnswer(query3, Row("abc") :: Row(null) :: Nil) | ||
| } | ||
|
|
||
| testSchemaPruning("select one deep nested complex field after outer join") { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding the tests. |
||
| val query1 = sql("select departments.contactId, contacts.name.middle from departments " + | ||
| "left outer join contacts on departments.contactId = contacts.id") | ||
| checkScan(query1, | ||
| "struct<contactId:int>", | ||
| "struct<id:int,name:struct<middle:string>>") | ||
| checkAnswer(query1, Row(0, "X.") :: Row(1, "Y.") :: Row(4, null) :: Nil) | ||
|
|
||
| val query2 = sql("select contacts.name.first, departments.employer.company.name " + | ||
| "from contacts right outer join departments on contacts.id = departments.contactId") | ||
| checkScan(query2, | ||
| "struct<id:int,name:struct<first:string>>", | ||
| "struct<contactId:int,employer:struct<company:struct<name:string>>>") | ||
| checkAnswer(query2, | ||
| Row("Jane", "abc") :: | ||
| Row("John", null) :: | ||
| Row(null, null) :: Nil) | ||
| } | ||
|
|
||
| testSchemaPruning("select nested field in aggregation function of Aggregate") { | ||
| val query1 = sql("select count(name.first) from contacts group by name.last") | ||
| checkScan(query1, "struct<name:struct<first:string,last:string>>") | ||
|
|
@@ -439,6 +527,7 @@ abstract class SchemaPruningSuite | |
|
|
||
| makeDataSourceFile(contacts, new File(path + "/contacts/p=1")) | ||
| makeDataSourceFile(briefContacts, new File(path + "/contacts/p=2")) | ||
| makeDataSourceFile(departments, new File(path + "/departments")) | ||
|
|
||
| // Providing user specified schema. Inferred schema from different data sources might | ||
| // be different. | ||
|
|
@@ -451,6 +540,11 @@ abstract class SchemaPruningSuite | |
| spark.read.format(dataSourceName).schema(schema).load(path + "/contacts") | ||
| .createOrReplaceTempView("contacts") | ||
|
|
||
| val departmentScahem = "`depId` INT,`depName` STRING,`contactId` INT, " + | ||
| "`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, `address`: STRING>>" | ||
| spark.read.format(dataSourceName).schema(departmentScahem).load(path + "/departments") | ||
| .createOrReplaceTempView("departments") | ||
|
|
||
| testThunk | ||
| } | ||
| } | ||
|
|
@@ -461,6 +555,7 @@ abstract class SchemaPruningSuite | |
|
|
||
| makeDataSourceFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1")) | ||
| makeDataSourceFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2")) | ||
| makeDataSourceFile(departments, new File(path + "/departments")) | ||
|
|
||
| // Providing user specified schema. Inferred schema from different data sources might | ||
| // be different. | ||
|
|
@@ -473,6 +568,11 @@ abstract class SchemaPruningSuite | |
| spark.read.format(dataSourceName).schema(schema).load(path + "/contacts") | ||
| .createOrReplaceTempView("contacts") | ||
|
|
||
| val departmentScahem = "`depId` INT,`depName` STRING,`contactId` INT, " + | ||
| "`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, `address`: STRING>>" | ||
| spark.read.format(dataSourceName).schema(departmentScahem).load(path + "/departments") | ||
| .createOrReplaceTempView("departments") | ||
|
|
||
| testThunk | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@viirya, just to clarify, you added
producedAttributeshere just to be safe but not related to the current changes (?). SeemsJoinandRepartitionByExpressionhave an emptyproducedAttributes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, if it's going to output, it shouldn't be pruned anyway.