Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
using System.Net;
using Microsoft.Identity.Client.ApiConfig.Parameters;
using System.Text;
using System.Security.Cryptography.X509Certificates;
using System.Net.Security;

#if SUPPORTS_SYSTEM_TEXT_JSON
using System.Text.Json;
#else
Expand All @@ -22,11 +25,13 @@ namespace Microsoft.Identity.Client.ManagedIdentity
{
internal abstract class AbstractManagedIdentity
{
private const string ManagedIdentityPrefix = "[Managed Identity] ";

protected readonly RequestContext _requestContext;

internal const string TimeoutError = "[Managed Identity] Authentication unavailable. The request to the managed identity endpoint timed out.";
internal readonly ManagedIdentitySource _sourceType;
private const string ManagedIdentityPrefix = "[Managed Identity] ";


protected AbstractManagedIdentity(RequestContext requestContext, ManagedIdentitySource sourceType)
{
_requestContext = requestContext;
Expand Down Expand Up @@ -65,7 +70,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
validateServerCertificate: ValidateServerCertificate,
validateServerCertificate: GetValidationCallback(),
cancellationToken: cancellationToken,
retryPolicy: request.RetryPolicy).ConfigureAwait(false);
}
Expand All @@ -80,7 +85,7 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
logger: _requestContext.Logger,
doNotThrow: true,
mtlsCertificate: null,
validateServerCertificate: ValidateServerCertificate,
validateServerCertificate: GetValidationCallback(),
cancellationToken: cancellationToken,
retryPolicy: request.RetryPolicy)
.ConfigureAwait(false);
Expand All @@ -96,13 +101,14 @@ public virtual async Task<ManagedIdentityResponse> AuthenticateAsync(
}
}

// This method is used to validate the server certificate.
// It is overridden in the Service Fabric managed identity source to validate the certificate thumbprint.
// The default implementation always returns true.
internal virtual bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
/// <summary>
/// Method to be overridden in the derived classes to provide a custom validation callback for the server certificate.
/// This validation is needed for service fabric managed identity endpoints.
/// </summary>
/// <returns>Callback to validate the server certificate.</returns>
internal virtual Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> GetValidationCallback()
{
return true;
return null;
}

protected virtual Task<ManagedIdentityResponse> HandleResponseAsync(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Globalization;
using System.Net.Http;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using Microsoft.Identity.Client.Core;
using Microsoft.Identity.Client.Internal;

Expand All @@ -15,6 +16,7 @@ internal class ServiceFabricManagedIdentitySource : AbstractManagedIdentity
private const string ServiceFabricMsiApiVersion = "2019-07-01-preview";
private readonly Uri _endpoint;
private readonly string _identityHeaderValue;

internal static Lazy<HttpClient> _httpClientLazy;

public static AbstractManagedIdentity Create(RequestContext requestContext)
Expand All @@ -40,11 +42,17 @@ public static AbstractManagedIdentity Create(RequestContext requestContext)
}

requestContext.Logger.Verbose(() => "[Managed Identity] Creating Service Fabric managed identity. Endpoint URI: " + identityEndpoint);

return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader);
}

internal override bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate,
System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors)
internal override Func<HttpRequestMessage, X509Certificate2, X509Chain, SslPolicyErrors, bool> GetValidationCallback()
{
return ValidateServerCertificateCallback;
}

private bool ValidateServerCertificateCallback(HttpRequestMessage message, X509Certificate2 certificate,
X509Chain chain, SslPolicyErrors sslPolicyErrors)
{
if (sslPolicyErrors == SslPolicyErrors.None)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@

using System;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Net.Security;
using System.Net.Sockets;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
Expand Down Expand Up @@ -1330,5 +1333,89 @@ await mi.AcquireTokenForManagedIdentity(Resource)
Assert.AreEqual(httpManager.QueueSize, 0);
}
}

[TestMethod]
public void ValidateServerCertificate_OnlySetForServiceFabric()
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
// Test all managed identity sources
foreach (ManagedIdentitySource sourceType in Enum.GetValues(typeof(ManagedIdentitySource))
.Cast<ManagedIdentitySource>()
.Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds))
{
// Create a managed identity source for each type
AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(sourceType, httpManager);

// Check if ValidateServerCertificate is set based on the source type
bool shouldHaveCallback = sourceType == ManagedIdentitySource.ServiceFabric;
bool hasCallback = managedIdentity.GetValidationCallback() != null;

Assert.AreEqual(
shouldHaveCallback,
hasCallback,
$"For source type {sourceType}, ValidateServerCertificate should {(shouldHaveCallback ? "" : "not ")}be set");

// For ServiceFabric, verify it's set to the right method
if (sourceType == ManagedIdentitySource.ServiceFabric)
{
Assert.IsNotNull(managedIdentity.GetValidationCallback(),
"ServiceFabric should have ValidateServerCertificate set");

Assert.IsInstanceOfType(managedIdentity, typeof(ServiceFabricManagedIdentitySource),
"ServiceFabric managed identity should be of type ServiceFabricManagedIdentitySource");
}
else
{
Assert.IsNull(managedIdentity.GetValidationCallback(),
$"Non-ServiceFabric source type {sourceType} should not have ValidateServerCertificate set");
}
}
}
}

private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySource sourceType, MockHttpManager httpManager)
{
string endpoint = "https://identity.endpoint.com";

// Setup environment based on the source type
SetEnvironmentVariables(sourceType, endpoint);

var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned)
.WithHttpManager(httpManager);

var managedIdentityApp = miBuilder.BuildConcrete();
RequestContext requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null);

// Create the correct managed identity source based on the type
AbstractManagedIdentity managedIdentity = null;

switch (sourceType)
{
case ManagedIdentitySource.ServiceFabric:
managedIdentity = ServiceFabricManagedIdentitySource.Create(requestContext);
break;
case ManagedIdentitySource.AppService:
managedIdentity = AppServiceManagedIdentitySource.Create(requestContext);
break;
case ManagedIdentitySource.AzureArc:
managedIdentity = AzureArcManagedIdentitySource.Create(requestContext);
break;
case ManagedIdentitySource.CloudShell:
managedIdentity = CloudShellManagedIdentitySource.Create(requestContext);
break;
case ManagedIdentitySource.Imds:
managedIdentity = new ImdsManagedIdentitySource(requestContext);
break;
case ManagedIdentitySource.MachineLearning:
managedIdentity = MachineLearningManagedIdentitySource.Create(requestContext);
break;
default:
throw new NotSupportedException($"Unsupported managed identity source type: {sourceType}");
}

return managedIdentity;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public void ValidateServerCertificateCallback_ServerCertificateValidationCallbac
var sf = ServiceFabricManagedIdentitySource.Create(requestContext);

Assert.IsInstanceOfType(sf, typeof(ServiceFabricManagedIdentitySource));
var callback = ((ServiceFabricManagedIdentitySource)sf).ValidateServerCertificate(null, certificate, chain, sslPolicyErrors);
var callback = sf.GetValidationCallback();
Assert.IsNotNull(callback);
}
}
Expand Down
Loading