diff --git a/lib/multiplexer/multiplexer.go b/lib/multiplexer/multiplexer.go index ce1d1a1d0e40a..8ab2d38a66b85 100644 --- a/lib/multiplexer/multiplexer.go +++ b/lib/multiplexer/multiplexer.go @@ -496,19 +496,13 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) { proxyLine = newProxyLine // repeat the cycle to detect the protocol - case ProtoTLS, ProtoSSH, ProtoHTTP: + case ProtoTLS, ProtoSSH, ProtoHTTP, ProtoPostgres: return &Conn{ protocol: proto, Conn: conn, reader: reader, proxyLine: proxyLine, }, nil - case ProtoPostgres: - return &Conn{ - protocol: proto, - Conn: conn, - reader: reader, - }, nil } } // if code ended here after three attempts, something is wrong diff --git a/lib/multiplexer/multiplexer_test.go b/lib/multiplexer/multiplexer_test.go index c3716663ca277..c24d20ccba693 100644 --- a/lib/multiplexer/multiplexer_test.go +++ b/lib/multiplexer/multiplexer_test.go @@ -538,8 +538,9 @@ func TestMux(t *testing.T) { defer cancel() mux, err := New(Config{ - Context: ctx, - Listener: listener, + Context: ctx, + Listener: listener, + EnableExternalProxyProtocol: true, }) require.NoError(t, err) go mux.Serve() @@ -548,20 +549,36 @@ func TestMux(t *testing.T) { // register listener before establishing frontend connection dblistener := mux.DB() - // Connect to the listener and send Postgres SSLRequest which is what - // psql or other Postgres client will do. - conn, err := net.Dial("tcp", listener.Addr().String()) - require.NoError(t, err) - defer conn.Close() + check := func(t *testing.T, expectedAddr string, proxyLine []byte) { + // Connect to the listener and send Postgres SSLRequest which is what + // psql or other Postgres client will do. + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + defer conn.Close() - frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(conn), conn) - err = frontend.Send(&pgproto3.SSLRequest{}) - require.NoError(t, err) + _, err = conn.Write(sampleProxyV2Line) + require.NoError(t, err) + + frontend := pgproto3.NewFrontend(pgproto3.NewChunkReader(conn), conn) + err = frontend.Send(&pgproto3.SSLRequest{}) + require.NoError(t, err) - // This should not hang indefinitely since we set timeout on the mux context above. - conn, err = dblistener.Accept() - require.NoError(t, err, "detected Postgres connection") - require.Equal(t, ProtoPostgres, conn.(*Conn).Protocol()) + // This should not hang indefinitely since we set timeout on the mux context above. + dbConn, err := dblistener.Accept() + require.NoError(t, err, "detected Postgres connection") + require.Equal(t, ProtoPostgres, dbConn.(*Conn).Protocol()) + if expectedAddr != "" { + require.Equal(t, expectedAddr, dbConn.RemoteAddr().String()) + } + } + + t.Run("without proxy line", func(t *testing.T) { + check(t, "", nil) + }) + + t.Run("with proxy line", func(t *testing.T) { + check(t, "127.0.0.1:12345", sampleProxyV2Line) + }) }) // WebListener verifies web listener correctly multiplexes connections