Skip to content

Commit 09c621e

Browse files
coeuvrecopybara-github
authored andcommitted
Remote: Fix a race that AsyncTaskCache#Execution could be reused after disposed which results in CancellationException("disposed") propagated to downstream.
Also added a test case to verify the fix. PiperOrigin-RevId: 364699975
1 parent a8ef70e commit 09c621e

File tree

2 files changed

+133
-68
lines changed

2 files changed

+133
-68
lines changed

src/main/java/com/google/devtools/build/lib/remote/util/AsyncTaskCache.java

+90-50
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import java.util.Map;
2727
import java.util.Optional;
2828
import java.util.concurrent.CancellationException;
29+
import java.util.concurrent.atomic.AtomicBoolean;
2930
import java.util.concurrent.atomic.AtomicInteger;
3031
import java.util.concurrent.atomic.AtomicReference;
3132
import javax.annotation.concurrent.GuardedBy;
@@ -54,7 +55,7 @@ public final class AsyncTaskCache<KeyT, ValueT> {
5455
private final Map<KeyT, ValueT> finished;
5556

5657
@GuardedBy("lock")
57-
private final Map<KeyT, Execution> inProgress;
58+
private final Map<KeyT, Execution<ValueT>> inProgress;
5859

5960
public static <KeyT, ValueT> AsyncTaskCache<KeyT, ValueT> create() {
6061
return new AsyncTaskCache<>();
@@ -90,18 +91,22 @@ public Single<ValueT> executeIfNot(KeyT key, Single<ValueT> task) {
9091
return execute(key, task, false);
9192
}
9293

93-
private class Execution {
94+
private static class Execution<ValueT> {
95+
private final AtomicBoolean isTaskDisposed = new AtomicBoolean(false);
9496
private final Single<ValueT> task;
9597
private final AsyncSubject<ValueT> asyncSubject = AsyncSubject.create();
96-
private final AtomicInteger subscriberCount = new AtomicInteger(0);
98+
private final AtomicInteger referenceCount = new AtomicInteger(0);
9799
private final AtomicReference<Disposable> taskDisposable = new AtomicReference<>(null);
98100

99101
Execution(Single<ValueT> task) {
100102
this.task = task;
101103
}
102104

103-
public Single<ValueT> start() {
104-
if (taskDisposable.get() == null) {
105+
Single<ValueT> executeIfNot() {
106+
checkState(!isTaskDisposed(), "disposed");
107+
108+
int subscribed = referenceCount.getAndIncrement();
109+
if (taskDisposable.get() == null && subscribed == 0) {
105110
task.subscribe(
106111
new SingleObserver<ValueT>() {
107112
@Override
@@ -122,27 +127,39 @@ public void onError(@NonNull Throwable e) {
122127
});
123128
}
124129

125-
return Single.fromObservable(asyncSubject)
126-
.doOnSubscribe(d -> subscriberCount.incrementAndGet())
127-
.doOnDispose(
128-
() -> {
129-
if (subscriberCount.decrementAndGet() == 0) {
130-
Disposable d = taskDisposable.get();
131-
if (d != null) {
132-
d.dispose();
133-
}
134-
asyncSubject.onError(new CancellationException("disposed"));
135-
}
136-
});
130+
return Single.fromObservable(asyncSubject);
131+
}
132+
133+
boolean isTaskTerminated() {
134+
return asyncSubject.hasComplete() || asyncSubject.hasThrowable();
135+
}
136+
137+
boolean isTaskDisposed() {
138+
return isTaskDisposed.get();
139+
}
140+
141+
void tryDisposeTask() {
142+
checkState(!isTaskDisposed(), "disposed");
143+
checkState(!isTaskTerminated(), "terminated");
144+
145+
if (referenceCount.decrementAndGet() == 0) {
146+
isTaskDisposed.set(true);
147+
asyncSubject.onError(new CancellationException("disposed"));
148+
149+
Disposable d = taskDisposable.get();
150+
if (d != null) {
151+
d.dispose();
152+
}
153+
}
137154
}
138155
}
139156

140157
/** Returns count of subscribers for a task. */
141158
public int getSubscriberCount(KeyT key) {
142159
synchronized (lock) {
143-
Execution execution = inProgress.get(key);
160+
Execution<ValueT> execution = inProgress.get(key);
144161
if (execution != null) {
145-
return execution.subscriberCount.get();
162+
return execution.referenceCount.get();
146163
}
147164
}
148165

@@ -158,49 +175,72 @@ public int getSubscriberCount(KeyT key) {
158175
* error if any.
159176
*/
160177
public Single<ValueT> execute(KeyT key, Single<ValueT> task, boolean force) {
161-
return Single.defer(
162-
() -> {
178+
return Single.create(
179+
emitter -> {
163180
synchronized (lock) {
164181
if (!force && finished.containsKey(key)) {
165-
return Single.just(finished.get(key));
182+
emitter.onSuccess(finished.get(key));
183+
return;
166184
}
167185

168186
finished.remove(key);
169187

170-
Execution execution =
188+
Execution<ValueT> execution =
171189
inProgress.computeIfAbsent(
172190
key,
173-
missingKey -> {
191+
ignoredKey -> {
174192
AtomicInteger subscribeTimes = new AtomicInteger(0);
175-
return new Execution(
193+
return new Execution<>(
176194
Single.defer(
177-
() -> {
178-
int times = subscribeTimes.incrementAndGet();
179-
checkState(times == 1, "Subscribed more than once to the task");
180-
return task;
181-
})
182-
.doOnSuccess(
183-
value -> {
184-
synchronized (lock) {
185-
finished.put(key, value);
186-
inProgress.remove(key);
187-
}
188-
})
189-
.doOnError(
190-
error -> {
191-
synchronized (lock) {
192-
inProgress.remove(key);
193-
}
194-
})
195-
.doOnDispose(
196-
() -> {
197-
synchronized (lock) {
198-
inProgress.remove(key);
199-
}
200-
}));
195+
() -> {
196+
int times = subscribeTimes.incrementAndGet();
197+
checkState(times == 1, "Subscribed more than once to the task");
198+
return task;
199+
}));
201200
});
202201

203-
return execution.start();
202+
execution
203+
.executeIfNot()
204+
.subscribe(
205+
new SingleObserver<ValueT>() {
206+
@Override
207+
public void onSubscribe(@NonNull Disposable d) {
208+
emitter.setCancellable(
209+
() -> {
210+
d.dispose();
211+
212+
if (!execution.isTaskTerminated()) {
213+
synchronized (lock) {
214+
execution.tryDisposeTask();
215+
if (execution.isTaskDisposed()) {
216+
inProgress.remove(key);
217+
}
218+
}
219+
}
220+
});
221+
}
222+
223+
@Override
224+
public void onSuccess(@NonNull ValueT value) {
225+
synchronized (lock) {
226+
finished.put(key, value);
227+
inProgress.remove(key);
228+
}
229+
230+
emitter.onSuccess(value);
231+
}
232+
233+
@Override
234+
public void onError(@NonNull Throwable e) {
235+
synchronized (lock) {
236+
inProgress.remove(key);
237+
}
238+
239+
if (!emitter.isDisposed()) {
240+
emitter.onError(e);
241+
}
242+
}
243+
});
204244
}
205245
});
206246
}

src/test/java/com/google/devtools/build/lib/remote/util/AsyncTaskCacheTest.java

+43-18
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,15 @@
1414
package com.google.devtools.build.lib.remote.util;
1515

1616
import static com.google.common.truth.Truth.assertThat;
17+
import static java.util.concurrent.TimeUnit.SECONDS;
1718

1819
import io.reactivex.rxjava3.core.Single;
1920
import io.reactivex.rxjava3.core.SingleEmitter;
2021
import io.reactivex.rxjava3.observers.TestObserver;
21-
import io.reactivex.rxjava3.plugins.RxJavaPlugins;
2222
import java.util.concurrent.atomic.AtomicBoolean;
2323
import java.util.concurrent.atomic.AtomicInteger;
2424
import java.util.concurrent.atomic.AtomicReference;
25-
import org.junit.After;
26-
import org.junit.Before;
25+
import org.junit.Rule;
2726
import org.junit.Test;
2827
import org.junit.runner.RunWith;
2928
import org.junit.runners.JUnit4;
@@ -32,21 +31,7 @@
3231
@RunWith(JUnit4.class)
3332
public class AsyncTaskCacheTest {
3433

35-
private final AtomicReference<Throwable> rxGlobalThrowable = new AtomicReference<>(null);
36-
37-
@Before
38-
public void setUp() {
39-
RxJavaPlugins.setErrorHandler(rxGlobalThrowable::set);
40-
}
41-
42-
@After
43-
public void tearDown() throws Throwable {
44-
// Make sure rxjava didn't receive global errors
45-
Throwable t = rxGlobalThrowable.getAndSet(null);
46-
if (t != null) {
47-
throw t;
48-
}
49-
}
34+
@Rule public final RxNoGlobalErrorsRule rxNoGlobalErrorsRule = new RxNoGlobalErrorsRule();
5035

5136
@Test
5237
public void execute_noSubscription_noExecution() {
@@ -296,4 +281,44 @@ public void execute_multipleTasks_completeOne() {
296281
assertThat(cache.getInProgressTasks()).containsExactly("key2");
297282
assertThat(cache.getFinishedTasks()).containsExactly("key1");
298283
}
284+
285+
@Test
286+
public void execute_executeAndDisposeLoop_noErrors() throws InterruptedException {
287+
AsyncTaskCache<String, Long> cache = AsyncTaskCache.create();
288+
Single<Long> task = Single.timer(1, SECONDS);
289+
AtomicReference<Throwable> error = new AtomicReference<>(null);
290+
AtomicInteger errorCount = new AtomicInteger(0);
291+
int executionCount = 100;
292+
Runnable runnable =
293+
() -> {
294+
try {
295+
for (int i = 0; i < executionCount; ++i) {
296+
TestObserver<Long> observer = cache.execute("key1", task, true).test();
297+
observer.assertNoErrors();
298+
observer.dispose();
299+
}
300+
} catch (Throwable t) {
301+
errorCount.incrementAndGet();
302+
error.set(t);
303+
}
304+
};
305+
int threadCount = 10;
306+
Thread[] threads = new Thread[threadCount];
307+
for (int i = 0; i < threadCount; ++i) {
308+
Thread thread = new Thread(runnable);
309+
threads[i] = thread;
310+
}
311+
312+
for (Thread thread : threads) {
313+
thread.start();
314+
}
315+
for (Thread thread : threads) {
316+
thread.join();
317+
}
318+
319+
if (error.get() != null) {
320+
throw new IllegalStateException(
321+
String.format("%s/%s errors", errorCount.get(), threadCount), error.get());
322+
}
323+
}
299324
}

0 commit comments

Comments
 (0)