diff --git a/api/client/client.go b/api/client/client.go index 6c92a59be0697..9bac7060c88a0 100644 --- a/api/client/client.go +++ b/api/client/client.go @@ -1040,16 +1040,6 @@ func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessReque return reqs, nil } -// CreateAccessRequest registers a new access request with the auth server. -func (c *Client) CreateAccessRequest(ctx context.Context, req types.AccessRequest) error { - r, ok := req.(*types.AccessRequestV3) - if !ok { - return trace.BadParameter("unexpected access request type %T", req) - } - _, err := c.grpc.CreateAccessRequest(ctx, r) - return trace.Wrap(err) -} - // CreateAccessRequestV2 registers a new access request with the auth server. func (c *Client) CreateAccessRequestV2(ctx context.Context, req types.AccessRequest) (types.AccessRequest, error) { r, ok := req.(*types.AccessRequestV3) diff --git a/e b/e index e412b8e570de9..93162b421a5ca 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit e412b8e570de946a885d6ff6218421d78e5d9d7a +Subproject commit 93162b421a5caf23392540f5bf5a09b75bc60a12 diff --git a/integration/integration_test.go b/integration/integration_test.go index 12aaa718f2c8c..be8449195d0a8 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -6764,11 +6764,11 @@ func testSessionStartContainsAccessRequest(t *testing.T, suite *integrationTestS req, err := services.NewAccessRequest(suite.Me.Username, requestedRole.GetMetadata().Name) require.NoError(t, err) - accessRequestID := req.GetName() - - err = authServer.CreateAccessRequest(ctx, req, tlsca.Identity{}) + req, err = authServer.CreateAccessRequestV2(ctx, req, tlsca.Identity{}) require.NoError(t, err) + accessRequestID := req.GetName() + err = authServer.SetAccessRequestState(ctx, types.AccessRequestUpdate{ RequestID: accessRequestID, State: types.RequestState_APPROVED, diff --git a/integrations/access/discord/discord_test.go b/integrations/access/discord/discord_test.go index 16157f2076bb4..981b2158262b2 100644 --- a/integrations/access/discord/discord_test.go +++ b/integrations/access/discord/discord_test.go @@ -235,9 +235,9 @@ func (s *DiscordSuite) createAccessRequest() types.AccessRequest { t.Helper() req := s.newAccessRequest() - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) - return req + return out } func (s *DiscordSuite) checkPluginData(reqID string, cond func(common.GenericPluginData) bool) common.GenericPluginData { @@ -626,7 +626,7 @@ func (s *DiscordSuite) TestRace() { if err != nil { return setRaceErr(trace.Wrap(err)) } - if err := s.requestor().CreateAccessRequest(ctx, req); err != nil { + if _, err := s.requestor().CreateAccessRequestV2(ctx, req); err != nil { return setRaceErr(trace.Wrap(err)) } return nil diff --git a/integrations/access/jira/jira_test.go b/integrations/access/jira/jira_test.go index 1fabc18ba3261..3c779993d9f58 100644 --- a/integrations/access/jira/jira_test.go +++ b/integrations/access/jira/jira_test.go @@ -249,9 +249,9 @@ func (s *JiraSuite) createAccessRequest() types.AccessRequest { t.Helper() req := s.newAccessRequest() - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) - return req + return out } func (s *JiraSuite) checkPluginData(reqID string, cond func(PluginData) bool) PluginData { @@ -323,7 +323,8 @@ func (s *JiraSuite) TestIssueCreationWithRequestReason() { req := s.newAccessRequest() req.SetRequestReason("because of") - err := s.requestor().CreateAccessRequest(s.Context(), req) + var err error + req, err = s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) s.checkPluginData(req.GetName(), func(data PluginData) bool { return data.IssueID != "" @@ -344,7 +345,8 @@ func (s *JiraSuite) TestIssueCreationWithLargeRequestReason() { req := s.newAccessRequest() req.SetRequestReason(strings.Repeat("a", jiraReasonLimit+10)) - err := s.requestor().CreateAccessRequest(s.Context(), req) + var err error + req, err = s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) s.checkPluginData(req.GetName(), func(data PluginData) bool { return data.IssueID != "" @@ -757,7 +759,8 @@ func (s *JiraSuite) TestRace() { if err != nil { return setRaceErr(trace.Wrap(err)) } - if err = s.requestor().CreateAccessRequest(s.Context(), req); err != nil { + _, err = s.requestor().CreateAccessRequestV2(s.Context(), req) + if err != nil { return setRaceErr(trace.Wrap(err)) } return nil diff --git a/integrations/access/mattermost/mattermost_test.go b/integrations/access/mattermost/mattermost_test.go index 90f3f8feb7e16..bcc8d1a3962dc 100644 --- a/integrations/access/mattermost/mattermost_test.go +++ b/integrations/access/mattermost/mattermost_test.go @@ -256,9 +256,9 @@ func (s *MattermostSuite) createAccessRequest(reviewers []User) types.AccessRequ t.Helper() req := s.newAccessRequest(reviewers) - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(s.T(), err) - return req + return out } func (s *MattermostSuite) checkPluginData(reqID string, cond func(common.GenericPluginData) bool) common.GenericPluginData { @@ -650,7 +650,7 @@ func (s *MattermostSuite) TestRace() { return setRaceErr(trace.Wrap(err)) } req.SetSuggestedReviewers([]string{reviewer1.Email, reviewer2.Email}) - if err := s.requestor().CreateAccessRequest(ctx, req); err != nil { + if _, err := s.requestor().CreateAccessRequestV2(ctx, req); err != nil { return setRaceErr(trace.Wrap(err)) } return nil diff --git a/integrations/access/opsgenie/opsgenie_test.go b/integrations/access/opsgenie/opsgenie_test.go index 230bb481a4e13..7c58a31ee8fb7 100644 --- a/integrations/access/opsgenie/opsgenie_test.go +++ b/integrations/access/opsgenie/opsgenie_test.go @@ -331,9 +331,9 @@ func (s *OpsgenieSuite) createAccessRequest() types.AccessRequest { t.Helper() req := s.newAccessRequest() - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) - return req + return out } func (s *OpsgenieSuite) checkPluginData(reqID string, cond func(PluginData) bool) PluginData { diff --git a/integrations/access/pagerduty/pagerduty_test.go b/integrations/access/pagerduty/pagerduty_test.go index 77536cbd6c116..b164ba8632b19 100644 --- a/integrations/access/pagerduty/pagerduty_test.go +++ b/integrations/access/pagerduty/pagerduty_test.go @@ -340,9 +340,9 @@ func (s *PagerdutySuite) createAccessRequest() types.AccessRequest { t.Helper() req := s.newAccessRequest() - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) - return req + return out } func (s *PagerdutySuite) checkPluginData(reqID string, cond func(PluginData) bool) PluginData { @@ -851,7 +851,8 @@ func (s *PagerdutySuite) TestRace() { if err != nil { return setRaceErr(trace.Wrap(err)) } - if err := s.clients[userName].CreateAccessRequest(ctx, req); err != nil { + req, err = s.clients[userName].CreateAccessRequestV2(ctx, req) + if err != nil { return setRaceErr(trace.Wrap(err)) } pendingRequests.Store(req.GetName(), struct{}{}) diff --git a/integrations/access/servicenow/servicenow_test.go b/integrations/access/servicenow/servicenow_test.go index 5bcdecb889e1f..a46fbbfd21905 100644 --- a/integrations/access/servicenow/servicenow_test.go +++ b/integrations/access/servicenow/servicenow_test.go @@ -321,9 +321,9 @@ func (s *ServiceNowSuite) createAccessRequest() types.AccessRequest { t.Helper() req := s.newAccessRequest() - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) - return req + return out } func (s *ServiceNowSuite) checkPluginData(reqID string, cond func(PluginData) bool) PluginData { diff --git a/integrations/access/slack/slack_test.go b/integrations/access/slack/slack_test.go index d37592e7b3d14..45412bb279a2a 100644 --- a/integrations/access/slack/slack_test.go +++ b/integrations/access/slack/slack_test.go @@ -249,9 +249,9 @@ func (s *SlackSuite) createAccessRequest(reviewers []User) types.AccessRequest { t.Helper() req := s.newAccessRequest(reviewers) - err := s.requestor().CreateAccessRequest(s.Context(), req) + out, err := s.requestor().CreateAccessRequestV2(s.Context(), req) require.NoError(t, err) - return req + return out } func (s *SlackSuite) checkPluginData(reqID string, cond func(common.GenericPluginData) bool) common.GenericPluginData { @@ -651,7 +651,7 @@ func (s *SlackSuite) TestRace() { return setRaceErr(trace.Wrap(err)) } req.SetSuggestedReviewers([]string{reviewer1.Profile.Email, reviewer2.Profile.Email}) - if err := s.requestor().CreateAccessRequest(ctx, req); err != nil { + if _, err := s.requestor().CreateAccessRequestV2(ctx, req); err != nil { return setRaceErr(trace.Wrap(err)) } return nil diff --git a/lib/ai/model/tools/tool.go b/lib/ai/model/tools/tool.go index 2265f606d4a44..d263fd94962d3 100644 --- a/lib/ai/model/tools/tool.go +++ b/lib/ai/model/tools/tool.go @@ -62,7 +62,7 @@ type AccessPoint interface { // AccessRequestClient abstracts away the access request client for testing purposes. type AccessRequestClient interface { - CreateAccessRequest(ctx context.Context, req types.AccessRequest) error + CreateAccessRequestV2(ctx context.Context, req types.AccessRequest) (types.AccessRequest, error) GetAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]types.AccessRequest, error) } diff --git a/lib/auth/access_request_test.go b/lib/auth/access_request_test.go index 5a60411c1aea8..3c55831e190d0 100644 --- a/lib/auth/access_request_test.go +++ b/lib/auth/access_request_test.go @@ -313,7 +313,7 @@ func testSingleAccessRequests(t *testing.T, testPack *accessRequestTestPack) { require.NoError(t, err) // send the request to the auth server - err = requesterClient.CreateAccessRequest(ctx, req) + req, err = requesterClient.CreateAccessRequestV2(ctx, req) require.ErrorIs(t, err, tc.expectRequestError) if tc.expectRequestError != nil { return diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 3a9748329f8ab..a9d0099f0e8a0 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -4087,10 +4087,6 @@ func (a *Server) DeleteNamespace(namespace string) error { } return a.Services.DeleteNamespace(namespace) } -func (a *Server) CreateAccessRequest(ctx context.Context, req types.AccessRequest, identity tlsca.Identity) error { - _, err := a.CreateAccessRequestV2(ctx, req, identity) - return trace.Wrap(err) -} func (a *Server) CreateAccessRequestV2(ctx context.Context, req types.AccessRequest, identity tlsca.Identity) (types.AccessRequest, error) { now := a.clock.Now().UTC() diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 0d9a9d838727e..2a49026293c39 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -24,6 +24,7 @@ import ( "time" "github.com/coreos/go-semver/semver" + "github.com/google/uuid" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/sirupsen/logrus" @@ -2649,11 +2650,6 @@ func (a *ServerWithRoles) GetAccessRequests(ctx context.Context, filter types.Ac return filtered, nil } -func (a *ServerWithRoles) CreateAccessRequest(ctx context.Context, req types.AccessRequest) error { - _, err := a.CreateAccessRequestV2(ctx, req) - return trace.Wrap(err) -} - func (a *ServerWithRoles) CreateAccessRequestV2(ctx context.Context, req types.AccessRequest) (types.AccessRequest, error) { // An exception is made to allow users to create access *pending* requests for themselves. if !req.GetState().IsPending() || a.currentUserAction(req.GetUser()) != nil { @@ -2661,6 +2657,10 @@ func (a *ServerWithRoles) CreateAccessRequestV2(ctx context.Context, req types.A return nil, trace.Wrap(err) } } + + // ensure request ID is set server-side + req.SetName(uuid.NewString()) + resp, err := a.authServer.CreateAccessRequestV2(ctx, req, a.context.Identity.GetIdentity()) return resp, trace.Wrap(err) } diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index a3de2eda0c920..4c812d35797e3 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -6352,10 +6352,19 @@ func TestCreateAccessRequest(t *testing.T) { client, err := srv.NewClient(TestUser(test.user)) require.NoError(t, err) - test.errAssertionFunc(t, client.CreateAccessRequest(ctx, test.accessRequest)) + req, err := client.CreateAccessRequestV2(ctx, test.accessRequest) + test.errAssertionFunc(t, err) + + if err != nil { + require.Nil(t, test.expected, "erroring test-cases should not assert expectations (this is a bug)") + return + } + + // id should be regenerated server-side + require.NotEqual(t, test.accessRequest.GetName(), req.GetName()) accessRequests, err := srv.Auth().GetAccessRequests(ctx, types.AccessRequestFilter{ - ID: test.accessRequest.GetName(), + ID: req.GetName(), }) require.NoError(t, err) diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 12a2b1fecda11..327728aeae873 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -779,8 +779,7 @@ func (g *GRPCServer) GetAccessRequestsV2(f *types.AccessRequestFilter, stream au } func (g *GRPCServer) CreateAccessRequest(ctx context.Context, req *types.AccessRequestV3) (*emptypb.Empty, error) { - _, err := g.CreateAccessRequestV2(ctx, req) - return &emptypb.Empty{}, trace.Wrap(err) + return nil, trace.NotImplemented("access request creation API has changed, please update your client to v14 or newer") } func (g *GRPCServer) CreateAccessRequestV2(ctx context.Context, req *types.AccessRequestV3) (*types.AccessRequestV3, error) { @@ -796,10 +795,17 @@ func (g *GRPCServer) CreateAccessRequestV2(ctx context.Context, req *types.Acces return nil, trace.Wrap(err) } - if err := auth.ServerWithRoles.CreateAccessRequest(ctx, req); err != nil { + out, err := auth.ServerWithRoles.CreateAccessRequestV2(ctx, req) + if err != nil { return nil, trace.Wrap(err) } - return req, nil + + r, ok := out.(*types.AccessRequestV3) + if !ok { + return nil, trace.Wrap(trace.BadParameter("unexpected access request type %T", r)) + } + + return r, nil } func (g *GRPCServer) DeleteAccessRequest(ctx context.Context, id *authpb.RequestID) (*emptypb.Empty, error) { diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 233d53565aa27..3c4c76eff45ec 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1486,7 +1486,7 @@ func TestWebSessionMultiAccessRequests(t *testing.T) { require.NoError(t, err) roleReq.SetState(types.RequestState_APPROVED) roleReq.SetAccessExpiry(clock.Now().Add(8 * time.Hour)) - err = clt.CreateAccessRequest(ctx, roleReq) + roleReq, err = clt.CreateAccessRequestV2(ctx, roleReq) require.NoError(t, err) // Create remote cluster so create access request doesn't err due to non existent cluster @@ -1499,7 +1499,7 @@ func TestWebSessionMultiAccessRequests(t *testing.T) { resourceReq, err := services.NewAccessRequestWithResources(username, []string{resourceRequestRoleName}, resourceIDs) require.NoError(t, err) resourceReq.SetState(types.RequestState_APPROVED) - err = clt.CreateAccessRequest(ctx, resourceReq) + resourceReq, err = clt.CreateAccessRequestV2(ctx, resourceReq) require.NoError(t, err) // Create a web session and client for the user. @@ -1694,7 +1694,7 @@ func TestWebSessionWithApprovedAccessRequestAndSwitchback(t *testing.T) { accessReq.SetAccessExpiry(clock.Now().Add(time.Minute * 10)) accessReq.SetState(types.RequestState_APPROVED) - err = clt.CreateAccessRequest(ctx, accessReq) + accessReq, err = clt.CreateAccessRequestV2(ctx, accessReq) require.NoError(t, err) sess1, err := web.ExtendWebSession(ctx, WebSessionReq{ @@ -1898,7 +1898,7 @@ func TestExtendWebSessionWithMaxDuration(t *testing.T) { err = accessReq.SetState(types.RequestState_APPROVED) require.NoError(t, err) - err = adminClient.CreateAccessRequest(ctx, accessReq) + accessReq, err = adminClient.CreateAccessRequestV2(ctx, accessReq) require.NoError(t, err) sess1, err := userClient.ExtendWebSession(ctx, WebSessionReq{ @@ -2055,7 +2055,8 @@ func TestPluginData(t *testing.T) { req, err := services.NewAccessRequest(user, role) require.NoError(t, err) - require.NoError(t, userClient.CreateAccessRequest(ctx, req)) + req, err = userClient.CreateAccessRequestV2(ctx, req) + require.NoError(t, err) err = pluginClient.UpdatePluginData(ctx, types.PluginDataUpdateParams{ Kind: types.KindAccessRequest, diff --git a/lib/auth/usage_test.go b/lib/auth/usage_test.go index 6b9f06d3c2869..77322e21fc327 100644 --- a/lib/auth/usage_test.go +++ b/lib/auth/usage_test.go @@ -108,13 +108,13 @@ func TestAccessRequestLimit(t *testing.T) { // Check July req, err := types.NewAccessRequest(uuid.New().String(), "alice", "access") require.NoError(t, err) - err = p.a.CreateAccessRequest(ctx, req, tlsca.Identity{}) + _, err = p.a.CreateAccessRequestV2(ctx, req, tlsca.Identity{}) require.Error(t, err, "expected access request creation to fail due to the monthly limit") // Check August clock.Advance(31 * 24 * time.Hour) req, err = types.NewAccessRequest(uuid.New().String(), "alice", "access") require.NoError(t, err) - err = p.a.CreateAccessRequest(ctx, req, tlsca.Identity{}) + _, err = p.a.CreateAccessRequestV2(ctx, req, tlsca.Identity{}) require.NoError(t, err) } diff --git a/lib/cache/cache_test.go b/lib/cache/cache_test.go index ab87566ccded5..44fb7e917b2a7 100644 --- a/lib/cache/cache_test.go +++ b/lib/cache/cache_test.go @@ -404,7 +404,8 @@ func TestWatchers(t *testing.T) { req, err := services.NewAccessRequest("alice", "dictator") require.NoError(t, err) - require.NoError(t, p.dynamicAccessS.CreateAccessRequest(ctx, req)) + req, err = p.dynamicAccessS.CreateAccessRequestV2(ctx, req) + require.NoError(t, err) select { case e := <-w.Events(): @@ -429,7 +430,8 @@ func TestWatchers(t *testing.T) { require.NoError(t, err) // create and then delete the non-matching request. - require.NoError(t, p.dynamicAccessS.CreateAccessRequest(ctx, req2)) + req2, err = p.dynamicAccessS.CreateAccessRequestV2(ctx, req2) + require.NoError(t, err) require.NoError(t, p.dynamicAccessS.DeleteAccessRequest(ctx, req2.GetName())) // because our filter did not match the request, the create event should never diff --git a/lib/client/api.go b/lib/client/api.go index c7f254e586d6c..fee101b7797f4 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -1349,11 +1349,11 @@ func (tc *TeleportClient) IssueUserCertsWithMFA(ctx context.Context, params Reis return key, trace.Wrap(err) } -// CreateAccessRequest registers a new access request with the auth server. -func (tc *TeleportClient) CreateAccessRequest(ctx context.Context, req types.AccessRequest) error { +// CreateAccessRequestV2 registers a new access request with the auth server. +func (tc *TeleportClient) CreateAccessRequestV2(ctx context.Context, req types.AccessRequest) (types.AccessRequest, error) { ctx, span := tc.Tracer.Start( ctx, - "teleportClient/CreateAccessRequest", + "teleportClient/CreateAccessRequestV2", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes(attribute.String("request", req.GetName())), @@ -1362,11 +1362,11 @@ func (tc *TeleportClient) CreateAccessRequest(ctx context.Context, req types.Acc proxyClient, err := tc.ConnectToProxy(ctx) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } defer proxyClient.Close() - return proxyClient.CreateAccessRequest(ctx, req) + return proxyClient.CreateAccessRequestV2(ctx, req) } // GetAccessRequests loads all access requests matching the supplied filter. diff --git a/lib/client/client.go b/lib/client/client.go index bb5f7e02d820d..51e412f8ceead 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -693,11 +693,11 @@ func (proxy *ProxyClient) RootClusterName(ctx context.Context) (string, error) { return proxy.teleportClient.RootClusterName(ctx) } -// CreateAccessRequest registers a new access request with the auth server. -func (proxy *ProxyClient) CreateAccessRequest(ctx context.Context, req types.AccessRequest) error { +// CreateAccessRequestV2 registers a new access request with the auth server. +func (proxy *ProxyClient) CreateAccessRequestV2(ctx context.Context, req types.AccessRequest) (types.AccessRequest, error) { ctx, span := proxy.Tracer.Start( ctx, - "proxyClient/CreateAccessRequest", + "proxyClient/CreateAccessRequestV2", oteltrace.WithSpanKind(oteltrace.SpanKindClient), oteltrace.WithAttributes(attribute.String("request", req.GetName())), ) @@ -705,7 +705,7 @@ func (proxy *ProxyClient) CreateAccessRequest(ctx context.Context, req types.Acc site := proxy.CurrentCluster() - return site.CreateAccessRequest(ctx, req) + return site.CreateAccessRequestV2(ctx, req) } // GetAccessRequests loads all access requests matching the supplied filter. diff --git a/lib/services/access_request.go b/lib/services/access_request.go index 7bc28b6d8f752..3b9e837c7dded 100644 --- a/lib/services/access_request.go +++ b/lib/services/access_request.go @@ -167,8 +167,6 @@ type AccessRequestGetter interface { // DynamicAccessCore is the core functionality common to all DynamicAccess implementations. type DynamicAccessCore interface { AccessRequestGetter - // CreateAccessRequest stores a new access request. - CreateAccessRequest(ctx context.Context, req types.AccessRequest) error // CreateAccessRequestV2 stores a new access request. CreateAccessRequestV2(ctx context.Context, req types.AccessRequest) (types.AccessRequest, error) // DeleteAccessRequest deletes an access request. @@ -261,6 +259,8 @@ func (m *RequestValidator) applicableSearchAsRoles(ctx context.Context, resource // used to implement some auth server internals. type DynamicAccessExt interface { DynamicAccessCore + // CreateAccessRequest stores a new access request. + CreateAccessRequest(ctx context.Context, req types.AccessRequest) error // ApplyAccessReview applies a review to a request in the backend and returns the post-application state. ApplyAccessReview(ctx context.Context, params types.AccessReviewSubmission, checker ReviewPermissionChecker) (types.AccessRequest, error) // UpsertAccessRequest creates or updates an access request. diff --git a/lib/services/watcher_test.go b/lib/services/watcher_test.go index def125495bf10..9c6a0d4a34c47 100644 --- a/lib/services/watcher_test.go +++ b/lib/services/watcher_test.go @@ -1148,7 +1148,8 @@ func TestAccessRequestWatcher(t *testing.T) { // Add an access request. accessRequest1 := newAccessRequest(t, uuid.NewString()) - require.NoError(t, dynamicAccessService.CreateAccessRequest(ctx, accessRequest1)) + accessRequest1, err = dynamicAccessService.CreateAccessRequestV2(ctx, accessRequest1) + require.NoError(t, err) // The first event is always the current list of access requests. select { @@ -1163,7 +1164,8 @@ func TestAccessRequestWatcher(t *testing.T) { // Add a second access request. accessRequest2 := newAccessRequest(t, uuid.NewString()) - require.NoError(t, dynamicAccessService.CreateAccessRequest(ctx, accessRequest2)) + accessRequest2, err = dynamicAccessService.CreateAccessRequestV2(ctx, accessRequest2) + require.NoError(t, err) // Watcher should detect the access request list change. select { diff --git a/lib/teleterm/clusters/cluster_access_requests.go b/lib/teleterm/clusters/cluster_access_requests.go index 39f731e779808..c4d24687c6297 100644 --- a/lib/teleterm/clusters/cluster_access_requests.go +++ b/lib/teleterm/clusters/cluster_access_requests.go @@ -145,8 +145,10 @@ func (c *Cluster) CreateAccessRequest(ctx context.Context, req *api.CreateAccess request.SetRequestReason(req.Reason) request.SetSuggestedReviewers(req.SuggestedReviewers) + var reqOut types.AccessRequest err = AddMetadataToRetryableError(ctx, func() error { - return c.clusterClient.CreateAccessRequest(ctx, request) + reqOut, err = c.clusterClient.CreateAccessRequestV2(ctx, request) + return trace.Wrap(err) }) if err != nil { return nil, trace.Wrap(err) @@ -154,7 +156,7 @@ func (c *Cluster) CreateAccessRequest(ctx context.Context, req *api.CreateAccess return &AccessRequest{ URI: c.URI.AppendAccessRequest(request.GetName()), - AccessRequest: request, + AccessRequest: reqOut, }, nil } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index e6daccf3006a3..30caef6eaca68 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -8234,7 +8234,7 @@ func TestUserContextWithAccessRequest(t *testing.T) { accessReq, err := services.NewAccessRequest(username, requestableRolename) require.NoError(t, err) accessReq.SetState(types.RequestState_APPROVED) - err = env.server.Auth().CreateAccessRequest(ctx, accessReq, identity) + accessReq, err = env.server.Auth().CreateAccessRequestV2(ctx, accessReq, identity) require.NoError(t, err) // Get the ID of the created and approved access request. diff --git a/tool/tctl/common/access_request_command.go b/tool/tctl/common/access_request_command.go index d2db54a91515c..da620c60edb2f 100644 --- a/tool/tctl/common/access_request_command.go +++ b/tool/tctl/common/access_request_command.go @@ -282,7 +282,8 @@ func (c *AccessRequestCommand) Create(ctx context.Context, client auth.ClientI) } return trace.Wrap(printJSON(req, "request")) } - if err := client.CreateAccessRequest(ctx, req); err != nil { + req, err = client.CreateAccessRequestV2(ctx, req) + if err != nil { return trace.Wrap(err) } fmt.Printf("%s\n", req.GetName()) diff --git a/tool/tsh/common/kube.go b/tool/tsh/common/kube.go index 51420ad203995..61e28481052f5 100644 --- a/tool/tsh/common/kube.go +++ b/tool/tsh/common/kube.go @@ -1685,7 +1685,8 @@ func (c *kubeLoginCommand) accessRequestForKubeCluster(ctx context.Context, cf * req.SetDryRun(true) req.SetRequestReason("Dry run, this request will not be created. If you see this, there is a bug.") if err := tc.WithRootClusterClient(ctx, func(clt auth.ClientI) error { - return trace.Wrap(clt.CreateAccessRequest(ctx, req)) + req, err = clt.CreateAccessRequestV2(ctx, req) + return trace.Wrap(err) }); err != nil { return nil, trace.Wrap(err) } diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index b7035157fc418..731621e2b6544 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -2499,28 +2499,17 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { } } - // Watch for resolution events on the given request. Start watcher and wait - // for it to be ready before creating the request to avoid a potential race. - requestWatcher := newAccessRequestWatcher(req) - defer requestWatcher.Close() - if !cf.NoWait { - // Don't initialize the watcher unless we'll actually use it. - if err := requestWatcher.initialize(cf.Context, tc); err != nil { - return trace.Wrap(err) - } - } - // Upsert request if it doesn't already exist. if cf.RequestID == "" { - cf.RequestID = req.GetName() fmt.Fprint(os.Stdout, "Creating request...\n") // always create access request against the root cluster if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - err := clt.CreateAccessRequest(cf.Context, req) + req, err = clt.CreateAccessRequestV2(cf.Context, req) return trace.Wrap(err) }); err != nil { return trace.Wrap(err) } + cf.RequestID = req.GetName() } onRequestShow(cf) @@ -2533,13 +2522,12 @@ func executeAccessRequest(cf *CLIConf, tc *client.TeleportClient) error { // Wait for the request to be resolved. fmt.Fprintf(os.Stdout, "Waiting for request approval...\n") - resolvedReq, err := requestWatcher.awaitResolution() - if err != nil { + + var resolvedReq types.AccessRequest + if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { + resolvedReq, err = awaitRequestResolution(cf.Context, clt, req) return trace.Wrap(err) - } - if err := requestWatcher.Close(); err != nil { - // This was deferred above to catch all other error cases, here we - // actually handle any errors from requestWatcher.Close(). + }); err != nil { return trace.Wrap(err) } @@ -3123,7 +3111,8 @@ func accessRequestForSSH(ctx context.Context, _ *CLIConf, tc *client.TeleportCli req.SetDryRun(true) req.SetRequestReason("Dry run, this request will not be created. If you see this, there is a bug.") if err := tc.WithRootClusterClient(ctx, func(clt auth.ClientI) error { - return trace.Wrap(clt.CreateAccessRequest(ctx, req)) + req, err = clt.CreateAccessRequestV2(ctx, req) + return trace.Wrap(err) }); err != nil { return nil, trace.Wrap(err) } @@ -3174,18 +3163,11 @@ func retryWithAccessRequest( } req.SetRequestReason(requestReason) - // Watch for resolution events on the given request. Start watcher and wait - // for it to be ready before creating the request to avoid a potential race. - requestWatcher := newAccessRequestWatcher(req) - defer requestWatcher.Close() - if err := requestWatcher.initialize(cf.Context, tc); err != nil { - return trace.Wrap(err) - } - fmt.Fprint(os.Stdout, "Creating request...\n") // Always create access request against the root cluster. if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { - return trace.Wrap(clt.CreateAccessRequest(cf.Context, req)) + req, err = clt.CreateAccessRequestV2(cf.Context, req) + return trace.Wrap(err) }); err != nil { return trace.Wrap(err) } @@ -3199,13 +3181,11 @@ func retryWithAccessRequest( // Wait for the request to be resolved. fmt.Fprintf(os.Stdout, "Waiting for request approval...\n") - resolvedReq, err := requestWatcher.awaitResolution() - if err != nil { + var resolvedReq types.AccessRequest + if err := tc.WithRootClusterClient(cf.Context, func(clt auth.ClientI) error { + resolvedReq, err = awaitRequestResolution(cf.Context, clt, req) return trace.Wrap(err) - } - if err := requestWatcher.Close(); err != nil { - // This was deferred above to catch all other error cases, here we - // actually handle any errors from requestWatcher.Close(). + }); err != nil { return trace.Wrap(err) } @@ -4347,50 +4327,12 @@ func host(in string) string { return out } -// accessRequestWatcher is a helper to wait for an access request to be resolved. -type accessRequestWatcher struct { - req types.AccessRequest - watcher types.Watcher - closers []io.Closer - sync.RWMutex -} - -// newAccessRequestWatcher returns a new accessRequestWatcher. Callers should -// always defer (*accessRequestWatcher).Close(). -func newAccessRequestWatcher(req types.AccessRequest) *accessRequestWatcher { - return &accessRequestWatcher{ - req: req, - } -} - -// initialize sets up the underlying event watcher, when this returns without -// error the watcher is guaranteed to be in a ready state. Call this before -// creating the request to prevent a race. -func (w *accessRequestWatcher) initialize(ctx context.Context, tc *client.TeleportClient) error { - w.Lock() - defer w.Unlock() - - if w.watcher != nil { - return trace.BadParameter("cannot re-initialize accessRequestWatcher") - } - - proxyClient, err := tc.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - w.closers = append(w.closers, proxyClient) - - rootClient, err := proxyClient.ConnectToRootCluster(ctx) - if err != nil { - return trace.Wrap(err) - } - w.closers = append(w.closers, rootClient) - +func awaitRequestResolution(ctx context.Context, clt auth.ClientI, req types.AccessRequest) (types.AccessRequest, error) { filter := types.AccessRequestFilter{ - User: w.req.GetUser(), - ID: w.req.GetName(), + User: req.GetUser(), + ID: req.GetName(), } - w.watcher, err = rootClient.NewWatcher(ctx, types.Watch{ + watcher, err := clt.NewWatcher(ctx, types.Watch{ Name: "await-request-approval", Kinds: []types.WatchKind{{ Kind: types.KindAccessRequest, @@ -4398,77 +4340,51 @@ func (w *accessRequestWatcher) initialize(ctx context.Context, tc *client.Telepo }}, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - w.closers = append(w.closers, w.watcher) + defer watcher.Close() // Wait for OpInit event so that returned watcher is ready. select { - case event := <-w.watcher.Events(): + case event := <-watcher.Events(): if event.Type != types.OpInit { - return trace.BadParameter("failed to watch for access requests: received an unexpected event while waiting for the initial OpInit") + return nil, trace.BadParameter("failed to watch for access requests: received an unexpected event while waiting for the initial OpInit") } - case <-w.watcher.Done(): - return trace.Wrap(w.watcher.Error()) - case <-ctx.Done(): - // This should be the same as w.watcher.Done(), including for completeness. - return trace.Wrap(ctx.Err()) + case <-watcher.Done(): + return nil, trace.Wrap(watcher.Error()) } - return nil -} - -// awaitResolution waits for the request to be resolved (state != PENDING). -func (w *accessRequestWatcher) awaitResolution() (types.AccessRequest, error) { - w.RLock() - defer w.RUnlock() - - if w.watcher == nil { - return nil, trace.BadParameter("must initialize accessRequestWatcher before calling awaitResolution()") + // get initial state of request + reqState, err := services.GetAccessRequest(ctx, clt, req.GetName()) + if err != nil { + return nil, trace.Wrap(err) } for { + if !reqState.GetState().IsPending() { + return reqState, nil + } + select { - case event := <-w.watcher.Events(): + case event := <-watcher.Events(): switch event.Type { case types.OpPut: - r, ok := event.Resource.(*types.AccessRequestV3) + var ok bool + reqState, ok = event.Resource.(*types.AccessRequestV3) if !ok { return nil, trace.BadParameter("unexpected resource type %T", event.Resource) } - if !r.GetState().IsPending() { - return r, nil - } case types.OpDelete: return nil, trace.Errorf("request %s has expired or been deleted...", event.Resource.GetName()) default: log.Warnf("Skipping unknown event type %s", event.Type) } - case <-w.watcher.Done(): - return nil, trace.Wrap(w.watcher.Error()) + case <-watcher.Done(): + return nil, trace.Wrap(watcher.Error()) } } } -// Close closes the clients held by the watcher. -func (w *accessRequestWatcher) Close() error { - var errs []error - // Close in reverse order, like defer. - w.RLock() - for i := len(w.closers) - 1; i >= 0; i-- { - errs = append(errs, w.closers[i].Close()) - } - w.RUnlock() - - // Closed the watcher above, awaitResolution should now terminate and we can - // grab the lock. - w.Lock() - w.closers = nil - w.Unlock() - - return trace.NewAggregate(errs...) -} - func onRequestResolution(cf *CLIConf, tc *client.TeleportClient, req types.AccessRequest) error { if !req.GetState().IsApproved() { msg := fmt.Sprintf("request %s has been set to %s", req.GetName(), req.GetState().String())