diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/PrepareRequestReceivedCallBack.java b/common/network-common/src/main/java/org/apache/spark/network/client/PrepareRequestReceivedCallBack.java new file mode 100644 index 000000000000..379c17b9968e --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/PrepareRequestReceivedCallBack.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +public interface PrepareRequestReceivedCallBack { + void onSuccess(); + + void onFailure(Throwable e); +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockPreparingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockPreparingListener.java new file mode 100644 index 000000000000..67f643fd4a8a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockPreparingListener.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +public interface BlockPreparingListener { + void onBlockPrepareSuccess(); + void onBlockPrepareFailure(Throwable exception); +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockToPrepareInfoSender.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockToPrepareInfoSender.java new file mode 100644 index 000000000000..0e7a7f78dca9 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockToPrepareInfoSender.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.lang.Override; +import java.lang.String; +import java.lang.Throwable; +import java.nio.ByteBuffer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.apache.spark.network.client.PrepareRequestReceivedCallBack; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.PrepareBlocks; + + +public class BlockToPrepareInfoSender { + private final Logger logger = LoggerFactory.getLogger(BlockToPrepareInfoSender.class); + + private final TransportClient client; + private final PrepareBlocks prepareMessage; + private final String[] blockIds; + private final String[] blocksToRelease; + private final BlockPreparingListener listener; + private final PrepareRequestReceivedCallBack requestReceivedCallBack; + + public BlockToPrepareInfoSender( + TransportClient client, + String appId, + String execId, + String[] blockIds, + String[] blocksToRelease, + BlockPreparingListener listener) { + this.client = client; + this.prepareMessage = new PrepareBlocks(appId, execId, blockIds, blocksToRelease); + this.blockIds = blockIds; + this.blocksToRelease = blocksToRelease; + this.listener = listener; + this.requestReceivedCallBack = new PrepareCallBack(); + } + + private class PrepareCallBack implements PrepareRequestReceivedCallBack { + @Override + public void onSuccess() { + listener.onBlockPrepareSuccess(); + } + + @Override + public void onFailure(Throwable e) { + listener.onBlockPrepareFailure(e); + } + } + + public void start() { + if (blockIds.length == 0) { +// throw new IllegalArgumentException("Zero-size blockIds array"); + logger.warn("Zero-size blockIds array"); + } + + client.sendRpc(prepareMessage.toByteBuffer(), new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + logger.debug("Successfully send prepare block's info, ready for the next step"); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while send the prepare message"); + } + }); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 58ca87d9d3b1..94d26f18d517 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -151,4 +151,38 @@ public void registerWithShuffleServer( public void close() { clientFactory.close(); } + + @Override + public void prepareBlocks( + final String host, + final int port, + final String execId, + String[] prepareBlockIds, + final String[] releaseBlockIds, + BlockPreparingListener listener) { + logger.debug("Send prepare block info to {}:{} (executor id {})", host, port, execId); + + try { + RetryingBlockPreparer.PreparerStarter blockPrepareStarter = new RetryingBlockPreparer.PreparerStarter() { + @Override + public void createAndStart(String[] prepareBlockIds, String[] releaseBlocks, BlockPreparingListener listener) throws IOException { + TransportClient client = clientFactory.createClient(host, port); + new BlockToPrepareInfoSender(client, appId, execId, prepareBlockIds, + releaseBlockIds, listener).start(); + } + }; + + int maxRetries = conf.maxIORetries(); + if (maxRetries > 0) { + new RetryingBlockPreparer(conf, blockPrepareStarter, prepareBlockIds, + releaseBlockIds, listener).start(); + } else { + blockPrepareStarter.createAndStart(prepareBlockIds, releaseBlockIds, listener); + } + + } catch (Exception e) { + logger.error("Exception while sending the block list", e); + listener.onBlockPrepareFailure(e); + } + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockPreparer.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockPreparer.java new file mode 100644 index 000000000000..10c8ad9d6335 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockPreparer.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import com.google.common.collect.Sets; +import com.google.common.util.concurrent.Uninterruptibles; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +public class RetryingBlockPreparer { + + public static interface PreparerStarter { + void createAndStart(String[] prepareBlockIds, String[] releaseBlocks, BlockPreparingListener listener) throws IOException; + } + + private static final ExecutorService executorService = Executors.newCachedThreadPool( + NettyUtils.createThreadFactory("Prepare Info Send Retry") + ); + + private final Logger logger = LoggerFactory.getLogger(RetryingBlockPreparer.class); + + private final PreparerStarter preparerStarter; + + private final BlockPreparingListener listener; + + private final int maxRetries; + + private final int retryWaitTime; + + private int retryCount = 0; + + private final LinkedHashSet outstandingBlockInfosForPrepare; + + private final LinkedHashSet outStandingBlockInfosForRelease; + + private RetryingBlockPreparerListener currentListener; + + public RetryingBlockPreparer( + TransportConf conf, + PreparerStarter prepareStarter, + String[] prepareBlockIds, + String[] releaseBlockIds, + BlockPreparingListener listener) { + this.preparerStarter = prepareStarter; + this.listener = listener; + this.maxRetries = conf.maxIORetries(); + this.retryWaitTime = conf.ioRetryWaitTimeMs(); + this.outstandingBlockInfosForPrepare = Sets.newLinkedHashSet(); + this.outStandingBlockInfosForRelease = Sets.newLinkedHashSet(); + Collections.addAll(outstandingBlockInfosForPrepare, prepareBlockIds); + Collections.addAll(outStandingBlockInfosForRelease, releaseBlockIds); + this.currentListener = new RetryingBlockPreparerListener(); + } + + public void start(){ + senAllOutStanding(); + } + + private void senAllOutStanding() { + String[] blockIdsToSendForPrepare; + String[] blockIdsToSendForRelease; + int numRetries; + RetryingBlockPreparerListener myListener; + synchronized (this) { + blockIdsToSendForPrepare = outstandingBlockInfosForPrepare.toArray(new String[outstandingBlockInfosForPrepare.size()]); + blockIdsToSendForRelease = outStandingBlockInfosForRelease.toArray(new String[outStandingBlockInfosForRelease.size()]); + numRetries = retryCount; + myListener = currentListener; + } + + try { + preparerStarter.createAndStart(blockIdsToSendForPrepare, blockIdsToSendForRelease ,myListener); + listener.onBlockPrepareSuccess(); + } catch (Exception e) { + logger.error(String.format("Exception while begin send %s outstanding block info %s", + blockIdsToSendForPrepare.length, numRetries > 0 ? "(after )" + numRetries + "retries)" : ""), e); + if (shouldRetry(e)) { + initiateRetry(); + } else { + for (String bid: blockIdsToSendForPrepare) { + listener.onBlockPrepareFailure(e); + } + } + } + } + + private synchronized void initiateRetry(){ + retryCount += 1; + currentListener = new RetryingBlockPreparerListener(); + logger.info("Retrying send ({}/{}) for {} outstading_prepare and release blocks after {} ms", + retryCount, maxRetries, outstandingBlockInfosForPrepare.size()+outStandingBlockInfosForRelease.size(), retryWaitTime); + + executorService.submit(new Runnable() { + @Override + public void run() { + Uninterruptibles.sleepUninterruptibly(retryWaitTime, TimeUnit.MILLISECONDS); + senAllOutStanding(); + } + }); + } + + private synchronized boolean shouldRetry(Throwable e) { + boolean isIOException = e instanceof IOException + || (e.getCause() != null + && e.getCause() instanceof IOException); + boolean hasRemainRetries = retryCount < maxRetries; + return isIOException && hasRemainRetries; + } + + private class RetryingBlockPreparerListener implements BlockPreparingListener { + @Override + public void onBlockPrepareSuccess() { + boolean shouldForwardSuccess = false; + synchronized (RetryingBlockPreparer.this) { + if (this == currentListener) { + shouldForwardSuccess = true; + } + } + + if (shouldForwardSuccess) { + listener.onBlockPrepareSuccess(); + } + } + + @Override + public void onBlockPrepareFailure(Throwable exception) { + boolean shouldForwardFailure = false; + synchronized (RetryingBlockPreparer.this) { + if (this == currentListener) { + initiateRetry(); + } else { + logger.error(String.format("PrepareBlock failed to send blocks' info, " + + "and will not retry (%s retries)", retryCount), exception); + shouldForwardFailure = true; + } + } + + if (shouldForwardFailure) { + listener.onBlockPrepareFailure(exception); + } + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index f72ab40690d0..e068a9ee5db6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -41,4 +41,15 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener); + + /** + * Prepare a sequence of blocks from remote node asynchronously + */ + public abstract void prepareBlocks( + String host, + int port, + String execId, + String[] prepareBlockIds, + String[] releaseBlocks, + BlockPreparingListener listener); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 9af6759f5d5f..7532f678a97d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -42,7 +42,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5); + HEARTBEAT(5), PREPARE_BLOCKS(6); private final byte id; @@ -67,6 +67,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 3: return StreamHandle.decode(buf); case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); + case 6: return PrepareBlocks.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PrepareBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PrepareBlocks.java new file mode 100644 index 000000000000..6136e944dbde --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PrepareBlocks.java @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; +import java.util.Arrays; + +public class PrepareBlocks extends BlockTransferMessage{ + public final String appId; + public final String execId; + public final String[] blockIds; + public final String[] blockIdsToRelease; + + public PrepareBlocks (String appId, String execId, String[] blockIdsToPrepare, String[] blockIdsToRelease) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIdsToPrepare; + this.blockIdsToRelease = blockIdsToRelease; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .add("blockIdsToRelease", Arrays.toString(blockIdsToRelease)) + .toString(); + } + + @Override + public boolean equals(Object obj) { + if (obj != null && obj instanceof PrepareBlocks) { + PrepareBlocks o = (PrepareBlocks) obj; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds) + && Arrays.equals(blockIdsToRelease, o.blockIdsToRelease); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds) + + Encoders.StringArrays.encodedLength(blockIdsToRelease); + } + + @Override + protected Type type() { + return Type.PREPARE_BLOCKS; + } + + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + Encoders.StringArrays.encode(buf, blockIdsToRelease); + } + + public static PrepareBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + String[] releaseBlocks = Encoders.StringArrays.decode(buf); + return new PrepareBlocks(appId, execId, blockIds, releaseBlocks); + } +} + diff --git a/core/src/main/scala/org/apache/spark/network/netty/BlockCache.scala b/core/src/main/scala/org/apache/spark/network/netty/BlockCache.scala new file mode 100644 index 000000000000..39e8e61d5a89 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/network/netty/BlockCache.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.netty + +import java.util.concurrent._ + +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.storage.BlockId +import org.apache.spark.SparkEnv + +object BlockCache extends Logging{ + + val reqBuffer = new ConcurrentHashMap[Seq[BlockId], FutureCacheForBLocks]() + + def releaseAll(blockIds: Array[BlockId]): Unit = { + reqBuffer.remove(blockIds) + } + + def addAll(blockIds: Seq[BlockId]): Unit = { + val data = new FutureCacheForBLocks(blockIds) + reqBuffer.put(blockIds, data) + } + + def getAll(blockIds: Seq[BlockId]): LinkedBlockingQueue[ManagedBuffer] = { + val buffers = reqBuffer.get(blockIds) + buffers.get() + } +} + +class FutureCacheForBLocks { + var blockIds: Seq[BlockId] = _ + var future: FutureTask[LinkedBlockingQueue[ManagedBuffer]] = _ + + def this (blockIds: Seq[BlockId]) { + this() + this.blockIds = blockIds + future = new FutureTask[LinkedBlockingQueue[ManagedBuffer]](new RealCacheForBlocks(blockIds)) + + val executor = Executors.newFixedThreadPool(1) + + executor.submit(future) + } + + def get(): LinkedBlockingQueue[ManagedBuffer] = { + future.get() + } +} + +class RealCacheForBlocks extends Callable[LinkedBlockingQueue[ManagedBuffer]] { + val blockManager = SparkEnv.get.blockManager + var blockIds: Seq[BlockId] = _ + + def this(blockIds: Seq[BlockId]) { + this() + this.blockIds = blockIds + } + + override def call(): LinkedBlockingQueue[ManagedBuffer] = { + val resQueue = new LinkedBlockingQueue[ManagedBuffer]() + val iterator = blockIds.iterator + while (iterator.hasNext) { + val blockId = iterator.next() + if (blockId != null) { + val data = blockManager.getBlockData(blockId) + resQueue.add(data) + } + } + resQueue + } +} diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df702..aa0e5d49cbf8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -28,7 +28,7 @@ import org.apache.spark.network.BlockDataManager import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} -import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock} +import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{BlockId, StorageLevel} @@ -74,6 +74,22 @@ class NettyBlockRpcServer( val blockId = BlockId(uploadBlock.blockId) blockManager.putBlockData(blockId, data, level, classTag) responseContext.onSuccess(ByteBuffer.allocate(0)) + + case prepareBlocks: PrepareBlocks => + + if (prepareBlocks.blockIdsToRelease.size > 0) { + val blocksToRelease: Seq[BlockId] = + prepareBlocks.blockIdsToRelease.map(BlockId.apply) + BlockCache.releaseAll(blocksToRelease.toArray) + } + + if (prepareBlocks.blockIds.size > 0) { + val blockIds: Seq[BlockId] = + prepareBlocks.blockIds.map(BlockId.apply) + BlockCache.addAll(blockIds) + } + + responseContext.onSuccess(ByteBuffer.allocate(0)) } } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 33a321960774..fcc9212941a3 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -29,7 +29,8 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.RetryingBlockPreparer.PreparerStarter import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -113,6 +114,40 @@ private[spark] class NettyBlockTransferService( } } + override def prepareBlocks( + host: String, + port: Int, + execId: String, + prepareBlockIds: Array[String], + releaseBlockIds: Array[String], + listener: BlockPreparingListener): Unit = { + + try { + val blockPrepareStarter = new PreparerStarter { + override def createAndStart( + prepareBlockIds: Array[String], + releaseBlockIds: Array[String], + listener: BlockPreparingListener): Unit = { + val client = clientFactory.createClient(host, port) + new BlockToPrepareInfoSender(client, appId, execId, prepareBlockIds.toArray, + releaseBlockIds, listener).start() + } + } + + val maxRetries = transportConf.maxIORetries() + if (maxRetries > 0) { + new RetryingBlockPreparer(transportConf, blockPrepareStarter, prepareBlockIds, + releaseBlockIds, listener).start() + } else { + blockPrepareStarter.createAndStart(prepareBlockIds, releaseBlockIds, listener) + } + } catch { + case e : Exception => + logError("Exception while sending the block list", e) + listener.onBlockPrepareFailure(e) + } + } + override def port: Int = server.getPort override def uploadBlock( diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 87c8628ce97e..dfe8eb8c105a 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -39,7 +39,7 @@ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService} import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockPreparingListener} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} @@ -1241,6 +1241,16 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE Future {} } + override def prepareBlocks( + host: String, + port: Int, + execId: String, + prepareBlockIds: Array[String], + releaseBlocks: Array[String], + listener: BlockPreparingListener): Unit = { + listener.onBlockPrepareSuccess() + } + override def fetchBlockSync( host: String, port: Int,