diff --git a/internal/api/permissions.go b/internal/api/permissions.go index 93411ae09..48984da29 100644 --- a/internal/api/permissions.go +++ b/internal/api/permissions.go @@ -1,10 +1,10 @@ package api import ( + "context" "errors" "fmt" "net/http" - "sync" "time" "github.com/labstack/echo/v4" @@ -99,10 +99,15 @@ type checkAction struct { Action string `json:"action"` } -type checkStatus struct { +type checkRequest struct { + Index int Resource types.Resource Action string - Error error +} + +type checkResult struct { + Request checkRequest + Error error } // checkAllActions will check if a subject is allowed to perform an action on a list of resources. @@ -136,7 +141,7 @@ func (r *Router) checkAllActions(c echo.Context) error { var errs []error - results := make([]*checkStatus, len(reqBody.Actions)) + requestsCh := make(chan checkRequest, len(reqBody.Actions)) for i, check := range reqBody.Actions { if check.Action == "" { @@ -159,93 +164,108 @@ func (r *Router) checkAllActions(c echo.Context) error { continue } - results[i] = &checkStatus{ + requestsCh <- checkRequest{ + Index: i, Resource: resource, Action: check.Action, } } + close(requestsCh) + if len(errs) != 0 { return echo.NewHTTPError(http.StatusBadRequest, "invalid check request").SetInternal(multierr.Combine(errs...)) } - checkCh := make(chan *checkStatus) + doneCh := make(chan bool) + resultsCh := make(chan checkResult, len(reqBody.Actions)) - wg := new(sync.WaitGroup) + ctx, cancel := context.WithTimeout(ctx, maxCheckDuration) - for i := 0; i < r.concurrentChecks; i++ { - wg.Add(1) + defer cancel() + for i := 0; i < r.concurrentChecks; i++ { go func() { - defer wg.Done() + for check := range requestsCh { + result := &checkResult{ + Request: check, + } - for check := range checkCh { // Check the permissions err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) if err != nil { - check.Error = err + result.Error = err + } + + // Check if doneCh has been closed, if so, don't write to resultsCh. + select { + case <-doneCh: + return + default: } + + resultsCh <- *result } }() } - wg.Add(1) + var ( + unauthorizedErrors int + internalErrors int + allErrors []error + ) go func() { - defer wg.Done() + var count int - for _, check := range results { - checkCh <- check - } + for result := range resultsCh { + count++ - close(checkCh) - }() + if result.Error != nil { + if errors.Is(result.Error, query.ErrActionNotAssigned) { + err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", + ErrAccessDenied, subject, result.Request.Action, result.Request.Resource.ID.String()) - doneCh := make(chan struct{}) + unauthorizedErrors++ - go func() { - defer close(doneCh) + allErrors = append(allErrors, err) + } else { + err := fmt.Errorf("check %d: %w", result.Request.Index, result.Error) - wg.Wait() - }() + internalErrors++ - select { - case <-doneCh: - case <-ctx.Done(): - return echo.NewHTTPError(http.StatusInternalServerError, "request cancelled").WithInternal(ctx.Err()) - case <-time.After(maxCheckDuration): - return echo.NewHTTPError(http.StatusInternalServerError, "checks didn't complete in time") - } - - var ( - unauthorizedErrors []error - internalErrors []error - allErrors []error - ) + allErrors = append(allErrors, err) + } - for i, check := range results { - if check.Error != nil { - if errors.Is(check.Error, query.ErrActionNotAssigned) { - err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", - ErrAccessDenied, subject, check.Action, check.Resource.ID.String()) + close(doneCh) + close(resultsCh) - unauthorizedErrors = append(unauthorizedErrors, err) - allErrors = append(allErrors, err) - } else { - err := fmt.Errorf("check %d: %w", i, check.Error) + return + } - internalErrors = append(internalErrors, err) - allErrors = append(allErrors, err) + if count == len(reqBody.Actions) { + close(doneCh) + close(resultsCh) } } + }() + + select { + case <-doneCh: + case <-ctx.Done(): + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return echo.NewHTTPError(http.StatusInternalServerError, "checks didn't complete in time") + } + + return echo.NewHTTPError(http.StatusInternalServerError, "request cancelled").WithInternal(ctx.Err()) } - if len(internalErrors) != 0 { + if internalErrors != 0 { return echo.NewHTTPError(http.StatusInternalServerError, "an error occurred checking permissions").SetInternal(multierr.Combine(allErrors...)) } - if len(unauthorizedErrors) != 0 { - msg := fmt.Sprintf("subject '%s' does not have permission to the requests resource actions", subject) + if unauthorizedErrors != 0 { + msg := fmt.Sprintf("subject '%s' does not have permission to the requested resource actions", subject) return echo.NewHTTPError(http.StatusForbidden, msg).SetInternal(multierr.Combine(allErrors...)) }