Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ namespace Aspire.Hosting.Azure.Provisioning;

internal sealed class AzureProvisionerOptions
{
public string? TenantId { get; set; }

public string? SubscriptionId { get; set; }

public string? ResourceGroup { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ internal abstract partial class BaseProvisioningContextProvider(
internal const string LocationName = "Location";
internal const string SubscriptionIdName = "SubscriptionId";
internal const string ResourceGroupName = "ResourceGroup";
internal const string TenantName = "Tenant";

protected readonly IInteractionService _interactionService = interactionService;
protected readonly AzureProvisionerOptions _options = options.Value;
Expand Down Expand Up @@ -161,6 +162,10 @@ public virtual async Task<ProvisioningContext> CreateProvisioningContextAsync(Js
azureSection["Location"] = _options.Location;
azureSection["SubscriptionId"] = _options.SubscriptionId;
azureSection["ResourceGroup"] = resourceGroupName;
if (!string.IsNullOrEmpty(_options.TenantId))
{
azureSection["TenantId"] = _options.TenantId;
}
if (_options.AllowResourceGroupCreation.HasValue)
{
azureSection["AllowResourceGroupCreation"] = _options.AllowResourceGroupCreation.Value;
Expand All @@ -180,7 +185,56 @@ public virtual async Task<ProvisioningContext> CreateProvisioningContextAsync(Js

protected abstract string GetDefaultResourceGroupName();

protected async Task<(List<KeyValuePair<string, string>>? subscriptionOptions, bool fetchSucceeded)> TryGetSubscriptionsAsync(CancellationToken cancellationToken)
protected async Task<(List<KeyValuePair<string, string>>? tenantOptions, bool fetchSucceeded)> TryGetTenantsAsync(CancellationToken cancellationToken)
{
List<KeyValuePair<string, string>>? tenantOptions = null;
var fetchSucceeded = false;

try
{
var credential = _tokenCredentialProvider.TokenCredential;
var armClient = _armClientProvider.GetArmClient(credential);
var availableTenants = await armClient.GetAvailableTenantsAsync(cancellationToken).ConfigureAwait(false);
var tenantList = availableTenants.ToList();

if (tenantList.Count > 0)
{
tenantOptions = tenantList
.Select(t =>
{
var tenantId = t.TenantId?.ToString() ?? "";

// Build display name: prefer DisplayName, fall back to domain, then to "Unknown"
var displayName = !string.IsNullOrEmpty(t.DisplayName)
? t.DisplayName
: !string.IsNullOrEmpty(t.DefaultDomain)
? t.DefaultDomain
: "Unknown";

// Build full description
var description = displayName;
if (!string.IsNullOrEmpty(t.DefaultDomain) && t.DisplayName != t.DefaultDomain)
{
description += $" ({t.DefaultDomain})";
}
description += $" — {tenantId}";

return KeyValuePair.Create(tenantId, description);
})
.OrderBy(kvp => kvp.Value)
.ToList();
fetchSucceeded = true;
}
}
catch (Exception ex)
{
_logger.LogWarning(ex, "Failed to enumerate available tenants. Falling back to manual input.");
}

return (tenantOptions, fetchSucceeded);
}

protected async Task<(List<KeyValuePair<string, string>>? subscriptionOptions, bool fetchSucceeded)> TryGetSubscriptionsAsync(string? tenantId, CancellationToken cancellationToken)
{
List<KeyValuePair<string, string>>? subscriptionOptions = null;
var fetchSucceeded = false;
Expand All @@ -189,7 +243,7 @@ public virtual async Task<ProvisioningContext> CreateProvisioningContextAsync(Js
{
var credential = _tokenCredentialProvider.TokenCredential;
var armClient = _armClientProvider.GetArmClient(credential);
var availableSubscriptions = await armClient.GetAvailableSubscriptionsAsync(cancellationToken).ConfigureAwait(false);
var availableSubscriptions = await armClient.GetAvailableSubscriptionsAsync(tenantId, cancellationToken).ConfigureAwait(false);
var subscriptionList = availableSubscriptions.ToList();

if (subscriptionList.Count > 0)
Expand All @@ -208,6 +262,11 @@ public virtual async Task<ProvisioningContext> CreateProvisioningContextAsync(Js
return (subscriptionOptions, fetchSucceeded);
}

protected async Task<(List<KeyValuePair<string, string>>? subscriptionOptions, bool fetchSucceeded)> TryGetSubscriptionsAsync(CancellationToken cancellationToken)
{
return await TryGetSubscriptionsAsync(_options.TenantId, cancellationToken).ConfigureAwait(false);
}

protected async Task<(List<KeyValuePair<string, string>> locationOptions, bool fetchSucceeded)> TryGetLocationsAsync(string subscriptionId, CancellationToken cancellationToken)
{
List<KeyValuePair<string, string>>? locationOptions = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ private sealed class DefaultArmClient(ArmClient armClient) : IArmClient
return (subscriptionResource, tenantResource);
}

public async Task<IEnumerable<ITenantResource>> GetAvailableTenantsAsync(CancellationToken cancellationToken = default)
{
var tenants = new List<ITenantResource>();

await foreach (var tenant in armClient.GetTenants().GetAllAsync(cancellationToken: cancellationToken).ConfigureAwait(false))
{
tenants.Add(new DefaultTenantResource(tenant));
}

return tenants;
}

public async Task<IEnumerable<ISubscriptionResource>> GetAvailableSubscriptionsAsync(CancellationToken cancellationToken = default)
{
var subscriptions = new List<ISubscriptionResource>();
Expand All @@ -62,6 +74,27 @@ public async Task<IEnumerable<ISubscriptionResource>> GetAvailableSubscriptionsA
return subscriptions;
}

public async Task<IEnumerable<ISubscriptionResource>> GetAvailableSubscriptionsAsync(string? tenantId, CancellationToken cancellationToken = default)
{
if (string.IsNullOrEmpty(tenantId))
{
return await GetAvailableSubscriptionsAsync(cancellationToken).ConfigureAwait(false);
}

var subscriptions = new List<ISubscriptionResource>();

await foreach (var subscription in armClient.GetSubscriptions().GetAllAsync(cancellationToken: cancellationToken).ConfigureAwait(false))
{
// Filter subscriptions by tenant ID
if (subscription.Data.TenantId?.ToString().Equals(tenantId, StringComparison.OrdinalIgnoreCase) == true)
{
subscriptions.Add(new DefaultSubscriptionResource(subscription));
}
}

return subscriptions;
}

public async Task<IEnumerable<(string Name, string DisplayName)>> GetAvailableLocationsAsync(string subscriptionId, CancellationToken cancellationToken = default)
{
var subscription = await armClient.GetSubscriptions().GetAsync(subscriptionId, cancellationToken).ConfigureAwait(false);
Expand All @@ -78,6 +111,7 @@ public async Task<IEnumerable<ISubscriptionResource>> GetAvailableSubscriptionsA
private sealed class DefaultTenantResource(TenantResource tenantResource) : ITenantResource
{
public Guid? TenantId => tenantResource.Data.TenantId;
public string? DisplayName => tenantResource.Data.DisplayName;
public string? DefaultDomain => tenantResource.Data.DefaultDomain;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,38 @@ internal class DefaultTokenCredentialProvider : ITokenCredentialProvider

public DefaultTokenCredentialProvider(
ILogger<DefaultTokenCredentialProvider> logger,
IOptions<AzureProvisionerOptions> options,
DistributedApplicationExecutionContext distributedApplicationExecutionContext)
IOptions<AzureProvisionerOptions> options)
{
_logger = logger;

// Optionally configured in AppHost appSettings under "Azure" : { "CredentialSource": "AzureCli" }
var credentialSetting = options.Value.CredentialSource;

// Use AzureCli as default for publish mode when no explicit credential source is set
var credentialSource = credentialSetting switch
TokenCredential credential = options.Value.CredentialSource switch
{
null or "Default" when distributedApplicationExecutionContext.IsPublishMode => "AzureCli",
_ => credentialSetting ?? "Default"
};

TokenCredential credential = credentialSource switch
{
"AzureCli" => new AzureCliCredential(),
"AzurePowerShell" => new AzurePowerShellCredential(),
"VisualStudio" => new VisualStudioCredential(),
"AzureDeveloperCli" => new AzureDeveloperCliCredential(),
"AzureCli" => new AzureCliCredential(new()
{
AdditionallyAllowedTenants = { "*" }
}),
"AzurePowerShell" => new AzurePowerShellCredential(new()
{
AdditionallyAllowedTenants = { "*" }
}),
"VisualStudio" => new VisualStudioCredential(new()
{
AdditionallyAllowedTenants = { "*" }
}),
"AzureDeveloperCli" => new AzureDeveloperCliCredential(new()
{
AdditionallyAllowedTenants = { "*" }
}),
"InteractiveBrowser" => new InteractiveBrowserCredential(),
_ => new DefaultAzureCredential(new DefaultAzureCredentialOptions()
{
ExcludeManagedIdentityCredential = true,
ExcludeWorkloadIdentityCredential = true,
ExcludeAzurePowerShellCredential = true,
CredentialProcessTimeout = TimeSpan.FromSeconds(15)
CredentialProcessTimeout = TimeSpan.FromSeconds(15),
AdditionallyAllowedTenants = { "*" }
})
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,21 @@ internal interface IArmClient
/// </summary>
Task<(ISubscriptionResource subscription, ITenantResource tenant)> GetSubscriptionAndTenantAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Gets all tenants accessible to the current user.
/// </summary>
Task<IEnumerable<ITenantResource>> GetAvailableTenantsAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Gets all subscriptions accessible to the current user.
/// </summary>
Task<IEnumerable<ISubscriptionResource>> GetAvailableSubscriptionsAsync(CancellationToken cancellationToken = default);

/// <summary>
/// Gets all subscriptions accessible to the current user filtered by tenant ID.
/// </summary>
Task<IEnumerable<ISubscriptionResource>> GetAvailableSubscriptionsAsync(string? tenantId, CancellationToken cancellationToken = default);

/// <summary>
/// Gets all available locations for the specified subscription.
/// </summary>
Expand Down Expand Up @@ -174,6 +184,11 @@ internal interface ITenantResource
/// </summary>
Guid? TenantId { get; }

/// <summary>
/// Gets the display name.
/// </summary>
string? DisplayName { get; }

/// <summary>
/// Gets the default domain.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@ private async Task RetrieveAzureProvisioningOptions(CancellationToken cancellati
{
while (_options.Location == null || _options.SubscriptionId == null)
{
// Skip tenant prompting if subscription ID is already set
if (_options.TenantId == null && _options.SubscriptionId == null)
{
await PromptForTenantAsync(cancellationToken).ConfigureAwait(false);
if (_options.TenantId == null)
{
continue;
}
}

if (_options.SubscriptionId == null)
{
await PromptForSubscriptionAsync(cancellationToken).ConfigureAwait(false);
Expand All @@ -97,6 +107,105 @@ private async Task RetrieveAzureProvisioningOptions(CancellationToken cancellati
}
}

private async Task PromptForTenantAsync(CancellationToken cancellationToken)
{
List<KeyValuePair<string, string>>? tenantOptions = null;
var fetchSucceeded = false;

var step = await activityReporter.CreateStepAsync(
"fetch-tenant",
cancellationToken).ConfigureAwait(false);

await using (step.ConfigureAwait(false))
{
try
{
var task = await step.CreateTaskAsync("Fetching available tenants", cancellationToken).ConfigureAwait(false);

await using (task.ConfigureAwait(false))
{
(tenantOptions, fetchSucceeded) = await TryGetTenantsAsync(cancellationToken).ConfigureAwait(false);
}

if (fetchSucceeded)
{
await step.SucceedAsync($"Found {tenantOptions!.Count} available tenant(s)", cancellationToken).ConfigureAwait(false);
}
else
{
await step.WarnAsync("Failed to fetch tenants, falling back to manual entry", cancellationToken).ConfigureAwait(false);
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Failed to retrieve Azure tenant information.");
await step.FailAsync($"Failed to retrieve tenant information: {ex.Message}", cancellationToken).ConfigureAwait(false);
throw;
}
}

if (tenantOptions?.Count > 0)
{
var result = await _interactionService.PromptInputsAsync(
AzureProvisioningStrings.TenantDialogTitle,
AzureProvisioningStrings.TenantSelectionMessage,
[
new InteractionInput
{
Name = TenantName,
InputType = InputType.Choice,
Label = AzureProvisioningStrings.TenantLabel,
Required = true,
Options = [..tenantOptions]
}
],
new InputsDialogInteractionOptions
{
EnableMessageMarkdown = false
},
cancellationToken).ConfigureAwait(false);

if (!result.Canceled)
{
_options.TenantId = result.Data[TenantName].Value;
return;
}
}

var manualResult = await _interactionService.PromptInputsAsync(
AzureProvisioningStrings.TenantDialogTitle,
AzureProvisioningStrings.TenantManualEntryMessage,
[
new InteractionInput
{
Name = TenantName,
InputType = InputType.SecretText,
Label = AzureProvisioningStrings.TenantLabel,
Placeholder = AzureProvisioningStrings.TenantPlaceholder,
Required = true
}
],
new InputsDialogInteractionOptions
{
EnableMessageMarkdown = false,
ValidationCallback = static (validationContext) =>
{
var tenantInput = validationContext.Inputs[TenantName];
if (!Guid.TryParse(tenantInput.Value, out var _))
{
validationContext.AddValidationError(tenantInput, AzureProvisioningStrings.ValidationTenantIdInvalid);
}
return Task.CompletedTask;
}
},
cancellationToken).ConfigureAwait(false);

if (!manualResult.Canceled)
{
_options.TenantId = manualResult.Data[TenantName].Value;
}
}

private async Task PromptForSubscriptionAsync(CancellationToken cancellationToken)
{
List<KeyValuePair<string, string>>? subscriptionOptions = null;
Expand All @@ -114,7 +223,7 @@ private async Task PromptForSubscriptionAsync(CancellationToken cancellationToke

await using (task.ConfigureAwait(false))
{
(subscriptionOptions, fetchSucceeded) = await TryGetSubscriptionsAsync(cancellationToken).ConfigureAwait(false);
(subscriptionOptions, fetchSucceeded) = await TryGetSubscriptionsAsync(_options.TenantId, cancellationToken).ConfigureAwait(false);
}

if (fetchSucceeded)
Expand Down
Loading