Skip to content

Commit

Permalink
Support bulk checks (#146)
Browse files Browse the repository at this point in the history
* add support for bulk permission check requests

Signed-off-by: Mike Mason <[email protected]>

* split out checker code

Signed-off-by: Mike Mason <[email protected]>

* add bulk permission checks to client

Signed-off-by: Mike Mason <[email protected]>

* correct lint issues

Signed-off-by: Mike Mason <[email protected]>

* implement review suggestions

- Splits up the Request and Results
- Switch to using context.WithTimeout instead of time.After to ensure context is cancelled
- Replaces WaitGroup with looping through the known count and logging all errors

Signed-off-by: Mike Mason <[email protected]>

---------

Signed-off-by: Mike Mason <[email protected]>
  • Loading branch information
mikemrm authored Jul 26, 2023
1 parent 7364ba4 commit 53a5013
Show file tree
Hide file tree
Showing 7 changed files with 361 additions and 76 deletions.
2 changes: 1 addition & 1 deletion cmd/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func serve(ctx context.Context, cfg *config.AppConfig) {
logger.Fatal("failed to initialize new server", zap.Error(err))
}

r, err := api.NewRouter(cfg.OIDC, engine, logger)
r, err := api.NewRouter(cfg.OIDC, engine, api.WithLogger(logger))
if err != nil {
logger.Fatalw("unable to initialize router", "error", err)
}
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.42.0
go.opentelemetry.io/otel v1.16.0
go.opentelemetry.io/otel/trace v1.16.0
go.uber.org/multierr v1.9.0
go.uber.org/zap v1.24.0
google.golang.org/grpc v1.56.1
gopkg.in/yaml.v3 v3.0.1
Expand Down Expand Up @@ -91,7 +92,6 @@ require (
go.opentelemetry.io/otel/sdk v1.16.0 // indirect
go.opentelemetry.io/proto/otlp v0.19.0 // indirect
go.uber.org/atomic v1.10.0 // indirect
go.uber.org/multierr v1.9.0 // indirect
golang.org/x/crypto v0.11.0 // indirect
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1 // indirect
golang.org/x/net v0.12.0 // indirect
Expand Down
186 changes: 186 additions & 0 deletions internal/api/permissions.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
package api

import (
"context"
"errors"
"fmt"
"net/http"
"time"

"github.com/labstack/echo/v4"
"go.infratographer.com/permissions-api/internal/query"
"go.infratographer.com/permissions-api/internal/types"
"go.infratographer.com/x/gidx"
"go.uber.org/multierr"
)

const (
defaultMaxCheckConcurrency = 5

maxCheckDuration = 5 * time.Second
)

var (
// ErrNoActionDefined is the error returned when an access request is has no action defined
ErrNoActionDefined = errors.New("no action defined")

// ErrAccessDenied is returned when access is denied
ErrAccessDenied = errors.New("access denied")
)

// checkAction will check if a subject is allowed to perform an action on a resource.
Expand Down Expand Up @@ -72,6 +90,174 @@ func (r *Router) checkAction(c echo.Context) error {
return c.JSON(http.StatusOK, echo.Map{})
}

type checkPermissionsRequest struct {
Actions []checkAction `json:"actions"`
}

type checkAction struct {
ResourceID string `json:"resource_id"`
Action string `json:"action"`
}

type checkRequest struct {
Index int
Resource types.Resource
Action string
}

type checkResult struct {
Request checkRequest
Error error
}

// checkAllActions will check if a subject is allowed to perform an action on a list of resources.
// This is the permissions check endpoint.
// It will return a 200 if the subject is allowed to perform all requested resource actions.
// It will return a 400 if the request is invalid.
// It will return a 403 if the subject is not allowed to perform all requested resource actions.
//
// Note that this expects a JWT token to be present in the request. This token must
// contain the subject of the request in the "sub" claim.
func (r *Router) checkAllActions(c echo.Context) error {
ctx, span := tracer.Start(c.Request().Context(), "api.checkAllActions")
defer span.End()

// Subject validation
subject, err := currentSubject(c)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "failed to get the subject").SetInternal(err)
}

subjectResource, err := r.engine.NewResourceFromID(subject)
if err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "error processing subject ID").SetInternal(err)
}

var reqBody checkPermissionsRequest

if err := c.Bind(&reqBody); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "error parsing request body").SetInternal(err)
}

var errs []error

requestsCh := make(chan checkRequest, len(reqBody.Actions))

for i, check := range reqBody.Actions {
if check.Action == "" {
errs = append(errs, fmt.Errorf("check %d: %w", i, ErrNoActionDefined))

continue
}

resourceID, err := gidx.Parse(check.ResourceID)
if err != nil {
errs = append(errs, fmt.Errorf("check %d: %w: error parsing resource id: %s", i, err, check.ResourceID))

continue
}

resource, err := r.engine.NewResourceFromID(resourceID)
if err != nil {
errs = append(errs, fmt.Errorf("check %d: %w: error creating resource from id: %s", i, err, resourceID.String()))

continue
}

requestsCh <- checkRequest{
Index: i,
Resource: resource,
Action: check.Action,
}
}

close(requestsCh)

if len(errs) != 0 {
return echo.NewHTTPError(http.StatusBadRequest, "invalid check request").SetInternal(multierr.Combine(errs...))
}

resultsCh := make(chan checkResult, len(reqBody.Actions))

ctx, cancel := context.WithTimeout(ctx, maxCheckDuration)

defer cancel()

