diff --git a/core/src/main/java/org/elasticsearch/action/ActionListener.java b/core/src/main/java/org/elasticsearch/action/ActionListener.java index fa32ab417737c..8579fb55613ce 100644 --- a/core/src/main/java/org/elasticsearch/action/ActionListener.java +++ b/core/src/main/java/org/elasticsearch/action/ActionListener.java @@ -24,6 +24,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.BiConsumer; import java.util.function.Consumer; /** @@ -69,6 +70,42 @@ public void onFailure(Exception e) { }; } + /** + * Creates a listener that listens for a response (or failure) and executes the + * corresponding runnable when the response (or failure) is received. + * + * @param runnable the runnable that will be called in event of success or failure + * @param the type of the response + * @return a listener that listens for responses and invokes the runnable when received + */ + static ActionListener wrap(Runnable runnable) { + return wrap(r -> runnable.run(), e -> runnable.run()); + } + + /** + * Converts a listener to a {@link BiConsumer} for compatibility with the {@link java.util.concurrent.CompletableFuture} + * api. + * + * @param listener that will be wrapped + * @param the type of the response + * @return a bi consumer that will complete the wrapped listener + */ + static BiConsumer toBiConsumer(ActionListener listener) { + return (response, throwable) -> { + if (throwable == null) { + listener.onResponse(response); + } else { + if (throwable instanceof Exception) { + listener.onFailure((Exception) throwable); + } else if (throwable instanceof Error) { + throw (Error) throwable; + } else { + throw new AssertionError("Should have been either Error or Exception", throwable); + } + } + }; + } + /** * Notifies every given listener with the response passed to {@link #onResponse(Object)}. If a listener itself throws an exception * the exception is forwarded to {@link #onFailure(Exception)}. If in turn {@link #onFailure(Exception)} fails all remaining diff --git a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java index 17f3f7b7b4a0b..a36c9f6f77b9b 100644 --- a/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java +++ b/core/src/main/java/org/elasticsearch/transport/ConnectionProfile.java @@ -208,12 +208,12 @@ private ConnectionTypeHandle(int offset, int length, Set T getChannel(T[] channels) { + T getChannel(List channels) { if (length == 0) { throw new IllegalStateException("can't select channel size is 0 for types: " + types); } - assert channels.length >= offset + length : "illegal size: " + channels.length + " expected >= " + (offset + length); - return channels[offset + Math.floorMod(counter.incrementAndGet(), length)]; + assert channels.size() >= offset + length : "illegal size: " + channels.size() + " expected >= " + (offset + length); + return channels.get(offset + Math.floorMod(counter.incrementAndGet(), length)); } /** @@ -223,5 +223,4 @@ Set getTypes() { return types; } } - } diff --git a/core/src/main/java/org/elasticsearch/transport/TcpChannel.java b/core/src/main/java/org/elasticsearch/transport/TcpChannel.java new file mode 100644 index 0000000000000..f429e71f4a874 --- /dev/null +++ b/core/src/main/java/org/elasticsearch/transport/TcpChannel.java @@ -0,0 +1,169 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport; + +import org.apache.logging.log4j.Logger; +import org.apache.logging.log4j.message.ParameterizedMessage; +import org.elasticsearch.action.ActionFuture; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.unit.TimeValue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + + +/** + * This is a tcp channel representing a single channel connection to another node. It is the base channel + * abstraction used by the {@link TcpTransport} and {@link TransportService}. All tcp transport + * implementations must return channels that adhere to the required method contracts. + */ +public interface TcpChannel extends Releasable { + + /** + * Closes the channel. This might be an asynchronous process. There is notguarantee that the channel + * will be closed when this method returns. Use the {@link #addCloseListener(ActionListener)} method + * to implement logic that depends on knowing when the channel is closed. + */ + void close(); + + /** + * Adds a listener that will be executed when the channel is closed. If the channel is still open when + * this listener is added, the listener will be executed by the thread that eventually closes the + * channel. If the channel is already closed when the listener is added the listener will immediately be + * executed by the thread that is attempting to add the listener. + * + * @param listener to be executed + */ + void addCloseListener(ActionListener listener); + + + /** + * This sets the low level socket option {@link java.net.StandardSocketOptions} SO_LINGER on a channel. + * + * @param value to set for SO_LINGER + * @throws IOException that can be throw by the low level socket implementation + */ + void setSoLinger(int value) throws IOException; + + + /** + * Indicates whether a channel is currently open + * + * @return boolean indicating if channel is open + */ + boolean isOpen(); + + /** + * Closes the channel. + * + * @param channel to close + * @param blocking indicates if we should block on channel close + */ + static void closeChannel(C channel, boolean blocking) { + closeChannels(Collections.singletonList(channel), blocking); + } + + /** + * Closes the channels. + * + * @param channels to close + * @param blocking indicates if we should block on channel close + */ + static void closeChannels(List channels, boolean blocking) { + if (blocking) { + ArrayList> futures = new ArrayList<>(channels.size()); + for (final C channel : channels) { + if (channel.isOpen()) { + PlainActionFuture closeFuture = PlainActionFuture.newFuture(); + channel.addCloseListener(closeFuture); + channel.close(); + futures.add(closeFuture); + } + } + blockOnFutures(futures); + } else { + Releasables.close(channels); + } + } + + /** + * Awaits for all of the pending connections to complete. Will throw an exception if at least one of the + * connections fails. + * + * @param discoveryNode the node for the pending connections + * @param connectionFutures representing the pending connections + * @param connectTimeout to wait for a connection + * @param the type of channel + * @throws ConnectTransportException if one of the connections fails + */ + static void awaitConnected(DiscoveryNode discoveryNode, List> connectionFutures, + TimeValue connectTimeout) throws ConnectTransportException { + Exception connectionException = null; + boolean allConnected = true; + + for (ActionFuture connectionFuture : connectionFutures) { + try { + connectionFuture.get(connectTimeout.getMillis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException e) { + allConnected = false; + break; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException(e); + } catch (ExecutionException e) { + allConnected = false; + connectionException = (Exception) e.getCause(); + break; + } + } + + if (allConnected == false) { + if (connectionException == null) { + throw new ConnectTransportException(discoveryNode, "connect_timeout[" + connectTimeout + "]"); + } else { + throw new ConnectTransportException(discoveryNode, "connect_exception", connectionException); + } + } + } + + static void blockOnFutures(List> futures) { + for (ActionFuture future : futures) { + try { + future.get(); + } catch (ExecutionException e) { + // Ignore as we are only interested in waiting for the close process to complete. Logging + // close exceptions happens elsewhere. + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IllegalStateException("Future got interrupted", e); + } + } + } +} diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java index 62ad2b58fb78e..4092eb6256988 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransport.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransport.java @@ -25,8 +25,10 @@ import org.apache.lucene.util.IOUtils; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; +import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NotifyOnceListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Nullable; @@ -104,7 +106,6 @@ import java.util.function.Function; import java.util.regex.Matcher; import java.util.regex.Pattern; -import java.util.stream.Collectors; import static java.util.Collections.emptyList; import static java.util.Collections.unmodifiableMap; @@ -117,7 +118,7 @@ import static org.elasticsearch.common.transport.NetworkExceptionHelper.isConnectException; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; -public abstract class TcpTransport extends AbstractLifecycleComponent implements Transport { +public abstract class TcpTransport extends AbstractLifecycleComponent implements Transport { public static final String TRANSPORT_SERVER_WORKER_THREAD_NAME_PREFIX = "transport_server_worker"; public static final String TRANSPORT_CLIENT_BOSS_THREAD_NAME_PREFIX = "transport_client_boss"; @@ -178,7 +179,7 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i public static final Setting.AffixSetting> PUBLISH_HOST_PROFILE = affixKeySetting("transport.profiles.", "publish_host", key -> listSetting(key, PUBLISH_HOST, Function.identity(), Setting.Property.NodeScope)); public static final Setting.AffixSetting PORT_PROFILE = affixKeySetting("transport.profiles.", "port", - key -> new Setting(key, PORT, Function.identity(), Setting.Property.NodeScope)); + key -> new Setting<>(key, PORT, Function.identity(), Setting.Property.NodeScope)); public static final Setting.AffixSetting PUBLISH_PORT_PROFILE = affixKeySetting("transport.profiles.", "publish_port", key -> intSetting(key, -1, -1, Setting.Property.NodeScope)); @@ -197,8 +198,9 @@ public abstract class TcpTransport extends AbstractLifecycleComponent i // node id to actual channel protected final ConcurrentMap connectedNodes = newConcurrentMap(); - protected final Map> serverChannels = newConcurrentMap(); protected final ConcurrentMap profileBoundAddresses = newConcurrentMap(); + private final Map> serverChannels = newConcurrentMap(); + private final Set acceptedChannels = Collections.newSetFromMap(new ConcurrentHashMap<>()); protected final KeyedLock connectionLock = new KeyedLock<>(); private final NamedWriteableRegistry namedWriteableRegistry; @@ -347,7 +349,7 @@ protected void innerInnerOnResponse(Channel channel) { @Override protected void innerOnFailure(Exception e) { - if (isOpen(channel)) { + if (channel.isOpen()) { logger.debug( (Supplier) () -> new ParameterizedMessage("[{}] failed to send ping transport message", node), e); failedPings.inc(); @@ -395,29 +397,22 @@ public void onFailure(Exception e) { public final class NodeChannels implements Connection { private final Map typeMapping; - private final Channel[] channels; + private final List channels; private final DiscoveryNode node; private final AtomicBoolean closed = new AtomicBoolean(false); private final Version version; - public NodeChannels(DiscoveryNode node, Channel[] channels, ConnectionProfile connectionProfile) { + NodeChannels(DiscoveryNode node, List channels, ConnectionProfile connectionProfile, Version handshakeVersion) { this.node = node; - this.channels = channels; - assert channels.length == connectionProfile.getNumConnections() : "expected channels size to be == " - + connectionProfile.getNumConnections() + " but was: [" + channels.length + "]"; + this.channels = Collections.unmodifiableList(channels); + assert channels.size() == connectionProfile.getNumConnections() : "expected channels size to be == " + + connectionProfile.getNumConnections() + " but was: [" + channels.size() + "]"; typeMapping = new EnumMap<>(TransportRequestOptions.Type.class); for (ConnectionProfile.ConnectionTypeHandle handle : connectionProfile.getHandles()) { for (TransportRequestOptions.Type type : handle.getTypes()) typeMapping.put(type, handle); } - version = node.getVersion(); - } - - NodeChannels(NodeChannels channels, Version handshakeVersion) { - this.node = channels.node; - this.channels = channels.channels; - this.typeMapping = channels.typeMapping; - this.version = handshakeVersion; + version = handshakeVersion; } @Override @@ -426,7 +421,7 @@ public Version getVersion() { } public List getChannels() { - return Arrays.asList(channels); + return channels; } public Channel channel(TransportRequestOptions.Type type) { @@ -437,12 +432,34 @@ public Channel channel(TransportRequestOptions.Type type) { return connectionTypeHandle.getChannel(channels); } + public boolean allChannelsOpen() { + return channels.stream().allMatch(TcpChannel::isOpen); + } + @Override public void close() throws IOException { if (closed.compareAndSet(false, true)) { try { - closeChannels(Arrays.stream(channels).filter(Objects::nonNull).collect(Collectors.toList()), false, - lifecycle.stopped()); + if (lifecycle.stopped()) { + /* We set SO_LINGER timeout to 0 to ensure that when we shutdown the node we don't + * have a gazillion connections sitting in TIME_WAIT to free up resources quickly. + * This is really the only part where we close the connection from the server side + * otherwise the client (node) initiates the TCP closing sequence which doesn't cause + * these issues. Setting this by default from the beginning can have unexpected + * side-effects an should be avoided, our protocol is designed in a way that clients + * close connection which is how it should be*/ + + channels.forEach(c -> { + try { + c.setSoLinger(0); + } catch (IOException e) { + logger.warn(new ParameterizedMessage("unexpected exception when setting SO_LINGER on channel {}", c), e); + } + }); + } + + boolean block = lifecycle.stopped() && Transports.isTransportThread(Thread.currentThread()) == false; + TcpChannel.closeChannels(channels, block); } finally { transportService.onConnectionClosed(this); } @@ -478,7 +495,7 @@ public boolean nodeConnected(DiscoveryNode node) { public void connectToNode(DiscoveryNode node, ConnectionProfile connectionProfile, CheckedBiConsumer connectionValidator) throws ConnectTransportException { - connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); + connectionProfile = resolveConnectionProfile(connectionProfile); if (node == null) { throw new ConnectTransportException(null, "can't connect to a null node"); } @@ -559,6 +576,10 @@ static ConnectionProfile resolveConnectionProfile(@Nullable ConnectionProfile co } } + protected ConnectionProfile resolveConnectionProfile(ConnectionProfile connectionProfile) { + return resolveConnectionProfile(connectionProfile, defaultConnectionProfile); + } + @Override public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile connectionProfile) throws IOException { if (node == null) { @@ -566,40 +587,66 @@ public final NodeChannels openConnection(DiscoveryNode node, ConnectionProfile c } boolean success = false; NodeChannels nodeChannels = null; - connectionProfile = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); + connectionProfile = resolveConnectionProfile(connectionProfile); closeLock.readLock().lock(); // ensure we don't open connections while we are closing try { ensureOpen(); try { + int numConnections = connectionProfile.getNumConnections(); + assert numConnections > 0 : "A connection profile must be configured with at least one connection"; + List channels = new ArrayList<>(numConnections); + List> connectionFutures = new ArrayList<>(numConnections); + for (int i = 0; i < numConnections; ++i) { + try { + PlainActionFuture connectFuture = PlainActionFuture.newFuture(); + connectionFutures.add(connectFuture); + Channel channel = initiateChannel(node, connectionProfile.getConnectTimeout(), connectFuture); + channels.add(channel); + } catch (Exception e) { + // If there was an exception when attempting to instantiate the raw channels, we close all of the channels + TcpChannel.closeChannels(channels, false); + throw e; + } + } + + // If we make it past the block above, we successfully instantiated all of the channels + try { + TcpChannel.awaitConnected(node, connectionFutures, connectionProfile.getConnectTimeout()); + } catch (Exception ex) { + TcpChannel.closeChannels(channels, false); + throw ex; + } + + // If we make it past the block above, we have successfully established connections for all of the channels + final Channel handshakeChannel = channels.get(0); // one channel is guaranteed by the connection profile + handshakeChannel.addCloseListener(ActionListener.wrap(() -> cancelHandshakeForChannel(handshakeChannel))); + Version version; + try { + version = executeHandshake(node, handshakeChannel, connectionProfile.getHandshakeTimeout()); + } catch (Exception ex) { + TcpChannel.closeChannels(channels, false); + throw ex; + } + + // If we make it past the block above, we have successfully completed the handshake and the connection is now open. + // At this point we should construct the connection, notify the transport service, and attach close listeners to the + // underlying channels. + nodeChannels = new NodeChannels(node, channels, connectionProfile, version); + transportService.onConnectionOpened(nodeChannels); + final NodeChannels finalNodeChannels = nodeChannels; final AtomicBoolean runOnce = new AtomicBoolean(false); - final AtomicReference connectionRef = new AtomicReference<>(); Consumer onClose = c -> { - assert isOpen(c) == false : "channel is still open when onClose is called"; - try { - onChannelClosed(c); - } finally { - // we only need to disconnect from the nodes once since all other channels - // will also try to run this we protect it from running multiple times. - if (runOnce.compareAndSet(false, true)) { - NodeChannels connection = connectionRef.get(); - if (connection != null) { - disconnectFromNodeCloseAndNotify(node, connection); - } - } + assert c.isOpen() == false : "channel is still open when onClose is called"; + // we only need to disconnect from the nodes once since all other channels + // will also try to run this we protect it from running multiple times. + if (runOnce.compareAndSet(false, true)) { + disconnectFromNodeCloseAndNotify(node, finalNodeChannels); } }; - nodeChannels = connectToChannels(node, connectionProfile, onClose); - final Channel channel = nodeChannels.getChannels().get(0); // one channel is guaranteed by the connection profile - final TimeValue connectTimeout = connectionProfile.getConnectTimeout() == null ? - defaultConnectionProfile.getConnectTimeout() : - connectionProfile.getConnectTimeout(); - final TimeValue handshakeTimeout = connectionProfile.getHandshakeTimeout() == null ? - connectTimeout : connectionProfile.getHandshakeTimeout(); - final Version version = executeHandshake(node, channel, handshakeTimeout); - nodeChannels = new NodeChannels(nodeChannels, version); // clone the channels - we now have the correct version - transportService.onConnectionOpened(nodeChannels); - connectionRef.set(nodeChannels); - if (Arrays.stream(nodeChannels.channels).allMatch(this::isOpen) == false) { + + nodeChannels.channels.forEach(ch -> ch.addCloseListener(ActionListener.wrap(() -> onClose.accept(ch)))); + + if (nodeChannels.allChannelsOpen() == false) { throw new ConnectTransportException(node, "a channel closed while connecting"); } success = true; @@ -637,19 +684,6 @@ private void disconnectFromNodeCloseAndNotify(DiscoveryNode node, NodeChannels n } } - /** - * Disconnects from a node if a channel is found as part of that nodes channels. - */ - protected final void closeChannelWhileHandlingExceptions(final Channel channel) { - if (isOpen(channel)) { - try { - closeChannels(Collections.singletonList(channel), false, false); - } catch (IOException e) { - logger.warn("failed to close channel", e); - } - } - } - @Override public NodeChannels getConnection(DiscoveryNode node) { NodeChannels nodeChannels = connectedNodes.get(node); @@ -904,12 +938,20 @@ protected final void doStop() { try { // first stop to accept any incoming connections so nobody can connect to this transport for (Map.Entry> entry : serverChannels.entrySet()) { - try { - closeChannels(entry.getValue(), true, false); - } catch (Exception e) { - logger.warn(new ParameterizedMessage("Error closing serverChannel for profile [{}]", entry.getKey()), e); - } + String profile = entry.getKey(); + List channels = entry.getValue(); + ActionListener closeFailLogger = ActionListener.wrap(c -> {}, + e -> logger.warn(() -> new ParameterizedMessage("Error closing serverChannel for profile [{}]", profile), e)); + channels.forEach(c -> c.addCloseListener(closeFailLogger)); + TcpChannel.closeChannels(channels, true); } + serverChannels.clear(); + + // close all of the incoming channels. The closeChannels method takes a list so we must convert the set. + TcpChannel.closeChannels(new ArrayList<>(acceptedChannels), true); + acceptedChannels.clear(); + + // we are holding a write lock so nobody modifies the connectedNodes / openConnections map - it's safe to first close // all instances and then clear them maps Iterator> iterator = connectedNodes.entrySet().iterator(); @@ -940,7 +982,7 @@ protected final void doStop() { protected void onException(Channel channel, Exception e) { if (!lifecycle.started()) { // just close and ignore - we are already stopped and just need to make sure we release all resources - closeChannelWhileHandlingExceptions(channel); + TcpChannel.closeChannel(channel, false); return; } @@ -951,15 +993,15 @@ protected void onException(Channel channel, Exception e) { channel), e); // close the channel, which will cause a node to be disconnected if relevant - closeChannelWhileHandlingExceptions(channel); + TcpChannel.closeChannel(channel, false); } else if (isConnectException(e)) { logger.trace((Supplier) () -> new ParameterizedMessage("connect exception caught on transport layer [{}]", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - closeChannelWhileHandlingExceptions(channel); + TcpChannel.closeChannel(channel, false); } else if (e instanceof BindException) { logger.trace((Supplier) () -> new ParameterizedMessage("bind exception caught on transport layer [{}]", channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - closeChannelWhileHandlingExceptions(channel); + TcpChannel.closeChannel(channel, false); } else if (e instanceof CancelledKeyException) { logger.trace( (Supplier) () -> new ParameterizedMessage( @@ -967,29 +1009,21 @@ protected void onException(Channel channel, Exception e) { channel), e); // close the channel as safe measure, which will cause a node to be disconnected if relevant - closeChannelWhileHandlingExceptions(channel); + TcpChannel.closeChannel(channel, false); } else if (e instanceof TcpTransport.HttpOnTransportException) { // in case we are able to return data, serialize the exception content and sent it back to the client - if (isOpen(channel)) { + if (channel.isOpen()) { BytesArray message = new BytesArray(e.getMessage().getBytes(StandardCharsets.UTF_8)); final SendMetricListener closeChannel = new SendMetricListener(message.length()) { @Override protected void innerInnerOnResponse(Channel channel) { - try { - closeChannels(Collections.singletonList(channel), false, false); - } catch (IOException e1) { - logger.debug("failed to close httpOnTransport channel", e1); - } + TcpChannel.closeChannel(channel, false); } @Override protected void innerOnFailure(Exception e) { - try { - closeChannels(Collections.singletonList(channel), false, false); - } catch (IOException e1) { - e.addSuppressed(e1); - logger.debug("failed to close httpOnTransport channel", e1); - } + logger.debug("failed to send message to httpOnTransport channel", e); + TcpChannel.closeChannel(channel, false); } }; internalSendMessage(channel, message, closeChannel); @@ -998,10 +1032,16 @@ protected void innerOnFailure(Exception e) { logger.warn( (Supplier) () -> new ParameterizedMessage("exception caught on transport layer [{}], closing connection", channel), e); // close the channel, which will cause a node to be disconnected if relevant - closeChannelWhileHandlingExceptions(channel); + TcpChannel.closeChannel(channel, false); } } + protected void serverAcceptedChannel(Channel channel) { + boolean addedOnThisCall = acceptedChannels.add(channel); + assert addedOnThisCall : "Channel should only be added to accept channel set once"; + channel.addCloseListener(ActionListener.wrap(() -> acceptedChannels.remove(channel))); + } + /** * Returns the channels local address */ @@ -1015,44 +1055,34 @@ protected void innerOnFailure(Exception e) { */ protected abstract Channel bind(String name, InetSocketAddress address) throws IOException; - /** - * Closes all channels in this list. If the blocking boolean is set to true, the channels must be - * closed before the method returns. This should never be called with blocking set to true from a - * network thread. - * - * @param channels the channels to close - * @param blocking whether the channels should be closed synchronously - * @param doNotLinger whether we abort the connection on RST instead of FIN - */ - protected abstract void closeChannels(List channels, boolean blocking, boolean doNotLinger) throws IOException; - /** * Sends message to channel. The listener's onResponse method will be called when the send is complete unless an exception * is thrown during the send. If an exception is thrown, the listener's onException method will be called. - * @param channel the destination channel + * + * @param channel the destination channel * @param reference the byte reference for the message - * @param listener the listener to call when the operation has completed + * @param listener the listener to call when the operation has completed */ protected abstract void sendMessage(Channel channel, BytesReference reference, ActionListener listener); /** - * Connect to the node with channels as defined by the specified connection profile. Implementations must invoke the specified channel - * close callback when a channel is closed. + * Initiate a single tcp socket channel to a node. Implementations do not have to observe the connectTimeout. + * It is provided for synchronous connection implementations. * - * @param node the node to connect to - * @param connectionProfile the connection profile - * @param onChannelClose callback to invoke when a channel is closed - * @return the channels - * @throws IOException if an I/O exception occurs while opening channels + * @param node the node + * @param connectTimeout the connection timeout + * @param connectListener listener to be called when connection complete + * @return the pending connection + * @throws IOException if an I/O exception occurs while opening the channel */ - protected abstract NodeChannels connectToChannels(DiscoveryNode node, - ConnectionProfile connectionProfile, - Consumer onChannelClose) throws IOException; + protected abstract Channel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) + throws IOException; /** * Called to tear down internal resources */ - protected void stopInternal() {} + protected void stopInternal() { + } public boolean canCompress(TransportRequest request) { return compress && (!(request instanceof BytesTransportRequest)); @@ -1118,10 +1148,10 @@ private void internalSendMessage(Channel targetChannel, BytesReference message, * Sends back an error response to the caller via the given channel * * @param nodeVersion the caller node version - * @param channel the channel to send the response to - * @param error the error to return - * @param requestId the request ID this response replies to - * @param action the action this response replies to + * @param channel the channel to send the response to + * @param error the error to return + * @param requestId the request ID this response replies to + * @param action the action this response replies to */ public void sendErrorResponse(Version nodeVersion, Channel channel, final Exception error, final long requestId, final String action) throws IOException { @@ -1146,7 +1176,7 @@ public void sendErrorResponse(Version nodeVersion, Channel channel, final Except /** * Sends the response to the given channel. This method should be used to send {@link TransportResponse} objects back to the caller. * - * @see #sendErrorResponse(Version, Object, Exception, long, String) for sending back errors to the caller + * @see #sendErrorResponse(Version, TcpChannel, Exception, long, String) for sending back errors to the caller */ public void sendResponse(Version nodeVersion, Channel channel, final TransportResponse response, final long requestId, final String action, TransportResponseOptions options) throws IOException { @@ -1154,7 +1184,7 @@ public void sendResponse(Version nodeVersion, Channel channel, final TransportRe } private void sendResponse(Version nodeVersion, Channel channel, final TransportResponse response, final long requestId, - final String action, TransportResponseOptions options, byte status) throws IOException { + final String action, TransportResponseOptions options, byte status) throws IOException { if (compress) { options = TransportResponseOptions.builder(options).withCompress(true).build(); } @@ -1232,10 +1262,10 @@ private BytesReference buildMessage(long requestId, byte status, Version nodeVer * Validates the first N bytes of the message header and returns false if the message is * a ping message and has no payload ie. isn't a real user level message. * - * @throws IllegalStateException if the message is too short, less than the header or less that the header plus the message size + * @throws IllegalStateException if the message is too short, less than the header or less that the header plus the message size * @throws HttpOnTransportException if the message has no valid header and appears to be a HTTP message * @throws IllegalArgumentException if the message is greater that the maximum allowed frame size. This is dependent on the available - * memory. + * memory. */ public static boolean validateMessageHeader(BytesReference buffer) throws IOException { final int sizeHeaderLength = TcpHeader.MARKER_BYTES_SIZE + TcpHeader.MESSAGE_LENGTH_SIZE; @@ -1246,23 +1276,23 @@ public static boolean validateMessageHeader(BytesReference buffer) throws IOExce if (buffer.get(offset) != 'E' || buffer.get(offset + 1) != 'S') { // special handling for what is probably HTTP if (bufferStartsWith(buffer, offset, "GET ") || - bufferStartsWith(buffer, offset, "POST ") || - bufferStartsWith(buffer, offset, "PUT ") || - bufferStartsWith(buffer, offset, "HEAD ") || - bufferStartsWith(buffer, offset, "DELETE ") || - bufferStartsWith(buffer, offset, "OPTIONS ") || - bufferStartsWith(buffer, offset, "PATCH ") || - bufferStartsWith(buffer, offset, "TRACE ")) { + bufferStartsWith(buffer, offset, "POST ") || + bufferStartsWith(buffer, offset, "PUT ") || + bufferStartsWith(buffer, offset, "HEAD ") || + bufferStartsWith(buffer, offset, "DELETE ") || + bufferStartsWith(buffer, offset, "OPTIONS ") || + bufferStartsWith(buffer, offset, "PATCH ") || + bufferStartsWith(buffer, offset, "TRACE ")) { throw new HttpOnTransportException("This is not a HTTP port"); } // we have 6 readable bytes, show 4 (should be enough) throw new StreamCorruptedException("invalid internal transport message format, got (" - + Integer.toHexString(buffer.get(offset) & 0xFF) + "," - + Integer.toHexString(buffer.get(offset + 1) & 0xFF) + "," - + Integer.toHexString(buffer.get(offset + 2) & 0xFF) + "," - + Integer.toHexString(buffer.get(offset + 3) & 0xFF) + ")"); + + Integer.toHexString(buffer.get(offset) & 0xFF) + "," + + Integer.toHexString(buffer.get(offset + 1) & 0xFF) + "," + + Integer.toHexString(buffer.get(offset + 2) & 0xFF) + "," + + Integer.toHexString(buffer.get(offset + 3) & 0xFF) + ")"); } final int dataLen; @@ -1322,8 +1352,6 @@ public HttpOnTransportException(StreamInput in) throws IOException { } } - protected abstract boolean isOpen(Channel channel); - /** * This method handles the message receive part for both request and responses */ @@ -1410,7 +1438,7 @@ static void ensureVersionCompatibility(Version version, Version currentVersion, final Version compatibilityVersion = isHandshake ? currentVersion.minimumCompatibilityVersion() : currentVersion; if (version.isCompatible(compatibilityVersion) == false) { final Version minCompatibilityVersion = isHandshake ? compatibilityVersion : compatibilityVersion.minimumCompatibilityVersion(); - String msg = "Received " + (isHandshake? "handshake " : "") + "message from unsupported version: ["; + String msg = "Received " + (isHandshake ? "handshake " : "") + "message from unsupported version: ["; throw new IllegalStateException(msg + version + "] minimal compatible version is: [" + minCompatibilityVersion + "]"); } } @@ -1566,7 +1594,8 @@ private VersionHandshakeResponse(Version version) { this.version = version; } - private VersionHandshakeResponse() {} + private VersionHandshakeResponse() { + } @Override public void readFrom(StreamInput in) throws IOException { @@ -1591,7 +1620,7 @@ protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValu pendingHandshakes.put(requestId, handler); boolean success = false; try { - if (isOpen(channel) == false) { + if (channel.isOpen() == false) { // we have to protect us here since sendRequestToChannel won't barf if the channel is closed. // it's weird but to change it will cause a lot of impact on the exception handling code all over the codebase. // yet, if we don't check the state here we might have registered a pending handshake handler but the close @@ -1642,9 +1671,9 @@ public long newRequestId() { /** * Called once the channel is closed for instance due to a disconnect or a closed socket etc. */ - private void onChannelClosed(Channel channel) { + private void cancelHandshakeForChannel(Channel channel) { final Optional first = pendingHandshakes.entrySet().stream() - .filter((entry) -> entry.getValue().channel == channel).map((e) -> e.getKey()).findFirst(); + .filter((entry) -> entry.getValue().channel == channel).map(Map.Entry::getKey).findFirst(); if (first.isPresent()) { final Long requestId = first.get(); final HandshakeResponseHandler handler = pendingHandshakes.remove(requestId); @@ -1778,5 +1807,4 @@ public ProfileSettings(Settings settings, String profileName) { PUBLISH_PORT_PROFILE.getConcreteSettingForNamespace(profileName).get(settings); } } - } diff --git a/core/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java b/core/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java index ae9fa22f70e32..3267548e91434 100644 --- a/core/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java +++ b/core/src/main/java/org/elasticsearch/transport/TcpTransportChannel.java @@ -23,8 +23,7 @@ import java.io.IOException; import java.util.concurrent.atomic.AtomicBoolean; -public final class TcpTransportChannel implements TransportChannel { - +public final class TcpTransportChannel implements TransportChannel { private final TcpTransport transport; private final Version version; private final String action; diff --git a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java index b18b57e371782..c4a7ca5bee190 100644 --- a/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java +++ b/core/src/test/java/org/elasticsearch/transport/ConnectionProfileTests.java @@ -22,7 +22,11 @@ import org.elasticsearch.test.ESTestCase; import org.hamcrest.Matchers; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; import java.util.EnumSet; +import java.util.List; public class ConnectionProfileTests extends ESTestCase { @@ -65,16 +69,16 @@ public void testBuildConnectionProfile() { assertNull(build.getHandshakeTimeout()); } - Integer[] array = new Integer[10]; - for (int i = 0; i < array.length; i++) { - array[i] = i; + List list = new ArrayList<>(10); + for (int i = 0; i < 10; i++) { + list.add(i); } final int numIters = randomIntBetween(5, 10); assertEquals(4, build.getHandles().size()); assertEquals(0, build.getHandles().get(0).offset); assertEquals(1, build.getHandles().get(0).length); assertEquals(EnumSet.of(TransportRequestOptions.Type.BULK), build.getHandles().get(0).getTypes()); - Integer channel = build.getHandles().get(0).getChannel(array); + Integer channel = build.getHandles().get(0).getChannel(list); for (int i = 0; i < numIters; i++) { assertEquals(0, channel.intValue()); } @@ -83,7 +87,7 @@ public void testBuildConnectionProfile() { assertEquals(2, build.getHandles().get(1).length); assertEquals(EnumSet.of(TransportRequestOptions.Type.STATE, TransportRequestOptions.Type.RECOVERY), build.getHandles().get(1).getTypes()); - channel = build.getHandles().get(1).getChannel(array); + channel = build.getHandles().get(1).getChannel(list); for (int i = 0; i < numIters; i++) { assertThat(channel, Matchers.anyOf(Matchers.is(1), Matchers.is(2))); } @@ -91,7 +95,7 @@ public void testBuildConnectionProfile() { assertEquals(3, build.getHandles().get(2).offset); assertEquals(3, build.getHandles().get(2).length); assertEquals(EnumSet.of(TransportRequestOptions.Type.PING), build.getHandles().get(2).getTypes()); - channel = build.getHandles().get(2).getChannel(array); + channel = build.getHandles().get(2).getChannel(list); for (int i = 0; i < numIters; i++) { assertThat(channel, Matchers.anyOf(Matchers.is(3), Matchers.is(4), Matchers.is(5))); } @@ -99,7 +103,7 @@ public void testBuildConnectionProfile() { assertEquals(6, build.getHandles().get(3).offset); assertEquals(4, build.getHandles().get(3).length); assertEquals(EnumSet.of(TransportRequestOptions.Type.REG), build.getHandles().get(3).getTypes()); - channel = build.getHandles().get(3).getChannel(array); + channel = build.getHandles().get(3).getChannel(list); for (int i = 0; i < numIters; i++) { assertThat(channel, Matchers.anyOf(Matchers.is(6), Matchers.is(7), Matchers.is(8), Matchers.is(9))); } @@ -119,7 +123,7 @@ public void testNoChannels() { TransportRequestOptions.Type.REG); builder.addConnections(0, TransportRequestOptions.Type.PING); ConnectionProfile build = builder.build(); - Integer[] array = new Integer[]{Integer.valueOf(0)}; + List array = Collections.singletonList(0); assertEquals(Integer.valueOf(0), build.getHandles().get(0).getChannel(array)); expectThrows(IllegalStateException.class, () -> build.getHandles().get(1).getChannel(array)); } diff --git a/core/src/test/java/org/elasticsearch/transport/TcpTransportTests.java b/core/src/test/java/org/elasticsearch/transport/TcpTransportTests.java index 54efd231182b6..19ada600cc105 100644 --- a/core/src/test/java/org/elasticsearch/transport/TcpTransportTests.java +++ b/core/src/test/java/org/elasticsearch/transport/TcpTransportTests.java @@ -37,11 +37,10 @@ import java.io.IOException; import java.net.InetSocketAddress; -import java.util.List; +import java.util.ArrayList; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Consumer; import static org.hamcrest.Matchers.equalTo; @@ -178,25 +177,21 @@ public void testCompressRequest() throws IOException { ThreadPool threadPool = new TestThreadPool(TcpTransportTests.class.getName()); AtomicReference exceptionReference = new AtomicReference<>(); try { - TcpTransport transport = new TcpTransport("test", Settings.builder().put("transport.tcp.compress", compressed).build(), - threadPool, new BigArrays(Settings.EMPTY, null), null, null, null) { + TcpTransport transport = new TcpTransport( + "test", Settings.builder().put("transport.tcp.compress", compressed).build(), threadPool, + new BigArrays(Settings.EMPTY, null), null, null, null) { @Override - protected InetSocketAddress getLocalAddress(Object o) { + protected InetSocketAddress getLocalAddress(FakeChannel o) { return null; } @Override - protected Object bind(String name, InetSocketAddress address) throws IOException { + protected FakeChannel bind(String name, InetSocketAddress address) throws IOException { return null; } @Override - protected void closeChannels(List channel, boolean blocking, boolean doNotLinger) throws IOException { - - } - - @Override - protected void sendMessage(Object o, BytesReference reference, ActionListener listener) { + protected void sendMessage(FakeChannel o, BytesReference reference, ActionListener listener) { try { StreamInput streamIn = reference.streamInput(); streamIn.skip(TcpHeader.MARKER_BYTES_SIZE); @@ -224,14 +219,10 @@ protected void sendMessage(Object o, BytesReference reference, ActionListener li } @Override - protected NodeChannels connectToChannels( - DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) throws IOException { - return new NodeChannels(node, new Object[profile.getNumConnections()], profile); - } - - @Override - protected boolean isOpen(Object o) { - return false; + protected FakeChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, + ActionListener connectListener) throws IOException { + FakeChannel fakeChannel = new FakeChannel(); + return fakeChannel; } @Override @@ -241,8 +232,12 @@ public long getNumOpenServerConnections() { @Override public NodeChannels getConnection(DiscoveryNode node) { - return new NodeChannels(node, new Object[MockTcpTransport.LIGHT_PROFILE.getNumConnections()], - MockTcpTransport.LIGHT_PROFILE); + int numConnections = MockTcpTransport.LIGHT_PROFILE.getNumConnections(); + ArrayList fakeChannels = new ArrayList<>(numConnections); + for (int i = 0; i < numConnections; ++i) { + fakeChannels.add(new FakeChannel()); + } + return new NodeChannels(node, fakeChannels, MockTcpTransport.LIGHT_PROFILE, Version.CURRENT); } }; DiscoveryNode node = new DiscoveryNode("foo", buildNewFakeTransportAddress(), Version.CURRENT); @@ -255,6 +250,26 @@ public NodeChannels getConnection(DiscoveryNode node) { } } + private static final class FakeChannel implements TcpChannel { + + @Override + public void close() { + } + + @Override + public void addCloseListener(ActionListener listener) { + } + + @Override + public void setSoLinger(int value) throws IOException { + } + + @Override + public boolean isOpen() { + return false; + } + } + private static final class Req extends TransportRequest { public String value; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java index 9763a5116b163..9e59ba0908d0b 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4MessageChannelHandler.java @@ -20,8 +20,10 @@ package org.elasticsearch.transport.netty4; import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.util.Attribute; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.transport.TcpHeader; import org.elasticsearch.transport.Transports; @@ -53,11 +55,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception final int remainingMessageSize = buffer.getInt(buffer.readerIndex() - TcpHeader.MESSAGE_LENGTH_SIZE); final int expectedReaderIndex = buffer.readerIndex() + remainingMessageSize; try { - InetSocketAddress remoteAddress = (InetSocketAddress) ctx.channel().remoteAddress(); + Channel channel = ctx.channel(); + InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress(); // netty always copies a buffer, either in NioWorker in its read handler, where it copies to a fresh // buffer, or in the cumulative buffer, which is cleaned each time so it could be bigger than the actual size BytesReference reference = Netty4Utils.toBytesReference(buffer, remainingMessageSize); - transport.messageReceived(reference, ctx.channel(), profileName, remoteAddress, remainingMessageSize); + Attribute channelAttribute = channel.attr(Netty4Transport.CHANNEL_KEY); + transport.messageReceived(reference, channelAttribute.get(), profileName, remoteAddress, remainingMessageSize); } finally { // Set the expected position of the buffer, no matter what happened buffer.readerIndex(expectedReaderIndex); diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index 11e5d2f44a81a..9cdefc292f22f 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -24,7 +24,6 @@ import io.netty.channel.AdaptiveRecvByteBufAllocator; import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelInitializer; @@ -34,6 +33,7 @@ import io.netty.channel.nio.NioEventLoopGroup; import io.netty.channel.socket.nio.NioServerSocketChannel; import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.AttributeKey; import io.netty.util.concurrent.Future; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; @@ -55,24 +55,18 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.FutureUtils; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.ConnectTransportException; -import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.TransportRequestOptions; import java.io.IOException; import java.net.InetSocketAddress; import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.concurrent.ThreadFactory; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import static org.elasticsearch.common.settings.Setting.byteSizeSetting; import static org.elasticsearch.common.settings.Setting.intSetting; @@ -85,7 +79,7 @@ * longer. Med is for the typical search / single doc index. And High for things like cluster state. Ping is reserved for * sending out ping requests to other nodes. */ -public class Netty4Transport extends TcpTransport { +public class Netty4Transport extends TcpTransport { static { Netty4Utils.setup(); @@ -97,7 +91,7 @@ public class Netty4Transport extends TcpTransport { (s) -> Setting.parseInt(s, 1, "transport.netty.worker_count"), Property.NodeScope); public static final Setting NETTY_RECEIVE_PREDICTOR_SIZE = Setting.byteSizeSetting( - "transport.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); + "transport.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); public static final Setting NETTY_RECEIVE_PREDICTOR_MIN = byteSizeSetting("transport.netty.receive_predictor_min", NETTY_RECEIVE_PREDICTOR_SIZE, Property.NodeScope); public static final Setting NETTY_RECEIVE_PREDICTOR_MAX = @@ -116,7 +110,7 @@ public class Netty4Transport extends TcpTransport { protected final Map serverBootstraps = newConcurrentMap(); public Netty4Transport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, - NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { + NamedWriteableRegistry namedWriteableRegistry, CircuitBreakerService circuitBreakerService) { super("netty", settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); this.workerCount = WORKER_COUNT.get(settings); @@ -239,10 +233,13 @@ protected ChannelHandler getClientChannelInitializer() { return new ClientChannelInitializer(); } + static final AttributeKey CHANNEL_KEY = AttributeKey.newInstance("es-channel"); + protected final void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { final Throwable unwrapped = ExceptionsHelper.unwrap(cause, ElasticsearchException.class); final Throwable t = unwrapped != null ? unwrapped : cause; - onException(ctx.channel(), t instanceof Exception ? (Exception) t : new ElasticsearchException(t)); + Channel channel = ctx.channel(); + onException(channel.attr(CHANNEL_KEY).get(), t instanceof Exception ? (Exception) t : new ElasticsearchException(t)); } @Override @@ -252,70 +249,39 @@ public long getNumOpenServerConnections() { } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) { - final Channel[] channels = new Channel[profile.getNumConnections()]; - final NodeChannels nodeChannels = new NodeChannels(node, channels, profile); - boolean success = false; - try { - final TimeValue connectTimeout; - final Bootstrap bootstrap; - final TimeValue defaultConnectTimeout = defaultConnectionProfile.getConnectTimeout(); - if (profile.getConnectTimeout() != null && profile.getConnectTimeout().equals(defaultConnectTimeout) == false) { - bootstrap = this.bootstrap.clone(this.bootstrap.config().group()); - bootstrap.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, Math.toIntExact(profile.getConnectTimeout().millis())); - connectTimeout = profile.getConnectTimeout(); + protected NettyTcpChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener listener) + throws IOException { + ChannelFuture channelFuture = bootstrap.connect(node.getAddress().address()); + Channel channel = channelFuture.channel(); + if (channel == null) { + Netty4Utils.maybeDie(channelFuture.cause()); + throw new IOException(channelFuture.cause()); + } + addClosedExceptionLogger(channel); + + NettyTcpChannel nettyChannel = new NettyTcpChannel(channel); + channel.attr(CHANNEL_KEY).set(nettyChannel); + + channelFuture.addListener(f -> { + if (f.isSuccess()) { + listener.onResponse(nettyChannel); } else { - connectTimeout = defaultConnectTimeout; - bootstrap = this.bootstrap; - } - final ArrayList connections = new ArrayList<>(channels.length); - final InetSocketAddress address = node.getAddress().address(); - for (int i = 0; i < channels.length; i++) { - connections.add(bootstrap.connect(address)); - } - final Iterator iterator = connections.iterator(); - final ChannelFutureListener closeListener = future -> onChannelClose.accept(future.channel()); - try { - for (int i = 0; i < channels.length; i++) { - assert iterator.hasNext(); - ChannelFuture future = iterator.next(); - future.awaitUninterruptibly((long) (connectTimeout.millis() * 1.5)); - if (!future.isSuccess()) { - throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", future.cause()); - } - channels[i] = future.channel(); - channels[i].closeFuture().addListener(closeListener); + Throwable cause = f.cause(); + if (cause instanceof Error) { + Netty4Utils.maybeDie(cause); + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); } - assert iterator.hasNext() == false : "not all created connection have been consumed"; - } catch (final RuntimeException e) { - for (final ChannelFuture future : Collections.unmodifiableList(connections)) { - FutureUtils.cancel(future); - if (future.channel() != null && future.channel().isOpen()) { - try { - future.channel().close(); - } catch (Exception inner) { - e.addSuppressed(inner); - } - } - } - throw e; } - success = true; - } finally { - if (success == false) { - try { - nodeChannels.close(); - } catch (IOException e) { - logger.trace("exception while closing channels", e); - } - } - } - return nodeChannels; + }); + + return nettyChannel; } @Override - protected void sendMessage(Channel channel, BytesReference reference, ActionListener listener) { - final ChannelFuture future = channel.writeAndFlush(Netty4Utils.toByteBuf(reference)); + protected void sendMessage(NettyTcpChannel channel, BytesReference reference, ActionListener listener) { + final ChannelFuture future = channel.getLowLevelChannel().writeAndFlush(Netty4Utils.toByteBuf(reference)); future.addListener(f -> { if (f.isSuccess()) { listener.onResponse(channel); @@ -331,54 +297,22 @@ protected void sendMessage(Channel channel, BytesReference reference, ActionList } @Override - protected void closeChannels(final List channels, boolean blocking, boolean doNotLinger) throws IOException { - if (doNotLinger) { - for (Channel channel : channels) { - /* We set SO_LINGER timeout to 0 to ensure that when we shutdown the node we don't have a gazillion connections sitting - * in TIME_WAIT to free up resources quickly. This is really the only part where we close the connection from the server - * side otherwise the client (node) initiates the TCP closing sequence which doesn't cause these issues. Setting this - * by default from the beginning can have unexpected side-effects an should be avoided, our protocol is designed - * in a way that clients close connection which is how it should be*/ - if (channel.isOpen()) { - channel.config().setOption(ChannelOption.SO_LINGER, 0); - } - } - } - if (blocking) { - Netty4Utils.closeChannels(channels); - } else { - for (Channel channel : channels) { - if (channel != null && channel.isOpen()) { - ChannelFuture closeFuture = channel.close(); - closeFuture.addListener((f) -> { - if (f.isSuccess() == false) { - logger.warn("failed to close channel", f.cause()); - } - }); - } - } - } - } - - @Override - protected InetSocketAddress getLocalAddress(Channel channel) { - return (InetSocketAddress) channel.localAddress(); + protected InetSocketAddress getLocalAddress(NettyTcpChannel channel) { + return (InetSocketAddress) channel.getLowLevelChannel().localAddress(); } @Override - protected Channel bind(String name, InetSocketAddress address) { - return serverBootstraps.get(name).bind(address).syncUninterruptibly().channel(); + protected NettyTcpChannel bind(String name, InetSocketAddress address) { + Channel channel = serverBootstraps.get(name).bind(address).syncUninterruptibly().channel(); + NettyTcpChannel esChannel = new NettyTcpChannel(channel); + channel.attr(CHANNEL_KEY).set(esChannel); + return esChannel; } ScheduledPing getPing() { return scheduledPing; } - @Override - protected boolean isOpen(Channel channel) { - return channel.isOpen(); - } - @Override @SuppressForbidden(reason = "debug") protected void stopInternal() { @@ -420,7 +354,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E Netty4Utils.maybeDie(cause); super.exceptionCaught(ctx, cause); } - } protected class ServerChannelInitializer extends ChannelInitializer { @@ -433,6 +366,10 @@ protected ServerChannelInitializer(String name) { @Override protected void initChannel(Channel ch) throws Exception { + addClosedExceptionLogger(ch); + NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch); + ch.attr(CHANNEL_KEY).set(nettyTcpChannel); + serverAcceptedChannel(nettyTcpChannel); ch.pipeline().addLast("logging", new ESLoggingHandler()); ch.pipeline().addLast("open_channels", Netty4Transport.this.serverOpenChannels); ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder()); @@ -444,7 +381,13 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E Netty4Utils.maybeDie(cause); super.exceptionCaught(ctx, cause); } - } + private void addClosedExceptionLogger(Channel channel) { + channel.closeFuture().addListener(f -> { + if (f.isSuccess() == false) { + logger.debug(() -> new ParameterizedMessage("exception while closing channel: {}", channel), f.cause()); + } + }); + } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java new file mode 100644 index 0000000000000..c18c3c4fe1f11 --- /dev/null +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java @@ -0,0 +1,74 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.transport.netty4; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelOption; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.transport.TcpChannel; + +import java.util.concurrent.CompletableFuture; + +public class NettyTcpChannel implements TcpChannel { + + private final Channel channel; + private final CompletableFuture closeContext = new CompletableFuture<>(); + + NettyTcpChannel(Channel channel) { + this.channel = channel; + this.channel.closeFuture().addListener(f -> { + if (f.isSuccess()) { + closeContext.complete(this); + } else { + Throwable cause = f.cause(); + if (cause instanceof Error) { + Netty4Utils.maybeDie(cause); + closeContext.completeExceptionally(cause); + } else { + closeContext.completeExceptionally(cause); + } + } + }); + } + + public Channel getLowLevelChannel() { + return channel; + } + + @Override + public void close() { + channel.close(); + } + + @Override + public void addCloseListener(ActionListener listener) { + closeContext.whenComplete(ActionListener.toBiConsumer(listener)); + } + + @Override + public void setSoLinger(int value) { + channel.config().setOption(ChannelOption.SO_LINGER, value); + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } +} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/ByteBufBytesReferenceTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/ByteBufBytesReferenceTests.java index afe6bbbc90f14..7a6768010eb16 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/ByteBufBytesReferenceTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/ByteBufBytesReferenceTests.java @@ -25,7 +25,9 @@ import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import javax.net.ssl.SSLEngine; import java.io.IOException; +import java.nio.ByteBuffer; public class ByteBufBytesReferenceTests extends AbstractBytesReferenceTestCase { @@ -81,5 +83,4 @@ public void testImmutable() throws IOException { channelBuffer.readInt(); // this advances the index of the channel buffer assertEquals(utf8ToString, byteBufBytesReference.utf8ToString()); } - } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java index 3537d5fbbe578..3eb5adc8d067d 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/Netty4TransportIT.java @@ -18,7 +18,6 @@ */ package org.elasticsearch.transport.netty4; -import io.netty.channel.Channel; import org.elasticsearch.ESNetty4IntegTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.Version; @@ -109,7 +108,7 @@ public ExceptionThrowingNetty4Transport( super(settings, threadPool, networkService, bigArrays, namedWriteableRegistry, circuitBreakerService); } - protected String handleRequest(Channel channel, String profileName, + protected String handleRequest(NettyTcpChannel channel, String profileName, StreamInput stream, long requestId, int messageLengthBytes, Version version, InetSocketAddress remoteAddress, byte status) throws IOException { String action = super.handleRequest(channel, profileName, stream, requestId, messageLengthBytes, version, diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java index bdf4adb5ea91c..47259a7c613eb 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/transport/netty4/SimpleNetty4TransportTests.java @@ -19,7 +19,6 @@ package org.elasticsearch.transport.netty4; -import io.netty.channel.Channel; import org.elasticsearch.Version; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; @@ -36,6 +35,7 @@ import org.elasticsearch.transport.AbstractSimpleTransportTestCase; import org.elasticsearch.transport.BindTransportException; import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportService; @@ -58,7 +58,7 @@ public static MockTransportService nettyFromThreadPool(Settings settings, Thread BigArrays.NON_RECYCLING_INSTANCE, namedWriteableRegistry, new NoneCircuitBreakerService()) { @Override - protected Version executeHandshake(DiscoveryNode node, Channel channel, TimeValue timeout) throws IOException, + protected Version executeHandshake(DiscoveryNode node, NettyTcpChannel channel, TimeValue timeout) throws IOException, InterruptedException { if (doHandshake) { return super.executeHandshake(node, channel, timeout); @@ -89,8 +89,9 @@ protected MockTransportService build(Settings settings, Version version, Cluster @Override protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException { final Netty4Transport t = (Netty4Transport) transport; - @SuppressWarnings("unchecked") final TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection; - t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false); + @SuppressWarnings("unchecked") + final TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection; + TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true); } public void testConnectException() throws UnknownHostException { @@ -99,7 +100,7 @@ public void testConnectException() throws UnknownHostException { emptyMap(), emptySet(),Version.CURRENT)); fail("Expected ConnectTransportException"); } catch (ConnectTransportException e) { - assertThat(e.getMessage(), containsString("connect_timeout")); + assertThat(e.getMessage(), containsString("connect_exception")); assertThat(e.getMessage(), containsString("[127.0.0.1:9876]")); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java index 6d5b94dd67a05..4b1da5c212621 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/MockTcpTransport.java @@ -52,8 +52,8 @@ import java.net.SocketTimeoutException; import java.util.Collections; import java.util.HashSet; -import java.util.List; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executor; @@ -176,37 +176,38 @@ private void readMessage(MockChannel mockChannel, StreamInput input) throws IOEx } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, - ConnectionProfile profile, - Consumer onChannelClose) throws IOException { - final MockChannel[] mockChannels = new MockChannel[1]; - final NodeChannels nodeChannels = new NodeChannels(node, mockChannels, LIGHT_PROFILE); // we always use light here - boolean success = false; + protected MockChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) + throws IOException { + InetSocketAddress address = node.getAddress().address(); final MockSocket socket = new MockSocket(); + boolean success = false; try { - final InetSocketAddress address = node.getAddress().address(); - // we just use a single connections configureSocket(socket); - final TimeValue connectTimeout = profile.getConnectTimeout(); try { socket.connect(address, Math.toIntExact(connectTimeout.millis())); } catch (SocketTimeoutException ex) { throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]", ex); } - MockChannel channel = new MockChannel(socket, address, "none", onChannelClose); + MockChannel channel = new MockChannel(socket, address, "none", (c) -> {}); channel.loopRead(executor); - mockChannels[0] = channel; success = true; + connectListener.onResponse(channel); + return channel; } finally { if (success == false) { - IOUtils.close(nodeChannels, socket); + IOUtils.close(socket); } } - - return nodeChannels; } - + @Override + protected ConnectionProfile resolveConnectionProfile(ConnectionProfile connectionProfile) { + ConnectionProfile connectionProfile1 = resolveConnectionProfile(connectionProfile, defaultConnectionProfile); + ConnectionProfile.Builder builder = new ConnectionProfile.Builder(LIGHT_PROFILE); + builder.setHandshakeTimeout(connectionProfile1.getHandshakeTimeout()); + builder.setConnectTimeout(connectionProfile1.getConnectTimeout()); + return builder.build(); + } private void configureSocket(Socket socket) throws SocketException { socket.setTcpNoDelay(TCP_NO_DELAY.get(settings)); @@ -221,11 +222,6 @@ private void configureSocket(Socket socket) throws SocketException { socket.setReuseAddress(TCP_REUSE_ADDRESS.get(settings)); } - @Override - protected boolean isOpen(MockChannel mockChannel) { - return mockChannel.isOpen.get(); - } - @Override protected void sendMessage(MockChannel mockChannel, BytesReference reference, ActionListener listener) { try { @@ -242,31 +238,12 @@ protected void sendMessage(MockChannel mockChannel, BytesReference reference, Ac } } - @Override - protected void closeChannels(List channels, boolean blocking, boolean doNotLinger) throws IOException { - if (doNotLinger) { - for (MockChannel channel : channels) { - if (channel.activeChannel != null) { - /* We set SO_LINGER timeout to 0 to ensure that when we shutdown the node we don't have a gazillion connections sitting - * in TIME_WAIT to free up resources quickly. This is really the only part where we close the connection from the server - * side otherwise the client (node) initiates the TCP closing sequence which doesn't cause these issues. Setting this - * by default from the beginning can have unexpected side-effects an should be avoided, our protocol is designed - * in a way that clients close connection which is how it should be*/ - if (channel.isOpen.get()) { - channel.activeChannel.setSoLinger(true, 0); - } - } - } - } - IOUtils.close(channels); - } - @Override public long getNumOpenServerConnections() { return 1; } - public final class MockChannel implements Closeable { + public final class MockChannel implements Closeable, TcpChannel { private final AtomicBoolean isOpen = new AtomicBoolean(true); private final InetSocketAddress localAddress; private final ServerSocket serverSocket; @@ -275,6 +252,7 @@ public final class MockChannel implements Closeable { private final String profile; private final CancellableThreads cancellableThreads = new CancellableThreads(); private final Closeable onClose; + private final CompletableFuture closeFuture = new CompletableFuture<>(); /** * Constructs a new MockChannel instance intended for handling the actual incoming / outgoing traffic. @@ -323,6 +301,7 @@ public void accept(Executor executor) throws IOException { incomingChannel = new MockChannel(incomingSocket, new InetSocketAddress(incomingSocket.getLocalAddress(), incomingSocket.getPort()), profile, workerChannels::remove); + serverAcceptedChannel(incomingChannel); //establish a happens-before edge between closing and accepting a new connection workerChannels.add(incomingChannel); @@ -366,8 +345,7 @@ protected void doRun() throws Exception { }); } - @Override - public synchronized void close() throws IOException { + public synchronized void close0() throws IOException { // establish a happens-before edge between closing and accepting a new connection // we have to sync this entire block to ensure that our openChannels checks work correctly. // The close block below will close all worker channels but if one of the worker channels runs into an exception @@ -394,6 +372,35 @@ public String toString() { ", isServerSocket=" + (serverSocket != null) + '}'; } + + @Override + public void close() { + try { + close0(); + closeFuture.complete(this); + } catch (IOException e) { + closeFuture.completeExceptionally(e); + } + } + + @Override + public void addCloseListener(ActionListener listener) { + closeFuture.whenComplete(ActionListener.toBiConsumer(listener)); + } + + @Override + public void setSoLinger(int value) throws IOException { + if (activeChannel != null && activeChannel.isClosed() == false) { + activeChannel.setSoLinger(true, value); + } + + } + + @Override + public boolean isOpen() { + return isOpen.get(); + } + } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java index 3de846fd61f6b..49bba47ef0256 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/AcceptorEventHandler.java @@ -22,11 +22,13 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.elasticsearch.transport.nio.channel.ChannelFactory; +import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.elasticsearch.transport.nio.channel.SelectionKeyUtils; import java.io.IOException; +import java.util.function.Consumer; import java.util.function.Supplier; /** @@ -35,12 +37,15 @@ public class AcceptorEventHandler extends EventHandler { private final Supplier selectorSupplier; + private final Consumer acceptedChannelCallback; private final OpenChannels openChannels; - public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier selectorSupplier) { - super(logger); + public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier selectorSupplier, + Consumer acceptedChannelCallback) { + super(logger, openChannels); this.openChannels = openChannels; this.selectorSupplier = selectorSupplier; + this.acceptedChannelCallback = acceptedChannelCallback; } /** @@ -73,8 +78,9 @@ void registrationException(NioServerSocketChannel channel, Exception exception) void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException { ChannelFactory channelFactory = nioServerChannel.getChannelFactory(); SocketSelector selector = selectorSupplier.get(); - NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector, openChannels::channelClosed); + NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector); openChannels.acceptedChannelOpened(nioSocketChannel); + acceptedChannelCallback.accept(nioSocketChannel); } /** diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java index 04e1b21b1b065..59e866036cc51 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/EventHandler.java @@ -21,7 +21,6 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; -import org.elasticsearch.transport.nio.channel.CloseFuture; import org.elasticsearch.transport.nio.channel.NioChannel; import java.io.IOException; @@ -30,9 +29,11 @@ public abstract class EventHandler { protected final Logger logger; + private final OpenChannels openChannels; - EventHandler(Logger logger) { + public EventHandler(Logger logger, OpenChannels openChannels) { this.logger = logger; + this.openChannels = openChannels; } /** @@ -70,13 +71,13 @@ void uncaughtException(Exception exception) { * @param channel that should be closed */ void handleClose(NioChannel channel) { - channel.closeFromSelector(); - CloseFuture closeFuture = channel.getCloseFuture(); - assert closeFuture.isDone() : "Should always be done as we are on the selector thread"; - IOException closeException = closeFuture.getCloseException(); - if (closeException != null) { - closeException(channel, closeException); + openChannels.channelClosed(channel); + try { + channel.closeFromSelector(); + } catch (IOException e) { + closeException(channel, e); } + assert channel.isOpen() == false : "Should always be done as we are on the selector thread"; } /** diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java index ee0b32db0149a..74a9eb46a23c8 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioClient.java @@ -19,129 +19,44 @@ package org.elasticsearch.transport.nio; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.unit.TimeValue; -import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.nio.channel.ChannelFactory; -import org.elasticsearch.transport.nio.channel.ConnectFuture; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import java.io.IOException; import java.net.InetSocketAddress; -import java.util.ArrayList; -import java.util.Iterator; import java.util.concurrent.Semaphore; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Supplier; public class NioClient { - private final Logger logger; private final OpenChannels openChannels; private final Supplier selectorSupplier; - private final TimeValue defaultConnectTimeout; private final ChannelFactory channelFactory; private final Semaphore semaphore = new Semaphore(Integer.MAX_VALUE); - public NioClient(Logger logger, OpenChannels openChannels, Supplier selectorSupplier, TimeValue connectTimeout, - ChannelFactory channelFactory) { - this.logger = logger; + NioClient(OpenChannels openChannels, Supplier selectorSupplier, ChannelFactory channelFactory) { this.openChannels = openChannels; this.selectorSupplier = selectorSupplier; - this.defaultConnectTimeout = connectTimeout; this.channelFactory = channelFactory; } - public boolean connectToChannels(DiscoveryNode node, - NioSocketChannel[] channels, - TimeValue connectTimeout, - Consumer closeListener) throws IOException { + public void close() { + semaphore.acquireUninterruptibly(Integer.MAX_VALUE); + } + + NioSocketChannel initiateConnection(InetSocketAddress address) throws IOException { boolean allowedToConnect = semaphore.tryAcquire(); if (allowedToConnect == false) { - return false; + return null; } - final ArrayList connections = new ArrayList<>(channels.length); - connectTimeout = getConnectTimeout(connectTimeout); - final InetSocketAddress address = node.getAddress().address(); try { - for (int i = 0; i < channels.length; i++) { - SocketSelector selector = selectorSupplier.get(); - NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address, selector, closeListener); - openChannels.clientChannelOpened(nioSocketChannel); - connections.add(nioSocketChannel); - } - - Exception ex = null; - boolean allConnected = true; - for (NioSocketChannel socketChannel : connections) { - ConnectFuture connectFuture = socketChannel.getConnectFuture(); - boolean success = connectFuture.awaitConnectionComplete(connectTimeout.getMillis(), TimeUnit.MILLISECONDS); - if (success == false) { - allConnected = false; - Exception exception = connectFuture.getException(); - if (exception != null) { - ex = exception; - break; - } - } - } - - if (allConnected == false) { - if (ex == null) { - throw new ConnectTransportException(node, "connect_timeout[" + connectTimeout + "]"); - } else { - throw new ConnectTransportException(node, "connect_exception", ex); - } - } - addConnectionsToList(channels, connections); - return true; - - } catch (IOException | RuntimeException e) { - closeChannels(connections, e); - throw e; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - closeChannels(connections, e); - throw new ElasticsearchException(e); + SocketSelector selector = selectorSupplier.get(); + NioSocketChannel nioSocketChannel = channelFactory.openNioChannel(address, selector); + openChannels.clientChannelOpened(nioSocketChannel); + return nioSocketChannel; } finally { semaphore.release(); } } - - public void close() { - semaphore.acquireUninterruptibly(Integer.MAX_VALUE); - } - - private TimeValue getConnectTimeout(TimeValue connectTimeout) { - if (connectTimeout != null && connectTimeout.equals(defaultConnectTimeout) == false) { - return connectTimeout; - } else { - return defaultConnectTimeout; - } - } - - private static void addConnectionsToList(NioSocketChannel[] channels, ArrayList connections) { - final Iterator iterator = connections.iterator(); - for (int i = 0; i < channels.length; i++) { - assert iterator.hasNext(); - channels[i] = iterator.next(); - } - assert iterator.hasNext() == false : "not all created connection have been consumed"; - } - - private void closeChannels(ArrayList connections, Exception e) { - for (final NioSocketChannel socketChannel : connections) { - try { - socketChannel.closeAsync().awaitClose(); - } catch (Exception inner) { - logger.trace("exception while closing channel", e); - e.addSuppressed(inner); - } - } - } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java index 686432722a204..381f2841136e9 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/NioTransport.java @@ -19,7 +19,6 @@ package org.elasticsearch.transport.nio; -import java.net.StandardSocketOptions; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; @@ -29,15 +28,14 @@ import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.ConnectionProfile; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.Transports; import org.elasticsearch.transport.nio.channel.ChannelFactory; -import org.elasticsearch.transport.nio.channel.CloseFuture; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; @@ -47,7 +45,6 @@ import java.io.IOException; import java.net.InetSocketAddress; import java.util.ArrayList; -import java.util.List; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ThreadFactory; import java.util.function.Consumer; @@ -70,9 +67,9 @@ public class NioTransport extends TcpTransport { public static final Setting NIO_ACCEPTOR_COUNT = intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope); + protected final OpenChannels openChannels = new OpenChannels(logger); private final Consumer contextSetter; private final ConcurrentMap profileToChannelFactory = newConcurrentMap(); - private final OpenChannels openChannels = new OpenChannels(logger); private final ArrayList acceptors = new ArrayList<>(); private final ArrayList socketSelectors = new ArrayList<>(); private NioClient client; @@ -101,48 +98,6 @@ protected NioServerSocketChannel bind(String name, InetSocketAddress address) th return channelFactory.openNioServerSocketChannel(name, address, selector); } - @Override - protected void closeChannels(List channels, boolean blocking, boolean doNotLinger) throws IOException { - if (doNotLinger) { - for (NioChannel channel : channels) { - /* We set SO_LINGER timeout to 0 to ensure that when we shutdown the node we don't have a gazillion connections sitting - * in TIME_WAIT to free up resources quickly. This is really the only part where we close the connection from the server - * side otherwise the client (node) initiates the TCP closing sequence which doesn't cause these issues. Setting this - * by default from the beginning can have unexpected side-effects an should be avoided, our protocol is designed - * in a way that clients close connection which is how it should be*/ - if (channel.isOpen() && channel.getRawChannel().supportedOptions().contains(StandardSocketOptions.SO_LINGER)) { - channel.getRawChannel().setOption(StandardSocketOptions.SO_LINGER, 0); - } - } - } - ArrayList futures = new ArrayList<>(channels.size()); - for (final NioChannel channel : channels) { - if (channel != null && channel.isOpen()) { - // We do not need to wait for the close operation to complete. If the close operation fails due - // to an IOException, the selector's handler will log the exception. Additionally, in the case - // of transport shutdown, where we do want to ensure that all channels are finished closing, the - // NioShutdown class will block on close. - futures.add(channel.closeAsync()); - } - } - - if (blocking == false) { - return; - } - - IOException closingExceptions = null; - for (CloseFuture future : futures) { - try { - future.awaitClose(); - } catch (Exception e) { - closingExceptions = addClosingException(closingExceptions, e); - } - } - if (closingExceptions != null) { - throw closingExceptions; - } - } - @Override protected void sendMessage(NioChannel channel, BytesReference reference, ActionListener listener) { if (channel instanceof NioSocketChannel) { @@ -154,20 +109,14 @@ protected void sendMessage(NioChannel channel, BytesReference reference, ActionL } @Override - protected NodeChannels connectToChannels(DiscoveryNode node, ConnectionProfile profile, Consumer onChannelClose) + protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener connectListener) throws IOException { - NioSocketChannel[] channels = new NioSocketChannel[profile.getNumConnections()]; - ClientChannelCloseListener closeListener = new ClientChannelCloseListener(onChannelClose); - boolean connected = client.connectToChannels(node, channels, profile.getConnectTimeout(), closeListener); - if (connected == false) { + NioSocketChannel channel = client.initiateConnection(node.getAddress().address()); + if (channel == null) { throw new ElasticsearchException("client is shutdown"); } - return new NodeChannels(node, channels, profile); - } - - @Override - protected boolean isOpen(NioChannel channel) { - return channel.isOpen(); + channel.addConnectListener(connectListener); + return channel; } @Override @@ -194,7 +143,8 @@ protected void doStart() { int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings); for (int i = 0; i < acceptorCount; ++i) { Supplier selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); - AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier); + AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier, + this::serverAcceptedChannel); AcceptingSelector acceptor = new AcceptingSelector(eventHandler); acceptors.add(acceptor); } @@ -235,7 +185,7 @@ protected void stopInternal() { } protected SocketEventHandler getSocketEventHandler() { - return new SocketEventHandler(logger, this::exceptionCaught); + return new SocketEventHandler(logger, this::exceptionCaught, openChannels); } final void exceptionCaught(NioSocketChannel channel, Throwable cause) { @@ -247,29 +197,6 @@ final void exceptionCaught(NioSocketChannel channel, Throwable cause) { private NioClient createClient() { Supplier selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors); ChannelFactory channelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), contextSetter); - return new NioClient(logger, openChannels, selectorSupplier, defaultConnectionProfile.getConnectTimeout(), channelFactory); - } - - private IOException addClosingException(IOException closingExceptions, Exception e) { - if (closingExceptions == null) { - closingExceptions = new IOException("failed to close channels"); - } - closingExceptions.addSuppressed(e); - return closingExceptions; - } - - class ClientChannelCloseListener implements Consumer { - - private final Consumer consumer; - - private ClientChannelCloseListener(Consumer consumer) { - this.consumer = consumer; - } - - @Override - public void accept(final NioChannel channel) { - consumer.accept(channel); - openChannels.channelClosed(channel); - } + return new NioClient(openChannels, selectorSupplier, channelFactory); } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java index 4655f19001da3..68bb2f99bf3c5 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/OpenChannels.java @@ -21,15 +21,17 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.transport.nio.channel.CloseFuture; +import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioServerSocketChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import java.util.ArrayList; import java.util.HashSet; -import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentMap; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.elasticsearch.common.util.concurrent.ConcurrentCollections.newConcurrentMap; @@ -75,6 +77,10 @@ public void clientChannelOpened(NioSocketChannel channel) { } } + public Map getClientChannels() { + return openClientChannels; + } + public void channelClosed(NioChannel channel) { boolean removed; if (channel instanceof NioServerSocketChannel) { @@ -92,40 +98,17 @@ public void channelClosed(NioChannel channel) { } public void closeServerChannels() { - List futures = new ArrayList<>(); - for (NioServerSocketChannel channel : openServerChannels.keySet()) { - CloseFuture closeFuture = channel.closeAsync(); - futures.add(closeFuture); - } - ensureChannelsClosed(futures); + TcpChannel.closeChannels(new ArrayList<>(openServerChannels.keySet()), true); openServerChannels.clear(); } @Override public void close() { - List futures = new ArrayList<>(); - for (NioSocketChannel channel : openClientChannels.keySet()) { - CloseFuture closeFuture = channel.closeAsync(); - futures.add(closeFuture); - } - for (NioSocketChannel channel : openAcceptedChannels.keySet()) { - CloseFuture closeFuture = channel.closeAsync(); - futures.add(closeFuture); - } - ensureChannelsClosed(futures); + Stream channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream()); + TcpChannel.closeChannels(channels.collect(Collectors.toList()), true); openClientChannels.clear(); openAcceptedChannels.clear(); } - - private void ensureChannelsClosed(List futures) { - for (CloseFuture future : futures) { - try { - future.get(); - } catch (Exception e) { - logger.debug("exception while closing channels", e); - } - } - } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java index 58958a2b3ce3f..b04ecc4ea9a6f 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/SocketEventHandler.java @@ -37,8 +37,8 @@ public class SocketEventHandler extends EventHandler { private final BiConsumer exceptionHandler; private final Logger logger; - public SocketEventHandler(Logger logger, BiConsumer exceptionHandler) { - super(logger); + public SocketEventHandler(Logger logger, BiConsumer exceptionHandler, OpenChannels openChannels) { + super(logger, openChannels); this.exceptionHandler = exceptionHandler; this.logger = logger; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java index c550785fac517..a7208beb6618f 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/AbstractNioChannel.java @@ -19,14 +19,18 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.ESSelector; import java.io.IOException; import java.net.InetSocketAddress; +import java.net.StandardSocketOptions; import java.nio.channels.ClosedChannelException; import java.nio.channels.NetworkChannel; import java.nio.channels.SelectableChannel; import java.nio.channels.SelectionKey; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -54,7 +58,7 @@ public abstract class AbstractNioChannel closeContext = new CompletableFuture<>(); private final ESSelector selector; private SelectionKey selectionKey; @@ -67,7 +71,7 @@ public abstract class AbstractNioChannel * If the channel is already set to closed, it is assumed that it is already scheduled to be closed. - * - * @return future that will be complete when the channel is closed */ @Override - public CloseFuture closeAsync() { + public void close() { if (isClosing.compareAndSet(false, true)) { selector.queueChannelClose(this); } - return closeFuture; } /** @@ -104,20 +105,19 @@ public CloseFuture closeAsync() { * Once this method returns, the channel will be closed. */ @Override - public void closeFromSelector() { + public void closeFromSelector() throws IOException { assert selector.isOnCurrentThread() : "Should only call from selector thread"; isClosing.set(true); - if (closeFuture.isClosed() == false) { - boolean closedOnThisCall = false; + if (closeContext.isDone() == false) { try { closeRawChannel(); - closedOnThisCall = closeFuture.channelClosed(this); + closeContext.complete(this); } catch (IOException e) { - closedOnThisCall = closeFuture.channelCloseThrewException(e); + closeContext.completeExceptionally(e); + throw e; } finally { - if (closedOnThisCall) { - selector.removeRegisteredChannel(this); - } + // There is no problem with calling this multiple times + selector.removeRegisteredChannel(this); } } } @@ -143,11 +143,6 @@ public SelectionKey getSelectionKey() { return selectionKey; } - @Override - public CloseFuture getCloseFuture() { - return closeFuture; - } - @Override public S getRawChannel() { return socketChannel; @@ -162,4 +157,16 @@ void setSelectionKey(SelectionKey selectionKey) { void closeRawChannel() throws IOException { socketChannel.close(); } + + @Override + public void addCloseListener(ActionListener listener) { + closeContext.whenComplete(ActionListener.toBiConsumer(listener)); + } + + @Override + public void setSoLinger(int value) throws IOException { + if (isOpen()) { + socketChannel.setOption(StandardSocketOptions.SO_LINGER, value); + } + } } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java index 199bab9a904b0..8d739fd72778e 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ChannelFactory.java @@ -21,7 +21,6 @@ import org.apache.lucene.util.IOUtils; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.mocksocket.PrivilegedSocketAccess; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.nio.AcceptingSelector; @@ -63,22 +62,18 @@ public ChannelFactory(TcpTransport.ProfileSettings profileSettings, Consumer closeListener) throws IOException { + public NioSocketChannel openNioChannel(InetSocketAddress remoteAddress, SocketSelector selector) throws IOException { SocketChannel rawChannel = rawChannelFactory.openNioChannel(remoteAddress); NioSocketChannel channel = new NioSocketChannel(NioChannel.CLIENT, rawChannel, selector); setContexts(channel); - channel.getCloseFuture().addListener(ActionListener.wrap(closeListener::accept, (e) -> closeListener.accept(channel))); scheduleChannel(channel, selector); return channel; } - public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector, - Consumer closeListener) throws IOException { + public NioSocketChannel acceptNioChannel(NioServerSocketChannel serverChannel, SocketSelector selector) throws IOException { SocketChannel rawChannel = rawChannelFactory.acceptNioChannel(serverChannel); NioSocketChannel channel = new NioSocketChannel(serverChannel.getProfile(), rawChannel, selector); setContexts(channel); - channel.getCloseFuture().addListener(ActionListener.wrap(closeListener::accept, (e) -> closeListener.accept(channel))); scheduleChannel(channel, selector); return channel; } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java deleted file mode 100644 index 5932de8fef708..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/CloseFuture.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.transport.nio.channel; - -import org.elasticsearch.action.support.PlainListenableActionFuture; - -import java.io.IOException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -public class CloseFuture extends PlainListenableActionFuture { - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - throw new UnsupportedOperationException("Cannot cancel close future"); - } - - public void awaitClose() throws IOException { - try { - super.get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Future got interrupted", e); - } catch (ExecutionException e) { - throw (IOException) e.getCause(); - } - } - - public void awaitClose(long timeout, TimeUnit unit) throws TimeoutException, IOException { - try { - super.get(timeout, unit); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new IllegalStateException("Future got interrupted", e); - } catch (ExecutionException e) { - throw (IOException) e.getCause(); - } - } - - public IOException getCloseException() { - if (isDone()) { - try { - super.get(0, TimeUnit.NANOSECONDS); - return null; - } catch (ExecutionException e) { - // We only make a setter for IOException - return (IOException) e.getCause(); - } catch (TimeoutException e) { - return null; - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return null; - } - } else { - return null; - } - } - - public boolean isClosed() { - return super.isDone(); - } - - boolean channelClosed(NioChannel channel) { - return set(channel); - } - - - boolean channelCloseThrewException(IOException ex) { - return setException(ex); - } - -} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ConnectFuture.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ConnectFuture.java deleted file mode 100644 index 1675c7326ee04..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/ConnectFuture.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.transport.nio.channel; - -import org.elasticsearch.common.util.concurrent.BaseFuture; - -import java.io.IOException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; - -public class ConnectFuture extends BaseFuture { - - public boolean awaitConnectionComplete(long timeout, TimeUnit unit) throws InterruptedException { - try { - super.get(timeout, unit); - return true; - } catch (ExecutionException | TimeoutException e) { - return false; - } - } - - public Exception getException() { - if (isDone()) { - try { - // Get should always return without blocking as we already checked 'isDone' - // We are calling 'get' here in order to throw the ExecutionException - super.get(); - return null; - } catch (ExecutionException e) { - // We only make a public setters for IOException or RuntimeException - return (Exception) e.getCause(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return null; - } - } else { - return null; - } - } - - public boolean isConnectComplete() { - return getChannel() != null; - } - - public boolean connectFailed() { - return getException() != null; - } - - void setConnectionComplete(NioSocketChannel channel) { - set(channel); - } - - void setConnectionFailed(IOException e) { - setException(e); - } - - void setConnectionFailed(RuntimeException e) { - setException(e); - } - - private NioSocketChannel getChannel() { - if (isDone()) { - try { - // Get should always return without blocking as we already checked 'isDone' - return super.get(0, TimeUnit.NANOSECONDS); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - return null; - } catch (ExecutionException e) { - return null; - } catch (TimeoutException e) { - throw new AssertionError("This should never happen as we only call get() after isDone() is true."); - } - } else { - return null; - } - } -} diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java index c4133cce27105..b519ec0dc11df 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioChannel.java @@ -19,26 +19,26 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.ESSelector; +import java.io.IOException; import java.net.InetSocketAddress; import java.nio.channels.ClosedChannelException; import java.nio.channels.NetworkChannel; import java.nio.channels.SelectionKey; -public interface NioChannel { +public interface NioChannel extends TcpChannel { String CLIENT = "client-socket"; - boolean isOpen(); - InetSocketAddress getLocalAddress(); String getProfile(); - CloseFuture closeAsync(); + void close(); - void closeFromSelector(); + void closeFromSelector() throws IOException; void register() throws ClosedChannelException; @@ -46,7 +46,5 @@ public interface NioChannel { SelectionKey getSelectionKey(); - CloseFuture getCloseFuture(); - NetworkChannel getRawChannel(); } diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java index 4c6c0b2b65acd..fab6fa22c6b16 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannel.java @@ -23,6 +23,7 @@ import java.io.IOException; import java.nio.channels.ServerSocketChannel; +import java.util.concurrent.Future; public class NioServerSocketChannel extends AbstractNioChannel { diff --git a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java index 6d41ad563a4e3..5e4e323094199 100644 --- a/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java +++ b/test/framework/src/main/java/org/elasticsearch/transport/nio/channel/NioSocketChannel.java @@ -19,6 +19,7 @@ package org.elasticsearch.transport.nio.channel; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.transport.nio.NetworkBytesReference; import org.elasticsearch.transport.nio.SocketSelector; @@ -28,14 +29,16 @@ import java.nio.channels.ClosedChannelException; import java.nio.channels.SocketChannel; import java.util.Arrays; +import java.util.concurrent.CompletableFuture; public class NioSocketChannel extends AbstractNioChannel { private final InetSocketAddress remoteAddress; - private final ConnectFuture connectFuture = new ConnectFuture(); + private final CompletableFuture connectContext = new CompletableFuture<>(); private final SocketSelector socketSelector; private WriteContext writeContext; private ReadContext readContext; + private Exception connectException; public NioSocketChannel(String profile, SocketChannel socketChannel, SocketSelector selector) throws IOException { super(profile, socketChannel, selector); @@ -44,7 +47,7 @@ public NioSocketChannel(String profile, SocketChannel socketChannel, SocketSelec } @Override - public void closeFromSelector() { + public void closeFromSelector() throws IOException { assert socketSelector.isOnCurrentThread() : "Should only call from selector thread"; // Even if the channel has already been closed we will clear any pending write operations just in case if (writeContext.hasQueuedWriteOps()) { @@ -108,7 +111,7 @@ public InetSocketAddress getRemoteAddress() { } public boolean isConnectComplete() { - return connectFuture.isConnectComplete(); + return isConnectComplete0(); } public boolean isWritable() { @@ -130,11 +133,13 @@ public boolean isReadable() { * @throws IOException if an I/O error occurs */ public boolean finishConnect() throws IOException { - if (connectFuture.isConnectComplete()) { + if (isConnectComplete0()) { return true; - } else if (connectFuture.connectFailed()) { - Exception exception = connectFuture.getException(); - if (exception instanceof IOException) { + } else if (connectContext.isCompletedExceptionally()) { + Exception exception = connectException; + if (exception == null) { + throw new AssertionError("Should have received connection exception"); + } else if (exception instanceof IOException) { throw (IOException) exception; } else { throw (RuntimeException) exception; @@ -146,13 +151,13 @@ public boolean finishConnect() throws IOException { isConnected = internalFinish(); } if (isConnected) { - connectFuture.setConnectionComplete(this); + connectContext.complete(this); } return isConnected; } - public ConnectFuture getConnectFuture() { - return connectFuture; + public void addConnectListener(ActionListener listener) { + connectContext.whenComplete(ActionListener.toBiConsumer(listener)); } @Override @@ -167,12 +172,14 @@ public String toString() { private boolean internalFinish() throws IOException { try { return socketChannel.finishConnect(); - } catch (IOException e) { - connectFuture.setConnectionFailed(e); - throw e; - } catch (RuntimeException e) { - connectFuture.setConnectionFailed(e); + } catch (IOException | RuntimeException e) { + connectException = e; + connectContext.completeExceptionally(e); throw e; } } + + private boolean isConnectComplete0() { + return connectContext.isDone() && connectContext.isCompletedExceptionally() == false; + } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java index b1a3a914be89e..bbe3c13442cec 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/MockTcpTransportTests.java @@ -60,7 +60,7 @@ protected void closeConnectionChannel(Transport transport, Transport.Connection final MockTcpTransport t = (MockTcpTransport) transport; @SuppressWarnings("unchecked") final TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection; - t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false); + TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java index 8ae6559c7413b..abca0295f5835 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/AcceptorEventHandlerTests.java @@ -53,15 +53,18 @@ public class AcceptorEventHandlerTests extends ESTestCase { private ChannelFactory channelFactory; private OpenChannels openChannels; private NioServerSocketChannel channel; + private Consumer acceptedChannelCallback; @Before + @SuppressWarnings("unchecked") public void setUpHandler() throws IOException { channelFactory = mock(ChannelFactory.class); socketSelector = mock(SocketSelector.class); + acceptedChannelCallback = mock(Consumer.class); openChannels = new OpenChannels(logger); ArrayList selectors = new ArrayList<>(); selectors.add(socketSelector); - handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors)); + handler = new AcceptorEventHandler(logger, openChannels, new RoundRobinSelectorSupplier(selectors), acceptedChannelCallback); AcceptingSelector selector = mock(AcceptingSelector.class); channel = new DoNotRegisterServerChannel("", mock(ServerSocketChannel.class), channelFactory, selector); @@ -86,31 +89,26 @@ public void testHandleRegisterSetsOP_ACCEPTInterest() { public void testHandleAcceptCallsChannelFactory() throws IOException { NioSocketChannel childChannel = new NioSocketChannel("", mock(SocketChannel.class), socketSelector); - when(channelFactory.acceptNioChannel(same(channel), same(socketSelector), any())).thenReturn(childChannel); + when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); handler.acceptChannel(channel); - verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector), any()); + verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector)); } @SuppressWarnings("unchecked") - public void testHandleAcceptAddsToOpenChannelsAndAddsCloseListenerToRemove() throws IOException { + public void testHandleAcceptAddsToOpenChannelsAndIsRemovedOnClose() throws IOException { SocketChannel rawChannel = SocketChannel.open(); NioSocketChannel childChannel = new NioSocketChannel("", rawChannel, socketSelector); childChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); - when(channelFactory.acceptNioChannel(same(channel), same(socketSelector), any())).thenReturn(childChannel); + when(channelFactory.acceptNioChannel(same(channel), same(socketSelector))).thenReturn(childChannel); handler.acceptChannel(channel); - Class> clazz = (Class>)(Class)Consumer.class; - ArgumentCaptor> listener = ArgumentCaptor.forClass(clazz); - verify(channelFactory).acceptNioChannel(same(channel), same(socketSelector), listener.capture()); - - assertEquals(new HashSet<>(Collections.singletonList(childChannel)), openChannels.getAcceptedChannels()); - listener.getValue().accept(childChannel); + verify(acceptedChannelCallback).accept(childChannel); - assertEquals(new HashSet<>(), openChannels.getAcceptedChannels()); + assertEquals(new HashSet<>(Collections.singletonList(childChannel)), openChannels.getAcceptedChannels()); IOUtils.closeWhileHandlingException(rawChannel); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java index 4cae51acc83fa..6b376af066474 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/NioClientTests.java @@ -19,16 +19,8 @@ package org.elasticsearch.transport.nio; -import org.elasticsearch.Version; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.common.transport.TransportAddress; -import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.nio.channel.ChannelFactory; -import org.elasticsearch.transport.nio.channel.CloseFuture; -import org.elasticsearch.transport.nio.channel.ConnectFuture; -import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; import org.junit.Before; @@ -36,12 +28,11 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.ArrayList; -import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Supplier; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class NioClientTests extends ESTestCase { @@ -50,126 +41,41 @@ public class NioClientTests extends ESTestCase { private SocketSelector selector; private ChannelFactory channelFactory; private OpenChannels openChannels = new OpenChannels(logger); - private NioSocketChannel[] channels; - private DiscoveryNode node; - private Consumer listener; - private TransportAddress address; + private InetSocketAddress address; @Before @SuppressWarnings("unchecked") public void setUpClient() { channelFactory = mock(ChannelFactory.class); selector = mock(SocketSelector.class); - listener = mock(Consumer.class); - ArrayList selectors = new ArrayList<>(); selectors.add(selector); Supplier selectorSupplier = new RoundRobinSelectorSupplier(selectors); - client = new NioClient(logger, openChannels, selectorSupplier, TimeValue.timeValueMillis(5), channelFactory); - - channels = new NioSocketChannel[2]; - address = new TransportAddress(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)); - node = new DiscoveryNode("node-id", address, Version.CURRENT); - } - - public void testCreateConnections() throws IOException, InterruptedException { - NioSocketChannel channel1 = mock(NioSocketChannel.class); - ConnectFuture connectFuture1 = mock(ConnectFuture.class); - NioSocketChannel channel2 = mock(NioSocketChannel.class); - ConnectFuture connectFuture2 = mock(ConnectFuture.class); - - when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1, channel2); - when(channel1.getConnectFuture()).thenReturn(connectFuture1); - when(channel2.getConnectFuture()).thenReturn(connectFuture2); - when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); - when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); - - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); - - assertEquals(channel1, channels[0]); - assertEquals(channel2, channels[1]); + client = new NioClient(openChannels, selectorSupplier, channelFactory); + address = new InetSocketAddress(InetAddress.getLoopbackAddress(), 0); } - public void testWithADifferentConnectTimeout() throws IOException, InterruptedException { + public void testCreateConnection() throws IOException, InterruptedException { NioSocketChannel channel1 = mock(NioSocketChannel.class); - ConnectFuture connectFuture1 = mock(ConnectFuture.class); - - when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1); - when(channel1.getConnectFuture()).thenReturn(connectFuture1); - when(connectFuture1.awaitConnectionComplete(3, TimeUnit.MILLISECONDS)).thenReturn(true); - channels = new NioSocketChannel[1]; - client.connectToChannels(node, channels, TimeValue.timeValueMillis(3), listener); + when(channelFactory.openNioChannel(eq(address), eq(selector))).thenReturn(channel1); - assertEquals(channel1, channels[0]); - } + NioSocketChannel nioSocketChannel = client.initiateConnection(address); - public void testConnectionTimeout() throws IOException, InterruptedException { - NioSocketChannel channel1 = mock(NioSocketChannel.class); - ConnectFuture connectFuture1 = mock(ConnectFuture.class); - CloseFuture closeFuture1 = mock(CloseFuture.class); - NioSocketChannel channel2 = mock(NioSocketChannel.class); - ConnectFuture connectFuture2 = mock(ConnectFuture.class); - CloseFuture closeFuture2 = mock(CloseFuture.class); - - when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1, channel2); - when(channel1.getCloseFuture()).thenReturn(closeFuture1); - when(channel1.getConnectFuture()).thenReturn(connectFuture1); - when(channel2.getCloseFuture()).thenReturn(closeFuture2); - when(channel2.getConnectFuture()).thenReturn(connectFuture2); - when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); - when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); - - try { - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); - fail("Should have thrown ConnectTransportException"); - } catch (ConnectTransportException e) { - assertTrue(e.getMessage().contains("connect_timeout[5ms]")); - } - - verify(channel1).closeAsync(); - verify(channel2).closeAsync(); - - assertNull(channels[0]); - assertNull(channels[1]); + assertEquals(channel1, nioSocketChannel); } public void testConnectionException() throws IOException, InterruptedException { - NioSocketChannel channel1 = mock(NioSocketChannel.class); - ConnectFuture connectFuture1 = mock(ConnectFuture.class); - NioSocketChannel channel2 = mock(NioSocketChannel.class); - ConnectFuture connectFuture2 = mock(ConnectFuture.class); IOException ioException = new IOException(); - when(channelFactory.openNioChannel(address.address(), selector, listener)).thenReturn(channel1, channel2); - when(channel1.getConnectFuture()).thenReturn(connectFuture1); - when(channel2.getConnectFuture()).thenReturn(connectFuture2); - when(connectFuture1.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(true); - when(connectFuture2.awaitConnectionComplete(5, TimeUnit.MILLISECONDS)).thenReturn(false); - when(connectFuture2.getException()).thenReturn(ioException); - - try { - client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener); - fail("Should have thrown ConnectTransportException"); - } catch (ConnectTransportException e) { - assertTrue(e.getMessage().contains("connect_exception")); - assertSame(ioException, e.getCause()); - } - - verify(channel1).closeAsync(); - verify(channel2).closeAsync(); - - assertNull(channels[0]); - assertNull(channels[1]); + when(channelFactory.openNioChannel(eq(address), eq(selector))).thenThrow(ioException); + + expectThrows(IOException.class, () -> client.initiateConnection(address)); } public void testCloseDoesNotAllowConnections() throws IOException { client.close(); - assertFalse(client.connectToChannels(node, channels, TimeValue.timeValueMillis(5), listener)); - - for (NioSocketChannel channel : channels) { - assertNull(channel); - } + assertNull(client.initiateConnection(address)); } } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java index f4e21f7093be1..04f1b424142c5 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SimpleNioTransportTests.java @@ -35,6 +35,7 @@ import org.elasticsearch.transport.AbstractSimpleTransportTestCase; import org.elasticsearch.transport.BindTransportException; import org.elasticsearch.transport.ConnectTransportException; +import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.TcpTransport; import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportService; @@ -77,7 +78,7 @@ protected Version getCurrentVersion() { @Override protected SocketEventHandler getSocketEventHandler() { - return new TestingSocketEventHandler(logger, this::exceptionCaught); + return new TestingSocketEventHandler(logger, this::exceptionCaught, openChannels); } }; MockTransportService mockTransportService = @@ -98,9 +99,9 @@ protected MockTransportService build(Settings settings, Version version, Cluster @Override protected void closeConnectionChannel(Transport transport, Transport.Connection connection) throws IOException { - final NioTransport t = (NioTransport) transport; - @SuppressWarnings("unchecked") TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection; - t.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true, false); + @SuppressWarnings("unchecked") + TcpTransport.NodeChannels channels = (TcpTransport.NodeChannels) connection; + TcpChannel.closeChannels(channels.getChannels().subList(0, randomIntBetween(1, channels.getChannels().size())), true); } public void testConnectException() throws UnknownHostException { diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java index 3bc5cd083a692..b1c6fab2065a9 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketEventHandlerTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.transport.nio.channel.CloseFuture; import org.elasticsearch.transport.nio.channel.DoNotRegisterChannel; import org.elasticsearch.transport.nio.channel.NioChannel; import org.elasticsearch.transport.nio.channel.NioSocketChannel; @@ -56,7 +55,7 @@ public class SocketEventHandlerTests extends ESTestCase { public void setUpHandler() throws IOException { exceptionHandler = mock(BiConsumer.class); SocketSelector socketSelector = mock(SocketSelector.class); - handler = new SocketEventHandler(logger, exceptionHandler); + handler = new SocketEventHandler(logger, exceptionHandler, mock(OpenChannels.class)); rawChannel = mock(SocketChannel.class); channel = new DoNotRegisterChannel("", rawChannel, socketSelector); readContext = mock(ReadContext.class); @@ -102,11 +101,8 @@ public void testHandleReadDelegatesToReadContext() throws IOException { public void testHandleReadMarksChannelForCloseIfPeerClosed() throws IOException { NioSocketChannel nioSocketChannel = mock(NioSocketChannel.class); - CloseFuture closeFuture = mock(CloseFuture.class); when(nioSocketChannel.getReadContext()).thenReturn(readContext); when(readContext.read()).thenReturn(-1); - when(nioSocketChannel.getCloseFuture()).thenReturn(closeFuture); - when(closeFuture.isDone()).thenReturn(true); handler.handleRead(nioSocketChannel); diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java index cb266831530c8..fdaed26a557f7 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/SocketSelectorTests.java @@ -34,7 +34,6 @@ import java.nio.channels.ClosedSelectorException; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; -import java.util.HashSet; import java.util.Set; import static org.mockito.Matchers.any; diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java index 29f595c87a53b..7d3cf97ee08ee 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/TestingSocketEventHandler.java @@ -32,8 +32,8 @@ public class TestingSocketEventHandler extends SocketEventHandler { private final Logger logger; - public TestingSocketEventHandler(Logger logger, BiConsumer exceptionHandler) { - super(logger, exceptionHandler); + public TestingSocketEventHandler(Logger logger, BiConsumer exceptionHandler, OpenChannels openChannels) { + super(logger, exceptionHandler, openChannels); this.logger = logger; } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java index 710f26bedcf39..50770f459cd75 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/ChannelFactoryTests.java @@ -23,14 +23,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.SocketSelector; -import org.elasticsearch.transport.nio.TcpReadHandler; import org.junit.After; import org.junit.Before; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; import java.io.IOException; -import java.net.InetAddress; import java.net.InetSocketAddress; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; @@ -48,7 +44,6 @@ public class ChannelFactoryTests extends ESTestCase { private ChannelFactory channelFactory; private ChannelFactory.RawChannelFactory rawChannelFactory; - private Consumer listener; private SocketChannel rawChannel; private ServerSocketChannel rawServerChannel; private SocketSelector socketSelector; @@ -60,7 +55,6 @@ public void setupFactory() throws IOException { rawChannelFactory = mock(ChannelFactory.RawChannelFactory.class); Consumer contextSetter = mock(Consumer.class); channelFactory = new ChannelFactory(rawChannelFactory, contextSetter); - listener = mock(Consumer.class); socketSelector = mock(SocketSelector.class); acceptingSelector = mock(AcceptingSelector.class); rawChannel = SocketChannel.open(); @@ -84,17 +78,13 @@ public void testAcceptChannel() throws IOException { when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); when(serverChannel.getProfile()).thenReturn("parent-profile"); - NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannel, socketSelector, listener); + NioSocketChannel channel = channelFactory.acceptNioChannel(serverChannel, socketSelector); verify(socketSelector).scheduleForRegistration(channel); assertEquals(socketSelector, channel.getSelector()); assertEquals("parent-profile", channel.getProfile()); assertEquals(rawChannel, channel.getRawChannel()); - - channel.getCloseFuture().channelClosed(channel); - - verify(listener).accept(channel); } public void testAcceptedChannelRejected() throws IOException { @@ -102,7 +92,7 @@ public void testAcceptedChannelRejected() throws IOException { when(rawChannelFactory.acceptNioChannel(serverChannel)).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannel, socketSelector, listener)); + expectThrows(IllegalStateException.class, () -> channelFactory.acceptNioChannel(serverChannel, socketSelector)); assertFalse(rawChannel.isOpen()); } @@ -111,17 +101,13 @@ public void testOpenChannel() throws IOException { InetSocketAddress address = mock(InetSocketAddress.class); when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); - NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelector, listener); + NioSocketChannel channel = channelFactory.openNioChannel(address, socketSelector); verify(socketSelector).scheduleForRegistration(channel); assertEquals(socketSelector, channel.getSelector()); assertEquals("client-socket", channel.getProfile()); assertEquals(rawChannel, channel.getRawChannel()); - - channel.getCloseFuture().channelClosed(channel); - - verify(listener).accept(channel); } public void testOpenedChannelRejected() throws IOException { @@ -129,7 +115,7 @@ public void testOpenedChannelRejected() throws IOException { when(rawChannelFactory.openNioChannel(same(address))).thenReturn(rawChannel); doThrow(new IllegalStateException()).when(socketSelector).scheduleForRegistration(any()); - expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelector, listener)); + expectThrows(IllegalStateException.class, () -> channelFactory.openNioChannel(address, socketSelector)); assertFalse(rawChannel.isOpen()); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java index 367df0c78f4c8..62f87d4f57473 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioServerSocketChannelTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TcpChannel; import org.elasticsearch.transport.nio.AcceptingSelector; import org.elasticsearch.transport.nio.AcceptorEventHandler; import org.elasticsearch.transport.nio.OpenChannels; @@ -30,8 +31,6 @@ import java.io.IOException; import java.nio.channels.ServerSocketChannel; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; @@ -48,7 +47,7 @@ public class NioServerSocketChannelTests extends ESTestCase { @Before @SuppressWarnings("unchecked") public void setSelector() throws IOException { - selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(OpenChannels.class), mock(Supplier.class))); + selector = new AcceptingSelector(new AcceptorEventHandler(logger, mock(OpenChannels.class), mock(Supplier.class), (c) -> {})); thread = new Thread(selector::runLoop); closedRawChannel = new AtomicBoolean(false); thread.start(); @@ -61,28 +60,25 @@ public void stopSelector() throws IOException, InterruptedException { thread.join(); } - public void testClose() throws IOException, TimeoutException, InterruptedException { - AtomicReference ref = new AtomicReference<>(); + public void testClose() throws Exception { + AtomicReference ref = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); NioChannel channel = new DoNotCloseServerChannel("nio", mock(ServerSocketChannel.class), mock(ChannelFactory.class), selector); - Consumer listener = (c) -> { + Consumer listener = (c) -> { ref.set(c); latch.countDown(); }; - channel.getCloseFuture().addListener(ActionListener.wrap(listener::accept, (e) -> listener.accept(channel))); + channel.addCloseListener(ActionListener.wrap(listener::accept, (e) -> listener.accept(channel))); - CloseFuture closeFuture = channel.getCloseFuture(); - - assertFalse(closeFuture.isClosed()); + assertTrue(channel.isOpen()); assertFalse(closedRawChannel.get()); - channel.closeAsync(); + TcpChannel.closeChannel(channel, true); - closeFuture.awaitClose(100, TimeUnit.SECONDS); assertTrue(closedRawChannel.get()); - assertTrue(closeFuture.isClosed()); + assertFalse(channel.isOpen()); latch.await(); assertSame(channel, ref.get()); } diff --git a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java index 75ec57b2603db..d8d4b41df7038 100644 --- a/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java +++ b/test/framework/src/test/java/org/elasticsearch/transport/nio/channel/NioSocketChannelTests.java @@ -20,7 +20,10 @@ package org.elasticsearch.transport.nio.channel; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TcpChannel; +import org.elasticsearch.transport.nio.OpenChannels; import org.elasticsearch.transport.nio.SocketEventHandler; import org.elasticsearch.transport.nio.SocketSelector; import org.junit.After; @@ -30,14 +33,13 @@ import java.net.ConnectException; import java.nio.channels.SocketChannel; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.function.BiConsumer; import java.util.function.Consumer; -import static org.hamcrest.Matchers.instanceOf; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -46,11 +48,13 @@ public class NioSocketChannelTests extends ESTestCase { private SocketSelector selector; private AtomicBoolean closedRawChannel; private Thread thread; + private OpenChannels openChannels; @Before @SuppressWarnings("unchecked") public void startSelector() throws IOException { - selector = new SocketSelector(new SocketEventHandler(logger, mock(BiConsumer.class))); + openChannels = new OpenChannels(logger); + selector = new SocketSelector(new SocketEventHandler(logger, mock(BiConsumer.class), openChannels)); thread = new Thread(selector::runLoop); closedRawChannel = new AtomicBoolean(false); thread.start(); @@ -63,64 +67,63 @@ public void stopSelector() throws IOException, InterruptedException { thread.join(); } - public void testClose() throws IOException, TimeoutException, InterruptedException { - AtomicReference ref = new AtomicReference<>(); + public void testClose() throws Exception { + AtomicReference ref = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, mock(SocketChannel.class), selector); + openChannels.clientChannelOpened(socketChannel); socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); - Consumer listener = (c) -> { + Consumer listener = (c) -> { ref.set(c); latch.countDown(); }; - socketChannel.getCloseFuture().addListener(ActionListener.wrap(listener::accept, (e) -> listener.accept(socketChannel))); - CloseFuture closeFuture = socketChannel.getCloseFuture(); + socketChannel.addCloseListener(ActionListener.wrap(listener::accept, (e) -> listener.accept(socketChannel))); - assertFalse(closeFuture.isClosed()); + assertTrue(socketChannel.isOpen()); assertFalse(closedRawChannel.get()); + assertTrue(openChannels.getClientChannels().containsKey(socketChannel)); - socketChannel.closeAsync(); - - closeFuture.awaitClose(100, TimeUnit.SECONDS); + TcpChannel.closeChannel(socketChannel, true); assertTrue(closedRawChannel.get()); - assertTrue(closeFuture.isClosed()); + assertFalse(socketChannel.isOpen()); + assertFalse(openChannels.getClientChannels().containsKey(socketChannel)); latch.await(); assertSame(socketChannel, ref.get()); } - public void testConnectSucceeds() throws IOException, InterruptedException { + public void testConnectSucceeds() throws Exception { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenReturn(true); NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, rawChannel, selector); socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); selector.scheduleForRegistration(socketChannel); - ConnectFuture connectFuture = socketChannel.getConnectFuture(); - assertTrue(connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS)); + PlainActionFuture connectFuture = PlainActionFuture.newFuture(); + socketChannel.addConnectListener(connectFuture); + connectFuture.get(100, TimeUnit.SECONDS); assertTrue(socketChannel.isConnectComplete()); assertTrue(socketChannel.isOpen()); assertFalse(closedRawChannel.get()); - assertFalse(connectFuture.connectFailed()); - assertNull(connectFuture.getException()); } - public void testConnectFails() throws IOException, InterruptedException { + public void testConnectFails() throws Exception { SocketChannel rawChannel = mock(SocketChannel.class); when(rawChannel.finishConnect()).thenThrow(new ConnectException()); NioSocketChannel socketChannel = new DoNotCloseChannel(NioChannel.CLIENT, rawChannel, selector); socketChannel.setContexts(mock(ReadContext.class), mock(WriteContext.class)); selector.scheduleForRegistration(socketChannel); - ConnectFuture connectFuture = socketChannel.getConnectFuture(); - assertFalse(connectFuture.awaitConnectionComplete(100, TimeUnit.SECONDS)); + PlainActionFuture connectFuture = PlainActionFuture.newFuture(); + socketChannel.addConnectListener(connectFuture); + ExecutionException e = expectThrows(ExecutionException.class, () -> connectFuture.get(100, TimeUnit.SECONDS)); + assertTrue(e.getCause() instanceof IOException); assertFalse(socketChannel.isConnectComplete()); // Even if connection fails the channel is 'open' until close() is called assertTrue(socketChannel.isOpen()); - assertTrue(connectFuture.connectFailed()); - assertThat(connectFuture.getException(), instanceOf(ConnectException.class)); } private class DoNotCloseChannel extends DoNotRegisterChannel {