Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,27 @@ public static String[] decode(ByteBuf buf) {
return strings;
}
}

/** Integer arrays are encoded with their length followed by integers. */
public static class IntArrays {
public static int encodedLength(int[] ints) {
return 4 + 4 * ints.length;
}

public static void encode(ByteBuf buf, int[] ints) {
buf.writeInt(ints.length);
for (int i : ints) {
buf.writeInt(i);
}
}

public static int[] decode(ByteBuf buf) {
int numInts = buf.readInt();
int[] ints = new int[numInts];
for (int i = 0; i < ints.length; i ++) {
ints[i] = buf.readInt();
}
return ints;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -341,4 +341,13 @@ public int chunkFetchHandlerThreads() {
return (int) Math.ceil(threads * (chunkFetchHandlerThreadsPercent / 100.0));
}

/**
* Whether to use the old protocol while doing the shuffle block fetching.
* It is only enabled while we need the compatibility in the scenario of new spark version
* job fetching blocks from old version external shuffle service.
*/
public boolean useOldFetchProtocol() {
return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,22 +92,37 @@ protected void handleMessage(
BlockTransferMessage msgObj,
TransportClient client,
RpcResponseCallback callback) {
if (msgObj instanceof OpenBlocks) {
if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
try {
OpenBlocks msg = (OpenBlocks) msgObj;
checkAuth(client, msg.appId);
long streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel());
int numBlockIds;
long streamId;
if (msgObj instanceof FetchShuffleBlocks) {
FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
checkAuth(client, msg.appId);
numBlockIds = 0;
for (int[] ids: msg.reduceIds) {
numBlockIds += ids.length;
}
streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg, numBlockIds), client.getChannel());
} else {
// For the compatibility with the old version, still keep the support for OpenBlocks.
OpenBlocks msg = (OpenBlocks) msgObj;
numBlockIds = msg.blockIds.length;
checkAuth(client, msg.appId);
streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg), client.getChannel());
}
if (logger.isTraceEnabled()) {
logger.trace(
"Registered streamId {} with {} buffers for client {} from host {}",
streamId,
msg.blockIds.length,
numBlockIds,
client.getClientId(),
getRemoteAddress(client.getChannel()));
}
callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer());
callback.onSuccess(new StreamHandle(streamId, numBlockIds).toByteBuffer());
} finally {
responseDelayContext.stop();
}
Expand Down Expand Up @@ -224,7 +239,10 @@ private class ManagedBufferIterator implements Iterator<ManagedBuffer> {
private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
private final int size;

ManagedBufferIterator(final String appId, final String execId, String[] blockIds) {
ManagedBufferIterator(OpenBlocks msg) {
String appId = msg.appId;
String execId = msg.execId;
String[] blockIds = msg.blockIds;
String[] blockId0Parts = blockIds[0].split("_");
if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) {
final int shuffleId = Integer.parseInt(blockId0Parts[1]);
Expand Down Expand Up @@ -272,6 +290,21 @@ private int[] shuffleMapIdAndReduceIds(String[] blockIds, int shuffleId) {
return mapIdAndReduceIds;
}

ManagedBufferIterator(FetchShuffleBlocks msg, int numBlockIds) {
final int[] mapIdAndReduceIds = new int[2 * numBlockIds];
int idx = 0;
for (int i = 0; i < msg.mapIds.length; i++) {
for (int reduceId : msg.reduceIds[i]) {
mapIdAndReduceIds[idx++] = msg.mapIds[i];
mapIdAndReduceIds[idx++] = reduceId;
}
}
assert(idx == 2 * numBlockIds);
size = mapIdAndReduceIds.length;
blockDataForIndexFn = index -> blockManager.getBlockData(msg.appId, msg.execId,
msg.shuffleId, mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
}

@Override
public boolean hasNext() {
return index < size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;

import com.google.common.primitives.Ints;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -31,6 +34,7 @@
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks;
import org.apache.spark.network.shuffle.protocol.OpenBlocks;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.util.TransportConf;
Expand All @@ -48,7 +52,7 @@ public class OneForOneBlockFetcher {
private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockFetcher.class);

private final TransportClient client;
private final OpenBlocks openMessage;
private final BlockTransferMessage message;
private final String[] blockIds;
private final BlockFetchingListener listener;
private final ChunkReceivedCallback chunkCallback;
Expand Down Expand Up @@ -76,12 +80,71 @@ public OneForOneBlockFetcher(
TransportConf transportConf,
DownloadFileManager downloadFileManager) {
this.client = client;
this.openMessage = new OpenBlocks(appId, execId, blockIds);
this.blockIds = blockIds;
this.listener = listener;
this.chunkCallback = new ChunkCallback();
this.transportConf = transportConf;
this.downloadFileManager = downloadFileManager;
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}
if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
this.message = createFetchShuffleBlocksMsg(appId, execId, blockIds);
} else {
this.message = new OpenBlocks(appId, execId, blockIds);
}
}

private boolean isShuffleBlocks(String[] blockIds) {
for (String blockId : blockIds) {
if (!blockId.startsWith("shuffle_")) {
return false;
}
}
return true;
}

/**
* Analyze the pass in blockIds and create FetchShuffleBlocks message.
* The blockIds has been sorted by mapId and reduceId. It's produced in
* org.apache.spark.MapOutputTracker.convertMapStatuses.
*/
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
String appId, String execId, String[] blockIds) {
int shuffleId = splitBlockId(blockIds[0])[0];
HashMap<Integer, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
for (String blockId : blockIds) {
int[] blockIdParts = splitBlockId(blockId);
if (blockIdParts[0] != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
int mapId = blockIdParts[1];
if (!mapIdToReduceIds.containsKey(mapId)) {
mapIdToReduceIds.put(mapId, new ArrayList<>());
}
mapIdToReduceIds.get(mapId).add(blockIdParts[2]);
}
int[] mapIds = Ints.toArray(mapIdToReduceIds.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
for (int i = 0; i < mapIds.length; i++) {
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
}
return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIdArr);
}

/** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */
private int[] splitBlockId(String blockId) {
String[] blockIdParts = blockId.split("_");
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
return new int[] {
Integer.parseInt(blockIdParts[1]),
Integer.parseInt(blockIdParts[2]),
Integer.parseInt(blockIdParts[3])
};
}

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
Expand All @@ -106,11 +169,7 @@ public void onFailure(int chunkIndex, Throwable e) {
* {@link StreamHandle}. We will send all fetch requests immediately, without throttling.
*/
public void start() {
if (blockIds.length == 0) {
throw new IllegalArgumentException("Zero-sized blockIds array");
}

client.sendRpc(openMessage.toByteBuffer(), new RpcResponseCallback() {
client.sendRpc(message.toByteBuffer(), new RpcResponseCallback() {
@Override
public void onSuccess(ByteBuffer response) {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,22 @@
* by Spark's NettyBlockTransferService.
*
* At a high level:
* - OpenBlock is handled by both services, but only services shuffle files for the external
* shuffle service. It returns a StreamHandle.
* - OpenBlock is logically only handled by the NettyBlockTransferService, but for the capability
* for old version Spark, we still keep it in external shuffle service.
* It returns a StreamHandle.
* - UploadBlock is only handled by the NettyBlockTransferService.
* - RegisterExecutor is only handled by the external shuffle service.
* - RemoveBlocks is only handled by the external shuffle service.
* - FetchShuffleBlocks is handled by both services for shuffle files. It returns a StreamHandle.
*/
public abstract class BlockTransferMessage implements Encodable {
protected abstract Type type();

/** Preceding every serialized message is its type, which allows us to deserialize it. */
public enum Type {
OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4),
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8);
HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8),
FETCH_SHUFFLE_BLOCKS(9);

private final byte id;

Expand Down Expand Up @@ -71,6 +74,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) {
case 6: return UploadBlockStream.decode(buf);
case 7: return RemoveBlocks.decode(buf);
case 8: return BlocksRemoved.decode(buf);
case 9: return FetchShuffleBlocks.decode(buf);
default: throw new IllegalArgumentException("Unknown message type: " + type);
}
}
Expand Down
Loading