Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/azd/grpc/proto/prompt.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ message PromptSubscriptionResponse {

message PromptLocationRequest {
AzureContext azure_context = 1;
repeated string allowed_locations = 2;
}

message PromptLocationResponse {
Expand Down
29 changes: 25 additions & 4 deletions cli/azd/internal/grpcserver/prompt_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,14 @@ func (s *promptService) PromptLocation(
return nil, err
}

selectedLocation, err := s.prompter.PromptLocation(ctx, azureContext, nil)
var selectorOptions *prompt.SelectOptions
if len(req.AllowedLocations) > 0 {
selectorOptions = &prompt.SelectOptions{
AllowedValues: req.AllowedLocations,
}
}

selectedLocation, err := s.prompter.PromptLocation(ctx, azureContext, selectorOptions)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -843,6 +850,18 @@ func (s *promptService) PromptAiDeployment(

// --- Step 3: Resolve capacity, optionally prompting ---
capacity := ai.ResolveCapacity(selectedSku.sku, options.Capacity)
if req.Quota != nil && selectedSku.remaining != nil {
resolvedCapacity, ok := ai.ResolveCapacityWithQuota(selectedSku.sku, options.Capacity, *selectedSku.remaining)
if !ok {
return nil, aiStatusError(
codes.FailedPrecondition,
azdext.AiErrorReasonNoDeploymentMatch,
fmt.Sprintf("no deployment match for model %q with the selected SKU and quota", req.ModelName),
map[string]string{"model_name": req.ModelName},
)
}
capacity = resolvedCapacity
}

if !req.UseDefaultCapacity {
sku := selectedSku.sku
Expand Down Expand Up @@ -1229,8 +1248,6 @@ func buildSkuCandidatesForVersion(
continue
}

capacity := ai.ResolveCapacity(sku, options.Capacity)

var remaining *float64
if quota != nil {
if usageMap == nil {
Expand All @@ -1244,7 +1261,11 @@ func buildSkuCandidatesForVersion(

rem := usage.Limit - usage.CurrentValue
remaining = &rem
if rem < minReq || (capacity > 0 && float64(capacity) > rem) {
if rem < minReq {
continue
}

if _, ok := ai.ResolveCapacityWithQuota(sku, options.Capacity, rem); !ok {
continue
}
}
Expand Down
66 changes: 66 additions & 0 deletions cli/azd/internal/grpcserver/prompt_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package grpcserver
import (
"context"
"errors"
"slices"
"testing"

"github.com/azure/azure-dev/cli/azd/internal"
Expand Down Expand Up @@ -216,6 +217,39 @@ func Test_PromptService_PromptLocation(t *testing.T) {
mockPrompter.AssertExpectations(t)
}

func Test_PromptService_PromptLocation_WithAllowedLocations(t *testing.T) {
mockPrompter := &mockprompt.MockPromptService{}
globalOptions := &internal.GlobalCommandOptions{NoPrompt: false}

expectedLocation := &account.Location{
Name: "westus3",
DisplayName: "West US 3",
RegionalDisplayName: "(US) West US 3",
}

mockPrompter.
On("PromptLocation", mock.Anything, mock.Anything, mock.MatchedBy(func(opts *prompt.SelectOptions) bool {
return opts != nil && slices.Equal(opts.AllowedValues, []string{"westus3", "eastus2"})
})).
Return(expectedLocation, nil)

service := NewPromptService(mockPrompter, nil, nil, globalOptions)

resp, err := service.PromptLocation(context.Background(), &azdext.PromptLocationRequest{
AzureContext: &azdext.AzureContext{
Scope: &azdext.AzureScope{
SubscriptionId: "sub-123",
},
},
AllowedLocations: []string{"westus3", "eastus2"},
})

require.NoError(t, err)
require.NotNil(t, resp.Location)
require.Equal(t, expectedLocation.Name, resp.Location.Name)
mockPrompter.AssertExpectations(t)
}

func Test_PromptService_PromptResourceGroup(t *testing.T) {
mockPrompter := &mockprompt.MockPromptService{}
globalOptions := &internal.GlobalCommandOptions{NoPrompt: false}
Expand Down Expand Up @@ -826,6 +860,38 @@ func Test_buildSkuCandidatesForVersion(t *testing.T) {
require.NotNil(t, candidates[0].remaining)
require.Equal(t, float64(10), *candidates[0].remaining)
})

t.Run("falls back to lower capacity that fits remaining quota", func(t *testing.T) {
deepSeekVersion := ai.AiModelVersion{
Version: "1",
Skus: []ai.AiModelSku{
{
Name: "GlobalStandard",
UsageName: "AIServices.GlobalStandard.DeepSeek-R1-0528",
DefaultCapacity: 5000,
MinCapacity: 0,
MaxCapacity: 5000,
CapacityStep: 0,
},
},
}
quota := &azdext.QuotaCheckOptions{
MinRemainingCapacity: 1,
}
usageMap := map[string]ai.AiModelUsage{
"AIServices.GlobalStandard.DeepSeek-R1-0528": {
Name: "AIServices.GlobalStandard.DeepSeek-R1-0528",
CurrentValue: 0,
Limit: 1000,
},
}

candidates := buildSkuCandidatesForVersion(deepSeekVersion, nil, quota, usageMap, false)
require.Len(t, candidates, 1)
require.Equal(t, "AIServices.GlobalStandard.DeepSeek-R1-0528", candidates[0].sku.UsageName)
require.NotNil(t, candidates[0].remaining)
require.Equal(t, float64(1000), *candidates[0].remaining)
})
}

func Test_maxSkuCandidateRemaining(t *testing.T) {
Expand Down
130 changes: 115 additions & 15 deletions cli/azd/pkg/ai/model_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,15 @@ func (s *AiModelService) ListModels(
return s.convertToAiModels(rawModels), nil
}

// ListLocations returns subscription location names that can be used for model queries.
// ListLocations returns AI Services-supported location names that can be used for model queries.
func (s *AiModelService) ListLocations(
ctx context.Context,
subscriptionId string,
) ([]string, error) {
subLocations, err := s.subManager.GetLocations(ctx, subscriptionId)
locations, err := s.azureClient.GetResourceSkuLocations(
ctx, subscriptionId, "AIServices", "S0", "Standard", "accounts")
if err != nil {
return nil, fmt.Errorf("listing locations: %w", err)
}

locations := make([]string, 0, len(subLocations))
for _, loc := range subLocations {
locations = append(locations, loc.Name)
return nil, fmt.Errorf("listing AI Services locations: %w", err)
}

return locations, nil
Expand Down Expand Up @@ -456,9 +452,8 @@ func (s *AiModelService) resolveDeployments(
continue
}

capacity := ResolveCapacity(sku, options.Capacity)

// Quota check
capacity := ResolveCapacity(sku, options.Capacity)
if quotaOpts != nil && usageMap != nil {
usage, ok := usageMap[sku.UsageName]
if !ok {
Expand All @@ -470,9 +465,15 @@ func (s *AiModelService) resolveDeployments(
if minReq <= 0 {
minReq = 1
}
if remaining < minReq || (capacity > 0 && float64(capacity) > remaining) {
if remaining < minReq {
continue
}

resolvedCapacity, fitsQuota := ResolveCapacityWithQuota(sku, options.Capacity, remaining)
if !fitsQuota {
continue
}
capacity = resolvedCapacity
}

// Only set location when exactly one was provided — never guess.
Expand Down Expand Up @@ -731,16 +732,109 @@ func convertSku(sku *armcognitiveservices.ModelSKU) AiModelSku {
func ResolveCapacity(sku AiModelSku, preferred *int32) int32 {
if preferred != nil {
cap := *preferred
if cap > 0 &&
(sku.MinCapacity <= 0 || cap >= sku.MinCapacity) &&
(sku.MaxCapacity <= 0 || cap <= sku.MaxCapacity) &&
(sku.CapacityStep <= 0 || cap%sku.CapacityStep == 0) {
if capacityValidForSku(sku, cap) {
return cap
}
}
return sku.DefaultCapacity
}

// ResolveCapacityWithQuota resolves the deployment capacity for a SKU while considering remaining quota.
// If preferred is set, it must fit within the remaining quota or resolution fails.
// When preferred is unset and the default capacity exceeds remaining quota, it falls back to the highest
// valid capacity within the SKU constraints that still fits in the remaining quota.
func ResolveCapacityWithQuota(sku AiModelSku, preferred *int32, remaining float64) (int32, bool) {
capacity := ResolveCapacity(sku, preferred)
if preferred != nil {
return capacity, capacityFitsWithinQuota(sku, capacity, remaining)
}

if capacityFitsWithinQuota(sku, capacity, remaining) {
return capacity, true
}

return fallbackCapacityWithinQuota(sku, remaining)
}

func capacityValidForSku(sku AiModelSku, capacity int32) bool {
if capacity <= 0 {
return false
}

if sku.MinCapacity > 0 && capacity < sku.MinCapacity {
return false
}

if sku.MaxCapacity > 0 && capacity > sku.MaxCapacity {
return false
}

if sku.CapacityStep > 0 {
baseline := capacityStepBaseline(sku)
if capacity < baseline || (capacity-baseline)%sku.CapacityStep != 0 {
return false
}
}

return true
}

func capacityStepBaseline(sku AiModelSku) int32 {
if sku.MinCapacity > 0 {
return sku.MinCapacity
}

return sku.CapacityStep
}

func minimumValidCapacity(sku AiModelSku) int32 {
if sku.MinCapacity > 0 {
return sku.MinCapacity
}

if sku.CapacityStep > 0 {
return sku.CapacityStep
}

return 1
}

func capacityFitsWithinQuota(sku AiModelSku, capacity int32, remaining float64) bool {
if !capacityValidForSku(sku, capacity) {
return false
}

return float64(capacity) <= remaining
}

func fallbackCapacityWithinQuota(sku AiModelSku, remaining float64) (int32, bool) {
upperBound := int32(remaining)
if upperBound <= 0 {
return 0, false
}

if sku.MaxCapacity > 0 && upperBound > sku.MaxCapacity {
upperBound = sku.MaxCapacity
}

lowerBound := minimumValidCapacity(sku)

if upperBound < lowerBound {
return 0, false
}

if sku.CapacityStep <= 0 {
return upperBound, true
}

candidate := upperBound - (upperBound-lowerBound)%sku.CapacityStep
if candidate < lowerBound {
return 0, false
}

return candidate, true
}

// ModelHasDefaultVersion returns true if any version of the model is marked as default.
func ModelHasDefaultVersion(model AiModel) bool {
for _, v := range model.Versions {
Expand All @@ -758,6 +852,9 @@ func modelHasQuota(model AiModel, usageMap map[string]AiModelUsage, minRemaining
if ok {
remaining := usage.Limit - usage.CurrentValue
if remaining >= minRemaining {
if _, ok := ResolveCapacityWithQuota(sku, nil, remaining); !ok {
continue
}
return true
}
}
Expand All @@ -777,6 +874,9 @@ func maxModelRemainingQuota(model AiModel, usageMap map[string]AiModelUsage) (fl
}

remaining := usage.Limit - usage.CurrentValue
if _, ok := ResolveCapacityWithQuota(sku, nil, remaining); !ok {
continue
}
if !found || remaining > maxRemaining {
maxRemaining = remaining
}
Expand Down
Loading
Loading