Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 43 additions & 10 deletions integrations/access/common/auth/token_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"

"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/integrations/access/common/auth/oauth"
"github.com/gravitational/teleport/integrations/access/common/auth/storage"
)
Expand Down Expand Up @@ -146,11 +147,9 @@ func (r *RotatedAccessTokenProvider) GetAccessToken() (string, error) {
// RefreshLoop runs the credential refresh process.
func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) {
r.lock.RLock()
creds := r.creds
interval := r.getRefreshInterval(r.creds)
r.lock.RUnlock()

interval := r.getRefreshInterval(creds)

timer := r.clock.NewTimer(interval)
defer timer.Stop()
r.log.InfoContext(ctx, "Starting token refresh loop", "next_refresh", interval)
Expand All @@ -161,7 +160,19 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) {
r.log.InfoContext(ctx, "Shutting down")
return
case <-timer.Chan():
creds, _ := r.store.GetCredentials(ctx)
r.log.DebugContext(ctx, "Entering token refresh loop")
creds, err := r.store.GetCredentials(ctx)
if err != nil {
r.lock.RLock()
credsExpiry := r.creds.ExpiresAt
r.lock.RUnlock()
r.log.WarnContext(ctx, "Error getting credentials, not attempting to refresh credentials", "error", err, "creds_expiry", credsExpiry)
// We cannot get the credentials from the backend, something is going on.
// If we don't have backend access, or we are in an unknown state, we should not attempt to refresh
// credentials. This will lower the probability of ending up in an awkward state where we refreshed the
// token but cannot store it.
timer.Reset(r.retryInterval)
}

// Skip if the credentials are sufficiently fresh
// (in an HA setup another instance might have refreshed the credentials).
Expand All @@ -174,25 +185,47 @@ func (r *RotatedAccessTokenProvider) RefreshLoop(ctx context.Context) {

interval := r.getRefreshInterval(creds)
timer.Reset(interval)
r.log.InfoContext(ctx, "Refreshed token", "next_refresh", interval)
r.log.InfoContext(ctx, "Reloaded token", "next_refresh", interval)
continue
}

creds, err := r.refresh(ctx)
// Important: we are entering the critical section here.
// Once we start refreshing the token, we must not stop until we are done writing it to the backend.
// Failure to do so results in a lost token and broken Slack integration until the user re-registers it.
// We ignore cancellation here to make sure the refresh process finishes even during a shutdown.
criticalCtx := context.WithoutCancel(ctx)

creds, err = r.refresh(criticalCtx)
if err != nil {
r.log.ErrorContext(ctx, "Error while refreshing token",
"error", err,
"retry_interval", r.retryInterval,
)
timer.Reset(r.retryInterval)
} else {
err := r.store.PutCredentials(ctx, creds)
retry, err := retryutils.NewLinear(retryutils.LinearConfig{
Step: time.Second,
Max: time.Minute,
Jitter: retryutils.DefaultJitter,
})
if err != nil {
r.log.ErrorContext(ctx, "Error while storing the refreshed credentials", "error", err)
timer.Reset(r.retryInterval)
continue
r.log.ErrorContext(ctx, "Error while creating the token retry configuration, this is a bug", "error", err)
}

// We don't need to check the error because we keep retrying the context cannot be canceled.
_ = retry.For(criticalCtx, func() error {
err := r.store.PutCredentials(criticalCtx, creds)
if err != nil {
// If we land here, we refreshed the Slack token but failed to store it back.
// This is the worst case scenario: the refresh token is single-use, and we burnt it.
// This Slack integration will very likely get locked out.
// The only thing we can do is try again.
r.log.WarnContext(ctx, "Error while saving credentials to storage", "error", err)
return err
}
return nil
})

r.lock.Lock()
r.creds = creds
r.lock.Unlock()
Expand Down
24 changes: 21 additions & 3 deletions integrations/access/slack/oauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package slack

import (
"context"
"log/slog"
"time"

"github.com/go-resty/resty/v2"
Expand All @@ -29,36 +30,46 @@ import (
"github.com/gravitational/teleport/integrations/access/common/auth/storage"
)

const (
requestTimeout = 30 * time.Second
)
Comment on lines +33 to +35
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add a timeout here? It's always better to let the caller do this IMO.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't agree, I don't trust the caller to set a deadline and I don't want this to hang. I'm still honouring the caller's context by cancelling when they want, but I don't see why this should be the caller's responsibility to be sure my HTTP client will not block infinitely.


// Authorizer implements oauth2.Authorizer for Slack API.
type Authorizer struct {
client *resty.Client

clientID string
clientSecret string
log *slog.Logger
}

func newAuthorizer(client *resty.Client, clientID string, clientSecret string) *Authorizer {
func newAuthorizer(client *resty.Client, clientID string, clientSecret string, log *slog.Logger) *Authorizer {
return &Authorizer{
client: client,
clientID: clientID,
clientSecret: clientSecret,
log: log,
}
}

// NewAuthorizer returns a new Authorizer.
//
// clientID is the Client ID for this Slack app as specified by OAuth2.
// clientSecret is the Client Secret for this Slack app as specified by OAuth2.
func NewAuthorizer(clientID string, clientSecret string) *Authorizer {
func NewAuthorizer(clientID string, clientSecret string, log *slog.Logger) *Authorizer {
client := makeSlackClient(slackAPIURL)
return newAuthorizer(client, clientID, clientSecret)
return newAuthorizer(client, clientID, clientSecret, log.With("authorizer", "slack"))
}

// Exchange implements oauth.Exchanger
func (a *Authorizer) Exchange(ctx context.Context, authorizationCode string, redirectURI string) (*storage.Credentials, error) {
var result AccessResponse

ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()

_, err := a.client.R().
SetContext(ctx).
SetQueryParam("client_id", a.clientID).
SetQueryParam("client_secret", a.clientSecret).
SetQueryParam("code", authorizationCode).
Expand All @@ -67,6 +78,7 @@ func (a *Authorizer) Exchange(ctx context.Context, authorizationCode string, red
Post("oauth.v2.access")

if err != nil {
a.log.WarnContext(ctx, "Failed to exchange access token.", "error", err)
return nil, trace.Wrap(err)
}

Expand All @@ -84,7 +96,12 @@ func (a *Authorizer) Exchange(ctx context.Context, authorizationCode string, red
// Refresh implements oauth.Refresher
func (a *Authorizer) Refresh(ctx context.Context, refreshToken string) (*storage.Credentials, error) {
var result AccessResponse

ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()

_, err := a.client.R().
SetContext(ctx).
SetQueryParam("client_id", a.clientID).
SetQueryParam("client_secret", a.clientSecret).
SetQueryParam("grant_type", "refresh_token").
Expand All @@ -93,6 +110,7 @@ func (a *Authorizer) Refresh(ctx context.Context, refreshToken string) (*storage
Post("oauth.v2.access")

if err != nil {
a.log.WarnContext(ctx, "Failed to refresh access token.", "error", err)
return nil, trace.Wrap(err)
}

Expand Down
12 changes: 8 additions & 4 deletions integrations/access/slack/oauth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (

"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils/log/logtest"
)

type testOAuthServer struct {
Expand Down Expand Up @@ -100,6 +102,8 @@ func TestOAuth(t *testing.T) {
expiresInSeconds = 43200
)

log := logtest.NewLogger()

newServer := func(t *testing.T) *testOAuthServer {
s := &testOAuthServer{
clientID: clientID,
Expand Down Expand Up @@ -137,7 +141,7 @@ func TestOAuth(t *testing.T) {
defer s.close()
s.exchangeResponse = ok("my-access-token1", "my-refresh-token2", expiresInSeconds)

authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret)
authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log)

creds, err := authorizer.Exchange(context.Background(), s.authorizationCode, s.redirectURI)
require.NoError(t, err)
Expand All @@ -151,7 +155,7 @@ func TestOAuth(t *testing.T) {
defer s.close()
s.exchangeResponse = fail("invalid_code")

authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret)
authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log)

_, err := authorizer.Exchange(context.Background(), s.authorizationCode, s.redirectURI)
require.Error(t, err)
Expand All @@ -164,7 +168,7 @@ func TestOAuth(t *testing.T) {
defer s.close()
s.refreshResponse = ok("my-access-token2", "my-refresh-token3", expiresInSeconds)

authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret)
authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log)

creds, err := authorizer.Refresh(context.Background(), refreshToken)
require.NoError(t, err)
Expand All @@ -179,7 +183,7 @@ func TestOAuth(t *testing.T) {
defer s.close()
s.refreshResponse = fail("expired_token")

authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret)
authorizer := newAuthorizer(makeSlackClient(s.url()), clientID, clientSecret, log)

_, err := authorizer.Refresh(context.Background(), refreshToken)
require.Error(t, err)
Expand Down
Loading