Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.SocketChannel;
Expand All @@ -49,6 +50,7 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;

import static com.mongodb.assertions.Assertions.assertTrue;
import static com.mongodb.assertions.Assertions.isTrue;
Expand Down Expand Up @@ -97,21 +99,40 @@ public void close() {
group.shutdown();
}

/**
* Monitors `OP_CONNECT` events for socket connections.
*/
private static class SelectorMonitor implements Closeable {

private static final class Pair {
static final class SocketRegistration {
private final SocketChannel socketChannel;
private final Runnable attachment;
private final AtomicReference<Runnable> afterConnectAction;

private Pair(final SocketChannel socketChannel, final Runnable attachment) {
SocketRegistration(final SocketChannel socketChannel, final Runnable afterConnectAction) {
this.socketChannel = socketChannel;
this.attachment = attachment;
this.afterConnectAction = new AtomicReference<>(afterConnectAction);
}

boolean tryCancelPendingConnection() {
return tryTakeAction() != null;
}

void runAfterConnectActionIfNotCanceled() {
Runnable afterConnectActionToExecute = tryTakeAction();
if (afterConnectActionToExecute != null) {
afterConnectActionToExecute.run();
}
}

@Nullable
private Runnable tryTakeAction() {
return afterConnectAction.getAndSet(null);
}
}

private final Selector selector;
private volatile boolean isClosed;
private final ConcurrentLinkedDeque<Pair> pendingRegistrations = new ConcurrentLinkedDeque<>();
private final ConcurrentLinkedDeque<SocketRegistration> pendingRegistrations = new ConcurrentLinkedDeque<>();

SelectorMonitor() {
try {
Expand All @@ -127,17 +148,14 @@ void start() {
while (!isClosed) {
try {
selector.select();

for (SelectionKey selectionKey : selector.selectedKeys()) {
selectionKey.cancel();
Runnable runnable = (Runnable) selectionKey.attachment();
runnable.run();
((SocketRegistration) selectionKey.attachment()).runAfterConnectActionIfNotCanceled();
}

for (Iterator<Pair> iter = pendingRegistrations.iterator(); iter.hasNext();) {
Pair pendingRegistration = iter.next();
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT,
pendingRegistration.attachment);
for (Iterator<SocketRegistration> iter = pendingRegistrations.iterator(); iter.hasNext();) {
SocketRegistration pendingRegistration = iter.next();
pendingRegistration.socketChannel.register(selector, SelectionKey.OP_CONNECT, pendingRegistration);
iter.remove();
}
} catch (Exception e) {
Expand All @@ -156,8 +174,9 @@ void start() {
selectorThread.start();
}

void register(final SocketChannel channel, final Runnable attachment) {
pendingRegistrations.add(new Pair(channel, attachment));

void register(final SocketRegistration registration) {
pendingRegistrations.add(registration);
selector.wakeup();
}

Expand Down Expand Up @@ -200,44 +219,82 @@ public void openAsync(final OperationContext operationContext, final AsyncComple
if (getSettings().getSendBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
}

//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
socketChannel, () -> initializeTslChannel(handler, socketChannel));

selectorMonitor.register(socketChannel, () -> {
try {
if (!socketChannel.finishConnect()) {
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
}
if (connectTimeoutMs > 0) {
scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs);
}
selectorMonitor.register(socketRegistration);
} catch (IOException e) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
handler.failed(t);
}
}

SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
getServerAddress().getPort());
sslEngine.setUseClientMode(true);
private void scheduleTimeoutInterruption(final AsyncCompletionHandler<Void> handler,
final SelectorMonitor.SocketRegistration socketRegistration,
final int connectTimeoutMs) {
group.getTimeoutExecutor().schedule(() -> {
if (socketRegistration.tryCancelPendingConnection()) {
closeAndTimeout(handler, socketRegistration.socketChannel);
}
}, connectTimeoutMs, TimeUnit.MILLISECONDS);
}

SSLParameters sslParameters = sslEngine.getSSLParameters();
enableSni(getServerAddress().getHost(), sslParameters);
private void closeAndTimeout(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
// We check if this stream was closed before timeout exception.
boolean streamClosed = isClosed();

if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
sslEngine.setSSLParameters(sslParameters);
//TODO refactor ths draft
InterruptedByTimeoutException timeoutException = new InterruptedByTimeoutException();
try {
socketChannel.close();
} catch (Exception e) {
//TODO should ignore this exception? We seem to do so in other places
timeoutException.addSuppressed(e);
}

BufferAllocator bufferAllocator = new BufferProviderAllocator();
if (streamClosed) {
handler.completed(null);
} else {
handler.failed(new MongoSocketOpenException("Exception opening socket", getAddress(), timeoutException));
}
}

TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
.withEncryptedBufferAllocator(bufferAllocator)
.withPlainBufferAllocator(bufferAllocator)
.build();
private void initializeTslChannel(final AsyncCompletionHandler<Void> handler, final SocketChannel socketChannel) {
try {
if (!socketChannel.finishConnect()) {
throw new MongoSocketOpenException("Failed to finish connect", getServerAddress());
}

// build asynchronous channel, based in the TLS channel and associated with the global group.
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));
SSLEngine sslEngine = getSslContext().createSSLEngine(getServerAddress().getHost(),
getServerAddress().getPort());
sslEngine.setUseClientMode(true);

handler.completed(null);
} catch (IOException e) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
handler.failed(t);
}
});
SSLParameters sslParameters = sslEngine.getSSLParameters();
enableSni(getServerAddress().getHost(), sslParameters);

if (!sslSettings.isInvalidHostNameAllowed()) {
enableHostNameVerification(sslParameters);
}
sslEngine.setSSLParameters(sslParameters);

BufferAllocator bufferAllocator = new BufferProviderAllocator();

TlsChannel tlsChannel = ClientTlsChannel.newBuilder(socketChannel, sslEngine)
.withEncryptedBufferAllocator(bufferAllocator)
.withPlainBufferAllocator(bufferAllocator)
.build();

// build asynchronous channel, based in the TLS channel and associated with the global group.
setChannel(new AsynchronousTlsChannelAdapter(new AsynchronousTlsChannel(group, tlsChannel, socketChannel)));

handler.completed(null);
} catch (IOException e) {
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -823,4 +823,13 @@ public long getCurrentWriteCount() {
public long getCurrentRegistrationCount() {
return registrations.mappingCount();
}

/**
* Returns the timeout executor used by this channel group.
*
* @return the timeout executor
*/
public ScheduledThreadPoolExecutor getTimeoutExecutor() {
return timeoutExecutor;
}
}