diff --git a/e2e/keycloak/setup-keycloak.sh b/e2e/keycloak/setup-keycloak.sh index 7fc46af0..ee3d3402 100755 --- a/e2e/keycloak/setup-keycloak.sh +++ b/e2e/keycloak/setup-keycloak.sh @@ -51,6 +51,7 @@ set -ex -s clientId="${CLIENT_ID}" \ -s secret="${CLIENT_SECRET}" \ -s "redirectUris=[\"${REDIRECT_URL}\"]" \ + -s "attributes={\"pkce.code.challenge.method\":\"S256\"}" \ -s consentRequired=false \ --server "${KEYCLOAK_SERVER}" \ --realm "${REALM}" \ diff --git a/e2e/redis/store_test.go b/e2e/redis/store_test.go index f77db33e..88917f70 100644 --- a/e2e/redis/store_test.go +++ b/e2e/redis/store_test.go @@ -85,6 +85,7 @@ func TestRedisAuthorizationState(t *testing.T) { State: "state", Nonce: "nonce", RequestedURL: "https://example.com", + CodeVerifier: "code_verifier", } require.NoError(t, store.SetAuthorizationState(ctx, "s1", as)) @@ -126,6 +127,7 @@ func TestSessionExpiration(t *testing.T) { State: "state", Nonce: "nonce", RequestedURL: "https://example.com", + CodeVerifier: "code_verifier", } require.NoError(t, store.SetAuthorizationState(ctx, "s1", as)) require.Eventually(t, func() bool { diff --git a/go.mod b/go.mod index 52839f67..ed5144e8 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/tetratelabs/run v0.3.0 github.com/tetratelabs/telemetry v0.8.2 golang.org/x/net v0.23.0 + golang.org/x/oauth2 v0.18.0 google.golang.org/genproto/googleapis/rpc v0.0.0-20240304212257-790db918fca8 google.golang.org/grpc v1.62.1 google.golang.org/protobuf v1.33.0 @@ -74,7 +75,6 @@ require ( github.com/yuin/gopher-lua v1.1.1 // indirect golang.org/x/crypto v0.21.0 // indirect golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 // indirect - golang.org/x/oauth2 v0.18.0 // indirect golang.org/x/sys v0.18.0 // indirect golang.org/x/term v0.18.0 // indirect golang.org/x/text v0.14.0 // indirect diff --git a/internal/authz/oidc.go b/internal/authz/oidc.go index d5a2d8f6..ce224f3d 100644 --- a/internal/authz/oidc.go +++ b/internal/authz/oidc.go @@ -30,6 +30,7 @@ import ( typev3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/lestrrat-go/jwx/v2/jws" "github.com/tetratelabs/telemetry" + "golang.org/x/oauth2" "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc/codes" @@ -237,9 +238,10 @@ func (o *oidcHandler) redirectToIDP(ctx context.Context, log telemetry.Logger, } var ( - sessionID = o.sessionGen.GenerateSessionID() - nonce = o.sessionGen.GenerateNonce() - state = o.sessionGen.GenerateState() + sessionID = o.sessionGen.GenerateSessionID() + nonce = o.sessionGen.GenerateNonce() + state = o.sessionGen.GenerateState() + codeVerifier = o.sessionGen.GenerateCodeVerifier() ) // Store the authorization state @@ -251,6 +253,7 @@ func (o *oidcHandler) redirectToIDP(ctx context.Context, log telemetry.Logger, State: state, Nonce: nonce, RequestedURL: requestedURL, + CodeVerifier: codeVerifier, }); err != nil { log.Error("error storing the new authorization state", err) setDenyResponse(resp, newSessionErrorResponse(), codes.Unauthenticated) @@ -259,12 +262,14 @@ func (o *oidcHandler) redirectToIDP(ctx context.Context, log telemetry.Logger, // Generate the redirect URL query := url.Values{ - "response_type": []string{"code"}, - "client_id": []string{o.config.GetClientId()}, - "redirect_uri": []string{o.config.GetCallbackUri()}, - "scope": []string{strings.Join(o.config.GetScopes(), " ")}, - "state": []string{state}, - "nonce": []string{nonce}, + "response_type": []string{"code"}, + "client_id": []string{o.config.GetClientId()}, + "redirect_uri": []string{o.config.GetCallbackUri()}, + "scope": []string{strings.Join(o.config.GetScopes(), " ")}, + "state": []string{state}, + "nonce": []string{nonce}, + "code_challenge": []string{oauth2.S256ChallengeFromVerifier(codeVerifier)}, + "code_challenge_method": []string{"S256"}, } redirectURL := o.config.GetAuthorizationUri() + "?" + query.Encode() @@ -328,9 +333,10 @@ func (o *oidcHandler) retrieveTokens(ctx context.Context, log telemetry.Logger, // build body form := url.Values{ - "grant_type": []string{"authorization_code"}, - "code": []string{codeFromReq}, - "redirect_uri": []string{o.config.GetCallbackUri()}, + "grant_type": []string{"authorization_code"}, + "code": []string{codeFromReq}, + "redirect_uri": []string{o.config.GetCallbackUri()}, + "code_verifier": []string{stateFromStore.CodeVerifier}, } // build headers diff --git a/internal/authz/oidc_test.go b/internal/authz/oidc_test.go index c16e5d32..e3d21a7b 100644 --- a/internal/authz/oidc_test.go +++ b/internal/authz/oidc_test.go @@ -34,6 +34,7 @@ import ( "github.com/lestrrat-go/jwx/v2/jwt" "github.com/stretchr/testify/require" "github.com/tetratelabs/telemetry" + "golang.org/x/oauth2" "google.golang.org/grpc/codes" "google.golang.org/grpc/test/bufconn" "google.golang.org/protobuf/proto" @@ -125,10 +126,11 @@ var ( yesterday = time.Now().Add(-24 * time.Hour) tomorrow = time.Now().Add(24 * time.Hour) - sessionID = "test-session-id" - newSessionID = "new-session-id" - newNonce = "new-nonce" - newState = "new-state" + sessionID = "test-session-id" + newSessionID = "new-session-id" + newNonce = "new-nonce" + newState = "new-state" + newCodeVerifier = "new-code-verifier" basicOIDCConfig = &oidcv1.OIDCConfig{ IdToken: &oidcv1.TokenConfig{ @@ -190,12 +192,14 @@ var ( }` wantRedirectParams = url.Values{ - "response_type": {"code"}, - "client_id": {"test-client-id"}, - "redirect_uri": {"https://localhost:443/callback"}, - "scope": {"openid email"}, - "state": {newState}, - "nonce": {newNonce}, + "response_type": {"code"}, + "client_id": {"test-client-id"}, + "redirect_uri": {"https://localhost:443/callback"}, + "scope": {"openid email"}, + "state": {newState}, + "nonce": {newNonce}, + "code_challenge": {oauth2.S256ChallengeFromVerifier(newCodeVerifier)}, + "code_challenge_method": {"S256"}, } wantRedirectBaseURI = "http://idp-test-server/auth" @@ -228,7 +232,7 @@ func TestOIDCProcess(t *testing.T) { tlsPool := internal.NewTLSConfigPool(context.Background()) h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, oidc.NewJWKSProvider(newConfigFor(basicOIDCConfig), tlsPool), sessions, clock, - oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + oidc.NewStaticGenerator(newSessionID, newNonce, newState, newCodeVerifier)) require.NoError(t, err) ctx := context.Background() @@ -949,7 +953,7 @@ func TestOIDCProcessWithFailingSessionStore(t *testing.T) { } h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, oidc.NewJWKSProvider(newConfigFor(basicOIDCConfig), tlsPool), - sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + sessions, oidc.Clock{}, oidc.NewStaticGenerator(newSessionID, newNonce, newState, newCodeVerifier)) require.NoError(t, err) ctx := context.Background() @@ -1094,7 +1098,8 @@ func TestOIDCProcessWithFailingJWKSProvider(t *testing.T) { sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} store := sessions.Get(basicOIDCConfig) tlsPool := internal.NewTLSConfigPool(context.Background()) - h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, funcJWKSProvider, sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + h, err := NewOIDCHandler(basicOIDCConfig, tlsPool, funcJWKSProvider, sessions, clock, + oidc.NewStaticGenerator(newSessionID, newNonce, newState, newCodeVerifier)) require.NoError(t, err) idpServer := newServer(wellKnownURIs) @@ -1425,7 +1430,7 @@ func TestLoadWellKnownConfigError(t *testing.T) { cfg.ConfigurationUri = "http://stopped-server/.well-known/openid-configuration" sessions := &mockSessionStoreFactory{store: oidc.NewMemoryStore(&clock, time.Hour, time.Hour)} _, err := NewOIDCHandler(cfg, tlsPool, oidc.NewJWKSProvider(newConfigFor(basicOIDCConfig), tlsPool), - sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState, newCodeVerifier)) require.Error(t, err) // Fail to retrieve the dynamic config since the test server is not running } @@ -1447,7 +1452,7 @@ func TestNewOIDCHandler(t *testing.T) { t.Run(tt.name, func(t *testing.T) { _, err := NewOIDCHandler(tt.config, tlsPool, oidc.NewJWKSProvider(newConfigFor(basicOIDCConfig), tlsPool), - sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState)) + sessions, clock, oidc.NewStaticGenerator(newSessionID, newNonce, newState, newCodeVerifier)) if tt.wantErr { require.Error(t, err) } else { diff --git a/internal/oidc/redis.go b/internal/oidc/redis.go index 6033631f..1c3a98ca 100644 --- a/internal/oidc/redis.go +++ b/internal/oidc/redis.go @@ -40,12 +40,13 @@ const ( keyState = "state" keyNonce = "nonce" keyRequestedURL = "requested_url" + keyCodeVerifier = "code_verifier" keyTimeAdded = "time_added" ) var ( tokenResponseKeys = []string{keyIDToken, keyAccessToken, keyRefreshToken, keyAccessTokenExpiry, keyTimeAdded} - authorizationStateKeys = []string{keyState, keyNonce, keyRequestedURL, keyTimeAdded} + authorizationStateKeys = []string{keyState, keyNonce, keyRequestedURL, keyTimeAdded, keyCodeVerifier} ) // redisStore is an in-memory implementation of the SessionStore interface that stores @@ -165,6 +166,7 @@ func (r *redisStore) SetAuthorizationState(ctx context.Context, sessionID string keyState: authorizationState.State, keyNonce: authorizationState.Nonce, keyRequestedURL: authorizationState.RequestedURL, + keyCodeVerifier: authorizationState.CodeVerifier, } if err := r.client.HMSet(ctx, sessionID, state).Err(); err != nil { @@ -193,7 +195,7 @@ func (r *redisStore) GetAuthorizationState(ctx context.Context, sessionID string return nil, err } - if state.State == "" || state.Nonce == "" || state.RequestedURL == "" { + if state.State == "" || state.Nonce == "" || state.RequestedURL == "" || state.CodeVerifier == "" { return nil, nil } @@ -286,6 +288,7 @@ type ( State string `redis:"state"` Nonce string `redis:"nonce"` RequestedURL string `redis:"requested_url"` + CodeVerifier string `redis:"code_verifier"` TimeAdded time.Time `redis:"time_added"` } ) @@ -304,5 +307,6 @@ func (r redisAuthState) AuthorizationState() *AuthorizationState { State: r.State, Nonce: r.Nonce, RequestedURL: r.RequestedURL, + CodeVerifier: r.CodeVerifier, } } diff --git a/internal/oidc/redis_test.go b/internal/oidc/redis_test.go index 89230c82..30a5f4f4 100644 --- a/internal/oidc/redis_test.go +++ b/internal/oidc/redis_test.go @@ -91,6 +91,7 @@ func TestRedisAuthorizationState(t *testing.T) { State: "state", Nonce: "nonce", RequestedURL: "requested_url", + CodeVerifier: "code_verifier", } require.NoError(t, store.SetAuthorizationState(ctx, "s1", as)) diff --git a/internal/oidc/session.go b/internal/oidc/session.go index 0d702765..6c9be37f 100644 --- a/internal/oidc/session.go +++ b/internal/oidc/session.go @@ -22,6 +22,7 @@ import ( "github.com/redis/go-redis/v9" "github.com/tetratelabs/run" "github.com/tetratelabs/telemetry" + "golang.org/x/oauth2" configv1 "github.com/istio-ecosystem/authservice/config/gen/go/v1" oidcv1 "github.com/istio-ecosystem/authservice/config/gen/go/v1/oidc" @@ -136,6 +137,7 @@ type SessionGenerator interface { GenerateSessionID() string GenerateNonce() string GenerateState() string + GenerateCodeVerifier() string } var ( @@ -151,9 +153,10 @@ type ( // staticGenerator is a session generator that uses static strings. staticGenerator struct { - sessionID string - nonce string - state string + sessionID string + nonce string + state string + codeVerifier string } ) @@ -176,6 +179,10 @@ func (r randomGenerator) GenerateState() string { return r.generate(32) } +func (r randomGenerator) GenerateCodeVerifier() string { + return oauth2.GenerateVerifier() +} + func (r *randomGenerator) generate(n int) string { const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" b := make([]byte, n) @@ -186,11 +193,12 @@ func (r *randomGenerator) generate(n int) string { } // NewStaticGenerator creates a new static session generator. -func NewStaticGenerator(sessionID, nonce, state string) SessionGenerator { +func NewStaticGenerator(sessionID, nonce, state, codeVerifier string) SessionGenerator { return &staticGenerator{ - sessionID: sessionID, - nonce: nonce, - state: state, + sessionID: sessionID, + nonce: nonce, + state: state, + codeVerifier: codeVerifier, } } @@ -205,3 +213,7 @@ func (s staticGenerator) GenerateNonce() string { func (s staticGenerator) GenerateState() string { return s.state } + +func (s staticGenerator) GenerateCodeVerifier() string { + return s.codeVerifier +} diff --git a/internal/oidc/session_test.go b/internal/oidc/session_test.go index 35abc12f..f5a0addf 100644 --- a/internal/oidc/session_test.go +++ b/internal/oidc/session_test.go @@ -131,14 +131,17 @@ func TestSessionGenerator(t *testing.T) { require.NotEqual(t, sg.GenerateSessionID(), sg.GenerateSessionID()) require.NotEqual(t, sg.GenerateState(), sg.GenerateState()) require.NotEqual(t, sg.GenerateNonce(), sg.GenerateNonce()) + require.NotEqual(t, sg.GenerateCodeVerifier(), sg.GenerateCodeVerifier()) }) t.Run("static", func(t *testing.T) { - sg := NewStaticGenerator("sessionid", "nonce", "state") + sg := NewStaticGenerator("sessionid", "nonce", "state", "codeverifier") require.Equal(t, sg.GenerateSessionID(), sg.GenerateSessionID()) require.Equal(t, sg.GenerateState(), sg.GenerateState()) require.Equal(t, sg.GenerateNonce(), sg.GenerateNonce()) + require.Equal(t, sg.GenerateCodeVerifier(), sg.GenerateCodeVerifier()) require.Equal(t, "sessionid", sg.GenerateSessionID()) require.Equal(t, "state", sg.GenerateState()) require.Equal(t, "nonce", sg.GenerateNonce()) + require.Equal(t, "codeverifier", sg.GenerateCodeVerifier()) }) } diff --git a/internal/oidc/state.go b/internal/oidc/state.go index 96236e71..02b2048d 100644 --- a/internal/oidc/state.go +++ b/internal/oidc/state.go @@ -19,4 +19,5 @@ type AuthorizationState struct { State string Nonce string RequestedURL string + CodeVerifier string }