Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion src/Common/tests/System/Net/VirtualNetwork/VirtualNetwork.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,29 @@ namespace System.Net.Test.Common
{
public class VirtualNetwork
{
public class VirtualNetworkConnectionBroken : Exception
{
public VirtualNetworkConnectionBroken() : base("Connection broken") { }
}

private readonly int WaitForReadDataTimeoutMilliseconds = 30 * 1000;

private readonly ConcurrentQueue<byte[]> _clientWriteQueue = new ConcurrentQueue<byte[]>();
private readonly ConcurrentQueue<byte[]> _serverWriteQueue = new ConcurrentQueue<byte[]>();

private readonly SemaphoreSlim _clientDataAvailable = new SemaphoreSlim(0);
private readonly SemaphoreSlim _serverDataAvailable = new SemaphoreSlim(0);

public bool DisableConnectionBreaking { get; set; } = false;
private bool _connectionBroken = false;

public void ReadFrame(bool server, out byte[] buffer)
{
if (_connectionBroken)
{
throw new VirtualNetworkConnectionBroken();
}

SemaphoreSlim semaphore;
ConcurrentQueue<byte[]> packetQueue;

Expand All @@ -39,6 +52,11 @@ public void ReadFrame(bool server, out byte[] buffer)
throw new TimeoutException("VirtualNetwork: Timeout reading the next frame.");
}

if (_connectionBroken)
{
throw new VirtualNetworkConnectionBroken();
}

bool dequeueSucceeded = false;
int remainingTries = 3;
int backOffDelayMilliseconds = 2;
Expand All @@ -62,6 +80,11 @@ public void ReadFrame(bool server, out byte[] buffer)

public void WriteFrame(bool server, byte[] buffer)
{
if (_connectionBroken)
{
throw new VirtualNetworkConnectionBroken();
}

SemaphoreSlim semaphore;
ConcurrentQueue<byte[]> packetQueue;

Expand All @@ -82,5 +105,15 @@ public void WriteFrame(bool server, byte[] buffer)
packetQueue.Enqueue(innerBuffer);
semaphore.Release();
}

public void BreakConnection()
{
if (!DisableConnectionBreaking)
{
_connectionBroken = true;
_serverDataAvailable.Release(1_000_000);
_clientDataAvailable.Release(1_000_000);
}
}
}
}
10 changes: 10 additions & 0 deletions src/Common/tests/System/Net/VirtualNetwork/VirtualNetworkStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,15 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As

public override void EndWrite(IAsyncResult asyncResult) =>
TaskToApm.End(asyncResult);

protected override void Dispose(bool disposing)
{
if (disposing)
{
_network.BreakConnection();
}

base.Dispose(disposing);
}
}
}
3 changes: 3 additions & 0 deletions src/System.Net.Security/ref/System.Net.Security.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public enum EncryptionPolicy
RequireEncryption = 0,
}
public delegate System.Security.Cryptography.X509Certificates.X509Certificate LocalCertificateSelectionCallback(object sender, string targetHost, System.Security.Cryptography.X509Certificates.X509CertificateCollection localCertificates, System.Security.Cryptography.X509Certificates.X509Certificate remoteCertificate, string[] acceptableIssuers);
public delegate System.Security.Cryptography.X509Certificates.X509Certificate ServerCertificateSelectionCallback(object sender, string hostName);

public partial class NegotiateStream : AuthenticatedStream
{
public NegotiateStream(System.IO.Stream innerStream) : base(innerStream, false) { }
Expand Down Expand Up @@ -108,6 +110,7 @@ public class SslServerAuthenticationOptions
public X509RevocationMode CertificateRevocationCheckMode { get { throw null; } set { } }
public List<SslApplicationProtocol> ApplicationProtocols { get { throw null; } set { } }
public RemoteCertificateValidationCallback RemoteCertificateValidationCallback { get { throw null; } set { } }
public ServerCertificateSelectionCallback ServerCertificateSelectionCallback { get { throw null; } set { } }
public EncryptionPolicy EncryptionPolicy { get { throw null; } set { } }
}
public partial class SslClientAuthenticationOptions
Expand Down
1 change: 1 addition & 0 deletions src/System.Net.Security/src/System.Net.Security.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
<Compile Include="System\Net\FixedSizeReader.cs" />
<Compile Include="System\Net\HelperAsyncResults.cs" />
<Compile Include="System\Net\Logging\NetEventSource.cs" />
<Compile Include="System\Net\Security\SniHelper.cs" />
<Compile Include="System\Net\Security\SslApplicationProtocol.cs" />
<Compile Include="System\Net\Security\SslAuthenticationOptions.cs" />
<Compile Include="System\Net\Security\SslClientAuthenticationOptions.cs" />
Expand Down
18 changes: 14 additions & 4 deletions src/System.Net.Security/src/System/Net/Security/SecureChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -624,15 +624,26 @@ private bool AcquireClientCredentials(ref byte[] thumbPrint)
//
// Acquire Server Side Certificate information and set it on the class.
//
private bool AcquireServerCredentials(ref byte[] thumbPrint)
private bool AcquireServerCredentials(ref byte[] thumbPrint, byte[] clientHello)
{
if (NetEventSource.IsEnabled)
NetEventSource.Enter(this);

X509Certificate localCertificate = null;
bool cachedCred = false;

if (_sslAuthenticationOptions.CertSelectionDelegate != null)
if (_sslAuthenticationOptions.ServerCertSelectionDelegate != null)
{
string serverIdentity = SniHelper.GetServerName(clientHello);
localCertificate = _sslAuthenticationOptions.ServerCertSelectionDelegate(serverIdentity);

if (localCertificate == null)
{
throw new AuthenticationException(SR.net_ssl_io_no_server_cert);
}
}
// This probably never gets called as this is a client options delegate
else if (_sslAuthenticationOptions.CertSelectionDelegate != null)
{
X509CertificateCollection tempCollection = new X509CertificateCollection();
tempCollection.Add(_sslAuthenticationOptions.ServerCertificate);
Expand Down Expand Up @@ -744,7 +755,6 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref
#if TRACE_VERBOSE
if (NetEventSource.IsEnabled) NetEventSource.Enter(this, $"_refreshCredentialNeeded = {_refreshCredentialNeeded}");
#endif

if (offset < 0 || offset > (input == null ? 0 : input.Length))
{
NetEventSource.Fail(this, "Argument 'offset' out of range.");
Expand Down Expand Up @@ -786,7 +796,7 @@ private SecurityStatusPal GenerateToken(byte[] input, int offset, int count, ref
if (_refreshCredentialNeeded)
{
cachedCreds = _sslAuthenticationOptions.IsServer
? AcquireServerCredentials(ref thumbPrint)
? AcquireServerCredentials(ref thumbPrint, input)
: AcquireClientCredentials(ref thumbPrint);
}

Expand Down
Loading