Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType

/**
* Performs (external) sorting.
Expand Down Expand Up @@ -71,36 +72,8 @@ case class SortExec(
* should make it public.
*/
def createSorter(): UnsafeExternalRowSorter = {
val ordering = RowOrdering.create(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)

// The generator for prefix
val prefixExpr = SortPrefix(boundSortExpression)
val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
val prefix = prefixProjection.apply(row)
result.isNull = prefix.isNullAt(0)
result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0)
result
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)

if (testSpillFrequency > 0) {
rowSorter.setTestSpillFrequency(testSpillFrequency)
}
rowSorter = SortExec.createSorter(
sortOrder, output, schema, enableRadixSort, testSpillFrequency)
rowSorter
}

Expand Down Expand Up @@ -206,3 +179,43 @@ case class SortExec(
override protected def withNewChildInternal(newChild: SparkPlan): SortExec =
copy(child = newChild)
}
object SortExec {
def createSorter(
sortOrder: Seq[SortOrder],
output: Seq[Attribute],
schema: StructType,
enableRadixSort: Boolean,
testSpillFrequency: Int = 0): UnsafeExternalRowSorter = {
val ordering = RowOrdering.create(sortOrder, output)

// The comparator for comparing prefix
val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)

val canUseRadixSort = enableRadixSort && sortOrder.length == 1 &&
SortPrefixUtils.canSortFullyWithPrefix(boundSortExpression)

// The generator for prefix
val prefixExpr = SortPrefix(boundSortExpression)
val prefixProjection = UnsafeProjection.create(Seq(prefixExpr))
val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
private val result = new UnsafeExternalRowSorter.PrefixComputer.Prefix
override def computePrefix(row: InternalRow):
UnsafeExternalRowSorter.PrefixComputer.Prefix = {
val prefix = prefixProjection.apply(row)
result.isNull = prefix.isNullAt(0)
result.value = if (result.isNull) prefixExpr.nullValue else prefix.getLong(0)
result
}
}

val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
val rowSorter = UnsafeExternalRowSorter.create(
schema, ordering, prefixComparator, prefixComputer, pageSize, canUseRadixSort)

if (testSpillFrequency > 0) {
rowSorter.setTestSpillFrequency(testSpillFrequency)
}
rowSorter
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions
import org.apache.spark.sql.execution.datasources.SchemaPruning
import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions, SchemaPruning, V1Writes}
import org.apache.spark.sql.execution.datasources.v2.{V2ScanRelationPushDown, V2Writes}
import org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters, PartitionPruning}
import org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate, ExtractPythonUDFFromAggregate, ExtractPythonUDFs}
Expand All @@ -37,7 +36,8 @@ class SparkOptimizer(

override def earlyScanPushDownRules: Seq[Rule[LogicalPlan]] =
// TODO: move SchemaPruning into catalyst
SchemaPruning :: V2ScanRelationPushDown :: V2Writes :: PruneFileSourcePartitions :: Nil
SchemaPruning :: V2ScanRelationPushDown :: V1Writes :: V2Writes ::
PruneFileSourcePartitions:: Nil

override def defaultBatches: Seq[Batch] = (preOptimizationBatches ++ super.defaultBatches :+
Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ trait DataWritingCommand extends UnaryCommand {

override final def child: LogicalPlan = query

/**
* resolved by V1Writes and V1HiveWrites
*/
def outputOrderResolved: Boolean = true

// Output column names of the analyzed input query plan.
def outputColumnNames: Seq[String]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ case class CreateDataSourceTableAsSelectCommand(
table: CatalogTable,
mode: SaveMode,
query: LogicalPlan,
outputColumnNames: Seq[String])
outputColumnNames: Seq[String],
override val outputOrderResolved: Boolean = false)
extends DataWritingCommand {

override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,8 @@ case class DataSource(
}
val resolved = cmd.copy(
partitionColumns = resolvedPartCols,
outputColumnNames = outputColumnNames)
outputColumnNames = outputColumnNames,
outputOrderResolved = true)
resolved.run(sparkSession, physicalPlan)
DataWritingCommand.propogateMetrics(sparkSession.sparkContext, resolved, metrics)
// Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ object DataSourceAnalysis extends Rule[LogicalPlan] with CastSupport {
mode,
table,
Some(t.location),
actualQuery.output.map(_.name))
actualQuery.output.map(_.name),
false)

// For dynamic partition overwrite, we do not delete partition directories ahead.
// We write to staging directories and move to final partition directories after writing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter}
Expand All @@ -45,9 +43,8 @@ import org.apache.spark.sql.types.StringType
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.{SerializableConfiguration, Utils}


