@@ -352,6 +352,63 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
352352 intercept[FetchFailedException ] { iterator.next() }
353353 }
354354
355+ test(" big corrupt blocks will not be retiried" ) {
356+ val corruptStream = mock(classOf [InputStream ])
357+ when(corruptStream.read(any(), any(), any())).thenThrow(new IOException (" corrupt" ))
358+ val corruptBuffer = mock(classOf [ManagedBuffer ])
359+ when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
360+ doReturn(10000L ).when(corruptBuffer).size()
361+
362+ val blockManager = mock(classOf [BlockManager ])
363+ val localBmId = BlockManagerId (" test-client" , " test-client" , 1 )
364+ doReturn(localBmId).when(blockManager).blockManagerId
365+ doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId (0 , 0 , 0 ))
366+ val localBlockLengths = Seq [Tuple2 [BlockId , Long ]](
367+ ShuffleBlockId (0 , 0 , 0 ) -> corruptBuffer.size()
368+ )
369+
370+ val remoteBmId = BlockManagerId (" test-client-1" , " test-client-1" , 2 )
371+ val remoteBlockLengths = Seq [Tuple2 [BlockId , Long ]](
372+ ShuffleBlockId (0 , 1 , 0 ) -> corruptBuffer.size()
373+ )
374+
375+ val transfer = mock(classOf [BlockTransferService ])
376+ when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
377+ .thenAnswer(new Answer [Unit ] {
378+ override def answer (invocation : InvocationOnMock ): Unit = {
379+ val listener = invocation.getArguments()(4 ).asInstanceOf [BlockFetchingListener ]
380+ val blocks = invocation.getArguments()(3 ).asInstanceOf [Array [String ]]
381+ Future {
382+ blocks.foreach (listener.onBlockFetchSuccess(_, corruptBuffer))
383+ }
384+ }
385+ })
386+
387+ val blocksByAddress = Seq [(BlockManagerId , Seq [(BlockId , Long )])](
388+ (localBmId, localBlockLengths),
389+ (remoteBmId, remoteBlockLengths)
390+ )
391+
392+ val taskContext = TaskContext .empty()
393+ val iterator = new ShuffleBlockFetcherIterator (
394+ taskContext,
395+ transfer,
396+ blockManager,
397+ blocksByAddress,
398+ (_, in) => new LimitedInputStream (in, 10000 ),
399+ 2048 ,
400+ Int .MaxValue ,
401+ Int .MaxValue ,
402+ Int .MaxValue ,
403+ true )
404+ // Blocks should be returned without exceptions.
405+ val blockSet = collection.mutable.HashSet [BlockId ]()
406+ blockSet.add(iterator.next()._1)
407+ blockSet.add(iterator.next()._1)
408+ assert(blockSet == collection.immutable.HashSet (
409+ ShuffleBlockId (0 , 0 , 0 ), ShuffleBlockId (0 , 1 , 0 )))
410+ }
411+
355412 test(" retry corrupt blocks (disabled)" ) {
356413 val blockManager = mock(classOf [BlockManager ])
357414 val localBmId = BlockManagerId (" test-client" , " test-client" , 1 )
0 commit comments