Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class OneForOneStreamManager extends StreamManager {
private final ConcurrentHashMap<Long, StreamState> streams;

/** State of a single stream. */
private static class StreamState {
public static class StreamState {
final String appId;
final Iterator<ManagedBuffer> buffers;

Expand All @@ -59,11 +59,19 @@ private static class StreamState {
// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;

StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
public StreamState(String appId, Iterator<ManagedBuffer> buffers, Channel channel) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
this.associatedChannel = channel;
}

public String getAppId() {
return appId;
}

public Iterator<ManagedBuffer> getBuffers() {
return buffers;
}
}

public OneForOneStreamManager() {
Expand Down Expand Up @@ -208,4 +216,8 @@ public long registerStream(String appId, Iterator<ManagedBuffer> buffers, Channe
public int numStreamStates() {
return streams.size();
}

public ConcurrentHashMap<Long, StreamState> getStreams() {
return streams;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

import com.codahale.metrics.Gauge;
Expand All @@ -39,6 +40,7 @@
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.OneForOneStreamManager.StreamState;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.StreamManager;
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;
Expand Down Expand Up @@ -109,15 +111,29 @@ protected void handleMessage(
numBlockIds += ids.length;
}
}
streamId = streamManager.registerStream(client.getClientId(),
new ShuffleManagedBufferIterator(msg), client.getChannel());
if (shouldRegisterStream(msg.appId, msg.execId, client)) {
streamId = streamManager.registerStream(msg.appId,
new ShuffleManagedBufferIterator(msg), client.getChannel());
} else {
Exception e = new RuntimeException("can not register stream since the app: " +
msg.appId + " has already terminated");
callback.onFailure(e);
return;
}
} else {
// For the compatibility with the old version, still keep the support for OpenBlocks.
OpenBlocks msg = (OpenBlocks) msgObj;
numBlockIds = msg.blockIds.length;
checkAuth(client, msg.appId);
streamId = streamManager.registerStream(client.getClientId(),
new ManagedBufferIterator(msg), client.getChannel());
if (shouldRegisterStream(msg.appId, msg.execId, client)) {
streamId = streamManager.registerStream(msg.appId,
new ManagedBufferIterator(msg), client.getChannel());
} else {
Exception e = new RuntimeException("can not register stream since the app: " +
msg.appId + " has already terminated");
callback.onFailure(e);
return;
}
}
if (logger.isTraceEnabled()) {
logger.trace(
Expand Down Expand Up @@ -180,6 +196,22 @@ public StreamManager getStreamManager() {
* local directories associated with the executors of that application in a separate thread.
*/
public void applicationRemoved(String appId, boolean cleanupLocalDirs) {
logger.info("Current numStreamStates size: {}", streamManager.numStreamStates());
ConcurrentHashMap<Long, StreamState> streams = streamManager.getStreams();
for (Map.Entry<Long, StreamState> entry: streams.entrySet()) {
StreamState state = entry.getValue();
if (state.getAppId().equals(appId)) {
logger.warn("Found finished app: {} , " +
"but streamState is still in memory, clean it now", appId);
streams.remove(entry.getKey());

// Release all remaining buffers.
while (state.getBuffers().hasNext()) {
state.getBuffers().next().release();
}
}
}

blockManager.applicationRemoved(appId, cleanupLocalDirs);
}

Expand Down Expand Up @@ -388,4 +420,13 @@ public void channelInactive(TransportClient client) {
super.channelInactive(client);
}

private boolean shouldRegisterStream(String appId, String execId, TransportClient client) {
if (!blockManager.getExecutors().containsKey(new AppExecId(appId, execId))) {
logger.warn("the App {} with execId {} is not exist, so do not register stream",
appId, execId);
client.getChannel().close();
return false;
}
return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ public class ExternalShuffleBlockResolver {
@VisibleForTesting
final ConcurrentMap<AppExecId, ExecutorShuffleInfo> executors;

public ConcurrentMap<AppExecId, ExecutorShuffleInfo> getExecutors() {
return executors;
}

/**
* Caches index file information so that we can avoid open/close the index files
* for each block fetch.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,15 @@
package org.apache.spark.network.shuffle;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentMap;

import com.codahale.metrics.Meter;
import com.codahale.metrics.Timer;
import com.google.common.collect.Maps;
import io.netty.channel.Channel;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
Expand All @@ -35,6 +40,7 @@
import org.apache.spark.network.client.RpcResponseCallback;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.OneForOneStreamManager.StreamState;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.shuffle.protocol.BlockTransferMessage;
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo;
Expand All @@ -43,6 +49,7 @@
import org.apache.spark.network.shuffle.protocol.RegisterExecutor;
import org.apache.spark.network.shuffle.protocol.StreamHandle;
import org.apache.spark.network.shuffle.protocol.UploadBlock;
import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId;

public class ExternalBlockHandlerSuite {
TransportClient client = mock(TransportClient.class);
Expand Down Expand Up @@ -161,6 +168,13 @@ public void testOpenDiskPersistedRDDBlocksWithMissingBlock() {
private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] blockMarkers) {
when(client.getClientId()).thenReturn("app0");

// add app info to executors
ExecutorShuffleInfo shuffleInfo = new ExecutorShuffleInfo(new String[] {"/a", "/b"},
16, "sort");
ConcurrentMap<AppExecId, ExecutorShuffleInfo> executors = Maps.newConcurrentMap();
executors.put(new AppExecId("app0", "exec1"), shuffleInfo);
when(blockResolver.getExecutors()).thenReturn(executors);

RpcResponseCallback callback = mock(RpcResponseCallback.class);
handler.receive(client, msg.toByteBuffer(), callback);

Expand Down Expand Up @@ -222,4 +236,53 @@ public void testBadMessages() {
verify(callback, never()).onSuccess(any(ByteBuffer.class));
verify(callback, never()).onFailure(any(Throwable.class));
}

@Test
public void testDoNotRegisterStreamWhenAppHasFinished() {
// add app info to executors
ExecutorShuffleInfo shuffleInfo = new ExecutorShuffleInfo(new String[] {"/a", "/b"},
16, "sort");
ConcurrentMap<AppExecId, ExecutorShuffleInfo> executors = Maps.newConcurrentMap();
executors.put(new AppExecId("app0", "exec1"), shuffleInfo);
when(blockResolver.getExecutors()).thenReturn(executors);

Channel dummyChannel = mock(Channel.class, RETURNS_SMART_NULLS);
when(client.getChannel()).thenReturn(dummyChannel);
when(client.getClientId()).thenReturn("app1");

// suppose app1's info has been cleaned
// for OpenBlocks msg case
OpenBlocks openBlocks = new OpenBlocks(
"app1", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" });
RpcResponseCallback callback = mock(RpcResponseCallback.class);
handler.receive(client, openBlocks.toByteBuffer(), callback);
verify(callback, times(1)).onFailure(any(Throwable.class));
assertEquals(0, streamManager.numStreamStates());

// for FetchShuffleBlocks msg case
FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks(
"app1", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, false);
handler.receive(client, fetchShuffleBlocks.toByteBuffer(), callback);
verify(callback, never()).onSuccess(any(ByteBuffer.class));
assertEquals(0, streamManager.numStreamStates());
}

@Test
public void testWhenApplicationRemovedCleanRelatedStreamState() {
OneForOneStreamManager oneForOneStreamManager = new OneForOneStreamManager();
ExternalBlockHandler externalBlockHandler = new ExternalBlockHandler(
oneForOneStreamManager,
blockResolver);
Channel dummyChannel = mock(Channel.class, RETURNS_SMART_NULLS);

List<ManagedBuffer> buffers = new ArrayList<>();
buffers.add(new NioManagedBuffer(ByteBuffer.wrap(new byte[3])));
buffers.add(new NioManagedBuffer(ByteBuffer.wrap(new byte[7])));
oneForOneStreamManager.getStreams().put(1L,
new StreamState("app0", buffers.iterator(), dummyChannel));
assertEquals(1, oneForOneStreamManager.numStreamStates());

externalBlockHandler.applicationRemoved("app0", false);
assertEquals(0, oneForOneStreamManager.numStreamStates());
}
}