Skip to content

Commit f9f2b0e

Browse files
authored
Expose SshIdentificationReceived event (#1195)
* Fix #1191 * Expose `SshIdentificationReceived` event so that lib consumer can adjust based on server identification * revert unrelated code style change * revert OpenSSH 6.6 related tests * revert ConnectionBase * Add unit tests * Rename to `ServerIdentificationReceived` * rename
1 parent 54d0162 commit f9f2b0e

File tree

7 files changed

+128
-4
lines changed

7 files changed

+128
-4
lines changed

src/Renci.SshNet/BaseClient.cs

+13
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,11 @@ public TimeSpan KeepAliveInterval
153153
/// </example>
154154
public event EventHandler<HostKeyEventArgs> HostKeyReceived;
155155

156+
/// <summary>
157+
/// Occurs when server identification received.
158+
/// </summary>
159+
public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
160+
156161
/// <summary>
157162
/// Initializes a new instance of the <see cref="BaseClient"/> class.
158163
/// </summary>
@@ -390,6 +395,11 @@ private void Session_HostKeyReceived(object sender, HostKeyEventArgs e)
390395
HostKeyReceived?.Invoke(this, e);
391396
}
392397

398+
private void Session_ServerIdentificationReceived(object sender, SshIdentificationEventArgs e)
399+
{
400+
ServerIdentificationReceived?.Invoke(this, e);
401+
}
402+
393403
/// <summary>
394404
/// Performs application-defined tasks associated with freeing, releasing, or resetting unmanaged resources.
395405
/// </summary>
@@ -532,6 +542,7 @@ private Timer CreateKeepAliveTimer(TimeSpan dueTime, TimeSpan period)
532542
private ISession CreateAndConnectSession()
533543
{
534544
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
545+
session.ServerIdentificationReceived += Session_ServerIdentificationReceived;
535546
session.HostKeyReceived += Session_HostKeyReceived;
536547
session.ErrorOccured += Session_ErrorOccured;
537548

@@ -550,6 +561,7 @@ private ISession CreateAndConnectSession()
550561
private async Task<ISession> CreateAndConnectSessionAsync(CancellationToken cancellationToken)
551562
{
552563
var session = _serviceFactory.CreateSession(ConnectionInfo, _serviceFactory.CreateSocketFactory());
564+
session.ServerIdentificationReceived += Session_ServerIdentificationReceived;
553565
session.HostKeyReceived += Session_HostKeyReceived;
554566
session.ErrorOccured += Session_ErrorOccured;
555567

@@ -569,6 +581,7 @@ private void DisposeSession(ISession session)
569581
{
570582
session.ErrorOccured -= Session_ErrorOccured;
571583
session.HostKeyReceived -= Session_HostKeyReceived;
584+
session.ServerIdentificationReceived -= Session_ServerIdentificationReceived;
572585
session.Dispose();
573586
}
574587

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
3+
using Renci.SshNet.Connection;
4+
5+
namespace Renci.SshNet.Common
6+
{
7+
/// <summary>
8+
/// Provides data for the ServerIdentificationReceived events.
9+
/// </summary>
10+
public class SshIdentificationEventArgs : EventArgs
11+
{
12+
/// <summary>
13+
/// Initializes a new instance of the <see cref="SshIdentificationEventArgs"/> class.
14+
/// </summary>
15+
/// <param name="sshIdentification">The SSH identification.</param>
16+
public SshIdentificationEventArgs(SshIdentification sshIdentification)
17+
{
18+
SshIdentification = sshIdentification;
19+
}
20+
21+
/// <summary>
22+
/// Gets the SSH identification.
23+
/// </summary>
24+
public SshIdentification SshIdentification { get; private set; }
25+
}
26+
}

src/Renci.SshNet/Connection/SshIdentification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ namespace Renci.SshNet.Connection
55
/// <summary>
66
/// Represents an SSH identification.
77
/// </summary>
8-
internal sealed class SshIdentification
8+
public sealed class SshIdentification
99
{
1010
/// <summary>
1111
/// Initializes a new instance of the <see cref="SshIdentification"/> class with the specified protocol version

src/Renci.SshNet/ISession.cs

+5
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ internal interface ISession : IDisposable
260260
/// </summary>
261261
event EventHandler<ExceptionEventArgs> ErrorOccured;
262262

263+
/// <summary>
264+
/// Occurs when server identification received.
265+
/// </summary>
266+
event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
267+
263268
/// <summary>
264269
/// Occurs when host key received.
265270
/// </summary>

src/Renci.SshNet/Session.cs

+9
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,11 @@ public Message ClientInitMessage
366366
/// </summary>
367367
public event EventHandler<EventArgs> Disconnected;
368368

369+
/// <summary>
370+
/// Occurs when server identification received.
371+
/// </summary>
372+
public event EventHandler<SshIdentificationEventArgs> ServerIdentificationReceived;
373+
369374
/// <summary>
370375
/// Occurs when host key received.
371376
/// </summary>
@@ -624,6 +629,8 @@ public void Connect()
624629
DisconnectReason.ProtocolVersionNotSupported);
625630
}
626631

632+
ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
633+
627634
// Register Transport response messages
628635
RegisterMessage("SSH_MSG_DISCONNECT");
629636
RegisterMessage("SSH_MSG_IGNORE");
@@ -736,6 +743,8 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
736743
DisconnectReason.ProtocolVersionNotSupported);
737744
}
738745

