diff --git a/integrations/access/common/auth/token_provider.go b/integrations/access/common/auth/token_provider.go index e0c23b0b36427..d67f96911c91a 100644 --- a/integrations/access/common/auth/token_provider.go +++ b/integrations/access/common/auth/token_provider.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/integrations/access/common/auth/oauth" "github.com/gravitational/teleport/integrations/access/common/auth/storage" ) @@ -146,11 +147,9 @@ func (r *RotatedAccessTokenProvider) GetAccessToken() (string, error) { // RefreshLoop runs the credential refresh process. func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { r.lock.RLock() - creds := r.creds + interval := r.getRefreshInterval(r.creds) r.lock.RUnlock() - interval := r.getRefreshInterval(creds) - timer := r.clock.NewTimer(interval) defer timer.Stop() r.log.InfoContext(ctx, "Starting token refresh loop", "next_refresh", interval) @@ -161,7 +160,19 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { r.log.InfoContext(ctx, "Shutting down") return case <-timer.Chan(): - creds, _ := r.store.GetCredentials(ctx) + r.log.DebugContext(ctx, "Entering token refresh loop") + creds, err := r.store.GetCredentials(ctx) + if err != nil { + r.lock.RLock() + credsExpiry := r.creds.ExpiresAt + r.lock.RUnlock() + r.log.WarnContext(ctx, "Error getting credentials, not attempting to refresh credentials", "error", err, "creds_expiry", credsExpiry) + // We cannot get the credentials from the backend, something is going on. + // If we don't have backend access, or we are in an unknown state, we should not attempt to refresh + // credentials. This will lower the probability of ending up in an awkward state where we refreshed the + // token but cannot store it. + timer.Reset(r.retryInterval) + } // Skip if the credentials are sufficiently fresh // (in an HA setup another instance might have refreshed the credentials). @@ -174,11 +185,17 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { interval := r.getRefreshInterval(creds) timer.Reset(interval) - r.log.InfoContext(ctx, "Refreshed token", "next_refresh", interval) + r.log.InfoContext(ctx, "Reloaded token", "next_refresh", interval) continue } - creds, err := r.refresh(ctx) + // Important: we are entering the critical section here. + // Once we start refreshing the token, we must not stop until we are done writing it to the backend. + // Failure to do so results in a lost token and broken Slack integration until the user re-registers it. + // We ignore cancellation here to make sure the refresh process finishes even during a shutdown. + criticalCtx := context.WithoutCancel(ctx) + + creds, err = r.refresh(criticalCtx) if err != nil { r.log.ErrorContext(ctx, "Error while refreshing token", "error", err, @@ -186,13 +203,29 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) { ) timer.Reset(r.retryInterval) } else { - err := r.store.PutCredentials(ctx, creds) + retry, err := retryutils.NewLinear(retryutils.LinearConfig{ + Step: time.Second, + Max: time.Minute, + Jitter: retryutils.DefaultJitter, + }) if err != nil { - r.log.ErrorContext(ctx, "Error while storing the refreshed credentials", "error", err) - timer.Reset(r.retryInterval) - continue + r.log.ErrorContext(ctx, "Error while creating the token retry configuration, this is a bug", "error", err) } + // We don't need to check the error because we keep retrying the context cannot be canceled. + _ = retry.For(criticalCtx, func() error { + err := r.store.PutCredentials(criticalCtx, creds) + if err != nil { + // If we land here, we refreshed the Slack token but failed to store it back. + // This is the worst case scenario: the refresh token is single-use, and we burnt it. + // This Slack integration will very likely get locked out. + // The only thing we can do is try again. + r.log.WarnContext(ctx, "Error while saving credentials to storage", "error", err) + return err + } + return nil + }) + r.lock.Lock() r.creds = creds r.lock.Unlock() diff --git a/integrations/access/slack/oauth.go b/integrations/access/slack/oauth.go index 52f4c6931d781..57f2f44187bf6 100644 --- a/integrations/access/slack/oauth.go +++ b/integrations/access/slack/oauth.go @@ -20,6 +20,7 @@ package slack import ( "context" + "log/slog" "time" "github.com/go-resty/resty/v2" @@ -29,19 +30,25 @@ import ( "github.com/gravitational/teleport/integrations/access/common/auth/storage" ) +const ( + requestTimeout = 30 * time.Second +) + // Authorizer implements oauth2.Authorizer for Slack API. type Authorizer struct { client *resty.Client clientID string clientSecret string + log *slog.Logger } -func newAuthorizer(client *resty.Client, clientID string, clientSecret string) *Authorizer { +func newAuthorizer(client *resty.Client, clientID string, clientSecret string, log *slog.Logger) *Authorizer { return &Authorizer{ client: client, clientID: clientID, clientSecret: clientSecret, + log: log, } } @@ -49,16 +56,20 @@ func newAuthorizer(client *resty.Client, clientID string, clientSecret string) * // // clientID is the Client ID for this Slack app as specified by OAuth2. // clientSecret is the Client Secret for this Slack app as specified by OAuth2. -func NewAuthorizer(clientID string, clientSecret string) *Authorizer { +func NewAuthorizer(clientID string, clientSecret string, log *slog.Logger) *Authorizer { client := makeSlackClient(slackAPIURL) - return newAuthorizer(client, clientID, clientSecret) + return newAuthorizer(client, clientID, clientSecret, log.With("authorizer", "slack")) } // Exchange implements oauth.Exchanger func (a *Authorizer) Exchange(ctx context.Context, authorizationCode string, redirectURI string) (*storage.Credentials, error) { var result AccessResponse + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + _, err := a.client.R(). + SetContext(ctx). SetQueryParam("client_id", a.clientID). SetQueryParam("client_secret", a.clientSecret). SetQueryParam("code", authorizationCode). @@ -67,6 +78,7 @@ func (a *Authorizer) Exchange(ctx context.Context, authorizationCode string, red Post("oauth.v2.access") if err != nil { + a.log.WarnContext(ctx, "Failed to exchange access token.", "error", err) return nil, trace.Wrap(err) } @@ -84,7 +96,12 @@ func (a *Authorizer) Exchange(ctx context.Context, authorizationCode string, red // Refresh implements oauth.Refresher func (a *Authorizer) Refresh(ctx context.Context, refreshToken string) (*storage.Credentials, error) { var result AccessResponse + + ctx, cancel := context.WithTimeout(ctx, requestTimeout) + defer cancel() + _, err := a.client.R(). + SetContext(ctx). SetQueryParam("client_id", a.clientID). SetQueryParam("client_secret", a.clientSecret). SetQueryParam("grant_type", "refresh_token"). @@ -93,6 +110,7 @@ func (a *Authorizer) Refresh(ctx context.Context, refreshToken string) (*storage Post("oauth.v2.access") if err != nil { + a.log.WarnContext(ctx, "Failed to refresh access token.", "error", err) return nil, trace.Wrap(err) } diff --git a/integrations/access/slack/oauth_test.go b/integrations/access/slack/oauth_test.go index d9b47588476c6..b838e044ec0ad 100644 --- a/integrations/access/slack/oauth_test.go +++ b/integrations/access/slack/oauth_test.go @@ -28,6 +28,8 @@ import ( "github.com/julienschmidt/httprouter" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils/log/logtest" ) type testOAuthServer struct { @@ -100,6 +102,8 @@ func TestOAuth(t *testing.T) { expiresInSeconds = 43200 ) + log := logtest.NewLogger() + newServer := func(t *testing.T) *testOAuthServer { s := &testOAuthServer{ clientID: clientID, @@ -137,7 +141,7 @@ func TestOAuth(t *testing.T) { defer s.close() s.exchangeResponse = ok("my-access-token1", "my-refresh-token2", expiresInSeconds) - authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret) + authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log) creds, err := authorizer.Exchange(context.Background(), s.authorizationCode, s.redirectURI) require.NoError(t, err) @@ -151,7 +155,7 @@ func TestOAuth(t *testing.T) { defer s.close() s.exchangeResponse = fail("invalid_code") - authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret) + authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log) _, err := authorizer.Exchange(context.Background(), s.authorizationCode, s.redirectURI) require.Error(t, err) @@ -164,7 +168,7 @@ func TestOAuth(t *testing.T) { defer s.close() s.refreshResponse = ok("my-access-token2", "my-refresh-token3", expiresInSeconds) - authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret) + authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log) creds, err := authorizer.Refresh(context.Background(), refreshToken) require.NoError(t, err) @@ -179,7 +183,7 @@ func TestOAuth(t *testing.T) { defer s.close() s.refreshResponse = fail("expired_token") - authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret) + authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log) _, err := authorizer.Refresh(context.Background(), refreshToken) require.Error(t, err)