diff --git a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs index b33afbe6e5..66a5dee5bb 100644 --- a/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs +++ b/src/client/Microsoft.Identity.Client/ManagedIdentity/AzureArcManagedIdentitySource.cs @@ -138,36 +138,18 @@ private void ValidateSplitChallenge(string[] splitChallenge) } _requestContext.Logger.Verbose(() => $"[Managed Identity] Challenge is valid. FilePath: {splitChallenge[1]}"); + string path = Path.GetFullPath(new Uri(splitChallenge[1]).LocalPath) + .TrimEnd(Path.DirectorySeparatorChar, Path.AltDirectorySeparatorChar); - if (DesktopOsHelper.IsWindows()) - { - if (!IsValidWindowsPath(splitChallenge[1])) - { - throw CreateManagedIdentityException( - MsalError.ManagedIdentityRequestFailed, - MsalErrorMessage.ManagedIdentityInvalidFile); - } - - _requestContext.Logger.Verbose(() => "[Managed Identity] Windows path is valid."); - } - else if (DesktopOsHelper.IsLinux()) - { - if (!IsValidLinuxPath(splitChallenge[1])) - { - throw CreateManagedIdentityException( - MsalError.ManagedIdentityRequestFailed, - MsalErrorMessage.ManagedIdentityInvalidFile); - } - - _requestContext.Logger.Verbose(() => "[Managed Identity] Linux path is valid."); - } - else + if (!IsValidPath(splitChallenge[1])) { throw CreateManagedIdentityException( MsalError.ManagedIdentityRequestFailed, - MsalErrorMessage.ManagedIdentityPlatformNotSupported); + MsalErrorMessage.ManagedIdentityInvalidFile); } + _requestContext.Logger.Verbose(() => $"[Managed Identity] File path is valid. Path: {path}"); + var length = new FileInfo(splitChallenge[1]).Length; if ((!File.Exists(splitChallenge[1]) || (length) > 4096)) @@ -191,19 +173,28 @@ private MsalException CreateManagedIdentityException(string errorCode, string er null); } - private bool IsValidLinuxPath(string path) + private bool IsValidPath(string path) { - string linuxPath = "/var/opt/azcmagent/tokens/"; + string expectedFilePath; - return path.StartsWith(linuxPath, StringComparison.OrdinalIgnoreCase) && - path.EndsWith(".key", StringComparison.OrdinalIgnoreCase); - } + if (DesktopOsHelper.IsWindows()) + { + string expandedExpectedPath = Environment.ExpandEnvironmentVariables("%ProgramData%\\AzureConnectedMachineAgent\\Tokens\\"); - private bool IsValidWindowsPath(string path) - { - string expandedExpectedPath = Environment.ExpandEnvironmentVariables("%ProgramData%\\AzureConnectedMachineAgent\\Tokens\\"); + expectedFilePath = expandedExpectedPath + Path.GetFileName(path); + } + else if (DesktopOsHelper.IsLinux()) + { + expectedFilePath = "/var/opt/azcmagent/tokens/" + Path.GetFileName(path); + } + else + { + throw CreateManagedIdentityException( + MsalError.ManagedIdentityRequestFailed, + MsalErrorMessage.ManagedIdentityPlatformNotSupported); + } - return path.StartsWith(expandedExpectedPath, StringComparison.OrdinalIgnoreCase) && + return path.Equals(expectedFilePath, StringComparison.OrdinalIgnoreCase) && path.EndsWith(".key", StringComparison.OrdinalIgnoreCase); } } diff --git a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs index 6697f76df9..40f5d68564 100644 --- a/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs +++ b/tests/Microsoft.Identity.Test.Unit/ManagedIdentityTests/AzureArcTests.cs @@ -81,7 +81,8 @@ await mi.AcquireTokenForManagedIdentity("scope") [DataTestMethod] [DataRow("somefile=filename", MsalErrorMessage.ManagedIdentityInvalidChallenge)] - [DataRow("path/filename", MsalErrorMessage.ManagedIdentityInvalidFile)] + [DataRow("C:\\ProgramData\\AzureConnectedMachineAgent\\Tokens\\filename.txt", MsalErrorMessage.ManagedIdentityInvalidFile)] + [DataRow("C:\\ProgramData\\AzureConnectedMachineAgent\\Tokens\\...\\etc\\filename.key", MsalErrorMessage.ManagedIdentityInvalidFile)] public async Task AzureArcAuthHeaderInvalidAsync(string filename, string errorMessage) { using (new EnvVariableContext()) @@ -97,7 +98,7 @@ public async Task AzureArcAuthHeaderInvalidAsync(string filename, string errorMe var mi = miBuilder.Build(); - httpManager.AddManagedIdentityWSTrustMockHandler(ManagedIdentityTests.AzureArcEndpoint, "somevalue=filepath"); + httpManager.AddManagedIdentityWSTrustMockHandler(ManagedIdentityTests.AzureArcEndpoint, filename); MsalServiceException ex = await Assert.ThrowsExceptionAsync(async () => await mi.AcquireTokenForManagedIdentity("scope") @@ -106,7 +107,7 @@ await mi.AcquireTokenForManagedIdentity("scope") Assert.IsNotNull(ex); Assert.AreEqual(ManagedIdentitySource.AzureArc.ToString(), ex.AdditionalExceptionData[MsalException.ManagedIdentitySource]); Assert.AreEqual(MsalError.ManagedIdentityRequestFailed, ex.ErrorCode); - Assert.AreEqual(MsalErrorMessage.ManagedIdentityInvalidChallenge, ex.Message); + Assert.AreEqual(errorMessage, ex.Message); } }