diff --git a/internal/query/relations.go b/internal/query/relations.go index 508de160c..d7f913bb7 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -371,8 +371,22 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou return types.Role{}, err } + err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + if err != nil { + sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) + + span.RecordError(sErr) + span.SetStatus(codes.Error, sErr.Error()) + + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) + + return types.Role{}, err + } + role, err := e.GetRole(dbCtx, roleResource) if err != nil { + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) + return types.Role{}, err } @@ -1003,14 +1017,30 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er defer span.End() - var ( - resActions map[types.Resource][]string - err error - ) + dbCtx, err := e.store.BeginContext(ctx) + if err != nil { + return err + } + + err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + if err != nil { + sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) + + span.RecordError(sErr) + span.SetStatus(codes.Error, sErr.Error()) + + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) + + return err + } + + var resActions map[types.Resource][]string for _, resType := range e.schemaRoleables { resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name) if err != nil { + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) + return err } @@ -1020,10 +1050,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er } } - if len(resActions) == 0 { - return ErrRoleNotFound - } - roleType := e.namespace + "/role" var filters []*pb.RelationshipFilter @@ -1047,11 +1073,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er } } - dbCtx, err := e.store.BeginContext(ctx) - if err != nil { - return err - } - _, err = e.store.DeleteRole(dbCtx, roleResource.ID) if err != nil { logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) diff --git a/internal/query/relations_test.go b/internal/query/relations_test.go index afa88eaa6..2116f01e7 100644 --- a/internal/query/relations_test.go +++ b/internal/query/relations_test.go @@ -13,6 +13,7 @@ import ( "go.infratographer.com/permissions-api/internal/iapl" "go.infratographer.com/permissions-api/internal/spicedbx" + "go.infratographer.com/permissions-api/internal/storage" "go.infratographer.com/permissions-api/internal/storage/teststore" "go.infratographer.com/permissions-api/internal/testingx" "go.infratographer.com/permissions-api/internal/types" @@ -229,7 +230,7 @@ func TestRoleUpdate(t *testing.T) { Input: gidx.MustNewID(RolePrefix), CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { require.Error(t, res.Err) - assert.ErrorIs(t, res.Err, ErrRoleNotFound) + assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) }, }, { diff --git a/internal/storage/roles.go b/internal/storage/roles.go index d3cd24fc6..dbe1af45b 100644 --- a/internal/storage/roles.go +++ b/internal/storage/roles.go @@ -18,6 +18,7 @@ type RoleService interface { CreateRole(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (Role, error) UpdateRole(ctx context.Context, actorID, roleID gidx.PrefixedID, name string) (Role, error) DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (Role, error) + LockRoleForUpdate(ctx context.Context, roleID gidx.PrefixedID) error } // Role represents a role in the database. @@ -74,6 +75,29 @@ func (e *engine) GetRoleByID(ctx context.Context, id gidx.PrefixedID) (Role, err return role, nil } +// LockRoleForUpdate locks the provided role's record to be updated to ensure consistency. +// If no role exists an ErrNoRoleFound error is returned. +func (e *engine) LockRoleForUpdate(ctx context.Context, id gidx.PrefixedID) error { + db, err := getContextDBQuery(ctx, e) + if err != nil { + return err + } + + var one int + + err = db.QueryRowContext(ctx, `SELECT 1 FROM roles WHERE id = $1 FOR UPDATE`, id.String()).Scan(&one) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("%w: %s", ErrNoRoleFound, id.String()) + } + + return fmt.Errorf("%w: %s", err, id.String()) + } + + return nil +} + // GetResourceRoleByName retrieves a role from the database by the provided resource ID and role name. // If no role exists an ErrRoleNotFound error is returned. func (e *engine) GetResourceRoleByName(ctx context.Context, resourceID gidx.PrefixedID, name string) (Role, error) {