Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions lib/web/app/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,7 @@ func (h *Handler) completeAppAuthExchange(w http.ResponseWriter, r *http.Request

// Validate that the caller is asking for a session that exists and that they have the secret
// session token for.
ws, err := h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{
SessionID: req.CookieValue,
})
ws, err := h.getAppSessionFromAccessPoint(r.Context(), req.CookieValue)
if err != nil {
h.log.WithError(err).Warn("Request failed: session does not exist.")
return trace.AccessDenied("access denied")
Expand Down
27 changes: 18 additions & 9 deletions lib/web/app/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,7 @@ func (h *Handler) HandleConnection(ctx context.Context, clientConn net.Conn) err
return trace.Wrap(err)
}

ws, err := h.c.AccessPoint.GetAppSession(ctx, types.GetAppSessionRequest{
SessionID: identity.RouteToApp.SessionID,
})
ws, err := h.getAppSessionFromAccessPoint(ctx, identity.RouteToApp.SessionID)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -393,6 +391,21 @@ func (h *Handler) getAppSession(r *http.Request) (ws types.WebSession, err error
return ws, nil
}

func (h *Handler) getAppSessionFromAccessPoint(ctx context.Context, sessionID string) (types.WebSession, error) {
ws, err := h.c.AccessPoint.GetAppSession(ctx, types.GetAppSessionRequest{
SessionID: sessionID,
})
if err != nil {
return nil, trace.Wrap(err)
}
// Do an extra check in case expired app session is still cached.
if ws.Expiry().Before(h.c.Clock.Now()) {
h.log.Debug(ctx, "Session expired")
return nil, trace.AccessDenied("invalid session")
}
return ws, nil
}

func (h *Handler) getAppSessionFromCert(r *http.Request) (types.WebSession, error) {
if !HasClientCert(r) {
return nil, trace.BadParameter("request missing client certificate")
Expand All @@ -405,9 +418,7 @@ func (h *Handler) getAppSessionFromCert(r *http.Request) (types.WebSession, erro
// Check that the session exists in the backend cache. This allows the user
// to logout and invalidate their application session immediately. This
// lookup should also be fast because it's in the local cache.
ws, err := h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{
SessionID: identity.RouteToApp.SessionID,
})
ws, err := h.getAppSessionFromAccessPoint(r.Context(), identity.RouteToApp.SessionID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -449,9 +460,7 @@ func (h *Handler) getAppSessionFromCookie(r *http.Request) (types.WebSession, er
// Check that the session exists in the backend cache. This allows the user
// to logout and invalidate their application session immediately. This
// lookup should also be fast because it's in the local cache.
ws, err := h.c.AccessPoint.GetAppSession(r.Context(), types.GetAppSessionRequest{
SessionID: sessionID,
})
ws, err := h.getAppSessionFromAccessPoint(r.Context(), sessionID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
162 changes: 122 additions & 40 deletions lib/web/app/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"context"
"crypto"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"fmt"
Expand Down Expand Up @@ -325,36 +326,16 @@ func TestMatchApplicationServers(t *testing.T) {
caCert: cert,
}

// Create a fake remote site and tunnel.
fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, authClient)
// Create a httptest server to serve the application requests. It must serve
// TLS content with the generated certificate.
expectedContent := "Hello application"
fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key)
tunnel := &reversetunnelclient.FakeServer{
Sites: []reversetunnelclient.RemoteSite{
fakeRemoteSite,
},
}

// Create a httptest server to serve the application requests. It must serve
// TLS content with the generated certificate.
tlsCert, err := tls.X509KeyPair(cert, key)
require.NoError(t, err)
expectedContent := "Hello from application"
server := &httptest.Server{
TLS: &tls.Config{
Certificates: []tls.Certificate{tlsCert},
},
Listener: &fakeRemoteListener{fakeRemoteSite},
Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, expectedContent)
})},
}
server.StartTLS()

// Teardown the remote site and the httptest server.
t.Cleanup(func() {
require.NoError(t, fakeRemoteSite.Close())
server.Close()
})

p := setup(t, fakeClock, authClient, tunnel)
status, content := p.makeRequest(t, "GET", "/", []byte{}, []http.Cookie{
{
Expand Down Expand Up @@ -436,24 +417,9 @@ func TestHealthCheckAppServer(t *testing.T) {
caCert: cert,
}

fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, authClient)
fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key)
authClient.appServers = tc.appServersFunc(t, fakeRemoteSite)

// Create a httptest server to serve the application requests. It must serve
// TLS content with the generated certificate.
tlsCert, err := tls.X509KeyPair(cert, key)
require.NoError(t, err)
server := &httptest.Server{
TLS: &tls.Config{
Certificates: []tls.Certificate{tlsCert},
},
Listener: &fakeRemoteListener{fakeRemoteSite},
Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, "Hello application")
})},
}
server.StartTLS()

