Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
metrics,
shuffleExecutorComponents.writes())
case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
new SortShuffleWriter(
shuffleBlockResolver, other, mapId, context, shuffleExecutorComponents.writes())
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
package org.apache.spark.shuffle.sort

import org.apache.spark._
import org.apache.spark.api.shuffle.ShuffleWriteSupport
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter

private[spark] class SortShuffleWriter[K, V, C](
shuffleBlockResolver: IndexShuffleBlockResolver,
handle: BaseShuffleHandle[K, V, C],
mapId: Int,
context: TaskContext)
context: TaskContext,
writeSupport: ShuffleWriteSupport)
extends ShuffleWriter[K, V] with Logging {

private val dep = handle.dependency
Expand Down Expand Up @@ -64,18 +64,11 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
val tmp = Utils.tempFileWith(output)
try {
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
} finally {
if (tmp.exists() && !tmp.delete()) {
logError(s"Error while deleting temp file ${tmp.getAbsolutePath}")
}
}
val mapOutputWriter = writeSupport.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
val partitionLengths = sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
mapOutputWriter.commitAllPartitions()
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}

/** Close this writer, passing along whether the map completed */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.PairsWriter

/**
* A class for writing JVM objects directly to a file on disk. This class allows data to be appended
Expand All @@ -46,7 +47,8 @@ private[spark] class DiskBlockObjectWriter(
writeMetrics: ShuffleWriteMetricsReporter,
val blockId: BlockId = null)
extends OutputStream
with Logging {
with Logging
with PairsWriter {

/**
* Guards against close calls, e.g. from a wrapping stream.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.ByteStreams

import org.apache.spark._
import org.apache.spark.api.shuffle.{ShuffleMapOutputWriter, ShufflePartitionWriter}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.internal.{config, Logging}
import org.apache.spark.serializer._
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, ShuffleBlockId}

/**
* Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
Expand Down Expand Up @@ -674,11 +675,9 @@ private[spark] class ExternalSorter[K, V, C](
}

/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
* TODO remove this, as this is only used by UnsafeRowSerializerSuite in the SQL project.
* We should figure out an alternative way to test that so that we can remove this otherwise
* unused code path.
*/
def writePartitionedFile(
blockId: BlockId,
Expand Down Expand Up @@ -722,6 +721,123 @@ private[spark] class ExternalSorter[K, V, C](
lengths
}

private def writeEmptyPartition(mapOutputWriter: ShuffleMapOutputWriter): Unit = {
var partitionWriter: ShufflePartitionWriter = null
try {
partitionWriter = mapOutputWriter.getNextPartitionWriter

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you need to do partitonWriter.toStream(), as in UnsafeShuffleWriter, to ensure that the outputFileStream is created and an empty file exists. It seems be expected by the UnsafeShuffleWriterSuite, idk if it is the same here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That shouldn't be necessary, the writer.close() should properly know what to do if a stream was never created.

} finally {
if (partitionWriter != null) {
partitionWriter.close()
}
}
}

/**
* Write all the data added into this ExternalSorter into a map output writer that pushes bytes
* to some arbitrary backing store. This is called by the SortShuffleWriter.
*
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedMapOutput(
shuffleId: Int, mapId: Int, mapOutputWriter: ShuffleMapOutputWriter): Array[Long] = {
// Track location of each range in the map output
val lengths = new Array[Long](numPartitions)
var nextPartitionId = 0
if (spills.isEmpty) {
// Case where we only have in-memory data
val collection = if (aggregator.isDefined) map else buffer
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext()) {
val partitionId = it.nextPartition()
// The contract for the plugin is that we will ask for a writer for every partition
// even if it's empty. However, the external sorter will return non-contiguous
// partition ids. So this loop "backfills" the empty partitions that form the gaps.

// The algorithm as a whole is correct because the partition ids are returned by the
// iterator in ascending order.
for (emptyPartition <- nextPartitionId until partitionId) {
writeEmptyPartition(mapOutputWriter)
}
var partitionWriter: ShufflePartitionWriter = null
var partitionPairsWriter: ShufflePartitionPairsWriter = null
try {
partitionWriter = mapOutputWriter.getNextPartitionWriter
val blockId = ShuffleBlockId(shuffleId, mapId, partitionId)
partitionPairsWriter = new ShufflePartitionPairsWriter(
partitionWriter,
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics)
while (it.hasNext && it.nextPartition() == partitionId) {
it.writeNext(partitionPairsWriter)
}
} finally {
if (partitionPairsWriter != null) {
partitionPairsWriter.close()
}
if (partitionWriter != null) {
partitionWriter.close()
}
}
if (partitionWriter != null) {
lengths(partitionId) = partitionWriter.getNumBytesWritten
}
nextPartitionId = partitionId + 1
}
} else {
// We must perform merge-sort; get an iterator by partition and write everything directly.
for ((id, elements) <- this.partitionedIterator) {
// The contract for the plugin is that we will ask for a writer for every partition
// even if it's empty. However, the external sorter will return non-contiguous
// partition ids. So this loop "backfills" the empty partitions that form the gaps.

// The algorithm as a whole is correct because the partition ids are returned by the
// iterator in ascending order.
for (emptyPartition <- nextPartitionId until id) {
writeEmptyPartition(mapOutputWriter)
}
val blockId = ShuffleBlockId(shuffleId, mapId, id)
var partitionWriter: ShufflePartitionWriter = null
var partitionPairsWriter: ShufflePartitionPairsWriter = null
try {
partitionWriter = mapOutputWriter.getNextPartitionWriter
partitionPairsWriter = new ShufflePartitionPairsWriter(
partitionWriter,
serializerManager,
serInstance,
blockId,
context.taskMetrics().shuffleWriteMetrics)
if (elements.hasNext) {
for (elem <- elements) {
partitionPairsWriter.write(elem._1, elem._2)
}
}
} finally {
if (partitionPairsWriter!= null) {
partitionPairsWriter.close()
}
}
if (partitionWriter != null) {
lengths(id) = partitionWriter.getNumBytesWritten
}
nextPartitionId = id + 1
}
}

// The iterator may have stopped short of opening a writer for every partition. So fill in the
// remaining empty partitions.
for (emptyPartition <- nextPartitionId until numPartitions) {
writeEmptyPartition(mapOutputWriter)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm wait why do we need this? Shouldn't new long[numPartitions] fill the array with default value of 0?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on the contract we want to present to other plugin writers. I.e. do we make a contract that we open a writer for strictly every partition, even empty ones? Or do we say we open for the first N partitions where N is the last non-empty partition? My take is that we should have the contract that we always open a writer for every partition, empty or not, from 0 through numPartitions - 1. But, again, this shows the limitation of presenting an API that doesn't include the partition identifier explicitly when getting partition writers.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh i see hmm yea I'm ok keeping it like this then. It does show more consistency for plugin implementers

}

context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

lengths
}

def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
Expand Down Expand Up @@ -785,7 +901,7 @@ private[spark] class ExternalSorter[K, V, C](
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (upstream.hasNext) upstream.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
def writeNext(writer: PairsWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (upstream.hasNext) upstream.next() else null
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

private[spark] trait PairsWriter {

def write(key: Any, value: Any): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util.collection

import java.io.{Closeable, FilterOutputStream, OutputStream}

import org.apache.spark.api.shuffle.ShufflePartitionWriter
import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager}
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
import org.apache.spark.storage.BlockId

/**
* A key-value writer inspired by {@link DiskBlockObjectWriter} that pushes the bytes to an
* arbitrary partition writer instead of writing to local disk through the block manager.
*/
private[spark] class ShufflePartitionPairsWriter(
partitionWriter: ShufflePartitionWriter,
serializerManager: SerializerManager,
serializerInstance: SerializerInstance,
blockId: BlockId,
writeMetrics: ShuffleWriteMetricsReporter)
extends PairsWriter with Closeable {

private var isOpen = false
private var partitionStream: OutputStream = _
private var wrappedStream: OutputStream = _
private var objOut: SerializationStream = _
private var numRecordsWritten = 0
private var curNumBytesWritten = 0L

override def write(key: Any, value: Any): Unit = {
if (!isOpen) {
open()
isOpen = true
}
objOut.writeKey(key)
objOut.writeValue(value)
writeMetrics.incRecordsWritten(1)
}

private def open(): Unit = {
// The contract is that the partition writer is expected to close its own streams, but
// the compressor will only flush the stream when it is specifically closed. So we want to
// close objOut to flush the compressed bytes to the partition writer stream, but we don't want
// to close the partition output stream in the process.
partitionStream = new CloseShieldOutputStream(partitionWriter.toStream)
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
objOut = serializerInstance.serializeStream(wrappedStream)
}

override def close(): Unit = {
if (isOpen) {
// Closing objOut should propagate close to all inner layers
// We can't close wrappedStream explicitly because closing objOut and closing wrappedStream
// causes problems when closing compressed output streams twice.
objOut.close()
objOut = null
wrappedStream = null
partitionStream = null
partitionWriter.close()
isOpen = false
updateBytesWritten()
}
}

/**
* Notify the writer that a record worth of bytes has been written with OutputStream#write.
*/
private def recordWritten(): Unit = {
numRecordsWritten += 1
writeMetrics.incRecordsWritten(1)

if (numRecordsWritten % 16384 == 0) {
updateBytesWritten()
}
}

private def updateBytesWritten(): Unit = {
val numBytesWritten = partitionWriter.getNumBytesWritten
val bytesWrittenDiff = numBytesWritten - curNumBytesWritten
writeMetrics.incBytesWritten(bytesWrittenDiff)
curNumBytesWritten = numBytesWritten
}

private class CloseShieldOutputStream(delegate: OutputStream)
extends FilterOutputStream(delegate) {

override def close(): Unit = flush()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
def writeNext(writer: PairsWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}
Expand Down Expand Up @@ -96,7 +96,7 @@ private[spark] object WritablePartitionedPairCollection {
* has an associated partition.
*/
private[spark] trait WritablePartitionedIterator {
def writeNext(writer: DiskBlockObjectWriter): Unit
def writeNext(writer: PairsWriter): Unit

def hasNext(): Boolean

Expand Down
Loading