diff --git a/cli/azd/internal/cmd/add/add_coverage_test.go b/cli/azd/internal/cmd/add/add_coverage_test.go new file mode 100644 index 00000000000..1d0a4353a7c --- /dev/null +++ b/cli/azd/internal/cmd/add/add_coverage_test.go @@ -0,0 +1,833 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package add + +import ( + "bytes" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/internal/appdetect" + "github.com/azure/azure-dev/cli/azd/pkg/project" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// add.go — NewAddCmd +// --------------------------------------------------------------------------- + +func TestNewAddCmd_ReturnsCommand(t *testing.T) { + t.Parallel() + cmd := NewAddCmd() + require.NotNil(t, cmd) + assert.Equal(t, "add", cmd.Use) + assert.NotEmpty(t, cmd.Short) +} + +// --------------------------------------------------------------------------- +// add.go — selectMenu +// --------------------------------------------------------------------------- + +func TestSelectMenu_AllNamespacesPresent(t *testing.T) { + t.Parallel() + a := &AddAction{} + menu := a.selectMenu() + require.NotEmpty(t, menu) + + namespaces := make(map[string]bool, len(menu)) + for _, m := range menu { + namespaces[m.Namespace] = true + assert.NotEmpty(t, m.Label, "menu item %q should have a label", m.Namespace) + assert.NotNil(t, m.SelectResource, "menu item %q should have a SelectResource func", m.Namespace) + } + + for _, ns := range []string{"db", "host", "ai", "messaging", "storage", "keyvault", "existing"} { + assert.True(t, namespaces[ns], "expected namespace %q in menu", ns) + } +} + +// --------------------------------------------------------------------------- +// add_configure.go — Configure: default (unknown) type +// --------------------------------------------------------------------------- + +func TestConfigure_DefaultType_ReturnsUnchanged(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceType("unknown.something"), + Name: "thing", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := Configure(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, r, got) +} + +// --------------------------------------------------------------------------- +// add_configure.go — Configure: Existing with name preset +// --------------------------------------------------------------------------- + +func TestConfigure_ExistingWithNamePreset(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbPostgres, + Name: "existing-db", + Existing: true, + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := Configure(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, "existing-db", got.Name) + assert.True(t, got.Existing) +} + +// --------------------------------------------------------------------------- +// add_configure.go — Configure: DB types with name preset (short-circuit) +// --------------------------------------------------------------------------- + +func TestConfigure_DbTypesWithNamePreset(t *testing.T) { + t.Parallel() + tests := []struct { + name string + resType project.ResourceType + }{ + {"postgres", project.ResourceTypeDbPostgres}, + {"mysql", project.ResourceTypeDbMySql}, + {"mongo", project.ResourceTypeDbMongo}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: tt.resType, + Name: "my-db", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := Configure(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, "my-db", got.Name) + }) + } +} + +// --------------------------------------------------------------------------- +// add_configure.go — Configure: CosmosDB sets CosmosDBProps +// --------------------------------------------------------------------------- + +func TestConfigure_CosmosDbWithNamePreset(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbCosmos, + Name: "my-cosmos", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := Configure(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, "my-cosmos", got.Name) + _, ok := got.Props.(project.CosmosDBProps) + assert.True(t, ok, "expected CosmosDBProps to be set") +} + +// --------------------------------------------------------------------------- +// add_configure.go — Configure: OpenAI with name preset +// --------------------------------------------------------------------------- + +func TestConfigure_OpenAiWithNamePreset(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeOpenAiModel, + Name: "my-model", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := Configure(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, "my-model", got.Name) +} + +// --------------------------------------------------------------------------- +// add_configure.go — Configure: host types with empty resources +// (fillUses short-circuits when no resources to link) +// --------------------------------------------------------------------------- + +func TestConfigure_HostTypes_EmptyResources(t *testing.T) { + t.Parallel() + tests := []struct { + name string + resType project.ResourceType + }{ + {"container app", project.ResourceTypeHostContainerApp}, + {"app service", project.ResourceTypeHostAppService}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: tt.resType, + Name: "my-host", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := Configure(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, "my-host", got.Name) + }) + } +} + +// --------------------------------------------------------------------------- +// add_configure.go / add_configure_messaging.go — duplicate messaging errors +// --------------------------------------------------------------------------- + +func TestConfigure_MessagingDuplicates(t *testing.T) { + t.Parallel() + tests := []struct { + name string + resType project.ResourceType + existingKey string + wantError string + }{ + { + name: "event hubs duplicate", + resType: project.ResourceTypeMessagingEventHubs, + existingKey: "event-hubs", + wantError: "only one event hubs", + }, + { + name: "service bus duplicate", + resType: project.ResourceTypeMessagingServiceBus, + existingKey: "service-bus", + wantError: "only one service bus", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{Type: tt.resType} + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{ + tt.existingKey: {}, + }, + }, + } + _, err := Configure(t.Context(), r, nil, opts) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.wantError) + }) + } +} + +// --------------------------------------------------------------------------- +// add_configure.go / add_configure_storage.go — storage duplicate & invalid props +// --------------------------------------------------------------------------- + +func TestConfigure_StorageDuplicate(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeStorage, + Props: project.StorageProps{}, + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{ + "storage": {}, + }, + }, + } + _, err := Configure(t.Context(), r, nil, opts) + require.Error(t, err) + assert.Contains(t, err.Error(), "only one Storage") +} + +func TestConfigure_StorageInvalidProps(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeStorage, + Props: nil, // not StorageProps + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + _, err := Configure(t.Context(), r, nil, opts) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid resource properties") +} + +// --------------------------------------------------------------------------- +// add_configure.go — ConfigureLive +// --------------------------------------------------------------------------- + +func TestConfigureLive_ExistingResource(t *testing.T) { + t.Parallel() + a := &AddAction{} + r := &project.ResourceConfig{ + Type: project.ResourceTypeOpenAiModel, + Name: "existing-model", + Existing: true, + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := a.ConfigureLive(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, r, got) +} + +func TestConfigureLive_NonAiType_ReturnsUnchanged(t *testing.T) { + t.Parallel() + a := &AddAction{} + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbPostgres, + Name: "my-db", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := a.ConfigureLive(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, r, got) +} + +// --------------------------------------------------------------------------- +// add_configure_existing.go — ConfigureExisting with name preset +// --------------------------------------------------------------------------- + +func TestConfigureExisting_WithNamePreset(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "my-redis", + ResourceId: "/subscriptions/sub1/resourceGroups/rg/providers/Microsoft.Cache/redis/my-redis", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + got, err := ConfigureExisting(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Equal(t, "my-redis", got.Name) + assert.Equal(t, r.ResourceId, got.ResourceId) +} + +// --------------------------------------------------------------------------- +// add_configure_host.go — PromptPort (pure language-based paths) +// --------------------------------------------------------------------------- + +func TestPromptPort_NoDocker(t *testing.T) { + t.Parallel() + tests := []struct { + name string + lang appdetect.Language + wantPort int + }{ + {"python returns 80", appdetect.Python, 80}, + {"java returns 8080", appdetect.Java, 8080}, + {"dotnet returns 8080", appdetect.DotNet, 8080}, + {"javascript returns 80", appdetect.JavaScript, 80}, + {"typescript returns 80", appdetect.TypeScript, 80}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + prj := appdetect.Project{ + Language: tt.lang, + Docker: nil, + } + port, err := PromptPort(nil, t.Context(), "svc", prj) + require.NoError(t, err) + assert.Equal(t, tt.wantPort, port) + }) + } +} + +func TestPromptPort_DockerEmptyPath(t *testing.T) { + t.Parallel() + prj := appdetect.Project{ + Language: appdetect.Python, + Docker: &appdetect.Docker{Path: ""}, + } + port, err := PromptPort(nil, t.Context(), "svc", prj) + require.NoError(t, err) + assert.Equal(t, 80, port) +} + +func TestPromptPort_SingleDockerPort(t *testing.T) { + t.Parallel() + prj := appdetect.Project{ + Language: appdetect.Python, + Docker: &appdetect.Docker{ + Path: "/some/Dockerfile", + Ports: []appdetect.Port{{Number: 3000}}, + }, + } + port, err := PromptPort(nil, t.Context(), "svc", prj) + require.NoError(t, err) + assert.Equal(t, 3000, port) +} + +// --------------------------------------------------------------------------- +// add_configure_host.go — addServiceAsResource +// --------------------------------------------------------------------------- + +func TestAddServiceAsResource_UnsupportedTarget(t *testing.T) { + t.Parallel() + svc := &project.ServiceConfig{ + Name: "svc", + Host: project.ServiceTargetKind("unsupported"), + } + _, err := addServiceAsResource(t.Context(), nil, svc, appdetect.Project{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported service target") +} + +func TestAddServiceAsResource_ContainerApp_NoDockerfile(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() // no Dockerfile inside + svc := &project.ServiceConfig{ + Name: "test-svc", + Host: project.ContainerAppTarget, + Language: project.ServiceLanguagePython, + RelativePath: tempDir, + } + prj := appdetect.Project{Language: appdetect.Python} + r, err := addServiceAsResource(t.Context(), nil, svc, prj) + require.NoError(t, err) + assert.Equal(t, "test-svc", r.Name) + assert.Equal(t, project.ResourceTypeHostContainerApp, r.Type) + props, ok := r.Props.(project.ContainerAppProps) + require.True(t, ok) + assert.Equal(t, 80, props.Port) +} + +func TestAddServiceAsResource_ContainerApp_JavaPort(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + svc := &project.ServiceConfig{ + Name: "java-svc", + Host: project.ContainerAppTarget, + Language: project.ServiceLanguageJava, + RelativePath: tempDir, + } + prj := appdetect.Project{Language: appdetect.Java} + r, err := addServiceAsResource(t.Context(), nil, svc, prj) + require.NoError(t, err) + props, ok := r.Props.(project.ContainerAppProps) + require.True(t, ok) + assert.Equal(t, 8080, props.Port) +} + +func TestAddServiceAsResource_ContainerApp_DotNetPort(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + svc := &project.ServiceConfig{ + Name: "dotnet-svc", + Host: project.ContainerAppTarget, + Language: project.ServiceLanguageDotNet, + RelativePath: tempDir, + } + prj := appdetect.Project{Language: appdetect.DotNet} + r, err := addServiceAsResource(t.Context(), nil, svc, prj) + require.NoError(t, err) + props, ok := r.Props.(project.ContainerAppProps) + require.True(t, ok) + assert.Equal(t, 8080, props.Port) +} + +func TestAddServiceAsResource_AppService_UnsupportedLanguage(t *testing.T) { + t.Parallel() + tempDir := t.TempDir() + svc := &project.ServiceConfig{ + Name: "java-svc", + Host: project.AppServiceTarget, + Language: project.ServiceLanguageJava, + RelativePath: tempDir, + } + prj := appdetect.Project{Language: appdetect.Java} + _, err := addServiceAsResource(t.Context(), nil, svc, prj) + require.Error(t, err) + assert.Contains(t, err.Error(), "unsupported language") +} + +// --------------------------------------------------------------------------- +// add_configure_host.go — ServiceFromDetect additional cases +// --------------------------------------------------------------------------- + +func TestServiceFromDetect_AngularDependency(t *testing.T) { + t.Parallel() + svc, err := ServiceFromDetect( + "/projects", + "angular-app", + appdetect.Project{ + Path: "/projects/angular-app", + Language: appdetect.TypeScript, + Dependencies: []appdetect.Dependency{ + appdetect.JsAngular, + }, + }, + project.ContainerAppTarget, + ) + require.NoError(t, err) + assert.Equal(t, "dist/angular-app", svc.OutputPath) +} + +func TestServiceFromDetect_DockerRelativePath(t *testing.T) { + t.Parallel() + svc, err := ServiceFromDetect( + "/projects", + "docker-svc", + appdetect.Project{ + Path: "/projects/app", + Language: appdetect.Python, + Docker: &appdetect.Docker{ + Path: "/projects/app/Dockerfile", + }, + }, + project.ContainerAppTarget, + ) + require.NoError(t, err) + assert.Equal(t, "docker-svc", svc.Name) + assert.Equal(t, "Dockerfile", svc.Docker.Path) +} + +func TestServiceFromDetect_WithRootPath(t *testing.T) { + t.Parallel() + svc, err := ServiceFromDetect( + "/projects", + "mono-svc", + appdetect.Project{ + Path: "/projects/app", + Language: appdetect.Python, + Docker: &appdetect.Docker{ + Path: "/projects/app/Dockerfile", + }, + RootPath: "/projects", + }, + project.ContainerAppTarget, + ) + require.NoError(t, err) + assert.Equal(t, "..", svc.Docker.Context) +} + +func TestServiceFromDetect_ViteOverridesReact(t *testing.T) { + t.Parallel() + svc, err := ServiceFromDetect( + "/projects", + "spa", + appdetect.Project{ + Path: "/projects/spa", + Language: appdetect.TypeScript, + Dependencies: []appdetect.Dependency{ + appdetect.JsReact, // react sets "build" + appdetect.JsVite, // vite overrides to "dist" + }, + }, + project.ContainerAppTarget, + ) + require.NoError(t, err) + assert.Equal(t, "dist", svc.OutputPath) +} + +// --------------------------------------------------------------------------- +// add_select.go — selectStorage, selectKeyVault +// --------------------------------------------------------------------------- + +func TestSelectStorage_ReturnType(t *testing.T) { + t.Parallel() + r, err := selectStorage(nil, t.Context(), PromptOptions{}) + require.NoError(t, err) + assert.Equal(t, project.ResourceTypeStorage, r.Type) + _, ok := r.Props.(project.StorageProps) + assert.True(t, ok, "expected StorageProps") +} + +func TestSelectKeyVault_ReturnType(t *testing.T) { + t.Parallel() + r, err := selectKeyVault(nil, t.Context(), PromptOptions{}) + require.NoError(t, err) + assert.Equal(t, project.ResourceTypeKeyVault, r.Type) +} + +// --------------------------------------------------------------------------- +// add_configure.go — promptUsedBy +// --------------------------------------------------------------------------- + +func TestPromptUsedBy_EmptyResources(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "redis", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{}, + }, + } + result, err := promptUsedBy(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Nil(t, result) +} + +func TestPromptUsedBy_NonHostResources(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "redis", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{ + "postgres": {Type: project.ResourceTypeDbPostgres, Name: "postgres"}, + }, + }, + } + result, err := promptUsedBy(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Nil(t, result) +} + +func TestPromptUsedBy_HostAlreadyUsesResource(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeDbRedis, + Name: "redis", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{ + "web": { + Type: project.ResourceTypeHostContainerApp, + Name: "web", + Uses: []string{"redis"}, + }, + }, + }, + } + result, err := promptUsedBy(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Nil(t, result) +} + +func TestPromptUsedBy_DifferentHostTypesSkipped(t *testing.T) { + t.Parallel() + r := &project.ResourceConfig{ + Type: project.ResourceTypeHostContainerApp, + Name: "backend", + } + opts := PromptOptions{ + PrjConfig: &project.ProjectConfig{ + Resources: map[string]*project.ResourceConfig{ + "web": { + Type: project.ResourceTypeHostAppService, // different host type + Name: "web", + }, + }, + }, + } + result, err := promptUsedBy(t.Context(), r, nil, opts) + require.NoError(t, err) + assert.Nil(t, result) +} + +// --------------------------------------------------------------------------- +// add_select_ai.go — selectSearch, selectOpenAi, selectAiModel +// --------------------------------------------------------------------------- + +func TestSelectSearch_ReturnType(t *testing.T) { + t.Parallel() + a := &AddAction{} + r, err := a.selectSearch(nil, t.Context(), PromptOptions{}) + require.NoError(t, err) + assert.Equal(t, project.ResourceTypeAiSearch, r.Type) +} + +func TestSelectOpenAi_ReturnType(t *testing.T) { + t.Parallel() + a := &AddAction{} + r, err := a.selectOpenAi(nil, t.Context(), PromptOptions{}) + require.NoError(t, err) + assert.Equal(t, project.ResourceTypeOpenAiModel, r.Type) +} + +func TestSelectAiModel_ReturnType(t *testing.T) { + t.Parallel() + a := &AddAction{} + r, err := a.selectAiModel(nil, t.Context(), PromptOptions{}) + require.NoError(t, err) + assert.Equal(t, project.ResourceTypeAiProject, r.Type) +} + +// --------------------------------------------------------------------------- +// add_select_ai.go — selectFromSkus +// --------------------------------------------------------------------------- + +func TestSelectFromSkus_Empty(t *testing.T) { + t.Parallel() + _, err := selectFromSkus(t.Context(), nil, "Select", []ModelSku{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "no skus found") +} + +func TestSelectFromSkus_SingleAutoSelects(t *testing.T) { + t.Parallel() + expected := ModelSku{ + Name: "Standard", + UsageName: "std", + Capacity: ModelSkuCapacity{Default: 10}, + } + got, err := selectFromSkus(t.Context(), nil, "Select", []ModelSku{expected}) + require.NoError(t, err) + assert.Equal(t, expected, got) +} + +// --------------------------------------------------------------------------- +// add_select_ai.go — selectFromMap (single-entry auto-select) +// --------------------------------------------------------------------------- + +func TestSelectFromMap_SingleEntry(t *testing.T) { + t.Parallel() + m := map[string]string{"only-key": "only-value"} + key, val, err := selectFromMap(t.Context(), nil, "Pick one", m, nil) + require.NoError(t, err) + assert.Equal(t, "only-key", key) + assert.Equal(t, "only-value", val) +} + +func TestSelectFromMap_SingleEntry_ComplexType(t *testing.T) { + t.Parallel() + m := map[string]ModelCatalogKind{ + "gpt-4o": {Kinds: map[string]ModelCatalogVersions{}}, + } + key, val, err := selectFromMap(t.Context(), nil, "Select model", m, nil) + require.NoError(t, err) + assert.Equal(t, "gpt-4o", key) + assert.NotNil(t, val.Kinds) +} + +// --------------------------------------------------------------------------- +// diff.go — DiffBlocks: modified entry (same key, different value) +// --------------------------------------------------------------------------- + +func TestDiffBlocks_ModifiedEntry(t *testing.T) { + t.Parallel() + old := map[string]*project.ResourceConfig{ + "db": {Type: project.ResourceTypeDbPostgres}, + } + newMap := map[string]*project.ResourceConfig{ + "db": {Type: project.ResourceTypeDbPostgres, Uses: []string{"web"}}, + } + result, err := DiffBlocks(old, newMap) + require.NoError(t, err) + assert.Contains(t, result, "db:") + assert.NotEmpty(t, result) +} + +// --------------------------------------------------------------------------- +// diff.go — DiffBlocks: multiple new entries (verify sorted output) +// --------------------------------------------------------------------------- + +func TestDiffBlocks_MultipleNewEntries_Sorted(t *testing.T) { + t.Parallel() + old := map[string]*project.ResourceConfig{} + newMap := map[string]*project.ResourceConfig{ + "beta": {Type: project.ResourceTypeDbRedis, Name: "beta"}, + "alpha": {Type: project.ResourceTypeDbPostgres, Name: "alpha"}, + } + result, err := DiffBlocks(old, newMap) + require.NoError(t, err) + + alphaIdx := strings.Index(result, "alpha:") + betaIdx := strings.Index(result, "beta:") + require.Greater(t, alphaIdx, -1, "expected alpha in output") + require.Greater(t, betaIdx, -1, "expected beta in output") + assert.Less(t, alphaIdx, betaIdx, "entries should be sorted alphabetically") +} + +// --------------------------------------------------------------------------- +// diff.go — DiffBlocks: new + existing mix +// --------------------------------------------------------------------------- + +func TestDiffBlocks_NewAndExistingMix(t *testing.T) { + t.Parallel() + existing := &project.ResourceConfig{Type: project.ResourceTypeDbRedis, Name: "redis"} + old := map[string]*project.ResourceConfig{ + "redis": existing, + } + newMap := map[string]*project.ResourceConfig{ + "redis": existing, // unchanged + "postgres": {Type: project.ResourceTypeDbPostgres, Name: "postgres"}, // new + } + result, err := DiffBlocks(old, newMap) + require.NoError(t, err) + // Unchanged redis should NOT appear, new postgres should appear + assert.Contains(t, result, "postgres:") + assert.Contains(t, result, "+") +} + +// --------------------------------------------------------------------------- +// add_preview.go — previewWriter edge cases +// --------------------------------------------------------------------------- + +func TestPreviewWriter_EmptyWrite(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + pw := &previewWriter{w: &buf} + n, err := pw.Write([]byte{}) + require.NoError(t, err) + assert.Equal(t, 0, n) + assert.Empty(t, buf.String()) +} + +func TestPreviewWriter_NoNewline_BuffersInternally(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + pw := &previewWriter{w: &buf} + n, err := pw.Write([]byte("partial")) + require.NoError(t, err) + assert.Equal(t, 7, n) + // No newline means nothing flushed to underlying writer + assert.Empty(t, buf.String()) +} + +func TestPreviewWriter_MultipleLines(t *testing.T) { + t.Parallel() + var buf bytes.Buffer + pw := &previewWriter{w: &buf} + input := "+ added\n normal\n" + n, err := pw.Write([]byte(input)) + require.NoError(t, err) + assert.Equal(t, len(input), n) + out := buf.String() + assert.Contains(t, out, "added") + assert.Contains(t, out, "normal") +} diff --git a/cli/azd/pkg/auth/credential_providers_test.go b/cli/azd/pkg/auth/credential_providers_test.go new file mode 100644 index 00000000000..170064ea800 --- /dev/null +++ b/cli/azd/pkg/auth/credential_providers_test.go @@ -0,0 +1,295 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import ( + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/azure/azure-dev/cli/azd/pkg/cloud" + "github.com/stretchr/testify/require" +) + +// tokenServer creates an httptest.Server that responds to RemoteCredential token +// requests. It returns a valid success response and tracks the number of calls +// received via the returned *atomic.Int32. +func tokenServer(t *testing.T) (*httptest.Server, *atomic.Int32) { + t.Helper() + + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, + `{"status":"success","token":"tok","expiresOn":"2099-01-01T00:00:00Z"}`) + })) + t.Cleanup(srv.Close) + return srv, &calls +} + +// errorTokenServer creates an httptest.Server that always returns an error +// response so EnsureLoggedInCredential fails. +func errorTokenServer(t *testing.T) *httptest.Server { + t.Helper() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, + `{"status":"error","code":"auth_failed","message":"token denied"}`) + })) + t.Cleanup(srv.Close) + return srv +} + +// externalAuthManager returns a *Manager configured to use external auth backed +// by the given endpoint URL. The cloud field is set to AzurePublic. +func externalAuthManager(endpoint string, client *http.Client) *Manager { + return &Manager{ + cloud: cloud.AzurePublic(), + externalAuthCfg: ExternalAuthConfiguration{ + Endpoint: endpoint, + Key: "test-key", + Transporter: client, + }, + } +} + +func TestCredentialProvider_SuccessAndCaching(t *testing.T) { + t.Parallel() + + srv, calls := tokenServer(t) + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + // First call: should hit the server (via EnsureLoggedInCredential) + cred1, err := provider.GetTokenCredential(t.Context(), "tenant-a") + require.NoError(t, err) + require.NotNil(t, cred1) + require.Equal(t, int32(1), calls.Load(), "expected exactly one HTTP call on first fetch") + + // Second call with same tenant: should return cached credential, no new HTTP call + cred2, err := provider.GetTokenCredential(t.Context(), "tenant-a") + require.NoError(t, err) + require.Same(t, cred1, cred2, "expected same pointer from cache") + require.Equal(t, int32(1), calls.Load(), "expected no additional HTTP call on cache hit") +} + +func TestCredentialProvider_DifferentTenants(t *testing.T) { + t.Parallel() + + srv, calls := tokenServer(t) + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + credA, err := provider.GetTokenCredential(t.Context(), "tenant-a") + require.NoError(t, err) + + credB, err := provider.GetTokenCredential(t.Context(), "tenant-b") + require.NoError(t, err) + + require.NotSame(t, credA, credB, "different tenants must return different credential instances") + require.Equal(t, int32(2), calls.Load(), "expected one HTTP call per distinct tenant") +} + +func TestCredentialProvider_EmptyTenantID(t *testing.T) { + t.Parallel() + + srv, calls := tokenServer(t) + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + cred, err := provider.GetTokenCredential(t.Context(), "") + require.NoError(t, err) + require.NotNil(t, cred) + require.Equal(t, int32(1), calls.Load()) + + // Empty tenant should also be cached under the "" key + cred2, err := provider.GetTokenCredential(t.Context(), "") + require.NoError(t, err) + require.Same(t, cred, cred2) + require.Equal(t, int32(1), calls.Load(), "empty tenant credential should be cached") +} + +func TestCredentialProvider_ErrorFromCredentialForCurrentUser(t *testing.T) { + t.Parallel() + + // Manager with no auth config and no current user - CredentialForCurrentUser + // will return ErrNoCurrentUser. + m := &Manager{ + configManager: newMemoryConfigManager(), + userConfigManager: newMemoryUserConfigManager(), + publicClient: &mockPublicClient{}, + } + + provider := NewMultiTenantCredentialProvider(m) + _, err := provider.GetTokenCredential(t.Context(), "any-tenant") + + require.Error(t, err) + require.ErrorIs(t, err, ErrNoCurrentUser) +} + +func TestCredentialProvider_ErrorFromEnsureLoggedIn(t *testing.T) { + t.Parallel() + + // The remote credential server returns an error response, so + // EnsureLoggedInCredential (which calls GetToken) will fail. + srv := errorTokenServer(t) + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + _, err := provider.GetTokenCredential(t.Context(), "tenant-x") + + require.Error(t, err) + require.Contains(t, err.Error(), "token denied") +} + +func TestCredentialProvider_EnsureLoggedInErrorDoesNotCache(t *testing.T) { + t.Parallel() + + // Use a server that fails first, then succeeds. This verifies that a failed + // EnsureLoggedInCredential call does NOT store the credential in the cache. + var attempt atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + if attempt.Add(1) == 1 { + _, _ = io.WriteString(w, + `{"status":"error","code":"auth_failed","message":"transient failure"}`) + return + } + _, _ = io.WriteString(w, + `{"status":"success","token":"recovered","expiresOn":"2099-01-01T00:00:00Z"}`) + })) + t.Cleanup(srv.Close) + + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + // First call: EnsureLoggedInCredential fails + _, err := provider.GetTokenCredential(t.Context(), "tenant-retry") + require.Error(t, err) + require.Contains(t, err.Error(), "transient failure") + + // Second call: should NOT return a cached (bad) credential; should retry the + // full flow and succeed. + cred, err := provider.GetTokenCredential(t.Context(), "tenant-retry") + require.NoError(t, err) + require.NotNil(t, cred) +} + +func TestCredentialProvider_ConcurrentAccess(t *testing.T) { + t.Parallel() + + srv, calls := tokenServer(t) + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + const goroutines = 20 + + var wg sync.WaitGroup + wg.Add(goroutines) + + errs := make([]error, goroutines) + creds := make([]azcore.TokenCredential, goroutines) + + for i := range goroutines { + go func(idx int) { + defer wg.Done() + c, err := provider.GetTokenCredential(t.Context(), "shared-tenant") + creds[idx] = c + errs[idx] = err + }(i) + } + + wg.Wait() + + // All goroutines must succeed + for i, err := range errs { + require.NoError(t, err, "goroutine %d returned error", i) + require.NotNil(t, creds[i], "goroutine %d returned nil credential", i) + } + + // After concurrent access, a subsequent call must return a cached credential + cachedCred, err := provider.GetTokenCredential(t.Context(), "shared-tenant") + require.NoError(t, err) + require.NotNil(t, cachedCred) + + // Verify that the HTTP server wasn't called an excessive number of times. + // The implementation doesn't use LoadOrStore so multiple goroutines may + // redundantly create credentials, but total calls should be bounded. + totalCalls := calls.Load() + require.LessOrEqual(t, totalCalls, int32(goroutines), + "expected at most %d HTTP calls, got %d", goroutines, totalCalls) + require.GreaterOrEqual(t, totalCalls, int32(1), + "expected at least 1 HTTP call") +} + +func TestCredentialProvider_ConcurrentDifferentTenants(t *testing.T) { + t.Parallel() + + srv, calls := tokenServer(t) + m := externalAuthManager(srv.URL, srv.Client()) + provider := NewMultiTenantCredentialProvider(m) + + const tenantCount = 10 + var wg sync.WaitGroup + wg.Add(tenantCount) + + errs := make([]error, tenantCount) + creds := make([]azcore.TokenCredential, tenantCount) + + for i := range tenantCount { + go func(idx int) { + defer wg.Done() + c, err := provider.GetTokenCredential(t.Context(), fmt.Sprintf("tenant-%d", idx)) + creds[idx] = c + errs[idx] = err + }(i) + } + + wg.Wait() + + for i, err := range errs { + require.NoError(t, err, "tenant-%d returned error", i) + require.NotNil(t, creds[i], "tenant-%d returned nil credential", i) + } + + // Each distinct tenant should have made at least one HTTP call + require.GreaterOrEqual(t, calls.Load(), int32(tenantCount), + "expected at least %d HTTP calls for %d distinct tenants", tenantCount, tenantCount) +} + +func TestCredentialProvider_NewMultiTenantCredentialProviderReturnsInterface(t *testing.T) { + t.Parallel() + + m := &Manager{cloud: cloud.AzurePublic()} + provider := NewMultiTenantCredentialProvider(m) + + // Verify the returned value satisfies the interface + var _ MultiTenantCredentialProvider = provider + require.NotNil(t, provider) +} + +func TestCredentialProvider_CredentialForCurrentUserWrapsErrors(t *testing.T) { + t.Parallel() + + // A Manager with a userConfigManager that returns an error on Load + m := &Manager{ + configManager: newMemoryConfigManager(), + userConfigManager: &failingUserConfigManager{err: errors.New("config load boom")}, + publicClient: &mockPublicClient{}, + } + + provider := NewMultiTenantCredentialProvider(m) + _, err := provider.GetTokenCredential(t.Context(), "some-tenant") + + require.Error(t, err) + require.Contains(t, err.Error(), "config load boom") +} diff --git a/cli/azd/pkg/auth/remote_credential_test.go b/cli/azd/pkg/auth/remote_credential_test.go new file mode 100644 index 00000000000..81246a4221d --- /dev/null +++ b/cli/azd/pkg/auth/remote_credential_test.go @@ -0,0 +1,319 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package auth + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/stretchr/testify/require" +) + +func TestRemoteCredential(t *testing.T) { + t.Parallel() + + fixedExpiry := time.Date(2030, 6, 15, 12, 0, 0, 0, time.UTC) + + tests := []struct { + name string + tenantID string // struct-level tenant + opts policy.TokenRequestOptions + status int + body string + wantToken string + wantExpiry time.Time + wantErr bool + errContains string + }{ + { + name: "success returns token and expiry", + tenantID: "my-tenant", + opts: policy.TokenRequestOptions{ + Scopes: []string{"https://management.azure.com/.default"}, + }, + status: http.StatusOK, + body: remoteCredTestJSON(map[string]any{ + "status": "success", "token": "tok-abc", + "expiresOn": "2030-06-15T12:00:00Z", + }), + wantToken: "tok-abc", + wantExpiry: fixedExpiry, + }, + { + name: "error status returns code and message", + opts: policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }, + status: http.StatusOK, + body: remoteCredTestJSON( + map[string]any{"status": "error", "code": "auth_failed", "message": "bad creds"}), + wantErr: true, + errContains: "bad creds", + }, + { + name: "non-200 HTTP status returns error with status code", + opts: policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }, + status: http.StatusForbidden, + body: `{"error":"forbidden"}`, + wantErr: true, + errContains: "unexpected status code", + }, + { + name: "malformed JSON returns decode error", + opts: policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }, + status: http.StatusOK, + body: "<<>>", + wantErr: true, + errContains: "decoding token response", + }, + { + name: "unexpected status field returns error", + opts: policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }, + status: http.StatusOK, + body: remoteCredTestJSON(map[string]any{"status": "pending"}), + wantErr: true, + errContains: "unexpected status", + }, + { + name: "empty scopes still succeeds", + opts: policy.TokenRequestOptions{ + Scopes: []string{}, + }, + status: http.StatusOK, + body: remoteCredTestJSON(map[string]any{ + "status": "success", "token": "empty-scope-tok", + "expiresOn": "2030-06-15T12:00:00Z", + }), + wantToken: "empty-scope-tok", + wantExpiry: fixedExpiry, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(tt.status) + _, _ = io.WriteString(w, tt.body) + })) + defer srv.Close() + + rc := newRemoteCredential(srv.URL, "test-key", tt.tenantID, srv.Client()) + tok, err := rc.GetToken(t.Context(), tt.opts) + + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errContains) + require.Equal(t, azcore.AccessToken{}, tok) + return + } + + require.NoError(t, err) + require.Equal(t, tt.wantToken, tok.Token) + require.True(t, tt.wantExpiry.Equal(tok.ExpiresOn), + "expiry mismatch: want %v, got %v", tt.wantExpiry, tok.ExpiresOn) + }) + } +} + +// TestRemoteCredential_RequestFormat validates the HTTP method, URL path, query params, headers, and body. +func TestRemoteCredential_RequestFormat(t *testing.T) { + t.Parallel() + + var captured struct { + method string + path string + apiVersion string + contentType string + authHeader string + body []byte + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + captured.method = r.Method + captured.path = r.URL.Path + captured.apiVersion = r.URL.Query().Get("api-version") + captured.contentType = r.Header.Get("Content-Type") + captured.authHeader = r.Header.Get("Authorization") + captured.body, _ = io.ReadAll(r.Body) + + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"status":"success","token":"t","expiresOn":"2030-01-01T00:00:00Z"}`) + })) + defer srv.Close() + + rc := newRemoteCredential(srv.URL, "my-api-key", "the-tenant", srv.Client()) + _, err := rc.GetToken(t.Context(), policy.TokenRequestOptions{ + Scopes: []string{"https://graph.microsoft.com/.default", "openid"}, + }) + require.NoError(t, err) + + require.Equal(t, http.MethodPost, captured.method) + require.Equal(t, "/token", captured.path) + require.Equal(t, "2023-07-12-preview", captured.apiVersion) + require.Equal(t, "application/json", captured.contentType) + require.Equal(t, "Bearer my-api-key", captured.authHeader) + + var reqBody struct { + Scopes []string `json:"scopes"` + TenantId string `json:"tenantId"` + } + require.NoError(t, json.Unmarshal(captured.body, &reqBody)) + require.Equal(t, []string{"https://graph.microsoft.com/.default", "openid"}, reqBody.Scopes) + require.Equal(t, "the-tenant", reqBody.TenantId) +} + +// TestRemoteCredential_TenantOverride verifies options.TenantID overrides the struct tenantID. +func TestRemoteCredential_TenantOverride(t *testing.T) { + t.Parallel() + + var receivedTenant string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req struct { + TenantId string `json:"tenantId"` + } + _ = json.Unmarshal(body, &req) + receivedTenant = req.TenantId + + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"status":"success","token":"t","expiresOn":"2030-01-01T00:00:00Z"}`) + })) + defer srv.Close() + + rc := newRemoteCredential(srv.URL, "key", "default-tenant", srv.Client()) + _, err := rc.GetToken(t.Context(), policy.TokenRequestOptions{ + Scopes: []string{"s1"}, + TenantID: "override-tenant", + }) + require.NoError(t, err) + require.Equal(t, "override-tenant", receivedTenant) +} + +// TestRemoteCredential_TenantDefault verifies the struct tenantID is used when options.TenantID is empty. +func TestRemoteCredential_TenantDefault(t *testing.T) { + t.Parallel() + + var receivedTenant string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + var req struct { + TenantId string `json:"tenantId"` + } + _ = json.Unmarshal(body, &req) + receivedTenant = req.TenantId + + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"status":"success","token":"t","expiresOn":"2030-01-01T00:00:00Z"}`) + })) + defer srv.Close() + + rc := newRemoteCredential(srv.URL, "key", "struct-tenant", srv.Client()) + _, err := rc.GetToken(t.Context(), policy.TokenRequestOptions{ + Scopes: []string{"s1"}, + }) + require.NoError(t, err) + require.Equal(t, "struct-tenant", receivedTenant) +} + +// TestRemoteCredential_ConnectionFailure verifies the error path when the server is unreachable. +func TestRemoteCredential_ConnectionFailure(t *testing.T) { + t.Parallel() + + // Start a server then close it immediately so the port is unreachable. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + endpoint := srv.URL + srv.Close() + + rc := newRemoteCredential(endpoint, "key", "", http.DefaultClient) + _, err := rc.GetToken(t.Context(), policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "making request") +} + +// TestRemoteCredential_ContextCancelled verifies the request fails when the context is already cancelled. +func TestRemoteCredential_ContextCancelled(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"status":"success","token":"t","expiresOn":"2030-01-01T00:00:00Z"}`) + })) + defer srv.Close() + + ctx, cancel := context.WithCancel(t.Context()) + cancel() // cancel before making the request + + rc := newRemoteCredential(srv.URL, "key", "", srv.Client()) + _, err := rc.GetToken(ctx, policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "RemoteCredential") +} + +// TestRemoteCredential_NonOKStatusIncludesCode verifies the error message includes the HTTP status code. +func TestRemoteCredential_NonOKStatusIncludesCode(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) // 503 + })) + defer srv.Close() + + rc := newRemoteCredential(srv.URL, "key", "", srv.Client()) + _, err := rc.GetToken(t.Context(), policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "503") +} + +// TestRemoteCredential_ErrorResponseIncludesCode verifies the error response includes the code field. +func TestRemoteCredential_ErrorResponseIncludesCode(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"status":"error","code":"token_expired","message":"token has expired"}`) + })) + defer srv.Close() + + rc := newRemoteCredential(srv.URL, "key", "", srv.Client()) + _, err := rc.GetToken(t.Context(), policy.TokenRequestOptions{ + Scopes: []string{"scope1"}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "token_expired") + require.Contains(t, err.Error(), "token has expired") + require.Contains(t, err.Error(), "failed to acquire token") +} + +// remoteCredTestJSON marshals v to JSON, panicking on error. Test helper only. +func remoteCredTestJSON(v any) string { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return string(b) +} diff --git a/cli/azd/pkg/exec/runresult_test.go b/cli/azd/pkg/exec/runresult_test.go new file mode 100644 index 00000000000..7e9e4a3cd5c --- /dev/null +++ b/cli/azd/pkg/exec/runresult_test.go @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exec + +import ( + osexec "os/exec" + "testing" + + "github.com/stretchr/testify/require" +) + +// makeOSExitError runs a command that exits with a non-zero code and returns the resulting exec.ExitError. +// "go --help" exits with code 2 on all platforms. +func makeOSExitError(t *testing.T) osexec.ExitError { + t.Helper() + cmd := osexec.CommandContext(t.Context(), "go", "--help") //nolint:gosec // hardcoded test args + err := cmd.Run() + var exitErr *osexec.ExitError + require.ErrorAs(t, err, &exitErr) + return *exitErr +} + +func TestNewRunResult(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + code int + stdout string + stderr string + }{ + { + name: "SuccessWithOutput", + code: 0, + stdout: "hello world", + stderr: "", + }, + { + name: "NonZeroExitCode", + code: 1, + stdout: "", + stderr: "error occurred", + }, + { + name: "AllFieldsEmpty", + code: 0, + stdout: "", + stderr: "", + }, + { + name: "BothStdoutAndStderr", + code: 42, + stdout: "some output", + stderr: "some error", + }, + { + name: "NegativeExitCode", + code: -1, + stdout: "", + stderr: "", + }, + { + name: "MultilineOutput", + code: 0, + stdout: "line1\nline2\nline3", + stderr: "err1\nerr2", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + result := NewRunResult(tt.code, tt.stdout, tt.stderr) + require.Equal(t, tt.code, result.ExitCode) + require.Equal(t, tt.stdout, result.Stdout) + require.Equal(t, tt.stderr, result.Stderr) + }) + } +} + +func TestNewExitError(t *testing.T) { + t.Parallel() + + osExitErr := makeOSExitError(t) + + tests := []struct { + name string + cmd string + stdOut string + stdErr string + outputAvailable bool + }{ + { + name: "WithOutputAvailable", + cmd: "mycli", + stdOut: "standard output", + stdErr: "error output", + outputAvailable: true, + }, + { + name: "WithoutOutputAvailable", + cmd: "mycli", + stdOut: "", + stdErr: "", + outputAvailable: false, + }, + { + name: "OnlyStdout", + cmd: "anothercli", + stdOut: "some output", + stdErr: "", + outputAvailable: true, + }, + { + name: "OnlyStderr", + cmd: "anothercli", + stdOut: "", + stdErr: "some error", + outputAvailable: true, + }, + { + name: "EmptyCmd", + cmd: "", + stdOut: "out", + stdErr: "err", + outputAvailable: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := NewExitError(osExitErr, tt.cmd, tt.stdOut, tt.stdErr, tt.outputAvailable) + require.Error(t, err) + + var typedErr *ExitError + require.ErrorAs(t, err, &typedErr) + require.Equal(t, tt.cmd, typedErr.Cmd) + require.Equal(t, osExitErr.ExitCode(), typedErr.ExitCode) + }) + } +} + +func TestExitError_Error(t *testing.T) { + t.Parallel() + + osExitErr := makeOSExitError(t) + + tests := []struct { + name string + stdOut string + stdErr string + outputAvailable bool + wantContains []string + wantNotContains []string + }{ + { + name: "OutputAvailableIncludesStdoutStderr", + stdOut: "my stdout", + stdErr: "my stderr", + outputAvailable: true, + wantContains: []string{"exit code:", "stdout: my stdout", "stderr: my stderr"}, + }, + { + name: "OutputNotAvailableExcludesStdoutStderr", + stdOut: "my stdout", + stdErr: "my stderr", + outputAvailable: false, + wantContains: []string{"exit code:"}, + wantNotContains: []string{"stdout:", "stderr:"}, + }, + { + name: "EmptyOutputFields", + stdOut: "", + stdErr: "", + outputAvailable: true, + wantContains: []string{"exit code:", "stdout: ", "stderr: "}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + err := NewExitError(osExitErr, "testcmd", tt.stdOut, tt.stdErr, tt.outputAvailable) + errMsg := err.Error() + + for _, want := range tt.wantContains { + require.Contains(t, errMsg, want) + } + for _, notWant := range tt.wantNotContains { + require.NotContains(t, errMsg, notWant) + } + }) + } +} + +func TestExitError_ErrorContainsExitCode(t *testing.T) { + t.Parallel() + + osExitErr := makeOSExitError(t) + err := NewExitError(osExitErr, "go", "out", "err", true) + + // go --help exits with code 2 + require.ErrorContains(t, err, "exit code: 2") + require.ErrorContains(t, err, "stdout: out") + require.ErrorContains(t, err, "stderr: err") +} + +func TestExitError_SatisfiesErrorInterface(t *testing.T) { + t.Parallel() + + osExitErr := makeOSExitError(t) + err := NewExitError(osExitErr, "go", "", "", false) + + // NewExitError returns the error interface; confirm non-nil and type-assertable. + require.Error(t, err) + + var typedErr *ExitError + require.ErrorAs(t, err, &typedErr) + + // The error string is non-empty even without output. + require.NotEmpty(t, typedErr.Error()) +} diff --git a/cli/azd/pkg/exec/sanitizer_test.go b/cli/azd/pkg/exec/sanitizer_test.go new file mode 100644 index 00000000000..8f34ac9aa5b --- /dev/null +++ b/cli/azd/pkg/exec/sanitizer_test.go @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package exec + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRedactSensitiveArgs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args []string + sensitiveDataMatch []string + expected []string + }{ + { + name: "EmptySensitiveDataReturnsOriginalSlice", + args: []string{"--password", "secret123"}, + sensitiveDataMatch: []string{}, + expected: []string{"--password", "secret123"}, + }, + { + name: "NilSensitiveDataReturnsOriginalSlice", + args: []string{"--password", "secret123"}, + sensitiveDataMatch: nil, + expected: []string{"--password", "secret123"}, + }, + { + name: "EmptyArgs", + args: []string{}, + sensitiveDataMatch: []string{"secret"}, + expected: []string{}, + }, + { + name: "NoMatchingData", + args: []string{"git", "push", "origin"}, + sensitiveDataMatch: []string{"secret123"}, + expected: []string{"git", "push", "origin"}, + }, + { + name: "SingleArgRedacted", + args: []string{"--token", "abc123"}, + sensitiveDataMatch: []string{"abc123"}, + expected: []string{"--token", ""}, + }, + { + name: "MultipleArgsRedacted", + args: []string{"--user", "admin", "--pass", "s3cret"}, + sensitiveDataMatch: []string{"admin", "s3cret"}, + expected: []string{"--user", "", "--pass", ""}, + }, + { + name: "SamePatternAppearsMultipleTimesInOneArg", + args: []string{"user=admin&backup=admin"}, + sensitiveDataMatch: []string{"admin"}, + expected: []string{"user=&backup="}, + }, + { + name: "SensitiveDataAsSubstring", + args: []string{"Server=myhost;Password=secret123;Database=mydb"}, + sensitiveDataMatch: []string{"secret123"}, + expected: []string{"Server=myhost;Password=;Database=mydb"}, + }, + { + name: "MultipleSensitivePatternsInOneArg", + args: []string{"user:admin pass:secret123"}, + sensitiveDataMatch: []string{"admin", "secret123"}, + expected: []string{"user: pass:"}, + }, + { + name: "EntireArgIsSensitive", + args: []string{"mytoken"}, + sensitiveDataMatch: []string{"mytoken"}, + expected: []string{""}, + }, + { + name: "PreservesNonSensitiveArgs", + args: []string{"--verbose", "--token", "secret", "--output", "json"}, + sensitiveDataMatch: []string{"secret"}, + expected: []string{"--verbose", "--token", "", "--output", "json"}, + }, + { + name: "SingleArgSingleSensitive", + args: []string{"key=value"}, + sensitiveDataMatch: []string{"value"}, + expected: []string{"key="}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + originalArgs := make([]string, len(tt.args)) + copy(originalArgs, tt.args) + + result := RedactSensitiveArgs(tt.args, tt.sensitiveDataMatch) + require.Equal(t, tt.expected, result) + + // When sensitive patterns are provided, a new slice is allocated, so the original must be untouched. + if len(tt.sensitiveDataMatch) > 0 { + require.Equal(t, originalArgs, tt.args, "original args must not be modified") + } + }) + } +} + +func TestRedactSensitiveData_TokenPattern(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected string + }{ + { + name: "EmptyString", + input: "", + expected: "", + }, + { + name: "TokenField", + input: `{"token": "eyJhbGciOiJSUzI1NiJ9"}`, + expected: `{"token": ""}`, + }, + { + name: "KubectlFromLiteral", + input: `kubectl create secret generic my-secret --from-literal=DB_PASSWORD=super-s3cret`, + expected: `kubectl create secret generic my-secret --from-literal=DB_PASSWORD=`, + }, + { + name: "CombinedArgKeyValue", + input: `--api-key=abc123xyz`, + expected: `--api-key=`, + }, + { + name: "AccessTokenField", + input: `{"accessToken": "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.payload.sig"}`, + expected: `{"accessToken": ""}`, + }, + { + name: "DeploymentToken", + input: `az staticwebapp deploy --deployment-token abc123secret`, + expected: `az staticwebapp deploy --deployment-token `, + }, + { + name: "UsernameFlag", + input: `docker login --username myuser --password mypass`, + expected: `docker login --username --password `, + }, + { + name: "PasswordFlagAlone", + input: `mysql --password SuperSecret123`, + expected: `mysql --password `, + }, + { + name: "NoSensitiveData", + input: `just a plain message with no patterns`, + expected: `just a plain message with no patterns`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + actual := RedactSensitiveData(tt.input) + require.Equal(t, tt.expected, actual) + }) + } +} diff --git a/cli/azd/pkg/extensions/command_resolver_test.go b/cli/azd/pkg/extensions/command_resolver_test.go new file mode 100644 index 00000000000..4fd08f389be --- /dev/null +++ b/cli/azd/pkg/extensions/command_resolver_test.go @@ -0,0 +1,436 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package extensions + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolveCommandPathEdgeCases(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata *ExtensionCommandMetadata + args []string + want []string + }{ + { + name: "LongestMatchWins", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + {Name: []string{"deploy"}}, + { + Name: []string{"deploy"}, + Subcommands: []Command{ + {Name: []string{"deploy", "status"}}, + }, + }, + }, + }, + args: []string{"deploy", "status"}, + want: []string{"deploy", "status"}, + }, + { + name: "ShortMatchWhenSubcommandDoesNotMatch", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + {Name: []string{"deploy"}}, + { + Name: []string{"deploy"}, + Subcommands: []Command{ + {Name: []string{"deploy", "status"}}, + }, + }, + }, + }, + args: []string{"deploy", "other"}, + want: []string{"deploy"}, + }, + { + name: "CommandNameLongerThanArgsSkippedButSubcommandsSearched", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + // This parent command name is longer than what's provided, + // but its subcommands should still be searched. + Name: []string{"group", "sub", "deep"}, + Subcommands: []Command{ + {Name: []string{"single"}}, + }, + }, + }, + }, + args: []string{"single"}, + want: []string{"single"}, + }, + { + name: "EmptyCommandNameSkippedButSubcommandsSearched", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + // Parent has empty name (acts as a grouping container). + Name: []string{}, + Subcommands: []Command{ + {Name: []string{"leaf"}}, + }, + }, + }, + }, + args: []string{"leaf"}, + want: []string{"leaf"}, + }, + { + name: "EmptyAliasSkipped", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"colors"}, + Aliases: []string{"", "colours"}, + }, + }, + }, + args: []string{"colours"}, + want: []string{"colors"}, + }, + { + name: "AliasOnMultiSegmentCommand", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"mcp"}, + Subcommands: []Command{ + { + Name: []string{"mcp", "start"}, + Aliases: []string{"run"}, + }, + }, + }, + }, + }, + args: []string{"mcp", "run"}, + want: []string{"mcp", "start"}, + }, + { + name: "OnlyFlagsNoCommandArgs", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + {Name: []string{"version"}}, + }, + }, + args: []string{"--verbose"}, + want: nil, + }, + { + name: "DoubleDashBeforeCommandArgs", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + {Name: []string{"version"}}, + }, + }, + args: []string{"--", "version"}, + want: nil, + }, + { + name: "FlagBetweenCommandSegments", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + {Name: []string{"mcp"}}, + { + Name: []string{"mcp"}, + Subcommands: []Command{ + {Name: []string{"mcp", "start"}}, + }, + }, + }, + }, + // The flag stops command arg extraction at "mcp", so "start" is not considered a command segment. + args: []string{"mcp", "--verbose", "start"}, + want: []string{"mcp"}, + }, + { + name: "MultipleAliasesFirstMatchUsed", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"list"}, + Aliases: []string{"ls", "l"}, + }, + }, + }, + args: []string{"ls"}, + want: []string{"list"}, + }, + { + name: "SingleArgExactMatch", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + {Name: []string{"init"}}, + {Name: []string{"up"}}, + {Name: []string{"down"}}, + }, + }, + args: []string{"up"}, + want: []string{"up"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ResolveCommandPath(tt.metadata, tt.args) + require.Equal(t, tt.want, got) + }) + } +} + +func TestResolveCommandFlagsAllTypes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + metadata *ExtensionCommandMetadata + args []string + want []string + }{ + { + name: "IntFlag", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "count", Shorthand: "c", Type: "int"}, + }, + }, + }, + }, + args: []string{"run", "--count", "5"}, + want: []string{"count"}, + }, + { + name: "IntFlagShorthand", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "count", Shorthand: "c", Type: "int"}, + }, + }, + }, + }, + args: []string{"run", "-c", "3"}, + want: []string{"count"}, + }, + { + name: "StringArrayFlag", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "tags", Shorthand: "t", Type: "stringArray"}, + }, + }, + }, + }, + args: []string{"run", "--tags", "a,b,c"}, + want: []string{"tags"}, + }, + { + name: "IntArrayFlag", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "ports", Shorthand: "p", Type: "intArray"}, + }, + }, + }, + }, + args: []string{"run", "--ports", "80,443"}, + want: []string{"ports"}, + }, + { + name: "MixedFlagTypes", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "verbose", Shorthand: "v", Type: "bool"}, + {Name: "count", Shorthand: "c", Type: "int"}, + {Name: "output", Shorthand: "o", Type: "string"}, + {Name: "tags", Shorthand: "t", Type: "stringArray"}, + {Name: "ports", Shorthand: "p", Type: "intArray"}, + }, + }, + }, + }, + args: []string{"run", "-v", "--count", "5", "--output", "json", "--tags", "a,b", "--ports", "80"}, + want: []string{"verbose", "count", "output", "tags", "ports"}, + }, + { + name: "UnknownFlagTypeDefaultsToString", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "custom", Shorthand: "x", Type: "special"}, + }, + }, + }, + }, + args: []string{"run", "--custom", "val"}, + want: []string{"custom"}, + }, + { + name: "EmptyFlagNameSkipped", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "", Shorthand: "x", Type: "string"}, + {Name: "valid", Shorthand: "v", Type: "bool"}, + }, + }, + }, + }, + args: []string{"run", "-v"}, + want: []string{"valid"}, + }, + { + name: "FlagsAfterDoubleDashIgnored", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "verbose", Shorthand: "v", Type: "bool"}, + }, + }, + }, + }, + args: []string{"run", "--", "--verbose"}, + want: nil, + }, + { + name: "BoolFlagWithEquals", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "verbose", Type: "bool"}, + }, + }, + }, + }, + args: []string{"run", "--verbose=true"}, + want: []string{"verbose"}, + }, + { + name: "NoFlagsProvided", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "verbose", Type: "bool"}, + }, + }, + }, + }, + args: []string{"run"}, + want: nil, + }, + { + name: "NilMetadata", + metadata: nil, + args: []string{"run", "--verbose"}, + want: nil, + }, + { + name: "CommandNotFoundReturnsNil", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "verbose", Type: "bool"}, + }, + }, + }, + }, + args: []string{"unknown", "--verbose"}, + want: nil, + }, + { + name: "FlagsResolvedViaAlias", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"list"}, + Aliases: []string{"ls"}, + Flags: []Flag{ + {Name: "all", Shorthand: "a", Type: "bool"}, + }, + }, + }, + }, + args: []string{"ls", "-a"}, + want: []string{"all"}, + }, + { + name: "FlagWithNoShorthand", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "long-only", Type: "string"}, + }, + }, + }, + }, + args: []string{"run", "--long-only", "value"}, + want: []string{"long-only"}, + }, + { + name: "CombinedShortBoolFlags", + metadata: &ExtensionCommandMetadata{ + Commands: []Command{ + { + Name: []string{"run"}, + Flags: []Flag{ + {Name: "all", Shorthand: "a", Type: "bool"}, + {Name: "brief", Shorthand: "b", Type: "bool"}, + {Name: "color", Shorthand: "c", Type: "bool"}, + }, + }, + }, + }, + args: []string{"run", "-abc"}, + want: []string{"all", "brief", "color"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := ResolveCommandFlags(tt.metadata, tt.args) + if tt.want == nil { + require.Nil(t, got) + } else { + require.ElementsMatch(t, tt.want, got) + } + }) + } +} diff --git a/cli/azd/pkg/extensions/extension_test.go b/cli/azd/pkg/extensions/extension_test.go new file mode 100644 index 00000000000..7bb15038af6 --- /dev/null +++ b/cli/azd/pkg/extensions/extension_test.go @@ -0,0 +1,208 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package extensions + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestExtension_Initialize_SignalsReadiness(t *testing.T) { + t.Parallel() + + ext := &Extension{} + ext.Initialize() + + err := ext.WaitUntilReady(t.Context()) + require.NoError(t, err) +} + +func TestExtension_Initialize_Idempotent(t *testing.T) { + t.Parallel() + + ext := &Extension{} + + // Calling Initialize twice must not panic or block. + ext.Initialize() + ext.Initialize() + + err := ext.WaitUntilReady(t.Context()) + require.NoError(t, err) +} + +func TestExtension_Fail_SignalsError(t *testing.T) { + t.Parallel() + + ext := &Extension{} + expected := errors.New("extension startup failed") + ext.Fail(expected) + + err := ext.WaitUntilReady(t.Context()) + require.ErrorIs(t, err, expected) +} + +func TestExtension_WaitUntilReady_CancelledContext(t *testing.T) { + t.Parallel() + + ext := &Extension{} + ctx, cancel := context.WithCancel(t.Context()) + cancel() // cancel immediately + + err := ext.WaitUntilReady(ctx) + require.ErrorIs(t, err, context.Canceled) +} + +func TestExtension_WaitUntilReady_Timeout(t *testing.T) { + t.Parallel() + + ext := &Extension{} + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Millisecond) + defer cancel() + + err := ext.WaitUntilReady(ctx) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +func TestExtension_HasCapability(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + capabilities []CapabilityType + query []CapabilityType + want bool + }{ + { + name: "SinglePresent", + capabilities: []CapabilityType{CustomCommandCapability, McpServerCapability}, + query: []CapabilityType{CustomCommandCapability}, + want: true, + }, + { + name: "SingleMissing", + capabilities: []CapabilityType{CustomCommandCapability}, + query: []CapabilityType{McpServerCapability}, + want: false, + }, + { + name: "MultipleAllPresent", + capabilities: []CapabilityType{CustomCommandCapability, McpServerCapability, MetadataCapability}, + query: []CapabilityType{CustomCommandCapability, McpServerCapability}, + want: true, + }, + { + name: "MultipleOneMissing", + capabilities: []CapabilityType{CustomCommandCapability}, + query: []CapabilityType{CustomCommandCapability, McpServerCapability}, + want: false, + }, + { + name: "EmptyQuery", + capabilities: []CapabilityType{CustomCommandCapability}, + query: []CapabilityType{}, + want: true, + }, + { + name: "EmptyCapabilities", + capabilities: []CapabilityType{}, + query: []CapabilityType{CustomCommandCapability}, + want: false, + }, + { + name: "NilCapabilities", + capabilities: nil, + query: []CapabilityType{CustomCommandCapability}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ext := &Extension{Capabilities: tt.capabilities} + got := ext.HasCapability(tt.query...) + require.Equal(t, tt.want, got) + }) + } +} + +func TestExtension_StdIn_ReturnsNonNil(t *testing.T) { + t.Parallel() + + ext := &Extension{} + reader := ext.StdIn() + require.NotNil(t, reader) +} + +func TestExtension_StdOut_ReturnsNonNil(t *testing.T) { + t.Parallel() + + ext := &Extension{} + writer := ext.StdOut() + require.NotNil(t, writer) +} + +func TestExtension_StdErr_ReturnsNonNil(t *testing.T) { + t.Parallel() + + ext := &Extension{} + writer := ext.StdErr() + require.NotNil(t, writer) +} + +func TestExtension_ReportedError_RoundTrip(t *testing.T) { + t.Parallel() + + ext := &Extension{} + expected := errors.New("something went wrong") + + ext.SetReportedError(expected) + got := ext.GetReportedError() + require.ErrorIs(t, got, expected) +} + +func TestExtension_GetReportedError_NilByDefault(t *testing.T) { + t.Parallel() + + ext := &Extension{} + got := ext.GetReportedError() + require.NoError(t, got) +} + +func TestExtension_ReportedError_ConcurrentAccess(t *testing.T) { + t.Parallel() + + ext := &Extension{} + const goroutines = 50 + + var wg sync.WaitGroup + wg.Add(goroutines * 2) + + for range goroutines { + err := errors.New("error from writer") + + go func() { + defer wg.Done() + ext.SetReportedError(err) + }() + + go func() { + defer wg.Done() + _ = ext.GetReportedError() + }() + } + + wg.Wait() + + // After all goroutines complete, the reported error should be one of the + // written errors (non-nil) — the exact value depends on scheduling. + got := ext.GetReportedError() + require.Error(t, got) +} diff --git a/cli/azd/pkg/extensions/runner_test.go b/cli/azd/pkg/extensions/runner_test.go new file mode 100644 index 00000000000..a31dbd268c3 --- /dev/null +++ b/cli/azd/pkg/extensions/runner_test.go @@ -0,0 +1,437 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package extensions + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "slices" + "strings" + "testing" + + "github.com/azure/azure-dev/cli/azd/pkg/exec" + "github.com/azure/azure-dev/cli/azd/test/mocks/mockexec" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// ExtensionRunError +// --------------------------------------------------------------------------- + +func TestExtensionRunError_Error(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + extensionId string + inner error + want string + }{ + { + name: "BasicError", + extensionId: "my-ext", + inner: errors.New("exit code 1"), + want: "extension 'my-ext' run failed: exit code 1", + }, + { + name: "WrappedError", + extensionId: "azd.test", + inner: errors.New("signal: killed"), + want: "extension 'azd.test' run failed: signal: killed", + }, + { + name: "EmptyExtensionId", + extensionId: "", + inner: errors.New("boom"), + want: "extension '' run failed: boom", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + e := &ExtensionRunError{ExtensionId: tt.extensionId, Err: tt.inner} + require.Equal(t, tt.want, e.Error()) + }) + } +} + +func TestExtensionRunError_Unwrap(t *testing.T) { + t.Parallel() + + inner := errors.New("root cause") + e := &ExtensionRunError{ExtensionId: "test-ext", Err: inner} + + require.ErrorIs(t, e, inner) + require.Equal(t, inner, e.Unwrap()) +} + +func TestExtensionRunError_NilInner(t *testing.T) { + t.Parallel() + + e := &ExtensionRunError{ExtensionId: "ext", Err: nil} + require.Contains(t, e.Error(), "ext") + require.Nil(t, e.Unwrap()) +} + +// --------------------------------------------------------------------------- +// NewRunner +// --------------------------------------------------------------------------- + +func TestRunner_NewRunner(t *testing.T) { + t.Parallel() + + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + require.NotNil(t, runner) +} + +// --------------------------------------------------------------------------- +// Runner.Invoke +// --------------------------------------------------------------------------- + +// setupConfigAndExtension creates a temp config dir, sets AZD_CONFIG_DIR, and +// creates a fake extension binary at the expected path. Returns the config dir +// and a minimal Extension value whose Path resolves correctly. +// +// Tests that call this helper must NOT use t.Parallel() because t.Setenv +// mutates process-global state. +func setupConfigAndExtension(t *testing.T) (string, *Extension) { + t.Helper() + + configDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", configDir) + + extRelPath := filepath.Join("extensions", "test-ext", "bin", "test-ext") + extFullPath := filepath.Join(configDir, extRelPath) + + require.NoError(t, os.MkdirAll(filepath.Dir(extFullPath), 0o755)) + require.NoError(t, os.WriteFile(extFullPath, []byte("fake-binary"), 0o600)) + + ext := &Extension{ + Id: "test-ext", + Path: extRelPath, + } + + return configDir, ext +} + +func TestRunner_Invoke_Success(t *testing.T) { + configDir, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + expectedPath := filepath.Join(configDir, ext.Path) + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return args.Cmd == expectedPath + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{ExitCode: 0, Stdout: "ok"}, nil + }) + + result, err := runner.Invoke(t.Context(), ext, &InvokeOptions{ + Args: []string{"hello", "world"}, + }) + + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 0, result.ExitCode) + require.Equal(t, "ok", result.Stdout) +} + +func TestRunner_Invoke_MissingExtensionPath(t *testing.T) { + configDir := t.TempDir() + t.Setenv("AZD_CONFIG_DIR", configDir) + + ext := &Extension{ + Id: "nonexistent", + Path: filepath.Join("extensions", "missing", "binary"), + } + + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + result, err := runner.Invoke(t.Context(), ext, &InvokeOptions{}) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "extension path") + require.Contains(t, err.Error(), "not found") +} + +func TestRunner_Invoke_ArgsPassedThrough(t *testing.T) { + configDir, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + expectedPath := filepath.Join(configDir, ext.Path) + wantArgs := []string{"serve", "--port", "8080"} + + var capturedArgs exec.RunArgs + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return args.Cmd == expectedPath + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.RunResult{ExitCode: 0}, nil + }) + + _, err := runner.Invoke(t.Context(), ext, &InvokeOptions{ + Args: wantArgs, + }) + require.NoError(t, err) + require.Equal(t, wantArgs, capturedArgs.Args) +} + +func TestRunner_Invoke_EnvVariablePropagation(t *testing.T) { + tests := []struct { + name string + options InvokeOptions + wantEnv []string + }{ + { + name: "DebugTrue", + options: InvokeOptions{ + Debug: true, + }, + wantEnv: []string{"AZD_DEBUG=true"}, + }, + { + name: "NoPromptTrue", + options: InvokeOptions{ + NoPrompt: true, + }, + wantEnv: []string{"AZD_NO_PROMPT=true"}, + }, + { + name: "CwdSet", + options: InvokeOptions{ + Cwd: "/my/project", + }, + wantEnv: []string{"AZD_CWD=/my/project"}, + }, + { + name: "EnvironmentSet", + options: InvokeOptions{ + Environment: "dev", + }, + wantEnv: []string{"AZD_ENVIRONMENT=dev"}, + }, + { + name: "AllFlags", + options: InvokeOptions{ + Debug: true, + NoPrompt: true, + Cwd: "/work", + Environment: "staging", + }, + wantEnv: []string{ + "AZD_DEBUG=true", + "AZD_NO_PROMPT=true", + "AZD_CWD=/work", + "AZD_ENVIRONMENT=staging", + }, + }, + { + name: "NoFlags", + options: InvokeOptions{}, + wantEnv: nil, + }, + { + name: "ExistingEnvPreserved", + options: InvokeOptions{ + Env: []string{"CUSTOM_VAR=hello"}, + Debug: true, + }, + wantEnv: []string{"CUSTOM_VAR=hello", "AZD_DEBUG=true"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Each subtest gets its own config dir via setupConfigAndExtension, + // which calls t.Setenv - cannot use t.Parallel(). + _, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + var capturedArgs exec.RunArgs + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.RunResult{ExitCode: 0}, nil + }) + + // Make a copy so the table entry isn't mutated across iterations. + opts := tt.options + _, err := runner.Invoke(t.Context(), ext, &opts) + require.NoError(t, err) + + if tt.wantEnv == nil { + require.Empty(t, capturedArgs.Env) + } else { + for _, expected := range tt.wantEnv { + require.True(t, + slices.Contains(capturedArgs.Env, expected), + "expected env var %q not found in %v", expected, capturedArgs.Env, + ) + } + require.Len(t, capturedArgs.Env, len(tt.wantEnv)) + } + }) + } +} + +func TestRunner_Invoke_InteractiveMode(t *testing.T) { + _, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + var capturedArgs exec.RunArgs + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.RunResult{ExitCode: 0}, nil + }) + + _, err := runner.Invoke(t.Context(), ext, &InvokeOptions{ + Interactive: true, + }) + require.NoError(t, err) + require.True(t, capturedArgs.Interactive) + // In interactive mode, custom streams should not be set on RunArgs + require.Nil(t, capturedArgs.StdIn) + require.Nil(t, capturedArgs.StdOut) +} + +func TestRunner_Invoke_NonInteractiveWithStreams(t *testing.T) { + _, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + var capturedArgs exec.RunArgs + + stdinBuf := strings.NewReader("input data") + stdoutBuf := &bytes.Buffer{} + stderrBuf := &bytes.Buffer{} + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.RunResult{ExitCode: 0}, nil + }) + + _, err := runner.Invoke(t.Context(), ext, &InvokeOptions{ + Interactive: false, + StdIn: stdinBuf, + StdOut: stdoutBuf, + StdErr: stderrBuf, + }) + require.NoError(t, err) + require.False(t, capturedArgs.Interactive) + require.Equal(t, stdinBuf, capturedArgs.StdIn) + require.Equal(t, stdoutBuf, capturedArgs.StdOut) + // RunArgs.Stderr (note: lowercase 'e') maps to WithStdErr + require.Equal(t, stderrBuf, capturedArgs.Stderr) +} + +func TestRunner_Invoke_NonInteractiveNilStreams(t *testing.T) { + _, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + var capturedArgs exec.RunArgs + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedArgs = args + return exec.RunResult{ExitCode: 0}, nil + }) + + _, err := runner.Invoke(t.Context(), ext, &InvokeOptions{ + Interactive: false, + StdIn: nil, + StdOut: nil, + StdErr: nil, + }) + require.NoError(t, err) + require.False(t, capturedArgs.Interactive) + require.Nil(t, capturedArgs.StdIn) + require.Nil(t, capturedArgs.StdOut) + require.Nil(t, capturedArgs.Stderr) +} + +func TestRunner_Invoke_CommandError_WrapsInExtensionRunError(t *testing.T) { + _, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + cmdError := errors.New("exit status 42") + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{ExitCode: 42, Stderr: "something broke"}, cmdError + }) + + result, err := runner.Invoke(t.Context(), ext, &InvokeOptions{}) + require.Error(t, err) + require.NotNil(t, result) + require.Equal(t, 42, result.ExitCode) + + // Verify we get an ExtensionRunError + var runErr *ExtensionRunError + require.ErrorAs(t, err, &runErr) + require.Equal(t, ext.Id, runErr.ExtensionId) + require.ErrorIs(t, runErr, cmdError) +} + +func TestRunner_Invoke_ExtensionPathResolution(t *testing.T) { + configDir, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + expectedCmd := filepath.Join(configDir, ext.Path) + var capturedCmd string + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + capturedCmd = args.Cmd + return exec.RunResult{ExitCode: 0}, nil + }) + + _, err := runner.Invoke(t.Context(), ext, &InvokeOptions{}) + require.NoError(t, err) + require.Equal(t, expectedCmd, capturedCmd) +} + +func TestRunner_Invoke_EnsureInit_Called(t *testing.T) { + _, ext := setupConfigAndExtension(t) + cmdRunner := mockexec.NewMockCommandRunner() + runner := NewRunner(cmdRunner) + + // Extension starts uninitialized + require.False(t, ext.initialized) + + cmdRunner.When(func(args exec.RunArgs, command string) bool { + return true + }).RespondFn(func(args exec.RunArgs) (exec.RunResult, error) { + return exec.RunResult{ExitCode: 0}, nil + }) + + _, err := runner.Invoke(t.Context(), ext, &InvokeOptions{}) + require.NoError(t, err) + + // After Invoke, extension should be initialized + require.True(t, ext.initialized) +} diff --git a/cli/azd/pkg/ioc/container_coverage_test.go b/cli/azd/pkg/ioc/container_coverage_test.go new file mode 100644 index 00000000000..ea33b1df6d3 --- /dev/null +++ b/cli/azd/pkg/ioc/container_coverage_test.go @@ -0,0 +1,812 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package ioc + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// ---------- helper types for tests ---------- + +type greeter interface { + Greet() string +} + +type englishGreeter struct{ name string } + +func (g *englishGreeter) Greet() string { return "Hello, " + g.name } + +type spanishGreeter struct{ name string } + +func (g *spanishGreeter) Greet() string { return "Hola, " + g.name } + +type counterService struct { + calls int +} + +func newCounterService() *counterService { return &counterService{} } + +type depService struct { + counter *counterService +} + +func newDepService(c *counterService) *depService { return &depService{counter: c} } + +// fillTarget is used to test Fill(). +type fillTarget struct { + Greeter greeter `container:"type"` +} + +// ---------- Named registration & resolution ---------- + +func Test_Named_Singleton_Register_Resolve(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + regName string + resolver any + wantGreet string + }{ + { + name: "EnglishGreeter", + regName: "english", + resolver: func() greeter { return &englishGreeter{name: "World"} }, + wantGreet: "Hello, World", + }, + { + name: "SpanishGreeter", + regName: "spanish", + resolver: func() greeter { return &spanishGreeter{name: "Mundo"} }, + wantGreet: "Hola, Mundo", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterNamedSingleton(tt.regName, tt.resolver) + require.NoError(t, err) + + var resolved greeter + err = c.ResolveNamed(tt.regName, &resolved) + require.NoError(t, err) + require.Equal(t, tt.wantGreet, resolved.Greet()) + + // Singleton: second resolve returns same pointer + var resolved2 greeter + err = c.ResolveNamed(tt.regName, &resolved2) + require.NoError(t, err) + require.Same(t, resolved, resolved2) + }) + } +} + +func Test_MustRegisterNamedSingleton(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + c.MustRegisterNamedSingleton("en", func() greeter { + return &englishGreeter{name: "test"} + }) + + var g greeter + err := c.ResolveNamed("en", &g) + require.NoError(t, err) + require.Equal(t, "Hello, test", g.Greet()) + }) + + t.Run("PanicsOnInvalidResolver", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + require.Panics(t, func() { + c.MustRegisterNamedSingleton("bad", "not-a-func") + }) + }) +} + +func Test_Named_Transient_Register_Resolve(t *testing.T) { + t.Parallel() + + c := NewNestedContainer(nil) + err := c.RegisterNamedTransient("counter", func() *counterService { + return newCounterService() + }) + require.NoError(t, err) + + var inst1 *counterService + err = c.ResolveNamed("counter", &inst1) + require.NoError(t, err) + require.NotNil(t, inst1) + + var inst2 *counterService + err = c.ResolveNamed("counter", &inst2) + require.NoError(t, err) + require.NotNil(t, inst2) + + // Transient: different instances each time + require.NotSame(t, inst1, inst2) +} + +func Test_MustRegisterNamedTransient(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + c.MustRegisterNamedTransient("svc", func() *counterService { + return newCounterService() + }) + + var inst *counterService + err := c.ResolveNamed("svc", &inst) + require.NoError(t, err) + require.NotNil(t, inst) + }) + + t.Run("PanicsOnInvalidResolver", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + require.Panics(t, func() { + c.MustRegisterNamedTransient("bad", 42) + }) + }) +} + +// ---------- RegisterTransient (error-returning) ---------- + +func Test_RegisterTransient(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterTransient(func() *counterService { + return newCounterService() + }) + require.NoError(t, err) + + var inst *counterService + err = c.Resolve(&inst) + require.NoError(t, err) + require.NotNil(t, inst) + }) + + t.Run("ErrorOnInvalidResolver", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterTransient("not-a-function") + require.Error(t, err) + }) +} + +// ---------- RegisterNamedInstance ---------- + +func Test_RegisterNamedInstance(t *testing.T) { + t.Parallel() + + c := NewNestedContainer(nil) + eng := &englishGreeter{name: "Named"} + spa := &spanishGreeter{name: "Nombrado"} + RegisterNamedInstance[greeter](c, "english", eng) + RegisterNamedInstance[greeter](c, "spanish", spa) + + var resolvedEn greeter + err := c.ResolveNamed("english", &resolvedEn) + require.NoError(t, err) + require.Equal(t, "Hello, Named", resolvedEn.Greet()) + + var resolvedEs greeter + err = c.ResolveNamed("spanish", &resolvedEs) + require.NoError(t, err) + require.Equal(t, "Hola, Nombrado", resolvedEs.Greet()) +} + +// ---------- ResolveNamed error cases ---------- + +func Test_ResolveNamed_Errors(t *testing.T) { + t.Parallel() + + t.Run("UnregisteredName", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + var g greeter + err := c.ResolveNamed("nope", &g) + require.Error(t, err) + require.True(t, errors.Is(err, ErrResolveInstance)) + }) + + t.Run("ResolverReturnsError", func(t *testing.T) { + t.Parallel() + sentinel := fmt.Errorf("custom resolution error") + c := NewNestedContainer(nil) + c.MustRegisterNamedSingleton("fail", func() (greeter, error) { + return nil, sentinel + }) + + var g greeter + err := c.ResolveNamed("fail", &g) + require.Error(t, err) + // The underlying error should propagate, not be wrapped as container error + require.ErrorIs(t, err, sentinel) + require.False(t, errors.Is(err, ErrResolveInstance)) + }) +} + +// ---------- Invoke ---------- + +func Test_Invoke(t *testing.T) { + t.Parallel() + + t.Run("InjectsRegisteredDeps", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + c.MustRegisterSingleton(func() *counterService { + return &counterService{calls: 42} + }) + + var captured *counterService + err := c.Invoke(func(cs *counterService) { + captured = cs + }) + require.NoError(t, err) + require.NotNil(t, captured) + require.Equal(t, 42, captured.calls) + }) + + t.Run("ErrorWhenDepMissing", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.Invoke(func(cs *counterService) { + t.Fatal("should not be called") + }) + require.Error(t, err) + }) +} + +// ---------- Fill ---------- + +func Test_Fill(t *testing.T) { + t.Parallel() + + t.Run("FillByType", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + c.MustRegisterSingleton(func() greeter { + return &englishGreeter{name: "Fill"} + }) + + target := &fillTarget{} + err := c.Fill(target) + require.NoError(t, err) + require.NotNil(t, target.Greeter) + require.Equal(t, "Hello, Fill", target.Greeter.Greet()) + }) + + t.Run("ErrorWhenUnregistered", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + target := &fillTarget{} + err := c.Fill(target) + require.Error(t, err) + }) +} + +// ---------- RegisterSingletonAndInvoke ---------- + +func Test_RegisterSingletonAndInvoke(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterSingletonAndInvoke(func() *counterService { + return &counterService{calls: 7} + }) + require.NoError(t, err) + + var inst *counterService + err = c.Resolve(&inst) + require.NoError(t, err) + require.Equal(t, 7, inst.calls) + }) + + t.Run("ErrorOnInvalidResolver", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterSingletonAndInvoke("not-a-func") + require.Error(t, err) + }) +} + +// ---------- RegisterSingleton (error-returning) ---------- + +func Test_RegisterSingleton_Error(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterSingleton("not-a-function") + require.Error(t, err) +} + +// ---------- Named Scoped ---------- + +func Test_RegisterNamedScoped(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterNamedScoped("svc", func() *counterService { + return newCounterService() + }) + require.NoError(t, err) + + var inst *counterService + err = c.ResolveNamed("svc", &inst) + require.NoError(t, err) + require.NotNil(t, inst) + }) + + t.Run("ErrorOnInvalidResolver", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterNamedScoped("bad", 123) + require.Error(t, err) + }) +} + +func Test_MustRegisterNamedScoped(t *testing.T) { + t.Parallel() + + t.Run("Success", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + c.MustRegisterNamedScoped("svc", func() *counterService { + return newCounterService() + }) + + var inst *counterService + err := c.ResolveNamed("svc", &inst) + require.NoError(t, err) + require.NotNil(t, inst) + }) + + t.Run("PanicsOnInvalidResolver", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + require.Panics(t, func() { + c.MustRegisterNamedScoped("bad", 99) + }) + }) +} + +// ---------- MustRegisterScoped panics ---------- + +func Test_MustRegisterScoped_Panics(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + require.Panics(t, func() { + c.MustRegisterScoped("not-a-function") + }) +} + +// ---------- NewScope with named scoped bindings ---------- + +func Test_NewScope_NamedScopedBindings(t *testing.T) { + t.Parallel() + + root := NewNestedContainer(nil) + root.MustRegisterNamedScoped("greeter", func() greeter { + return &englishGreeter{name: "scoped"} + }) + + scope1, err := root.NewScope() + require.NoError(t, err) + + var g1 greeter + err = scope1.ResolveNamed("greeter", &g1) + require.NoError(t, err) + require.Equal(t, "Hello, scoped", g1.Greet()) + + // Same scope resolves same singleton + var g1again greeter + err = scope1.ResolveNamed("greeter", &g1again) + require.NoError(t, err) + require.Same(t, g1, g1again) + + // Different scope gets different instance + scope2, err := root.NewScope() + require.NoError(t, err) + + var g2 greeter + err = scope2.ResolveNamed("greeter", &g2) + require.NoError(t, err) + require.NotSame(t, g1, g2) +} + +// ---------- NewScopeRegistrationsOnly ---------- + +func Test_NewScopeRegistrationsOnly(t *testing.T) { + t.Parallel() + + t.Run("ResetsSingletonInstances", func(t *testing.T) { + t.Parallel() + root := NewNestedContainer(nil) + root.MustRegisterSingleton(func() *counterService { + return newCounterService() + }) + + // Resolve in root to cache the singleton + var rootInst *counterService + err := root.Resolve(&rootInst) + require.NoError(t, err) + require.NotNil(t, rootInst) + + // Create scope from registrations only - resets cached instances + scope, err := root.NewScopeRegistrationsOnly() + require.NoError(t, err) + + var scopeInst *counterService + err = scope.Resolve(&scopeInst) + require.NoError(t, err) + require.NotNil(t, scopeInst) + + // Instances should differ because the scope got fresh registrations + require.NotSame(t, rootInst, scopeInst) + }) + + t.Run("WithScopedBindings", func(t *testing.T) { + t.Parallel() + root := NewNestedContainer(nil) + root.MustRegisterScoped(func() *counterService { + return newCounterService() + }) + + scope1, err := root.NewScopeRegistrationsOnly() + require.NoError(t, err) + + var inst1 *counterService + err = scope1.Resolve(&inst1) + require.NoError(t, err) + + scope2, err := root.NewScopeRegistrationsOnly() + require.NoError(t, err) + + var inst2 *counterService + err = scope2.Resolve(&inst2) + require.NoError(t, err) + + // Different scopes, different instances + require.NotSame(t, inst1, inst2) + }) + + t.Run("WithNamedScopedBindings", func(t *testing.T) { + t.Parallel() + root := NewNestedContainer(nil) + root.MustRegisterNamedScoped("counter", func() *counterService { + return newCounterService() + }) + + scope, err := root.NewScopeRegistrationsOnly() + require.NoError(t, err) + + var inst *counterService + err = scope.ResolveNamed("counter", &inst) + require.NoError(t, err) + require.NotNil(t, inst) + }) + + t.Run("NilParent", func(t *testing.T) { + t.Parallel() + // NewRegistrationsOnly(nil) should produce a working empty container + c := NewRegistrationsOnly(nil) + require.NotNil(t, c) + + // ServiceLocator should still self-register + var sl ServiceLocator + err := c.Resolve(&sl) + require.NoError(t, err) + require.NotNil(t, sl) + }) +} + +// ---------- ServiceLocator self-registration ---------- + +func Test_ServiceLocator_SelfRegistered(t *testing.T) { + t.Parallel() + + t.Run("InNewNestedContainer", func(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + + var sl ServiceLocator + err := c.Resolve(&sl) + require.NoError(t, err) + require.NotNil(t, sl) + require.Same(t, c, sl) + }) + + t.Run("InNewRegistrationsOnly", func(t *testing.T) { + t.Parallel() + parent := NewNestedContainer(nil) + child := NewRegistrationsOnly(parent) + + var sl ServiceLocator + err := child.Resolve(&sl) + require.NoError(t, err) + require.Same(t, child, sl) + }) +} + +// ---------- ServiceLocator interface methods ---------- + +func Test_ServiceLocator_Methods(t *testing.T) { + t.Parallel() + + c := NewNestedContainer(nil) + c.MustRegisterSingleton(func() *counterService { + return &counterService{calls: 10} + }) + RegisterNamedInstance[greeter](c, "en", &englishGreeter{name: "SL"}) + + var sl ServiceLocator + err := c.Resolve(&sl) + require.NoError(t, err) + + t.Run("Resolve", func(t *testing.T) { + var cs *counterService + err := sl.Resolve(&cs) + require.NoError(t, err) + require.Equal(t, 10, cs.calls) + }) + + t.Run("ResolveNamed", func(t *testing.T) { + var g greeter + err := sl.ResolveNamed("en", &g) + require.NoError(t, err) + require.Equal(t, "Hello, SL", g.Greet()) + }) + + t.Run("Invoke", func(t *testing.T) { + var captured *counterService + err := sl.Invoke(func(cs *counterService) { + captured = cs + }) + require.NoError(t, err) + require.Equal(t, 10, captured.calls) + }) +} + +// ---------- Must* panic on invalid resolver ---------- + +func Test_MustRegisterSingleton_Panics(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + require.Panics(t, func() { + c.MustRegisterSingleton("invalid") + }) +} + +func Test_MustRegisterTransient_Panics(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + require.Panics(t, func() { + c.MustRegisterTransient(42) + }) +} + +// ---------- RegisterNamedSingleton error path ---------- + +func Test_RegisterNamedSingleton_Error(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterNamedSingleton("bad", "not-a-func") + require.Error(t, err) +} + +// ---------- RegisterNamedTransient error path ---------- + +func Test_RegisterNamedTransient_Error(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterNamedTransient("bad", 123) + require.Error(t, err) +} + +// ---------- RegisterScoped error path ---------- + +func Test_RegisterScoped_Error(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + err := c.RegisterScoped("not-a-func") + require.Error(t, err) +} + +// ---------- Dependency injection chain ---------- + +func Test_DependencyChain(t *testing.T) { + t.Parallel() + c := NewNestedContainer(nil) + c.MustRegisterSingleton(func() *counterService { + return &counterService{calls: 5} + }) + c.MustRegisterSingleton(newDepService) + + var dep *depService + err := c.Resolve(&dep) + require.NoError(t, err) + require.NotNil(t, dep) + require.NotNil(t, dep.counter) + require.Equal(t, 5, dep.counter.calls) +} + +// ---------- Nested container inherits from parent ---------- + +func Test_NewNestedContainer_InheritsParent(t *testing.T) { + t.Parallel() + + parent := NewNestedContainer(nil) + parent.MustRegisterSingleton(func() *counterService { + return &counterService{calls: 99} + }) + + // Resolve in parent first to cache the singleton + var parentInst *counterService + err := parent.Resolve(&parentInst) + require.NoError(t, err) + + child := NewNestedContainer(parent) + var childInst *counterService + err = child.Resolve(&childInst) + require.NoError(t, err) + // Child inherits parent's cached singleton + require.Same(t, parentInst, childInst) +} + +// ---------- inspectResolveError ---------- + +func Test_InspectResolveError(t *testing.T) { + t.Parallel() + + t.Run("ContainerError", func(t *testing.T) { + t.Parallel() + err := inspectResolveError(fmt.Errorf("container: no binding found")) + require.ErrorIs(t, err, ErrResolveInstance) + }) + + t.Run("WrappedContainerError", func(t *testing.T) { + t.Parallel() + inner := fmt.Errorf("container: something broke") + wrapped := fmt.Errorf("outer: %w", inner) + err := inspectResolveError(wrapped) + require.ErrorIs(t, err, ErrResolveInstance) + }) + + t.Run("NonContainerError", func(t *testing.T) { + t.Parallel() + sentinel := fmt.Errorf("custom app error") + err := inspectResolveError(sentinel) + require.Equal(t, sentinel, err) + require.False(t, errors.Is(err, ErrResolveInstance)) + }) + + t.Run("WrappedNonContainerError", func(t *testing.T) { + t.Parallel() + inner := fmt.Errorf("app error inside") + wrapped := fmt.Errorf("outer: %w", inner) + err := inspectResolveError(wrapped) + // Should unwrap and return the inner non-container error + require.Equal(t, inner, err) + }) +} + +// ---------- Multiple named registrations for same type ---------- + +func Test_MultipleNamedRegistrations(t *testing.T) { + t.Parallel() + + c := NewNestedContainer(nil) + RegisterNamedInstance[greeter](c, "en", &englishGreeter{name: "A"}) + RegisterNamedInstance[greeter](c, "es", &spanishGreeter{name: "B"}) + + var en greeter + err := c.ResolveNamed("en", &en) + require.NoError(t, err) + require.Equal(t, "Hello, A", en.Greet()) + + var es greeter + err = c.ResolveNamed("es", &es) + require.NoError(t, err) + require.Equal(t, "Hola, B", es.Greet()) +} + +// ---------- Scoped with mixed named and unnamed ---------- + +func Test_NewScope_MixedScopedBindings(t *testing.T) { + t.Parallel() + + root := NewNestedContainer(nil) + // Unnamed scoped + root.MustRegisterScoped(func() *counterService { + return newCounterService() + }) + // Named scoped + root.MustRegisterNamedScoped("named-counter", func() *counterService { + return newCounterService() + }) + + scope, err := root.NewScope() + require.NoError(t, err) + + var unnamed *counterService + err = scope.Resolve(&unnamed) + require.NoError(t, err) + require.NotNil(t, unnamed) + + var named *counterService + err = scope.ResolveNamed("named-counter", &named) + require.NoError(t, err) + require.NotNil(t, named) + + // They're different registrations so different instances + require.NotSame(t, unnamed, named) +} + +// ---------- NewScopeRegistrationsOnly with mixed ---------- + +func Test_NewScopeRegistrationsOnly_MixedScopedBindings(t *testing.T) { + t.Parallel() + + root := NewNestedContainer(nil) + root.MustRegisterScoped(func() *counterService { + return newCounterService() + }) + root.MustRegisterNamedScoped("named-counter", func() *counterService { + return newCounterService() + }) + + scope, err := root.NewScopeRegistrationsOnly() + require.NoError(t, err) + + var unnamed *counterService + err = scope.Resolve(&unnamed) + require.NoError(t, err) + require.NotNil(t, unnamed) + + var named *counterService + err = scope.ResolveNamed("named-counter", &named) + require.NoError(t, err) + require.NotNil(t, named) +} + +// ---------- Transient in nested scope ---------- + +func Test_TransientInNestedScope(t *testing.T) { + t.Parallel() + + root := NewNestedContainer(nil) + root.MustRegisterTransient(func() *counterService { + return newCounterService() + }) + + scope, err := root.NewScope() + require.NoError(t, err) + + var inst1 *counterService + err = scope.Resolve(&inst1) + require.NoError(t, err) + + var inst2 *counterService + err = scope.Resolve(&inst2) + require.NoError(t, err) + + require.NotSame(t, inst1, inst2) +}