From 400833baa9f74938b90c8c53d27afb1e37573fe3 Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Mon, 24 Jul 2023 16:47:46 +0000 Subject: [PATCH 1/5] add support for bulk permission check requests Signed-off-by: Mike Mason --- cmd/server.go | 2 +- go.mod | 2 +- internal/api/permissions.go | 169 ++++++++++++++++++++++++++++++++++++ internal/api/router.go | 44 +++++++++- 4 files changed, 211 insertions(+), 6 deletions(-) diff --git a/cmd/server.go b/cmd/server.go index 10debe79..d0b94c5a 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -79,7 +79,7 @@ func serve(ctx context.Context, cfg *config.AppConfig) { logger.Fatal("failed to initialize new server", zap.Error(err)) } - r, err := api.NewRouter(cfg.OIDC, engine, logger) + r, err := api.NewRouter(cfg.OIDC, engine, api.WithLogger(logger)) if err != nil { logger.Fatalw("unable to initialize router", "error", err) } diff --git a/go.mod b/go.mod index 299f894e..6c7bc49c 100644 --- a/go.mod +++ b/go.mod @@ -18,6 +18,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0 go.opentelemetry.io/otel v1.16.0 go.opentelemetry.io/otel/trace v1.16.0 + go.uber.org/multierr v1.9.0 go.uber.org/zap v1.24.0 google.golang.org/grpc v1.56.1 gopkg.in/yaml.v3 v3.0.1 @@ -91,7 +92,6 @@ require ( go.opentelemetry.io/otel/sdk v1.16.0 // indirect go.opentelemetry.io/proto/otlp v0.19.0 // indirect go.uber.org/atomic v1.10.0 // indirect - go.uber.org/multierr v1.9.0 // indirect golang.org/x/crypto v0.11.0 // indirect golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect golang.org/x/net v0.12.0 // indirect diff --git a/internal/api/permissions.go b/internal/api/permissions.go index 29005012..1a335c1c 100644 --- a/internal/api/permissions.go +++ b/internal/api/permissions.go @@ -4,12 +4,18 @@ import ( "errors" "fmt" "net/http" + "sync" + "time" "github.com/labstack/echo/v4" "go.infratographer.com/permissions-api/internal/query" + "go.infratographer.com/permissions-api/internal/types" "go.infratographer.com/x/gidx" + "go.uber.org/multierr" ) +const maxCheckDuration = 5 * time.Second + // checkAction will check if a subject is allowed to perform an action on a resource. // This is the permissions check endpoint. // It will return a 200 if the subject is allowed to perform the action on the resource. @@ -72,6 +78,169 @@ func (r *Router) checkAction(c echo.Context) error { return c.JSON(http.StatusOK, echo.Map{}) } +type checkPermissionsRequest struct { + Actions []checkAction `json:"actions"` +} + +type checkAction struct { + ResourceID string `json:"resource_id"` + Action string `json:"action"` +} + +type checkStatus struct { + Resource types.Resource + Action string + Error error +} + +// checkAllActions will check if a subject is allowed to perform an action on a list of resources. +// This is the permissions check endpoint. +// It will return a 200 if the subject is allowed to perform all requested resource actions. +// It will return a 400 if the request is invalid. +// It will return a 403 if the subject is not allowed to perform all requested resource actions. +// +// Note that this expects a JWT token to be present in the request. This token must +// contain the subject of the request in the "sub" claim. +func (r *Router) checkAllActions(c echo.Context) error { + ctx, span := tracer.Start(c.Request().Context(), "api.checkAllActions") + defer span.End() + + // Subject validation + subject, err := currentSubject(c) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "failed to get the subject").SetInternal(err) + } + + subjectResource, err := r.engine.NewResourceFromID(subject) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "error processing subject ID").SetInternal(err) + } + + var reqBody checkPermissionsRequest + + if err := c.Bind(&reqBody); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "error parsing request body").SetInternal(err) + } + + var errs []error + + results := make([]*checkStatus, len(reqBody.Actions)) + + for i, check := range reqBody.Actions { + if check.Action == "" { + errs = append(errs, fmt.Errorf("check %d: no action defined", i)) + + continue + } + + resourceID, err := gidx.Parse(check.ResourceID) + if err != nil { + errs = append(errs, fmt.Errorf("check %d: %w: error parsing resource id: %s", i, err, check.ResourceID)) + + continue + } + + resource, err := r.engine.NewResourceFromID(resourceID) + if err != nil { + errs = append(errs, fmt.Errorf("check %d: %w: error creating resource from id: %s", i, err, resourceID.String())) + + continue + } + + results[i] = &checkStatus{ + Resource: resource, + Action: check.Action, + } + } + + if len(errs) != 0 { + return echo.NewHTTPError(http.StatusBadRequest, "invalid check request").SetInternal(multierr.Combine(errs...)) + } + + checkCh := make(chan *checkStatus) + + wg := new(sync.WaitGroup) + + for i := 0; i < r.concurrentChecks; i++ { + wg.Add(1) + + go func() { + defer wg.Done() + + for check := range checkCh { + // Check the permissions + err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) + if err != nil { + check.Error = err + } + } + }() + } + + wg.Add(1) + + go func() { + defer wg.Done() + + for _, check := range results { + checkCh <- check + } + + close(checkCh) + }() + + doneCh := make(chan struct{}) + + go func() { + defer close(doneCh) + + wg.Wait() + }() + + select { + case <-doneCh: + case <-ctx.Done(): + return echo.NewHTTPError(http.StatusInternalServerError, "request cancelled").WithInternal(ctx.Err()) + case <-time.After(maxCheckDuration): + return echo.NewHTTPError(http.StatusInternalServerError, "checks didn't complete in time") + } + + var ( + unauthorizedErrors []error + internalErrors []error + allErrors []error + ) + + for i, check := range results { + if check.Error != nil { + if errors.Is(check.Error, query.ErrActionNotAssigned) { + err := fmt.Errorf("subject '%s' does not have permission to perform action '%s' on resource '%s'", + subject, check.Action, check.Resource.ID.String()) + + unauthorizedErrors = append(unauthorizedErrors, err) + allErrors = append(allErrors, err) + } else { + err := fmt.Errorf("check %d: %w", i, check.Error) + + internalErrors = append(internalErrors, err) + allErrors = append(allErrors, err) + } + } + } + + if len(internalErrors) != 0 { + return echo.NewHTTPError(http.StatusInternalServerError, "an error occurred checking permissions").SetInternal(multierr.Combine(allErrors...)) + } + + if len(unauthorizedErrors) != 0 { + msg := fmt.Sprintf("subject '%s' does not have permission to the requests resource actions", subject) + + return echo.NewHTTPError(http.StatusForbidden, msg).SetInternal(multierr.Combine(allErrors...)) + } + + return nil +} + func getParam(c echo.Context, name string) (string, bool) { values, ok := c.QueryParams()[name] if !ok { diff --git a/internal/api/router.go b/internal/api/router.go index 458071bd..fe39424b 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -18,22 +18,32 @@ type Router struct { authMW echo.MiddlewareFunc engine query.Engine logger *zap.SugaredLogger + + concurrentChecks int } // NewRouter returns a new api router -func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, l *zap.SugaredLogger) (*Router, error) { +func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, options ...Option) (*Router, error) { auth, err := echojwtx.NewAuth(context.Background(), authCfg) if err != nil { return nil, err } - out := &Router{ + router := &Router{ authMW: auth.Middleware(), engine: engine, - logger: l.Named("api"), + logger: zap.NewNop().Sugar(), + + concurrentChecks: 5, + } + + for _, opt := range options { + if err := opt(router); err != nil { + return nil, err + } } - return out, nil + return router, nil } // Routes will add the routes for this API version to a router group @@ -58,6 +68,32 @@ func (r *Router) Routes(rg *echo.Group) { // /allow is the permissions check endpoint v1.GET("/allow", r.checkAction) + v1.POST("/allow", r.checkAllActions) + } +} + +// Option defines a router option function. +type Option func(r *Router) error + +// WithLogger sets the logger for the router. +func WithLogger(logger *zap.SugaredLogger) Option { + return func(r *Router) error { + r.logger = logger.Named("api") + + return nil + } +} + +// WithCheckConcurrency sets the check concurrency for bulk permission checks. +func WithCheckConcurrency(count int) Option { + return func(r *Router) error { + if count <= 0 { + count = 5 + } + + r.concurrentChecks = count + + return nil } } From ade6fc1c00e25a52ce43a7d6b108097bf4556226 Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Mon, 24 Jul 2023 16:58:53 +0000 Subject: [PATCH 2/5] split out checker code Signed-off-by: Mike Mason --- pkg/permissions/checker.go | 54 ++++++++++++++++++++++++++++++++++ pkg/permissions/permissions.go | 44 --------------------------- 2 files changed, 54 insertions(+), 44 deletions(-) create mode 100644 pkg/permissions/checker.go diff --git a/pkg/permissions/checker.go b/pkg/permissions/checker.go new file mode 100644 index 00000000..991dc92e --- /dev/null +++ b/pkg/permissions/checker.go @@ -0,0 +1,54 @@ +package permissions + +import ( + "context" + + "github.com/labstack/echo/v4" + "go.infratographer.com/x/gidx" +) + +var ( + // CheckerCtxKey is the context key used to set the checker handling function + CheckerCtxKey = checkerCtxKey{} + + // DefaultAllowChecker defaults to allow when checker is disabled or skipped + DefaultAllowChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { + return nil + } + + // DefaultDenyChecker defaults to denied when checker is disabled or skipped + DefaultDenyChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { + return ErrPermissionDenied + } +) + +// Checker defines the checker function definition +type Checker func(ctx context.Context, resource gidx.PrefixedID, action string) error + +type checkerCtxKey struct{} + +func setCheckerContext(c echo.Context, checker Checker) { + if checker == nil { + checker = DefaultDenyChecker + } + + req := c.Request().WithContext( + context.WithValue( + c.Request().Context(), + CheckerCtxKey, + checker, + ), + ) + + c.SetRequest(req) +} + +// CheckAccess runs the checker function to check if the provided resource and action are supported. +func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) error { + checker, ok := ctx.Value(CheckerCtxKey).(Checker) + if !ok { + return ErrCheckerNotFound + } + + return checker(ctx, resource, action) +} diff --git a/pkg/permissions/permissions.go b/pkg/permissions/permissions.go index 8bc72541..e3763863 100644 --- a/pkg/permissions/permissions.go +++ b/pkg/permissions/permissions.go @@ -27,19 +27,6 @@ const ( ) var ( - // CheckerCtxKey is the context key used to set the checker handling function - CheckerCtxKey = checkerCtxKey{} - - // DefaultAllowChecker defaults to allow when checker is disabled or skipped - DefaultAllowChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { - return nil - } - - // DefaultDenyChecker defaults to denied when checker is disabled or skipped - DefaultDenyChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { - return ErrPermissionDenied - } - defaultClient = &http.Client{ Timeout: defaultClientTimeout, Transport: otelhttp.NewTransport(http.DefaultTransport), @@ -48,11 +35,6 @@ var ( tracer = otel.GetTracerProvider().Tracer("go.infratographer.com/permissions-api/pkg/permissions") ) -// Checker defines the checker function definition -type Checker func(ctx context.Context, resource gidx.PrefixedID, action string) error - -type checkerCtxKey struct{} - // Permissions handles supporting authorization checks type Permissions struct { enabled bool @@ -191,22 +173,6 @@ func New(config Config, options ...Option) (*Permissions, error) { return p, nil } -func setCheckerContext(c echo.Context, checker Checker) { - if checker == nil { - checker = DefaultDenyChecker - } - - req := c.Request().WithContext( - context.WithValue( - c.Request().Context(), - CheckerCtxKey, - checker, - ), - ) - - c.SetRequest(req) -} - func ensureValidServerResponse(resp *http.Response) error { if resp.StatusCode >= http.StatusMultiStatus { if resp.StatusCode == http.StatusForbidden { @@ -218,13 +184,3 @@ func ensureValidServerResponse(resp *http.Response) error { return nil } - -// CheckAccess runs the checker function to check if the provided resource and action are supported. -func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) error { - checker, ok := ctx.Value(CheckerCtxKey).(Checker) - if !ok { - return ErrCheckerNotFound - } - - return checker(ctx, resource, action) -} From 030f958cb75b4e2ee4725be11ead19fba27a59a6 Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Mon, 24 Jul 2023 17:21:08 +0000 Subject: [PATCH 3/5] add bulk permission checks to client Signed-off-by: Mike Mason --- pkg/permissions/checker.go | 29 +++++++++++++++--- pkg/permissions/permissions.go | 46 +++++++++++++++++++---------- pkg/permissions/permissions_test.go | 38 +++++++++++++++++------- 3 files changed, 83 insertions(+), 30 deletions(-) diff --git a/pkg/permissions/checker.go b/pkg/permissions/checker.go index 991dc92e..1a2f5e58 100644 --- a/pkg/permissions/checker.go +++ b/pkg/permissions/checker.go @@ -12,18 +12,24 @@ var ( CheckerCtxKey = checkerCtxKey{} // DefaultAllowChecker defaults to allow when checker is disabled or skipped - DefaultAllowChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { + DefaultAllowChecker Checker = func(_ context.Context, _ ...AccessRequest) error { return nil } // DefaultDenyChecker defaults to denied when checker is disabled or skipped - DefaultDenyChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { + DefaultDenyChecker Checker = func(_ context.Context, _ ...AccessRequest) error { return ErrPermissionDenied } ) // Checker defines the checker function definition -type Checker func(ctx context.Context, resource gidx.PrefixedID, action string) error +type Checker func(ctx context.Context, requests ...AccessRequest) error + +// AccessRequest defines the required fields to check permissions access. +type AccessRequest struct { + ResourceID gidx.PrefixedID `json:"resource_id"` + Action string `json:"action"` +} type checkerCtxKey struct{} @@ -50,5 +56,20 @@ func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) e return ErrCheckerNotFound } - return checker(ctx, resource, action) + request := AccessRequest{ + ResourceID: resource, + Action: action, + } + + return checker(ctx, request) +} + +// CheckAll runs the checker function to check if all the provided resources and actions are permitted. +func CheckAll(ctx context.Context, requests ...AccessRequest) error { + checker, ok := ctx.Value(CheckerCtxKey).(Checker) + if !ok { + return ErrCheckerNotFound + } + + return checker(ctx, requests...) } diff --git a/pkg/permissions/permissions.go b/pkg/permissions/permissions.go index e3763863..33b64d47 100644 --- a/pkg/permissions/permissions.go +++ b/pkg/permissions/permissions.go @@ -1,7 +1,10 @@ package permissions import ( + "bytes" "context" + "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -12,7 +15,6 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/pkg/errors" "go.infratographer.com/x/echojwtx" - "go.infratographer.com/x/gidx" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -79,35 +81,49 @@ func (p *Permissions) Middleware() echo.MiddlewareFunc { } } +type checkPermissionRequest struct { + Actions []AccessRequest `json:"actions"` +} + func (p *Permissions) checker(c echo.Context, actor, token string) Checker { - return func(ctx context.Context, resource gidx.PrefixedID, action string) error { - ctx, span := tracer.Start(ctx, "permissions.checkAccess") + return func(ctx context.Context, requests ...AccessRequest) error { + ctx, span := tracer.Start(ctx, "permissions.checker") defer span.End() span.SetAttributes( attribute.String("permissions.actor", actor), - attribute.String("permissions.action", action), - attribute.String("permissions.resource", resource.String()), + attribute.Int("permissions.requests", len(requests)), ) - logger := p.logger.With("actor", actor, "resource", resource.String(), "action", action) + logger := p.logger.With("actor", actor, "requests", len(requests)) + + request := checkPermissionRequest{ + Actions: requests, + } + + var reqBody bytes.Buffer + + if err := json.NewEncoder(&reqBody).Encode(request); err != nil { + err = errors.WithStack(err) - values := url.Values{} - values.Add("resource", resource.String()) - values.Add("action", action) + span.SetStatus(codes.Error, err.Error()) + logger.Errorw("failed to encode request body", "error", err) - url := *p.url - url.RawQuery = values.Encode() + return err + } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url.String(), &reqBody) if err != nil { - span.SetStatus(codes.Error, errors.WithStack(err).Error()) + err = errors.WithStack(err) + + span.SetStatus(codes.Error, err.Error()) logger.Errorw("failed to create checker request", "error", err) - return errors.WithStack(err) + return err } req.Header.Set(echo.HeaderAuthorization, c.Request().Header.Get(echo.HeaderAuthorization)) + req.Header.Set(echo.HeaderContentType, "application/json") resp, err := p.client.Do(req) if err != nil { @@ -179,7 +195,7 @@ func ensureValidServerResponse(resp *http.Response) error { return ErrPermissionDenied } - return ErrBadResponse + return fmt.Errorf("%w: %d", ErrBadResponse, resp.StatusCode) } return nil diff --git a/pkg/permissions/permissions_test.go b/pkg/permissions/permissions_test.go index d5f9b58b..90949920 100644 --- a/pkg/permissions/permissions_test.go +++ b/pkg/permissions/permissions_test.go @@ -1,6 +1,7 @@ package permissions_test import ( + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -30,25 +31,40 @@ func TestPermissions(t *testing.T) { return } - resource, err := gidx.Parse(r.URL.Query().Get("resource")) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - - return + var reqBody struct { + Actions []struct { + ResourceID string `json:"resource_id"` + Action string `json:"action"` + } `json:"actions"` } - action := r.URL.Query().Get("action") - - if resource != allowedID && resource != deniedID { + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { w.WriteHeader(http.StatusInternalServerError) return } - if resource != allowedID || !actions[action] { - w.WriteHeader(http.StatusForbidden) + for _, request := range reqBody.Actions { + resource, err := gidx.Parse(request.ResourceID) + if err != nil { + w.WriteHeader(http.StatusBadRequest) - return + return + } + + action := request.Action + + if resource != allowedID && resource != deniedID { + w.WriteHeader(http.StatusInternalServerError) + + return + } + + if resource != allowedID || !actions[action] { + w.WriteHeader(http.StatusForbidden) + + return + } } })) From aefb5ec5262114b36e24a480e25547f990a2ca3f Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Mon, 24 Jul 2023 17:45:41 +0000 Subject: [PATCH 4/5] correct lint issues Signed-off-by: Mike Mason --- internal/api/permissions.go | 20 ++++++++++++++++---- internal/api/router.go | 2 +- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/internal/api/permissions.go b/internal/api/permissions.go index 1a335c1c..93411ae0 100644 --- a/internal/api/permissions.go +++ b/internal/api/permissions.go @@ -14,7 +14,19 @@ import ( "go.uber.org/multierr" ) -const maxCheckDuration = 5 * time.Second +const ( + defaultMaxCheckConcurrency = 5 + + maxCheckDuration = 5 * time.Second +) + +var ( + // ErrNoActionDefined is the error returned when an access request is has no action defined + ErrNoActionDefined = errors.New("no action defined") + + // ErrAccessDenied is returned when access is denied + ErrAccessDenied = errors.New("access denied") +) // checkAction will check if a subject is allowed to perform an action on a resource. // This is the permissions check endpoint. @@ -128,7 +140,7 @@ func (r *Router) checkAllActions(c echo.Context) error { for i, check := range reqBody.Actions { if check.Action == "" { - errs = append(errs, fmt.Errorf("check %d: no action defined", i)) + errs = append(errs, fmt.Errorf("check %d: %w", i, ErrNoActionDefined)) continue } @@ -214,8 +226,8 @@ func (r *Router) checkAllActions(c echo.Context) error { for i, check := range results { if check.Error != nil { if errors.Is(check.Error, query.ErrActionNotAssigned) { - err := fmt.Errorf("subject '%s' does not have permission to perform action '%s' on resource '%s'", - subject, check.Action, check.Resource.ID.String()) + err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", + ErrAccessDenied, subject, check.Action, check.Resource.ID.String()) unauthorizedErrors = append(unauthorizedErrors, err) allErrors = append(allErrors, err) diff --git a/internal/api/router.go b/internal/api/router.go index fe39424b..bd3d1c29 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -34,7 +34,7 @@ func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, options ...Opti engine: engine, logger: zap.NewNop().Sugar(), - concurrentChecks: 5, + concurrentChecks: defaultMaxCheckConcurrency, } for _, opt := range options { From db7b6815052c1134d1340ec9e7bdfe1c8e71f74a Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Tue, 25 Jul 2023 20:14:49 +0000 Subject: [PATCH 5/5] implement review suggestions - Splits up the Request and Results - Switch to using context.WithTimeout instead of time.After to ensure context is cancelled - Replaces WaitGroup with looping through the known count and logging all errors Signed-off-by: Mike Mason --- internal/api/permissions.go | 125 +++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 60 deletions(-) diff --git a/internal/api/permissions.go b/internal/api/permissions.go index 93411ae0..ed2f2073 100644 --- a/internal/api/permissions.go +++ b/internal/api/permissions.go @@ -1,10 +1,10 @@ package api import ( + "context" "errors" "fmt" "net/http" - "sync" "time" "github.com/labstack/echo/v4" @@ -99,10 +99,15 @@ type checkAction struct { Action string `json:"action"` } -type checkStatus struct { +type checkRequest struct { + Index int Resource types.Resource Action string - Error error +} + +type checkResult struct { + Request checkRequest + Error error } // checkAllActions will check if a subject is allowed to perform an action on a list of resources. @@ -136,7 +141,7 @@ func (r *Router) checkAllActions(c echo.Context) error { var errs []error - results := make([]*checkStatus, len(reqBody.Actions)) + requestsCh := make(chan checkRequest, len(reqBody.Actions)) for i, check := range reqBody.Actions { if check.Action == "" { @@ -159,93 +164,93 @@ func (r *Router) checkAllActions(c echo.Context) error { continue } - results[i] = &checkStatus{ + requestsCh <- checkRequest{ + Index: i, Resource: resource, Action: check.Action, } } + close(requestsCh) + if len(errs) != 0 { return echo.NewHTTPError(http.StatusBadRequest, "invalid check request").SetInternal(multierr.Combine(errs...)) } - checkCh := make(chan *checkStatus) + resultsCh := make(chan checkResult, len(reqBody.Actions)) - wg := new(sync.WaitGroup) + ctx, cancel := context.WithTimeout(ctx, maxCheckDuration) - for i := 0; i < r.concurrentChecks; i++ { - wg.Add(1) + defer cancel() + for i := 0; i < r.concurrentChecks; i++ { go func() { - defer wg.Done() - - for check := range checkCh { - // Check the permissions - err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) - if err != nil { - check.Error = err + for { + var result checkResult + + select { + case check, ok := <-requestsCh: + // if channel is closed, quit the go routine. + if !ok { + return + } + + result.Request = check + + // Check the permissions + err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) + if err != nil { + result.Error = err + } + case <-ctx.Done(): + result.Error = ctx.Err() } + + resultsCh <- result } }() } - wg.Add(1) - - go func() { - defer wg.Done() - - for _, check := range results { - checkCh <- check - } - - close(checkCh) - }() - - doneCh := make(chan struct{}) - - go func() { - defer close(doneCh) - - wg.Wait() - }() - - select { - case <-doneCh: - case <-ctx.Done(): - return echo.NewHTTPError(http.StatusInternalServerError, "request cancelled").WithInternal(ctx.Err()) - case <-time.After(maxCheckDuration): - return echo.NewHTTPError(http.StatusInternalServerError, "checks didn't complete in time") - } - var ( - unauthorizedErrors []error - internalErrors []error + unauthorizedErrors int + internalErrors int allErrors []error ) - for i, check := range results { - if check.Error != nil { - if errors.Is(check.Error, query.ErrActionNotAssigned) { - err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", - ErrAccessDenied, subject, check.Action, check.Resource.ID.String()) + for i := 0; i < len(reqBody.Actions); i++ { + select { + case result := <-resultsCh: + if result.Error != nil { + if errors.Is(result.Error, query.ErrActionNotAssigned) { + err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", + ErrAccessDenied, subject, result.Request.Action, result.Request.Resource.ID.String()) - unauthorizedErrors = append(unauthorizedErrors, err) - allErrors = append(allErrors, err) - } else { - err := fmt.Errorf("check %d: %w", i, check.Error) + unauthorizedErrors++ + + allErrors = append(allErrors, err) + } else { + err := fmt.Errorf("check %d: %w", result.Request.Index, result.Error) + + internalErrors++ + + allErrors = append(allErrors, err) + } + } + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + internalErrors++ - internalErrors = append(internalErrors, err) - allErrors = append(allErrors, err) + allErrors = append(allErrors, ctx.Err()) } } } - if len(internalErrors) != 0 { + if internalErrors != 0 { return echo.NewHTTPError(http.StatusInternalServerError, "an error occurred checking permissions").SetInternal(multierr.Combine(allErrors...)) } - if len(unauthorizedErrors) != 0 { - msg := fmt.Sprintf("subject '%s' does not have permission to the requests resource actions", subject) + if unauthorizedErrors != 0 { + msg := fmt.Sprintf("subject '%s' does not have permission to the requested resource actions", subject) return echo.NewHTTPError(http.StatusForbidden, msg).SetInternal(multierr.Combine(allErrors...)) }