Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cli/azd/docs/extensions/extension-framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -1979,11 +1979,15 @@ Returns available AI models for a subscription.
- `locations` (repeated string)
- `capabilities` (repeated string)
- `formats` (repeated string)
- `statuses` (repeated string)
- `statuses` (repeated string, applied to version lifecycle status before aggregation)
- `exclude_model_names` (repeated string)
- **Response:** _ListModelsResponse_
- `models` (repeated _AiModel_)

`filter.statuses` matches version-level lifecycle status before aggregation. Returned models
only contain versions (and locations) that matched. `AiModel.lifecycle_status` is deprecated
and always empty; use `AiModelVersion.lifecycle_status` for lifecycle state.

If `filter.locations` is empty, models are listed across all subscription locations.
When `filter.locations` is provided, it limits which models are returned, but each returned model still contains canonical
`locations`.
Expand Down
4 changes: 3 additions & 1 deletion cli/azd/extensions/microsoft.azd.demo/internal/cmd/ai.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ func printAiModelDetails(model *azdext.AiModel) {
color.HiWhite("Model Details")
fmt.Printf(" Name: %s\n", color.CyanString(model.Name))
fmt.Printf(" Format: %s\n", model.Format)
fmt.Printf(" Status: %s\n", model.LifecycleStatus)

if len(model.Capabilities) > 0 {
capabilities := slices.Clone(model.Capabilities)
Expand Down Expand Up @@ -138,6 +137,9 @@ func printAiModelDetails(model *azdext.AiModel) {
defaultLabel = color.YellowString(" (default)")
}
fmt.Printf(" - Version: %s%s\n", version.Version, defaultLabel)
if version.LifecycleStatus != "" {
fmt.Printf(" Status: %s\n", version.LifecycleStatus)
}

skus := slices.Clone(version.Skus)
slices.SortFunc(skus, func(a, b *azdext.AiModelSku) int {
Expand Down
8 changes: 5 additions & 3 deletions cli/azd/grpc/proto/ai_model.proto
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ service AiModelService {
message AiModel {
string name = 1; // e.g. "gpt-4o"
string format = 2; // e.g. "OpenAI"
string lifecycle_status = 3; // e.g. "preview", "stable"
string lifecycle_status = 3 [deprecated = true]; // deprecated; always empty; use AiModelVersion.lifecycle_status
repeated string capabilities = 4; // e.g. ["chat", "embeddings"]
repeated AiModelVersion versions = 5;
repeated string locations = 6; // canonical locations where available
Expand All @@ -50,6 +50,7 @@ message AiModelVersion {
string version = 1;
bool is_default = 2;
repeated AiModelSku skus = 3;
string lifecycle_status = 4; // e.g. "GenerallyAvailable", "Preview"
}

// AiModelSku represents a deployment SKU with capacity constraints.
Expand Down Expand Up @@ -110,8 +111,9 @@ message AiModelFilterOptions {
// Matches AiModel.format exactly (for example: "OpenAI", "Microsoft").
repeated string formats = 3;

// Include models whose lifecycle status matches one of these values.
// Matches AiModel.lifecycle_status exactly (for example: "Stable", "Preview").
// Include model versions whose lifecycle status matches one of these values.
// Filtering is applied before aggregation, so returned versions, derived
// AiModel.lifecycle_status, and locations reflect only matching versions.
repeated string statuses = 4;

// Exclude models by exact model name (for example: "gpt-4o-mini").
Expand Down
38 changes: 20 additions & 18 deletions cli/azd/pkg/ai/mapper_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ func registerAiModelMappings() {
versions[i] = proto
}

// LifecycleStatus intentionally omitted — deprecated, always empty.
return &azdext.AiModel{
Name: src.Name,
Format: src.Format,
LifecycleStatus: src.LifecycleStatus,
Capabilities: src.Capabilities,
Versions: versions,
Locations: src.Locations,
Name: src.Name,
Format: src.Format,
Capabilities: src.Capabilities,
Versions: versions,
Locations: src.Locations,
}, nil
})

Expand All @@ -44,13 +44,13 @@ func registerAiModelMappings() {
versions[i] = protoToAiModelVersion(v)
}

// LifecycleStatus intentionally omitted — deprecated, always empty.
return &AiModel{
Name: src.Name,
Format: src.Format,
LifecycleStatus: src.LifecycleStatus,
Capabilities: src.Capabilities,
Versions: versions,
Locations: src.Locations,
Name: src.Name,
Format: src.Format,
Capabilities: src.Capabilities,
Versions: versions,
Locations: src.Locations,
}, nil
})

Expand Down Expand Up @@ -120,9 +120,10 @@ func aiModelVersionToProto(src *AiModelVersion) (*azdext.AiModelVersion, error)
}

return &azdext.AiModelVersion{
Version: src.Version,
IsDefault: src.IsDefault,
Skus: skus,
Version: src.Version,
IsDefault: src.IsDefault,
Skus: skus,
LifecycleStatus: src.LifecycleStatus,
}, nil
}

Expand All @@ -133,9 +134,10 @@ func protoToAiModelVersion(src *azdext.AiModelVersion) AiModelVersion {
}

return AiModelVersion{
Version: src.Version,
IsDefault: src.IsDefault,
Skus: skus,
Version: src.Version,
IsDefault: src.IsDefault,
Skus: skus,
LifecycleStatus: src.LifecycleStatus,
}
}

Expand Down
145 changes: 126 additions & 19 deletions cli/azd/pkg/ai/model_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"slices"
"strings"
"sync"
"time"

"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/cognitiveservices/armcognitiveservices"
"github.com/azure/azure-dev/cli/azd/pkg/account"
Expand Down Expand Up @@ -82,14 +83,30 @@ func (s *AiModelService) ListFilteredModels(
subscriptionId string,
options *FilterOptions,
) ([]AiModel, error) {
if options == nil {
return s.ListModels(ctx, subscriptionId, nil)
}

filteredOptions := *options

// Fetch canonical models and apply filters in-memory so model metadata
// (especially Locations) remains complete.
models, err := s.ListModels(ctx, subscriptionId, nil)
// remains complete for non-status filters. Status filtering is applied during
// aggregation so Versions, derived LifecycleStatus, and Locations reflect only
// versions matching the requested statuses.
locations, err := s.ListLocations(ctx, subscriptionId)
if err != nil {
return nil, err
}

rawModels, err := s.fetchModelsForLocations(ctx, subscriptionId, locations)
if err != nil {
return nil, err
}

return FilterModels(models, options), nil
models := s.convertToAiModelsAt(rawModels, time.Now().UTC(), filteredOptions.Statuses)
filteredOptions.Statuses = nil

return FilterModels(models, &filteredOptions), nil
}

// ListModelVersions returns available versions for a specific model at a location.
Expand Down Expand Up @@ -569,13 +586,27 @@ func (s *AiModelService) fetchModelsForLocations(
// convertToAiModels converts raw ARM models grouped by location into domain AiModel types.
func (s *AiModelService) convertToAiModels(
rawByLocation map[string][]*armcognitiveservices.Model,
) []AiModel {
return s.convertToAiModelsAt(rawByLocation, time.Now().UTC(), nil)
}

// convertToAiModelsAt converts raw ARM models grouped by location into domain AiModel types,
// optionally filtering by version lifecycle status before aggregation. The now parameter
// makes deprecation filtering deterministic in tests.
func (s *AiModelService) convertToAiModelsAt(
rawByLocation map[string][]*armcognitiveservices.Model,
now time.Time,
statuses []string,
) []AiModel {
// Aggregate: model name → location → version → SKUs
modelMap := make(map[string]*AiModel)

for loc, models := range rawByLocation {
for _, m := range models {
if m.Model == nil || m.Model.Name == nil {
if m.Model == nil || m.Model.Name == nil || modelVersionDeprecated(m.Model, now) {
continue
}
if len(statuses) > 0 && !slices.Contains(statuses, modelLifecycleStatusValue(m.Model.LifecycleStatus)) {
continue
}
name := *m.Model.Name
Expand All @@ -586,9 +617,6 @@ func (s *AiModelService) convertToAiModels(
Name: name,
Format: safeString(m.Model.Format),
}
if m.Model.LifecycleStatus != nil {
aiModel.LifecycleStatus = string(*m.Model.LifecycleStatus)
}
if m.Model.Capabilities != nil {
for key := range m.Model.Capabilities {
aiModel.Capabilities = append(aiModel.Capabilities, key)
Expand All @@ -598,21 +626,29 @@ func (s *AiModelService) convertToAiModels(
modelMap[name] = aiModel
}

// Track locations
if !slices.Contains(aiModel.Locations, loc) {
aiModel.Locations = append(aiModel.Locations, loc)
}

// Build version entry
ver := safeString(m.Model.Version)
isDefault := m.Model.IsDefaultVersion != nil && *m.Model.IsDefaultVersion
lifecycleStatus := modelLifecycleStatusValue(m.Model.LifecycleStatus)

hadSkus := len(m.Model.SKUs) > 0
var skus []AiModelSku
if m.Model.SKUs != nil {
for _, sku := range m.Model.SKUs {
if modelSkuDeprecated(sku, now) {
continue
}
skus = append(skus, convertSku(sku))
}
}
if hadSkus && len(skus) == 0 {
continue
}

// Track locations only when this location contributes a surviving version/SKU.
if !slices.Contains(aiModel.Locations, loc) {
aiModel.Locations = append(aiModel.Locations, loc)
}

// Find or create version in model
versionFound := false
Expand All @@ -622,6 +658,9 @@ func (s *AiModelService) convertToAiModels(
if isDefault {
aiModel.Versions[i].IsDefault = true
}
if aiModel.Versions[i].LifecycleStatus == "" {
aiModel.Versions[i].LifecycleStatus = lifecycleStatus
}
// Merge SKUs (deduplicate by name + usage_name, since the same SKU name
// can appear with different usage names representing different quota pools)
for _, newSku := range skus {
Expand All @@ -636,9 +675,10 @@ func (s *AiModelService) convertToAiModels(
}
if !versionFound {
aiModel.Versions = append(aiModel.Versions, AiModelVersion{
Version: ver,
IsDefault: isDefault,
Skus: skus,
Version: ver,
IsDefault: isDefault,
LifecycleStatus: lifecycleStatus,
Skus: skus,
})
}
}
Expand All @@ -647,6 +687,9 @@ func (s *AiModelService) convertToAiModels(
// Convert map to sorted slice
result := make([]AiModel, 0, len(modelMap))
for _, model := range modelMap {
if len(model.Versions) == 0 {
continue
}
slices.Sort(model.Locations)
result = append(result, *model)
}
Expand All @@ -657,23 +700,87 @@ func (s *AiModelService) convertToAiModels(
return result
}

// FilterModels applies FilterOptions to a list of models.
func modelVersionDeprecated(model *armcognitiveservices.AccountModel, now time.Time) bool {
if model == nil {
return false
}

if modelLifecycleDeprecated(model.LifecycleStatus) {
return true
}

return modelDeprecationReached(model.Deprecation, now)
}

func modelLifecycleDeprecated(status *armcognitiveservices.ModelLifecycleStatus) bool {
if status == nil {
return false
}

return strings.EqualFold(string(*status), "Deprecated")
}

func modelLifecycleStatusValue(status *armcognitiveservices.ModelLifecycleStatus) string {
if status == nil {
return ""
}

return string(*status)
}

func modelDeprecationReached(info *armcognitiveservices.ModelDeprecationInfo, now time.Time) bool {
if info == nil || info.Inference == nil {
return false
}

return deprecationReached(*info.Inference, now)
}

func modelSkuDeprecated(sku *armcognitiveservices.ModelSKU, now time.Time) bool {
if sku == nil || sku.DeprecationDate == nil {
return false
}

return !sku.DeprecationDate.After(now)
}

func deprecationReached(value string, now time.Time) bool {
if strings.TrimSpace(value) == "" {
return false
}

deprecatedAt, err := time.Parse(time.RFC3339, value)
if err != nil {
return false
}

return !deprecatedAt.After(now)
}

// FilterModels applies FilterOptions to already-aggregated models. When Statuses is set,
// versions are pruned, but Locations cannot be recomputed (version-to-location provenance
// is lost). Use ListFilteredModels for full fidelity.
func FilterModels(models []AiModel, options *FilterOptions) []AiModel {
if options == nil {
return models
}

var filtered []AiModel
for _, model := range models {
if len(options.Statuses) > 0 {
model.Versions = slices.DeleteFunc(slices.Clone(model.Versions), func(version AiModelVersion) bool {
return !slices.Contains(options.Statuses, version.LifecycleStatus)
})
if len(model.Versions) == 0 {
continue
}
}
if len(options.ExcludeModelNames) > 0 && slices.Contains(options.ExcludeModelNames, model.Name) {
continue
}
if len(options.Formats) > 0 && !slices.Contains(options.Formats, model.Format) {
continue
}
if len(options.Statuses) > 0 && !slices.Contains(options.Statuses, model.LifecycleStatus) {
continue
}
if len(options.Capabilities) > 0 {
hasCapability := false
for _, cap := range options.Capabilities {
Expand Down
Loading
Loading