Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.api.shuffle;

import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.channels.Channels;
Expand All @@ -31,12 +32,43 @@
* @since 3.0.0
*/
@Experimental
public interface ShufflePartitionWriter {
OutputStream openStream() throws IOException;
public interface ShufflePartitionWriter extends Closeable {

long closeAndGetLength();
/**
* Returns an underlying {@link OutputStream} that can write bytes to the underlying data store.
* <p>
* Note that this stream itself is not closed by the caller; close the stream in
* the implementation of this class's {@link #close()}..
*/
OutputStream toStream() throws IOException;
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed from openStream -> toStream in order to better indicate that this writer is still responsible for closing its own resources. Meaning, "convert this writer to a stream", in a sense, rather than, "Open a stream to write contents". But there might be a better naming convention here.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see how the conversion logic can be advantageous. +1


default WritableByteChannel openChannel() throws IOException {
return Channels.newChannel(openStream());
/**
* Returns an underlying {@link WritableByteChannel} that can write bytes to the underlying data
* store.
* <p>
* Note that this channel itself is not closed by the caller; close the stream in
* the implementation of this class's {@link #close()}..
*/
default WritableByteChannel toChannel() throws IOException {
return Channels.newChannel(toStream());
}

/**
* Get the number of bytes written by this writer's stream returned by {@link #toStream()} or
* the channel returned by {@link #toChannel()}.
*/
long getNumBytesWritten();

/**
* Close all resources created by this ShufflePartitionWriter, via calls to {@link #toStream()}
* or {@link #toChannel()}.
* <p>
* This must always close any stream returned by {@link #toStream()}.
* <p>
* Note that the default version of {@link #toChannel()} returns a {@link WritableByteChannel}
* that does not itself need to be closed up front; only the underlying output stream given by
* {@link #toStream()} must be closed.
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Do we need to mention this in the API doc? Seems like a comment we can add to the implementation class.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"default implementation" here refers to the default method we put above.

*/
@Override
void close() throws IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
public void write(Iterator<Product2<K, V>> records) throws IOException {
assert (partitionWriters == null);
ShuffleMapOutputWriter mapOutputWriter = shuffleWriteSupport
.createMapOutputWriter(shuffleId, mapId, numPartitions);
.createMapOutputWriter(shuffleId, mapId, numPartitions);
try {
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
Expand All @@ -144,11 +144,11 @@ public void write(Iterator<Product2<K, V>> records) throws IOException {
partitionWriterSegments = new FileSegment[numPartitions];
for (int i = 0; i < numPartitions; i++) {
final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
blockManager.diskBlockManager().createTempShuffleBlock();
blockManager.diskBlockManager().createTempShuffleBlock();
final File file = tempShuffleBlockIdPlusFile._2();
final BlockId blockId = tempShuffleBlockIdPlusFile._1();
partitionWriters[i] =
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics);
}
// Creating the file to write to and creating a disk writer both involve interacting with
// the disk, and can take a long time in aggregate when we open many files, so should be
Expand Down Expand Up @@ -202,20 +202,22 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
for (int i = 0; i < numPartitions; i++) {
final File file = partitionWriterSegments[i].file();
boolean copyThrewException = true;
ShufflePartitionWriter writer = mapOutputWriter.getNextPartitionWriter();
if (transferToEnabled) {
WritableByteChannel outputChannel = writer.openChannel();
if (file.exists()) {
FileInputStream in = new FileInputStream(file);
try (FileChannel inputChannel = in.getChannel()){
Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size());
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
ShufflePartitionWriter writer = null;
try {
writer = mapOutputWriter.getNextPartitionWriter();
if (transferToEnabled) {
WritableByteChannel outputChannel = writer.toChannel();
if (file.exists()) {
FileInputStream in = new FileInputStream(file);
try (FileChannel inputChannel = in.getChannel()) {
Utils.copyFileStreamNIO(inputChannel, outputChannel, 0, inputChannel.size());
copyThrewException = false;
} finally {
Closeables.close(in, copyThrewException);
}
}
}
} else {
try (OutputStream tempOutputStream = writer.openStream()) {
} else {
OutputStream tempOutputStream = writer.toStream();
if (file.exists()) {
FileInputStream in = new FileInputStream(file);
try {
Expand All @@ -226,11 +228,14 @@ private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) thro
}
}
}
if (file.exists() && !file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}
} finally {
Closeables.close(writer, copyThrewException);
}
lengths[i] = writer.closeAndGetLength();
if (file.exists() && !file.delete()) {
logger.error("Unable to delete file for partition {}", i);
}

lengths[i] = writer.getNumBytesWritten();
}
} finally {
writeMetrics.incWriteTime(System.nanoTime() - writeStartTime);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public DefaultShuffleMapOutputWriter(
}

@Override
public ShufflePartitionWriter getNextPartitionWriter() throws IOException {
public ShufflePartitionWriter getNextPartitionWriter() {
return new DefaultShufflePartitionWriter(currPartitionId++);
}

Expand All @@ -97,7 +97,7 @@ public void commitAllPartitions() throws IOException {
}

@Override
public void abort(Throwable error) throws IOException {
public void abort(Throwable error) {
try {
cleanUp();
} catch (Exception e) {
Expand All @@ -107,7 +107,7 @@ public void abort(Throwable error) throws IOException {
log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath());
}
if (!outputFile.delete() && outputFile.exists()) {
log.warn("Failed to delete outputshuffle file at {}", outputFile.getAbsolutePath());
log.warn("Failed to delete output shuffle file at {}", outputFile.getAbsolutePath());
}
}

Expand Down Expand Up @@ -154,42 +154,42 @@ private DefaultShufflePartitionWriter(int partitionId) {
}

@Override
public OutputStream openStream() throws IOException {
public OutputStream toStream() throws IOException {
initStream();
stream = new PartitionWriterStream();
return stream;
}

@Override
public long closeAndGetLength() {
public FileChannel toChannel() throws IOException {
initChannel();
currChannelPosition = outputFileChannel.position();
return outputFileChannel;
}

@Override
public long getNumBytesWritten() {
if (outputFileChannel != null && stream == null) {
try {
long newPosition = outputFileChannel.position();
long length = newPosition - currChannelPosition;
partitionLengths[partitionId] = length;
currChannelPosition = newPosition;
return length;
return newPosition - currChannelPosition;
} catch (Exception e) {
log.error("The currPartition is: " + partitionId, e);
throw new IllegalStateException("Attempting to calculate position of file channel", e);
log.error("The currPartition is: {}", partitionId, e);
throw new IllegalStateException("Failed to calculate position of file channel", e);
}
} else if (stream != null) {
return stream.getCount();
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Observe that flush isn't strictly necessary here. Getting the count retrieves the view of the number of bytes written by the counting output stream, which is correct.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, you are calling stream.close() below (before calling getNumBytesWritten) so it will call flush via that, no?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily - see BypassMergeSortShuffleWriter. In BypassMergeSortShuffleWriter we call getNumBytesWritten after transferring the bytes from the spill files to the output writer, before the writer is closed. I think we're effectively counting on the idea that getNumBytesWritten telling the truth about how many bytes were actually written to the streams / channels it gives back, which seems reasonable enough - the method does what it says it does.

} else {
try {
stream.close();
} catch (Exception e) {
throw new IllegalStateException("Attempting to close output stream", e);
}
int length = stream.getCount();
partitionLengths[partitionId] = length;
return length;
return 0;
}
}

@Override
public FileChannel openChannel() throws IOException {
initChannel();
currChannelPosition = outputFileChannel.position();
return outputFileChannel;
public void close() throws IOException {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it might be good to keep track of the outputFileChannel close as well. Basically, we should ensure that this method is called before getNumBytesWrittern() is called. Otherwise, I might actually be in favor of having close() return the long to ensure that getNumBytesWritten() can never be called before closing.

if (stream != null) {
stream.close();
}
partitionLengths[partitionId] = getNumBytesWritten();
}
}

Expand Down Expand Up @@ -218,7 +218,9 @@ public void close() throws IOException {

@Override
public void flush() throws IOException {
outputBufferedFileStream.flush();
if (!isClosed) {
outputBufferedFileStream.flush();
}
}
}
}
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,11 @@ private[spark] object Utils extends Logging {
output: WritableByteChannel,
startPosition: Long,
bytesToCopy: Long): Unit = {
// val initialPos = output.position()
val outputInitialState = output match {
case outputFileChannel: FileChannel =>
Some((outputFileChannel.position(), outputFileChannel))
case _ => None
}
var count = 0L
// In case transferTo method transferred less data than we have required.
while (count < bytesToCopy) {
Expand All @@ -349,21 +353,23 @@ private[spark] object Utils extends Logging {
assert(count == bytesToCopy,
s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")

// // Check the position after transferTo loop to see if it is in the right position and
// // give user information if not.
// // Position will not be increased to the expected length after calling transferTo in
// // kernel version 2.6.32, this issue can be seen in
// // https://bugs.openjdk.java.net/browse/JDK-7052359
// // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
// val finalPos = output.position()
// val expectedPos = initialPos + bytesToCopy
// assert(finalPos == expectedPos,
// s"""
// |Current position $finalPos do not equal to expected position $expectedPos
// |after transferTo, please check your kernel version to see if it is 2.6.32,
// |this is a kernel bug which will lead to unexpected behavior when using transferTo.
// |You can set spark.file.transferTo = false to disable this NIO feature.
// """.stripMargin)
// Check the position after transferTo loop to see if it is in the right position and
// give user information if not.
// Position will not be increased to the expected length after calling transferTo in
// kernel version 2.6.32, this issue can be seen in
// https://bugs.openjdk.java.net/browse/JDK-7052359
// This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
outputInitialState.foreach { case (initialPos, outputFileChannel) =>
val finalPos = outputFileChannel.position()
val expectedPos = initialPos + bytesToCopy
assert(finalPos == expectedPos,
s"""
|Current position $finalPos do not equal to expected position $expectedPos
|after transferTo, please check your kernel version to see if it is 2.6.32,
|this is a kernel bug which will lead to unexpected behavior when using transferTo.
|You can set spark.file.transferTo = false to disable this NIO feature.
""".stripMargin)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,14 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("writing to an outputstream") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val stream = writer.openStream()
val stream = writer.toStream()
data(p).foreach { i => stream.write(i)}
stream.close()
intercept[IllegalStateException] {
stream.write(p)
}
assert(writer.closeAndGetLength() == D_LEN)
assert(writer.getNumBytesWritten() == D_LEN)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => D_LEN.toDouble}.toArray
Expand All @@ -152,14 +153,15 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("writing to a channel") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val channel = writer.openChannel()
val channel = writer.toChannel()
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
val intBuffer = byteBuffer.asIntBuffer()
intBuffer.put(data(p))
assert(channel.isOpen)
channel.write(byteBuffer)
// Bytes require * 4
assert(writer.closeAndGetLength == D_LEN * 4)
assert(writer.getNumBytesWritten == D_LEN * 4)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
Expand All @@ -171,15 +173,16 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("copyStreams with an outputstream") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val stream = writer.openStream()
val stream = writer.toStream()
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
val intBuffer = byteBuffer.asIntBuffer()
intBuffer.put(data(p))
val in: InputStream = new ByteArrayInputStream(byteBuffer.array())
Utils.copyStream(in, stream, false, false)
in.close()
stream.close()
assert(writer.closeAndGetLength == D_LEN * 4)
assert(writer.getNumBytesWritten == D_LEN * 4)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
Expand All @@ -191,7 +194,7 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
test("copyStreamsWithNIO with a channel") {
(0 until NUM_PARTITIONS).foreach{ p =>
val writer = mapOutputWriter.getNextPartitionWriter
val channel = writer.openChannel()
val channel = writer.toChannel()
val byteBuffer = ByteBuffer.allocate(D_LEN * 4)
val intBuffer = byteBuffer.asIntBuffer()
intBuffer.put(data(p))
Expand All @@ -201,7 +204,8 @@ class DefaultShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAft
val in = new FileInputStream(tempFile)
Utils.copyFileStreamNIO(in.getChannel, channel, 0, D_LEN * 4)
in.close()
assert(writer.closeAndGetLength == D_LEN * 4)
assert(writer.getNumBytesWritten == D_LEN * 4)
writer.close
}
mapOutputWriter.commitAllPartitions()
val partitionLengths = (0 until NUM_PARTITIONS).map { _ => (D_LEN * 4).toDouble}.toArray
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,6 @@ case class InputFileBlockLength() extends LeafExpression with Nondeterministic {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = InputFileBlockHolder.getClass.getName.stripSuffix("$")
val typeDef = s"final ${CodeGenerator.javaType(dataType)}"
ev.copy(code = code"$typeDef ${ev.value} = $className.getLength();", isNull = FalseLiteral)
ev.copy(code = code"$typeDef ${ev.value} = $className.getNumBytesWritten();", isNull = FalseLiteral)
}
}