From 31bbd1c2742fc45015f0a4dfac7a5aa72270c55d Mon Sep 17 00:00:00 2001 From: John Schaeffer Date: Thu, 22 Aug 2024 17:05:32 -0400 Subject: [PATCH] Simplify role fetching logic in query engine (#282) * Simplify role fetching logic in query engine Prior implementations of the query engine fetched role information such as the owning resource ID directly from SpiceDB, as it was the only data store available. With the introduction of CRDB, that is no longer the case and the CRDB SQL table should be considered the authoritative source of most role data. This commit updates the query engine to fetch role resource owner ID and other data from the SQL DB whenever possible, getting rid of some obscure failure modes that occur when a role has no associated actions. Signed-off-by: John Schaeffer * Fix error type in RBAC v2 tests As described. Signed-off-by: John Schaeffer * Wrap LockRoleForUpdate in a method to return non-DB errors As described. Signed-off-by: John Schaeffer * Fix incorrect error in role update test case As described. Signed-off-by: John Schaeffer --------- Signed-off-by: John Schaeffer --- internal/query/relations.go | 169 +++++++++++++++---------------- internal/query/relations_test.go | 41 +++++--- internal/query/roles_v2.go | 8 +- internal/query/roles_v2_test.go | 9 +- internal/storage/errors.go | 2 +- 5 files changed, 119 insertions(+), 110 deletions(-) diff --git a/internal/query/relations.go b/internal/query/relations.go index c9a7a174..58480bf0 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -408,7 +408,7 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou return types.Role{}, err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -913,12 +913,18 @@ func (e *engine) ListRoles(ctx context.Context, resource types.Resource) ([]type // listRoleResourceActions returns all resources and action relations for the provided resource type to the provided role. // Note: The actions returned by this function are the spicedb relationship action. -func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resource, resTypeName string) (map[types.Resource][]string, error) { - resType := e.namespace + "/" + resTypeName +func (e *engine) listRoleResourceActions(ctx context.Context, role storage.Role) ([]string, error) { + roleOwnerResource, err := e.NewResourceFromID(role.ResourceID) + if err != nil { + return nil, err + } + + resType := e.namespace + "/" + roleOwnerResource.Type roleType := e.namespace + "/role" filter := &pb.RelationshipFilter{ - ResourceType: resType, + ResourceType: resType, + OptionalResourceId: roleOwnerResource.ID.String(), OptionalSubjectFilter: &pb.SubjectFilter{ SubjectType: roleType, OptionalSubjectId: role.ID.String(), @@ -933,84 +939,47 @@ func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resourc return nil, err } - resourceIDActions := make(map[gidx.PrefixedID][]string) + out := make([]string, 0, len(relationships)) for _, rel := range relationships { - resourceID, err := gidx.Parse(rel.Resource.ObjectId) - if err != nil { - return nil, err - } - - resourceIDActions[resourceID] = append(resourceIDActions[resourceID], rel.Relation) - } - - resourceActions := make(map[types.Resource][]string, len(resourceIDActions)) - - for resID, actions := range resourceIDActions { - res, err := e.NewResourceFromID(resID) - if err != nil { - return nil, err - } + action := relationToAction(rel.Relation) - resourceActions[res] = actions + out = append(out, action) } - return resourceActions, nil + return out, nil } -// GetRole gets the role with it's actions. +// GetRole gets the given role and its actions. func (e *engine) GetRole(ctx context.Context, roleResource types.Resource) (types.Role, error) { - var ( - resActions map[types.Resource][]string - err error - ) - - for _, resType := range e.schemaRoleables { - resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name) - if err != nil { - return types.Role{}, err - } - - // roles are only ever created for a single resource, so we can break after the first one is found. - if len(resActions) != 0 { - break - } + dbRole, err := e.getStorageRole(ctx, roleResource) + if err != nil { + return types.Role{}, err } - if len(resActions) > 1 { - return types.Role{}, ErrRoleHasTooManyResources + actions, err := e.listRoleResourceActions(ctx, dbRole) + if err != nil { + return types.Role{}, err } - // returns the first resources actions. - for _, actions := range resActions { - for i, action := range actions { - actions[i] = relationToAction(action) - } - - dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID) - if err != nil && !errors.Is(err, storage.ErrNoRoleFound) { - e.logger.Error("error while getting role", zap.Error(err)) - } + out := types.Role{ + ID: roleResource.ID, + Name: dbRole.Name, + Actions: actions, - return types.Role{ - ID: roleResource.ID, - Name: dbRole.Name, - Actions: actions, - - ResourceID: dbRole.ResourceID, - CreatedBy: dbRole.CreatedBy, - UpdatedBy: dbRole.UpdatedBy, - CreatedAt: dbRole.CreatedAt, - UpdatedAt: dbRole.UpdatedAt, - }, nil + ResourceID: dbRole.ResourceID, + CreatedBy: dbRole.CreatedBy, + UpdatedBy: dbRole.UpdatedBy, + CreatedAt: dbRole.CreatedAt, + UpdatedAt: dbRole.UpdatedAt, } - return types.Role{}, ErrRoleNotFound + return out, nil } // GetRoleResource gets the role's assigned resource. func (e *engine) GetRoleResource(ctx context.Context, roleResource types.Resource) (types.Resource, error) { - dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID) + dbRole, err := e.getStorageRole(ctx, roleResource) if err != nil { return types.Resource{}, err } @@ -1029,7 +998,17 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er return err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + dbRole, err := e.getStorageRole(ctx, roleResource) + if err != nil { + return err + } + + roleOwnerResource, err := e.NewResourceFromID(dbRole.ResourceID) + if err != nil { + return err + } + + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -1041,20 +1020,11 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er return err } - var resActions map[types.Resource][]string - - for _, resType := range e.schemaRoleables { - resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name) - if err != nil { - logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - - return err - } + actions, err := e.listRoleResourceActions(ctx, dbRole) + if err != nil { + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - // roles are only ever created for a single resource, so we can break after the first one is found. - if len(resActions) != 0 { - break - } + return err } roleType := e.namespace + "/role" @@ -1069,15 +1039,16 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er }, } - for resource, relActions := range resActions { - for _, relAction := range relActions { - filters = append(filters, &pb.RelationshipFilter{ - ResourceType: e.namespace + "/" + resource.Type, - OptionalResourceId: resource.ID.String(), - OptionalRelation: relAction, - OptionalSubjectFilter: roleSubjectFilter, - }) - } + ownerType := e.namespace + "/" + roleOwnerResource.Type + ownerIDStr := roleOwnerResource.ID.String() + + for _, relAction := range actions { + filters = append(filters, &pb.RelationshipFilter{ + ResourceType: ownerType, + OptionalResourceId: ownerIDStr, + OptionalRelation: relAction, + OptionalSubjectFilter: roleSubjectFilter, + }) } _, err = e.store.DeleteRole(dbCtx, roleResource.ID) @@ -1229,3 +1200,29 @@ func (e *engine) applyUpdates(ctx context.Context, updates []*pb.RelationshipUpd return nil } + +func (e *engine) getStorageRole(ctx context.Context, roleResource types.Resource) (storage.Role, error) { + dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID) + + switch { + case err == nil: + return dbRole, nil + case errors.Is(err, storage.ErrNoRoleFound): + return storage.Role{}, ErrRoleNotFound + default: + return storage.Role{}, err + } +} + +func (e *engine) lockRoleForUpdate(ctx context.Context, roleResource types.Resource) error { + err := e.store.LockRoleForUpdate(ctx, roleResource.ID) + + switch { + case err == nil: + return nil + case errors.Is(err, storage.ErrNoRoleFound): + return ErrRoleNotFound + default: + return err + } +} diff --git a/internal/query/relations_test.go b/internal/query/relations_test.go index 62860f04..c74aef5d 100644 --- a/internal/query/relations_test.go +++ b/internal/query/relations_test.go @@ -13,7 +13,6 @@ import ( "go.infratographer.com/permissions-api/internal/iapl" "go.infratographer.com/permissions-api/internal/spicedbx" - "go.infratographer.com/permissions-api/internal/storage" "go.infratographer.com/permissions-api/internal/storage/teststore" "go.infratographer.com/permissions-api/internal/testingx" "go.infratographer.com/permissions-api/internal/types" @@ -96,36 +95,47 @@ func TestCreateRoles(t *testing.T) { ctx := context.Background() e := testEngine(ctx, t, namespace, testPolicy()) - testCases := []testingx.TestCase[[]string, []types.Role]{ + testCases := []testingx.TestCase[[]string, types.Role]{ { Name: "CreateInvalidAction", Input: []string{ "bad_action", }, - CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) { + CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { assert.Error(t, res.Err) }, }, + { + Name: "CreateNoActions", + Input: []string{}, + CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { + expActions := []string{} + + require.NoError(t, res.Err) + + role := res.Success + assert.Equal(t, expActions, role.Actions) + }, + }, { Name: "CreateSuccess", Input: []string{ "loadbalancer_get", }, - CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) { + CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { expActions := []string{ "loadbalancer_get", } - assert.NoError(t, res.Err) - require.Equal(t, 1, len(res.Success)) + require.NoError(t, res.Err) - role := res.Success[0] + role := res.Success assert.Equal(t, expActions, role.Actions) }, }, } - testFn := func(ctx context.Context, actions []string) testingx.TestResult[[]types.Role] { + testFn := func(ctx context.Context, actions []string) testingx.TestResult[types.Role] { tenID, err := gidx.NewID("tnntten") require.NoError(t, err) tenRes, err := e.NewResourceFromID(tenID) @@ -133,17 +143,20 @@ func TestCreateRoles(t *testing.T) { actorRes, err := e.NewResourceFromID(gidx.MustNewID("idntusr")) require.NoError(t, err) - _, err = e.CreateRole(ctx, actorRes, tenRes, "test", actions) + role, err := e.CreateRole(ctx, actorRes, tenRes, "test", actions) if err != nil { - return testingx.TestResult[[]types.Role]{ + return testingx.TestResult[types.Role]{ Err: err, } } - roles, err := e.ListRoles(ctx, tenRes) + roleResource, err := e.NewResourceFromID(role.ID) + require.NoError(t, err) - return testingx.TestResult[[]types.Role]{ - Success: roles, + obs, err := e.GetRole(ctx, roleResource) + + return testingx.TestResult[types.Role]{ + Success: obs, Err: err, } } @@ -232,7 +245,7 @@ func TestRoleUpdate(t *testing.T) { Input: gidx.MustNewID(RolePrefix), CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { require.Error(t, res.Err) - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, }, { diff --git a/internal/query/roles_v2.go b/internal/query/roles_v2.go index e13f4282..bdfa9287 100644 --- a/internal/query/roles_v2.go +++ b/internal/query/roles_v2.go @@ -202,7 +202,7 @@ func (e *engine) GetRoleV2(ctx context.Context, role types.Resource) (types.Role } // 2. Get role info (name, created_by, etc.) from permissions API DB - dbrole, err := e.store.GetRoleByID(ctx, role.ID) + dbrole, err := e.getStorageRole(ctx, role) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -234,7 +234,7 @@ func (e *engine) UpdateRoleV2(ctx context.Context, actor, roleResource types.Res return types.Role{}, err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -360,7 +360,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource) return err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -399,7 +399,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource) return err } - dbRole, err := e.store.GetRoleByID(dbCtx, roleResource.ID) + dbRole, err := e.getStorageRole(dbCtx, roleResource) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) diff --git a/internal/query/roles_v2_test.go b/internal/query/roles_v2_test.go index 1aee76e6..68d0e596 100644 --- a/internal/query/roles_v2_test.go +++ b/internal/query/roles_v2_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc/status" "go.infratographer.com/permissions-api/internal/iapl" - "go.infratographer.com/permissions-api/internal/storage" "go.infratographer.com/permissions-api/internal/testingx" "go.infratographer.com/permissions-api/internal/types" ) @@ -176,7 +175,7 @@ func TestGetRoleV2(t *testing.T) { Name: "GetRoleNotFound", Input: missingRes, CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, }, { @@ -324,7 +323,7 @@ func TestUpdateRolesV2(t *testing.T) { role: notfoundRes, }, CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, Sync: true, }, @@ -460,7 +459,7 @@ func TestDeleteRolesV2(t *testing.T) { Name: "DeleteRoleNotFound", Input: notfoundRes, CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, Sync: true, }, @@ -510,7 +509,7 @@ func TestDeleteRolesV2(t *testing.T) { assert.NoError(t, res.Err) _, err := e.GetRoleV2(ctx, roleRes) - assert.ErrorIs(t, err, storage.ErrNoRoleFound) + assert.ErrorIs(t, err, ErrRoleNotFound) }, Sync: true, }, diff --git a/internal/storage/errors.go b/internal/storage/errors.go index 61c00464..4ee5e6cb 100644 --- a/internal/storage/errors.go +++ b/internal/storage/errors.go @@ -8,7 +8,7 @@ import ( var ( // ErrNoRoleFound is returned when no role is found when retrieving or deleting a role. - ErrNoRoleFound = errors.New("role not found") + ErrNoRoleFound = errors.New("role not found in database") // ErrRoleAlreadyExists is returned when creating a role which already has an existing record. ErrRoleAlreadyExists = errors.New("role already exists")