Skip to content

Commit 1bb7a6c

Browse files
committed
Remove unnecessary DNS lookup in Connect
The library currently performs a DNS lookup of the desired host, takes the first returned IP address and connects to that. Instead, we can just pass the hostname down to System.Net.Sockets which will do the right thing, potentially trying multiple addresses if needed.
1 parent fdbc4d3 commit 1bb7a6c

File tree

43 files changed

+147
-154
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+147
-154
lines changed

src/Renci.SshNet/Abstractions/SocketAbstraction.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,17 +47,17 @@ public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
4747
return socket;
4848
}
4949

50-
public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
50+
public static void Connect(Socket socket, EndPoint remoteEndpoint, TimeSpan connectTimeout)
5151
{
5252
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false);
5353
}
5454

55-
public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
55+
public static async Task ConnectAsync(Socket socket, EndPoint remoteEndpoint, CancellationToken cancellationToken)
5656
{
5757
await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false);
5858
}
5959

60-
private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
60+
private static void ConnectCore(Socket socket, EndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
6161
{
6262
var connectCompleted = new ManualResetEvent(initialState: false);
6363
var args = new SocketAsyncEventArgs

src/Renci.SshNet/Abstractions/SocketExtensions.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ public void GetResult()
8585
}
8686
}
8787

88-
public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
88+
public static async Task ConnectAsync(this Socket socket, EndPoint remoteEndpoint, CancellationToken cancellationToken)
8989
{
9090
cancellationToken.ThrowIfCancellationRequested();
9191

src/Renci.SshNet/Connection/ConnectorBase.cs

+12-25
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,21 @@ protected ConnectorBase(ISocketFactory socketFactory)
2929
public abstract Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken);
3030

3131
/// <summary>
32-
/// Establishes a socket connection to the specified host and port.
32+
/// Establishes a socket connection to the specified endpoint.
3333
/// </summary>
34-
/// <param name="host">The host name of the server to connect to.</param>
35-
/// <param name="port">The port to connect to.</param>
34+
/// <param name="endPoint">The <see cref="EndPoint"/> representing the server to connect to.</param>
3635
/// <param name="timeout">The maximum time to wait for the connection to be established.</param>
3736
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
3837
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
39-
protected Socket SocketConnect(string host, int port, TimeSpan timeout)
38+
protected Socket SocketConnect(EndPoint endPoint, TimeSpan timeout)
4039
{
41-
var ipAddress = Dns.GetHostAddresses(host)[0];
42-
var ep = new IPEndPoint(ipAddress, port);
40+
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}'.", endPoint));
4341

44-
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));
45-
46-
var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
42+
var socket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
4743

4844
try
4945
{
50-
SocketAbstraction.Connect(socket, ep, timeout);
46+
SocketAbstraction.Connect(socket, endPoint, timeout);
5147

5248
const int socketBufferSize = 10 * Session.MaximumSshPacketSize;
5349
socket.SendBufferSize = socketBufferSize;
@@ -62,31 +58,22 @@ protected Socket SocketConnect(string host, int port, TimeSpan timeout)
6258
}
6359

6460
/// <summary>
65-
/// Establishes a socket connection to the specified host and port.
61+
/// Establishes a socket connection to the specified endpoint.
6662
/// </summary>
67-
/// <param name="host">The host name of the server to connect to.</param>
68-
/// <param name="port">The port to connect to.</param>
63+
/// <param name="endPoint">The <see cref="EndPoint"/> representing the server to connect to.</param>
6964
/// <param name="cancellationToken">The cancellation token to observe.</param>
7065
/// <exception cref="SshOperationTimeoutException">The connection failed to establish within the configured <see cref="ConnectionInfo.Timeout"/>.</exception>
7166
/// <exception cref="SocketException">An error occurred trying to establish the connection.</exception>
72-
protected async Task<Socket> SocketConnectAsync(string host, int port, CancellationToken cancellationToken)
67+
protected async Task<Socket> SocketConnectAsync(EndPoint endPoint, CancellationToken cancellationToken)
7368
{
7469
cancellationToken.ThrowIfCancellationRequested();
7570

76-
#if NET6_0_OR_GREATER
77-
var ipAddress = (await Dns.GetHostAddressesAsync(host, cancellationToken).ConfigureAwait(false))[0];
78-
#else
79-
var ipAddress = (await Dns.GetHostAddressesAsync(host).ConfigureAwait(false))[0];
80-
#endif
81-
82-
var ep = new IPEndPoint(ipAddress, port);
83-
84-
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}:{1}'.", host, port));
71+
DiagnosticAbstraction.Log(string.Format("Initiating connection to '{0}'.", endPoint));
8572

