diff --git a/cli/azd/grpc/proto/prompt.proto b/cli/azd/grpc/proto/prompt.proto index 94b806312d3..e4265f1fede 100644 --- a/cli/azd/grpc/proto/prompt.proto +++ b/cli/azd/grpc/proto/prompt.proto @@ -74,6 +74,7 @@ message PromptSubscriptionResponse { message PromptLocationRequest { AzureContext azure_context = 1; + repeated string allowed_locations = 2; } message PromptLocationResponse { diff --git a/cli/azd/internal/grpcserver/prompt_service.go b/cli/azd/internal/grpcserver/prompt_service.go index 119eda1d5bc..d6b128533e9 100644 --- a/cli/azd/internal/grpcserver/prompt_service.go +++ b/cli/azd/internal/grpcserver/prompt_service.go @@ -266,7 +266,14 @@ func (s *promptService) PromptLocation( return nil, err } - selectedLocation, err := s.prompter.PromptLocation(ctx, azureContext, nil) + var selectorOptions *prompt.SelectOptions + if len(req.AllowedLocations) > 0 { + selectorOptions = &prompt.SelectOptions{ + AllowedValues: req.AllowedLocations, + } + } + + selectedLocation, err := s.prompter.PromptLocation(ctx, azureContext, selectorOptions) if err != nil { return nil, err } @@ -843,6 +850,18 @@ func (s *promptService) PromptAiDeployment( // --- Step 3: Resolve capacity, optionally prompting --- capacity := ai.ResolveCapacity(selectedSku.sku, options.Capacity) + if req.Quota != nil && selectedSku.remaining != nil { + resolvedCapacity, ok := ai.ResolveCapacityWithQuota(selectedSku.sku, options.Capacity, *selectedSku.remaining) + if !ok { + return nil, aiStatusError( + codes.FailedPrecondition, + azdext.AiErrorReasonNoDeploymentMatch, + fmt.Sprintf("no deployment match for model %q with the selected SKU and quota", req.ModelName), + map[string]string{"model_name": req.ModelName}, + ) + } + capacity = resolvedCapacity + } if !req.UseDefaultCapacity { sku := selectedSku.sku @@ -1229,8 +1248,6 @@ func buildSkuCandidatesForVersion( continue } - capacity := ai.ResolveCapacity(sku, options.Capacity) - var remaining *float64 if quota != nil { if usageMap == nil { @@ -1244,7 +1261,11 @@ func buildSkuCandidatesForVersion( rem := usage.Limit - usage.CurrentValue remaining = &rem - if rem < minReq || (capacity > 0 && float64(capacity) > rem) { + if rem < minReq { + continue + } + + if _, ok := ai.ResolveCapacityWithQuota(sku, options.Capacity, rem); !ok { continue } } diff --git a/cli/azd/internal/grpcserver/prompt_service_test.go b/cli/azd/internal/grpcserver/prompt_service_test.go index b59c7391ae2..5373984ac99 100644 --- a/cli/azd/internal/grpcserver/prompt_service_test.go +++ b/cli/azd/internal/grpcserver/prompt_service_test.go @@ -6,6 +6,7 @@ package grpcserver import ( "context" "errors" + "slices" "testing" "github.com/azure/azure-dev/cli/azd/internal" @@ -216,6 +217,39 @@ func Test_PromptService_PromptLocation(t *testing.T) { mockPrompter.AssertExpectations(t) } +func Test_PromptService_PromptLocation_WithAllowedLocations(t *testing.T) { + mockPrompter := &mockprompt.MockPromptService{} + globalOptions := &internal.GlobalCommandOptions{NoPrompt: false} + + expectedLocation := &account.Location{ + Name: "westus3", + DisplayName: "West US 3", + RegionalDisplayName: "(US) West US 3", + } + + mockPrompter. + On("PromptLocation", mock.Anything, mock.Anything, mock.MatchedBy(func(opts *prompt.SelectOptions) bool { + return opts != nil && slices.Equal(opts.AllowedValues, []string{"westus3", "eastus2"}) + })). + Return(expectedLocation, nil) + + service := NewPromptService(mockPrompter, nil, nil, globalOptions) + + resp, err := service.PromptLocation(context.Background(), &azdext.PromptLocationRequest{ + AzureContext: &azdext.AzureContext{ + Scope: &azdext.AzureScope{ + SubscriptionId: "sub-123", + }, + }, + AllowedLocations: []string{"westus3", "eastus2"}, + }) + + require.NoError(t, err) + require.NotNil(t, resp.Location) + require.Equal(t, expectedLocation.Name, resp.Location.Name) + mockPrompter.AssertExpectations(t) +} + func Test_PromptService_PromptResourceGroup(t *testing.T) { mockPrompter := &mockprompt.MockPromptService{} globalOptions := &internal.GlobalCommandOptions{NoPrompt: false} @@ -826,6 +860,38 @@ func Test_buildSkuCandidatesForVersion(t *testing.T) { require.NotNil(t, candidates[0].remaining) require.Equal(t, float64(10), *candidates[0].remaining) }) + + t.Run("falls back to lower capacity that fits remaining quota", func(t *testing.T) { + deepSeekVersion := ai.AiModelVersion{ + Version: "1", + Skus: []ai.AiModelSku{ + { + Name: "GlobalStandard", + UsageName: "AIServices.GlobalStandard.DeepSeek-R1-0528", + DefaultCapacity: 5000, + MinCapacity: 0, + MaxCapacity: 5000, + CapacityStep: 0, + }, + }, + } + quota := &azdext.QuotaCheckOptions{ + MinRemainingCapacity: 1, + } + usageMap := map[string]ai.AiModelUsage{ + "AIServices.GlobalStandard.DeepSeek-R1-0528": { + Name: "AIServices.GlobalStandard.DeepSeek-R1-0528", + CurrentValue: 0, + Limit: 1000, + }, + } + + candidates := buildSkuCandidatesForVersion(deepSeekVersion, nil, quota, usageMap, false) + require.Len(t, candidates, 1) + require.Equal(t, "AIServices.GlobalStandard.DeepSeek-R1-0528", candidates[0].sku.UsageName) + require.NotNil(t, candidates[0].remaining) + require.Equal(t, float64(1000), *candidates[0].remaining) + }) } func Test_maxSkuCandidateRemaining(t *testing.T) { diff --git a/cli/azd/pkg/ai/model_service.go b/cli/azd/pkg/ai/model_service.go index 601df534e72..013aaec11aa 100644 --- a/cli/azd/pkg/ai/model_service.go +++ b/cli/azd/pkg/ai/model_service.go @@ -62,19 +62,15 @@ func (s *AiModelService) ListModels( return s.convertToAiModels(rawModels), nil } -// ListLocations returns subscription location names that can be used for model queries. +// ListLocations returns AI Services-supported location names that can be used for model queries. func (s *AiModelService) ListLocations( ctx context.Context, subscriptionId string, ) ([]string, error) { - subLocations, err := s.subManager.GetLocations(ctx, subscriptionId) + locations, err := s.azureClient.GetResourceSkuLocations( + ctx, subscriptionId, "AIServices", "S0", "Standard", "accounts") if err != nil { - return nil, fmt.Errorf("listing locations: %w", err) - } - - locations := make([]string, 0, len(subLocations)) - for _, loc := range subLocations { - locations = append(locations, loc.Name) + return nil, fmt.Errorf("listing AI Services locations: %w", err) } return locations, nil @@ -456,9 +452,8 @@ func (s *AiModelService) resolveDeployments( continue } - capacity := ResolveCapacity(sku, options.Capacity) - // Quota check + capacity := ResolveCapacity(sku, options.Capacity) if quotaOpts != nil && usageMap != nil { usage, ok := usageMap[sku.UsageName] if !ok { @@ -470,9 +465,15 @@ func (s *AiModelService) resolveDeployments( if minReq <= 0 { minReq = 1 } - if remaining < minReq || (capacity > 0 && float64(capacity) > remaining) { + if remaining < minReq { continue } + + resolvedCapacity, fitsQuota := ResolveCapacityWithQuota(sku, options.Capacity, remaining) + if !fitsQuota { + continue + } + capacity = resolvedCapacity } // Only set location when exactly one was provided — never guess. @@ -731,16 +732,109 @@ func convertSku(sku *armcognitiveservices.ModelSKU) AiModelSku { func ResolveCapacity(sku AiModelSku, preferred *int32) int32 { if preferred != nil { cap := *preferred - if cap > 0 && - (sku.MinCapacity <= 0 || cap >= sku.MinCapacity) && - (sku.MaxCapacity <= 0 || cap <= sku.MaxCapacity) && - (sku.CapacityStep <= 0 || cap%sku.CapacityStep == 0) { + if capacityValidForSku(sku, cap) { return cap } } return sku.DefaultCapacity } +// ResolveCapacityWithQuota resolves the deployment capacity for a SKU while considering remaining quota. +// If preferred is set, it must fit within the remaining quota or resolution fails. +// When preferred is unset and the default capacity exceeds remaining quota, it falls back to the highest +// valid capacity within the SKU constraints that still fits in the remaining quota. +func ResolveCapacityWithQuota(sku AiModelSku, preferred *int32, remaining float64) (int32, bool) { + capacity := ResolveCapacity(sku, preferred) + if preferred != nil { + return capacity, capacityFitsWithinQuota(sku, capacity, remaining) + } + + if capacityFitsWithinQuota(sku, capacity, remaining) { + return capacity, true + } + + return fallbackCapacityWithinQuota(sku, remaining) +} + +func capacityValidForSku(sku AiModelSku, capacity int32) bool { + if capacity <= 0 { + return false + } + + if sku.MinCapacity > 0 && capacity < sku.MinCapacity { + return false + } + + if sku.MaxCapacity > 0 && capacity > sku.MaxCapacity { + return false + } + + if sku.CapacityStep > 0 { + baseline := capacityStepBaseline(sku) + if capacity < baseline || (capacity-baseline)%sku.CapacityStep != 0 { + return false + } + } + + return true +} + +func capacityStepBaseline(sku AiModelSku) int32 { + if sku.MinCapacity > 0 { + return sku.MinCapacity + } + + return sku.CapacityStep +} + +func minimumValidCapacity(sku AiModelSku) int32 { + if sku.MinCapacity > 0 { + return sku.MinCapacity + } + + if sku.CapacityStep > 0 { + return sku.CapacityStep + } + + return 1 +} + +func capacityFitsWithinQuota(sku AiModelSku, capacity int32, remaining float64) bool { + if !capacityValidForSku(sku, capacity) { + return false + } + + return float64(capacity) <= remaining +} + +func fallbackCapacityWithinQuota(sku AiModelSku, remaining float64) (int32, bool) { + upperBound := int32(remaining) + if upperBound <= 0 { + return 0, false + } + + if sku.MaxCapacity > 0 && upperBound > sku.MaxCapacity { + upperBound = sku.MaxCapacity + } + + lowerBound := minimumValidCapacity(sku) + + if upperBound < lowerBound { + return 0, false + } + + if sku.CapacityStep <= 0 { + return upperBound, true + } + + candidate := upperBound - (upperBound-lowerBound)%sku.CapacityStep + if candidate < lowerBound { + return 0, false + } + + return candidate, true +} + // ModelHasDefaultVersion returns true if any version of the model is marked as default. func ModelHasDefaultVersion(model AiModel) bool { for _, v := range model.Versions { @@ -758,6 +852,9 @@ func modelHasQuota(model AiModel, usageMap map[string]AiModelUsage, minRemaining if ok { remaining := usage.Limit - usage.CurrentValue if remaining >= minRemaining { + if _, ok := ResolveCapacityWithQuota(sku, nil, remaining); !ok { + continue + } return true } } @@ -777,6 +874,9 @@ func maxModelRemainingQuota(model AiModel, usageMap map[string]AiModelUsage) (fl } remaining := usage.Limit - usage.CurrentValue + if _, ok := ResolveCapacityWithQuota(sku, nil, remaining); !ok { + continue + } if !found || remaining > maxRemaining { maxRemaining = remaining } diff --git a/cli/azd/pkg/ai/model_service_test.go b/cli/azd/pkg/ai/model_service_test.go index 96cc1aa5a0c..e56697fbcfa 100644 --- a/cli/azd/pkg/ai/model_service_test.go +++ b/cli/azd/pkg/ai/model_service_test.go @@ -269,6 +269,17 @@ func TestResolveCapacity(t *testing.T) { preferred: new(int32(15)), expected: 10, }, + { + name: "preferred capacity aligned relative to minimum", + sku: AiModelSku{ + DefaultCapacity: 7, + MinCapacity: 7, + MaxCapacity: 100, + CapacityStep: 5, + }, + preferred: new(int32(12)), + expected: 12, + }, { name: "no preferred uses default", sku: AiModelSku{ @@ -301,6 +312,80 @@ func TestResolveCapacity(t *testing.T) { } } +func TestResolveCapacityWithQuota(t *testing.T) { + t.Run("uses default when it fits in remaining quota", func(t *testing.T) { + capacity, ok := ResolveCapacityWithQuota(AiModelSku{ + DefaultCapacity: 25, + MinCapacity: 1, + MaxCapacity: 100, + CapacityStep: 1, + }, nil, 50) + + require.True(t, ok) + require.Equal(t, int32(25), capacity) + }) + + t.Run("falls back below default when no preferred capacity is set", func(t *testing.T) { + capacity, ok := ResolveCapacityWithQuota(AiModelSku{ + DefaultCapacity: 5000, + MinCapacity: 0, + MaxCapacity: 5000, + CapacityStep: 0, + }, nil, 1000) + + require.True(t, ok) + require.Equal(t, int32(1000), capacity) + }) + + t.Run("respects min and step when falling back", func(t *testing.T) { + capacity, ok := ResolveCapacityWithQuota(AiModelSku{ + DefaultCapacity: 3000, + MinCapacity: 100, + MaxCapacity: 3000, + CapacityStep: 100, + }, nil, 950) + + require.True(t, ok) + require.Equal(t, int32(900), capacity) + }) + + t.Run("fails when explicit preferred capacity does not fit", func(t *testing.T) { + capacity, ok := ResolveCapacityWithQuota(AiModelSku{ + DefaultCapacity: 5000, + MinCapacity: 0, + MaxCapacity: 5000, + CapacityStep: 0, + }, new(int32(5000)), 1000) + + require.False(t, ok) + require.Equal(t, int32(5000), capacity) + }) + + t.Run("fails when remaining quota is below effective minimum", func(t *testing.T) { + capacity, ok := ResolveCapacityWithQuota(AiModelSku{ + DefaultCapacity: 0, + MinCapacity: 100, + MaxCapacity: 3000, + CapacityStep: 100, + }, nil, 50) + + require.False(t, ok) + require.Equal(t, int32(0), capacity) + }) + + t.Run("falls back using step alignment relative to minimum", func(t *testing.T) { + capacity, ok := ResolveCapacityWithQuota(AiModelSku{ + DefaultCapacity: 27, + MinCapacity: 7, + MaxCapacity: 100, + CapacityStep: 5, + }, nil, 20) + + require.True(t, ok) + require.Equal(t, int32(17), capacity) + }) +} + func TestMaxModelRemainingQuota(t *testing.T) { model := AiModel{ Name: "gpt-4o", diff --git a/cli/azd/pkg/azdext/prompt.pb.go b/cli/azd/pkg/azdext/prompt.pb.go index 3aa40a5c8b9..4cc94bdae95 100644 --- a/cli/azd/pkg/azdext/prompt.pb.go +++ b/cli/azd/pkg/azdext/prompt.pb.go @@ -121,10 +121,11 @@ func (x *PromptSubscriptionResponse) GetSubscription() *Subscription { } type PromptLocationRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - AzureContext *AzureContext `protobuf:"bytes,1,opt,name=azure_context,json=azureContext,proto3" json:"azure_context,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + AzureContext *AzureContext `protobuf:"bytes,1,opt,name=azure_context,json=azureContext,proto3" json:"azure_context,omitempty"` + AllowedLocations []string `protobuf:"bytes,2,rep,name=allowed_locations,json=allowedLocations,proto3" json:"allowed_locations,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *PromptLocationRequest) Reset() { @@ -164,6 +165,13 @@ func (x *PromptLocationRequest) GetAzureContext() *AzureContext { return nil } +func (x *PromptLocationRequest) GetAllowedLocations() []string { + if x != nil { + return x.AllowedLocations + } + return nil +} + type PromptLocationResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Location *Location `protobuf:"bytes,1,opt,name=location,proto3" json:"location,omitempty"` @@ -2131,9 +2139,10 @@ const file_prompt_proto_rawDesc = "" + "\aMessage\x18\x01 \x01(\tR\aMessage\x12 \n" + "\vHelpMessage\x18\x02 \x01(\tR\vHelpMessage\"V\n" + "\x1aPromptSubscriptionResponse\x128\n" + - "\fsubscription\x18\x01 \x01(\v2\x14.azdext.SubscriptionR\fsubscription\"R\n" + + "\fsubscription\x18\x01 \x01(\v2\x14.azdext.SubscriptionR\fsubscription\"\x7f\n" + "\x15PromptLocationRequest\x129\n" + - "\razure_context\x18\x01 \x01(\v2\x14.azdext.AzureContextR\fazureContext\"F\n" + + "\razure_context\x18\x01 \x01(\v2\x14.azdext.AzureContextR\fazureContext\x12+\n" + + "\x11allowed_locations\x18\x02 \x03(\tR\x10allowedLocations\"F\n" + "\x16PromptLocationResponse\x12,\n" + "\blocation\x18\x01 \x01(\v2\x10.azdext.LocationR\blocation\"\x95\x01\n" + "\x1aPromptResourceGroupRequest\x129\n" + diff --git a/cli/azd/pkg/prompt/prompt_service.go b/cli/azd/pkg/prompt/prompt_service.go index f8d4fc0ad97..ecda02797d6 100644 --- a/cli/azd/pkg/prompt/prompt_service.go +++ b/cli/azd/pkg/prompt/prompt_service.go @@ -110,6 +110,9 @@ type SelectOptions struct { Hint string // EnableFiltering specifies whether to enable filtering of choices. EnableFiltering *bool + // AllowedValues limits candidates for prompts that support value filtering, + // such as PromptLocation. + AllowedValues []string // Writer is the writer to use for output. Writer io.Writer } @@ -377,8 +380,10 @@ func (ps *promptService) PromptLocation( return nil, fmt.Errorf("failed to load locations: %w", err) } + locationList = filterLocationOptions(locationList, mergedOptions.AllowedValues) + for _, location := range locationList { - if location.Name == defaultLocation { + if strings.EqualFold(location.Name, defaultLocation) { return &account.Location{ Name: location.Name, DisplayName: location.DisplayName, @@ -388,7 +393,7 @@ func (ps *promptService) PromptLocation( } return nil, fmt.Errorf( - "default location '%s' not found. "+ + "default location '%s' not found in the available location options. "+ "Update your default location using 'azd config set defaults.location '", defaultLocation) } @@ -404,6 +409,14 @@ func (ps *promptService) PromptLocation( return nil, err } + locationList = filterLocationOptions(locationList, mergedOptions.AllowedValues) + + if len(locationList) == 0 { + return nil, fmt.Errorf( + "no locations matched the allowed locations filter. " + + "Verify the allowed locations configuration is correct") + } + locations := make([]*account.Location, len(locationList)) for i, location := range locationList { locations[i] = &account.Location{ @@ -419,11 +432,40 @@ func (ps *promptService) PromptLocation( return fmt.Sprintf("%s %s", location.RegionalDisplayName, output.WithGrayFormat("(%s)", location.Name)), nil }, Selected: func(resource *account.Location) bool { - return resource.Name == defaultLocation + return strings.EqualFold(resource.Name, defaultLocation) }, }) } +func filterLocationOptions(locations []account.Location, allowed []string) []account.Location { + if len(allowed) == 0 { + return locations + } + + allowedSet := make(map[string]struct{}, len(allowed)) + for _, location := range allowed { + normalized := normalizePromptLocationName(location) + if normalized == "" { + continue + } + allowedSet[normalized] = struct{}{} + } + + // If all allowed entries normalize to empty/whitespace, treat as "no filtering". + if len(allowedSet) == 0 { + return locations + } + + return slices.DeleteFunc(slices.Clone(locations), func(location account.Location) bool { + _, ok := allowedSet[normalizePromptLocationName(location.Name)] + return !ok + }) +} + +func normalizePromptLocationName(location string) string { + return strings.TrimSpace(strings.ToLower(location)) +} + // PromptResourceGroup prompts the user to select an Azure resource group. func (ps *promptService) PromptResourceGroup( ctx context.Context, diff --git a/cli/azd/pkg/prompt/prompt_service_test.go b/cli/azd/pkg/prompt/prompt_service_test.go index 01def1c4399..939125d8ad9 100644 --- a/cli/azd/pkg/prompt/prompt_service_test.go +++ b/cli/azd/pkg/prompt/prompt_service_test.go @@ -243,3 +243,159 @@ func TestPromptSubscription_NoPrompt_DefaultNotFound_DemoModeRedactsId(t *testin require.ErrorContains(t, err, "default subscription not found") require.False(t, strings.Contains(err.Error(), "sub-secret")) } + +func TestPromptLocation_NoPrompt_FiltersAllowedValues(t *testing.T) { + cfg := config.NewEmptyConfig() + err := cfg.Set("defaults.location", "westus3") + require.NoError(t, err) + + ucm := newInMemoryUserConfigManager(cfg) + authManager := &mockauth.MockAuthManager{} + subscriptionManager := &mockaccount.MockSubscriptionManager{} + resourceService := &mockazapi.MockResourceService{} + mockConsole := mockinput.NewMockConsole() + mockConsole.SetNoPromptMode(true) + + subscriptionManager. + On("GetLocations", mock.Anything, "sub-123"). + Return([]account.Location{ + {Name: "eastus2", DisplayName: "East US 2", RegionalDisplayName: "(US) East US 2"}, + {Name: "westus3", DisplayName: "West US 3", RegionalDisplayName: "(US) West US 3"}, + }, nil) + + ps := NewPromptService( + authManager, + mockConsole, + ucm, + subscriptionManager, + resourceService, + &internal.GlobalCommandOptions{NoPrompt: true}, + ) + + location, err := ps.PromptLocation(context.Background(), &AzureContext{ + Scope: AzureScope{SubscriptionId: "sub-123"}, + }, &SelectOptions{ + AllowedValues: []string{"westus3"}, + }) + + require.NoError(t, err) + require.Equal(t, "westus3", location.Name) + subscriptionManager.AssertExpectations(t) +} + +func TestPromptLocation_NoPrompt_FiltersAllowedValuesCaseInsensitive(t *testing.T) { + cfg := config.NewEmptyConfig() + err := cfg.Set("defaults.location", "WESTUS3") + require.NoError(t, err) + + ucm := newInMemoryUserConfigManager(cfg) + authManager := &mockauth.MockAuthManager{} + subscriptionManager := &mockaccount.MockSubscriptionManager{} + resourceService := &mockazapi.MockResourceService{} + mockConsole := mockinput.NewMockConsole() + mockConsole.SetNoPromptMode(true) + + subscriptionManager. + On("GetLocations", mock.Anything, "sub-123"). + Return([]account.Location{ + {Name: "eastus2", DisplayName: "East US 2", RegionalDisplayName: "(US) East US 2"}, + {Name: "westus3", DisplayName: "West US 3", RegionalDisplayName: "(US) West US 3"}, + }, nil) + + ps := NewPromptService( + authManager, + mockConsole, + ucm, + subscriptionManager, + resourceService, + &internal.GlobalCommandOptions{NoPrompt: true}, + ) + + location, err := ps.PromptLocation(context.Background(), &AzureContext{ + Scope: AzureScope{SubscriptionId: "sub-123"}, + }, &SelectOptions{ + AllowedValues: []string{"WestUS3"}, + }) + + require.NoError(t, err) + require.Equal(t, "westus3", location.Name) + subscriptionManager.AssertExpectations(t) +} + +func TestPromptLocation_NoPrompt_DefaultFilteredOut(t *testing.T) { + cfg := config.NewEmptyConfig() + err := cfg.Set("defaults.location", "westus3") + require.NoError(t, err) + + ucm := newInMemoryUserConfigManager(cfg) + authManager := &mockauth.MockAuthManager{} + subscriptionManager := &mockaccount.MockSubscriptionManager{} + resourceService := &mockazapi.MockResourceService{} + mockConsole := mockinput.NewMockConsole() + mockConsole.SetNoPromptMode(true) + + subscriptionManager. + On("GetLocations", mock.Anything, "sub-123"). + Return([]account.Location{ + {Name: "eastus2", DisplayName: "East US 2", RegionalDisplayName: "(US) East US 2"}, + {Name: "westus3", DisplayName: "West US 3", RegionalDisplayName: "(US) West US 3"}, + }, nil) + + ps := NewPromptService( + authManager, + mockConsole, + ucm, + subscriptionManager, + resourceService, + &internal.GlobalCommandOptions{NoPrompt: true}, + ) + + _, err = ps.PromptLocation(context.Background(), &AzureContext{ + Scope: AzureScope{SubscriptionId: "sub-123"}, + }, &SelectOptions{ + AllowedValues: []string{"eastus2"}, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "default location 'westus3' not found in the available location options") + subscriptionManager.AssertExpectations(t) +} + +func TestPromptLocation_NoPrompt_IgnoresEmptyAllowedValues(t *testing.T) { + cfg := config.NewEmptyConfig() + err := cfg.Set("defaults.location", "westus3") + require.NoError(t, err) + + ucm := newInMemoryUserConfigManager(cfg) + authManager := &mockauth.MockAuthManager{} + subscriptionManager := &mockaccount.MockSubscriptionManager{} + resourceService := &mockazapi.MockResourceService{} + mockConsole := mockinput.NewMockConsole() + mockConsole.SetNoPromptMode(true) + + subscriptionManager. + On("GetLocations", mock.Anything, "sub-123"). + Return([]account.Location{ + {Name: "eastus2", DisplayName: "East US 2", RegionalDisplayName: "(US) East US 2"}, + {Name: "westus3", DisplayName: "West US 3", RegionalDisplayName: "(US) West US 3"}, + }, nil) + + ps := NewPromptService( + authManager, + mockConsole, + ucm, + subscriptionManager, + resourceService, + &internal.GlobalCommandOptions{NoPrompt: true}, + ) + + location, err := ps.PromptLocation(context.Background(), &AzureContext{ + Scope: AzureScope{SubscriptionId: "sub-123"}, + }, &SelectOptions{ + AllowedValues: []string{" ", ""}, + }) + + require.NoError(t, err) + require.Equal(t, "westus3", location.Name) + subscriptionManager.AssertExpectations(t) +}