diff --git a/src/Renci.SshNet/BaseClient.cs b/src/Renci.SshNet/BaseClient.cs index 6879e1349..f6044f796 100644 --- a/src/Renci.SshNet/BaseClient.cs +++ b/src/Renci.SshNet/BaseClient.cs @@ -240,6 +240,11 @@ public void Connect() var session = Session; if (session is null || !session.IsConnected) { + if (session is not null) + { + DisposeSession(session); + } + Session = CreateAndConnectSession(); } @@ -304,6 +309,11 @@ public async Task ConnectAsync(CancellationToken cancellationToken) var session = Session; if (session is null || !session.IsConnected) { + if (session is not null) + { + DisposeSession(session); + } + Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false); } diff --git a/src/Renci.SshNet/SftpClient.cs b/src/Renci.SshNet/SftpClient.cs index b82e605cc..d4d04183d 100644 --- a/src/Renci.SshNet/SftpClient.cs +++ b/src/Renci.SshNet/SftpClient.cs @@ -2496,15 +2496,8 @@ protected override void OnConnected() { base.OnConnected(); - var sftpSession = _sftpSession; - if (sftpSession is null) - { - _sftpSession = CreateAndConnectToSftpSession(); - } - else if (!sftpSession.IsOpen) - { - sftpSession.Connect(); - } + _sftpSession?.Dispose(); + _sftpSession = CreateAndConnectToSftpSession(); } /// diff --git a/test/Renci.SshNet.IntegrationTests/ConnectivityTests.cs b/test/Renci.SshNet.IntegrationTests/ConnectivityTests.cs index 16b80e161..3fae0a8dc 100644 --- a/test/Renci.SshNet.IntegrationTests/ConnectivityTests.cs +++ b/test/Renci.SshNet.IntegrationTests/ConnectivityTests.cs @@ -326,6 +326,54 @@ public async Task SftpClient_HandleSftpSessionCloseAsync() } } + [TestMethod] + public void SftpClient_HandleSftpSessionAbortByServer() + { + using (var client = new SftpClient(_connectionInfoFactory.Create())) + { + client.Connect(); + Assert.IsTrue(client.IsConnected); + + _sshConnectionDisruptor.BreakConnections(); + WaitForConnectionInterruption(client); + Assert.IsFalse(client.IsConnected); + + client.Connect(); + Assert.IsTrue(client.IsConnected); + + foreach (var file in client.ListDirectory(".")) + { + } + + client.Disconnect(); + Assert.IsFalse(client.IsConnected); + } + } + + [TestMethod] + public async Task SftpClient_HandleSftpSessionAbortByServerAsync() + { + using (var client = new SftpClient(_connectionInfoFactory.Create())) + { + await client.ConnectAsync(CancellationToken.None); + Assert.IsTrue(client.IsConnected); + + _sshConnectionDisruptor.BreakConnections(); + WaitForConnectionInterruption(client); + Assert.IsFalse(client.IsConnected); + + await client.ConnectAsync(CancellationToken.None); + Assert.IsTrue(client.IsConnected); + + await foreach (var file in client.ListDirectoryAsync(".", CancellationToken.None)) + { + } + + client.Disconnect(); + Assert.IsFalse(client.IsConnected); + } + } + [TestMethod] public void Common_DetectSessionKilledOnServer() {