diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index c38a1189387d7..eb8022c8c8404 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -34,7 +34,8 @@ object NestedColumnAliasing { : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = plan match { case Project(projectList, child) if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) => - getAliasSubMap(projectList) + val exprCandidatesToPrune = projectList ++ child.expressions + getAliasSubMap(exprCandidatesToPrune, child.producedAttributes.toSeq) case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) => val exprCandidatesToPrune = plan.expressions @@ -53,11 +54,11 @@ object NestedColumnAliasing { case Project(projectList, child) => Project( getNewProjectList(projectList, nestedFieldToAlias), - replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases)) + replaceWithAliases(child, nestedFieldToAlias, attrToAliases)) // The operators reaching here was already guarded by `canPruneOn`. case other => - replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases) + replaceWithAliases(other, nestedFieldToAlias, attrToAliases) } /** @@ -73,9 +74,10 @@ object NestedColumnAliasing { } /** - * Return a plan with new children replaced with aliases. + * Return a plan with new children replaced with aliases, and expressions replaced with + * aliased attributes. */ - def replaceChildrenWithAliases( + def replaceWithAliases( plan: LogicalPlan, nestedFieldToAlias: Map[ExtractValue, Alias], attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { @@ -104,6 +106,8 @@ object NestedColumnAliasing { case _: LocalLimit => true case _: Repartition => true case _: Sample => true + case _: RepartitionByExpression => true + case _: Join => true case _ => false } @@ -202,7 +206,9 @@ object GeneratorNestedColumnAliasing { val exprsToPrune = projectList ++ g.generator.children NestedColumnAliasing.getAliasSubMap(exprsToPrune, g.qualifiedGeneratorOutput).map { case (nestedFieldToAlias, attrToAliases) => - val newChild = pruneGenerate(g, nestedFieldToAlias, attrToAliases) + // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. + val newChild = + NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) } @@ -215,21 +221,14 @@ object GeneratorNestedColumnAliasing { NestedColumnAliasing.getAliasSubMap( g.generator.children, g.requiredChildOutput).map { case (nestedFieldToAlias, attrToAliases) => - pruneGenerate(g, nestedFieldToAlias, attrToAliases) + // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. + NestedColumnAliasing.replaceWithAliases(g, nestedFieldToAlias, attrToAliases) } case _ => None } - private def pruneGenerate( - g: Generate, - nestedFieldToAlias: Map[ExtractValue, Alias], - attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases) - } - /** * This is a while-list for pruning nested fields at `Generator`. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index 30fdcf17f8d60..7b1735a6f04ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -144,7 +144,6 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { test("Pushing a single nested field projection - negative") { val ops = Seq( (input: LogicalPlan) => input.distribute('name)(1), - (input: LogicalPlan) => input.distribute($"name.middle")(1), (input: LogicalPlan) => input.orderBy('name.asc), (input: LogicalPlan) => input.orderBy($"name.middle".asc), (input: LogicalPlan) => input.sortBy('name.asc), @@ -342,6 +341,89 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { comparePlans(optimized, expected) } + test("Nested field pruning through RepartitionByExpression") { + val query1 = contact + .distribute($"id")(1) + .select($"name.middle") + .analyze + val optimized1 = Optimize.execute(query1) + + val aliases1 = collectGeneratedAliases(optimized1) + + val expected1 = contact + .select('id, 'name.getField("middle").as(aliases1(0))) + .distribute($"id")(1) + .select($"${aliases1(0)}".as("middle")) + .analyze + comparePlans(optimized1, expected1) + + val query2 = contact + .distribute($"name.middle")(1) + .select($"name.middle") + .analyze + val optimized2 = Optimize.execute(query2) + + val aliases2 = collectGeneratedAliases(optimized2) + + val expected2 = contact + .select('name.getField("middle").as(aliases2(0))) + .distribute($"${aliases2(0)}")(1) + .select($"${aliases2(0)}".as("middle")) + .analyze + comparePlans(optimized2, expected2) + + val query3 = contact + .select($"name") + .distribute($"name")(1) + .select($"name.middle") + .analyze + val optimized3 = Optimize.execute(query3) + + comparePlans(optimized3, query3) + } + + test("Nested field pruning through Join") { + val department = LocalRelation( + 'depID.int, + 'personID.string) + + val query1 = contact.join(department, condition = Some($"id" === $"depID")) + .select($"name.middle") + .analyze + val optimized1 = Optimize.execute(query1) + + val aliases1 = collectGeneratedAliases(optimized1) + + val expected1 = contact.select('id, 'name.getField("middle").as(aliases1(0))) + .join(department.select('depID), condition = Some($"id" === $"depID")) + .select($"${aliases1(0)}".as("middle")) + .analyze + comparePlans(optimized1, expected1) + + val query2 = contact.join(department, condition = Some($"name.middle" === $"personID")) + .select($"name.first") + .analyze + val optimized2 = Optimize.execute(query2) + + val aliases2 = collectGeneratedAliases(optimized2) + + val expected2 = contact.select( + 'name.getField("first").as(aliases2(0)), + 'name.getField("middle").as(aliases2(1))) + .join(department.select('personID), condition = Some($"${aliases2(1)}" === $"personID")) + .select($"${aliases2(0)}".as("first")) + .analyze + comparePlans(optimized2, expected2) + + val contact2 = LocalRelation('name2.struct(name)) + val query3 = contact.select('name) + .join(contact2, condition = Some($"name" === $"name2")) + .select($"name.first") + .analyze + val optimized3 = Optimize.execute(query3) + comparePlans(optimized3, query3) + } + test("Nested field pruning for Aggregate") { def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = { val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 2f9e510752b02..8b859e951b9b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -51,6 +51,11 @@ abstract class SchemaPruningSuite relatives: Map[String, FullName] = Map.empty, employer: Employer = null, relations: Map[FullName, String] = Map.empty) + case class Department( + depId: Int, + depName: String, + contactId: Int, + employer: Employer) val janeDoe = FullName("Jane", "X.", "Doe") val johnDoe = FullName("John", "Y.", "Doe") @@ -58,6 +63,7 @@ abstract class SchemaPruningSuite val employer = Employer(0, Company("abc", "123 Business Street")) val employerWithNullCompany = Employer(1, null) + val employerWithNullCompany2 = Employer(2, null) val contacts = Contact(0, janeDoe, "123 Main Street", 1, friends = Array(susanSmith), @@ -66,6 +72,11 @@ abstract class SchemaPruningSuite Contact(1, johnDoe, "321 Wall Street", 3, relatives = Map("sister" -> janeDoe), employer = employerWithNullCompany, relations = Map(janeDoe -> "sister")) :: Nil + val departments = + Department(0, "Engineering", 0, employer) :: + Department(1, "Marketing", 1, employerWithNullCompany) :: + Department(2, "Operation", 4, employerWithNullCompany2) :: Nil + case class Name(first: String, last: String) case class BriefContact(id: Int, name: Name, address: String) @@ -350,6 +361,83 @@ abstract class SchemaPruningSuite checkAnswer(query, Row(0) :: Nil) } + testSchemaPruning("select one deep nested complex field after repartition by expression") { + val query1 = sql("select * from contacts") + .repartition(100, col("id")) + .where("employer.company.address is not null") + .selectExpr("employer.id as employer_id") + checkScan(query1, + "struct>>") + checkAnswer(query1, Row(0) :: Nil) + + val query2 = sql("select * from contacts") + .repartition(100, col("employer")) + .where("employer.company.address is not null") + .selectExpr("employer.id as employer_id") + checkScan(query2, + "struct>>") + checkAnswer(query2, Row(0) :: Nil) + + val query3 = sql("select * from contacts") + .repartition(100, col("employer.company")) + .where("employer.company.address is not null") + .selectExpr("employer.company as employer_company") + checkScan(query3, + "struct>>") + checkAnswer(query3, Row(Row("abc", "123 Business Street")) :: Nil) + + val query4 = sql("select * from contacts") + .repartition(100, col("employer.company.address")) + .where("employer.company.address is not null") + .selectExpr("employer.company.address as employer_company_addr") + checkScan(query4, + "struct>>") + checkAnswer(query4, Row("123 Business Street") :: Nil) + } + + testSchemaPruning("select one deep nested complex field after join") { + val query1 = sql("select contacts.name.middle from contacts, departments where " + + "contacts.id = departments.contactId") + checkScan(query1, + "struct>", + "struct") + checkAnswer(query1, Row("X.") :: Row("Y.") :: Nil) + + val query2 = sql("select contacts.name.middle from contacts, departments where " + + "contacts.employer = departments.employer") + checkScan(query2, + "struct," + + "employer:struct>>", + "struct>>") + checkAnswer(query2, Row("X.") :: Row("Y.") :: Nil) + + val query3 = sql("select contacts.employer.company.name from contacts, departments where " + + "contacts.employer = departments.employer") + checkScan(query3, + "struct>>", + "struct>>") + checkAnswer(query3, Row("abc") :: Row(null) :: Nil) + } + + testSchemaPruning("select one deep nested complex field after outer join") { + val query1 = sql("select departments.contactId, contacts.name.middle from departments " + + "left outer join contacts on departments.contactId = contacts.id") + checkScan(query1, + "struct", + "struct>") + checkAnswer(query1, Row(0, "X.") :: Row(1, "Y.") :: Row(4, null) :: Nil) + + val query2 = sql("select contacts.name.first, departments.employer.company.name " + + "from contacts right outer join departments on contacts.id = departments.contactId") + checkScan(query2, + "struct>", + "struct>>") + checkAnswer(query2, + Row("Jane", "abc") :: + Row("John", null) :: + Row(null, null) :: Nil) + } + testSchemaPruning("select nested field in aggregation function of Aggregate") { val query1 = sql("select count(name.first) from contacts group by name.last") checkScan(query1, "struct>") @@ -439,6 +527,7 @@ abstract class SchemaPruningSuite makeDataSourceFile(contacts, new File(path + "/contacts/p=1")) makeDataSourceFile(briefContacts, new File(path + "/contacts/p=2")) + makeDataSourceFile(departments, new File(path + "/departments")) // Providing user specified schema. Inferred schema from different data sources might // be different. @@ -451,6 +540,11 @@ abstract class SchemaPruningSuite spark.read.format(dataSourceName).schema(schema).load(path + "/contacts") .createOrReplaceTempView("contacts") + val departmentScahem = "`depId` INT,`depName` STRING,`contactId` INT, " + + "`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, `address`: STRING>>" + spark.read.format(dataSourceName).schema(departmentScahem).load(path + "/departments") + .createOrReplaceTempView("departments") + testThunk } } @@ -461,6 +555,7 @@ abstract class SchemaPruningSuite makeDataSourceFile(contactsWithDataPartitionColumn, new File(path + "/contacts/p=1")) makeDataSourceFile(briefContactsWithDataPartitionColumn, new File(path + "/contacts/p=2")) + makeDataSourceFile(departments, new File(path + "/departments")) // Providing user specified schema. Inferred schema from different data sources might // be different. @@ -473,6 +568,11 @@ abstract class SchemaPruningSuite spark.read.format(dataSourceName).schema(schema).load(path + "/contacts") .createOrReplaceTempView("contacts") + val departmentScahem = "`depId` INT,`depName` STRING,`contactId` INT, " + + "`employer` STRUCT<`id`: INT, `company`: STRUCT<`name`: STRING, `address`: STRING>>" + spark.read.format(dataSourceName).schema(departmentScahem).load(path + "/departments") + .createOrReplaceTempView("departments") + testThunk } }