/** A helper object for writing FileFormat data out to a location. */
object FileFormatWriter extends Logging {
object FileFormatWriter extends Logging with V1WritesHelper {
/** Describes how output files should be placed in the filesystem. */
case class OutputSpec(
outputPath: String,
Expand Down Expand Up @@ -78,6 +75,7 @@ object FileFormatWriter extends Logging {
maxWriters: Int,
createSorter: () => UnsafeExternalRowSorter)

// scalastyle:off argcount
/**
* Basic work flow of this command is:
* 1. Driver side setup, including output committer initialization and data source specific
Expand All @@ -100,6 +98,7 @@ object FileFormatWriter extends Logging {
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
staticPartitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String])
Expand All @@ -121,40 +120,7 @@ object FileFormatWriter extends Logging {
case attr => attr
}
val empty2NullPlan = if (needConvert) ProjectExec(projectList, plan) else plan

val writerBucketSpec = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)

if (options.getOrElse(BucketingUtils.optionForHiveCompatibleBucketWrite, "false") ==
"true") {
// Hive bucketed table: use `HiveHash` and bitwise-and as bucket id expression.
// Without the extra bitwise-and operation, we can get wrong bucket id when hash value of
// columns is negative. See Hive implementation in
// `org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#getBucketNumber()`.
val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue))
val bucketIdExpression = Pmod(hashId, Literal(spec.numBuckets))

// The bucket file name prefix is following Hive, Presto and Trino conversion, so this
// makes sure Hive bucketed table written by Spark, can be read by other SQL engines.
//
// Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`.
// Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`.
val fileNamePrefix = (bucketId: Int) => f"$bucketId%05d_0_"
WriterBucketSpec(bucketIdExpression, fileNamePrefix)
} else {
// Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id
// expression, so that we can guarantee the data distribution is same between shuffle and
// bucketed data source, which enables us to only shuffle one side when join a bucketed
// table and a normal one.
val bucketIdExpression = HashPartitioning(bucketColumns, spec.numBuckets)
.partitionIdExpression
WriterBucketSpec(bucketIdExpression, (_: Int) => "")
}
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

val writerBucketSpec = getBucketSpec(bucketSpec, dataColumns, options)
val caseInsensitiveOptions = CaseInsensitiveMap(options)

val dataSchema = dataColumns.toStructType
Expand All @@ -180,20 +146,6 @@ object FileFormatWriter extends Logging {
statsTrackers = statsTrackers
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering =
partitionColumns ++ writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns
// the sort order doesn't matter
val actualOrdering = empty2NullPlan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}

SQLExecution.checkSQLExecutionId(sparkSession)

// propagate the description UUID into the jobs, so that committers
Expand All @@ -204,28 +156,25 @@ object FileFormatWriter extends Logging {
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

val sortColumns = getBucketSortColumns(bucketSpec, dataColumns)
try {
val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) {
(empty2NullPlan.execute(), None)
val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
val rdd = empty2NullPlan.execute()
val concurrentOutputWriterSpec = if (concurrentWritersEnabled) {
val enableRadixSort = sparkSession.sessionState.conf.enableRadixSort
val output = empty2NullPlan.output
val outputSchema = empty2NullPlan.schema
Some(ConcurrentOutputWriterSpec(maxWriters,
() => SortExec.createSorter(
getSortOrder(output, partitionColumns, staticPartitionColumns,
bucketSpec, options),
output,
outputSchema,
enableRadixSort
)))
} else {
// SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and
// the physical plan may have different attribute ids due to optimizer removing some
// aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch.
val orderingExpr = bindReferences(
requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns)
val sortPlan = SortExec(
orderingExpr,
global = false,
child = empty2NullPlan)

val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters
val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty
if (concurrentWritersEnabled) {
(empty2NullPlan.execute(),
Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter())))
} else {
(sortPlan.execute(), None)
}
None
}

// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
Expand Down Expand Up @@ -274,6 +223,7 @@ object FileFormatWriter extends Logging {
throw QueryExecutionErrors.jobAbortedError(cause)
}
}
// scalastyle:on argcount

/** Writes data out in a single Spark task. */
private def executeTask(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ case class InsertIntoHadoopFsRelationCommand(
mode: SaveMode,
catalogTable: Option[CatalogTable],
fileIndex: Option[FileIndex],
outputColumnNames: Seq[String])
outputColumnNames: Seq[String],
override val outputOrderResolved: Boolean = false)
extends DataWritingCommand {

private lazy val parameters = CaseInsensitiveMap(options)
Expand Down Expand Up @@ -181,6 +182,7 @@ case class InsertIntoHadoopFsRelationCommand(
committerOutputPath.toString, customPartitionLocations, outputColumns),
hadoopConf = hadoopConf,
partitionColumns = partitionColumns,
staticPartitionColumns = partitionColumns.take(staticPartitions.size),
bucketSpec = bucketSpec,
statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)),
options = options)
Expand Down
Loading