86-
var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
73+
var socket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
8774
try
8875
{
89-
await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false);
76+
await SocketAbstraction.ConnectAsync(socket, endPoint, cancellationToken).ConfigureAwait(false);
9077

9178
const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
9279
socket.SendBufferSize = socketBufferSize;

src/Renci.SshNet/Connection/DirectConnector.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Net.Sockets;
1+
using System.Net;
2+
using System.Net.Sockets;
23
using System.Threading;
34

45
namespace Renci.SshNet.Connection
@@ -12,12 +13,12 @@ public DirectConnector(ISocketFactory socketFactory)
1213

1314
public override Socket Connect(IConnectionInfo connectionInfo)
1415
{
15-
return SocketConnect(connectionInfo.Host, connectionInfo.Port, connectionInfo.Timeout);
16+
return SocketConnect(new DnsEndPoint(connectionInfo.Host, connectionInfo.Port), connectionInfo.Timeout);
1617
}
1718

1819
public override System.Threading.Tasks.Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
1920
{
20-
return SocketConnectAsync(connectionInfo.Host, connectionInfo.Port, cancellationToken);
21+
return SocketConnectAsync(new DnsEndPoint(connectionInfo.Host, connectionInfo.Port), cancellationToken);
2122
}
2223
}
2324
}

src/Renci.SshNet/Connection/ISocketFactory.cs

+3-5
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,14 @@ namespace Renci.SshNet.Connection
88
internal interface ISocketFactory
99
{
1010
/// <summary>
11-
/// Creates a <see cref="Socket"/> with the specified <see cref="AddressFamily"/>,
12-
/// <see cref="SocketType"/> and <see cref="ProtocolType"/> that does not use the
13-
/// <c>Nagle</c> algorithm.
11+
/// Creates a <see cref="Socket"/> with the specified <see cref="SocketType"/>
12+
/// and <see cref="ProtocolType"/> that does not use the <c>Nagle</c> algorithm.
1413
/// </summary>
15-
/// <param name="addressFamily">The <see cref="AddressFamily"/>.</param>
1614
/// <param name="socketType">The <see cref="SocketType"/>.</param>
1715
/// <param name="protocolType">The <see cref="ProtocolType"/>.</param>
1816
/// <returns>
1917
/// The <see cref="Socket"/>.
2018
/// </returns>
21-
Socket Create(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType);
19+
Socket Create(SocketType socketType, ProtocolType protocolType);
2220
}
2321
}

