Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
private SspiContextProvider _sspiContextProvider;

internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
Expand Down Expand Up @@ -650,7 +651,7 @@ public override string ConnectionString
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
}
}
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
Expand Down Expand Up @@ -710,7 +711,7 @@ public string AccessToken
}

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null, sspiContextProvider: null));
_accessToken = value;
}
}
Expand All @@ -733,11 +734,21 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
}

ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value, sspiContextProvider: null));
_accessTokenCallback = value;
}
}

internal SspiContextProvider SspiContextProvider
{
get { return _sspiContextProvider; }
set
{
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value));
_sspiContextProvider = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
Expand Down Expand Up @@ -1035,7 +1046,7 @@ public SqlCredential Credential
_credential = value;

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback, sspiContextProvider: null));
}
}

Expand Down Expand Up @@ -2265,7 +2276,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.Instance.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity)
Expand Down Expand Up @@ -2314,7 +2325,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.Instance.FindSqlConnectionOptions(key);

Expand Down Expand Up @@ -2352,7 +2363,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
{
con?.Dispose();
}
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);

SqlConnectionFactory.Instance.ClearPool(key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
internal readonly SspiContextProvider _sspiContextProvider;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;

Expand Down Expand Up @@ -460,8 +461,8 @@ internal SqlInternalConnectionTds(
bool applyTransientFaultHandling = false,
string accessToken = null,
IDbConnectionPool pool = null,
Func<SqlAuthenticationParameters, CancellationToken,
Task<SqlAuthenticationToken>> accessTokenCallback = null) : base(connectionOptions)
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null,
SspiContextProvider sspiContextProvider = null) : base(connectionOptions)
{
#if DEBUG
if (reconnectSessionData != null)
Expand Down Expand Up @@ -514,6 +515,7 @@ internal SqlInternalConnectionTds(
}

_accessTokenCallback = accessTokenCallback;
_sspiContextProvider = sspiContextProvider;

_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ internal void Connect(ServerInfo serverInfo,
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
_authenticationProvider = _physicalStateObj.CreateSspiContextProvider();
_authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider();
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | SSPI or Active Directory Authentication Library loaded for SQL Server based integrated authentication");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>

private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;

private SspiContextProvider _sspiContextProvider;

internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;

Expand Down Expand Up @@ -577,6 +579,16 @@ internal int ConnectRetryInterval
get => ((SqlConnectionString)ConnectionOptions).ConnectRetryInterval;
}

internal SspiContextProvider SspiContextProvider
{
get { return _sspiContextProvider; }
set
{
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: value));
_sspiContextProvider = value;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ConnectionString/*' />
[DefaultValue("")]
#pragma warning disable 618 // ignore obsolete warning about RecommendedAsConfigurable to use SettingsBindableAttribute
Expand Down Expand Up @@ -645,7 +657,7 @@ public override string ConnectionString
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
}
}
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProvider));
_connectionString = value; // Change _connectionString value only after value is validated
CacheConnectionStringProperties();
}
Expand Down Expand Up @@ -705,7 +717,7 @@ public string AccessToken

_accessToken = value;
// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, null, sspiContextProvider: null));
}
}

Expand All @@ -727,7 +739,7 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
}

ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, value));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, value, sspiContextProvider: null));
_accessTokenCallback = value;
}
}
Expand Down Expand Up @@ -1029,7 +1041,7 @@ public SqlCredential Credential
_credential = value;

// Need to call ConnectionString_Set to do proper pool group check
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _accessTokenCallback, sspiContextProvider: null));
}
}

Expand Down Expand Up @@ -2184,7 +2196,7 @@ public static void ChangePassword(string connectionString, string newPassword)
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.Instance.FindSqlConnectionOptions(key);
if (connectionOptions.IntegratedSecurity || connectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
Expand Down Expand Up @@ -2236,7 +2248,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
}

SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);

SqlConnectionString connectionOptions = SqlConnectionFactory.Instance.FindSqlConnectionOptions(key);

Expand Down Expand Up @@ -2277,7 +2289,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
{
con?.Dispose();
}
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProvider: null);

SqlConnectionFactory.Instance.ClearPool(key);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
SqlFedAuthToken _fedAuthToken = null;
internal byte[] _accessTokenInBytes;
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
internal readonly SspiContextProvider _sspiContextProvider;

private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;

Expand Down Expand Up @@ -472,8 +473,8 @@ internal SqlInternalConnectionTds(
bool applyTransientFaultHandling = false,
string accessToken = null,
IDbConnectionPool pool = null,
Func<SqlAuthenticationParameters, CancellationToken,
Task<SqlAuthenticationToken>> accessTokenCallback = null)
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null,
SspiContextProvider sspiContextProvider = null)
: base(connectionOptions)
{
#if DEBUG
Expand Down Expand Up @@ -525,6 +526,7 @@ internal SqlInternalConnectionTds(
}

_accessTokenCallback = accessTokenCallback;
_sspiContextProvider = sspiContextProvider;

_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ internal void Connect(ServerInfo serverInfo,
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
_authenticationProvider = _physicalStateObj.CreateSspiContextProvider();
_authenticationProvider = Connection._sspiContextProvider ?? _physicalStateObj.CreateSspiContextProvider();

if (!string.IsNullOrEmpty(serverInfo.ServerSPN))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ internal class SqlConnectionPoolKey : DbConnectionPoolKey
private readonly SqlCredential _credential;
private readonly string _accessToken;
private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
private SspiContextProvider _sspiContextProvider;

internal SqlCredential Credential => _credential;
internal string AccessToken => _accessToken;
internal Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> AccessTokenCallback => _accessTokenCallback;
internal SspiContextProvider SspiContextProvider => _sspiContextProvider;

internal override string ConnectionString
{
Expand All @@ -33,12 +35,18 @@ internal override string ConnectionString
}
}

internal SqlConnectionPoolKey(string connectionString, SqlCredential credential, string accessToken, Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback) : base(connectionString)
internal SqlConnectionPoolKey(
string connectionString,
SqlCredential credential,
string accessToken,
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback,
SspiContextProvider sspiContextProvider) : base(connectionString)
{
Debug.Assert(credential == null || accessToken == null || accessTokenCallback == null, "Credential, AccessToken, and Callback can't have a value at the same time.");
_credential = credential;
_accessToken = accessToken;
_accessTokenCallback = accessTokenCallback;
_sspiContextProvider = sspiContextProvider;
CalculateHashCode();
}

Expand All @@ -47,6 +55,8 @@ private SqlConnectionPoolKey(SqlConnectionPoolKey key) : base(key)
_credential = key.Credential;
_accessToken = key.AccessToken;
_accessTokenCallback = key._accessTokenCallback;
_sspiContextProvider = key._sspiContextProvider;

CalculateHashCode();
}

Expand All @@ -61,7 +71,8 @@ public override bool Equals(object obj)
&& _credential == key._credential
&& ConnectionString == key.ConnectionString
&& _accessTokenCallback == key._accessTokenCallback
&& string.CompareOrdinal(_accessToken, key._accessToken) == 0);
&& string.CompareOrdinal(_accessToken, key._accessToken) == 0
&& _sspiContextProvider == key._sspiContextProvider);
}

public override int GetHashCode()
Expand Down Expand Up @@ -94,6 +105,11 @@ private void CalculateHashCode()
_hashValue = _hashValue * 17 + _accessTokenCallback.GetHashCode();
}
}

if (_sspiContextProvider != null)
{
_hashValue = _hashValue * 17 + _sspiContextProvider.GetHashCode();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private void LoadSSPILibrary()
}
}

protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SspiAuthenticationParameters authParams)
protected override bool GenerateContext(ReadOnlySpan<byte> incomingBlob, IBufferWriter<byte> outgoingBlobWriter, SspiAuthenticationParameters authParams)
{
#if NETFRAMEWORK
SNIHandle handle = _physicalStateObj.Handle;
Expand All @@ -62,7 +62,7 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan<byte> incomingBlo
var sendLength = s_maxSSPILength;
var outBuff = outgoingBlobWriter.GetSpan((int)sendLength);

if (0 != SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource))
if (SniNativeWrapper.SniSecGenClientContext(handle, incomingBlob, outBuff, ref sendLength, authParams.Resource) != 0)
{
return false;
}
Expand Down
Loading
Loading