Skip to content

Commit

Permalink
Merge pull request #1708 from Permify/refactor/remove-continuous-toke…
Browse files Browse the repository at this point in the history
…n-handling

refactor: remove unnecessary continuous token handling in subject filter
  • Loading branch information
tolgaOzen authored Oct 19, 2024
2 parents 3c7e9a9 + 754478c commit f90d035
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 25 deletions.
30 changes: 5 additions & 25 deletions internal/engines/lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ func (engine *LookupEngine) LookupSubject(ctx context.Context, request *base.Per
}

// If '<>' was found, query all subjects with exclusions if provided
if excludedIds != nil || slices.Contains(ids, "<>") {
if excludedIds != nil || slices.Contains(ids, ALL) {
resp, pct, err := engine.dataReader.QueryUniqueSubjectReferences(
ctx,
request.GetTenantId(),
Expand All @@ -243,30 +243,10 @@ func (engine *LookupEngine) LookupSubject(ctx context.Context, request *base.Per
// Sort the IDs
sort.Strings(ids)

// Initialize the start index as a string (to match token format)
start := ""

// Handle continuous token if present
if request.GetContinuousToken() != "" {
var t database.ContinuousToken
t, err := utils.EncodedContinuousToken{Value: request.GetContinuousToken()}.Decode()
if err != nil {
return nil, err
}
start = t.(utils.ContinuousToken).Value
}

// Find the start index based on the continuous token
// Since the incoming 'ids' are already filtered based on the continuous token,
// there is no need to decode or handle the continuous token manually.
// The startIndex is initialized to 0.
startIndex := 0
if start != "" {
// Locate the position in the sorted list where the ID equals or exceeds the token value
for i, id := range ids {
if id >= start {
startIndex = i
break
}
}
}

// Convert size to int for compatibility with startIndex
pageSize := int(size)
Expand All @@ -284,7 +264,7 @@ func (engine *LookupEngine) LookupSubject(ctx context.Context, request *base.Per
ct = ""
}

// Return the paginated and sorted list of IDs
// Return the paginated list of IDs
return &base.PermissionLookupSubjectResponse{
SubjectIds: ids[startIndex:end], // Slice the IDs based on pagination
ContinuousToken: ct, // Return the next continuous token
Expand Down
175 changes: 175 additions & 0 deletions internal/engines/lookup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4394,5 +4394,180 @@ var _ = Describe("lookup-entity-engine", func() {
}
}
})

It("Weekday Sample: Case 3 pagination", func() {
db, err := factories.DatabaseFactory(
config.Database{
Engine: "memory",
},
)

Expect(err).ShouldNot(HaveOccurred())

conf, err := newSchema(workdaySchemaSubjectFilter)
Expect(err).ShouldNot(HaveOccurred())

schemaWriter := factories.SchemaWriterFactory(db)
err = schemaWriter.WriteSchema(context.Background(), conf)

Expect(err).ShouldNot(HaveOccurred())

type filter struct {
subjectReference string
entity string
assertions map[string][]string
}

tests := struct {
relationships []string
attributes []string
filters []filter
}{
relationships: []string{
"organization:1#member@user:1",
"repository:4#organization@organization:1",

"repository:3#organization@organization:1",
"repository:1#organization@organization:1",

"organization:2#member@user:1",
"organization:2#member@user:3",
"organization:5#member@user:2",
"organization:5#member@user:5",

"repository:12#member@user:1",
"repository:12#member@user:2",

"repository:82#organization@organization:43",

"organization:43#member@user:90",
"organization:43#member@user:54",
},
attributes: []string{
"repository:1$is_public|boolean:true",
"repository:2$is_public|boolean:false",
"repository:3$is_public|boolean:true",
"repository:12$is_public|boolean:true",
"repository:82$is_public|boolean:true",

"organization:1$balance|integer:4000",
"organization:2$balance|integer:6000",

"organization:43$balance|integer:6000",
},
filters: []filter{
{
subjectReference: "user",
entity: "repository:1",
assertions: map[string][]string{
"up": {"2", "3", "5", "54", "90"},
},
},
{
subjectReference: "user",
entity: "repository:3",
assertions: map[string][]string{
"up": {"2", "3", "5", "54", "90"},
},
},
{
subjectReference: "user",
entity: "repository:12",
assertions: map[string][]string{
"deploy": {"3", "5", "54", "90"},
},
},
{
subjectReference: "user",
entity: "repository:82",
assertions: map[string][]string{
"check": {"1", "2", "3", "5"},
},
},
},
}

// filters

schemaReader := factories.SchemaReaderFactory(db)
dataReader := factories.DataReaderFactory(db)
dataWriter := factories.DataWriterFactory(db)

checkEngine := NewCheckEngine(schemaReader, dataReader)

lookupEngine := NewLookupEngine(
checkEngine,
schemaReader,
dataReader,
)

invoker := invoke.NewDirectInvoker(
schemaReader,
dataReader,
checkEngine,
nil,
lookupEngine,
nil,
)

checkEngine.SetInvoker(invoker)

var tuples []*base.Tuple

for _, relationship := range tests.relationships {
t, err := tuple.Tuple(relationship)
Expect(err).ShouldNot(HaveOccurred())
tuples = append(tuples, t)
}

var attributes []*base.Attribute

for _, attr := range tests.attributes {
a, err := attribute.Attribute(attr)
Expect(err).ShouldNot(HaveOccurred())
attributes = append(attributes, a)
}

_, err = dataWriter.Write(context.Background(), "t1", database.NewTupleCollection(tuples...), database.NewAttributeCollection(attributes...))
Expect(err).ShouldNot(HaveOccurred())

for _, filter := range tests.filters {
entity, err := tuple.E(filter.entity)
Expect(err).ShouldNot(HaveOccurred())

for permission, res := range filter.assertions {

ct := ""

var ids []string

for {
response, err := invoker.LookupSubject(context.Background(), &base.PermissionLookupSubjectRequest{
TenantId: "t1",
SubjectReference: tuple.RelationReference(filter.subjectReference),
Entity: entity,
Permission: permission,
Metadata: &base.PermissionLookupSubjectRequestMetadata{
SnapToken: token.NewNoopToken().Encode().String(),
SchemaVersion: "",
},
ContinuousToken: ct,
PageSize: 2,
})
Expect(err).ShouldNot(HaveOccurred())

ids = append(ids, response.GetSubjectIds()...)

ct = response.GetContinuousToken()

if ct == "" {
break
}
}

Expect(ids).Should(Equal(res))
}
}
})
})
})

0 comments on commit f90d035

Please sign in to comment.