diff --git a/cmd/server.go b/cmd/server.go index 10debe79..d0b94c5a 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -79,7 +79,7 @@ func serve(ctx context.Context, cfg *config.AppConfig) { logger.Fatal("failed to initialize new server", zap.Error(err)) } - r, err := api.NewRouter(cfg.OIDC, engine, logger) + r, err := api.NewRouter(cfg.OIDC, engine, api.WithLogger(logger)) if err != nil { logger.Fatalw("unable to initialize router", "error", err) } diff --git a/go.mod b/go.mod index 299f894e..6c7bc49c 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 go.opentelemetry.io/otel v1.16.0 go.opentelemetry.io/otel/trace v1.16.0 + go.uber.org/multierr v1.9.0 go.uber.org/zap v1.24.0 google.golang.org/grpc v1.56.1 gopkg.in/yaml.v3 v3.0.1 @@ -91,7 +92,6 @@ require ( go.opentelemetry.io/otel/sdk v1.16.0 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect go.uber.org/atomic v1.10.0 // indirect - go.uber.org/multierr v1.9.0 // indirect golang.org/x/crypto v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/net v0.12.0 // indirect diff --git a/internal/api/permissions.go b/internal/api/permissions.go index 29005012..ed2f2073 100644 --- a/internal/api/permissions.go +++ b/internal/api/permissions.go @@ -1,13 +1,31 @@ package api import ( + "context" "errors" "fmt" "net/http" + "time" "github.com/labstack/echo/v4" "go.infratographer.com/permissions-api/internal/query" + "go.infratographer.com/permissions-api/internal/types" "go.infratographer.com/x/gidx" + "go.uber.org/multierr" +) + +const ( + defaultMaxCheckConcurrency = 5 + + maxCheckDuration = 5 * time.Second +) + +var ( + // ErrNoActionDefined is the error returned when an access request is has no action defined + ErrNoActionDefined = errors.New("no action defined") + + // ErrAccessDenied is returned when access is denied + ErrAccessDenied = errors.New("access denied") ) // checkAction will check if a subject is allowed to perform an action on a resource. @@ -72,6 +90,174 @@ func (r *Router) checkAction(c echo.Context) error { return c.JSON(http.StatusOK, echo.Map{}) } +type checkPermissionsRequest struct { + Actions []checkAction `json:"actions"` +} + +type checkAction struct { + ResourceID string `json:"resource_id"` + Action string `json:"action"` +} + +type checkRequest struct { + Index int + Resource types.Resource + Action string +} + +type checkResult struct { + Request checkRequest + Error error +} + +// checkAllActions will check if a subject is allowed to perform an action on a list of resources. +// This is the permissions check endpoint. +// It will return a 200 if the subject is allowed to perform all requested resource actions. +// It will return a 400 if the request is invalid. +// It will return a 403 if the subject is not allowed to perform all requested resource actions. +// +// Note that this expects a JWT token to be present in the request. This token must +// contain the subject of the request in the "sub" claim. +func (r *Router) checkAllActions(c echo.Context) error { + ctx, span := tracer.Start(c.Request().Context(), "api.checkAllActions") + defer span.End() + + // Subject validation + subject, err := currentSubject(c) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to get the subject").SetInternal(err) + } + + subjectResource, err := r.engine.NewResourceFromID(subject) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "error processing subject ID").SetInternal(err) + } + + var reqBody checkPermissionsRequest + + if err := c.Bind(&reqBody); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "error parsing request body").SetInternal(err) + } + + var errs []error + + requestsCh := make(chan checkRequest, len(reqBody.Actions)) + + for i, check := range reqBody.Actions { + if check.Action == "" { + errs = append(errs, fmt.Errorf("check %d: %w", i, ErrNoActionDefined)) + + continue + } + + resourceID, err := gidx.Parse(check.ResourceID) + if err != nil { + errs = append(errs, fmt.Errorf("check %d: %w: error parsing resource id: %s", i, err, check.ResourceID)) + + continue + } + + resource, err := r.engine.NewResourceFromID(resourceID) + if err != nil { + errs = append(errs, fmt.Errorf("check %d: %w: error creating resource from id: %s", i, err, resourceID.String())) + + continue + } + + requestsCh <- checkRequest{ + Index: i, + Resource: resource, + Action: check.Action, + } + } + + close(requestsCh) + + if len(errs) != 0 { + return echo.NewHTTPError(http.StatusBadRequest, "invalid check request").SetInternal(multierr.Combine(errs...)) + } + + resultsCh := make(chan checkResult, len(reqBody.Actions)) + + ctx, cancel := context.WithTimeout(ctx, maxCheckDuration) + + defer cancel() + + for i := 0; i < r.concurrentChecks; i++ { + go func() { + for { + var result checkResult + + select { + case check, ok := <-requestsCh: + // if channel is closed, quit the go routine. + if !ok { + return + } + + result.Request = check + + // Check the permissions + err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) + if err != nil { + result.Error = err + } + case <-ctx.Done(): + result.Error = ctx.Err() + } + + resultsCh <- result + } + }() + } + + var ( + unauthorizedErrors int + internalErrors int + allErrors []error + ) + + for i := 0; i < len(reqBody.Actions); i++ { + select { + case result := <-resultsCh: + if result.Error != nil { + if errors.Is(result.Error, query.ErrActionNotAssigned) { + err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", + ErrAccessDenied, subject, result.Request.Action, result.Request.Resource.ID.String()) + + unauthorizedErrors++ + + allErrors = append(allErrors, err) + } else { + err := fmt.Errorf("check %d: %w", result.Request.Index, result.Error) + + internalErrors++ + + allErrors = append(allErrors, err) + } + } + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + internalErrors++ + + allErrors = append(allErrors, ctx.Err()) + } + } + } + + if internalErrors != 0 { + return echo.NewHTTPError(http.StatusInternalServerError, "an error occurred checking permissions").SetInternal(multierr.Combine(allErrors...)) + } + + if unauthorizedErrors != 0 { + msg := fmt.Sprintf("subject '%s' does not have permission to the requested resource actions", subject) + + return echo.NewHTTPError(http.StatusForbidden, msg).SetInternal(multierr.Combine(allErrors...)) + } + + return nil +} + func getParam(c echo.Context, name string) (string, bool) { values, ok := c.QueryParams()[name] if !ok { diff --git a/internal/api/router.go b/internal/api/router.go index 458071bd..bd3d1c29 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -18,22 +18,32 @@ type Router struct { authMW echo.MiddlewareFunc engine query.Engine logger *zap.SugaredLogger + + concurrentChecks int } // NewRouter returns a new api router -func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, l *zap.SugaredLogger) (*Router, error) { +func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, options ...Option) (*Router, error) { auth, err := echojwtx.NewAuth(context.Background(), authCfg) if err != nil { return nil, err } - out := &Router{ + router := &Router{ authMW: auth.Middleware(), engine: engine, - logger: l.Named("api"), + logger: zap.NewNop().Sugar(), + + concurrentChecks: defaultMaxCheckConcurrency, + } + + for _, opt := range options { + if err := opt(router); err != nil { + return nil, err + } } - return out, nil + return router, nil } // Routes will add the routes for this API version to a router group @@ -58,6 +68,32 @@ func (r *Router) Routes(rg *echo.Group) { // /allow is the permissions check endpoint v1.GET("/allow", r.checkAction) + v1.POST("/allow", r.checkAllActions) + } +} + +// Option defines a router option function. +type Option func(r *Router) error + +// WithLogger sets the logger for the router. +func WithLogger(logger *zap.SugaredLogger) Option { + return func(r *Router) error { + r.logger = logger.Named("api") + + return nil + } +} + +// WithCheckConcurrency sets the check concurrency for bulk permission checks. +func WithCheckConcurrency(count int) Option { + return func(r *Router) error { + if count <= 0 { + count = 5 + } + + r.concurrentChecks = count + + return nil } } diff --git a/pkg/permissions/checker.go b/pkg/permissions/checker.go new file mode 100644 index 00000000..1a2f5e58 --- /dev/null +++ b/pkg/permissions/checker.go @@ -0,0 +1,75 @@ +package permissions + +import ( + "context" + + "github.com/labstack/echo/v4" + "go.infratographer.com/x/gidx" +) + +var ( + // CheckerCtxKey is the context key used to set the checker handling function + CheckerCtxKey = checkerCtxKey{} + + // DefaultAllowChecker defaults to allow when checker is disabled or skipped + DefaultAllowChecker Checker = func(_ context.Context, _ ...AccessRequest) error { + return nil + } + + // DefaultDenyChecker defaults to denied when checker is disabled or skipped + DefaultDenyChecker Checker = func(_ context.Context, _ ...AccessRequest) error { + return ErrPermissionDenied + } +) + +// Checker defines the checker function definition +type Checker func(ctx context.Context, requests ...AccessRequest) error + +// AccessRequest defines the required fields to check permissions access. +type AccessRequest struct { + ResourceID gidx.PrefixedID `json:"resource_id"` + Action string `json:"action"` +} + +type checkerCtxKey struct{} + +func setCheckerContext(c echo.Context, checker Checker) { + if checker == nil { + checker = DefaultDenyChecker + } + + req := c.Request().WithContext( + context.WithValue( + c.Request().Context(), + CheckerCtxKey, + checker, + ), + ) + + c.SetRequest(req) +} + +// CheckAccess runs the checker function to check if the provided resource and action are supported. +func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) error { + checker, ok := ctx.Value(CheckerCtxKey).(Checker) + if !ok { + return ErrCheckerNotFound + } + + request := AccessRequest{ + ResourceID: resource, + Action: action, + } + + return checker(ctx, request) +} + +// CheckAll runs the checker function to check if all the provided resources and actions are permitted. +func CheckAll(ctx context.Context, requests ...AccessRequest) error { + checker, ok := ctx.Value(CheckerCtxKey).(Checker) + if !ok { + return ErrCheckerNotFound + } + + return checker(ctx, requests...) +} diff --git a/pkg/permissions/permissions.go b/pkg/permissions/permissions.go index 8bc72541..33b64d47 100644 --- a/pkg/permissions/permissions.go +++ b/pkg/permissions/permissions.go @@ -1,7 +1,10 @@ package permissions import ( + "bytes" "context" + "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -12,7 +15,6 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/pkg/errors" "go.infratographer.com/x/echojwtx" - "go.infratographer.com/x/gidx" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -27,19 +29,6 @@ const ( ) var ( - // CheckerCtxKey is the context key used to set the checker handling function - CheckerCtxKey = checkerCtxKey{} - - // DefaultAllowChecker defaults to allow when checker is disabled or skipped - DefaultAllowChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { - return nil - } - - // DefaultDenyChecker defaults to denied when checker is disabled or skipped - DefaultDenyChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { - return ErrPermissionDenied - } - defaultClient = &http.Client{ Timeout: defaultClientTimeout, Transport: otelhttp.NewTransport(http.DefaultTransport), @@ -48,11 +37,6 @@ var ( tracer = otel.GetTracerProvider().Tracer("go.infratographer.com/permissions-api/pkg/permissions") ) -// Checker defines the checker function definition -type Checker func(ctx context.Context, resource gidx.PrefixedID, action string) error - -type checkerCtxKey struct{} - // Permissions handles supporting authorization checks type Permissions struct { enabled bool @@ -97,35 +81,49 @@ func (p *Permissions) Middleware() echo.MiddlewareFunc { } } +type checkPermissionRequest struct { + Actions []AccessRequest `json:"actions"` +} + func (p *Permissions) checker(c echo.Context, actor, token string) Checker { - return func(ctx context.Context, resource gidx.PrefixedID, action string) error { - ctx, span := tracer.Start(ctx, "permissions.checkAccess") + return func(ctx context.Context, requests ...AccessRequest) error { + ctx, span := tracer.Start(ctx, "permissions.checker") defer span.End() span.SetAttributes( attribute.String("permissions.actor", actor), - attribute.String("permissions.action", action), - attribute.String("permissions.resource", resource.String()), + attribute.Int("permissions.requests", len(requests)), ) - logger := p.logger.With("actor", actor, "resource", resource.String(), "action", action) + logger := p.logger.With("actor", actor, "requests", len(requests)) - values := url.Values{} - values.Add("resource", resource.String()) - values.Add("action", action) + request := checkPermissionRequest{ + Actions: requests, + } - url := *p.url - url.RawQuery = values.Encode() + var reqBody bytes.Buffer - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) + if err := json.NewEncoder(&reqBody).Encode(request); err != nil { + err = errors.WithStack(err) + + span.SetStatus(codes.Error, err.Error()) + logger.Errorw("failed to encode request body", "error", err) + + return err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url.String(), &reqBody) if err != nil { - span.SetStatus(codes.Error, errors.WithStack(err).Error()) + err = errors.WithStack(err) + + span.SetStatus(codes.Error, err.Error()) logger.Errorw("failed to create checker request", "error", err) - return errors.WithStack(err) + return err } req.Header.Set(echo.HeaderAuthorization, c.Request().Header.Get(echo.HeaderAuthorization)) + req.Header.Set(echo.HeaderContentType, "application/json") resp, err := p.client.Do(req) if err != nil { @@ -191,40 +189,14 @@ func New(config Config, options ...Option) (*Permissions, error) { return p, nil } -func setCheckerContext(c echo.Context, checker Checker) { - if checker == nil { - checker = DefaultDenyChecker - } - - req := c.Request().WithContext( - context.WithValue( - c.Request().Context(), - CheckerCtxKey, - checker, - ), - ) - - c.SetRequest(req) -} - func ensureValidServerResponse(resp *http.Response) error { if resp.StatusCode >= http.StatusMultiStatus { if resp.StatusCode == http.StatusForbidden { return ErrPermissionDenied } - return ErrBadResponse + return fmt.Errorf("%w: %d", ErrBadResponse, resp.StatusCode) } return nil } - -// CheckAccess runs the checker function to check if the provided resource and action are supported. -func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) error { - checker, ok := ctx.Value(CheckerCtxKey).(Checker) - if !ok { - return ErrCheckerNotFound - } - - return checker(ctx, resource, action) -} diff --git a/pkg/permissions/permissions_test.go b/pkg/permissions/permissions_test.go index d5f9b58b..90949920 100644 --- a/pkg/permissions/permissions_test.go +++ b/pkg/permissions/permissions_test.go @@ -1,6 +1,7 @@ package permissions_test import ( + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -30,25 +31,40 @@ func TestPermissions(t *testing.T) { return } - resource, err := gidx.Parse(r.URL.Query().Get("resource")) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - - return + var reqBody struct { + Actions []struct { + ResourceID string `json:"resource_id"` + Action string `json:"action"` + } `json:"actions"` } - action := r.URL.Query().Get("action") - - if resource != allowedID && resource != deniedID { + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { w.WriteHeader(http.StatusInternalServerError) return } - if resource != allowedID || !actions[action] { - w.WriteHeader(http.StatusForbidden) + for _, request := range reqBody.Actions { + resource, err := gidx.Parse(request.ResourceID) + if err != nil { + w.WriteHeader(http.StatusBadRequest) - return + return + } + + action := request.Action + + if resource != allowedID && resource != deniedID { + w.WriteHeader(http.StatusInternalServerError) + + return + } + + if resource != allowedID || !actions[action] { + w.WriteHeader(http.StatusForbidden) + + return + } } }))