From db7b6815052c1134d1340ec9e7bdfe1c8e71f74a Mon Sep 17 00:00:00 2001 From: Mike Mason Date: Tue, 25 Jul 2023 20:14:49 +0000 Subject: [PATCH] implement review suggestions - Splits up the Request and Results - Switch to using context.WithTimeout instead of time.After to ensure context is cancelled - Replaces WaitGroup with looping through the known count and logging all errors Signed-off-by: Mike Mason --- internal/api/permissions.go | 125 +++++++++++++++++++----------------- 1 file changed, 65 insertions(+), 60 deletions(-) diff --git a/internal/api/permissions.go b/internal/api/permissions.go index 93411ae0..ed2f2073 100644 --- a/internal/api/permissions.go +++ b/internal/api/permissions.go @@ -1,10 +1,10 @@ package api import ( + "context" "errors" "fmt" "net/http" - "sync" "time" "github.com/labstack/echo/v4" @@ -99,10 +99,15 @@ type checkAction struct { Action string `json:"action"` } -type checkStatus struct { +type checkRequest struct { + Index int Resource types.Resource Action string - Error error +} + +type checkResult struct { + Request checkRequest + Error error } // checkAllActions will check if a subject is allowed to perform an action on a list of resources. @@ -136,7 +141,7 @@ func (r *Router) checkAllActions(c echo.Context) error { var errs []error - results := make([]*checkStatus, len(reqBody.Actions)) + requestsCh := make(chan checkRequest, len(reqBody.Actions)) for i, check := range reqBody.Actions { if check.Action == "" { @@ -159,93 +164,93 @@ func (r *Router) checkAllActions(c echo.Context) error { continue } - results[i] = &checkStatus{ + requestsCh <- checkRequest{ + Index: i, Resource: resource, Action: check.Action, } } + close(requestsCh) + if len(errs) != 0 { return echo.NewHTTPError(http.StatusBadRequest, "invalid check request").SetInternal(multierr.Combine(errs...)) } - checkCh := make(chan *checkStatus) + resultsCh := make(chan checkResult, len(reqBody.Actions)) - wg := new(sync.WaitGroup) + ctx, cancel := context.WithTimeout(ctx, maxCheckDuration) - for i := 0; i < r.concurrentChecks; i++ { - wg.Add(1) + defer cancel() + for i := 0; i < r.concurrentChecks; i++ { go func() { - defer wg.Done() - - for check := range checkCh { - // Check the permissions - err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) - if err != nil { - check.Error = err + for { + var result checkResult + + select { + case check, ok := <-requestsCh: + // if channel is closed, quit the go routine. + if !ok { + return + } + + result.Request = check + + // Check the permissions + err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource) + if err != nil { + result.Error = err + } + case <-ctx.Done(): + result.Error = ctx.Err() } + + resultsCh <- result } }() } - wg.Add(1) - - go func() { - defer wg.Done() - - for _, check := range results { - checkCh <- check - } - - close(checkCh) - }() - - doneCh := make(chan struct{}) - - go func() { - defer close(doneCh) - - wg.Wait() - }() - - select { - case <-doneCh: - case <-ctx.Done(): - return echo.NewHTTPError(http.StatusInternalServerError, "request cancelled").WithInternal(ctx.Err()) - case <-time.After(maxCheckDuration): - return echo.NewHTTPError(http.StatusInternalServerError, "checks didn't complete in time") - } - var ( - unauthorizedErrors []error - internalErrors []error + unauthorizedErrors int + internalErrors int allErrors []error ) - for i, check := range results { - if check.Error != nil { - if errors.Is(check.Error, query.ErrActionNotAssigned) { - err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", - ErrAccessDenied, subject, check.Action, check.Resource.ID.String()) + for i := 0; i < len(reqBody.Actions); i++ { + select { + case result := <-resultsCh: + if result.Error != nil { + if errors.Is(result.Error, query.ErrActionNotAssigned) { + err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'", + ErrAccessDenied, subject, result.Request.Action, result.Request.Resource.ID.String()) - unauthorizedErrors = append(unauthorizedErrors, err) - allErrors = append(allErrors, err) - } else { - err := fmt.Errorf("check %d: %w", i, check.Error) + unauthorizedErrors++ + + allErrors = append(allErrors, err) + } else { + err := fmt.Errorf("check %d: %w", result.Request.Index, result.Error) + + internalErrors++ + + allErrors = append(allErrors, err) + } + } + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + internalErrors++ - internalErrors = append(internalErrors, err) - allErrors = append(allErrors, err) + allErrors = append(allErrors, ctx.Err()) } } } - if len(internalErrors) != 0 { + if internalErrors != 0 { return echo.NewHTTPError(http.StatusInternalServerError, "an error occurred checking permissions").SetInternal(multierr.Combine(allErrors...)) } - if len(unauthorizedErrors) != 0 { - msg := fmt.Sprintf("subject '%s' does not have permission to the requests resource actions", subject) + if unauthorizedErrors != 0 { + msg := fmt.Sprintf("subject '%s' does not have permission to the requested resource actions", subject) return echo.NewHTTPError(http.StatusForbidden, msg).SetInternal(multierr.Combine(allErrors...)) }