Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/client/db/postgres/repl/repl.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ func New(_ context.Context, cfg *dbrepl.NewREPLConfig) (dbrepl.REPLInstance, err
applicationNameParamName: applicationNameParamValue,
}
config.TLSConfig = nil
// disable fallbacks because our fake dialer returns the same connection
// each time and pgconn closes a conn on error before using a fallback,
// which obscures the actual error and instead shows:
// "failed to write startup message (use of closed network connection)"
config.Fallbacks = nil

// Provide a lookup function to avoid having the hostname placeholder to
// resolve into something else. Note that the returned value won't be used.
Expand Down
49 changes: 41 additions & 8 deletions lib/client/db/postgres/repl/repl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
"github.com/jackc/pgproto3/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -178,14 +179,34 @@ func TestClose(t *testing.T) {
func TestConnectionError(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
instance, tc := StartWithServer(t, ctx, WithSkipREPLRun())

// Force the server to be closed
tc.CloseServer()

err := instance.Run(ctx)
require.Error(t, err)
require.True(t, trace.IsConnectionProblem(err), "expected run to be a connection error but got %T", err)
tests := []struct {
desc string
modifyTestCtx func(tc *testCtx)
wantErrContains string
}{
{
desc: "closed server",
// Force the server to be closed
modifyTestCtx: func(tc *testCtx) { tc.CloseServer() },
wantErrContains: "failed to write startup message",
},
{
desc: "access denied",
modifyTestCtx: func(tc *testCtx) { tc.denyAccess = true },
wantErrContains: "server error (ERROR: access to db denied (SQLSTATE 28000))",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
instance, tc := StartWithServer(t, ctx, WithSkipREPLRun())

test.modifyTestCtx(tc)
err := instance.Run(ctx)
require.Error(t, err)
require.True(t, trace.IsConnectionProblem(err), "expected run to be a connection error but got %T", err)
require.ErrorContains(t, err, test.wantErrContains)
})
}
}

