diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index a5064cc25113..002bf65ba593 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -328,15 +328,15 @@ object SparkEnv extends Logging { conf.get(BLOCK_MANAGER_PORT) } - val blockTransferService = - new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress, - blockManagerPort, numUsableCores) - val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), conf, isDriver) + val blockTransferService = + new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress, + blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint) + // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 977a27bdfe1b..4ad9a0cc4b10 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -37,3 +37,9 @@ private[spark] class SparkDriverExecutionException(cause: Throwable) */ private[spark] case class SparkUserAppException(exitCode: Int) extends SparkException(s"User application exited with $exitCode") + +/** + * Exception thrown when the relative executor to access is dead. + */ +private[spark] case class ExecutorDeadException(message: String) + extends SparkException(message) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 864e8ad1a6f9..b12cd4254f19 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -17,16 +17,19 @@ package org.apache.spark.network.netty +import java.io.IOException import java.nio.ByteBuffer import java.util.{HashMap => JHashMap, Map => JMap} import scala.collection.JavaConverters._ import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag +import scala.util.{Success, Try} import com.codahale.metrics.{Metric, MetricSet} import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.ExecutorDeadException import org.apache.spark.internal.config import org.apache.spark.network._ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} @@ -36,8 +39,10 @@ import org.apache.spark.network.server._ import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.{UploadBlock, UploadBlockStream} import org.apache.spark.network.util.JavaUtils +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.serializer.JavaSerializer import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.BlockManagerMessages.IsExecutorAlive import org.apache.spark.util.Utils /** @@ -49,7 +54,8 @@ private[spark] class NettyBlockTransferService( bindAddress: String, override val hostName: String, _port: Int, - numCores: Int) + numCores: Int, + driverEndPointRef: RpcEndpointRef = null) extends BlockTransferService { // TODO: Don't use Java serialization, use a more cross-version compatible serialization format. @@ -112,8 +118,20 @@ private[spark] class NettyBlockTransferService( val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) - new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempFileManager).start() + try { + new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, + transportConf, tempFileManager).start() + } catch { + case e: IOException => + Try { + driverEndPointRef.askSync[Boolean](IsExecutorAlive(execId)) + } match { + case Success(v) if v == false => + throw new ExecutorDeadException(s"The relative remote executor(Id: $execId)," + + " which maintains the block data to fetch is dead.") + case _ => throw e + } + } } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index f5d6029e445c..f388d59e78ba 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -105,6 +105,9 @@ class BlockManagerMasterEndpoint( case GetBlockStatus(blockId, askSlaves) => context.reply(blockStatus(blockId, askSlaves)) + case IsExecutorAlive(executorId) => + context.reply(blockManagerIdByExecutor.contains(executorId)) + case GetMatchingBlockIds(filter, askSlaves) => context.reply(getMatchingBlockIds(filter, askSlaves)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 1bbe7a5b3950..2be28420b495 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -123,4 +123,6 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster + + case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster } diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 78423ee68a0e..5d67d3358a9c 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -17,13 +17,21 @@ package org.apache.spark.network.netty +import java.io.IOException + +import scala.concurrent.{ExecutionContext, Future} +import scala.reflect.ClassTag import scala.util.Random -import org.mockito.Mockito.mock +import org.mockito.ArgumentMatchers.any +import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest._ -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.{ExecutorDeadException, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.network.BlockDataManager +import org.apache.spark.network.client.{TransportClient, TransportClientFactory} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcTimeout} class NettyBlockTransferServiceSuite extends SparkFunSuite @@ -77,6 +85,48 @@ class NettyBlockTransferServiceSuite verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port) } + test("SPARK-27637: test fetch block with executor dead") { + implicit val exectionContext = ExecutionContext.global + val port = 17634 + Random.nextInt(10000) + logInfo("random port for test: " + port) + + val driverEndpointRef = new RpcEndpointRef(new SparkConf()) { + override def address: RpcAddress = null + override def name: String = "test" + override def send(message: Any): Unit = {} + // This rpcEndPointRef always return false for unit test to touch ExecutorDeadException. + override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { + Future{false.asInstanceOf[T]} + } + } + + val clientFactory = mock(classOf[TransportClientFactory]) + val client = mock(classOf[TransportClient]) + // This is used to touch an IOException during fetching block. + when(client.sendRpc(any(), any())).thenAnswer(_ => {throw new IOException()}) + var createClientCount = 0 + when(clientFactory.createClient(any(), any())).thenAnswer(_ => { + createClientCount += 1 + client + }) + + val listener = mock(classOf[BlockFetchingListener]) + var hitExecutorDeadException = false + when(listener.onBlockFetchFailure(any(), any(classOf[ExecutorDeadException]))) + .thenAnswer(_ => {hitExecutorDeadException = true}) + + service0 = createService(port, driverEndpointRef) + val clientFactoryField = service0.getClass.getField( + "org$apache$spark$network$netty$NettyBlockTransferService$$clientFactory") + clientFactoryField.setAccessible(true) + clientFactoryField.set(service0, clientFactory) + + service0.fetchBlocks("localhost", port, "exec1", + Array("block1"), listener, mock(classOf[DownloadFileManager])) + assert(createClientCount === 1) + assert(hitExecutorDeadException) + } + private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { actualPort should be >= expectedPort // avoid testing equality in case of simultaneous tests @@ -85,13 +135,15 @@ class NettyBlockTransferServiceSuite actualPort should be <= (expectedPort + 100) } - private def createService(port: Int): NettyBlockTransferService = { + private def createService( + port: Int, + rpcEndpointRef: RpcEndpointRef = null): NettyBlockTransferService = { val conf = new SparkConf() .set("spark.app.id", s"test-${getClass.getName}") val securityManager = new SecurityManager(conf) val blockDataManager = mock(classOf[BlockDataManager]) val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost", - port, 1) + port, 1, rpcEndpointRef) service.init(blockDataManager) service }