diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs index 65a4f56597..17b3c56f3a 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AbstractManagedIdentity.cs @@ -12,6 +12,9 @@ using System.Net; using Microsoft.Identity.Client.ApiConfig.Parameters; using System.Text; +using System.Security.Cryptography.X509Certificates; +using System.Net.Security; + #if SUPPORTS_SYSTEM_TEXT_JSON using System.Text.Json; #else @@ -22,11 +25,13 @@ namespace Microsoft.Identity.Client.ManagedIdentity { internal abstract class AbstractManagedIdentity { + private const string ManagedIdentityPrefix = "[Managed Identity] "; + protected readonly RequestContext _requestContext; + internal const string TimeoutError = "[Managed Identity] Authentication unavailable. The request to the managed identity endpoint timed out."; internal readonly ManagedIdentitySource _sourceType; - private const string ManagedIdentityPrefix = "[Managed Identity] "; - + protected AbstractManagedIdentity(RequestContext requestContext, ManagedIdentitySource sourceType) { _requestContext = requestContext; @@ -65,7 +70,7 @@ public virtual async Task AuthenticateAsync( logger: _requestContext.Logger, doNotThrow: true, mtlsCertificate: null, - validateServerCertificate: ValidateServerCertificate, + validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, retryPolicy: request.RetryPolicy).ConfigureAwait(false); } @@ -80,7 +85,7 @@ public virtual async Task AuthenticateAsync( logger: _requestContext.Logger, doNotThrow: true, mtlsCertificate: null, - validateServerCertificate: ValidateServerCertificate, + validateServerCertificate: GetValidationCallback(), cancellationToken: cancellationToken, retryPolicy: request.RetryPolicy) .ConfigureAwait(false); @@ -96,13 +101,14 @@ public virtual async Task AuthenticateAsync( } } - // This method is used to validate the server certificate. - // It is overridden in the Service Fabric managed identity source to validate the certificate thumbprint. - // The default implementation always returns true. - internal virtual bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate, - System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors) + /// + /// Method to be overridden in the derived classes to provide a custom validation callback for the server certificate. + /// This validation is needed for service fabric managed identity endpoints. + /// + /// Callback to validate the server certificate. + internal virtual Func GetValidationCallback() { - return true; + return null; } protected virtual Task HandleResponseAsync( diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs index a8fb2379fd..a35ce1b1bf 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/ServiceFabricManagedIdentitySource.cs @@ -5,6 +5,7 @@ using System.Globalization; using System.Net.Http; using System.Net.Security; +using System.Security.Cryptography.X509Certificates; using Microsoft.Identity.Client.Core; using Microsoft.Identity.Client.Internal; @@ -15,6 +16,7 @@ internal class ServiceFabricManagedIdentitySource : AbstractManagedIdentity private const string ServiceFabricMsiApiVersion = "2019-07-01-preview"; private readonly Uri _endpoint; private readonly string _identityHeaderValue; + internal static Lazy _httpClientLazy; public static AbstractManagedIdentity Create(RequestContext requestContext) @@ -40,11 +42,17 @@ public static AbstractManagedIdentity Create(RequestContext requestContext) } requestContext.Logger.Verbose(() => "[Managed Identity] Creating Service Fabric managed identity. Endpoint URI: " + identityEndpoint); + return new ServiceFabricManagedIdentitySource(requestContext, endpointUri, EnvironmentVariables.IdentityHeader); } - internal override bool ValidateServerCertificate(HttpRequestMessage message, System.Security.Cryptography.X509Certificates.X509Certificate2 certificate, - System.Security.Cryptography.X509Certificates.X509Chain chain, System.Net.Security.SslPolicyErrors sslPolicyErrors) + internal override Func GetValidationCallback() + { + return ValidateServerCertificateCallback; + } + + private bool ValidateServerCertificateCallback(HttpRequestMessage message, X509Certificate2 certificate, + X509Chain chain, SslPolicyErrors sslPolicyErrors) { if (sslPolicyErrors == SslPolicyErrors.None) { diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs index 84d399a51d..809f725682 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ManagedIdentityTests.cs @@ -3,9 +3,12 @@ using System; using System.Diagnostics; +using System.Linq; using System.Net; using System.Net.Http; +using System.Net.Security; using System.Net.Sockets; +using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.Identity.Client; @@ -1330,5 +1333,89 @@ await mi.AcquireTokenForManagedIdentity(Resource) Assert.AreEqual(httpManager.QueueSize, 0); } } + + [TestMethod] + public void ValidateServerCertificate_OnlySetForServiceFabric() + { + using (new EnvVariableContext()) + using (var httpManager = new MockHttpManager()) + { + // Test all managed identity sources + foreach (ManagedIdentitySource sourceType in Enum.GetValues(typeof(ManagedIdentitySource)) + .Cast() + .Where(s => s != ManagedIdentitySource.None && s != ManagedIdentitySource.DefaultToImds)) + { + // Create a managed identity source for each type + AbstractManagedIdentity managedIdentity = CreateManagedIdentitySource(sourceType, httpManager); + + // Check if ValidateServerCertificate is set based on the source type + bool shouldHaveCallback = sourceType == ManagedIdentitySource.ServiceFabric; + bool hasCallback = managedIdentity.GetValidationCallback() != null; + + Assert.AreEqual( + shouldHaveCallback, + hasCallback, + $"For source type {sourceType}, ValidateServerCertificate should {(shouldHaveCallback ? "" : "not ")}be set"); + + // For ServiceFabric, verify it's set to the right method + if (sourceType == ManagedIdentitySource.ServiceFabric) + { + Assert.IsNotNull(managedIdentity.GetValidationCallback(), + "ServiceFabric should have ValidateServerCertificate set"); + + Assert.IsInstanceOfType(managedIdentity, typeof(ServiceFabricManagedIdentitySource), + "ServiceFabric managed identity should be of type ServiceFabricManagedIdentitySource"); + } + else + { + Assert.IsNull(managedIdentity.GetValidationCallback(), + $"Non-ServiceFabric source type {sourceType} should not have ValidateServerCertificate set"); + } + } + } + } + + private AbstractManagedIdentity CreateManagedIdentitySource(ManagedIdentitySource sourceType, MockHttpManager httpManager) + { + string endpoint = "https://identity.endpoint.com"; + + // Setup environment based on the source type + SetEnvironmentVariables(sourceType, endpoint); + + var miBuilder = ManagedIdentityApplicationBuilder.Create(ManagedIdentityId.SystemAssigned) + .WithHttpManager(httpManager); + + var managedIdentityApp = miBuilder.BuildConcrete(); + RequestContext requestContext = new RequestContext(managedIdentityApp.ServiceBundle, Guid.NewGuid(), null); + + // Create the correct managed identity source based on the type + AbstractManagedIdentity managedIdentity = null; + + switch (sourceType) + { + case ManagedIdentitySource.ServiceFabric: + managedIdentity = ServiceFabricManagedIdentitySource.Create(requestContext); + break; + case ManagedIdentitySource.AppService: + managedIdentity = AppServiceManagedIdentitySource.Create(requestContext); + break; + case ManagedIdentitySource.AzureArc: + managedIdentity = AzureArcManagedIdentitySource.Create(requestContext); + break; + case ManagedIdentitySource.CloudShell: + managedIdentity = CloudShellManagedIdentitySource.Create(requestContext); + break; + case ManagedIdentitySource.Imds: + managedIdentity = new ImdsManagedIdentitySource(requestContext); + break; + case ManagedIdentitySource.MachineLearning: + managedIdentity = MachineLearningManagedIdentitySource.Create(requestContext); + break; + default: + throw new NotSupportedException($"Unsupported managed identity source type: {sourceType}"); + } + + return managedIdentity; + } } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs index 91905f4897..8dc79d1e99 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/ServiceFabricTests.cs @@ -86,7 +86,7 @@ public void ValidateServerCertificateCallback_ServerCertificateValidationCallbac var sf = ServiceFabricManagedIdentitySource.Create(requestContext); Assert.IsInstanceOfType(sf, typeof(ServiceFabricManagedIdentitySource)); - var callback = ((ServiceFabricManagedIdentitySource)sf).ValidateServerCertificate(null, certificate, chain, sslPolicyErrors); + var callback = sf.GetValidationCallback(); Assert.IsNotNull(callback); } }