diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index 5abc6f3ed5769..1d34e07ccd17c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.StructType /** * Performs (external) sorting. @@ -71,36 +72,8 @@ case class SortExec( * should make it public. */ def createSorter(): UnsafeExternalRowSorter = { - val ordering = RowOrdering.create(sortOrder, output) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - val canUseRadixSort = enableRadixSort && sortOrder.length == 1 && - SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) - - // The generator for prefix - val prefixExpr = SortPrefix(boundSortExpression) - val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix - override def computePrefix(row: InternalRow): - UnsafeExternalRowSorter.PrefixComputer.Prefix = { - val prefix = prefixProjection.apply(row) - result.isNull = prefix.isNullAt(0) - result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) - result - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - rowSorter = UnsafeExternalRowSorter.create( - schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) - - if (testSpillFrequency > 0) { - rowSorter.setTestSpillFrequency(testSpillFrequency) - } + rowSorter = SortExec.createSorter( + sortOrder, output, schema, enableRadixSort, testSpillFrequency) rowSorter } @@ -206,3 +179,43 @@ case class SortExec( override protected def withNewChildInternal(newChild: SparkPlan): SortExec = copy(child = newChild) } +object SortExec { + def createSorter( + sortOrder: Seq[SortOrder], + output: Seq[Attribute], + schema: StructType, + enableRadixSort: Boolean, + testSpillFrequency: Int = 0): UnsafeExternalRowSorter = { + val ordering = RowOrdering.create(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + val canUseRadixSort = enableRadixSort && sortOrder.length == 1 && + SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression) + + // The generator for prefix + val prefixExpr = SortPrefix(boundSortExpression) + val prefixProjection = UnsafeProjection.create(Seq(prefixExpr)) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix + override def computePrefix(row: InternalRow): + UnsafeExternalRowSorter.PrefixComputer.Prefix = { + val prefix = prefixProjection.apply(row) + result.isNull = prefix.isNullAt(0) + result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0) + result + } + } + + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val rowSorter = UnsafeExternalRowSorter.create( + schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort) + + if (testSpillFrequency > 0) { + rowSorter.setTestSpillFrequency(testSpillFrequency) + } + rowSorter + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index dc3ceb5c595d0..4ffed07731ed9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager -import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions -import org.apache.spark.sql.execution.datasources.SchemaPruning +import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes} import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes} import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning} import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs} @@ -37,7 +36,8 @@ class SparkOptimizer( override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] = // TODO: move SchemaPruning into catalyst - SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil + SchemaPruning :: V2ScanRelationPushDown :: V1Writes :: V2Writes :: + PruneFileSourcePartitions:: Nil override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 338ce8cac420f..b1ff9481ab353 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -45,6 +45,11 @@ trait DataWritingCommand extends UnaryCommand { override final def child: LogicalPlan = query + /** + * resolved by V1Writes and V1HiveWrites + */ + def outputOrderResolved: Boolean = true + // Output column names of the analyzed input query plan. def outputColumnNames: Seq[String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index e64426f8de8f3..5fbda46d0f538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -140,7 +140,8 @@ case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, query: LogicalPlan, - outputColumnNames: Seq[String]) + outputColumnNames: Seq[String], + override val outputOrderResolved: Boolean = false) extends DataWritingCommand { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 99361264148df..a06fda3a89ef6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -534,7 +534,8 @@ case class DataSource( } val resolved = cmd.copy( partitionColumns = resolvedPartCols, - outputColumnNames = outputColumnNames) + outputColumnNames = outputColumnNames, + outputOrderResolved = true) resolved.run(sparkSession, physicalPlan) DataWritingCommand.propogateMetrics(sparkSession.sparkContext, resolved, metrics) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6febbd590f246..9ef6b16fc35b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -213,7 +213,8 @@ object DataSourceAnalysis extends Rule[LogicalPlan] with CastSupport { mode, table, Some(t.location), - actualQuery.output.map(_.name)) + actualQuery.output.map(_.name), + false) // For dynamic partition overwrite, we do not delete partition directories ahead. // We write to staging directories and move to final partition directories after writing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 409e33448acf8..97e84ba8d664b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -34,9 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} @@ -45,9 +43,8 @@ import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} - /** A helper object for writing FileFormat data out to a location. */ -object FileFormatWriter extends Logging { +object FileFormatWriter extends Logging with V1WritesHelper { /** Describes how output files should be placed in the filesystem. */ case class OutputSpec( outputPath: String, @@ -78,6 +75,7 @@ object FileFormatWriter extends Logging { maxWriters: Int, createSorter: () => UnsafeExternalRowSorter) + // scalastyle:off argcount /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -100,6 +98,7 @@ object FileFormatWriter extends Logging { outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], + staticPartitionColumns: Seq[Attribute], bucketSpec: Option[BucketSpec], statsTrackers: Seq[WriteJobStatsTracker], options: Map[String, String]) @@ -121,40 +120,7 @@ object FileFormatWriter extends Logging { case attr => attr } val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan - - val writerBucketSpec = bucketSpec.map { spec => - val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) - - if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == - "true") { - // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. - // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of - // columns is negative. See Hive implementation in - // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. - val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) - val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) - - // The bucket file name prefix is following Hive, Presto and Trino conversion, so this - // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. - // - // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. - // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. - val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" - WriterBucketSpec(bucketIdExpression, fileNamePrefix) - } else { - // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id - // expression, so that we can guarantee the data distribution is same between shuffle and - // bucketed data source, which enables us to only shuffle one side when join a bucketed - // table and a normal one. - val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) - .partitionIdExpression - WriterBucketSpec(bucketIdExpression, (_: Int) => "") - } - } - val sortColumns = bucketSpec.toSeq.flatMap { - spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) - } - + val writerBucketSpec = getBucketSpec(bucketSpec, dataColumns, options) val caseInsensitiveOptions = CaseInsensitiveMap(options) val dataSchema = dataColumns.toStructType @@ -180,20 +146,6 @@ object FileFormatWriter extends Logging { statsTrackers = statsTrackers ) - // We should first sort by partition columns, then bucket id, and finally sorting columns. - val requiredOrdering = - partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns - // the sort order doesn't matter - val actualOrdering = empty2NullPlan.outputOrdering.map(_.child) - val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { - false - } else { - requiredOrdering.zip(actualOrdering).forall { - case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) - } - } - SQLExecution.checkSQLExecutionId(sparkSession) // propagate the description UUID into the jobs, so that committers @@ -204,28 +156,25 @@ object FileFormatWriter extends Logging { // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. committer.setupJob(job) + val sortColumns = getBucketSortColumns(bucketSpec, dataColumns) try { - val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { - (empty2NullPlan.execute(), None) + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + val rdd = empty2NullPlan.execute() + val concurrentOutputWriterSpec = if (concurrentWritersEnabled) { + val enableRadixSort = sparkSession.sessionState.conf.enableRadixSort + val output = empty2NullPlan.output + val outputSchema = empty2NullPlan.schema + Some(ConcurrentOutputWriterSpec(maxWriters, + () => SortExec.createSorter( + getSortOrder(output, partitionColumns, staticPartitionColumns, + bucketSpec, options), + output, + outputSchema, + enableRadixSort + ))) } else { - // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and - // the physical plan may have different attribute ids due to optimizer removing some - // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. - val orderingExpr = bindReferences( - requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) - val sortPlan = SortExec( - orderingExpr, - global = false, - child = empty2NullPlan) - - val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters - val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty - if (concurrentWritersEnabled) { - (empty2NullPlan.execute(), - Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))) - } else { - (sortPlan.execute(), None) - } + None } // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single @@ -274,6 +223,7 @@ object FileFormatWriter extends Logging { throw QueryExecutionErrors.jobAbortedError(cause) } } + // scalastyle:on argcount /** Writes data out in a single Spark task. */ private def executeTask( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 267b360b474ca..51b8d8503b7c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -56,7 +56,8 @@ case class InsertIntoHadoopFsRelationCommand( mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], - outputColumnNames: Seq[String]) + outputColumnNames: Seq[String], + override val outputOrderResolved: Boolean = false) extends DataWritingCommand { private lazy val parameters = CaseInsensitiveMap(options) @@ -181,6 +182,7 @@ case class InsertIntoHadoopFsRelationCommand( committerOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionColumns, + staticPartitionColumns = partitionColumns.take(staticPartitions.size), bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala new file mode 100644 index 0000000000000..4930962c06ed9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/V1Writes.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, AttributeSet, BitwiseAnd, HiveHash, Literal, Pmod, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Sort} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.command.CreateDataSourceTableAsSelectCommand +import org.apache.spark.sql.internal.SQLConf + +/** + * A rule that constructs logical writes for datasource v1. + */ +object V1Writes extends Rule[LogicalPlan] with V1WritesHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case i @ InsertIntoHadoopFsRelationCommand(_, staticPartitions, _, partitionColumns, + bucketSpec, _, options, query, _, _, _, _, false) => + i.copy( + query = + prepareQuery( + query, + i.outputColumns, + partitionColumns, + partitionColumns.take(staticPartitions.size), + bucketSpec, + options), + outputOrderResolved = true) + + case c @ CreateDataSourceTableAsSelectCommand(table, _, query, _, false) => + val partitionColumns = table.partitionColumnNames.map { name => + query.resolve(name :: Nil, SparkSession.active.sessionState.analyzer.resolver).getOrElse { + throw QueryCompilationErrors.cannotResolveAttributeError( + name, query.output.map(_.name).mkString(", ")) + }.asInstanceOf[Attribute] + } + c.copy( + query = + prepareQuery( + query, + c.outputColumns, + partitionColumns, + Seq.empty, + table.bucketSpec, + table.storage.properties + ), + outputOrderResolved = true) + + case _ => plan + } +} + +trait V1WritesHelper { + + def getBucketSpec( + bucketSpec: Option[BucketSpec], + dataColumns: Seq[Attribute], + options: Map[String, String]): Option[WriterBucketSpec] = { + bucketSpec.map { spec => + val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") == + "true") { + // Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression. + // Without the extra bitwise-and operation, we can get wrong bucket id when hash value of + // columns is negative. See Hive implementation in + // `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`. + val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) + val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets)) + + // The bucket file name prefix is following Hive, Presto and Trino conversion, so this + // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. + // + // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_" + WriterBucketSpec(bucketIdExpression, fileNamePrefix) + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets) + .partitionIdExpression + WriterBucketSpec(bucketIdExpression, (_: Int) => "") + } + } + } + + def getBucketSortColumns( + bucketSpec: Option[BucketSpec], dataColumns: Seq[Attribute]): Seq[Attribute] = { + bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) + } + } + + def getSortOrder( + outputColumns: Seq[Attribute], + partitionColumns: Seq[Attribute], + staticPartitions: Seq[Attribute], + bucketSpec: Option[BucketSpec], + options: Map[String, String]): Seq[SortOrder] = { + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = outputColumns.filterNot(partitionSet.contains) + val writerBucketSpec = getBucketSpec(bucketSpec, dataColumns, options) + val sortColumns = getBucketSortColumns(bucketSpec, dataColumns) + + assert(partitionColumns.size >= staticPartitions.size) + // We should first sort by partition columns, then bucket id, and finally sorting columns. + (partitionColumns.take(partitionColumns.size - staticPartitions.size) ++ + writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns) + .map(SortOrder(_, Ascending)) + } + + def prepareQuery( + query: LogicalPlan, + outputColumns: Seq[Attribute], + partitionColumns: Seq[Attribute], + staticPartitions: Seq[Attribute], + bucketSpec: Option[BucketSpec], + options: Map[String, String]): LogicalPlan = { + val requiredOrdering = getSortOrder( + outputColumns, partitionColumns, staticPartitions, bucketSpec, options) + val actualOrdering = query.outputOrdering + val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { + false + } else { + requiredOrdering.zip(actualOrdering).forall { + case (requiredOrder, childOutputOrder) => + requiredOrder.semanticEquals(childOutputOrder) + } + } + + val partitionSet = AttributeSet(partitionColumns) + val dataColumns = outputColumns.filterNot(partitionSet.contains) + val sortColumns = getBucketSortColumns(bucketSpec, dataColumns) + if (orderingMatched || + (SQLConf.get.maxConcurrentOutputFileWriters > 0 && sortColumns.isEmpty)) { + query + } else { + Sort(requiredOrdering, false, query) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 5058a1dfc3baf..91dff9b4339d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -176,6 +176,7 @@ class FileStreamSink( outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), hadoopConf = hadoopConf, partitionColumns = partitionColumns, + staticPartitionColumns = Seq.empty, bucketSpec = None, statsTrackers = Seq(basicWriteJobStatsTracker), options = options) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 471f2c2303048..01da6843228f1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.execution.datasources.v2.TableCapabilityCheck import org.apache.spark.sql.execution.streaming.ResolveWriteToStream import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient -import org.apache.spark.sql.hive.execution.PruneHiveTablePartitions +import org.apache.spark.sql.hive.execution.{PruneHiveTablePartitions, V1HiveWrites} import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState, SparkUDFExpressionBuilder} import org.apache.spark.util.Utils @@ -113,7 +113,7 @@ class HiveSessionStateBuilder( } override def customEarlyScanPushDownRules: Seq[Rule[LogicalPlan]] = - Seq(new PruneHiveTablePartitions(session)) + V1HiveWrites :: new PruneHiveTablePartitions(session) :: Nil /** * Planner that takes into account Hive-specific strategies. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 96b41dd8e35fa..f392c14e3d019 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -115,7 +115,8 @@ case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, outputColumnNames: Seq[String], - mode: SaveMode) + mode: SaveMode, + override val outputOrderResolved: Boolean = false) extends CreateHiveTableAsSelectBase { override def getWritingCommand( @@ -152,7 +153,8 @@ case class OptimizedCreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, outputColumnNames: Seq[String], - mode: SaveMode) + mode: SaveMode, + override val outputOrderResolved: Boolean = false) extends CreateHiveTableAsSelectBase { override def getWritingCommand( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 7484a1edcd92f..3e90684cebd92 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,20 +17,15 @@ package org.apache.spark.sql.hive.execution -import java.util.Locale - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc -import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils import org.apache.spark.sql.hive.HiveExternalCatalog @@ -75,8 +70,9 @@ case class InsertIntoHiveTable( query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean, - outputColumnNames: Seq[String]) extends SaveAsHiveFile { - + outputColumnNames: Seq[String], + override val outputOrderResolved: Boolean = false) + extends SaveAsHiveFile with V1HiveWritesHelper { /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the @@ -130,55 +126,7 @@ case class InsertIntoHiveTable( tmpLocation: Path, child: SparkPlan): Unit = { val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - - val numDynamicPartitions = partition.values.count(_.isEmpty) - val numStaticPartitions = partition.values.count(_.nonEmpty) - val partitionSpec = partition.map { - case (key, Some(null)) => key -> ExternalCatalogUtils.DEFAULT_PARTITION_NAME - case (key, Some(value)) => key -> value - case (key, None) => key -> "" - } - - // All partition column names in the format of "//..." - val partitionColumns = fileSinkConf.getTableInfo.getProperties.getProperty("partition_columns") - val partitionColumnNames = Option(partitionColumns).map(_.split("/")).getOrElse(Array.empty) - - // By this time, the partition map must match the table's partition columns - if (partitionColumnNames.toSet != partition.keySet) { - throw QueryExecutionErrors.requestedPartitionsMismatchTablePartitionsError(table, partition) - } - - // Validate partition spec if there exist any dynamic partitions - if (numDynamicPartitions > 0) { - // Report error if dynamic partitioning is not enabled - if (!hadoopConf.get("hive.exec.dynamic.partition", "true").toBoolean) { - throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) - } - - // Report error if dynamic partition strict mode is on but no static partition is found - if (numStaticPartitions == 0 && - hadoopConf.get("hive.exec.dynamic.partition.mode", "strict").equalsIgnoreCase("strict")) { - throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) - } - - // Report error if any static partition appears after a dynamic partition - val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) - if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { - throw new AnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) - } - } - - val partitionAttributes = partitionColumnNames.takeRight(numDynamicPartitions).map { name => - val attr = query.resolve(name :: Nil, sparkSession.sessionState.analyzer.resolver).getOrElse { - throw QueryCompilationErrors.cannotResolveAttributeError( - name, query.output.map(_.name).mkString(", ")) - }.asInstanceOf[Attribute] - // SPARK-28054: Hive metastore is not case preserving and keeps partition columns - // with lower cased names. Hive will validate the column names in the partition directories - // during `loadDynamicPartitions`. Spark needs to write partition directories with lower-cased - // column names in order to make `loadDynamicPartitions` work. - attr.withName(name.toLowerCase(Locale.ROOT)) - } + val partitionAttributes = getPartitionColumns(table, query, partition) val writtenParts = saveAsHiveFile( sparkSession = sparkSession, @@ -189,6 +137,8 @@ case class InsertIntoHiveTable( partitionAttributes = partitionAttributes, bucketSpec = table.bucketSpec) + val partitionSpec = getPartitionSpec(partition) + val numDynamicPartitions = partition.values.count(_.isEmpty) if (partition.nonEmpty) { if (numDynamicPartitions > 0) { if (overwrite && table.tableType == CatalogTableType.EXTERNAL) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 7f885729bd2be..22f9bb2f881e2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -37,13 +37,13 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.DataWritingCommand -import org.apache.spark.sql.execution.datasources.{BucketingUtils, FileFormatWriter} +import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveVersion // Base trait from which all hive insert statement physical execution extends. -private[hive] trait SaveAsHiveFile extends DataWritingCommand { +private[hive] trait SaveAsHiveFile extends DataWritingCommand with V1HiveWritesHelper { var createdTempDir: Option[Path] = None @@ -86,10 +86,6 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { jobId = java.util.UUID.randomUUID().toString, outputPath = outputLocation) - val options = bucketSpec - .map(_ => Map(BucketingUtils.optionForHiveCompatibleBucketWrite -> "true")) - .getOrElse(Map.empty) - FileFormatWriter.write( sparkSession = sparkSession, plan = plan, @@ -99,9 +95,10 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, partitionColumns = partitionAttributes, + staticPartitionColumns = Seq.empty, bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), - options = options) + options = options(bucketSpec)) } protected def getExternalTmpPath( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1HiveWrites.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1HiveWrites.scala new file mode 100644 index 0000000000000..54387dc8a09cd --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/V1HiveWrites.scala @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.util.Locale + +import org.apache.hadoop.hive.ql.ErrorMsg +import org.apache.hadoop.hive.ql.plan.TableDesc + +import org.apache.spark.SparkException +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.datasources.{BucketingUtils, V1WritesHelper} +import org.apache.spark.sql.hive.client.HiveClientImpl + +/** + * A rule that constructs logical writes for Hive. + */ +object V1HiveWrites extends Rule[LogicalPlan] with V1WritesHelper with V1HiveWritesHelper { + override def apply(plan: LogicalPlan): LogicalPlan = plan match { + case i @ InsertIntoHiveTable(table, partition, query, _, _, _, false) => + val partitionCols = getPartitionColumns(table, query, partition) + i.copy( + query = + prepareQuery(i.query, + i.outputColumns, + partitionCols, + partitionCols.take(partition.count(_._2.isDefined)), + table.bucketSpec, + options(table.bucketSpec)), + outputOrderResolved = true) + + case c @ CreateHiveTableAsSelectCommand(tableDesc, query, _, _, false) => + // if table is not exists the schema should always be empty + val table = if (tableDesc.schema.isEmpty) { + val tableSchema = CharVarcharUtils.getRawSchema(c.outputColumns.toStructType, conf) + tableDesc.copy(schema = tableSchema) + } else { + tableDesc + } + // For CTAS, there is no static partition values to insert. + val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap + c.copy( + query = + prepareQuery(query, + c.outputColumns, + getPartitionColumns(table, query, partition), + Seq.empty, + tableDesc.bucketSpec, + options(tableDesc.bucketSpec)), + outputOrderResolved = true) + + // OptimizedCreateHiveTableAsSelectCommand does not support partitioned table + case c @ OptimizedCreateHiveTableAsSelectCommand(tableDesc, query, _, _, false) => + c.copy( + query = + prepareQuery(query, + c.outputColumns, + Seq.empty, + Seq.empty, + tableDesc.bucketSpec, + options(tableDesc.bucketSpec)), + outputOrderResolved = true) + + case _ => plan + } +} + +trait V1HiveWritesHelper { + def options(bucketSpec: Option[BucketSpec]): Map[String, String] = { + bucketSpec + .map(_ => Map(BucketingUtils.optionForHiveCompatibleBucketWrite -> "true")) + .getOrElse(Map.empty) + } + + def getPartitionSpec(partition: Map[String, Option[String]]): Map[String, String] = { + partition.map { + case (key, Some(null)) => key -> ExternalCatalogUtils.DEFAULT_PARTITION_NAME + case (key, Some(value)) => key -> value + case (key, None) => key -> "" + } + } + + def getPartitionColumns( + table: CatalogTable, + query: LogicalPlan, + partition: Map[String, Option[String]]): Seq[Attribute] = { + val hadoopConf = SparkSession.active.sessionState.newHadoopConf() + val numStaticPartitions = partition.values.count(_.nonEmpty) + val numDynamicPartitions = partition.values.count(_.isEmpty) + + val hiveQlTable = HiveClientImpl.toHiveTable(table) + val tableDesc = new TableDesc( + hiveQlTable.getInputFormatClass, + hiveQlTable.getOutputFormatClass, + hiveQlTable.getMetadata + ) + // All partition column names in the format of "//..." + val partitionColumns = tableDesc.getProperties.getProperty("partition_columns") + val partitionColumnNames = Option(partitionColumns).map(_.split("/")).getOrElse(Array.empty) + val partitionSpec = getPartitionSpec(partition) + + // By this time, the partition map must match the table's partition columns + if (partitionColumnNames.toSet != partition.keySet) { + throw QueryExecutionErrors.requestedPartitionsMismatchTablePartitionsError(table, partition) + } + + // Validate partition spec if there exist any dynamic partitions + if (numDynamicPartitions > 0) { + // Report error if dynamic partitioning is not enabled + if (!hadoopConf.get("hive.exec.dynamic.partition", "true").toBoolean) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) + } + + // Report error if dynamic partition strict mode is on but no static partition is found + if (numStaticPartitions == 0 && + hadoopConf.get("hive.exec.dynamic.partition.mode", "strict").equalsIgnoreCase("strict")) { + throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) + } + + // Report error if any static partition appears after a dynamic partition + val isDynamic = partitionColumnNames.map(partitionSpec(_).isEmpty) + if (isDynamic.init.zip(isDynamic.tail).contains((true, false))) { + throw new AnalysisException(ErrorMsg.PARTITION_DYN_STA_ORDER.getMsg) + } + } + + partitionColumnNames.takeRight(numDynamicPartitions).map { name => + val attr = query.resolve(name :: Nil, SparkSession.active.sessionState.analyzer.resolver) + .getOrElse { + throw QueryCompilationErrors.cannotResolveAttributeError( + name, query.output.map(_.name).mkString(", ")) + }.asInstanceOf[Attribute] + // SPARK-28054: Hive metastore is not case preserving and keeps partition columns + // with lower cased names. Hive will validate the column names in the partition directories + // during `loadDynamicPartitions`. Spark needs to write partition directories with lower-cased + // column names in order to make `loadDynamicPartitions` work. + attr.withName(name.toLowerCase(Locale.ROOT)) + } + } +}