diff --git a/lib/client/db/postgres/repl/repl.go b/lib/client/db/postgres/repl/repl.go index 66c8cb0f046f2..344ae8f033f13 100644 --- a/lib/client/db/postgres/repl/repl.go +++ b/lib/client/db/postgres/repl/repl.go @@ -57,6 +57,11 @@ func New(_ context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, err applicationNameParamName: applicationNameParamValue, } config.TLSConfig = nil + // disable fallbacks because our fake dialer returns the same connection + // each time and pgconn closes a conn on error before using a fallback, + // which obscures the actual error and instead shows: + // "failed to write startup message (use of closed network connection)" + config.Fallbacks = nil // Provide a lookup function to avoid having the hostname placeholder to // resolve into something else. Note that the returned value won't be used. diff --git a/lib/client/db/postgres/repl/repl_test.go b/lib/client/db/postgres/repl/repl_test.go index 58edeb7947443..47f2346bf0825 100644 --- a/lib/client/db/postgres/repl/repl_test.go +++ b/lib/client/db/postgres/repl/repl_test.go @@ -27,6 +27,7 @@ import ( "github.com/gravitational/trace" "github.com/jackc/pgconn" + "github.com/jackc/pgerrcode" "github.com/jackc/pgproto3/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -178,14 +179,34 @@ func TestClose(t *testing.T) { func TestConnectionError(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) - - // Force the server to be closed - tc.CloseServer() - - err := instance.Run(ctx) - require.Error(t, err) - require.True(t, trace.IsConnectionProblem(err), "expected run to be a connection error but got %T", err) + tests := []struct { + desc string + modifyTestCtx func(tc *testCtx) + wantErrContains string + }{ + { + desc: "closed server", + // Force the server to be closed + modifyTestCtx: func(tc *testCtx) { tc.CloseServer() }, + wantErrContains: "failed to write startup message", + }, + { + desc: "access denied", + modifyTestCtx: func(tc *testCtx) { tc.denyAccess = true }, + wantErrContains: "server error (ERROR: access to db denied (SQLSTATE 28000))", + }, + } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + instance, tc := StartWithServer(t, ctx, WithSkipREPLRun()) + + test.modifyTestCtx(tc) + err := instance.Run(ctx) + require.Error(t, err) + require.True(t, trace.IsConnectionProblem(err), "expected run to be a connection error but got %T", err) + require.ErrorContains(t, err, test.wantErrContains) + }) + } } func writeLine(t *testing.T, c *testCtx, line string) { @@ -259,6 +280,8 @@ type testCtx struct { cfg *testCtxConfig ctx context.Context cancelFunc context.CancelFunc + // denyAccess controls whether access is denied during authentication + denyAccess bool // conn is the connection used by tests to read/write from/to the REPL. conn net.Conn @@ -404,6 +427,16 @@ func (tc *testCtx) processMessages() error { switch msg := startupMessage.(type) { case *pgproto3.StartupMessage: + if tc.denyAccess { + if err := tc.pgClient.Send(&pgproto3.ErrorResponse{ + Severity: "ERROR", + Code: pgerrcode.InvalidAuthorizationSpecification, + Message: "access to db denied", + }); err != nil { + return trace.Wrap(err) + } + return nil + } // Accept auth and send ready for query. if err := tc.pgClient.Send(&pgproto3.AuthenticationOk{}); err != nil { return trace.Wrap(err) diff --git a/lib/web/databases.go b/lib/web/databases.go index 0c33f4b83bcc1..f9bbf0c782a56 100644 --- a/lib/web/databases.go +++ b/lib/web/databases.go @@ -491,9 +491,6 @@ func (h *Handler) dbConnect( } defer sess.Close() - // Don't close the terminal stream on session error, as it would also - // cause the underlying connection to be closed. This will prevent the - // middleware from properly writing the error into the WebSocket connection. if err := sess.Run(); err != nil { log.ErrorContext(ctx, "Database interactive session exited with error", "error", err) return nil, trace.Wrap(err) @@ -624,11 +621,27 @@ func newDatabaseInteractiveSession(ctx context.Context, cfg databaseInteractiveS replConn: replConn, alpnConn: alpnConn, stream: terminal.NewStream(ctx, terminal.StreamConfig{ - WS: cfg.ws, + // Don't close the terminal stream on session error, as it would also + // cause the underlying connection to be closed. This will prevent the + // middleware from properly writing the error into the WebSocket connection. + // The middleware initiates the connection, forwards it to our + // handler, and always closes it. + WS: noopCloserWS{Conn: cfg.ws}, }), }, nil } +// noopCloserWS prevents the stream from closing the websocket, to allow the +// middleware to write any returned errors to the client before closing the +// websocket. +type noopCloserWS struct { + *websocket.Conn +} + +func (c noopCloserWS) Close() error { + return nil +} + func (s *databaseInteractiveSession) Run() error { replConn, err := s.makeReplConn() if err != nil { diff --git a/lib/web/databases_test.go b/lib/web/databases_test.go index daaace2d0bbf8..54b92a611bebb 100644 --- a/lib/web/databases_test.go +++ b/lib/web/databases_test.go @@ -599,61 +599,81 @@ func TestConnectDatabaseInteractiveSession(t *testing.T) { _, err = s.server.Auth().UpsertDatabaseServer(ctx, mustCreateDatabaseServer(t, selfHosted)) require.NoError(t, err) - - u := url.URL{ - Host: s.webServer.Listener.Addr().String(), - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()), + tests := []struct { + desc string + replErr error + }{ + { + desc: "success", + }, + { + desc: "errors are sent to the user", + replErr: trace.Errorf("database connection interrupted by unexpected llama crossing"), + }, } + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + repl.setError(test.replErr) + u := url.URL{ + Host: s.webServer.Listener.Addr().String(), + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()), + } - header := http.Header{} - header.Add(xForwardedForHeader, "1.2.3.4") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } + header := http.Header{} + header.Add(xForwardedForHeader, "1.2.3.4") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } - dialer := websocket.Dialer{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - } + dialer := websocket.Dialer{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } - ws, resp, err := dialer.DialContext(ctx, u.String(), header) - require.NoError(t, err) - defer ws.Close() - require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) - require.NoError(t, resp.Body.Close()) - require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token)) - - req := DatabaseSessionRequest{ - Protocol: databaseProtocol, - ServiceName: databaseName, - DatabaseName: "postgres", - DatabaseUser: "postgres", - DatabaseRoles: []string{"reader"}, + ws, resp, err := dialer.DialContext(ctx, u.String(), header) + require.NoError(t, err) + defer ws.Close() + require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode) + require.NoError(t, resp.Body.Close()) + require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token)) + + req := DatabaseSessionRequest{ + Protocol: databaseProtocol, + ServiceName: databaseName, + DatabaseName: "postgres", + DatabaseUser: "postgres", + DatabaseRoles: []string{"reader"}, + } + encodedReq, err := json.Marshal(req) + require.NoError(t, err) + reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{ + Version: defaults.WebsocketVersion, + Type: defaults.WebsocketDatabaseSessionRequest, + Payload: string(encodedReq), + }) + require.NoError(t, err) + require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage)) + + performMFACeremonyWS(t, ws, pack) + + // After the MFA is performed we expect the WebSocket to receive the + // session data information. + sessionData := receiveWSMessage(t, ws) + require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type) + + // Assert data written by the REPL comes as raw data. + replResp := receiveWSMessage(t, ws) + if test.replErr != nil { + require.Equal(t, defaults.WebsocketError, replResp.Type) + require.Equal(t, test.replErr.Error(), replResp.Payload) + } else { + require.Equal(t, defaults.WebsocketRaw, replResp.Type) + require.Equal(t, repl.message, replResp.Payload) + } + require.NoError(t, ws.Close()) + require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed") + }) } - encodedReq, err := json.Marshal(req) - require.NoError(t, err) - reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{ - Version: defaults.WebsocketVersion, - Type: defaults.WebsocketDatabaseSessionRequest, - Payload: string(encodedReq), - }) - require.NoError(t, err) - require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage)) - - performMFACeremonyWS(t, ws, pack) - - // After the MFA is performed we expect the WebSocket to receive the - // session data information. - sessionData := receiveWSMessage(t, ws) - require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type) - - // Assert data written by the REPL comes as raw data. - replResp := receiveWSMessage(t, ws) - require.Equal(t, defaults.WebsocketRaw, replResp.Type) - require.Equal(t, repl.message, replResp.Payload) - - require.NoError(t, ws.Close()) - require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed") } func receiveWSMessage(t *testing.T, ws *websocket.Conn) terminal.Envelope { @@ -719,17 +739,28 @@ func (m *mockDatabaseREPLRegistry) IsSupported(protocol string) bool { type mockDatabaseREPL struct { mu sync.Mutex message string + err error cfg *dbrepl.NewREPLConfig closed bool } +func (m *mockDatabaseREPL) setError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.err = err +} + func (m *mockDatabaseREPL) Run(_ context.Context) error { m.mu.Lock() defer func() { - m.closeUnlocked() + m.closeLocked() m.mu.Unlock() }() + if m.err != nil { + return trace.Wrap(m.err) + } + if _, err := m.cfg.Client.Write([]byte(m.message)); err != nil { return trace.Wrap(err) } @@ -753,7 +784,7 @@ func (m *mockDatabaseREPL) getClosed() bool { return m.closed } -func (m *mockDatabaseREPL) closeUnlocked() { +func (m *mockDatabaseREPL) closeLocked() { m.closed = true }