Skip to content

Commit e3b8855

Browse files
committed
Cleanup in UnsafeShuffleWriter
1 parent 4a2c785 commit e3b8855

File tree

2 files changed

+48
-22
lines changed

2 files changed

+48
-22
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ public UnsafeShuffleWriter(
102102
UnsafeShuffleHandle<K, V> handle,
103103
int mapId,
104104
TaskContext taskContext,
105-
SparkConf sparkConf) {
105+
SparkConf sparkConf) throws IOException {
106106
final int numPartitions = handle.dependency().partitioner().numPartitions();
107107
if (numPartitions > PackedRecordPointer.MAXIMUM_PARTITION_ID) {
108108
throw new IllegalArgumentException(
@@ -123,27 +123,29 @@ public UnsafeShuffleWriter(
123123
this.taskContext = taskContext;
124124
this.sparkConf = sparkConf;
125125
this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
126+
open();
126127
}
127128

129+
/**
130+
* This convenience method should only be called in test code.
131+
*/
132+
@VisibleForTesting
128133
public void write(Iterator<Product2<K, V>> records) throws IOException {
129134
write(JavaConversions.asScalaIterator(records));
130135
}
131136

132137
@Override
133138
public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
139+
boolean success = false;
134140
try {
135141
while (records.hasNext()) {
136142
insertRecordIntoSorter(records.next());
137143
}
138144
closeAndWriteOutput();
139-
} catch (Exception e) {
140-
// Unfortunately, we have to catch Exception here in order to ensure proper cleanup after
141-
// errors because Spark's Scala code, or users' custom Serializers, might throw arbitrary
142-
// unchecked exceptions.
143-
try {
145+
success = true;
146+
} finally {
147+
if (!success) {
144148
sorter.cleanupAfterError();
145-
} finally {
146-
throw new IOException("Error during shuffle write", e);
147149
}
148150
}
149151
}
@@ -165,9 +167,6 @@ private void open() throws IOException {
165167

166168
@VisibleForTesting
167169
void closeAndWriteOutput() throws IOException {
168-
if (sorter == null) {
169-
open();
170-
}
171170
serBuffer = null;
172171
serOutputStream = null;
173172
final SpillInfo[] spills = sorter.closeAndGetSpills();
@@ -187,10 +186,7 @@ void closeAndWriteOutput() throws IOException {
187186
}
188187

189188
@VisibleForTesting
190-
void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
191-
if (sorter == null) {
192-
open();
193-
}
189+
void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
194190
final K key = record._1();
195191
final int partitionId = partitioner.getPartition(key);
196192
serBuffer.reset();
@@ -275,15 +271,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
275271
}
276272
}
277273

274+
/**
275+
* Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
276+
* {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
277+
* cases where the IO compression codec does not support concatenation of compressed data, or in
278+
* cases where users have explicitly disabled use of {@code transferTo} in order to work around
279+
* kernel bugs.
280+
*
281+
* @param spills the spills to merge.
282+
* @param outputFile the file to write the merged data to.
283+
* @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
284+
* @return the partition lengths in the merged file.
285+
*/
278286
private long[] mergeSpillsWithFileStream(
279287
SpillInfo[] spills,
280288
File outputFile,
281289
@Nullable CompressionCodec compressionCodec) throws IOException {
290+
assert (spills.length >= 2);
282291
final int numPartitions = partitioner.numPartitions();
283292
final long[] partitionLengths = new long[numPartitions];
284293
final InputStream[] spillInputStreams = new FileInputStream[spills.length];
285294
OutputStream mergedFileOutputStream = null;
286295

296+
boolean threwException = true;
287297
try {
288298
for (int i = 0; i < spills.length; i++) {
289299
spillInputStreams[i] = new FileInputStream(spills[i].file);
@@ -311,22 +321,34 @@ private long[] mergeSpillsWithFileStream(
311321
mergedFileOutputStream.close();
312322
partitionLengths[partition] = (outputFile.length() - initialFileLength);
313323
}
324+
threwException = false;
314325
} finally {
326+
// To avoid masking exceptions that caused us to prematurely enter the finally block, only
327+
// throw exceptions during cleanup if threwException == false.
315328
for (InputStream stream : spillInputStreams) {
316-
Closeables.close(stream, false);
329+
Closeables.close(stream, threwException);
317330
}
318-
Closeables.close(mergedFileOutputStream, false);
331+
Closeables.close(mergedFileOutputStream, threwException);
319332
}
320333
return partitionLengths;
321334
}
322335

336+
/**
337+
* Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
338+
* This is only safe when the IO compression codec and serializer support concatenation of
339+
* serialized streams.
340+
*
341+
* @return the partition lengths in the merged file.
342+
*/
323343
private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
344+
assert (spills.length >= 2);
324345
final int numPartitions = partitioner.numPartitions();
325346
final long[] partitionLengths = new long[numPartitions];
326347
final FileChannel[] spillInputChannels = new FileChannel[spills.length];
327348
final long[] spillInputChannelPositions = new long[spills.length];
328349
FileChannel mergedFileOutputChannel = null;
329350

351+
boolean threwException = true;
330352
try {
331353
for (int i = 0; i < spills.length; i++) {
332354
spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
@@ -368,12 +390,15 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
368390
"to disable this NIO feature."
369391
);
370392
}
393+
threwException = false;
371394
} finally {
395+
// To avoid masking exceptions that caused us to prematurely enter the finally block, only
396+
// throw exceptions during cleanup if threwException == false.
372397
for (int i = 0; i < spills.length; i++) {
373398
assert(spillInputChannelPositions[i] == spills[i].file.length());
374-
Closeables.close(spillInputChannels[i], false);
399+
Closeables.close(spillInputChannels[i], threwException);
375400
}
376-
Closeables.close(mergedFileOutputChannel, false);
401+
Closeables.close(mergedFileOutputChannel, threwException);
377402
}
378403
return partitionLengths;
379404
}

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,8 @@ public Tuple2<TempShuffleBlockId, File> answer(
194194
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
195195
}
196196

197-
private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabled) {
197+
private UnsafeShuffleWriter<Object, Object> createWriter(
198+
boolean transferToEnabled) throws IOException {
198199
conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
199200
return new UnsafeShuffleWriter<Object, Object>(
200201
blockManager,
@@ -242,12 +243,12 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
242243
}
243244

244245
@Test(expected=IllegalStateException.class)
245-
public void mustCallWriteBeforeSuccessfulStop() {
246+
public void mustCallWriteBeforeSuccessfulStop() throws IOException {
246247
createWriter(false).stop(true);
247248
}
248249

249250
@Test
250-
public void doNotNeedToCallWriteBeforeUnsuccessfulStop() {
251+
public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
251252
createWriter(false).stop(false);
252253
}
253254

0 commit comments

Comments
 (0)