diff --git a/internal/api/roles.go b/internal/api/roles.go index 50166100c..3619286b4 100644 --- a/internal/api/roles.go +++ b/internal/api/roles.go @@ -11,6 +11,7 @@ import ( "go.opentelemetry.io/otel/trace" "go.infratographer.com/permissions-api/internal/query" + "go.infratographer.com/permissions-api/internal/storage" ) const ( @@ -54,7 +55,14 @@ func (r *Router) roleCreate(c echo.Context) error { } role, err := r.engine.CreateRole(ctx, subjectResource, resource, reqBody.Name, reqBody.Actions) - if err != nil { + + switch { + case err == nil: + case errors.Is(err, query.ErrInvalidAction): + return echo.NewHTTPError(http.StatusBadRequest, "error creating resource: "+err.Error()) + case errors.Is(err, storage.ErrRoleAlreadyExists), errors.Is(err, storage.ErrRoleNameTaken): + return echo.NewHTTPError(http.StatusConflict, "error creating resource: "+err.Error()) + default: return echo.NewHTTPError(http.StatusInternalServerError, "error creating resource").SetInternal(err) } @@ -103,11 +111,12 @@ func (r *Router) roleUpdate(c echo.Context) error { // Roles belong to resources by way of the actions they can perform; do the permissions // check on the role resource. resource, err := r.engine.GetRoleResource(ctx, roleResource) - if err != nil { - if errors.Is(err, query.ErrRoleNotFound) { - return echo.NewHTTPError(http.StatusNotFound, "resource not found").SetInternal(err) - } + switch { + case err == nil: + case errors.Is(err, query.ErrRoleNotFound): + return echo.NewHTTPError(http.StatusNotFound, "resource not found").SetInternal(err) + default: return echo.NewHTTPError(http.StatusInternalServerError, "error getting resource").SetInternal(err) } @@ -116,7 +125,14 @@ func (r *Router) roleUpdate(c echo.Context) error { } role, err := r.engine.UpdateRole(ctx, subjectResource, roleResource, reqBody.Name, reqBody.Actions) - if err != nil { + + switch { + case err == nil: + case errors.Is(err, query.ErrInvalidAction): + return echo.NewHTTPError(http.StatusBadRequest, "error updating resource: "+err.Error()) + case errors.Is(err, storage.ErrRoleNameTaken): + return echo.NewHTTPError(http.StatusConflict, "error updating resource: "+err.Error()) + default: return echo.NewHTTPError(http.StatusInternalServerError, "error updating resource").SetInternal(err) } diff --git a/internal/query/relations.go b/internal/query/relations.go index 46f29a04e..1e136e884 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -68,6 +68,28 @@ func resourceToSpiceDBRef(namespace string, r types.Resource) *pb.ObjectReferenc } } +func (e *engine) validateResourceActions(resource types.Resource, actions ...string) error { + var invalidActions []string + + for _, action := range actions { + containsFn := func(sliceAction types.Action) bool { + return sliceAction.Name == action + } + + rescType := e.schemaTypeMap[resource.Type] + + if !slices.ContainsFunc(rescType.Actions, containsFn) { + invalidActions = append(invalidActions, action) + } + } + + if len(invalidActions) == 0 { + return nil + } + + return fmt.Errorf("%w: %s for %s", ErrInvalidAction, strings.Join(invalidActions, ","), resource.Type) +} + // SubjectHasPermission checks if the given subject can do the given action on the given resource func (e *engine) SubjectHasPermission(ctx context.Context, subject types.Resource, action string, resource types.Resource) error { ctx, span := e.tracer.Start( @@ -99,14 +121,10 @@ func (e *engine) SubjectHasPermission(ctx context.Context, subject types.Resourc ), ) - var err error - - containsFn := func(sliceAction types.Action) bool { - return sliceAction.Name == action - } + err := e.validateResourceActions(resource, action) // Only check permissions if the requested action exists in the policy. - if rescType := e.schemaTypeMap[resource.Type]; slices.ContainsFunc(rescType.Actions, containsFn) { + if err == nil { req := &pb.CheckPermissionRequest{ Consistency: consistency, Resource: resourceToSpiceDBRef(e.namespace, resource), @@ -117,8 +135,6 @@ func (e *engine) SubjectHasPermission(ctx context.Context, subject types.Resourc } err = e.checkPermission(ctx, req) - } else { - err = ErrInvalidAction } switch { @@ -293,6 +309,10 @@ func (e *engine) CreateRole(ctx context.Context, actor, res types.Resource, role defer span.End() + if err := e.validateResourceActions(res, actions...); err != nil { + return types.Role{}, err + } + roleName = strings.TrimSpace(roleName) role := newRole(roleName, actions) @@ -385,6 +405,10 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou defer span.End() + if err := e.validateResourceActions(roleResource, newActions...); err != nil { + return types.Role{}, err + } + dbCtx, err := e.store.BeginContext(ctx) if err != nil { return types.Role{}, err