Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ private TransportChannelHandler createChannelHandler(Channel channel, RpcHandler
TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
TransportClient client = new TransportClient(channel, responseHandler);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
rpcHandler);
rpcHandler, conf.maxChunksBeingTransferred());
return new TransportChannelHandler(client, responseHandler, requestHandler,
conf.connectionTimeoutMs(), closeIdleConnections);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import com.google.common.base.Preconditions;
import io.netty.channel.Channel;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -53,6 +55,9 @@ private static class StreamState {
// that the caller only requests each chunk one at a time, in order.
int curChunk = 0;

// Used to keep track of the number of chunks being transferred and not finished yet.
volatile long chunksBeingTransferred = 0L;

StreamState(String appId, Iterator<ManagedBuffer> buffers) {
this.appId = appId;
this.buffers = Preconditions.checkNotNull(buffers);
Expand Down Expand Up @@ -96,18 +101,25 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) {

@Override
public ManagedBuffer openStream(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2:
"Stream id and chunk index should be specified when open stream for fetching block.";
long streamId = Long.valueOf(array[0]);
int chunkIndex = Integer.valueOf(array[1]);
return getChunk(streamId, chunkIndex);
Pair<Long, Integer> streamChunkIdPair = parseStreamChunkId(streamChunkId);
return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight());
}

public static String genStreamChunkId(long streamId, int chunkId) {
return String.format("%d_%d", streamId, chunkId);
}

// Parse streamChunkId to be stream id and chunk id. This is used when fetch remote chunk as a
// stream.
public static Pair<Long, Integer> parseStreamChunkId(String streamChunkId) {
String[] array = streamChunkId.split("_");
assert array.length == 2:
"Stream id and chunk index should be specified.";
long streamId = Long.valueOf(array[0]);
int chunkIndex = Integer.valueOf(array[1]);
return ImmutablePair.of(streamId, chunkIndex);
}

@Override
public void connectionTerminated(Channel channel) {
// Close all streams which have been associated with the channel.
Expand Down Expand Up @@ -139,6 +151,42 @@ public void checkAuthorization(TransportClient client, long streamId) {
}
}

@Override
public void chunkBeingSent(long streamId) {
StreamState streamState = streams.get(streamId);
if (streamState != null) {
streamState.chunksBeingTransferred++;
}

}

@Override
public void streamBeingSent(String streamId) {
chunkBeingSent(parseStreamChunkId(streamId).getLeft());
}

@Override
public void chunkSent(long streamId) {
StreamState streamState = streams.get(streamId);
if (streamState != null) {
streamState.chunksBeingTransferred--;
}
}

@Override
public void streamSent(String streamId) {
chunkSent(OneForOneStreamManager.parseStreamChunkId(streamId).getLeft());
}

@Override
public long chunksBeingTransferred() {
long sum = 0L;
for (StreamState streamState: streams.values()) {
sum += streamState.chunksBeingTransferred;
}
return sum;
}

/**
* Registers a stream of ManagedBuffers which are served as individual chunks one at a time to
* callers. Each ManagedBuffer will be release()'d after it is transferred on the wire. If a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,31 @@ public void connectionTerminated(Channel channel) { }
*/
public void checkAuthorization(TransportClient client, long streamId) { }

/**
* Return the number of chunks being transferred and not finished yet in this StreamManager.
*/
public long chunksBeingTransferred() {
return 0;
}

/**
* Called when start sending a chunk.
*/
public void chunkBeingSent(long streamId) { }

/**
* Called when start sending a stream.
*/
public void streamBeingSent(String streamId) { }

/**
* Called when a chunk is successfully sent.
*/
public void chunkSent(long streamId) { }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes more sense to have 2 methods:

chunkSent(long streamId);

streamSent(String streamId);

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be much better


/**
* Called when a stream is successfully sent.
*/
public void streamSent(String streamId) { }

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import com.google.common.base.Throwables;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -65,14 +66,19 @@ public class TransportRequestHandler extends MessageHandler<RequestMessage> {
/** Returns each chunk part of a stream. */
private final StreamManager streamManager;

/** The max number of chunks being transferred and not finished yet. */
private final long maxChunksBeingTransferred;

public TransportRequestHandler(
Channel channel,
TransportClient reverseClient,
RpcHandler rpcHandler) {
RpcHandler rpcHandler,
Long maxChunksBeingTransferred) {
this.channel = channel;
this.reverseClient = reverseClient;
this.rpcHandler = rpcHandler;
this.streamManager = rpcHandler.getStreamManager();
this.maxChunksBeingTransferred = maxChunksBeingTransferred;
}

@Override
Expand Down Expand Up @@ -117,7 +123,13 @@ private void processFetchRequest(final ChunkFetchRequest req) {
logger.trace("Received req from {} to fetch block {}", getRemoteAddress(channel),
req.streamChunkId);
}

long chunksBeingTransferred = streamManager.chunksBeingTransferred();
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
chunksBeingTransferred, maxChunksBeingTransferred);
channel.close();
return;
}
ManagedBuffer buf;
try {
streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId);
Expand All @@ -130,10 +142,25 @@ private void processFetchRequest(final ChunkFetchRequest req) {
return;
}

