diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdown.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdown.scala new file mode 100644 index 000000000000..0e883072f509 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdown.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, NamedExpression} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} + +/** + * Pushes down aliases to [[expressions.GetStructField]] expressions in an aggregate's grouping and + * aggregate expressions into a projection over its children. The original + * [[expressions.GetStructField]] expressions are replaced with references to the pushed down + * aliases. This allows the optimizer to minimize the columns read from a column-oriented file + * format for aggregation queries involving only nested fields. + */ +object AggregateFieldExtractionPushdown extends FieldExtractionPushdown { + override protected def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case agg @ Aggregate(groupingExpressions, aggregateExpressions, child) => + val expressions = groupingExpressions ++ aggregateExpressions + val attributes = AttributeSet(expressions.collect { case att: Attribute => att }) + val childAttributes = AttributeSet(child.expressions) + val fieldExtractors0 = + expressions + .flatMap(getFieldExtractors) + .distinct + val fieldExtractors1 = + fieldExtractors0 + .filter(_.collectFirst { case att: Attribute => att } + .filter(attributes.contains).isEmpty) + val fieldExtractors = + fieldExtractors1 + .filter(_.collectFirst { case att: Attribute => att } + .filter(childAttributes.contains).nonEmpty) + + if (fieldExtractors.nonEmpty) { + val (aliases, substituteAttributes) = constructAliasesAndSubstitutions(fieldExtractors) + + if (aliases.nonEmpty) { + // Construct the new grouping and aggregate expressions by substituting + // each GetStructField expression with a reference to its alias + val newAggregateExpressions = + aggregateExpressions.map(substituteAttributes) + .collect { case named: NamedExpression => named } + val newGroupingExpressions = groupingExpressions.map(substituteAttributes) + + // We need to push down the aliases we've created. We do this with a new projection over + // this aggregate's child consisting of the aliases and original child's output sans + // attributes referenced by the aliases + + // None of these attributes are required by this aggregate because we filtered out the + // GetStructField instances which referred to attributes that were required + val unnecessaryAttributes = aliases.map(_.child.references).reduce(_ ++ _) + // The output we require from this aggregate is the child's output minus the unnecessary + // attributes + val requiredChildOutput = child.output.filterNot(unnecessaryAttributes.contains) + val projects = requiredChildOutput ++ aliases + val newProject = Project(projects, child) + + Aggregate(newGroupingExpressions, newAggregateExpressions, newProject) + } else { + agg + } + } else { + agg + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/FieldExtractionPushdown.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/FieldExtractionPushdown.scala new file mode 100644 index 000000000000..725ae1414dcb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/FieldExtractionPushdown.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, GetStructField} +import org.apache.spark.sql.catalyst.planning.SelectedField +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +abstract class FieldExtractionPushdown extends Rule[LogicalPlan] { + final override def apply(plan: LogicalPlan): LogicalPlan = + if (SQLConf.get.nestedSchemaPruningEnabled) { + apply0(plan) + } else { + plan + } + + protected def apply0(plan: LogicalPlan): LogicalPlan + + // Gets the top-level GetStructField expressions from the given expression + // and its children. This does not return children of a GetStructField. + protected final def getFieldExtractors(expr: Expression): Seq[GetStructField] = + expr match { + // Check that getField matches SelectedField(_) to ensure that getField defines a chain of + // extractors down to an attribute. + case getField: GetStructField if SelectedField.unapply(getField).isDefined => + getField :: Nil + case _ => + expr.children.flatMap(getFieldExtractors) + } + + // Constructs aliases and a substitution function for the given sequence of + // GetStructField expressions. + protected final def constructAliasesAndSubstitutions(fieldExtractors: Seq[GetStructField]) = { + val aliases = + fieldExtractors.map(extractor => + Alias(extractor, extractor.childSchema(extractor.ordinal).name)()) + + val attributes = aliases.map(alias => (alias.child, alias.toAttribute)).toMap + + val substituteAttributes: Expression => Expression = _.transformDown { + case expr: GetStructField => attributes.getOrElse(expr, expr) + } + + (aliases, substituteAttributes) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/JoinFieldExtractionPushdown.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/JoinFieldExtractionPushdown.scala new file mode 100644 index 000000000000..e7196a77582c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/JoinFieldExtractionPushdown.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, NamedExpression} +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} + +/** + * Pushes down aliases to [[expressions.GetStructField]] expressions in a projection over a join + * and its join condition. The original [[expressions.GetStructField]] expressions are replaced + * with references to the pushed down aliases. This allows the optimizer to minimize the columns + * read from a column-oriented file format for joins involving only nested fields. + */ +object JoinFieldExtractionPushdown extends FieldExtractionPushdown { + override protected def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op @ PhysicalOperation(projects, Seq(), + join @ Join(left, right, joinType, Some(joinCondition))) => + val fieldExtractors = (projects :+ joinCondition).flatMap(getFieldExtractors).distinct + + if (fieldExtractors.nonEmpty) { + val (aliases, substituteAttributes) = constructAliasesAndSubstitutions(fieldExtractors) + + if (aliases.nonEmpty) { + // Construct the new projections and join condition by substituting each GetStructField + // expression with a reference to its alias + val newProjects = + projects.map(substituteAttributes).collect { case named: NamedExpression => named } + val newJoinCondition = substituteAttributes(joinCondition) + + // Prune left and right output attributes according to whether they're needed by the + // new projections or join conditions + val aliasAttributes = AttributeSet(aliases.map(_.toAttribute)) + val neededAttributes = AttributeSet((newProjects :+ newJoinCondition) + .flatMap(_.collect { case att: Attribute => att })) -- aliasAttributes + val leftAtts = left.output.filter(neededAttributes.contains) + val rightAtts = right.output.filter(neededAttributes.contains) + + // Construct the left and right side aliases by partitioning the aliases according to + // whether they reference attributes in the left side or the right side + val (leftAliases, rightAliases) = + aliases.partition(_.references.intersect(left.outputSet).nonEmpty) + + val newLeft = Project(leftAtts.toSeq ++ leftAliases, left) + val newRight = Project(rightAtts.toSeq ++ rightAliases, right) + + Project(newProjects, Join(newLeft, newRight, joinType, Some(newJoinCondition))) + } else { + op + } + } else { + op + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 45f13956a0a8..6848b396b1e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -151,6 +151,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) // The following batch should be executed after batch "Join Reorder" and "LocalRelation". Batch("Check Cartesian Products", Once, CheckCartesianProducts) :+ + Batch("Field Extraction Pushdown", fixedPoint, + AggregateFieldExtractionPushdown, + JoinFieldExtractionPushdown) :+ Batch("RewriteSubquery", Once, RewritePredicateSubquery, ColumnPruning, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/GetStructFieldObject.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/GetStructFieldObject.scala new file mode 100644 index 000000000000..033792a9ac72 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/GetStructFieldObject.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField} +import org.apache.spark.sql.types.StructField + +/** + * A Scala extractor that extracts the child expression and struct field from a [[GetStructField]]. + * This is in contrast to the [[GetStructField]] case class extractor which returns the field + * ordinal instead of the field itself. + */ +private[planning] object GetStructFieldObject { + def unapply(getStructField: GetStructField): Option[(Expression, StructField)] = + Some(( + getStructField.child, + getStructField.childSchema(getStructField.ordinal))) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/ProjectionOverSchema.scala new file mode 100644 index 000000000000..e305676ffa8a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/ProjectionOverSchema.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that projects an expression over a given schema. Data types, + * field indexes and field counts of complex type extractors and attributes + * are adjusted to fit the schema. All other expressions are left as-is. This + * class is motivated by columnar nested schema pruning. + */ +case class ProjectionOverSchema(schema: StructType) { + private val fieldNames = schema.fieldNames.toSet + + def unapply(expr: Expression): Option[Expression] = getProjection(expr) + + private def getProjection(expr: Expression): Option[Expression] = + expr match { + case a @ AttributeReference(name, _, _, _) if (fieldNames.contains(name)) => + Some(a.copy(dataType = schema(name).dataType)(a.exprId, a.qualifier)) + case GetArrayItem(child, arrayItemOrdinal) => + getProjection(child).map { + case projection => + GetArrayItem(projection, arrayItemOrdinal) + } + case GetArrayStructFields(child, StructField(name, _, _, _), _, numFields, containsNull) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, ArrayType(projSchema @ StructType(_), _)) => + GetArrayStructFields(projection, + projSchema(name), projSchema.fieldIndex(name), projSchema.size, containsNull) + } + case GetMapValue(child, key) => + getProjection(child).map { + case projection => + GetMapValue(projection, key) + } + case GetStructFieldObject(child, StructField(name, _, _, _)) => + getProjection(child).map(p => (p, p.dataType)).map { + case (projection, projSchema @ StructType(_)) => + GetStructField(projection, projSchema.fieldIndex(name)) + } + case _ => + None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/SelectedField.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/SelectedField.scala new file mode 100644 index 000000000000..dc1e00290bed --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/SelectedField.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + +/** + * A Scala extractor that builds a [[org.apache.spark.sql.types.StructField]] from a Catalyst + * complex type extractor. For example, consider a relation with the following schema: + * + * {{{ + * root + * |-- name: struct (nullable = true) + * | |-- first: string (nullable = true) + * | |-- last: string (nullable = true) + * }}} + * + * Further, suppose we take the select expression `name.first`. This will parse into an + * `Alias(child, "first")`. Ignoring the alias, `child` matches the following pattern: + * + * {{{ + * GetStructFieldObject( + * AttributeReference("name", StructType(_), _, _), + * StructField("first", StringType, _, _)) + * }}} + * + * [[SelectedField]] converts that expression into + * + * {{{ + * StructField("name", StructType(Array(StructField("first", StringType)))) + * }}} + * + * by mapping each complex type extractor to a [[org.apache.spark.sql.types.StructField]] with the + * same name as its child (or "parent" going right to left in the select expression) and a data + * type appropriate to the complex type extractor. In our example, the name of the child expression + * is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string + * field named "first". + * + * @param expr the top-level complex type extractor + */ +object SelectedField { + def unapply(expr: Expression): Option[StructField] = { + // If this expression is an alias, work on its child instead + val unaliased = expr match { + case Alias(child, _) => child + case expr => expr + } + selectField(unaliased, None) + } + + private def selectField(expr: Expression, fieldOpt: Option[StructField]): Option[StructField] = { + expr match { + // No children. Returns a StructField with the attribute name or None if fieldOpt is None. + case AttributeReference(name, dataType, nullable, metadata) => + fieldOpt.map(field => + StructField(name, wrapStructType(dataType, field), nullable, metadata)) + // Handles case "expr0.field[n]", where "expr0" is of struct type and "expr0.field" is of + // array type. + case GetArrayItem(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field[n]", where "expr0.field" is of array type. + case GetArrayItem(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field.subfield", where "expr0" and "expr0.field" are of array type. + case GetArrayStructFields(child: GetArrayStructFields, + field @ StructField(name, dataType, nullable, metadata), _, _, _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field", where "expr0" is of array type. + case GetArrayStructFields(child, + field @ StructField(name, dataType, nullable, metadata), _, _, containsNull) => + val childField = + fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field[key]", where "expr0" is of struct type and "expr0.field" is of + // map type. + case GetMapValue(x @ GetStructFieldObject(child, field @ StructField(name, + dataType, + nullable, metadata)), _) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + // Handles case "expr0.field[key]", where "expr0.field" is of map type. + case GetMapValue(child, _) => + selectField(child, fieldOpt) + // Handles case "expr0.field", where expr0 is of struct type. + case GetStructFieldObject(child, + field @ StructField(name, dataType, nullable, metadata)) => + val childField = fieldOpt.map(field => StructField(name, + wrapStructType(dataType, field), + nullable, metadata)).getOrElse(field) + selectField(child, Some(childField)) + case _ => + None + } + } + + // Constructs a composition of complex types with a StructType(Array(field)) at its core. Returns + // a StructType for a StructType, an ArrayType for an ArrayType and a MapType for a MapType. + private def wrapStructType(dataType: DataType, field: StructField): DataType = { + dataType match { + case _: StructType => + StructType(Array(field)) + case ArrayType(elementType, containsNull) => + ArrayType(wrapStructType(elementType, field), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType(keyType, wrapStructType(valueType, field), valueContainsNull) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index cc352c59dff8..5bfb2ec4c678 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -81,8 +81,9 @@ trait ConstraintHelper { /** * Infers a set of `isNotNull` constraints from null intolerant expressions as well as - * non-nullable attributes. For e.g., if an expression is of the form (`a > 5`), this - * returns a constraint of the form `isNotNull(a)` + * non-nullable attributes and complex type extractors. For example, if an expression is of the + * form (`a > 5`), this returns a constraint of the form `isNotNull(a)`. For an expression of the + * form (`a.b > 5`), this returns the more precise constraint `isNotNull(a.b)`. */ def constructIsNotNullConstraints( constraints: Set[Expression], @@ -99,27 +100,28 @@ trait ConstraintHelper { } /** - * Infer the Attribute-specific IsNotNull constraints from the null intolerant child expressions - * of constraints. + * Infer the Attribute and ExtractValue-specific IsNotNull constraints from the null intolerant + * child expressions of constraints. */ private def inferIsNotNullConstraints(constraint: Expression): Seq[Expression] = constraint match { // When the root is IsNotNull, we can push IsNotNull through the child null intolerant // expressions - case IsNotNull(expr) => scanNullIntolerantAttribute(expr).map(IsNotNull(_)) + case IsNotNull(expr) => scanNullIntolerantField(expr).map(IsNotNull(_)) // Constraints always return true for all the inputs. That means, null will never be returned. // Thus, we can infer `IsNotNull(constraint)`, and also push IsNotNull through the child // null intolerant expressions. - case _ => scanNullIntolerantAttribute(constraint).map(IsNotNull(_)) + case _ => scanNullIntolerantField(constraint).map(IsNotNull(_)) } /** - * Recursively explores the expressions which are null intolerant and returns all attributes - * in these expressions. + * Recursively explores the expressions which are null intolerant and returns all attributes and + * complex type extractors in these expressions. */ - private def scanNullIntolerantAttribute(expr: Expression): Seq[Attribute] = expr match { + private def scanNullIntolerantField(expr: Expression): Seq[Expression] = expr match { + case ev: ExtractValue => Seq(ev) case a: Attribute => Seq(a) - case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantAttribute) + case _: NullIntolerant => expr.children.flatMap(scanNullIntolerantField) case _ => Seq.empty[Attribute] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 3729bd5293ec..948cadd4843e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1235,8 +1235,18 @@ object SQLConf { "issues. Turn on this config to insert a local sort before actually doing repartition " + "to generate consistent repartition results. The performance of repartition() may go " + "down since we insert extra local sort before it.") + .booleanConf + .createWithDefault(true) + + val NESTED_SCHEMA_PRUNING_ENABLED = + buildConf("spark.sql.nestedSchemaPruning.enabled") + .internal() + .doc("Prune nested fields from a logical relation's output which are unnecessary in " + + "satisfying a query. This optimization allows columnar file format readers to avoid " + + "reading unnecessary nested column data. Currently Parquet is the only data source that " + + "implements this optimization.") .booleanConf - .createWithDefault(true) + .createWithDefault(false) object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" @@ -1603,6 +1613,8 @@ class SQLConf extends Serializable with Logging { def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala new file mode 100644 index 000000000000..88e30e03d2d3 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SchemaPruningTest.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf.NESTED_SCHEMA_PRUNING_ENABLED + +/** + * A PlanTest that ensures that all tests in this suite are run with nested schema pruning enabled. + * Remove this trait once the default value of SQLConf.NESTED_SCHEMA_PRUNING_ENABLED is set to true. + */ +private[sql] trait SchemaPruningTest extends PlanTest with BeforeAndAfterAll { + private var originalConfSchemaPruningEnabled = false + + override protected def beforeAll(): Unit = { + // Call `withSQLConf` eagerly because some subtypes of `PlanTest` (I'm looking at you, + // `SQLTestUtils`) override `withSQLConf` to reset the existing `SQLConf` with a new one without + // copying existing settings first. This here is an awful, ugly way to get around that behavior + // by initializing the "real" `SQLConf` with an noop call to `withSQLConf`. I don't want to risk + // "fixing" the downstream behavior, breaking everything else that's expecting these semantics. + // Oh well... + withSQLConf()(()) + originalConfSchemaPruningEnabled = conf.nestedSchemaPruningEnabled + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, true) + super.beforeAll() + } + + override protected def afterAll(): Unit = { + try { + super.afterAll() + } finally { + conf.setConf(NESTED_SCHEMA_PRUNING_ENABLED, originalConfSchemaPruningEnabled) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdownSuite.scala new file mode 100644 index 000000000000..6d4017ad371f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateFieldExtractionPushdownSuite.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class AggregateFieldExtractionPushdownSuite extends SchemaPruningTest { + private val testRelation = + LocalRelation( + StructField("a", StructType( + StructField("a1", IntegerType) :: Nil)), + StructField("b", IntegerType), + StructField("c", StructType( + StructField("c1", IntegerType) :: Nil))) + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Aggregate Field Extraction Pushdown", Once, + AggregateFieldExtractionPushdown) :: Nil + } + + test("basic aggregate field extraction pushdown") { + val originalQuery = + testRelation + .select('a) + .groupBy('a getField "a1")('a getField "a1" as 'a1, Count("*")) + .analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + testRelation + .select('a) + .select('a getField "a1" as 'a1) + .groupBy('a1)('a1 as 'a1, Count("*")) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("do not push down a field extractor whose root child output is required by the aggregate") { + val originalQuery = + testRelation + .select('a, 'c) + .groupBy('a, 'a getField "a1", 'c getField "c1")( + 'a, 'a getField "a1" as 'a1, 'c getField "c1" as 'c1, Count("*")) + .analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + testRelation + .select('a, 'c) + .select('a, 'c getField "c1" as 'c1) + .groupBy('a, 'a getField "a1", 'c1)('a, 'a getField "a1" as 'a1, 'c1 as 'c1, Count("*")) + .analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFieldExtractionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFieldExtractionPushdownSuite.scala new file mode 100644 index 000000000000..16ea72b4ea9b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinFieldExtractionPushdownSuite.scala @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class JoinFieldExtractionPushdownSuite extends SchemaPruningTest { + private val leftRelation = + LocalRelation( + StructField("la", StructType( + StructField("la1", IntegerType) :: Nil)), + StructField("lb", IntegerType), + StructField("lc", StructType( + StructField("lc1", IntegerType) :: Nil))) + + private val rightRelation = + LocalRelation( + StructField("ra", StructType( + StructField("ra1", IntegerType) :: Nil)), + StructField("rb", IntegerType), + StructField("rc", StructType( + StructField("rc1", IntegerType) :: Nil))) + + private val joinTypes = Inner :: LeftOuter :: RightOuter :: FullOuter :: LeftSemi :: Nil + + object Optimizer extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Join Field Extraction Pushdown", Once, + JoinFieldExtractionPushdown) :: Nil + } + + test("don't modify simple joins") { + joinTypes.foreach { joinType => + val originalQuery = + leftRelation.join(rightRelation, joinType, Some($"lb" === $"rb")).analyze + val optimized = Optimizer.execute(originalQuery) + + comparePlans(optimized, originalQuery) + } + } + + test("don't modify output of bare join") { + joinTypes.filterNot(_ == LeftSemi).foreach { joinType => + val originalQuery = + leftRelation.join(rightRelation, joinType, Some($"la.la1" === $"rb")).analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + leftRelation.select('la, 'lb, 'lc, $"la.la1" as 'la1) + .join(rightRelation.select('ra, 'rb, 'rc), joinType, Some('la1 === 'rb)) + .select('la, 'lb, 'lc, 'ra, 'rb, 'rc) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + test("push down left join condition path of degree 1") { + joinTypes.foreach { joinType => + val originalQuery = + leftRelation.join(rightRelation, joinType, Some($"la.la1" === $"rb")) + .select($"la.la1") + .analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + leftRelation.select($"la.la1" as 'la1) + .join(rightRelation.select('rb), joinType, Some('la1 === 'rb)) + .select('la1 as 'la1) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + test("push down right join condition path of degree 1") { + joinTypes.filterNot(_ == LeftSemi).foreach { joinType => + val originalQuery = + leftRelation.join(rightRelation, joinType, Some($"lb" === $"ra.ra1")) + .select($"ra.ra1") + .analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + leftRelation.select('lb) + .join(rightRelation.select($"ra.ra1" as 'ra1), joinType, Some('lb === 'ra1)) + .select('ra1 as 'ra1) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + test("push down both join condition paths of degree 1") { + def test1( + joinType: JoinType, + originalSelects: Seq[NamedExpression], + expectedSelects: Seq[NamedExpression]) { + val originalQuery = + leftRelation.join(rightRelation, joinType, Some($"la.la1" === $"ra.ra1")) + .select(originalSelects: _*) + .analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + leftRelation.select($"la.la1" as 'la1) + .join(rightRelation.select($"ra.ra1" as 'ra1), joinType, Some('la1 === 'ra1)) + .select(expectedSelects: _*) + .analyze + + comparePlans(optimized, correctAnswer) + } + + joinTypes.filterNot(_ == LeftSemi) + .foreach(joinType => + test1(joinType, $"la.la1" :: $"ra.ra1" :: Nil, ('la1 as 'la1) :: ('ra1 as 'ra1) :: Nil)) + + test1(LeftSemi, $"la.la1" :: Nil, ('la1 as 'la1) :: Nil) + } + + test("don't prune root of leaf when the root is in the projection") { + joinTypes.filterNot(_ == LeftSemi).foreach { joinType => + val originalQuery = + leftRelation.join(rightRelation, joinType, Some($"la.la1" === $"ra.ra1")) + .select($"la.la1", 'ra) + .analyze + val optimized = Optimizer.execute(originalQuery) + val correctAnswer = + leftRelation.select($"la.la1" as 'la1) + .join(rightRelation.select('ra, $"ra.ra1" as 'ra1), joinType, Some('la1 === 'ra1)) + .select('la1 as 'la1, 'ra) + .analyze + + comparePlans(optimized, correctAnswer) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/SelectedFieldSuite.scala new file mode 100644 index 000000000000..6923c7a89abf --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/SelectedFieldSuite.scala @@ -0,0 +1,432 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.planning + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types._ + +// scalastyle:off line.size.limit +class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll { + // The test schema as a tree string, i.e. `schema.treeString` + // root + // |-- col1: string (nullable = false) + // |-- col2: struct (nullable = true) + // | |-- field1: integer (nullable = true) + // | |-- field2: array (nullable = true) + // | | |-- element: integer (containsNull = false) + // | |-- field3: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: integer (nullable = true) + // | | | |-- subfield3: array (nullable = true) + // | | | | |-- element: integer (containsNull = true) + // | |-- field4: map (nullable = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- subfield1: integer (nullable = true) + // | | | |-- subfield2: array (nullable = true) + // | | | | |-- element: integer (containsNull = false) + // | |-- field5: array (nullable = false) + // | | |-- element: struct (containsNull = true) + // | | | |-- subfield1: struct (nullable = false) + // | | | | |-- subsubfield1: integer (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + // | | | |-- subfield2: struct (nullable = true) + // | | | | |-- subsubfield1: struct (nullable = true) + // | | | | | |-- subsubsubfield1: string (nullable = true) + // | | | | |-- subsubfield2: integer (nullable = true) + // | |-- field6: struct (nullable = true) + // | | |-- subfield1: string (nullable = false) + // | | |-- subfield2: string (nullable = true) + // | |-- field7: struct (nullable = true) + // | | |-- subfield1: struct (nullable = true) + // | | | |-- subsubfield1: integer (nullable = true) + // | | | |-- subsubfield2: integer (nullable = true) + // | |-- field8: map (nullable = true) + // | | |-- key: string + // | | |-- value: array (valueContainsNull = false) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: array (nullable = true) + // | | | | | |-- element: integer (containsNull = false) + // | |-- field9: map (nullable = true) + // | | |-- key: string + // | | |-- value: integer (valueContainsNull = false) + // |-- col3: array (nullable = false) + // | |-- element: struct (containsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + // |-- col4: map (nullable = false) + // | |-- key: string + // | |-- value: struct (valueContainsNull = false) + // | | |-- field1: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | | |-- subfield2: integer (nullable = true) + // | | |-- field2: map (nullable = true) + // | | | |-- key: string + // | | | |-- value: integer (valueContainsNull = false) + // |-- col5: array (nullable = true) + // | |-- element: map (containsNull = true) + // | | |-- key: string + // | | |-- value: struct (valueContainsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + // |-- col6: map (nullable = true) + // | |-- key: string + // | |-- value: array (valueContainsNull = true) + // | | |-- element: struct (containsNull = false) + // | | | |-- field1: struct (nullable = true) + // | | | | |-- subfield1: integer (nullable = true) + // | | | | |-- subfield2: integer (nullable = true) + // |-- col7: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: integer (nullable = false) + // | | |-- field2: struct (nullable = true) + // | | | |-- subfield1: integer (nullable = false) + // | | |-- field3: array (nullable = true) + // | | | |-- element: struct (containsNull = true) + // | | | | |-- subfield1: integer (nullable = false) + // |-- col8: array (nullable = true) + // | |-- element: struct (containsNull = true) + // | | |-- field1: array (nullable = false) + // | | | |-- element: integer (containsNull = false) + private val schema = + StructType( + StructField("col1", StringType, nullable = false) :: + StructField("col2", StructType( + StructField("field1", IntegerType) :: + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil), valueContainsNull = false)) :: + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)), nullable = false) :: + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil)), valueContainsNull = false)) :: + StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), containsNull = false), nullable = false) :: + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: + StructField("subfield2", IntegerType) :: Nil)) :: + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), valueContainsNull = false), nullable = false) :: + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) :: + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: + StructField("subfield2", IntegerType) :: Nil)) :: Nil), containsNull = false))) :: + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) :: + StructField("col8", ArrayType(StructType( + StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) :: Nil))) :: Nil) + + private val testRelation = LocalRelation(schema.toAttributes) + + test("should not match an attribute reference") { + assertResult(None)(unapplySelect("col1")) + assertResult(None)(unapplySelect("col1 as foo")) + assertResult(None)(unapplySelect("col2")) + } + + test("col2.field2, col2.field2[0] as foo") { + val expected = + StructField("col2", StructType( + StructField("field2", ArrayType(IntegerType, containsNull = false)) :: Nil)) + testSelect("col2.field2", expected) + testSelect("col2.field2[0] as foo", expected) + } + + test("col2.field9, col2.field9['foo'] as foo") { + val expected = + StructField("col2", StructType( + StructField("field9", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) + testSelect("col2.field9", expected) + testSelect("col2.field9['foo'] as foo", expected) + } + + test("col2.field3.subfield3, col2.field3[0].subfield3 as foo, col2.field3.subfield3[0] as foo, col2.field3[0].subfield3[0] as foo") { + val expected = + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield3", ArrayType(IntegerType)) :: Nil)), nullable = false) :: Nil)) + testSelect("col2.field3.subfield3", expected) + testSelect("col2.field3[0].subfield3 as foo", expected) + testSelect("col2.field3.subfield3[0] as foo", expected) + testSelect("col2.field3[0].subfield3[0] as foo", expected) + } + + test("col2.field3.subfield1") { + val expected = + StructField("col2", StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), nullable = false) :: Nil)) + testSelect("col2.field3.subfield1", expected) + } + + test("col2.field5.subfield1") { + val expected = + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil), nullable = false) :: Nil)), nullable = false) :: Nil)) + testSelect("col2.field5.subfield1", expected) + } + + test("col3.field1.subfield1") { + val expected = + StructField("col3", ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil), containsNull = false), nullable = false) + testSelect("col3.field1.subfield1", expected) + } + + test("col3.field2['foo'] as foo") { + val expected = + StructField("col3", ArrayType(StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), containsNull = false), nullable = false) + testSelect("col3.field2['foo'] as foo", expected) + } + + test("col4['foo'].field1.subfield1 as foo") { + val expected = + StructField("col4", MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil), valueContainsNull = false), nullable = false) + testSelect("col4['foo'].field1.subfield1 as foo", expected) + } + + test("col4['foo'].field2['bar'] as foo") { + val expected = + StructField("col4", MapType(StringType, StructType( + StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil), valueContainsNull = false), nullable = false) + testSelect("col4['foo'].field2['bar'] as foo", expected) + } + + test("col5[0]['foo'].field1.subfield1 as foo") { + val expected = + StructField("col5", ArrayType(MapType(StringType, StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), valueContainsNull = false))) + testSelect("col5[0]['foo'].field1.subfield1 as foo", expected) + } + + test("col6['foo'][0].field1.subfield1 as foo") { + val expected = + StructField("col6", MapType(StringType, ArrayType(StructType( + StructField("field1", StructType( + StructField("subfield1", IntegerType) :: Nil)) :: Nil), containsNull = false))) + testSelect("col6['foo'][0].field1.subfield1 as foo", expected) + } + + test("col2.field5.subfield1.subsubfield1") { + val expected = + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: Nil), nullable = false) :: Nil)), nullable = false) :: Nil)) + testSelect("col2.field5.subfield1.subsubfield1", expected) + } + + test("col2.field5.subfield2.subsubfield1.subsubsubfield1") { + val expected = + StructField("col2", StructType( + StructField("field5", ArrayType(StructType( + StructField("subfield2", StructType( + StructField("subsubfield1", StructType( + StructField("subsubsubfield1", StringType) :: Nil)) :: Nil)) :: Nil)), nullable = false) :: Nil)) + testSelect("col2.field5.subfield2.subsubfield1.subsubsubfield1", expected) + } + + test("col2.field4['foo'].subfield1 as foo") { + val expected = + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield1", IntegerType) :: Nil), valueContainsNull = false)) :: Nil)) + testSelect("col2.field4['foo'].subfield1 as foo", expected) + } + + test("col2.field4['foo'].subfield2 as foo, col2.field4['foo'].subfield2[0] as foo") { + val expected = + StructField("col2", StructType( + StructField("field4", MapType(StringType, StructType( + StructField("subfield2", ArrayType(IntegerType, containsNull = false)) :: Nil), valueContainsNull = false)) :: Nil)) + testSelect("col2.field4['foo'].subfield2 as foo", expected) + testSelect("col2.field4['foo'].subfield2[0] as foo", expected) + } + + test("col2.field8['foo'][0].subfield1 as foo") { + val expected = + StructField("col2", StructType( + StructField("field8", MapType(StringType, ArrayType(StructType( + StructField("subfield1", IntegerType) :: Nil)), valueContainsNull = false)) :: Nil)) + testSelect("col2.field8['foo'][0].subfield1 as foo", expected) + } + + test("col2.field1") { + val expected = + StructField("col2", StructType( + StructField("field1", IntegerType) :: Nil)) + testSelect("col2.field1", expected) + } + + test("col2.field6") { + val expected = + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: + StructField("subfield2", StringType) :: Nil)) :: Nil)) + testSelect("col2.field6", expected) + } + + test("col2.field7.subfield1") { + val expected = + StructField("col2", StructType( + StructField("field7", StructType( + StructField("subfield1", StructType( + StructField("subsubfield1", IntegerType) :: + StructField("subsubfield2", IntegerType) :: Nil)) :: Nil)) :: Nil)) + testSelect("col2.field7.subfield1", expected) + } + + test("col2.field6.subfield1") { + val expected = + StructField("col2", StructType( + StructField("field6", StructType( + StructField("subfield1", StringType, nullable = false) :: Nil)) :: Nil)) + testSelect("col2.field6.subfield1", expected) + } + + test("col7.field1, col7[0].field1 as foo, col7.field1[0] as foo") { + val expected = + StructField("col7", ArrayType(StructType( + StructField("field1", IntegerType, nullable = false) :: Nil))) + testSelect("col7.field1", expected) + testSelect("col7[0].field1 as foo", expected) + testSelect("col7.field1[0] as foo", expected) + } + + test("col7.field2.subfield1") { + val expected = + StructField("col7", ArrayType(StructType( + StructField("field2", StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil)) :: Nil))) + testSelect("col7.field2.subfield1", expected) + } + + test("col7.field3.subfield1") { + val expected = + StructField("col7", ArrayType(StructType( + StructField("field3", ArrayType(StructType( + StructField("subfield1", IntegerType, nullable = false) :: Nil))) :: Nil))) + testSelect("col7.field3.subfield1", expected) + } + + test("col8.field1, col8[0].field1 as foo, col8.field1[0] as foo, col8[0].field1[0] as foo") { + val expected = + StructField("col8", ArrayType(StructType( + StructField("field1", ArrayType(IntegerType, containsNull = false), nullable = false) :: Nil))) + testSelect("col8.field1", expected) + testSelect("col8[0].field1 as foo", expected) + testSelect("col8.field1[0] as foo", expected) + testSelect("col8[0].field1[0] as foo", expected) + } + + def assertResult(expected: StructField)(actual: StructField)(expr: String): Unit = { + try { + super.assertResult(expected)(actual) + } catch { + case ex: TestFailedException => + // Print some helpful diagnostics in the case of failure + // scalastyle:off println + println("For " + expr) + println("Expected:") + println(StructType(expected :: Nil).treeString) + println("Actual:") + println(StructType(actual :: Nil).treeString) + println("expected.dataType.sameType(actual.dataType) = " + + expected.dataType.sameType(actual.dataType)) + // scalastyle:on println + throw ex + } + } + + private def testSelect(expr: String, expected: StructField) = { + unapplySelect(expr) match { + case Some(field) => + assertResult(expected)(field)(expr) + case None => + val failureMessage = + "Failed to select a field from " + expr + ". " + + "Expected:\n" + + StructType(expected :: Nil).treeString + fail(failureMessage) + } + } + + private def unapplySelect(expr: String) = { + val parsedExpr = + CatalystSqlParser.parseExpression(expr) match { + case namedExpr: NamedExpression => namedExpr + } + val select = testRelation.select(parsedExpr) + val analyzed = select.analyze + SelectedField.unapply(analyzed.expressions.head) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 08ff33afbba3..a9858542155a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -286,7 +286,19 @@ case class FileSourceScanExec( } getOrElse { metadata } - withOptPartitionCount + val withOptColumnCount = relation.fileFormat match { + case columnar: ColumnarFileFormat => + SparkSession + .getActiveSession + .map { sparkSession => + val columnCount = columnar.columnCountForSchema(sparkSession, requiredSchema) + withOptPartitionCount + ("ColumnCount" -> columnCount.toString) + } getOrElse { + withOptPartitionCount + } + case _ => withOptPartitionCount + } + withOptColumnCount } private lazy val inputRDD: RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 1c8e4050978d..fb0cdbf7f517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.parquet.ParquetSchemaPruning import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate @@ -33,6 +34,7 @@ class SparkOptimizer( Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Parquet Schema Pruning", Once, ParquetSchemaPruning) :+ Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 1edfdc888afd..bbecd417e256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -162,7 +162,9 @@ case class FilterExec(condition: Expression, child: SparkPlan) val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => - val idx = notNullPreds.indexWhere { n => n.asInstanceOf[IsNotNull].child.semanticEquals(r)} + val idx = notNullPreds.indexWhere { n => + n.asInstanceOf[IsNotNull].child.references.contains(r) + } if (idx != -1 && !generatedIsNotNullChecks(idx)) { generatedIsNotNullChecks(idx) = true // Use the child's output. The nullability is what the child produced. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ColumnarFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ColumnarFileFormat.scala new file mode 100644 index 000000000000..ee0726cb2f00 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ColumnarFileFormat.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.types.StructType + +/** + * An optional mix-in for columnar [[FileFormat]]s. This trait provides some helpful metadata when + * debugging a physical query plan. + */ +private[sql] trait ColumnarFileFormat { + _: FileFormat => + + /** Returns the number of columns in this file format required to satisfy the given schema. */ + def columnCountForSchema(sparkSession: SparkSession, readSchema: StructType): Int +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d8f47eec952d..70550676b3fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -55,6 +55,7 @@ import org.apache.spark.util.{SerializableConfiguration, ThreadUtils} class ParquetFileFormat extends FileFormat + with ColumnarFileFormat with DataSourceRegister with Logging with Serializable { @@ -72,6 +73,14 @@ class ParquetFileFormat override def equals(other: Any): Boolean = other.isInstanceOf[ParquetFileFormat] + override def columnCountForSchema(sparkSession: SparkSession, readSchema: StructType): Int = { + val converter = new SparkToParquetSchemaConverter( + sparkSession.sessionState.conf.writeLegacyParquetFormat, + sparkSession.sessionState.conf.parquetOutputTimestampType) + val parquetSchema = converter.convert(readSchema) + parquetSchema.getPaths.size + } + override def prepareWrite( sparkSession: SparkSession, job: Job, @@ -414,11 +423,12 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow + val readSupport = new ParquetReadSupport(convertTz, true) val reader = if (pushed.isDefined && enableRecordFilter) { val parquetFilter = FilterCompat.get(pushed.get, null) - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz), parquetFilter) + new ParquetRecordReader[UnsafeRow](readSupport, parquetFilter) } else { - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) + new ParquetRecordReader[UnsafeRow](readSupport) } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 40ce5d5e0564..00db3cc62ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -47,16 +47,25 @@ import org.apache.spark.sql.types._ * * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] * to [[prepareForRead()]], but use a private `var` for simplicity. + * + * @param parquetMrCompatibility support reading with parquet-mr or Spark's built-in Parquet reader */ -private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) +private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone], + parquetMrCompatibility: Boolean) extends ReadSupport[UnsafeRow] with Logging { private var catalystRequestedSchema: StructType = _ + /** + * Construct a [[ParquetReadSupport]] with [[convertTz]] set to [[None]] and + * [[parquetMrCompatibility]] set to [[false]]. + * + * We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only + * used in the vectorized reader, where we get the convertTz value directly, and the value here + * is ignored. Further, we set [[parquetMrCompatibility]] to [[false]] as this constructor is only + * called by the Spark reader. + */ def this() { - // We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only - // used in the vectorized reader, where we get the convertTz value directly, and the value here - // is ignored. - this(None) + this(None, false) } /** @@ -71,9 +80,22 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) StructType.fromString(schemaString) } - val parquetRequestedSchema = + val clippedParquetSchema = ParquetReadSupport.clipParquetSchema(context.getFileSchema, catalystRequestedSchema) + val parquetRequestedSchema = if (parquetMrCompatibility) { + // Parquet-mr will throw an exception if we try to read a superset of the file's schema. + // Therefore, we intersect our clipped schema with the underlying file's schema + ParquetReadSupport.intersectParquetGroups(clippedParquetSchema, context.getFileSchema) + .map(intersectionGroup => + new MessageType(intersectionGroup.getName, intersectionGroup.getFields)) + .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE) + } else { + // Spark's built-in Parquet reader will throw an exception in some cases if the requested + // schema is not the same as the clipped schema + clippedParquetSchema + } + new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -96,7 +118,7 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) |Parquet form: |$parquetRequestedSchema |Catalyst form: - |$catalystRequestedSchema + |${catalystRequestedSchema.prettyJson} """.stripMargin } @@ -288,6 +310,27 @@ private[parquet] object ParquetReadSupport { } } + /** + * Computes the structural intersection between two Parquet group types. + */ + private def intersectParquetGroups( + groupType1: GroupType, groupType2: GroupType): Option[GroupType] = { + val fields = + groupType1.getFields.asScala + .filter(field => groupType2.containsField(field.getName)) + .flatMap { + case field1: GroupType => + intersectParquetGroups(field1, groupType2.getType(field1.getName).asGroupType) + case field1 => Some(field1) + } + + if (fields.nonEmpty) { + Some(groupType1.withNewFields(fields.asJava)) + } else { + None + } + } + def expandUDT(schema: StructType): StructType = { def expand(dataType: DataType): DataType = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 119972594184..8bf0f32fdf61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -130,8 +130,8 @@ private[parquet] class ParquetRowConverter( extends ParquetGroupConverter(updater) with Logging { assert( - parquetType.getFieldCount == catalystType.length, - s"""Field counts of the Parquet schema and the Catalyst schema don't match: + parquetType.getFieldCount <= catalystType.length, + s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema: | |Parquet schema: |$parquetType @@ -182,10 +182,12 @@ private[parquet] class ParquetRowConverter( // Converters for each field. private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { - parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map { - case ((parquetFieldType, catalystField), ordinal) => - // Converted field value should be set to the `ordinal`-th cell of `currentRow` - newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) + parquetType.getFields.asScala.map { + case parquetField => + val fieldIndex = catalystType.fieldIndex(parquetField.getName) + val catalystField = catalystType(fieldIndex) + // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` + newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) }.toArray } @@ -193,7 +195,7 @@ private[parquet] class ParquetRowConverter( override def end(): Unit = { var i = 0 - while (i < currentRow.numFields) { + while (i < fieldConverters.length) { fieldConverters(i).updater.end() i += 1 } @@ -202,11 +204,15 @@ private[parquet] class ParquetRowConverter( override def start(): Unit = { var i = 0 - while (i < currentRow.numFields) { + while (i < fieldConverters.length) { fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } + while (i < currentRow.numFields) { + currentRow.setNullAt(i) + i += 1 + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala new file mode 100644 index 000000000000..c66afff2bc70 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruning.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression, NamedExpression} +import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ProjectionOverSchema, SelectedField} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} + +/** + * Prunes unnecessary Parquet columns given a [[PhysicalOperation]] over a + * [[ParquetRelation]]. By "Parquet column", we mean a column as defined in the + * Parquet format. In Spark SQL, a root-level Parquet column corresponds to a + * SQL column, and a nested Parquet column corresponds to a [[StructField]]. + */ +private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = + if (SQLConf.get.nestedSchemaPruningEnabled) { + apply0(plan) + } else { + plan + } + + private def apply0(plan: LogicalPlan): LogicalPlan = + plan transformDown { + case op @ PhysicalOperation(projects, filters, + l @ LogicalRelation(hadoopFsRelation @ HadoopFsRelation(_, partitionSchema, + dataSchema, _, parquetFormat: ParquetFileFormat, _), _, _, _)) => + val projectionFields = projects.flatMap(getFields) + val filterFields = filters.flatMap(getFields) + val requestedFields = (projectionFields ++ filterFields).distinct + + // If [[requestedFields]] includes a nested field, continue. Otherwise, + // return [[op]] + if (requestedFields.exists { case (_, optAtt) => optAtt.isEmpty }) { + val prunedSchema = requestedFields + .map { case (field, _) => StructType(Array(field)) } + .reduceLeft(_ merge _) + val dataSchemaFieldNames = dataSchema.fieldNames.toSet + val prunedDataSchema = + StructType(prunedSchema.filter(f => dataSchemaFieldNames.contains(f.name))) + + // If the data schema is different from the pruned data schema, continue. Otherwise, + // return [[op]]. We effect this comparison by counting the number of "leaf" fields in + // each schemata, assuming the fields in [[prunedDataSchema]] are a subset of the fields + // in [[dataSchema]]. + if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) { + val prunedParquetRelation = + hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession) + + // We need to replace the expression ids of the pruned relation output attributes + // with the expression ids of the original relation output attributes so that + // references to the original relation's output are not broken + val outputIdMap = l.output.map(att => (att.name, att.exprId)).toMap + val prunedRelationOutput = + prunedParquetRelation + .schema + .toAttributes + .map { + case att if outputIdMap.contains(att.name) => + att.withExprId(outputIdMap(att.name)) + case att => att + } + val prunedRelation = + l.copy(relation = prunedParquetRelation, output = prunedRelationOutput) + + val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) + + // Construct a new target for our projection by rewriting and + // including the original filters where available + val projectionChild = + if (filters.nonEmpty) { + val projectedFilters = filters.map(_.transformDown { + case projectionOverSchema(expr) => expr + }) + val newFilterCondition = projectedFilters.reduce(And) + Filter(newFilterCondition, prunedRelation) + } else { + prunedRelation + } + + val nonDataPartitionColumnNames = + partitionSchema.map(_.name).filterNot(dataSchemaFieldNames.contains).toSet + + // Construct the new projections of our [[Project]] by + // rewriting the original projections + val newProjects = projects.map { + case project if (nonDataPartitionColumnNames.contains(project.name)) => project + case project => + (project transformDown { + case projectionOverSchema(expr) => expr + }).asInstanceOf[NamedExpression] + } + + logDebug("New projects:\n" + newProjects.map(_.treeString).mkString("\n")) + logDebug(s"Pruned data schema:\n${prunedDataSchema.treeString}") + + Project(newProjects, projectionChild) + } else { + op + } + } else { + op + } + } + + /** + * Gets the top-level (no-parent) [[StructField]]s for the given [[Expression]]. + * When [[expr]] is an [[Attribute]], construct a field around it and return the + * attribute as the second component of the returned tuple. + */ + private def getFields(expr: Expression): Seq[(StructField, Option[Attribute])] = { + expr match { + case att: Attribute => + (StructField(att.name, att.dataType, att.nullable), Some(att)) :: Nil + case SelectedField(field) => (field, None) :: Nil + case _ => + expr.children.flatMap(getFields) + } + } + + /** + * Counts the "leaf" fields of the given [[dataType]]. Informally, this is the + * number of fields of non-complex data type in the tree representation of + * [[dataType]]. + */ + private def countLeaves(dataType: DataType): Int = { + dataType match { + case array: ArrayType => countLeaves(array.elementType) + case map: MapType => countLeaves(map.keyType) + countLeaves(map.valueType) + case struct: StructType => + struct.map(field => countLeaves(field.dataType)).sum + case _ => 1 + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 60e84e6ee750..72af9d7d9721 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2265,4 +2265,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val df = spark.range(1).select($"id", new Column(Uuid())) checkAnswer(df, df.collect()) } + + test("SPARK-4502: Nested column pruning shouldn't fail filter") { + withSQLConf(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") { + withTempPath { dir => + val path = dir.getCanonicalPath + val data = + """{"a":{"b":1,"c":2}} + |{}""".stripMargin + Seq(data).toDF().repartition(1).write.text(path) + checkAnswer( + spark.read.json(path).filter($"a.b" > 1).select($"a.b"), + Seq.empty) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 771e1186e63a..4acd257618f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -36,6 +36,24 @@ class JoinSuite extends QueryTest with SharedSQLContext { setupTestData() + private lazy val upperCaseStruct: DataFrame = { + val df = sql("select named_struct(\"N\", N, \"L\", L) as S from uppercasedata") + df.createOrReplaceTempView("upperCaseStruct") + df + } + + private lazy val lowerCaseStruct: DataFrame = { + val df = sql("select named_struct(\"n\", n, \"l\", l) as s from lowercasedata") + df.createOrReplaceTempView("lowerCaseStruct") + df + } + + override def loadTestData(): Unit = { + super.loadTestData + upperCaseStruct + lowerCaseStruct + } + def statisticSizeInByte(df: DataFrame): BigInt = { df.queryExecution.optimizedPlan.stats.sizeInBytes } @@ -167,6 +185,19 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("inner join with struct where, one match per row") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer( + upperCaseStruct.join(lowerCaseStruct.select("s.n", "s.l")).where('n === $"S.N"), + Seq( + Row(Row(1, "A"), 1, "a"), + Row(Row(2, "B"), 2, "b"), + Row(Row(3, "C"), 3, "c"), + Row(Row(4, "D"), 4, "d") + )) + } + } + test("inner join ON, one match per row") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( @@ -180,6 +211,19 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("inner join with struct ON, one match per row") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer( + upperCaseStruct.join(lowerCaseStruct.select("s.n", "s.l"), $"n" === $"S.N"), + Seq( + Row(Row(1, "A"), 1, "a"), + Row(Row(2, "B"), 2, "b"), + Row(Row(3, "C"), 3, "c"), + Row(Row(4, "D"), 4, "d") + )) + } + } + test("inner join, where, multiple matches") { val x = testData2.where($"a" === 1).as("x") val y = testData2.where($"a" === 1).as("y") @@ -310,6 +354,72 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("left outer join with struct") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer( + upperCaseStruct.join(lowerCaseStruct.select("s.n", "s.l"), $"n" === $"S.N", "left"), + Row(Row(1, "A"), 1, "a") :: + Row(Row(2, "B"), 2, "b") :: + Row(Row(3, "C"), 3, "c") :: + Row(Row(4, "D"), 4, "d") :: + Row(Row(5, "E"), null, null) :: + Row(Row(6, "F"), null, null) :: Nil) + + checkAnswer( + upperCaseStruct + .join(lowerCaseStruct.select("s.n", "s.l"), $"n" === $"S.N" && $"n" > 1, "left"), + Row(Row(1, "A"), null, null) :: + Row(Row(2, "B"), 2, "b") :: + Row(Row(3, "C"), 3, "c") :: + Row(Row(4, "D"), 4, "d") :: + Row(Row(5, "E"), null, null) :: + Row(Row(6, "F"), null, null) :: Nil) + + checkAnswer( + upperCaseStruct + .join(lowerCaseStruct.select("s.n", "s.l"), $"n" === $"S.N" && $"S.N" > 1, "left"), + Row(Row(1, "A"), null, null) :: + Row(Row(2, "B"), 2, "b") :: + Row(Row(3, "C"), 3, "c") :: + Row(Row(4, "D"), 4, "d") :: + Row(Row(5, "E"), null, null) :: + Row(Row(6, "F"), null, null) :: Nil) + + checkAnswer( + upperCaseStruct + .join(lowerCaseStruct.select("s.n", "s.l"), $"n" === $"S.N" && $"l" > $"S.L", "left"), + Row(Row(1, "A"), 1, "a") :: + Row(Row(2, "B"), 2, "b") :: + Row(Row(3, "C"), 3, "c") :: + Row(Row(4, "D"), 4, "d") :: + Row(Row(5, "E"), null, null) :: + Row(Row(6, "F"), null, null) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.S.N, count(*) + |FROM uppercasestruct l LEFT OUTER JOIN allnulls r ON (l.S.N = r.a) + |GROUP BY l.S.N + """.stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM uppercasestruct l LEFT OUTER JOIN allnulls r ON (l.S.N = r.a) + |GROUP BY r.a + """.stripMargin), + Row(null, 6) :: Nil) + } + } + test("right outer join") { withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { checkAnswer( @@ -374,6 +484,69 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("right outer join with struct") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + checkAnswer( + lowerCaseStruct.select("s.n", "s.l").join(upperCaseStruct, $"n" === $"S.N", "right"), + Row(1, "a", Row(1, "A")) :: + Row(2, "b", Row(2, "B")) :: + Row(3, "c", Row(3, "C")) :: + Row(4, "d", Row(4, "D")) :: + Row(null, null, Row(5, "E")) :: + Row(null, null, Row(6, "F")) :: Nil) + checkAnswer( + lowerCaseStruct + .select("s.n", "s.l").join(upperCaseStruct, $"n" === $"S.N" && $"n" > 1, "right"), + Row(null, null, Row(1, "A")) :: + Row(2, "b", Row(2, "B")) :: + Row(3, "c", Row(3, "C")) :: + Row(4, "d", Row(4, "D")) :: + Row(null, null, Row(5, "E")) :: + Row(null, null, Row(6, "F")) :: Nil) + checkAnswer( + lowerCaseStruct + .select("s.n", "s.l").join(upperCaseStruct, $"n" === $"S.N" && $"S.N" > 1, "right"), + Row(null, null, Row(1, "A")) :: + Row(2, "b", Row(2, "B")) :: + Row(3, "c", Row(3, "C")) :: + Row(4, "d", Row(4, "D")) :: + Row(null, null, Row(5, "E")) :: + Row(null, null, Row(6, "F")) :: Nil) + checkAnswer( + lowerCaseStruct + .select("s.n", "s.l").join(upperCaseStruct, $"n" === $"S.N" && $"l" > $"S.L", "right"), + Row(1, "a", Row(1, "A")) :: + Row(2, "b", Row(2, "B")) :: + Row(3, "c", Row(3, "C")) :: + Row(4, "d", Row(4, "D")) :: + Row(null, null, Row(5, "E")) :: + Row(null, null, Row(6, "F")) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allnulls l RIGHT OUTER JOIN uppercasestruct r ON (l.a = r.S.N) + |GROUP BY l.a + """.stripMargin), + Row(null, 6)) + + checkAnswer( + sql( + """ + |SELECT r.S.N, count(*) + |FROM allnulls l RIGHT OUTER JOIN uppercasestruct r ON (l.a = r.S.N) + |GROUP BY r.S.N + """.stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: Nil) + } + } + test("full outer join") { upperCaseData.where('N <= 4).createOrReplaceTempView("`left`") upperCaseData.where('N >= 3).createOrReplaceTempView("`right`") @@ -465,6 +638,91 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 10)) } + test("full outer join with struct") { + upperCaseStruct.where($"S.N" <= 4).createOrReplaceTempView("`left`") + upperCaseStruct.where($"S.N" >= 3).createOrReplaceTempView("`right`") + + val left = UnresolvedRelation(TableIdentifier("left")) + val right = UnresolvedRelation(TableIdentifier("right")) + + checkAnswer( + left.join(right.select("S.N", "S.L"), $"left.S.N" === $"N", "full"), + Row(Row(1, "A"), null, null) :: + Row(Row(2, "B"), null, null) :: + Row(Row(3, "C"), 3, "C") :: + Row(Row(4, "D"), 4, "D") :: + Row(null, 5, "E") :: + Row(null, 6, "F") :: Nil) + + checkAnswer( + left.join(right.select("S.N", "S.L"), ($"left.S.N" === $"N") && ($"left.S.N" =!= 3), "full"), + Row(Row(1, "A"), null, null) :: + Row(Row(2, "B"), null, null) :: + Row(Row(3, "C"), null, null) :: + Row(null, 3, "C") :: + Row(Row(4, "D"), 4, "D") :: + Row(null, 5, "E") :: + Row(null, 6, "F") :: Nil) + + checkAnswer( + left.join(right.select("S.N", "S.L"), ($"left.S.N" === $"N") && ($"N" =!= 3), "full"), + Row(Row(1, "A"), null, null) :: + Row(Row(2, "B"), null, null) :: + Row(Row(3, "C"), null, null) :: + Row(null, 3, "C") :: + Row(Row(4, "D"), 4, "D") :: + Row(null, 5, "E") :: + Row(null, 6, "F") :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.a, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseStruct r ON (l.a = r.S.N) + |GROUP BY l.a + """.stripMargin), + Row(null, 10)) + + checkAnswer( + sql( + """ + |SELECT r.S.N, count(*) + |FROM allNulls l FULL OUTER JOIN upperCaseStruct r ON (l.a = r.S.N) + |GROUP BY r.S.N + """.stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT l.S.N, count(*) + |FROM upperCaseStruct l FULL OUTER JOIN allNulls r ON (l.S.N = r.a) + |GROUP BY l.S.N + """.stripMargin), + Row(1, 1) :: + Row(2, 1) :: + Row(3, 1) :: + Row(4, 1) :: + Row(5, 1) :: + Row(6, 1) :: + Row(null, 4) :: Nil) + + checkAnswer( + sql( + """ + |SELECT r.a, count(*) + |FROM upperCaseStruct l FULL OUTER JOIN allNulls r ON (l.S.N = r.a) + |GROUP BY r.a + """.stripMargin), + Row(null, 10)) + } + test("broadcasted existence join operator selection") { spark.sharedState.cacheManager.clearCache() sql("CACHE TABLE testData") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/FileSchemaPruningTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/FileSchemaPruningTest.scala new file mode 100644 index 000000000000..e7b634d227c8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/FileSchemaPruningTest.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.scalactic.Equality +import org.scalatest.Assertions + +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.catalyst.SchemaPruningTest +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types.StructType + +private[sql] trait FileSchemaPruningTest extends SchemaPruningTest { + _: Assertions => + + private val schemaEquality = new Equality[StructType] { + override def areEqual(a: StructType, b: Any) = + b match { + case otherType: StructType => a sameType otherType + case _ => false + } + } + + protected def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { + val fileSourceScanSchemata = + df.queryExecution.executedPlan.collect { + case scan: FileSourceScanExec => scan.requiredSchema + } + assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, + s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " + + s"but expected ${expectedSchemaCatalogStrings}") + fileSourceScanSchemata.zip(expectedSchemaCatalogStrings).foreach { + case (scanSchema, expectedScanSchemaCatalogString) => + val expectedScanSchema = CatalystSqlParser.parseDataType(expectedScanSchemaCatalogString) + implicit val equality = schemaEquality + assert(scanSchema === expectedScanSchema) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index e1f094d0a7af..2613d6b4f38a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -108,7 +108,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext val queryOutput = selfJoin.queryExecution.analyzed.output assertResult(4, "Field count mismatches")(queryOutput.size) - assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") { + assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") { queryOutput.filter(_.name == "_1").map(_.exprId).size } @@ -117,7 +117,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - struct with array field") { - val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i")))) + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map { case Tuple1((_, Seq(string))) => Row(string) @@ -126,7 +126,7 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } test("nested data - array of struct") { - val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i"))) + val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i"))) withParquetTable(data, "t") { checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map { case Tuple1(Seq((_, string))) => Row(string) @@ -879,6 +879,15 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } } + + test("select function over nested data") { + val data = (1 to 10).map(i => Tuple1((i, s"val_$i"))) + withParquetTable(data, "t") { + checkAnswer(sql("SELECT isnotnull(_1._2) FROM t"), data.map { + case _ => Row(true) + }) + } + } } object TestingUDT { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningJoinSuite.scala new file mode 100644 index 000000000000..5db6ee3a003c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningJoinSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.execution.FileSchemaPruningTest +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetSchemaPruningJoinSuite + extends QueryTest + with ParquetTest + with FileSchemaPruningTest + with SharedSQLContext { + setupTestData() + + private lazy val upperCaseStructData: DataFrame = { + val df = sql("select named_struct(\"N\", N, \"L\", L) as S from uppercasedata") + df.createOrReplaceTempView("upperCaseStruct") + df + } + + private lazy val lowerCaseStructData: DataFrame = { + val df = sql("select named_struct(\"n\", n, \"l\", l) as s from lowercasedata") + df.createOrReplaceTempView("lowerCaseStruct") + df + } + + override def loadTestData(): Unit = { + super.loadTestData + upperCaseStructData + lowerCaseStructData + } + + testStandardAndLegacyModes("schema pruning join 1") { + asParquetTable(upperCaseStructData, "r1") { + asParquetTable(lowerCaseStructData, "r2") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + def join(joinType: String): DataFrame = + sql(s"select s.n from r1 $joinType join r2 on r1.S.N = r2.s.n") + val scanSchema1 = "struct>" + val scanSchema2 = "struct>" + checkScanSchemata( + join("inner"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("left outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("right outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("full outer"), + scanSchema1, + scanSchema2) + } + } + } + } + + testStandardAndLegacyModes("schema pruning join 2") { + asParquetTable(upperCaseStructData, "r1") { + asParquetTable(lowerCaseStructData, "r2") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + def join(joinType: String): DataFrame = + sql(s"select s.l from r1 $joinType join r2 on r1.S.N = r2.s.n") + val scanSchema1 = "struct>" + val scanSchema2 = "struct>" + checkScanSchemata( + join("inner"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("left outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("right outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("full outer"), + scanSchema1, + scanSchema2) + } + } + } + } + + testStandardAndLegacyModes("schema pruning join 3") { + asParquetTable(upperCaseStructData, "r1") { + asParquetTable(lowerCaseStructData, "r2") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + def join(joinType: String): DataFrame = + sql(s"select S.L from r1 $joinType join r2 on r1.S.N = r2.s.n") + val scanSchema1 = "struct>" + val scanSchema2 = "struct>" + checkScanSchemata( + join("inner"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("left outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("right outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("full outer"), + scanSchema1, + scanSchema2) + } + } + } + } + + testStandardAndLegacyModes("schema pruning join 4") { + asParquetTable(upperCaseStructData, "r1") { + asParquetTable(lowerCaseStructData, "r2") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + def join(joinType: String): DataFrame = + sql(s"select count(s.n) from r1 $joinType join r2 on r1.S.N = r2.s.n") + val scanSchema1 = "struct>" + val scanSchema2 = "struct>" + checkScanSchemata( + join("inner"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("left outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("right outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("full outer"), + scanSchema1, + scanSchema2) + } + } + } + } + + testStandardAndLegacyModes("schema pruning join 5") { + asParquetTable(upperCaseStructData, "r1") { + asParquetTable(lowerCaseStructData, "r2") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + def join(joinType: String): DataFrame = + sql(s"select count(1), s.n from r1 $joinType join r2 on r1.S.N = r2.s.n group by s.n") + val scanSchema1 = "struct>" + val scanSchema2 = "struct>" + checkScanSchemata( + join("inner"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("left outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("right outer"), + scanSchema1, + scanSchema2) + checkScanSchemata( + join("full outer"), + scanSchema1, + scanSchema2) + } + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala new file mode 100644 index 000000000000..5d0f28184673 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.execution.FileSchemaPruningTest +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetSchemaPruningSuite + extends QueryTest + with ParquetTest + with FileSchemaPruningTest + with SharedSQLContext { + case class FullName(first: String, middle: String, last: String) + case class Contact(name: FullName, address: String, pets: Int, friends: Array[FullName] = Array(), + relatives: Map[String, FullName] = Map()) + + val contacts = + Contact(FullName("Jane", "X.", "Doe"), "123 Main Street", 1) :: + Contact(FullName("John", "Y.", "Doe"), "321 Wall Street", 3) :: Nil + + case class Name(first: String, last: String) + case class BriefContact(name: Name, address: String) + + val briefContacts = + BriefContact(Name("Janet", "Jones"), "567 Maple Drive") :: + BriefContact(Name("Jim", "Jones"), "6242 Ash Street") :: Nil + + testStandardAndLegacyModes("partial schema intersection - select missing subfield") { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + val query = sql("select name.middle, address from contacts where p=2") + checkScanSchemata(query, "struct,address:string>") + checkAnswer(query, + Row(null, "567 Maple Drive") :: + Row(null, "6242 Ash Street") :: Nil) + } + } + + testStandardAndLegacyModes("partial schema intersection - filter on subfield") { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + val query = + sql("select name.middle, name.first, pets, address from contacts where " + + "name.first = 'Janet' and p=2") + checkScanSchemata(query, + "struct,pets:int,address:string>") + checkAnswer(query, + Row(null, "Janet", null, "567 Maple Drive") :: Nil) + } + } + + testStandardAndLegacyModes("no unnecessary schema pruning") { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + val query = + sql("select name.last, name.middle, name.first, relatives[''].last, " + + "relatives[''].middle, relatives[''].first, friends[0].last, friends[0].middle, " + + "friends[0].first, pets, address from contacts where p=2") + // We've selected every field in the schema. Therefore, no schema pruning should be performed. + // We check this by asserting that the scanned schema of the query is identical to the schema + // of the contacts relation, even though the fields are selected in different orders. + checkScanSchemata(query, + "struct,address:string,pets:int," + + "friends:array>," + + "relatives:map>>") + checkAnswer(query, + Row("Jones", null, "Janet", null, null, null, null, null, null, null, "567 Maple Drive") :: + Row("Jones", null, "Jim", null, null, null, null, null, null, null, "6242 Ash Street") :: + Nil) + } + } + + testStandardAndLegacyModes("empty schema intersection") { + withTempPath { dir => + val path = dir.getCanonicalPath + + makeParquetFile(contacts, new File(path + "/contacts/p=1")) + makeParquetFile(briefContacts, new File(path + "/contacts/p=2")) + + spark.read.parquet(path + "/contacts").createOrReplaceTempView("contacts") + + val query = sql("select name.middle from contacts where p=2") + checkScanSchemata(query, "struct>") + checkAnswer(query, + Row(null) :: Row(null) :: Nil) + } + } + + testStandardAndLegacyModes("aggregation over nested data") { + withParquetTable(contacts, "contacts") { + val query = sql("select count(distinct name.last), address from contacts group by address " + + "order by address") + checkScanSchemata(query, "struct>") + checkAnswer(query, + Row(1, "123 Main Street") :: + Row(1, "321 Wall Street") :: Nil) + } + } + + testStandardAndLegacyModes("select function over nested data") { + withParquetTable(contacts, "contacts") { + val query = sql("select count(name.middle) from contacts") + checkScanSchemata(query, "struct>") + checkAnswer(query, + Row(2) :: Nil) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index f05f5722af51..bf15525cbda3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -59,42 +59,68 @@ private[sql] trait ParquetTest extends SQLTestUtils { } /** - * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` + * Writes `df` to a Parquet file, which is then passed to `f` and will be deleted after `f` * returns. */ - protected def withParquetFile[T <: Product: ClassTag: TypeTag] - (data: Seq[T]) + protected def asParquetFile + (df: DataFrame) (f: String => Unit): Unit = { withTempPath { file => - spark.createDataFrame(data).write.parquet(file.getCanonicalPath) + df.write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } + /** + * Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f` + * returns. + */ + protected def withParquetFile[T <: Product: ClassTag: TypeTag] + (data: Seq[T]) + (f: String => Unit): Unit = asParquetFile(spark.createDataFrame(data))(f) + + /** + * Writes `df` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. + */ + protected def asParquetDataFrame + (df: DataFrame, testVectorized: Boolean = true) + (f: DataFrame => Unit): Unit = + asParquetFile(df)(path => readParquetFile(path.toString, testVectorized)(f)) + /** * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T], testVectorized: Boolean = true) - (f: DataFrame => Unit): Unit = { - withParquetFile(data)(path => readParquetFile(path.toString, testVectorized)(f)) - } + (f: DataFrame => Unit): Unit = + asParquetDataFrame(spark.createDataFrame(data), testVectorized)(f) /** - * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * Writes `df` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a * temporary table named `tableName`, then call `f`. The temporary table together with the * Parquet file will be dropped/deleted after `f` returns. */ - protected def withParquetTable[T <: Product: ClassTag: TypeTag] - (data: Seq[T], tableName: String, testVectorized: Boolean = true) + protected def asParquetTable + (df: DataFrame, tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { - withParquetDataFrame(data, testVectorized) { df => + asParquetDataFrame(df, testVectorized) { df => df.createOrReplaceTempView(tableName) withTempView(tableName)(f) } } + /** + * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Parquet file will be dropped/deleted after `f` returns. + */ + protected def withParquetTable[T <: Product: ClassTag: TypeTag] + (data: Seq[T], tableName: String, testVectorized: Boolean = true) + (f: => Unit): Unit = + asParquetTable(spark.createDataFrame(data), tableName, testVectorized)(f) + protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { spark.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)