Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,31 @@ internal enum MsiTestType
MsiAppJsonParseFailure,
MsiMissingToken,
MsiAppServicesIncorrectRequest,
MsiAzureVmTimeout,
MsiAzureVmImdsTimeout,
MsiUnresponsive,
MsiThrottled,
MsiTransientServerError
}

private readonly MsiTestType _msiTestType;

private const string _azureVmImdsInstanceEndpoint = "http://169.254.169.254/metadata/instance";
private const string _azureVmImdsTokenEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token";

internal MockMsi(MsiTestType msiTestType)
{
_msiTestType = msiTestType;
}

/// <summary>
/// Returns a response based on the response type.
/// Returns a response based on the response type.
/// </summary>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
// HitCount is updated when this method gets called. This allows for testing of cache and retry logic.
// HitCount is updated when this method gets called. This allows for testing of cache and retry logic.
HitCount++;

HttpResponseMessage responseMessage = null;
Expand Down Expand Up @@ -138,16 +141,20 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
};
break;

case MsiTestType.MsiAzureVmTimeout:
var start = DateTime.Now;
while(DateTime.Now - start < TimeSpan.FromSeconds(MsiAccessTokenProvider.AzureVmImdsProbeTimeoutInSeconds + 10))
case MsiTestType.MsiAzureVmImdsTimeout:
if (request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsInstanceEndpoint))
{
if (cancellationToken.IsCancellationRequested)
var start = DateTime.Now;
while (DateTime.Now - start < TimeSpan.FromSeconds(MsiAccessTokenProvider.AzureVmImdsProbeTimeoutInSeconds + 10))
{
throw new TaskCanceledException();
if (cancellationToken.IsCancellationRequested)
{
throw new TaskCanceledException();
}
}
throw new Exception("Test fail");
}
throw new Exception("Test fail");
break;

case MsiTestType.MsiUnresponsive:
case MsiTestType.MsiThrottled:
Expand All @@ -167,7 +174,20 @@ protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage reques
// give error based on test type
if (_msiTestType == MsiTestType.MsiUnresponsive)
{
throw new HttpRequestException();
if (request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsInstanceEndpoint))
{
responseMessage = new HttpResponseMessage
{
Content = new StringContent(TokenHelper.GetInstanceMetadataResponse(),
Encoding.UTF8,
Constants.JsonContentType)
};
}
else if (Environment.GetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv) != null
|| request.RequestUri.AbsoluteUri.StartsWith(_azureVmImdsTokenEndpoint))
{
throw new HttpRequestException();
}
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
namespace Microsoft.Azure.Services.AppAuthentication.Unit.Tests
{
/// <summary>
/// Test cases for MsiAccessTokenProvider class. MsiAccessTokenProvider is an internal class.
/// Test cases for MsiAccessTokenProvider class. MsiAccessTokenProvider is an internal class.
/// </summary>
public class MsiAccessTokenProviderTests : IDisposable
{
Expand Down Expand Up @@ -48,26 +48,26 @@ public async Task GetTokenUsingManagedIdentityAzureVm(bool specifyUserAssignedMa
expectedAppId = Constants.TestAppId;
}

// MockMsi is being asked to act like response from Azure VM MSI succeeded.
// MockMsi is being asked to act like response from Azure VM MSI succeeded.
MockMsi mockMsi = new MockMsi(msiTestType);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient, managedIdentityClientId: managedIdentityArgument);

// Get token.
var authResult = await msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);

// Check if the principalused and type are as expected.
// Check if the principalused and type are as expected.
Validator.ValidateToken(authResult.AccessToken, msiAccessTokenProvider.PrincipalUsed, Constants.AppType, Constants.TenantId, expectedAppId, expiresOn: authResult.ExpiresOn);
}

/// <summary>
/// If json parse error when aquiring token, an exception should be thrown.
/// If json parse error when aquiring token, an exception should be thrown.
/// </summary>
/// <returns></returns>
[Fact]
public async Task ParseErrorMsiGetTokenTest()
{
// MockMsi is being asked to act like response from Azure VM MSI suceeded.
// MockMsi is being asked to act like response from Azure VM MSI suceeded.
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppJsonParseFailure);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand All @@ -80,13 +80,13 @@ public async Task ParseErrorMsiGetTokenTest()
}