src/Renci.SshNet/Connection/ProxyConnector.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Net;
23
using System.Net.Sockets;
34
using System.Threading;
45
using System.Threading.Tasks;
@@ -52,7 +53,7 @@ Task HandleProxyConnectAsync(IConnectionInfo connectionInfo, Socket socket, Canc
5253
/// </returns>
5354
public override Socket Connect(IConnectionInfo connectionInfo)
5455
{
55-
var socket = SocketConnect(connectionInfo.ProxyHost, connectionInfo.ProxyPort, connectionInfo.Timeout);
56+
var socket = SocketConnect(new DnsEndPoint(connectionInfo.ProxyHost, connectionInfo.ProxyPort), connectionInfo.Timeout);
5657

5758
try
5859
{
@@ -78,7 +79,7 @@ public override Socket Connect(IConnectionInfo connectionInfo)
7879
/// </returns>
7980
public override async Task<Socket> ConnectAsync(IConnectionInfo connectionInfo, CancellationToken cancellationToken)
8081
{
81-
var socket = await SocketConnectAsync(connectionInfo.ProxyHost, connectionInfo.ProxyPort, cancellationToken).ConfigureAwait(false);
82+
var socket = await SocketConnectAsync(new DnsEndPoint(connectionInfo.ProxyHost, connectionInfo.ProxyPort), cancellationToken).ConfigureAwait(false);
8283

8384
try
8485
{

src/Renci.SshNet/Connection/SocketFactory.cs

+3-13
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,10 @@ namespace Renci.SshNet.Connection
77
/// </summary>
88
internal sealed class SocketFactory : ISocketFactory
99
{
10-
/// <summary>
11-
/// Creates a <see cref="Socket"/> with the specified <see cref="AddressFamily"/>,
12-
/// <see cref="SocketType"/> and <see cref="ProtocolType"/> that does not use the
13-
/// <c>Nagle</c> algorithm.
14-
/// </summary>
15-
/// <param name="addressFamily">The <see cref="AddressFamily"/>.</param>
16-
/// <param name="socketType">The <see cref="SocketType"/>.</param>
17-
/// <param name="protocolType">The <see cref="ProtocolType"/>.</param>
18-
/// <returns>
19-
/// The <see cref="Socket"/>.
20-
/// </returns>
21-
public Socket Create(AddressFamily addressFamily, SocketType socketType, ProtocolType protocolType)
10+
/// <inheritdoc/>
11+
public Socket Create(SocketType socketType, ProtocolType protocolType)
2212
{
23-
return new Socket(addressFamily, socketType, protocolType) { NoDelay = true };
13+
return new Socket(socketType, protocolType) { NoDelay = true };
2414
}
2515
}
2616
}

test/Renci.SshNet.Tests/Classes/Connection/DirectConnectorTest_Connect_ConnectionRefusedByServer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ protected override void SetupData()
2626
_stopWatch = new Stopwatch();
2727
_actualException = null;
2828

29-
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
29+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
3030
}
3131

3232
protected override void SetupMocks()
3333
{
34-
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
34+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
3535
.Returns(_clientSocket);
3636
}
3737

@@ -95,7 +95,7 @@ public void ClientSocketShouldHaveBeenDisposed()
9595
[TestMethod]
9696
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
9797
{
98-
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
98+
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
9999
Times.Once());
100100
}
101101
}

test/Renci.SshNet.Tests/Classes/Connection/DirectConnectorTest_Connect_ConnectionSucceeded.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ protected override void SetupData()
3232
_stopWatch = new Stopwatch();
3333
_disconnected = false;
3434

35-
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
35+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
3636

3737
_server = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.Port));
3838
_server.Disconnected += (socket) => _disconnected = true;
@@ -42,7 +42,7 @@ protected override void SetupData()
4242

4343
protected override void SetupMocks()
4444
{
45-
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
45+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
4646
.Returns(_clientSocket);
4747
}
4848

@@ -106,7 +106,7 @@ public void NoBytesShouldHaveBeenReadFromSocket()
106106
[TestMethod]
107107
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
108108
{
109-
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
109+
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
110110
Times.Once());
111111
}
112112
}

test/Renci.SshNet.Tests/Classes/Connection/DirectConnectorTest_Connect_HostNameInvalid.cs

+8
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ namespace Renci.SshNet.Tests.Classes.Connection
88
public class DirectConnectorTest_Connect_HostNameInvalid : DirectConnectorTestBase
99
{
1010
private ConnectionInfo _connectionInfo;
11+
private Socket _clientSocket;
1112
private SocketException _actualException;
1213

1314
protected override void SetupData()
@@ -16,6 +17,7 @@ protected override void SetupData()
1617

1718
_connectionInfo = CreateConnectionInfo("invalid.");
1819
_actualException = null;
20+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
1921
}
2022

2123
protected override void Act()
@@ -31,6 +33,12 @@ protected override void Act()
3133
}
3234
}
3335

