diff --git a/pkg/permissions/checker.go b/pkg/permissions/checker.go index 991dc92e8..1a2f5e581 100644 --- a/pkg/permissions/checker.go +++ b/pkg/permissions/checker.go @@ -12,18 +12,24 @@ var ( CheckerCtxKey = checkerCtxKey{} // DefaultAllowChecker defaults to allow when checker is disabled or skipped - DefaultAllowChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { + DefaultAllowChecker Checker = func(_ context.Context, _ ...AccessRequest) error { return nil } // DefaultDenyChecker defaults to denied when checker is disabled or skipped - DefaultDenyChecker Checker = func(_ context.Context, _ gidx.PrefixedID, _ string) error { + DefaultDenyChecker Checker = func(_ context.Context, _ ...AccessRequest) error { return ErrPermissionDenied } ) // Checker defines the checker function definition -type Checker func(ctx context.Context, resource gidx.PrefixedID, action string) error +type Checker func(ctx context.Context, requests ...AccessRequest) error + +// AccessRequest defines the required fields to check permissions access. +type AccessRequest struct { + ResourceID gidx.PrefixedID `json:"resource_id"` + Action string `json:"action"` +} type checkerCtxKey struct{} @@ -50,5 +56,20 @@ func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) e return ErrCheckerNotFound } - return checker(ctx, resource, action) + request := AccessRequest{ + ResourceID: resource, + Action: action, + } + + return checker(ctx, request) +} + +// CheckAll runs the checker function to check if all the provided resources and actions are permitted. +func CheckAll(ctx context.Context, requests ...AccessRequest) error { + checker, ok := ctx.Value(CheckerCtxKey).(Checker) + if !ok { + return ErrCheckerNotFound + } + + return checker(ctx, requests...) } diff --git a/pkg/permissions/permissions.go b/pkg/permissions/permissions.go index e37638637..33b64d472 100644 --- a/pkg/permissions/permissions.go +++ b/pkg/permissions/permissions.go @@ -1,7 +1,10 @@ package permissions import ( + "bytes" "context" + "encoding/json" + "fmt" "io" "net/http" "net/url" @@ -12,7 +15,6 @@ import ( "github.com/labstack/echo/v4/middleware" "github.com/pkg/errors" "go.infratographer.com/x/echojwtx" - "go.infratographer.com/x/gidx" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -79,35 +81,49 @@ func (p *Permissions) Middleware() echo.MiddlewareFunc { } } +type checkPermissionRequest struct { + Actions []AccessRequest `json:"actions"` +} + func (p *Permissions) checker(c echo.Context, actor, token string) Checker { - return func(ctx context.Context, resource gidx.PrefixedID, action string) error { - ctx, span := tracer.Start(ctx, "permissions.checkAccess") + return func(ctx context.Context, requests ...AccessRequest) error { + ctx, span := tracer.Start(ctx, "permissions.checker") defer span.End() span.SetAttributes( attribute.String("permissions.actor", actor), - attribute.String("permissions.action", action), - attribute.String("permissions.resource", resource.String()), + attribute.Int("permissions.requests", len(requests)), ) - logger := p.logger.With("actor", actor, "resource", resource.String(), "action", action) + logger := p.logger.With("actor", actor, "requests", len(requests)) + + request := checkPermissionRequest{ + Actions: requests, + } + + var reqBody bytes.Buffer + + if err := json.NewEncoder(&reqBody).Encode(request); err != nil { + err = errors.WithStack(err) - values := url.Values{} - values.Add("resource", resource.String()) - values.Add("action", action) + span.SetStatus(codes.Error, err.Error()) + logger.Errorw("failed to encode request body", "error", err) - url := *p.url - url.RawQuery = values.Encode() + return err + } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url.String(), nil) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, p.url.String(), &reqBody) if err != nil { - span.SetStatus(codes.Error, errors.WithStack(err).Error()) + err = errors.WithStack(err) + + span.SetStatus(codes.Error, err.Error()) logger.Errorw("failed to create checker request", "error", err) - return errors.WithStack(err) + return err } req.Header.Set(echo.HeaderAuthorization, c.Request().Header.Get(echo.HeaderAuthorization)) + req.Header.Set(echo.HeaderContentType, "application/json") resp, err := p.client.Do(req) if err != nil { @@ -179,7 +195,7 @@ func ensureValidServerResponse(resp *http.Response) error { return ErrPermissionDenied } - return ErrBadResponse + return fmt.Errorf("%w: %d", ErrBadResponse, resp.StatusCode) } return nil diff --git a/pkg/permissions/permissions_test.go b/pkg/permissions/permissions_test.go index d5f9b58ba..90949920e 100644 --- a/pkg/permissions/permissions_test.go +++ b/pkg/permissions/permissions_test.go @@ -1,6 +1,7 @@ package permissions_test import ( + "encoding/json" "net/http" "net/http/httptest" "testing" @@ -30,25 +31,40 @@ func TestPermissions(t *testing.T) { return } - resource, err := gidx.Parse(r.URL.Query().Get("resource")) - if err != nil { - w.WriteHeader(http.StatusBadRequest) - - return + var reqBody struct { + Actions []struct { + ResourceID string `json:"resource_id"` + Action string `json:"action"` + } `json:"actions"` } - action := r.URL.Query().Get("action") - - if resource != allowedID && resource != deniedID { + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { w.WriteHeader(http.StatusInternalServerError) return } - if resource != allowedID || !actions[action] { - w.WriteHeader(http.StatusForbidden) + for _, request := range reqBody.Actions { + resource, err := gidx.Parse(request.ResourceID) + if err != nil { + w.WriteHeader(http.StatusBadRequest) - return + return + } + + action := request.Action + + if resource != allowedID && resource != deniedID { + w.WriteHeader(http.StatusInternalServerError) + + return + } + + if resource != allowedID || !actions[action] { + w.WriteHeader(http.StatusForbidden) + + return + } } }))