tunnel := &reversetunnelclient.FakeServer{
Sites: []reversetunnelclient.RemoteSite{fakeRemoteSite},
}
Expand Down Expand Up @@ -809,3 +775,119 @@ func TestMakeAppRedirectURL(t *testing.T) {
})
}
}

func startFakeAppServerOnRemoteSite(t *testing.T, clusterName string, accessPoint authclient.RemoteProxyAccessPoint, cert, key []byte) *reversetunnelclient.FakeRemoteSite {
t.Helper()

tlsCert, err := tls.X509KeyPair(cert, key)
require.NoError(t, err)

fakeRemoteSite := reversetunnelclient.NewFakeRemoteSite(clusterName, accessPoint)
server := &httptest.Server{
TLS: &tls.Config{
Certificates: []tls.Certificate{tlsCert},
},
Listener: &fakeRemoteListener{
fakeRemote: fakeRemoteSite,
},
Config: &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
fmt.Fprint(w, "Hello application")
})},
}
server.StartTLS()
t.Cleanup(func() {
// Close fake remote site first to make sure fake listener quits.
fakeRemoteSite.Close()
server.Close()
})
return fakeRemoteSite
}

func TestHandlerAuthenticate(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

clusterName := "test-cluster"
publicAddr := "app.example.com"
key, cert, err := tlsca.GenerateSelfSignedCA(
pkix.Name{CommonName: clusterName},
[]string{publicAddr, apiutils.EncodeClusterName(clusterName)},
defaults.CATTL,
)
require.NoError(t, err)
fakeClock := clockwork.NewFakeClock()

authClient := &mockAuthClient{
clusterName: clusterName,
appSession: createAppSession(t, fakeClock, key, cert, clusterName, publicAddr),
appServers: []types.AppServer{
createAppServer(t, publicAddr),
},
caKey: key,
caCert: cert,
}

fakeRemoteSite := startFakeAppServerOnRemoteSite(t, clusterName, authClient, cert, key)

appHandler, err := NewHandler(ctx, &HandlerConfig{
Clock: fakeClock,
AuthClient: authClient,
AccessPoint: authClient,
ProxyClient: &reversetunnelclient.FakeServer{
Sites: []reversetunnelclient.RemoteSite{fakeRemoteSite},
},
CipherSuites: utils.DefaultCipherSuites(),
IntegrationAppHandler: &mockIntegrationAppHandler{},
})
require.NoError(t, err)

t.Run("with cookie", func(t *testing.T) {
request := httptest.NewRequest("GET", "https://"+publicAddr, nil)
addValidSessionCookiesToRequest(authClient.appSession, request)

_, err = appHandler.authenticate(ctx, request)
require.NoError(t, err)
})

t.Run("with client cert", func(t *testing.T) {
clientCert, err := tls.X509KeyPair(authClient.appSession.GetTLSCert(), authClient.appSession.GetTLSPriv())
require.NoError(t, err)
require.NotEmpty(t, clientCert.Certificate)
x509Cert, err := x509.ParseCertificate(clientCert.Certificate[0])
require.NoError(t, err)

request := httptest.NewRequest("GET", "https://"+publicAddr, nil)
request.TLS.PeerCertificates = []*x509.Certificate{x509Cert}

_, err = appHandler.authenticate(ctx, request)
require.NoError(t, err)
})

t.Run("without cookie or client cert", func(t *testing.T) {
request := httptest.NewRequest("GET", "https://"+publicAddr, nil)
_, err := appHandler.authenticate(ctx, request)
require.Error(t, err)
require.True(t, trace.IsAccessDenied(err))
})

t.Run("session expired", func(t *testing.T) {
fakeClock.Advance(authClient.appSession.Expiry().Sub(fakeClock.Now()) + time.Minute)
request := httptest.NewRequest("GET", "https://"+publicAddr, nil)
addValidSessionCookiesToRequest(authClient.appSession, request)

_, err := appHandler.authenticate(ctx, request)
require.Error(t, err)
require.True(t, trace.IsAccessDenied(err))
})
}

func addValidSessionCookiesToRequest(appSession types.WebSession, r *http.Request) {
r.AddCookie(&http.Cookie{
Name: CookieName,
Value: appSession.GetName(),
})
r.AddCookie(&http.Cookie{
Name: SubjectCookieName,
Value: appSession.GetBearerToken(),
})
}
Loading