diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java index 6a53803e5d117..74c928b0b9c8f 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShufflePartitionWriter.java @@ -17,13 +17,10 @@ package org.apache.spark.api.shuffle; -import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; -import java.nio.channels.Channels; -import java.nio.channels.WritableByteChannel; -import org.apache.http.annotation.Experimental; +import org.apache.spark.annotation.Experimental; /** * :: Experimental :: @@ -32,43 +29,16 @@ * @since 3.0.0 */ @Experimental -public interface ShufflePartitionWriter extends Closeable { +public interface ShufflePartitionWriter { /** - * Returns an underlying {@link OutputStream} that can write bytes to the underlying data store. - *

- * Note that this stream itself is not closed by the caller; close the stream in the - * implementation of this interface's {@link #close()}. + * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying + * data store. */ - OutputStream toStream() throws IOException; + OutputStream openStream() throws IOException; /** - * Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data - * store. - *

- * Note that this channel itself is not closed by the caller; close the channel in the - * implementation of this interface's {@link #close()}. - */ - default WritableByteChannel toChannel() throws IOException { - return Channels.newChannel(toStream()); - } - - /** - * Get the number of bytes written by this writer's stream returned by {@link #toStream()} or - * the channel returned by {@link #toChannel()}. + * Get the number of bytes written by this writer's stream returned by {@link #openStream()}. */ long getNumBytesWritten(); - - /** - * Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()} - * or {@link #toChannel()}. - *

- * This must always close any stream returned by {@link #toStream()}. - *

- * Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel} - * that does not itself need to be closed up front; only the underlying output stream given by - * {@link #toStream()} must be closed. - */ - @Override - void close() throws IOException; } diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java index 6c69d5db9fd06..7e2b6cf4133fd 100644 --- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java @@ -19,7 +19,7 @@ import java.io.IOException; -import org.apache.http.annotation.Experimental; +import org.apache.spark.annotation.Experimental; /** * :: Experimental :: diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java new file mode 100644 index 0000000000000..866b61d0bafd9 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.IOException; + +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * Indicates that partition writers can transfer bytes directly from input byte channels to + * output channels that stream data to the underlying shuffle partition storage medium. + *

+ * This API is separated out for advanced users because it only needs to be used for + * specific low-level optimizations. The idea is that the returned channel can transfer bytes + * from the input file channel out to the backing storage system without copying data into + * memory. + *

+ * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead. + * + * @since 3.0.0 + */ +@Experimental +public interface SupportsTransferTo extends ShufflePartitionWriter { + + /** + * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from + * input byte channels to the underlying shuffle data store. + */ + TransferrableWritableByteChannel openTransferrableChannel() throws IOException; + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}. + */ + @Override + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java new file mode 100644 index 0000000000000..18234d7c4c944 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.shuffle; + +import java.io.Closeable; +import java.io.IOException; + +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import org.apache.spark.annotation.Experimental; + +/** + * :: Experimental :: + * Represents an output byte channel that can copy bytes from input file channels to some + * arbitrary storage system. + *

+ * This API is provided for advanced users who can transfer bytes from a file channel to + * some output sink without copying data into memory. Most users should not need to use + * this functionality; this is primarily provided for the built-in shuffle storage backends + * that persist shuffle files on local disk. + *

+ * For a simpler alternative, see {@link ShufflePartitionWriter}. + * + * @since 3.0.0 + */ +@Experimental +public interface TransferrableWritableByteChannel extends Closeable { + + /** + * Copy all bytes from the source readable byte channel into this byte channel. + * + * @param source File to transfer bytes from. Do not call anything on this channel other than + * {@link FileChannel#transferTo(long, long, WritableByteChannel)}. + * @param transferStartPosition Start position of the input file to transfer from. + * @param numBytesToTransfer Number of bytes to transfer from the given source. + */ + void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer) + throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 22386c39aca0a..128b90429209e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -21,12 +21,10 @@ import java.io.FileInputStream; import java.io.IOException; import java.io.OutputStream; +import java.nio.channels.Channels; import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; import javax.annotation.Nullable; -import org.apache.spark.api.java.Optional; -import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.None$; import scala.Option; import scala.Product2; @@ -38,19 +36,22 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.Partitioner; +import org.apache.spark.ShuffleDependency; +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShufflePartitionWriter; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; -import org.apache.spark.Partitioner; -import org.apache.spark.ShuffleDependency; -import org.apache.spark.SparkConf; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -90,7 +91,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private final int mapId; private final Serializer serializer; private final ShuffleWriteSupport shuffleWriteSupport; - private final IndexShuffleBlockResolver shuffleBlockResolver; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -107,7 +107,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { BypassMergeSortShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle handle, int mapId, SparkConf conf, @@ -124,7 +123,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleBlockResolver = shuffleBlockResolver; this.shuffleWriteSupport = shuffleWriteSupport; } @@ -209,40 +207,43 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); - boolean copyThrewException = true; - ShufflePartitionWriter writer = null; - try { - writer = mapOutputWriter.getPartitionWriter(i); - if (!file.exists()) { - copyThrewException = false; - } else { - if (transferToEnabled) { - WritableByteChannel outputChannel = writer.toChannel(); - FileInputStream in = new FileInputStream(file); - try (FileChannel inputChannel = in.getChannel()) { - Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size()); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); - } - } else { - OutputStream tempOutputStream = writer.toStream(); - FileInputStream in = new FileInputStream(file); - try { - Utils.copyStream(in, tempOutputStream, false, false); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); + if (file.exists()) { + boolean copyThrewException = true; + if (transferToEnabled) { + FileInputStream in = new FileInputStream(file); + TransferrableWritableByteChannel outputChannel = null; + try (FileChannel inputChannel = in.getChannel()) { + if (writer instanceof SupportsTransferTo) { + outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel(); + } else { + // Use default transferrable writable channel anyways in order to have parity with + // UnsafeShuffleWriter. + outputChannel = new DefaultTransferrableWritableByteChannel( + Channels.newChannel(writer.openStream())); } + outputChannel.transferFrom(inputChannel, 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + Closeables.close(outputChannel, copyThrewException); } - if (!file.delete()) { - logger.error("Unable to delete file for partition {}", i); + } else { + FileInputStream in = new FileInputStream(file); + OutputStream outputStream = null; + try { + outputStream = writer.openStream(); + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + Closeables.close(outputStream, copyThrewException); } } - } finally { - Closeables.close(writer, copyThrewException); + if (!file.delete()) { + logger.error("Unable to delete file for partition {}", i); + } } - lengths[i] = writer.getNumBytesWritten(); } } finally { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java new file mode 100644 index 0000000000000..64ce851e392d2 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; +import org.apache.spark.util.Utils; + +/** + * This is used when transferTo is enabled but the shuffle plugin hasn't implemented + * {@link org.apache.spark.api.shuffle.SupportsTransferTo}. + *

+ * This default implementation exists as a convenience to the unsafe shuffle writer and + * the bypass merge sort shuffle writers. + */ +public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel { + + private final WritableByteChannel delegate; + + public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) { + this.delegate = delegate; + } + + @Override + public void transferFrom( + FileChannel source, long transferStartPosition, long numBytesToTransfer) { + Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 069716cc99db0..f147bd79773e1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -17,14 +17,12 @@ package org.apache.spark.shuffle.sort; +import java.nio.channels.Channels; import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; -import java.nio.channels.WritableByteChannel; import java.util.Iterator; -import org.apache.spark.api.java.Optional; -import org.apache.spark.api.shuffle.MapShuffleLocations; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -39,14 +37,17 @@ import org.apache.spark.*; import org.apache.spark.annotation.Private; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShufflePartitionWriter; import org.apache.spark.api.shuffle.ShuffleWriteSupport; +import org.apache.spark.api.shuffle.SupportsTransferTo; import org.apache.spark.internal.config.package$; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; import org.apache.spark.io.NioBufferedFileInputStream; -import org.apache.commons.io.output.CloseShieldOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -54,11 +55,9 @@ import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.serializer.SerializationStream; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.BlockManager; import org.apache.spark.unsafe.Platform; -import org.apache.spark.util.Utils; @Private public class UnsafeShuffleWriter extends ShuffleWriter { @@ -71,7 +70,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024; private final BlockManager blockManager; - private final IndexShuffleBlockResolver shuffleBlockResolver; private final TaskMemoryManager memoryManager; private final SerializerInstance serializer; private final Partitioner partitioner; @@ -107,7 +105,6 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream public UnsafeShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, TaskMemoryManager memoryManager, SerializedShuffleHandle handle, int mapId, @@ -123,7 +120,6 @@ public UnsafeShuffleWriter( " reduce partitions"); } this.blockManager = blockManager; - this.shuffleBlockResolver = shuffleBlockResolver; this.memoryManager = memoryManager; this.mapId = mapId; final ShuffleDependency dep = handle.dependency(); @@ -364,45 +360,37 @@ private long[] mergeSpillsWithFileStream( } for (int partition = 0; partition < numPartitions; partition++) { boolean copyThrewExecption = true; - ShufflePartitionWriter writer = null; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + OutputStream partitionOutput = null; try { - writer = mapWriter.getPartitionWriter(partition); - OutputStream partitionOutput = null; - try { - // Shield the underlying output stream from close() calls, so that we can close the - // higher level streams to make sure all data is really flushed and internal state - // is cleaned - partitionOutput = new CloseShieldOutputStream(writer.toStream()); - partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); - if (compressionCodec != null) { - partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); - } - for (int i = 0; i < spills.length; i++) { - final long partitionLengthInSpill = spills[i].partitionLengths[partition]; - - if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - try { - partitionInputStream = new LimitedInputStream(spillInputStreams[i], - partitionLengthInSpill, false); - partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionOutput = writer.openStream(); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); + if (compressionCodec != null) { + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); + } + for (int i = 0; i < spills.length; i++) { + final long partitionLengthInSpill = spills[i].partitionLengths[partition]; + + if (partitionLengthInSpill > 0) { + InputStream partitionInputStream = null; + try { + partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); + if (compressionCodec != null) { + partitionInputStream = compressionCodec.compressedInputStream( partitionInputStream); - if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream( - partitionInputStream); - } - ByteStreams.copy(partitionInputStream, partitionOutput); - } finally { - partitionInputStream.close(); } + ByteStreams.copy(partitionInputStream, partitionOutput); + } finally { + partitionInputStream.close(); } - copyThrewExecption = false; } - } finally { - Closeables.close(partitionOutput, copyThrewExecption); + copyThrewExecption = false; } } finally { - Closeables.close(writer, copyThrewExecption); + Closeables.close(partitionOutput, copyThrewExecption); } long numBytesWritten = writer.getNumBytesWritten(); partitionLengths[partition] = numBytesWritten; @@ -443,26 +431,26 @@ private long[] mergeSpillsWithTransferTo( } for (int partition = 0; partition < numPartitions; partition++) { boolean copyThrewExecption = true; - ShufflePartitionWriter writer = null; + ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition); + TransferrableWritableByteChannel partitionChannel = null; try { - writer = mapWriter.getPartitionWriter(partition); - WritableByteChannel channel = writer.toChannel(); + partitionChannel = writer instanceof SupportsTransferTo ? + ((SupportsTransferTo) writer).openTransferrableChannel() + : new DefaultTransferrableWritableByteChannel( + Channels.newChannel(writer.openStream())); for (int i = 0; i < spills.length; i++) { long partitionLengthInSpill = 0L; partitionLengthInSpill += spills[i].partitionLengths[partition]; final FileChannel spillInputChannel = spillInputChannels[i]; final long writeStartTime = System.nanoTime(); - Utils.copyFileStreamNIO( - spillInputChannel, - channel, - spillInputChannelPositions[i], - partitionLengthInSpill); - copyThrewExecption = false; + partitionChannel.transferFrom( + spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill); spillInputChannelPositions[i] += partitionLengthInSpill; writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } + copyThrewExecption = false; } finally { - Closeables.close(writer, copyThrewExecption); + Closeables.close(partitionChannel, copyThrewExecption); } long numBytes = writer.getNumBytesWritten(); partitionLengths[partition] = numBytes; diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java index 926c3b9433990..e83db4e4bcef6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java @@ -24,18 +24,21 @@ import java.io.OutputStream; import java.nio.channels.FileChannel; -import org.apache.spark.api.java.Optional; -import org.apache.spark.api.shuffle.MapShuffleLocations; -import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; -import org.apache.spark.storage.BlockManagerId; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.Optional; +import org.apache.spark.api.shuffle.MapShuffleLocations; import org.apache.spark.api.shuffle.ShuffleMapOutputWriter; import org.apache.spark.api.shuffle.ShufflePartitionWriter; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.api.shuffle.SupportsTransferTo; +import org.apache.spark.api.shuffle.TransferrableWritableByteChannel; import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations; +import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; +import org.apache.spark.storage.BlockManagerId; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.TimeTrackingOutputStream; import org.apache.spark.util.Utils; @@ -151,70 +154,70 @@ private void initChannel() throws IOException { } } - private class DefaultShufflePartitionWriter implements ShufflePartitionWriter { + private class DefaultShufflePartitionWriter implements SupportsTransferTo { private final int partitionId; - private PartitionWriterStream stream = null; + private PartitionWriterStream partStream = null; + private PartitionWriterChannel partChannel = null; private DefaultShufflePartitionWriter(int partitionId) { this.partitionId = partitionId; } @Override - public OutputStream toStream() throws IOException { - if (outputFileChannel != null) { - throw new IllegalStateException("Requested an output channel for a previous write but" + - " now an output stream has been requested. Should not be using both channels" + - " and streams to write."); + public OutputStream openStream() throws IOException { + if (partStream == null) { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + partStream = new PartitionWriterStream(partitionId); } - initStream(); - stream = new PartitionWriterStream(); - return stream; + return partStream; } @Override - public FileChannel toChannel() throws IOException { - if (stream != null) { - throw new IllegalStateException("Requested an output stream for a previous write but" + - " now an output channel has been requested. Should not be using both channels" + - " and streams to write."); + public TransferrableWritableByteChannel openTransferrableChannel() throws IOException { + if (partChannel == null) { + if (partStream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + partChannel = new PartitionWriterChannel(partitionId); } - initChannel(); - return outputFileChannel; + return partChannel; } @Override public long getNumBytesWritten() { - if (outputFileChannel != null && stream == null) { + if (partChannel != null) { try { - long newPosition = outputFileChannel.position(); - return newPosition - currChannelPosition; - } catch (Exception e) { - log.error("The partition which failed is: {}", partitionId, e); - throw new IllegalStateException("Failed to calculate position of file channel", e); + return partChannel.getCount(); + } catch (IOException e) { + throw new RuntimeException(e); } - } else if (stream != null) { - return stream.getCount(); + } else if (partStream != null) { + return partStream.getCount(); } else { // Assume an empty partition if stream and channel are never created return 0; } } - - @Override - public void close() { - if (stream != null) { - // Closing is a no-op. - stream.close(); - } - partitionLengths[partitionId] = getNumBytesWritten(); - } } private class PartitionWriterStream extends OutputStream { + private final int partitionId; private int count = 0; private boolean isClosed = false; + PartitionWriterStream(int partitionId) { + this.partitionId = partitionId; + } + public int getCount() { return count; } @@ -236,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException { @Override public void close() { isClosed = true; + partitionLengths[partitionId] = count; } private void verifyNotClosed() { @@ -244,4 +248,24 @@ private void verifyNotClosed() { } } } + + private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel { + + private final int partitionId; + + PartitionWriterChannel(int partitionId) { + super(outputFileChannel); + this.partitionId = partitionId; + } + + public long getCount() throws IOException { + long writtenPosition = outputFileChannel.position(); + return writtenPosition - currChannelPosition; + } + + @Override + public void close() throws IOException { + partitionLengths[partitionId] = getCount(); + } + } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 849050556c569..5fa9296b022ca 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -140,7 +140,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], context.taskMemoryManager(), unsafeShuffleHandle, mapId, @@ -151,7 +150,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, env.conf, diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index df5ce73b9acf1..14d34e1c47c8e 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -755,9 +755,6 @@ private[spark] class ExternalSorter[K, V, C]( if (partitionPairsWriter != null) { partitionPairsWriter.close() } - if (partitionWriter != null) { - partitionWriter.close() - } } if (partitionWriter != null) { lengths(partitionId) = partitionWriter.getNumBytesWritten diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala index 6f19a2323efde..8538a78b377c8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala @@ -54,25 +54,17 @@ private[spark] class ShufflePartitionPairsWriter( } private def open(): Unit = { - // The contract is that the partition writer is expected to close its own streams, but - // the compressor will only flush the stream when it is specifically closed. So we want to - // close objOut to flush the compressed bytes to the partition writer stream, but we don't want - // to close the partition output stream in the process. - partitionStream = new CloseShieldOutputStream(partitionWriter.toStream) + partitionStream = partitionWriter.openStream wrappedStream = serializerManager.wrapStream(blockId, partitionStream) objOut = serializerInstance.serializeStream(wrappedStream) } override def close(): Unit = { if (isOpen) { - // Closing objOut should propagate close to all inner layers - // We can't close wrappedStream explicitly because closing objOut and closing wrappedStream - // causes problems when closing compressed output streams twice. objOut.close() objOut = null wrappedStream = null partitionStream = null - partitionWriter.close() isOpen = false updateBytesWritten() } @@ -96,10 +88,4 @@ private[spark] class ShufflePartitionPairsWriter( writeMetrics.incBytesWritten(bytesWrittenDiff) curNumBytesWritten = numBytesWritten } - - private class CloseShieldOutputStream(delegate: OutputStream) - extends FilterOutputStream(delegate) { - - override def close(): Unit = flush() - } } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 5f0de31bd25e3..5ea0907277ebf 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -183,8 +183,7 @@ private UnsafeShuffleWriter createWriter( conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); return new UnsafeShuffleWriter<>( blockManager, - shuffleBlockResolver, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, @@ -544,8 +543,7 @@ public void testPeakMemoryUsed() throws Exception { final UnsafeShuffleWriter writer = new UnsafeShuffleWriter<>( blockManager, - shuffleBlockResolver, - taskMemoryManager, + taskMemoryManager, new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala index 0b3394e88d9f1..dbd73f2688dfc 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterBenchmark.scala @@ -54,7 +54,6 @@ object BypassMergeSortShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase val shuffleWriter = new BypassMergeSortShuffleWriter[String, String]( blockManager, - blockResolver, shuffleHandle, 0, conf, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 538672e4bc738..013c5916284d2 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -158,7 +158,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, @@ -184,7 +183,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId transferConf, @@ -209,7 +207,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, @@ -245,7 +242,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, @@ -268,7 +264,6 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId conf, diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala index 0e659ff7cc5f3..7066ba8fb44df 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/UnsafeShuffleWriterBenchmark.scala @@ -50,15 +50,13 @@ object UnsafeShuffleWriterBenchmark extends ShuffleWriterBenchmarkBase { TaskContext.setTaskContext(taskContext) new UnsafeShuffleWriter[String, String]( blockManager, - blockResolver, taskMemoryManager, shuffleHandle, 0, taskContext, conf, taskContext.taskMetrics().shuffleWriteMetrics, - shuffleWriteSupport - ) + shuffleWriteSupport) } def writeBenchmarkWithSmallDataset(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala index 420b0d4d2f674..1f4ef0f203994 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriterSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle.sort.io import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream} import java.math.BigInteger import java.nio.ByteBuffer +import java.nio.channels.{Channels, WritableByteChannel} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.{any, anyInt, anyLong} @@ -31,10 +32,12 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.api.shuffle.SupportsTransferTo import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -141,14 +144,13 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.toStream() + val stream = writer.openStream() data(p).foreach { i => stream.write(i)} stream.close() intercept[IllegalStateException] { stream.write(p) } assert(writer.getNumBytesWritten() == D_LEN) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray @@ -160,15 +162,23 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("writing to a channel") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.toChannel() + val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() intBuffer.put(data(p)) - assert(channel.isOpen) - channel.write(byteBuffer) + val numBytes = byteBuffer.remaining() + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + val outputTempFileStream = new FileOutputStream(outputTempFile) + Utils.copyStream( + new ByteBufferInputStream(byteBuffer), + outputTempFileStream, + closeStreams = true) + val tempFileInput = new FileInputStream(outputTempFile) + channel.transferFrom(tempFileInput.getChannel, 0L, numBytes) // Bytes require * 4 + channel.close() + tempFileInput.close() assert(writer.getNumBytesWritten == D_LEN * 4) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray @@ -180,7 +190,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreams with an outputstream") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val stream = writer.toStream() + val stream = writer.openStream() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() intBuffer.put(data(p)) @@ -189,7 +199,6 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft in.close() stream.close() assert(writer.getNumBytesWritten == D_LEN * 4) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray @@ -201,7 +210,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft test("copyStreamsWithNIO with a channel") { (0 until NUM_PARTITIONS).foreach{ p => val writer = mapOutputWriter.getPartitionWriter(p) - val channel = writer.toChannel() + val channel = writer.asInstanceOf[SupportsTransferTo].openTransferrableChannel() val byteBuffer = ByteBuffer.allocate(D_LEN * 4) val intBuffer = byteBuffer.asIntBuffer() intBuffer.put(data(p)) @@ -209,10 +218,9 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft out.write(byteBuffer.array()) out.close() val in = new FileInputStream(tempFile) - Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4) - in.close() + channel.transferFrom(in.getChannel, 0L, byteBuffer.remaining()) + channel.close() assert(writer.getNumBytesWritten == D_LEN * 4) - writer.close } mapOutputWriter.commitAllPartitions() val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray