Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cli/azd/grpc/proto/prompt.proto
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ message PromptSubscriptionResponse {

message PromptLocationRequest {
AzureContext azure_context = 1;
repeated string allowed_locations = 2;
}

message PromptLocationResponse {
Expand Down
29 changes: 25 additions & 4 deletions cli/azd/internal/grpcserver/prompt_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,14 @@ func (s *promptService) PromptLocation(
return nil, err
}

selectedLocation, err := s.prompter.PromptLocation(ctx, azureContext, nil)
var selectorOptions *prompt.SelectOptions
if len(req.AllowedLocations) > 0 {
selectorOptions = &prompt.SelectOptions{
AllowedValues: req.AllowedLocations,
}
}

selectedLocation, err := s.prompter.PromptLocation(ctx, azureContext, selectorOptions)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -843,6 +850,18 @@ func (s *promptService) PromptAiDeployment(

// --- Step 3: Resolve capacity, optionally prompting ---
capacity := ai.ResolveCapacity(selectedSku.sku, options.Capacity)
if req.Quota != nil && selectedSku.remaining != nil {
resolvedCapacity, ok := ai.ResolveCapacityWithQuota(selectedSku.sku, options.Capacity, *selectedSku.remaining)
if !ok {
return nil, aiStatusError(
codes.FailedPrecondition,
azdext.AiErrorReasonNoDeploymentMatch,
fmt.Sprintf("no deployment match for model %q with the selected SKU and quota", req.ModelName),
map[string]string{"model_name": req.ModelName},
)
}
capacity = resolvedCapacity
}

if !req.UseDefaultCapacity {
sku := selectedSku.sku
Expand Down Expand Up @@ -1229,8 +1248,6 @@ func buildSkuCandidatesForVersion(
continue
}

capacity := ai.ResolveCapacity(sku, options.Capacity)

var remaining *float64
if quota != nil {
if usageMap == nil {
Expand All @@ -1244,7 +1261,11 @@ func buildSkuCandidatesForVersion(

rem := usage.Limit - usage.CurrentValue
remaining = &rem
if rem < minReq || (capacity > 0 && float64(capacity) > rem) {
if rem < minReq {
continue
}

if _, ok := ai.ResolveCapacityWithQuota(sku, options.Capacity, rem); !ok {
continue
}
}
Expand Down
66 changes: 66 additions & 0 deletions cli/azd/internal/grpcserver/prompt_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package grpcserver
import (
"context"
"errors"
"slices"
"testing"

"github.com/azure/azure-dev/cli/azd/internal"
Expand Down Expand Up @@ -216,6 +217,39 @@ func Test_PromptService_PromptLocation(t *testing.T) {
mockPrompter.AssertExpectations(t)
}

func Test_PromptService_PromptLocation_WithAllowedLocations(t *testing.T) {
mockPrompter := &mockprompt.MockPromptService{}
globalOptions := &internal.GlobalCommandOptions{NoPrompt: false}

expectedLocation := &account.Location{
Name: "westus3",
DisplayName: "West US 3",
RegionalDisplayName: "(US) West US 3",
}

mockPrompter.
On("PromptLocation", mock.Anything, mock.Anything, mock.MatchedBy(func(opts *prompt.SelectOptions) bool {
return opts != nil && slices.Equal(opts.AllowedValues, []string{"westus3", "eastus2"})
})).
Return(expectedLocation, nil)

service := NewPromptService(mockPrompter, nil, nil, globalOptions)

resp, err := service.PromptLocation(context.Background(), &azdext.PromptLocationRequest{
AzureContext: &azdext.AzureContext{
Scope: &azdext.AzureScope{
SubscriptionId: "sub-123",
},
},
AllowedLocations: []string{"westus3", "eastus2"},
})

require.NoError(t, err)
require.NotNil(t, resp.Location)
require.Equal(t, expectedLocation.Name, resp.Location.Name)
mockPrompter.AssertExpectations(t)
}

func Test_PromptService_PromptResourceGroup(t *testing.T) {
mockPrompter := &mockprompt.MockPromptService{}
globalOptions := &internal.GlobalCommandOptions{NoPrompt: false}
Expand Down Expand Up @@ -826,6 +860,38 @@ func Test_buildSkuCandidatesForVersion(t *testing.T) {
require.NotNil(t, candidates[0].remaining)
require.Equal(t, float64(10), *candidates[0].remaining)
})

t.Run("falls back to lower capacity that fits remaining quota", func(t *testing.T) {
deepSeekVersion := ai.AiModelVersion{
Version: "1",
Skus: []ai.AiModelSku{
{
Name: "GlobalStandard",
UsageName: "AIServices.GlobalStandard.DeepSeek-R1-0528",
DefaultCapacity: 5000,
MinCapacity: 0,
MaxCapacity: 5000,
CapacityStep: 0,
},
},
}
quota := &azdext.QuotaCheckOptions{
MinRemainingCapacity: 1,
}
usageMap := map[string]ai.AiModelUsage{
"AIServices.GlobalStandard.DeepSeek-R1-0528": {
Name: "AIServices.GlobalStandard.DeepSeek-R1-0528",
CurrentValue: 0,
Limit: 1000,
},
}

candidates := buildSkuCandidatesForVersion(deepSeekVersion, nil, quota, usageMap, false)
require.Len(t, candidates, 1)
require.Equal(t, "AIServices.GlobalStandard.DeepSeek-R1-0528", candidates[0].sku.UsageName)
require.NotNil(t, candidates[0].remaining)
require.Equal(t, float64(1000), *candidates[0].remaining)
})
}

func Test_maxSkuCandidateRemaining(t *testing.T) {
Expand Down
107 changes: 96 additions & 11 deletions cli/azd/pkg/ai/model_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,15 @@ func (s *AiModelService) ListModels(
return s.convertToAiModels(rawModels), nil
}

// ListLocations returns subscription location names that can be used for model queries.
// ListLocations returns AI Services-supported location names that can be used for model queries.
func (s *AiModelService) ListLocations(
ctx context.Context,
subscriptionId string,
) ([]string, error) {
subLocations, err := s.subManager.GetLocations(ctx, subscriptionId)
locations, err := s.azureClient.GetResourceSkuLocations(
ctx, subscriptionId, "AIServices", "S0", "Standard", "accounts")
if err != nil {
return nil, fmt.Errorf("listing locations: %w", err)
}

locations := make([]string, 0, len(subLocations))
for _, loc := range subLocations {
locations = append(locations, loc.Name)
return nil, fmt.Errorf("listing AI Services locations: %w", err)
}

return locations, nil
Expand Down Expand Up @@ -456,9 +452,8 @@ func (s *AiModelService) resolveDeployments(
continue
}

capacity := ResolveCapacity(sku, options.Capacity)

// Quota check
capacity := ResolveCapacity(sku, options.Capacity)
if quotaOpts != nil && usageMap != nil {
usage, ok := usageMap[sku.UsageName]
if !ok {
Expand All @@ -470,9 +465,15 @@ func (s *AiModelService) resolveDeployments(
if minReq <= 0 {
minReq = 1
}
if remaining < minReq || (capacity > 0 && float64(capacity) > remaining) {
if remaining < minReq {
continue
}

resolvedCapacity, fitsQuota := ResolveCapacityWithQuota(sku, options.Capacity, remaining)
if !fitsQuota {
continue
}
capacity = resolvedCapacity
}

// Only set location when exactly one was provided — never guess.
Expand Down Expand Up @@ -741,6 +742,84 @@ func ResolveCapacity(sku AiModelSku, preferred *int32) int32 {
return sku.DefaultCapacity
}

// ResolveCapacityWithQuota resolves the deployment capacity for a SKU while considering remaining quota.
// If preferred is set, it must fit within the remaining quota or resolution fails.
// When preferred is unset and the default capacity exceeds remaining quota, it falls back to the highest
// valid capacity within the SKU constraints that still fits in the remaining quota.
func ResolveCapacityWithQuota(sku AiModelSku, preferred *int32, remaining float64) (int32, bool) {
capacity := ResolveCapacity(sku, preferred)
if preferred != nil {
return capacity, capacityFitsWithinQuota(sku, capacity, remaining)
}

if capacityFitsWithinQuota(sku, capacity, remaining) {
return capacity, true
}

return fallbackCapacityWithinQuota(sku, remaining)
}

func capacityFitsWithinQuota(sku AiModelSku, capacity int32, remaining float64) bool {
if capacity <= 0 {
return false
}

if sku.MinCapacity > 0 && capacity < sku.MinCapacity {
return false
}

if sku.MaxCapacity > 0 && capacity > sku.MaxCapacity {
return false
}

if sku.CapacityStep > 0 {
baseline := sku.MinCapacity
if baseline <= 0 {
baseline = sku.CapacityStep
}
if capacity < baseline || (capacity-baseline)%sku.CapacityStep != 0 {
Comment thread
JeffreyCA marked this conversation as resolved.
return false
}
}

return float64(capacity) <= remaining
}

func fallbackCapacityWithinQuota(sku AiModelSku, remaining float64) (int32, bool) {
upperBound := int32(remaining)
if upperBound <= 0 {
return 0, false
}

if sku.MaxCapacity > 0 && upperBound > sku.MaxCapacity {
upperBound = sku.MaxCapacity
}

lowerBound := sku.MinCapacity
if lowerBound <= 0 {
if sku.CapacityStep > 0 {
lowerBound = sku.CapacityStep
} else {
lowerBound = 1
}
}

if upperBound < lowerBound {
return 0, false
}

if sku.CapacityStep <= 0 {
return upperBound, true
}

candidate := upperBound - (upperBound-lowerBound)%sku.CapacityStep
if candidate < lowerBound {
return 0, false
}

return candidate, true
}

// ModelHasDefaultVersion returns true if any version of the model is marked as default.
func ModelHasDefaultVersion(model AiModel) bool {
for _, v := range model.Versions {
Expand All @@ -758,6 +837,9 @@ func modelHasQuota(model AiModel, usageMap map[string]AiModelUsage, minRemaining
if ok {
remaining := usage.Limit - usage.CurrentValue
if remaining >= minRemaining {
if _, ok := ResolveCapacityWithQuota(sku, nil, remaining); !ok {
continue
}
return true
}
}
Expand All @@ -777,6 +859,9 @@ func maxModelRemainingQuota(model AiModel, usageMap map[string]AiModelUsage) (fl
}

remaining := usage.Limit - usage.CurrentValue
if _, ok := ResolveCapacityWithQuota(sku, nil, remaining); !ok {
continue
}
if !found || remaining > maxRemaining {
maxRemaining = remaining
}
Expand Down
62 changes: 62 additions & 0 deletions cli/azd/pkg/ai/model_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,68 @@ func TestResolveCapacity(t *testing.T) {
}
}

func TestResolveCapacityWithQuota(t *testing.T) {
t.Run("uses default when it fits in remaining quota", func(t *testing.T) {
capacity, ok := ResolveCapacityWithQuota(AiModelSku{
DefaultCapacity: 25,
MinCapacity: 1,
MaxCapacity: 100,
CapacityStep: 1,
}, nil, 50)

require.True(t, ok)
require.Equal(t, int32(25), capacity)
})

t.Run("falls back below default when no preferred capacity is set", func(t *testing.T) {
capacity, ok := ResolveCapacityWithQuota(AiModelSku{
DefaultCapacity: 5000,
MinCapacity: 0,
MaxCapacity: 5000,
CapacityStep: 0,
}, nil, 1000)

require.True(t, ok)
require.Equal(t, int32(1000), capacity)
})

t.Run("respects min and step when falling back", func(t *testing.T) {
capacity, ok := ResolveCapacityWithQuota(AiModelSku{
DefaultCapacity: 3000,
MinCapacity: 100,
MaxCapacity: 3000,
CapacityStep: 100,
}, nil, 950)

require.True(t, ok)
require.Equal(t, int32(900), capacity)
})

t.Run("fails when explicit preferred capacity does not fit", func(t *testing.T) {
capacity, ok := ResolveCapacityWithQuota(AiModelSku{
DefaultCapacity: 5000,
MinCapacity: 0,
MaxCapacity: 5000,
CapacityStep: 0,
}, new(int32(5000)), 1000)

require.False(t, ok)
require.Equal(t, int32(5000), capacity)
})

t.Run("fails when remaining quota is below effective minimum", func(t *testing.T) {
capacity, ok := ResolveCapacityWithQuota(AiModelSku{
DefaultCapacity: 0,
MinCapacity: 100,
MaxCapacity: 3000,
CapacityStep: 100,
}, nil, 50)

require.False(t, ok)
require.Equal(t, int32(0), capacity)
})
}

func TestMaxModelRemainingQuota(t *testing.T) {
model := AiModel{
Name: "gpt-4o",
Expand Down
Loading
Loading