Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
<PropertyGroup>
<!-- For files to appear in the Visual Studio Solution explorer given we have conditional inclusion in some projects (IdWeb for instance)
we need to have the higher framework, even if this produces a warning in the IDE -->
<!-- Please update SUPPORTS_MTLS constant below if needed when targeting new frameworks -->
<TargetFrameworks>net8.0; net9.0; net10.0; net462; net472; netstandard2.0</TargetFrameworks>
<SignAssembly>true</SignAssembly>
<AssemblyOriginatorKeyFile>../../build/MSAL.snk</AssemblyOriginatorKeyFile>
Expand Down Expand Up @@ -199,4 +200,8 @@
<IncludeAssets>runtime; build; native; contentfiles; analyzers</IncludeAssets>
</PackageReference>
</ItemGroup>

<PropertyGroup Condition="'$(TargetFramework)' != 'net462'">
<DefineConstants>$(DefineConstants);SUPPORTS_MTLS;</DefineConstants>
</PropertyGroup>
</Project>
212 changes: 136 additions & 76 deletions src/Microsoft.Identity.Web.DownstreamApi/DownstreamApi.cs
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Identity.Abstractions;
using System;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Linq;
using System.Net.Http;
using System.Runtime.CompilerServices;
using System.Security.Claims;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Microsoft.Identity.Abstractions;
using Microsoft.Identity.Client;

