Skip to content

Commit 6d99d7d

Browse files
authored
ListenableFuture should preserve ThreadContext (#34394)
ListenableFuture may run a listener on the same thread that called the addListener method or it may execute on another thread after the future has completed. Whenever the ListenableFuture stores the listener for execution later, it should preserve the thread context which is what this change does.
1 parent 7bc11a8 commit 6d99d7d

File tree

4 files changed

+31
-9
lines changed

4 files changed

+31
-9
lines changed

server/src/main/java/org/elasticsearch/common/util/concurrent/ListenableFuture.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.elasticsearch.common.util.concurrent;
2121

2222
import org.elasticsearch.action.ActionListener;
23+
import org.elasticsearch.action.support.ContextPreservingActionListener;
2324
import org.elasticsearch.common.collect.Tuple;
2425

2526
import java.util.ArrayList;
@@ -47,7 +48,7 @@ public final class ListenableFuture<V> extends BaseFuture<V> implements ActionLi
4748
* If the future has completed, the listener will be notified immediately without forking to
4849
* a different thread.
4950
*/
50-
public void addListener(ActionListener<V> listener, ExecutorService executor) {
51+
public void addListener(ActionListener<V> listener, ExecutorService executor, ThreadContext threadContext) {
5152
if (done) {
5253
// run the callback directly, we don't hold the lock and don't need to fork!
5354
notifyListener(listener, EsExecutors.newDirectExecutorService());
@@ -59,7 +60,7 @@ public void addListener(ActionListener<V> listener, ExecutorService executor) {
5960
if (done) {
6061
run = true;
6162
} else {
62-
listeners.add(new Tuple<>(listener, executor));
63+
listeners.add(new Tuple<>(ContextPreservingActionListener.wrapPreservingContext(listener, threadContext), executor));
6364
run = false;
6465
}
6566
}

server/src/test/java/org/elasticsearch/common/util/concurrent/ListenableFutureTests.java

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
package org.elasticsearch.common.util.concurrent;
2121

22+
import org.apache.logging.log4j.message.ParameterizedMessage;
2223
import org.elasticsearch.action.ActionListener;
2324
import org.elasticsearch.common.settings.Settings;
2425
import org.elasticsearch.test.ESTestCase;
@@ -30,9 +31,12 @@
3031
import java.util.concurrent.ExecutorService;
3132
import java.util.concurrent.atomic.AtomicInteger;
3233

34+
import static org.hamcrest.Matchers.is;
35+
3336
public class ListenableFutureTests extends ESTestCase {
3437

3538
private ExecutorService executorService;
39+
private ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
3640

3741
@After
3842
public void stopExecutorService() throws InterruptedException {
@@ -46,7 +50,7 @@ public void testListenableFutureNotifiesListeners() {
4650
AtomicInteger notifications = new AtomicInteger(0);
4751
final int numberOfListeners = scaledRandomIntBetween(1, 12);
4852
for (int i = 0; i < numberOfListeners; i++) {
49-
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService());
53+
future.addListener(ActionListener.wrap(notifications::incrementAndGet), EsExecutors.newDirectExecutorService(), threadContext);
5054
}
5155

5256
future.onResponse("");
@@ -63,7 +67,7 @@ public void testListenableFutureNotifiesListenersOnException() {
6367
future.addListener(ActionListener.wrap(s -> fail("this should never be called"), e -> {
6468
assertEquals(exception, e);
6569
notifications.incrementAndGet();
66-
}), EsExecutors.newDirectExecutorService());
70+
}), EsExecutors.newDirectExecutorService(), threadContext);
6771
}
6872

6973
future.onFailure(exception);
@@ -76,7 +80,7 @@ public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarri
7680
final int completingThread = randomIntBetween(0, numberOfThreads - 1);
7781
final ListenableFuture<String> future = new ListenableFuture<>();
7882
executorService = EsExecutors.newFixed("testConcurrentListenerRegistrationAndCompletion", numberOfThreads, 1000,
79-
EsExecutors.daemonThreadFactory("listener"), new ThreadContext(Settings.EMPTY));
83+
EsExecutors.daemonThreadFactory("listener"), threadContext);
8084
final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads);
8185
final CountDownLatch listenersLatch = new CountDownLatch(numberOfThreads - 1);
8286
final AtomicInteger numResponses = new AtomicInteger(0);
@@ -85,20 +89,31 @@ public void testConcurrentListenerRegistrationAndCompletion() throws BrokenBarri
8589
for (int i = 0; i < numberOfThreads; i++) {
8690
final int threadNum = i;
8791
Thread thread = new Thread(() -> {
92+
threadContext.putTransient("key", threadNum);
8893
try {
8994
barrier.await();
9095
if (threadNum == completingThread) {
96+
// we need to do more than just call onResponse as this often results in synchronous
97+
// execution of the listeners instead of actually going async
98+
final int waitTime = randomIntBetween(0, 50);
99+
Thread.sleep(waitTime);
100+
logger.info("completing the future after sleeping {}ms", waitTime);
91101
future.onResponse("");
102+
logger.info("future received response");
92103
} else {
104+
logger.info("adding listener {}", threadNum);
93105
future.addListener(ActionListener.wrap(s -> {
106+
logger.info("listener {} received value {}", threadNum, s);
94107
assertEquals("", s);
108+
assertThat(threadContext.getTransient("key"), is(threadNum));
95109
numResponses.incrementAndGet();
96110
listenersLatch.countDown();
97111
}, e -> {
98-
logger.error("caught unexpected exception", e);
112+
logger.error(new ParameterizedMessage("listener {} caught unexpected exception", threadNum), e);
99113
numExceptions.incrementAndGet();
100114
listenersLatch.countDown();
101-
}), executorService);
115+
}), executorService, threadContext);
116+
logger.info("listener {} added", threadNum);
102117
}
103118
barrier.await();
104119
} catch (InterruptedException | BrokenBarrierException e) {

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealm.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ private void authenticateWithCache(UsernamePasswordToken token, ActionListener<A
153153
// is cleared of the failed authentication
154154
cache.invalidate(token.principal(), listenableCacheEntry);
155155
authenticateWithCache(token, listener);
156-
}), threadPool.executor(ThreadPool.Names.GENERIC));
156+
}), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
157157
} else {
158158
// attempt authentication against the authentication source
159159
doAuthenticate(token, ActionListener.wrap(authResult -> {
@@ -255,7 +255,7 @@ private void lookupWithCache(String username, ActionListener<User> listener) {
255255
} else {
256256
listener.onResponse(null);
257257
}
258-
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC));
258+
}, listener::onFailure), threadPool.executor(ThreadPool.Names.GENERIC), threadPool.getThreadContext());
259259
} catch (final ExecutionException e) {
260260
listener.onFailure(e);
261261
}

