diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index be217522367c..736059fdd1f5 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -89,4 +89,27 @@ public static String[] decode(ByteBuf buf) { return strings; } } + + /** Integer arrays are encoded with their length followed by integers. */ + public static class IntArrays { + public static int encodedLength(int[] ints) { + return 4 + 4 * ints.length; + } + + public static void encode(ByteBuf buf, int[] ints) { + buf.writeInt(ints.length); + for (int i : ints) { + buf.writeInt(i); + } + } + + public static int[] decode(ByteBuf buf) { + int numInts = buf.readInt(); + int[] ints = new int[numInts]; + for (int i = 0; i < ints.length; i ++) { + ints[i] = buf.readInt(); + } + return ints; + } + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 3628da68f1c6..453bff23026d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -341,4 +341,13 @@ public int chunkFetchHandlerThreads() { return (int) Math.ceil(threads * (chunkFetchHandlerThreadsPercent / 100.0)); } + /** + * Whether to use the old protocol while doing the shuffle block fetching. + * It is only enabled while we need the compatibility in the scenario of new spark version + * job fetching blocks from old version external shuffle service. + */ + public boolean useOldFetchProtocol() { + return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); + } + } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index ba9d657d0e56..cb2d01d4161e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -92,22 +92,37 @@ protected void handleMessage( BlockTransferMessage msgObj, TransportClient client, RpcResponseCallback callback) { - if (msgObj instanceof OpenBlocks) { + if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) { final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time(); try { - OpenBlocks msg = (OpenBlocks) msgObj; - checkAuth(client, msg.appId); - long streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel()); + int numBlockIds; + long streamId; + if (msgObj instanceof FetchShuffleBlocks) { + FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj; + checkAuth(client, msg.appId); + numBlockIds = 0; + for (int[] ids: msg.reduceIds) { + numBlockIds += ids.length; + } + streamId = streamManager.registerStream(client.getClientId(), + new ManagedBufferIterator(msg, numBlockIds), client.getChannel()); + } else { + // For the compatibility with the old version, still keep the support for OpenBlocks. + OpenBlocks msg = (OpenBlocks) msgObj; + numBlockIds = msg.blockIds.length; + checkAuth(client, msg.appId); + streamId = streamManager.registerStream(client.getClientId(), + new ManagedBufferIterator(msg), client.getChannel()); + } if (logger.isTraceEnabled()) { logger.trace( "Registered streamId {} with {} buffers for client {} from host {}", streamId, - msg.blockIds.length, + numBlockIds, client.getClientId(), getRemoteAddress(client.getChannel())); } - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); + callback.onSuccess(new StreamHandle(streamId, numBlockIds).toByteBuffer()); } finally { responseDelayContext.stop(); } @@ -224,7 +239,10 @@ private class ManagedBufferIterator implements Iterator { private final Function blockDataForIndexFn; private final int size; - ManagedBufferIterator(final String appId, final String execId, String[] blockIds) { + ManagedBufferIterator(OpenBlocks msg) { + String appId = msg.appId; + String execId = msg.execId; + String[] blockIds = msg.blockIds; String[] blockId0Parts = blockIds[0].split("_"); if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) { final int shuffleId = Integer.parseInt(blockId0Parts[1]); @@ -272,6 +290,21 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) { return mapIdAndReduceIds; } + ManagedBufferIterator(FetchShuffleBlocks msg, int numBlockIds) { + final int[] mapIdAndReduceIds = new int[2 * numBlockIds]; + int idx = 0; + for (int i = 0; i < msg.mapIds.length; i++) { + for (int reduceId : msg.reduceIds[i]) { + mapIdAndReduceIds[idx++] = msg.mapIds[i]; + mapIdAndReduceIds[idx++] = reduceId; + } + } + assert(idx == 2 * numBlockIds); + size = mapIdAndReduceIds.length; + blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId, + msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + } + @Override public boolean hasNext() { return index < size; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 30587023877c..cc11e9206737 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -19,8 +19,11 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import com.google.common.primitives.Ints; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -31,6 +34,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.TransportConf; @@ -48,7 +52,7 @@ public class OneForOneBlockFetcher { private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class); private final TransportClient client; - private final OpenBlocks openMessage; + private final BlockTransferMessage message; private final String[] blockIds; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; @@ -76,12 +80,71 @@ public OneForOneBlockFetcher( TransportConf transportConf, DownloadFileManager downloadFileManager) { this.client = client; - this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; this.downloadFileManager = downloadFileManager; + if (blockIds.length == 0) { + throw new IllegalArgumentException("Zero-sized blockIds array"); + } + if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds); + } else { + this.message = new OpenBlocks(appId, execId, blockIds); + } + } + + private boolean isShuffleBlocks(String[] blockIds) { + for (String blockId : blockIds) { + if (!blockId.startsWith("shuffle_")) { + return false; + } + } + return true; + } + + /** + * Analyze the pass in blockIds and create FetchShuffleBlocks message. + * The blockIds has been sorted by mapId and reduceId. It's produced in + * org.apache.spark.MapOutputTracker.convertMapStatuses. + */ + private FetchShuffleBlocks createFetchShuffleBlocksMsg( + String appId, String execId, String[] blockIds) { + int shuffleId = splitBlockId(blockIds[0])[0]; + HashMap> mapIdToReduceIds = new HashMap<>(); + for (String blockId : blockIds) { + int[] blockIdParts = splitBlockId(blockId); + if (blockIdParts[0] != shuffleId) { + throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + + ", got:" + blockId); + } + int mapId = blockIdParts[1]; + if (!mapIdToReduceIds.containsKey(mapId)) { + mapIdToReduceIds.put(mapId, new ArrayList<>()); + } + mapIdToReduceIds.get(mapId).add(blockIdParts[2]); + } + int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet()); + int[][] reduceIdArr = new int[mapIds.length][]; + for (int i = 0; i < mapIds.length; i++) { + reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); + } + return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIdArr); + } + + /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ + private int[] splitBlockId(String blockId) { + String[] blockIdParts = blockId.split("_"); + if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + throw new IllegalArgumentException( + "Unexpected shuffle block id format: " + blockId); + } + return new int[] { + Integer.parseInt(blockIdParts[1]), + Integer.parseInt(blockIdParts[2]), + Integer.parseInt(blockIdParts[3]) + }; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -106,11 +169,7 @@ public void onFailure(int chunkIndex, Throwable e) { * {@link StreamHandle}. We will send all fetch requests immediately, without throttling. */ public void start() { - if (blockIds.length == 0) { - throw new IllegalArgumentException("Zero-sized blockIds array"); - } - - client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() { + client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { try { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index e09775644963..29d7edc25813 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -31,11 +31,13 @@ * by Spark's NettyBlockTransferService. * * At a high level: - * - OpenBlock is handled by both services, but only services shuffle files for the external - * shuffle service. It returns a StreamHandle. + * - OpenBlock is logically only handled by the NettyBlockTransferService, but for the capability + * for old version Spark, we still keep it in external shuffle service. + * It returns a StreamHandle. * - UploadBlock is only handled by the NettyBlockTransferService. * - RegisterExecutor is only handled by the external shuffle service. * - RemoveBlocks is only handled by the external shuffle service. + * - FetchShuffleBlocks is handled by both services for shuffle files. It returns a StreamHandle. */ public abstract class BlockTransferMessage implements Encodable { protected abstract Type type(); @@ -43,7 +45,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), + FETCH_SHUFFLE_BLOCKS(9); private final byte id; @@ -71,6 +74,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 6: return UploadBlockStream.decode(buf); case 7: return RemoveBlocks.decode(buf); case 8: return BlocksRemoved.decode(buf); + case 9: return FetchShuffleBlocks.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java new file mode 100644 index 000000000000..466eeb3e048a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** Request to read a set of blocks. Returns {@link StreamHandle}. */ +public class FetchShuffleBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, + // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. + public final int[] mapIds; + public final int[][] reduceIds; + + public FetchShuffleBlocks( + String appId, + String execId, + int shuffleId, + int[] mapIds, + int[][] reduceIds) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapIds = mapIds; + this.reduceIds = reduceIds; + assert(mapIds.length == reduceIds.length); + } + + @Override + protected Type type() { return Type.FETCH_SHUFFLE_BLOCKS; } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapIds", Arrays.toString(mapIds)) + .add("reduceIds", Arrays.deepToString(reduceIds)) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + FetchShuffleBlocks that = (FetchShuffleBlocks) o; + + if (shuffleId != that.shuffleId) return false; + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; + if (!Arrays.equals(mapIds, that.mapIds)) return false; + return Arrays.deepEquals(reduceIds, that.reduceIds); + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + shuffleId; + result = 31 * result + Arrays.hashCode(mapIds); + result = 31 * result + Arrays.deepHashCode(reduceIds); + return result; + } + + @Override + public int encodedLength() { + int encodedLengthOfReduceIds = 0; + for (int[] ids: reduceIds) { + encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids); + } + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 4 /* encoded length of shuffleId */ + + Encoders.IntArrays.encodedLength(mapIds) + + 4 /* encoded length of reduceIds.size() */ + + encodedLengthOfReduceIds; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + Encoders.IntArrays.encode(buf, mapIds); + buf.writeInt(reduceIds.length); + for (int[] ids: reduceIds) { + Encoders.IntArrays.encode(buf, ids); + } + } + + public static FetchShuffleBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int[] mapIds = Encoders.IntArrays.decode(buf); + int reduceIdsSize = buf.readInt(); + int[][] reduceIds = new int[reduceIdsSize][]; + for (int i = 0; i < reduceIdsSize; i++) { + reduceIds[i] = Encoders.IntArrays.decode(buf); + } + return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIds); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index 86c8609e7070..649c471dc167 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -28,6 +28,9 @@ public class BlockTransferMessagesSuite { @Test public void serializeOpenShuffleBlocks() { checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize(new FetchShuffleBlocks( + "app-1", "exec-2", 0, new int[] {0, 1}, + new int[][] {{ 0, 1 }, { 0, 1, 2 }})); checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index d51e14a66faf..3d30fd02a8ca 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -38,6 +38,7 @@ import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; @@ -81,11 +82,27 @@ public void testRegisterExecutor() { } @Test - public void testOpenShuffleBlocks() { + public void testCompatibilityWithOldVersion() { when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]); when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); - checkOpenBlocksReceive(new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }, blockMarkers); + OpenBlocks openBlocks = new OpenBlocks( + "app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }); + checkOpenBlocksReceive(openBlocks, blockMarkers); + + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verifyOpenBlockLatencyMetrics(); + } + + @Test + public void testFetchShuffleBlocks() { + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(blockMarkers[0]); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); + + FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( + "app0", "exec1", 0, new int[] { 0 }, new int[][] {{ 0, 1 }}); + checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers); verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); @@ -97,7 +114,9 @@ public void testOpenDiskPersistedRDDBlocks() { when(blockResolver.getRddBlockData("app0", "exec1", 0, 0)).thenReturn(blockMarkers[0]); when(blockResolver.getRddBlockData("app0", "exec1", 0, 1)).thenReturn(blockMarkers[1]); - checkOpenBlocksReceive(new String[] { "rdd_0_0", "rdd_0_1" }, blockMarkers); + OpenBlocks openBlocks = new OpenBlocks( + "app0", "exec1", new String[] { "rdd_0_0", "rdd_0_1" }); + checkOpenBlocksReceive(openBlocks, blockMarkers); verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 0); verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 1); @@ -115,18 +134,19 @@ public void testOpenDiskPersistedRDDBlocksWithMissingBlock() { when(blockResolver.getRddBlockData("app0", "exec1", 0, 1)) .thenReturn(null); - checkOpenBlocksReceive(new String[] { "rdd_0_0", "rdd_0_1" }, blockMarkersWithMissingBlock); + OpenBlocks openBlocks = new OpenBlocks( + "app0", "exec1", new String[] { "rdd_0_0", "rdd_0_1" }); + checkOpenBlocksReceive(openBlocks, blockMarkersWithMissingBlock); verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 0); verify(blockResolver, times(1)).getRddBlockData("app0", "exec1", 0, 1); } - private void checkOpenBlocksReceive(String[] blockIds, ManagedBuffer[] blockMarkers) { + private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] blockMarkers) { when(client.getClientId()).thenReturn("app0"); RpcResponseCallback callback = mock(RpcResponseCallback.class); - ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", blockIds).toByteBuffer(); - handler.receive(client, openBlocks, callback); + handler.receive(client, msg.toByteBuffer(), callback); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 95460637db89..66633cc7a359 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.nio.ByteBuffer; +import java.util.HashMap; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.concurrent.atomic.AtomicInteger; @@ -44,6 +45,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; import org.apache.spark.network.shuffle.protocol.OpenBlocks; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.util.MapConfigProvider; @@ -57,20 +59,69 @@ public class OneForOneBlockFetcherSuite { public void testFetchOne() { LinkedHashMap blocks = Maps.newLinkedHashMap(); blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0 }}), + conf); + + verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + } + + @Test + public void testUseOldProtocol() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - BlockFetchingListener listener = fetchBlocks(blocks); + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new OpenBlocks("app-id", "exec-id", blockIds), + new TransportConf("shuffle", new MapConfigProvider( + new HashMap() {{ + put("spark.shuffle.useOldFetchProtocol", "true"); + }} + ))); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); } + @Test + public void testFetchThreeShuffleBlocks() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shuffle_0_0_1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("shuffle_0_0_2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new FetchShuffleBlocks("app-id", "exec-id", 0, new int[] { 0 }, new int[][] {{ 0, 1, 2 }}), + conf); + + for (int i = 0; i < 3; i ++) { + verify(listener, times(1)).onBlockFetchSuccess( + "shuffle_0_0_" + i, blocks.get("shuffle_0_0_" + i)); + } + } + @Test public void testFetchThree() { LinkedHashMap blocks = Maps.newLinkedHashMap(); blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - BlockFetchingListener listener = fetchBlocks(blocks); + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new OpenBlocks("app-id", "exec-id", blockIds), + conf); for (int i = 0; i < 3; i ++) { verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); @@ -83,8 +134,13 @@ public void testFailure() { blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); blocks.put("b1", null); blocks.put("b2", null); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - BlockFetchingListener listener = fetchBlocks(blocks); + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new OpenBlocks("app-id", "exec-id", blockIds), + conf); // Each failure will cause a failure to be invoked in all remaining block fetches. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); @@ -98,8 +154,13 @@ public void testFailureAndSuccess() { blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); blocks.put("b1", null); blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[21]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); - BlockFetchingListener listener = fetchBlocks(blocks); + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new OpenBlocks("app-id", "exec-id", blockIds), + conf); // We may call both success and failure for the same block. verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); @@ -111,7 +172,11 @@ public void testFailureAndSuccess() { @Test public void testEmptyBlockFetch() { try { - fetchBlocks(Maps.newLinkedHashMap()); + fetchBlocks( + Maps.newLinkedHashMap(), + new String[] {}, + new OpenBlocks("app-id", "exec-id", new String[] {}), + conf); fail(); } catch (IllegalArgumentException e) { assertEquals("Zero-sized blockIds array", e.getMessage()); @@ -126,12 +191,15 @@ public void testEmptyBlockFetch() { * * If a block's buffer is "null", an exception will be thrown instead. */ - private static BlockFetchingListener fetchBlocks(LinkedHashMap blocks) { + private static BlockFetchingListener fetchBlocks( + LinkedHashMap blocks, + String[] blockIds, + BlockTransferMessage expectMessage, + TransportConf transportConf) { TransportClient client = mock(TransportClient.class); BlockFetchingListener listener = mock(BlockFetchingListener.class); - String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); OneForOneBlockFetcher fetcher = - new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, conf); + new OneForOneBlockFetcher(client, "app-id", "exec-id", blockIds, listener, transportConf); // Respond to the "OpenBlocks" message with an appropriate ShuffleStreamHandle with streamId 123 doAnswer(invocationOnMock -> { @@ -139,7 +207,7 @@ private static BlockFetchingListener fetchBlocks(LinkedHashMap + val blocks = fetchShuffleBlocks.mapIds.zipWithIndex.flatMap { case (mapId, index) => + fetchShuffleBlocks.reduceIds.apply(index).map { reduceId => + blockManager.getBlockData( + ShuffleBlockId(fetchShuffleBlocks.shuffleId, mapId, reduceId)) + } + } + val numBlockIds = fetchShuffleBlocks.reduceIds.map(_.length).sum + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, + client.getChannel) + logTrace(s"Registered streamId $streamId with $numBlockIds buffers") + responseContext.onSuccess( + new StreamHandle(streamId, numBlockIds).toByteBuffer) + case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. val (level, classTag) = deserializeMetadata(uploadBlock.metadata) diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 5dc7a8951497..2de0a1d9290d 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -136,6 +136,8 @@ license: | - Since Spark 3.0, when Avro files are written with user provided non-nullable schema, even the catalyst schema is nullable, Spark is still able to write the files. However, Spark will throw runtime NPE if any of the records contains null. + - Since Spark 3.0, we use a new protocol for fetching shuffle blocks, for external shuffle service users, we need to upgrade the server correspondingly. Otherwise, we'll get the error message `UnsupportedOperationException: Unexpected message: FetchShuffleBlocks`. If it is hard to upgrade the shuffle service right now, you can still use the old protocol by setting `spark.shuffle.useOldFetchProtocol` to `true`. + ## Upgrading from Spark SQL 2.4 to 2.4.1 - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was