namespace Microsoft.Identity.Web
Expand All @@ -26,6 +26,12 @@ internal partial class DownstreamApi : IDownstreamApi
{
private readonly IAuthorizationHeaderProvider _authorizationHeaderProvider;
private readonly IHttpClientFactory _httpClientFactory;

// This MSAL HTTP client factory is used to create HTTP clients with mTLS binding certificate.
// Note, that it doesn't replace _httpClientFactory to keep backward compatibility and ability
// to create named HTTP clients for non-mTLS scenarios.
private readonly IMsalHttpClientFactory? _msalHttpClientFactory;

private readonly IOptionsMonitor<DownstreamApiOptions> _namedDownstreamApiOptions;
private const string Authorization = "Authorization";
protected readonly ILogger<DownstreamApi> _logger;
Expand All @@ -43,10 +49,33 @@ public DownstreamApi(
IOptionsMonitor<DownstreamApiOptions> namedDownstreamApiOptions,
IHttpClientFactory httpClientFactory,
ILogger<DownstreamApi> logger)
: this(authorizationHeaderProvider,
namedDownstreamApiOptions,
httpClientFactory,
logger,
msalHttpClientFactory: null)
{
}

/// <summary>
/// Constructor which accepts optional MSAL HTTP client factory.
/// </summary>
/// <param name="authorizationHeaderProvider">Authorization header provider.</param>
/// <param name="namedDownstreamApiOptions">Named options provider.</param>
/// <param name="httpClientFactory">HTTP client factory.</param>
/// <param name="logger">Logger.</param>
/// <param name="msalHttpClientFactory">The MSAL HTTP client factory for mTLS PoP scenarios.</param>
public DownstreamApi(
IAuthorizationHeaderProvider authorizationHeaderProvider,
IOptionsMonitor<DownstreamApiOptions> namedDownstreamApiOptions,
IHttpClientFactory httpClientFactory,
ILogger<DownstreamApi> logger,
IMsalHttpClientFactory? msalHttpClientFactory)
{
_authorizationHeaderProvider = authorizationHeaderProvider;
_namedDownstreamApiOptions = namedDownstreamApiOptions;
_httpClientFactory = httpClientFactory;
_msalHttpClientFactory = msalHttpClientFactory ?? new MsalMtlsHttpClientFactory(httpClientFactory);
_logger = logger;
}

Expand Down Expand Up @@ -436,7 +465,7 @@ public Task<HttpResponseMessage> CallApiForAppAsync(
string stringContent = await content.ReadAsStringAsync();
if (mediaType == "application/json")
{
return JsonSerializer.Deserialize<TOutput>(stringContent, new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
return JsonSerializer.Deserialize<TOutput>(stringContent, new JsonSerializerOptions { PropertyNameCaseInsensitive = true });
}
if (mediaType != null && !mediaType.StartsWith("text/", StringComparison.OrdinalIgnoreCase))
{
Expand Down Expand Up @@ -514,11 +543,17 @@ public Task<HttpResponseMessage> CallApiForAppAsync(
new HttpMethod(effectiveOptions.HttpMethod),
apiUrl);

await UpdateRequestAsync(httpRequestMessage, content, effectiveOptions, appToken, user, cancellationToken);
// Request result will contain authorization header and potentially binding certificate for mTLS
var requestResult = await UpdateRequestAsync(httpRequestMessage, content, effectiveOptions, appToken, user, cancellationToken);

using HttpClient client = string.IsNullOrEmpty(serviceName) ? _httpClientFactory.CreateClient() : _httpClientFactory.CreateClient(serviceName);
// If a binding certificate is specified (which means mTLS is required) and MSAL mTLS HTTP factory is present
// then create an HttpClient with the certificate by using IMsalMtlsHttpClientFactory.
// Otherwise use the default HttpClientFactory with optional named client.
using HttpClient client = requestResult?.BindingCertificate != null && _msalHttpClientFactory != null && _msalHttpClientFactory is IMsalMtlsHttpClientFactory msalMtlsHttpClientFactory
? msalMtlsHttpClientFactory.GetHttpClient(requestResult.BindingCertificate)
: (string.IsNullOrEmpty(serviceName) ? _httpClientFactory.CreateClient() : _httpClientFactory.CreateClient(serviceName));

// Send the HTTP message
// Send the HTTP message
var downstreamApiResult = await client.SendAsync(httpRequestMessage, cancellationToken).ConfigureAwait(false);

// Retry only if the resource sent 401 Unauthorized with WWW-Authenticate header and claims
Expand All @@ -541,7 +576,7 @@ public Task<HttpResponseMessage> CallApiForAppAsync(
return downstreamApiResult;
}

internal /* internal for test */ async Task UpdateRequestAsync(
internal /* internal for test */ async Task<AuthorizationHeaderInformation?> UpdateRequestAsync(
HttpRequestMessage httpRequestMessage,
HttpContent? content,
DownstreamApiOptions effectiveOptions,
Expand All @@ -558,15 +593,38 @@ public Task<HttpResponseMessage> CallApiForAppAsync(

effectiveOptions.RequestAppToken = appToken;

AuthorizationHeaderInformation? authorizationHeaderInformation = null;

// Obtention of the authorization header (except when calling an anonymous endpoint
// which is done by not specifying any scopes
if (effectiveOptions.Scopes != null && effectiveOptions.Scopes.Any())
{
string authorizationHeader = await _authorizationHeaderProvider.CreateAuthorizationHeaderAsync(
effectiveOptions.Scopes,
effectiveOptions,
user,
cancellationToken).ConfigureAwait(false);
string authorizationHeader = string.Empty;

// Firstly check if it's token binding scenario so authorization header provider returns

// a binding certificate along with acquired authorization header.
if (_authorizationHeaderProvider is IAuthorizationHeaderProvider2 authorizationHeaderBoundProviderForMtls)
{
var authorizationHeaderResult = await authorizationHeaderBoundProviderForMtls.CreateAuthorizationHeaderAsync(
effectiveOptions,
user,
cancellationToken).ConfigureAwait(false);

if (authorizationHeaderResult.Succeeded)
{
authorizationHeaderInformation = authorizationHeaderResult.Result;
authorizationHeader = authorizationHeaderInformation?.AuthorizationHeaderValue ?? string.Empty;
}
}
else
{
authorizationHeader = await _authorizationHeaderProvider.CreateAuthorizationHeaderAsync(
effectiveOptions.Scopes,
effectiveOptions,
user,
cancellationToken).ConfigureAwait(false);
}

if (authorizationHeader.StartsWith(AuthSchemeDstsSamlBearer, StringComparison.OrdinalIgnoreCase))
{
Expand All @@ -582,54 +640,56 @@ public Task<HttpResponseMessage> CallApiForAppAsync(
{
Logger.UnauthenticatedApiCall(_logger, null);
}
if (!string.IsNullOrEmpty(effectiveOptions.AcceptHeader))
{
httpRequestMessage.Headers.Accept.ParseAdd(effectiveOptions.AcceptHeader);
}

// Add extra headers if specified directly on DownstreamApiOptions
if (effectiveOptions.ExtraHeaderParameters != null)
{
foreach (var header in effectiveOptions.ExtraHeaderParameters)
{
httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
}

// Add extra query parameters if specified directly on DownstreamApiOptions
if (effectiveOptions.ExtraQueryParameters != null && effectiveOptions.ExtraQueryParameters.Count > 0)
{
var uriBuilder = new UriBuilder(httpRequestMessage.RequestUri!);
var existingQuery = uriBuilder.Query;
var queryString = new StringBuilder(existingQuery);
foreach (var queryParam in effectiveOptions.ExtraQueryParameters)
{
if (queryString.Length > 1) // if there are existing query parameters
{
queryString.Append('&');
}
else if (queryString.Length == 0)
{
queryString.Append('?');
}
queryString.Append(Uri.EscapeDataString(queryParam.Key));
queryString.Append('=');
queryString.Append(Uri.EscapeDataString(queryParam.Value));
}
uriBuilder.Query = queryString.ToString().TrimStart('?');
httpRequestMessage.RequestUri = uriBuilder.Uri;
}

// Opportunity to change the request message
if (!string.IsNullOrEmpty(effectiveOptions.AcceptHeader))
{
httpRequestMessage.Headers.Accept.ParseAdd(effectiveOptions.AcceptHeader);
}
// Add extra headers if specified directly on DownstreamApiOptions
if (effectiveOptions.ExtraHeaderParameters != null)
{
foreach (var header in effectiveOptions.ExtraHeaderParameters)
{
httpRequestMessage.Headers.TryAddWithoutValidation(header.Key, header.Value);
}
}
// Add extra query parameters if specified directly on DownstreamApiOptions
if (effectiveOptions.ExtraQueryParameters != null && effectiveOptions.ExtraQueryParameters.Count > 0)
{
var uriBuilder = new UriBuilder(httpRequestMessage.RequestUri!);
var existingQuery = uriBuilder.Query;
var queryString = new StringBuilder(existingQuery);
foreach (var queryParam in effectiveOptions.ExtraQueryParameters)
{
if (queryString.Length > 1) // if there are existing query parameters
{
queryString.Append('&');
}
else if (queryString.Length == 0)
{
queryString.Append('?');
}
queryString.Append(Uri.EscapeDataString(queryParam.Key));
queryString.Append('=');
queryString.Append(Uri.EscapeDataString(queryParam.Value));
}
uriBuilder.Query = queryString.ToString().TrimStart('?');
httpRequestMessage.RequestUri = uriBuilder.Uri;
}
// Opportunity to change the request message
effectiveOptions.CustomizeHttpRequestMessage?.Invoke(httpRequestMessage);

return authorizationHeaderInformation;
}

internal /* for test */ static Dictionary<string, string> CallerSDKDetails { get; } = new()
{
{ "caller-sdk-id", "IdWeb_1" },
{ "caller-sdk-id", "IdWeb_1" },
{ "caller-sdk-ver", IdHelper.GetIdWebVersion() }
};

Expand Down Expand Up @@ -657,33 +717,33 @@ private static void AddCallerSDKTelemetry(DownstreamApiOptions effectiveOptions)
internal static async Task<string> ReadErrorResponseContentAsync(HttpResponseMessage response, CancellationToken cancellationToken = default)
{
const int maxErrorContentLength = 4096;

long? contentLength = response.Content.Headers.ContentLength;

if (contentLength.HasValue && contentLength.Value > maxErrorContentLength)
{
return $"[Error response too large: {contentLength.Value} bytes, not captured]";
}

// Use streaming to read only up to maxErrorContentLength to avoid loading entire response into memory
#if NET5_0_OR_GREATER
using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);
#else
using var stream = await response.Content.ReadAsStreamAsync().ConfigureAwait(false);
#endif
using var reader = new StreamReader(stream);

char[] buffer = new char[maxErrorContentLength];
int readCount = await reader.ReadBlockAsync(buffer, 0, maxErrorContentLength).ConfigureAwait(false);

string errorResponseContent = new string(buffer, 0, readCount);

// Check if there's more content that was truncated
if (readCount == maxErrorContentLength && reader.Peek() != -1)
{
errorResponseContent += "... (truncated)";
}

return errorResponseContent;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#nullable enable
Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor<Microsoft.Identity.Abstractions.DownstreamApiOptions!>! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger<Microsoft.Identity.Web.DownstreamApi!>! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void
Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task<Microsoft.Identity.Abstractions.AuthorizationHeaderInformation?>!
static Microsoft.Identity.Web.DownstreamApi.DeserializeOutputAsync<TOutput>(System.Net.Http.HttpResponseMessage! response, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, System.Text.Json.Serialization.Metadata.JsonTypeInfo<TOutput!>! outputJsonTypeInfo, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<TOutput?>!
static Microsoft.Identity.Web.DownstreamApi.DeserializeOutputAsync<TOutput>(System.Net.Http.HttpResponseMessage! response, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task<TOutput?>!
static Microsoft.Identity.Web.DownstreamApi.Logger.HttpRequestError(Microsoft.Extensions.Logging.ILogger! logger, string! ServiceName, string! BaseUrl, string! RelativePath, int statusCode, string! responseContent, System.Exception? ex) -> void
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor<Microsoft.Identity.Abstractions.DownstreamApiOptions!>! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger<Microsoft.Identity.Web.DownstreamApi!>! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void
Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task<Microsoft.Identity.Abstractions.AuthorizationHeaderInformation?>!
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
#nullable enable
Microsoft.Identity.Web.DownstreamApi.DownstreamApi(Microsoft.Identity.Abstractions.IAuthorizationHeaderProvider! authorizationHeaderProvider, Microsoft.Extensions.Options.IOptionsMonitor<Microsoft.Identity.Abstractions.DownstreamApiOptions!>! namedDownstreamApiOptions, System.Net.Http.IHttpClientFactory! httpClientFactory, Microsoft.Extensions.Logging.ILogger<Microsoft.Identity.Web.DownstreamApi!>! logger, Microsoft.Identity.Client.IMsalHttpClientFactory? msalHttpClientFactory) -> void
Microsoft.Identity.Web.DownstreamApi.UpdateRequestAsync(System.Net.Http.HttpRequestMessage! httpRequestMessage, System.Net.Http.HttpContent? content, Microsoft.Identity.Abstractions.DownstreamApiOptions! effectiveOptions, bool appToken, System.Security.Claims.ClaimsPrincipal? user, System.Threading.CancellationToken cancellationToken) -> System.Threading.Tasks.Task<Microsoft.Identity.Abstractions.AuthorizationHeaderInformation?>!
Loading
Loading