From dc434431655329e2d2974436b562fcf63837f69b Mon Sep 17 00:00:00 2001 From: Tiago Silva Date: Thu, 18 Jan 2024 19:03:58 +0000 Subject: [PATCH] Prevent deadlock on moderated sessions when mod connection drops This PR removes a deadlock caused by the moderator leaving the session because his connection drops. `OnWriteError` was called under lock which creates an issue if the function calls `Termnager.DeleteWriter` to exclude the writter from the term manager. This PR also correctly forwards errors to the clients when they occur. Changelog: Ensure that moderated sessions do not get stuck in the event of an unexpected drop in the moderator's connection. Fixes #36881 Signed-off-by: Tiago Silva --- lib/kube/proxy/forwarder.go | 92 ++++++++++++++++------------- lib/kube/proxy/streamproto/proto.go | 14 +++++ lib/srv/termmanager.go | 29 +++++++-- 3 files changed, 89 insertions(+), 46 deletions(-) diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index b374399c99ef6..25545ff43b56d 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -1339,9 +1339,16 @@ func (f *Forwarder) join(ctx *authContext, w http.ResponseWriter, req *http.Requ if err != nil { return nil, trace.Wrap(err) } - + var stream *streamproto.SessionStream + // Close the stream when we exit to ensure no goroutines are leaked and + // to ensure the client gets a close message in case of an error. + defer func() { + if stream != nil { + stream.Close() + } + }() if err := func() error { - stream, err := streamproto.NewSessionStream(ws, streamproto.ServerHandshake{MFARequired: session.PresenceEnabled}) + stream, err = streamproto.NewSessionStream(ws, streamproto.ServerHandshake{MFARequired: session.PresenceEnabled}) if err != nil { return trace.Wrap(err) } @@ -1444,67 +1451,70 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt } defer wsSource.Close() - err = wsProxy(wsSource, wsTarget) - if err != nil { - return nil, trace.Wrap(err) - } + wsProxy(f.log, wsSource, wsTarget) return nil, nil } // wsProxy proxies a websocket connection between two clusters transparently to allow for // remote joins. -func wsProxy(wsSource *websocket.Conn, wsTarget *websocket.Conn) error { - closeM := make(chan struct{}) - errS := make(chan error) - errT := make(chan error) - - go func() { +func wsProxy(log logrus.FieldLogger, wsSource *websocket.Conn, wsTarget *websocket.Conn) { + errS := make(chan error, 1) + errT := make(chan error, 1) + wg := &sync.WaitGroup{} + + forwardConn := func(dst, src *websocket.Conn, errc chan<- error) { + defer dst.Close() + defer src.Close() for { - ty, data, err := wsSource.ReadMessage() + msgType, msg, err := src.ReadMessage() if err != nil { - wsSource.Close() - errS <- trace.Wrap(err) - return + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error()) + var e *websocket.CloseError + if errors.As(err, &e) { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } + } + errc <- err + dst.WriteMessage(websocket.CloseMessage, m) + break } - wsTarget.WriteMessage(ty, data) - - if ty == websocket.CloseMessage { - closeM <- struct{}{} - return + err = dst.WriteMessage(msgType, msg) + if err != nil { + errc <- err + break } } - }() + } + wg.Add(2) go func() { - for { - ty, data, err := wsTarget.ReadMessage() - if err != nil { - wsTarget.Close() - errT <- trace.Wrap(err) - return - } - - wsSource.WriteMessage(ty, data) - - if ty == websocket.CloseMessage { - closeM <- struct{}{} - return - } - } + defer wg.Done() + forwardConn(wsSource, wsTarget, errS) + }() + go func() { + defer wg.Done() + forwardConn(wsTarget, wsSource, errT) }() var err error + var from, to string select { case err = <-errS: - wsTarget.WriteMessage(websocket.CloseMessage, []byte{}) + from = "client" + to = "upstream" case err = <-errT: - wsSource.WriteMessage(websocket.CloseMessage, []byte{}) - case <-closeM: + from = "upstream" + to = "client" } - return trace.Wrap(err) + var websocketErr *websocket.CloseError + if errors.As(err, &websocketErr) && websocketErr.Code == websocket.CloseAbnormalClosure { + log.WithError(err).Debugf("websocket proxy: Error when copying from %s to %s", from, to) + } + wg.Wait() } // acquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. diff --git a/lib/kube/proxy/streamproto/proto.go b/lib/kube/proxy/streamproto/proto.go index 957a8c092f391..e11c34b815d97 100644 --- a/lib/kube/proxy/streamproto/proto.go +++ b/lib/kube/proxy/streamproto/proto.go @@ -17,6 +17,8 @@ limitations under the License. package streamproto import ( + "errors" + "fmt" "io" "sync" "sync/atomic" @@ -72,6 +74,7 @@ type SessionStream struct { closed int32 MFARequired bool Mode types.SessionParticipantMode + isClient bool } // NewSessionStream creates a new session stream. @@ -87,6 +90,7 @@ func NewSessionStream(conn *websocket.Conn, handshake any) (*SessionStream, erro clientHandshake, isClient := handshake.(ClientHandshake) serverHandshake, ok := handshake.(ServerHandshake) + s.isClient = isClient if !isClient && !ok { return nil, trace.BadParameter("Handshake must be either client or server handshake, got %T", handshake) @@ -167,6 +171,16 @@ func (s *SessionStream) readTask() { log.WithError(err).Warn("Failed to read message from websocket") } + var closeErr *websocket.CloseError + // If it's a close error, we want to send a message to the stdout + if s.isClient && errors.As(err, &closeErr) && closeErr.Text != "" { + select { + case s.in <- []byte(fmt.Sprintf("\r\n---\r\nConnection closed: %v\r\n", closeErr.Text)): + case <-s.done: + return + } + } + return } diff --git a/lib/srv/termmanager.go b/lib/srv/termmanager.go index 7053a4708912a..0f72352573ce2 100644 --- a/lib/srv/termmanager.go +++ b/lib/srv/termmanager.go @@ -95,21 +95,37 @@ func (g *TermManager) writeToClients(p []byte) { g.history = truncateFront(append(g.history, p...), maxHistoryBytes) atomic.AddUint64(&g.countWritten, uint64(len(p))) + var toDelete []struct { + key string + err error + } for key, w := range g.writers { _, err := w.Write(p) if err != nil { if err != io.EOF { log.Warnf("Failed to write to remote terminal: %v", err) } - - // Let term manager decide how to handle broken party writers - if g.OnWriteError != nil { - g.OnWriteError(key, err) - } + toDelete = append( + toDelete, struct { + key string + err error + }{key, err}) delete(g.writers, key) } } + + // Let term manager decide how to handle broken party writers + if g.OnWriteError != nil { + // writeToClients is called with the lock held, so we need to release it + // before calling OnWriteError to avoid a deadlock if OnWriteError + // calls DeleteWriter/DeleteReader. + g.mu.Unlock() + for _, deleteWriter := range toDelete { + g.OnWriteError(deleteWriter.key, deleteWriter.err) + } + g.mu.Lock() + } } func (g *TermManager) TerminateNotifier() <-chan struct{} { @@ -213,7 +229,10 @@ func (g *TermManager) DeleteWriter(name string) { } func (g *TermManager) AddReader(name string, r io.Reader) { + // AddReader is called by goroutines so we need to hold the lock. + g.mu.Lock() g.readerState[name] = false + g.mu.Unlock() go func() { for {