/// <summary>
/// If MSI response if missing the token, an exception should be thrown.
/// If MSI response if missing the token, an exception should be thrown.
/// </summary>
/// <returns></returns>
[Fact]
public async Task MsiResponseMissingTokenTest()
{
// MockMsi is being asked to act like response from Azure VM MSI failed.
// MockMsi is being asked to act like response from Azure VM MSI failed.
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiMissingToken);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand All @@ -103,7 +103,7 @@ public async Task MsiResponseMissingTokenTest()
[InlineData(false)]
public async Task GetTokenUsingManagedIdentityAppServices(bool specifyUserAssignedManagedIdentity)
{
// Setup the environment variables that App Service MSI would setup.
// Setup the environment variables that App Service MSI would setup.
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

Expand All @@ -125,19 +125,19 @@ public async Task GetTokenUsingManagedIdentityAppServices(bool specifyUserAssign
expectedAppId = Constants.TestAppId;
}

// MockMsi is being asked to act like response from App Service MSI suceeded.
// MockMsi is being asked to act like response from App Service MSI suceeded.
MockMsi mockMsi = new MockMsi(msiTestType);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient, managedIdentityClientId: managedIdentityArgument);

// Get token. This confirms that the environment variables are being read.
// Get token. This confirms that the environment variables are being read.
var authResult = await msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);

Validator.ValidateToken(authResult.AccessToken, msiAccessTokenProvider.PrincipalUsed, Constants.AppType, Constants.TenantId, expectedAppId, expiresOn: authResult.ExpiresOn);
}

/// <summary>
/// Test response when IDENTITY_HEADER in AppServices MSI is invalid.
/// Test response when IDENTITY_HEADER in AppServices MSI is invalid.
/// </summary>
/// <returns></returns>
[Fact]
Expand All @@ -147,7 +147,7 @@ public async Task UnauthorizedTest()
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

// MockMsi is being asked to act like response from App Service MSI failed (unauthorized).
// MockMsi is being asked to act like response from App Service MSI failed (unauthorized).
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppServicesUnauthorized);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand All @@ -159,7 +159,7 @@ public async Task UnauthorizedTest()
}

/// <summary>
/// Test that response when MSI request is not valid is as expected.
/// Test that response when MSI request is not valid is as expected.
/// </summary>
/// <returns></returns>
[Fact]
Expand All @@ -180,7 +180,7 @@ public async Task IncorrectFormatTest()
}

/// <summary>
/// If an unexpected http response has been received, ensure exception is thrown.
/// If an unexpected http response has been received, ensure exception is thrown.
/// </summary>
/// <returns></returns>
[Fact]
Expand Down Expand Up @@ -208,13 +208,13 @@ public async Task AzureVmImdsTimeoutTest()
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, null);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, null);

MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAzureVmTimeout);
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAzureVmImdsTimeout);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);

var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => msiAccessTokenProvider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId)));

Assert.Contains(AzureServiceTokenProviderException.MsiEndpointNotListening, exception.Message);
Assert.Contains(AzureServiceTokenProviderException.MetadataEndpointNotListening, exception.Message);
Assert.DoesNotContain(AzureServiceTokenProviderException.RetryFailure, exception.Message);
}

Expand All @@ -224,7 +224,7 @@ public async Task AzureVmImdsTimeoutTest()
[InlineData(MockMsi.MsiTestType.MsiTransientServerError)]
internal async Task TransientErrorRetryTest(MockMsi.MsiTestType testType)
{
// To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request by
// To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request by
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

Expand Down Expand Up @@ -254,12 +254,17 @@ internal async Task TransientErrorRetryTest(MockMsi.MsiTestType testType)
}
}

[Fact]
private async Task MsiRetryTimeoutTest()
[Theory]
[InlineData(false)]
[InlineData(true)]
internal async Task MsiRetryTimeoutTest(bool isAppServices)
{
// To simplify tests, mock as MSI App Services to skip Azure VM IDMS probe request
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);
if (isAppServices)
{
// Mock as MSI App Services to skip Azure VM IDMS probe request
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);
}

int timeoutInSeconds = (new Random()).Next(1, 4);

