Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.atomic.AtomicBoolean;

import org.apache.hadoop.util.Time;
import org.slf4j.Logger;
Expand All @@ -48,8 +49,6 @@ abstract class SocketIOWithTimeout {
private long timeout;
private boolean closed = false;

private static SelectorPool selector = new SelectorPool();

/* A timeout value of 0 implies wait for ever.
* We should have a value of timeout that implies zero wait.. i.e.
* read or write returns immediately.
Expand Down Expand Up @@ -154,7 +153,7 @@ int doIO(ByteBuffer buf, int ops) throws IOException {
//now wait for socket to be ready.
int count = 0;
try {
count = selector.select(channel, ops, timeout);
count = SelectorPool.select(channel, ops, timeout);
} catch (IOException e) { //unexpected IOException.
closed = true;
throw e;
Expand Down Expand Up @@ -200,7 +199,7 @@ static void connect(SocketChannel channel,
// we might have to call finishConnect() more than once
// for some channels (with user level protocols)

int ret = selector.select((SelectableChannel)channel,
int ret = SelectorPool.select(channel,
SelectionKey.OP_CONNECT, timeoutLeft);

if (ret > 0 && channel.finishConnect()) {
Expand Down Expand Up @@ -242,7 +241,7 @@ static void connect(SocketChannel channel,
*/
void waitForIO(int ops) throws IOException {

if (selector.select(channel, ops, timeout) == 0) {
if (SelectorPool.select(channel, ops, timeout) == 0) {
throw new SocketTimeoutException(timeoutExceptionString(channel, timeout,
ops));
}
Expand Down Expand Up @@ -280,12 +279,17 @@ private static String timeoutExceptionString(SelectableChannel channel,
* This maintains a pool of selectors. These selectors are closed
* once they are idle (unused) for a few seconds.
*/
private static class SelectorPool {
private static final class SelectorPool {

private static class SelectorInfo {
Selector selector;
long lastActivityTime;
LinkedList<SelectorInfo> queue;
private static final class SelectorInfo {
private final SelectorProvider provider;
private final Selector selector;
private long lastActivityTime;

private SelectorInfo(SelectorProvider provider, Selector selector) {
this.provider = provider;
this.selector = selector;
}

void close() {
if (selector != null) {
Expand All @@ -298,16 +302,11 @@ void close() {
}
}

private static class ProviderInfo {
SelectorProvider provider;
LinkedList<SelectorInfo> queue; // lifo
ProviderInfo next;
}
private static ConcurrentHashMap<SelectorProvider, ConcurrentLinkedDeque
<SelectorInfo>> providerMap = new ConcurrentHashMap<>();

private static final long IDLE_TIMEOUT = 10 * 1000; // 10 seconds.

private ProviderInfo providerList = null;

/**
* Waits on the channel with the given timeout using one of the
* cached selectors. It also removes any cached selectors that are
Expand All @@ -319,7 +318,7 @@ private static class ProviderInfo {
* @return
* @throws IOException
*/
int select(SelectableChannel channel, int ops, long timeout)
static int select(SelectableChannel channel, int ops, long timeout)
throws IOException {

SelectorInfo info = get(channel);
Expand Down Expand Up @@ -385,35 +384,18 @@ int select(SelectableChannel channel, int ops, long timeout)
* @return
* @throws IOException
*/
private synchronized SelectorInfo get(SelectableChannel channel)
private static SelectorInfo get(SelectableChannel channel)
throws IOException {
SelectorInfo selInfo = null;

SelectorProvider provider = channel.provider();

// pick the list : rarely there is more than one provider in use.
ProviderInfo pList = providerList;
while (pList != null && pList.provider != provider) {
pList = pList.next;
}
if (pList == null) {
//LOG.info("Creating new ProviderInfo : " + provider.toString());
pList = new ProviderInfo();
pList.provider = provider;
pList.queue = new LinkedList<SelectorInfo>();
pList.next = providerList;
providerList = pList;
}

LinkedList<SelectorInfo> queue = pList.queue;

if (queue.isEmpty()) {
ConcurrentLinkedDeque<SelectorInfo> infoQ = providerMap.computeIfAbsent(
provider, k -> new ConcurrentLinkedDeque<>());

SelectorInfo selInfo = infoQ.pollLast(); // last in first out
if (selInfo == null) {
Selector selector = provider.openSelector();
selInfo = new SelectorInfo();
selInfo.selector = selector;
selInfo.queue = queue;
} else {
selInfo = queue.removeLast();
// selInfo will be put into infoQ after `#release()`
selInfo = new SelectorInfo(provider, selector);
}

trimIdleSelectors(Time.now());
Expand All @@ -426,34 +408,39 @@ private synchronized SelectorInfo get(SelectableChannel channel)
*
* @param info
*/
private synchronized void release(SelectorInfo info) {
private static void release(SelectorInfo info) {
long now = Time.now();
trimIdleSelectors(now);
info.lastActivityTime = now;
info.queue.addLast(info);
// SelectorInfos in queue are sorted by lastActivityTime
providerMap.get(info.provider).addLast(info);
}

private static AtomicBoolean trimming = new AtomicBoolean(false);

/**
* Closes selectors that are idle for IDLE_TIMEOUT (10 sec). It does not
* traverse the whole list, just over the one that have crossed
* the timeout.
*/
private void trimIdleSelectors(long now) {
private static void trimIdleSelectors(long now) {
if (!trimming.compareAndSet(false, true)) {
return;
}

long cutoff = now - IDLE_TIMEOUT;

for(ProviderInfo pList=providerList; pList != null; pList=pList.next) {
if (pList.queue.isEmpty()) {
continue;
}
for(Iterator<SelectorInfo> it = pList.queue.iterator(); it.hasNext();) {
SelectorInfo info = it.next();
if (info.lastActivityTime > cutoff) {
for (ConcurrentLinkedDeque<SelectorInfo> infoQ : providerMap.values()) {
SelectorInfo oldest;
while ((oldest = infoQ.peekFirst()) != null) {
if (oldest.lastActivityTime <= cutoff && infoQ.remove(oldest)) {
oldest.close();
} else {
break;
}
it.remove();
info.close();
}
}

trimming.set(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
import java.net.SocketTimeoutException;
import java.nio.channels.Pipe;
import java.util.Arrays;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.hadoop.test.GenericTestUtils;
import org.apache.hadoop.test.MultithreadedTestUtil;
Expand Down Expand Up @@ -186,6 +191,46 @@ public void doWork() throws Exception {
}
}

@Test
public void testSocketIOWithTimeoutByMultiThread() throws Exception {
CountDownLatch latch = new CountDownLatch(1);
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, TIMEOUT);
Pipe.SinkChannel sink = pipe.sink();
OutputStream out = new SocketOutputStream(sink, TIMEOUT)) {

byte[] writeBytes = TEST_STRING.getBytes();
byte[] readBytes = new byte[writeBytes.length];
latch.await();

out.write(writeBytes);
doIO(null, out, TIMEOUT);

in.read(readBytes);
assertArrayEquals(writeBytes, readBytes);
doIO(in, null, TIMEOUT);
}
} catch (Exception e) {
fail(e.getMessage());
}
};

int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}

Thread.sleep(1000);
latch.countDown();

threadPool.shutdown();
assertTrue(threadPool.awaitTermination(3, TimeUnit.SECONDS));
}

@Test
public void testSocketIOWithTimeoutInterrupted() throws Exception {
Pipe pipe = Pipe.open();
Expand Down Expand Up @@ -223,4 +268,38 @@ public void doWork() throws Exception {
ctx.stop();
}
}

@Test
public void testSocketIOWithTimeoutInterruptedByMultiThread()
throws Exception {
final int timeout = TIMEOUT * 10;
AtomicLong readCount = new AtomicLong();
AtomicLong exceptionCount = new AtomicLong();
Runnable ioTask = () -> {
try {
Pipe pipe = Pipe.open();
try (Pipe.SourceChannel source = pipe.source();
InputStream in = new SocketInputStream(source, timeout)) {
in.read();
readCount.incrementAndGet();
} catch (InterruptedIOException ste) {
exceptionCount.incrementAndGet();
}
} catch (Exception e) {
fail(e.getMessage());
}
};

int threadCnt = 64;
ExecutorService threadPool = Executors.newFixedThreadPool(threadCnt);
for (int i = 0; i < threadCnt; ++i) {
threadPool.submit(ioTask);
}
Thread.sleep(1000);
threadPool.shutdownNow();
threadPool.awaitTermination(1, TimeUnit.SECONDS);

assertEquals(0, readCount.get());
assertEquals(threadCnt, exceptionCount.get());
}
}