Skip to content

Commit

Permalink
Simplify role fetching logic in query engine
Browse files Browse the repository at this point in the history
Prior implementations of the query engine fetched role information
such as the owning resource ID directly from SpiceDB, as it was the
only data store available. With the introduction of CRDB, that is no
longer the case and the CRDB SQL table should be considered the
authoritative source of most role data. This commit updates the query
engine to fetch role resource owner ID and other data from the SQL DB
whenever possible, getting rid of some obscure failure modes that
occur when a role has no associated actions.

Signed-off-by: John Schaeffer <[email protected]>
  • Loading branch information
jnschaeffer committed Aug 22, 2024
1 parent 91d9a4e commit bb53015
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 98 deletions.
152 changes: 68 additions & 84 deletions internal/query/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -913,12 +913,18 @@ func (e *engine) ListRoles(ctx context.Context, resource types.Resource) ([]type

// listRoleResourceActions returns all resources and action relations for the provided resource type to the provided role.
// Note: The actions returned by this function are the spicedb relationship action.
func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resource, resTypeName string) (map[types.Resource][]string, error) {
resType := e.namespace + "/" + resTypeName
func (e *engine) listRoleResourceActions(ctx context.Context, role storage.Role) ([]string, error) {
roleOwnerResource, err := e.NewResourceFromID(role.ResourceID)
if err != nil {
return nil, err
}

resType := e.namespace + "/" + roleOwnerResource.Type
roleType := e.namespace + "/role"

filter := &pb.RelationshipFilter{
ResourceType: resType,
ResourceType: resType,
OptionalResourceId: roleOwnerResource.ID.String(),
OptionalSubjectFilter: &pb.SubjectFilter{
SubjectType: roleType,
OptionalSubjectId: role.ID.String(),
Expand All @@ -933,84 +939,47 @@ func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resourc
return nil, err
}

resourceIDActions := make(map[gidx.PrefixedID][]string)
out := make([]string, 0, len(relationships))

for _, rel := range relationships {
resourceID, err := gidx.Parse(rel.Resource.ObjectId)
if err != nil {
return nil, err
}

resourceIDActions[resourceID] = append(resourceIDActions[resourceID], rel.Relation)
}

resourceActions := make(map[types.Resource][]string, len(resourceIDActions))

for resID, actions := range resourceIDActions {
res, err := e.NewResourceFromID(resID)
if err != nil {
return nil, err
}
action := relationToAction(rel.Relation)

resourceActions[res] = actions
out = append(out, action)
}

return resourceActions, nil
return out, nil
}

// GetRole gets the role with it's actions.
// GetRole gets the given role and its actions.
func (e *engine) GetRole(ctx context.Context, roleResource types.Resource) (types.Role, error) {
var (
resActions map[types.Resource][]string
err error
)

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
return types.Role{}, err
}

// roles are only ever created for a single resource, so we can break after the first one is found.
if len(resActions) != 0 {
break
}
dbRole, err := e.getStorageRole(ctx, roleResource)
if err != nil {
return types.Role{}, err
}

if len(resActions) > 1 {
return types.Role{}, ErrRoleHasTooManyResources
actions, err := e.listRoleResourceActions(ctx, dbRole)
if err != nil {
return types.Role{}, err
}

// returns the first resources actions.
for _, actions := range resActions {
for i, action := range actions {
actions[i] = relationToAction(action)
}
out := types.Role{
ID: roleResource.ID,
Name: dbRole.Name,
Actions: actions,

dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
if err != nil && !errors.Is(err, storage.ErrNoRoleFound) {
e.logger.Error("error while getting role", zap.Error(err))
}

return types.Role{
ID: roleResource.ID,
Name: dbRole.Name,
Actions: actions,

ResourceID: dbRole.ResourceID,
CreatedBy: dbRole.CreatedBy,
UpdatedBy: dbRole.UpdatedBy,
CreatedAt: dbRole.CreatedAt,
UpdatedAt: dbRole.UpdatedAt,
}, nil
ResourceID: dbRole.ResourceID,
CreatedBy: dbRole.CreatedBy,
UpdatedBy: dbRole.UpdatedBy,
CreatedAt: dbRole.CreatedAt,
UpdatedAt: dbRole.UpdatedAt,
}

return types.Role{}, ErrRoleNotFound
return out, nil
}

// GetRoleResource gets the role's assigned resource.
func (e *engine) GetRoleResource(ctx context.Context, roleResource types.Resource) (types.Resource, error) {
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)
dbRole, err := e.getStorageRole(ctx, roleResource)
if err != nil {
return types.Resource{}, err
}
Expand All @@ -1029,6 +998,16 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
return err
}

dbRole, err := e.getStorageRole(ctx, roleResource)
if err != nil {
return err
}

roleOwnerResource, err := e.NewResourceFromID(dbRole.ResourceID)
if err != nil {
return err
}

err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID)
if err != nil {
sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err)
Expand All @@ -1041,20 +1020,11 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
return err
}

