- * Note that this stream itself is not closed by the caller; close the stream in the
- * implementation of this interface's {@link #close()}.
+ * Opens and returns an underlying {@link OutputStream} that can write bytes to the underlying
+ * data store.
*/
- OutputStream toStream() throws IOException;
+ OutputStream openStream() throws IOException;
/**
- * Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data
- * store.
- *
- * Note that this channel itself is not closed by the caller; close the channel in the
- * implementation of this interface's {@link #close()}.
- */
- default WritableByteChannel toChannel() throws IOException {
- return Channels.newChannel(toStream());
- }
-
- /**
- * Get the number of bytes written by this writer's stream returned by {@link #toStream()} or
- * the channel returned by {@link #toChannel()}.
+ * Get the number of bytes written by this writer's stream returned by {@link #openStream()}.
*/
long getNumBytesWritten();
-
- /**
- * Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()}
- * or {@link #toChannel()}.
- *
- * This must always close any stream returned by {@link #toStream()}.
- *
- * Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel}
- * that does not itself need to be closed up front; only the underlying output stream given by
- * {@link #toStream()} must be closed.
- */
- @Override
- void close() throws IOException;
}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
index 6c69d5db9fd06..7e2b6cf4133fd 100644
--- a/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
+++ b/core/src/main/java/org/apache/spark/api/shuffle/ShuffleWriteSupport.java
@@ -19,7 +19,7 @@
import java.io.IOException;
-import org.apache.http.annotation.Experimental;
+import org.apache.spark.annotation.Experimental;
/**
* :: Experimental ::
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java
new file mode 100644
index 0000000000000..866b61d0bafd9
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/SupportsTransferTo.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.IOException;
+
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * Indicates that partition writers can transfer bytes directly from input byte channels to
+ * output channels that stream data to the underlying shuffle partition storage medium.
+ *
+ * This API is separated out for advanced users because it only needs to be used for
+ * specific low-level optimizations. The idea is that the returned channel can transfer bytes
+ * from the input file channel out to the backing storage system without copying data into
+ * memory.
+ *
+ * Most shuffle plugin implementations should use {@link ShufflePartitionWriter} instead.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface SupportsTransferTo extends ShufflePartitionWriter {
+
+ /**
+ * Opens and returns a {@link TransferrableWritableByteChannel} for transferring bytes from
+ * input byte channels to the underlying shuffle data store.
+ */
+ TransferrableWritableByteChannel openTransferrableChannel() throws IOException;
+
+ /**
+ * Returns the number of bytes written either by this writer's output stream opened by
+ * {@link #openStream()} or the byte channel opened by {@link #openTransferrableChannel()}.
+ */
+ @Override
+ long getNumBytesWritten();
+}
diff --git a/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java
new file mode 100644
index 0000000000000..18234d7c4c944
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/shuffle/TransferrableWritableByteChannel.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.shuffle;
+
+import java.io.Closeable;
+import java.io.IOException;
+
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+import org.apache.spark.annotation.Experimental;
+
+/**
+ * :: Experimental ::
+ * Represents an output byte channel that can copy bytes from input file channels to some
+ * arbitrary storage system.
+ *
+ * This API is provided for advanced users who can transfer bytes from a file channel to
+ * some output sink without copying data into memory. Most users should not need to use
+ * this functionality; this is primarily provided for the built-in shuffle storage backends
+ * that persist shuffle files on local disk.
+ *
+ * For a simpler alternative, see {@link ShufflePartitionWriter}.
+ *
+ * @since 3.0.0
+ */
+@Experimental
+public interface TransferrableWritableByteChannel extends Closeable {
+
+ /**
+ * Copy all bytes from the source readable byte channel into this byte channel.
+ *
+ * @param source File to transfer bytes from. Do not call anything on this channel other than
+ * {@link FileChannel#transferTo(long, long, WritableByteChannel)}.
+ * @param transferStartPosition Start position of the input file to transfer from.
+ * @param numBytesToTransfer Number of bytes to transfer from the given source.
+ */
+ void transferFrom(FileChannel source, long transferStartPosition, long numBytesToTransfer)
+ throws IOException;
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 22386c39aca0a..128b90429209e 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -21,12 +21,10 @@
import java.io.FileInputStream;
import java.io.IOException;
import java.io.OutputStream;
+import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
-import java.nio.channels.WritableByteChannel;
import javax.annotation.Nullable;
-import org.apache.spark.api.java.Optional;
-import org.apache.spark.api.shuffle.MapShuffleLocations;
import scala.None$;
import scala.Option;
import scala.Product2;
@@ -38,19 +36,22 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.SupportsTransferTo;
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShufflePartitionWriter;
import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
import org.apache.spark.internal.config.package$;
-import org.apache.spark.Partitioner;
-import org.apache.spark.ShuffleDependency;
-import org.apache.spark.SparkConf;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.*;
import org.apache.spark.util.Utils;
@@ -90,7 +91,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
private final int mapId;
private final Serializer serializer;
private final ShuffleWriteSupport shuffleWriteSupport;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
/** Array of file writers, one for each partition */
private DiskBlockObjectWriter[] partitionWriters;
@@ -107,7 +107,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
BypassMergeSortShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
BypassMergeSortShuffleHandle handle,
int mapId,
SparkConf conf,
@@ -124,7 +123,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter {
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = writeMetrics;
this.serializer = dep.serializer();
- this.shuffleBlockResolver = shuffleBlockResolver;
this.shuffleWriteSupport = shuffleWriteSupport;
}
@@ -209,40 +207,43 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
try {
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
- boolean copyThrewException = true;
- ShufflePartitionWriter writer = null;
- try {
- writer = mapOutputWriter.getPartitionWriter(i);
- if (!file.exists()) {
- copyThrewException = false;
- } else {
- if (transferToEnabled) {
- WritableByteChannel outputChannel = writer.toChannel();
- FileInputStream in = new FileInputStream(file);
- try (FileChannel inputChannel = in.getChannel()) {
- Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size());
- copyThrewException = false;
- } finally {
- Closeables.close(in, copyThrewException);
- }
- } else {
- OutputStream tempOutputStream = writer.toStream();
- FileInputStream in = new FileInputStream(file);
- try {
- Utils.copyStream(in, tempOutputStream, false, false);
- copyThrewException = false;
- } finally {
- Closeables.close(in, copyThrewException);
+ ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i);
+ if (file.exists()) {
+ boolean copyThrewException = true;
+ if (transferToEnabled) {
+ FileInputStream in = new FileInputStream(file);
+ TransferrableWritableByteChannel outputChannel = null;
+ try (FileChannel inputChannel = in.getChannel()) {
+ if (writer instanceof SupportsTransferTo) {
+ outputChannel = ((SupportsTransferTo) writer).openTransferrableChannel();
+ } else {
+ // Use default transferrable writable channel anyways in order to have parity with
+ // UnsafeShuffleWriter.
+ outputChannel = new DefaultTransferrableWritableByteChannel(
+ Channels.newChannel(writer.openStream()));
}
+ outputChannel.transferFrom(inputChannel, 0L, inputChannel.size());
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ Closeables.close(outputChannel, copyThrewException);
}
- if (!file.delete()) {
- logger.error("Unable to delete file for partition {}", i);
+ } else {
+ FileInputStream in = new FileInputStream(file);
+ OutputStream outputStream = null;
+ try {
+ outputStream = writer.openStream();
+ Utils.copyStream(in, outputStream, false, false);
+ copyThrewException = false;
+ } finally {
+ Closeables.close(in, copyThrewException);
+ Closeables.close(outputStream, copyThrewException);
}
}
- } finally {
- Closeables.close(writer, copyThrewException);
+ if (!file.delete()) {
+ logger.error("Unable to delete file for partition {}", i);
+ }
}
-
lengths[i] = writer.getNumBytesWritten();
}
} finally {
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java
new file mode 100644
index 0000000000000..64ce851e392d2
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/DefaultTransferrableWritableByteChannel.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.nio.channels.WritableByteChannel;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
+import org.apache.spark.util.Utils;
+
+/**
+ * This is used when transferTo is enabled but the shuffle plugin hasn't implemented
+ * {@link org.apache.spark.api.shuffle.SupportsTransferTo}.
+ *
+ * This default implementation exists as a convenience to the unsafe shuffle writer and
+ * the bypass merge sort shuffle writers.
+ */
+public class DefaultTransferrableWritableByteChannel implements TransferrableWritableByteChannel {
+
+ private final WritableByteChannel delegate;
+
+ public DefaultTransferrableWritableByteChannel(WritableByteChannel delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public void transferFrom(
+ FileChannel source, long transferStartPosition, long numBytesToTransfer) {
+ Utils.copyFileStreamNIO(source, delegate, transferStartPosition, numBytesToTransfer);
+ }
+
+ @Override
+ public void close() throws IOException {
+ delegate.close();
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 069716cc99db0..f147bd79773e1 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -17,14 +17,12 @@
package org.apache.spark.shuffle.sort;
+import java.nio.channels.Channels;
import javax.annotation.Nullable;
import java.io.*;
import java.nio.channels.FileChannel;
-import java.nio.channels.WritableByteChannel;
import java.util.Iterator;
-import org.apache.spark.api.java.Optional;
-import org.apache.spark.api.shuffle.MapShuffleLocations;
import scala.Option;
import scala.Product2;
import scala.collection.JavaConverters;
@@ -39,14 +37,17 @@
import org.apache.spark.*;
import org.apache.spark.annotation.Private;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShufflePartitionWriter;
import org.apache.spark.api.shuffle.ShuffleWriteSupport;
+import org.apache.spark.api.shuffle.SupportsTransferTo;
import org.apache.spark.internal.config.package$;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
import org.apache.spark.io.NioBufferedFileInputStream;
-import org.apache.commons.io.output.CloseShieldOutputStream;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
@@ -54,11 +55,9 @@
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.util.Utils;
@Private
public class UnsafeShuffleWriter extends ShuffleWriter {
@@ -71,7 +70,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter {
static final int DEFAULT_INITIAL_SER_BUFFER_SIZE = 1024 * 1024;
private final BlockManager blockManager;
- private final IndexShuffleBlockResolver shuffleBlockResolver;
private final TaskMemoryManager memoryManager;
private final SerializerInstance serializer;
private final Partitioner partitioner;
@@ -107,7 +105,6 @@ private static final class MyByteArrayOutputStream extends ByteArrayOutputStream
public UnsafeShuffleWriter(
BlockManager blockManager,
- IndexShuffleBlockResolver shuffleBlockResolver,
TaskMemoryManager memoryManager,
SerializedShuffleHandle handle,
int mapId,
@@ -123,7 +120,6 @@ public UnsafeShuffleWriter(
" reduce partitions");
}
this.blockManager = blockManager;
- this.shuffleBlockResolver = shuffleBlockResolver;
this.memoryManager = memoryManager;
this.mapId = mapId;
final ShuffleDependency dep = handle.dependency();
@@ -364,45 +360,37 @@ private long[] mergeSpillsWithFileStream(
}
for (int partition = 0; partition < numPartitions; partition++) {
boolean copyThrewExecption = true;
- ShufflePartitionWriter writer = null;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ OutputStream partitionOutput = null;
try {
- writer = mapWriter.getPartitionWriter(partition);
- OutputStream partitionOutput = null;
- try {
- // Shield the underlying output stream from close() calls, so that we can close the
- // higher level streams to make sure all data is really flushed and internal state
- // is cleaned
- partitionOutput = new CloseShieldOutputStream(writer.toStream());
- partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
- if (compressionCodec != null) {
- partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
- }
- for (int i = 0; i < spills.length; i++) {
- final long partitionLengthInSpill = spills[i].partitionLengths[partition];
-
- if (partitionLengthInSpill > 0) {
- InputStream partitionInputStream = null;
- try {
- partitionInputStream = new LimitedInputStream(spillInputStreams[i],
- partitionLengthInSpill, false);
- partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+ partitionOutput = writer.openStream();
+ partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput);
+ if (compressionCodec != null) {
+ partitionOutput = compressionCodec.compressedOutputStream(partitionOutput);
+ }
+ for (int i = 0; i < spills.length; i++) {
+ final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+
+ if (partitionLengthInSpill > 0) {
+ InputStream partitionInputStream = null;
+ try {
+ partitionInputStream = new LimitedInputStream(spillInputStreams[i],
+ partitionLengthInSpill, false);
+ partitionInputStream = blockManager.serializerManager().wrapForEncryption(
+ partitionInputStream);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(
partitionInputStream);
- if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(
- partitionInputStream);
- }
- ByteStreams.copy(partitionInputStream, partitionOutput);
- } finally {
- partitionInputStream.close();
}
+ ByteStreams.copy(partitionInputStream, partitionOutput);
+ } finally {
+ partitionInputStream.close();
}
- copyThrewExecption = false;
}
- } finally {
- Closeables.close(partitionOutput, copyThrewExecption);
+ copyThrewExecption = false;
}
} finally {
- Closeables.close(writer, copyThrewExecption);
+ Closeables.close(partitionOutput, copyThrewExecption);
}
long numBytesWritten = writer.getNumBytesWritten();
partitionLengths[partition] = numBytesWritten;
@@ -443,26 +431,26 @@ private long[] mergeSpillsWithTransferTo(
}
for (int partition = 0; partition < numPartitions; partition++) {
boolean copyThrewExecption = true;
- ShufflePartitionWriter writer = null;
+ ShufflePartitionWriter writer = mapWriter.getPartitionWriter(partition);
+ TransferrableWritableByteChannel partitionChannel = null;
try {
- writer = mapWriter.getPartitionWriter(partition);
- WritableByteChannel channel = writer.toChannel();
+ partitionChannel = writer instanceof SupportsTransferTo ?
+ ((SupportsTransferTo) writer).openTransferrableChannel()
+ : new DefaultTransferrableWritableByteChannel(
+ Channels.newChannel(writer.openStream()));
for (int i = 0; i < spills.length; i++) {
long partitionLengthInSpill = 0L;
partitionLengthInSpill += spills[i].partitionLengths[partition];
final FileChannel spillInputChannel = spillInputChannels[i];
final long writeStartTime = System.nanoTime();
- Utils.copyFileStreamNIO(
- spillInputChannel,
- channel,
- spillInputChannelPositions[i],
- partitionLengthInSpill);
- copyThrewExecption = false;
+ partitionChannel.transferFrom(
+ spillInputChannel, spillInputChannelPositions[i], partitionLengthInSpill);
spillInputChannelPositions[i] += partitionLengthInSpill;
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
}
+ copyThrewExecption = false;
} finally {
- Closeables.close(writer, copyThrewExecption);
+ Closeables.close(partitionChannel, copyThrewExecption);
}
long numBytes = writer.getNumBytesWritten();
partitionLengths[partition] = numBytes;
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
index 926c3b9433990..e83db4e4bcef6 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/DefaultShuffleMapOutputWriter.java
@@ -24,18 +24,21 @@
import java.io.OutputStream;
import java.nio.channels.FileChannel;
-import org.apache.spark.api.java.Optional;
-import org.apache.spark.api.shuffle.MapShuffleLocations;
-import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
-import org.apache.spark.storage.BlockManagerId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.Optional;
+import org.apache.spark.api.shuffle.MapShuffleLocations;
import org.apache.spark.api.shuffle.ShuffleMapOutputWriter;
import org.apache.spark.api.shuffle.ShufflePartitionWriter;
-import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.api.shuffle.SupportsTransferTo;
+import org.apache.spark.api.shuffle.TransferrableWritableByteChannel;
import org.apache.spark.internal.config.package$;
+import org.apache.spark.shuffle.sort.DefaultMapShuffleLocations;
+import org.apache.spark.shuffle.sort.DefaultTransferrableWritableByteChannel;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.util.Utils;
@@ -151,70 +154,70 @@ private void initChannel() throws IOException {
}
}
- private class DefaultShufflePartitionWriter implements ShufflePartitionWriter {
+ private class DefaultShufflePartitionWriter implements SupportsTransferTo {
private final int partitionId;
- private PartitionWriterStream stream = null;
+ private PartitionWriterStream partStream = null;
+ private PartitionWriterChannel partChannel = null;
private DefaultShufflePartitionWriter(int partitionId) {
this.partitionId = partitionId;
}
@Override
- public OutputStream toStream() throws IOException {
- if (outputFileChannel != null) {
- throw new IllegalStateException("Requested an output channel for a previous write but" +
- " now an output stream has been requested. Should not be using both channels" +
- " and streams to write.");
+ public OutputStream openStream() throws IOException {
+ if (partStream == null) {
+ if (outputFileChannel != null) {
+ throw new IllegalStateException("Requested an output channel for a previous write but" +
+ " now an output stream has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initStream();
+ partStream = new PartitionWriterStream(partitionId);
}
- initStream();
- stream = new PartitionWriterStream();
- return stream;
+ return partStream;
}
@Override
- public FileChannel toChannel() throws IOException {
- if (stream != null) {
- throw new IllegalStateException("Requested an output stream for a previous write but" +
- " now an output channel has been requested. Should not be using both channels" +
- " and streams to write.");
+ public TransferrableWritableByteChannel openTransferrableChannel() throws IOException {
+ if (partChannel == null) {
+ if (partStream != null) {
+ throw new IllegalStateException("Requested an output stream for a previous write but" +
+ " now an output channel has been requested. Should not be using both channels" +
+ " and streams to write.");
+ }
+ initChannel();
+ partChannel = new PartitionWriterChannel(partitionId);
}
- initChannel();
- return outputFileChannel;
+ return partChannel;
}
@Override
public long getNumBytesWritten() {
- if (outputFileChannel != null && stream == null) {
+ if (partChannel != null) {
try {
- long newPosition = outputFileChannel.position();
- return newPosition - currChannelPosition;
- } catch (Exception e) {
- log.error("The partition which failed is: {}", partitionId, e);
- throw new IllegalStateException("Failed to calculate position of file channel", e);
+ return partChannel.getCount();
+ } catch (IOException e) {
+ throw new RuntimeException(e);
}
- } else if (stream != null) {
- return stream.getCount();
+ } else if (partStream != null) {
+ return partStream.getCount();
} else {
// Assume an empty partition if stream and channel are never created
return 0;
}
}
-
- @Override
- public void close() {
- if (stream != null) {
- // Closing is a no-op.
- stream.close();
- }
- partitionLengths[partitionId] = getNumBytesWritten();
- }
}
private class PartitionWriterStream extends OutputStream {
+ private final int partitionId;
private int count = 0;
private boolean isClosed = false;
+ PartitionWriterStream(int partitionId) {
+ this.partitionId = partitionId;
+ }
+
public int getCount() {
return count;
}
@@ -236,6 +239,7 @@ public void write(byte[] buf, int pos, int length) throws IOException {
@Override
public void close() {
isClosed = true;
+ partitionLengths[partitionId] = count;
}
private void verifyNotClosed() {
@@ -244,4 +248,24 @@ private void verifyNotClosed() {
}
}
}
+
+ private class PartitionWriterChannel extends DefaultTransferrableWritableByteChannel {
+
+ private final int partitionId;
+
+ PartitionWriterChannel(int partitionId) {
+ super(outputFileChannel);
+ this.partitionId = partitionId;
+ }
+
+ public long getCount() throws IOException {
+ long writtenPosition = outputFileChannel.position();
+ return writtenPosition - currChannelPosition;
+ }
+
+ @Override
+ public void close() throws IOException {
+ partitionLengths[partitionId] = getCount();
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 849050556c569..5fa9296b022ca 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -140,7 +140,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
new UnsafeShuffleWriter(
env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
context.taskMemoryManager(),
unsafeShuffleHandle,
mapId,
@@ -151,7 +150,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
new BypassMergeSortShuffleWriter(
env.blockManager,
- shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
bypassMergeSortHandle,
mapId,
env.conf,
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index df5ce73b9acf1..14d34e1c47c8e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -755,9 +755,6 @@ private[spark] class ExternalSorter[K, V, C](
if (partitionPairsWriter != null) {
partitionPairsWriter.close()
}
- if (partitionWriter != null) {
- partitionWriter.close()
- }
}
if (partitionWriter != null) {
lengths(partitionId) = partitionWriter.getNumBytesWritten
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala
index 6f19a2323efde..8538a78b377c8 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ShufflePartitionPairsWriter.scala
@@ -54,25 +54,17 @@ private[spark] class ShufflePartitionPairsWriter(
}
private def open(): Unit = {
- // The contract is that the partition writer is expected to close its own streams, but
- // the compressor will only flush the stream when it is specifically closed. So we want to
- // close objOut to flush the compressed bytes to the partition writer stream, but we don't want
- // to close the partition output stream in the process.
- partitionStream = new CloseShieldOutputStream(partitionWriter.toStream)
+ partitionStream = partitionWriter.openStream
wrappedStream = serializerManager.wrapStream(blockId, partitionStream)
objOut = serializerInstance.serializeStream(wrappedStream)
}
override def close(): Unit = {
if (isOpen) {
- // Closing objOut should propagate close to all inner layers
- // We can't close wrappedStream explicitly because closing objOut and closing wrappedStream
- // causes problems when closing compressed output streams twice.
objOut.close()
objOut = null
wrappedStream = null
partitionStream = null
- partitionWriter.close()
isOpen = false
updateBytesWritten()
}
@@ -96,10 +88,4 @@ private[spark] class ShufflePartitionPairsWriter(
writeMetrics.incBytesWritten(bytesWrittenDiff)
curNumBytesWritten = numBytesWritten
}
-
- private class CloseShieldOutputStream(delegate: OutputStream)
- extends FilterOutputStream(delegate) {
-
- override def close(): Unit = flush()
- }
}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 5f0de31bd25e3..5ea0907277ebf 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -183,8 +183,7 @@ private UnsafeShuffleWriter