diff --git a/Directory.Build.props b/Directory.Build.props index 423b9e4a7..3b2062ebc 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -30,6 +30,7 @@ + net8.0; net9.0; net10.0; net462; net472; netstandard2.0 true ../../build/MSAL.snk @@ -199,4 +200,8 @@ runtime; build; native; contentfiles; analyzers + + + $(DefineConstants);SUPPORTS_MTLS; + diff --git a/src/Microsoft.Identity.Web.DownstreamApi/DownstreamApi.cs b/src/Microsoft.Identity.Web.DownstreamApi/DownstreamApi.cs index cfe802afe..1fd567af3 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/DownstreamApi.cs +++ b/src/Microsoft.Identity.Web.DownstreamApi/DownstreamApi.cs @@ -1,22 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -using System; -using System.Collections.Generic; -using System.Diagnostics.CodeAnalysis; -using System.IO; -using System.Linq; -using System.Net.Http; -using System.Runtime.CompilerServices; -using System.Security.Claims; -using System.Text; -using System.Text.Json; -using System.Text.Json.Serialization.Metadata; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using Microsoft.Identity.Abstractions; +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Linq; +using System.Net.Http; +using System.Runtime.CompilerServices; +using System.Security.Claims; +using System.Text; +using System.Text.Json; +using System.Text.Json.Serialization.Metadata; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using Microsoft.Identity.Abstractions; using Microsoft.Identity.Client; namespace Microsoft.Identity.Web @@ -26,6 +26,12 @@ internal partial class DownstreamApi : IDownstreamApi { private readonly IAuthorizationHeaderProvider _authorizationHeaderProvider; private readonly IHttpClientFactory _httpClientFactory; + + // This MSAL HTTP client factory is used to create HTTP clients with mTLS binding certificate. + // Note, that it doesn't replace _httpClientFactory to keep backward compatibility and ability + // to create named HTTP clients for non-mTLS scenarios. + private readonly IMsalHttpClientFactory? _msalHttpClientFactory; + private readonly IOptionsMonitor _namedDownstreamApiOptions; private const string Authorization = "Authorization"; protected readonly ILogger _logger; @@ -43,10 +49,33 @@ public DownstreamApi( IOptionsMonitor namedDownstreamApiOptions, IHttpClientFactory httpClientFactory, ILogger logger) + : this(authorizationHeaderProvider, + namedDownstreamApiOptions, + httpClientFactory, + logger, + msalHttpClientFactory: null) + { + } + + /// + /// Constructor which accepts optional MSAL HTTP client factory. + /// + /// Authorization header provider. + /// Named options provider. + /// HTTP client factory. + /// Logger. + /// The MSAL HTTP client factory for mTLS PoP scenarios. + public DownstreamApi( + IAuthorizationHeaderProvider authorizationHeaderProvider, + IOptionsMonitor namedDownstreamApiOptions, + IHttpClientFactory httpClientFactory, + ILogger logger, + IMsalHttpClientFactory? msalHttpClientFactory) { _authorizationHeaderProvider = authorizationHeaderProvider; _namedDownstreamApiOptions = namedDownstreamApiOptions; _httpClientFactory = httpClientFactory; + _msalHttpClientFactory = msalHttpClientFactory ?? new MsalMtlsHttpClientFactory(httpClientFactory); _logger = logger; } @@ -436,7 +465,7 @@ public Task CallApiForAppAsync( string stringContent = await content.ReadAsStringAsync(); if (mediaType == "application/json") { - return JsonSerializer.Deserialize(stringContent, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); + return JsonSerializer.Deserialize(stringContent, new JsonSerializerOptions { PropertyNameCaseInsensitive = true }); } if (mediaType != null && !mediaType.StartsWith("text/", StringComparison.OrdinalIgnoreCase)) { @@ -514,11 +543,17 @@ public Task CallApiForAppAsync( new HttpMethod(effectiveOptions.HttpMethod), apiUrl); - await UpdateRequestAsync(httpRequestMessage, content, effectiveOptions, appToken, user, cancellationToken); + // Request result will contain authorization header and potentially binding certificate for mTLS + var requestResult = await UpdateRequestAsync(httpRequestMessage, content, effectiveOptions, appToken, user, cancellationToken); - using HttpClient client = string.IsNullOrEmpty(serviceName) ? _httpClientFactory.CreateClient() : _httpClientFactory.CreateClient(serviceName); + // If a binding certificate is specified (which means mTLS is required) and MSAL mTLS HTTP factory is present + // then create an HttpClient with the certificate by using IMsalMtlsHttpClientFactory. + // Otherwise use the default HttpClientFactory with optional named client. + using HttpClient client = requestResult?.BindingCertificate != null && _msalHttpClientFactory != null && _msalHttpClientFactory is IMsalMtlsHttpClientFactory msalMtlsHttpClientFactory + ? msalMtlsHttpClientFactory.GetHttpClient(requestResult.BindingCertificate) + : (string.IsNullOrEmpty(serviceName) ? _httpClientFactory.CreateClient() : _httpClientFactory.CreateClient(serviceName)); - // Send the HTTP message + // Send the HTTP message var downstreamApiResult = await client.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false); // Retry only if the resource sent 401 Unauthorized with WWW-Authenticate header and claims @@ -541,7 +576,7 @@ public Task CallApiForAppAsync( return downstreamApiResult; } - internal /* internal for test */ async Task UpdateRequestAsync( + internal /* internal for test */ async Task UpdateRequestAsync( HttpRequestMessage httpRequestMessage, HttpContent? content, DownstreamApiOptions effectiveOptions, @@ -558,15 +593,38 @@ public Task CallApiForAppAsync( effectiveOptions.RequestAppToken = appToken; + AuthorizationHeaderInformation? authorizationHeaderInformation = null; + // Obtention of the authorization header (except when calling an anonymous endpoint // which is done by not specifying any scopes if (effectiveOptions.Scopes != null && effectiveOptions.Scopes.Any()) { - string authorizationHeader = await _authorizationHeaderProvider.CreateAuthorizationHeaderAsync( - effectiveOptions.Scopes, - effectiveOptions, - user, - cancellationToken).ConfigureAwait(false); + string authorizationHeader = string.Empty; + + // Firstly check if it's token binding scenario so authorization header provider returns + + // a binding certificate along with acquired authorization header. + if (_authorizationHeaderProvider is IAuthorizationHeaderProvider2 authorizationHeaderBoundProviderForMtls) + { + var authorizationHeaderResult = await authorizationHeaderBoundProviderForMtls.CreateAuthorizationHeaderAsync( + effectiveOptions, + user, + cancellationToken).ConfigureAwait(false); + + if (authorizationHeaderResult.Succeeded) + { + authorizationHeaderInformation = authorizationHeaderResult.Result; + authorizationHeader = authorizationHeaderInformation?.AuthorizationHeaderValue ?? string.Empty; + } + } + else + { + authorizationHeader = await _authorizationHeaderProvider.CreateAuthorizationHeaderAsync( + effectiveOptions.Scopes, + effectiveOptions, + user, + cancellationToken).ConfigureAwait(false); + } if (authorizationHeader.StartsWith(AuthSchemeDstsSamlBearer, StringComparison.OrdinalIgnoreCase)) { @@ -582,54 +640,56 @@ public Task CallApiForAppAsync( { Logger.UnauthenticatedApiCall(_logger, null); } - if (!string.IsNullOrEmpty(effectiveOptions.AcceptHeader)) - { - httpRequestMessage.Headers.Accept.ParseAdd(effectiveOptions.AcceptHeader); - } - - // Add extra headers if specified directly on DownstreamApiOptions - if (effectiveOptions.ExtraHeaderParameters != null) - { - foreach (var header in effectiveOptions.ExtraHeaderParameters) - { - httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value); - } - } - - // Add extra query parameters if specified directly on DownstreamApiOptions - if (effectiveOptions.ExtraQueryParameters != null && effectiveOptions.ExtraQueryParameters.Count > 0) - { - var uriBuilder = new UriBuilder(httpRequestMessage.RequestUri!); - var existingQuery = uriBuilder.Query; - var queryString = new StringBuilder(existingQuery); - - foreach (var queryParam in effectiveOptions.ExtraQueryParameters) - { - if (queryString.Length > 1) // if there are existing query parameters - { - queryString.Append('&'); - } - else if (queryString.Length == 0) - { - queryString.Append('?'); - } - - queryString.Append(Uri.EscapeDataString(queryParam.Key)); - queryString.Append('='); - queryString.Append(Uri.EscapeDataString(queryParam.Value)); - } - - uriBuilder.Query = queryString.ToString().TrimStart('?'); - httpRequestMessage.RequestUri = uriBuilder.Uri; - } - - // Opportunity to change the request message + if (!string.IsNullOrEmpty(effectiveOptions.AcceptHeader)) + { + httpRequestMessage.Headers.Accept.ParseAdd(effectiveOptions.AcceptHeader); + } + + // Add extra headers if specified directly on DownstreamApiOptions + if (effectiveOptions.ExtraHeaderParameters != null) + { + foreach (var header in effectiveOptions.ExtraHeaderParameters) + { + httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value); + } + } + + // Add extra query parameters if specified directly on DownstreamApiOptions + if (effectiveOptions.ExtraQueryParameters != null && effectiveOptions.ExtraQueryParameters.Count > 0) + { + var uriBuilder = new UriBuilder(httpRequestMessage.RequestUri!); + var existingQuery = uriBuilder.Query; + var queryString = new StringBuilder(existingQuery); + + foreach (var queryParam in effectiveOptions.ExtraQueryParameters) + { + if (queryString.Length > 1) // if there are existing query parameters + { + queryString.Append('&'); + } + else if (queryString.Length == 0) + { + queryString.Append('?'); + } + + queryString.Append(Uri.EscapeDataString(queryParam.Key)); + queryString.Append('='); + queryString.Append(Uri.EscapeDataString(queryParam.Value)); + } + + uriBuilder.Query = queryString.ToString().TrimStart('?'); + httpRequestMessage.RequestUri = uriBuilder.Uri; + } + + // Opportunity to change the request message effectiveOptions.CustomizeHttpRequestMessage?.Invoke(httpRequestMessage); + + return authorizationHeaderInformation; } internal /* for test */ static Dictionary CallerSDKDetails { get; } = new() { - { "caller-sdk-id", "IdWeb_1" }, + { "caller-sdk-id", "IdWeb_1" }, { "caller-sdk-ver", IdHelper.GetIdWebVersion() } }; @@ -657,14 +717,14 @@ private static void AddCallerSDKTelemetry(DownstreamApiOptions effectiveOptions) internal static async Task ReadErrorResponseContentAsync(HttpResponseMessage response, CancellationToken cancellationToken = default) { const int maxErrorContentLength = 4096; - + long? contentLength = response.Content.Headers.ContentLength; - + if (contentLength.HasValue && contentLength.Value > maxErrorContentLength) { return $"[Error response too large: {contentLength.Value} bytes, not captured]"; } - + // Use streaming to read only up to maxErrorContentLength to avoid loading entire response into memory #if NET5_0_OR_GREATER using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false); @@ -672,18 +732,18 @@ internal static async Task ReadErrorResponseContentAsync(HttpResponseMes using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false); #endif using var reader = new StreamReader(stream); - + char[] buffer = new char[maxErrorContentLength]; int readCount = await reader.ReadBlockAsync(buffer, 0, maxErrorContentLength).ConfigureAwait(false); - + string errorResponseContent = new string(buffer, 0, readCount); - + // Check if there's more content that was truncated if (readCount == maxErrorContentLength && reader.Peek() != -1) { errorResponseContent += "... (truncated)"; } - + return errorResponseContent; } } diff --git a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net10.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net10.0/InternalAPI.Unshipped.txt index 40d750c56..983d4aa1b 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net10.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net10.0/InternalAPI.Unshipped.txt @@ -1,4 +1,6 @@ #nullable enable +Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void +Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.DownstreamApi.DeserializeOutputAsync(System.Net.Http.HttpResponseMessage! response, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, System.Text.Json.Serialization.Metadata.JsonTypeInfo! outputJsonTypeInfo, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.DownstreamApi.DeserializeOutputAsync(System.Net.Http.HttpResponseMessage! response, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.DownstreamApi.Logger.HttpRequestError(Microsoft.Extensions.Logging.ILogger! logger, string! ServiceName, string! BaseUrl, string! RelativePath, int statusCode, string! responseContent, System.Exception? ex) -> void diff --git a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net462/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net462/InternalAPI.Unshipped.txt index 7dc5c5811..57a84df1b 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net462/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net462/InternalAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void +Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! diff --git a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net472/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net472/InternalAPI.Unshipped.txt index 7dc5c5811..57a84df1b 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net472/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net472/InternalAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void +Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! diff --git a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net8.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net8.0/InternalAPI.Unshipped.txt index 7dc5c5811..57a84df1b 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net8.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net8.0/InternalAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void +Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! diff --git a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net9.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net9.0/InternalAPI.Unshipped.txt index 7dc5c5811..57a84df1b 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net9.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/net9.0/InternalAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void +Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! diff --git a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt index 7dc5c5811..57a84df1b 100644 --- a/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.DownstreamApi/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt @@ -1 +1,3 @@ #nullable enable +Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void +Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/ConfidentialClientApplicationBuilderExtension.cs b/src/Microsoft.Identity.Web.TokenAcquisition/ConfidentialClientApplicationBuilderExtension.cs index 324f0e458..340076157 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/ConfidentialClientApplicationBuilderExtension.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/ConfidentialClientApplicationBuilderExtension.cs @@ -58,6 +58,34 @@ public static async Task WithClientCredent } } + public static async Task WithBindingCertificateAsync( + this ConfidentialClientApplicationBuilder builder, + IEnumerable clientCredentials, + ILogger logger, + ICredentialsLoader credentialsLoader, + CredentialSourceLoaderParameters? credentialSourceLoaderParameters, + bool isTokenBinding) + { + var credential = await LoadCredentialForMsalOrFailAsync( + clientCredentials, + logger, + credentialsLoader, + credentialSourceLoaderParameters).ConfigureAwait(false); + + if (credential?.Certificate != null) + { + return builder.WithCertificate(credential.Certificate); + } + + if (isTokenBinding) + { + logger.LogError("A certificate, which is required for token binding, is missing in loaded credentials."); + throw new InvalidOperationException(IDWebErrorMessage.MissingTokenBindingCertificate); + } + + return builder; + } + internal /* for test */ async static Task LoadCredentialForMsalOrFailAsync( IEnumerable clientCredentials, ILogger logger, diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs b/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs index a3db6ad04..7c24a9956 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/DefaultAuthorizationHeaderProvider.cs @@ -11,7 +11,7 @@ namespace Microsoft.Identity.Web { - internal sealed class DefaultAuthorizationHeaderProvider : IAuthorizationHeaderProvider + internal sealed class DefaultAuthorizationHeaderProvider : IAuthorizationHeaderProvider, IAuthorizationHeaderProvider2 { private readonly ITokenAcquisition _tokenAcquisition; @@ -97,6 +97,33 @@ public async Task CreateAuthorizationHeaderAsync( return result.CreateAuthorizationHeader(); } + /// + public async Task> CreateAuthorizationHeaderAsync( + DownstreamApiOptions downstreamApiOptions, + ClaimsPrincipal? claimsPrincipal = null, + CancellationToken cancellationToken = default) + { + var newTokenAcquisitionOptions = CreateTokenAcquisitionOptionsFromApiOptions(downstreamApiOptions, cancellationToken); + + // Token binding flow currently supports only app tokens. + var tokenAcquisitionResult = await _tokenAcquisition.GetAuthenticationResultForAppAsync( + downstreamApiOptions.Scopes?.FirstOrDefault() ?? string.Empty, + downstreamApiOptions?.AcquireTokenOptions.AuthenticationOptionsName, + downstreamApiOptions?.AcquireTokenOptions.Tenant, + newTokenAcquisitionOptions).ConfigureAwait(false); + + UpdateOriginalTokenAcquisitionOptions(downstreamApiOptions?.AcquireTokenOptions, newTokenAcquisitionOptions); + + var authorizationHeader = tokenAcquisitionResult.CreateAuthorizationHeader(); + var authorizationHeaderInformation = new AuthorizationHeaderInformation() + { + AuthorizationHeaderValue = authorizationHeader, + BindingCertificate = tokenAcquisitionResult.BindingCertificate + }; + + return new(authorizationHeaderInformation); + } + private static TokenAcquisitionOptions CreateTokenAcquisitionOptionsFromApiOptions( AuthorizationHeaderProviderOptions? downstreamApiOptions, CancellationToken cancellationToken) diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/IDWebErrorMessage.cs b/src/Microsoft.Identity.Web.TokenAcquisition/IDWebErrorMessage.cs index d01e300ae..07e79bc40 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/IDWebErrorMessage.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/IDWebErrorMessage.cs @@ -21,6 +21,7 @@ internal static class IDWebErrorMessage public const string MissingRequiredScopesForAuthorizationFilter = "IDW10108: RequiredScope Attribute does not contain a value. The scopes need to be set on the controller, the page or action. See https://aka.ms/ms-id-web/required-scope-attribute. "; public const string ClientCertificatesHaveExpiredOrCannotBeLoaded = "IDW10109: No credential could be loaded. This can happen when certificates passed to the configuration have expired or can't be loaded and the code isn't running on Azure to be able to use Managed Identity, Pod Identity etc. Details: "; public const string ClientSecretAndCredentialsCannotBeCombined = "IDW10110: ClientSecret top level configuration cannot be combined with ClientCredentials. Instead, add a new entry in the ClientCredentials array describing the secret."; + public const string MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials."; // Authorization IDW10200 = "IDW10200:" public const string NeitherScopeOrRolesClaimFoundInToken = "IDW10201: Neither scope nor roles claim was found in the bearer token. Authentication scheme used: '{0}'. "; diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/MergedOptions.cs b/src/Microsoft.Identity.Web.TokenAcquisition/MergedOptions.cs index 62606f3d5..54847977e 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/MergedOptions.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/MergedOptions.cs @@ -50,6 +50,7 @@ public ConfidentialClientApplicationOptions ConfidentialClientApplicationOptions public LogLevel LogLevel { get; set; } public string? RedirectUri { get; set; } public bool EnableCacheSynchronization { get; set; } + public bool IsTokenBinding { get; set; } internal bool MergedWithCca { get; set; } // This is for supporting for CIAM authorities including custom url domains, see https://github.com/AzureAD/microsoft-identity-web/issues/2690 internal bool PreserveAuthority { get; set; } diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/MsalMtlsHttpClientFactory.cs b/src/Microsoft.Identity.Web.TokenAcquisition/MsalMtlsHttpClientFactory.cs new file mode 100644 index 000000000..f3d6aeadd --- /dev/null +++ b/src/Microsoft.Identity.Web.TokenAcquisition/MsalMtlsHttpClientFactory.cs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client; + +namespace Microsoft.Identity.Web +{ + /// + /// Provides a factory for creating HTTP clients configured for mTLS authentication with using binding certificate. + /// It uses a hybrid approach with leveraging IHttpClientFactory for non-mTLS HTTP clients and maintaining + /// a pool of mTLS clients with using certificate thumbprint as a key. + /// + public sealed class MsalMtlsHttpClientFactory : IMsalMtlsHttpClientFactory + { + private const long MaxMtlsHttpClientCountInPool = 1000; + private const long MaxResponseContentBufferSizeInBytes = 1024 * 1024; + + // Please see (https://aka.ms/msal-httpclient-info) for important information regarding the HttpClient. + private static readonly ConcurrentDictionary s_mtlsHttpClientPool = new ConcurrentDictionary(); + private static readonly object s_cacheLock = new object(); + + private readonly IHttpClientFactory _httpClientFactory; + + /// + /// Initializes a new instance of the MsalMtlsHttpClientFactory class using the specified HTTP client factory. + /// + /// The factory used to create HttpClient instances for mutual TLS (mTLS) operations. Cannot be null. + public MsalMtlsHttpClientFactory(IHttpClientFactory httpClientFactory) + { + _httpClientFactory = httpClientFactory; + } + + /// + /// Creates and configures a new instance of with telemetry headers applied. + /// + /// + /// The returned includes a telemetry header for tracking or + /// diagnostics purposes. Callers are responsible for disposing the instance when it is + /// no longer needed. + /// + /// A new instance with telemetry information included in the default request headers. + public HttpClient GetHttpClient() + { + HttpClient httpClient = _httpClientFactory.CreateClient(); + httpClient.DefaultRequestHeaders.Add(Constants.TelemetryHeaderKey, IdHelper.CreateTelemetryInfo()); + return httpClient; + } + + /// + /// Returns an instance of configured to use the specified X.509 client certificate for + /// mutual TLS authentication. + /// + /// + /// The returned instance is pooled and reused for the given certificate. + /// The client includes a telemetry header in each request. Callers should not modify the default + /// request headers or dispose the returned instance. + /// + /// The X.509 certificate to use for client authentication. If , a default instance without client certificate authentication is returned. + /// A instance configured for mutual TLS authentication using the specified certificate or default instance. + public HttpClient GetHttpClient(X509Certificate2 x509Certificate2) + { + if (x509Certificate2 == null) + { + return GetHttpClient(); + } + + string key = x509Certificate2.Thumbprint; + HttpClient httpClient = CreateMtlsHttpClient(x509Certificate2); + httpClient = s_mtlsHttpClientPool.GetOrAdd(key, httpClient); + return httpClient; + } + + private HttpClient CreateMtlsHttpClient(X509Certificate2 bindingCertificate) + { +#if SUPPORTS_MTLS + CheckAndManageCache(); + + if (bindingCertificate == null) + { + throw new ArgumentNullException(nameof(bindingCertificate), "A valid X509 certificate must be provided for mTLS."); + } + + HttpClientHandler handler = new(); + handler.ClientCertificates.Add(bindingCertificate); + + // HTTP client factory can't be used there because HTTP client handler needs to be configured + // before a HTTP client instance is created + var httpClient = new HttpClient(handler); + ConfigureRequestHeadersAndSize(httpClient); + return httpClient; +#else + throw new NotSupportedException("mTLS is not supported on this platform."); +#endif + } + + private static void CheckAndManageCache() + { + lock (s_cacheLock) + { + if (s_mtlsHttpClientPool.Count >= MaxMtlsHttpClientCountInPool) + { + s_mtlsHttpClientPool.Clear(); + } + } + } + + private static void ConfigureRequestHeadersAndSize(HttpClient httpClient) + { + httpClient.MaxResponseContentBufferSize = MaxResponseContentBufferSizeInBytes; + httpClient.DefaultRequestHeaders.Accept.Clear(); + httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + httpClient.DefaultRequestHeaders.Add(Constants.TelemetryHeaderKey, IdHelper.CreateTelemetryInfo()); + } + } +} diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/InternalAPI.Unshipped.txt index ec8e30c95..460f69d3a 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/InternalAPI.Unshipped.txt @@ -1,4 +1,9 @@ #nullable enable const Microsoft.Identity.Web.Constants.UserIdKey = "IDWEB_USER_ID" -> string! readonly Microsoft.Identity.Web.TokenAcquisition._certificatesObservers -> System.Collections.Generic.IReadOnlyList! +const Microsoft.Identity.Web.IDWebErrorMessage.MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials." -> string! +Microsoft.Identity.Web.DefaultAuthorizationHeaderProvider.CreateAuthorizationHeaderAsync(Microsoft.Identity.Abstractions.DownstreamApiOptions! downstreamApiOptions, System.Security.Claims.ClaimsPrincipal? claimsPrincipal = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task>! +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.get -> bool +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.set -> void +static Microsoft.Identity.Web.ConfidentialClientApplicationBuilderExtension.WithBindingCertificateAsync(this Microsoft.Identity.Client.ConfidentialClientApplicationBuilder! builder, System.Collections.Generic.IEnumerable! clientCredentials, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Abstractions.ICredentialsLoader! credentialsLoader, Microsoft.Identity.Abstractions.CredentialSourceLoaderParameters? credentialSourceLoaderParameters, bool isTokenBinding) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.TokenAcquisition.MergeExtraQueryParameters(Microsoft.Identity.Web.MergedOptions! mergedOptions, Microsoft.Identity.Web.TokenAcquisitionOptions? tokenAcquisitionOptions) -> System.Collections.Generic.Dictionary? diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/PublicAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/PublicAPI.Unshipped.txt index ec706b293..a562dd0dc 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net10.0/PublicAPI.Unshipped.txt @@ -10,6 +10,10 @@ Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.MicrosoftIdentityMessageHandlerOptions() -> void Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.get -> System.Collections.Generic.IList! Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.set -> void +Microsoft.Identity.Web.MsalMtlsHttpClientFactory +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient() -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2! x509Certificate2) -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.MsalMtlsHttpClientFactory(System.Net.Http.IHttpClientFactory! httpClientFactory) -> void override Microsoft.Identity.Web.MicrosoftIdentityMessageHandler.SendAsync(System.Net.Http.HttpRequestMessage! request, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.GetAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request) -> Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions? static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.WithAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request, Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions! options) -> System.Net.Http.HttpRequestMessage! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/InternalAPI.Unshipped.txt index 62d5c481d..a92211697 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/InternalAPI.Unshipped.txt @@ -1,2 +1,7 @@ #nullable enable +const Microsoft.Identity.Web.IDWebErrorMessage.MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials." -> string! +Microsoft.Identity.Web.DefaultAuthorizationHeaderProvider.CreateAuthorizationHeaderAsync(Microsoft.Identity.Abstractions.DownstreamApiOptions! downstreamApiOptions, System.Security.Claims.ClaimsPrincipal? claimsPrincipal = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task>! +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.get -> bool +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.set -> void +static Microsoft.Identity.Web.ConfidentialClientApplicationBuilderExtension.WithBindingCertificateAsync(this Microsoft.Identity.Client.ConfidentialClientApplicationBuilder! builder, System.Collections.Generic.IEnumerable! clientCredentials, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Abstractions.ICredentialsLoader! credentialsLoader, Microsoft.Identity.Abstractions.CredentialSourceLoaderParameters? credentialSourceLoaderParameters, bool isTokenBinding) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.TokenAcquisition.MergeExtraQueryParameters(Microsoft.Identity.Web.MergedOptions! mergedOptions, Microsoft.Identity.Web.TokenAcquisitionOptions? tokenAcquisitionOptions) -> System.Collections.Generic.Dictionary? diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/PublicAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/PublicAPI.Unshipped.txt index 13496576f..d6ab94814 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/PublicAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net462/PublicAPI.Unshipped.txt @@ -9,6 +9,10 @@ Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.MicrosoftIdentityMessageHandlerOptions() -> void Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.get -> System.Collections.Generic.IList! Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.set -> void +Microsoft.Identity.Web.MsalMtlsHttpClientFactory +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient() -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2! x509Certificate2) -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.MsalMtlsHttpClientFactory(System.Net.Http.IHttpClientFactory! httpClientFactory) -> void override Microsoft.Identity.Web.MicrosoftIdentityMessageHandler.SendAsync(System.Net.Http.HttpRequestMessage! request, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.GetAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request) -> Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions? static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.WithAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request, Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions! options) -> System.Net.Http.HttpRequestMessage! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/InternalAPI.Unshipped.txt index 62d5c481d..a92211697 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/InternalAPI.Unshipped.txt @@ -1,2 +1,7 @@ #nullable enable +const Microsoft.Identity.Web.IDWebErrorMessage.MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials." -> string! +Microsoft.Identity.Web.DefaultAuthorizationHeaderProvider.CreateAuthorizationHeaderAsync(Microsoft.Identity.Abstractions.DownstreamApiOptions! downstreamApiOptions, System.Security.Claims.ClaimsPrincipal? claimsPrincipal = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task>! +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.get -> bool +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.set -> void +static Microsoft.Identity.Web.ConfidentialClientApplicationBuilderExtension.WithBindingCertificateAsync(this Microsoft.Identity.Client.ConfidentialClientApplicationBuilder! builder, System.Collections.Generic.IEnumerable! clientCredentials, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Abstractions.ICredentialsLoader! credentialsLoader, Microsoft.Identity.Abstractions.CredentialSourceLoaderParameters? credentialSourceLoaderParameters, bool isTokenBinding) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.TokenAcquisition.MergeExtraQueryParameters(Microsoft.Identity.Web.MergedOptions! mergedOptions, Microsoft.Identity.Web.TokenAcquisitionOptions? tokenAcquisitionOptions) -> System.Collections.Generic.Dictionary? diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/PublicAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/PublicAPI.Unshipped.txt index 13496576f..d6ab94814 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/PublicAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net472/PublicAPI.Unshipped.txt @@ -9,6 +9,10 @@ Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.MicrosoftIdentityMessageHandlerOptions() -> void Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.get -> System.Collections.Generic.IList! Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.set -> void +Microsoft.Identity.Web.MsalMtlsHttpClientFactory +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient() -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2! x509Certificate2) -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.MsalMtlsHttpClientFactory(System.Net.Http.IHttpClientFactory! httpClientFactory) -> void override Microsoft.Identity.Web.MicrosoftIdentityMessageHandler.SendAsync(System.Net.Http.HttpRequestMessage! request, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.GetAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request) -> Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions? static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.WithAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request, Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions! options) -> System.Net.Http.HttpRequestMessage! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/InternalAPI.Unshipped.txt index 62d5c481d..a92211697 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/InternalAPI.Unshipped.txt @@ -1,2 +1,7 @@ #nullable enable +const Microsoft.Identity.Web.IDWebErrorMessage.MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials." -> string! +Microsoft.Identity.Web.DefaultAuthorizationHeaderProvider.CreateAuthorizationHeaderAsync(Microsoft.Identity.Abstractions.DownstreamApiOptions! downstreamApiOptions, System.Security.Claims.ClaimsPrincipal? claimsPrincipal = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task>! +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.get -> bool +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.set -> void +static Microsoft.Identity.Web.ConfidentialClientApplicationBuilderExtension.WithBindingCertificateAsync(this Microsoft.Identity.Client.ConfidentialClientApplicationBuilder! builder, System.Collections.Generic.IEnumerable! clientCredentials, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Abstractions.ICredentialsLoader! credentialsLoader, Microsoft.Identity.Abstractions.CredentialSourceLoaderParameters? credentialSourceLoaderParameters, bool isTokenBinding) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.TokenAcquisition.MergeExtraQueryParameters(Microsoft.Identity.Web.MergedOptions! mergedOptions, Microsoft.Identity.Web.TokenAcquisitionOptions? tokenAcquisitionOptions) -> System.Collections.Generic.Dictionary? diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/PublicAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/PublicAPI.Unshipped.txt index 13496576f..d6ab94814 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net8.0/PublicAPI.Unshipped.txt @@ -9,6 +9,10 @@ Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.MicrosoftIdentityMessageHandlerOptions() -> void Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.get -> System.Collections.Generic.IList! Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.set -> void +Microsoft.Identity.Web.MsalMtlsHttpClientFactory +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient() -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2! x509Certificate2) -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.MsalMtlsHttpClientFactory(System.Net.Http.IHttpClientFactory! httpClientFactory) -> void override Microsoft.Identity.Web.MicrosoftIdentityMessageHandler.SendAsync(System.Net.Http.HttpRequestMessage! request, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.GetAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request) -> Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions? static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.WithAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request, Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions! options) -> System.Net.Http.HttpRequestMessage! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/InternalAPI.Unshipped.txt index 62d5c481d..a92211697 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/InternalAPI.Unshipped.txt @@ -1,2 +1,7 @@ #nullable enable +const Microsoft.Identity.Web.IDWebErrorMessage.MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials." -> string! +Microsoft.Identity.Web.DefaultAuthorizationHeaderProvider.CreateAuthorizationHeaderAsync(Microsoft.Identity.Abstractions.DownstreamApiOptions! downstreamApiOptions, System.Security.Claims.ClaimsPrincipal? claimsPrincipal = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task>! +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.get -> bool +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.set -> void +static Microsoft.Identity.Web.ConfidentialClientApplicationBuilderExtension.WithBindingCertificateAsync(this Microsoft.Identity.Client.ConfidentialClientApplicationBuilder! builder, System.Collections.Generic.IEnumerable! clientCredentials, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Abstractions.ICredentialsLoader! credentialsLoader, Microsoft.Identity.Abstractions.CredentialSourceLoaderParameters? credentialSourceLoaderParameters, bool isTokenBinding) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.TokenAcquisition.MergeExtraQueryParameters(Microsoft.Identity.Web.MergedOptions! mergedOptions, Microsoft.Identity.Web.TokenAcquisitionOptions? tokenAcquisitionOptions) -> System.Collections.Generic.Dictionary? diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/PublicAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/PublicAPI.Unshipped.txt index 13496576f..d6ab94814 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/net9.0/PublicAPI.Unshipped.txt @@ -9,6 +9,10 @@ Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.MicrosoftIdentityMessageHandlerOptions() -> void Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.get -> System.Collections.Generic.IList! Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.set -> void +Microsoft.Identity.Web.MsalMtlsHttpClientFactory +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient() -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2! x509Certificate2) -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.MsalMtlsHttpClientFactory(System.Net.Http.IHttpClientFactory! httpClientFactory) -> void override Microsoft.Identity.Web.MicrosoftIdentityMessageHandler.SendAsync(System.Net.Http.HttpRequestMessage! request, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.GetAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request) -> Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions? static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.WithAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request, Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions! options) -> System.Net.Http.HttpRequestMessage! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt index 62d5c481d..a92211697 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/InternalAPI.Unshipped.txt @@ -1,2 +1,7 @@ #nullable enable +const Microsoft.Identity.Web.IDWebErrorMessage.MissingTokenBindingCertificate = "IDW10111: A certificate, which is required for token binding, is missing in loaded credentials." -> string! +Microsoft.Identity.Web.DefaultAuthorizationHeaderProvider.CreateAuthorizationHeaderAsync(Microsoft.Identity.Abstractions.DownstreamApiOptions! downstreamApiOptions, System.Security.Claims.ClaimsPrincipal? claimsPrincipal = null, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task>! +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.get -> bool +Microsoft.Identity.Web.MergedOptions.IsTokenBinding.set -> void +static Microsoft.Identity.Web.ConfidentialClientApplicationBuilderExtension.WithBindingCertificateAsync(this Microsoft.Identity.Client.ConfidentialClientApplicationBuilder! builder, System.Collections.Generic.IEnumerable! clientCredentials, Microsoft.Extensions.Logging.ILogger! logger, Microsoft.Identity.Abstractions.ICredentialsLoader! credentialsLoader, Microsoft.Identity.Abstractions.CredentialSourceLoaderParameters? credentialSourceLoaderParameters, bool isTokenBinding) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.TokenAcquisition.MergeExtraQueryParameters(Microsoft.Identity.Web.MergedOptions! mergedOptions, Microsoft.Identity.Web.TokenAcquisitionOptions? tokenAcquisitionOptions) -> System.Collections.Generic.Dictionary? diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt index 13496576f..d6ab94814 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt +++ b/src/Microsoft.Identity.Web.TokenAcquisition/PublicAPI/netstandard2.0/PublicAPI.Unshipped.txt @@ -9,6 +9,10 @@ Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.MicrosoftIdentityMessageHandlerOptions() -> void Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.get -> System.Collections.Generic.IList! Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions.Scopes.set -> void +Microsoft.Identity.Web.MsalMtlsHttpClientFactory +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient() -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.GetHttpClient(System.Security.Cryptography.X509Certificates.X509Certificate2! x509Certificate2) -> System.Net.Http.HttpClient! +Microsoft.Identity.Web.MsalMtlsHttpClientFactory.MsalMtlsHttpClientFactory(System.Net.Http.IHttpClientFactory! httpClientFactory) -> void override Microsoft.Identity.Web.MicrosoftIdentityMessageHandler.SendAsync(System.Net.Http.HttpRequestMessage! request, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task! static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.GetAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request) -> Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions? static Microsoft.Identity.Web.HttpRequestMessageAuthenticationExtensions.WithAuthenticationOptions(this System.Net.Http.HttpRequestMessage! request, Microsoft.Identity.Web.MicrosoftIdentityMessageHandlerOptions! options) -> System.Net.Http.HttpRequestMessage! diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/ServiceCollectionExtensions.cs b/src/Microsoft.Identity.Web.TokenAcquisition/ServiceCollectionExtensions.cs index 71368cc29..49c8a50fc 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/ServiceCollectionExtensions.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/ServiceCollectionExtensions.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using System.Net.Http; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; @@ -63,14 +64,14 @@ public static IServiceCollection AddTokenAcquisition( services.TryAddSingleton, ConfidentialClientApplicationOptionsMerger>(); } - ServiceDescriptor? tokenAcquisitionService = services.FirstOrDefault(s => s.ServiceType == typeof(ITokenAcquisition)); + ServiceDescriptor? tokenAcquisitionService = services.FirstOrDefault(s => s.ServiceType == typeof(ITokenAcquisition)); ServiceDescriptor? tokenAcquisitionInternalService = services.FirstOrDefault(s => s.ServiceType == typeof(ITokenAcquisitionInternal)); ServiceDescriptor? tokenAcquisitionhost = services.FirstOrDefault(s => s.ServiceType == typeof(ITokenAcquisitionHost)); ServiceDescriptor? authenticationHeaderCreator = services.FirstOrDefault(s => s.ServiceType == typeof(IAuthorizationHeaderProvider)); ServiceDescriptor? tokenAcquirerFactory = services.FirstOrDefault(s => s.ServiceType == typeof(ITokenAcquirerFactory)); ServiceDescriptor? authSchemeInfoProvider = services.FirstOrDefault(s => s.ServiceType == typeof(Abstractions.IAuthenticationSchemeInformationProvider)); - - if (tokenAcquisitionService != null && tokenAcquisitionInternalService != null && + + if (tokenAcquisitionService != null && tokenAcquisitionInternalService != null && tokenAcquisitionhost != null && authenticationHeaderCreator != null && authSchemeInfoProvider != null) { if (isTokenAcquisitionSingleton ^ (tokenAcquisitionService.Lifetime == ServiceLifetime.Singleton)) diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs index 9d0d1ce1e..26676da84 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquirer.cs @@ -65,7 +65,10 @@ async Task ITokenAcquirer.GetTokenForAppAsync(string scope, result.IdToken, result.Scopes, result.CorrelationId, - result.TokenType); + result.TokenType) + { + BindingCertificate = result.BindingCertificate + }; } private static TokenAcquisitionOptions? GetEffectiveTokenAcquisitionOptions(AcquireTokenOptions? tokenAcquisitionOptions, string? authenticationScheme, CancellationToken cancellationToken) diff --git a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs index 45b0a4518..6d3b0c989 100644 --- a/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs +++ b/src/Microsoft.Identity.Web.TokenAcquisition/TokenAcquisition.cs @@ -52,6 +52,8 @@ class OAuthConstants private readonly ConcurrentDictionary _applicationsByAuthorityClientId = new(); private readonly ConcurrentDictionary _appSemaphores = new(); + private const string RequestBoundTokenParameterName = "RequestBoundToken"; + private bool _retryClientCertificate; protected readonly IMsalHttpClientFactory _httpClientFactory; protected readonly ILogger _logger; @@ -104,7 +106,7 @@ public TokenAcquisition( ICredentialsLoader credentialsLoader) { _tokenCacheProvider = tokenCacheProvider; - _httpClientFactory = serviceProvider.GetService() ?? new MsalAspNetCoreHttpClientFactory(httpClientFactory); + _httpClientFactory = serviceProvider.GetService() ?? new MsalMtlsHttpClientFactory(httpClientFactory); _logger = logger; _serviceProvider = serviceProvider; _tokenAcquisitionHost = tokenAcquisitionHost; @@ -587,6 +589,11 @@ public async Task GetAuthenticationResultForAppAsync( .AcquireTokenForClient(new[] { scope }.Except(_scopesRequestedByMsal)) .WithSendX5C(mergedOptions.SendX5C); + if (mergedOptions.IsTokenBinding) + { + builder.WithMtlsProofOfPossession(); + } + if (addInOptions != null) { addInOptions.InvokeOnBeforeTokenAcquisitionForApp(builder, tokenAcquisitionOptions); @@ -747,6 +754,10 @@ private MergedOptions GetMergedOptions(string? authenticationScheme, TokenAcquis mergedOptions = _tokenAcquisitionHost.GetOptions(authenticationScheme ?? tokenAcquisitionOptions?.AuthenticationOptionsName, out _); } + mergedOptions.IsTokenBinding = tokenAcquisitionOptions?.ExtraParameters?.TryGetValue(RequestBoundTokenParameterName, out var requestBoundTokenValue) == true + && requestBoundTokenValue is bool requestBoundToken + && requestBoundToken; + return mergedOptions; } @@ -987,11 +998,23 @@ private async Task BuildConfidentialClientApplic try { - await builder.WithClientCredentialsAsync( - mergedOptions.ClientCredentials!, - _logger, - _credentialsLoader, - new CredentialSourceLoaderParameters(mergedOptions.ClientId!, authority)); + if (mergedOptions.IsTokenBinding) + { + await builder.WithBindingCertificateAsync( + mergedOptions.ClientCredentials!, + _logger, + _credentialsLoader, + new CredentialSourceLoaderParameters(mergedOptions.ClientId!, authority), + isTokenBinding: true); + } + else + { + await builder.WithClientCredentialsAsync( + mergedOptions.ClientCredentials!, + _logger, + _credentialsLoader, + new CredentialSourceLoaderParameters(mergedOptions.ClientId!, authority)); + } } catch (ArgumentException ex) when (ex.Message == IDWebErrorMessage.ClientCertificatesHaveExpiredOrCannotBeLoaded) { diff --git a/tests/DevApps/MtlsPop/MtlsPopClient/MtlsPopClient.csproj b/tests/DevApps/MtlsPop/MtlsPopClient/MtlsPopClient.csproj new file mode 100644 index 000000000..4a7b2c7a0 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopClient/MtlsPopClient.csproj @@ -0,0 +1,31 @@ + + + + Exe + net8.0 + enable + enable + true + + + + + + + + + + + + + + + + + + + Always + + + + diff --git a/tests/DevApps/MtlsPop/MtlsPopClient/Program.cs b/tests/DevApps/MtlsPop/MtlsPopClient/Program.cs new file mode 100644 index 000000000..983f40764 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopClient/Program.cs @@ -0,0 +1,55 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Graph; +using Microsoft.Identity.Abstractions; +using Microsoft.Identity.Web; + +namespace MtlsPopSample +{ + public class Program + { + public static async Task Main(string[] args) + { + var tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance(); + + tokenAcquirerFactory.Services.AddLogging(builder => + builder.AddConsole().SetMinimumLevel(LogLevel.Information)); + + tokenAcquirerFactory.Services.AddDownstreamApi("WebApi", + tokenAcquirerFactory.Configuration.GetSection("WebApi")); + + tokenAcquirerFactory.Services.AddMicrosoftGraph(); + + var sp = tokenAcquirerFactory.Build(); + + Console.WriteLine("Scenario 1: calling web API with mTLS PoP token..."); + var webApi = sp.GetRequiredService(); + var result = await webApi.GetForAppAsync>("WebApi", + options => + { + options.AcquireTokenOptions.ExtraParameters ??= new Dictionary(); + options.AcquireTokenOptions.ExtraParameters["RequestBoundToken"] = true; // mTLS PoP + }).ConfigureAwait(false); + + Console.WriteLine("Web API result:"); + foreach (var forecast in result!) + { + Console.WriteLine($"{forecast.Date}: {forecast.Summary} - {forecast.TemperatureC}C/{forecast.TemperatureF}F"); + } + + Console.WriteLine(); + Console.WriteLine("Scenario 2: Calling Microsoft Graph with Bearer (non mTLS PoP) token..."); + var graphServiceClient = sp.GetRequiredService(); + var users = await graphServiceClient.Users + .Request() + .WithAppOnly() + .GetAsync(); + + Console.WriteLine("Microsoft Graph result:"); + Console.WriteLine($"{users.Count} users"); + } + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopClient/WeatherForecast.cs b/tests/DevApps/MtlsPop/MtlsPopClient/WeatherForecast.cs new file mode 100644 index 000000000..77f705c3e --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopClient/WeatherForecast.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace MtlsPopSample +{ + public class WeatherForecast + { + public DateOnly Date { get; set; } + + public int TemperatureC { get; set; } + + public int TemperatureF => 32 + (int)(TemperatureC / 0.5556); + + public string? Summary { get; set; } + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopClient/appsettings.json b/tests/DevApps/MtlsPop/MtlsPopClient/appsettings.json new file mode 100644 index 000000000..dfd9ee81f --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopClient/appsettings.json @@ -0,0 +1,22 @@ +{ + "AzureAd": { + "Instance": "https://login.microsoftonline.com/", + "TenantId": "bea21ebe-8b64-4d06-9f6d-6a889b120a7c", + "ClientId": "163ffef9-a313-45b4-ab2f-c7e2f5e0e23e", + "AzureRegion": "westus3", + "ClientCredentials": [ + { + "SourceType": "StoreWithDistinguishedName", + "CertificateStorePath": "CurrentUser/My", + "CertificateDistinguishedName": "CN=LabAuth.MSIDLab.com" + } + ], + "SendX5c": true + }, + "WebApi": { + "BaseUrl": "https://localhost:7060/", + "RelativePath": "WeatherForecast", + "RequestAppToken": true, + "Scopes": [ "https://graph.microsoft.com/.default" ] + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/Controllers/WeatherForecastController.cs b/tests/DevApps/MtlsPop/MtlsPopWebApi/Controllers/WeatherForecastController.cs new file mode 100644 index 000000000..12d490a26 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/Controllers/WeatherForecastController.cs @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Identity.Web.Resource; + +namespace MtlsPopSample.Controllers +{ + [ApiController] + [Route("[controller]")] + [Authorize(Policy = "MtlsPop")] + public class WeatherForecastController : ControllerBase + { + private static readonly string[] Summaries = new[] + { + "Freezing", "Bracing", "Chilly", "Cool", "Mild", "Warm", "Balmy", "Hot", "Sweltering", "Scorching" + }; + + private readonly ILogger _logger; + + public WeatherForecastController(ILogger logger) + { + _logger = logger; + } + + [HttpGet(Name = "GetWeatherForecast")] + public IEnumerable Get() + { + return Enumerable.Range(1, 5).Select(index => new WeatherForecast + { + Date = DateOnly.FromDateTime(DateTime.Now.AddDays(index)), + TemperatureC = Random.Shared.Next(-20, 55), + Summary = Summaries[Random.Shared.Next(Summaries.Length)] + }) + .ToArray(); + } + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/MtlsPopAuthorizationHandler.cs b/tests/DevApps/MtlsPop/MtlsPopWebApi/MtlsPopAuthorizationHandler.cs new file mode 100644 index 000000000..d57751087 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/MtlsPopAuthorizationHandler.cs @@ -0,0 +1,120 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.AspNetCore.Authorization; +using Microsoft.IdentityModel.Tokens; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Text.Json; + +namespace MtlsPopSample +{ + public class MtlsPopRequirement : IAuthorizationRequirement + { + } + + public class MtlsPopAuthorizationHandler : AuthorizationHandler + { + private const string ProtocolName = "MTLS_POP"; + + private readonly ILogger _logger; + + public MtlsPopAuthorizationHandler(ILogger logger) + { + _logger = logger; + } + + protected override Task HandleRequirementAsync(AuthorizationHandlerContext context, MtlsPopRequirement requirement) + { + _logger.LogInformation("MtlsPopAuthorizationHandler invoked"); + + if (context.Resource is not HttpContext httpContext) + { + _logger.LogWarning("Resource is not HttpContext"); + return Task.CompletedTask; + } + + var authHeader = httpContext.Request.Headers.Authorization.FirstOrDefault(); + var authToken = !string.IsNullOrEmpty(authHeader) && authHeader.StartsWith($"{ProtocolName} ", StringComparison.OrdinalIgnoreCase) + ? authHeader.Substring($"{ProtocolName} ".Length).Trim() + : null; + + if (string.IsNullOrEmpty(authToken)) + { + _logger.LogWarning("No auth token found"); + return Task.CompletedTask; + } + + try + { + var handler = new JwtSecurityTokenHandler(); + var token = handler.ReadJwtToken(authToken); + + var cnfClaim = token.Claims.FirstOrDefault(c => c.Type == "cnf"); + if (cnfClaim == null) + { + _logger.LogWarning("mTLS PoP token does not contain 'cnf' claim"); + context.Fail(new AuthorizationFailureReason(this, "Missing 'cnf' claim in token")); + return Task.CompletedTask; + } + + _logger.LogInformation($"The 'cnf' claim value: {cnfClaim.Value}"); + + var cnfJson = JsonDocument.Parse(cnfClaim.Value); + if (!cnfJson.RootElement.TryGetProperty("x5t#S256", out var x5tS256Element)) + { + _logger.LogWarning("The 'cnf' claim does not contain 'x5t#S256' property"); + context.Fail(new AuthorizationFailureReason(this, "Missing 'x5t#S256' property in mTLS PoP 'cnf' claim")); + return Task.CompletedTask; + } + + var x5tS256 = x5tS256Element.GetString(); + if (string.IsNullOrEmpty(x5tS256)) + { + _logger.LogWarning("The 'cnf' claim contains an empty 'x5t#S256' property"); + context.Fail(new AuthorizationFailureReason(this, "Empty 'x5t#S256' property in mTLS PoP 'cnf' claim")); + return Task.CompletedTask; + } + + _logger.LogInformation($"Token bound to certificate with x5t#S256: {x5tS256}"); + + var clientCert = httpContext.Connection.ClientCertificate; + if (clientCert != null) + { + var certThumbprint = GetCertificateThumbprint(clientCert); + _logger.LogInformation($"Client cert thumprint: {certThumbprint}"); + + if (!string.Equals(certThumbprint, x5tS256, StringComparison.OrdinalIgnoreCase)) + { + _logger.LogWarning($"Mismatch between cert thumbprint and 'x5t#S256' from mTLS PoP 'cnf' claim property: cert thumbprint - {certThumbprint}, x5t#S256 = {x5tS256}"); + context.Fail(new AuthorizationFailureReason(this, "Cert thumbprint and mTLS PoP 'cnf' claim 'x5t#S256' property mismatch")); + return Task.CompletedTask; + } + + _logger.LogInformation("mTLS PoP token validation successful"); + } + else + { + _logger.LogInformation("No client certificate in request"); + } + + context.Succeed(requirement); + } + catch (Exception ex) + { + _logger.LogError(ex, "Error validating mTLS PoP token"); + context.Fail(new AuthorizationFailureReason(this, $"mTLS PoP validation error: {ex.Message}")); + } + + return Task.CompletedTask; + } + + private static string GetCertificateThumbprint(X509Certificate2 certificate) + { + using var sha256 = SHA256.Create(); + var hash = sha256.ComputeHash(certificate.RawData); + return Base64UrlEncoder.Encode(hash); + } + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/MtlsPopWebApi.csproj b/tests/DevApps/MtlsPop/MtlsPopWebApi/MtlsPopWebApi.csproj new file mode 100644 index 000000000..8b593378f --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/MtlsPopWebApi.csproj @@ -0,0 +1,22 @@ + + + + net8.0 + enable + enable + true + + + + + + + + + + + + + + + diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/Program.cs b/tests/DevApps/MtlsPop/MtlsPopWebApi/Program.cs new file mode 100644 index 000000000..1c5903ac4 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/Program.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Authorization; +using Microsoft.Identity.Web; + +namespace MtlsPopSample +{ + public class Program + { + public static void Main(string[] args) + { + var builder = WebApplication.CreateBuilder(args); + + builder.Services.AddControllers(); + + // Learn more about configuring OpenAPI at https://learn.microsoft.com/aspnet/core/fundamentals/openapi/aspnetcore-openapi + builder.Services.AddEndpointsApiExplorer(); + + builder.Services.AddMicrosoftIdentityWebApiAuthentication(builder.Configuration); + + builder.Services.AddAuthorization(options => + { + options.AddPolicy("MtlsPop", policy => + policy.Requirements.Add(new MtlsPopRequirement())); + }); + + builder.Services.AddSingleton(); + + var app = builder.Build(); + + app.UseHttpsRedirection(); + + app.UseAuthentication(); + app.UseAuthorization(); + + app.MapControllers(); + + app.Run(); + } + } +} + +/* + +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.Identity.Web; +using Microsoft.IdentityModel.Tokens; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; + +var builder = WebApplication.CreateBuilder(args); + +// Add services to the container. +builder.Services.AddControllers(); +builder.Services.AddEndpointsApiExplorer(); + +// Configure Microsoft Identity Web API authentication +builder.Services.AddMicrosoftIdentityWebApiAuthentication(builder.Configuration, "AzureAd", "Bearer", true); + +// Configure custom mTLS PoP token validation +builder.Services.Configure(JwtBearerDefaults.AuthenticationScheme, options => +{ + var existingOnTokenValidated = options.Events?.OnTokenValidated; + + options.Events ??= new JwtBearerEvents(); + + options.Events.OnTokenValidated = async context => + { + // Call the existing handler first (if any) + if (existingOnTokenValidated != null) + { + await existingOnTokenValidated(context); + } + + var logger = context.HttpContext.RequestServices.GetRequiredService>(); + + try + { + // Get the JWT token + var token = context.SecurityToken as JwtSecurityToken; + if (token == null) + { + logger.LogWarning("SecurityToken is not a JwtSecurityToken"); + context.Fail("Invalid token type"); + return; + } + + // Check for cnf claim (confirmation claim for mTLS PoP) + var cnfClaim = token.Claims.FirstOrDefault(c => c.Type == "cnf"); + if (cnfClaim == null) + { + logger.LogWarning("Token does not contain cnf claim - not an mTLS PoP token"); + context.Fail("Missing cnf claim - mTLS PoP token required"); + return; + } + + logger.LogInformation("Found cnf claim in token: {CnfValue}", cnfClaim.Value); + + // Parse the cnf claim to get x5t#S256 + var cnfJson = System.Text.Json.JsonDocument.Parse(cnfClaim.Value); + if (!cnfJson.RootElement.TryGetProperty("x5t#S256", out var x5tS256Element)) + { + logger.LogWarning("cnf claim does not contain x5t#S256 property"); + context.Fail("Invalid cnf claim - missing x5t#S256"); + return; + } + + var x5tS256 = x5tS256Element.GetString(); + if (string.IsNullOrEmpty(x5tS256)) + { + logger.LogWarning("x5t#S256 value is empty"); + context.Fail("Invalid x5t#S256 value"); + return; + } + + logger.LogInformation("Token bound to certificate with x5t#S256: {X5tS256}", x5tS256); + + // Get client certificate from request (if mTLS is configured) + var clientCert = context.HttpContext.Connection.ClientCertificate; + if (clientCert != null) + { + // Compute the SHA256 thumbprint of the client certificate + var certThumbprint = ComputeCertificateThumbprint(clientCert); + + logger.LogInformation("Client certificate x5t#S256: {CertThumbprint}", certThumbprint); + + // Verify that the certificate in the request matches the one in the token + if (!string.Equals(certThumbprint, x5tS256, StringComparison.OrdinalIgnoreCase)) + { + logger.LogWarning("Certificate mismatch - Token x5t#S256: {TokenThumbprint}, Client cert x5t#S256: {CertThumbprint}", + x5tS256, certThumbprint); + context.Fail("Certificate thumbprint mismatch"); + return; + } + + logger.LogInformation("mTLS PoP token validation successful - certificate binding verified"); + } + else + { + logger.LogInformation("No client certificate in request - mTLS PoP token accepted based on cnf claim presence"); + } + } + catch (Exception ex) + { + logger.LogError(ex, "Error validating mTLS PoP token"); + context.Fail($"mTLS PoP validation error: {ex.Message}"); + } + + await Task.CompletedTask; + }; +}); + +var app = builder.Build(); + +app.UseHttpsRedirection(); +app.UseAuthentication(); +app.UseAuthorization(); + +app.MapControllers(); + +app.Run(); + +static string ComputeCertificateThumbprint(X509Certificate2 certificate) +{ + using var sha256 = SHA256.Create(); + var hash = sha256.ComputeHash(certificate.RawData); + return Base64UrlEncoder.Encode(hash); +} +*/ diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/Properties/launchSettings.json b/tests/DevApps/MtlsPop/MtlsPopWebApi/Properties/launchSettings.json new file mode 100644 index 000000000..d0f30c895 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/Properties/launchSettings.json @@ -0,0 +1,14 @@ +{ + "$schema": "http://json.schemastore.org/launchsettings.json", + "profiles": { + "https": { + "commandName": "Project", + "dotnetRunMessages": true, + "launchBrowser": false, + "applicationUrl": "https://localhost:7060", + "environmentVariables": { + "ASPNETCORE_ENVIRONMENT": "Development" + } + } + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/WeatherForecast.cs b/tests/DevApps/MtlsPop/MtlsPopWebApi/WeatherForecast.cs new file mode 100644 index 000000000..77f705c3e --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/WeatherForecast.cs @@ -0,0 +1,16 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace MtlsPopSample +{ + public class WeatherForecast + { + public DateOnly Date { get; set; } + + public int TemperatureC { get; set; } + + public int TemperatureF => 32 + (int)(TemperatureC / 0.5556); + + public string? Summary { get; set; } + } +} diff --git a/tests/DevApps/MtlsPop/MtlsPopWebApi/appsettings.json b/tests/DevApps/MtlsPop/MtlsPopWebApi/appsettings.json new file mode 100644 index 000000000..b14406255 --- /dev/null +++ b/tests/DevApps/MtlsPop/MtlsPopWebApi/appsettings.json @@ -0,0 +1,24 @@ +{ + "AzureAd": { + "Instance": "https://login.microsoftonline.com/", + "TenantId": "f645ad92-e38d-4d1a-b510-d1b09a74a8ca", + "ClientId": "556d438d-2f4b-4add-9713-ede4e5f5d7da", + "Scopes": "https://graph.microsoft.com/.default" + }, + "Logging": { + "LogLevel": { + "Default": "Trace", + "Microsoft.AspNetCore": "Trace" + } + }, + "AllowedHosts": "*", + "Kestrel": { + "Endpoints": { + "HttpsClientCert": { + "Url": "https://localhost:7060", + "ClientCertificateMode": "RequireCertificate", + "CheckCertificateRevocation": true + } + } + } +} diff --git a/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs b/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs index 84687c2a9..fa6b8d24c 100644 --- a/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs +++ b/tests/E2E Tests/TokenAcquirerTests/TokenAcquirer.cs @@ -3,9 +3,12 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Linq; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; +using System.Text; +using System.Text.Json; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; @@ -406,6 +409,61 @@ public async Task AcquireTokenWithMs10AtPop_ClientCredentialsAsync() Assert.NotNull(result.AccessToken); } + [IgnoreOnAzureDevopsFact] + // [Fact] + public async Task AcquireTokenWithMtlsPop_WithBindingCertificate_ReturnsMtlsPopToken() + { + // Arrange + TokenAcquirerFactoryTesting.ResetTokenAcquirerFactoryInTest(); + TokenAcquirerFactory tokenAcquirerFactory = TokenAcquirerFactory.GetDefaultInstance(); + IServiceCollection services = tokenAcquirerFactory.Services; + + services.Configure(s_optionName, option => + { + option.Instance = "https://login.microsoftonline.com/"; + option.TenantId = "bea21ebe-8b64-4d06-9f6d-6a889b120a7c"; + option.ClientId = "163ffef9-a313-45b4-ab2f-c7e2f5e0e23e"; + option.AzureRegion = "westus3"; + option.ClientCredentials = s_clientCredentials; + }); + + services.AddInMemoryTokenCaches(); + + var serviceProvider = tokenAcquirerFactory.Build(); + ITokenAcquirer tokenAcquirer = tokenAcquirerFactory.GetTokenAcquirer(s_optionName); + + var tokenAcquisitionOptions = new TokenAcquisitionOptions + { + ExtraParameters = new Dictionary + { + { "RequestBoundToken", true } // mTLS PoP + } + }; + + // Act + var result = await tokenAcquirer.GetTokenForAppAsync("https://graph.microsoft.com/.default", tokenAcquisitionOptions); + + // Assert + Assert.NotNull(result.AccessToken); + Assert.StartsWith("eyJ0e", result.AccessToken, StringComparison.OrdinalIgnoreCase); + + var tokenParts = result.AccessToken.Split('.'); + Assert.Equal(3, tokenParts.Length); + + var tokenPayload = tokenParts[1]; + var tokenPayloadBytes = Base64UrlEncoder.DecodeBytes(tokenPayload); + var tokenPayloadString = Encoding.UTF8.GetString(tokenPayloadBytes); + + using var tokenPayloadJson = JsonDocument.Parse(tokenPayloadString); + var tokenPayloadJsonRoot = tokenPayloadJson.RootElement; + + Assert.True(tokenPayloadJsonRoot.TryGetProperty("cnf", out var tokenCnfClaim), "The mTLS PoP token should contain a 'cnf' claim"); + Assert.True(tokenCnfClaim.TryGetProperty("x5t#S256", out var tokenX5tS256), "The mTLS PoP 'cnf' claim should contain an 'x5t#S256' property"); + + var tokenX5tS256Value = tokenX5tS256.GetString(); + Assert.False(string.IsNullOrEmpty(tokenX5tS256Value)); + } + private static string CreatePopClaim(RsaSecurityKey key, string algorithm) { var parameters = key.Rsa == null ? key.Parameters : key.Rsa.ExportParameters(false); diff --git a/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockMtlsHttpClientFactory.cs b/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockMtlsHttpClientFactory.cs new file mode 100644 index 000000000..a3e65152d --- /dev/null +++ b/tests/Microsoft.Identity.Web.Test.Common/Mocks/MockMtlsHttpClientFactory.cs @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Client; +using Xunit; + +namespace Microsoft.Identity.Web.Test.Common.Mocks +{ + /// + /// HttpClient factory that serves Http responses for testing purposes and supports mTLS certificate binding. + /// + /// + /// This implements both IHttpClientFactory and IMsalMtlsHttpClientFactory for testing mTLS scenarios. + /// + public class MockMtlsHttpClientFactory : IMsalMtlsHttpClientFactory, IHttpClientFactory, IDisposable + { + private LinkedList _httpMessageHandlerQueue = new(); + + private volatile bool _addInstanceDiscovery = true; + + public MockHttpMessageHandler AddMockHandler(MockHttpMessageHandler handler) + { + if (_httpMessageHandlerQueue.Count == 0 && _addInstanceDiscovery) + { + _addInstanceDiscovery = false; + handler.ReplaceMockHttpMessageHandler = (h) => + { + return _httpMessageHandlerQueue.AddFirst(h).Value; + }; + } + + // add a message to the front of the queue + _httpMessageHandlerQueue.AddLast(handler); + return handler; + } + + public HttpClient GetHttpClient() + { + HttpMessageHandler? messageHandler = _httpMessageHandlerQueue.First?.Value; + if (messageHandler == null) + { + throw new InvalidOperationException("The mock HTTP message handler queue is empty."); + } + _httpMessageHandlerQueue.RemoveFirst(); + + var httpClient = new HttpClient(messageHandler); + + httpClient.DefaultRequestHeaders.Accept.Clear(); + httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + + return httpClient; + } + + public HttpClient CreateClient(string name) + { + return GetHttpClient(); + } + + /// + /// Gets an HttpClient configured with the specified certificate for mTLS. + /// + /// The certificate to use for mTLS. + /// An HttpClient configured for mTLS. + public HttpClient GetHttpClient(X509Certificate2 certificate) + { + // For testing purposes, return the same mocked HttpClient regardless of certificate + // In a real implementation, this would configure the HttpClient with the certificate + return GetHttpClient(); + } + + /// + public void Dispose() + { + // This ensures we only check the mock queue on dispose when we're not in the middle of an + // exception flow. Otherwise, any early assertion will cause this to likely fail + // even though it's not the root cause. +#pragma warning disable CS0618 // Type or member is obsolete - this is non-production code so it's fine + if (Marshal.GetExceptionCode() == 0) +#pragma warning restore CS0618 // Type or member is obsolete + { + string remainingMocks = string.Join( + " ", + _httpMessageHandlerQueue.Select( + h => (h as MockHttpMessageHandler)?.ExpectedUrl ?? string.Empty)); + + Assert.Empty(_httpMessageHandlerQueue); + } + } + } +} diff --git a/tests/Microsoft.Identity.Web.Test/Certificates/WithClientCredentialsTests.cs b/tests/Microsoft.Identity.Web.Test/Certificates/WithClientCredentialsTests.cs index 176d802bf..435982a51 100644 --- a/tests/Microsoft.Identity.Web.Test/Certificates/WithClientCredentialsTests.cs +++ b/tests/Microsoft.Identity.Web.Test/Certificates/WithClientCredentialsTests.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Identity.Abstractions; +using Microsoft.Identity.Client; using Microsoft.Identity.Web.Test.Common; using NSubstitute; using NSubstitute.ExceptionExtensions; @@ -195,5 +196,222 @@ private static async Task RunFailToLoadLogicAsync(IEnumerable(); + var credLoader = Substitute.For(); + var builder = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant); + + var credentialDescription = new CredentialDescription + { + SourceType = CredentialSource.StoreWithThumbprint, + CertificateThumbprint = "test-thumbprint", + CertificateStorePath = "CurrentUser/My" + }; + + var testCertificate = Base64EncodedCertificateLoader.LoadFromBase64Encoded( + TestConstants.CertificateX5cWithPrivateKey, + TestConstants.CertificateX5cWithPrivateKeyPassword, + X509KeyStorageFlags.DefaultKeySet); + + // Mock the credential loader to successfully load the certificate + credLoader.LoadCredentialsIfNeededAsync(Arg.Any(), Arg.Any()) + .Returns(args => + { + var cd = (args[0] as CredentialDescription)!; + cd.Certificate = testCertificate; + return Task.CompletedTask; + }); + + // Act + var result = await builder.WithBindingCertificateAsync( + new[] { credentialDescription }, + logger, + credLoader, + credentialSourceLoaderParameters: null, + isTokenBinding: true); + + // Assert + Assert.NotNull(result); + Assert.Same(builder, result); // Should return the same builder instance + await credLoader.Received(1).LoadCredentialsIfNeededAsync(credentialDescription, null); + } + + [Fact] + public async Task WithBindingCertificateAsync_NoValidCredentials_ReturnsOriginalBuilderAfterException() + { + // Arrange + var logger = Substitute.For(); + var credLoader = Substitute.For(); + var builder = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant); + + var credentialDescription = new CredentialDescription + { + SourceType = CredentialSource.StoreWithThumbprint, + CertificateThumbprint = "test-thumbprint", + CertificateStorePath = "CurrentUser/My" + }; + + // Mock the credential loader to fail loading (Skip = true causes LoadCredentialForMsalOrFailAsync to throw) + credLoader.LoadCredentialsIfNeededAsync(Arg.Any(), Arg.Any()) + .Returns(args => + { + var cd = (args[0] as CredentialDescription)!; + cd.Skip = true; + return Task.CompletedTask; + }); + + // Act & Assert + // This should throw because LoadCredentialForMsalOrFailAsync throws when no credentials can be loaded + await Assert.ThrowsAsync( + () => builder.WithBindingCertificateAsync( + new[] { credentialDescription }, + logger, + credLoader, + credentialSourceLoaderParameters: null, + isTokenBinding: true)); + } + + [Fact] + public async Task WithBindingCertificateAsync_CredentialWithoutCertificate_ReturnsOriginalBuilder() + { + // Arrange + var logger = Substitute.For(); + var credLoader = Substitute.For(); + var builder = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant); + + var credentialDescription = new CredentialDescription + { + SourceType = CredentialSource.ClientSecret, + ClientSecret = "test-secret" + }; + + // Mock the credential loader to load a credential without a certificate + credLoader.LoadCredentialsIfNeededAsync(Arg.Any(), Arg.Any()) + .Returns(args => + { + var cd = (args[0] as CredentialDescription)!; + // Certificate is null by default + return Task.CompletedTask; + }); + + // Act & Assert + await Assert.ThrowsAsync( + () => builder.WithBindingCertificateAsync( + new[] { credentialDescription }, + logger, + credLoader, + credentialSourceLoaderParameters: null, + isTokenBinding: true)); + } + + [Fact] + public async Task WithBindingCertificateAsync_CredentialLoadingFails_PropagatesException() + { + // Arrange + var logger = Substitute.For(); + var credLoader = Substitute.For(); + var builder = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant); + + var credentialDescription = new CredentialDescription + { + SourceType = CredentialSource.StoreWithThumbprint, + CertificateThumbprint = "invalid-thumbprint", + CertificateStorePath = "CurrentUser/My" + }; + + var expectedException = new Exception("Certificate not found"); + + // Mock the credential loader to throw an exception + credLoader.LoadCredentialsIfNeededAsync(Arg.Any(), Arg.Any()) + .ThrowsAsync(expectedException); + + // Act & Assert + var actualException = await Assert.ThrowsAsync( + () => builder.WithBindingCertificateAsync( + new[] { credentialDescription }, + logger, + credLoader, + credentialSourceLoaderParameters: null, + isTokenBinding: true)); + + // Verify the exception is propagated from LoadCredentialForMsalOrFailAsync + Assert.Contains("Certificate not found", actualException.Message, StringComparison.Ordinal); + } + + [Fact] + public async Task WithBindingCertificateAsync_EmptyCredentialsList_ReturnsOriginalBuilder() + { + // Arrange + var logger = Substitute.For(); + var credLoader = Substitute.For(); + var builder = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant); + + // Act & Assert + await Assert.ThrowsAsync( + () => builder.WithBindingCertificateAsync( + new CredentialDescription[0], + logger, + credLoader, + credentialSourceLoaderParameters: null, + isTokenBinding: true)); + } + + [Fact] + public async Task WithBindingCertificateAsync_WithCredentialSourceLoaderParameters_PassesParametersCorrectly() + { + // Arrange + var logger = Substitute.For(); + var credLoader = Substitute.For(); + var builder = ConfidentialClientApplicationBuilder.Create(TestConstants.ClientId) + .WithAuthority(TestConstants.AuthorityCommonTenant); + + var credentialDescription = new CredentialDescription + { + SourceType = CredentialSource.StoreWithThumbprint, + CertificateThumbprint = "test-thumbprint", + CertificateStorePath = "CurrentUser/My" + }; + + var testCertificate = Base64EncodedCertificateLoader.LoadFromBase64Encoded( + TestConstants.CertificateX5cWithPrivateKey, + TestConstants.CertificateX5cWithPrivateKeyPassword, + X509KeyStorageFlags.DefaultKeySet); + + var credentialSourceLoaderParameters = new CredentialSourceLoaderParameters("test-client-id", "test-tenant-id"); + + // Mock the credential loader to successfully load the certificate + credLoader.LoadCredentialsIfNeededAsync(Arg.Any(), Arg.Any()) + .Returns(args => + { + var cd = (args[0] as CredentialDescription)!; + cd.Certificate = testCertificate; + return Task.CompletedTask; + }); + + // Act + var result = await builder.WithBindingCertificateAsync( + new[] { credentialDescription }, + logger, + credLoader, + credentialSourceLoaderParameters, + isTokenBinding: true); + + // Assert + Assert.NotNull(result); + await credLoader.Received(1).LoadCredentialsIfNeededAsync(credentialDescription, credentialSourceLoaderParameters); + } + + #endregion + } } diff --git a/tests/Microsoft.Identity.Web.Test/DefaultAuthorizationHeaderProviderTests.cs b/tests/Microsoft.Identity.Web.Test/DefaultAuthorizationHeaderProviderTests.cs new file mode 100644 index 000000000..508d634ee --- /dev/null +++ b/tests/Microsoft.Identity.Web.Test/DefaultAuthorizationHeaderProviderTests.cs @@ -0,0 +1,533 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Security.Claims; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Abstractions; +using Microsoft.Identity.Client; +using Microsoft.Identity.Web.Test.Common; +using NSubstitute; +using Xunit; + +namespace Microsoft.Identity.Web.Test +{ + [Collection(nameof(TokenAcquirerFactorySingletonProtection))] + public class DefaultAuthorizationHeaderProviderTests + { + private readonly ITokenAcquisition _mockTokenAcquisition; + private readonly DefaultAuthorizationHeaderProvider _provider; + + public DefaultAuthorizationHeaderProviderTests() + { + _mockTokenAcquisition = Substitute.For(); + _provider = new DefaultAuthorizationHeaderProvider(_mockTokenAcquisition); + } + + [Fact] + public void Constructor_WithValidParameters_InitializesCorrectly() + { + // Arrange & Act + var provider = new DefaultAuthorizationHeaderProvider(_mockTokenAcquisition); + + // Assert + Assert.NotNull(provider); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_ForBoundHeader_WithValidOptions_ReturnsSuccessResult() + { + // Arrange + var downstreamApiOptions = new DownstreamApiOptions + { + Scopes = new[] { "https://graph.microsoft.com/.default" } + }; + + var mockAuthenticationResult = new AuthenticationResult( + "access_token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { "scope1" }, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + var claimsPrincipal = new ClaimsPrincipal(); + var cancellationToken = CancellationToken.None; + + // Act + var result = await ((IAuthorizationHeaderProvider2)_provider).CreateAuthorizationHeaderAsync( + downstreamApiOptions, + claimsPrincipal, + cancellationToken); + + // Assert + Assert.NotNull(result.Result); + Assert.Equal("Bearer access_token", result.Result.AuthorizationHeaderValue); + + await _mockTokenAcquisition.Received(1).GetAuthenticationResultForAppAsync( + "https://graph.microsoft.com/.default", + null, + null, + Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_ForBoundHeader_WithBindingCertificate_ReturnsBindingCertificate() + { + // Arrange + var downstreamApiOptions = new DownstreamApiOptions + { + Scopes = new[] { "https://graph.microsoft.com/.default" } + }; + + // Create test certificate + var bytes = Convert.FromBase64String(TestConstants.CertificateX5c); +#if NET9_0_OR_GREATER + var bindingCertificate = X509CertificateLoader.LoadCertificate(bytes); +#else +#pragma warning disable SYSLIB0057 // Type or member is obsolete + var bindingCertificate = new X509Certificate2(bytes); +#pragma warning restore SYSLIB0057 // Type or member is obsolete +#endif + + var mockAuthenticationResult = new AuthenticationResult( + "access_token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { "scope1" }, + Guid.NewGuid()) + { + BindingCertificate = bindingCertificate + }; + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider2)_provider).CreateAuthorizationHeaderAsync( + downstreamApiOptions, + null, + CancellationToken.None); + + // Assert + Assert.NotNull(result.Result); + Assert.Equal("Bearer access_token", result.Result.AuthorizationHeaderValue); + Assert.Same(bindingCertificate, result.Result.BindingCertificate); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_ForBoundHeader_WithoutBindingCertificate_ReturnsNullBindingCertificate() + { + // Arrange + var downstreamApiOptions = new DownstreamApiOptions + { + Scopes = new[] { "https://graph.microsoft.com/.default" } + }; + + var mockAuthenticationResult = new AuthenticationResult( + "access_token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { "scope1" }, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider2)_provider).CreateAuthorizationHeaderAsync( + downstreamApiOptions, + null, + CancellationToken.None); + + // Assert + Assert.NotNull(result.Result); + Assert.Equal("Bearer access_token", result.Result.AuthorizationHeaderValue); + Assert.Null(result.Result.BindingCertificate); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_ForBoundHeader_WithEmptyScopes_UsesEmptyString() + { + // Arrange + var downstreamApiOptions = new DownstreamApiOptions + { + Scopes = new string[0] // Empty scopes + }; + + var mockAuthenticationResult = new AuthenticationResult( + "access_token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { "scope1" }, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider2)_provider).CreateAuthorizationHeaderAsync( + downstreamApiOptions, + null, + CancellationToken.None); + + // Assert + Assert.NotNull(result.Result); + await _mockTokenAcquisition.Received(1).GetAuthenticationResultForAppAsync( + string.Empty, // Should use empty string when no scopes + null, + null, + Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_ForBoundHeader_WithAcquireTokenOptions_PassesCorrectParameters() + { + // Arrange + var downstreamApiOptions = new DownstreamApiOptions + { + Scopes = new[] { "https://graph.microsoft.com/.default" }, + AcquireTokenOptions = new AcquireTokenOptions + { + AuthenticationOptionsName = "TestAuth", + Tenant = "test-tenant" + } + }; + + var mockAuthenticationResult = new AuthenticationResult( + "access_token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { "scope1" }, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + await ((IAuthorizationHeaderProvider2)_provider).CreateAuthorizationHeaderAsync( + downstreamApiOptions, + null, + CancellationToken.None); + + // Assert + await _mockTokenAcquisition.Received(1).GetAuthenticationResultForAppAsync( + "https://graph.microsoft.com/.default", + "TestAuth", + "test-tenant", + Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_WithUserScopes_AcquiresUserToken() + { + // Arrange + var scopes = new[] { "User.Read", "Mail.Read" }; + var options = new AuthorizationHeaderProviderOptions(); + var claimsPrincipal = new ClaimsPrincipal(); + var cancellationToken = CancellationToken.None; + var expectedHeader = "Bearer test-token"; + + var mockAuthenticationResult = new AuthenticationResult( + "test-token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + scopes, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForUserAsync( + scopes, + null, + null, + null, + claimsPrincipal, + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider)_provider).CreateAuthorizationHeaderAsync(scopes, options, claimsPrincipal, cancellationToken); + + // Assert + Assert.Equal(expectedHeader, result); + await _mockTokenAcquisition.Received(1) + .GetAuthenticationResultForUserAsync(scopes, null, null, null, claimsPrincipal, Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderForAppAsync_AcquiresAppToken() + { + // Arrange + var scopes = "https://graph.microsoft.com/.default"; + var options = new AuthorizationHeaderProviderOptions(); + var cancellationToken = CancellationToken.None; + var expectedHeader = "Bearer app-token"; + + var mockAuthenticationResult = new AuthenticationResult( + "app-token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { scopes }, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + scopes, + null, + null, + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider)_provider).CreateAuthorizationHeaderForAppAsync(scopes, options, cancellationToken); + + // Assert + Assert.Equal(expectedHeader, result); + await _mockTokenAcquisition.Received(1) + .GetAuthenticationResultForAppAsync(scopes, null, null, Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderForUserAsync_AcquiresUserToken() + { + // Arrange + var scopes = new[] { "User.Read", "Mail.Read" }; + var options = new AuthorizationHeaderProviderOptions(); + var claimsPrincipal = new ClaimsPrincipal(); + var cancellationToken = CancellationToken.None; + var expectedHeader = "Bearer user-token"; + + var mockAuthenticationResult = new AuthenticationResult( + "user-token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + scopes, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForUserAsync( + scopes, + null, + null, + null, + claimsPrincipal, + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider)_provider).CreateAuthorizationHeaderForUserAsync(scopes, options, claimsPrincipal, cancellationToken); + + // Assert + Assert.Equal(expectedHeader, result); + await _mockTokenAcquisition.Received(1) + .GetAuthenticationResultForUserAsync(scopes, null, null, null, claimsPrincipal, Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_WithNullParameters_AcquiresUserToken() + { + // Arrange + var scopes = new[] { "User.Read" }; + var expectedHeader = "Bearer test-token"; + + var mockAuthenticationResult = new AuthenticationResult( + "test-token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + scopes, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForUserAsync( + scopes, + null, + null, + null, + null, + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider)_provider).CreateAuthorizationHeaderAsync(scopes, null, null, CancellationToken.None); + + // Assert + Assert.Equal(expectedHeader, result); + await _mockTokenAcquisition.Received(1) + .GetAuthenticationResultForUserAsync(scopes, null, null, null, null, Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderForAppAsync_WithNullOptions_AcquiresAppToken() + { + // Arrange + var scopes = "https://graph.microsoft.com/.default"; + var expectedHeader = "Bearer app-token"; + + var mockAuthenticationResult = new AuthenticationResult( + "app-token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + new[] { scopes }, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + scopes, + null, + null, + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider)_provider).CreateAuthorizationHeaderForAppAsync(scopes, null, CancellationToken.None); + + // Assert + Assert.Equal(expectedHeader, result); + await _mockTokenAcquisition.Received(1) + .GetAuthenticationResultForAppAsync(scopes, null, null, Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderForUserAsync_WithNullParameters_AcquiresUserToken() + { + // Arrange + var scopes = new[] { "User.Read" }; + var expectedHeader = "Bearer user-token"; + + var mockAuthenticationResult = new AuthenticationResult( + "user-token", + false, + null, + DateTimeOffset.UtcNow.AddHours(1), + DateTimeOffset.UtcNow.AddHours(1), + "tenant_id", + null, + null, + scopes, + Guid.NewGuid()); + + _mockTokenAcquisition + .GetAuthenticationResultForUserAsync( + scopes, + null, + null, + null, + null, + Arg.Any()) + .Returns(Task.FromResult(mockAuthenticationResult)); + + // Act + var result = await ((IAuthorizationHeaderProvider)_provider).CreateAuthorizationHeaderForUserAsync(scopes, null, null, CancellationToken.None); + + // Assert + Assert.Equal(expectedHeader, result); + await _mockTokenAcquisition.Received(1) + .GetAuthenticationResultForUserAsync(scopes, null, null, null, null, Arg.Any()); + } + + [Fact] + public async Task CreateAuthorizationHeaderAsync_ForBoundProvider_TokenAcquisitionThrows_PropagatesException() + { + // Arrange + var downstreamApiOptions = new DownstreamApiOptions + { + Scopes = new[] { "https://graph.microsoft.com/.default" } + }; + + var expectedException = new MsalServiceException("test-error", "Test error message"); + _mockTokenAcquisition + .GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(Task.FromException(expectedException)); + + // Act & Assert + var actualException = await Assert.ThrowsAsync( + () => ((IAuthorizationHeaderProvider2)_provider).CreateAuthorizationHeaderAsync(downstreamApiOptions, null, CancellationToken.None)); + + Assert.Equal(expectedException.ErrorCode, actualException.ErrorCode); + Assert.Equal(expectedException.Message, actualException.Message); + } + } +} diff --git a/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/DownstreamApiTests.cs b/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/DownstreamApiTests.cs index 774ec159e..391ac79d5 100644 --- a/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/DownstreamApiTests.cs +++ b/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/DownstreamApiTests.cs @@ -8,15 +8,21 @@ using System.Net; using System.Net.Http; using System.Security.Claims; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; using System.Threading; using System.Threading.Tasks; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Identity.Abstractions; +using Microsoft.Identity.Client; +using Microsoft.Identity.Web.Test.Common; +using Microsoft.Identity.Web.Test.Common.Mocks; using Microsoft.Identity.Web.Test.Resource; +using NSubstitute; using Xunit; namespace Microsoft.Identity.Web.Tests @@ -77,10 +83,10 @@ public async Task UpdateRequestAsync_AddsToExtraQP() // Arrange var httpRequestMessage = new HttpRequestMessage(HttpMethod.Post, "https://example.com"); var content = new StringContent("test content"); - var options = new DownstreamApiOptions() { - AcquireTokenOptions = new AcquireTokenOptions() { - ExtraQueryParameters = new Dictionary() - { + var options = new DownstreamApiOptions() { + AcquireTokenOptions = new AcquireTokenOptions() { + ExtraQueryParameters = new Dictionary() + { { "n1", "v1" }, { "n2", "v2" }, { "caller-sdk-id", "bogus" } // value will be overwritten by the SDK @@ -97,7 +103,7 @@ public async Task UpdateRequestAsync_AddsToExtraQP() Assert.Equal("v1", options.AcquireTokenOptions.ExtraQueryParameters["n1"]); Assert.Equal("v2", options.AcquireTokenOptions.ExtraQueryParameters["n2"]); Assert.Equal( - DownstreamApi.CallerSDKDetails["caller-sdk-id"], + DownstreamApi.CallerSDKDetails["caller-sdk-id"], options.AcquireTokenOptions.ExtraQueryParameters["caller-sdk-id"] ); Assert.Equal( DownstreamApi.CallerSDKDetails["caller-sdk-ver"], @@ -464,11 +470,334 @@ public async Task ReadErrorResponseContentAsync_ReturnsMessage_WhenContentLength Assert.NotNull(result); // Either we get the truncation message about size, or the actual content is truncated Assert.True( - result.Contains("[Error response too large:", StringComparison.Ordinal) || + result.Contains("[Error response too large:", StringComparison.Ordinal) || result.EndsWith("... (truncated)", StringComparison.Ordinal) || result.Length <= 4096 + "... (truncated)".Length, "Error response should be limited in size"); } + + [Fact] + public void DownstreamApi_Constructor_WithBoundProvider_AcceptsIMsalMtlsHttpClientFactory() + { + // Arrange + var mockBoundProvider = Substitute.For(); + var mockMtlsHttpClientFactory = Substitute.For(); + + // Act & Assert - Should not throw + var downstreamApi = new DownstreamApi( + mockBoundProvider, + _namedDownstreamApiOptions, + mockMtlsHttpClientFactory, + _logger); + + Assert.NotNull(downstreamApi); + } + + [Fact] + public async Task UpdateRequestAsync_WithAuthorizationHeaderBoundProvider_CallsCorrectInterface() + { + // Arrange + var mockBoundProvider = Substitute.For(); + var testCertificate = Substitute.For(); + + var downstreamApi = new DownstreamApi( + mockBoundProvider, + _namedDownstreamApiOptions, + _httpClientFactory, + _logger); + + var options = new DownstreamApiOptions + { + Scopes = new[] { "https://api.example.com/.default" } + }; + + var authHeaderInfo = new AuthorizationHeaderInformation + { + AuthorizationHeaderValue = "Bearer test-token", + BindingCertificate = testCertificate + }; + + var mockResult = new OperationResult(authHeaderInfo); + + ((IAuthorizationHeaderProvider2)mockBoundProvider) + .CreateAuthorizationHeaderAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(mockResult); + + var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://api.example.com"); + + // Act + var result = await downstreamApi.UpdateRequestAsync( + httpRequestMessage, + null, + options, + false, + null, + CancellationToken.None); + + // Assert - Verify the bound provider interface was called + await ((IAuthorizationHeaderProvider2)mockBoundProvider).Received(1).CreateAuthorizationHeaderAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()); + + // Verify the regular provider interface was NOT called + await mockBoundProvider.DidNotReceive().CreateAuthorizationHeaderAsync( + Arg.Any>(), + Arg.Any(), + Arg.Any(), + Arg.Any()); + + // Assert - Verify the returned AuthorizationHeaderInformation + Assert.NotNull(result); + Assert.Equal("Bearer test-token", result.AuthorizationHeaderValue); + Assert.Equal(testCertificate, result.BindingCertificate); + Assert.Equal("Bearer test-token", httpRequestMessage.Headers.Authorization?.ToString()); + } + + [Fact] + public async Task UpdateRequestAsync_WithRegularAuthorizationHeaderProvider_FallsBackCorrectly() + { + // Arrange - Using the existing regular provider from the constructor + var httpRequestMessage = new HttpRequestMessage(HttpMethod.Get, "https://api.example.com"); + var options = new DownstreamApiOptions + { + Scopes = new[] { "https://api.example.com/.default" } + }; + + // Act + var result = await _input.UpdateRequestAsync( + httpRequestMessage, + null, + options, + false, + null, + CancellationToken.None); + + // Assert + Assert.Null(result); // Regular provider doesn't return AuthorizationHeaderInformation + Assert.Equal("Bearer ey", httpRequestMessage.Headers.Authorization?.ToString()); + } + + [Fact] + public async Task CallApiInternalAsync_WithRegularAuthorizationHeaderProvider_UsesRegularHttpClientFactory() + { + // Arrange + var mockHttpClientFactory = Substitute.For(); + var mockHandler = new MockHttpMessageHandler() + { + ExpectedMethod = HttpMethod.Get, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("{\"result\": \"success\"}") + } + }; + var mockHttpClient = new HttpClient(mockHandler); + + mockHttpClientFactory.CreateClient(Arg.Any()).Returns(mockHttpClient); + + var downstreamApi = new DownstreamApi( + _authorizationHeaderProvider, // Regular provider + _namedDownstreamApiOptions, + mockHttpClientFactory, + _logger); + + var options = new DownstreamApiOptions + { + BaseUrl = "https://api.example.com", + Scopes = new[] { "https://api.example.com/.default" }, + HttpMethod = "GET" + }; + + // Act + await downstreamApi.CallApiInternalAsync(null, options, false, null, null, CancellationToken.None); + + // Assert + mockHttpClientFactory.Received(1).CreateClient(Arg.Any()); + + // Note: HttpClient is disposed by DownstreamApi, no manual disposal needed + } + + [Fact] + public async Task CallApiInternalAsync_WithAuthorizationHeaderBoundProviderAndWithBindingCertificate_UsesMtlsHttpClientFactory() + { + // Arrange + var mockBoundProvider = Substitute.For(); + var mockMtlsHttpClientFactory = Substitute.For(); + var testCertificate = CreateTestCertificate(); + + var mockHandler = new MockHttpMessageHandler() + { + ExpectedMethod = HttpMethod.Get, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("{\"result\": \"success\"}") + } + }; + + // Create HttpClient with our mock handler + var mockMtlsHttpClient = new HttpClient(mockHandler); + + var downstreamApi = new DownstreamApi( + mockBoundProvider, + _namedDownstreamApiOptions, + mockMtlsHttpClientFactory, + _logger, + (IMsalHttpClientFactory)mockMtlsHttpClientFactory); + + var options = new DownstreamApiOptions + { + BaseUrl = "https://api.example.com", + Scopes = new[] { "https://api.example.com/.default" }, + HttpMethod = "GET" + }; + + var authHeaderInfo = new AuthorizationHeaderInformation + { + AuthorizationHeaderValue = "Bearer test-token", + BindingCertificate = testCertificate + }; + + var mockResult = new OperationResult(authHeaderInfo); + + ((IAuthorizationHeaderProvider2)mockBoundProvider) + .CreateAuthorizationHeaderAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(mockResult); + + // Setup mTLS HTTP client factory to return our pre-configured HttpClient + ((IMsalMtlsHttpClientFactory)mockMtlsHttpClientFactory) + .GetHttpClient(testCertificate) + .Returns(mockMtlsHttpClient); + + // Act + await downstreamApi.CallApiInternalAsync(null, options, false, null, null, CancellationToken.None); + + // Assert - Verify mTLS HTTP client factory was used + var _ = ((IMsalMtlsHttpClientFactory)mockMtlsHttpClientFactory).Received(1).GetHttpClient(testCertificate); + + // Verify regular HTTP client factory was NOT used + ((IHttpClientFactory)mockMtlsHttpClientFactory).DidNotReceive().CreateClient(Arg.Any()); + } + + [Fact] + public async Task CallApiInternalAsync_WithAuthorizationHeaderBoundProviderButWithoutBindingCertificate_UsesRegularHttpClientFactory() + { + // Arrange + var mockBoundProvider = Substitute.For(); + var mockMtlsHttpClientFactory = Substitute.For(); + + var mockHandler = new MockHttpMessageHandler() + { + ExpectedMethod = HttpMethod.Get, + ResponseMessage = new HttpResponseMessage(HttpStatusCode.OK) + { + Content = new StringContent("{\"result\": \"success\"}") + } + }; + var mockRegularHttpClient = new HttpClient(mockHandler); + + var downstreamApi = new DownstreamApi( + mockBoundProvider, + _namedDownstreamApiOptions, + mockMtlsHttpClientFactory, + _logger, + (IMsalHttpClientFactory)mockMtlsHttpClientFactory); + + var options = new DownstreamApiOptions + { + BaseUrl = "https://api.example.com", + Scopes = new[] { "https://api.example.com/.default" }, + HttpMethod = "GET" + }; + + var authHeaderInfo = new AuthorizationHeaderInformation + { + AuthorizationHeaderValue = "Bearer test-token", + BindingCertificate = null // No binding certificate + }; + + var mockResult = new OperationResult(authHeaderInfo); + + ((IAuthorizationHeaderProvider2)mockBoundProvider) + .CreateAuthorizationHeaderAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(mockResult); + + // Setup regular HTTP client factory to return our mocked HttpClient + ((IHttpClientFactory)mockMtlsHttpClientFactory) + .CreateClient(Arg.Any()) + .Returns(mockRegularHttpClient); + + // Act + await downstreamApi.CallApiInternalAsync(null, options, false, null, null, CancellationToken.None); + + // Assert - Verify regular HTTP client factory was used + ((IHttpClientFactory)mockMtlsHttpClientFactory).Received(1).CreateClient(Arg.Any()); + + // Verify mTLS HTTP client factory was NOT used + ((IMsalMtlsHttpClientFactory)mockMtlsHttpClientFactory).DidNotReceive().GetHttpClient(Arg.Any()); + + // Note: HttpClient is disposed by DownstreamApi, no manual disposal needed + } + + [Fact] + public async Task CallApiInternalAsync_WithAuthorizationHeaderBoundProviderWithAuthenticationFailure_ThrowsException() + { + // Arrange + var mockBoundProvider = Substitute.For(); + var mockHttpClientFactory = Substitute.For(); + + var downstreamApi = new DownstreamApi( + mockBoundProvider, + _namedDownstreamApiOptions, + mockHttpClientFactory, + _logger); + + var options = new DownstreamApiOptions + { + BaseUrl = "https://api.example.com", + Scopes = new[] { "https://api.example.com/.default" } + }; + + // Mock authentication failure + var mockResult = new OperationResult( + new AuthorizationHeaderError("token_acquisition_failed", "Failed to acquire token")); + + ((IAuthorizationHeaderProvider2)mockBoundProvider) + .CreateAuthorizationHeaderAsync( + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(mockResult); + + // Act & Assert + await Assert.ThrowsAnyAsync(() => + downstreamApi.CallApiInternalAsync(null, options, false, null, null, CancellationToken.None)); + } + + private static X509Certificate2 CreateTestCertificate() + { + // Create a simple test certificate for mocking purposes + // We don't need a real certificate with private key for HTTP client factory testing + var bytes = Convert.FromBase64String(TestConstants.CertificateX5c); + +#if NET9_0_OR_GREATER + // Use the new X509CertificateLoader for .NET 9.0+ + return X509CertificateLoader.LoadCertificate(bytes); +#else + // Use the legacy constructor for older frameworks +#pragma warning disable SYSLIB0057 // Type or member is obsolete + return new X509Certificate2(bytes); +#pragma warning restore SYSLIB0057 // Type or member is obsolete +#endif + } } public class Person diff --git a/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/ExtraParametersTests.cs b/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/ExtraParametersTests.cs index 9714e3ea0..e63810920 100644 --- a/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/ExtraParametersTests.cs +++ b/tests/Microsoft.Identity.Web.Test/DownstreamWebApiSupport/ExtraParametersTests.cs @@ -12,6 +12,7 @@ using Microsoft.Extensions.Options; using Microsoft.Identity.Abstractions; using Microsoft.Identity.Web.Test.Resource; +using NSubstitute; using Xunit; namespace Microsoft.Identity.Web.Tests diff --git a/tests/Microsoft.Identity.Web.Test/MsalMtlsHttpClientFactoryTests.cs b/tests/Microsoft.Identity.Web.Test/MsalMtlsHttpClientFactoryTests.cs new file mode 100644 index 000000000..8a779111b --- /dev/null +++ b/tests/Microsoft.Identity.Web.Test/MsalMtlsHttpClientFactoryTests.cs @@ -0,0 +1,197 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if SUPPORTS_MTLS +using System; +using System.Net.Http; +using System.Security.Cryptography.X509Certificates; +using Microsoft.Identity.Web.Test.Common; +using Xunit; + +namespace Microsoft.Identity.Web.Test +{ + public class MsalMtlsHttpClientFactoryTests : IDisposable + { + private readonly TestHttpClientFactory _httpClientFactory; + private readonly MsalMtlsHttpClientFactory _factory; + private bool _disposed = false; + + public MsalMtlsHttpClientFactoryTests() + { + _httpClientFactory = new TestHttpClientFactory(); + _factory = new MsalMtlsHttpClientFactory(_httpClientFactory); + } + + [Fact] + public void Constructor_WithValidHttpClientFactory_ShouldNotThrow() + { + // Arrange & Act + var factory = new MsalMtlsHttpClientFactory(_httpClientFactory); + + // Assert + Assert.NotNull(factory); + } + + [Fact] + public void Constructor_WithNullHttpClientFactory_ShouldAcceptNull() + { + // Arrange & Act + var factory = new MsalMtlsHttpClientFactory(null!); + + // Assert + Assert.NotNull(factory); + } + + [Fact] + public void GetHttpClient_WithoutCertificate_ShouldReturnConfiguredHttpClient() + { + // Arrange & Act + HttpClient actualHttpClient = _factory.GetHttpClient(); + + // Assert + Assert.NotNull(actualHttpClient); + + // Verify telemetry header is present + Assert.True(actualHttpClient.DefaultRequestHeaders.Contains(Constants.TelemetryHeaderKey)); + + var telemetryHeaderValues = actualHttpClient.DefaultRequestHeaders.GetValues(Constants.TelemetryHeaderKey); + Assert.Single(telemetryHeaderValues); + } + + [Fact] + public void GetHttpClient_WithNullCertificate_ShouldReturnConfiguredHttpClient() + { + // Arrange & Act + HttpClient actualHttpClient = _factory.GetHttpClient(null!); + + // Assert + Assert.NotNull(actualHttpClient); + Assert.True(actualHttpClient.DefaultRequestHeaders.Contains(Constants.TelemetryHeaderKey)); + } + +#if SUPPORTS_MTLS + [Fact] + public void GetHttpClient_WithSameCertificate_ShouldReturnCachedClient() + { + // Arrange + using var certificate = CreateTestCertificate(); + + // Act + HttpClient firstClient = _factory.GetHttpClient(certificate); + HttpClient secondClient = _factory.GetHttpClient(certificate); + + // Assert + Assert.Same(firstClient, secondClient); + } + + [Fact] + public void GetHttpClient_WithCertificate_ShouldConfigureProperHeaders() + { + // Arrange + using var certificate = CreateTestCertificate(); + + // Act + HttpClient httpClient = _factory.GetHttpClient(certificate); + + // Assert + // Verify telemetry header + Assert.True(httpClient.DefaultRequestHeaders.Contains(Constants.TelemetryHeaderKey)); + + // Verify max response buffer size + Assert.Equal(1024 * 1024, httpClient.MaxResponseContentBufferSize); + } +#else + [Fact] + public void GetHttpClient_WithCertificateOnUnsupportedPlatform_ShouldThrowNotSupportedException() + { + // Arrange + using var certificate = CreateTestCertificate(); + + // Act & Assert + Assert.Throws(() => _factory.GetHttpClient(certificate)); + } +#endif + + [Fact] + public void GetHttpClient_CreatesClientFromFactory() + { + // Arrange & Act + _factory.GetHttpClient(); + + // Assert + Assert.True(_httpClientFactory.CreateClientCalled); + } + + [Fact] + public void GetHttpClient_MultipleCalls_CallsFactoryEachTime() + { + // Arrange & Act + _factory.GetHttpClient(); + _factory.GetHttpClient(); + + // Assert + Assert.Equal(2, _httpClientFactory.CreateClientCallCount); + } + + private static X509Certificate2 CreateTestCertificate() + { + // Create a simple test certificate for mocking purposes + // We don't need a real certificate with private key for HTTP client factory testing + var bytes = Convert.FromBase64String(TestConstants.CertificateX5c); + +#if NET9_0_OR_GREATER + // Use the new X509CertificateLoader for .NET 9.0+ + return X509CertificateLoader.LoadCertificate(bytes); +#else + // Use the legacy constructor for older frameworks +#pragma warning disable SYSLIB0057 // Type or member is obsolete + return new X509Certificate2(bytes); +#pragma warning restore SYSLIB0057 // Type or member is obsolete +#endif + } + + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + protected virtual void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + _httpClientFactory?.Dispose(); + } + _disposed = true; + } + } + + /// + /// Simple test HttpClientFactory implementation for testing purposes. + /// + private sealed class TestHttpClientFactory : IHttpClientFactory, IDisposable + { + public bool CreateClientCalled { get; private set; } + public int CreateClientCallCount { get; private set; } + private bool _disposed = false; + + public HttpClient CreateClient(string name) + { + CreateClientCalled = true; + CreateClientCallCount++; + return new HttpClient(); + } + + public void Dispose() + { + if (!_disposed) + { + _disposed = true; + } + } + } + } +} +#endif \ No newline at end of file diff --git a/tests/Microsoft.Identity.Web.Test/ServiceCollectionExtensionsTests.cs b/tests/Microsoft.Identity.Web.Test/ServiceCollectionExtensionsTests.cs index 5815eb5c7..570f5debe 100644 --- a/tests/Microsoft.Identity.Web.Test/ServiceCollectionExtensionsTests.cs +++ b/tests/Microsoft.Identity.Web.Test/ServiceCollectionExtensionsTests.cs @@ -6,6 +6,8 @@ using System.Linq; using System.Threading.Tasks; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Diagnostics.Metrics.Configuration; +using Microsoft.Extensions.Http; using Microsoft.Extensions.Options; using Microsoft.Identity.Abstractions; using Microsoft.Identity.Client; @@ -34,7 +36,8 @@ public void AddTokenAcquisition_Sdk_AddsWithCorrectLifetime() Assert.Equal(typeof(MicrosoftIdentityApplicationOptionsMerger), actual.ImplementationType); Assert.Null(actual.ImplementationInstance); Assert.Null(actual.ImplementationFactory); - }, actual => + }, + actual => { Assert.Equal(ServiceLifetime.Singleton, actual.Lifetime); Assert.Equal(typeof(IPostConfigureOptions), actual.ServiceType); @@ -90,13 +93,13 @@ public void AddTokenAcquisition_Sdk_AddsWithCorrectLifetime() Assert.Null(actual.ImplementationFactory); }, actual => - { - Assert.Equal(ServiceLifetime.Scoped, actual.Lifetime); - Assert.Equal(typeof(ITokenAcquisition), actual.ServiceType); - Assert.Equal(typeof(TokenAcquisitionAspNetCore), actual.ImplementationType); - Assert.Null(actual.ImplementationInstance); - Assert.Null(actual.ImplementationFactory); - }, + { + Assert.Equal(ServiceLifetime.Scoped, actual.Lifetime); + Assert.Equal(typeof(ITokenAcquisition), actual.ServiceType); + Assert.Equal(typeof(TokenAcquisitionAspNetCore), actual.ImplementationType); + Assert.Null(actual.ImplementationInstance); + Assert.Null(actual.ImplementationFactory); + }, actual => { Assert.Equal(ServiceLifetime.Scoped, actual.Lifetime); diff --git a/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs b/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs new file mode 100644 index 000000000..3fcab3b21 --- /dev/null +++ b/tests/Microsoft.Identity.Web.Test/TokenAcquirerTests.cs @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +using System; +using System.Reflection; +using System.Security.Cryptography; +using System.Security.Cryptography.X509Certificates; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Identity.Abstractions; +using Microsoft.Identity.Client; +using NSubstitute; +using Xunit; + +namespace Microsoft.Identity.Web.Test +{ + public class TokenAcquirerTests + { + private readonly ITokenAcquisition _tokenAcquisition; + private readonly string _scope = "https://graph.microsoft.com/.default"; + private readonly string _accessToken = "test_access_token"; + private readonly string _tenantId = "test_tenant_id"; + private readonly string _idToken = "test_id_token"; + private readonly DateTimeOffset _expiresOn = DateTimeOffset.UtcNow.AddHours(1); + private readonly Guid _correlationId = Guid.NewGuid(); + private readonly string _tokenType = "Bearer"; + private readonly string _authenticationScheme = "TestScheme"; + private readonly X509Certificate2 _bindingCertificate; + + public TokenAcquirerTests() + { + _tokenAcquisition = Substitute.For(); + + // Create a test certificate for BindingCertificate scenarios + using var rsa = RSA.Create(); + var request = new CertificateRequest(new X500DistinguishedName("CN=Test"), rsa, HashAlgorithmName.SHA256, RSASignaturePadding.Pkcs1); + _bindingCertificate = request.CreateSelfSigned(DateTimeOffset.UtcNow, DateTimeOffset.UtcNow.AddYears(1)); + } + + [Fact] + public async Task GetTokenForAppAsync_WithoutBindingCertificate_ReturnsCorrectAcquireTokenResult() + { + // Arrange + var authResult = CreateMockAuthenticationResult(); + _tokenAcquisition.GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(authResult); + + var tokenAcquirer = new TokenAcquirer(_tokenAcquisition, _authenticationScheme); + + // Act + var result = await ((ITokenAcquirer)tokenAcquirer).GetTokenForAppAsync( + _scope, + null, + CancellationToken.None); + + // Assert + Assert.NotNull(result); + Assert.Equal(_accessToken, result.AccessToken); + Assert.Equal(_expiresOn, result.ExpiresOn); + Assert.Equal(_tenantId, result.TenantId); + Assert.Equal(_idToken, result.IdToken); + Assert.Equal(new[] { _scope }, result.Scopes); + Assert.Equal(_correlationId, result.CorrelationId); + Assert.Equal(_tokenType, result.TokenType); + Assert.Null(result.BindingCertificate); + } + + [Fact] + public async Task GetTokenForAppAsync_WithBindingCertificate_ReturnsAcquireTokenResultWithBindingCertificate() + { + // Arrange + var authResult = CreateMockAuthenticationResult(bindingCertificate: _bindingCertificate); + _tokenAcquisition.GetAuthenticationResultForAppAsync( + Arg.Any(), + Arg.Any(), + Arg.Any(), + Arg.Any()) + .Returns(authResult); + + var tokenAcquirer = new TokenAcquirer(_tokenAcquisition, _authenticationScheme); + + // Act + var result = await ((ITokenAcquirer)tokenAcquirer).GetTokenForAppAsync( + _scope, + null, + CancellationToken.None); + + // Assert + Assert.NotNull(result); + Assert.Equal(_accessToken, result.AccessToken); + Assert.Equal(_expiresOn, result.ExpiresOn); + Assert.Equal(_tenantId, result.TenantId); + Assert.Equal(_idToken, result.IdToken); + Assert.Equal(new[] { _scope }, result.Scopes); + Assert.Equal(_correlationId, result.CorrelationId); + Assert.Equal(_tokenType, result.TokenType); + Assert.NotNull(result.BindingCertificate); + Assert.Equal(_bindingCertificate.Thumbprint, result.BindingCertificate.Thumbprint); + } + + private AuthenticationResult CreateMockAuthenticationResult(X509Certificate2? bindingCertificate = null) + { + var authResult = new AuthenticationResult( + _accessToken, + false, + null, + _expiresOn, + _expiresOn, + _tenantId, + null, + _idToken, + new[] { _scope }, + _correlationId); + + // Unfortunately, MSAL's AuthenticationResult.BindingCertificate is not settable, + // and we can't mock it, so we'll use a custom AuthenticationResult wrapper + // or test the functionality through integration tests + if (bindingCertificate != null) + { + // Use reflection to set the BindingCertificate property since it's not exposed in the constructor + var bindingCertificateProperty = typeof(AuthenticationResult).GetProperty("BindingCertificate"); + bindingCertificateProperty?.SetValue(authResult, bindingCertificate); + } + + return authResult; + } + } +} diff --git a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs index b998d57a2..9d12b6d36 100644 --- a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs +++ b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionAuthorityTests.cs @@ -478,9 +478,9 @@ public async Task GetOrBuildManagedIdentity_TestAsync(string? clientId) InitializeTokenAcquisitionObjects(); // Act - var app1 = + var app1 = await _tokenAcquisition.GetOrBuildManagedIdentityApplicationAsync(mergedOptions, managedIdentityOptions); - var app2 = + var app2 = await _tokenAcquisition.GetOrBuildManagedIdentityApplicationAsync(mergedOptions, managedIdentityOptions); // Assert diff --git a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionTests.cs b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionTests.cs index bc8ef2fd7..5fe5f40e6 100644 --- a/tests/Microsoft.Identity.Web.Test/TokenAcquisitionTests.cs +++ b/tests/Microsoft.Identity.Web.Test/TokenAcquisitionTests.cs @@ -126,6 +126,42 @@ public async Task ExtraBodyParametersAreSentToEndpointTest() Assert.Equal("Bearer header.payload.signature", result); } + [Theory] + [InlineData(true, true)] + [InlineData(false, false)] + [InlineData(null, false)] + public void GetMergedOptions_SetsIsTokenBindingCorrectly(bool? requestBoundToken, bool expectedIsTokenBinding) + { + // Arrange + var tokenAcquirerFactory = InitTokenAcquirerFactory(); + IServiceProvider serviceProvider = tokenAcquirerFactory.Build(); + var tokenAcquisition = serviceProvider.GetRequiredService() as TokenAcquisition; + + var tokenAcquisitionOptions = new TokenAcquisitionOptions(); + + if (requestBoundToken.HasValue) + { + tokenAcquisitionOptions.ExtraParameters = new Dictionary + { + { "RequestBoundToken", requestBoundToken.Value } + }; + } + + // Act + // Use reflection to access the private GetMergedOptions method + var method = typeof(TokenAcquisition).GetMethod("GetMergedOptions", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + Assert.NotNull(method); + +#pragma warning disable CS8601 // Possible null reference assignment. + var mergedOptions = method.Invoke(tokenAcquisition, new object?[] { null, tokenAcquisitionOptions }) as MergedOptions; +#pragma warning restore CS8601 // Possible null reference assignment. + + // Assert + Assert.NotNull(mergedOptions); + Assert.Equal(expectedIsTokenBinding, mergedOptions.IsTokenBinding); + } + private TokenAcquirerFactory InitTokenAcquirerFactory() { TokenAcquirerFactoryTesting.ResetTokenAcquirerFactoryInTest();