var resActions map[types.Resource][]string

for _, resType := range e.schemaRoleables {
resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

return err
}
actions, err := e.listRoleResourceActions(ctx, dbRole)
if err != nil {
logRollbackErr(e.logger, e.store.RollbackContext(dbCtx))

// roles are only ever created for a single resource, so we can break after the first one is found.
if len(resActions) != 0 {
break
}
return err
}

roleType := e.namespace + "/role"
Expand All @@ -1069,15 +1039,16 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er
},
}

for resource, relActions := range resActions {
for _, relAction := range relActions {
filters = append(filters, &pb.RelationshipFilter{
ResourceType: e.namespace + "/" + resource.Type,
OptionalResourceId: resource.ID.String(),
OptionalRelation: relAction,
OptionalSubjectFilter: roleSubjectFilter,
})
}
ownerType := e.namespace + "/" + roleOwnerResource.Type
ownerIDStr := roleOwnerResource.ID.String()

for _, relAction := range actions {
filters = append(filters, &pb.RelationshipFilter{
ResourceType: ownerType,
OptionalResourceId: ownerIDStr,
OptionalRelation: relAction,
OptionalSubjectFilter: roleSubjectFilter,
})
}

_, err = e.store.DeleteRole(dbCtx, roleResource.ID)
Expand Down Expand Up @@ -1229,3 +1200,16 @@ func (e *engine) applyUpdates(ctx context.Context, updates []*pb.RelationshipUpd

return nil
}

func (e *engine) getStorageRole(ctx context.Context, roleResource types.Resource) (storage.Role, error) {
dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID)

switch {
case err == nil:
return dbRole, nil
case errors.Is(err, storage.ErrNoRoleFound):
return storage.Role{}, ErrRoleNotFound
default:
return storage.Role{}, err
}
}
38 changes: 26 additions & 12 deletions internal/query/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,54 +96,68 @@ func TestCreateRoles(t *testing.T) {
ctx := context.Background()
e := testEngine(ctx, t, namespace, testPolicy())

testCases := []testingx.TestCase[[]string, []types.Role]{
testCases := []testingx.TestCase[[]string, types.Role]{
{
Name: "CreateInvalidAction",
Input: []string{
"bad_action",
},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) {
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
assert.Error(t, res.Err)
},
},
{
Name: "CreateNoActions",
Input: []string{},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
expActions := []string{}

require.NoError(t, res.Err)

role := res.Success
assert.Equal(t, expActions, role.Actions)
},
},
{
Name: "CreateSuccess",
Input: []string{
"loadbalancer_get",
},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) {
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) {
expActions := []string{
"loadbalancer_get",
}

assert.NoError(t, res.Err)
require.Equal(t, 1, len(res.Success))
require.NoError(t, res.Err)

role := res.Success[0]
role := res.Success
assert.Equal(t, expActions, role.Actions)
},
},
}

testFn := func(ctx context.Context, actions []string) testingx.TestResult[[]types.Role] {
testFn := func(ctx context.Context, actions []string) testingx.TestResult[types.Role] {
tenID, err := gidx.NewID("tnntten")
require.NoError(t, err)
tenRes, err := e.NewResourceFromID(tenID)
require.NoError(t, err)
actorRes, err := e.NewResourceFromID(gidx.MustNewID("idntusr"))
require.NoError(t, err)

_, err = e.CreateRole(ctx, actorRes, tenRes, "test", actions)
role, err := e.CreateRole(ctx, actorRes, tenRes, "test", actions)
if err != nil {
return testingx.TestResult[[]types.Role]{
return testingx.TestResult[types.Role]{
Err: err,
}
}

roles, err := e.ListRoles(ctx, tenRes)
roleResource, err := e.NewResourceFromID(role.ID)
require.NoError(t, err)

return testingx.TestResult[[]types.Role]{
Success: roles,
obs, err := e.GetRole(ctx, roleResource)

return testingx.TestResult[types.Role]{
Success: obs,
Err: err,
}
}
Expand Down
4 changes: 2 additions & 2 deletions internal/query/roles_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ func (e *engine) GetRoleV2(ctx context.Context, role types.Resource) (types.Role
}

// 2. Get role info (name, created_by, etc.) from permissions API DB
dbrole, err := e.store.GetRoleByID(ctx, role.ID)
dbrole, err := e.getStorageRole(ctx, role)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
Expand Down Expand Up @@ -399,7 +399,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource)
return err
}

dbRole, err := e.store.GetRoleByID(dbCtx, roleResource.ID)
dbRole, err := e.getStorageRole(dbCtx, roleResource)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, err.Error())
Expand Down

0 comments on commit bb53015

Please sign in to comment.