diff --git a/cli/azd/internal/grpcserver/account_service.go b/cli/azd/internal/grpcserver/account_service.go index fcfb3243068..59cdfa2d9e0 100644 --- a/cli/azd/internal/grpcserver/account_service.go +++ b/cli/azd/internal/grpcserver/account_service.go @@ -8,6 +8,8 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/account" "github.com/azure/azure-dev/cli/azd/pkg/azdext" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type accountService struct { @@ -63,6 +65,10 @@ func (s *accountService) LookupTenant( ctx context.Context, req *azdext.LookupTenantRequest, ) (*azdext.LookupTenantResponse, error) { + if req.SubscriptionId == "" { + return nil, status.Error(codes.InvalidArgument, "subscription id is required") + } + tenantId, err := s.subscriptionsManager.LookupTenant(ctx, req.SubscriptionId) if err != nil { return nil, err diff --git a/cli/azd/internal/grpcserver/compose_service.go b/cli/azd/internal/grpcserver/compose_service.go index 2d84d2a3bc2..9f555ea320a 100644 --- a/cli/azd/internal/grpcserver/compose_service.go +++ b/cli/azd/internal/grpcserver/compose_service.go @@ -136,7 +136,7 @@ func (c *composeService) GetResourceType( context.Context, *azdext.GetResourceTypeRequest, ) (*azdext.GetResourceTypeResponse, error) { - panic("unimplemented") + return nil, status.Error(codes.Unimplemented, "GetResourceType is not yet implemented") } // ListResourceTypes lists all available resource types. diff --git a/cli/azd/internal/grpcserver/compose_service_test.go b/cli/azd/internal/grpcserver/compose_service_test.go index 748dd108e77..c139ecdfc65 100644 --- a/cli/azd/internal/grpcserver/compose_service_test.go +++ b/cli/azd/internal/grpcserver/compose_service_test.go @@ -18,6 +18,8 @@ import ( "github.com/azure/azure-dev/cli/azd/test/mocks" "github.com/azure/azure-dev/cli/azd/test/mocks/mockenv" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func Test_ComposeService_AddResource(t *testing.T) { @@ -229,3 +231,27 @@ func Test_Test_ComposeService_ListResourceTypes(t *testing.T) { require.NotEmpty(t, randomResource.DisplayName) require.NotEmpty(t, randomResource.Type) } + +func Test_ComposeService_GetResourceType_Unimplemented(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + lazyAzdContext := lazy.NewLazy(func() (*azdcontext.AzdContext, error) { + return nil, azdcontext.ErrNoProject + }) + env := environment.New("test") + envManager := &mockenv.MockEnvManager{} + lazyEnvManager := lazy.NewLazy(func() (environment.Manager, error) { + return envManager, nil + }) + lazyEnv := lazy.NewLazy(func() (*environment.Environment, error) { + return env, nil + }) + service := NewComposeService(lazyAzdContext, lazyEnv, lazyEnvManager) + + _, err := service.GetResourceType(*mockContext.Context, &azdext.GetResourceTypeRequest{}) + require.Error(t, err) + + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.Unimplemented, st.Code()) + require.Contains(t, st.Message(), "not yet implemented") +} diff --git a/cli/azd/internal/grpcserver/container_service.go b/cli/azd/internal/grpcserver/container_service.go index 73dcb9eeb6a..dba079449a5 100644 --- a/cli/azd/internal/grpcserver/container_service.go +++ b/cli/azd/internal/grpcserver/container_service.go @@ -14,6 +14,8 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/input" "github.com/azure/azure-dev/cli/azd/pkg/lazy" "github.com/azure/azure-dev/cli/azd/pkg/project" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type containerService struct { @@ -45,6 +47,10 @@ func (c *containerService) Build( ctx context.Context, req *azdext.ContainerBuildRequest, ) (*azdext.ContainerBuildResponse, error) { + if req.ServiceName == "" { + return nil, status.Error(codes.InvalidArgument, "service name is required") + } + projectConfig, err := c.lazyProject.GetValue() if err != nil { return nil, err @@ -52,7 +58,8 @@ func (c *containerService) Build( serviceConfig, has := projectConfig.Services[req.ServiceName] if !has { - return nil, fmt.Errorf("service %q not found in project configuration", req.ServiceName) + return nil, status.Errorf(codes.NotFound, + "service %q not found in project configuration", req.ServiceName) } containerHelper, err := c.lazyContainerHelper.GetValue() @@ -95,6 +102,10 @@ func (c *containerService) Package( ctx context.Context, req *azdext.ContainerPackageRequest, ) (*azdext.ContainerPackageResponse, error) { + if req.ServiceName == "" { + return nil, status.Error(codes.InvalidArgument, "service name is required") + } + projectConfig, err := c.lazyProject.GetValue() if err != nil { return nil, err @@ -102,7 +113,8 @@ func (c *containerService) Package( serviceConfig, has := projectConfig.Services[req.ServiceName] if !has { - return nil, fmt.Errorf("service %q not found in project configuration", req.ServiceName) + return nil, status.Errorf(codes.NotFound, + "service %q not found in project configuration", req.ServiceName) } containerHelper, err := c.lazyContainerHelper.GetValue() @@ -145,6 +157,10 @@ func (c *containerService) Publish( ctx context.Context, req *azdext.ContainerPublishRequest, ) (*azdext.ContainerPublishResponse, error) { + if req.ServiceName == "" { + return nil, status.Error(codes.InvalidArgument, "service name is required") + } + projectConfig, err := c.lazyProject.GetValue() if err != nil { return nil, err @@ -152,7 +168,8 @@ func (c *containerService) Publish( serviceConfig, has := projectConfig.Services[req.ServiceName] if !has { - return nil, fmt.Errorf("service %q not found in project configuration", req.ServiceName) + return nil, status.Errorf(codes.NotFound, + "service %q not found in project configuration", req.ServiceName) } containerHelper, err := c.lazyContainerHelper.GetValue() diff --git a/cli/azd/internal/grpcserver/copilot_service.go b/cli/azd/internal/grpcserver/copilot_service.go index 3180d38e48f..c5369727d6d 100644 --- a/cli/azd/internal/grpcserver/copilot_service.go +++ b/cli/azd/internal/grpcserver/copilot_service.go @@ -359,6 +359,13 @@ func convertFileChangeType(ct watch.FileChangeType) azdext.CopilotFileChangeType // convertSessionEvent converts a Copilot SDK SessionEvent to the proto representation. // Event data is marshaled to JSON then converted to google.protobuf.Struct for // dynamic, schema-free transport. +// +// Errors during data conversion are intentionally logged and swallowed rather than propagated. +// This function feeds into a streaming response, and its signature intentionally omits an error +// return to support graceful degradation: callers always receive a valid event with at least +// the Type and Timestamp fields populated, even when the Data payload cannot be converted. +// Failing the entire stream for a single malformed event would be worse than delivering +// partial data. func convertSessionEvent(event agent.SessionEvent) *azdext.CopilotSessionEvent { protoEvent := &azdext.CopilotSessionEvent{ Type: string(event.Type), @@ -368,19 +375,28 @@ func convertSessionEvent(event agent.SessionEvent) *azdext.CopilotSessionEvent { // Marshal event.Data to JSON, then to protobuf Struct jsonBytes, err := json.Marshal(event.Data) if err != nil { - log.Printf("[copilot-service] failed to marshal event data: %v", err) + log.Printf( + "[copilot-service] failed to marshal event data for event type %q: %v", + event.Type, err, + ) return protoEvent } var dataMap map[string]any if err := json.Unmarshal(jsonBytes, &dataMap); err != nil { - log.Printf("[copilot-service] failed to unmarshal event data to map: %v", err) + log.Printf( + "[copilot-service] failed to unmarshal event data to map for event type %q: %v", + event.Type, err, + ) return protoEvent } protoStruct, err := structpb.NewStruct(dataMap) if err != nil { - log.Printf("[copilot-service] failed to create protobuf struct: %v", err) + log.Printf( + "[copilot-service] failed to create protobuf struct for event type %q: %v", + event.Type, err, + ) return protoEvent } diff --git a/cli/azd/internal/grpcserver/environment_service.go b/cli/azd/internal/grpcserver/environment_service.go index 2b52c9eb4ad..afe5a4355bc 100644 --- a/cli/azd/internal/grpcserver/environment_service.go +++ b/cli/azd/internal/grpcserver/environment_service.go @@ -12,6 +12,8 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/environment" "github.com/azure/azure-dev/cli/azd/pkg/environment/azdcontext" "github.com/azure/azure-dev/cli/azd/pkg/lazy" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type environmentService struct { @@ -152,6 +154,10 @@ func (s *environmentService) GetValues( // GetValue retrieves the value of a specific key in the specified environment. func (s *environmentService) GetValue(ctx context.Context, req *azdext.GetEnvRequest) (*azdext.KeyValueResponse, error) { + if req.Key == "" { + return nil, status.Error(codes.InvalidArgument, "key is required") + } + env, err := s.resolveEnvironment(ctx, req.EnvName) if err != nil { return nil, err @@ -167,6 +173,10 @@ func (s *environmentService) GetValue(ctx context.Context, req *azdext.GetEnvReq // SetValue sets the value of a key in the specified environment. func (s *environmentService) SetValue(ctx context.Context, req *azdext.SetEnvRequest) (*azdext.EmptyResponse, error) { + if req.Key == "" { + return nil, status.Error(codes.InvalidArgument, "key is required") + } + envManager, err := s.lazyEnvManager.GetValue() if err != nil { return nil, err diff --git a/cli/azd/internal/grpcserver/environment_service_test.go b/cli/azd/internal/grpcserver/environment_service_test.go index b63875c99df..17b9db934ef 100644 --- a/cli/azd/internal/grpcserver/environment_service_test.go +++ b/cli/azd/internal/grpcserver/environment_service_test.go @@ -15,6 +15,8 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/project" "github.com/azure/azure-dev/cli/azd/test/mocks" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) // Test_EnvironmentService_NoEnvironment verifies that when no environments are set, @@ -332,3 +334,60 @@ func Test_EnvironmentService_ResolveEnvironment(t *testing.T) { }) }) } + +// Test_EnvironmentService_EmptyKeyValidation verifies that GetValue and SetValue +// return InvalidArgument when called with an empty key. +func Test_EnvironmentService_EmptyKeyValidation(t *testing.T) { + mockContext := mocks.NewMockContext(context.Background()) + temp := t.TempDir() + + azdContext := azdcontext.NewAzdContextWithDirectory(temp) + projectConfig := project.ProjectConfig{Name: "test"} + err := project.Save(*mockContext.Context, &projectConfig, azdContext.ProjectPath()) + require.NoError(t, err) + + fileConfigManager := config.NewFileConfigManager(config.NewManager()) + localDataStore := environment.NewLocalFileDataStore(azdContext, fileConfigManager) + envManager, err := environment.NewManager( + mockContext.Container, azdContext, mockContext.Console, localDataStore, nil, + ) + require.NoError(t, err) + + env1, err := envManager.Create(*mockContext.Context, environment.Spec{Name: "env1"}) + require.NoError(t, err) + require.NoError(t, envManager.Save(*mockContext.Context, env1)) + require.NoError(t, azdContext.SetProjectState( + azdcontext.ProjectState{DefaultEnvironment: "env1"}, + )) + + service := NewEnvironmentService(lazy.From(azdContext), lazy.From(envManager)) + ctx := *mockContext.Context + + tests := []struct { + name string + method string + }{ + {"GetValue_empty_key", "GetValue"}, + {"SetValue_empty_key", "SetValue"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var callErr error + switch tt.method { + case "GetValue": + _, callErr = service.GetValue(ctx, &azdext.GetEnvRequest{Key: ""}) + case "SetValue": + _, callErr = service.SetValue( + ctx, &azdext.SetEnvRequest{Key: "", Value: "v"}, + ) + } + + require.Error(t, callErr) + st, ok := status.FromError(callErr) + require.True(t, ok) + require.Equal(t, codes.InvalidArgument, st.Code()) + require.Contains(t, st.Message(), "key is required") + }) + } +} diff --git a/cli/azd/internal/grpcserver/event_service.go b/cli/azd/internal/grpcserver/event_service.go index 531e5b3f4f5..6555a78fa55 100644 --- a/cli/azd/internal/grpcserver/event_service.go +++ b/cli/azd/internal/grpcserver/event_service.go @@ -106,13 +106,22 @@ func (s *eventService) onSubscribeProjectEvent( subscribeMsg *azdext.SubscribeProjectEvent, broker *grpcbroker.MessageBroker[azdext.EventMessage], ) error { + if subscribeMsg == nil || len(subscribeMsg.EventNames) == 0 { + return status.Error(codes.InvalidArgument, "event names are required") + } + projectConfig, err := s.lazyProject.GetValue() if err != nil { return err } - for i := 0; i < len(subscribeMsg.EventNames); i++ { - eventName := subscribeMsg.EventNames[i] + for i, eventName := range subscribeMsg.EventNames { + if eventName == "" { + return status.Errorf( + codes.InvalidArgument, + "event name at index %d cannot be empty", i, + ) + } evt := ext.Event(eventName) // Pass the stream context (ctx) which has extension claims @@ -191,13 +200,23 @@ func (s *eventService) onSubscribeServiceEvent( subscribeMsg *azdext.SubscribeServiceEvent, broker *grpcbroker.MessageBroker[azdext.EventMessage], ) error { + if subscribeMsg == nil || len(subscribeMsg.EventNames) == 0 { + return status.Error(codes.InvalidArgument, "event names are required") + } + projectConfig, err := s.lazyProject.GetValue() if err != nil { return err } - for i := 0; i < len(subscribeMsg.EventNames); i++ { - eventName := subscribeMsg.EventNames[i] + for i, eventName := range subscribeMsg.EventNames { + if eventName == "" { + return status.Errorf( + codes.InvalidArgument, + "event name at index %d cannot be empty", i, + ) + } + evt := ext.Event(eventName) for _, serviceConfig := range projectConfig.Services { if subscribeMsg.Language != "" && string(serviceConfig.Language) != subscribeMsg.Language { diff --git a/cli/azd/internal/grpcserver/event_service_test.go b/cli/azd/internal/grpcserver/event_service_test.go index c036c756c3d..670b29f098c 100644 --- a/cli/azd/internal/grpcserver/event_service_test.go +++ b/cli/azd/internal/grpcserver/event_service_test.go @@ -18,7 +18,9 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" ) // MockBidiStreamingServer mocks the gRPC bidirectional streaming server using generics @@ -178,7 +180,7 @@ func TestEventService_handleSubscribeProjectEvent(t *testing.T) { subscribeMsg: &azdext.SubscribeProjectEvent{ EventNames: []string{}, }, - expectError: false, + expectError: true, }, } @@ -360,3 +362,37 @@ func TestEventService_New(t *testing.T) { assert.NotNil(t, eventSvc.lazyEnv) assert.NotNil(t, eventSvc.console) } + +func TestEventService_EmptyEventNameInArray(t *testing.T) { + service, _ := createTestEventService() + extension := createTestExtension() + ctx := t.Context() + + t.Run("project_event_with_empty_name", func(t *testing.T) { + var mockBroker *grpcbroker.MessageBroker[azdext.EventMessage] + + err := service.onSubscribeProjectEvent(ctx, extension, &azdext.SubscribeProjectEvent{ + EventNames: []string{"prepackage", "", "postpackage"}, + }, mockBroker) + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.InvalidArgument, st.Code()) + require.Contains(t, st.Message(), "event name at index 1 cannot be empty") + }) + + t.Run("service_event_with_empty_name", func(t *testing.T) { + var mockBroker *grpcbroker.MessageBroker[azdext.EventMessage] + + err := service.onSubscribeServiceEvent(ctx, extension, &azdext.SubscribeServiceEvent{ + EventNames: []string{"", "prepackage"}, + }, mockBroker) + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.InvalidArgument, st.Code()) + require.Contains(t, st.Message(), "event name at index 0 cannot be empty") + }) +} diff --git a/cli/azd/internal/grpcserver/prompt_service.go b/cli/azd/internal/grpcserver/prompt_service.go index 119eda1d5bc..7e65a275b0f 100644 --- a/cli/azd/internal/grpcserver/prompt_service.go +++ b/cli/azd/internal/grpcserver/prompt_service.go @@ -20,6 +20,7 @@ import ( "github.com/azure/azure-dev/cli/azd/pkg/prompt" "github.com/azure/azure-dev/cli/azd/pkg/ux" "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) type promptService struct { @@ -47,6 +48,10 @@ func NewPromptService( } func (s *promptService) Confirm(ctx context.Context, req *azdext.ConfirmRequest) (*azdext.ConfirmResponse, error) { + if req == nil || req.Options == nil { + return nil, status.Error(codes.InvalidArgument, "request and options are required") + } + if s.globalOptions.NoPrompt { if req.Options.DefaultValue == nil { return nil, fmt.Errorf("no default response for prompt '%s'", req.Options.Message) @@ -80,6 +85,10 @@ func (s *promptService) Confirm(ctx context.Context, req *azdext.ConfirmRequest) } func (s *promptService) Select(ctx context.Context, req *azdext.SelectRequest) (*azdext.SelectResponse, error) { + if req == nil || req.Options == nil { + return nil, status.Error(codes.InvalidArgument, "request and options are required") + } + if s.globalOptions.NoPrompt { if req.Options.SelectedIndex == nil { return nil, fmt.Errorf("no default selection for prompt '%s'", req.Options.Message) @@ -126,6 +135,10 @@ func (s *promptService) MultiSelect( ctx context.Context, req *azdext.MultiSelectRequest, ) (*azdext.MultiSelectResponse, error) { + if req == nil || req.Options == nil { + return nil, status.Error(codes.InvalidArgument, "request and options are required") + } + if s.globalOptions.NoPrompt { var selectedChoices []*azdext.MultiSelectChoice for _, choice := range req.Options.Choices { @@ -385,6 +398,13 @@ func (s *promptService) PromptResourceGroupResource( } func (s *promptService) createAzureContext(wire *azdext.AzureContext) (*prompt.AzureContext, error) { + if wire == nil { + return nil, status.Error(codes.InvalidArgument, "azure context is required") + } + if wire.Scope == nil { + return nil, status.Error(codes.InvalidArgument, "azure context scope is required") + } + scope := prompt.AzureScope{ TenantId: wire.Scope.TenantId, SubscriptionId: wire.Scope.SubscriptionId, @@ -396,7 +416,8 @@ func (s *promptService) createAzureContext(wire *azdext.AzureContext) (*prompt.A for _, resourceId := range wire.Resources { parsedResource, err := arm.ParseResourceID(resourceId) if err != nil { - return nil, err + return nil, status.Errorf(codes.InvalidArgument, + "invalid resource ID %q: %v", resourceId, err) } resources = append(resources, parsedResource) diff --git a/cli/azd/internal/grpcserver/prompt_service_test.go b/cli/azd/internal/grpcserver/prompt_service_test.go index b59c7391ae2..f10dc82d199 100644 --- a/cli/azd/internal/grpcserver/prompt_service_test.go +++ b/cli/azd/internal/grpcserver/prompt_service_test.go @@ -19,6 +19,8 @@ import ( "github.com/azure/azure-dev/cli/azd/test/mocks/mockprompt" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" ) func Test_PromptService_Confirm_NoPromptWithDefault(t *testing.T) { @@ -1000,3 +1002,80 @@ func Test_selectModelNoPrompt(t *testing.T) { }) } } + +func Test_PromptService_NilOptions_Validation(t *testing.T) { + globalOptions := &internal.GlobalCommandOptions{NoPrompt: true} + service := NewPromptService(nil, nil, nil, globalOptions) + + tests := []struct { + name string + method string + }{ + {"Confirm_nil_options", "Confirm"}, + {"Select_nil_options", "Select"}, + {"MultiSelect_nil_options", "MultiSelect"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var err error + switch tt.method { + case "Confirm": + _, err = service.Confirm( + t.Context(), + &azdext.ConfirmRequest{Options: nil}, + ) + case "Select": + _, err = service.Select( + t.Context(), + &azdext.SelectRequest{Options: nil}, + ) + case "MultiSelect": + _, err = service.MultiSelect( + t.Context(), + &azdext.MultiSelectRequest{Options: nil}, + ) + } + + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.InvalidArgument, st.Code()) + require.Contains(t, st.Message(), "options are required") + }) + } +} + +func Test_PromptService_CreateAzureContext_NilScope(t *testing.T) { + globalOptions := &internal.GlobalCommandOptions{NoPrompt: false} + svc := NewPromptService(nil, nil, nil, globalOptions) + ps := svc.(*promptService) + + tests := []struct { + name string + wire *azdext.AzureContext + errContains string + }{ + { + name: "nil_azure_context", + wire: nil, + errContains: "azure context is required", + }, + { + name: "nil_scope", + wire: &azdext.AzureContext{Scope: nil}, + errContains: "azure context scope is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := ps.createAzureContext(tt.wire) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok) + require.Equal(t, codes.InvalidArgument, st.Code()) + require.Contains(t, st.Message(), tt.errContains) + }) + } +}