Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PublishSubject non-deterministic behavior on concurrent modification #288

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 208 additions & 9 deletions rxjava-core/src/main/java/rx/subjects/PublishSubject.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,30 @@
*/
package rx.subjects;

import static org.junit.Assert.*;
import static org.mockito.Matchers.*;
import static org.mockito.Mockito.*;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;

import junit.framework.Assert;

import org.junit.Test;
import org.mockito.InOrder;
import org.mockito.Mockito;

import rx.Notification;
import rx.Observable;
import rx.Observer;
import rx.Subscription;
import rx.operators.AtomicObservableSubscription;
import rx.subscriptions.Subscriptions;
import rx.util.functions.Action1;
import rx.util.functions.Func0;
import rx.util.functions.Func1;
Expand Down Expand Up @@ -62,10 +67,15 @@
public class PublishSubject<T> extends Subject<T, T> {
public static <T> PublishSubject<T> create() {
final ConcurrentHashMap<Subscription, Observer<T>> observers = new ConcurrentHashMap<Subscription, Observer<T>>();

final AtomicReference<Notification<T>> terminalState = new AtomicReference<Notification<T>>();

Func1<Observer<T>, Subscription> onSubscribe = new Func1<Observer<T>, Subscription>() {
@Override
public Subscription call(Observer<T> observer) {
// shortcut check if terminal state exists already
Subscription s = checkTerminalState(observer);
if(s != null) return s;

final AtomicObservableSubscription subscription = new AtomicObservableSubscription();

subscription.wrap(new Subscription() {
Expand All @@ -76,43 +86,110 @@ public void unsubscribe() {
}
});

// on subscribe add it to the map of outbound observers to notify
observers.put(subscription, observer);
return subscription;
/**
* NOTE: We are synchronizing to avoid a race condition between terminalState being set and
* a new observer being added to observers.
*
* The synchronization only occurs on subscription and terminal states, it does not affect onNext calls
* so a high-volume hot-observable will not pay this cost for emitting data.
*
* Due to the restricted impact of blocking synchronization here I have not pursued more complicated
* approaches to try and stay completely non-blocking.
*/
synchronized (terminalState) {
// check terminal state again
s = checkTerminalState(observer);
if (s != null)
return s;

// on subscribe add it to the map of outbound observers to notify
observers.put(subscription, observer);

return subscription;
}
}

private Subscription checkTerminalState(Observer<T> observer) {
Notification<T> n = terminalState.get();
if (n != null) {
// we are terminated to immediately emit and don't continue with subscription
if (n.isOnCompleted()) {
observer.onCompleted();
} else {
observer.onError(n.getException());
}
return Subscriptions.empty();
} else {
return null;
}
}
};

return new PublishSubject<T>(onSubscribe, observers);
return new PublishSubject<T>(onSubscribe, observers, terminalState);
}

private final ConcurrentHashMap<Subscription, Observer<T>> observers;
private final AtomicReference<Notification<T>> terminalState;

protected PublishSubject(Func1<Observer<T>, Subscription> onSubscribe, ConcurrentHashMap<Subscription, Observer<T>> observers) {
protected PublishSubject(Func1<Observer<T>, Subscription> onSubscribe, ConcurrentHashMap<Subscription, Observer<T>> observers, AtomicReference<Notification<T>> terminalState) {
super(onSubscribe);
this.observers = observers;
this.terminalState = terminalState;
}

@Override
public void onCompleted() {
for (Observer<T> observer : observers.values()) {
/**
* Synchronizing despite terminalState being an AtomicReference because of multi-step logic in subscription.
* Why use AtomicReference then? Convenient for passing around a mutable reference holder between the
* onSubscribe function and PublishSubject instance... and it's a "better volatile" for the shortcut codepath.
*/
synchronized (terminalState) {
terminalState.set(new Notification<T>());
}
for (Observer<T> observer : snapshotOfValues()) {
observer.onCompleted();
}
observers.clear();
}

@Override
public void onError(Exception e) {
for (Observer<T> observer : observers.values()) {
/**
* Synchronizing despite terminalState being an AtomicReference because of multi-step logic in subscription.
* Why use AtomicReference then? Convenient for passing around a mutable reference holder between the
* onSubscribe function and PublishSubject instance... and it's a "better volatile" for the shortcut codepath.
*/
synchronized (terminalState) {
terminalState.set(new Notification<T>(e));
}
for (Observer<T> observer : snapshotOfValues()) {
observer.onError(e);
}
observers.clear();
}

@Override
public void onNext(T args) {
for (Observer<T> observer : observers.values()) {
for (Observer<T> observer : snapshotOfValues()) {
observer.onNext(args);
}
}

/**
* Current snapshot of 'values()' so that concurrent modifications aren't included.
*
* This makes it behave deterministically in a single-threaded execution when nesting subscribes.
*
* In multi-threaded execution it will cause new subscriptions to wait until the following onNext instead
* of possibly being included in the current onNext iteration.
*
* @return List<Observer<T>>
*/
private Collection<Observer<T>> snapshotOfValues() {
return new ArrayList<Observer<T>>(observers.values());
}

public static class UnitTest {
@Test
public void test() {
Expand Down Expand Up @@ -307,6 +384,75 @@ private void assertObservedUntilTwo(Observer<String> aObserver)
verify(aObserver, Mockito.never()).onCompleted();
}

/**
* Test that subscribing after onError/onCompleted immediately terminates instead of causing it to hang.
*
* Nothing is mentioned in Rx Guidelines for what to do in this case so I'm doing what seems to make sense
* which is:
*
* - cache terminal state (onError/onCompleted)
* - any subsequent subscriptions will immediately receive the terminal state rather than start a new subscription
*
*/
@Test
public void testUnsubscribeAfterOnCompleted() {
PublishSubject<Object> subject = PublishSubject.create();

@SuppressWarnings("unchecked")
Observer<String> anObserver = mock(Observer.class);
subject.subscribe(anObserver);

subject.onNext("one");
subject.onNext("two");
subject.onCompleted();

InOrder inOrder = inOrder(anObserver);
inOrder.verify(anObserver, times(1)).onNext("one");
inOrder.verify(anObserver, times(1)).onNext("two");
inOrder.verify(anObserver, times(1)).onCompleted();
inOrder.verify(anObserver, Mockito.never()).onError(any(Exception.class));

@SuppressWarnings("unchecked")
Observer<String> anotherObserver = mock(Observer.class);
subject.subscribe(anotherObserver);

inOrder = inOrder(anotherObserver);
inOrder.verify(anotherObserver, Mockito.never()).onNext("one");
inOrder.verify(anotherObserver, Mockito.never()).onNext("two");
inOrder.verify(anotherObserver, times(1)).onCompleted();
inOrder.verify(anotherObserver, Mockito.never()).onError(any(Exception.class));
}

@Test
public void testUnsubscribeAfterOnError() {
PublishSubject<Object> subject = PublishSubject.create();
RuntimeException exception = new RuntimeException("failure");

@SuppressWarnings("unchecked")
Observer<String> anObserver = mock(Observer.class);
subject.subscribe(anObserver);

subject.onNext("one");
subject.onNext("two");
subject.onError(exception);

InOrder inOrder = inOrder(anObserver);
inOrder.verify(anObserver, times(1)).onNext("one");
inOrder.verify(anObserver, times(1)).onNext("two");
inOrder.verify(anObserver, times(1)).onError(exception);
inOrder.verify(anObserver, Mockito.never()).onCompleted();

@SuppressWarnings("unchecked")
Observer<String> anotherObserver = mock(Observer.class);
subject.subscribe(anotherObserver);

inOrder = inOrder(anotherObserver);
inOrder.verify(anotherObserver, Mockito.never()).onNext("one");
inOrder.verify(anotherObserver, Mockito.never()).onNext("two");
inOrder.verify(anotherObserver, times(1)).onError(exception);
inOrder.verify(anotherObserver, Mockito.never()).onCompleted();
}

@Test
public void testUnsubscribe()
{
Expand Down Expand Up @@ -340,5 +486,58 @@ public void call(PublishSubject<Object> DefaultSubject)
}
});
}

@Test
public void testNestedSubscribe() {
final PublishSubject<Integer> s = PublishSubject.create();

final AtomicInteger countParent = new AtomicInteger();
final AtomicInteger countChildren = new AtomicInteger();
final AtomicInteger countTotal = new AtomicInteger();

final ArrayList<String> list = new ArrayList<String>();

s.mapMany(new Func1<Integer, Observable<String>>() {

@Override
public Observable<String> call(final Integer v) {
countParent.incrementAndGet();

// then subscribe to subject again (it will not receive the previous value)
return s.map(new Func1<Integer, String>() {

@Override
public String call(Integer v2) {
countChildren.incrementAndGet();
return "Parent: " + v + " Child: " + v2;
}

});
}

}).subscribe(new Action1<String>() {

@Override
public void call(String v) {
countTotal.incrementAndGet();
list.add(v);
}

});


for(int i=0; i<10; i++) {
s.onNext(i);
}
s.onCompleted();

// System.out.println("countParent: " + countParent.get());
// System.out.println("countChildren: " + countChildren.get());
// System.out.println("countTotal: " + countTotal.get());

// 9+8+7+6+5+4+3+2+1+0 == 45
assertEquals(45, list.size());
}

}
}