@@ -102,7 +102,7 @@ public UnsafeShuffleWriter(
102102 UnsafeShuffleHandle <K , V > handle ,
103103 int mapId ,
104104 TaskContext taskContext ,
105- SparkConf sparkConf ) {
105+ SparkConf sparkConf ) throws IOException {
106106 final int numPartitions = handle .dependency ().partitioner ().numPartitions ();
107107 if (numPartitions > PackedRecordPointer .MAXIMUM_PARTITION_ID ) {
108108 throw new IllegalArgumentException (
@@ -123,27 +123,29 @@ public UnsafeShuffleWriter(
123123 this .taskContext = taskContext ;
124124 this .sparkConf = sparkConf ;
125125 this .transferToEnabled = sparkConf .getBoolean ("spark.file.transferTo" , true );
126+ open ();
126127 }
127128
129+ /**
130+ * This convenience method should only be called in test code.
131+ */
132+ @ VisibleForTesting
128133 public void write (Iterator <Product2 <K , V >> records ) throws IOException {
129134 write (JavaConversions .asScalaIterator (records ));
130135 }
131136
132137 @ Override
133138 public void write (scala .collection .Iterator <Product2 <K , V >> records ) throws IOException {
139+ boolean success = false ;
134140 try {
135141 while (records .hasNext ()) {
136142 insertRecordIntoSorter (records .next ());
137143 }
138144 closeAndWriteOutput ();
139- } catch (Exception e ) {
140- // Unfortunately, we have to catch Exception here in order to ensure proper cleanup after
141- // errors because Spark's Scala code, or users' custom Serializers, might throw arbitrary
142- // unchecked exceptions.
143- try {
145+ success = true ;
146+ } finally {
147+ if (!success ) {
144148 sorter .cleanupAfterError ();
145- } finally {
146- throw new IOException ("Error during shuffle write" , e );
147149 }
148150 }
149151 }
@@ -165,9 +167,6 @@ private void open() throws IOException {
165167
166168 @ VisibleForTesting
167169 void closeAndWriteOutput () throws IOException {
168- if (sorter == null ) {
169- open ();
170- }
171170 serBuffer = null ;
172171 serOutputStream = null ;
173172 final SpillInfo [] spills = sorter .closeAndGetSpills ();
@@ -187,10 +186,7 @@ void closeAndWriteOutput() throws IOException {
187186 }
188187
189188 @ VisibleForTesting
190- void insertRecordIntoSorter (Product2 <K , V > record ) throws IOException {
191- if (sorter == null ) {
192- open ();
193- }
189+ void insertRecordIntoSorter (Product2 <K , V > record ) throws IOException {
194190 final K key = record ._1 ();
195191 final int partitionId = partitioner .getPartition (key );
196192 serBuffer .reset ();
@@ -275,15 +271,29 @@ private long[] mergeSpills(SpillInfo[] spills) throws IOException {
275271 }
276272 }
277273
274+ /**
275+ * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
276+ * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
277+ * cases where the IO compression codec does not support concatenation of compressed data, or in
278+ * cases where users have explicitly disabled use of {@code transferTo} in order to work around
279+ * kernel bugs.
280+ *
281+ * @param spills the spills to merge.
282+ * @param outputFile the file to write the merged data to.
283+ * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
284+ * @return the partition lengths in the merged file.
285+ */
278286 private long [] mergeSpillsWithFileStream (
279287 SpillInfo [] spills ,
280288 File outputFile ,
281289 @ Nullable CompressionCodec compressionCodec ) throws IOException {
290+ assert (spills .length >= 2 );
282291 final int numPartitions = partitioner .numPartitions ();
283292 final long [] partitionLengths = new long [numPartitions ];
284293 final InputStream [] spillInputStreams = new FileInputStream [spills .length ];
285294 OutputStream mergedFileOutputStream = null ;
286295
296+ boolean threwException = true ;
287297 try {
288298 for (int i = 0 ; i < spills .length ; i ++) {
289299 spillInputStreams [i ] = new FileInputStream (spills [i ].file );
@@ -311,22 +321,34 @@ private long[] mergeSpillsWithFileStream(
311321 mergedFileOutputStream .close ();
312322 partitionLengths [partition ] = (outputFile .length () - initialFileLength );
313323 }
324+ threwException = false ;
314325 } finally {
326+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
327+ // throw exceptions during cleanup if threwException == false.
315328 for (InputStream stream : spillInputStreams ) {
316- Closeables .close (stream , false );
329+ Closeables .close (stream , threwException );
317330 }
318- Closeables .close (mergedFileOutputStream , false );
331+ Closeables .close (mergedFileOutputStream , threwException );
319332 }
320333 return partitionLengths ;
321334 }
322335
336+ /**
337+ * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
338+ * This is only safe when the IO compression codec and serializer support concatenation of
339+ * serialized streams.
340+ *
341+ * @return the partition lengths in the merged file.
342+ */
323343 private long [] mergeSpillsWithTransferTo (SpillInfo [] spills , File outputFile ) throws IOException {
344+ assert (spills .length >= 2 );
324345 final int numPartitions = partitioner .numPartitions ();
325346 final long [] partitionLengths = new long [numPartitions ];
326347 final FileChannel [] spillInputChannels = new FileChannel [spills .length ];
327348 final long [] spillInputChannelPositions = new long [spills .length ];
328349 FileChannel mergedFileOutputChannel = null ;
329350
351+ boolean threwException = true ;
330352 try {
331353 for (int i = 0 ; i < spills .length ; i ++) {
332354 spillInputChannels [i ] = new FileInputStream (spills [i ].file ).getChannel ();
@@ -368,12 +390,15 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
368390 "to disable this NIO feature."
369391 );
370392 }
393+ threwException = false ;
371394 } finally {
395+ // To avoid masking exceptions that caused us to prematurely enter the finally block, only
396+ // throw exceptions during cleanup if threwException == false.
372397 for (int i = 0 ; i < spills .length ; i ++) {
373398 assert (spillInputChannelPositions [i ] == spills [i ].file .length ());
374- Closeables .close (spillInputChannels [i ], false );
399+ Closeables .close (spillInputChannels [i ], threwException );
375400 }
376- Closeables .close (mergedFileOutputChannel , false );
401+ Closeables .close (mergedFileOutputChannel , threwException );
377402 }
378403 return partitionLengths ;
379404 }
0 commit comments