From a1edb4cca3fd7a1dc960813153c4c49dc9e88c62 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Thu, 20 Mar 2025 11:42:20 -0400 Subject: [PATCH 1/4] Fix an issue expired app session won't redirect to login on DynamoDB backend --- lib/web/app/handler.go | 5 ++ lib/web/app/handler_test.go | 162 +++++++++++++++++++++++++++--------- 2 files changed, 127 insertions(+), 40 deletions(-) diff --git a/lib/web/app/handler.go b/lib/web/app/handler.go index a23742957136d..5cc6b6d02da2b 100644 --- a/lib/web/app/handler.go +++ b/lib/web/app/handler.go @@ -390,6 +390,11 @@ func (h *Handler) getAppSession(r *http.Request) (ws types.WebSession, err error h.log.Warnf("Failed to get session: %v.", err) return nil, trace.AccessDenied("invalid session") } + + if ws.Expiry().Before(h.c.Clock.Now()) { + h.logger.WarnContext(r.Context(), "Session expired") + return nil, trace.AccessDenied("session expired") + } return ws, nil } diff --git a/lib/web/app/handler_test.go b/lib/web/app/handler_test.go index d7cc3a6e03fde..a5cebee8889ae 100644 --- a/lib/web/app/handler_test.go +++ b/lib/web/app/handler_test.go @@ -23,6 +23,7 @@ import ( "context" "crypto" "crypto/tls" + "crypto/x509" "crypto/x509/pkix" "encoding/json" "fmt" @@ -325,35 +326,15 @@ func TestMatchApplicationServers(t *testing.T) { caCert: cert, } - // Create a fake remote site and tunnel. - fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, authClient) - tunnel := &reversetunnelclient.FakeServer{ - Sites: []reversetunnelclient.RemoteSite{ - fakeRemoteSite, - }, - } - // Create a httptest server to serve the application requests. It must serve // TLS content with the generated certificate. - tlsCert, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) expectedContent := "Hello from application" - server := &httptest.Server{ - TLS: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, + fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key) + tunnel := &reversetunnelclient.FakeServer{ + Sites: []reversetunnelclient.RemoteSite{ + fakeRemoteSite, }, - Listener: &fakeRemoteListener{fakeRemoteSite}, - Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - fmt.Fprint(w, expectedContent) - })}, } - server.StartTLS() - - // Teardown the remote site and the httptest server. - t.Cleanup(func() { - require.NoError(t, fakeRemoteSite.Close()) - server.Close() - }) p := setup(t, fakeClock, authClient, tunnel) status, content := p.makeRequest(t, "GET", "/", []byte{}, []http.Cookie{ @@ -436,24 +417,9 @@ func TestHealthCheckAppServer(t *testing.T) { caCert: cert, } - fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, authClient) + fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key) authClient.appServers = tc.appServersFunc(t, fakeRemoteSite) - // Create a httptest server to serve the application requests. It must serve - // TLS content with the generated certificate. - tlsCert, err := tls.X509KeyPair(cert, key) - require.NoError(t, err) - server := &httptest.Server{ - TLS: &tls.Config{ - Certificates: []tls.Certificate{tlsCert}, - }, - Listener: &fakeRemoteListener{fakeRemoteSite}, - Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - fmt.Fprint(w, "Hello application") - })}, - } - server.StartTLS() - tunnel := &reversetunnelclient.FakeServer{ Sites: []reversetunnelclient.RemoteSite{fakeRemoteSite}, } @@ -809,3 +775,119 @@ func TestMakeAppRedirectURL(t *testing.T) { }) } } + +func startFakeAppServerOnRemoteSite(t *testing.T, clusterName string, accessPoint authclient.RemoteProxyAccessPoint, cert, key []byte) *reversetunnelclient.FakeRemoteSite { + t.Helper() + + tlsCert, err := tls.X509KeyPair(cert, key) + require.NoError(t, err) + + fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, accessPoint) + server := &httptest.Server{ + TLS: &tls.Config{ + Certificates: []tls.Certificate{tlsCert}, + }, + Listener: &fakeRemoteListener{ + fakeRemote: fakeRemoteSite, + }, + Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + fmt.Fprint(w, "Hello application") + })}, + } + server.StartTLS() + t.Cleanup(func() { + // Close fake remote site first to make sure fake listener quits. + fakeRemoteSite.Close() + server.Close() + }) + return fakeRemoteSite +} + +func TestHandlerAuthenticate(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + clusterName := "test-cluster" + publicAddr := "app.example.com" + key, cert, err := tlsca.GenerateSelfSignedCA( + pkix.Name{CommonName: clusterName}, + []string{publicAddr, apiutils.EncodeClusterName(clusterName)}, + defaults.CATTL, + ) + require.NoError(t, err) + fakeClock := clockwork.NewFakeClock() + + authClient := &mockAuthClient{ + clusterName: clusterName, + appSession: createAppSession(t, fakeClock, key, cert, clusterName, publicAddr), + appServers: []types.AppServer{ + createAppServer(t, publicAddr), + }, + caKey: key, + caCert: cert, + } + + fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key) + + appHandler, err := NewHandler(ctx, &HandlerConfig{ + Clock: fakeClock, + AuthClient: authClient, + AccessPoint: authClient, + ProxyClient: &reversetunnelclient.FakeServer{ + Sites: []reversetunnelclient.RemoteSite{fakeRemoteSite}, + }, + CipherSuites: utils.DefaultCipherSuites(), + IntegrationAppHandler: &mockIntegrationAppHandler{}, + }) + require.NoError(t, err) + + t.Run("with cookie", func(t *testing.T) { + request := httptest.NewRequest("GET", "https://"+publicAddr, nil) + addValidSessionCookiesToRequest(authClient.appSession, request) + + _, err = appHandler.authenticate(ctx, request) + require.NoError(t, err) + }) + + t.Run("with client cert", func(t *testing.T) { + clientCert, err := tls.X509KeyPair(authClient.appSession.GetTLSCert(), authClient.appSession.GetTLSPriv()) + require.NoError(t, err) + require.NotEmpty(t, clientCert.Certificate) + x509Cert, err := x509.ParseCertificate(clientCert.Certificate[0]) + require.NoError(t, err) + + request := httptest.NewRequest("GET", "https://"+publicAddr, nil) + request.TLS.PeerCertificates = []*x509.Certificate{x509Cert} + + _, err = appHandler.authenticate(ctx, request) + require.NoError(t, err) + }) + + t.Run("without cookie or client cert", func(t *testing.T) { + request := httptest.NewRequest("GET", "https://"+publicAddr, nil) + _, err := appHandler.authenticate(ctx, request) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }) + + t.Run("session expired", func(t *testing.T) { + fakeClock.Advance(authClient.appSession.Expiry().Sub(fakeClock.Now()) + time.Minute) + request := httptest.NewRequest("GET", "https://"+publicAddr, nil) + addValidSessionCookiesToRequest(authClient.appSession, request) + + _, err := appHandler.authenticate(ctx, request) + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err)) + }) +} + +func addValidSessionCookiesToRequest(appSession types.WebSession, r *http.Request) { + r.AddCookie(&http.Cookie{ + Name: CookieName, + Value: appSession.GetName(), + }) + r.AddCookie(&http.Cookie{ + Name: SubjectCookieName, + Value: appSession.GetBearerToken(), + }) +} From 7dfb65617c1f8ed93f153301156b552ded8f62b3 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Thu, 20 Mar 2025 15:03:20 -0400 Subject: [PATCH 2/4] fix ut --- lib/web/app/handler_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/web/app/handler_test.go b/lib/web/app/handler_test.go index a5cebee8889ae..fd83e1a75c524 100644 --- a/lib/web/app/handler_test.go +++ b/lib/web/app/handler_test.go @@ -328,7 +328,7 @@ func TestMatchApplicationServers(t *testing.T) { // Create a httptest server to serve the application requests. It must serve // TLS content with the generated certificate. - expectedContent := "Hello from application" + expectedContent := "Hello application" fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key) tunnel := &reversetunnelclient.FakeServer{ Sites: []reversetunnelclient.RemoteSite{ From 4d14a4b17771c18a3e7e569de9a0940b903ed6d1 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Wed, 26 Mar 2025 09:43:43 -0400 Subject: [PATCH 3/4] convert all usages from access point --- lib/web/app/auth.go | 4 +--- lib/web/app/handler.go | 26 +++++++++++++++----------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lib/web/app/auth.go b/lib/web/app/auth.go index fd2bb9620fb29..b6adb1059a02f 100644 --- a/lib/web/app/auth.go +++ b/lib/web/app/auth.go @@ -184,9 +184,7 @@ func (h *Handler) completeAppAuthExchange(w http.ResponseWriter, r *http.Request // Validate that the caller is asking for a session that exists and that they have the secret // session token for. - ws, err := h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{ - SessionID: req.CookieValue, - }) + ws, err := h.getAppSessionFromAccessPoint(r.Context(), req.CookieValue) if err != nil { h.log.WithError(err).Warn("Request failed: session does not exist.") return trace.AccessDenied("access denied") diff --git a/lib/web/app/handler.go b/lib/web/app/handler.go index 5cc6b6d02da2b..287dd79a20080 100644 --- a/lib/web/app/handler.go +++ b/lib/web/app/handler.go @@ -174,9 +174,7 @@ func (h *Handler) HandleConnection(ctx context.Context, clientConn net.Conn) err return trace.Wrap(err) } - ws, err := h.c.AccessPoint.GetAppSession(ctx, types.GetAppSessionRequest{ - SessionID: identity.RouteToApp.SessionID, - }) + ws, err := h.getAppSessionFromAccessPoint(ctx, identity.RouteToApp.SessionID) if err != nil { return trace.Wrap(err) } @@ -390,10 +388,20 @@ func (h *Handler) getAppSession(r *http.Request) (ws types.WebSession, err error h.log.Warnf("Failed to get session: %v.", err) return nil, trace.AccessDenied("invalid session") } + return ws, nil +} +func (h *Handler) getAppSessionFromAccessPoint(ctx context.Context, sessionID string) (types.WebSession, error) { + ws, err := h.c.AccessPoint.GetAppSession(ctx, types.GetAppSessionRequest{ + SessionID: sessionID, + }) + if err != nil { + return nil, trace.Wrap(err) + } + // Do an extra check in case expired app session is still cached. if ws.Expiry().Before(h.c.Clock.Now()) { - h.logger.WarnContext(r.Context(), "Session expired") - return nil, trace.AccessDenied("session expired") + h.logger.DebugContext(ctx, "Session expired") + return nil, trace.AccessDenied("invalid session") } return ws, nil } @@ -410,9 +418,7 @@ func (h *Handler) getAppSessionFromCert(r *http.Request) (types.WebSession, erro // Check that the session exists in the backend cache. This allows the user // to logout and invalidate their application session immediately. This // lookup should also be fast because it's in the local cache. - ws, err := h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{ - SessionID: identity.RouteToApp.SessionID, - }) + ws, err := h.getAppSessionFromAccessPoint(r.Context(), identity.RouteToApp.SessionID) if err != nil { return nil, trace.Wrap(err) } @@ -454,9 +460,7 @@ func (h *Handler) getAppSessionFromCookie(r *http.Request) (types.WebSession, er // Check that the session exists in the backend cache. This allows the user // to logout and invalidate their application session immediately. This // lookup should also be fast because it's in the local cache. - ws, err := h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{ - SessionID: sessionID, - }) + ws, err := h.getAppSessionFromAccessPoint(r.Context(), sessionID) if err != nil { return nil, trace.Wrap(err) } From 836999834c27959099215299922a9c2d69709e74 Mon Sep 17 00:00:00 2001 From: "STeve (Xin) Huang" Date: Tue, 1 Apr 2025 07:25:52 -0700 Subject: [PATCH 4/4] fix logger --- lib/web/app/handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/web/app/handler.go b/lib/web/app/handler.go index 287dd79a20080..fc5e5c243eaa6 100644 --- a/lib/web/app/handler.go +++ b/lib/web/app/handler.go @@ -400,7 +400,7 @@ func (h *Handler) getAppSessionFromAccessPoint(ctx context.Context, sessionID st } // Do an extra check in case expired app session is still cached. if ws.Expiry().Before(h.c.Clock.Now()) { - h.logger.DebugContext(ctx, "Session expired") + h.log.Debug(ctx, "Session expired") return nil, trace.AccessDenied("invalid session") } return ws, nil