Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 1 addition & 7 deletions lib/multiplexer/multiplexer.go
Original file line number Diff line number Diff line change
Expand Up @@ -496,19 +496,13 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) {

proxyLine = newProxyLine
// repeat the cycle to detect the protocol
case ProtoTLS, ProtoSSH, ProtoHTTP:
case ProtoTLS, ProtoSSH, ProtoHTTP, ProtoPostgres:
return &Conn{
protocol: proto,
Conn: conn,
reader: reader,
proxyLine: proxyLine,
}, nil
case ProtoPostgres:
return &Conn{
protocol: proto,
Conn: conn,
reader: reader,
}, nil
}
}
// if code ended here after three attempts, something is wrong
Expand Down
45 changes: 31 additions & 14 deletions lib/multiplexer/multiplexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,9 @@ func TestMux(t *testing.T) {
defer cancel()

mux, err := New(Config{
Context: ctx,
Listener: listener,
Context: ctx,
Listener: listener,
EnableExternalProxyProtocol: true,
})
require.NoError(t, err)
go mux.Serve()
Expand All @@ -548,20 +549,36 @@ func TestMux(t *testing.T) {
// register listener before establishing frontend connection
dblistener := mux.DB()

// Connect to the listener and send Postgres SSLRequest which is what
// psql or other Postgres client will do.
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()
check := func(t *testing.T, expectedAddr string, proxyLine []byte) {
// Connect to the listener and send Postgres SSLRequest which is what
// psql or other Postgres client will do.
conn, err := net.Dial("tcp", listener.Addr().String())
require.NoError(t, err)
defer conn.Close()

frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(conn), conn)
err = frontend.Send(&pgproto3.SSLRequest{})
require.NoError(t, err)
_, err = conn.Write(sampleProxyV2Line)
require.NoError(t, err)

frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(conn), conn)
err = frontend.Send(&pgproto3.SSLRequest{})
require.NoError(t, err)

// This should not hang indefinitely since we set timeout on the mux context above.
conn, err = dblistener.Accept()
require.NoError(t, err, "detected Postgres connection")
require.Equal(t, ProtoPostgres, conn.(*Conn).Protocol())
// This should not hang indefinitely since we set timeout on the mux context above.
dbConn, err := dblistener.Accept()
require.NoError(t, err, "detected Postgres connection")
require.Equal(t, ProtoPostgres, dbConn.(*Conn).Protocol())
if expectedAddr != "" {
require.Equal(t, expectedAddr, dbConn.RemoteAddr().String())
}
}

t.Run("without proxy line", func(t *testing.T) {
check(t, "", nil)
})

t.Run("with proxy line", func(t *testing.T) {
check(t, "127.0.0.1:12345", sampleProxyV2Line)
})
})

// WebListener verifies web listener correctly multiplexes connections
Expand Down