746+
ServerIdentificationReceived?.Invoke(this, new SshIdentificationEventArgs(serverIdentification));
747+
739748
// Register Transport response messages
740749
RegisterMessage("SSH_MSG_DISCONNECT");
741750
RegisterMessage("SSH_MSG_IGNORE");

test/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

+9-3
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ public abstract class SessionTest_ConnectedBase
4646
protected Session Session { get; private set; }
4747
protected Socket ClientSocket { get; private set; }
4848
protected Socket ServerSocket { get; private set; }
49-
internal SshIdentification ServerIdentification { get; private set; }
49+
internal SshIdentification ServerIdentification { get; set; }
50+
protected bool CallSessionConnectWhenArrange { get; set; }
5051

5152
[TestInitialize]
5253
public void Setup()
@@ -159,6 +160,8 @@ protected virtual void SetupData()
159160
ServerListener.Start();
160161

161162
ClientSocket = new DirectConnector(_socketFactory).Connect(ConnectionInfo);
163+
164+
CallSessionConnectWhenArrange = true;
162165
}
163166

164167
private void CreateMocks()
@@ -180,7 +183,7 @@ private void SetupMocks()
180183
_ = ServiceFactoryMock.Setup(p => p.CreateProtocolVersionExchange())
181184
.Returns(_protocolVersionExchangeMock.Object);
182185
_ = _protocolVersionExchangeMock.Setup(p => p.Start(Session.ClientVersion, ClientSocket, ConnectionInfo.Timeout))
183-
.Returns(ServerIdentification);
186+
.Returns(() => ServerIdentification);
184187
_ = ServiceFactoryMock.Setup(p => p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
185188
_ = _keyExchangeMock.Setup(p => p.Name)
186189
.Returns(_keyExchangeAlgorithm);
@@ -212,7 +215,10 @@ protected void Arrange()
212215
SetupData();
213216
SetupMocks();
214217

215-
Session.Connect();
218+
if (CallSessionConnectWhenArrange)
219+
{
220+
Session.Connect();
221+
}
216222
}
217223

218224
protected virtual void ClientAuthentication_Callback()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
3+
using Renci.SshNet.Connection;
4+
5+
namespace Renci.SshNet.Tests.Classes
6+
{
7+
[TestClass]
8+
public class SessionTest_Connected_ServerIdentificationReceived : SessionTest_ConnectedBase
9+
{
10+
protected override void SetupData()
11+
{
12+
base.SetupData();
13+
14+
CallSessionConnectWhenArrange = false;
15+
16+
Session.ServerIdentificationReceived += (s, e) =>
17+
{
18+
if ((e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.5", System.StringComparison.Ordinal) || e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6", System.StringComparison.Ordinal))
19+
&& !e.SshIdentification.SoftwareVersion.StartsWith("OpenSSH_6.6.1", System.StringComparison.Ordinal))
20+
{
21+
_ = ConnectionInfo.KeyExchangeAlgorithms.Remove("curve25519-sha256");
22+
_ = ConnectionInfo.KeyExchangeAlgorithms.Remove("[email protected]");
23+
}
24+
};
25+
}
26+
27+
protected override void Act()
28+
{
29+
}
30+
31+
[TestMethod]
32+
[DataRow("OpenSSH_6.5")]
33+
[DataRow("OpenSSH_6.5p1")]
34+
[DataRow("OpenSSH_6.5 PKIX")]
35+
[DataRow("OpenSSH_6.6")]
36+
[DataRow("OpenSSH_6.6p1")]
37+
[DataRow("OpenSSH_6.6 PKIX")]
38+
public void ShouldExcludeCurve25519KexWhenServerIs(string softwareVersion)
39+
{
40+
ServerIdentification = new SshIdentification("2.0", softwareVersion);
41+
42+
Session.Connect();
43+
44+
Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256"));
45+
Assert.IsFalse(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("[email protected]"));
46+
}
47+
48+
[TestMethod]
49+
[DataRow("OpenSSH_6.6.1")]
50+
[DataRow("OpenSSH_6.6.1p1")]
51+
[DataRow("OpenSSH_6.6.1 PKIX")]
52+
[DataRow("OpenSSH_6.7")]
53+
[DataRow("OpenSSH_6.7p1")]
54+
[DataRow("OpenSSH_6.7 PKIX")]
55+
public void ShouldIncludeCurve25519KexWhenServerIs(string softwareVersion)
56+
{
57+
ServerIdentification = new SshIdentification("2.0", softwareVersion);
58+
59+
Session.Connect();
60+
61+
Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("curve25519-sha256"));
62+
Assert.IsTrue(ConnectionInfo.KeyExchangeAlgorithms.ContainsKey("[email protected]"));
63+
}
64+
}
65+
}

0 commit comments

Comments
 (0)