diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index 43c3d23b6304..94412c4db559 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -91,6 +91,9 @@ protected void channelRead0( try { streamManager.checkAuthorization(client, msg.streamChunkId.streamId); buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); + if (buf == null) { + throw new IllegalStateException("Chunk was not found"); + } } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", msg.streamChunkId, getRemoteAddress(channel)), e); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 6fafcc131fa2..67f64d796203 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -125,7 +125,10 @@ public void connectionTerminated(Channel channel) { // Release all remaining buffers. while (state.buffers.hasNext()) { - state.buffers.next().release(); + ManagedBuffer buffer = state.buffers.next(); + if (buffer != null) { + buffer.release(); + } } } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java index 6c9239606bb8..7e30ed4048ca 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java @@ -23,6 +23,7 @@ import io.netty.channel.Channel; import org.apache.spark.network.server.ChunkFetchRequestHandler; +import org.junit.Assert; import org.junit.Test; import static org.mockito.Mockito.*; @@ -45,9 +46,8 @@ public void handleChunkFetchRequest() throws Exception { Channel channel = mock(Channel.class); ChannelHandlerContext context = mock(ChannelHandlerContext.class); when(context.channel()) - .thenAnswer(invocationOnMock0 -> { - return channel; - }); + .thenAnswer(invocationOnMock0 -> channel); + List> responseAndPromisePairs = new ArrayList<>(); when(channel.writeAndFlush(any())) @@ -62,6 +62,7 @@ public void handleChunkFetchRequest() throws Exception { List managedBuffers = new ArrayList<>(); managedBuffers.add(new TestManagedBuffer(10)); managedBuffers.add(new TestManagedBuffer(20)); + managedBuffers.add(null); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); @@ -71,31 +72,40 @@ public void handleChunkFetchRequest() throws Exception { RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0)); requestHandler.channelRead(context, request0); - assert responseAndPromisePairs.size() == 1; - assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() == - managedBuffers.get(0); + Assert.assertEquals(1, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess); + Assert.assertEquals(managedBuffers.get(0), + ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body()); RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1)); requestHandler.channelRead(context, request1); - assert responseAndPromisePairs.size() == 2; - assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() == - managedBuffers.get(1); + Assert.assertEquals(2, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess); + Assert.assertEquals(managedBuffers.get(1), + ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body()); // Finish flushing the response for request0. responseAndPromisePairs.get(0).getRight().finish(true); RequestMessage request2 = new ChunkFetchRequest(new StreamChunkId(streamId, 2)); requestHandler.channelRead(context, request2); - assert responseAndPromisePairs.size() == 3; - assert responseAndPromisePairs.get(2).getLeft() instanceof ChunkFetchSuccess; - assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(2).getLeft())).body() == - managedBuffers.get(2); + Assert.assertEquals(3, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(2).getLeft() instanceof ChunkFetchFailure); + ChunkFetchFailure chunkFetchFailure = + ((ChunkFetchFailure) (responseAndPromisePairs.get(2).getLeft())); + Assert.assertEquals("java.lang.IllegalStateException: Chunk was not found", + chunkFetchFailure.errorString.split("\\r?\\n")[0]); RequestMessage request3 = new ChunkFetchRequest(new StreamChunkId(streamId, 3)); requestHandler.channelRead(context, request3); + Assert.assertEquals(4, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(3).getLeft() instanceof ChunkFetchSuccess); + Assert.assertEquals(managedBuffers.get(3), + ((ChunkFetchSuccess) (responseAndPromisePairs.get(3).getLeft())).body()); + + RequestMessage request4 = new ChunkFetchRequest(new StreamChunkId(streamId, 4)); + requestHandler.channelRead(context, request4); verify(channel, times(1)).close(); - assert responseAndPromisePairs.size() == 3; + Assert.assertEquals(4, responseAndPromisePairs.size()); } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index a87f6c11a2bf..a43a65904868 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -21,6 +21,7 @@ import java.util.List; import io.netty.channel.Channel; +import org.junit.Assert; import org.junit.Test; import static org.mockito.Mockito.*; @@ -38,7 +39,7 @@ public class TransportRequestHandlerSuite { @Test - public void handleStreamRequest() throws Exception { + public void handleStreamRequest() { RpcHandler rpcHandler = new NoOpRpcHandler(); OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager()); Channel channel = mock(Channel.class); @@ -56,11 +57,12 @@ public void handleStreamRequest() throws Exception { List managedBuffers = new ArrayList<>(); managedBuffers.add(new TestManagedBuffer(10)); managedBuffers.add(new TestManagedBuffer(20)); + managedBuffers.add(null); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); - assert streamManager.numStreamStates() == 1; + Assert.assertEquals(1, streamManager.numStreamStates()); TransportClient reverseClient = mock(TransportClient.class); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, @@ -68,36 +70,43 @@ public void handleStreamRequest() throws Exception { RequestMessage request0 = new StreamRequest(String.format("%d_%d", streamId, 0)); requestHandler.handle(request0); - assert responseAndPromisePairs.size() == 1; - assert responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse; - assert ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body() == - managedBuffers.get(0); + Assert.assertEquals(1, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof StreamResponse); + Assert.assertEquals(managedBuffers.get(0), + ((StreamResponse) (responseAndPromisePairs.get(0).getLeft())).body()); RequestMessage request1 = new StreamRequest(String.format("%d_%d", streamId, 1)); requestHandler.handle(request1); - assert responseAndPromisePairs.size() == 2; - assert responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse; - assert ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body() == - managedBuffers.get(1); + Assert.assertEquals(2, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof StreamResponse); + Assert.assertEquals(managedBuffers.get(1), + ((StreamResponse) (responseAndPromisePairs.get(1).getLeft())).body()); // Finish flushing the response for request0. responseAndPromisePairs.get(0).getRight().finish(true); - RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2)); + StreamRequest request2 = new StreamRequest(String.format("%d_%d", streamId, 2)); requestHandler.handle(request2); - assert responseAndPromisePairs.size() == 3; - assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse; - assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() == - managedBuffers.get(2); + Assert.assertEquals(3, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(2).getLeft() instanceof StreamFailure); + Assert.assertEquals(String.format("Stream '%s' was not found.", request2.streamId), + ((StreamFailure) (responseAndPromisePairs.get(2).getLeft())).error); - // Request3 will trigger the close of channel, because the number of max chunks being - // transferred is 2; RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3)); requestHandler.handle(request3); + Assert.assertEquals(4, responseAndPromisePairs.size()); + Assert.assertTrue(responseAndPromisePairs.get(3).getLeft() instanceof StreamResponse); + Assert.assertEquals(managedBuffers.get(3), + ((StreamResponse) (responseAndPromisePairs.get(3).getLeft())).body()); + + // Request4 will trigger the close of channel, because the number of max chunks being + // transferred is 2; + RequestMessage request4 = new StreamRequest(String.format("%d_%d", streamId, 4)); + requestHandler.handle(request4); verify(channel, times(1)).close(); - assert responseAndPromisePairs.size() == 3; + Assert.assertEquals(4, responseAndPromisePairs.size()); streamManager.connectionTerminated(channel); - assert streamManager.numStreamStates() == 0; + Assert.assertEquals(0, streamManager.numStreamStates()); } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index 4248762c3238..fb3503b783e5 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -21,6 +21,8 @@ import java.util.List; import io.netty.channel.Channel; +import org.junit.After; +import org.junit.Assert; import org.junit.Test; import org.mockito.Mockito; @@ -29,23 +31,69 @@ public class OneForOneStreamManagerSuite { + List managedBuffersToRelease = new ArrayList<>(); + + @After + public void tearDown() { + managedBuffersToRelease.forEach(managedBuffer -> managedBuffer.release()); + managedBuffersToRelease.clear(); + } + + private ManagedBuffer getChunk(OneForOneStreamManager manager, long streamId, int chunkIndex) { + ManagedBuffer chunk = manager.getChunk(streamId, chunkIndex); + if (chunk != null) { + managedBuffersToRelease.add(chunk); + } + return chunk; + } + @Test - public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { + public void testMissingChunk() { OneForOneStreamManager manager = new OneForOneStreamManager(); List buffers = new ArrayList<>(); TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); + TestManagedBuffer buffer3 = Mockito.spy(new TestManagedBuffer(20)); + buffers.add(buffer1); + // the nulls here are to simulate a file which goes missing before being read, + // just as a defensive measure + buffers.add(null); buffers.add(buffer2); + buffers.add(null); + buffers.add(buffer3); Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerStream("appId", buffers.iterator(), dummyChannel); - assert manager.numStreamStates() == 1; + long streamId = manager.registerStream("appId", buffers.iterator(), dummyChannel); + Assert.assertEquals(1, manager.numStreamStates()); + Assert.assertNotNull(getChunk(manager, streamId, 0)); + Assert.assertNull(getChunk(manager, streamId, 1)); + Assert.assertNotNull(getChunk(manager, streamId, 2)); + manager.connectionTerminated(dummyChannel); + + // loaded buffers are not released yet as in production a MangedBuffer returned by getChunk() + // would only be released by Netty after it is written to the network + Mockito.verify(buffer1, Mockito.never()).release(); + Mockito.verify(buffer2, Mockito.never()).release(); + Mockito.verify(buffer3, Mockito.times(1)).release(); + } + @Test + public void managedBuffersAreFreedWhenConnectionIsClosed() { + OneForOneStreamManager manager = new OneForOneStreamManager(); + List buffers = new ArrayList<>(); + TestManagedBuffer buffer1 = Mockito.spy(new TestManagedBuffer(10)); + TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); + buffers.add(buffer1); + buffers.add(buffer2); + + Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); + manager.registerStream("appId", buffers.iterator(), dummyChannel); + Assert.assertEquals(1, manager.numStreamStates()); manager.connectionTerminated(dummyChannel); Mockito.verify(buffer1, Mockito.times(1)).release(); Mockito.verify(buffer2, Mockito.times(1)).release(); - assert manager.numStreamStates() == 0; + Assert.assertEquals(0, manager.numStreamStates()); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/Constants.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/Constants.java new file mode 100644 index 000000000000..01aca7efb12b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/Constants.java @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +public class Constants { + + public static final String SHUFFLE_SERVICE_FETCH_RDD_ENABLED = + "spark.shuffle.service.fetch.rdd.enabled"; +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 70dcc8b8b8b6..ba9d657d0e56 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.function.Function; import com.codahale.metrics.Gauge; import com.codahale.metrics.Meter; @@ -48,9 +49,9 @@ /** * RPC Handler for a server which can serve shuffle blocks from outside of an Executor process. * - * Handles registering executors and opening shuffle blocks from them. Shuffle blocks are registered - * with the "one-for-one" strategy, meaning each Transport-layer Chunk is equivalent to one Spark- - * level shuffle block. + * Handles registering executors and opening shuffle or disk persisted RDD blocks from them. + * Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk + * is equivalent to one block. */ public class ExternalShuffleBlockHandler extends RpcHandler { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockHandler.class); @@ -99,11 +100,12 @@ protected void handleMessage( long streamId = streamManager.registerStream(client.getClientId(), new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel()); if (logger.isTraceEnabled()) { - logger.trace("Registered streamId {} with {} buffers for client {} from host {}", - streamId, - msg.blockIds.length, - client.getClientId(), - getRemoteAddress(client.getChannel())); + logger.trace( + "Registered streamId {} with {} buffers for client {} from host {}", + streamId, + msg.blockIds.length, + client.getClientId(), + getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); } finally { @@ -122,6 +124,12 @@ protected void handleMessage( responseDelayContext.stop(); } + } else if (msgObj instanceof RemoveBlocks) { + RemoveBlocks msg = (RemoveBlocks) msgObj; + checkAuth(client, msg.appId); + int numRemovedBlocks = blockManager.removeBlocks(msg.appId, msg.execId, msg.blockIds); + callback.onSuccess(new BlocksRemoved(numRemovedBlocks).toByteBuffer()); + } else { throw new UnsupportedOperationException("Unexpected message: " + msgObj); } @@ -213,21 +221,42 @@ public Map getMetrics() { private class ManagedBufferIterator implements Iterator { private int index = 0; - private final String appId; - private final String execId; - private final int shuffleId; - // An array containing mapId and reduceId pairs. - private final int[] mapIdAndReduceIds; - - ManagedBufferIterator(String appId, String execId, String[] blockIds) { - this.appId = appId; - this.execId = execId; + private final Function blockDataForIndexFn; + private final int size; + + ManagedBufferIterator(final String appId, final String execId, String[] blockIds) { String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { - throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); + if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) { + final int shuffleId = Integer.parseInt(blockId0Parts[1]); + final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); + size = mapIdAndReduceIds.length; + blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId, + mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + } else if (blockId0Parts.length == 3 && blockId0Parts[0].equals("rdd")) { + final int[] rddAndSplitIds = rddAndSplitIds(blockIds); + size = rddAndSplitIds.length; + blockDataForIndexFn = index -> blockManager.getRddBlockData(appId, execId, + rddAndSplitIds[index], rddAndSplitIds[index + 1]); + } else { + throw new IllegalArgumentException("Unexpected block id format: " + blockIds[0]); + } + } + + private int[] rddAndSplitIds(String[] blockIds) { + final int[] rddAndSplitIds = new int[2 * blockIds.length]; + for (int i = 0; i < blockIds.length; i++) { + String[] blockIdParts = blockIds[i].split("_"); + if (blockIdParts.length != 3 || !blockIdParts[0].equals("rdd")) { + throw new IllegalArgumentException("Unexpected RDD block id format: " + blockIds[i]); + } + rddAndSplitIds[2 * i] = Integer.parseInt(blockIdParts[1]); + rddAndSplitIds[2 * i + 1] = Integer.parseInt(blockIdParts[2]); } - this.shuffleId = Integer.parseInt(blockId0Parts[1]); - mapIdAndReduceIds = new int[2 * blockIds.length]; + return rddAndSplitIds; + } + + private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { + final int[] mapIdAndReduceIds = new int[2 * blockIds.length]; for (int i = 0; i < blockIds.length; i++) { String[] blockIdParts = blockIds[i].split("_"); if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { @@ -240,17 +269,17 @@ private class ManagedBufferIterator implements Iterator { mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); } + return mapIdAndReduceIds; } @Override public boolean hasNext() { - return index < mapIdAndReduceIds.length; + return index < size; } @Override public ManagedBuffer next() { - final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, - mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + final ManagedBuffer block = blockDataForIndexFn.apply(index); index += 2; metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); return block; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369..6b6ca9243b62 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -86,6 +86,8 @@ public class ExternalShuffleBlockResolver { private final TransportConf conf; + private final boolean rddFetchEnabled; + @VisibleForTesting final File registeredExecutorFile; @VisibleForTesting @@ -109,6 +111,8 @@ public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorF File registeredExecutorFile, Executor directoryCleaner) throws IOException { this.conf = conf; + this.rddFetchEnabled = + Boolean.valueOf(conf.get(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, "true")); this.registeredExecutorFile = registeredExecutorFile; String indexCacheSize = conf.get("spark.shuffle.service.index.cache.size", "100m"); CacheLoader indexCacheLoader = @@ -179,6 +183,18 @@ public ManagedBuffer getBlockData( return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } + public ManagedBuffer getRddBlockData( + String appId, + String execId, + int rddId, + int splitIndex) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + return getDiskPersistedRddBlockData(executor, rddId, splitIndex); + } /** * Removes our metadata of all executors registered for the given application, and optionally * also deletes the local directories associated with the executors of that application in a @@ -217,22 +233,23 @@ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { } /** - * Removes all the non-shuffle files in any local directories associated with the finished - * executor. + * Removes all the files which cannot be served by the external shuffle service (non-shuffle and + * non-RDD files) in any local directories associated with the finished executor. */ public void executorRemoved(String executorId, String appId) { - logger.info("Clean up non-shuffle files associated with the finished executor {}", executorId); + logger.info("Clean up non-shuffle and non-RDD files associated with the finished executor {}", + executorId); AppExecId fullId = new AppExecId(appId, executorId); final ExecutorShuffleInfo executor = executors.get(fullId); if (executor == null) { // Executor not registered, skip clean up of the local directories. logger.info("Executor is not registered (appId={}, execId={})", appId, executorId); } else { - logger.info("Cleaning up non-shuffle files in executor {}'s {} local dirs", fullId, - executor.localDirs.length); + logger.info("Cleaning up non-shuffle and non-RDD files in executor {}'s {} local dirs", + fullId, executor.localDirs.length); // Execute the actual deletion in a different thread, as it may take some time. - directoryCleaner.execute(() -> deleteNonShuffleFiles(executor.localDirs)); + directoryCleaner.execute(() -> deleteNonShuffleServiceServedFiles(executor.localDirs)); } } @@ -252,24 +269,24 @@ private void deleteExecutorDirs(String[] dirs) { } /** - * Synchronously deletes non-shuffle files in each directory recursively. + * Synchronously deletes files not served by shuffle service in each directory recursively. * Should be executed in its own thread, as this may take a long time. */ - private void deleteNonShuffleFiles(String[] dirs) { - FilenameFilter filter = new FilenameFilter() { - @Override - public boolean accept(File dir, String name) { - // Don't delete shuffle data or shuffle index files. - return !name.endsWith(".index") && !name.endsWith(".data"); - } + private void deleteNonShuffleServiceServedFiles(String[] dirs) { + FilenameFilter filter = (dir, name) -> { + // Don't delete shuffle data, shuffle index files or cached RDD files. + return !name.endsWith(".index") && !name.endsWith(".data") + && (!rddFetchEnabled || !name.startsWith("rdd_")); }; for (String localDir : dirs) { try { JavaUtils.deleteRecursively(new File(localDir), filter); - logger.debug("Successfully cleaned up non-shuffle files in directory: {}", localDir); + logger.debug("Successfully cleaned up files not served by shuffle service in directory: {}", + localDir); } catch (Exception e) { - logger.error("Failed to delete non-shuffle files in directory: " + localDir, e); + logger.error("Failed to delete files not served by shuffle service in directory: " + + localDir, e); } } } @@ -298,6 +315,18 @@ private ManagedBuffer getSortBasedShuffleBlockData( } } + public ManagedBuffer getDiskPersistedRddBlockData( + ExecutorShuffleInfo executor, int rddId, int splitIndex) { + File file = getFile(executor.localDirs, executor.subDirsPerLocalDir, + "rdd_" + rddId + "_" + splitIndex); + long fileLength = file.length(); + ManagedBuffer res = null; + if (file.exists()) { + res = new FileSegmentManagedBuffer(conf, file, 0, fileLength); + } + return res; + } + /** * Hashes a filename into the corresponding local directory, in a manner consistent with * Spark's DiskBlockManager.getFile(). @@ -343,6 +372,24 @@ static String createNormalizedInternedPathname(String dir1, String dir2, String return pathname.intern(); } + public int removeBlocks(String appId, String execId, String[] blockIds) { + ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); + if (executor == null) { + throw new RuntimeException( + String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); + } + int numRemovedBlocks = 0; + for (String blockId : blockIds) { + File file = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); + if (file.delete()) { + numRemovedBlocks++; + } else { + logger.warn("Failed to delete block: " + file.getAbsolutePath()); + } + } + return numRemovedBlocks; + } + /** Simply encodes an executor's full ID, which is appId + execId. */ public static class AppExecId { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index e49e27ab5aa7..0e11d2124ada 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -19,10 +19,15 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.Future; import com.codahale.metrics.MetricSet; import com.google.common.collect.Lists; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.shuffle.protocol.*; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -33,8 +38,6 @@ import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; -import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.util.TransportConf; /** @@ -73,7 +76,10 @@ protected void checkInit() { assert appId != null : "Called before init()"; } - @Override + /** + * Initializes the ShuffleClient, specifying this Executor's appId. + * Must be called before any other method on the ShuffleClient. + */ public void init(String appId) { this.appId = appId; TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); @@ -139,12 +145,40 @@ public void registerWithShuffleServer( String execId, ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { checkInit(); - try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { + try (TransportClient client = clientFactory.createClient(host, port)) { ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, registrationTimeoutMs); } } + public Future removeBlocks( + String host, + int port, + String execId, + String[] blockIds) throws IOException, InterruptedException { + checkInit(); + CompletableFuture numRemovedBlocksFuture = new CompletableFuture<>(); + ByteBuffer removeBlocksMessage = new RemoveBlocks(appId, execId, blockIds).toByteBuffer(); + final TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(removeBlocksMessage, new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(response); + numRemovedBlocksFuture.complete(((BlocksRemoved)msgObj).numRemovedBlocks); + client.close(); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Error trying to remove RDD blocks " + Arrays.toString(blockIds) + + " via external shuffle service from executor: " + execId, e); + numRemovedBlocksFuture.complete(0); + client.close(); + } + }); + return numRemovedBlocksFuture; + } + @Override public void close() { checkInit(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 62b99c40f61f..0be5cf5ad922 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -25,12 +25,6 @@ /** Provides an interface for reading shuffle files, either from an Executor or external service. */ public abstract class ShuffleClient implements Closeable { - /** - * Initializes the ShuffleClient, specifying this Executor's appId. - * Must be called before any other method on the ShuffleClient. - */ - public void init(String appId) { } - /** * Fetch a sequence of blocks from a remote node asynchronously, * diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a68a297519b6..e09775644963 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -35,6 +35,7 @@ * shuffle service. It returns a StreamHandle. * - UploadBlock is only handled by the NettyBlockTransferService. * - RegisterExecutor is only handled by the external shuffle service. + * - RemoveBlocks is only handled by the external shuffle service. */ public abstract class BlockTransferMessage implements Encodable { protected abstract Type type(); @@ -42,7 +43,7 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8); private final byte id; @@ -68,6 +69,8 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); + case 7: return RemoveBlocks.decode(buf); + case 8: return BlocksRemoved.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java new file mode 100644 index 000000000000..3f04443871b6 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlocksRemoved.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** The reply to remove blocks giving back the number of removed blocks. */ +public class BlocksRemoved extends BlockTransferMessage { + public final int numRemovedBlocks; + + public BlocksRemoved(int numRemovedBlocks) { + this.numRemovedBlocks = numRemovedBlocks; + } + + @Override + protected Type type() { return Type.BLOCKS_REMOVED; } + + @Override + public int hashCode() { + return Objects.hashCode(numRemovedBlocks); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("numRemovedBlocks", numRemovedBlocks) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof BlocksRemoved) { + BlocksRemoved o = (BlocksRemoved) other; + return Objects.equal(numRemovedBlocks, o.numRemovedBlocks); + } + return false; + } + + @Override + public int encodedLength() { + return 4; + } + + @Override + public void encode(ByteBuf buf) { + buf.writeInt(numRemovedBlocks); + } + + public static BlocksRemoved decode(ByteBuf buf) { + int numRemovedBlocks = buf.readInt(); + return new BlocksRemoved(numRemovedBlocks); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java new file mode 100644 index 000000000000..1c718d307753 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RemoveBlocks.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +import java.util.Arrays; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** Request to remove a set of blocks. */ +public class RemoveBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public RemoveBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.REMOVE_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof RemoveBlocks) { + RemoveBlocks o = (RemoveBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static RemoveBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new RemoveBlocks(appId, execId, blockIds); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java new file mode 100644 index 000000000000..e38442327e22 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/CleanupNonShuffleServiceServedFilesSuite.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.*; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.MoreExecutors; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +public class CleanupNonShuffleServiceServedFilesSuite { + + // Same-thread Executor used to ensure cleanup happens synchronously in test thread. + private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); + + private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + + private static Set expectedShuffleFilesToKeep = + ImmutableSet.of("shuffle_782_450_0.index", "shuffle_782_450_0.data"); + + private static Set expectedShuffleAndRddFilesToKeep = + ImmutableSet.of("shuffle_782_450_0.index", "shuffle_782_450_0.data", "rdd_12_34"); + + private TransportConf getConf(boolean isFetchRddEnabled) { + return new TransportConf( + "shuffle", + new MapConfigProvider(ImmutableMap.of( + Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED, + Boolean.toString(isFetchRddEnabled)))); + } + + @Test + public void cleanupOnRemovedExecutorWithFilesToKeepFetchRddEnabled() throws IOException { + cleanupOnRemovedExecutor(true, getConf(true), expectedShuffleAndRddFilesToKeep); + } + + @Test + public void cleanupOnRemovedExecutorWithFilesToKeepFetchRddDisabled() throws IOException { + cleanupOnRemovedExecutor(true, getConf(false), expectedShuffleFilesToKeep); + } + + @Test + public void cleanupOnRemovedExecutorWithoutFilesToKeep() throws IOException { + cleanupOnRemovedExecutor(false, getConf(true), Collections.emptySet()); + } + + private void cleanupOnRemovedExecutor( + boolean withFilesToKeep, + TransportConf conf, + Set expectedFilesKept) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withFilesToKeep); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + resolver.executorRemoved("exec0", "app"); + + assertContainedFilenames(dataContext, expectedFilesKept); + } + + @Test + public void cleanupUsesExecutorWithFilesToKeep() throws IOException { + cleanupUsesExecutor(true); + } + + @Test + public void cleanupUsesExecutorWithoutFilesToKeep() throws IOException { + cleanupUsesExecutor(false); + } + + private void cleanupUsesExecutor(boolean withFilesToKeep) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withFilesToKeep); + + AtomicBoolean cleanupCalled = new AtomicBoolean(false); + + // Executor which only captures whether it's being used, without executing anything. + Executor dummyExecutor = runnable -> cleanupCalled.set(true); + + ExternalShuffleBlockResolver manager = + new ExternalShuffleBlockResolver(getConf(true), null, dummyExecutor); + + manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + manager.executorRemoved("exec0", "app"); + + assertTrue(cleanupCalled.get()); + assertStillThere(dataContext); + } + + @Test + public void cleanupOnlyRemovedExecutorWithFilesToKeepFetchRddEnabled() throws IOException { + cleanupOnlyRemovedExecutor(true, getConf(true), expectedShuffleAndRddFilesToKeep); + } + + @Test + public void cleanupOnlyRemovedExecutorWithFilesToKeepFetchRddDisabled() throws IOException { + cleanupOnlyRemovedExecutor(true, getConf(false), expectedShuffleFilesToKeep); + } + + @Test + public void cleanupOnlyRemovedExecutorWithoutFilesToKeep() throws IOException { + cleanupOnlyRemovedExecutor(false, getConf(true) , Collections.emptySet()); + } + + private void cleanupOnlyRemovedExecutor( + boolean withFilesToKeep, + TransportConf conf, + Set expectedFilesKept) throws IOException { + TestShuffleDataContext dataContext0 = initDataContext(withFilesToKeep); + TestShuffleDataContext dataContext1 = initDataContext(withFilesToKeep); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); + resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); + + + resolver.executorRemoved("exec-nonexistent", "app"); + assertStillThere(dataContext0); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec0", "app"); + assertContainedFilenames(dataContext0, expectedFilesKept); + assertStillThere(dataContext1); + + resolver.executorRemoved("exec1", "app"); + assertContainedFilenames(dataContext0, expectedFilesKept); + assertContainedFilenames(dataContext1, expectedFilesKept); + + // Make sure it's not an error to cleanup multiple times + resolver.executorRemoved("exec1", "app"); + assertContainedFilenames(dataContext0, expectedFilesKept); + assertContainedFilenames(dataContext1, expectedFilesKept); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithFilesToKeepFetchRddEnabled() throws IOException { + cleanupOnlyRegisteredExecutor(true, getConf(true), expectedShuffleAndRddFilesToKeep); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithFilesToKeepFetchRddDisabled() throws IOException { + cleanupOnlyRegisteredExecutor(true, getConf(false), expectedShuffleFilesToKeep); + } + + @Test + public void cleanupOnlyRegisteredExecutorWithoutFilesToKeep() throws IOException { + cleanupOnlyRegisteredExecutor(false, getConf(true), Collections.emptySet()); + } + + private void cleanupOnlyRegisteredExecutor( + boolean withFilesToKeep, + TransportConf conf, + Set expectedFilesKept) throws IOException { + TestShuffleDataContext dataContext = initDataContext(withFilesToKeep); + + ExternalShuffleBlockResolver resolver = + new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); + resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); + + resolver.executorRemoved("exec1", "app"); + assertStillThere(dataContext); + + resolver.executorRemoved("exec0", "app"); + assertContainedFilenames(dataContext, expectedFilesKept); + } + + private static void assertStillThere(TestShuffleDataContext dataContext) { + for (String localDir : dataContext.localDirs) { + assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); + } + } + + private static Set collectFilenames(File[] files) throws IOException { + Set result = new HashSet<>(); + for (File file : files) { + if (file.exists()) { + try (Stream walk = Files.walk(file.toPath())) { + result.addAll(walk + .filter(Files::isRegularFile) + .map(x -> x.toFile().getName()) + .collect(Collectors.toSet())); + } + } + } + return result; + } + + private static void assertContainedFilenames( + TestShuffleDataContext dataContext, + Set expectedFilenames) throws IOException { + Set collectedFilenames = new HashSet<>(); + for (String localDir : dataContext.localDirs) { + File[] dirs = new File[] { new File(localDir) }; + collectedFilenames.addAll(collectFilenames(dirs)); + } + assertEquals(expectedFilenames, collectedFilenames); + } + + private static TestShuffleDataContext initDataContext(boolean withFilesToKeep) + throws IOException { + TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); + dataContext.create(); + if (withFilesToKeep) { + createFilesToKeep(dataContext); + } else { + createRemovableTestFiles(dataContext); + } + return dataContext; + } + + private static void createFilesToKeep(TestShuffleDataContext dataContext) throws IOException { + Random rand = new Random(123); + dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { + "ABC".getBytes(StandardCharsets.UTF_8), + "DEF".getBytes(StandardCharsets.UTF_8)}); + dataContext.insertCachedRddData(12, 34, new byte[] { 42 }); + } + + private static void createRemovableTestFiles(TestShuffleDataContext dataContext) + throws IOException { + dataContext.insertSpillData(); + dataContext.insertBroadcastData(); + dataContext.insertTempShuffleData(); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 537c277cd26b..d51e14a66faf 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -49,6 +49,10 @@ public class ExternalShuffleBlockHandlerSuite { OneForOneStreamManager streamManager; ExternalShuffleBlockResolver blockResolver; RpcHandler handler; + ManagedBuffer[] blockMarkers = { + new NioManagedBuffer(ByteBuffer.wrap(new byte[3])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + }; @Before public void beforeEach() { @@ -76,20 +80,52 @@ public void testRegisterExecutor() { assertEquals(1, registerExecutorRequestLatencyMillis.getCount()); } - @SuppressWarnings("unchecked") @Test public void testOpenShuffleBlocks() { + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); + + checkOpenBlocksReceive(new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }, blockMarkers); + + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verifyOpenBlockLatencyMetrics(); + } + + @Test + public void testOpenDiskPersistedRDDBlocks() { + when(blockResolver.getRddBlockData("app0", "exec1", 0, 0)).thenReturn(blockMarkers[0]); + when(blockResolver.getRddBlockData("app0", "exec1", 0, 1)).thenReturn(blockMarkers[1]); + + checkOpenBlocksReceive(new String[] { "rdd_0_0", "rdd_0_1" }, blockMarkers); + + verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 0); + verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 1); + verifyOpenBlockLatencyMetrics(); + } + + @Test + public void testOpenDiskPersistedRDDBlocksWithMissingBlock() { + ManagedBuffer[] blockMarkersWithMissingBlock = { + new NioManagedBuffer(ByteBuffer.wrap(new byte[3])), + null + }; + when(blockResolver.getRddBlockData("app0", "exec1", 0, 0)) + .thenReturn(blockMarkersWithMissingBlock[0]); + when(blockResolver.getRddBlockData("app0", "exec1", 0, 1)) + .thenReturn(null); + + checkOpenBlocksReceive(new String[] { "rdd_0_0", "rdd_0_1" }, blockMarkersWithMissingBlock); + + verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 0); + verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 1); + } + + private void checkOpenBlocksReceive(String[] blockIds, ManagedBuffer[] blockMarkers) { when(client.getClientId()).thenReturn("app0"); RpcResponseCallback callback = mock(RpcResponseCallback.class); - - ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); - ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", - new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) - .toByteBuffer(); + ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", blockIds).toByteBuffer(); handler.receive(client, openBlocks, callback); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); @@ -106,13 +142,12 @@ public void testOpenShuffleBlocks() { verify(streamManager, times(1)).registerStream(anyString(), stream.capture(), any()); Iterator buffers = stream.getValue(); - assertEquals(block0Marker, buffers.next()); - assertEquals(block1Marker, buffers.next()); + assertEquals(blockMarkers[0], buffers.next()); + assertEquals(blockMarkers[1], buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + } - // Verify open block request latency metrics + private void verifyOpenBlockLatencyMetrics() { Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) .getAllMetrics() .getMetrics() diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index f5b1ec9d46da..55eac27da6b0 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -17,20 +17,25 @@ package org.apache.spark.network.shuffle; +import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Random; import java.util.Set; +import java.util.concurrent.Future; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.server.OneForOneStreamManager; import org.junit.After; import org.junit.AfterClass; import org.junit.BeforeClass; @@ -52,6 +57,13 @@ public class ExternalShuffleIntegrationSuite { private static final String APP_ID = "app-id"; private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; + private static final int RDD_ID = 1; + private static final int SPLIT_INDEX_VALID_BLOCK = 0; + private static final int SPLIT_INDEX_MISSING_FILE = 1; + private static final int SPLIT_INDEX_CORRUPT_LENGTH = 2; + private static final int SPLIT_INDEX_VALID_BLOCK_TO_RM = 3; + private static final int SPLIT_INDEX_MISSING_BLOCK_TO_RM = 4; + // Executor 0 is sort-based static TestShuffleDataContext dataContext0; @@ -60,6 +72,9 @@ public class ExternalShuffleIntegrationSuite { static TransportConf conf; static TransportContext transportContext; + static byte[] exec0RddBlockValid = new byte[123]; + static byte[] exec0RddBlockToRemove = new byte[124]; + static byte[][] exec0Blocks = new byte[][] { new byte[123], new byte[12345], @@ -81,13 +96,38 @@ public static void beforeAll() throws IOException { for (byte[] block: exec1Blocks) { rand.nextBytes(block); } + rand.nextBytes(exec0RddBlockValid); + rand.nextBytes(exec0RddBlockToRemove); dataContext0 = new TestShuffleDataContext(2, 5); dataContext0.create(); dataContext0.insertSortShuffleData(0, 0, exec0Blocks); - - conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - handler = new ExternalShuffleBlockHandler(conf, null); + dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK, exec0RddBlockValid); + dataContext0.insertCachedRddData(RDD_ID, SPLIT_INDEX_VALID_BLOCK_TO_RM, exec0RddBlockToRemove); + + HashMap config = new HashMap<>(); + config.put("spark.shuffle.io.maxRetries", "0"); + conf = new TransportConf("shuffle", new MapConfigProvider(config)); + handler = new ExternalShuffleBlockHandler( + new OneForOneStreamManager(), + new ExternalShuffleBlockResolver(conf, null) { + @Override + public ManagedBuffer getRddBlockData(String appId, String execId, int rddId, int splitIdx) { + ManagedBuffer res; + if (rddId == RDD_ID) { + switch (splitIdx) { + case SPLIT_INDEX_CORRUPT_LENGTH: + res = new FileSegmentManagedBuffer(conf, new File("missing.file"), 0, 12); + break; + default: + res = super.getRddBlockData(appId, execId, rddId, splitIdx); + } + } else { + res = super.getRddBlockData(appId, execId, rddId, splitIdx); + } + return res; + } + }); transportContext = new TransportContext(conf, handler); server = transportContext.createServer(); } @@ -199,9 +239,55 @@ public void testRegisterInvalidExecutor() throws Exception { @Test public void testFetchWrongBlockId() throws Exception { registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-1", new String[] { "rdd_1_0_0" }); + FetchResult execFetch = fetchBlocks("exec-1", new String[] { "broadcast_1" }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet("broadcast_1"), execFetch.failedBlocks); + } + + @Test + public void testFetchValidRddBlock() throws Exception { + registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); + String validBlockId = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_VALID_BLOCK; + FetchResult execFetch = fetchBlocks("exec-1", new String[] { validBlockId }); + assertTrue(execFetch.failedBlocks.isEmpty()); + assertEquals(Sets.newHashSet(validBlockId), execFetch.successBlocks); + assertBuffersEqual(new NioManagedBuffer(ByteBuffer.wrap(exec0RddBlockValid)), + execFetch.buffers.get(0)); + } + + @Test + public void testFetchDeletedRddBlock() throws Exception { + registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); + String missingBlockId = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_MISSING_FILE; + FetchResult execFetch = fetchBlocks("exec-1", new String[] { missingBlockId }); + assertTrue(execFetch.successBlocks.isEmpty()); + assertEquals(Sets.newHashSet(missingBlockId), execFetch.failedBlocks); + } + + @Test + public void testRemoveRddBlocks() throws Exception { + registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); + String validBlockIdToRemove = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_VALID_BLOCK_TO_RM; + String missingBlockIdToRemove = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_MISSING_BLOCK_TO_RM; + + try (ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, 5000)) { + client.init(APP_ID); + Future numRemovedBlocks = client.removeBlocks( + TestUtils.getLocalHost(), + server.getPort(), + "exec-1", + new String[] { validBlockIdToRemove, missingBlockIdToRemove }); + assertEquals(1, numRemovedBlocks.get().intValue()); + } + } + + @Test + public void testFetchCorruptRddBlock() throws Exception { + registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); + String corruptBlockId = "rdd_" + RDD_ID +"_" + SPLIT_INDEX_CORRUPT_LENGTH; + FetchResult execFetch = fetchBlocks("exec-1", new String[] { corruptBlockId }); assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet(corruptBlockId), execFetch.failedBlocks); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java deleted file mode 100644 index d22f3ace4103..000000000000 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/NonShuffleFilesCleanupSuite.java +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle; - -import java.io.File; -import java.io.FilenameFilter; -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Random; -import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; - -import com.google.common.util.concurrent.MoreExecutors; -import org.junit.Test; -import static org.junit.Assert.assertTrue; - -import org.apache.spark.network.util.MapConfigProvider; -import org.apache.spark.network.util.TransportConf; - -public class NonShuffleFilesCleanupSuite { - - // Same-thread Executor used to ensure cleanup happens synchronously in test thread. - private Executor sameThreadExecutor = MoreExecutors.sameThreadExecutor(); - private TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); - private static final String SORT_MANAGER = "org.apache.spark.shuffle.sort.SortShuffleManager"; - - @Test - public void cleanupOnRemovedExecutorWithShuffleFiles() throws IOException { - cleanupOnRemovedExecutor(true); - } - - @Test - public void cleanupOnRemovedExecutorWithoutShuffleFiles() throws IOException { - cleanupOnRemovedExecutor(false); - } - - private void cleanupOnRemovedExecutor(boolean withShuffleFiles) throws IOException { - TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); - - ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); - resolver.executorRemoved("exec0", "app"); - - assertCleanedUp(dataContext); - } - - @Test - public void cleanupUsesExecutorWithShuffleFiles() throws IOException { - cleanupUsesExecutor(true); - } - - @Test - public void cleanupUsesExecutorWithoutShuffleFiles() throws IOException { - cleanupUsesExecutor(false); - } - - private void cleanupUsesExecutor(boolean withShuffleFiles) throws IOException { - TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); - - AtomicBoolean cleanupCalled = new AtomicBoolean(false); - - // Executor which does nothing to ensure we're actually using it. - Executor noThreadExecutor = runnable -> cleanupCalled.set(true); - - ExternalShuffleBlockResolver manager = - new ExternalShuffleBlockResolver(conf, null, noThreadExecutor); - - manager.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); - manager.executorRemoved("exec0", "app"); - - assertTrue(cleanupCalled.get()); - assertStillThere(dataContext); - } - - @Test - public void cleanupOnlyRemovedExecutorWithShuffleFiles() throws IOException { - cleanupOnlyRemovedExecutor(true); - } - - @Test - public void cleanupOnlyRemovedExecutorWithoutShuffleFiles() throws IOException { - cleanupOnlyRemovedExecutor(false); - } - - private void cleanupOnlyRemovedExecutor(boolean withShuffleFiles) throws IOException { - TestShuffleDataContext dataContext0 = initDataContext(withShuffleFiles); - TestShuffleDataContext dataContext1 = initDataContext(withShuffleFiles); - - ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - resolver.registerExecutor("app", "exec0", dataContext0.createExecutorInfo(SORT_MANAGER)); - resolver.registerExecutor("app", "exec1", dataContext1.createExecutorInfo(SORT_MANAGER)); - - - resolver.executorRemoved("exec-nonexistent", "app"); - assertStillThere(dataContext0); - assertStillThere(dataContext1); - - resolver.executorRemoved("exec0", "app"); - assertCleanedUp(dataContext0); - assertStillThere(dataContext1); - - resolver.executorRemoved("exec1", "app"); - assertCleanedUp(dataContext0); - assertCleanedUp(dataContext1); - - // Make sure it's not an error to cleanup multiple times - resolver.executorRemoved("exec1", "app"); - assertCleanedUp(dataContext0); - assertCleanedUp(dataContext1); - } - - @Test - public void cleanupOnlyRegisteredExecutorWithShuffleFiles() throws IOException { - cleanupOnlyRegisteredExecutor(true); - } - - @Test - public void cleanupOnlyRegisteredExecutorWithoutShuffleFiles() throws IOException { - cleanupOnlyRegisteredExecutor(false); - } - - private void cleanupOnlyRegisteredExecutor(boolean withShuffleFiles) throws IOException { - TestShuffleDataContext dataContext = initDataContext(withShuffleFiles); - - ExternalShuffleBlockResolver resolver = - new ExternalShuffleBlockResolver(conf, null, sameThreadExecutor); - resolver.registerExecutor("app", "exec0", dataContext.createExecutorInfo(SORT_MANAGER)); - - resolver.executorRemoved("exec1", "app"); - assertStillThere(dataContext); - - resolver.executorRemoved("exec0", "app"); - assertCleanedUp(dataContext); - } - - private static void assertStillThere(TestShuffleDataContext dataContext) { - for (String localDir : dataContext.localDirs) { - assertTrue(localDir + " was cleaned up prematurely", new File(localDir).exists()); - } - } - - private static FilenameFilter filter = new FilenameFilter() { - @Override - public boolean accept(File dir, String name) { - // Don't delete shuffle data or shuffle index files. - return !name.endsWith(".index") && !name.endsWith(".data"); - } - }; - - private static boolean assertOnlyShuffleDataInDir(File[] dirs) { - for (File dir : dirs) { - assertTrue(dir.getName() + " wasn't cleaned up", !dir.exists() || - dir.listFiles(filter).length == 0 || assertOnlyShuffleDataInDir(dir.listFiles())); - } - return true; - } - - private static void assertCleanedUp(TestShuffleDataContext dataContext) { - for (String localDir : dataContext.localDirs) { - File[] dirs = new File[] {new File(localDir)}; - assertOnlyShuffleDataInDir(dirs); - } - } - - private static TestShuffleDataContext initDataContext(boolean withShuffleFiles) - throws IOException { - if (withShuffleFiles) { - return initDataContextWithShuffleFiles(); - } else { - return initDataContextWithoutShuffleFiles(); - } - } - - private static TestShuffleDataContext initDataContextWithShuffleFiles() throws IOException { - TestShuffleDataContext dataContext = createDataContext(); - createShuffleFiles(dataContext); - createNonShuffleFiles(dataContext); - return dataContext; - } - - private static TestShuffleDataContext initDataContextWithoutShuffleFiles() throws IOException { - TestShuffleDataContext dataContext = createDataContext(); - createNonShuffleFiles(dataContext); - return dataContext; - } - - private static TestShuffleDataContext createDataContext() { - TestShuffleDataContext dataContext = new TestShuffleDataContext(10, 5); - dataContext.create(); - return dataContext; - } - - private static void createShuffleFiles(TestShuffleDataContext dataContext) throws IOException { - Random rand = new Random(123); - dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { - "ABC".getBytes(StandardCharsets.UTF_8), - "DEF".getBytes(StandardCharsets.UTF_8)}); - } - - private static void createNonShuffleFiles(TestShuffleDataContext dataContext) throws IOException { - // Create spill file(s) - dataContext.insertSpillData(); - } -} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 6989c3baf2e2..10be95ec50c3 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -22,7 +22,6 @@ import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; -import java.util.UUID; import com.google.common.io.Closeables; import com.google.common.io.Files; @@ -97,13 +96,36 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr /** Creates spill file(s) within the local dirs. */ public void insertSpillData() throws IOException { - String filename = "temp_local_" + UUID.randomUUID(); - OutputStream dataStream = null; + String filename = "temp_local_uuid"; + insertFile(filename); + } + + public void insertBroadcastData() throws IOException { + String filename = "broadcast_12_uuid"; + insertFile(filename); + } + + public void insertTempShuffleData() throws IOException { + String filename = "temp_shuffle_uuid"; + insertFile(filename); + } + public void insertCachedRddData(int rddId, int splitId, byte[] block) throws IOException { + String blockId = "rdd_" + rddId + "_" + splitId; + insertFile(blockId, block); + } + + private void insertFile(String filename) throws IOException { + insertFile(filename, new byte[] { 42 }); + } + + private void insertFile(String filename, byte[] block) throws IOException { + OutputStream dataStream = null; + File file = ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, filename); + assert(!file.exists()) : "this test file has been already generated"; try { - dataStream = new FileOutputStream( - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, filename)); - dataStream.write(42); + dataStream = new FileOutputStream(file); + dataStream.write(block); } finally { Closeables.close(dataStream, false); } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6fade10b7a3c..1782027fe293 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -491,8 +491,8 @@ private[spark] class ExecutorAllocationManager( newExecutorTotal = numExistingExecutors if (testing || executorsRemoved.nonEmpty) { executorsRemoved.foreach { removedExecutorId => - // If it is a cached block, it uses cachedExecutorIdleTimeoutS for timeout - val idleTimeout = if (blockManagerMaster.hasCachedBlocks(removedExecutorId)) { + // If it has an exclusive cached block then cachedExecutorIdleTimeoutS is used for timeout + val idleTimeout = if (blockManagerMaster.hasExclusiveCachedBlocks(removedExecutorId)) { cachedExecutorIdleTimeoutS } else { executorIdleTimeoutS @@ -605,10 +605,10 @@ private[spark] class ExecutorAllocationManager( private def onExecutorIdle(executorId: String): Unit = synchronized { if (executorIds.contains(executorId)) { if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) { - // Note that it is not necessary to query the executors since all the cached - // blocks we are concerned with are reported to the driver. Note that this - // does not include broadcast blocks. - val hasCachedBlocks = blockManagerMaster.hasCachedBlocks(executorId) + // Note that it is not necessary to query the executors since all the cached blocks we are + // concerned with are reported to the driver. This does not include broadcast blocks and + // non-exclusive blocks which are also available via the external shuffle service. + val hasCachedBlocks = blockManagerMaster.hasExclusiveCachedBlocks(executorId) val now = clock.getTimeMillis() val timeout = { if (hasCachedBlocks) { diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 002bf65ba593..4e778a1ddd5f 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -35,7 +35,8 @@ import org.apache.spark.internal.{config, Logging} import org.apache.spark.internal.config._ import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances} -import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} +import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint @@ -328,9 +329,26 @@ object SparkEnv extends Logging { conf.get(BLOCK_MANAGER_PORT) } + val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + Some(new ExternalShuffleClient(transConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) + } else { + None + } + val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint( BlockManagerMaster.DRIVER_ENDPOINT_NAME, - new BlockManagerMasterEndpoint(rpcEnv, isLocal, conf, listenerBus)), + new BlockManagerMasterEndpoint( + rpcEnv, + isLocal, + conf, + listenerBus, + if (conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)) { + externalShuffleClient + } else { + None + })), conf, isDriver) val blockTransferService = @@ -338,9 +356,18 @@ object SparkEnv extends Logging { blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, - serializerManager, conf, memoryManager, mapOutputTracker, shuffleManager, - blockTransferService, securityManager, numUsableCores) + val blockManager = new BlockManager( + executorId, + rpcEnv, + blockManagerMaster, + serializerManager, + conf, + memoryManager, + mapOutputTracker, + shuffleManager, + blockTransferService, + securityManager, + externalShuffleClient) val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f8ec5b6b190c..974e54689cc2 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -100,8 +100,8 @@ private[deploy] class Worker( // TTL for app folders/data; after TTL expires it will be cleaned up private val APP_DATA_RETENTION_SECONDS = conf.get(APP_DATA_RETENTION) - // Whether or not cleanup the non-shuffle files on executor exits. - private val CLEANUP_NON_SHUFFLE_FILES_ENABLED = + // Whether or not cleanup the non-shuffle service served files on executor exits. + private val CLEANUP_FILES_AFTER_EXECUTOR_EXIT = conf.get(config.STORAGE_CLEANUP_FILES_AFTER_EXECUTOR_EXIT) private var master: Option[RpcEndpointRef] = None @@ -750,7 +750,8 @@ private[deploy] class Worker( trimFinishedExecutorsIfNecessary() coresUsed -= executor.cores memoryUsed -= executor.memory - if (CLEANUP_NON_SHUFFLE_FILES_ENABLED) { + + if (CLEANUP_FILES_AFTER_EXECUTOR_EXIT) { shuffleService.executorRemoved(executorStateChanged.execId.toString, appId) } case None => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0aed1af023f8..882de1deb5e6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -21,6 +21,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.launcher.SparkLauncher import org.apache.spark.metrics.GarbageCollectionMetrics +import org.apache.spark.network.shuffle.Constants import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} @@ -299,7 +300,8 @@ package object config { private[spark] val STORAGE_CLEANUP_FILES_AFTER_EXECUTOR_EXIT = ConfigBuilder("spark.storage.cleanupFilesAfterExecutorExit") - .doc("Whether or not cleanup the non-shuffle files on executor exits.") + .doc("Whether or not cleanup the files not served by the external shuffle service " + + "on executor exits.") .booleanConf .createWithDefault(true) @@ -366,6 +368,15 @@ package object config { private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_FETCH_RDD_ENABLED = + ConfigBuilder(Constants.SHUFFLE_SERVICE_FETCH_RDD_ENABLED) + .doc("Whether to use the ExternalShuffleService for fetching disk persisted RDD blocks. " + + "In case of dynamic allocation if this feature is enabled executors having only disk " + + "persisted blocks are considered idle after " + + "'spark.dynamicAllocation.executorIdleTimeout' and will be released accordingly.") + .booleanConf + .createWithDefault(true) + private[spark] val SHUFFLE_SERVICE_DB_ENABLED = ConfigBuilder("spark.shuffle.service.db.enabled") .doc("Whether to use db in ExternalShuffleService. Note that this only affects " + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 09928e47634c..6dbef784fe0c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -45,7 +45,6 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID -import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf @@ -130,11 +129,12 @@ private[spark] class BlockManager( shuffleManager: ShuffleManager, val blockTransferService: BlockTransferService, securityManager: SecurityManager, - numUsableCores: Int) + externalShuffleClient: Option[ExternalShuffleClient]) extends BlockDataManager with BlockEvictionHandler with Logging { - private[spark] val externalShuffleServiceEnabled = - conf.get(config.SHUFFLE_SERVICE_ENABLED) + // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)` + private[spark] val externalShuffleServiceEnabled: Boolean = externalShuffleClient.isDefined + private val remoteReadNioBufferConversion = conf.get(Network.NETWORK_REMOTE_READ_NIO_BUFFER_CONVERSION) @@ -164,20 +164,7 @@ private[spark] class BlockManager( private val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory private val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory - // Port used by the external shuffle service. In Yarn mode, this may be already be - // set through the Hadoop configuration as the server is launched in the Yarn NM. - private val externalShuffleServicePort = { - val tmpPort = Utils.getSparkOrYarnConfig(conf, config.SHUFFLE_SERVICE_PORT.key, - config.SHUFFLE_SERVICE_PORT.defaultValueString).toInt - if (tmpPort == 0) { - // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds - // an open port. But we still need to tell our spark apps the right port to use. So - // only if the yarn config has the port set to 0, we prefer the value in the spark config - conf.get(config.SHUFFLE_SERVICE_PORT.key).toInt - } else { - tmpPort - } - } + private val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf) var blockManagerId: BlockManagerId = _ @@ -187,13 +174,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. - private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new ExternalShuffleClient(transConf, securityManager, - securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) - } else { - blockTransferService - } + private[spark] val shuffleClient = externalShuffleClient.getOrElse(blockTransferService) // Max number of failures before this block manager refreshes the block locations from the driver private val maxFailuresBeforeLocationRefresh = @@ -414,8 +395,9 @@ private[spark] class BlockManager( */ def initialize(appId: String): Unit = { blockTransferService.init(this) - shuffleClient.init(appId) - + externalShuffleClient.foreach { shuffleClient => + shuffleClient.init(appId) + } blockReplicationPolicy = { val priorityClass = conf.get(config.STORAGE_REPLICATION_POLICY) val clazz = Utils.classForName(priorityClass) @@ -843,7 +825,7 @@ private[spark] class BlockManager( * * This does not acquire a lock on this block in this JVM. */ - private def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = { + private[spark] def getRemoteValues[T: ClassTag](blockId: BlockId): Option[BlockResult] = { val ct = implicitly[ClassTag[T]] getRemoteManagedBuffer(blockId).map { data => val values = @@ -852,21 +834,30 @@ private[spark] class BlockManager( } } + private def preferExecutors(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { + val (executors, shuffleServers) = locations.partition(_.port != externalShuffleServicePort) + executors ++ shuffleServers + } + /** * Return a list of locations for the given block, prioritizing the local machine since * multiple block managers can share the same host, followed by hosts on the same rack. + * + * Within each of the above listed groups (same host, same rack and others) executors are + * preferred over the external shuffle service. */ - private def sortLocations(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { + private[spark] def sortLocations(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { val locs = Random.shuffle(locations) - val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } - blockManagerId.topologyInfo match { - case None => preferredLocs ++ otherLocs + val (preferredLocs, otherLocs) = locs.partition(_.host == blockManagerId.host) + val orderedParts = blockManagerId.topologyInfo match { + case None => Seq(preferredLocs, otherLocs) case Some(_) => val (sameRackLocs, differentRackLocs) = otherLocs.partition { loc => blockManagerId.topologyInfo == loc.topologyInfo } - preferredLocs ++ sameRackLocs ++ differentRackLocs + Seq(preferredLocs, sameRackLocs, differentRackLocs) } + orderedParts.map(preferExecutors).reduce(_ ++ _) } /** @@ -902,8 +893,12 @@ private[spark] class BlockManager( val loc = locationIterator.next() logDebug(s"Getting remote block $blockId from $loc") val data = try { - blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager) + val buf = blockTransferService.fetchBlockSync(loc.host, loc.port, loc.executorId, + blockId.toString, tempFileManager) + if (blockSize > 0 && buf.size() == 0) { + throw new IllegalStateException("Empty buffer received for non empty block") + } + buf } catch { case NonFatal(e) => runningFailureCount += 1 diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index d24421b96277..b18d38fe5253 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -223,11 +223,12 @@ class BlockManagerMaster( } /** - * Find out if the executor has cached blocks. This method does not consider broadcast blocks, - * since they are not reported the master. + * Find out if the executor has cached blocks which are not available via the external shuffle + * service. + * This method does not consider broadcast blocks, since they are not reported to the master. */ - def hasCachedBlocks(executorId: String): Boolean = { - driverEndpoint.askSync[Boolean](HasCachedBlocks(executorId)) + def hasExclusiveCachedBlocks(executorId: String): Boolean = { + driverEndpoint.askSync[Boolean](HasExclusiveCachedBlocks(executorId)) } /** Stop the driver endpoint, called only on the Spark driver node */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index f388d59e78ba..65ec1c3f0dc6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.io.IOException import java.util.{HashMap => JHashMap} +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,10 +29,11 @@ import scala.util.Random import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} +import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} /** * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses @@ -42,12 +44,17 @@ class BlockManagerMasterEndpoint( override val rpcEnv: RpcEnv, val isLocal: Boolean, conf: SparkConf, - listenerBus: LiveListenerBus) + listenerBus: LiveListenerBus, + externalShuffleClient: Option[ExternalShuffleClient]) extends ThreadSafeRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] + // Mapping from external shuffle service block manager id to the block statuses. + private val blockStatusByShuffleService = + new mutable.HashMap[BlockManagerId, JHashMap[BlockId, BlockStatus]] + // Mapping from executor ID to block manager ID. private val blockManagerIdByExecutor = new mutable.HashMap[String, BlockManagerId] @@ -70,7 +77,13 @@ class BlockManagerMasterEndpoint( val proactivelyReplicate = conf.get(config.STORAGE_REPLICATION_PROACTIVE) + val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf) + logInfo("BlockManagerMasterEndpoint up") + // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED) + // && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)` + private val externalShuffleServiceRddFetchEnabled: Boolean = externalShuffleClient.isDefined + private val externalShuffleServicePort: Int = StorageUtils.externalShuffleServicePort(conf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterBlockManager(blockManagerId, maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) => @@ -135,12 +148,12 @@ class BlockManagerMasterEndpoint( case BlockManagerHeartbeat(blockManagerId) => context.reply(heartbeatReceived(blockManagerId)) - case HasCachedBlocks(executorId) => + case HasExclusiveCachedBlocks(executorId) => blockManagerIdByExecutor.get(executorId) match { case Some(bm) => if (blockManagerInfo.contains(bm)) { val bmInfo = blockManagerInfo(bm) - context.reply(bmInfo.cachedBlocks.nonEmpty) + context.reply(bmInfo.exclusiveCachedBlocks.nonEmpty) } else { context.reply(false) } @@ -152,29 +165,62 @@ class BlockManagerMasterEndpoint( // First remove the metadata for the given RDD, and then asynchronously remove the blocks // from the slaves. + // The message sent to the slaves to remove the RDD + val removeMsg = RemoveRdd(rddId) + // Find all blocks for the given RDD, remove the block from both blockLocations and - // the blockManagerInfo that is tracking the blocks. + // the blockManagerInfo that is tracking the blocks and create the futures which asynchronously + // remove the blocks from slaves and gives back the number of removed blocks val blocks = blockLocations.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocksToDeleteByShuffleService = + new mutable.HashMap[BlockManagerId, mutable.HashSet[RDDBlockId]] + blocks.foreach { blockId => - val bms: mutable.HashSet[BlockManagerId] = blockLocations.get(blockId) - bms.foreach(bm => blockManagerInfo.get(bm).foreach(_.removeBlock(blockId))) - blockLocations.remove(blockId) + val bms: mutable.HashSet[BlockManagerId] = blockLocations.remove(blockId) + + val (bmIdsExtShuffle, bmIdsExecutor) = bms.partition(_.port == externalShuffleServicePort) + val liveExecutorsForBlock = bmIdsExecutor.map(_.executorId).toSet + bmIdsExtShuffle.foreach { bmIdForShuffleService => + // if the original executor is already released then delete this disk block via + // the external shuffle service + if (!liveExecutorsForBlock.contains(bmIdForShuffleService.executorId)) { + val blockIdsToDel = blocksToDeleteByShuffleService.getOrElseUpdate(bmIdForShuffleService, + new mutable.HashSet[RDDBlockId]()) + blockIdsToDel += blockId + blockStatusByShuffleService.get(bmIdForShuffleService).foreach { blockStatus => + blockStatus.remove(blockId) + } + } + } + bmIdsExecutor.foreach { bmId => + blockManagerInfo.get(bmId).foreach { bmInfo => + bmInfo.removeBlock(blockId) + } + } } - - // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. - // The dispatcher is used as an implicit argument into the Future sequence construction. - val removeMsg = RemoveRdd(rddId) - - val futures = blockManagerInfo.values.map { bm => - bm.slaveEndpoint.ask[Int](removeMsg).recover { + val removeRddFromExecutorsFutures = blockManagerInfo.values.map { bmInfo => + bmInfo.slaveEndpoint.ask[Int](removeMsg).recover { case e: IOException => - logWarning(s"Error trying to remove RDD $rddId from block manager ${bm.blockManagerId}", - e) + logWarning(s"Error trying to remove RDD ${removeMsg.rddId} " + + s"from block manager ${bmInfo.blockManagerId}", e) 0 // zero blocks were removed } }.toSeq - Future.sequence(futures) + val removeRddBlockViaExtShuffleServiceFutures = externalShuffleClient.map { shuffleClient => + blocksToDeleteByShuffleService.map { case (bmId, blockIds) => + Future[Int] { + val numRemovedBlocks = shuffleClient.removeBlocks( + bmId.host, + bmId.port, + bmId.executorId, + blockIds.map(_.toString).toArray) + numRemovedBlocks.get(defaultRpcTimeout.duration.toSeconds, TimeUnit.SECONDS) + } + } + }.getOrElse(Seq.empty) + + Future.sequence(removeRddFromExecutorsFutures ++ removeRddBlockViaExtShuffleServiceFutures) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { @@ -353,6 +399,12 @@ class BlockManagerMasterEndpoint( ).map(_.flatten.toSeq) } + private def externalShuffleServiceIdOnHost(blockManagerId: BlockManagerId): BlockManagerId = { + // we need to keep the executor ID of the original executor to let the shuffle service know + // which local directories should be used to look for the file + BlockManagerId(blockManagerId.executorId, blockManagerId.host, externalShuffleServicePort) + } + /** * Returns the BlockManagerId with topology information populated, if available. */ @@ -384,8 +436,17 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor(id.executorId) = id - blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxOnHeapMemSize, maxOffHeapMemSize, slaveEndpoint) + val externalShuffleServiceBlockStatus = + if (externalShuffleServiceRddFetchEnabled) { + val externalShuffleServiceBlocks = blockStatusByShuffleService + .getOrElseUpdate(externalShuffleServiceIdOnHost(id), new JHashMap[BlockId, BlockStatus]) + Some(externalShuffleServiceBlocks) + } else { + None + } + + blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(), maxOnHeapMemSize, + maxOffHeapMemSize, slaveEndpoint, externalShuffleServiceBlockStatus) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize, Some(maxOnHeapMemSize), Some(maxOffHeapMemSize))) @@ -430,6 +491,15 @@ class BlockManagerMasterEndpoint( locations.remove(blockManagerId) } + if (blockId.isRDD && storageLevel.useDisk && externalShuffleServiceRddFetchEnabled) { + val externalShuffleServiceId = externalShuffleServiceIdOnHost(blockManagerId) + if (storageLevel.isValid) { + locations.add(externalShuffleServiceId) + } else { + locations.remove(externalShuffleServiceId) + } + } + // Remove the block from master tracking if it has been removed on all slaves. if (locations.size == 0) { blockLocations.remove(blockId) @@ -443,7 +513,13 @@ class BlockManagerMasterEndpoint( private def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty) - val status = locations.headOption.flatMap { bmId => blockManagerInfo(bmId).getStatus(blockId) } + val status = locations.headOption.flatMap { bmId => + if (externalShuffleServiceRddFetchEnabled && bmId.port == externalShuffleServicePort) { + Option(blockStatusByShuffleService(bmId).get(blockId)) + } else { + blockManagerInfo(bmId).getStatus(blockId) + } + } if (locations.nonEmpty && status.isDefined) { Some(BlockLocationsAndStatus(locations, status.get)) @@ -499,19 +575,25 @@ private[spark] class BlockManagerInfo( timeMs: Long, val maxOnHeapMem: Long, val maxOffHeapMem: Long, - val slaveEndpoint: RpcEndpointRef) + val slaveEndpoint: RpcEndpointRef, + val externalShuffleServiceBlockStatus: Option[JHashMap[BlockId, BlockStatus]]) extends Logging { val maxMem = maxOnHeapMem + maxOffHeapMem + val externalShuffleServiceEnabled = externalShuffleServiceBlockStatus.isDefined + private var _lastSeenMs: Long = timeMs private var _remainingMem: Long = maxMem // Mapping from block id to its status. private val _blocks = new JHashMap[BlockId, BlockStatus] - // Cached blocks held by this BlockManager. This does not include broadcast blocks. - private val _cachedBlocks = new mutable.HashSet[BlockId] + /** + * Cached blocks which are not available via the external shuffle service. + * This does not include broadcast blocks. + */ + private val _exclusiveCachedBlocks = new mutable.HashSet[BlockId] def getStatus(blockId: BlockId): Option[BlockStatus] = Option(_blocks.get(blockId)) @@ -579,13 +661,28 @@ private[spark] class BlockManagerInfo( s" (size: ${Utils.bytesToString(diskSize)})") } } - if (!blockId.isBroadcast && blockStatus.isCached) { - _cachedBlocks += blockId + + if (!blockId.isBroadcast) { + if (!externalShuffleServiceEnabled || !storageLevel.useDisk) { + _exclusiveCachedBlocks += blockId + } else if (blockExists) { + // removing block from the exclusive cached blocks when updated to non-exclusive + _exclusiveCachedBlocks -= blockId + } + } + + externalShuffleServiceBlockStatus.foreach { shuffleServiceBlocks => + if (!blockId.isBroadcast && blockStatus.diskSize > 0) { + shuffleServiceBlocks.put(blockId, blockStatus) + } } } else if (blockExists) { // If isValid is not true, drop the block. _blocks.remove(blockId) - _cachedBlocks -= blockId + _exclusiveCachedBlocks -= blockId + externalShuffleServiceBlockStatus.foreach { blockStatus => + blockStatus.remove(blockId) + } if (originalLevel.useMemory) { logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + s" (size: ${Utils.bytesToString(originalMemSize)}," + @@ -602,8 +699,11 @@ private[spark] class BlockManagerInfo( if (_blocks.containsKey(blockId)) { _remainingMem += _blocks.get(blockId).memSize _blocks.remove(blockId) + externalShuffleServiceBlockStatus.foreach { blockStatus => + blockStatus.remove(blockId) + } } - _cachedBlocks -= blockId + _exclusiveCachedBlocks -= blockId } def remainingMem: Long = _remainingMem @@ -612,8 +712,7 @@ private[spark] class BlockManagerInfo( def blocks: JHashMap[BlockId, BlockStatus] = _blocks - // This does not include broadcast blocks. - def cachedBlocks: collection.Set[BlockId] = _cachedBlocks + def exclusiveCachedBlocks: collection.Set[BlockId] = _exclusiveCachedBlocks override def toString: String = "BlockManagerInfo " + timeMs + " " + _remainingMem diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 2be28420b495..3dbac694cf81 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -122,7 +122,8 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster - case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster + case class HasExclusiveCachedBlocks(executorId: String) extends ToBlockManagerMaster case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster + } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 1c9ea1dba97d..fc426eee608c 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -26,7 +26,8 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils} import sun.misc.Unsafe import sun.nio.ch.DirectBuffer -import org.apache.spark.internal.Logging +import org.apache.spark.SparkConf +import org.apache.spark.internal.{config, Logging} import org.apache.spark.util.Utils /** @@ -236,4 +237,20 @@ private[spark] object StorageUtils extends Logging { } } + /** + * Get the port used by the external shuffle service. In Yarn mode, this may be already be + * set through the Hadoop configuration as the server is launched in the Yarn NM. + */ + def externalShuffleServicePort(conf: SparkConf): Int = { + val tmpPort = Utils.getSparkOrYarnConfig(conf, config.SHUFFLE_SERVICE_PORT.key, + config.SHUFFLE_SERVICE_PORT.defaultValueString).toInt + if (tmpPort == 0) { + // for testing, we set "spark.shuffle.service.port" to 0 in the yarn config, so yarn finds + // an open port. But we still need to tell our spark apps the right port to use. So + // only if the yarn config has the port set to 0, we prefer the value in the spark config + conf.get(config.SHUFFLE_SERVICE_PORT.key).toInt + } else { + tmpPort + } + } } diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 8b737cd8c81f..e5644f25a0b7 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -18,12 +18,15 @@ package org.apache.spark import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually +import org.scalatest.time.SpanSugar._ import org.apache.spark.internal.config import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.server.TransportServer import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.Utils /** @@ -32,7 +35,7 @@ import org.apache.spark.util.Utils * set up in `ExternalShuffleBlockHandler`, such as changing the format of shuffle files or how * we hash files into folders. */ -class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { +class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll with Eventually { var server: TransportServer = _ var transportContext: TransportContext = _ var rpcHandler: ExternalShuffleBlockHandler = _ @@ -92,4 +95,42 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } e.getMessage should include ("Fetch failure will not retry stage due to testing config") } + + test("SPARK-25888: using external shuffle service fetching disk persisted blocks") { + sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) + sc.env.blockManager.externalShuffleServiceEnabled should equal(true) + sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient]) + try { + val rdd = sc.parallelize(0 until 100, 2) + .map { i => (i, 1) } + .persist(StorageLevel.DISK_ONLY) + + rdd.count() + + val blockId = RDDBlockId(rdd.id, 0) + eventually(timeout(2.seconds), interval(100.milliseconds)) { + val locations = sc.env.blockManager.master.getLocations(blockId) + assert(locations.size === 2) + assert(locations.map(_.port).contains(server.getPort), + "external shuffle service port should be contained") + } + + sc.killExecutors(sc.getExecutorIds()) + + eventually(timeout(2.seconds), interval(100.milliseconds)) { + val locations = sc.env.blockManager.master.getLocations(blockId) + assert(locations.size === 1) + assert(locations.map(_.port).contains(server.getPort), + "external shuffle service port should be contained") + } + + assert(sc.env.blockManager.getRemoteValues(blockId).isDefined) + + // test unpersist: as executors are killed the blocks will be removed via the shuffle service + rdd.unpersist(true) + assert(sc.env.blockManager.getRemoteValues(blockId).isEmpty) + } finally { + rpcHandler.applicationRemoved(sc.conf.getAppId, true) + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala new file mode 100644 index 000000000000..8df123250303 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerInfoSuite.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.{HashMap => JHashMap} + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite + +class BlockManagerInfoSuite extends SparkFunSuite { + + def testWithShuffleServiceOnOff(testName: String) + (f: (Boolean, BlockManagerInfo) => Unit): Unit = { + Seq(true, false).foreach { svcEnabled => + val bmInfo = new BlockManagerInfo( + BlockManagerId("executor0", "host", 1234, None), + timeMs = 300, + maxOnHeapMem = 10000, + maxOffHeapMem = 20000, + slaveEndpoint = null, + if (svcEnabled) Some(new JHashMap[BlockId, BlockStatus]) else None) + test(s"$testName externalShuffleServiceEnabled=$svcEnabled") { + f(svcEnabled, bmInfo) + } + } + } + + testWithShuffleServiceOnOff("broadcast block") { (_, bmInfo) => + val broadcastId: BlockId = BroadcastBlockId(0, "field1") + bmInfo.updateBlockInfo( + broadcastId, StorageLevel.MEMORY_AND_DISK, memSize = 200, diskSize = 100) + assert(bmInfo.blocks.asScala === + Map(broadcastId -> BlockStatus(StorageLevel.MEMORY_AND_DISK, 0, 100))) + assert(bmInfo.exclusiveCachedBlocks.isEmpty) + assert(bmInfo.remainingMem === 29800) + } + + testWithShuffleServiceOnOff("RDD block with MEMORY_ONLY") { (svcEnabled, bmInfo) => + val rddId: BlockId = RDDBlockId(0, 0) + bmInfo.updateBlockInfo(rddId, StorageLevel.MEMORY_ONLY, memSize = 200, diskSize = 0) + assert(bmInfo.blocks.asScala === + Map(rddId -> BlockStatus(StorageLevel.MEMORY_ONLY, 200, 0))) + assert(bmInfo.exclusiveCachedBlocks === Set(rddId)) + assert(bmInfo.remainingMem === 29800) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty) + } + } + + testWithShuffleServiceOnOff("RDD block with MEMORY_AND_DISK") { (svcEnabled, bmInfo) => + // This is the effective storage level, not the requested storage level, but MEMORY_AND_DISK + // is still possible if it's first in memory, purged to disk, and later promoted back to memory. + val rddId: BlockId = RDDBlockId(0, 0) + bmInfo.updateBlockInfo(rddId, StorageLevel.MEMORY_AND_DISK, memSize = 200, diskSize = 400) + assert(bmInfo.blocks.asScala === + Map(rddId -> BlockStatus(StorageLevel.MEMORY_AND_DISK, 0, 400))) + val exclusiveCachedBlocksForOneMemoryOnly = if (svcEnabled) Set() else Set(rddId) + assert(bmInfo.exclusiveCachedBlocks === exclusiveCachedBlocksForOneMemoryOnly) + assert(bmInfo.remainingMem === 29800) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala === + Map(rddId -> BlockStatus(StorageLevel.MEMORY_AND_DISK, 0, 400))) + } + } + + testWithShuffleServiceOnOff("RDD block with DISK_ONLY") { (svcEnabled, bmInfo) => + val rddId: BlockId = RDDBlockId(0, 0) + bmInfo.updateBlockInfo(rddId, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200) + assert(bmInfo.blocks.asScala === + Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + val exclusiveCachedBlocksForOneMemoryOnly = if (svcEnabled) Set() else Set(rddId) + assert(bmInfo.exclusiveCachedBlocks === exclusiveCachedBlocksForOneMemoryOnly) + assert(bmInfo.remainingMem === 30000) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala === + Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + } + } + + testWithShuffleServiceOnOff("update from MEMORY_ONLY to DISK_ONLY") { (svcEnabled, bmInfo) => + // This happens if MEMORY_AND_DISK is the requested storage level, but the block gets purged + // to disk under memory pressure. + val rddId: BlockId = RDDBlockId(0, 0) + bmInfo.updateBlockInfo(rddId, StorageLevel.MEMORY_ONLY, memSize = 200, 0) + assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.MEMORY_ONLY, 200, 0))) + assert(bmInfo.exclusiveCachedBlocks === Set(rddId)) + assert(bmInfo.remainingMem === 29800) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty) + } + + bmInfo.updateBlockInfo(rddId, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200) + assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + val exclusiveCachedBlocksForNoMemoryOnly = if (svcEnabled) Set() else Set(rddId) + assert(bmInfo.exclusiveCachedBlocks === exclusiveCachedBlocksForNoMemoryOnly) + assert(bmInfo.remainingMem === 30000) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala === + Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + } + } + + testWithShuffleServiceOnOff("using invalid StorageLevel") { (svcEnabled, bmInfo) => + val rddId: BlockId = RDDBlockId(0, 0) + bmInfo.updateBlockInfo(rddId, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200) + assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + val exclusiveCachedBlocksForOneMemoryOnly = if (svcEnabled) Set() else Set(rddId) + assert(bmInfo.exclusiveCachedBlocks === exclusiveCachedBlocksForOneMemoryOnly) + assert(bmInfo.remainingMem === 30000) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala === + Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + } + + bmInfo.updateBlockInfo(rddId, StorageLevel.NONE, memSize = 0, diskSize = 200) + assert(bmInfo.blocks.isEmpty) + assert(bmInfo.exclusiveCachedBlocks.isEmpty) + assert(bmInfo.remainingMem === 30000) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty) + } + } + + testWithShuffleServiceOnOff("remove block") { (svcEnabled, bmInfo) => + val rddId: BlockId = RDDBlockId(0, 0) + bmInfo.updateBlockInfo(rddId, StorageLevel.DISK_ONLY, memSize = 0, diskSize = 200) + assert(bmInfo.blocks.asScala === Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + val exclusiveCachedBlocksForOneMemoryOnly = if (svcEnabled) Set() else Set(rddId) + assert(bmInfo.exclusiveCachedBlocks === exclusiveCachedBlocksForOneMemoryOnly) + assert(bmInfo.remainingMem === 30000) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.asScala === + Map(rddId -> BlockStatus(StorageLevel.DISK_ONLY, 0, 200))) + } + + bmInfo.removeBlock(rddId) + assert(bmInfo.blocks.asScala.isEmpty) + assert(bmInfo.exclusiveCachedBlocks.isEmpty) + assert(bmInfo.remainingMem === 30000) + if (svcEnabled) { + assert(bmInfo.externalShuffleServiceBlockStatus.get.isEmpty) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index a739701853f6..3f1a14fc91ed 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -75,7 +75,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val store = new BlockManager(name, rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, None) memManager.setMemoryStore(store.memoryStore) store.initialize("app-id") allStores += store @@ -99,7 +99,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite sc = new SparkContext("local", "test", conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf))), conf, true) + new LiveListenerBus(conf), None)), conf, true) allStores.clear() } @@ -235,7 +235,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite val memManager = UnifiedMemoryManager(conf, numCores = 1) val serializerManager = new SerializerManager(serializer, conf) val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) + memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, None) memManager.setMemoryStore(failableStore.memoryStore) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 9f3d8f291ede..59d58edc9dfe 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.storage +import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -36,15 +37,16 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.DataReadMethod +import org.apache.spark.internal.config import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Tests._ import org.apache.spark.memory.UnifiedMemoryManager import org.apache.spark.network.{BlockDataManager, BlockTransferService, TransportContext} -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalShuffleClient} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus @@ -110,8 +112,15 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE .getOrElse(new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1)) val memManager = UnifiedMemoryManager(bmConf, numCores = 1) val serializerManager = new SerializerManager(serializer, bmConf) + val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", 0) + Some(new ExternalShuffleClient(transConf, bmSecurityMgr, + bmSecurityMgr.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) + } else { + None + } val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf, - memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, 0) + memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, externalShuffleClient) memManager.setMemoryStore(blockManager.memoryStore) allStores += blockManager blockManager.initialize("app-id") @@ -134,7 +143,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(sc.conf).thenReturn(conf) master = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf))), conf, true) + new LiveListenerBus(conf), None)), conf, true) val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() @@ -893,7 +902,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, serializerManager, conf, memoryManager, mapOutputTracker, - shuffleManager, transfer, securityMgr, 0) + shuffleManager, transfer, securityMgr, None) allStores += store store.initialize("app-id") @@ -942,7 +951,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val memoryManager = UnifiedMemoryManager(conf, numCores = 1) val blockManager = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, serializerManager, conf, memoryManager, mapOutputTracker, - shuffleManager, transfer, securityMgr, 0) + shuffleManager, transfer, securityMgr, None) try { blockManager.initialize("app-id") testPutBlockDataAsStream(blockManager, storageLevel) @@ -1354,6 +1363,58 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(master.getLocations("item").isEmpty) } + test("SPARK-25888: serving of removed file not detected by shuffle service") { + // although the existence of the file is checked before serving it but a delete can happen + // somewhere after that check + val store = makeBlockManager(8000, "executor1") + val emptyBlockFetcher = new MockBlockTransferService(0) { + override def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String, + tempFileManager: DownloadFileManager): ManagedBuffer = { + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 1) + // empty ManagedBuffer + new FileSegmentManagedBuffer(transConf, new File("missing.file"), 0, 0) + } + } + val store2 = makeBlockManager(8000, "executor2", this.master, Some(emptyBlockFetcher)) + store.putSingle("item", "value", StorageLevel.DISK_ONLY, tellMaster = true) + assert(master.getLocations("item").nonEmpty) + assert(store2.getRemoteBytes("item").isEmpty) + } + + test("test sorting of block locations") { + val localHost = "localhost" + val otherHost = "otherHost" + val store = makeBlockManager(8000, "executor1") + val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf) + val port = store.blockTransferService.port + val rack = Some("rack") + val blockManagerWithTopolgyInfo = BlockManagerId( + store.blockManagerId.executorId, + store.blockManagerId.host, + store.blockManagerId.port, + rack) + store.blockManagerId = blockManagerWithTopolgyInfo + val locations = Seq( + BlockManagerId("executor4", otherHost, externalShuffleServicePort, rack), + BlockManagerId("executor3", otherHost, port, rack), + BlockManagerId("executor6", otherHost, externalShuffleServicePort), + BlockManagerId("executor5", otherHost, port), + BlockManagerId("executor2", localHost, externalShuffleServicePort), + BlockManagerId("executor1", localHost, port)) + val sortedLocations = Seq( + BlockManagerId("executor1", localHost, port), + BlockManagerId("executor2", localHost, externalShuffleServicePort), + BlockManagerId("executor3", otherHost, port, rack), + BlockManagerId("executor4", otherHost, externalShuffleServicePort, rack), + BlockManagerId("executor5", otherHost, port), + BlockManagerId("executor6", otherHost, externalShuffleServicePort)) + assert(store.sortLocations(locations) === sortedLocations) + } + test("SPARK-20640: Shuffle registration timeout and maxAttempts conf are working") { val tryAgainMsg = "test_spark_20640_try_again" val timingoutExecutor = "timingoutExecutor" diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 15060614983a..c8f424af9af0 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -89,7 +89,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("blockmanager", new BlockManagerMasterEndpoint(rpcEnv, true, conf, - new LiveListenerBus(conf))), conf, true) + new LiveListenerBus(conf), None)), conf, true) storageLevel = StorageLevel.MEMORY_ONLY_SER blockManager = createBlockManager(blockManagerSize, conf) @@ -282,7 +282,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean) val memManager = new UnifiedMemoryManager(conf, maxMem, maxMem / 2, 1) val transfer = new NettyBlockTransferService(conf, securityMgr, "localhost", "localhost", 0, 1) val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf, - memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) + memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, None) memManager.setMemoryStore(blockManager.memoryStore) blockManager.initialize("app-id") blockManagerBuffer += blockManager