func writeLine(t *testing.T, c *testCtx, line string) {
Expand Down Expand Up @@ -259,6 +280,8 @@ type testCtx struct {
cfg *testCtxConfig
ctx context.Context
cancelFunc context.CancelFunc
// denyAccess controls whether access is denied during authentication
denyAccess bool

// conn is the connection used by tests to read/write from/to the REPL.
conn net.Conn
Expand Down Expand Up @@ -404,6 +427,16 @@ func (tc *testCtx) processMessages() error {

switch msg := startupMessage.(type) {
case *pgproto3.StartupMessage:
if tc.denyAccess {
if err := tc.pgClient.Send(&pgproto3.ErrorResponse{
Severity: "ERROR",
Code: pgerrcode.InvalidAuthorizationSpecification,
Message: "access to db denied",
}); err != nil {
return trace.Wrap(err)
}
return nil
}
// Accept auth and send ready for query.
if err := tc.pgClient.Send(&pgproto3.AuthenticationOk{}); err != nil {
return trace.Wrap(err)
Expand Down
21 changes: 17 additions & 4 deletions lib/web/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,6 @@ func (h *Handler) dbConnect(
}
defer sess.Close()

// Don't close the terminal stream on session error, as it would also
// cause the underlying connection to be closed. This will prevent the
// middleware from properly writing the error into the WebSocket connection.
if err := sess.Run(); err != nil {
log.ErrorContext(ctx, "Database interactive session exited with error", "error", err)
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -624,11 +621,27 @@ func newDatabaseInteractiveSession(ctx context.Context, cfg databaseInteractiveS
replConn: replConn,
alpnConn: alpnConn,
stream: terminal.NewStream(ctx, terminal.StreamConfig{
WS: cfg.ws,
// Don't close the terminal stream on session error, as it would also
// cause the underlying connection to be closed. This will prevent the
// middleware from properly writing the error into the WebSocket connection.
// The middleware initiates the connection, forwards it to our
// handler, and always closes it.
WS: noopCloserWS{Conn: cfg.ws},
}),
}, nil
}

// noopCloserWS prevents the stream from closing the websocket, to allow the
// middleware to write any returned errors to the client before closing the
// websocket.
type noopCloserWS struct {
*websocket.Conn
}

func (c noopCloserWS) Close() error {
return nil
}

func (s *databaseInteractiveSession) Run() error {
replConn, err := s.makeReplConn()
if err != nil {
Expand Down
135 changes: 83 additions & 52 deletions lib/web/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,61 +599,81 @@ func TestConnectDatabaseInteractiveSession(t *testing.T) {

_, err = s.server.Auth().UpsertDatabaseServer(ctx, mustCreateDatabaseServer(t, selfHosted))
require.NoError(t, err)

u := url.URL{
Host: s.webServer.Listener.Addr().String(),
Scheme: client.WSS,
Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()),
tests := []struct {
desc string
replErr error
}{
{
desc: "success",
},
{
desc: "errors are sent to the user",
replErr: trace.Errorf("database connection interrupted by unexpected llama crossing"),
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
repl.setError(test.replErr)
u := url.URL{
Host: s.webServer.Listener.Addr().String(),
Scheme: client.WSS,
Path: fmt.Sprintf("/v1/webapi/sites/%s/db/exec/ws", s.server.ClusterName()),
}

header := http.Header{}
header.Add(xForwardedForHeader, "1.2.3.4")
for _, cookie := range pack.cookies {
header.Add("Cookie", cookie.String())
}
header := http.Header{}
header.Add(xForwardedForHeader, "1.2.3.4")
for _, cookie := range pack.cookies {
header.Add("Cookie", cookie.String())
}

dialer := websocket.Dialer{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
dialer := websocket.Dialer{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}

ws, resp, err := dialer.DialContext(ctx, u.String(), header)
require.NoError(t, err)
defer ws.Close()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.NoError(t, resp.Body.Close())
require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token))

req := DatabaseSessionRequest{
Protocol: databaseProtocol,
ServiceName: databaseName,
DatabaseName: "postgres",
DatabaseUser: "postgres",
DatabaseRoles: []string{"reader"},
ws, resp, err := dialer.DialContext(ctx, u.String(), header)
require.NoError(t, err)
defer ws.Close()
require.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode)
require.NoError(t, resp.Body.Close())
require.NoError(t, makeAuthReqOverWS(ws, pack.session.Token))

req := DatabaseSessionRequest{
Protocol: databaseProtocol,
ServiceName: databaseName,
DatabaseName: "postgres",
DatabaseUser: "postgres",
DatabaseRoles: []string{"reader"},
}
encodedReq, err := json.Marshal(req)
require.NoError(t, err)
reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{
Version: defaults.WebsocketVersion,
Type: defaults.WebsocketDatabaseSessionRequest,
Payload: string(encodedReq),
})
require.NoError(t, err)
require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage))

performMFACeremonyWS(t, ws, pack)

// After the MFA is performed we expect the WebSocket to receive the
// session data information.
sessionData := receiveWSMessage(t, ws)
require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type)

// Assert data written by the REPL comes as raw data.
replResp := receiveWSMessage(t, ws)
if test.replErr != nil {
require.Equal(t, defaults.WebsocketError, replResp.Type)
require.Equal(t, test.replErr.Error(), replResp.Payload)
} else {
require.Equal(t, defaults.WebsocketRaw, replResp.Type)
require.Equal(t, repl.message, replResp.Payload)
}
require.NoError(t, ws.Close())
require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed")
})
}
encodedReq, err := json.Marshal(req)
require.NoError(t, err)
reqWebSocketMessage, err := proto.Marshal(&terminal.Envelope{
Version: defaults.WebsocketVersion,
Type: defaults.WebsocketDatabaseSessionRequest,
Payload: string(encodedReq),
})
require.NoError(t, err)
require.NoError(t, ws.WriteMessage(websocket.BinaryMessage, reqWebSocketMessage))

performMFACeremonyWS(t, ws, pack)

// After the MFA is performed we expect the WebSocket to receive the
// session data information.
sessionData := receiveWSMessage(t, ws)
require.Equal(t, defaults.WebsocketSessionMetadata, sessionData.Type)

// Assert data written by the REPL comes as raw data.
replResp := receiveWSMessage(t, ws)
require.Equal(t, defaults.WebsocketRaw, replResp.Type)
require.Equal(t, repl.message, replResp.Payload)

require.NoError(t, ws.Close())
require.True(t, repl.getClosed(), "expected REPL instance to be closed after websocket.Conn is closed")
}

func receiveWSMessage(t *testing.T, ws *websocket.Conn) terminal.Envelope {
Expand Down Expand Up @@ -719,17 +739,28 @@ func (m *mockDatabaseREPLRegistry) IsSupported(protocol string) bool {
type mockDatabaseREPL struct {
mu sync.Mutex
message string
err error
cfg *dbrepl.NewREPLConfig
closed bool
}

func (m *mockDatabaseREPL) setError(err error) {
m.mu.Lock()
defer m.mu.Unlock()
m.err = err
}

func (m *mockDatabaseREPL) Run(_ context.Context) error {
m.mu.Lock()
defer func() {
m.closeUnlocked()
m.closeLocked()
m.mu.Unlock()
}()

if m.err != nil {
return trace.Wrap(m.err)
}

if _, err := m.cfg.Client.Write([]byte(m.message)); err != nil {
return trace.Wrap(err)
}
Expand All @@ -753,7 +784,7 @@ func (m *mockDatabaseREPL) getClosed() bool {
return m.closed
}

func (m *mockDatabaseREPL) closeUnlocked() {
func (m *mockDatabaseREPL) closeLocked() {
m.closed = true
}

Expand Down
Loading