diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 36081069b0e75..d7a6d6450ebc0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,6 +20,7 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; import java.util.Iterator; import scala.Option; @@ -31,18 +32,19 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.io.ByteStreams; import com.google.common.io.Closeables; -import com.google.common.io.Files; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; +import org.apache.spark.api.shuffle.ShufflePartitionWriter; +import org.apache.spark.api.shuffle.ShuffleWriteSupport; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; import org.apache.commons.io.output.CloseShieldOutputStream; -import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -53,7 +55,6 @@ import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; -import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.Utils; @@ -65,7 +66,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private static final ClassTag OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object(); @VisibleForTesting - static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; @@ -74,6 +74,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final SerializerInstance serializer; private final Partitioner partitioner; private final ShuffleWriteMetricsReporter writeMetrics; + private final ShuffleWriteSupport shuffleWriteSupport; private final int shuffleId; private final int mapId; private final TaskContext taskContext; @@ -81,7 +82,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final boolean transferToEnabled; private final int initialSortBufferSize; private final int inputBufferSizeInBytes; - private final int outputBufferSizeInBytes; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; @@ -103,18 +103,6 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream */ private boolean stopping = false; - private class CloseAndFlushShieldOutputStream extends CloseShieldOutputStream { - - CloseAndFlushShieldOutputStream(OutputStream outputStream) { - super(outputStream); - } - - @Override - public void flush() { - // do nothing - } - } - public UnsafeShuffleWriter( BlockManager blockManager, IndexShuffleBlockResolver shuffleBlockResolver, @@ -123,7 +111,8 @@ public UnsafeShuffleWriter( int mapId, TaskContext taskContext, SparkConf sparkConf, - ShuffleWriteMetricsReporter writeMetrics) throws IOException { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleWriteSupport shuffleWriteSupport) throws IOException { final int numPartitions = handle.dependency().partitioner().numPartitions(); if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) { throw new IllegalArgumentException( @@ -140,6 +129,7 @@ public UnsafeShuffleWriter( this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); this.writeMetrics = writeMetrics; + this.shuffleWriteSupport = shuffleWriteSupport; this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); @@ -147,8 +137,6 @@ public UnsafeShuffleWriter( (int) sparkConf.get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()); this.inputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; - this.outputBufferSizeInBytes = - (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; open(); } @@ -230,24 +218,27 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; + final ShuffleMapOutputWriter mapWriter = shuffleWriteSupport + .createMapOutputWriter(shuffleId, mapId, partitioner.numPartitions()); final long[] partitionLengths; - final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - final File tmp = Utils.tempFileWith(output); try { try { - partitionLengths = mergeSpills(spills, tmp); + partitionLengths = mergeSpills(spills, mapWriter); } finally { for (SpillInfo spill : spills) { - if (spill.file.exists() && ! spill.file.delete()) { + if (spill.file.exists() && !spill.file.delete()) { logger.error("Error while deleting spill file {}", spill.file.getPath()); } } } - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + mapWriter.commitAllPartitions(); + } catch (Exception e) { + try { + mapWriter.abort(e); + } catch (Exception innerE) { + logger.error("Failed to abort the Map Output Writer", innerE); } + throw e; } mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @@ -281,7 +272,8 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private long[] mergeSpills(SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final boolean compressionEnabled = (boolean) sparkConf.get(package$.MODULE$.SHUFFLE_COMPRESS()); final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf); final boolean fastMergeEnabled = @@ -289,17 +281,24 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); + final int numPartitions = partitioner.numPartitions(); + long[] partitionLengths = new long[numPartitions]; try { if (spills.length == 0) { - new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; - } else if (spills.length == 1) { - // Here, we don't need to perform any metrics updates because the bytes written to this - // output file would have already been counted as shuffle bytes written. - Files.move(spills[0].file, outputFile); - return spills[0].partitionLengths; + // The contract we are working under states that we will open a partition writer for + // each partition, regardless of number of spills + for (int i = 0; i < numPartitions; i++) { + ShufflePartitionWriter writer = null; + try { + writer = mapWriter.getNextPartitionWriter(); + } finally { + if (writer != null) { + writer.close(); + } + } + } + return partitionLengths; } else { - final long[] partitionLengths; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -316,14 +315,14 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + partitionLengths = mergeSpillsWithTransferTo(spills, mapWriter); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, null); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + partitionLengths = mergeSpillsWithFileStream(spills, mapWriter, compressionCodec); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -331,13 +330,9 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // to be counted as shuffle write, but this will lead to double-counting of the final // SpillInfo's bytes. writeMetrics.decBytesWritten(spills[spills.length - 1].file.length()); - writeMetrics.incBytesWritten(outputFile.length()); return partitionLengths; } } catch (IOException e) { - if (outputFile.exists() && !outputFile.delete()) { - logger.error("Unable to delete output file {}", outputFile.getPath()); - } throw e; } } @@ -345,73 +340,79 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], - * File)}, and it's mostly used in cases where the IO compression codec does not support - * concatenation of compressed data, when encryption is enabled, or when users have - * explicitly disabled use of {@code transferTo} in order to work around kernel bugs. + * ShuffleMapOutputWriter)}, and it's mostly used in cases where the IO compression codec + * does not support concatenation of compressed data, when encryption is enabled, or when + * users have explicitly disabled use of {@code transferTo} in order to work around kernel bugs. * This code path might also be faster in cases where individual partition size in a spill * is small and UnsafeShuffleWriter#mergeSpillsWithTransferTo method performs many small * disk ios which is inefficient. In those case, Using large buffers for input and output * files helps reducing the number of disk ios, making the file merging faster. * * @param spills the spills to merge. - * @param outputFile the file to write the merged data to. + * @param mapWriter the map output writer to use for output. * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled. * @return the partition lengths in the merged file. */ private long[] mergeSpillsWithFileStream( SpillInfo[] spills, - File outputFile, + ShuffleMapOutputWriter mapWriter, @Nullable CompressionCodec compressionCodec) throws IOException { - assert (spills.length >= 2); final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; - final OutputStream bos = new BufferedOutputStream( - new FileOutputStream(outputFile), - outputBufferSizeInBytes); - // Use a counting output stream to avoid having to close the underlying file and ask - // the file system for its size after each partition is written. - final CountingOutputStream mergedFileOutputStream = new CountingOutputStream(bos); - boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputStreams[i] = new NioBufferedFileInputStream( - spills[i].file, - inputBufferSizeInBytes); + spills[i].file, + inputBufferSizeInBytes); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = mergedFileOutputStream.getByteCount(); - // Shield the underlying output stream from close() and flush() calls, so that we can close - // the higher level streams to make sure all data is really flushed and internal state is - // cleaned. - OutputStream partitionOutput = new CloseAndFlushShieldOutputStream( - new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - try { - partitionInputStream = blockManager.serializerManager().wrapForEncryption( - partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + boolean copyThrewExecption = true; + ShufflePartitionWriter writer = null; + try { + writer = mapWriter.getNextPartitionWriter(); + OutputStream partitionOutput = null; + try { + // Shield the underlying output stream from close() calls, so that we can close the + // higher level streams to make sure all data is really flushed and internal state + // is cleaned + partitionOutput = new CloseShieldOutputStream(writer.toStream()); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + try { + partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream( + partitionInputStream); + } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); + } } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); + copyThrewExecption = false; } + } finally { + Closeables.close(partitionOutput, copyThrewExecption); } + } finally { + Closeables.close(writer, copyThrewExecption); } - partitionOutput.flush(); - partitionOutput.close(); - partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); + long numBytesWritten = writer.getNumBytesWritten(); + partitionLengths[partition] = numBytesWritten; + writeMetrics.incBytesWritten(numBytesWritten); } threwException = false; } finally { @@ -420,7 +421,6 @@ private long[] mergeSpillsWithFileStream( for (InputStream stream : spillInputStreams) { Closeables.close(stream, threwException); } - Closeables.close(mergedFileOutputStream, threwException); } return partitionLengths; } @@ -430,54 +430,49 @@ private long[] mergeSpillsWithFileStream( * This is only safe when the IO compression codec and serializer support concatenation of * serialized streams. * + * @param spills the spills to merge. + * @param mapWriter the map output writer to use for output. * @return the partition lengths in the merged file. */ - private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException { - assert (spills.length >= 2); + private long[] mergeSpillsWithTransferTo( + SpillInfo[] spills, + ShuffleMapOutputWriter mapWriter) throws IOException { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final FileChannel[] spillInputChannels = new FileChannel[spills.length]; final long[] spillInputChannelPositions = new long[spills.length]; - FileChannel mergedFileOutputChannel = null; boolean threwException = true; try { for (int i = 0; i < spills.length; i++) { spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel(); } - // This file needs to opened in append mode in order to work around a Linux kernel bug that - // affects transferTo; see SPARK-3948 for more details. - mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel(); - - long bytesWrittenToMergedFile = 0; for (int partition = 0; partition < numPartitions; partition++) { - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - final FileChannel spillInputChannel = spillInputChannels[i]; - final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - mergedFileOutputChannel, - spillInputChannelPositions[i], - partitionLengthInSpill); - spillInputChannelPositions[i] += partitionLengthInSpill; - writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); - bytesWrittenToMergedFile += partitionLengthInSpill; - partitionLengths[partition] += partitionLengthInSpill; + boolean copyThrewExecption = true; + ShufflePartitionWriter writer = null; + try { + writer = mapWriter.getNextPartitionWriter(); + WritableByteChannel channel = writer.toChannel(); + for (int i = 0; i < spills.length; i++) { + long partitionLengthInSpill = 0L; + partitionLengthInSpill += spills[i].partitionLengths[partition]; + final FileChannel spillInputChannel = spillInputChannels[i]; + final long writeStartTime = System.nanoTime(); + Utils.copyFileStreamNIO( + spillInputChannel, + channel, + spillInputChannelPositions[i], + partitionLengthInSpill); + copyThrewExecption = false; + spillInputChannelPositions[i] += partitionLengthInSpill; + writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); + } + } finally { + Closeables.close(writer, copyThrewExecption); } - } - // Check the position after transferTo loop to see if it is in the right position and raise an - // exception if it is incorrect. The position will not be increased to the expected length - // after calling transferTo in kernel version 2.6.32. This issue is described at - // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948. - if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) { - throw new IOException( - "Current position " + mergedFileOutputChannel.position() + " does not equal expected " + - "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" + - " version to see if it is 2.6.32, as there is a kernel bug which will lead to " + - "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " + - "to disable this NIO feature." - ); + long numBytes = writer.getNumBytesWritten(); + partitionLengths[partition] = numBytes; + writeMetrics.incBytesWritten(numBytes); } threwException = false; } finally { @@ -487,7 +482,6 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th assert(spillInputChannelPositions[i] == spills[i].file.length()); Closeables.close(spillInputChannels[i], threwException); } - Closeables.close(mergedFileOutputChannel, threwException); } return partitionLengths; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index 0f7e5ed66bb76..c84158e1891d7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -92,7 +92,8 @@ public ShufflePartitionWriter getNextPartitionWriter() throws IOException { @Override public void commitAllPartitions() throws IOException { cleanUp(); - blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, outputTempFile); + File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); } @Override @@ -111,11 +112,9 @@ private void cleanUp() throws IOException { if (outputBufferedFileStream != null) { outputBufferedFileStream.close(); } - if (outputFileChannel != null) { outputFileChannel.close(); } - if (outputFileStream != null) { outputFileStream.close(); } @@ -191,8 +190,9 @@ public long getNumBytesWritten() { } @Override - public void close() throws IOException { + public void close() { if (stream != null) { + // Closing is a no-op. stream.close(); } partitionLengths[partitionId] = getNumBytesWritten(); @@ -222,18 +222,10 @@ public void write(byte[] buf, int pos, int length) throws IOException { } @Override - public void close() throws IOException { - flush(); + public void close() { isClosed = true; } - @Override - public void flush() throws IOException { - if (!isClosed) { - outputBufferedFileStream.flush(); - } - } - private void verifyNotClosed() { if (isClosed) { throw new IllegalStateException("Attempting to write to a closed block output stream."); diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 42a249564cd07..849050556c569 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -146,7 +146,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId, context, env.conf, - metrics) + metrics, + shuffleExecutorComponents.writes()) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 9bf707f783d44..012dc5d21bce4 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -19,8 +19,10 @@ import java.io.*; import java.nio.ByteBuffer; +import java.nio.file.Files; import java.util.*; +import org.mockito.stubbing.Answer; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -39,6 +41,7 @@ import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; import org.apache.spark.TaskContext; +import org.apache.spark.TaskContext$; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.executor.TaskMetrics; import org.apache.spark.io.CompressionCodec$; @@ -53,6 +56,7 @@ import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -65,6 +69,7 @@ public class UnsafeShuffleWriterSuite { + static final int DEFAULT_INITIAL_SORT_BUFFER_SIZE = 4096; static final int NUM_PARTITITONS = 4; TestMemoryManager memoryManager; TaskMemoryManager taskMemoryManager; @@ -85,6 +90,7 @@ public class UnsafeShuffleWriterSuite { @After public void tearDown() { + TaskContext$.MODULE$.unset(); Utils.deleteRecursively(tempDir); final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory(); if (leakedMemory != 0) { @@ -132,14 +138,28 @@ public void setUp() throws IOException { }); when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); - doAnswer(invocationOnMock -> { + + Answer renameTempAnswer = invocationOnMock -> { partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2]; File tmp = (File) invocationOnMock.getArguments()[3]; - mergedOutputFile.delete(); - tmp.renameTo(mergedOutputFile); + if (!mergedOutputFile.delete()) { + throw new RuntimeException("Failed to delete old merged output file."); + } + if (tmp != null) { + Files.move(tmp.toPath(), mergedOutputFile.toPath()); + } else if (!mergedOutputFile.createNewFile()) { + throw new RuntimeException("Failed to create empty merged output file."); + } return null; - }).when(shuffleBlockResolver) - .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + }; + + doAnswer(renameTempAnswer) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class)); + + doAnswer(renameTempAnswer) + .when(shuffleBlockResolver) + .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), eq(null)); when(diskBlockManager.createTempShuffleBlock()).thenAnswer(invocationOnMock -> { TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID()); @@ -151,6 +171,9 @@ public void setUp() throws IOException { when(taskContext.taskMetrics()).thenReturn(taskMetrics); when(shuffleDep.serializer()).thenReturn(serializer); when(shuffleDep.partitioner()).thenReturn(hashPartitioner); + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager); + + TaskContext$.MODULE$.setTaskContext(taskContext); } private UnsafeShuffleWriter createWriter( @@ -164,7 +187,8 @@ private UnsafeShuffleWriter createWriter( 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics() + taskContext.taskMetrics().shuffleWriteMetrics(), + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver) ); } @@ -444,10 +468,10 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpillRadixOn() thro } private void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception { - memoryManager.limit(UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); + memoryManager.limit(DEFAULT_INITIAL_SORT_BUFFER_SIZE * 16); final UnsafeShuffleWriter writer = createWriter(false); final ArrayList> dataToWrite = new ArrayList<>(); - for (int i = 0; i < UnsafeShuffleWriter.DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { + for (int i = 0; i < DEFAULT_INITIAL_SORT_BUFFER_SIZE + 1; i++) { dataToWrite.add(new Tuple2<>(i, i)); } writer.write(dataToWrite.iterator()); @@ -525,7 +549,8 @@ public void testPeakMemoryUsed() throws Exception { 0, // map id taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics()); + taskContext.taskMetrics().shuffleWriteMetrics(), + new DefaultShuffleWriteSupport(conf, shuffleBlockResolver)); // Peak memory should be monotonically increasing. More specifically, every time // we allocate a new page it should increase by exactly the size of the page. diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 7eb867fc29fd2..69fe03e75606f 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -19,7 +19,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.SparkConf import org.apache.spark.benchmark.Benchmark -import org.apache.spark.shuffle.sort.io.{DefaultShuffleWriteSupport} +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport /** * Benchmark to measure performance for aggregate primitives. @@ -46,9 +46,9 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase def getWriter(transferTo: Boolean): BypassMergeSortShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) - val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) conf.set("spark.file.transferTo", String.valueOf(transferTo)) conf.set("spark.shuffle.file.buffer", "32k") + val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 15a08111f6d54..20bf3eac95d84 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle.sort import org.apache.spark.{SparkConf, TaskContext} import org.apache.spark.benchmark.Benchmark +import org.apache.spark.shuffle.sort.io.DefaultShuffleWriteSupport /** * Benchmark to measure performance for aggregate primitives. @@ -42,6 +43,7 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { def getWriter(transferTo: Boolean): UnsafeShuffleWriter[String, String] = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.file.transferTo", String.valueOf(transferTo)) + val shuffleWriteSupport = new DefaultShuffleWriteSupport(conf, blockResolver) TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( @@ -52,7 +54,8 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { 0, taskContext, conf, - taskContext.taskMetrics().shuffleWriteMetrics + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleWriteSupport ) }