diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 55ca4f11068f9..9ad2108a93d43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat +import org.apache.spark.sql.types.{StructField, StructType} /** * A strategy for planning scans over collections of files that might be partitioned or bucketed @@ -97,7 +99,19 @@ object FileSourceStrategy extends Strategy with Logging { dataColumns .filter(requiredAttributes.contains) .filterNot(partitionColumns.contains) - val outputSchema = readDataColumns.toStructType + val outputSchema = if ( + fsRelation.sqlContext.conf.parquetNestedColumnPruningEnabled && + fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] + ) { + val fullSchema = readDataColumns.toStructType + val prunedSchema = StructType( + generateStructFieldsContainsNesting(projects, fullSchema)) + // Merge schema in same StructType and merge with filterAttributes + prunedSchema.fields.map(f => StructType(Array(f))).reduceLeft(_ merge _) + .merge(filterAttributes.toSeq.toStructType) + } else { + readDataColumns.toStructType + } logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}") val pushedDownFilters = dataFilters.flatMap(DataSourceStrategy.translateFilter) @@ -126,4 +140,64 @@ object FileSourceStrategy extends Strategy with Logging { case _ => Nil } + + private[sql] def generateStructFieldsContainsNesting( + projects: Seq[Expression], + fullSchema: StructType) : Seq[StructField] = { + // By traverse projects, we can fisrt generate the access path of nested struct, then use the + // access path reconstruct the schema after pruning. + // In the process of traversing, we should deal with all expressions releted with complex + // struct type like GetArrayItem, GetArrayStructFields, GetMapValue and GetStructField + def generateStructField( + curField: List[String], + node: Expression) : Seq[StructField] = { + node match { + case ai: GetArrayItem => + // Here we drop the previous for simplify array and map support. + // Same strategy in GetArrayStructFields and GetMapValue + generateStructField(List.empty[String], ai.child) + case asf: GetArrayStructFields => + generateStructField(List.empty[String], asf.child) + case mv: GetMapValue => + generateStructField(List.empty[String], mv.child) + case attr: AttributeReference => + // Finally reach the leaf node AttributeReference, call getFieldRecursively + // and pass the access path of current nested struct + Seq(getFieldRecursively(fullSchema, attr.name :: curField)) + case sf: GetStructField if !sf.child.isInstanceOf[CreateNamedStruct] && + !sf.child.isInstanceOf[CreateStruct] => + val name = sf.name.getOrElse(sf.dataType match { + case StructType(fiedls) => + fiedls(sf.ordinal).name + }) + generateStructField(name :: curField, sf.child) + case _ => + if (node.children.nonEmpty) { + node.children.flatMap(child => generateStructField(curField, child)) + } else { + Seq.empty[StructField] + } + } + } + + def getFieldRecursively(schema: StructType, name: List[String]): StructField = { + if (name.length > 1) { + val curField = name.head + val curFieldType = schema(curField) + curFieldType.dataType match { + case st: StructType => + val newField = getFieldRecursively(StructType(st.fields), name.drop(1)) + StructField(curFieldType.name, StructType(Seq(newField)), + curFieldType.nullable, curFieldType.metadata) + case _ => + throw new IllegalArgumentException(s"""Field "$curField" is not struct field.""") + } + } else { + schema(name.head) + } + } + + projects.flatMap(p => generateStructField(List.empty[String], p)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f47ec7f3963a4..c7e44717f611c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -213,6 +213,11 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_NESTED_COLUMN_PRUNING = SQLConfigBuilder("spark.sql.parquet.nestedColumnPruning") + .doc("When true, Parquet column pruning also works for nested fields.") + .booleanConf + .createWithDefault(false) + val PARQUET_CACHE_METADATA = SQLConfigBuilder("spark.sql.parquet.cacheMetadata") .doc("Turns on caching of Parquet schema metadata. Can speed up querying of static data.") .booleanConf @@ -724,6 +729,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + def parquetNestedColumnPruningEnabled: Boolean = getConf(PARQUET_NESTED_COLUMN_PRUNING) + def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) def inMemoryPartitionPruning: Boolean = getConf(IN_MEMORY_PARTITION_PRUNING) diff --git a/sql/core/src/test/resources/test-data/nested-struct.snappy.parquet b/sql/core/src/test/resources/test-data/nested-struct.snappy.parquet new file mode 100644 index 0000000000000..2f02fb0dea3b4 Binary files /dev/null and b/sql/core/src/test/resources/test-data/nested-struct.snappy.parquet differ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c32254d9dfde2..38943febf52f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -27,16 +27,15 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{util, InternalRow} import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} -import org.apache.spark.sql.catalyst.util +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateNamedStruct, Expression, ExpressionSet, GetArrayItem, GetStructField, Literal, PredicateHelper} import org.apache.spark.sql.execution.{DataSourceScanExec, SparkPlan} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { @@ -442,6 +441,132 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi } } + test("[SPARK-4502] pruning nested schema by GetStructField projects") { + // Construct fullSchema like below: + // root + // |-- col: struct (nullable = true) + // | |-- s1: struct (nullable = true) + // | | |-- s1_1: long (nullable = true) + // | | |-- s1_2: long (nullable = true) + // | |-- str: string (nullable = true) + // |-- num: long (nullable = true) + // |-- str: string (nullable = true) + val nested_s1 = StructField("s1", + StructType( + Seq( + StructField("s1_1", LongType, true), + StructField("s1_2", LongType, true) + ) + ), true) + val flat_str = StructField("str", StringType, true) + + val fullSchema = StructType( + Seq( + StructField("col", StructType(Seq(nested_s1, flat_str)), true), + StructField("num", LongType, true), + flat_str + )) + + // Attr of struct col + val colAttr = AttributeReference("col", StructType( + Seq(nested_s1, flat_str)), true)() + // Child expression of col.s1.s1_1 + val childExp = GetStructField( + GetStructField(colAttr, 0, Some("s1")), 0, Some("s1_1")) + + // Project list of "select num, col.s1.s1_1 as s1_1" + val projects = Seq( + AttributeReference("num", LongType, true)(), + Alias(childExp, "s1_1")() + ) + val expextResult = + Seq( + StructField("num", LongType, true), + StructField("col", StructType( + Seq( + StructField( + "s1", + StructType(Seq(StructField("s1_1", LongType, true))), + true) + ) + ), true) + ) + // Call the function generateStructFieldsContainsNesting + val result = FileSourceStrategy.generateStructFieldsContainsNesting(projects, + fullSchema) + assert(result == expextResult) + } + + test("[SPARK-4502] pruning nested schema by GetArrayItem projects") { + // Construct fullSchema like below: + // root + // |-- col: struct (nullable = true) + // | |-- info_list: array (nullable = true) + // | | |-- element: struct (containsNull = true) + // | | | |-- s1: struct (nullable = true) + // | | | | |-- s1_1: long (nullable = true) + // | | | | |-- s1_2: long (nullable = true) + val nested_s1 = StructField("s1", + StructType( + Seq( + StructField("s1_1", LongType, true), + StructField("s1_2", LongType, true) + ) + ), true) + val nested_arr = StructField("info_list", ArrayType(StructType(Seq(nested_s1))), true) + + val fullSchema = StructType( + Seq( + StructField("col", StructType(Seq(nested_arr)), true) + )) + + // Attr of struct col + val colAttr = AttributeReference("col", StructType( + Seq(nested_arr)), true)() + // Child expression of col.info_list[0].s1.s1_1 + val arrayChildExp = GetStructField( + GetStructField( + GetArrayItem( + GetStructField(colAttr, 0, Some("info_list")), + Literal(0) + ), 0, Some("s1") + ), 0, Some("s1_1") + ) + // Project list of "select col.info_list[0].s1.s1_1 as complex_get" + val projects = Seq( + Alias(arrayChildExp, "complex_get")() + ) + val expextResult = + Seq( + StructField("col", StructType(Seq(nested_arr))) + ) + // Call the function generateStructFieldsContainsNesting + val result = FileSourceStrategy.generateStructFieldsContainsNesting(projects, + fullSchema) + assert(result == expextResult) + } + + test("[SPARK-4502] pruning nested schema while named_struct in project") { + val schema = new StructType() + .add("f0", IntegerType) + .add("f1", new StructType() + .add("f10", IntegerType)) + + val expr = GetStructField( + CreateNamedStruct(Seq( + Literal("f10"), + AttributeReference("f0", IntegerType)() + )), + 0, + Some("f10") + ) + + val expect = new StructType() + .add("f0", IntegerType) + + assert(FileSourceStrategy.generateStructFieldsContainsNesting(expr :: Nil, schema) == expect) + } + test("spark.files.ignoreCorruptFiles should work in SQL") { val inputFile = File.createTempFile("input-", ".gz") try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 4c4a7d86f2bd3..c00c071477a23 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -571,6 +571,36 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + test("SPARK-4502 parquet nested fields pruning") { + // Schema of "test-data/nested-array-struct.parquet": + // root + // |-- col: struct (nullable = true) + // | |-- s1: struct (nullable = true) + // | | |-- s1_1: long (nullable = true) + // | | |-- s1_2: long (nullable = true) + // | |-- str: string (nullable = true) + // |-- num: long (nullable = true) + // |-- str: string (nullable = true) + withTempView("tmp_table") { + val df = readResourceParquetFile("test-data/nested-struct.snappy.parquet") + df.createOrReplaceTempView("tmp_table") + // normal test + val query1 = "select num,col.s1.s1_1 from tmp_table" + val result1 = sql(query1) + withSQLConf(SQLConf.PARQUET_NESTED_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query1), result1) + } + // test for same struct meta merge + // col.s1.s1_1 and col.str should merge + // like col.[s1.s1_1, str] before pass to parquet + val query2 = "select col.s1.s1_1,col.str from tmp_table" + val result2 = sql(query2) + withSQLConf(SQLConf.PARQUET_NESTED_COLUMN_PRUNING.key -> "true") { + checkAnswer(sql(query2), result2) + } + } + } + test("expand UDT in StructType") { val schema = new StructType().add("n", new NestedStructUDT, nullable = true) val expected = new StructType().add("n", new NestedStructUDT().sqlType, nullable = true)