Skip to content

Commit

Permalink
lock role record before updating or deleting
Browse files Browse the repository at this point in the history
Since we're working with multiple backends, this allows us to place a
lock early and ensure a separate request doesn't conflict with an
in-flight change.

Signed-off-by: Mike Mason <[email protected]>
  • Loading branch information
mikemrm committed Jan 16, 2024
1 parent 3e0f1a8 commit ecd30fe
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 14 deletions.
47 changes: 34 additions & 13 deletions internal/query/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,22 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou
return types.Role{}, err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

span.RecordError(sErr)
span.SetStatus(codes.Error, sErr.Error())

logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return types.Role{}, err
}

role, err := e.GetRole(dbCtx, roleResource)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return types.Role{}, err
}

Expand Down Expand Up @@ -1010,14 +1024,30 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er

defer span.End()

var (
resActions map[types.Resource][]string
err error
)
dbCtx, err := e.store.BeginContext(ctx)
if err != nil {
return err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

span.RecordError(sErr)
span.SetStatus(codes.Error, sErr.Error())

logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}

var resActions map[types.Resource][]string

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}

Expand All @@ -1027,10 +1057,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
}
}

if len(resActions) == 0 {
return ErrRoleNotFound
}

roleType := e.namespace + "/role"

var filters []*pb.RelationshipFilter
Expand All @@ -1054,11 +1080,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
}
}

dbCtx, err := e.store.BeginContext(ctx)
if err != nil {
return err
}

_, err = e.store.DeleteRole(dbCtx, roleResource.ID)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))
Expand Down
3 changes: 2 additions & 1 deletion internal/query/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

"go.infratographer.com/permissions-api/internal/iapl"
"go.infratographer.com/permissions-api/internal/spicedbx"
"go.infratographer.com/permissions-api/internal/storage"
"go.infratographer.com/permissions-api/internal/storage/teststore"
"go.infratographer.com/permissions-api/internal/testingx"
"go.infratographer.com/permissions-api/internal/types"
Expand Down Expand Up @@ -244,7 +245,7 @@ func TestRoleUpdate(t *testing.T) {
Input: gidx.MustNewID(RolePrefix),
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
require.Error(t, res.Err)
assert.ErrorIs(t, res.Err, ErrRoleNotFound)
assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound)
},
},
{
Expand Down
26 changes: 26 additions & 0 deletions internal/storage/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RoleService interface {
CreateRole(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (Role, error)
UpdateRole(ctx context.Context, actorID, roleID gidx.PrefixedID, name string) (Role, error)
DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (Role, error)
LockRoleForUpdate(ctx context.Context, roleID gidx.PrefixedID) error
}

// Role represents a role in the database.
Expand Down Expand Up @@ -74,6 +75,31 @@ func (e *engine) GetRoleByID(ctx context.Context, id gidx.PrefixedID) (Role, err
return role, nil
}

// LockRoleForUpdate locks the provided role's record to be updated to ensure consistency.
// If no role exists an ErrNoRoleFound error is returned.
func (e *engine) LockRoleForUpdate(ctx context.Context, id gidx.PrefixedID) error {
db, err := getContextDBQuery(ctx, e)
if err != nil {
return err
}

result, err := db.ExecContext(ctx, `SELECT 1 FROM roles WHERE id = $1 FOR UPDATE`, id.String())
if err != nil {
return err
}

rowsAffected, err := result.RowsAffected()
if err != nil {
return err
}

if rowsAffected == 0 {
return ErrNoRoleFound
}

return nil
}

// GetResourceRoleByName retrieves a role from the database by the provided resource ID and role name.
// If no role exists an ErrRoleNotFound error is returned.
func (e *engine) GetResourceRoleByName(ctx context.Context, resourceID gidx.PrefixedID, name string) (Role, error) {
Expand Down

0 comments on commit ecd30fe

Please sign in to comment.