diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml
index 9d5bc9aae0719..d328a7de0a762 100644
--- a/common/network-common/pom.xml
+++ b/common/network-common/pom.xml
@@ -91,6 +91,10 @@
org.apache.commons
commons-crypto
+
+ org.roaringbitmap
+ RoaringBitmap
+
diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
index 490915f6de4b3..4fa191b3917e3 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java
@@ -17,9 +17,11 @@
package org.apache.spark.network.protocol;
+import java.io.IOException;
import java.nio.charset.StandardCharsets;
import io.netty.buffer.ByteBuf;
+import org.roaringbitmap.RoaringBitmap;
/** Provides a canonical set of Encoders for simple types. */
public class Encoders {
@@ -44,6 +46,40 @@ public static String decode(ByteBuf buf) {
}
}
+ /** Bitmaps are encoded with their serialization length followed by the serialization bytes. */
+ public static class Bitmaps {
+ public static int encodedLength(RoaringBitmap b) {
+ // Compress the bitmap before serializing it. Note that since BlockTransferMessage
+ // needs to invoke encodedLength first to figure out the length for the ByteBuf, it
+ // guarantees that the bitmap will always be compressed before being serialized.
+ b.trim();
+ b.runOptimize();
+ return b.serializedSizeInBytes();
+ }
+
+ public static void encode(ByteBuf buf, RoaringBitmap b) {
+ int encodedLength = b.serializedSizeInBytes();
+ // RoaringBitmap requires nio ByteBuffer for serde. We expose the netty ByteBuf as a nio
+ // ByteBuffer. Here, we need to explicitly manage the index so we can write into the
+ // ByteBuffer, and the write is reflected in the underneath ByteBuf.
+ b.serialize(buf.nioBuffer(buf.writerIndex(), encodedLength));
+ buf.writerIndex(buf.writerIndex() + encodedLength);
+ }
+
+ public static RoaringBitmap decode(ByteBuf buf) {
+ RoaringBitmap bitmap = new RoaringBitmap();
+ try {
+ bitmap.deserialize(buf.nioBuffer());
+ // RoaringBitmap deserialize does not advance the reader index of the underlying ByteBuf.
+ // Manually update the index here.
+ buf.readerIndex(buf.readerIndex() + bitmap.serializedSizeInBytes());
+ } catch (IOException e) {
+ throw new RuntimeException("Exception while decoding bitmap", e);
+ }
+ return bitmap;
+ }
+ }
+
/** Byte arrays are encoded with their length followed by bytes. */
public static class ByteArrays {
public static int encodedLength(byte[] arr) {
@@ -135,4 +171,31 @@ public static long[] decode(ByteBuf buf) {
return longs;
}
}
+
+ /** Bitmap arrays are encoded with the number of bitmaps followed by per-Bitmap encoding. */
+ public static class BitmapArrays {
+ public static int encodedLength(RoaringBitmap[] bitmaps) {
+ int totalLength = 4;
+ for (RoaringBitmap b : bitmaps) {
+ totalLength += Bitmaps.encodedLength(b);
+ }
+ return totalLength;
+ }
+
+ public static void encode(ByteBuf buf, RoaringBitmap[] bitmaps) {
+ buf.writeInt(bitmaps.length);
+ for (RoaringBitmap b : bitmaps) {
+ Bitmaps.encode(buf, b);
+ }
+ }
+
+ public static RoaringBitmap[] decode(ByteBuf buf) {
+ int numBitmaps = buf.readInt();
+ RoaringBitmap[] bitmaps = new RoaringBitmap[numBitmaps];
+ for (int i = 0; i < bitmaps.length; i ++) {
+ bitmaps[i] = Bitmaps.decode(buf);
+ }
+ return bitmaps;
+ }
+ }
}
diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml
index 00f1defbb0093..a4a1ff92ef9a0 100644
--- a/common/network-shuffle/pom.xml
+++ b/common/network-shuffle/pom.xml
@@ -57,6 +57,10 @@
com.google.guava
guava
+
+ org.roaringbitmap
+ RoaringBitmap
+
@@ -93,6 +97,11 @@
mockito-core
test
+
+ commons-io
+ commons-io
+ test
+
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
index e762bd2071632..37befcd4b67fa 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockStoreClient.java
@@ -29,6 +29,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientFactory;
@@ -135,4 +136,24 @@ public void onFailure(Throwable t) {
hostLocalDirsCompletable.completeExceptionally(e);
}
}
+
+ /**
+ * Push a sequence of shuffle blocks in a best-effort manner to a remote node asynchronously.
+ * These shuffle blocks, along with blocks pushed by other clients, will be merged into
+ * per-shuffle partition merged shuffle files on the destination node.
+ *
+ * @param host the host of the remote node.
+ * @param port the port of the remote node.
+ * @param blockIds block ids to be pushed
+ * @param buffers buffers to be pushed
+ * @param listener the listener to receive block push status.
+ */
+ public void pushBlocks(
+ String host,
+ int port,
+ String[] blockIds,
+ ManagedBuffer[] buffers,
+ BlockFetchingListener listener) {
+ throw new UnsupportedOperationException();
+ }
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
new file mode 100644
index 0000000000000..308b0b7a6b33b
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.net.ConnectException;
+
+import com.google.common.base.Throwables;
+
+/**
+ * Plugs into {@link RetryingBlockFetcher} to further control when an exception should be retried
+ * and logged.
+ * Note: {@link RetryingBlockFetcher} will delegate the exception to this handler only when
+ * - remaining retries < max retries
+ * - exception is an IOException
+ */
+
+public interface ErrorHandler {
+
+ boolean shouldRetryError(Throwable t);
+
+ default boolean shouldLogError(Throwable t) {
+ return true;
+ }
+
+ /**
+ * A no-op error handler instance.
+ */
+ ErrorHandler NOOP_ERROR_HANDLER = t -> true;
+
+ /**
+ * The error handler for pushing shuffle blocks to remote shuffle services.
+ */
+ class BlockPushErrorHandler implements ErrorHandler {
+ /**
+ * String constant used for generating exception messages indicating a block to be merged
+ * arrives too late on the server side, and also for later checking such exceptions on the
+ * client side. When we get a block push failure because of the block arrives too late, we
+ * will not retry pushing the block nor log the exception on the client side.
+ */
+ public static final String TOO_LATE_MESSAGE_SUFFIX =
+ "received after merged shuffle is finalized";
+
+ /**
+ * String constant used for generating exception messages indicating the server couldn't
+ * append a block after all available attempts due to collision with other blocks belonging
+ * to the same shuffle partition, and also for later checking such exceptions on the client
+ * side. When we get a block push failure because of the block couldn't be written due to
+ * this reason, we will not log the exception on the client side.
+ */
+ public static final String BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX =
+ "Couldn't find an opportunity to write block";
+
+ @Override
+ public boolean shouldRetryError(Throwable t) {
+ // If it is a connection time out or a connection closed exception, no need to retry.
+ if (t.getCause() != null && t.getCause() instanceof ConnectException) {
+ return false;
+ }
+ // If the block is too late, there is no need to retry it
+ return !Throwables.getStackTraceAsString(t).contains(TOO_LATE_MESSAGE_SUFFIX);
+ }
+
+ @Override
+ public boolean shouldLogError(Throwable t) {
+ String errorStackTrace = Throwables.getStackTraceAsString(t);
+ return !errorStackTrace.contains(BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX) &&
+ !errorStackTrace.contains(TOO_LATE_MESSAGE_SUFFIX);
+ }
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
index 33865a21ea914..321b25305c504 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
@@ -32,6 +32,7 @@
import com.codahale.metrics.Timer;
import com.codahale.metrics.Counter;
import com.google.common.annotations.VisibleForTesting;
+import org.apache.spark.network.client.StreamCallbackWithID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -61,11 +62,21 @@ public class ExternalBlockHandler extends RpcHandler {
final ExternalShuffleBlockResolver blockManager;
private final OneForOneStreamManager streamManager;
private final ShuffleMetrics metrics;
+ private final MergedShuffleFileManager mergeManager;
public ExternalBlockHandler(TransportConf conf, File registeredExecutorFile)
throws IOException {
this(new OneForOneStreamManager(),
- new ExternalShuffleBlockResolver(conf, registeredExecutorFile));
+ new ExternalShuffleBlockResolver(conf, registeredExecutorFile),
+ new NoOpMergedShuffleFileManager());
+ }
+
+ public ExternalBlockHandler(
+ TransportConf conf,
+ File registeredExecutorFile,
+ MergedShuffleFileManager mergeManager) throws IOException {
+ this(new OneForOneStreamManager(),
+ new ExternalShuffleBlockResolver(conf, registeredExecutorFile), mergeManager);
}
@VisibleForTesting
@@ -78,9 +89,19 @@ public ExternalShuffleBlockResolver getBlockResolver() {
public ExternalBlockHandler(
OneForOneStreamManager streamManager,
ExternalShuffleBlockResolver blockManager) {
+ this(streamManager, blockManager, new NoOpMergedShuffleFileManager());
+ }
+
+ /** Enables mocking out the StreamManager, BlockManager, and MergeManager. */
+ @VisibleForTesting
+ public ExternalBlockHandler(
+ OneForOneStreamManager streamManager,
+ ExternalShuffleBlockResolver blockManager,
+ MergedShuffleFileManager mergeManager) {
this.metrics = new ShuffleMetrics();
this.streamManager = streamManager;
this.blockManager = blockManager;
+ this.mergeManager = mergeManager;
}
@Override
@@ -89,6 +110,21 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb
handleMessage(msgObj, client, callback);
}
+ @Override
+ public StreamCallbackWithID receiveStream(
+ TransportClient client,
+ ByteBuffer messageHeader,
+ RpcResponseCallback callback) {
+ BlockTransferMessage msgObj = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader);
+ if (msgObj instanceof PushBlockStream) {
+ PushBlockStream message = (PushBlockStream) msgObj;
+ checkAuth(client, message.appId);
+ return mergeManager.receiveBlockDataAsStream(message);
+ } else {
+ throw new UnsupportedOperationException("Unexpected message with #receiveStream: " + msgObj);
+ }
+ }
+
protected void handleMessage(
BlockTransferMessage msgObj,
TransportClient client,
@@ -139,6 +175,7 @@ protected void handleMessage(
RegisterExecutor msg = (RegisterExecutor) msgObj;
checkAuth(client, msg.appId);
blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo);
+ mergeManager.registerExecutor(msg.appId, msg.executorInfo.localDirs);
callback.onSuccess(ByteBuffer.wrap(new byte[0]));
} finally {
responseDelayContext.stop();
@@ -156,6 +193,20 @@ protected void handleMessage(
Map localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
callback.onSuccess(new LocalDirsForExecutors(localDirs).toByteBuffer());
+ } else if (msgObj instanceof FinalizeShuffleMerge) {
+ final Timer.Context responseDelayContext =
+ metrics.finalizeShuffleMergeLatencyMillis.time();
+ FinalizeShuffleMerge msg = (FinalizeShuffleMerge) msgObj;
+ try {
+ checkAuth(client, msg.appId);
+ MergeStatuses statuses = mergeManager.finalizeShuffleMerge(msg);
+ callback.onSuccess(statuses.toByteBuffer());
+ } catch(IOException e) {
+ throw new RuntimeException(String.format("Error while finalizing shuffle merge "
+ + "for application %s shuffle %d", msg.appId, msg.shuffleId), e);
+ } finally {
+ responseDelayContext.stop();
+ }
} else {
throw new UnsupportedOperationException("Unexpected message: " + msgObj);
}
@@ -225,6 +276,8 @@ public class ShuffleMetrics implements MetricSet {
private final Timer openBlockRequestLatencyMillis = new Timer();
// Time latency for executor registration latency in ms
private final Timer registerExecutorRequestLatencyMillis = new Timer();
+ // Time latency for processing finalize shuffle merge request latency in ms
+ private final Timer finalizeShuffleMergeLatencyMillis = new Timer();
// Block transfer rate in byte per second
private final Meter blockTransferRateBytes = new Meter();
// Number of active connections to the shuffle service
@@ -236,6 +289,7 @@ public ShuffleMetrics() {
allMetrics = new HashMap<>();
allMetrics.put("openBlockRequestLatencyMillis", openBlockRequestLatencyMillis);
allMetrics.put("registerExecutorRequestLatencyMillis", registerExecutorRequestLatencyMillis);
+ allMetrics.put("finalizeShuffleMergeLatencyMillis", finalizeShuffleMergeLatencyMillis);
allMetrics.put("blockTransferRateBytes", blockTransferRateBytes);
allMetrics.put("registeredExecutorsSize",
(Gauge) () -> blockManager.getRegisteredExecutorsSize());
@@ -373,6 +427,54 @@ public ManagedBuffer next() {
}
}
+ /**
+ * Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle
+ * is not enabled.
+ */
+ private static class NoOpMergedShuffleFileManager implements MergedShuffleFileManager {
+
+ @Override
+ public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) {
+ throw new UnsupportedOperationException("Cannot handle shuffle block merge");
+ }
+
+ @Override
+ public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException {
+ throw new UnsupportedOperationException("Cannot handle shuffle block merge");
+ }
+
+ @Override
+ public void registerApplication(String appId, String user) {
+ // No-op. Do nothing.
+ }
+
+ @Override
+ public void registerExecutor(String appId, String[] localDirs) {
+ // No-Op. Do nothing.
+ }
+
+ @Override
+ public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
+ throw new UnsupportedOperationException("Cannot handle shuffle block merge");
+ }
+
+ @Override
+ public ManagedBuffer getMergedBlockData(
+ String appId, int shuffleId, int reduceId, int chunkId) {
+ throw new UnsupportedOperationException("Cannot handle shuffle block merge");
+ }
+
+ @Override
+ public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) {
+ throw new UnsupportedOperationException("Cannot handle shuffle block merge");
+ }
+
+ @Override
+ public String[] getMergedBlockDirs(String appId) {
+ throw new UnsupportedOperationException("Cannot handle shuffle block merge");
+ }
+ }
+
@Override
public void channelActive(TransportClient client) {
metrics.activeConnections.inc();
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
index 76e23e7c69d2d..eca35ed290467 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
@@ -20,21 +20,24 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;
import com.codahale.metrics.MetricSet;
import com.google.common.collect.Lists;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.client.TransportClientBootstrap;
-import org.apache.spark.network.shuffle.protocol.*;
-
-import org.apache.spark.network.TransportContext;
import org.apache.spark.network.crypto.AuthClientBootstrap;
import org.apache.spark.network.sasl.SecretKeyHolder;
import org.apache.spark.network.server.NoOpRpcHandler;
+import org.apache.spark.network.shuffle.protocol.*;
import org.apache.spark.network.util.TransportConf;
/**
@@ -43,6 +46,8 @@
* (via BlockTransferService), which has the downside of losing the data if we lose the executors.
*/
public class ExternalBlockStoreClient extends BlockStoreClient {
+ private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler();
+
private final TransportConf conf;
private final boolean authEnabled;
private final SecretKeyHolder secretKeyHolder;
@@ -90,12 +95,12 @@ public void fetchBlocks(
try {
int maxRetries = conf.maxIORetries();
RetryingBlockFetcher.BlockFetchStarter blockFetchStarter =
- (blockIds1, listener1) -> {
+ (inputBlockId, inputListener) -> {
// Unless this client is closed.
if (clientFactory != null) {
TransportClient client = clientFactory.createClient(host, port, maxRetries > 0);
new OneForOneBlockFetcher(client, appId, execId,
- blockIds1, listener1, conf, downloadFileManager).start();
+ inputBlockId, inputListener, conf, downloadFileManager).start();
} else {
logger.info("This clientFactory was closed. Skipping further block fetch retries.");
}
@@ -116,6 +121,43 @@ public void fetchBlocks(
}
}
+ @Override
+ public void pushBlocks(
+ String host,
+ int port,
+ String[] blockIds,
+ ManagedBuffer[] buffers,
+ BlockFetchingListener listener) {
+ checkInit();
+ assert blockIds.length == buffers.length : "Number of block ids and buffers do not match.";
+
+ Map buffersWithId = new HashMap<>();
+ for (int i = 0; i < blockIds.length; i++) {
+ buffersWithId.put(blockIds[i], buffers[i]);
+ }
+ logger.debug("Push {} shuffle blocks to {}:{}", blockIds.length, host, port);
+ try {
+ RetryingBlockFetcher.BlockFetchStarter blockPushStarter =
+ (inputBlockId, inputListener) -> {
+ TransportClient client = clientFactory.createClient(host, port);
+ new OneForOneBlockPusher(client, appId, inputBlockId, inputListener, buffersWithId)
+ .start();
+ };
+ int maxRetries = conf.maxIORetries();
+ if (maxRetries > 0) {
+ new RetryingBlockFetcher(
+ conf, blockPushStarter, blockIds, listener, PUSH_ERROR_HANDLER).start();
+ } else {
+ blockPushStarter.createAndStart(blockIds, listener);
+ }
+ } catch (Exception e) {
+ logger.error("Exception while beginning pushBlocks", e);
+ for (String blockId : blockIds) {
+ listener.onBlockFetchFailure(blockId, e);
+ }
+ }
+ }
+
@Override
public MetricSet shuffleMetrics() {
checkInit();
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java
new file mode 100644
index 0000000000000..e9d9e53495469
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import com.google.common.base.Preconditions;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+import org.roaringbitmap.RoaringBitmap;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Contains meta information for a merged block. Currently this information constitutes:
+ * 1. Number of chunks in a merged shuffle block.
+ * 2. Bitmaps for each chunk in the merged block. A chunk bitmap contains all the mapIds that were
+ * merged to that merged block chunk.
+ */
+public class MergedBlockMeta {
+ private final int numChunks;
+ private final ManagedBuffer chunksBitmapBuffer;
+
+ public MergedBlockMeta(int numChunks, ManagedBuffer chunksBitmapBuffer) {
+ this.numChunks = numChunks;
+ this.chunksBitmapBuffer = Preconditions.checkNotNull(chunksBitmapBuffer);
+ }
+
+ public int getNumChunks() {
+ return numChunks;
+ }
+
+ public ManagedBuffer getChunksBitmapBuffer() {
+ return chunksBitmapBuffer;
+ }
+
+ public RoaringBitmap[] readChunkBitmaps() throws IOException {
+ ByteBuf buf = Unpooled.wrappedBuffer(chunksBitmapBuffer.nioByteBuffer());
+ List bitmaps = new ArrayList<>();
+ while(buf.isReadable()) {
+ bitmaps.add(Encoders.Bitmaps.decode(buf));
+ }
+ assert (bitmaps.size() == numChunks);
+ return bitmaps.toArray(new RoaringBitmap[0]);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java
new file mode 100644
index 0000000000000..ef4dbb2bd0059
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.io.IOException;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.client.StreamCallbackWithID;
+import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
+import org.apache.spark.network.shuffle.protocol.MergeStatuses;
+import org.apache.spark.network.shuffle.protocol.PushBlockStream;
+
+
+/**
+ * The MergedShuffleFileManager is used to process push based shuffle when enabled. It works
+ * along side {@link ExternalBlockHandler} and serves as an RPCHandler for
+ * {@link org.apache.spark.network.server.RpcHandler#receiveStream}, where it processes the
+ * remotely pushed streams of shuffle blocks to merge them into merged shuffle files. Right
+ * now, support for push based shuffle is only implemented for external shuffle service in
+ * YARN mode.
+ */
+public interface MergedShuffleFileManager {
+ /**
+ * Provides the stream callback used to process a remotely pushed block. The callback is
+ * used by the {@link org.apache.spark.network.client.StreamInterceptor} installed on the
+ * channel to process the block data in the channel outside of the message frame.
+ *
+ * @param msg metadata of the remotely pushed blocks. This is processed inside the message frame
+ * @return A stream callback to process the block data in streaming fashion as it arrives
+ */
+ StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg);
+
+ /**
+ * Handles the request to finalize shuffle merge for a given shuffle.
+ *
+ * @param msg contains appId and shuffleId to uniquely identify a shuffle to be finalized
+ * @return The statuses of the merged shuffle partitions for the given shuffle on this
+ * shuffle service
+ * @throws IOException
+ */
+ MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException;
+
+ /**
+ * Registers an application when it starts. It also stores the username which is necessary
+ * for generating the host local directories for merged shuffle files.
+ * Right now, this is invoked by YarnShuffleService.
+ *
+ * @param appId application ID
+ * @param user username
+ */
+ void registerApplication(String appId, String user);
+
+ /**
+ * Registers an executor with its local dir list when it starts. This provides the specific path
+ * so MergedShuffleFileManager knows where to store and look for shuffle data for a
+ * given application. It is invoked by the RPC call when executor tries to register with the
+ * local shuffle service.
+ *
+ * @param appId application ID
+ * @param localDirs The list of local dirs that this executor gets granted from NodeManager
+ */
+ void registerExecutor(String appId, String[] localDirs);
+
+ /**
+ * Invoked when an application finishes. This cleans up any remaining metadata associated with
+ * this application, and optionally deletes the application specific directory path.
+ *
+ * @param appId application ID
+ * @param cleanupLocalDirs flag indicating whether MergedShuffleFileManager should handle
+ * deletion of local dirs itself.
+ */
+ void applicationRemoved(String appId, boolean cleanupLocalDirs);
+
+ /**
+ * Get the buffer for a given merged shuffle chunk when serving merged shuffle to reducers
+ *
+ * @param appId application ID
+ * @param shuffleId shuffle ID
+ * @param reduceId reducer ID
+ * @param chunkId merged shuffle file chunk ID
+ * @return The {@link ManagedBuffer} for the given merged shuffle chunk
+ */
+ ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId);
+
+ /**
+ * Get the meta information of a merged block.
+ *
+ * @param appId application ID
+ * @param shuffleId shuffle ID
+ * @param reduceId reducer ID
+ * @return meta information of a merged block
+ */
+ MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId);
+
+ /**
+ * Get the local directories which stores the merged shuffle files.
+ *
+ * @param appId application ID
+ */
+ String[] getMergedBlockDirs(String appId);
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java
new file mode 100644
index 0000000000000..407b248170a46
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Map;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.shuffle.protocol.PushBlockStream;
+
+/**
+ * Similar to {@link OneForOneBlockFetcher}, but for pushing blocks to remote shuffle service to
+ * be merged instead of for fetching them from remote shuffle services. This is used by
+ * ShuffleWriter when the block push process is initiated. The supplied BlockFetchingListener
+ * is used to handle the success or failure in pushing each blocks.
+ */
+public class OneForOneBlockPusher {
+ private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockPusher.class);
+ private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler();
+
+ private final TransportClient client;
+ private final String appId;
+ private final String[] blockIds;
+ private final BlockFetchingListener listener;
+ private final Map buffers;
+
+ public OneForOneBlockPusher(
+ TransportClient client,
+ String appId,
+ String[] blockIds,
+ BlockFetchingListener listener,
+ Map buffers) {
+ this.client = client;
+ this.appId = appId;
+ this.blockIds = blockIds;
+ this.listener = listener;
+ this.buffers = buffers;
+ }
+
+ private class BlockPushCallback implements RpcResponseCallback {
+
+ private int index;
+ private String blockId;
+
+ BlockPushCallback(int index, String blockId) {
+ this.index = index;
+ this.blockId = blockId;
+ }
+
+ @Override
+ public void onSuccess(ByteBuffer response) {
+ // On receipt of a successful block push
+ listener.onBlockFetchSuccess(blockId, new NioManagedBuffer(ByteBuffer.allocate(0)));
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ // Since block push is best effort, i.e., if we encountered a block push failure that's not
+ // retriable or exceeding the max retires, we should not fail all remaining block pushes.
+ // The best effort nature makes block push tolerable of a partial completion. Thus, we only
+ // fail the block that's actually failed. Not that, on the RetryingBlockFetcher side, once
+ // retry is initiated, it would still invalidate the previous active retry listener, and
+ // retry all outstanding blocks. We are preventing forwarding unnecessary block push failures
+ // to the parent listener of the retry listener. The only exceptions would be if the block
+ // push failure is due to block arriving on the server side after merge finalization, or the
+ // client fails to establish connection to the server side. In both cases, we would fail all
+ // remaining blocks.
+ if (PUSH_ERROR_HANDLER.shouldRetryError(e)) {
+ String[] targetBlockId = Arrays.copyOfRange(blockIds, index, index + 1);
+ failRemainingBlocks(targetBlockId, e);
+ } else {
+ String[] targetBlockId = Arrays.copyOfRange(blockIds, index, blockIds.length);
+ failRemainingBlocks(targetBlockId, e);
+ }
+ }
+ }
+
+ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
+ for (String blockId : failedBlockIds) {
+ try {
+ listener.onBlockFetchFailure(blockId, e);
+ } catch (Exception e2) {
+ logger.error("Error in block push failure callback", e2);
+ }
+ }
+ }
+
+ /**
+ * Begins the block pushing process, calling the listener with every block pushed.
+ */
+ public void start() {
+ logger.debug("Start pushing {} blocks", blockIds.length);
+ for (int i = 0; i < blockIds.length; i++) {
+ assert buffers.containsKey(blockIds[i]) : "Could not find the block buffer for block "
+ + blockIds[i];
+ ByteBuffer header = new PushBlockStream(appId, blockIds[i], i).toByteBuffer();
+ client.uploadStream(new NioManagedBuffer(header), buffers.get(blockIds[i]),
+ new BlockPushCallback(i, blockIds[i]));
+ }
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
index 6bf3da94030d4..43bde1610e41e 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java
@@ -99,11 +99,14 @@ void createAndStart(String[] blockIds, BlockFetchingListener listener)
*/
private RetryingBlockFetchListener currentListener;
+ private final ErrorHandler errorHandler;
+
public RetryingBlockFetcher(
TransportConf conf,
RetryingBlockFetcher.BlockFetchStarter fetchStarter,
String[] blockIds,
- BlockFetchingListener listener) {
+ BlockFetchingListener listener,
+ ErrorHandler errorHandler) {
this.fetchStarter = fetchStarter;
this.listener = listener;
this.maxRetries = conf.maxIORetries();
@@ -111,6 +114,15 @@ public RetryingBlockFetcher(
this.outstandingBlocksIds = Sets.newLinkedHashSet();
Collections.addAll(outstandingBlocksIds, blockIds);
this.currentListener = new RetryingBlockFetchListener();
+ this.errorHandler = errorHandler;
+ }
+
+ public RetryingBlockFetcher(
+ TransportConf conf,
+ BlockFetchStarter fetchStarter,
+ String[] blockIds,
+ BlockFetchingListener listener) {
+ this(conf, fetchStarter, blockIds, listener, ErrorHandler.NOOP_ERROR_HANDLER);
}
/**
@@ -178,7 +190,7 @@ private synchronized boolean shouldRetry(Throwable e) {
boolean isIOException = e instanceof IOException
|| (e.getCause() != null && e.getCause() instanceof IOException);
boolean hasRemainingRetries = retryCount < maxRetries;
- return isIOException && hasRemainingRetries;
+ return isIOException && hasRemainingRetries && errorHandler.shouldRetryError(e);
}
/**
@@ -215,8 +227,15 @@ public void onBlockFetchFailure(String blockId, Throwable exception) {
if (shouldRetry(exception)) {
initiateRetry();
} else {
- logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)",
- blockId, retryCount), exception);
+ if (errorHandler.shouldLogError(exception)) {
+ logger.error(
+ String.format("Failed to fetch block %s, and will not retry (%s retries)",
+ blockId, retryCount), exception);
+ } else {
+ logger.debug(
+ String.format("Failed to fetch block %s, and will not retry (%s retries)",
+ blockId, retryCount), exception);
+ }
outstandingBlocksIds.remove(blockId);
shouldForwardFailure = true;
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
index 89d8dfe8716b8..7f5058124988f 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java
@@ -47,7 +47,8 @@ public abstract class BlockTransferMessage implements Encodable {
public enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
- FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11);
+ FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11),
+ PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14);
private final byte id;
@@ -78,6 +79,9 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
case 9: return FetchShuffleBlocks.decode(buf);
case 10: return GetLocalDirsForExecutors.decode(buf);
case 11: return LocalDirsForExecutors.decode(buf);
+ case 12: return PushBlockStream.decode(buf);
+ case 13: return FinalizeShuffleMerge.decode(buf);
+ case 14: return MergeStatuses.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java
new file mode 100644
index 0000000000000..9058575df57ef
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Request to finalize merge for a given shuffle.
+ * Returns {@link MergeStatuses}
+ */
+public class FinalizeShuffleMerge extends BlockTransferMessage {
+ public final String appId;
+ public final int shuffleId;
+
+ public FinalizeShuffleMerge(
+ String appId,
+ int shuffleId) {
+ this.appId = appId;
+ this.shuffleId = shuffleId;
+ }
+
+ @Override
+ protected BlockTransferMessage.Type type() {
+ return Type.FINALIZE_SHUFFLE_MERGE;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, shuffleId);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("shuffleId", shuffleId)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof FinalizeShuffleMerge) {
+ FinalizeShuffleMerge o = (FinalizeShuffleMerge) other;
+ return Objects.equal(appId, o.appId)
+ && shuffleId == o.shuffleId;
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId) + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ buf.writeInt(shuffleId);
+ }
+
+ public static FinalizeShuffleMerge decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ int shuffleId = buf.readInt();
+ return new FinalizeShuffleMerge(appId, shuffleId);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java
new file mode 100644
index 0000000000000..f57e8b326e5e2
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import org.roaringbitmap.RoaringBitmap;
+
+import org.apache.spark.network.protocol.Encoders;
+
+/**
+ * Result returned by an ExternalShuffleService to the DAGScheduler. This represents the result
+ * of all the remote shuffle block merge operations performed by an ExternalShuffleService
+ * for a given shuffle ID. It includes the shuffle ID, an array of bitmaps each representing
+ * the set of mapper partition blocks that are merged for a given reducer partition, an array
+ * of reducer IDs, and an array of merged shuffle partition sizes. The 3 arrays list information
+ * about all the reducer partitions merged by the ExternalShuffleService in the same order.
+ */
+public class MergeStatuses extends BlockTransferMessage {
+ /** Shuffle ID **/
+ public final int shuffleId;
+ /**
+ * Array of bitmaps tracking the set of mapper partition blocks merged for each
+ * reducer partition
+ */
+ public final RoaringBitmap[] bitmaps;
+ /** Array of reducer IDs **/
+ public final int[] reduceIds;
+ /**
+ * Array of merged shuffle partition block size. Each represents the total size of all
+ * merged shuffle partition blocks for one reducer partition.
+ * **/
+ public final long[] sizes;
+
+ public MergeStatuses(
+ int shuffleId,
+ RoaringBitmap[] bitmaps,
+ int[] reduceIds,
+ long[] sizes) {
+ this.shuffleId = shuffleId;
+ this.bitmaps = bitmaps;
+ this.reduceIds = reduceIds;
+ this.sizes = sizes;
+ }
+
+ @Override
+ protected Type type() {
+ return Type.MERGE_STATUSES;
+ }
+
+ @Override
+ public int hashCode() {
+ int objectHashCode = Objects.hashCode(shuffleId);
+ return (objectHashCode * 41 + Arrays.hashCode(reduceIds) * 41
+ + Arrays.hashCode(bitmaps) * 41 + Arrays.hashCode(sizes));
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("shuffleId", shuffleId)
+ .add("reduceId size", reduceIds.length)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof MergeStatuses) {
+ MergeStatuses o = (MergeStatuses) other;
+ return Objects.equal(shuffleId, o.shuffleId)
+ && Arrays.equals(bitmaps, o.bitmaps)
+ && Arrays.equals(reduceIds, o.reduceIds)
+ && Arrays.equals(sizes, o.sizes);
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return 4 // int
+ + Encoders.BitmapArrays.encodedLength(bitmaps)
+ + Encoders.IntArrays.encodedLength(reduceIds)
+ + Encoders.LongArrays.encodedLength(sizes);
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ buf.writeInt(shuffleId);
+ Encoders.BitmapArrays.encode(buf, bitmaps);
+ Encoders.IntArrays.encode(buf, reduceIds);
+ Encoders.LongArrays.encode(buf, sizes);
+ }
+
+ public static MergeStatuses decode(ByteBuf buf) {
+ int shuffleId = buf.readInt();
+ RoaringBitmap[] bitmaps = Encoders.BitmapArrays.decode(buf);
+ int[] reduceIds = Encoders.IntArrays.decode(buf);
+ long[] sizes = Encoders.LongArrays.decode(buf);
+ return new MergeStatuses(shuffleId, bitmaps, reduceIds, sizes);
+ }
+}
diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java
new file mode 100644
index 0000000000000..7eab5a644783c
--- /dev/null
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to push a block to a remote shuffle service to be merged in push based shuffle.
+ * The remote shuffle service will also include this message when responding the push requests.
+ */
+public class PushBlockStream extends BlockTransferMessage {
+ public final String appId;
+ public final String blockId;
+ // Similar to the chunkIndex in StreamChunkId, indicating the index of a block in a batch of
+ // blocks to be pushed.
+ public final int index;
+
+ public PushBlockStream(String appId, String blockId, int index) {
+ this.appId = appId;
+ this.blockId = blockId;
+ this.index = index;
+ }
+
+ @Override
+ protected Type type() {
+ return Type.PUSH_BLOCK_STREAM;
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(appId, blockId, index);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("appId", appId)
+ .add("blockId", blockId)
+ .add("index", index)
+ .toString();
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other != null && other instanceof PushBlockStream) {
+ PushBlockStream o = (PushBlockStream) other;
+ return Objects.equal(appId, o.appId)
+ && Objects.equal(blockId, o.blockId)
+ && index == o.index;
+ }
+ return false;
+ }
+
+ @Override
+ public int encodedLength() {
+ return Encoders.Strings.encodedLength(appId)
+ + Encoders.Strings.encodedLength(blockId) + 4;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ Encoders.Strings.encode(buf, appId);
+ Encoders.Strings.encode(buf, blockId);
+ buf.writeInt(index);
+ }
+
+ public static PushBlockStream decode(ByteBuf buf) {
+ String appId = Encoders.Strings.decode(buf);
+ String blockId = Encoders.Strings.decode(buf);
+ int index = buf.readInt();
+ return new PushBlockStream(appId, blockId, index);
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java
new file mode 100644
index 0000000000000..992e7762c5a54
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ErrorHandlerSuite.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.net.ConnectException;
+
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Test suite for {@link ErrorHandler}
+ */
+public class ErrorHandlerSuite {
+
+ @Test
+ public void testPushErrorRetry() {
+ ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
+ assertFalse(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
+ ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
+ assertFalse(handler.shouldRetryError(new RuntimeException(new ConnectException())));
+ assertTrue(handler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
+ ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
+ assertTrue(handler.shouldRetryError(new Throwable()));
+ }
+
+ @Test
+ public void testPushErrorLogging() {
+ ErrorHandler.BlockPushErrorHandler handler = new ErrorHandler.BlockPushErrorHandler();
+ assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
+ ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))));
+ assertFalse(handler.shouldLogError(new RuntimeException(new IllegalArgumentException(
+ ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))));
+ assertTrue(handler.shouldLogError(new Throwable()));
+ }
+}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
index 455351fcf767c..680b8d74a2eea 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.network.shuffle;
+import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Iterator;
@@ -25,6 +26,7 @@
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
+import org.roaringbitmap.RoaringBitmap;
import static org.junit.Assert.*;
import static org.mockito.ArgumentMatchers.any;
@@ -39,6 +41,8 @@
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
+import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge;
+import org.apache.spark.network.shuffle.protocol.MergeStatuses;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
@@ -50,6 +54,7 @@ public class ExternalBlockHandlerSuite {
OneForOneStreamManager streamManager;
ExternalShuffleBlockResolver blockResolver;
RpcHandler handler;
+ MergedShuffleFileManager mergedShuffleManager;
ManagedBuffer[] blockMarkers = {
new NioManagedBuffer(ByteBuffer.wrap(new byte[3])),
new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))
@@ -59,17 +64,20 @@ public class ExternalBlockHandlerSuite {
public void beforeEach() {
streamManager = mock(OneForOneStreamManager.class);
blockResolver = mock(ExternalShuffleBlockResolver.class);
- handler = new ExternalBlockHandler(streamManager, blockResolver);
+ mergedShuffleManager = mock(MergedShuffleFileManager.class);
+ handler = new ExternalBlockHandler(streamManager, blockResolver, mergedShuffleManager);
}
@Test
public void testRegisterExecutor() {
RpcResponseCallback callback = mock(RpcResponseCallback.class);
- ExecutorShuffleInfo config = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, 16, "sort");
+ String[] localDirs = new String[] {"/a", "/b"};
+ ExecutorShuffleInfo config = new ExecutorShuffleInfo(localDirs, 16, "sort");
ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer();
handler.receive(client, registerMessage, callback);
verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config);
+ verify(mergedShuffleManager, times(1)).registerExecutor("app0", localDirs);
verify(callback, times(1)).onSuccess(any(ByteBuffer.class));
verify(callback, never()).onFailure(any(Throwable.class));
@@ -222,4 +230,32 @@ public void testBadMessages() {
verify(callback, never()).onSuccess(any(ByteBuffer.class));
verify(callback, never()).onFailure(any(Throwable.class));
}
+
+ @Test
+ public void testFinalizeShuffleMerge() throws IOException {
+ RpcResponseCallback callback = mock(RpcResponseCallback.class);
+
+ FinalizeShuffleMerge req = new FinalizeShuffleMerge("app0", 0);
+ RoaringBitmap bitmap = RoaringBitmap.bitmapOf(0, 1, 2);
+ MergeStatuses statuses = new MergeStatuses(0, new RoaringBitmap[]{bitmap},
+ new int[]{3}, new long[]{30});
+ when(mergedShuffleManager.finalizeShuffleMerge(req)).thenReturn(statuses);
+
+ ByteBuffer reqBuf = req.toByteBuffer();
+ handler.receive(client, reqBuf, callback);
+ verify(mergedShuffleManager, times(1)).finalizeShuffleMerge(req);
+ ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class);
+ verify(callback, times(1)).onSuccess(response.capture());
+ verify(callback, never()).onFailure(any());
+
+ MergeStatuses mergeStatuses =
+ (MergeStatuses) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
+ assertEquals(mergeStatuses, statuses);
+
+ Timer finalizeShuffleMergeLatencyMillis = (Timer) ((ExternalBlockHandler) handler)
+ .getAllMetrics()
+ .getMetrics()
+ .get("finalizeShuffleMergeLatencyMillis");
+ assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount());
+ }
}
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java
new file mode 100644
index 0000000000000..ebcdba72aa1a8
--- /dev/null
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import io.netty.buffer.Unpooled;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalMatchers.*;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+import org.apache.spark.network.buffer.NioManagedBuffer;
+import org.apache.spark.network.client.RpcResponseCallback;
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
+import org.apache.spark.network.shuffle.protocol.PushBlockStream;
+
+
+public class OneForOneBlockPusherSuite {
+
+ @Test
+ public void testPushOne() {
+ LinkedHashMap blocks = Maps.newLinkedHashMap();
+ blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1])));
+ String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+
+ BlockFetchingListener listener = pushBlocks(
+ blocks,
+ blockIds,
+ Arrays.asList(new PushBlockStream("app-id", "shuffle_0_0_0", 0)));
+
+ verify(listener).onBlockFetchSuccess(eq("shuffle_0_0_0"), any());
+ }
+
+ @Test
+ public void testPushThree() {
+ LinkedHashMap blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23])));
+ blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23])));
+ String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+
+ BlockFetchingListener listener = pushBlocks(
+ blocks,
+ blockIds,
+ Arrays.asList(new PushBlockStream("app-id", "b0", 0),
+ new PushBlockStream("app-id", "b1", 1),
+ new PushBlockStream("app-id", "b2", 2)));
+
+ for (int i = 0; i < 3; i ++) {
+ verify(listener, times(1)).onBlockFetchSuccess(eq("b" + i), any());
+ }
+ }
+
+ @Test
+ public void testServerFailures() {
+ LinkedHashMap blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+ blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+ String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+
+ BlockFetchingListener listener = pushBlocks(
+ blocks,
+ blockIds,
+ Arrays.asList(new PushBlockStream("app-id", "b0", 0),
+ new PushBlockStream("app-id", "b1", 1),
+ new PushBlockStream("app-id", "b2", 2)));
+
+ verify(listener, times(1)).onBlockFetchSuccess(eq("b0"), any());
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
+ verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any());
+ }
+
+ @Test
+ public void testHandlingRetriableFailures() {
+ LinkedHashMap blocks = Maps.newLinkedHashMap();
+ blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12])));
+ blocks.put("b1", null);
+ blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[0])));
+ String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]);
+
+ BlockFetchingListener listener = pushBlocks(
+ blocks,
+ blockIds,
+ Arrays.asList(new PushBlockStream("app-id", "b0", 0),
+ new PushBlockStream("app-id", "b1", 1),
+ new PushBlockStream("app-id", "b2", 2)));
+
+ verify(listener, times(1)).onBlockFetchSuccess(eq("b0"), any());
+ verify(listener, times(0)).onBlockFetchSuccess(not(eq("b0")), any());
+ verify(listener, times(0)).onBlockFetchFailure(eq("b0"), any());
+ verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any());
+ verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any());
+ }
+
+ /**
+ * Begins a push on the given set of blocks by mocking the response from server side.
+ * If a block is an empty byte, a server side retriable exception will be thrown.
+ * If a block is null, a non-retriable exception will be thrown.
+ */
+ private static BlockFetchingListener pushBlocks(
+ LinkedHashMap blocks,
+ String[] blockIds,
+ Iterable expectMessages) {
+ TransportClient client = mock(TransportClient.class);
+ BlockFetchingListener listener = mock(BlockFetchingListener.class);
+ OneForOneBlockPusher pusher =
+ new OneForOneBlockPusher(client, "app-id", blockIds, listener, blocks);
+
+ Iterator> blockIterator = blocks.entrySet().iterator();
+ Iterator msgIterator = expectMessages.iterator();
+ doAnswer(invocation -> {
+ ByteBuffer header = ((ManagedBuffer) invocation.getArguments()[0]).nioByteBuffer();
+ BlockTransferMessage message = BlockTransferMessage.Decoder.fromByteBuffer(header);
+ RpcResponseCallback callback = (RpcResponseCallback) invocation.getArguments()[2];
+ Map.Entry entry = blockIterator.next();
+ ManagedBuffer block = entry.getValue();
+ if (block != null && block.nioByteBuffer().capacity() > 0) {
+ callback.onSuccess(header);
+ } else if (block != null) {
+ callback.onFailure(new RuntimeException("Failed " + entry.getKey()
+ + ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX));
+ } else {
+ callback.onFailure(new RuntimeException("Quick fail " + entry.getKey()
+ + ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX));
+ }
+ assertEquals(msgIterator.next(), message);
+ return null;
+ }).when(client).uploadStream(any(ManagedBuffer.class), any(), any(RpcResponseCallback.class));
+
+ pusher.start();
+ return listener;
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala
index d681c13337e0d..ea4d252f0dbae 100644
--- a/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/ExternalShuffleServiceMetricsSuite.scala
@@ -61,7 +61,8 @@ class ExternalShuffleServiceMetricsSuite extends SparkFunSuite {
"registeredExecutorsSize",
"registerExecutorRequestLatencyMillis",
"shuffle-server.usedDirectMemory",
- "shuffle-server.usedHeapMemory")
+ "shuffle-server.usedHeapMemory",
+ "finalizeShuffleMergeLatencyMillis")
)
}
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala
index 63ac1af8a9127..9239d891aae3b 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala
@@ -40,7 +40,7 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers {
val allMetrics = Set(
"openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis",
"blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections",
- "numCaughtExceptions")
+ "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis")
metrics.getMetrics.keySet().asScala should be (allMetrics)
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
index 46e596575533d..a6a302ad5df95 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala
@@ -405,6 +405,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd
"openBlockRequestLatencyMillis",
"registeredExecutorsSize",
"registerExecutorRequestLatencyMillis",
+ "finalizeShuffleMergeLatencyMillis",
"shuffle-server.usedDirectMemory",
"shuffle-server.usedHeapMemory"
))