Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,12 @@ protected void handleMessage(
FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
checkAuth(client, msg.appId);
numBlockIds = 0;
for (int[] ids: msg.reduceIds) {
numBlockIds += ids.length;
if (msg.batchFetchEnabled) {
numBlockIds = msg.mapIds.length;
} else {
for (int[] ids: msg.reduceIds) {
numBlockIds += ids.length;
}
}
streamId = streamManager.registerStream(client.getClientId(),
new ShuffleManagedBufferIterator(msg), client.getChannel());
Expand Down Expand Up @@ -323,13 +327,15 @@ private class ShuffleManagedBufferIterator implements Iterator<ManagedBuffer> {
private final int shuffleId;
private final long[] mapIds;
private final int[][] reduceIds;
private final boolean batchFetchEnabled;

ShuffleManagedBufferIterator(FetchShuffleBlocks msg) {
appId = msg.appId;
execId = msg.execId;
shuffleId = msg.shuffleId;
mapIds = msg.mapIds;
reduceIds = msg.reduceIds;
batchFetchEnabled = msg.batchFetchEnabled;
}

@Override
Expand All @@ -343,12 +349,20 @@ public boolean hasNext() {

@Override
public ManagedBuffer next() {
final ManagedBuffer block = blockManager.getBlockData(
appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]);
if (reduceIdx < reduceIds[mapIdx].length - 1) {
reduceIdx += 1;
ManagedBuffer block;
if (!batchFetchEnabled) {
block = blockManager.getBlockData(
appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]);
if (reduceIdx < reduceIds[mapIdx].length - 1) {
reduceIdx += 1;
} else {
reduceIdx = 0;
mapIdx += 1;
}
} else {
reduceIdx = 0;
assert(reduceIds[mapIdx].length == 2);
block = blockManager.getContinuousBlocksData(appId, execId, shuffleId, mapIds[mapIdx],
reduceIds[mapIdx][0], reduceIds[mapIdx][1]);
mapIdx += 1;
}
metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,34 @@ public void registerExecutor(
}

/**
* Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions
* about how the hash and sort based shuffles store their data.
* Obtains a FileSegmentManagedBuffer from a single block (shuffleId, mapId, reduceId).
*/
public ManagedBuffer getBlockData(
String appId,
String execId,
int shuffleId,
long mapId,
int reduceId) {
return getContinuousBlocksData(appId, execId, shuffleId, mapId, reduceId, reduceId + 1);
}

/**
* Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, [startReduceId, endReduceId)).
* We make assumptions about how the hash and sort based shuffles store their data.
*/
public ManagedBuffer getContinuousBlocksData(
String appId,
String execId,
int shuffleId,
long mapId,
int startReduceId,
int endReduceId) {
ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId));
if (executor == null) {
throw new RuntimeException(
String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId));
}
return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId);
return getSortBasedShuffleBlockData(executor, shuffleId, mapId, startReduceId, endReduceId);
}

public ManagedBuffer getRddBlockData(
Expand Down Expand Up @@ -296,13 +309,14 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) {
* and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId.
*/
private ManagedBuffer getSortBasedShuffleBlockData(
ExecutorShuffleInfo executor, int shuffleId, long mapId, int reduceId) {
ExecutorShuffleInfo executor, int shuffleId, long mapId, int startReduceId, int endReduceId) {
File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir,
"shuffle_" + shuffleId + "_" + mapId + "_0.index");

try {
ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile);
ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId);
ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(
startReduceId, endReduceId);
return new FileSegmentManagedBuffer(
conf,
ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import com.google.common.primitives.Ints;
import com.google.common.primitives.Longs;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -113,39 +112,47 @@ private boolean isShuffleBlocks(String[] blockIds) {
*/
private FetchShuffleBlocks createFetchShuffleBlocksMsg(
String appId, String execId, String[] blockIds) {
int shuffleId = splitBlockId(blockIds[0]).left;
String[] firstBlock = splitBlockId(blockIds[0]);
int shuffleId = Integer.parseInt(firstBlock[1]);
boolean batchFetchEnabled = firstBlock.length == 5;

HashMap<Long, ArrayList<Integer>> mapIdToReduceIds = new HashMap<>();
for (String blockId : blockIds) {
ImmutableTriple<Integer, Long, Integer> blockIdParts = splitBlockId(blockId);
if (blockIdParts.left != shuffleId) {
String[] blockIdParts = splitBlockId(blockId);
if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
", got:" + blockId);
}
long mapId = blockIdParts.middle;
long mapId = Long.parseLong(blockIdParts[2]);
if (!mapIdToReduceIds.containsKey(mapId)) {
mapIdToReduceIds.put(mapId, new ArrayList<>());
}
mapIdToReduceIds.get(mapId).add(blockIdParts.right);
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3]));
if (batchFetchEnabled) {
// When we read continuous shuffle blocks in batch, we will reuse reduceIds in
// FetchShuffleBlocks to store the start and end reduce id for range
// [startReduceId, endReduceId).
assert(blockIdParts.length == 5);
mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4]));
}
}
long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet());
int[][] reduceIdArr = new int[mapIds.length][];
for (int i = 0; i < mapIds.length; i++) {
reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i]));
}
return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIdArr);
return new FetchShuffleBlocks(
appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
}

/** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */
private ImmutableTriple<Integer, Long, Integer> splitBlockId(String blockId) {
/** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
private String[] splitBlockId(String blockId) {
String[] blockIdParts = blockId.split("_");
if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) {
throw new IllegalArgumentException(
"Unexpected shuffle block id format: " + blockId);
}
return new ImmutableTriple<>(
Integer.parseInt(blockIdParts[1]),
Long.parseLong(blockIdParts[2]),
Integer.parseInt(blockIdParts[3]));
return blockIdParts;
}

/** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,15 @@ public int getSize() {
* Get index offset for a particular reducer.
*/
public ShuffleIndexRecord getIndex(int reduceId) {
long offset = offsets.get(reduceId);
long nextOffset = offsets.get(reduceId + 1);
return getIndex(reduceId, reduceId + 1);
}

/**
* Get index offset for the reducer range of [startReduceId, endReduceId).
*/
public ShuffleIndexRecord getIndex(int startReduceId, int endReduceId) {
long offset = offsets.get(startReduceId);
long nextOffset = offsets.get(endReduceId);
return new ShuffleIndexRecord(offset, nextOffset - offset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,32 @@ public class FetchShuffleBlocks extends BlockTransferMessage {
// The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds,
// it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id.
public final long[] mapIds;
// When batchFetchEnabled=true, reduceIds[i] contains 2 elements: startReduceId (inclusive) and
// endReduceId (exclusive) for the mapper mapIds[i].
// When batchFetchEnabled=false, reduceIds[i] contains all the reduce IDs that mapper mapIds[i]
// needs to fetch.
public final int[][] reduceIds;
public final boolean batchFetchEnabled;

public FetchShuffleBlocks(
String appId,
String execId,
int shuffleId,
long[] mapIds,
int[][] reduceIds) {
int[][] reduceIds,
boolean batchFetchEnabled) {
this.appId = appId;
this.execId = execId;
this.shuffleId = shuffleId;
this.mapIds = mapIds;
this.reduceIds = reduceIds;
assert(mapIds.length == reduceIds.length);
this.batchFetchEnabled = batchFetchEnabled;
if (batchFetchEnabled) {
for (int[] ids: reduceIds) {
assert(ids.length == 2);
}
}
}

@Override
Expand All @@ -62,6 +74,7 @@ public String toString() {
.add("shuffleId", shuffleId)
.add("mapIds", Arrays.toString(mapIds))
.add("reduceIds", Arrays.deepToString(reduceIds))
.add("batchFetchEnabled", batchFetchEnabled)
.toString();
}

Expand All @@ -73,6 +86,7 @@ public boolean equals(Object o) {
FetchShuffleBlocks that = (FetchShuffleBlocks) o;

if (shuffleId != that.shuffleId) return false;
if (batchFetchEnabled != that.batchFetchEnabled) return false;
if (!appId.equals(that.appId)) return false;
if (!execId.equals(that.execId)) return false;
if (!Arrays.equals(mapIds, that.mapIds)) return false;
Expand All @@ -86,6 +100,7 @@ public int hashCode() {
result = 31 * result + shuffleId;
result = 31 * result + Arrays.hashCode(mapIds);
result = 31 * result + Arrays.deepHashCode(reduceIds);
result = 31 * result + (batchFetchEnabled ? 1 : 0);
return result;
}

Expand All @@ -100,7 +115,8 @@ public int encodedLength() {
+ 4 /* encoded length of shuffleId */
+ Encoders.LongArrays.encodedLength(mapIds)
+ 4 /* encoded length of reduceIds.size() */
+ encodedLengthOfReduceIds;
+ encodedLengthOfReduceIds
+ 1; /* encoded length of batchFetchEnabled */
}

@Override
Expand All @@ -113,6 +129,7 @@ public void encode(ByteBuf buf) {
for (int[] ids: reduceIds) {
Encoders.IntArrays.encode(buf, ids);
}
buf.writeBoolean(batchFetchEnabled);
}

public static FetchShuffleBlocks decode(ByteBuf buf) {
Expand All @@ -125,6 +142,7 @@ public static FetchShuffleBlocks decode(ByteBuf buf) {
for (int i = 0; i < reduceIdsSize; i++) {
reduceIds[i] = Encoders.IntArrays.decode(buf);
}
return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIds);
boolean batchFetchEnabled = buf.readBoolean();
return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIds, batchFetchEnabled);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ public void serializeOpenShuffleBlocks() {
checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }));
checkSerializeDeserialize(new FetchShuffleBlocks(
"app-1", "exec-2", 0, new long[] {0, 1},
new int[][] {{ 0, 1 }, { 0, 1, 2 }}));
new int[][] {{ 0, 1 }, { 0, 1, 2 }}, false));
checkSerializeDeserialize(new FetchShuffleBlocks(
"app-1", "exec-2", 0, new long[] {0, 1},
new int[][] {{ 0, 1 }, { 0, 2 }}, true));
checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo(
new String[] { "/local1", "/local2" }, 32, "MyShuffleManager")));
checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,30 @@ public void testFetchShuffleBlocks() {
when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]);

FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
"app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }});
"app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, false);
checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers);

verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0);
verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1);
verifyOpenBlockLatencyMetrics();
}

@Test
public void testFetchShuffleBlocksInBatch() {
ManagedBuffer[] batchBlockMarkers = {
new NioManagedBuffer(ByteBuffer.wrap(new byte[10]))
};
when(blockResolver.getContinuousBlocksData(
"app0", "exec1", 0, 0, 0, 1)).thenReturn(batchBlockMarkers[0]);

FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
"app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, true);
checkOpenBlocksReceive(fetchShuffleBlocks, batchBlockMarkers);

verify(blockResolver, times(1)).getContinuousBlocksData("app0", "exec1", 0, 0, 0, 1);
verifyOpenBlockLatencyMetrics();
}

@Test
public void testOpenDiskPersistedRDDBlocks() {
when(blockResolver.getRddBlockData("app0", "exec1", 0, 0)).thenReturn(blockMarkers[0]);
Expand Down Expand Up @@ -154,16 +170,17 @@ private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] bl

StreamHandle handle =
(StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue());
assertEquals(2, handle.numChunks);
assertEquals(blockMarkers.length, handle.numChunks);

@SuppressWarnings("unchecked")
ArgumentCaptor<Iterator<ManagedBuffer>> stream = (ArgumentCaptor<Iterator<ManagedBuffer>>)
(ArgumentCaptor<?>) ArgumentCaptor.forClass(Iterator.class);
verify(streamManager, times(1)).registerStream(anyString(), stream.capture(),
any());
Iterator<ManagedBuffer> buffers = stream.getValue();
assertEquals(blockMarkers[0], buffers.next());
assertEquals(blockMarkers[1], buffers.next());
for (ManagedBuffer blockMarker : blockMarkers) {
assertEquals(blockMarker, buffers.next());
}
assertFalse(buffers.hasNext());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,13 @@ public void testSortShuffleBlocks() throws IOException {
CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8));
assertEquals(sortBlock1, block1);
}

try (InputStream blocksStream = resolver.getContinuousBlocksData(
"app0", "exec0", 0, 0, 0, 2).createInputStream()) {
String blocks =
CharStreams.toString(new InputStreamReader(blocksStream, StandardCharsets.UTF_8));
assertEquals(sortBlock0 + sortBlock1, blocks);
}
}

@Test
Expand Down
Loading