Skip to content

Commit

Permalink
GH-414: Correct error handling in KeyExchangeMessageHandler
Browse files Browse the repository at this point in the history
During KEX we may queue up higher-level packets. When KEX ends, we flush
this queue, writing all these queued packets. When writing a queued
packet fails, previous code handled the failure wrongly, leading to an
inconsistent state that could cause an endless loop in writeOrEnqueue.

Fix this by making sure that (a) kexFlushed is true also in this case,
and (b) by fulfilling the kexFlushedFuture and closing the session
outside of the critical region.

Because kexFlushed = true now also on failure, drain the queue and set
up all the queued futures such that they will be fulfilled with the
exception.

Additionally, do the same if the session closes while we're still
flushing queued packets.

Bug: #414
  • Loading branch information
tomaswolf committed Sep 8, 2023
1 parent 79573ca commit e92a46b
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
* [GH-403](https://github.com/apache/mina-sshd/issues/403) Work-around a bug in WS_FTP <= 12.9 SFTP clients.
* [GH-407](https://github.com/apache/mina-sshd/issues/407) (Regression in 2.10.0) SFTP performance fix: override `FilterOutputStream.write(byte[], int, int)`.
* [GH-410](https://github.com/apache/mina-sshd/issues/410) Fix a race condition to ensure `SSH_MSG_CHANNEL_EOF` is always sent before `SSH_MSG_CHANNEL_CLOSE`.
* [GH-414](https://github.com/apache/mina-sshd/issues/414) Fix error handling while flushing queued packets at end of KEX.

* [SSHD-1259](https://issues.apache.org/jira/browse/SSHD-1259) Consider all applicable host keys from the known_hosts files.
* [SSHD-1310](https://issues.apache.org/jira/browse/SSHD-1310) `SftpFileSystem`: do not close user session.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;

import org.apache.sshd.common.SshConstants;
import org.apache.sshd.common.SshException;
import org.apache.sshd.common.future.DefaultKeyExchangeFuture;
import org.apache.sshd.common.io.AbstractIoWriteFuture;
import org.apache.sshd.common.io.IoWriteFuture;
import org.apache.sshd.common.kex.KexState;
import org.apache.sshd.common.util.ExceptionUtils;
Expand Down Expand Up @@ -101,22 +104,23 @@ public class KeyExchangeMessageHandler {
protected final Queue<PendingWriteFuture> pendingPackets = new ConcurrentLinkedQueue<>();

/**
* Indicates that all pending packets have been flushed.
* Indicates that all pending packets have been flushed. Set to {@code true} by the flushing thread, or at the end
* of KEX if there are no packets to be flushed. Set to {@code false} when a new KEX starts. Initially {@code true}.
*/
protected volatile boolean kexFlushed = true;
protected final AtomicBoolean kexFlushed = new AtomicBoolean(true);

/**
* Indicates that the handler has been shut down.
*/
protected volatile boolean shutDown;
protected final AtomicBoolean shutDown = new AtomicBoolean();

/**
* Never {@code null}. Used to block some threads when writing packets while pending packets are still being flushed
* at the end of a KEX to avoid overrunning the flushing thread. Always set, initially fulfilled. At the beginning
* of a KEX a new future is installed, which is fulfilled at the end of the KEX once there are no more pending
* packets to be flushed.
*/
protected volatile DefaultKeyExchangeFuture kexFlushedFuture;
protected final AtomicReference<DefaultKeyExchangeFuture> kexFlushedFuture = new AtomicReference<>();

/**
* Creates a new {@link KeyExchangeMessageHandler} for the given {@code session}, using the given {@code Logger}.
Expand All @@ -128,8 +132,9 @@ public KeyExchangeMessageHandler(AbstractSession session, Logger log) {
this.session = Objects.requireNonNull(session);
this.log = Objects.requireNonNull(log);
// Start with a fulfilled kexFlushed future.
kexFlushedFuture = new DefaultKeyExchangeFuture(session.toString(), session.getFutureLock());
kexFlushedFuture.setValue(Boolean.TRUE);
DefaultKeyExchangeFuture initialFuture = new DefaultKeyExchangeFuture(session.toString(), session.getFutureLock());
initialFuture.setValue(Boolean.TRUE);
kexFlushedFuture.set(initialFuture);
}

public void updateState(Runnable update) {
Expand Down Expand Up @@ -170,10 +175,8 @@ public <V> V updateState(Supplier<V> update) {
*/
public DefaultKeyExchangeFuture initNewKeyExchange() {
return updateState(() -> {
kexFlushed = false;
DefaultKeyExchangeFuture oldFuture = kexFlushedFuture;
kexFlushedFuture = new DefaultKeyExchangeFuture(session.toString(), session.getFutureLock());
return oldFuture;
kexFlushed.set(false);
return kexFlushedFuture.getAndSet(new DefaultKeyExchangeFuture(session.toString(), session.getFutureLock()));
});
}

Expand All @@ -188,22 +191,22 @@ public SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture> terminateKeyExcha
return updateState(() -> {
int numPending = pendingPackets.size();
if (numPending == 0) {
kexFlushed = true;
kexFlushed.set(true);
}
return new SimpleImmutableEntry<>(Integer.valueOf(numPending), kexFlushedFuture);
return new SimpleImmutableEntry<>(Integer.valueOf(numPending), kexFlushedFuture.get());
});
}

/**
* Pretends all pending packets had been written. To be called when the {@link AbstractSession} closes.
*/
public void shutdown() {
shutDown.set(true);
SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture> items = updateState(() -> {
kexFlushed = true;
shutDown = true;
kexFlushed.set(true);
return new SimpleImmutableEntry<Integer, DefaultKeyExchangeFuture>(
Integer.valueOf(pendingPackets.size()),
kexFlushedFuture);
kexFlushedFuture.get());
});
items.getValue().setValue(Boolean.valueOf(items.getKey().intValue() == 0));
flushRunner.shutdownNow();
Expand Down Expand Up @@ -295,16 +298,16 @@ protected IoWriteFuture writeOrEnqueue(int cmd, Buffer buffer, long timeout, Tim
// Use the readLock here to give KEX state updates and the flushing thread priority.
lock.readLock().lock();
try {
if (shutDown) {
if (shutDown.get()) {
throw new SshException("Write attempt on closing session: " + SshConstants.getCommandMessageName(cmd));
}
KexState state = session.kexState.get();
boolean kexDone = KexState.DONE.equals(state) || KexState.KEYS.equals(state);
if (kexDone && kexFlushed) {
if (kexDone && kexFlushed.get()) {
// Not in KEX, no pending packets: out it goes.
return session.doWritePacket(buffer);
} else if (!holdsFutureLock && isBlockAllowed(cmd)) {
// KEX done, but still flushing: block until flushing is done, if we may block.
// Still in KEX or still flushing: block until flushing is done, if we may block.
//
// The future lock is a _very_ global lock used for synchronization in many futures, and in
// particular in the key exchange related futures; and it is accessible by client code. If we
Expand All @@ -323,7 +326,7 @@ protected IoWriteFuture writeOrEnqueue(int cmd, Buffer buffer, long timeout, Tim
// thread and ensures that the flushing thread does indeed terminate.
//
// Note that we block only for channel data.
block = kexFlushedFuture;
block = kexFlushedFuture.get();
} else {
// Still in KEX or still flushing and we cannot block the thread. Enqueue the packet; it will
// get written by the flushing thread at the end of KEX. Note that theoretically threads may
Expand Down Expand Up @@ -420,20 +423,19 @@ protected PendingWriteFuture enqueuePendingPacket(int cmd, Buffer buffer) {
* have been written
*/
protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
// kexFlushed must be set to true in all cases when this thread exits, **except** if a new KEX has started while
// flushing.
flushRunner.submit(() -> {
List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingFutures = new ArrayList<>();
boolean allFlushed = false;
DefaultKeyExchangeFuture newFuture = null;
// A Throwable when doWritePacket fails, or Boolean.FALSE if the session closes while flushing.
Object error = null;
try {
boolean warnedAboutChunkLimit = false;
int lastSize = -1;
int take = 2;
while (!allFlushed) {
if (!session.isOpen()) {
log.info("flushQueue({}): Session closed while flushing pending packets at end of KEX", session);
flushDone.setValue(Boolean.FALSE);
return;
}
// Using the writeLock this thread gets priority over the readLock used by writePacket(). Note that
// the outer loop essentially is just a loop around the critical region, so typically only one
// reader (i.e., writePacket() call) gets the lock before we get it again, and thus the flush really
Expand All @@ -445,16 +447,29 @@ protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
if (log.isDebugEnabled()) {
log.debug("flushQueue({}): All packets at end of KEX flushed", session);
}
kexFlushed = true;
kexFlushed.set(true);
allFlushed = true;
break;
}
if (kexFlushedFuture != flushDone) {

if (!session.isOpen()) {
log.info("flushQueue({}): Session closed while flushing pending packets at end of KEX", session);
AbstractIoWriteFuture aborted = new AbstractIoWriteFuture(session, null) {
};
aborted.setValue(new SshException("Session closed while flushing pending packets at end of KEX"));
drainQueueTo(pendingFutures, aborted);
kexFlushed.set(true);
error = Boolean.FALSE;
break;
}

DefaultKeyExchangeFuture currentFuture = kexFlushedFuture.get();
if (currentFuture != flushDone) {
if (log.isDebugEnabled()) {
log.debug("flushQueue({}): Stopping flushing pending packets", session);
}
// Another KEX was started. Exit and hook up the flushDone future with the new future.
newFuture = kexFlushedFuture;
newFuture = currentFuture;
break;
}
int newSize = pendingPackets.size();
Expand Down Expand Up @@ -486,26 +501,30 @@ protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
pending.getId());
}
written = session.doWritePacket(pending.getBuffer());
pendingFutures.add(new SimpleImmutableEntry<>(pending, written));
if (log.isTraceEnabled()) {
log.trace("flushQueue({}): Flushed a packet at end of KEX for {}", session,
pending.getId());
}
session.resetIdleTimeout();
} catch (Throwable e) {
log.error("flushQueue({}): Exception while flushing packet at end of KEX for {}", session,
pending.getId(), e);
pending.setException(e);
flushDone.setValue(e);
session.exceptionCaught(e);
AbstractIoWriteFuture aborted = new AbstractIoWriteFuture(pending.getId(), null) {
};
aborted.setValue(e);
pendingFutures.add(new SimpleImmutableEntry<>(pending, aborted));
drainQueueTo(pendingFutures, aborted);
kexFlushed.set(true);
// Remember the error, but close the session outside of the lock critical region.
error = e;
return;
}
pendingFutures.add(new SimpleImmutableEntry<>(pending, written));
if (log.isTraceEnabled()) {
log.trace("flushQueue({}): Flushed a packet at end of KEX for {}", session, pending.getId());
}
session.resetIdleTimeout();
}
if (pendingPackets.isEmpty()) {
if (log.isDebugEnabled()) {
log.debug("flushQueue({}): All packets at end of KEX flushed", session);
}
kexFlushed = true;
kexFlushed.set(true);
allFlushed = true;
break;
}
Expand All @@ -516,14 +535,16 @@ protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
} finally {
if (allFlushed) {
flushDone.setValue(Boolean.TRUE);
} else if (error != null) {
// We'll close the session (or it is closing already). Pretend we had written everything.
flushDone.setValue(error);
if (error instanceof Throwable) {
session.exceptionCaught((Throwable) error);
}
} else if (newFuture != null) {
newFuture.addListener(f -> {
Throwable error = f.getException();
if (error != null) {
flushDone.setValue(error);
} else {
flushDone.setValue(Boolean.TRUE);
}
Throwable failed = f.getException();
flushDone.setValue(failed != null ? failed : Boolean.TRUE);
});
}
// Connect all futures of packets that we wrote. We do this at the end instead of one-by-one inside the
Expand All @@ -532,4 +553,14 @@ protected void flushQueue(DefaultKeyExchangeFuture flushDone) {
}
});
}

private void drainQueueTo(
List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingAborted,
IoWriteFuture aborted) {
PendingWriteFuture pending = pendingPackets.poll();
while (pending != null) {
pendingAborted.add(new SimpleImmutableEntry<>(pending, aborted));
pending = pendingPackets.poll();
}
}
}

0 comments on commit e92a46b

Please sign in to comment.