diff --git a/lib/multiplexer/multiplexer.go b/lib/multiplexer/multiplexer.go index 2bd01d80e5020..d3c1527acd579 100644 --- a/lib/multiplexer/multiplexer.go +++ b/lib/multiplexer/multiplexer.go @@ -132,13 +132,14 @@ type Mux struct { sync.RWMutex *log.Entry Config - sshListener *Listener - tlsListener *Listener - dbListener *Listener - context context.Context - cancel context.CancelFunc - waitContext context.Context - waitCancel context.CancelFunc + sshListener *Listener + tlsListener *Listener + dbListener *Listener + httpListener *Listener + context context.Context + cancel context.CancelFunc + waitContext context.Context + waitCancel context.CancelFunc // logLimiter is a goroutine responsible for deduplicating multiplexer errors // (over a 1min window) that occur when detecting the types of new connections. // This ensures that health checkers / malicious actors cannot overpower / @@ -177,6 +178,16 @@ func (m *Mux) DB() net.Listener { return m.dbListener } +// HTTP returns listener that receives plain HTTP connections +func (m *Mux) HTTP() net.Listener { + m.Lock() + defer m.Unlock() + if m.httpListener == nil { + m.httpListener = newListener(m.context, m.Config.Listener.Addr()) + } + return m.httpListener +} + func (m *Mux) closeListener() { m.Lock() defer m.Unlock() @@ -248,6 +259,9 @@ func (m *Mux) protocolListener(proto Protocol) *Listener { return m.sshListener case ProtoPostgres: return m.dbListener + case ProtoHTTP: + return m.httpListener + } return nil } diff --git a/lib/multiplexer/multiplexer_test.go b/lib/multiplexer/multiplexer_test.go index c24d20ccba693..1a7a495481007 100644 --- a/lib/multiplexer/multiplexer_test.go +++ b/lib/multiplexer/multiplexer_test.go @@ -128,6 +128,48 @@ func TestMux(t *testing.T) { } require.NotNil(t, err) }) + t.Run("HTTP", func(t *testing.T) { + t.Parallel() + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + mux, err := New(Config{ + Listener: listener, + }) + require.NoError(t, err) + go mux.Serve() + defer mux.Close() + + backend1 := &httptest.Server{ + Listener: mux.HTTP(), + Config: &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "backend 1") + }), + }, + } + backend1.Start() + defer backend1.Close() + + re, err := http.Get(backend1.URL) + require.NoError(t, err) + defer re.Body.Close() + bytes, err := io.ReadAll(re.Body) + require.NoError(t, err) + require.Equal(t, "backend 1", string(bytes)) + + // Close mux, new requests should fail + mux.Close() + mux.Wait() + + // Use new client to use new connection pool + client := &http.Client{Transport: &http.Transport{}} + re, err = client.Get(backend1.URL) + if err == nil { + re.Body.Close() + } + require.Error(t, err) + }) // ProxyLine tests proxy line protocol t.Run("ProxyLines", func(t *testing.T) { t.Parallel() diff --git a/lib/service/service.go b/lib/service/service.go index eee5b3abcf043..dcc75673930a2 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -3093,15 +3093,32 @@ func (process *TeleportProcess) initDiagnosticService() error { log.Infof("Starting diagnostic service on %v.", process.Config.DiagnosticAddr.Addr) + muxListener, err := multiplexer.New(multiplexer.Config{ + Context: process.ExitContext(), + Listener: listener, + EnableExternalProxyProtocol: true, + ID: teleport.Component(teleport.ComponentDiagnostic), + }) + if err != nil { + return trace.Wrap(err) + } + process.RegisterFunc("diagnostic.service", func() error { - err := server.Serve(listener) - if err != nil && err != http.ErrServerClosed { + listenerHTTP := muxListener.HTTP() + go func() { + if err := muxListener.Serve(); err != nil && !utils.IsOKNetworkError(err) { + muxListener.Entry.WithError(err).Error("Mux encountered err serving") + } + }() + + if err := server.Serve(listenerHTTP); !errors.Is(err, http.ErrServerClosed) { log.Warningf("Diagnostic server exited with error: %v.", err) } return nil }) process.OnExit("diagnostic.shutdown", func(payload interface{}) { + warnOnErr(muxListener.Close(), log) if payload == nil { log.Infof("Shutting down immediately.") warnOnErr(server.Close(), log)