diff --git a/internal/api/relationships.go b/internal/api/relationships.go index fec54c61e..546c53710 100644 --- a/internal/api/relationships.go +++ b/internal/api/relationships.go @@ -199,7 +199,7 @@ func (r *Router) relationshipDelete(c echo.Context) error { Subject: relatedResource, } - _, err = r.engine.DeleteRelationship(ctx, relationship) + _, err = r.engine.DeleteRelationships(ctx, relationship) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "error deleting relationship").SetInternal(err) } diff --git a/internal/pubsub/subscriber.go b/internal/pubsub/subscriber.go index 72e66197a..04121a600 100644 --- a/internal/pubsub/subscriber.go +++ b/internal/pubsub/subscriber.go @@ -2,29 +2,36 @@ package pubsub import ( "context" + "errors" + "fmt" "sync" "time" "go.infratographer.com/permissions-api/internal/query" "go.infratographer.com/permissions-api/internal/types" "go.infratographer.com/x/events" - "go.infratographer.com/x/gidx" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" + "go.uber.org/multierr" "go.uber.org/zap" ) const nakDelay = 10 * time.Second -var tracer = otel.Tracer("go.infratographer.com/permissions-api/internal/pubsub") +var ( + tracer = otel.Tracer("go.infratographer.com/permissions-api/internal/pubsub") + + // ErrUnknownResourceType is returned when the corresponding resource type is not found for a resource id. + ErrUnknownResourceType = errors.New("unknown resource type") +) // Subscriber is the subscriber client type Subscriber struct { ctx context.Context - changeChannels []<-chan events.Message[events.ChangeMessage] + changeChannels []<-chan events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse] logger *zap.SugaredLogger - subscriber events.Subscriber + subscriber events.AuthRelationshipSubscriber qe query.Engine } @@ -39,7 +46,7 @@ func WithLogger(l *zap.SugaredLogger) SubscriberOption { } // NewSubscriber creates a new Subscriber -func NewSubscriber(ctx context.Context, subscriber events.Subscriber, engine query.Engine, opts ...SubscriberOption) (*Subscriber, error) { +func NewSubscriber(ctx context.Context, subscriber events.AuthRelationshipSubscriber, engine query.Engine, opts ...SubscriberOption) (*Subscriber, error) { s := &Subscriber{ ctx: ctx, logger: zap.NewNop().Sugar(), @@ -56,7 +63,7 @@ func NewSubscriber(ctx context.Context, subscriber events.Subscriber, engine que // Subscribe subscribes to a nats subject func (s *Subscriber) Subscribe(topic string) error { - msgChan, err := s.subscriber.SubscribeChanges(s.ctx, topic) + msgChan, err := s.subscriber.SubscribeAuthRelationshipRequests(s.ctx, topic) if err != nil { return err } @@ -83,14 +90,15 @@ func (s Subscriber) Listen() error { } // listen listens for messages on a channel and calls the registered message handler -func (s Subscriber) listen(messages <-chan events.Message[events.ChangeMessage], wg *sync.WaitGroup) { +func (s Subscriber) listen(messages <-chan events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse], wg *sync.WaitGroup) { defer wg.Done() for msg := range messages { elogger := s.logger.With( - "event.message.id", msg.ID(), - "event.message.timestamp", msg.Timestamp(), - "event.message.deliveries", msg.Deliveries(), + "event.message.topic", msg.Topic(), + "event.message.action", msg.Message().Action, + "event.message.object.id", msg.Message().ObjectID.String(), + "event.message.relations", len(msg.Message().Relations), ) if err := s.processEvent(msg); err != nil { @@ -100,17 +108,18 @@ func (s Subscriber) listen(messages <-chan events.Message[events.ChangeMessage], elogger.Warnw("error occurred while naking", "error", nakErr) } } else if ackErr := msg.Ack(); ackErr != nil { - elogger.Warnw("error occurred while acking", "error", ackErr) + elogger.Errorw("error occurred while acking", "error", ackErr) } } } // processEvent event message handler -func (s *Subscriber) processEvent(msg events.Message[events.ChangeMessage]) error { +func (s *Subscriber) processEvent(msg events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse]) error { elogger := s.logger.With( - "event.message.id", msg.ID(), - "event.message.timestamp", msg.Timestamp(), - "event.message.deliveries", msg.Deliveries(), + "event.message.topic", msg.Topic(), + "event.message.action", msg.Message().Action, + "event.message.object.id", msg.Message().ObjectID.String(), + "event.message.relations", len(msg.Message().Relations), ) if msg.Error() != nil { @@ -119,34 +128,46 @@ func (s *Subscriber) processEvent(msg events.Message[events.ChangeMessage]) erro return msg.Error() } - changeMsg := msg.Message() + request := msg.Message() - ctx := changeMsg.GetTraceContext(context.Background()) + ctx := request.GetTraceContext(context.Background()) - ctx, span := tracer.Start(ctx, "pubsub.receive", trace.WithAttributes(attribute.String("pubsub.subject", changeMsg.SubjectID.String()))) + ctx, span := tracer.Start(ctx, "pubsub.receive", trace.WithAttributes(attribute.String("pubsub.subject", request.ObjectID.String()))) defer span.End() - elogger = elogger.With( - "event.resource.id", changeMsg.SubjectID.String(), - "event.type", changeMsg.EventType, - ) - elogger.Debugw("received message") var err error - switch events.ChangeType(changeMsg.EventType) { - case events.CreateChangeType: + switch request.Action { + case events.WriteAuthRelationshipAction: err = s.handleCreateEvent(ctx, msg) - case events.UpdateChangeType: - err = s.handleUpdateEvent(ctx, msg) - case events.DeleteChangeType: + case events.DeleteAuthRelationshipAction: err = s.handleDeleteEvent(ctx, msg) default: - elogger.Warnw("ignoring msg, not a create, update or delete event") + elogger.Warnw("ignoring msg, not a write or delete action") + } + + if err != nil { + return err + } + + return nil +} + +func (s *Subscriber) createRelationships(ctx context.Context, relationships []types.Relationship) error { + // Attempt to create the relationships in SpiceDB. + _, err := s.qe.CreateRelationships(ctx, relationships) + if err != nil { + return fmt.Errorf("%w: error creating relationships", err) } + return nil +} + +func (s *Subscriber) deleteRelationships(ctx context.Context, relationships []types.Relationship) error { + _, err := s.qe.DeleteRelationships(ctx, relationships...) if err != nil { return err } @@ -154,106 +175,165 @@ func (s *Subscriber) processEvent(msg events.Message[events.ChangeMessage]) erro return nil } -func (s *Subscriber) createRelationships(ctx context.Context, msg events.Message[events.ChangeMessage], resource types.Resource, additionalSubjectIDs []gidx.PrefixedID) error { - var relationships []types.Relationship +func (s *Subscriber) handleCreateEvent(ctx context.Context, msg events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse]) error { + elogger := s.logger.With( + "event.message.topic", msg.Topic(), + "event.message.action", msg.Message().Action, + "event.message.object.id", msg.Message().ObjectID.String(), + "event.message.relations", len(msg.Message().Relations), + ) + + var errors []error + + if err := msg.Message().Validate(); err != nil { + errors = multierr.Errors(err) + } + + resource, err := s.qe.NewResourceFromID(msg.Message().ObjectID) + if err != nil { + elogger.Warnw("error parsing resource ID", "error", err.Error()) + + respondRequest(ctx, elogger, msg, err) + + return nil + } rType := s.qe.GetResourceType(resource.Type) if rType == nil { - s.logger.Warnw("no resource type found for", "resource_type", resource.Type) + elogger.Warnw("error finding resource type", "error", err.Error()) - return nil + respondRequest(ctx, elogger, msg, fmt.Errorf("%w: resource: %s", ErrUnknownResourceType, resource.Type)) } - // Attempt to create relationships from the message fields. If this fails, reject the message - for _, id := range additionalSubjectIDs { - subjResource, err := s.qe.NewResourceFromID(id) + relationships := make([]types.Relationship, len(msg.Message().Relations)) + + for i, relation := range msg.Message().Relations { + subject, err := s.qe.NewResourceFromID(relation.SubjectID) if err != nil { - s.logger.Warnw("error parsing additional subject id - will not reprocess", "event_type", events.CreateChangeType, "id", id.String(), "error", err.Error()) + elogger.Warnw("error parsing subject ID", "error", err.Error()) + + errors = append(errors, fmt.Errorf("%w: relation %d", err, i)) continue } - for _, rel := range rType.Relationships { - var relation string + sType := s.qe.GetResourceType(subject.Type) + if sType == nil { + elogger.Warnw("error finding subject resource type", "error", err.Error()) - for _, tName := range rel.Types { - if tName == subjResource.Type { - relation = rel.Relation + errors = append(errors, fmt.Errorf("%w: relation %d subject: %s", ErrUnknownResourceType, i, subject.Type)) - break - } - } - - if relation != "" { - relationship := types.Relationship{ - Resource: resource, - Relation: relation, - Subject: subjResource, - } + continue + } - relationships = append(relationships, relationship) - } + relationships[i] = types.Relationship{ + Resource: resource, + Relation: relation.Relation, + Subject: subject, } } - if len(relationships) == 0 { - s.logger.Warnw("no relations to create for resource", "resource_type", resource.Type, "resource_id", resource.ID.String()) - - return nil + if len(errors) != 0 { + respondRequest(ctx, elogger, msg, errors...) } - // Attempt to create the relationships in SpiceDB. If this fails, nak the message for reprocessing - _, err := s.qe.CreateRelationships(ctx, relationships) - if err != nil { - s.logger.Errorw("error creating relationships - will not reprocess", "error", err.Error()) - } + err = s.createRelationships(ctx, relationships) + + respondRequest(ctx, elogger, msg, err) return nil } -func (s *Subscriber) deleteRelationships(ctx context.Context, msg events.Message[events.ChangeMessage], resource types.Resource) error { - _, err := s.qe.DeleteRelationships(ctx, resource) +func (s *Subscriber) handleDeleteEvent(ctx context.Context, msg events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse]) error { + elogger := s.logger.With( + "event.message.topic", msg.Topic(), + "event.message.action", msg.Message().Action, + "event.message.object.id", msg.Message().ObjectID.String(), + "event.message.relations", len(msg.Message().Relations), + ) + + var errors []error + + if err := msg.Message().Validate(); err != nil { + errors = multierr.Errors(err) + } + + resource, err := s.qe.NewResourceFromID(msg.Message().ObjectID) if err != nil { - s.logger.Errorw("error deleting relationships - will not reprocess", "error", err.Error()) + elogger.Warnw("error parsing resource ID", "error", err.Error()) + + errors = append(errors, err) } - return nil -} + rType := s.qe.GetResourceType(resource.Type) + if rType == nil { + elogger.Warnw("error finding resource type", "error", err.Error()) -func (s *Subscriber) handleCreateEvent(ctx context.Context, msg events.Message[events.ChangeMessage]) error { - resource, err := s.qe.NewResourceFromID(msg.Message().SubjectID) - if err != nil { - s.logger.Warnw("error parsing subject ID - will not reprocess", "event_type", msg.Message().EventType, "error", err.Error()) + errors = append(errors, fmt.Errorf("%w: resource: %s", ErrUnknownResourceType, resource.Type)) + } - return nil + relationships := make([]types.Relationship, len(msg.Message().Relations)) + + for i, relation := range msg.Message().Relations { + subject, err := s.qe.NewResourceFromID(relation.SubjectID) + if err != nil { + elogger.Warnw("error parsing subject ID", "error", err.Error()) + + errors = append(errors, fmt.Errorf("%w: relation %d", err, i)) + + continue + } + + sType := s.qe.GetResourceType(subject.Type) + if sType == nil { + elogger.Warnw("error finding subject resource type", "error", err.Error()) + + errors = append(errors, fmt.Errorf("%w: relation %d subject: %s", ErrUnknownResourceType, i, subject.Type)) + + continue + } + + relationships[i] = types.Relationship{ + Resource: resource, + Relation: relation.Relation, + Subject: subject, + } } - return s.createRelationships(ctx, msg, resource, msg.Message().AdditionalSubjectIDs) + if len(errors) != 0 { + respondRequest(ctx, elogger, msg, errors...) + } + + err = s.deleteRelationships(ctx, relationships) + + respondRequest(ctx, elogger, msg, err) + + return nil } -func (s *Subscriber) handleDeleteEvent(ctx context.Context, msg events.Message[events.ChangeMessage]) error { - resource, err := s.qe.NewResourceFromID(msg.Message().SubjectID) - if err != nil { - s.logger.Warnw("error parsing subject ID - will not reprocess", "event_type", msg.Message().EventType, "error", err.Error()) +func respondRequest(ctx context.Context, logger *zap.SugaredLogger, msg events.Request[events.AuthRelationshipRequest, events.AuthRelationshipResponse], errors ...error) { + var filteredErrors []error - return nil + for _, err := range errors { + if err != nil { + filteredErrors = append(filteredErrors, err) + } } - return s.deleteRelationships(ctx, msg, resource) -} + response := events.AuthRelationshipResponse{ + Errors: filteredErrors, + } -func (s *Subscriber) handleUpdateEvent(ctx context.Context, msg events.Message[events.ChangeMessage]) error { - resource, err := s.qe.NewResourceFromID(msg.Message().SubjectID) - if err != nil { - s.logger.Warnw("error parsing subject ID - will not reprocess", "event_type", msg.Message().EventType, "error", err.Error()) + if len(filteredErrors) != 0 { + err := multierr.Combine(filteredErrors...) - return nil + logger.Errorw("error processing relationship, sending error response", "error", err) + } else { + logger.Debug("relationship successfully processed, sending response") } - err = s.deleteRelationships(ctx, msg, resource) + _, err := msg.Reply(ctx, response) if err != nil { - return err + logger.Errorw("error sending response", "error", err) } - - return s.createRelationships(ctx, msg, resource, msg.Message().AdditionalSubjectIDs) } diff --git a/internal/pubsub/subscriber_test.go b/internal/pubsub/subscriber_test.go index 4ef1d0301..8ce35c531 100644 --- a/internal/pubsub/subscriber_test.go +++ b/internal/pubsub/subscriber_test.go @@ -16,13 +16,9 @@ import ( "github.com/stretchr/testify/require" ) -const ( - sampleFrequency = "100" -) - var contextKeyEngine = struct{}{} -func setupEvents(t *testing.T, engine query.Engine) (*eventtools.TestNats, events.Publisher, *Subscriber) { +func setupEvents(t *testing.T, engine query.Engine) (*eventtools.TestNats, events.AuthRelationshipPublisher, *Subscriber) { ctx := context.Background() nats, err := eventtools.NewNatsServer() @@ -47,44 +43,51 @@ func setupEvents(t *testing.T, engine query.Engine) (*eventtools.TestNats, event func TestNATS(t *testing.T) { type testInput struct { - subject string - changeMessage events.ChangeMessage + subject string + request events.AuthRelationshipRequest } - createMsg := events.ChangeMessage{ - SubjectID: gidx.PrefixedID("loadbal-UCN7pxJO57BV_5pNiV95B"), - EventType: string(events.CreateChangeType), - AdditionalSubjectIDs: []gidx.PrefixedID{ - gidx.PrefixedID("othrsid-kXboa2UZbaNzMhng9vVha"), - gidx.PrefixedID("tnntten-gd6RExwAz353UqHLzjC1n"), + createMsg := events.AuthRelationshipRequest{ + Action: events.WriteAuthRelationshipAction, + ObjectID: gidx.PrefixedID("loadbal-UCN7pxJO57BV_5pNiV95B"), + Relations: []events.AuthRelationshipRelation{ + { + Relation: "owner", + SubjectID: gidx.PrefixedID("tnntten-gd6RExwAz353UqHLzjC1n"), + }, }, } - noCreateMsg := events.ChangeMessage{ - SubjectID: gidx.PrefixedID("loadbal-EA8CJagJPM4J-yw6_skd1"), - EventType: string(events.CreateChangeType), - AdditionalSubjectIDs: []gidx.PrefixedID{ - gidx.PrefixedID("othrsid-kXboa2UZbaNzMhng9vVha"), + noCreateMsg := events.AuthRelationshipRequest{ + Action: events.WriteAuthRelationshipAction, + ObjectID: gidx.PrefixedID("loadbal-EA8CJagJPM4J-yw6_skd1"), + Relations: []events.AuthRelationshipRelation{ + { + Relation: "owner", + }, }, } - updateMsg := events.ChangeMessage{ - SubjectID: gidx.PrefixedID("loadbal-UCN7pxJO57BV_5pNiV95B"), - EventType: string(events.UpdateChangeType), - AdditionalSubjectIDs: []gidx.PrefixedID{ - gidx.PrefixedID("othrsid-kXboa2UZbaNzMhng9vVha"), - gidx.PrefixedID("tnntten-gd6RExwAz353UqHLzjC1n"), + deleteMsg := events.AuthRelationshipRequest{ + Action: events.DeleteAuthRelationshipAction, + ObjectID: gidx.PrefixedID("loadbal-UCN7pxJO57BV_5pNiV95B"), + Relations: []events.AuthRelationshipRelation{ + { + Relation: "owner", + SubjectID: gidx.PrefixedID("tnntten-gd6RExwAz353UqHLzjC1n"), + }, }, } - deleteMsg := events.ChangeMessage{ - SubjectID: gidx.PrefixedID("loadbal-UCN7pxJO57BV_5pNiV95B"), - EventType: string(events.DeleteChangeType), - } - - unknownResourceMsg := events.ChangeMessage{ - SubjectID: gidx.PrefixedID("baddres-BfqAzfYxtFNlpKPGYLmra"), - EventType: string(events.CreateChangeType), + unknownResourceMsg := events.AuthRelationshipRequest{ + Action: events.WriteAuthRelationshipAction, + ObjectID: gidx.PrefixedID("baddres-BfqAzfYxtFNlpKPGYLmra"), + Relations: []events.AuthRelationshipRelation{ + { + Relation: "owner", + SubjectID: gidx.PrefixedID("tnntten-gd6RExwAz353UqHLzjC1n"), + }, + }, } // Each of these tests works as follows: @@ -96,12 +99,12 @@ func TestNATS(t *testing.T) { // // When writing tests, make sure the subject prefix in the test input matches the prefix provided in // setupClient, or else you will get undefined, racy behavior. - testCases := []testingx.TestCase[testInput, *Subscriber]{ + testCases := []testingx.TestCase[testInput, events.Message[events.AuthRelationshipResponse]]{ { Name: "goodcreate", Input: testInput{ - subject: "goodcreate.loadbalancer", - changeMessage: createMsg, + subject: "goodcreate.loadbalancer", + request: createMsg, }, SetupFn: func(ctx context.Context, t *testing.T) context.Context { var engine mock.Engine @@ -109,18 +112,20 @@ func TestNATS(t *testing.T) { return context.WithValue(ctx, contextKeyEngine, &engine) }, - CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[*Subscriber]) { + CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[events.Message[events.AuthRelationshipResponse]]) { require.NoError(t, result.Err) + require.NotNil(t, result.Success) + require.Empty(t, result.Success.Message().Errors) - engine := result.Success.qe.(*mock.Engine) + engine := ctx.Value(contextKeyEngine).(*mock.Engine) engine.AssertExpectations(t) }, }, { Name: "errcreate", Input: testInput{ - subject: "errcreate.loadbalancer", - changeMessage: createMsg, + subject: "errcreate.loadbalancer", + request: createMsg, }, SetupFn: func(ctx context.Context, t *testing.T) context.Context { var engine mock.Engine @@ -128,50 +133,34 @@ func TestNATS(t *testing.T) { return context.WithValue(ctx, contextKeyEngine, &engine) }, - CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[*Subscriber]) { + CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[events.Message[events.AuthRelationshipResponse]]) { require.NoError(t, result.Err) + require.NotNil(t, result.Success) + require.NotEmpty(t, result.Success.Message().Errors) }, }, { Name: "nocreate", Input: testInput{ - subject: "nocreate.loadbalancer", - changeMessage: noCreateMsg, + subject: "nocreate.loadbalancer", + request: noCreateMsg, }, SetupFn: func(ctx context.Context, t *testing.T) context.Context { var engine mock.Engine return context.WithValue(ctx, contextKeyEngine, &engine) }, - CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[*Subscriber]) { - require.NoError(t, result.Err) - }, - }, - { - Name: "goodupdate", - Input: testInput{ - subject: "goodupdate.loadbalancer", - changeMessage: updateMsg, - }, - SetupFn: func(ctx context.Context, t *testing.T) context.Context { - var engine mock.Engine - engine.On("DeleteRelationships").Return("", nil) - engine.On("CreateRelationships").Return("", nil) - - return context.WithValue(ctx, contextKeyEngine, &engine) - }, - CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[*Subscriber]) { - require.NoError(t, result.Err) - - engine := result.Success.qe.(*mock.Engine) - engine.AssertExpectations(t) + CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[events.Message[events.AuthRelationshipResponse]]) { + require.Error(t, result.Err) + require.ErrorIs(t, result.Err, events.ErrMissingAuthRelationshipRequestRelationSubjectID) + require.Nil(t, result.Success) }, }, { Name: "gooddelete", Input: testInput{ - subject: "gooddelete.loadbalancer", - changeMessage: deleteMsg, + subject: "gooddelete.loadbalancer", + request: deleteMsg, }, SetupFn: func(ctx context.Context, t *testing.T) context.Context { var engine mock.Engine @@ -180,36 +169,38 @@ func TestNATS(t *testing.T) { return context.WithValue(ctx, contextKeyEngine, &engine) }, - CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[*Subscriber]) { + CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[events.Message[events.AuthRelationshipResponse]]) { require.NoError(t, result.Err) + require.NotNil(t, result.Success) + require.Empty(t, result.Success.Message().Errors) - engine := result.Success.qe.(*mock.Engine) + engine := ctx.Value(contextKeyEngine).(*mock.Engine) engine.AssertExpectations(t) }, }, { Name: "badresource", Input: testInput{ - subject: "badresource.fakeresource", - changeMessage: unknownResourceMsg, + subject: "badresource.fakeresource", + request: unknownResourceMsg, }, SetupFn: func(ctx context.Context, t *testing.T) context.Context { var engine mock.Engine return context.WithValue(ctx, contextKeyEngine, &engine) }, - CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[*Subscriber]) { + CheckFn: func(ctx context.Context, t *testing.T, result testingx.TestResult[events.Message[events.AuthRelationshipResponse]]) { require.NoError(t, result.Err) + require.NotNil(t, result.Success) + require.NotEmpty(t, result.Success.Message().Errors) }, }, } - testFn := func(ctx context.Context, input testInput) testingx.TestResult[*Subscriber] { + testFn := func(ctx context.Context, input testInput) testingx.TestResult[events.Message[events.AuthRelationshipResponse]] { engine := ctx.Value(contextKeyEngine).(query.Engine) - nats, pub, sub := setupEvents(t, engine) - - consumerName := events.NATSConsumerDurableName("", eventtools.Prefix+".changes.*."+input.subject) + _, pub, sub := setupEvents(t, engine) err := sub.Subscribe("*." + input.subject) @@ -224,28 +215,13 @@ func TestNATS(t *testing.T) { // Allow time for the listener to to start time.Sleep(time.Second) - err = nats.SetConsumerSampleFrequency(consumerName, sampleFrequency) - require.NoError(t, err) - ackErr := make(chan error, 1) - - go func() { - ackErr <- nats.WaitForAck(consumerName, 5*time.Second) - }() - - _, err = pub.PublishChange(ctx, input.subject, input.changeMessage) - - require.NoError(t, err) - - if err = <-ackErr; err != nil { - return testingx.TestResult[*Subscriber]{ - Err: err, - } - } + resp, err := pub.PublishAuthRelationshipRequest(ctx, input.subject, input.request) - return testingx.TestResult[*Subscriber]{ - Success: sub, + return testingx.TestResult[events.Message[events.AuthRelationshipResponse]]{ + Err: err, + Success: resp, } } diff --git a/internal/query/mock/mock.go b/internal/query/mock/mock.go index 5105f0770..e20ace549 100644 --- a/internal/query/mock/mock.go +++ b/internal/query/mock/mock.go @@ -87,8 +87,8 @@ func (e *Engine) ListRoles(ctx context.Context, resource types.Resource, queryTo return nil, nil } -// DeleteRelationship does nothing but satisfies the Engine interface. -func (e *Engine) DeleteRelationship(ctx context.Context, rel types.Relationship) (string, error) { +// DeleteRelationships does nothing but satisfies the Engine interface. +func (e *Engine) DeleteRelationships(ctx context.Context, relationships ...types.Relationship) (string, error) { args := e.Called() return args.String(0), args.Error(1) @@ -101,8 +101,8 @@ func (e *Engine) DeleteRole(ctx context.Context, roleResource types.Resource, qu return args.String(0), args.Error(1) } -// DeleteRelationships does nothing but satisfies the Engine interface. -func (e *Engine) DeleteRelationships(ctx context.Context, resource types.Resource) (string, error) { +// DeleteResourceRelationships does nothing but satisfies the Engine interface. +func (e *Engine) DeleteResourceRelationships(ctx context.Context, resource types.Resource) (string, error) { args := e.Called() return args.String(0), args.Error(1) diff --git a/internal/query/relations.go b/internal/query/relations.go index dfcc1cc10..bf30a336a 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -9,6 +9,7 @@ import ( pb "github.com/authzed/authzed-go/proto/authzed/api/v1" "go.infratographer.com/permissions-api/internal/types" "go.infratographer.com/x/gidx" + "go.uber.org/multierr" ) var roleSubjectRelation = "subject" @@ -330,31 +331,77 @@ func (e *engine) readRelationships(ctx context.Context, filter *pb.RelationshipF return responses, nil } -// DeleteRelationship removes the specified relationship between the two resources. -func (e *engine) DeleteRelationship(ctx context.Context, rel types.Relationship) (string, error) { - err := e.validateRelationship(rel) - if err != nil { - return "", err +// DeleteRelationships removes the specified relationships. +// If any relationships fails to be deleted, all completed deletions are re-created. +func (e *engine) DeleteRelationships(ctx context.Context, relationships ...types.Relationship) (string, error) { + var errors []error + + for i, relationship := range relationships { + err := e.validateRelationship(relationship) + if err != nil { + errors = append(errors, fmt.Errorf("%w: invalid relationship %d", err, i)) + } } - resType := e.namespace + "/" + rel.Resource.Type - subjType := e.namespace + "/" + rel.Subject.Type + if len(errors) != 0 { + return "", multierr.Combine(errors...) + } - filter := &pb.RelationshipFilter{ - ResourceType: resType, - OptionalResourceId: rel.Resource.ID.String(), - OptionalRelation: rel.Relation, - OptionalSubjectFilter: &pb.SubjectFilter{ - SubjectType: subjType, - OptionalSubjectId: rel.Subject.ID.String(), - }, + errors = []error{} + + var ( + complete []types.Relationship + queryToken string + dErr error + cErr error + ) + + for i, relationship := range relationships { + resType := e.namespace + "/" + relationship.Resource.Type + subjType := e.namespace + "/" + relationship.Subject.Type + + filter := &pb.RelationshipFilter{ + ResourceType: resType, + OptionalResourceId: relationship.Resource.ID.String(), + OptionalRelation: relationship.Relation, + OptionalSubjectFilter: &pb.SubjectFilter{ + SubjectType: subjType, + OptionalSubjectId: relationship.Subject.ID.String(), + }, + } + + queryToken, dErr = e.deleteRelationships(ctx, filter) + if dErr != nil { + e.logger.Errorf("%w: failed to delete relationships %d reverting %d completed deletes", dErr, i, len(complete)) + + errors = append(errors, fmt.Errorf("%w: failed to delete relationship %d reverting", dErr, i)) + + break + } + + complete = append(complete, relationship) } - return e.deleteRelationships(ctx, filter) + if len(errors) != 0 { + if len(complete) != 0 { + _, cErr = e.CreateRelationships(ctx, complete) + if cErr != nil { + e.logger.Error("%w: failed to revert deleted relationships %d", cErr, len(complete)) + + err := fmt.Errorf("%w: failed to revert deleted relationships %d", cErr, len(complete)) + + errors = append(errors, err) + } + } + + return "", multierr.Combine(errors...) + } + + return queryToken, nil } -// DeleteRelationships deletes all relationships originating from the given resource. -func (e *engine) DeleteRelationships(ctx context.Context, resource types.Resource) (string, error) { +// DeleteResourceRelationships deletes all relationships originating from the given resource. +func (e *engine) DeleteResourceRelationships(ctx context.Context, resource types.Resource) (string, error) { resType := e.namespace + "/" + resource.Type filter := &pb.RelationshipFilter{ diff --git a/internal/query/relations_test.go b/internal/query/relations_test.go index d568c6306..c844d40f7 100644 --- a/internal/query/relations_test.go +++ b/internal/query/relations_test.go @@ -533,7 +533,7 @@ func TestRelationshipDelete(t *testing.T) { } testFn := func(ctx context.Context, input types.Relationship) testingx.TestResult[[]types.Relationship] { - queryToken, err := e.DeleteRelationship(ctx, input) + queryToken, err := e.DeleteRelationships(ctx, input) if err != nil { return testingx.TestResult[[]types.Relationship]{ Err: err, diff --git a/internal/query/service.go b/internal/query/service.go index c984b65ab..bec6021d4 100644 --- a/internal/query/service.go +++ b/internal/query/service.go @@ -23,9 +23,9 @@ type Engine interface { ListRelationshipsFrom(ctx context.Context, resource types.Resource, queryToken string) ([]types.Relationship, error) ListRelationshipsTo(ctx context.Context, resource types.Resource, queryToken string) ([]types.Relationship, error) ListRoles(ctx context.Context, resource types.Resource, queryToken string) ([]types.Role, error) - DeleteRelationship(ctx context.Context, rel types.Relationship) (string, error) + DeleteRelationships(ctx context.Context, relationships ...types.Relationship) (string, error) DeleteRole(ctx context.Context, roleResource types.Resource, queryToken string) (string, error) - DeleteRelationships(ctx context.Context, resource types.Resource) (string, error) + DeleteResourceRelationships(ctx context.Context, resource types.Resource) (string, error) NewResourceFromID(id gidx.PrefixedID) (types.Resource, error) GetResourceType(name string) *types.ResourceType SubjectHasPermission(ctx context.Context, subject types.Resource, action string, resource types.Resource) error