diff --git a/lib/auth/access_request_test.go b/lib/auth/access_request_test.go index 989a19ebf3c76..ee634728f06ce 100644 --- a/lib/auth/access_request_test.go +++ b/lib/auth/access_request_test.go @@ -25,6 +25,7 @@ import ( "time" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -33,10 +34,14 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/accesslist" + "github.com/gravitational/teleport/api/types/header" "github.com/gravitational/teleport/api/utils/sshutils" "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/tlsca" ) @@ -932,3 +937,103 @@ func TestPromotedRequest(t *testing.T) { require.Equal(t, "ACL title", promotedRequest.GetPromotedAccessListTitle()) }) } + +func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) { + t.Parallel() + + clock := clockwork.NewFakeClock() + + mustRequest := func(suggestedReviewers ...string) types.AccessRequest { + req, err := services.NewAccessRequest("test-user", "admins") + require.NoError(t, err) + req.SetSuggestedReviewers(suggestedReviewers) + return req + } + + mustAccessList := func(name string, owners ...string) *accesslist.AccessList { + ownersSpec := make([]accesslist.Owner, len(owners)) + for i, owner := range owners { + ownersSpec[i] = accesslist.Owner{ + Name: owner, + } + } + accessList, err := accesslist.NewAccessList(header.Metadata{ + Name: name, + }, accesslist.Spec{ + Title: "simple", + Grants: accesslist.Grants{ + Roles: []string{"grant-role"}, + }, + Audit: accesslist.Audit{ + NextAuditDate: clock.Now().AddDate(1, 0, 0), + }, + Owners: ownersSpec, + }) + require.NoError(t, err) + return accessList + } + + tests := []struct { + name string + req types.AccessRequest + accessLists []*accesslist.AccessList + promotions *types.AccessRequestAllowedPromotions + expectedReviewers []string + }{ + { + name: "nil promotions", + req: mustRequest("rev1", "rev2"), + expectedReviewers: []string{"rev1", "rev2"}, + }, + { + name: "a few promotions", + req: mustRequest("rev1", "rev2"), + accessLists: []*accesslist.AccessList{ + mustAccessList("name1", "owner1", "owner2"), + mustAccessList("name2", "owner1", "owner3"), + mustAccessList("name3", "owner4", "owner5"), + }, + promotions: &types.AccessRequestAllowedPromotions{ + Promotions: []*types.AccessRequestAllowedPromotion{ + {AccessListName: "name1"}, + {AccessListName: "name2"}, + }, + }, + expectedReviewers: []string{"rev1", "rev2", "owner1", "owner2", "owner3"}, + }, + { + name: "no promotions", + req: mustRequest("rev1", "rev2"), + accessLists: []*accesslist.AccessList{ + mustAccessList("name1", "owner1", "owner2"), + mustAccessList("name2", "owner1", "owner3"), + mustAccessList("name3", "owner4", "owner5"), + }, + promotions: &types.AccessRequestAllowedPromotions{ + Promotions: []*types.AccessRequestAllowedPromotion{}, + }, + expectedReviewers: []string{"rev1", "rev2"}, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + mem, err := memory.New(memory.Config{}) + require.NoError(t, err) + accessLists, err := local.NewAccessListService(mem, clock) + require.NoError(t, err) + + ctx := context.Background() + for _, accessList := range test.accessLists { + _, err = accessLists.UpsertAccessList(ctx, accessList) + require.NoError(t, err) + } + + req := test.req.Copy() + updateAccessRequestWithAdditionalReviewers(ctx, req, accessLists, test.promotions) + require.ElementsMatch(t, test.expectedReviewers, req.GetSuggestedReviewers()) + }) + } +} diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 4168069b5ba4c..ecd86e715817a 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -51,6 +51,7 @@ import ( "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "golang.org/x/crypto/ssh" + "golang.org/x/exp/maps" "golang.org/x/exp/slices" "golang.org/x/time/rate" "google.golang.org/protobuf/types/known/durationpb" @@ -4220,6 +4221,9 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ } if req.GetDryRun() { + _, promotions := a.generateAccessRequestPromotions(ctx, req) + // update the request with additional reviewers if possible. + updateAccessRequestWithAdditionalReviewers(ctx, req, a.AccessListClient(), promotions) // Made it this far with no errors, return before creating the request // if this is a dry run. return req, nil @@ -4253,14 +4257,10 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ if err != nil { log.WithError(err).Warn("Failed to emit access request create event.") } + // calculate the promotions - reqCopy := req.Copy() - promotions, err := modules.GetModules().GenerateAccessRequestPromotions(ctx, a.Services, reqCopy) - if err != nil { - // Do not fail the request if the promotions failed to generate. - // The request promotion will be blocked, but the request can still be approved. - log.WithError(err).Warn("Failed to generate access list promotions.") - } else if promotions != nil { + reqCopy, promotions := a.generateAccessRequestPromotions(ctx, req) + if promotions != nil { // Create the promotion entry even if the allowed promotion is empty. Otherwise, we won't // be able to distinguish between an allowed empty set and generation failure. if err := a.Services.CreateAccessRequestAllowedPromotions(ctx, reqCopy, promotions); err != nil { @@ -4274,6 +4274,48 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ return req, nil } +// generateAccessRequestPromotions will return potential access list promotions for an access request. On error, this function will log +// the error and return whatever it has. The caller is expected to deal with the possibility of a nil promotions object. +func (a *Server) generateAccessRequestPromotions(ctx context.Context, req types.AccessRequest) (types.AccessRequest, *types.AccessRequestAllowedPromotions) { + reqCopy := req.Copy() + promotions, err := modules.GetModules().GenerateAccessRequestPromotions(ctx, a.Services, reqCopy) + if err != nil { + // Do not fail the request if the promotions failed to generate. + // The request promotion will be blocked, but the request can still be approved. + log.WithError(err).Warn("Failed to generate access list promotions.") + } + return reqCopy, promotions +} + +// updateAccessRequestWithAdditionalReviewers will update the given access request with additional reviewers given the promotions +// created for the access request. +func updateAccessRequestWithAdditionalReviewers(ctx context.Context, req types.AccessRequest, accessLists services.AccessListsGetter, promotions *types.AccessRequestAllowedPromotions) { + if promotions == nil { + return + } + + // For promotions, add in access list owners as additional suggested reviewers + additionalReviewers := map[string]struct{}{} + + // Iterate through the promotions, adding the owners of the corresponding access lists as reviewers. + for _, promotion := range promotions.Promotions { + accessList, err := accessLists.GetAccessList(ctx, promotion.AccessListName) + if err != nil { + log.WithError(err).Warn("Failed to get access list, skipping additional reviewers") + break + } + + for _, owner := range accessList.GetOwners() { + additionalReviewers[owner.Name] = struct{}{} + } + } + + // Only modify the original request if additional reviewers were found. + if len(additionalReviewers) > 0 { + req.SetSuggestedReviewers(append(req.GetSuggestedReviewers(), maps.Keys(additionalReviewers)...)) + } +} + func (a *Server) DeleteAccessRequest(ctx context.Context, name string) error { if err := a.Services.DeleteAccessRequest(ctx, name); err != nil { return trace.Wrap(err)