x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/CachingUsernamePasswordRealmTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,14 +469,17 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
469469
List<Thread> threads = new ArrayList<>(numberOfThreads);
470470
for (int i = 0; i < numberOfThreads; i++) {
471471
final boolean invalidPassword = randomBoolean();
472+
final int threadNum = i;
472473
threads.add(new Thread(() -> {
474+
threadPool.getThreadContext().putTransient("key", threadNum);
473475
try {
474476
latch.countDown();
475477
latch.await();
476478
for (int i1 = 0; i1 < numberOfIterations; i1++) {
477479
UsernamePasswordToken token = new UsernamePasswordToken(username, invalidPassword ? randomPassword : password);
478480

479481
realm.authenticate(token, ActionListener.wrap((result) -> {
482+
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
480483
if (invalidPassword && result.isAuthenticated()) {
481484
throw new RuntimeException("invalid password led to an authenticated user: " + result);
482485
} else if (invalidPassword == false && result.isAuthenticated() == false) {
@@ -529,12 +532,15 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
529532
final CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
530533
List<Thread> threads = new ArrayList<>(numberOfThreads);
531534
for (int i = 0; i < numberOfThreads; i++) {
535+
final int threadNum = i;
532536
threads.add(new Thread(() -> {
533537
try {
538+
threadPool.getThreadContext().putTransient("key", threadNum);
534539
latch.countDown();
535540
latch.await();
536541
for (int i1 = 0; i1 < numberOfIterations; i1++) {
537542
realm.lookupUser(username, ActionListener.wrap((user) -> {
543+
assertThat(threadPool.getThreadContext().getTransient("key"), is(threadNum));
538544
if (user == null) {
539545
throw new RuntimeException("failed to lookup user");
540546
}

0 commit comments

Comments
 (0)