diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Common/ExpiredTokenException.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Common/ExpiredTokenException.cs new file mode 100644 index 0000000000..c9e760c9c2 --- /dev/null +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Common/ExpiredTokenException.cs @@ -0,0 +1,53 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Runtime.Serialization; + +namespace Microsoft.SqlTools.ResourceProvider.Core +{ + /// + /// The exception is used if any operation fails as a request failed due to an expired token + /// + public class ExpiredTokenException : ServiceExceptionBase + { + /// + /// Initializes a new instance of the ServiceFailedException class. + /// + public ExpiredTokenException() + { + } + + /// + /// Initializes a new instance of the ServiceFailedException class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public ExpiredTokenException(string message) + : base(message) + { + } + + /// + /// Initializes a new instance of the ServiceFailedException class with a specified error message + /// and a reference to the inner exception that is the cause of this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception, or a null reference + /// (Nothing in Visual Basic) if no inner exception is specified + public ExpiredTokenException(string message, Exception innerException) + : base(message, innerException) + { + } + + /// + /// Initializes a new instance of the ServiceFailedException class with serialized data. + /// + /// The SerializationInfo that holds the serialized object data about the exception being thrown. + /// The StreamingContext that contains contextual information about the source or destination. + public ExpiredTokenException(SerializationInfo info, StreamingContext context) + : base(info, context) + { + } + } +} diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs index 28de4b42cc..bebb473149 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/FirewallRule.cs @@ -50,9 +50,11 @@ public class CreateFirewallRuleParams } - public class CreateFirewallRuleResponse + public class CreateFirewallRuleResponse : TokenReliantResponse { - public bool Result { get; set; } + /// + /// An error message for why the request failed, if any + /// public string ErrorMessage { get; set; } } @@ -97,6 +99,4 @@ public class HandleFirewallRuleResponse /// public string IpAddress { get; set; } } - - } diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/TokenReliantResponse.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/TokenReliantResponse.cs new file mode 100644 index 0000000000..1cd15ba26c --- /dev/null +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Contracts/TokenReliantResponse.cs @@ -0,0 +1,28 @@ +// +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. +// + +using System.Collections.Generic; +using Microsoft.SqlTools.Hosting.Protocol.Contracts; + +namespace Microsoft.SqlTools.ResourceProvider.Core.Contracts +{ + /// + /// Any response which relies on a token may indicated that the operation failed due to token being expired. + /// All operational response messages should inherit from this class in order to support a standard method for defining + /// this failure path + /// + public class TokenReliantResponse + { + /// + /// Did this succeed? + /// + public bool Result { get; set; } + + /// + /// If this failed, was it due to a token expiring? + /// + public bool IsTokenExpiredFailure { get; set; } + } +} diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs index c04bdeee6b..dcb666cf0f 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.cs @@ -27,7 +27,7 @@ public static CultureInfo Culture Keys.Culture = value; } } - + public static string NoSubscriptionsFound { @@ -35,7 +35,7 @@ public static string NoSubscriptionsFound { return Keys.GetString(Keys.NoSubscriptionsFound); } - } + } public static string AzureServerNotFound { @@ -43,7 +43,7 @@ public static string AzureServerNotFound { return Keys.GetString(Keys.AzureServerNotFound); } - } + } public static string AzureSubscriptionFailedErrorMessage { @@ -51,7 +51,7 @@ public static string AzureSubscriptionFailedErrorMessage { return Keys.GetString(Keys.AzureSubscriptionFailedErrorMessage); } - } + } public static string DatabaseDiscoveryFailedErrorMessage { @@ -59,7 +59,7 @@ public static string DatabaseDiscoveryFailedErrorMessage { return Keys.GetString(Keys.DatabaseDiscoveryFailedErrorMessage); } - } + } public static string FirewallRuleAccessForbidden { @@ -67,7 +67,7 @@ public static string FirewallRuleAccessForbidden { return Keys.GetString(Keys.FirewallRuleAccessForbidden); } - } + } public static string FirewallRuleCreationFailed { @@ -75,7 +75,7 @@ public static string FirewallRuleCreationFailed { return Keys.GetString(Keys.FirewallRuleCreationFailed); } - } + } public static string FirewallRuleCreationFailedWithError { @@ -83,7 +83,7 @@ public static string FirewallRuleCreationFailedWithError { return Keys.GetString(Keys.FirewallRuleCreationFailedWithError); } - } + } public static string InvalidIpAddress { @@ -91,7 +91,7 @@ public static string InvalidIpAddress { return Keys.GetString(Keys.InvalidIpAddress); } - } + } public static string InvalidServerTypeErrorMessage { @@ -99,7 +99,7 @@ public static string InvalidServerTypeErrorMessage { return Keys.GetString(Keys.InvalidServerTypeErrorMessage); } - } + } public static string LoadingExportableFailedGeneralErrorMessage { @@ -107,7 +107,7 @@ public static string LoadingExportableFailedGeneralErrorMessage { return Keys.GetString(Keys.LoadingExportableFailedGeneralErrorMessage); } - } + } public static string FirewallRuleUnsupportedConnectionType { @@ -115,7 +115,7 @@ public static string FirewallRuleUnsupportedConnectionType { return Keys.GetString(Keys.FirewallRuleUnsupportedConnectionType); } - } + } [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys @@ -123,40 +123,40 @@ public class Keys static ResourceManager resourceManager = new ResourceManager("Microsoft.SqlTools.ResourceProvider.Core.Localization.SR", typeof(SR).GetTypeInfo().Assembly); static CultureInfo _culture = null; - - - public const string NoSubscriptionsFound = "NoSubscriptionsFound"; - - - public const string AzureServerNotFound = "AzureServerNotFound"; - - - public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; - - - public const string DatabaseDiscoveryFailedErrorMessage = "DatabaseDiscoveryFailedErrorMessage"; - - - public const string FirewallRuleAccessForbidden = "FirewallRuleAccessForbidden"; - - - public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; - - - public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; - - - public const string InvalidIpAddress = "InvalidIpAddress"; - - - public const string InvalidServerTypeErrorMessage = "InvalidServerTypeErrorMessage"; - - - public const string LoadingExportableFailedGeneralErrorMessage = "LoadingExportableFailedGeneralErrorMessage"; - - - public const string FirewallRuleUnsupportedConnectionType = "FirewallRuleUnsupportedConnectionType"; - + + + public const string NoSubscriptionsFound = "NoSubscriptionsFound"; + + + public const string AzureServerNotFound = "AzureServerNotFound"; + + + public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; + + + public const string DatabaseDiscoveryFailedErrorMessage = "DatabaseDiscoveryFailedErrorMessage"; + + + public const string FirewallRuleAccessForbidden = "FirewallRuleAccessForbidden"; + + + public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; + + + public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; + + + public const string InvalidIpAddress = "InvalidIpAddress"; + + + public const string InvalidServerTypeErrorMessage = "InvalidServerTypeErrorMessage"; + + + public const string LoadingExportableFailedGeneralErrorMessage = "LoadingExportableFailedGeneralErrorMessage"; + + + public const string FirewallRuleUnsupportedConnectionType = "FirewallRuleUnsupportedConnectionType"; + private Keys() { } @@ -177,7 +177,7 @@ public static string GetString(string key) { return resourceManager.GetString(key, _culture); } - - } - } -} + + } + } +} diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx index 55bc2567b3..0882e79a07 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.resx @@ -120,45 +120,45 @@ No subscriptions were found for the currently logged in user account. - + The server you specified {0} does not exist in any subscription in {1}. Either you have signed in with an incorrect account or your server was removed from subscription(s) in this account. Please check your account and try again. - + - An error occurred while getting Azure subscriptions + An error occurred while getting Azure subscriptions: {0} - + An error occurred while getting databases from servers of type {0} from {1} - + {0} does not have permission to change the server firewall rule. Try again with a different account that is an Owner or Contributor of the Azure subscription or the server. - + An error occurred while creating a new firewall rule. - + An error occurred while creating a new firewall rule: '{0}' - + Invalid IP address - + Server Type is invalid. - + A required dll cannot be loaded. Please repair your application. - + Cannot open a firewall rule for the specified connection type - - + + diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings index eab975f1be..10aa1ae955 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.strings @@ -24,7 +24,7 @@ # Azure Core DLL NoSubscriptionsFound = No subscriptions were found for the currently logged in user account. AzureServerNotFound = The server you specified {0} does not exist in any subscription in {1}. Either you have signed in with an incorrect account or your server was removed from subscription(s) in this account. Please check your account and try again. -AzureSubscriptionFailedErrorMessage = An error occurred while getting Azure subscriptions +AzureSubscriptionFailedErrorMessage = An error occurred while getting Azure subscriptions: {0} DatabaseDiscoveryFailedErrorMessage = An error occurred while getting databases from servers of type {0} from {1} FirewallRuleAccessForbidden = {0} does not have permission to change the server firewall rule. Try again with a different account that is an Owner or Contributor of the Azure subscription or the server. FirewallRuleCreationFailed = An error occurred while creating a new firewall rule. diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf index 8479b2e8cf..dd771645de 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/Localization/sr.xlf @@ -8,7 +8,7 @@ - An error occurred while getting Azure subscriptions + An error occurred while getting Azure subscriptions: {0} An error occurred while getting Azure subscriptions diff --git a/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs b/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs index 3a93cb133d..f013faad97 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.Core/ResourceProviderService.cs @@ -56,7 +56,16 @@ public async Task HandleCreateFirewallRuleRequest(CreateFirewallRuleParams firew { return DoHandleCreateFirewallRuleRequest(firewallRule); }; - await HandleRequest(requestHandler, requestContext, "HandleCreateFirewallRuleRequest"); + Func tokenExpiredHandler = (ExpiredTokenException ex) => + { + return new CreateFirewallRuleResponse() + { + Result = false, + IsTokenExpiredFailure = true, + ErrorMessage = ex.Message + }; + }; + await HandleRequest(requestHandler, tokenExpiredHandler, requestContext, "HandleCreateFirewallRuleRequest"); } private async Task DoHandleCreateFirewallRuleRequest(CreateFirewallRuleParams firewallRule) @@ -98,10 +107,10 @@ public async Task ProcessHandleFirewallRuleRequest(HandleFirewallRuleParams canH } return Task.FromResult(response); }; - await HandleRequest(requestHandler, requestContext, "HandleCreateFirewallRuleRequest"); + await HandleRequest(requestHandler, null, requestContext, "HandleCreateFirewallRuleRequest"); } - private async Task HandleRequest(Func> handler, RequestContext requestContext, string requestType) + private async Task HandleRequest(Func> handler, Func expiredTokenHandler, RequestContext requestContext, string requestType) { Logger.Write(LogLevel.Verbose, requestType); @@ -110,9 +119,25 @@ private async Task HandleRequest(Func> handler, RequestContext req T result = await handler(); await requestContext.SendResult(result); } + catch(ExpiredTokenException ex) + { + if (expiredTokenHandler != null) + { + // This is a special exception indicating the token(s) used to request resources had expired. + // Any Azure resource should have handling for this such as an error path that clearly indicates that a refresh is needed + T result = expiredTokenHandler(ex); + await requestContext.SendResult(result); + } + else + { + // No handling for expired tokens defined / expected + await requestContext.SendError(ex.Message); + } + } catch (Exception ex) { - await requestContext.SendError(ex.ToString()); + // Send just the error message back for now as stack trace isn't useful + await requestContext.SendError(ex.Message); } } } diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs index 7dd2d8eec5..7702f6fb03 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureAuthenticationManager.cs @@ -207,7 +207,7 @@ public async Task> GetSubscrip } catch (Exception ex) { - throw new ServiceFailedException(SR.AzureSubscriptionFailedErrorMessage, ex); + throw new ServiceFailedException(SR.FailedToGetAzureSubscriptionsErrorMessage, ex); } } result = result ?? Enumerable.Empty(); @@ -234,38 +234,25 @@ private async Task> GetSubscri private async Task> GetSubscriptionFromServiceAsync(AzureUserAccount userAccount) { + CommonUtil.CheckForNull(userAccount, nameof(userAccount)); List subscriptionList = new List(); - + if (userAccount.NeedsReauthentication) + { + throw new ExpiredTokenException(SR.UserNeedsAuthenticationError); + } try { - if (userAccount != null && !userAccount.NeedsReauthentication) - { - IAzureResourceManager resourceManager = ServiceProvider.GetService(); - IEnumerable contexts = await resourceManager.GetSubscriptionContextsAsync(userAccount); - subscriptionList = contexts.ToList(); - } - else - { - throw new UserNeedsAuthenticationException(SR.AzureSubscriptionFailedErrorMessage); - } + IAzureResourceManager resourceManager = ServiceProvider.GetService(); + IEnumerable contexts = await resourceManager.GetSubscriptionContextsAsync(userAccount); + subscriptionList = contexts.ToList(); } - // TODO handle stale tokens - //catch (MissingSecurityTokenException missingSecurityTokenException) - //{ - // //User needs to reauthenticate - // if (userAccount != null) - // { - // userAccount.NeedsReauthentication = true; - // } - // throw new UserNeedsAuthenticationException(SR.AzureSubscriptionFailedErrorMessage, missingSecurityTokenException); - //} catch (ServiceExceptionBase) { throw; } catch (Exception ex) { - throw new ServiceFailedException(SR.AzureSubscriptionFailedErrorMessage, ex); + throw new ServiceFailedException(SR.FailedToGetAzureSubscriptionsErrorMessage, ex); } return subscriptionList; } diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs index 54c06d5d71..2948381d47 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/AzureResourceManager.cs @@ -38,7 +38,8 @@ namespace Microsoft.SqlTools.ResourceProvider.DefaultImpl public class AzureResourceManager : ExportableBase, IAzureResourceManager { private readonly Uri _resourceManagementUri = new Uri("https://management.azure.com/"); - + private const string ExpiredTokenCode = "ExpiredAuthenticationToken"; + public AzureResourceManager() { // Duplicate the exportable attribute as at present we do not support filtering using extensiondescriptor. @@ -91,16 +92,10 @@ public async Task> GetAzureDatabasesAsync( if (vsAzureResourceManagementSession != null) { - try - { - IEnumerable databaseListResponse = await vsAzureResourceManagementSession.SqlManagementClient.Databases.ListByServerAsync(resourceGroupName, serverName); - return databaseListResponse.Select( - x => new AzureResourceWrapper(x) { ResourceGroupName = resourceGroupName }); - } - catch(HttpOperationException ex) - { - throw new AzureResourceFailedException(SR.FailedToGetAzureDatabasesErrorMessage, ex.Response.StatusCode); - } + IEnumerable databaseListResponse = await ExecuteCloudRequest( + () => vsAzureResourceManagementSession.SqlManagementClient.Databases.ListByServerAsync(resourceGroupName, serverName), + SR.FailedToGetAzureDatabasesErrorMessage); + return databaseListResponse.Select(x => new AzureResourceWrapper(x) { ResourceGroupName = resourceGroupName }); } } catch (Exception ex) @@ -127,28 +122,17 @@ public async Task> GetSqlServerAzureResourc AzureResourceManagementSession vsAzureResourceManagementSession = azureResourceManagementSession as AzureResourceManagementSession; if(vsAzureResourceManagementSession != null) { - // Note: Ideally wouldn't need to query resource groups, but the current impl requires it - // since any update will need the resource group name and it's not returned from the server. - // This has a very negative impact on perf, so we should investigate running these queries - // in parallel - - try + IServersOperations serverOperations = vsAzureResourceManagementSession.SqlManagementClient.Servers; + IPage servers = await ExecuteCloudRequest( + () => serverOperations.ListAsync(), + SR.FailedToGetAzureSqlServersWithError); + if (servers != null) { - IServersOperations serverOperations = vsAzureResourceManagementSession.SqlManagementClient.Servers; - IPage servers = await serverOperations.ListAsync(); - if (servers != null) - { - sqlServers.AddRange(servers.Select(server => { - var serverResource = new SqlAzureResource(server); - // TODO ResourceGroup name - return serverResource; - })); - } - } - catch (HttpOperationException ex) - { - throw new AzureResourceFailedException( - string.Format(CultureInfo.CurrentCulture, SR.FailedToGetAzureSqlServersWithError, ex.Message), ex.Response.StatusCode); + sqlServers.AddRange(servers.Select(server => { + var serverResource = new SqlAzureResource(server); + // TODO ResourceGroup name + return serverResource; + })); } } } @@ -176,33 +160,27 @@ public async Task CreateFirewallRuleAsync( if (vsAzureResourceManagementSession != null) { - try + var firewallRule = new RestFirewallRule() { - var firewallRule = new RestFirewallRule() - { - EndIpAddress = firewallRuleRequest.EndIpAddress.ToString(), - StartIpAddress = firewallRuleRequest.StartIpAddress.ToString() - }; - IFirewallRulesOperations firewallRuleOperations = vsAzureResourceManagementSession.SqlManagementClient.FirewallRules; - var firewallRuleResponse = await firewallRuleOperations.CreateOrUpdateWithHttpMessagesAsync( - azureSqlServer.ResourceGroupName ?? string.Empty, - azureSqlServer.Name, - firewallRuleRequest.FirewallRuleName, - firewallRule, - GetCustomHeaders()); - var response = firewallRuleResponse.Body; - return new FirewallRuleResponse() - { - StartIpAddress = response.StartIpAddress, - EndIpAddress = response.EndIpAddress, - Created = true - }; - } - catch (HttpOperationException ex) + EndIpAddress = firewallRuleRequest.EndIpAddress.ToString(), + StartIpAddress = firewallRuleRequest.StartIpAddress.ToString() + }; + IFirewallRulesOperations firewallRuleOperations = vsAzureResourceManagementSession.SqlManagementClient.FirewallRules; + var firewallRuleResponse = await ExecuteCloudRequest( + () => firewallRuleOperations.CreateOrUpdateWithHttpMessagesAsync( + azureSqlServer.ResourceGroupName ?? string.Empty, + azureSqlServer.Name, + firewallRuleRequest.FirewallRuleName, + firewallRule, + GetCustomHeaders()), + SR.FirewallRuleCreationFailedWithError); + var response = firewallRuleResponse.Body; + return new FirewallRuleResponse() { - throw new AzureResourceFailedException( - string.Format(CultureInfo.CurrentCulture, SR.FirewallRuleCreationFailedWithError, ex.Message), ex.Response.StatusCode); - } + StartIpAddress = response.StartIpAddress, + EndIpAddress = response.EndIpAddress, + Created = true + }; } // else respond with failure case return new FirewallRuleResponse() @@ -228,40 +206,6 @@ private Dictionary> GetCustomHeaders() return headers; } - /// - /// Returns the azure resource groups for given subscription - /// - private async Task> GetResourceGroupsAsync(AzureResourceManagementSession vsAzureResourceManagementSession) - { - try - { - if (vsAzureResourceManagementSession != null) - { - try - { - IResourceGroupsOperations resourceGroupOperations = vsAzureResourceManagementSession.ResourceManagementClient.ResourceGroups; - IPage resourceGroupList = await resourceGroupOperations.ListAsync(); - if (resourceGroupList != null) - { - return resourceGroupList.AsEnumerable(); - } - - } - catch (HttpOperationException ex) - { - throw new AzureResourceFailedException(string.Format(CultureInfo.CurrentCulture, SR.FailedToGetAzureResourceGroupsErrorMessage, ex.Message), ex.Response.StatusCode); - } - } - - return Enumerable.Empty(); - } - catch (Exception ex) - { - TraceException(TraceEventType.Error, (int)TraceId.AzureResource, ex, "Failed to get azure resource groups"); - throw; - } - } - /// /// Gets all subscription contexts under a specific user account. Queries all tenants for the account and uses these to log in /// and retrieve subscription information as needed @@ -278,7 +222,7 @@ public async Task> GetSubscrip { var ex = response.Errors.First(); throw new AzureResourceFailedException( - string.Format(CultureInfo.CurrentCulture, SR.AzureSubscriptionFailedErrorMessage, ex.Message)); + string.Format(CultureInfo.CurrentCulture, SR.FailedToGetAzureSubscriptionsErrorMessage, ex.Message)); } contexts.AddRange(response.Data); stopwatch.Stop(); @@ -318,20 +262,12 @@ private async Task> GetSubscriptionsAsync(Subscription { if (subscriptionClient != null) { - try + IPage subscriptionList = await ExecuteCloudRequest( + () => subscriptionClient.Subscriptions.ListAsync(), + SR.FailedToGetAzureSubscriptionsErrorMessage); + if (subscriptionList != null) { - ISubscriptionsOperations subscriptionsOperations = subscriptionClient.Subscriptions; - IPage subscriptionList = await subscriptionsOperations.ListAsync(); - if (subscriptionList != null) - { - return subscriptionList.AsEnumerable(); - } - - } - catch (HttpOperationException ex) - { - throw new AzureResourceFailedException( - string.Format(CultureInfo.CurrentCulture, SR.AzureSubscriptionFailedErrorMessage, ex.Message), ex.Response.StatusCode); + return subscriptionList.AsEnumerable(); } } @@ -382,5 +318,29 @@ private ServiceClientCredentials CreateCredentials(IAzureUserAccountSubscription } throw new NotSupportedException("This uses an unknown subscription type"); } + + private async Task ExecuteCloudRequest(Func> operation, string errorOccurredMsg) + { + try + { + return await operation(); + } + catch(CloudException ex) + { + if (ex.Body != null && string.Equals(ExpiredTokenCode, ex.Body.Code, StringComparison.OrdinalIgnoreCase)) + { + // Throw an expired token exception, which indicates that the operation could succeed if the user reauthenticates + throw new ExpiredTokenException(ex.Message); + } + throw new AzureResourceFailedException( + string.Format(CultureInfo.CurrentCulture, errorOccurredMsg, ex.Message), ex.Response.StatusCode); + } + catch (HttpOperationException ex) + { + throw new AzureResourceFailedException( + string.Format(CultureInfo.CurrentCulture, errorOccurredMsg, ex.Message), ex.Response.StatusCode); + } + + } } } diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs index 44f609074d..5ccc2bfad8 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.cs @@ -27,7 +27,7 @@ public static CultureInfo Culture Keys.Culture = value; } } - + public static string FailedToGetAzureDatabasesErrorMessage { @@ -35,7 +35,7 @@ public static string FailedToGetAzureDatabasesErrorMessage { return Keys.GetString(Keys.FailedToGetAzureDatabasesErrorMessage); } - } + } public static string FailedToGetAzureSubscriptionsErrorMessage { @@ -43,7 +43,7 @@ public static string FailedToGetAzureSubscriptionsErrorMessage { return Keys.GetString(Keys.FailedToGetAzureSubscriptionsErrorMessage); } - } + } public static string FailedToGetAzureResourceGroupsErrorMessage { @@ -51,7 +51,7 @@ public static string FailedToGetAzureResourceGroupsErrorMessage { return Keys.GetString(Keys.FailedToGetAzureResourceGroupsErrorMessage); } - } + } public static string FailedToGetAzureSqlServersErrorMessage { @@ -59,7 +59,7 @@ public static string FailedToGetAzureSqlServersErrorMessage { return Keys.GetString(Keys.FailedToGetAzureSqlServersErrorMessage); } - } + } public static string FailedToGetAzureSqlServersWithError { @@ -67,7 +67,7 @@ public static string FailedToGetAzureSqlServersWithError { return Keys.GetString(Keys.FailedToGetAzureSqlServersWithError); } - } + } public static string FirewallRuleCreationFailed { @@ -75,7 +75,7 @@ public static string FirewallRuleCreationFailed { return Keys.GetString(Keys.FirewallRuleCreationFailed); } - } + } public static string FirewallRuleCreationFailedWithError { @@ -83,23 +83,31 @@ public static string FirewallRuleCreationFailedWithError { return Keys.GetString(Keys.FirewallRuleCreationFailedWithError); } - } + } + + public static string UnsupportedAuthType + { + get + { + return Keys.GetString(Keys.UnsupportedAuthType); + } + } - public static string AzureSubscriptionFailedErrorMessage + public static string UserNotFoundError { get { - return Keys.GetString(Keys.AzureSubscriptionFailedErrorMessage); + return Keys.GetString(Keys.UserNotFoundError); } - } + } - public static string UnsupportedAuthType + public static string UserNeedsAuthenticationError { get { - return Keys.GetString(Keys.UnsupportedAuthType); + return Keys.GetString(Keys.UserNeedsAuthenticationError); } - } + } [System.Runtime.CompilerServices.CompilerGeneratedAttribute()] public class Keys @@ -107,34 +115,37 @@ public class Keys static ResourceManager resourceManager = new ResourceManager("Microsoft.SqlTools.ResourceProvider.DefaultImpl.Localization.SR", typeof(SR).GetTypeInfo().Assembly); static CultureInfo _culture = null; - - - public const string FailedToGetAzureDatabasesErrorMessage = "FailedToGetAzureDatabasesErrorMessage"; - - - public const string FailedToGetAzureSubscriptionsErrorMessage = "FailedToGetAzureSubscriptionsErrorMessage"; - - - public const string FailedToGetAzureResourceGroupsErrorMessage = "FailedToGetAzureResourceGroupsErrorMessage"; - - - public const string FailedToGetAzureSqlServersErrorMessage = "FailedToGetAzureSqlServersErrorMessage"; - - - public const string FailedToGetAzureSqlServersWithError = "FailedToGetAzureSqlServersWithError"; - - - public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; - - - public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; - - - public const string AzureSubscriptionFailedErrorMessage = "AzureSubscriptionFailedErrorMessage"; - - - public const string UnsupportedAuthType = "UnsupportedAuthType"; - + + + public const string FailedToGetAzureDatabasesErrorMessage = "FailedToGetAzureDatabasesErrorMessage"; + + + public const string FailedToGetAzureSubscriptionsErrorMessage = "FailedToGetAzureSubscriptionsErrorMessage"; + + + public const string FailedToGetAzureResourceGroupsErrorMessage = "FailedToGetAzureResourceGroupsErrorMessage"; + + + public const string FailedToGetAzureSqlServersErrorMessage = "FailedToGetAzureSqlServersErrorMessage"; + + + public const string FailedToGetAzureSqlServersWithError = "FailedToGetAzureSqlServersWithError"; + + + public const string FirewallRuleCreationFailed = "FirewallRuleCreationFailed"; + + + public const string FirewallRuleCreationFailedWithError = "FirewallRuleCreationFailedWithError"; + + + public const string UnsupportedAuthType = "UnsupportedAuthType"; + + + public const string UserNotFoundError = "UserNotFoundError"; + + + public const string UserNeedsAuthenticationError = "UserNeedsAuthenticationError"; + private Keys() { } @@ -155,7 +166,7 @@ public static string GetString(string key) { return resourceManager.GetString(key, _culture); } - - } - } -} + + } + } +} diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx index 81c293a464..540f690db1 100755 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.resx @@ -120,37 +120,41 @@ An error occurred while getting Azure databases - + An error occurred while getting Azure subscriptions: {0} - + An error occurred while getting Azure resource groups: {0} - + An error occurred while getting Azure Sql Servers - + An error occurred while getting Azure Sql Servers: '{0}' - + An error occurred while creating a new firewall rule. - + An error occurred while creating a new firewall rule: '{0}' - - - An error occurred while getting Azure subscriptions - - + Unsupported account type '{0}' for this provider - - + + + No user was found, cannot execute the operation + + + + The current user must be reauthenticated before executing this operation + + + diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.strings b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.strings index b8e5ba2dd7..603f04a6ee 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.strings +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.strings @@ -29,5 +29,6 @@ FailedToGetAzureSqlServersErrorMessage = An error occurred while getting Azure S FailedToGetAzureSqlServersWithError = An error occurred while getting Azure Sql Servers: '{0}' FirewallRuleCreationFailed = An error occurred while creating a new firewall rule. FirewallRuleCreationFailedWithError = An error occurred while creating a new firewall rule: '{0}' -AzureSubscriptionFailedErrorMessage = An error occurred while getting Azure subscriptions -UnsupportedAuthType = Unsupported account type '{0}' for this provider \ No newline at end of file +UnsupportedAuthType = Unsupported account type '{0}' for this provider +UserNotFoundError = No user was found, cannot execute the operation +UserNeedsAuthenticationError = The current user must be reauthenticated before executing this operation \ No newline at end of file diff --git a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.xlf b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.xlf index 41c18ae027..1cd1d84bf3 100644 --- a/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.xlf +++ b/src/Microsoft.SqlTools.ResourceProvider.DefaultImpl/Localization/sr.xlf @@ -22,11 +22,6 @@ An error occurred while creating a new firewall rule. - - An error occurred while getting Azure subscriptions - An error occurred while getting Azure subscriptions - - Unsupported account type '{0}' for this provider Unsupported account type '{0}' for this provider @@ -47,6 +42,16 @@ An error occurred while getting Azure subscriptions: {0} + + The current user must be reauthenticated before executing this operation + The current user must be reauthenticated before executing this operation + + + + No user was found, cannot execute the operation + No user was found, cannot execute the operation + + \ No newline at end of file diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Completion/AutoCompletionResultTest.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Completion/AutoCompletionResultTest.cs index a6219cd7df..8779265608 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Completion/AutoCompletionResultTest.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/Completion/AutoCompletionResultTest.cs @@ -14,10 +14,10 @@ namespace Microsoft.SqlTools.ServiceLayer.UnitTests.Completion public class AutoCompletionResultTest { [Fact] - public void MetricsShouldGetSortedGivenUnSortedArray() + public void CompletionShouldRecordDuration() { AutoCompletionResult result = new AutoCompletionResult(); - int duration = 2000; + int duration = 200; Thread.Sleep(duration); result.CompleteResult(new CompletionItem[] { }); Assert.True(result.Duration >= duration); diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs index a685ee32ab..9c0bf2a684 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/Azure/AzureAuthenticationManagerTest.cs @@ -51,7 +51,7 @@ public async Task GetSubscriptionShouldThrowWhenUserNeedsAuthentication() var currentUserAccount = CreateAccount(); currentUserAccount.Account.IsStale = true; IAzureAuthenticationManager accountManager = await CreateAccountManager(currentUserAccount, null); - await Assert.ThrowsAsync(() => accountManager.GetSelectedSubscriptionsAsync()); + await Assert.ThrowsAsync(() => accountManager.GetSelectedSubscriptionsAsync()); } [Fact] diff --git a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs index 832fd91aed..9876fa2fce 100644 --- a/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs +++ b/test/Microsoft.SqlTools.ServiceLayer.UnitTests/ResourceProvider/ResourceProviderServiceTests.cs @@ -123,7 +123,7 @@ public async Task TestHandleFirewallRuleDoesntBreakWithoutIp() // When I ask whether the service can process an error as a firewall rule request await TestUtils.RunAndVerify((context) => ResourceProviderService.ProcessHandleFirewallRuleRequest(handleFirewallParams, context), (response) => { - // Then I expect the response to be fakse as we require the known IP address to function + // Then I expect the response to be OK as we require the known IP address to function Assert.NotNull(response); Assert.False(response.Result); Assert.Equal(string.Empty, response.IpAddress); @@ -168,10 +168,42 @@ await TestUtils.RunAndVerify( (context) => ResourceProviderService.HandleCreateFirewallRuleRequest(createFirewallParams, context), (response) => { - // Then I expect the response to be fakse as we require the known IP address to function + // Then I expect the response to be OK as we require the known IP address to function Assert.NotNull(response); Assert.Null(response.ErrorMessage); Assert.True(response.Result); + Assert.False(response.IsTokenExpiredFailure); + }); + } + + [Fact] + public async Task TestCreateFirewallRuleHandlesTokenExpiration() + { + // Given the token has expired + string serverName = "myserver.database.windows.net"; + var sub1Mock = new Mock(); + SetupCreateSession(); + string expectedErrorMsg = "Token is expired"; + AuthenticationManagerMock.Setup(a => a.GetSubscriptionsAsync()).ThrowsAsync(new ExpiredTokenException(expectedErrorMsg)); + + // When I request the firewall be created + var createFirewallParams = new CreateFirewallRuleParams() + { + ServerName = serverName, + StartIpAddress = "1.1.1.1", + EndIpAddress = "1.1.1.255", + Account = CreateAccount(), + SecurityTokenMappings = new Dictionary() + }; + await TestUtils.RunAndVerify( + (context) => ResourceProviderService.HandleCreateFirewallRuleRequest(createFirewallParams, context), + (response) => + { + // Then I expect the response to indicate that we failed due to token expiration + Assert.NotNull(response); + Assert.Equal(expectedErrorMsg, response.ErrorMessage); + Assert.True(response.IsTokenExpiredFailure); + Assert.False(response.Result); }); }