36+
protected override void SetupMocks()
37+
{
38+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
39+
.Returns(_clientSocket);
40+
}
41+
3442
[TestMethod]
3543
public void ConnectShouldHaveThrownSocketException()
3644
{

test/Renci.SshNet.Tests/Classes/Connection/DirectConnectorTest_Connect_TimeoutConnectingToServer.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ protected override void SetupData()
3333
_stopWatch = new Stopwatch();
3434
_actualException = null;
3535

36-
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
36+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
3737
}
3838

3939
protected override void SetupMocks()
4040
{
41-
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
41+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
4242
.Returns(_clientSocket);
4343
}
4444

@@ -116,7 +116,7 @@ public void ClientSocketShouldHaveBeenDisposed()
116116
[TestMethod]
117117
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
118118
{
119-
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
119+
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
120120
Times.Once());
121121
}
122122
}

test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_ConnectionToProxyRefused.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ protected override void SetupData()
3636
_stopWatch = new Stopwatch();
3737
_actualException = null;
3838

39-
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
39+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
4040
}
4141

4242
protected override void SetupMocks()
4343
{
44-
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
44+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
4545
.Returns(_clientSocket);
4646
}
4747

@@ -105,7 +105,7 @@ public void ClientSocketShouldHaveBeenDisposed()
105105
[TestMethod]
106106
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
107107
{
108-
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
108+
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
109109
Times.Once());
110110
}
111111
}

test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_ProxyClosesConnectionBeforeStatusLineIsSent.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ protected override void SetupData()
3939
};
4040
_actualException = null;
4141

42-
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
42+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
4343

4444
_proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
4545
_proxyServer.Disconnected += socket => _disconnected = true;
@@ -52,7 +52,7 @@ protected override void SetupData()
5252

5353
protected override void SetupMocks()
5454
{
55-
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
55+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
5656
.Returns(_clientSocket);
5757
}
5858

@@ -110,7 +110,7 @@ public void ClientSocketShouldHaveBeenDisposed()
110110
[TestMethod]
111111
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
112112
{
113-
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
113+
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
114114
Times.Once());
115115
}
116116
}

test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_ProxyHostInvalid.cs

+8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public class HttpConnectorTest_Connect_ProxyHostInvalid : HttpConnectorTestBase
99
{
1010
private ConnectionInfo _connectionInfo;
1111
private SocketException _actualException;
12+
private Socket _clientSocket;
1213

1314
protected override void SetupData()
1415
{
@@ -24,6 +25,13 @@ protected override void SetupData()
2425
"proxyPwd",
2526
new KeyboardInteractiveAuthenticationMethod("user"));
2627
_actualException = null;
28+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
29+
}
30+
31+
protected override void SetupMocks()
32+
{
33+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
34+
.Returns(_clientSocket);
2735
}
2836

2937
protected override void Act()

test/Renci.SshNet.Tests/Classes/Connection/HttpConnectorTest_Connect_ProxyPasswordIsEmpty.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ protected override void SetupData()
4848
"\r\n");
4949
_bytesReceivedByProxy = new List<byte>();
5050
_disconnected = false;
51-
_clientSocket = SocketFactory.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
51+
_clientSocket = SocketFactory.Create(SocketType.Stream, ProtocolType.Tcp);
5252

5353
_proxyServer = new AsyncSocketListener(new IPEndPoint(IPAddress.Loopback, _connectionInfo.ProxyPort));
5454
_proxyServer.Disconnected += (socket) => _disconnected = true;
@@ -72,7 +72,7 @@ protected override void SetupData()
7272

7373
protected override void SetupMocks()
7474
{
75-
_ = SocketFactoryMock.Setup(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp))
75+
_ = SocketFactoryMock.Setup(p => p.Create(SocketType.Stream, ProtocolType.Tcp))
7676
.Returns(_clientSocket);
7777
}
7878

@@ -131,7 +131,7 @@ public void ClientSocketShouldBeConnected()
131131
[TestMethod]
132132
public void CreateOnSocketFactoryShouldHaveBeenInvokedOnce()
133133
{
134-
SocketFactoryMock.Verify(p => p.Create(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp),
134+
SocketFactoryMock.Verify(p => p.Create(SocketType.Stream, ProtocolType.Tcp),
135135
Times.Once());
136136
}
137137
}

0 commit comments

Comments
 (0)