Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure cursored LRv2 calls are dispatched to LRv2 #2040

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/middleware/consistency/consistency_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func TestAddRevisionToContextWithCursor(t *testing.T) {
ds.On("RevisionFromString", optimized.String()).Return(optimized, nil).Once()

// cursor is at `optimized`
cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized)
cursor, err := cursor.EncodeFromDispatchCursor(&dispatch.Cursor{}, "somehash", optimized, nil)
require.NoError(err)

// revision in context is at `exact`
Expand Down
25 changes: 21 additions & 4 deletions internal/services/v1/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,22 @@ func TranslateExpansionTree(node *core.RelationTupleTreeNode) *v1.PermissionRela
}
}

const lrv2CursorFlag = "lrv2"

func (ps *permissionServer) LookupResources(req *v1.LookupResourcesRequest, resp v1.PermissionsService_LookupResourcesServer) error {
// If the cursor specifies that this is a LookupResources2 request, then that implementation must
// be used.
if req.OptionalCursor != nil {
_, ok, err := cursor.GetCursorFlag(req.OptionalCursor, lrv2CursorFlag)
if err != nil {
return ps.rewriteError(resp.Context(), err)
}

if ok {
return ps.lookupResources2(req, resp)
}
}

if ps.config.UseExperimentalLookupResources2 {
return ps.lookupResources2(req, resp)
}
Expand Down Expand Up @@ -445,7 +460,7 @@ func (ps *permissionServer) lookupResources1(req *v1.LookupResourcesRequest, res
}

if req.OptionalCursor != nil {
decodedCursor, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -476,7 +491,7 @@ func (ps *permissionServer) lookupResources1(req *v1.LookupResourcesRequest, res
alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{}
}

encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision)
encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, nil)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -573,7 +588,7 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res
}

if req.OptionalCursor != nil {
decodedCursor, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, lrRequestHash)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -605,7 +620,9 @@ func (ps *permissionServer) lookupResources2(req *v1.LookupResourcesRequest, res
alreadyPublishedPermissionedResourceIds[found.ResourceId] = struct{}{}
}

encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision)
encodedCursor, err := cursor.EncodeFromDispatchCursor(result.AfterResponseCursor, lrRequestHash, atRevision, map[string]string{
lrv2CursorFlag: "1",
})
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/services/v1/relationships.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest,
}

if req.OptionalCursor != nil {
decodedCursor, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash)
decodedCursor, _, err := cursor.DecodeToDispatchCursor(req.OptionalCursor, rrRequestHash)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down Expand Up @@ -249,7 +249,7 @@ func (ps *permissionServer) ReadRelationships(req *v1.ReadRelationshipsRequest,
}

dispatchCursor.Sections[0] = tuple.StringWithoutCaveat(tpl)
encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision)
encodedCursor, err := cursor.EncodeFromDispatchCursor(dispatchCursor, rrRequestHash, atRevision, nil)
if err != nil {
return ps.rewriteError(ctx, err)
}
Expand Down
29 changes: 23 additions & 6 deletions pkg/cursor/cursor.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func Decode(encoded *v1.Cursor) (*impl.DecodedCursor, error) {
// consumption, including the provided call context to ensure the API cursor reflects the calling
// API method. The call hash should contain all the parameters of the calling API function,
// as well as its revision and name.
func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision) (*v1.Cursor, error) {
func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterHash string, revision datastore.Revision, flags map[string]string) (*v1.Cursor, error) {
if dispatchCursor == nil {
return nil, spiceerrors.MustBugf("got nil dispatch cursor")
}
Expand All @@ -60,34 +60,51 @@ func EncodeFromDispatchCursor(dispatchCursor *dispatch.Cursor, callAndParameterH
DispatchVersion: dispatchCursor.DispatchVersion,
Sections: dispatchCursor.Sections,
CallAndParametersHash: callAndParameterHash,
Flags: flags,
},
},
})
}

// GetCursorFlag retrieves a flag from an encoded API cursor, if any.
func GetCursorFlag(encoded *v1.Cursor, flagName string) (string, bool, error) {
decoded, err := Decode(encoded)
if err != nil {
return "", false, err
}

v1decoded := decoded.GetV1()
if v1decoded == nil {
return "", false, NewInvalidCursorErr(ErrNilCursor)
}

value, ok := v1decoded.Flags[flagName]
return value, ok, nil
}

// DecodeToDispatchCursor decodes an encoded API cursor into an internal dispatching cursor,
// ensuring that the provided call context matches that encoded into the API cursor. The call
// hash should contain all the parameters of the calling API function, as well as its revision
// and name.
func DecodeToDispatchCursor(encoded *v1.Cursor, callAndParameterHash string) (*dispatch.Cursor, error) {
func DecodeToDispatchCursor(encoded *v1.Cursor, callAndParameterHash string) (*dispatch.Cursor, map[string]string, error) {
decoded, err := Decode(encoded)
if err != nil {
return nil, err
return nil, nil, err
}

v1decoded := decoded.GetV1()
if v1decoded == nil {
return nil, NewInvalidCursorErr(ErrNilCursor)
return nil, nil, NewInvalidCursorErr(ErrNilCursor)
}

if v1decoded.CallAndParametersHash != callAndParameterHash {
return nil, NewInvalidCursorErr(ErrHashMismatch)
return nil, nil, NewInvalidCursorErr(ErrHashMismatch)
}

return &dispatch.Cursor{
DispatchVersion: v1decoded.DispatchVersion,
Sections: v1decoded.Sections,
}, nil
}, v1decoded.Flags, nil
vroldanbet marked this conversation as resolved.
Show resolved Hide resolved
}

// DecodeToDispatchRevision decodes an encoded API cursor into an internal dispatch revision.
Expand Down
7 changes: 4 additions & 3 deletions pkg/cursor/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ func TestEncodeDecode(t *testing.T) {
require := require.New(t)
encoded, err := EncodeFromDispatchCursor(&dispatch.Cursor{
Sections: tc.sections,
}, tc.hash, tc.revision)
}, tc.hash, tc.revision, map[string]string{"some": "flag"})
require.NoError(err)
require.NotNil(encoded)

decoded, err := DecodeToDispatchCursor(encoded, tc.hash)
decoded, flags, err := DecodeToDispatchCursor(encoded, tc.hash)
require.NoError(err)
require.NotNil(decoded)
require.Equal(map[string]string{"some": "flag"}, flags)

require.Equal(tc.sections, decoded.Sections)

Expand Down Expand Up @@ -123,7 +124,7 @@ func TestDecode(t *testing.T) {
t.Run(testName, func(t *testing.T) {
require := require.New(t)

decoded, err := DecodeToDispatchCursor(&v1.Cursor{
decoded, _, err := DecodeToDispatchCursor(&v1.Cursor{
Token: testCase.token,
}, testCase.expectedHash)

Expand Down
108 changes: 63 additions & 45 deletions pkg/proto/impl/v1/impl.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkg/proto/impl/v1/impl.pb.validate.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading