Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions api/types/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/utils"
)

Expand Down Expand Up @@ -826,3 +827,25 @@ func NewAccessRequestAllowedPromotions(promotions []*AccessRequestAllowedPromoti
Promotions: promotions,
}
}

// ValidateAssumeStartTime returns error if start time is in an invalid range.
func ValidateAssumeStartTime(assumeStartTime time.Time, accessExpiry time.Time, creationTime time.Time) error {
// Guard against requesting a start time before the request creation time.
if assumeStartTime.Before(creationTime) {
return trace.BadParameter("assume start time has to be after %v", creationTime.Format(time.RFC3339))
}
// Guard against requesting a start time after access expiry.
if assumeStartTime.After(accessExpiry) || assumeStartTime.Equal(accessExpiry) {
return trace.BadParameter("assume start time must be prior to access expiry time at %v",
accessExpiry.Format(time.RFC3339))
}
// Access expiry can be greater than constants.MaxAssumeStartDuration, but start time
// should be on or before constants.MaxAssumeStartDuration.
maxAssumableStartTime := creationTime.Add(constants.MaxAssumeStartDuration)
if maxAssumableStartTime.Before(accessExpiry) && assumeStartTime.After(maxAssumableStartTime) {
return trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumableStartTime.Format(time.RFC3339))
}

return nil
}
55 changes: 55 additions & 0 deletions api/types/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,67 @@ package types

import (
"testing"
"time"

"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/constants"
)

func TestAssertAccessRequestImplementsResourceWithLabels(t *testing.T) {
ar, err := NewAccessRequest("test", "test", "test")
require.NoError(t, err)
require.Implements(t, (*ResourceWithLabels)(nil), ar)
}

func TestValidateAssumeStartTime(t *testing.T) {
creation := time.Now().UTC()
const day = 24 * time.Hour

expiry := creation.Add(12 * day)
maxAssumeStartDuration := creation.Add(constants.MaxAssumeStartDuration)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "start time too far in the future",
startTime: creation.Add(constants.MaxAssumeStartDuration + day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(tt, err, trace.BadParameter("assume start time is too far in the future, latest time allowed is %v",
maxAssumeStartDuration.Format(time.RFC3339)))
},
},
{
name: "expired start time",
startTime: creation.Add(100 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time must be prior to access expiry time at %v",
expiry.Format(time.RFC3339)))
},
},
{
name: "before creation start time",
startTime: creation.Add(-10 * day),
errCheck: func(tt require.TestingT, err error, i ...any) {
require.ErrorIs(t, err, trace.BadParameter("assume start time has to be after %v",
creation.Format(time.RFC3339)))
},
},
{
name: "valid start time",
startTime: creation.Add(6 * day),
errCheck: require.NoError,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
err := ValidateAssumeStartTime(tc.startTime, expiry, creation)
tc.errCheck(t, err)
})
}
}
178 changes: 178 additions & 0 deletions lib/auth/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ func newAccessRequestTestPack(ctx context.Context, t *testing.T) *accessRequestT
Request: &types.AccessRequestConditions{
Roles: []string{"admins", "superadmins"},
SearchAsRoles: []string{"admins", "superadmins"},
MaxDuration: types.Duration(services.MaxAccessDuration),
},
},
},
Expand Down Expand Up @@ -1215,3 +1216,180 @@ func TestUpdateAccessRequestWithAdditionalReviewers(t *testing.T) {
})
}
}

func TestAssumeStartTime_CreateAccessRequestV2(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
}

for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
req, err := services.NewAccessRequest(s.requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(s.maxDuration)
req.SetAssumeStartTime(tc.startTime)
_, err = s.requesterClient.CreateAccessRequestV2(ctx, req)
tc.errCheck(t, err)
})
}
}

func TestAssumeStartTime_SubmitAccessReview(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid submission",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
review := types.AccessReviewSubmission{
RequestID: s.createdRequest.GetName(),
Review: types.AccessReview{
Author: "admin",
ProposedState: types.RequestState_APPROVED,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
review.Review.AssumeStartTime = &tc.startTime
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.SubmitAccessReview(ctx, review)
tc.errCheck(t, err)
if err == nil {
require.Equal(t, tc.startTime, *resp.GetAssumeStartTime())
}
})
}
}

func TestAssumeStartTime_SetAccessRequestState(t *testing.T) {
ctx := context.Background()
s := createAccessRequestWithStartTime(t)

testCases := []struct {
name string
startTime time.Time
errCheck require.ErrorAssertionFunc
}{
{
name: "after access expiry time",
startTime: s.invalidExpiredAssumeStartTime,
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err), "expected bad parameter, got %v", err)
require.ErrorContains(t, err, "assume start time must be prior to access expiry time")
},
},
{
name: "valid set state",
startTime: s.validStartTime,
errCheck: require.NoError,
},
}
update := types.AccessRequestUpdate{
RequestID: s.createdRequest.GetName(),
State: types.RequestState_APPROVED,
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
update.AssumeStartTime = &tc.startTime
err := s.testPack.tlsServer.Auth().SetAccessRequestState(ctx, update)
tc.errCheck(t, err)
if err == nil {
resp, err := s.testPack.tlsServer.AuthServer.AuthServer.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, resp, 1)
require.Equal(t, tc.startTime, *resp[0].GetAssumeStartTime())
}
})
}
}

type accessRequestWithStartTime struct {
testPack *accessRequestTestPack
requesterClient *Client
invalidMaxedAssumeStartTime time.Time
invalidExpiredAssumeStartTime time.Time
validStartTime time.Time
maxDuration time.Time
requesterUserName string
createdRequest types.AccessRequest
}

func createAccessRequestWithStartTime(t *testing.T) accessRequestWithStartTime {
t.Helper()
clock := clockwork.NewFakeClock()
now := clock.Now().UTC()

modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

testPack := newAccessRequestTestPack(ctx, t)

const requesterUserName = "requester"
requester := TestUser(requesterUserName)
requesterClient, err := testPack.tlsServer.NewClient(requester)
require.NoError(t, err)

t.Cleanup(func() { require.NoError(t, requesterClient.Close()) })

day := 24 * time.Hour

maxDuration := now.Add(services.MaxAccessDuration)

invalidMaxedAssumeStartTime := now.Add(constants.MaxAssumeStartDuration + (1 * day))
invalidExpiredAssumeStartTime := now.Add(100 * day)
validStartTime := now.Add(2 * day)

// create the access request object
req, err := services.NewAccessRequest(requesterUserName, "admins")
require.NoError(t, err)
req.SetMaxDuration(maxDuration)

req.SetAssumeStartTime(validStartTime)
createdReq, err := requesterClient.CreateAccessRequestV2(ctx, req)
require.NoError(t, err)
require.Equal(t, validStartTime, *createdReq.GetAssumeStartTime())

return accessRequestWithStartTime{
testPack: testPack,
requesterClient: requesterClient,
invalidMaxedAssumeStartTime: invalidMaxedAssumeStartTime,
invalidExpiredAssumeStartTime: invalidExpiredAssumeStartTime,
validStartTime: validStartTime,
maxDuration: maxDuration,
requesterUserName: requesterUserName,
createdRequest: createdReq,
}
}
25 changes: 16 additions & 9 deletions lib/services/access_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ const maxAccessRequestReasonSize = 4096
// A day is sometimes 23 hours, sometimes 25 hours, usually 24 hours.
const day = 24 * time.Hour

// maxAccessDuration is the maximum duration that an access request can be
// MaxAccessDuration is the maximum duration that an access request can be
// granted for.
const maxAccessDuration = 7 * day
const MaxAccessDuration = 7 * day

// ValidateAccessRequest validates the AccessRequest and sets default values
func ValidateAccessRequest(ar types.AccessRequest) error {
Expand Down Expand Up @@ -365,8 +365,8 @@ func ValidateAccessPredicates(role types.Role) error {
}

if maxDuration := role.GetAccessRequestConditions(types.Allow).MaxDuration; maxDuration.Duration() != 0 &&
maxDuration.Duration() > maxAccessDuration {
return trace.BadParameter("max access duration must be less or equal 7 days")
maxDuration.Duration() > MaxAccessDuration {
return trace.BadParameter("max access duration must be less than or equal to %v", MaxAccessDuration)
}

return nil
Expand Down Expand Up @@ -414,8 +414,8 @@ func ApplyAccessReview(req types.AccessRequest, rev types.AccessReview, author U
req.SetReviews(append(req.GetReviews(), rev))

if rev.AssumeStartTime != nil {
if rev.AssumeStartTime.After(req.GetAccessExpiry()) {
return trace.BadParameter("request start time is after expiry")
if err := types.ValidateAssumeStartTime(*rev.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
req.SetAssumeStartTime(*rev.AssumeStartTime)
}
Expand Down Expand Up @@ -1210,6 +1210,13 @@ func (m *RequestValidator) Validate(ctx context.Context, req types.AccessRequest
req.SetAccessExpiry(accessTTL)
// Adjusted max access duration is equal to the access expiry time.
req.SetMaxDuration(accessTTL)

if req.GetAssumeStartTime() != nil {
assumeStartTime := *req.GetAssumeStartTime()
if err := types.ValidateAssumeStartTime(assumeStartTime, accessTTL, req.GetCreationTime()); err != nil {
return trace.Wrap(err)
}
}
}

return nil
Expand Down Expand Up @@ -1240,13 +1247,13 @@ func (m *RequestValidator) calculateMaxAccessDuration(req types.AccessRequest) (
// This prevents the time drift that can occur as the value is set on the client side.
// TODO(jakule): Replace with MaxAccessDuration that is a duration (5h, 4d etc), and not a point in time.
if req.GetDryRun() {
maxDuration = maxAccessDuration
maxDuration = MaxAccessDuration
} else if maxDuration < 0 {
return 0, trace.BadParameter("invalid maxDuration: must be greater than creation time")
}

if maxDuration > maxAccessDuration {
return 0, trace.BadParameter("max_duration must be less or equal 7 days")
if maxDuration > MaxAccessDuration {
return 0, trace.BadParameter("max_duration must be less than or equal to %v", MaxAccessDuration)
}

minAdjDuration := maxDuration
Expand Down
2 changes: 1 addition & 1 deletion lib/services/access_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ func TestReviewThresholds(t *testing.T) {
propose: approve,
assumeStartTime: clock.Now().UTC().Add(10000 * time.Hour),
errCheck: func(tt require.TestingT, err error, i ...interface{}) {
require.ErrorIs(tt, err, trace.BadParameter("request start time is after expiry"), i...)
require.ErrorContains(tt, err, "assume start time must be prior to access expiry time", i...)
},
},
},
Expand Down
7 changes: 7 additions & 0 deletions lib/services/local/dynamic_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ func (s *DynamicAccessService) SetAccessRequestState(ctx context.Context, params
req.SetRoles(params.Roles)
}

if params.AssumeStartTime != nil {
if err := types.ValidateAssumeStartTime(*params.AssumeStartTime, req.GetAccessExpiry(), req.GetCreationTime()); err != nil {
return nil, trace.Wrap(err)
}
req.SetAssumeStartTime(*params.AssumeStartTime)
}

// approved requests should have a resource expiry which matches
// the underlying access expiry.
if params.State.IsApproved() {
Expand Down
5 changes: 0 additions & 5 deletions tool/tctl/common/access_request_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/asciitable"
"github.com/gravitational/teleport/lib/auth"
Expand Down Expand Up @@ -236,10 +235,6 @@ func (c *AccessRequestCommand) Approve(ctx context.Context, client *auth.Client)
if err != nil {
return trace.BadParameter("parsing assume-start-time (required format RFC3339 e.g 2023-12-12T23:20:50.52Z): %v", err)
}
if time.Until(parsedAssumeStartTime) > constants.MaxAssumeStartDuration {
return trace.BadParameter("assume-start-time too far in future: latest date %q",
parsedAssumeStartTime.Add(constants.MaxAssumeStartDuration).Format(time.RFC3339))
}
assumeStartTime = &parsedAssumeStartTime
}
for _, reqID := range strings.Split(c.reqIDs, ",") {
Expand Down
Loading