diff --git a/lib/multiplexer/singleplexer.go b/lib/multiplexer/singleplexer.go new file mode 100644 index 0000000000000..4efd07b344059 --- /dev/null +++ b/lib/multiplexer/singleplexer.go @@ -0,0 +1,241 @@ +// Teleport +// Copyright (C) 2023 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package multiplexer + +import ( + "bufio" + "context" + "net" + "sync" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/utils" +) + +type connectionLimiter interface { + AcquireConnection(string) error + ReleaseConnection(string) +} + +func RunSingleplexer[B ~string | ~[]byte](ctx context.Context, + listener net.Listener, + handleConn func(net.Conn), + earlyData B, + getCA CertAuthorityGetter, clusterName string, + limiter connectionLimiter, +) { + listenCtx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + // unblock the Accept by closing the listener when the context is done + <-listenCtx.Done() + _ = listener.Close() + }() + + for { + c, err := listener.Accept() + if err == nil { + go handleSingleplexedConn(ctx, c, handleConn, earlyData, getCA, clusterName, limiter) + continue + } + if listenCtx.Err() != nil || utils.IsUseOfClosedNetworkError(err) { + break + } + backoff := 5 * time.Second + if tErr, ok := err.(interface{ Temporary() bool }); ok && tErr.Temporary() { + backoff = 100 * time.Millisecond + } + select { + case <-listenCtx.Done(): + break + case <-time.After(backoff): + } + } +} + +func handleSingleplexedConn[B ~string | ~[]byte](ctx context.Context, + c net.Conn, + handleConn func(net.Conn), + earlyData B, + getCA CertAuthorityGetter, clusterName string, + limiter connectionLimiter, +) { + defer func() { + if c != nil { + c.Close() + } + }() + + // copied from [multiplexer.Mux.Serve()] + if t, ok := c.(*net.TCPConn); ok { + _ = t.SetKeepAlive(true) + _ = t.SetKeepAlivePeriod(3 * time.Minute) + } + + _ = c.SetDeadline(time.Now().Add(defaults.ReadHeadersTimeout)) + + // XXX: this makes the same assumption regarding the availability of a small + // write buffer that [ssh.NewServerConn] makes. It's not great as it limits + // the use of synchronous connections like [net.Pipe], but doing it in + // parallel makes the code quite a bit more complicated. + if len(earlyData) > 0 { + if _, err := c.Write([]byte(earlyData)); err != nil { + return + } + } + + reader := bufio.NewReader(c) + isProxyV2, err := readerHasPrefix(reader, ProxyV2Prefix) + if err != nil { + // errors on Peek(), almost surely I/O + return + } + + var remoteAddr net.Addr + var limiterToken string + + if isProxyV2 { + proxyline, err := ReadProxyLineV2(reader) + if err != nil { + // mostly I/O errors + return + } + if proxyline == nil { + // we shouldn't honor LOCAL proxylines + return + } + if err := proxyline.VerifySignature(ctx, + getCA, clusterName, + clockwork.NewRealClock(), + ); err != nil { + // bad signature + return + } + remoteAddr = &proxyline.Source + limiterToken = proxyline.Source.IP.String() + } else if r := c.RemoteAddr(); r != nil { + limiterToken = r.String() + if host, _, err := utils.SplitHostPort(c.RemoteAddr().String()); err == nil { + limiterToken = host + } + } + + if limiter != nil { + if err := limiter.AcquireConnection(limiterToken); err != nil { + return + } + defer limiter.ReleaseConnection(limiterToken) + } + + _ = c.SetDeadline(time.Time{}) + + wrapped := &singleplexedConn{ + Conn: c, + remoteAddr: remoteAddr, + reader: reader, + skip: len(earlyData), + } + + // handing the connection over, disable the defer + c = nil + + handleConn(wrapped) +} + +type singleplexedConn struct { + net.Conn + + remoteAddr net.Addr + + readMu sync.Mutex + reader *bufio.Reader + + writeMu sync.Mutex + skip int +} + +// Close implements [io.Closer] and [net.Conn]. +func (c *singleplexedConn) Close() error { + err := trace.Wrap(c.Conn.Close()) + + c.readMu.Lock() + defer c.readMu.Unlock() + _, _ = c.reader.Discard(c.reader.Buffered()) + + return err +} + +// Read implements [io.Reader] and [net.Conn]. +func (c *singleplexedConn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + return c.reader.Read(b) +} + +// Write implements [io.Writer] and [net.Conn]. +func (c *singleplexedConn) Write(b []byte) (int, error) { + c.writeMu.Lock() + if c.skip < 1 { + c.writeMu.Unlock() + return c.Conn.Write(b) + } + defer c.writeMu.Unlock() + + if len(b) <= c.skip { + // check if the connection is open and not past the write deadline + _, err := c.Conn.Write(nil) + if err != nil { + return 0, trace.Wrap(err) + } + c.skip -= len(b) + return len(b), nil + } + + b = b[c.skip:] + n, err := c.Conn.Write(b) + if n > 0 { + n += c.skip + c.skip = 0 + } + return n, trace.Wrap(err) +} + +// RemoteAddr implements [net.Conn]. +func (c *singleplexedConn) RemoteAddr() net.Addr { + if c.remoteAddr != nil { + return c.remoteAddr + } + return c.Conn.RemoteAddr() +} + +func readerHasPrefix[B ~[]byte | ~string](r *bufio.Reader, prefix B) (bool, error) { + for i, b := range []byte(prefix) { + buf, err := r.Peek(i + 1) + if err != nil { + return false, trace.Wrap(err) + } + if buf[i] != b { + return false, nil + } + } + return true, nil +} diff --git a/lib/multiplexer/wrappers.go b/lib/multiplexer/wrappers.go index 26f508fa5fce6..633ee242b4973 100644 --- a/lib/multiplexer/wrappers.go +++ b/lib/multiplexer/wrappers.go @@ -118,6 +118,13 @@ func newListener(parent context.Context, addr net.Addr) *Listener { } } +// NewListener returns an artificial [net.Listener] that pretends to be +// listening the given address and that receives connections through its +// HandleConnection method. +func NewListener(parent context.Context, addr net.Addr) *Listener { + return newListener(parent, addr) +} + // Listener is a listener that receives // connections from multiplexer based on the connection type type Listener struct { diff --git a/lib/service/service.go b/lib/service/service.go index c37a25bfcde58..897e068cb1bbe 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -137,6 +137,7 @@ import ( "github.com/gravitational/teleport/lib/srv/ingress" "github.com/gravitational/teleport/lib/srv/regular" "github.com/gravitational/teleport/lib/srv/transport/transportv1" + "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/system" usagereporter "github.com/gravitational/teleport/lib/usagereporter/teleport" "github.com/gravitational/teleport/lib/utils" @@ -2664,27 +2665,18 @@ func (process *TeleportProcess) initSSH() error { log.Infof("Service %s:%s is starting on %v %v.", teleport.Version, teleport.Gitref, cfg.SSH.Addr.Addr, process.Config.CachePolicy) - // Use multiplexer to leverage support for signed PROXY protocol headers. - mux, err := multiplexer.New(multiplexer.Config{ - Context: process.ExitContext(), - PROXYProtocolMode: multiplexer.PROXYProtocolOff, - Listener: listener, - ID: teleport.Component(teleport.ComponentNode, process.id), - CertAuthorityGetter: authClient.GetCertAuthority, - LocalClusterName: conn.ServerIdentity.ClusterName, - }) - if err != nil { - return trace.Wrap(err) - } + wrapper := multiplexer.NewListener(process.ExitContext(), listener.Addr()) + defer wrapper.Close() - go func() { - if err := mux.Serve(); err != nil && !utils.IsOKNetworkError(err) { - mux.Entry.WithError(err).Error("node ssh multiplexer terminated unexpectedly") - } - }() - defer mux.Close() + go multiplexer.RunSingleplexer(process.ExitContext(), + listener, + func(c net.Conn) { wrapper.HandleConnection(process.ExitContext(), c) }, + sshutils.SSHVersionPrefix+"\r\n", + authClient.GetCertAuthority, conn.ServerIdentity.ClusterName, + limiter, + ) - go s.Serve(limiter.WrapListener(mux.SSH())) + go s.Serve(wrapper) } else { // Start the SSH server. This kicks off updating labels and starting the // heartbeat.