diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index b886fce9be21..8c05288fb411 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -150,6 +150,12 @@ protected void handleMessage( int numRemovedBlocks = blockManager.removeBlocks(msg.appId, msg.execId, msg.blockIds); callback.onSuccess(new BlocksRemoved(numRemovedBlocks).toByteBuffer()); + } else if (msgObj instanceof GetLocalDirsForExecutors) { + GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj; + checkAuth(client, msg.appId); + Map localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds); + callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer()); + } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java index 85d278138c2b..d6185f089d3c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java @@ -21,20 +21,21 @@ import java.nio.ByteBuffer; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import com.codahale.metrics.MetricSet; import com.google.common.collect.Lists; import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientBootstrap; -import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; @@ -182,7 +183,7 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { logger.warn("Error trying to remove RDD blocks " + Arrays.toString(blockIds) + - " via external shuffle service from executor: " + execId, e); + " via external shuffle service from executor: " + execId, e); numRemovedBlocksFuture.complete(0); client.close(); } @@ -190,6 +191,46 @@ public void onFailure(Throwable e) { return numRemovedBlocksFuture; } + public void getHostLocalDirs( + String host, + int port, + String[] execIds, + CompletableFuture> hostLocalDirsCompletable) { + checkInit(); + GetLocalDirsForExecutors getLocalDirsMessage = new GetLocalDirsForExecutors(appId, execIds); + try { + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(getLocalDirsMessage.toByteBuffer(), new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + try { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response); + hostLocalDirsCompletable.complete( + ((LocalDirsForExecutors) msgObj).getLocalDirsByExec()); + } catch (Throwable t) { + logger.warn("Error trying to get the host local dirs for " + + Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service", + t.getCause()); + hostLocalDirsCompletable.completeExceptionally(t); + } finally { + client.close(); + } + } + + @Override + public void onFailure(Throwable t) { + logger.warn("Error trying to get the host local dirs for " + + Arrays.toString(getLocalDirsMessage.execIds) + " via external shuffle service", + t.getCause()); + hostLocalDirsCompletable.completeExceptionally(t); + client.close(); + } + }); + } catch (IOException | InterruptedException e) { + hostLocalDirsCompletable.completeExceptionally(e); + } + } + @Override public void close() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index beca5d6e5a78..657774c1b468 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -25,7 +25,9 @@ import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.regex.Pattern; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.ObjectMapper; @@ -369,6 +371,19 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { return numRemovedBlocks; } + public Map getLocalDirs(String appId, String[] execIds) { + return Arrays.stream(execIds) + .map(exec -> { + ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec)); + if (info == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, exec)); + } + return Pair.of(exec, info.localDirs); + }) + .collect(Collectors.toMap(Pair::getKey, Pair::getValue)); + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 41dd55847ebd..89d8dfe8716b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -47,7 +47,7 @@ public abstract class BlockTransferMessage implements Encodable { public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), - FETCH_SHUFFLE_BLOCKS(9); + FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11); private final byte id; @@ -76,6 +76,8 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 7: return RemoveBlocks.decode(buf); case 8: return BlocksRemoved.decode(buf); case 9: return FetchShuffleBlocks.decode(buf); + case 10: return GetLocalDirsForExecutors.decode(buf); + case 11: return LocalDirsForExecutors.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java index 3f04443871b6..723b2f75c6fc 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java @@ -50,7 +50,7 @@ public String toString() { public boolean equals(Object other) { if (other != null && other instanceof BlocksRemoved) { BlocksRemoved o = (BlocksRemoved) other; - return Objects.equal(numRemovedBlocks, o.numRemovedBlocks); + return numRemovedBlocks == o.numRemovedBlocks; } return false; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index 93758bdc58fb..540ecd09a7e3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -65,7 +65,7 @@ public boolean equals(Object other) { if (other != null && other instanceof ExecutorShuffleInfo) { ExecutorShuffleInfo o = (ExecutorShuffleInfo) other; return Arrays.equals(localDirs, o.localDirs) - && Objects.equal(subDirsPerLocalDir, o.subDirsPerLocalDir) + && subDirsPerLocalDir == o.subDirsPerLocalDir && Objects.equal(shuffleManager, o.shuffleManager); } return false; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/GetLocalDirsForExecutors.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/GetLocalDirsForExecutors.java new file mode 100644 index 000000000000..90c416acc69a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/GetLocalDirsForExecutors.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** Request to get the local dirs for the given executors. */ +public class GetLocalDirsForExecutors extends BlockTransferMessage { + public final String appId; + public final String[] execIds; + + public GetLocalDirsForExecutors(String appId, String[] execIds) { + this.appId = appId; + this.execIds = execIds; + } + + @Override + protected Type type() { return Type.GET_LOCAL_DIRS_FOR_EXECUTORS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId) * 41 + Arrays.hashCode(execIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execIds", Arrays.toString(execIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof GetLocalDirsForExecutors) { + GetLocalDirsForExecutors o = (GetLocalDirsForExecutors) other; + return appId.equals(o.appId) && Arrays.equals(execIds, o.execIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + Encoders.StringArrays.encodedLength(execIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.StringArrays.encode(buf, execIds); + } + + public static GetLocalDirsForExecutors decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String[] execIds = Encoders.StringArrays.decode(buf); + return new GetLocalDirsForExecutors(appId, execIds); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/LocalDirsForExecutors.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/LocalDirsForExecutors.java new file mode 100644 index 000000000000..0c3aa6a46114 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/LocalDirsForExecutors.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.*; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** The reply to get local dirs giving back the dirs for each of the requested executors. */ +public class LocalDirsForExecutors extends BlockTransferMessage { + private final String[] execIds; + private final int[] numLocalDirsByExec; + private final String[] allLocalDirs; + + public LocalDirsForExecutors(Map localDirsByExec) { + this.execIds = new String[localDirsByExec.size()]; + this.numLocalDirsByExec = new int[localDirsByExec.size()]; + ArrayList localDirs = new ArrayList<>(); + int index = 0; + for (Map.Entry e: localDirsByExec.entrySet()) { + execIds[index] = e.getKey(); + numLocalDirsByExec[index] = e.getValue().length; + Collections.addAll(localDirs, e.getValue()); + index++; + } + this.allLocalDirs = localDirs.toArray(new String[0]); + } + + private LocalDirsForExecutors(String[] execIds, int[] numLocalDirsByExec, String[] allLocalDirs) { + this.execIds = execIds; + this.numLocalDirsByExec = numLocalDirsByExec; + this.allLocalDirs = allLocalDirs; + } + + @Override + protected Type type() { return Type.LOCAL_DIRS_FOR_EXECUTORS; } + + @Override + public int hashCode() { + return Arrays.hashCode(execIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("execIds", Arrays.toString(execIds)) + .add("numLocalDirsByExec", Arrays.toString(numLocalDirsByExec)) + .add("allLocalDirs", Arrays.toString(allLocalDirs)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof LocalDirsForExecutors) { + LocalDirsForExecutors o = (LocalDirsForExecutors) other; + return Arrays.equals(execIds, o.execIds) + && Arrays.equals(numLocalDirsByExec, o.numLocalDirsByExec) + && Arrays.equals(allLocalDirs, o.allLocalDirs); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.StringArrays.encodedLength(execIds) + + Encoders.IntArrays.encodedLength(numLocalDirsByExec) + + Encoders.StringArrays.encodedLength(allLocalDirs); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.StringArrays.encode(buf, execIds); + Encoders.IntArrays.encode(buf, numLocalDirsByExec); + Encoders.StringArrays.encode(buf, allLocalDirs); + } + + public static LocalDirsForExecutors decode(ByteBuf buf) { + String[] execIds = Encoders.StringArrays.decode(buf); + int[] numLocalDirsByExec = Encoders.IntArrays.decode(buf); + String[] allLocalDirs = Encoders.StringArrays.decode(buf); + return new LocalDirsForExecutors(execIds, numLocalDirsByExec, allLocalDirs); + } + + public Map getLocalDirsByExec() { + Map localDirsByExec = new HashMap<>(); + int index = 0; + int localDirsIndex = 0; + for (int length: numLocalDirsByExec) { + localDirsByExec.put(execIds[index], + Arrays.copyOfRange(allLocalDirs, localDirsIndex, localDirsIndex + length)); + localDirsIndex += length; + index++; + } + return localDirsByExec; + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index fd2c67a3a270..67229371c3a4 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -21,6 +21,9 @@ import static org.junit.Assert.*; +import java.util.HashMap; +import java.util.Map; + import org.apache.spark.network.shuffle.protocol.*; /** Verifies that all BlockTransferMessages can be serialized correctly. */ @@ -41,10 +44,29 @@ public void serializeOpenShuffleBlocks() { checkSerializeDeserialize(new StreamHandle(12345, 16)); } - private void checkSerializeDeserialize(BlockTransferMessage msg) { + @Test + public void testLocalDirsMessages() { + checkSerializeDeserialize( + new GetLocalDirsForExecutors("app-1", new String[]{"exec-1", "exec-2"})); + + Map map = new HashMap<>(); + map.put("exec-1", new String[]{"loc1.1"}); + map.put("exec-22", new String[]{"loc2.1", "loc2.2"}); + LocalDirsForExecutors localDirsForExecs = new LocalDirsForExecutors(map); + Map resultMap = + ((LocalDirsForExecutors)checkSerializeDeserialize(localDirsForExecs)).getLocalDirsByExec(); + assertEquals(resultMap.size(), map.keySet().size()); + for (Map.Entry e: map.entrySet()) { + assertTrue(resultMap.containsKey(e.getKey())); + assertArrayEquals(e.getValue(), resultMap.get(e.getKey())); + } + } + + private BlockTransferMessage checkSerializeDeserialize(BlockTransferMessage msg) { BlockTransferMessage msg2 = BlockTransferMessage.Decoder.fromByteBuffer(msg.toByteBuffer()); assertEquals(msg, msg2); assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); + return msg2; } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index db3f2266cf33..c6521eacb0bb 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2884,6 +2884,14 @@ object SparkContext extends Logging { memoryPerSlaveInt, sc.executorMemory)) } + // For host local mode setting the default of SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED + // to false because this mode is intended to be used for testing and in this case all the + // executors are running on the same host. So if host local reading was enabled here then + // testing of the remote fetching would be secondary as setting this config explicitly to + // false would be required in most of the unit test (despite the fact that remote fetching + // is much more frequent in production). + sc.conf.setIfMissing(SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED, false) + val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8e8e36dbda94..e1eda91bf132 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1075,6 +1075,24 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED = + ConfigBuilder("spark.shuffle.readHostLocalDisk.enabled") + .doc("If enabled, shuffle blocks requested from those block managers which are running on " + + "the same host are read from the disk directly instead of being fetched as remote blocks " + + "over the network.") + .booleanConf + .createWithDefault(true) + + private[spark] val STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE = + ConfigBuilder("spark.storage.localDiskByExecutors.cacheSize") + .doc("The max number of executors for which the local dirs are stored. This size is " + + "both applied for the driver and both for the executors side to avoid having an " + + "unbounded store. This cache will be used to avoid the network in case of fetching disk " + + "persisted RDD blocks or shuffle blocks (when `spark.shuffle.readHostLocalDisk.enabled` " + + "is set) from the same host.") + .intConf + .createWithDefault(1000) + private[spark] val SHUFFLE_SYNC = ConfigBuilder("spark.shuffle.sync") .doc("Whether to force outstanding writes to disk.") diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 4993519aa384..0bd5774b632b 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,16 +22,22 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, ShuffleBlockId, StorageLevel} private[spark] trait BlockDataManager { + /** + * Interface to get host-local shuffle block data. Throws an exception if the block cannot be + * found or cannot be read successfully. + */ + def getHostLocalShuffleData(blockId: BlockId, dirs: Array[String]): ManagedBuffer + /** * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - def getBlockData(blockId: BlockId): ManagedBuffer + def getLocalBlockData(blockId: BlockId): ManagedBuffer /** * Put the block locally, using the given storage level. @@ -57,7 +63,7 @@ trait BlockDataManager { classTag: ClassTag[_]): StreamCallbackWithID /** - * Release locks acquired by [[putBlockData()]] and [[getBlockData()]]. + * Release locks acquired by [[putBlockData()]] and [[getLocalBlockData()]]. */ def releaseLock(blockId: BlockId, taskContext: Option[TaskContext]): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 3a41c5f73c0a..91910b936e7c 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -57,7 +57,7 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocksNum = openBlocks.blockIds.length val blocks = for (i <- (0 until blocksNum).view) - yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) + yield blockManager.getLocalBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, client.getChannel) logTrace(s"Registered streamId $streamId with $blocksNum buffers") @@ -67,7 +67,7 @@ class NettyBlockRpcServer( val blocks = fetchShuffleBlocks.mapIds.zipWithIndex.flatMap { case (mapId, index) => if (!fetchShuffleBlocks.batchFetchEnabled) { fetchShuffleBlocks.reduceIds(index).map { reduceId => - blockManager.getBlockData( + blockManager.getLocalBlockData( ShuffleBlockId(fetchShuffleBlocks.shuffleId, mapId, reduceId)) } } else { @@ -76,7 +76,7 @@ class NettyBlockRpcServer( throw new IllegalStateException(s"Invalid shuffle fetch request when batch mode " + s"is enabled: $fetchShuffleBlocks") } - Array(blockManager.getBlockData( + Array(blockManager.getLocalBlockData( ShuffleBlockBatchId( fetchShuffleBlocks.shuffleId, mapId, startAndEndId(0), startAndEndId(1)))) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 8b3993e21f07..af2c82e77197 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.io.NioBufferedFileInputStream import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.ExecutorDiskUtils import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -51,12 +52,36 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - def getDataFile(shuffleId: Int, mapId: Long): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + + def getDataFile(shuffleId: Int, mapId: Long): File = getDataFile(shuffleId, mapId, None) + + /** + * Get the shuffle data file. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + */ + def getDataFile(shuffleId: Int, mapId: Long, dirs: Option[Array[String]]): File = { + val blockId = ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } - private def getIndexFile(shuffleId: Int, mapId: Long): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + /** + * Get the shuffle index file. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + */ + private def getIndexFile( + shuffleId: Int, + mapId: Long, + dirs: Option[Array[String]] = None): File = { + val blockId = ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID) + dirs + .map(ExecutorDiskUtils.getFile(_, blockManager.subDirsPerLocalDir, blockId.name)) + .getOrElse(blockManager.diskBlockManager.getFile(blockId)) } /** @@ -190,7 +215,9 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: BlockId): ManagedBuffer = { + override def getBlockData( + blockId: BlockId, + dirs: Option[Array[String]]): ManagedBuffer = { val (shuffleId, mapId, startReduceId, endReduceId) = blockId match { case id: ShuffleBlockId => (id.shuffleId, id.mapId, id.reduceId, id.reduceId + 1) @@ -201,7 +228,7 @@ private[spark] class IndexShuffleBlockResolver( } // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(shuffleId, mapId) + val indexFile = getIndexFile(shuffleId, mapId, dirs) // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code // which is incorrectly using our file descriptor then this code will fetch the wrong offsets @@ -224,7 +251,7 @@ private[spark] class IndexShuffleBlockResolver( } new FileSegmentManagedBuffer( transportConf, - getDataFile(shuffleId, mapId), + getDataFile(shuffleId, mapId, dirs), startOffset, endOffset - startOffset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index c50789658d61..5485cf955f11 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -31,10 +31,14 @@ trait ShuffleBlockResolver { type ShuffleId = Int /** - * Retrieve the data for the specified block. If the data for that block is not available, - * throws an unspecified exception. + * Retrieve the data for the specified block. + * + * When the dirs parameter is None then use the disk manager's local directories. Otherwise, + * read from the specified directories. + * + * If the data for that block is not available, throws an unspecified exception. */ - def getBlockData(blockId: BlockId): ManagedBuffer + def getBlockData(blockId: BlockId, dirs: Option[Array[String]] = None): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index c869a7078a1e..cc28f9b77da3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -22,17 +22,18 @@ import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels import java.util.Collections -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, TimeUnit} import scala.collection.mutable import scala.collection.mutable.HashMap import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag -import scala.util.Random +import scala.util.{Failure, Random, Success, Try} import scala.util.control.NonFatal import com.codahale.metrics.{MetricRegistry, MetricSet} +import com.google.common.cache.CacheBuilder import org.apache.commons.io.IOUtils import org.apache.spark._ @@ -113,6 +114,47 @@ private[spark] class ByteBufferBlockData( } +private[spark] class HostLocalDirManager( + futureExecutionContext: ExecutionContext, + cacheSize: Int, + externalBlockStoreClient: ExternalBlockStoreClient, + host: String, + externalShuffleServicePort: Int) extends Logging { + + private val executorIdToLocalDirsCache = + CacheBuilder + .newBuilder() + .maximumSize(cacheSize) + .build[String, Array[String]]() + + private[spark] def getCachedHostLocalDirs() + : scala.collection.Map[String, Array[String]] = executorIdToLocalDirsCache.synchronized { + import scala.collection.JavaConverters._ + return executorIdToLocalDirsCache.asMap().asScala + } + + private[spark] def getHostLocalDirs( + executorIds: Array[String])( + callback: Try[java.util.Map[String, Array[String]]] => Unit): Unit = { + val hostLocalDirsCompletable = new CompletableFuture[java.util.Map[String, Array[String]]] + externalBlockStoreClient.getHostLocalDirs( + host, + externalShuffleServicePort, + executorIds, + hostLocalDirsCompletable) + hostLocalDirsCompletable.whenComplete { (hostLocalDirs, throwable) => + if (hostLocalDirs != null) { + callback(Success(hostLocalDirs)) + executorIdToLocalDirsCache.synchronized { + executorIdToLocalDirsCache.putAll(hostLocalDirs) + } + } else { + callback(Failure(throwable)) + } + } + } +} + /** * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). @@ -206,6 +248,8 @@ private[spark] class BlockManager( new BlockManager.RemoteBlockDownloadFileManager(this) private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + var hostLocalDirManager: Option[HostLocalDirManager] = None + /** * Abstraction for storing blocks from bytes, whether they start in memory or on disk. * @@ -433,6 +477,20 @@ private[spark] class BlockManager( registerWithExternalShuffleServer() } + hostLocalDirManager = + if (conf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED)) { + externalBlockStoreClient.map { blockStoreClient => + new HostLocalDirManager( + futureExecutionContext, + conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE), + blockStoreClient, + blockManagerId.host, + externalShuffleServicePort) + } + } else { + None + } + logInfo(s"Initialized BlockManager: $blockManagerId") } @@ -542,11 +600,17 @@ private[spark] class BlockManager( } } + override def getHostLocalShuffleData( + blockId: BlockId, + dirs: Array[String]): ManagedBuffer = { + shuffleManager.shuffleBlockResolver.getBlockData(blockId, Some(dirs)) + } + /** * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - override def getBlockData(blockId: BlockId): ManagedBuffer = { + override def getLocalBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { shuffleManager.shuffleBlockResolver.getBlockData(blockId) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7e2027701c33..41ef1909cd4c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -26,6 +26,8 @@ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.util.Random +import com.google.common.cache.CacheBuilder + import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} @@ -49,6 +51,13 @@ class BlockManagerMasterEndpoint( blockManagerInfo: mutable.Map[BlockManagerId, BlockManagerInfo]) extends IsolatedRpcEndpoint with Logging { + // Mapping from executor id to the block manager's local disk directories. + private val executorIdToLocalDirs = + CacheBuilder + .newBuilder() + .maximumSize(conf.get(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE)) + .build[String, Array[String]]() + // Mapping from external shuffle service block manager id to the block statuses. private val blockStatusByShuffleService = new mutable.HashMap[BlockManagerId, JHashMap[BlockId, BlockStatus]] @@ -393,6 +402,7 @@ class BlockManagerMasterEndpoint( topologyMapper.getTopologyForHost(idWithoutTopologyInfo.host)) val time = System.currentTimeMillis() + executorIdToLocalDirs.put(id.executorId, localDirs) if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { case Some(oldId) => @@ -416,7 +426,7 @@ class BlockManagerMasterEndpoint( None } - blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), localDirs, + blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint, externalShuffleServiceBlockStatus) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, @@ -496,15 +506,16 @@ class BlockManagerMasterEndpoint( if (locations.nonEmpty && status.isDefined) { val localDirs = locations.find { loc => - if (loc.port != externalShuffleServicePort && loc.host == requesterHost) { + // When the external shuffle service running on the same host is found among the block + // locations then the block must be persisted on the disk. In this case the executorId + // can be used to access this block even when the original executor is already stopped. + loc.host == requesterHost && + (loc.port == externalShuffleServicePort || blockManagerInfo .get(loc) .flatMap(_.getStatus(blockId).map(_.storageLevel.useDisk)) - .getOrElse(false) - } else { - false - } - }.map(blockManagerInfo(_).localDirs) + .getOrElse(false)) + }.flatMap { bmId => Option(executorIdToLocalDirs.getIfPresent(bmId.executorId)) } Some(BlockLocationsAndStatus(locations, status.get, localDirs)) } else { None @@ -556,7 +567,6 @@ object BlockStatus { private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, - val localDirs: Array[String], val maxOnHeapMem: Long, val maxOffHeapMem: Long, val slaveEndpoint: RpcEndpointRef, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index f8aa97267cf1..8fa7e68815a9 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,7 +23,8 @@ import java.util.concurrent.{LinkedBlockingQueue, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue} +import scala.util.{Failure, Success} import org.apache.commons.io.IOUtils @@ -84,11 +85,14 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ + // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them + // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 + // nodes, rather than blocking on reading output from one node. + private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L) + /** * Total number of blocks to fetch. This should be equal to the total number of blocks * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]]. - * - * This should equal localBlocks.size + remoteBlocks.size. */ private[this] var numBlocksToFetch = 0 @@ -103,8 +107,12 @@ final class ShuffleBlockFetcherIterator( /** Local blocks to fetch, excluding zero-sized blocks. */ private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() - /** Remote blocks to fetch, excluding zero-sized blocks. */ - private[this] val remoteBlocks = new HashSet[BlockId]() + /** Host local blockIds to fetch by executors, excluding zero-sized blocks. */ + private[this] val hostLocalBlocksByExecutor = + LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + + /** Host local blocks to fetch, excluding zero-sized blocks. */ + private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() /** * A queue to hold our results. This turns the asynchronous model provided by @@ -272,73 +280,91 @@ final class ShuffleBlockFetcherIterator( } } - private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = { - // Make remote requests at most maxBytesInFlight / 5 in length; the reason to keep them - // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 - // nodes, rather than blocking on reading output from one node. - val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize - + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress) - - // Split local and remote blocks. Remote blocks are further split into FetchRequests of size - // at most maxBytesInFlight in order to limit the amount of data in flight. - val remoteRequests = new ArrayBuffer[FetchRequest] + private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = { + logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") + + // Partition to local, host-local and remote blocks. Remote blocks are further split into + // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight + val collectedRemoteRequests = new ArrayBuffer[FetchRequest] var localBlockBytes = 0L + var hostLocalBlockBytes = 0L var remoteBlockBytes = 0L + var numRemoteBlocks = 0 + + val hostLocalDirReadingEnabled = + blockManager.hostLocalDirManager != null && blockManager.hostLocalDirManager.isDefined for ((address, blockInfos) <- blocksByAddress) { if (address.executorId == blockManager.blockManagerId.executorId) { - blockInfos.find(_._2 <= 0) match { - case Some((blockId, size, _)) if size < 0 => - throw new BlockException(blockId, "Negative block size " + size) - case Some((blockId, size, _)) if size == 0 => - throw new BlockException(blockId, "Zero-sized blocks should be excluded.") - case None => // do nothing. - } + checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)).to[ArrayBuffer]) localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) localBlockBytes += mergedBlockInfos.map(_.size).sum + } else if (hostLocalDirReadingEnabled && address.host == blockManager.blockManagerId.host) { + checkBlockSizes(blockInfos) + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)).to[ArrayBuffer]) + val blocksForAddress = + mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) + hostLocalBlocksByExecutor += address -> blocksForAddress + hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3)) + hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum } else { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[FetchBlockInfo] - while (iterator.hasNext) { - val (blockId, size, mapIndex) = iterator.next() - remoteBlockBytes += size - if (size < 0) { - throw new BlockException(blockId, "Negative block size " + size) - } else if (size == 0) { - throw new BlockException(blockId, "Zero-sized blocks should be excluded.") - } else { - curBlocks += FetchBlockInfo(blockId, size, mapIndex) - curRequestSize += size - } - if (curRequestSize >= targetRequestSize || - curBlocks.size >= maxBlocksInFlightPerAddress) { - // Add this FetchRequest - val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks) - remoteBlocks ++= mergedBlocks.map(_.blockId) - remoteRequests += new FetchRequest(address, mergedBlocks) - logDebug(s"Creating fetch request of $curRequestSize at $address " - + s"with ${mergedBlocks.size} blocks") - curBlocks = new ArrayBuffer[FetchBlockInfo] - curRequestSize = 0 - } - } - // Add in the final request - if (curBlocks.nonEmpty) { - val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks) - remoteBlocks ++= mergedBlocks.map(_.blockId) - remoteRequests += new FetchRequest(address, mergedBlocks) - } + numRemoteBlocks += blockInfos.size + remoteBlockBytes += blockInfos.map(_._2).sum + collectFetchRequests(address, blockInfos, collectedRemoteRequests) } } val totalBytes = localBlockBytes + remoteBlockBytes logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " + - s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local blocks and " + - s"${remoteBlocks.size} (${Utils.bytesToString(remoteBlockBytes)}) remote blocks") - remoteRequests + s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks") + collectedRemoteRequests + } + + private def collectFetchRequests( + address: BlockManagerId, + blockInfos: Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = new ArrayBuffer[FetchBlockInfo] + while (iterator.hasNext) { + val (blockId, size, mapIndex) = iterator.next() + assertPositiveBlockSize(blockId, size) + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size + if (curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { + // Add this FetchRequest + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks) + collectedRemoteRequests += new FetchRequest(address, mergedBlocks) + logDebug(s"Creating fetch request of $curRequestSize at $address " + + s"with ${mergedBlocks.size} blocks") + curBlocks = new ArrayBuffer[FetchBlockInfo] + curRequestSize = 0 + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks) + collectedRemoteRequests += new FetchRequest(address, mergedBlocks) + } + } + + private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit = { + if (blockSize < 0) { + throw BlockException(blockId, "Negative block size " + size) + } else if (blockSize == 0) { + throw BlockException(blockId, "Zero-sized blocks should be excluded.") + } + } + + private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = { + blockInfos.foreach { case (blockId, size, _) => assertPositiveBlockSize(blockId, size) } } private[this] def mergeContinuousShuffleBlockIdsIfNeeded( @@ -397,7 +423,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val (blockId, mapIndex) = iter.next() try { - val buf = blockManager.getBlockData(blockId) + val buf = blockManager.getLocalBlockData(blockId) shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() @@ -420,12 +446,89 @@ final class ShuffleBlockFetcherIterator( } } + private[this] def fetchHostLocalBlock( + blockId: BlockId, + mapIndex: Int, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val buf = blockManager.getHostLocalShuffleData(blockId, localDirs) + buf.retain() + results.put(SuccessFetchResult(blockId, mapIndex, blockManagerId, buf.size(), buf, + isNetworkReqDone = false)) + true + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e)) + false + } + } + + /** + * Fetch the host-local blocks while we are fetching remote blocks. This is ok because + * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we + * track in-memory are the ManagedBuffer references themselves. + */ + private[this] def fetchHostLocalBlocks(hostLocalDirManager: HostLocalDirManager): Unit = { + val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs() + val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = + hostLocalBlocksByExecutor + .map { case (hostLocalBmId, bmInfos) => + (hostLocalBmId, bmInfos, cachedDirsByExec.get(hostLocalBmId.executorId)) + }.partition(_._3.isDefined) + val bmId = blockManager.blockManagerId + val immutableHostLocalBlocksWithoutDirs = + hostLocalBlocksWithMissingDirs.map { case (hostLocalBmId, bmInfos, _) => + hostLocalBmId -> bmInfos + }.toMap + if (immutableHostLocalBlocksWithoutDirs.nonEmpty) { + logDebug(s"Asynchronous fetching host-local blocks without cached executors' dir: " + + s"${immutableHostLocalBlocksWithoutDirs.mkString(", ")}") + val execIdsWithoutDirs = immutableHostLocalBlocksWithoutDirs.keys.map(_.executorId).toArray + hostLocalDirManager.getHostLocalDirs(execIdsWithoutDirs) { + case Success(dirs) => + immutableHostLocalBlocksWithoutDirs.foreach { case (hostLocalBmId, blockInfos) => + blockInfos.takeWhile { case (blockId, _, mapIndex) => + fetchHostLocalBlock( + blockId, + mapIndex, + dirs.get(hostLocalBmId.executorId), + hostLocalBmId) + } + } + logDebug(s"Got host-local blocks (without cached executors' dir) in " + + s"${Utils.getUsedTimeNs(startTimeNs)}") + + case Failure(throwable) => + logError(s"Error occurred while fetching host local blocks", throwable) + val (hostLocalBmId, blockInfoSeq) = immutableHostLocalBlocksWithoutDirs.head + val (blockId, _, mapIndex) = blockInfoSeq.head + results.put(FailureFetchResult(blockId, mapIndex, hostLocalBmId, throwable)) + } + } + if (hostLocalBlocksWithCachedDirs.nonEmpty) { + logDebug(s"Synchronous fetching host-local blocks with cached executors' dir: " + + s"${hostLocalBlocksWithCachedDirs.mkString(", ")}") + hostLocalBlocksWithCachedDirs.foreach { case (_, blockInfos, localDirs) => + blockInfos.foreach { case (blockId, _, mapIndex) => + if (!fetchHostLocalBlock(blockId, mapIndex, localDirs.get, bmId)) { + return + } + } + } + logDebug(s"Got host-local blocks (with cached executors' dir) in " + + s"${Utils.getUsedTimeNs(startTimeNs)}") + } + } + private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. context.addTaskCompletionListener(onCompleteCallback) - // Split local and remote blocks. - val remoteRequests = splitLocalRemoteBlocks() + // Partition blocks by the different fetch modes: local, host-local and remote blocks. + val remoteRequests = partitionBlocksByFetchMode() // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -441,6 +544,10 @@ final class ShuffleBlockFetcherIterator( // Get Local Blocks fetchLocalBlocks() logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + + if (hostLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks) + } } override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch @@ -476,15 +583,18 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 - shuffleMetrics.incRemoteBytesRead(buf.size) - if (buf.isInstanceOf[FileSegmentManagedBuffer]) { - shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + if (hostLocalBlocks.contains(blockId -> mapIndex)) { + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size } - shuffleMetrics.incRemoteBlocksFetched(1) - } - if (!localBlocks.contains((blockId, mapIndex))) { - bytesInFlight -= size } if (isNetworkReqDone) { reqsInFlight -= 1 diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 8844a0598ccb..c217419f4092 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -68,6 +68,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi // This test ensures that the external shuffle service is actually in use for the other tests. test("using external shuffle service") { sc = new SparkContext("local-cluster[2,1,1024]", "test", conf) + sc.getConf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) should equal(false) sc.env.blockManager.externalShuffleServiceEnabled should equal(true) sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient]) @@ -79,7 +80,9 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi // Therefore, we should wait until all slaves are up TestUtils.waitUntilExecutorsUp(sc, 2, 60000) - val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _) + val rdd = sc.parallelize(0 until 1000, 10) + .map { i => (i, 1) } + .reduceByKey(_ + _) rdd.count() rdd.count() @@ -96,6 +99,50 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll wi e.getMessage should include ("Fetch failure will not retry stage due to testing config") } + test("SPARK-27651: read host local shuffle blocks from disk and avoid network remote fetches") { + val confWithHostLocalRead = + conf.clone.set(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED, true) + confWithHostLocalRead.set(config.STORAGE_LOCAL_DISK_BY_EXECUTORS_CACHE_SIZE, 5) + sc = new SparkContext("local-cluster[2,1,1024]", "test", confWithHostLocalRead) + sc.getConf.get(config.SHUFFLE_HOST_LOCAL_DISK_READING_ENABLED) should equal(true) + sc.env.blockManager.externalShuffleServiceEnabled should equal(true) + sc.env.blockManager.hostLocalDirManager.isDefined should equal(true) + sc.env.blockManager.blockStoreClient.getClass should equal(classOf[ExternalBlockStoreClient]) + + // In a slow machine, one slave may register hundreds of milliseconds ahead of the other one. + // If we don't wait for all slaves, it's possible that only one executor runs all jobs. Then + // all shuffle blocks will be in this executor, ShuffleBlockFetcherIterator will directly fetch + // local blocks from the local BlockManager and won't send requests to ExternalShuffleService. + // In this case, we won't receive FetchFailed. And it will make this test fail. + // Therefore, we should wait until all slaves are up + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + + val rdd = sc.parallelize(0 until 1000, 10) + .map { i => (i, 1) } + .reduceByKey(_ + _) + + rdd.count() + rdd.count() + + val cachedExecutors = rdd.mapPartitions { _ => + SparkEnv.get.blockManager.hostLocalDirManager.map { localDirManager => + localDirManager.getCachedHostLocalDirs().keySet.iterator + }.getOrElse(Iterator.empty) + }.collect().toSet + + // both executors are caching the dirs of the other one + cachedExecutors should equal(sc.getExecutorIds().toSet) + + // Invalidate the registered executors, disallowing access to their shuffle blocks (without + // deleting the actual shuffle files, so we could access them without the shuffle service). + // As directories are already cached there is no request to external shuffle service. + rpcHandler.applicationRemoved(sc.conf.getAppId, false /* cleanupLocalDirs */) + + // Now Spark will not receive FetchFailed as host local blocks are read from the cached local + // disk directly + rdd.collect().map(_._2).sum should equal(1000) + } + test("SPARK-25888: using external shuffle service fetching disk persisted blocks") { val confWithRddFetchEnabled = conf.clone.set(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, true) sc = new SparkContext("local-cluster[1,1,1024]", "test", confWithRddFetchEnabled) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index e05fad19567a..c726329ce8a8 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -122,7 +122,7 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi val blockString = "Hello, world!" val blockBuffer = new NioManagedBuffer(ByteBuffer.wrap( blockString.getBytes(StandardCharsets.UTF_8))) - when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer) + when(blockManager.getLocalBlockData(blockId)).thenReturn(blockBuffer) val securityManager0 = new SecurityManager(conf0) val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", "localhost", 0, diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 3f9536e224de..a82f86a11c77 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.shuffle import java.io.{ByteArrayOutputStream, InputStream} import java.nio.ByteBuffer +import org.mockito.ArgumentMatchers.{eq => meq} import org.mockito.Mockito.{mock, when} import org.apache.spark._ @@ -95,7 +96,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.getLocalBlockData(meq(shuffleBlockId))).thenReturn(managedBuffer) managedBuffer } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala index 49cbd66cccb8..01e3d6a46e70 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala @@ -31,7 +31,6 @@ class BlockManagerInfoSuite extends SparkFunSuite { val bmInfo = new BlockManagerInfo( BlockManagerId("executor0", "host", 1234, None), timeMs = 300, - Array(), maxOnHeapMem = 10000, maxOffHeapMem = 20000, slaveEndpoint = null, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 8595f73fe5dd..89f00b5a9d90 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -654,7 +654,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // check getRemoteBytes val bytesViaStore1 = cleanBm.getRemoteBytes(blockId) assert(bytesViaStore1.isDefined) - val expectedContent = sameHostBm.getBlockData(blockId).nioByteBuffer().array() + val expectedContent = sameHostBm.getLocalBlockData(blockId).nioByteBuffer().array() assert(bytesViaStore1.get.toArray === expectedContent) // check getRemoteValues @@ -1095,7 +1095,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockStatus = blockStatusOption.get assert((blockStatus.diskSize > 0) === !storageLevel.useMemory) assert((blockStatus.memSize > 0) === storageLevel.useMemory) - assert(blockManager.getBlockData(blockId).nioByteBuffer().array() === ser) + assert(blockManager.getLocalBlockData(blockId).nioByteBuffer().array() === ser) } Seq( diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 85b1a865603a..cf4c292d4ca9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.io._ import java.nio.ByteBuffer import java.util.UUID -import java.util.concurrent.Semaphore +import java.util.concurrent.{CompletableFuture, Semaphore} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -65,6 +65,29 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT transfer } + private def initHostLocalDirManager( + blockManager: BlockManager, + hostLocalDirs: Map[String, Array[String]]): Unit = { + val mockExternalBlockStoreClient = mock(classOf[ExternalBlockStoreClient]) + val hostLocalDirManager = new HostLocalDirManager( + futureExecutionContext = global, + cacheSize = 1, + externalBlockStoreClient = mockExternalBlockStoreClient, + host = "localhost", + externalShuffleServicePort = 7337) + + when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) + when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) + .thenAnswer { invocation => + val completableFuture = invocation.getArguments()(3) + .asInstanceOf[CompletableFuture[java.util.Map[String, Array[String]]]] + import scala.collection.JavaConverters._ + completableFuture.complete(hostLocalDirs.asJava) + } + + blockManager.hostLocalDirManager = Some(hostLocalDirManager) + } + // Create a mock managed buffer for testing def createMockManagedBuffer(size: Int = 1): ManagedBuffer = { val mockManagedBuffer = mock(classOf[ManagedBuffer]) @@ -76,9 +99,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT mockManagedBuffer } - test("successful 3 local reads + 2 remote reads") { + def verifyBufferRelease(buffer: ManagedBuffer, inputStream: InputStream): Unit = { + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] + verify(buffer, times(0)).release() + val delegateAccess = PrivateMethod[InputStream](Symbol("delegate")) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(buffer, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent + verify(buffer, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + } + + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) + val localBmId = BlockManagerId("test-local-client", "test-local-host", 1) doReturn(localBmId).when(blockManager).blockManagerId // Make sure blockManager.getBlockData would return the blocks @@ -87,20 +125,38 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) } // Make sure remote blocks would return - val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) + // Create a block manager running on the same host (host-local) + val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) + val hostLocalBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 5, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 6, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 7, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 8, 0) -> createMockManagedBuffer()) + + hostLocalBlocks.foreach { case (blockId, buf) => + doReturn(buf) + .when(blockManager) + .getHostLocalShuffleData(meq(blockId.asInstanceOf[ShuffleBlockId]), any()) + } + val hostLocalDirs = Map("test-host-local-client-1" -> Array("local-dir")) + // returning local dir for hostLocalBmId + initHostLocalDirManager(blockManager, hostLocalDirs) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq), - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq), + (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator val taskContext = TaskContext.empty() @@ -121,37 +177,86 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT false) // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getLocalBlockData(any()) - for (i <- 0 until 5) { - assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") + val allBlocks = localBlocks ++ remoteBlocks ++ hostLocalBlocks + for (i <- 0 until allBlocks.size) { + assert(iterator.hasNext, + s"iterator should have ${allBlocks.size} elements but actually has $i elements") val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. - val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) - // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream - val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] - verify(mockBuf, times(0)).release() - val delegateAccess = PrivateMethod[InputStream](Symbol("delegate")) - - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() - wrappedInputStream.close() - verify(mockBuf, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() - wrappedInputStream.close() // close should be idempotent - verify(mockBuf, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + val mockBuf = allBlocks(blockId) + verifyBufferRelease(mockBuf, inputStream) } - // 3 local blocks, and 2 remote blocks - // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(3)).getBlockData(any()) + // 4 host-local locks fetched + verify(blockManager, times(4)) + .getHostLocalShuffleData(any(), meq(Array("local-dir"))) + + // 2 remote blocks are read from the same block manager verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1) } - test("fetch continuous blocks in batch successful 3 local reads + 2 remote reads") { + test("error during accessing host local dirs for executors") { val blockManager = mock(classOf[BlockManager]) - val localBmId = BlockManagerId("test-client", "test-client", 1) + val localBmId = BlockManagerId("test-local-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + val hostLocalBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()) + + hostLocalBlocks.foreach { case (blockId, buf) => + doReturn(buf) + .when(blockManager) + .getHostLocalShuffleData(meq(blockId.asInstanceOf[ShuffleBlockId]), any()) + } + val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) + + val mockExternalBlockStoreClient = mock(classOf[ExternalBlockStoreClient]) + val hostLocalDirManager = new HostLocalDirManager( + futureExecutionContext = global, + cacheSize = 1, + externalBlockStoreClient = mockExternalBlockStoreClient, + host = "localhost", + externalShuffleServicePort = 7337) + + when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) + when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) + .thenAnswer { invocation => + val completableFuture = invocation.getArguments()(3) + .asInstanceOf[CompletableFuture[java.util.Map[String, Array[String]]]] + completableFuture.completeExceptionally(new Throwable("failed fetch")) + } + + blockManager.hostLocalDirManager = Some(hostLocalDirManager) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) + ).toIterator + + val transfer = createMockTransfer(Map()) + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true, + false, + metrics, + false) + intercept[FetchFailedException] { iterator.next() } + } + + test("fetch continuous blocks in batch successful 3 local + 4 host local + 2 remote reads") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-local-host", 1) doReturn(localBmId).when(blockManager).blockManagerId // Make sure blockManager.getBlockData would return the merged block @@ -162,7 +267,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val mergedLocalBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockBatchId(0, 0, 0, 3) -> createMockManagedBuffer()) mergedLocalBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getLocalBlockData(meq(blockId)) } // Make sure remote blocks would return the merged block @@ -174,9 +279,29 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer()) val transfer = createMockTransfer(mergedRemoteBlocks) + // Create a block manager running on the same host (host-local) + val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) + val hostLocalBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 1) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 3) -> createMockManagedBuffer()) + val mergedHostLocalBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 4, 0, 4) -> createMockManagedBuffer()) + + mergedHostLocalBlocks.foreach { case (blockId, buf) => + doReturn(buf) + .when(blockManager) + .getHostLocalShuffleData(meq(blockId.asInstanceOf[ShuffleBlockBatchId]), any()) + } + val hostLocalDirs = Map("test-host-local-client-1" -> Array("local-dir")) + // returning local dir for hostLocalBmId + initHostLocalDirManager(blockManager, hostLocalDirs) + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (localBmId, localBlocks.map(blockId => (blockId, 1L, 0))), - (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))) + (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))), + (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) ).toIterator val taskContext = TaskContext.empty() @@ -197,32 +322,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT true) // 3 local blocks batch fetched in initialization - verify(blockManager, times(1)).getBlockData(any()) + verify(blockManager, times(1)).getLocalBlockData(any()) - for (i <- 0 until 2) { - assert(iterator.hasNext, s"iterator should have 2 elements but actually has $i elements") + val allBlocks = mergedLocalBlocks ++ mergedRemoteBlocks ++ mergedHostLocalBlocks + for (i <- 0 until 3) { + assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) // Make sure we release buffers when a wrapped input stream is closed. - val mockBuf = mergedLocalBlocks.getOrElse(blockId, mergedRemoteBlocks(blockId)) - // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream - val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] - verify(mockBuf, times(0)).release() - val delegateAccess = PrivateMethod[InputStream]('delegate) - - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() - wrappedInputStream.close() - verify(mockBuf, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() - wrappedInputStream.close() // close should be idempotent - verify(mockBuf, times(1)).release() - verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + val mockBuf = allBlocks(blockId) + verifyBufferRelease(mockBuf, inputStream) } - // 2 remote blocks batch fetched - // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(1)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + // 4 host-local locks fetched + verify(blockManager, times(1)) + .getHostLocalShuffleData(any(), meq(Array("local-dir"))) + + assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size === 1) } test("release current unexhausted buffer in case the task completes early") { @@ -546,7 +662,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId - doReturn(managedBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0, 0, 0))) val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]]( (ShuffleBlockId(0, 0, 0), 10000, 0) )