Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/83035.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 83035
summary: Correct context for `ClusterConnManager` listener
area: Network
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
import org.elasticsearch.common.util.concurrent.ListenableFuture;
import org.elasticsearch.common.util.concurrent.RunOnce;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasable;
Expand Down Expand Up @@ -44,18 +46,20 @@ public class ClusterConnectionManager implements ConnectionManager {
private final AbstractRefCounted connectingRefCounter = AbstractRefCounted.of(this::pendingConnectionsComplete);

private final Transport transport;
private final ThreadContext threadContext;
private final ConnectionProfile defaultProfile;
private final AtomicBoolean closing = new AtomicBoolean(false);
private final CountDownLatch closeLatch = new CountDownLatch(1);
private final DelegatingNodeConnectionListener connectionListener = new DelegatingNodeConnectionListener();

public ClusterConnectionManager(Settings settings, Transport transport) {
this(ConnectionProfile.buildDefaultConnectionProfile(settings), transport);
public ClusterConnectionManager(Settings settings, Transport transport, ThreadContext threadContext) {
this(ConnectionProfile.buildDefaultConnectionProfile(settings), transport, threadContext);
}

public ClusterConnectionManager(ConnectionProfile connectionProfile, Transport transport) {
public ClusterConnectionManager(ConnectionProfile connectionProfile, Transport transport, ThreadContext threadContext) {
this.transport = transport;
this.defaultProfile = connectionProfile;
this.threadContext = threadContext;
}

@Override
Expand Down Expand Up @@ -91,7 +95,13 @@ public void connectToNode(
ConnectionValidator connectionValidator,
ActionListener<Releasable> listener
) throws ConnectTransportException {
connectToNodeOrRetry(node, connectionProfile, connectionValidator, 0, listener);
connectToNodeOrRetry(
node,
connectionProfile,
connectionValidator,
0,
ContextPreservingActionListener.wrapPreservingContext(listener, threadContext)
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ int getNumNodesConnected() {
}

private static ConnectionManager createConnectionManager(ConnectionProfile connectionProfile, TransportService transportService) {
return new ClusterConnectionManager(connectionProfile, transportService.transport);
return new ClusterConnectionManager(connectionProfile, transportService.transport, transportService.threadPool.getThreadContext());
}

ConnectionManager getConnectionManager() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ public TransportService(
localNodeFactory,
clusterSettings,
taskHeaders,
new ClusterConnectionManager(settings, transport)
new ClusterConnectionManager(settings, transport, threadPool.getThreadContext())
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportService;

Expand Down Expand Up @@ -77,7 +78,7 @@ public void testMainActionClusterAvailable() {
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
null,
mock(ThreadPool.class),
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private TransportMultiSearchAction createTransportMultiSearchAction(boolean cont
TransportService transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
null,
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(settings, boundAddress.publishAddress(), UUIDs.randomBase64UUID()),
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ protected TestAction(boolean withDocumentFailureOnPrimary, boolean withDocumentF
new TransportService(
Settings.EMPTY,
mock(Transport.class),
null,
TransportWriteActionTests.threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> null,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,16 @@ public void testJoinDeduplication() {
DeterministicTaskQueue deterministicTaskQueue = new DeterministicTaskQueue();
CapturingTransport capturingTransport = new HandshakingCapturingTransport();
DiscoveryNode localNode = new DiscoveryNode("node0", buildNewFakeTransportAddress(), Version.CURRENT);
final ThreadPool threadPool = deterministicTaskQueue.getThreadPool();
TransportService transportService = new TransportService(
Settings.EMPTY,
capturingTransport,
deterministicTaskQueue.getThreadPool(),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
x -> localNode,
null,
Collections.emptySet(),
new ClusterConnectionManager(Settings.EMPTY, capturingTransport)
new ClusterConnectionManager(Settings.EMPTY, capturingTransport, threadPool.getThreadContext())
);
JoinHelper joinHelper = new JoinHelper(
Settings.EMPTY,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.elasticsearch.test.transport.CapturingTransport;
import org.elasticsearch.test.transport.CapturingTransport.CapturedRequest;
import org.elasticsearch.test.transport.StubbableConnectionManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.ClusterConnectionManager;
import org.elasticsearch.transport.ConnectionManager;
import org.elasticsearch.transport.TransportException;
Expand Down Expand Up @@ -210,7 +211,13 @@ public void setup() {

localNode = newDiscoveryNode("local-node");

ConnectionManager innerConnectionManager = new ClusterConnectionManager(settings, capturingTransport);
final ThreadPool threadPool = deterministicTaskQueue.getThreadPool();

final ConnectionManager innerConnectionManager = new ClusterConnectionManager(
settings,
capturingTransport,
threadPool.getThreadContext()
);
StubbableConnectionManager connectionManager = new StubbableConnectionManager(innerConnectionManager);
connectionManager.setDefaultNodeConnectedBehavior((cm, discoveryNode) -> {
final boolean isConnected = connectedNodes.contains(discoveryNode);
Expand All @@ -222,7 +229,7 @@ public void setup() {
transportService = new TransportService(
settings,
capturingTransport,
deterministicTaskQueue.getThreadPool(),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundTransportAddress -> localNode,
null,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.core.TimeValue;
Expand All @@ -45,6 +46,7 @@
import java.util.function.Supplier;

import static org.elasticsearch.test.ActionListenerUtils.anyActionListener;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.notNullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
Expand All @@ -63,7 +65,7 @@ public void createConnectionManager() {
Settings settings = Settings.builder().put("node.name", ClusterConnectionManagerTests.class.getSimpleName()).build();
threadPool = new ThreadPool(settings);
transport = mock(Transport.class);
connectionManager = new ClusterConnectionManager(settings, transport);
connectionManager = new ClusterConnectionManager(settings, transport, threadPool.getThreadContext());
TimeValue oneSecond = new TimeValue(1000);
TimeValue oneMinute = TimeValue.timeValueMinutes(1);
connectionProfile = ConnectionProfile.buildSingleChannelProfile(
Expand Down Expand Up @@ -254,6 +256,9 @@ public void testConcurrentConnects() throws Exception {
int threadCount = between(1, 10);
Releasable[] releasables = new Releasable[threadCount];

final ThreadContext threadContext = threadPool.getThreadContext();
final String contextHeader = "test-context-header";

CyclicBarrier barrier = new CyclicBarrier(threadCount + 1);
Semaphore pendingCloses = new Semaphore(threadCount);
for (int i = 0; i < threadCount; i++) {
Expand All @@ -265,27 +270,33 @@ public void testConcurrentConnects() throws Exception {
throw new RuntimeException(e);
}
CountDownLatch latch = new CountDownLatch(1);
connectionManager.connectToNode(node, connectionProfile, validator, ActionListener.wrap(c -> {
assert connectionManager.nodeConnected(node);

assertTrue(pendingCloses.tryAcquire());
connectionManager.getConnection(node).addRemovedListener(ActionListener.wrap(pendingCloses::release));

if (randomBoolean()) {
releasables[threadIndex] = c;
nodeConnectedCount.incrementAndGet();
} else {
Releasables.close(c);
nodeClosedCount.incrementAndGet();
}

assert latch.getCount() == 1;
latch.countDown();
}, e -> {
nodeFailureCount.incrementAndGet();
assert latch.getCount() == 1;
latch.countDown();
}));
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
final String contextValue = randomAlphaOfLength(10);
threadContext.putHeader(contextHeader, contextValue);
connectionManager.connectToNode(node, connectionProfile, validator, ActionListener.wrap(c -> {
assert connectionManager.nodeConnected(node);
assertThat(threadContext.getHeader(contextHeader), equalTo(contextValue));

assertTrue(pendingCloses.tryAcquire());
connectionManager.getConnection(node).addRemovedListener(ActionListener.wrap(pendingCloses::release));

if (randomBoolean()) {
releasables[threadIndex] = c;
nodeConnectedCount.incrementAndGet();
} else {
Releasables.close(c);
nodeClosedCount.incrementAndGet();
}

assert latch.getCount() == 1;
latch.countDown();
}, e -> {
assertThat(threadContext.getHeader(contextHeader), equalTo(contextValue));
nodeFailureCount.incrementAndGet();
assert latch.getCount() == 1;
latch.countDown();
}));
}
try {
latch.await();
} catch (InterruptedException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ public void testProxyStrategyWillOpenExpectedNumberOfConnectionsToAddress() {
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -127,7 +131,11 @@ public void testProxyStrategyWillOpenNewConnectionsOnDisconnect() throws Excepti
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);

AtomicBoolean useAddress1 = new AtomicBoolean(true);
Expand Down Expand Up @@ -189,7 +197,11 @@ public void testConnectFailsWithIncompatibleNodes() {
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -232,7 +244,11 @@ public void testClusterNameValidationPreventConnectingToDifferentClusters() thro
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);

AtomicBoolean useAddress1 = new AtomicBoolean(true);
Expand Down Expand Up @@ -295,7 +311,11 @@ public void testProxyStrategyWillResolveAddressesEachConnect() throws Exception
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -330,7 +350,11 @@ public void testProxyStrategyWillNeedToBeRebuiltIfNumOfSocketsOrAddressesOrServe
localService.start();
localService.acceptIncomingRequests();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down Expand Up @@ -435,7 +459,11 @@ public void testServerNameAttributes() {

String address = "localhost:" + address1.getPort();

ClusterConnectionManager connectionManager = new ClusterConnectionManager(profile, localService.transport);
final ClusterConnectionManager connectionManager = new ClusterConnectionManager(
profile,
localService.transport,
threadPool.getThreadContext()
);
int numOfConnections = randomIntBetween(4, 8);
try (
RemoteConnectionManager remoteConnectionManager = new RemoteConnectionManager(clusterAlias, connectionManager);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;

import java.net.InetAddress;
Expand All @@ -35,7 +36,10 @@ public class RemoteConnectionManagerTests extends ESTestCase {
public void setUp() throws Exception {
super.setUp();
transport = mock(Transport.class);
remoteConnectionManager = new RemoteConnectionManager("remote-cluster", new ClusterConnectionManager(Settings.EMPTY, transport));
remoteConnectionManager = new RemoteConnectionManager(
"remote-cluster",
new ClusterConnectionManager(Settings.EMPTY, transport, new ThreadContext(Settings.EMPTY))
);
}

@SuppressWarnings("unchecked")
Expand Down
Loading