diff --git a/lib/kube/proxy/server.go b/lib/kube/proxy/server.go index 913eb519b67d7..990e21b50df0b 100644 --- a/lib/kube/proxy/server.go +++ b/lib/kube/proxy/server.go @@ -299,7 +299,7 @@ func (t *TLSServer) Serve(listener net.Listener, options ...ServeOption) error { // It's required to accommodate setups with high latency and where the time // between the TCP being accepted and the time for the first byte is longer // than the default value - 1s. - ReadDeadline: 10 * time.Second, + DetectTimeout: 10 * time.Second, } for _, opt := range options { opt(&muxConfig) diff --git a/lib/multiplexer/multiplexer.go b/lib/multiplexer/multiplexer.go index c064641565fda..b2548beb4e503 100644 --- a/lib/multiplexer/multiplexer.go +++ b/lib/multiplexer/multiplexer.go @@ -79,9 +79,9 @@ type Config struct { Listener net.Listener // Context is a context to signal stops, cancellations Context context.Context - // ReadDeadline is a connection read deadline, - // set to defaults.ReadHeadersTimeout if unspecified - ReadDeadline time.Duration + // DetectTimeout is a timeout applied to the whole detection phase of the + // connection, set to defaults.ReadHeadersTimeout if unspecified + DetectTimeout time.Duration // Clock is a clock to override in tests, set to real time clock // by default Clock clockwork.Clock @@ -98,6 +98,12 @@ type Config struct { // connection (coming from same IP as the listening address) when deciding if it should drop connection with // missing required PROXY header. This is needed since all connections in tests are self connections. IgnoreSelfConnections bool + + // FixedHeader contains data that's sent to the client at the beginning of + // every connection, before protocol detection happens. An equal amount of + // data is then skipped from the connection when the application writes into + // it. Mostly useful for SSH servers. + FixedHeader string } // CheckAndSetDefaults verifies configuration and sets defaults @@ -108,8 +114,8 @@ func (c *Config) CheckAndSetDefaults() error { if c.Context == nil { c.Context = context.TODO() } - if c.ReadDeadline == 0 { - c.ReadDeadline = defaults.ReadHeadersTimeout + if c.DetectTimeout == 0 { + c.DetectTimeout = defaults.ReadHeadersTimeout } if c.Clock == nil { c.Clock = clockwork.NewRealClock() @@ -277,13 +283,22 @@ func (m *Mux) protocolListener(proto Protocol) *Listener { // protocol without a registered protocol listener are closed. This // method is called as a goroutine by Serve for each connection. func (m *Mux) detectAndForward(conn net.Conn) { - err := conn.SetReadDeadline(m.Clock.Now().Add(m.ReadDeadline)) - if err != nil { + if err := conn.SetDeadline(m.Clock.Now().Add(m.DetectTimeout)); err != nil { m.Warning(err.Error()) conn.Close() return } + if m.FixedHeader != "" { + if _, err := conn.Write([]byte(m.FixedHeader)); err != nil { + if !utils.IsOKNetworkError(err) { + m.WithError(err).Warn("Failed to send connection header.") + } + conn.Close() + return + } + } + connWrapper, err := m.detect(conn) if err != nil { if trace.Unwrap(err) != io.EOF { @@ -295,8 +310,8 @@ func (m *Mux) detectAndForward(conn net.Conn) { conn.Close() return } - err = conn.SetReadDeadline(time.Time{}) - if err != nil { + + if err := connWrapper.SetDeadline(time.Time{}); err != nil { m.Warning(trace.DebugReport(err)) connWrapper.Close() return @@ -568,10 +583,11 @@ func (m *Mux) detect(conn net.Conn) (*Conn, error) { } return &Conn{ - protocol: proto, - Conn: conn, - reader: reader, - proxyLine: proxyLine, + protocol: proto, + Conn: conn, + reader: reader, + proxyLine: proxyLine, + alreadyWritten: []byte(m.FixedHeader), }, nil } } diff --git a/lib/multiplexer/multiplexer_test.go b/lib/multiplexer/multiplexer_test.go index fe5fe2852ccda..e5a82ae72ed44 100644 --- a/lib/multiplexer/multiplexer_test.go +++ b/lib/multiplexer/multiplexer_test.go @@ -357,7 +357,7 @@ func TestMux(t *testing.T) { Listener: listener, // Set read deadline in the past to remove reliance on real time // and simulate scenario when read deadline has elapsed. - ReadDeadline: -time.Millisecond, + DetectTimeout: -time.Millisecond, } mux, err := New(config) require.NoError(t, err) @@ -1393,3 +1393,53 @@ func TestIsDifferentTCPVersion(t *testing.T) { fmt.Sprintf("Unexpected result for %q, %q", tt.addr1, tt.addr2)) } } + +func TestFixedHeader(t *testing.T) { + t.Parallel() + require := require.New(t) + + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(err) + t.Cleanup(func() { listener.Close() }) + + const defaultSSHVersionIdentifier = "SSH-2.0-Go\r\n" + mux, err := New(Config{ + Listener: listener, + FixedHeader: defaultSSHVersionIdentifier, + }) + require.NoError(err) + t.Cleanup(func() { mux.Close() }) + go mux.Serve() + + go startSSHServer(t, mux.SSH()) + + netConn, err := net.DialTimeout(listener.Addr().Network(), listener.Addr().String(), 5*time.Second) + require.NoError(err) + t.Cleanup(func() { netConn.Close() }) + + // the SSH transport layer protocol rfc (5423) states that SSH servers must + // send a version string immediately after the connection is established, so + // we expect (a specific) version string without sending anything + buf := make([]byte, len(defaultSSHVersionIdentifier)) + _, err = io.ReadFull(netConn, buf) + require.NoError(err) + require.Equal(defaultSSHVersionIdentifier, string(buf)) + + // the SSH server hasn't even been touched yet, so we can connect to it from + // a separate connection (we have to, in fact, or startSSHServer will fail + // the test) + + sshClient, err := ssh.Dial(listener.Addr().Network(), listener.Addr().String(), &ssh.ClientConfig{ + User: "bob", + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 5 * time.Second, + }) + require.NoError(err) + t.Cleanup(func() { sshClient.Close() }) + + const payload = "this is a bit useless since we already went through a full handshake" + ok, echoReply, err := sshClient.Conn.SendRequest("echo", true, []byte(payload)) + require.NoError(err) + require.True(ok) + require.Equal(payload, string(echoReply)) +} diff --git a/lib/multiplexer/wrapper_test.go b/lib/multiplexer/wrapper_test.go index ba18822d44dce..1abdd5146a98c 100644 --- a/lib/multiplexer/wrapper_test.go +++ b/lib/multiplexer/wrapper_test.go @@ -22,6 +22,7 @@ import ( "testing" "time" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" ) @@ -124,3 +125,33 @@ func TestPROXYEnabledListener_Accept(t *testing.T) { }) } } + +func TestAlreadyWritten(t *testing.T) { + require := require.New(t) + + c := &Conn{ + Conn: zeroConn{}, + alreadyWritten: []byte("aa"), + } + + n, err := c.Write([]byte("a")) + require.NoError(err) + require.Equal(1, n) + require.Equal([]byte("a"), c.alreadyWritten) + + n, err = c.Write([]byte("b")) + require.Error(err) + require.ErrorAs(err, new(*trace.BadParameterError)) + require.Equal(0, n) + + n, err = c.Write([]byte("ab")) + require.NoError(err) + require.Equal(2, n) + require.Empty(c.alreadyWritten) +} + +type zeroConn struct{ net.Conn } + +func (zeroConn) Write(p []byte) (int, error) { + return len(p), nil +} diff --git a/lib/multiplexer/wrappers.go b/lib/multiplexer/wrappers.go index c3f5995d15cb1..0ce209968ec29 100644 --- a/lib/multiplexer/wrappers.go +++ b/lib/multiplexer/wrappers.go @@ -18,6 +18,7 @@ package multiplexer import ( "bufio" + "bytes" "context" "net" @@ -35,6 +36,11 @@ type Conn struct { protocol Protocol proxyLine *ProxyLine reader *bufio.Reader + + // alreadyWritten is a slice of data that we expect the application to + // Write() on the connection (because it was already sent on the wire). As + // the application writes, the slice gets smaller. + alreadyWritten []byte } // NewConn returns a net.Conn wrapper that supports peeking into the connection. @@ -55,6 +61,31 @@ func (c *Conn) Read(p []byte) (int, error) { return c.reader.Read(p) } +// Write implements [io.Writer] and [net.Conn]. +func (c *Conn) Write(p []byte) (int, error) { + if len(c.alreadyWritten) < 1 { + return c.Conn.Write(p) + } + + s := min(len(p), len(c.alreadyWritten)) + if !bytes.Equal(p[:s], c.alreadyWritten[:s]) { + return 0, trace.BadParameter("new application data doesn't match already written data (this is a bug)") + } + + // we should do the write even if it's zero-length to check that the + // connection is still open and that we're not past the write deadline + n, err := c.Conn.Write(p[s:]) + if n > 0 || err == nil { + n += s + c.alreadyWritten = c.alreadyWritten[s:] + if len(c.alreadyWritten) < 1 { + c.alreadyWritten = nil + } + } + + return n, trace.Wrap(err) +} + // LocalAddr returns local address of the connection func (c *Conn) LocalAddr() net.Addr { if c.proxyLine != nil { diff --git a/lib/service/service.go b/lib/service/service.go index 8f661949fc74e..203cb1d239e4d 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -134,6 +134,7 @@ import ( "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/srv/transport/transportv1" + "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/system" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" "github.com/gravitational/teleport/lib/utils" @@ -2662,6 +2663,7 @@ func (process *TeleportProcess) initSSH() error { ID: teleport.Component(teleport.ComponentNode, process.id), CertAuthorityGetter: authClient.GetCertAuthority, LocalClusterName: conn.ServerIdentity.ClusterName, + FixedHeader: sshutils.SSHVersionPrefix + "\r\n", }) if err != nil { return trace.Wrap(err)