Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for basic cursors and limits to LookupSubjects #1379

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions internal/datasets/basesubjectset.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,25 @@ func (bss BaseSubjectSet[T]) AsSlice() []T {
return values
}

// SubjectCount returns the number of subjects in the set.
func (bss BaseSubjectSet[T]) SubjectCount() int {
if bss.HasWildcard() {
return bss.ConcreteSubjectCount() + 1
}
return bss.ConcreteSubjectCount()
}

// ConcreteSubjectCount returns the number of concrete subjects in the set.
func (bss BaseSubjectSet[T]) ConcreteSubjectCount() int {
return len(bss.concrete)
}

// HasWildcard returns true if the subject set contains the specialized wildcard subject.
func (bss BaseSubjectSet[T]) HasWildcard() bool {
_, ok := bss.wildcard.get()
return ok
}

// Clone returns a clone of this subject set. Note that this is a shallow clone.
// NOTE: Should only be used when performance is not a concern.
func (bss BaseSubjectSet[T]) Clone() BaseSubjectSet[T] {
Expand Down
7 changes: 7 additions & 0 deletions internal/datasets/subjectset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ func TestSubjectSetAdd(t *testing.T) {
expectedSet := tc.expectedSet
computedSet := existingSet.AsSlice()
testutil.RequireEquivalentSets(t, expectedSet, computedSet)

require.Equal(t, len(expectedSet), existingSet.SubjectCount())
if existingSet.HasWildcard() {
require.Equal(t, len(expectedSet), existingSet.ConcreteSubjectCount()+1)
} else {
require.Equal(t, len(expectedSet), existingSet.ConcreteSubjectCount())
}
})
}
}
Expand Down
9 changes: 9 additions & 0 deletions internal/datasets/subjectsetbyresourceid.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ func (ssr SubjectSetByResourceID) add(resourceID string, subject *v1.FoundSubjec
return ssr.subjectSetByResourceID[resourceID].Add(subject)
}

// ConcreteSubjectCount returns the number concrete subjects in the map.
func (ssr SubjectSetByResourceID) ConcreteSubjectCount() int {
count := 0
for _, subjectSet := range ssr.subjectSetByResourceID {
count += subjectSet.ConcreteSubjectCount()
}
return count
}

// AddFromRelationship adds the subject found in the given relationship to this map, indexed at
// the resource ID specified in the relationship.
func (ssr SubjectSetByResourceID) AddFromRelationship(relationship *core.RelationTuple) error {
Expand Down
3 changes: 3 additions & 0 deletions internal/datasets/subjectsetbyresourceid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func TestSubjectSetByResourceIDBasicOperations(t *testing.T) {
slices.SortFunc(asMap["seconddoc"].FoundSubjects, testutil.CmpSubjects)

require.Equal(t, expected, asMap)
require.Equal(t, 3, ssr.ConcreteSubjectCount())
}

func TestSubjectSetByResourceIDUnionWith(t *testing.T) {
Expand Down Expand Up @@ -88,6 +89,8 @@ func TestSubjectSetByResourceIDUnionWith(t *testing.T) {
},
},
}, found)

require.Equal(t, 5, ssr.ConcreteSubjectCount())
}

func TestSubjectSetByResourceIDIntersectionDifference(t *testing.T) {
Expand Down
19 changes: 19 additions & 0 deletions internal/datasets/subjectsetbytype.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ func (s *SubjectByTypeSet) ForEachType(handler func(rr *core.RelationReference,
}
}

// ForEachTypeUntil invokes the handler for each type of ObjectAndRelation found in the set, along
// with all IDs of objects of that type, until the handler returns an error or false.
func (s *SubjectByTypeSet) ForEachTypeUntil(handler func(rr *core.RelationReference, subjects SubjectSet) (bool, error)) error {
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
for key, subjects := range s.byType {
ns, rel := tuple.MustSplitRelRef(key)
ok, err := handler(&core.RelationReference{
Namespace: ns,
Relation: rel,
}, subjects)
if err != nil {
return err
}
if !ok {
return nil
}
}
return nil
}

// Map runs the mapper function over each type of object in the set, returning a new ONRByTypeSet with
// the object type replaced by that returned by the mapper function.
func (s *SubjectByTypeSet) Map(mapper func(rr *core.RelationReference) (*core.RelationReference, error)) (*SubjectByTypeSet, error) {
Expand Down
20 changes: 20 additions & 0 deletions internal/datasets/subjectsetbytype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ func TestSubjectByTypeSet(t *testing.T) {
}
})
require.True(t, wasFound)

