Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions lib/auth/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand All @@ -33,10 +34,14 @@ import (
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/accesslist"
"github.com/gravitational/teleport/api/types/header"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand Down Expand Up @@ -932,3 +937,103 @@ func TestPromotedRequest(t *testing.T) {
require.Equal(t, "ACL title", promotedRequest.GetPromotedAccessListTitle())
})
}

func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) {
t.Parallel()

clock := clockwork.NewFakeClock()

mustRequest := func(suggestedReviewers ...string) types.AccessRequest {
req, err := services.NewAccessRequest("test-user", "admins")
require.NoError(t, err)
req.SetSuggestedReviewers(suggestedReviewers)
return req
}

mustAccessList := func(name string, owners ...string) *accesslist.AccessList {
ownersSpec := make([]accesslist.Owner, len(owners))
for i, owner := range owners {
ownersSpec[i] = accesslist.Owner{
Name: owner,
}
}
accessList, err := accesslist.NewAccessList(header.Metadata{
Name: name,
}, accesslist.Spec{
Title: "simple",
Grants: accesslist.Grants{
Roles: []string{"grant-role"},
},
Audit: accesslist.Audit{
NextAuditDate: clock.Now().AddDate(1, 0, 0),
},
Owners: ownersSpec,
})
require.NoError(t, err)
return accessList
}

tests := []struct {
name string
req types.AccessRequest
accessLists []*accesslist.AccessList
promotions *types.AccessRequestAllowedPromotions
expectedReviewers []string
}{
{
name: "nil promotions",
req: mustRequest("rev1", "rev2"),
expectedReviewers: []string{"rev1", "rev2"},
},
{
name: "a few promotions",
req: mustRequest("rev1", "rev2"),
accessLists: []*accesslist.AccessList{
mustAccessList("name1", "owner1", "owner2"),
mustAccessList("name2", "owner1", "owner3"),
mustAccessList("name3", "owner4", "owner5"),
},
promotions: &types.AccessRequestAllowedPromotions{
Promotions: []*types.AccessRequestAllowedPromotion{
{AccessListName: "name1"},
{AccessListName: "name2"},
},
},
expectedReviewers: []string{"rev1", "rev2", "owner1", "owner2", "owner3"},
},
{
name: "no promotions",
req: mustRequest("rev1", "rev2"),
accessLists: []*accesslist.AccessList{
mustAccessList("name1", "owner1", "owner2"),
mustAccessList("name2", "owner1", "owner3"),
mustAccessList("name3", "owner4", "owner5"),
},
promotions: &types.AccessRequestAllowedPromotions{
Promotions: []*types.AccessRequestAllowedPromotion{},
},
expectedReviewers: []string{"rev1", "rev2"},
},
}

for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
mem, err := memory.New(memory.Config{})
require.NoError(t, err)
accessLists, err := local.NewAccessListService(mem, clock)
require.NoError(t, err)

ctx := context.Background()
for _, accessList := range test.accessLists {
_, err = accessLists.UpsertAccessList(ctx, accessList)
require.NoError(t, err)
}

req := test.req.Copy()
updateAccessRequestWithAdditionalReviewers(ctx, req, accessLists, test.promotions)
require.ElementsMatch(t, test.expectedReviewers, req.GetSuggestedReviewers())
})
}
}
56 changes: 49 additions & 7 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/time/rate"
"google.golang.org/protobuf/types/known/durationpb"
Expand Down Expand Up @@ -4220,6 +4221,9 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ
}

if req.GetDryRun() {
_, promotions := a.generateAccessRequestPromotions(ctx, req)
// update the request with additional reviewers if possible.
updateAccessRequestWithAdditionalReviewers(ctx, req, a.AccessListClient(), promotions)
// Made it this far with no errors, return before creating the request
// if this is a dry run.
return req, nil
Expand Down Expand Up @@ -4253,14 +4257,10 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ
if err != nil {
log.WithError(err).Warn("Failed to emit access request create event.")
}

// calculate the promotions
reqCopy := req.Copy()
promotions, err := modules.GetModules().GenerateAccessRequestPromotions(ctx, a.Services, reqCopy)
if err != nil {
// Do not fail the request if the promotions failed to generate.
// The request promotion will be blocked, but the request can still be approved.
log.WithError(err).Warn("Failed to generate access list promotions.")
} else if promotions != nil {
reqCopy, promotions := a.generateAccessRequestPromotions(ctx, req)
if promotions != nil {
// Create the promotion entry even if the allowed promotion is empty. Otherwise, we won't
// be able to distinguish between an allowed empty set and generation failure.
if err := a.Services.CreateAccessRequestAllowedPromotions(ctx, reqCopy, promotions); err != nil {
Expand All @@ -4274,6 +4274,48 @@ func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequ
return req, nil
}

// generateAccessRequestPromotions will return potential access list promotions for an access request. On error, this function will log
// the error and return whatever it has. The caller is expected to deal with the possibility of a nil promotions object.
func (a *Server) generateAccessRequestPromotions(ctx context.Context, req types.AccessRequest) (types.AccessRequest, *types.AccessRequestAllowedPromotions) {
reqCopy := req.Copy()
promotions, err := modules.GetModules().GenerateAccessRequestPromotions(ctx, a.Services, reqCopy)
if err != nil {
// Do not fail the request if the promotions failed to generate.
// The request promotion will be blocked, but the request can still be approved.
log.WithError(err).Warn("Failed to generate access list promotions.")
}
return reqCopy, promotions
}

// updateAccessRequestWithAdditionalReviewers will update the given access request with additional reviewers given the promotions
// created for the access request.
func updateAccessRequestWithAdditionalReviewers(ctx context.Context, req types.AccessRequest, accessLists services.AccessListsGetter, promotions *types.AccessRequestAllowedPromotions) {
if promotions == nil {
return
}

// For promotions, add in access list owners as additional suggested reviewers
additionalReviewers := map[string]struct{}{}

// Iterate through the promotions, adding the owners of the corresponding access lists as reviewers.
for _, promotion := range promotions.Promotions {
accessList, err := accessLists.GetAccessList(ctx, promotion.AccessListName)
if err != nil {
log.WithError(err).Warn("Failed to get access list, skipping additional reviewers")
break
}

for _, owner := range accessList.GetOwners() {
additionalReviewers[owner.Name] = struct{}{}
}
}

// Only modify the original request if additional reviewers were found.
if len(additionalReviewers) > 0 {
req.SetSuggestedReviewers(append(req.GetSuggestedReviewers(), maps.Keys(additionalReviewers)...))
}
}

func (a *Server) DeleteAccessRequest(ctx context.Context, name string) error {
if err := a.Services.DeleteAccessRequest(ctx, name); err != nil {
return trace.Wrap(err)
Expand Down