Skip to content

Commit

Permalink
lock role record before updating or deleting
Browse files Browse the repository at this point in the history
Since we're working with multiple backends, this allows us to place a
lock early and ensure a separate request doesn't conflict with an
in-flight change.

Signed-off-by: Mike Mason <[email protected]>
  • Loading branch information
mikemrm committed Jan 11, 2024
1 parent 199915c commit efd9772
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 14 deletions.
47 changes: 34 additions & 13 deletions internal/query/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,22 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou
return types.Role{}, err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

span.RecordError(sErr)
span.SetStatus(codes.Error, sErr.Error())

logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return types.Role{}, err
}

role, err := e.GetRole(dbCtx, roleResource)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return types.Role{}, err
}

Expand Down Expand Up @@ -1003,14 +1017,30 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er

defer span.End()

var (
resActions map[types.Resource][]string
err error
)
dbCtx, err := e.store.BeginContext(ctx)
if err != nil {
return err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)

span.RecordError(sErr)
span.SetStatus(codes.Error, sErr.Error())

logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}

var resActions map[types.Resource][]string

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}

Expand All @@ -1020,10 +1050,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
}
}

if len(resActions) == 0 {
return ErrRoleNotFound
}

roleType := e.namespace + "/role"

var filters []*pb.RelationshipFilter
Expand All @@ -1047,11 +1073,6 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
}
}

dbCtx, err := e.store.BeginContext(ctx)
if err != nil {
return err
}

_, err = e.store.DeleteRole(dbCtx, roleResource.ID)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))
Expand Down
3 changes: 2 additions & 1 deletion internal/query/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"go.infratographer.com/permissions-api/internal/iapl"
"go.infratographer.com/permissions-api/internal/spicedbx"
"go.infratographer.com/permissions-api/internal/storage"
"go.infratographer.com/permissions-api/internal/storage/teststore"
"go.infratographer.com/permissions-api/internal/testingx"
"go.infratographer.com/permissions-api/internal/types"
Expand Down Expand Up @@ -229,7 +230,7 @@ func TestRoleUpdate(t *testing.T) {
Input: gidx.MustNewID(RolePrefix),
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
require.Error(t, res.Err)
assert.ErrorIs(t, res.Err, ErrRoleNotFound)
assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound)
},
},
{
Expand Down
24 changes: 24 additions & 0 deletions internal/storage/roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RoleService interface {
CreateRole(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (Role, error)
UpdateRole(ctx context.Context, actorID, roleID gidx.PrefixedID, name string) (Role, error)
DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (Role, error)
LockRoleForUpdate(ctx context.Context, roleID gidx.PrefixedID) error
}

// Role represents a role in the database.
Expand Down Expand Up @@ -74,6 +75,29 @@ func (e *engine) GetRoleByID(ctx context.Context, id gidx.PrefixedID) (Role, err
return role, nil
}

// LockRoleForUpdate locks the provided role's record to be updated to ensure consistency.
// If no role exists an ErrNoRoleFound error is returned.
func (e *engine) LockRoleForUpdate(ctx context.Context, id gidx.PrefixedID) error {
db, err := getContextDBQuery(ctx, e)
if err != nil {
return err
}

var one int

err = db.QueryRowContext(ctx, `SELECT 1 FROM roles WHERE id = $1 FOR UPDATE`, id.String()).Scan(&one)

if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("%w: %s", ErrNoRoleFound, id.String())
}

return fmt.Errorf("%w: %s", err, id.String())
}

return nil
}

// GetResourceRoleByName retrieves a role from the database by the provided resource ID and role name.
// If no role exists an ErrRoleNotFound error is returned.
func (e *engine) GetResourceRoleByName(ctx context.Context, resourceID gidx.PrefixedID, name string) (Role, error) {
Expand Down

0 comments on commit efd9772

Please sign in to comment.