From 30227c1b60ba915a10f1979513f57b706ca0fdd0 Mon Sep 17 00:00:00 2001 From: Bailin He <15058035+bailinhe@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:22:17 -0400 Subject: [PATCH] Apply suggestions from code review Co-authored-by: John Schaeffer Signed-off-by: Bailin He <15058035+bailinhe@users.noreply.github.com> --- cmd/createrole.go | 45 +++++----------- internal/api/rolebindings.go | 24 ++------- internal/api/router.go | 4 +- internal/api/types.go | 7 +-- internal/query/relations.go | 56 ++++++++++++++++++++ internal/query/rolebindings.go | 97 ++++++++++++---------------------- 6 files changed, 111 insertions(+), 122 deletions(-) diff --git a/cmd/createrole.go b/cmd/createrole.go index dff4dc6f9..4beb91458 100644 --- a/cmd/createrole.go +++ b/cmd/createrole.go @@ -23,7 +23,6 @@ const ( createRoleFlagResource = "resource" createRoleFlagActions = "actions" createRoleFlagName = "name" - createRoleFlagIsV2 = "v2" ) var createRoleCmd = &cobra.Command{ @@ -42,7 +41,6 @@ func init() { flags.StringSlice(createRoleFlagActions, []string{}, "actions to assign to created role") flags.String(createRoleFlagResource, "", "resource to bind to created role") flags.String(createRoleFlagName, "", "name of role to create") - flags.Bool(createRoleFlagIsV2, false, "create a v2 role") v := viper.GetViper() @@ -50,7 +48,6 @@ func init() { viperx.MustBindFlag(v, createRoleFlagActions, flags.Lookup(createRoleFlagActions)) viperx.MustBindFlag(v, createRoleFlagResource, flags.Lookup(createRoleFlagResource)) viperx.MustBindFlag(v, createRoleFlagName, flags.Lookup(createRoleFlagName)) - viperx.MustBindFlag(v, createRoleFlagIsV2, flags.Lookup(createRoleFlagIsV2)) } func createRole(ctx context.Context, cfg *config.AppConfig) { @@ -58,7 +55,6 @@ func createRole(ctx context.Context, cfg *config.AppConfig) { actions := viper.GetStringSlice(createRoleFlagActions) resourceIDStr := viper.GetString(createRoleFlagResource) name := viper.GetString(createRoleFlagName) - v2 := viper.GetBool(createRoleFlagIsV2) if subjectIDStr == "" || len(actions) == 0 || resourceIDStr == "" || name == "" { logger.Fatal("invalid config") @@ -128,35 +124,22 @@ func createRole(ctx context.Context, cfg *config.AppConfig) { logger.Fatalw("error creating subject resource", "error", err) } - if v2 { - role, err := engine.CreateRoleV2(ctx, subjectResource, resource, name, actions) - if err != nil { - logger.Fatalw("error creating role", "error", err) - } - - rbsubj := []types.RoleBindingSubject{{SubjectResource: subjectResource}} - - roleres, err := engine.NewResourceFromID(role.ID) - if err != nil { - logger.Fatalw("error creating role resource", "error", err) - } - - rb, err := engine.CreateRoleBinding(ctx, subjectResource, resource, roleres, rbsubj) - if err != nil { - logger.Fatalw("error creating role binding", "error", err) - } + role, err := engine.CreateRoleV2(ctx, subjectResource, resource, name, actions) + if err != nil { + logger.Fatalw("error creating role", "error", err) + } - logger.Infof("created role %s[%s] and role-binding %s", role.Name, role.ID, rb.ID) - } else { - role, err := engine.CreateRole(ctx, subjectResource, resource, name, actions) - if err != nil { - logger.Fatalw("error creating role", "error", err) - } + rbsubj := []types.RoleBindingSubject{{SubjectResource: subjectResource}} - if err = engine.AssignSubjectRole(ctx, subjectResource, role); err != nil { - logger.Fatalw("error creating role", "error", err) - } + roleres, err := engine.NewResourceFromID(role.ID) + if err != nil { + logger.Fatalw("error creating role resource", "error", err) + } - logger.Infow("role successfully created", "role_id", role.ID) + rb, err := engine.CreateRoleBinding(ctx, subjectResource, resource, roleres, rbsubj) + if err != nil { + logger.Fatalw("error creating role binding", "error", err) } + + logger.Infof("created role %s[%s] and role-binding %s", role.Name, role.ID, rb.ID) } diff --git a/internal/api/rolebindings.go b/internal/api/rolebindings.go index db4566d23..feb67bb5b 100644 --- a/internal/api/rolebindings.go +++ b/internal/api/rolebindings.go @@ -18,9 +18,8 @@ func resourceToSubject(subjects []types.RoleBindingSubject) []roleBindingSubject resp := make([]roleBindingSubject, len(subjects)) for i, subj := range subjects { resp[i] = roleBindingSubject{ - ID: subj.SubjectResource.ID, - Type: subj.SubjectResource.Type, - Condition: nil, + ID: subj.SubjectResource.ID, + Type: subj.SubjectResource.Type, } } @@ -113,7 +112,6 @@ func (r *Router) roleBindingCreate(c echo.Context) error { func (r *Router) roleBindingsList(c echo.Context) error { resourceIDStr := c.Param("id") - roleIDStr := c.QueryParam("role_id") ctx, span := tracer.Start( c.Request().Context(), "api.roleBindingList", @@ -140,23 +138,7 @@ func (r *Router) roleBindingsList(c echo.Context) error { return err } - roleFilter := (*types.Resource)(nil) - - if roleIDStr != "" { - roleID, err := gidx.Parse(roleIDStr) - if err != nil { - return r.errorResponse("error parsing role ID", fmt.Errorf("%w: %s", ErrInvalidID, err.Error())) - } - - roleResource, err := r.engine.NewResourceFromID(roleID) - if err != nil { - return r.errorResponse("error creating role resource", err) - } - - roleFilter = &roleResource - } - - rbs, err := r.engine.ListRoleBindings(ctx, resource, roleFilter) + rbs, err := r.engine.ListRoleBindings(ctx, resource, nil) if err != nil { return r.errorResponse("error listing role-binding", err) } diff --git a/internal/api/router.go b/internal/api/router.go index 391602555..0eafba6ba 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -87,10 +87,10 @@ func (r *Router) Routes(rg *echo.Group) { v2.DELETE("/roles/:id", r.roleV2Delete) v2.GET("/resources/:id/role-bindings", r.roleBindingsList) - v2.GET("/resources/:id/role-bindings/:rb_id", r.roleBindingGet) + v2.GET("/role-bindings/:rb_id", r.roleBindingGet) v2.POST("/resources/:id/role-bindings", r.roleBindingCreate) v2.DELETE("/resources/:id/role-bindings/:rb_id", r.roleBindingsDelete) - v2.PATCH("/resources/:id/role-bindings/:rb_id", r.roleBindingUpdate) + v2.PATCH("/role-bindings/:rb_id", r.roleBindingUpdate) v2.GET("/actions", r.listActions) } diff --git a/internal/api/types.go b/internal/api/types.go index 33e756275..0bf128eb0 100644 --- a/internal/api/types.go +++ b/internal/api/types.go @@ -88,12 +88,9 @@ type roleBindingResponseRole struct { Name string `json:"name"` } -type roleBindingSubjectCondition struct{} - type roleBindingSubject struct { - ID gidx.PrefixedID `json:"id" binding:"required"` - Type string `json:"type,omitempty"` - Condition *roleBindingSubjectCondition `json:"condition,omitempty"` + ID gidx.PrefixedID `json:"id" binding:"required"` + Type string `json:"type,omitempty"` } type roleBindingRequest struct { diff --git a/internal/query/relations.go b/internal/query/relations.go index ddbfbc7b6..40a844cda 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -1231,3 +1231,59 @@ func (e *engine) NewResourceFromIDString(id string) (types.Resource, error) { return subject, nil } + +// rollbackUpdates is a helper function that rolls back a list of +// relationship updates on spiceDB. +func (e *engine) rollbackUpdates(ctx context.Context, updates []*pb.RelationshipUpdate) error { + updatesLen := len(updates) + rollbacks := make([]*pb.RelationshipUpdate, 0, updatesLen) + + for i := range updates { + // reversed order + u := updates[updatesLen-i-1] + + if u == nil { + continue + } + + var op pb.RelationshipUpdate_Operation + + switch u.Operation { + case pb.RelationshipUpdate_OPERATION_CREATE: + fallthrough + case pb.RelationshipUpdate_OPERATION_TOUCH: + op = pb.RelationshipUpdate_OPERATION_DELETE + case pb.RelationshipUpdate_OPERATION_DELETE: + op = pb.RelationshipUpdate_OPERATION_TOUCH + default: + continue + } + + rollbacks = append(rollbacks, &pb.RelationshipUpdate{ + Operation: op, + Relationship: u.Relationship, + }) + } + + return e.applyUpdates(ctx, rollbacks) +} + +// applyUpdates is a wrapper function around the spiceDB WriteRelationships method +// it applies the given relationship updates and store the zed token for each resource. +func (e *engine) applyUpdates(ctx context.Context, updates []*pb.RelationshipUpdate) error { + resp, err := e.client.WriteRelationships(ctx, &pb.WriteRelationshipsRequest{Updates: updates}) + if err != nil { + return err + } + + t := resp.WrittenAt.Token + + for _, u := range updates { + resID := u.Relationship.Resource.ObjectId + if err := e.upsertZedToken(ctx, resID, t); err != nil { + return err + } + } + + return nil +} diff --git a/internal/query/rolebindings.go b/internal/query/rolebindings.go index a1c9f43ef..43f47f4ce 100644 --- a/internal/query/rolebindings.go +++ b/internal/query/rolebindings.go @@ -60,6 +60,8 @@ func (e *engine) GetRoleBinding(ctx context.Context, roleBinding types.Resource) rb.Subjects = make([]types.RoleBindingSubject, 0, len(rbRel)) + var roleID gidx.PrefixedID + for _, rel := range rbRel { // process subject relationships if rel.Relation == iapl.RolebindingSubjectRelation { @@ -77,26 +79,26 @@ func (e *engine) GetRoleBinding(ctx context.Context, roleBinding types.Resource) } // process role relationships - roleID, err := gidx.Parse(rel.Subject.Object.ObjectId) + roleID, err = gidx.Parse(rel.Subject.Object.ObjectId) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return types.RoleBinding{}, err } + } - dbRole, err := e.store.GetRoleByID(ctx, roleID) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, err.Error()) + dbRole, err := e.store.GetRoleByID(ctx, roleID) + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) - return types.RoleBinding{}, err - } + return types.RoleBinding{}, err + } - rb.Role = types.Role{ - ID: roleID, - Name: dbRole.Name, - } + rb.Role = types.Role{ + ID: roleID, + Name: dbRole.Name, } return rb, nil @@ -142,15 +144,24 @@ func (e *engine) CreateRoleBinding( dbCtx, err := e.store.BeginContext(ctx) if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + return types.RoleBinding{}, nil } rbResourceType, ok := e.schemaTypeMap[e.rbac.RoleBindingResource.Name] if !ok { - return types.RoleBinding{}, fmt.Errorf( + err := fmt.Errorf( "%w: invalid role-binding resource type: %s", ErrInvalidType, e.rbac.RoleBindingResource.Name, ) + + span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) + logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) + + return types.RoleBinding{}, err } rbid, err := gidx.NewID(rbResourceType.IDPrefix) @@ -215,9 +226,7 @@ func (e *engine) CreateRoleBinding( } } - if _, err := e.client.WriteRelationships(ctx, &pb.WriteRelationshipsRequest{ - Updates: append(updates, subjUpdates...), - }); err != nil { + if err := e.applyUpdates(ctx, append(updates, subjUpdates...)); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) @@ -229,7 +238,7 @@ func (e *engine) CreateRoleBinding( span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - logRollbackErr(e.logger, e.rollbackRoleBindingUpdates(ctx, updates)) + logRollbackErr(e.logger, e.rollbackUpdates(ctx, updates)) return types.RoleBinding{}, err } @@ -319,7 +328,7 @@ func (e *engine) DeleteRoleBinding(ctx context.Context, rb types.Resource) error } // apply changes - if _, err := e.client.WriteRelationships(ctx, &pb.WriteRelationshipsRequest{Updates: updates}); err != nil { + if err := e.applyUpdates(ctx, updates); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) @@ -331,7 +340,7 @@ func (e *engine) DeleteRoleBinding(ctx context.Context, rb types.Resource) error span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - logRollbackErr(e.logger, e.rollbackRoleBindingUpdates(ctx, updates)) + logRollbackErr(e.logger, e.rollbackUpdates(ctx, updates)) return err } @@ -340,7 +349,7 @@ func (e *engine) DeleteRoleBinding(ctx context.Context, rb types.Resource) error span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - logRollbackErr(e.logger, e.rollbackRoleBindingUpdates(ctx, updates)) + logRollbackErr(e.logger, e.rollbackUpdates(ctx, updates)) return err } @@ -400,6 +409,7 @@ func (e *engine) ListRoleBindings(ctx context.Context, resource types.Resource, err := fmt.Errorf("%w: dangling grant relationship: %s", err, rel.String()) span.RecordError(err) + span.SetStatus(codes.Error, err.Error()) e.logger.Warnf(err.Error()) } @@ -516,8 +526,7 @@ func (e *engine) UpdateRoleBinding(ctx context.Context, actor, rb types.Resource i++ } - _, err = e.client.WriteRelationships(ctx, &pb.WriteRelationshipsRequest{Updates: updates}) - if err != nil { + if err := e.applyUpdates(ctx, updates); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) @@ -531,14 +540,14 @@ func (e *engine) UpdateRoleBinding(ctx context.Context, actor, rb types.Resource span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - logRollbackErr(e.logger, e.rollbackRoleBindingUpdates(ctx, updates)) + logRollbackErr(e.logger, e.rollbackUpdates(ctx, updates)) } if err := e.store.CommitContext(dbCtx); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - logRollbackErr(e.logger, e.rollbackRoleBindingUpdates(ctx, updates)) + logRollbackErr(e.logger, e.rollbackUpdates(ctx, updates)) return types.RoleBinding{}, err } @@ -722,7 +731,7 @@ func (e *engine) deleteRoleBindingsForRole(ctx context.Context, roleResource typ } // 3.2 delete all the relationships - if _, err := e.client.WriteRelationships(ctx, &pb.WriteRelationshipsRequest{Updates: updates}); err != nil { + if err := e.applyUpdates(ctx, updates); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) @@ -734,7 +743,7 @@ func (e *engine) deleteRoleBindingsForRole(ctx context.Context, roleResource typ span.RecordError(err) span.SetStatus(codes.Error, err.Error()) logRollbackErr(e.logger, e.store.RollbackContext(dbCtx)) - logRollbackErr(e.logger, e.rollbackRoleBindingUpdates(ctx, updates)) + logRollbackErr(e.logger, e.rollbackUpdates(ctx, updates)) return err } @@ -833,41 +842,3 @@ func (e *engine) rolebindingRelationshipUpdateForSubject( return &pb.RelationshipUpdate{Operation: op, Relationship: rel}, nil } - -// rollbackRoleBindingUpdates is a helper function that rolls back a list of -// relationship updates on spiceDB. -func (e *engine) rollbackRoleBindingUpdates(ctx context.Context, updates []*pb.RelationshipUpdate) error { - updatesLen := len(updates) - rollbacks := make([]*pb.RelationshipUpdate, 0, updatesLen) - - for i := range updates { - // reversed order - u := updates[updatesLen-i-1] - - if u == nil { - continue - } - - var op pb.RelationshipUpdate_Operation - - switch u.Operation { - case pb.RelationshipUpdate_OPERATION_CREATE: - fallthrough - case pb.RelationshipUpdate_OPERATION_TOUCH: - op = pb.RelationshipUpdate_OPERATION_DELETE - case pb.RelationshipUpdate_OPERATION_DELETE: - op = pb.RelationshipUpdate_OPERATION_TOUCH - default: - continue - } - - rollbacks = append(rollbacks, &pb.RelationshipUpdate{ - Operation: op, - Relationship: u.Relationship, - }) - } - - _, err := e.client.WriteRelationships(ctx, &pb.WriteRelationshipsRequest{Updates: rollbacks}) - - return err -}