wasFound = false
err := s.ForEachTypeUntil(func(foundRR *core.RelationReference, subjects SubjectSet) (bool, error) {
objectIds := make([]string, 0, len(subjects.AsSlice()))
for _, subject := range subjects.AsSlice() {
require.Empty(t, subject.GetExcludedSubjects())
objectIds = append(objectIds, subject.SubjectId)
}

if rr.Namespace == foundRR.Namespace && rr.Relation == foundRR.Relation {
sort.Strings(objectIds)
require.Equal(t, expected, objectIds)
wasFound = true
return false, nil
}

return true, nil
})
require.True(t, wasFound)
require.NoError(t, err)
}

set := NewSubjectByTypeSet()
Expand Down
2 changes: 1 addition & 1 deletion internal/datastore/proxy/schemacaching/watchingcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ func (swc *schemaWatchCache[T]) readDefinitionsWithNames(ctx context.Context, na
}

// Find whichever trackers are cached.
remainingNames := mapz.NewSet(names...)
remainingNames := mapz.NewSetFromSlice(names)
foundDefs := make([]datastore.RevisionedDefinition[T], 0, len(names))
for _, name := range names {
tracker := swc.getTrackerForName(name)
Expand Down
4 changes: 2 additions & 2 deletions internal/developmentmembership/trackingsubjectset_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ func TestTrackingSubjectSet(t *testing.T) {
found, ok := tc.set.Get(fs.subject)
require.True(ok, "missing expected subject %s", fs.subject)

expectedExcluded := mapz.NewSet[string](fs.excludedSubjectStrings()...)
foundExcluded := mapz.NewSet[string](found.excludedSubjectStrings()...)
expectedExcluded := mapz.NewSetFromSlice(fs.excludedSubjectStrings())
foundExcluded := mapz.NewSetFromSlice(found.excludedSubjectStrings())
require.Len(expectedExcluded.Subtract(foundExcluded).AsSlice(), 0, "mismatch on excluded subjects on %s: expected: %s, found: %s", fs.subject, expectedExcluded, foundExcluded)
require.Len(foundExcluded.Subtract(expectedExcluded).AsSlice(), 0, "mismatch on excluded subjects on %s: expected: %s, found: %s", fs.subject, expectedExcluded, foundExcluded)
} else {
Expand Down
109 changes: 37 additions & 72 deletions internal/dispatch/graph/lookupresources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"slices"
"strings"
"testing"
"time"

Expand All @@ -14,6 +15,7 @@ import (
"github.com/authzed/spicedb/internal/dispatch"
datastoremw "github.com/authzed/spicedb/internal/middleware/datastore"
"github.com/authzed/spicedb/internal/testfixtures"
"github.com/authzed/spicedb/internal/testutil"
"github.com/authzed/spicedb/pkg/genutil/mapz"
core "github.com/authzed/spicedb/pkg/proto/core/v1"
v1 "github.com/authzed/spicedb/pkg/proto/dispatch/v1"
Expand Down Expand Up @@ -333,52 +335,15 @@ func TestMaxDepthLookup(t *testing.T) {
require.Error(err)
}

func joinTuples(first []*core.RelationTuple, second []*core.RelationTuple) []*core.RelationTuple {
return append(first, second...)
}

func genTuplesWithOffset(resourceName string, relation string, subjectName string, subjectID string, offset int, number int) []*core.RelationTuple {
return genTuplesWithCaveat(resourceName, relation, subjectName, subjectID, "", nil, offset, number)
}

func genTuples(resourceName string, relation string, subjectName string, subjectID string, number int) []*core.RelationTuple {
return genTuplesWithOffset(resourceName, relation, subjectName, subjectID, 0, number)
}
type OrderedResolved []*v1.ResolvedResource

func genSubjectTuples(resourceName string, relation string, subjectName string, subjectRelation string, number int) []*core.RelationTuple {
tuples := make([]*core.RelationTuple, 0, number)
for i := 0; i < number; i++ {
tpl := &core.RelationTuple{
ResourceAndRelation: ONR(resourceName, fmt.Sprintf("%s-%d", resourceName, i), relation),
Subject: ONR(subjectName, fmt.Sprintf("%s-%d", subjectName, i), subjectRelation),
}
tuples = append(tuples, tpl)
}
return tuples
}
func (a OrderedResolved) Len() int { return len(a) }

func genTuplesWithCaveat(resourceName string, relation string, subjectName string, subjectID string, caveatName string, context map[string]any, offset int, number int) []*core.RelationTuple {
tuples := make([]*core.RelationTuple, 0, number)
for i := 0; i < number; i++ {
tpl := &core.RelationTuple{
ResourceAndRelation: ONR(resourceName, fmt.Sprintf("%s-%d", resourceName, i+offset), relation),
Subject: ONR(subjectName, subjectID, "..."),
}
if caveatName != "" {
tpl = tuple.MustWithCaveat(tpl, caveatName, context)
}
tuples = append(tuples, tpl)
}
return tuples
func (a OrderedResolved) Less(i, j int) bool {
return strings.Compare(a[i].ResourceId, a[j].ResourceId) < 0
}

func genResourceIds(resourceName string, number int) []string {
resourceIDs := make([]string, 0, number)
for i := 0; i < number; i++ {
resourceIDs = append(resourceIDs, fmt.Sprintf("%s-%d", resourceName, i))
}
return resourceIDs
}
func (a OrderedResolved) Swap(i, j int) { a[i], a[j] = a[j], a[i] }

func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
testCases := []struct {
Expand All @@ -398,13 +363,13 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user
permission view = viewer + editor
}`,
joinTuples(
genTuples("document", "viewer", "user", "tom", 1510),
genTuples("document", "editor", "user", "tom", 1510),
testutil.JoinTuples(
testutil.GenTuples("document", "viewer", "user", "tom", 1510),
testutil.GenTuples("document", "editor", "user", "tom", 1510),
),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 1510),
testutil.GenResourceIds("document", 1510),
},
{
"basic exclusion",
Expand All @@ -415,10 +380,10 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user
permission view = viewer - banned
}`,
genTuples("document", "viewer", "user", "tom", 1010),
testutil.GenTuples("document", "viewer", "user", "tom", 1010),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 1010),
testutil.GenResourceIds("document", 1010),
},
{
"basic intersection",
Expand All @@ -429,13 +394,13 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user
permission view = viewer & editor
}`,
joinTuples(
genTuples("document", "viewer", "user", "tom", 510),
genTuples("document", "editor", "user", "tom", 510),
testutil.JoinTuples(
testutil.GenTuples("document", "viewer", "user", "tom", 510),
testutil.GenTuples("document", "editor", "user", "tom", 510),
),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 510),
testutil.GenResourceIds("document", 510),
},
{
"union and exclused union",
Expand All @@ -448,13 +413,13 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
permission can_view = viewer - banned
permission view = can_view + editor
}`,
joinTuples(
genTuples("document", "viewer", "user", "tom", 1310),
genTuplesWithOffset("document", "editor", "user", "tom", 1250, 1200),
testutil.JoinTuples(
testutil.GenTuples("document", "viewer", "user", "tom", 1310),
testutil.GenTuplesWithOffset("document", "editor", "user", "tom", 1250, 1200),
),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 2450),
testutil.GenResourceIds("document", 2450),
},
{
"basic caveats",
Expand All @@ -468,10 +433,10 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user with somecaveat
permission view = viewer
}`,
genTuplesWithCaveat("document", "viewer", "user", "tom", "somecaveat", map[string]any{"somecondition": 42}, 0, 2450),
testutil.GenTuplesWithCaveat("document", "viewer", "user", "tom", "somecaveat", map[string]any{"somecondition": 42}, 0, 2450),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 2450),
testutil.GenResourceIds("document", 2450),
},
{
"excluded items",
Expand All @@ -482,13 +447,13 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user
permission view = viewer - banned
}`,
joinTuples(
genTuples("document", "viewer", "user", "tom", 1310),
genTuplesWithOffset("document", "banned", "user", "tom", 1210, 100),
testutil.JoinTuples(
testutil.GenTuples("document", "viewer", "user", "tom", 1310),
testutil.GenTuplesWithOffset("document", "banned", "user", "tom", 1210, 100),
),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 1210),
testutil.GenResourceIds("document", 1210),
},
{
"basic caveats with missing field",
Expand All @@ -502,10 +467,10 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user with somecaveat
permission view = viewer
}`,
genTuplesWithCaveat("document", "viewer", "user", "tom", "somecaveat", map[string]any{}, 0, 2450),
testutil.GenTuplesWithCaveat("document", "viewer", "user", "tom", "somecaveat", map[string]any{}, 0, 2450),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 2450),
testutil.GenResourceIds("document", 2450),
},
{
"larger arrow dispatch",
Expand All @@ -519,13 +484,13 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation folder: folder
permission view = folder->viewer
}`,
joinTuples(
genTuples("folder", "viewer", "user", "tom", 150),
genSubjectTuples("document", "folder", "folder", "...", 150),
testutil.JoinTuples(
testutil.GenTuples("folder", "viewer", "user", "tom", 150),
testutil.GenSubjectTuples("document", "folder", "folder", "...", 150),
),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 150),
testutil.GenResourceIds("document", 150),
},
{
"big",
Expand All @@ -536,13 +501,13 @@ func TestLookupResourcesOverSchemaWithCursors(t *testing.T) {
relation viewer: user
permission view = viewer + editor
}`,
joinTuples(
genTuples("document", "viewer", "user", "tom", 15100),
genTuples("document", "editor", "user", "tom", 15100),
testutil.JoinTuples(
testutil.GenTuples("document", "viewer", "user", "tom", 15100),
testutil.GenTuples("document", "editor", "user", "tom", 15100),
),
RR("document", "view"),
ONR("user", "tom", "..."),
genResourceIds("document", 15100),
testutil.GenResourceIds("document", 15100),
},
}

Expand Down
Loading
Loading