Skip to content

Commit

Permalink
create relationships for all matched relations
Browse files Browse the repository at this point in the history
This allows for relations to be determined based on their type.

All matched relations will be created.

Signed-off-by: Mike Mason <[email protected]>
  • Loading branch information
mikemrm committed Jun 21, 2023
1 parent 5f8c6e9 commit e7cdd25
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 20 deletions.
33 changes: 27 additions & 6 deletions internal/pubsub/subscriber.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ func (s *Subscriber) processEvent(msg *message.Message) error {
func (s *Subscriber) createRelationships(ctx context.Context, msg *message.Message, resource types.Resource, additionalSubjectIDs []gidx.PrefixedID) error {
var relationships []types.Relationship

rType := s.qe.GetResourceType(resource.Type)
if rType == nil {
s.logger.Warnw("no resource type found for", "resource_type", resource.Type)

return nil
}

// Attempt to create relationships from the message fields. If this fails, reject the message
for _, id := range additionalSubjectIDs {
subjResource, err := s.qe.NewResourceFromID(id)
Expand All @@ -174,13 +181,27 @@ func (s *Subscriber) createRelationships(ctx context.Context, msg *message.Messa
continue
}

relationship := types.Relationship{
Resource: resource,
Relation: subjResource.Type,
Subject: subjResource,
}
for _, rel := range rType.Relationships {
var relation string

relationships = append(relationships, relationship)
for _, tName := range rel.Types {
if tName == subjResource.Type {
relation = rel.Relation

break
}
}

if relation != "" {
relationship := types.Relationship{
Resource: resource,
Relation: relation,
Subject: subjResource,
}

relationships = append(relationships, relationship)
}
}
}

if len(relationships) == 0 {
Expand Down
9 changes: 9 additions & 0 deletions internal/query/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,13 @@ var (

// ErrInvalidReference represents an error condition where a given SpiceDB object reference is for some reason invalid.
ErrInvalidReference = errors.New("invalid reference")

// ErrInvalidNamespace represents an error when the id prefix is not found in the resource schema
ErrInvalidNamespace = errors.New("invalid namespace")

// ErrInvalidType represents an error when a resource type is not found in the resource schema
ErrInvalidType = errors.New("invalid type")

// ErrInvalidRelationship represents an error when no matching relationship was found
ErrInvalidRelationship = errors.New("invalid relationship")
)
15 changes: 15 additions & 0 deletions internal/query/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,21 @@ func (e *Engine) NewResourceFromID(id gidx.PrefixedID) (types.Resource, error) {
return out, nil
}

// GetResourceType returns the resource type by name
func (e *Engine) GetResourceType(name string) *types.ResourceType {
if e.schema == nil {
e.schema = iapl.DefaultPolicy().Schema()
}

for _, resourceType := range e.schema {
if resourceType.Name == name {
return &resourceType
}
}

return nil
}

// SubjectHasPermission returns nil to satisfy the Engine interface.
func (e *Engine) SubjectHasPermission(ctx context.Context, subject types.Resource, action string, resource types.Resource, queryToken string) error {
e.Called()
Expand Down
25 changes: 15 additions & 10 deletions internal/query/relations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package query

import (
"context"
"errors"
"io"
"strings"

Expand All @@ -13,20 +12,14 @@ import (

var roleSubjectRelation = "subject"

var (
errorInvalidNamespace = errors.New("invalid namespace")
errorInvalidType = errors.New("invalid type")
errorInvalidRelationship = errors.New("invalid relationship")
)

func (e *engine) getTypeForResource(res types.Resource) (types.ResourceType, error) {
for _, resType := range e.schema {
if res.Type == resType.Name {
return resType, nil
}
}

return types.ResourceType{}, errorInvalidType
return types.ResourceType{}, ErrInvalidType
}

func (e *engine) validateRelationship(rel types.Relationship) error {
Expand All @@ -40,6 +33,8 @@ func (e *engine) validateRelationship(rel types.Relationship) error {
return err
}

e.logger.Infow("validation relationship", "sub", subjType.Name, "rel", rel.Relation, "res", resType.Name)

for _, typeRel := range resType.Relationships {
// If we find a relation with a name and type that matches our relationship,
// return
Expand All @@ -53,7 +48,7 @@ func (e *engine) validateRelationship(rel types.Relationship) error {
}

// No matching relationship was found, so we should return an error
return errorInvalidRelationship
return ErrInvalidRelationship
}

func resourceToSpiceDBRef(namespace string, r types.Resource) *pb.ObjectReference {
Expand Down Expand Up @@ -441,7 +436,7 @@ func (e *engine) NewResourceFromID(id gidx.PrefixedID) (types.Resource, error) {

rType, ok := e.schemaPrefixMap[prefix]
if !ok {
return types.Resource{}, errorInvalidNamespace
return types.Resource{}, ErrInvalidNamespace
}

out := types.Resource{
Expand All @@ -451,3 +446,13 @@ func (e *engine) NewResourceFromID(id gidx.PrefixedID) (types.Resource, error) {

return out, nil
}

// GetResourceType returns the resource type by name
func (e *engine) GetResourceType(name string) *types.ResourceType {
rType, ok := e.schemaTypeMap[name]
if !ok {
return nil
}

return &rType
}
2 changes: 1 addition & 1 deletion internal/query/relations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ func TestRelationships(t *testing.T) {
Subject: parentRes,
},
CheckFn: func(ctx context.Context, t *testing.T, res testingx.TestResult[[]types.Relationship]) {
assert.ErrorIs(t, res.Err, errorInvalidRelationship)
assert.ErrorIs(t, res.Err, ErrInvalidRelationship)
},
},
{
Expand Down
10 changes: 7 additions & 3 deletions internal/query/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Engine interface {
ListRoles(ctx context.Context, resource types.Resource, queryToken string) ([]types.Role, error)
DeleteRelationships(ctx context.Context, resource types.Resource) (string, error)
NewResourceFromID(id gidx.PrefixedID) (types.Resource, error)
GetResourceType(name string) *types.ResourceType
SubjectHasPermission(ctx context.Context, subject types.Resource, action string, resource types.Resource, queryToken string) error
}

Expand All @@ -30,13 +31,16 @@ type engine struct {
client *authzed.Client
schema []types.ResourceType
schemaPrefixMap map[string]types.ResourceType
schemaTypeMap map[string]types.ResourceType
}

func (e *engine) cacheSchemaPrefixes() {
func (e *engine) cacheSchemaResources() {
e.schemaPrefixMap = make(map[string]types.ResourceType, len(e.schema))
e.schemaTypeMap = make(map[string]types.ResourceType, len(e.schema))

for _, res := range e.schema {
e.schemaPrefixMap[res.IDPrefix] = res
e.schemaTypeMap[res.Name] = res
}
}

Expand All @@ -55,7 +59,7 @@ func NewEngine(namespace string, client *authzed.Client, options ...Option) Engi
if e.schema == nil {
e.schema = iapl.DefaultPolicy().Schema()

e.cacheSchemaPrefixes()
e.cacheSchemaResources()
}

return e
Expand All @@ -76,6 +80,6 @@ func WithPolicy(policy iapl.Policy) Option {
return func(e *engine) {
e.schema = policy.Schema()

e.cacheSchemaPrefixes()
e.cacheSchemaResources()
}
}

0 comments on commit e7cdd25

Please sign in to comment.