From dec0f6840c2948df4d83d0ae7d496d4cbd0aeac8 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Thu, 31 Jul 2025 14:27:05 -0700 Subject: [PATCH] Fix web UI database REPL error message on startup - The websocket must not be closed before an error is written to it. - Since the Postgres REPL uses a fake dialer returning a single connection, we must disable connection fallbacks to avoid obscuring the error message with fallback errors attempted after closing the connection. This fixes error propagation for access denied errors during connection startup. --- lib/client/db/postgres/repl/repl.go | 5 + lib/client/db/postgres/repl/repl_test.go | 49 ++++++-- lib/web/databases.go | 21 +++- lib/web/databases_test.go | 135 ++++++++++++++--------- 4 files changed, 146 insertions(+), 64 deletions(-) diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index 66c8cb0f046f2..344ae8f033f13 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -57,6 +57,11 @@ func New(_ context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, err applicationNameParamName: applicationNameParamValue, } config.TLSConfig = nil + // disable fallbacks because our fake dialer returns the same connection + // each time and pgconn closes a conn on error before using a fallback, + // which obscures the actual error and instead shows: + // "failed to write startup message (use of closed network connection)" + config.Fallbacks = nil // Provide a lookup function to avoid having the hostname placeholder to // resolve into something else. Note that the returned value won't be used. diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index 58edeb7947443..47f2346bf0825 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -178,14 +179,34 @@ func TestClose(t *testing.T) { func TestConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) - - // Force the server to be closed - tc.CloseServer() - - err := instance.Run(ctx) - require.Error(t, err) - require.True(t, trace.IsConnectionProblem(err), "expected run to be a connection error but got %T", err) + tests := []struct { + desc string + modifyTestCtx func(tc *testCtx) + wantErrContains string + }{ + { + desc: "closed server", + // Force the server to be closed + modifyTestCtx: func(tc *testCtx) { tc.CloseServer() }, + wantErrContains: "failed to write startup message", + }, + { + desc: "access denied", + modifyTestCtx: func(tc *testCtx) { tc.denyAccess = true }, + wantErrContains: "server error (ERROR: access to db denied (SQLSTATE 28000))", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) + + test.modifyTestCtx(tc) + err := instance.Run(ctx) + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err), "expected run to be a connection error but got %T", err) + require.ErrorContains(t, err, test.wantErrContains) + }) + } } func writeLine(t *testing.T, c *testCtx, line string) { @@ -259,6 +280,8 @@ type testCtx struct { cfg *testCtxConfig ctx context.Context cancelFunc context.CancelFunc + // denyAccess controls whether access is denied during authentication + denyAccess bool // conn is the connection used by tests to read/write from/to the REPL. conn net.Conn @@ -404,6 +427,16 @@ func (tc *testCtx) processMessages() error { switch msg := startupMessage.(type) { case *pgproto3.StartupMessage: + if tc.denyAccess { + if err := tc.pgClient.Send(&pgproto3.ErrorResponse{ + Severity: "ERROR", + Code: pgerrcode.InvalidAuthorizationSpecification, + Message: "access to db denied", + }); err != nil { + return trace.Wrap(err) + } + return nil + } // Accept auth and send ready for query. if err := tc.pgClient.Send(&pgproto3.AuthenticationOk{}); err != nil { return trace.Wrap(err) diff --git a/lib/web/databases.go b/lib/web/databases.go index 0c33f4b83bcc1..f9bbf0c782a56 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -491,9 +491,6 @@ func (h *Handler) dbConnect( } defer sess.Close() - // Don't close the terminal stream on session error, as it would also - // cause the underlying connection to be closed. This will prevent the - // middleware from properly writing the error into the WebSocket connection. if err := sess.Run(); err != nil { log.ErrorContext(ctx, "Database interactive session exited with error", "error", err) return nil, trace.Wrap(err) @@ -624,11 +621,27 @@ func newDatabaseInteractiveSession(ctx context.Context, cfg databaseInteractiveS replConn: replConn, alpnConn: alpnConn, stream: terminal.NewStream(ctx, terminal.StreamConfig{ - WS: cfg.ws, + // Don't close the terminal stream on session error, as it would also + // cause the underlying connection to be closed. This will prevent the + // middleware from properly writing the error into the WebSocket connection. + // The middleware initiates the connection, forwards it to our + // handler, and always closes it. + WS: noopCloserWS{Conn: cfg.ws}, }), }, nil } +// noopCloserWS prevents the stream from closing the websocket, to allow the +// middleware to write any returned errors to the client before closing the +// websocket. +type noopCloserWS struct { + *websocket.Conn +} + +func (c noopCloserWS) Close() error { + return nil +} + func (s *databaseInteractiveSession) Run() error { replConn, err := s.makeReplConn() if err != nil { diff --git a/lib/web/databases_test.go b/lib/web/databases_test.go index daaace2d0bbf8..54b92a611bebb 100644 --- a/lib/web/databases_test.go +++ b/lib/web/databases_test.go @@ -599,61 +599,81 @@ func TestConnectDatabaseInteractiveSession(t *testing.T) { _, err = s.server.Auth().UpsertDatabaseServer(ctx, mustCreateDatabaseServer(t, selfHosted)) require.NoError(t, err) - - u := url.URL{ - Host: s.webServer.Listener.Addr().String(), - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()), + tests := []struct { + desc string + replErr error + }{ + { + desc: "success", + }, + { + desc: "errors are sent to the user", + replErr: trace.Errorf("database connection interrupted by unexpected llama crossing"), + }, } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + repl.setError(test.replErr) + u := url.URL{ + Host: s.webServer.Listener.Addr().String(), + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()), + } - header := http.Header{} - header.Add(xForwardedForHeader, "1.2.3.4") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } + header := http.Header{} + header.Add(xForwardedForHeader, "1.2.3.4") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } - dialer := websocket.Dialer{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } - ws, resp, err := dialer.DialContext(ctx, u.String(), header) - require.NoError(t, err) - defer ws.Close() - require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) - require.NoError(t, resp.Body.Close()) - require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token)) - - req := DatabaseSessionRequest{ - Protocol: databaseProtocol, - ServiceName: databaseName, - DatabaseName: "postgres", - DatabaseUser: "postgres", - DatabaseRoles: []string{"reader"}, + ws, resp, err := dialer.DialContext(ctx, u.String(), header) + require.NoError(t, err) + defer ws.Close() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token)) + + req := DatabaseSessionRequest{ + Protocol: databaseProtocol, + ServiceName: databaseName, + DatabaseName: "postgres", + DatabaseUser: "postgres", + DatabaseRoles: []string{"reader"}, + } + encodedReq, err := json.Marshal(req) + require.NoError(t, err) + reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketDatabaseSessionRequest, + Payload: string(encodedReq), + }) + require.NoError(t, err) + require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage)) + + performMFACeremonyWS(t, ws, pack) + + // After the MFA is performed we expect the WebSocket to receive the + // session data information. + sessionData := receiveWSMessage(t, ws) + require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type) + + // Assert data written by the REPL comes as raw data. + replResp := receiveWSMessage(t, ws) + if test.replErr != nil { + require.Equal(t, defaults.WebsocketError, replResp.Type) + require.Equal(t, test.replErr.Error(), replResp.Payload) + } else { + require.Equal(t, defaults.WebsocketRaw, replResp.Type) + require.Equal(t, repl.message, replResp.Payload) + } + require.NoError(t, ws.Close()) + require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed") + }) } - encodedReq, err := json.Marshal(req) - require.NoError(t, err) - reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketDatabaseSessionRequest, - Payload: string(encodedReq), - }) - require.NoError(t, err) - require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage)) - - performMFACeremonyWS(t, ws, pack) - - // After the MFA is performed we expect the WebSocket to receive the - // session data information. - sessionData := receiveWSMessage(t, ws) - require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type) - - // Assert data written by the REPL comes as raw data. - replResp := receiveWSMessage(t, ws) - require.Equal(t, defaults.WebsocketRaw, replResp.Type) - require.Equal(t, repl.message, replResp.Payload) - - require.NoError(t, ws.Close()) - require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed") } func receiveWSMessage(t *testing.T, ws *websocket.Conn) terminal.Envelope { @@ -719,17 +739,28 @@ func (m *mockDatabaseREPLRegistry) IsSupported(protocol string) bool { type mockDatabaseREPL struct { mu sync.Mutex message string + err error cfg *dbrepl.NewREPLConfig closed bool } +func (m *mockDatabaseREPL) setError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.err = err +} + func (m *mockDatabaseREPL) Run(_ context.Context) error { m.mu.Lock() defer func() { - m.closeUnlocked() + m.closeLocked() m.mu.Unlock() }() + if m.err != nil { + return trace.Wrap(m.err) + } + if _, err := m.cfg.Client.Write([]byte(m.message)); err != nil { return trace.Wrap(err) } @@ -753,7 +784,7 @@ func (m *mockDatabaseREPL) getClosed() bool { return m.closed } -func (m *mockDatabaseREPL) closeUnlocked() { +func (m *mockDatabaseREPL) closeLocked() { m.closed = true }