3737import org .mockito .invocation .InvocationOnMock ;
3838import org .mockito .stubbing .Answer ;
3939import static org .mockito .AdditionalAnswers .returnsFirstArg ;
40- import static org .mockito .AdditionalAnswers .returnsSecondArg ;
4140import static org .mockito .Answers .RETURNS_SMART_NULLS ;
4241import static org .mockito .Mockito .*;
4342
4443import org .apache .spark .*;
44+ import org .apache .spark .io .CompressionCodec$ ;
45+ import org .apache .spark .io .LZ4CompressionCodec ;
46+ import org .apache .spark .io .LZFCompressionCodec ;
47+ import org .apache .spark .io .SnappyCompressionCodec ;
4548import org .apache .spark .executor .ShuffleWriteMetrics ;
4649import org .apache .spark .executor .TaskMetrics ;
4750import org .apache .spark .network .util .LimitedInputStream ;
@@ -65,6 +68,7 @@ public class UnsafeShuffleWriterSuite {
6568 File tempDir ;
6669 long [] partitionSizesInMergedFile ;
6770 final LinkedList <File > spillFilesCreated = new LinkedList <File >();
71+ SparkConf conf ;
6872 final Serializer serializer = new KryoSerializer (new SparkConf ());
6973
7074 @ Mock (answer = RETURNS_SMART_NULLS ) ShuffleMemoryManager shuffleMemoryManager ;
@@ -74,10 +78,14 @@ public class UnsafeShuffleWriterSuite {
7478 @ Mock (answer = RETURNS_SMART_NULLS ) TaskContext taskContext ;
7579 @ Mock (answer = RETURNS_SMART_NULLS ) ShuffleDependency <Object , Object , Object > shuffleDep ;
7680
77- private static final class CompressStream extends AbstractFunction1 <OutputStream , OutputStream > {
81+ private final class CompressStream extends AbstractFunction1 <OutputStream , OutputStream > {
7882 @ Override
7983 public OutputStream apply (OutputStream stream ) {
80- return stream ;
84+ if (conf .getBoolean ("spark.shuffle.compress" , true )) {
85+ return CompressionCodec$ .MODULE$ .createCodec (conf ).compressedOutputStream (stream );
86+ } else {
87+ return stream ;
88+ }
8189 }
8290 }
8391
@@ -98,6 +106,7 @@ public void setUp() throws IOException {
98106 mergedOutputFile = File .createTempFile ("mergedoutput" , "" , tempDir );
99107 partitionSizesInMergedFile = null ;
100108 spillFilesCreated .clear ();
109+ conf = new SparkConf ();
101110
102111 when (shuffleMemoryManager .tryToAcquire (anyLong ())).then (returnsFirstArg ());
103112
@@ -123,8 +132,35 @@ public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Th
123132 );
124133 }
125134 });
126- when (blockManager .wrapForCompression (any (BlockId .class ), any (InputStream .class )))
127- .then (returnsSecondArg ());
135+ when (blockManager .wrapForCompression (any (BlockId .class ), any (InputStream .class ))).thenAnswer (
136+ new Answer <InputStream >() {
137+ @ Override
138+ public InputStream answer (InvocationOnMock invocation ) throws Throwable {
139+ assert (invocation .getArguments ()[0 ] instanceof TempShuffleBlockId );
140+ InputStream is = (InputStream ) invocation .getArguments ()[1 ];
141+ if (conf .getBoolean ("spark.shuffle.compress" , true )) {
142+ return CompressionCodec$ .MODULE$ .createCodec (conf ).compressedInputStream (is );
143+ } else {
144+ return is ;
145+ }
146+ }
147+ }
148+ );
149+
150+ when (blockManager .wrapForCompression (any (BlockId .class ), any (OutputStream .class ))).thenAnswer (
151+ new Answer <OutputStream >() {
152+ @ Override
153+ public OutputStream answer (InvocationOnMock invocation ) throws Throwable {
154+ assert (invocation .getArguments ()[0 ] instanceof TempShuffleBlockId );
155+ OutputStream os = (OutputStream ) invocation .getArguments ()[1 ];
156+ if (conf .getBoolean ("spark.shuffle.compress" , true )) {
157+ return CompressionCodec$ .MODULE$ .createCodec (conf ).compressedOutputStream (os );
158+ } else {
159+ return os ;
160+ }
161+ }
162+ }
163+ );
128164
129165 when (shuffleBlockResolver .getDataFile (anyInt (), anyInt ())).thenReturn (mergedOutputFile );
130166 doAnswer (new Answer <Void >() {
@@ -136,11 +172,11 @@ public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
136172 }).when (shuffleBlockResolver ).writeIndexFile (anyInt (), anyInt (), any (long [].class ));
137173
138174 when (diskBlockManager .createTempShuffleBlock ()).thenAnswer (
139- new Answer <Tuple2 <TempLocalBlockId , File >>() {
175+ new Answer <Tuple2 <TempShuffleBlockId , File >>() {
140176 @ Override
141- public Tuple2 <TempLocalBlockId , File > answer (
177+ public Tuple2 <TempShuffleBlockId , File > answer (
142178 InvocationOnMock invocationOnMock ) throws Throwable {
143- TempLocalBlockId blockId = new TempLocalBlockId (UUID .randomUUID ());
179+ TempShuffleBlockId blockId = new TempShuffleBlockId (UUID .randomUUID ());
144180 File file = File .createTempFile ("spillFile" , ".spill" , tempDir );
145181 spillFilesCreated .add (file );
146182 return Tuple2$ .MODULE$ .apply (blockId , file );
@@ -154,7 +190,6 @@ public Tuple2<TempLocalBlockId, File> answer(
154190 }
155191
156192 private UnsafeShuffleWriter <Object , Object > createWriter (boolean transferToEnabled ) {
157- SparkConf conf = new SparkConf ();
158193 conf .set ("spark.file.transferTo" , String .valueOf (transferToEnabled ));
159194 return new UnsafeShuffleWriter <Object , Object >(
160195 blockManager ,
@@ -164,7 +199,7 @@ private UnsafeShuffleWriter<Object, Object> createWriter(boolean transferToEnabl
164199 new UnsafeShuffleHandle <Object , Object >(0 , 1 , shuffleDep ),
165200 0 , // map id
166201 taskContext ,
167- new SparkConf ()
202+ conf
168203 );
169204 }
170205
@@ -183,8 +218,11 @@ private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
183218 if (partitionSize > 0 ) {
184219 InputStream in = new FileInputStream (mergedOutputFile );
185220 ByteStreams .skipFully (in , startOffset );
186- DeserializationStream recordsStream = serializer .newInstance ().deserializeStream (
187- new LimitedInputStream (in , partitionSize ));
221+ in = new LimitedInputStream (in , partitionSize );
222+ if (conf .getBoolean ("spark.shuffle.compress" , true )) {
223+ in = CompressionCodec$ .MODULE$ .createCodec (conf ).compressedInputStream (in );
224+ }
225+ DeserializationStream recordsStream = serializer .newInstance ().deserializeStream (in );
188226 Iterator <Tuple2 <Object , Object >> records = recordsStream .asKeyValueIterator ();
189227 while (records .hasNext ()) {
190228 Tuple2 <Object , Object > record = records .next ();
@@ -245,7 +283,15 @@ public void writeWithoutSpilling() throws Exception {
245283 assertSpillFilesWereCleanedUp ();
246284 }
247285
248- private void testMergingSpills (boolean transferToEnabled ) throws IOException {
286+ private void testMergingSpills (
287+ boolean transferToEnabled ,
288+ String compressionCodecName ) throws IOException {
289+ if (compressionCodecName != null ) {
290+ conf .set ("spark.shuffle.compress" , "true" );
291+ conf .set ("spark.io.compression.codec" , compressionCodecName );
292+ } else {
293+ conf .set ("spark.shuffle.compress" , "false" );
294+ }
249295 final UnsafeShuffleWriter <Object , Object > writer = createWriter (transferToEnabled );
250296 final ArrayList <Product2 <Object , Object >> dataToWrite =
251297 new ArrayList <Product2 <Object , Object >>();
@@ -265,25 +311,57 @@ private void testMergingSpills(boolean transferToEnabled) throws IOException {
265311 Assert .assertTrue (mergedOutputFile .exists ());
266312 Assert .assertEquals (2 , spillFilesCreated .size ());
267313
268- long sumOfPartitionSizes = 0 ;
269- for (long size : partitionSizesInMergedFile ) {
270- sumOfPartitionSizes += size ;
271- }
272- Assert .assertEquals (mergedOutputFile .length (), sumOfPartitionSizes );
314+ // This assertion only holds for the fast merging path:
315+ // long sumOfPartitionSizes = 0;
316+ // for (long size: partitionSizesInMergedFile) {
317+ // sumOfPartitionSizes += size;
318+ // }
319+ // Assert.assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
320+ Assert .assertTrue (mergedOutputFile .length () > 0 );
273321 Assert .assertEquals (
274322 HashMultiset .create (dataToWrite ),
275323 HashMultiset .create (readRecordsFromFile ()));
276324 assertSpillFilesWereCleanedUp ();
277325 }
278326
279327 @ Test
280- public void mergeSpillsWithTransferTo () throws Exception {
281- testMergingSpills (true );
328+ public void mergeSpillsWithTransferToAndLZF () throws Exception {
329+ testMergingSpills (true , LZFCompressionCodec .class .getName ());
330+ }
331+
332+ @ Test
333+ public void mergeSpillsWithFileStreamAndLZF () throws Exception {
334+ testMergingSpills (false , LZFCompressionCodec .class .getName ());
335+ }
336+
337+ @ Test
338+ public void mergeSpillsWithTransferToAndLZ4 () throws Exception {
339+ testMergingSpills (true , LZ4CompressionCodec .class .getName ());
340+ }
341+
342+ @ Test
343+ public void mergeSpillsWithFileStreamAndLZ4 () throws Exception {
344+ testMergingSpills (false , LZ4CompressionCodec .class .getName ());
345+ }
346+
347+ @ Test
348+ public void mergeSpillsWithTransferToAndSnappy () throws Exception {
349+ testMergingSpills (true , SnappyCompressionCodec .class .getName ());
350+ }
351+
352+ @ Test
353+ public void mergeSpillsWithFileStreamAndSnappy () throws Exception {
354+ testMergingSpills (false , SnappyCompressionCodec .class .getName ());
355+ }
356+
357+ @ Test
358+ public void mergeSpillsWithTransferToAndNoCompression () throws Exception {
359+ testMergingSpills (true , null );
282360 }
283361
284362 @ Test
285- public void mergeSpillsWithFileStream () throws Exception {
286- testMergingSpills (false );
363+ public void mergeSpillsWithFileStreamAndNoCompression () throws Exception {
364+ testMergingSpills (false , null );
287365 }
288366
289367 @ Test
0 commit comments