Skip to content

Commit

Permalink
feat: propagate logout to identity provider
Browse files Browse the repository at this point in the history
This commit improves the integration between Hydra and Kratos when logging
out the user.

This adds a new configuration key for configuring a Kratos admin URL.
Additionally, Kratos can send a session ID when accepting a login request.
If a session ID was specified and a Kratos admin URL was configured,
Hydra will disable the corresponding Kratos session through the admin API
if a frontchannel or backchannel logout was triggered.
  • Loading branch information
hperl committed Aug 10, 2023
1 parent 219a7c0 commit 017859b
Show file tree
Hide file tree
Showing 58 changed files with 363 additions and 32 deletions.
2 changes: 1 addition & 1 deletion consent/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type (
// Cookie management
GetRememberedLoginSession(ctx context.Context, loginSessionFromCookie *flow.LoginSession, id string) (*flow.LoginSession, error)
CreateLoginSession(ctx context.Context, session *flow.LoginSession) error
DeleteLoginSession(ctx context.Context, id string) error
DeleteLoginSession(ctx context.Context, id string) (deletedSession *flow.LoginSession, err error)
RevokeSubjectLoginSession(ctx context.Context, user string) error
ConfirmLoginSession(ctx context.Context, session *flow.LoginSession, id string, authTime time.Time, subject string, remember bool) error

Expand Down
14 changes: 10 additions & 4 deletions consent/manager_test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,12 @@ func TestHelperNID(r interface {
require.NoError(t, err)
require.Error(t, t2InvalidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
require.NoError(t, t1ValidNID.ConfirmLoginSession(ctx, &testLS, testLS.ID, time.Now(), testLS.Subject, true))
require.Error(t, t2InvalidNID.DeleteLoginSession(ctx, testLS.ID))
require.NoError(t, t1ValidNID.DeleteLoginSession(ctx, testLS.ID))
ls, err := t2InvalidNID.DeleteLoginSession(ctx, testLS.ID)
require.Error(t, err)
assert.Nil(t, ls)
ls, err = t1ValidNID.DeleteLoginSession(ctx, testLS.ID)
require.NoError(t, err)
assert.Equal(t, testLS.ID, ls.ID)
}
}

Expand Down Expand Up @@ -429,8 +433,9 @@ func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeMana
},
} {
t.Run("case=delete-get-"+tc.id, func(t *testing.T) {
err := m.DeleteLoginSession(ctx, tc.id)
ls, err := m.DeleteLoginSession(ctx, tc.id)
require.NoError(t, err)
assert.EqualValues(t, tc.id, ls.ID)

_, err = m.GetRememberedLoginSession(ctx, nil, tc.id)
require.Error(t, err)
Expand Down Expand Up @@ -1083,7 +1088,8 @@ func ManagerTests(deps Deps, m Manager, clientManager client.Manager, fositeMana
require.NoError(t, err)
assert.EqualValues(t, expected.ID, result.ID)

require.NoError(t, m.DeleteLoginSession(ctx, s.ID))
_, err = m.DeleteLoginSession(ctx, s.ID)
require.NoError(t, err)

result, err = m.GetConsentRequest(ctx, expected.ID)
require.NoError(t, err)
Expand Down
2 changes: 2 additions & 0 deletions consent/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/ory/fosite/handler/openid"
"github.com/ory/hydra/v2/aead"
"github.com/ory/hydra/v2/client"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/x"
)

Expand All @@ -17,6 +18,7 @@ type InternalRegistry interface {
x.RegistryCookieStore
x.RegistryLogger
x.HTTPClientProvider
kratos.Provider
Registry
client.Registry

Expand Down
21 changes: 14 additions & 7 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,9 @@ func (s *DefaultStrategy) revokeAuthenticationSession(ctx context.Context, w htt
return nil
}

return s.r.ConsentManager().DeleteLoginSession(r.Context(), sid)
_, err = s.r.ConsentManager().DeleteLoginSession(r.Context(), sid)

return err
}

func (s *DefaultStrategy) revokeAuthenticationCookie(w http.ResponseWriter, r *http.Request, ss sessions.Store) (string, error) {
Expand Down Expand Up @@ -449,6 +451,7 @@ func (s *DefaultStrategy) verifyAuthentication(
return nil, fosite.ErrAccessDenied.WithHint("The login session cookie was not found or malformed.")
}

loginSession.KratosSessionID = f.KratosSessionID
if err := s.r.ConsentManager().ConfirmLoginSession(ctx, loginSession, sessionID, time.Time(session.AuthenticatedAt), session.Subject, session.Remember); err != nil {
return nil, err
}
Expand Down Expand Up @@ -716,7 +719,8 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
return urls, nil
}

func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error {
func (s *DefaultStrategy) executeBackChannelLogout(r *http.Request, subject, sid string) error {
ctx := r.Context()
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
if err != nil {
return err
Expand Down Expand Up @@ -976,8 +980,9 @@ func (s *DefaultStrategy) issueLogoutVerifier(ctx context.Context, w http.Respon
return nil, errorsx.WithStack(ErrAbortOAuth2Request)
}

func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Context, r *http.Request, subject string, sid string) error {
if err := s.executeBackChannelLogout(r.Context(), r, subject, sid); err != nil {
func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(r *http.Request, subject string, sid string) error {
ctx := r.Context()
if err := s.executeBackChannelLogout(r, subject, sid); err != nil {
return err
}

Expand All @@ -986,10 +991,12 @@ func (s *DefaultStrategy) performBackChannelLogoutAndDeleteSession(_ context.Con
//
// executeBackChannelLogout only fails on system errors so not on URL errors, so this should be fine
// even if an upstream URL fails!
if err := s.r.ConsentManager().DeleteLoginSession(r.Context(), sid); errors.Is(err, sqlcon.ErrNoRows) {
if session, err := s.r.ConsentManager().DeleteLoginSession(ctx, sid); errors.Is(err, sqlcon.ErrNoRows) {
// This is ok (session probably already revoked), do nothing!
} else if err != nil {
return err
} else {
_ = s.r.Kratos().DisableSession(ctx, session.KratosSessionID.String())
}

return nil
Expand Down Expand Up @@ -1044,7 +1051,7 @@ func (s *DefaultStrategy) completeLogout(ctx context.Context, w http.ResponseWri
return nil, err
}

if err := s.performBackChannelLogoutAndDeleteSession(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
if err := s.performBackChannelLogoutAndDeleteSession(r, lr.Subject, lr.SessionID); err != nil {
return nil, err
}

Expand Down Expand Up @@ -1081,7 +1088,7 @@ func (s *DefaultStrategy) HandleHeadlessLogout(ctx context.Context, _ http.Respo
return lsErr
}

if err := s.performBackChannelLogoutAndDeleteSession(r.Context(), r, loginSession.Subject, sid); err != nil {
if err := s.performBackChannelLogoutAndDeleteSession(r, loginSession.Subject, sid); err != nil {
return err
}

Expand Down
3 changes: 1 addition & 2 deletions consent/strategy_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"testing"

hydra "github.com/ory/hydra-client-go/v2"
Expand All @@ -17,8 +18,6 @@ import (
"github.com/ory/fosite/token/jwt"
"github.com/ory/x/urlx"

"net/url"

"github.com/google/uuid"
"github.com/tidwall/gjson"

Expand Down
19 changes: 18 additions & 1 deletion consent/strategy_logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"testing"
"time"

"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/x/pointerx"

"github.com/stretchr/testify/assert"
Expand All @@ -35,9 +36,11 @@ import (

func TestLogoutFlows(t *testing.T) {
ctx := context.Background()
fakeKratos := kratos.NewFake()
reg := internal.NewMockedRegistry(t, &contextx.Default{})
reg.Config().MustSet(ctx, config.KeyAccessTokenStrategy, "opaque")
reg.Config().MustSet(ctx, config.KeyConsentRequestMaxAge, time.Hour)
reg.WithKratos(fakeKratos)

defaultRedirectedMessage := "redirected to default server"
postLogoutCallback := func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -181,7 +184,7 @@ func TestLogoutFlows(t *testing.T) {
checkAndAcceptLoginHandler(t, adminApi, subject, func(t *testing.T, res *hydra.OAuth2LoginRequest, err error) hydra.AcceptOAuth2LoginRequest {
require.NoError(t, err)
//res.Payload.SessionID
return hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Bool(true)}
return hydra.AcceptOAuth2LoginRequest{Remember: pointerx.Bool(true), SessionId: pointerx.Ptr(kratos.FakeSessionID)}
}),
checkAndAcceptConsentHandler(t, adminApi, func(t *testing.T, res *hydra.OAuth2ConsentRequest, err error) hydra.AcceptOAuth2ConsentRequest {
require.NoError(t, err)
Expand Down Expand Up @@ -476,6 +479,7 @@ func TestLogoutFlows(t *testing.T) {
})

t.Run("case=should return to default post logout because session was revoked in browser context", func(t *testing.T) {
fakeKratos.Reset()
c := createSampleClient(t)
sid := make(chan string)
acceptLoginAsAndWatchSid(t, subject, sid)
Expand Down Expand Up @@ -518,9 +522,13 @@ func TestLogoutFlows(t *testing.T) {
assert.NotEmpty(t, res.Request.URL.Query().Get("code"))

wg.Wait()

assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should execute backchannel logout in headless flow with sid", func(t *testing.T) {
fakeKratos.Reset()
numSidConsumers := 2
sid := make(chan string, numSidConsumers)
acceptLoginAsAndWatchSidForConsumers(t, subject, sid, true, numSidConsumers)
Expand All @@ -535,22 +543,31 @@ func TestLogoutFlows(t *testing.T) {
logoutViaHeadlessAndExpectNoContent(t, createBrowserWithSession(t, c), url.Values{"sid": {<-sid}})

backChannelWG.Wait() // we want to ensure that all back channels have been called!
assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should logout in headless flow with non-existing sid", func(t *testing.T) {
fakeKratos.Reset()
logoutViaHeadlessAndExpectNoContent(t, browserWithoutSession, url.Values{"sid": {"non-existing-sid"}})
assert.False(t, fakeKratos.DisableSessionWasCalled)
})

t.Run("case=should logout in headless flow with session that has remember=false", func(t *testing.T) {
fakeKratos.Reset()
sid := make(chan string)
acceptLoginAsAndWatchSidForConsumers(t, subject, sid, false, 1)

c := createSampleClient(t)

logoutViaHeadlessAndExpectNoContent(t, createBrowserWithSession(t, c), url.Values{"sid": {<-sid}})
assert.True(t, fakeKratos.DisableSessionWasCalled)
assert.Equal(t, fakeKratos.LastDisabledSession, kratos.FakeSessionID)
})

t.Run("case=should fail headless logout because neither sid nor subject were provided", func(t *testing.T) {
fakeKratos.Reset()
logoutViaHeadlessAndExpectError(t, browserWithoutSession, url.Values{}, `Either 'subject' or 'sid' query parameters need to be defined.`)
assert.False(t, fakeKratos.DisableSessionWasCalled)
})
}
7 changes: 7 additions & 0 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ const (
KeyPublicURL = "urls.self.public"
KeyAdminURL = "urls.self.admin"
KeyIssuerURL = "urls.self.issuer"
KeyKratosAdminURL = "urls.kratos.admin"
KeyAccessTokenStrategy = "strategies.access_token"
KeyJWTScopeClaimStrategy = "strategies.jwt.scope_claim"
KeyDBIgnoreUnknownTableColumns = "db.ignore_unknown_table_columns"
Expand Down Expand Up @@ -388,6 +389,12 @@ func (p *DefaultProvider) IssuerURL(ctx context.Context) *url.URL {
)
}

func (p *DefaultProvider) KratosAdminURL(ctx context.Context) (*url.URL, bool) {
u := p.getProvider(ctx).RequestURIF(KeyKratosAdminURL, nil)

return u, u != nil
}

func (p *DefaultProvider) OAuth2ClientRegistrationURL(ctx context.Context) *url.URL {
return p.getProvider(ctx).RequestURIF(KeyOAuth2ClientRegistrationURL, new(url.URL))
}
Expand Down
4 changes: 4 additions & 0 deletions driver/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"go.opentelemetry.io/otel/trace"

"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/x/httprouterx"

"github.com/ory/hydra/v2/aead"
Expand Down Expand Up @@ -53,6 +54,7 @@ type Registry interface {
WithLogger(l *logrusx.Logger) Registry
WithTracer(t trace.Tracer) Registry
WithTracerWrapper(TracerWrapper) Registry
WithKratos(k kratos.Client) Registry
x.HTTPClientProvider
GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy

Expand All @@ -71,6 +73,8 @@ type Registry interface {
x.TracingProvider
FlowCipher() *aead.XChaCha20Poly1305

kratos.Provider

RegisterRoutes(ctx context.Context, admin *httprouterx.RouterAdmin, public *httprouterx.RouterPublic)
ClientHandler() *client.Handler
KeyHandler() *jwk.Handler
Expand Down
14 changes: 14 additions & 0 deletions driver/registry_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/fositex"
"github.com/ory/hydra/v2/hsm"
"github.com/ory/hydra/v2/internal/kratos"
"github.com/ory/hydra/v2/jwk"
"github.com/ory/hydra/v2/oauth2"
"github.com/ory/hydra/v2/oauth2/trust"
Expand Down Expand Up @@ -88,6 +89,7 @@ type RegistryBase struct {
hmacs *foauth2.HMACSHAStrategy
fc *fositex.Config
publicCORS *cors.Cors
kratos kratos.Client
}

func (m *RegistryBase) GetJWKSFetcherStrategy() fosite.JWKSFetcherStrategy {
Expand Down Expand Up @@ -201,6 +203,11 @@ func (m *RegistryBase) WithTracerWrapper(wrapper TracerWrapper) Registry {
return m.r
}

func (m *RegistryBase) WithKratos(k kratos.Client) Registry {
m.kratos = k
return m.r
}

func (m *RegistryBase) Logger() *logrusx.Logger {
if m.l == nil {
m.l = logrusx.New("Ory Hydra", m.BuildVersion())
Expand Down Expand Up @@ -552,3 +559,10 @@ func (m *RegistryBase) HSMContext() hsm.Context {
func (m *RegistrySQL) ClientAuthenticator() x.ClientAuthenticator {
return m.OAuth2Provider().(*fosite.Fosite)
}

func (m *RegistryBase) Kratos() kratos.Client {
if m.kratos == nil {
m.kratos = kratos.New(m)
}
return m.kratos
}
17 changes: 12 additions & 5 deletions flow/consent_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ type OAuth2RedirectTo struct {

// swagger:ignore
type LoginSession struct {
ID string `db:"id"`
NID uuid.UUID `db:"nid"`
AuthenticatedAt sqlxx.NullTime `db:"authenticated_at"`
Subject string `db:"subject"`
Remember bool `db:"remember"`
ID string `db:"id"`
NID uuid.UUID `db:"nid"`
AuthenticatedAt sqlxx.NullTime `db:"authenticated_at"`
Subject string `db:"subject"`
KratosSessionID sqlxx.NullString `db:"kratos_session_id"`
Remember bool `db:"remember"`
}

func (LoginSession) TableName() string {
Expand Down Expand Up @@ -292,6 +293,12 @@ type HandledLoginRequest struct {
// required: true
Subject string `json:"subject"`

// KratosSessionID is the session ID of the end-user that authenticated.
// If specified, we will use this value to propagate the logout.
//
// required: false
KratosSessionID string `json:"session_id"`

// ForceSubjectIdentifier forces the "pairwise" user ID of the end-user that authenticated. The "pairwise" user ID refers to the
// (Pairwise Identifier Algorithm)[http://openid.net/specs/openid-connect-core-1_0.html#PairwiseAlg] of the OpenID
// Connect specification. It allows you to set an obfuscated subject ("user") identifier that is unique to the client.
Expand Down
4 changes: 4 additions & 0 deletions flow/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ type Flow struct {
// channel logout. Its value can generally be used to associate consecutive login requests by a certain user.
SessionID sqlxx.NullString `db:"login_session_id"`

KratosSessionID sqlxx.NullString `db:"kratos_session_id"`

LoginVerifier string `db:"login_verifier"`
LoginCSRF string `db:"login_csrf"`

Expand Down Expand Up @@ -291,6 +293,7 @@ func (f *Flow) HandleLoginRequest(h *HandledLoginRequest) error {
f.ForceSubjectIdentifier = h.ForceSubjectIdentifier
f.LoginError = h.Error

f.KratosSessionID = sqlxx.NullString(h.KratosSessionID)
f.LoginRemember = h.Remember
f.LoginRememberFor = h.RememberFor
f.LoginExtendSessionLifespan = h.ExtendSessionLifespan
Expand All @@ -311,6 +314,7 @@ func (f *Flow) GetHandledLoginRequest() HandledLoginRequest {
ACR: f.ACR,
AMR: f.AMR,
Subject: f.Subject,
KratosSessionID: f.KratosSessionID.String(),
ForceSubjectIdentifier: f.ForceSubjectIdentifier,
Context: f.Context,
WasHandled: f.LoginWasUsed,
Expand Down
1 change: 1 addition & 0 deletions flow/flow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func (f *Flow) setHandledLoginRequest(r *HandledLoginRequest) {
f.ACR = r.ACR
f.AMR = r.AMR
f.Subject = r.Subject
f.KratosSessionID = sqlxx.NullString(r.KratosSessionID)
f.ForceSubjectIdentifier = r.ForceSubjectIdentifier
f.Context = r.Context
f.LoginWasUsed = r.WasHandled
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ require (
github.com/openzipkin/zipkin-go v0.4.1 // indirect
github.com/ory/dockertest/v3 v3.9.1 // indirect
github.com/ory/go-convenience v0.1.0 // indirect
github.com/ory/kratos-client-go v0.13.1 // indirect
github.com/pelletier/go-toml v1.9.5 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pkg/profile v1.7.0 // indirect
Expand Down
Loading

0 comments on commit 017859b

Please sign in to comment.