Expand Down Expand Up @@ -291,11 +296,11 @@ private async Task AppServicesDifferentCultureTest()
// ensure thread culture is NOT using en-US culture (App Services MSI endpoint always uses en-US DateTime format)
Thread.CurrentThread.CurrentCulture = new CultureInfo("en-GB");

// Setup the environment variables that App Service MSI would setup.
// Setup the environment variables that App Service MSI would setup.
Environment.SetEnvironmentVariable(Constants.MsiAppServiceEndpointEnv, Constants.MsiEndpoint);
Environment.SetEnvironmentVariable(Constants.MsiAppServiceHeaderEnv, Constants.ClientSecret);

// MockMsi is being asked to act like response from App Service MSI suceeded.
// MockMsi is being asked to act like response from App Service MSI suceeded.
MockMsi mockMsi = new MockMsi(MockMsi.MsiTestType.MsiAppServicesSuccess);
HttpClient httpClient = new HttpClient(mockMsi);
MsiAccessTokenProvider msiAccessTokenProvider = new MsiAccessTokenProvider(httpClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace Microsoft.Azure.Services.AppAuthentication.Unit.Tests
public class TokenHelper
{
/// <summary>
/// The hardcoded user token has expiry replaced by [exp], so we can replace it with some value to test functionality.
/// The hardcoded user token has expiry replaced by [exp], so we can replace it with some value to test functionality.
/// </summary>
/// <param name="accessToken"></param>
/// <param name="secondsFromCurrent"></param>
Expand All @@ -30,7 +30,7 @@ private static string UpdateTokenTime(string accessToken, long secondsFromCurren

internal static string GetUserToken()
{
// Gets a user token that will expire in 10 seconds from now.
// Gets a user token that will expire in 10 seconds from now.
return GetUserToken(10);
}

Expand Down Expand Up @@ -63,6 +63,16 @@ internal static string GetUserTokenResponse(long secondsFromCurrent, bool format
return tokenResult;
}

/// <summary>
/// Sample IMDS /instance response
/// </summary>
/// <returns></returns>
internal static string GetInstanceMetadataResponse()
{
return
"{\"compute\":{\"location\":\"westus\",\"name\":\"TestBedVm\",\"resourceGroupName\":\"testbed\",\"subscriptionId\":\"bdd789f3-d9d1-4bea-ac14-30a39ed66d33\"}}";
}

/// <summary>
/// The response has claims as expected from App Service MSI response
/// </summary>
Expand Down Expand Up @@ -128,7 +138,7 @@ internal static string GetInvalidMsiTokenResponse()
}

/// <summary>
/// The response has claims as expected from Client Credentials flow response.
/// The response has claims as expected from Client Credentials flow response.
/// </summary>
/// <returns></returns>
internal static string GetAppToken()
Expand All @@ -139,7 +149,7 @@ internal static string GetAppToken()
}

/// <summary>
/// Invalid AppToken.
/// Invalid AppToken.
/// </summary>
/// <returns></returns>
internal static string GetInvalidAppToken()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
namespace Microsoft.Azure.Services.AppAuthentication
{
/// <summary>
/// Instance of this exception is thrown if access token cannot be acquired.
/// Instance of this exception is thrown if access token cannot be acquired.
/// </summary>
#if FullNetFx || NETSTANDARD2_0
[Serializable]
#endif

public class AzureServiceTokenProviderException : Exception
{
internal const string MetadataEndpointNotListening = "Unable to connect to the Instance Metadata Service (IMDS). Skipping request to the Managed Service Identity (MSI) token endpoint.";

internal const string MsiEndpointNotListening = "Unable to connect to the Managed Service Identity (MSI) endpoint. Please check that you are running on an Azure resource that has MSI setup.";

internal const string UnableToParseMsiTokenResponse = "A successful response was received from Managed Service Identity, but it could not be parsed.";
Expand All @@ -42,7 +44,7 @@ public class AzureServiceTokenProviderException : Exception
internal const string NonRetryableError = "Received a non-retryable error.";

/// <summary>
/// Creates an instance of AzureServiceTokenProviderException.
/// Creates an instance of AzureServiceTokenProviderException.
/// </summary>
/// <param name="connectionString">Connection string used.</param>
/// <param name="resource">Resource for which token was expected.</param>
Expand Down
Loading