respond(new ChunkFetchSuccess(req.streamChunkId, buf));
streamManager.chunkBeingSent(req.streamChunkId.streamId);
respond(new ChunkFetchSuccess(req.streamChunkId, buf)).addListener(future -> {
streamManager.chunkSent(req.streamChunkId.streamId);
});
}

private void processStreamRequest(final StreamRequest req) {
if (logger.isTraceEnabled()) {
logger.trace("Received req from {} to fetch stream {}", getRemoteAddress(channel),
req.streamId);
}

long chunksBeingTransferred = streamManager.chunksBeingTransferred();
if (chunksBeingTransferred >= maxChunksBeingTransferred) {
logger.warn("The number of chunks being transferred {} is above {}, close the connection.",
chunksBeingTransferred, maxChunksBeingTransferred);
channel.close();
return;
}
ManagedBuffer buf;
try {
buf = streamManager.openStream(req.streamId);
Expand All @@ -145,7 +172,10 @@ private void processStreamRequest(final StreamRequest req) {
}

if (buf != null) {
respond(new StreamResponse(req.streamId, buf.size(), buf));
streamManager.streamBeingSent(req.streamId);
respond(new StreamResponse(req.streamId, buf.size(), buf)).addListener(future -> {
streamManager.streamSent(req.streamId);
});
} else {
respond(new StreamFailure(req.streamId, String.format(
"Stream '%s' was not found.", req.streamId)));
Expand Down Expand Up @@ -187,9 +217,9 @@ private void processOneWayMessage(OneWayMessage req) {
* Responds to a single message with some Encodable object. If a failure occurs while sending,
* it will be logged and the channel closed.
*/
private void respond(Encodable result) {
private ChannelFuture respond(Encodable result) {
SocketAddress remoteAddress = channel.remoteAddress();
channel.writeAndFlush(result).addListener(future -> {
return channel.writeAndFlush(result).addListener(future -> {
if (future.isSuccess()) {
logger.trace("Sent result {} to client {}", result, remoteAddress);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,4 +257,10 @@ public Properties cryptoConf() {
return CryptoUtils.toCryptoConf("spark.network.crypto.config.", conf.getAll());
}

/**
* The max number of chunks allowed to being transferred at the same time on shuffle service.
*/
public long maxChunksBeingTransferred() {
return conf.getLong("spark.shuffle.maxChunksBeingTransferred", Long.MAX_VALUE);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.network;

import java.util.ArrayList;
import java.util.List;

import io.netty.channel.Channel;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultChannelPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import org.junit.Test;

import static org.mockito.Mockito.*;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.spark.network.buffer.ManagedBuffer;
import org.apache.spark.network.client.TransportClient;
import org.apache.spark.network.protocol.*;
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.OneForOneStreamManager;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportRequestHandler;

public class TransportRequestHandlerSuite {

@Test
public void handleFetchRequestAndStreamRequest() throws Exception {
RpcHandler rpcHandler = new NoOpRpcHandler();
OneForOneStreamManager streamManager = (OneForOneStreamManager) (rpcHandler.getStreamManager());
Channel channel = mock(Channel.class);
List<Pair<Object, ExtendedChannelPromise>> responseAndPromisePairs =
new ArrayList<>();
when(channel.writeAndFlush(any()))
.thenAnswer(invocationOnMock0 -> {
Object response = invocationOnMock0.getArguments()[0];
ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel);
responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture));
return channelFuture;
});

// Prepare the stream.
List<ManagedBuffer> managedBuffers = new ArrayList<>();
managedBuffers.add(new TestManagedBuffer(10));
managedBuffers.add(new TestManagedBuffer(20));
managedBuffers.add(new TestManagedBuffer(30));
managedBuffers.add(new TestManagedBuffer(40));
long streamId = streamManager.registerStream("test-app", managedBuffers.iterator());
streamManager.registerChannel(channel, streamId);
TransportClient reverseClient = mock(TransportClient.class);
TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient,
rpcHandler, 2L);

RequestMessage request0 = new ChunkFetchRequest(new StreamChunkId(streamId, 0));
requestHandler.handle(request0);
assert responseAndPromisePairs.size() == 1;
assert responseAndPromisePairs.get(0).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(0).getLeft())).body() ==
managedBuffers.get(0);

RequestMessage request1 = new ChunkFetchRequest(new StreamChunkId(streamId, 1));
requestHandler.handle(request1);
assert responseAndPromisePairs.size() == 2;
assert responseAndPromisePairs.get(1).getLeft() instanceof ChunkFetchSuccess;
assert ((ChunkFetchSuccess) (responseAndPromisePairs.get(1).getLeft())).body() ==
managedBuffers.get(1);

// Finish flushing the response for request0.
responseAndPromisePairs.get(0).getRight().finish(true);

RequestMessage request2 = new StreamRequest(String.format("%d_%d", streamId, 2));
requestHandler.handle(request2);
assert responseAndPromisePairs.size() == 3;
assert responseAndPromisePairs.get(2).getLeft() instanceof StreamResponse;
assert ((StreamResponse) (responseAndPromisePairs.get(2).getLeft())).body() ==
managedBuffers.get(2);

// Request3 will trigger the close of channel, because the number of max chunks being
// transferred is 2;
RequestMessage request3 = new StreamRequest(String.format("%d_%d", streamId, 3));
requestHandler.handle(request3);
verify(channel, times(1)).close();
assert responseAndPromisePairs.size() == 3;
}

private class ExtendedChannelPromise extends DefaultChannelPromise {

private List<GenericFutureListener> listeners = new ArrayList<>();
private boolean success;

public ExtendedChannelPromise(Channel channel) {
super(channel);
success = false;
}

@Override
public ChannelPromise addListener(
GenericFutureListener<? extends Future<? super Void>> listener) {
listeners.add(listener);
return super.addListener(listener);
}

@Override
public boolean isSuccess() {
return success;
}

public void finish(boolean success) {
this.success = success;
listeners.forEach(listener -> {
try {
listener.operationComplete(this);
} catch (Exception e) { }
});
}
}
}
7 changes: 7 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,13 @@ Apart from these, the following properties are also available, and may be useful
Max number of entries to keep in the index cache of the shuffle service.
</td>
</tr>
<tr>
<td><code>spark.shuffle.maxChunksBeingTransferred</code></td>
<td>Long.MAX_VALUE</td>
<td>
The max number of chunks allowed to being transferred at the same time on shuffle service.
</td>
</tr>
<tr>
<td><code>spark.shuffle.sort.bypassMergeThreshold</code></td>
<td>200</td>
Expand Down