for i := 0; i < r.concurrentChecks; i++ {
go func() {
for {
var result checkResult

select {
case check, ok := <-requestsCh:
// if channel is closed, quit the go routine.
if !ok {
return
}

result.Request = check

// Check the permissions
err := r.engine.SubjectHasPermission(ctx, subjectResource, check.Action, check.Resource)
if err != nil {
result.Error = err
}
case <-ctx.Done():
result.Error = ctx.Err()
}

resultsCh <- result
}
}()
}

var (
unauthorizedErrors int
internalErrors int
allErrors []error
)

for i := 0; i < len(reqBody.Actions); i++ {
select {
case result := <-resultsCh:
if result.Error != nil {
if errors.Is(result.Error, query.ErrActionNotAssigned) {
err := fmt.Errorf("%w: subject '%s' does not have permission to perform action '%s' on resource '%s'",
ErrAccessDenied, subject, result.Request.Action, result.Request.Resource.ID.String())

unauthorizedErrors++

allErrors = append(allErrors, err)
} else {
err := fmt.Errorf("check %d: %w", result.Request.Index, result.Error)

internalErrors++

allErrors = append(allErrors, err)
}
}
case <-ctx.Done():
if err := ctx.Err(); err != nil {
internalErrors++

allErrors = append(allErrors, ctx.Err())
}
}
}

if internalErrors != 0 {
return echo.NewHTTPError(http.StatusInternalServerError, "an error occurred checking permissions").SetInternal(multierr.Combine(allErrors...))
}

if unauthorizedErrors != 0 {
msg := fmt.Sprintf("subject '%s' does not have permission to the requested resource actions", subject)

return echo.NewHTTPError(http.StatusForbidden, msg).SetInternal(multierr.Combine(allErrors...))
}

return nil
}

func getParam(c echo.Context, name string) (string, bool) {
values, ok := c.QueryParams()[name]
if !ok {
Expand Down
44 changes: 40 additions & 4 deletions internal/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,32 @@ type Router struct {
authMW echo.MiddlewareFunc
engine query.Engine
logger *zap.SugaredLogger

concurrentChecks int
}

// NewRouter returns a new api router
func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, l *zap.SugaredLogger) (*Router, error) {
func NewRouter(authCfg echojwtx.AuthConfig, engine query.Engine, options ...Option) (*Router, error) {
auth, err := echojwtx.NewAuth(context.Background(), authCfg)
if err != nil {
return nil, err
}

out := &Router{
router := &Router{
authMW: auth.Middleware(),
engine: engine,
logger: l.Named("api"),
logger: zap.NewNop().Sugar(),

concurrentChecks: defaultMaxCheckConcurrency,
}

for _, opt := range options {
if err := opt(router); err != nil {
return nil, err
}
}

return out, nil
return router, nil
}

// Routes will add the routes for this API version to a router group
Expand All @@ -58,6 +68,32 @@ func (r *Router) Routes(rg *echo.Group) {

// /allow is the permissions check endpoint
v1.GET("/allow", r.checkAction)
v1.POST("/allow", r.checkAllActions)
}
}

// Option defines a router option function.
type Option func(r *Router) error

// WithLogger sets the logger for the router.
func WithLogger(logger *zap.SugaredLogger) Option {
return func(r *Router) error {
r.logger = logger.Named("api")

return nil
}
}

// WithCheckConcurrency sets the check concurrency for bulk permission checks.
func WithCheckConcurrency(count int) Option {
return func(r *Router) error {
if count <= 0 {
count = 5
}

r.concurrentChecks = count

return nil
}
}

Expand Down
75 changes: 75 additions & 0 deletions pkg/permissions/checker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
package permissions

import (
"context"

"github.com/labstack/echo/v4"
"go.infratographer.com/x/gidx"
)

var (
// CheckerCtxKey is the context key used to set the checker handling function
CheckerCtxKey = checkerCtxKey{}

// DefaultAllowChecker defaults to allow when checker is disabled or skipped
DefaultAllowChecker Checker = func(_ context.Context, _ ...AccessRequest) error {
return nil
}

// DefaultDenyChecker defaults to denied when checker is disabled or skipped
DefaultDenyChecker Checker = func(_ context.Context, _ ...AccessRequest) error {
return ErrPermissionDenied
}
)

// Checker defines the checker function definition
type Checker func(ctx context.Context, requests ...AccessRequest) error

// AccessRequest defines the required fields to check permissions access.
type AccessRequest struct {
ResourceID gidx.PrefixedID `json:"resource_id"`
Action string `json:"action"`
}

type checkerCtxKey struct{}

func setCheckerContext(c echo.Context, checker Checker) {
if checker == nil {
checker = DefaultDenyChecker
}

req := c.Request().WithContext(
context.WithValue(
c.Request().Context(),
CheckerCtxKey,
checker,
),
)

c.SetRequest(req)
}

// CheckAccess runs the checker function to check if the provided resource and action are supported.
func CheckAccess(ctx context.Context, resource gidx.PrefixedID, action string) error {
checker, ok := ctx.Value(CheckerCtxKey).(Checker)
if !ok {
return ErrCheckerNotFound
}

request := AccessRequest{
ResourceID: resource,
Action: action,
}

return checker(ctx, request)
}

// CheckAll runs the checker function to check if all the provided resources and actions are permitted.
func CheckAll(ctx context.Context, requests ...AccessRequest) error {
checker, ok := ctx.Value(CheckerCtxKey).(Checker)
if !ok {
return ErrCheckerNotFound
}

return checker(ctx, requests...)
}
Loading

0 comments on commit 53a5013

Please sign in to comment.