diff --git a/internal/query/relations.go b/internal/query/relations.go index c9a7a174..58480bf0 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -408,7 +408,7 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou return types.Role{}, err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -913,12 +913,18 @@ func (e *engine) ListRoles(ctx context.Context, resource types.Resource) ([]type // listRoleResourceActions returns all resources and action relations for the provided resource type to the provided role. // Note: The actions returned by this function are the spicedb relationship action. -func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resource, resTypeName string) (map[types.Resource][]string, error) { - resType := e.namespace + "/" + resTypeName +func (e *engine) listRoleResourceActions(ctx context.Context, role storage.Role) ([]string, error) { + roleOwnerResource, err := e.NewResourceFromID(role.ResourceID) + if err != nil { + return nil, err + } + + resType := e.namespace + "/" + roleOwnerResource.Type roleType := e.namespace + "/role" filter := &pb.RelationshipFilter{ - ResourceType: resType, + ResourceType: resType, + OptionalResourceId: roleOwnerResource.ID.String(), OptionalSubjectFilter: &pb.SubjectFilter{ SubjectType: roleType, OptionalSubjectId: role.ID.String(), @@ -933,84 +939,47 @@ func (e *engine) listRoleResourceActions(ctx context.Context, role types.Resourc return nil, err } - resourceIDActions := make(map[gidx.PrefixedID][]string) + out := make([]string, 0, len(relationships)) for _, rel := range relationships { - resourceID, err := gidx.Parse(rel.Resource.ObjectId) - if err != nil { - return nil, err - } - - resourceIDActions[resourceID] = append(resourceIDActions[resourceID], rel.Relation) - } - - resourceActions := make(map[types.Resource][]string, len(resourceIDActions)) - - for resID, actions := range resourceIDActions { - res, err := e.NewResourceFromID(resID) - if err != nil { - return nil, err - } + action := relationToAction(rel.Relation) - resourceActions[res] = actions + out = append(out, action) } - return resourceActions, nil + return out, nil } -// GetRole gets the role with it's actions. +// GetRole gets the given role and its actions. func (e *engine) GetRole(ctx context.Context, roleResource types.Resource) (types.Role, error) { - var ( - resActions map[types.Resource][]string - err error - ) - - for _, resType := range e.schemaRoleables { - resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name) - if err != nil { - return types.Role{}, err - } - - // roles are only ever created for a single resource, so we can break after the first one is found. - if len(resActions) != 0 { - break - } + dbRole, err := e.getStorageRole(ctx, roleResource) + if err != nil { + return types.Role{}, err } - if len(resActions) > 1 { - return types.Role{}, ErrRoleHasTooManyResources + actions, err := e.listRoleResourceActions(ctx, dbRole) + if err != nil { + return types.Role{}, err } - // returns the first resources actions. - for _, actions := range resActions { - for i, action := range actions { - actions[i] = relationToAction(action) - } - - dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID) - if err != nil && !errors.Is(err, storage.ErrNoRoleFound) { - e.logger.Error("error while getting role", zap.Error(err)) - } + out := types.Role{ + ID: roleResource.ID, + Name: dbRole.Name, + Actions: actions, - return types.Role{ - ID: roleResource.ID, - Name: dbRole.Name, - Actions: actions, - - ResourceID: dbRole.ResourceID, - CreatedBy: dbRole.CreatedBy, - UpdatedBy: dbRole.UpdatedBy, - CreatedAt: dbRole.CreatedAt, - UpdatedAt: dbRole.UpdatedAt, - }, nil + ResourceID: dbRole.ResourceID, + CreatedBy: dbRole.CreatedBy, + UpdatedBy: dbRole.UpdatedBy, + CreatedAt: dbRole.CreatedAt, + UpdatedAt: dbRole.UpdatedAt, } - return types.Role{}, ErrRoleNotFound + return out, nil } // GetRoleResource gets the role's assigned resource. func (e *engine) GetRoleResource(ctx context.Context, roleResource types.Resource) (types.Resource, error) { - dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID) + dbRole, err := e.getStorageRole(ctx, roleResource) if err != nil { return types.Resource{}, err } @@ -1029,7 +998,17 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er return err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + dbRole, err := e.getStorageRole(ctx, roleResource) + if err != nil { + return err + } + + roleOwnerResource, err := e.NewResourceFromID(dbRole.ResourceID) + if err != nil { + return err + } + + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -1041,20 +1020,11 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er return err } - var resActions map[types.Resource][]string - - for _, resType := range e.schemaRoleables { - resActions, err = e.listRoleResourceActions(ctx, roleResource, resType.Name) - if err != nil { - logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - - return err - } + actions, err := e.listRoleResourceActions(ctx, dbRole) + if err != nil { + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - // roles are only ever created for a single resource, so we can break after the first one is found. - if len(resActions) != 0 { - break - } + return err } roleType := e.namespace + "/role" @@ -1069,15 +1039,16 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er }, } - for resource, relActions := range resActions { - for _, relAction := range relActions { - filters = append(filters, &pb.RelationshipFilter{ - ResourceType: e.namespace + "/" + resource.Type, - OptionalResourceId: resource.ID.String(), - OptionalRelation: relAction, - OptionalSubjectFilter: roleSubjectFilter, - }) - } + ownerType := e.namespace + "/" + roleOwnerResource.Type + ownerIDStr := roleOwnerResource.ID.String() + + for _, relAction := range actions { + filters = append(filters, &pb.RelationshipFilter{ + ResourceType: ownerType, + OptionalResourceId: ownerIDStr, + OptionalRelation: relAction, + OptionalSubjectFilter: roleSubjectFilter, + }) } _, err = e.store.DeleteRole(dbCtx, roleResource.ID) @@ -1229,3 +1200,29 @@ func (e *engine) applyUpdates(ctx context.Context, updates []*pb.RelationshipUpd return nil } + +func (e *engine) getStorageRole(ctx context.Context, roleResource types.Resource) (storage.Role, error) { + dbRole, err := e.store.GetRoleByID(ctx, roleResource.ID) + + switch { + case err == nil: + return dbRole, nil + case errors.Is(err, storage.ErrNoRoleFound): + return storage.Role{}, ErrRoleNotFound + default: + return storage.Role{}, err + } +} + +func (e *engine) lockRoleForUpdate(ctx context.Context, roleResource types.Resource) error { + err := e.store.LockRoleForUpdate(ctx, roleResource.ID) + + switch { + case err == nil: + return nil + case errors.Is(err, storage.ErrNoRoleFound): + return ErrRoleNotFound + default: + return err + } +} diff --git a/internal/query/relations_test.go b/internal/query/relations_test.go index 62860f04..c74aef5d 100644 --- a/internal/query/relations_test.go +++ b/internal/query/relations_test.go @@ -13,7 +13,6 @@ import ( "go.infratographer.com/permissions-api/internal/iapl" "go.infratographer.com/permissions-api/internal/spicedbx" - "go.infratographer.com/permissions-api/internal/storage" "go.infratographer.com/permissions-api/internal/storage/teststore" "go.infratographer.com/permissions-api/internal/testingx" "go.infratographer.com/permissions-api/internal/types" @@ -96,36 +95,47 @@ func TestCreateRoles(t *testing.T) { ctx := context.Background() e := testEngine(ctx, t, namespace, testPolicy()) - testCases := []testingx.TestCase[[]string, []types.Role]{ + testCases := []testingx.TestCase[[]string, types.Role]{ { Name: "CreateInvalidAction", Input: []string{ "bad_action", }, - CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) { + CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { assert.Error(t, res.Err) }, }, + { + Name: "CreateNoActions", + Input: []string{}, + CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { + expActions := []string{} + + require.NoError(t, res.Err) + + role := res.Success + assert.Equal(t, expActions, role.Actions) + }, + }, { Name: "CreateSuccess", Input: []string{ "loadbalancer_get", }, - CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Role]) { + CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { expActions := []string{ "loadbalancer_get", } - assert.NoError(t, res.Err) - require.Equal(t, 1, len(res.Success)) + require.NoError(t, res.Err) - role := res.Success[0] + role := res.Success assert.Equal(t, expActions, role.Actions) }, }, } - testFn := func(ctx context.Context, actions []string) testingx.TestResult[[]types.Role] { + testFn := func(ctx context.Context, actions []string) testingx.TestResult[types.Role] { tenID, err := gidx.NewID("tnntten") require.NoError(t, err) tenRes, err := e.NewResourceFromID(tenID) @@ -133,17 +143,20 @@ func TestCreateRoles(t *testing.T) { actorRes, err := e.NewResourceFromID(gidx.MustNewID("idntusr")) require.NoError(t, err) - _, err = e.CreateRole(ctx, actorRes, tenRes, "test", actions) + role, err := e.CreateRole(ctx, actorRes, tenRes, "test", actions) if err != nil { - return testingx.TestResult[[]types.Role]{ + return testingx.TestResult[types.Role]{ Err: err, } } - roles, err := e.ListRoles(ctx, tenRes) + roleResource, err := e.NewResourceFromID(role.ID) + require.NoError(t, err) - return testingx.TestResult[[]types.Role]{ - Success: roles, + obs, err := e.GetRole(ctx, roleResource) + + return testingx.TestResult[types.Role]{ + Success: obs, Err: err, } } @@ -232,7 +245,7 @@ func TestRoleUpdate(t *testing.T) { Input: gidx.MustNewID(RolePrefix), CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { require.Error(t, res.Err) - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, }, { diff --git a/internal/query/roles_v2.go b/internal/query/roles_v2.go index e13f4282..bdfa9287 100644 --- a/internal/query/roles_v2.go +++ b/internal/query/roles_v2.go @@ -202,7 +202,7 @@ func (e *engine) GetRoleV2(ctx context.Context, role types.Resource) (types.Role } // 2. Get role info (name, created_by, etc.) from permissions API DB - dbrole, err := e.store.GetRoleByID(ctx, role.ID) + dbrole, err := e.getStorageRole(ctx, role) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) @@ -234,7 +234,7 @@ func (e *engine) UpdateRoleV2(ctx context.Context, actor, roleResource types.Res return types.Role{}, err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -360,7 +360,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource) return err } - err = e.store.LockRoleForUpdate(dbCtx, roleResource.ID) + err = e.lockRoleForUpdate(dbCtx, roleResource) if err != nil { sErr := fmt.Errorf("failed to lock role: %s: %w", roleResource.ID, err) @@ -399,7 +399,7 @@ func (e *engine) DeleteRoleV2(ctx context.Context, roleResource types.Resource) return err } - dbRole, err := e.store.GetRoleByID(dbCtx, roleResource.ID) + dbRole, err := e.getStorageRole(dbCtx, roleResource) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) diff --git a/internal/query/roles_v2_test.go b/internal/query/roles_v2_test.go index 1aee76e6..68d0e596 100644 --- a/internal/query/roles_v2_test.go +++ b/internal/query/roles_v2_test.go @@ -11,7 +11,6 @@ import ( "google.golang.org/grpc/status" "go.infratographer.com/permissions-api/internal/iapl" - "go.infratographer.com/permissions-api/internal/storage" "go.infratographer.com/permissions-api/internal/testingx" "go.infratographer.com/permissions-api/internal/types" ) @@ -176,7 +175,7 @@ func TestGetRoleV2(t *testing.T) { Name: "GetRoleNotFound", Input: missingRes, CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, }, { @@ -324,7 +323,7 @@ func TestUpdateRolesV2(t *testing.T) { role: notfoundRes, }, CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, Sync: true, }, @@ -460,7 +459,7 @@ func TestDeleteRolesV2(t *testing.T) { Name: "DeleteRoleNotFound", Input: notfoundRes, CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[types.Role]) { - assert.ErrorIs(t, res.Err, storage.ErrNoRoleFound) + assert.ErrorIs(t, res.Err, ErrRoleNotFound) }, Sync: true, }, @@ -510,7 +509,7 @@ func TestDeleteRolesV2(t *testing.T) { assert.NoError(t, res.Err) _, err := e.GetRoleV2(ctx, roleRes) - assert.ErrorIs(t, err, storage.ErrNoRoleFound) + assert.ErrorIs(t, err, ErrRoleNotFound) }, Sync: true, }, diff --git a/internal/storage/errors.go b/internal/storage/errors.go index 61c00464..4ee5e6cb 100644 --- a/internal/storage/errors.go +++ b/internal/storage/errors.go @@ -8,7 +8,7 @@ import ( var ( // ErrNoRoleFound is returned when no role is found when retrieving or deleting a role. - ErrNoRoleFound = errors.New("role not found") + ErrNoRoleFound = errors.New("role not found in database") // ErrRoleAlreadyExists is returned when creating a role which already has an existing record. ErrRoleAlreadyExists = errors.New("role already exists")