diff --git a/lib/kube/proxy/streamproto/proto.go b/lib/kube/proxy/streamproto/proto.go index c333b24091427..871cb857a92db 100644 --- a/lib/kube/proxy/streamproto/proto.go +++ b/lib/kube/proxy/streamproto/proto.go @@ -74,7 +74,7 @@ type SessionStream struct { writeSync sync.Mutex done chan struct{} closeOnce sync.Once - closed int32 + closed atomic.Bool MFARequired bool Mode types.SessionParticipantMode isClient bool @@ -215,7 +215,7 @@ func (s *SessionStream) readTask() { if ty == websocket.CloseMessage { s.conn.Close() - atomic.StoreInt32(&s.closed, 1) + s.closed.Store(true) return } } @@ -236,11 +236,7 @@ func (s *SessionStream) Read(p []byte) (int, error) { } func (s *SessionStream) Write(data []byte) (int, error) { - s.writeSync.Lock() - defer s.writeSync.Unlock() - - err := s.conn.WriteMessage(websocket.BinaryMessage, data) - if err != nil { + if err := s.write(websocket.BinaryMessage, data); err != nil { return 0, trace.Wrap(err) } @@ -255,9 +251,7 @@ func (s *SessionStream) Resize(size *remotecommand.TerminalSize) error { return trace.Wrap(err) } - s.writeSync.Lock() - defer s.writeSync.Unlock() - return trace.Wrap(s.conn.WriteMessage(websocket.TextMessage, json)) + return trace.Wrap(s.write(websocket.TextMessage, json)) } // ResizeQueue returns a channel that will receive resize requests. @@ -278,32 +272,37 @@ func (s *SessionStream) ForceTerminate() error { return trace.Wrap(err) } - s.writeSync.Lock() - defer s.writeSync.Unlock() - - return trace.Wrap(s.conn.WriteMessage(websocket.TextMessage, json)) + return trace.Wrap(s.write(websocket.TextMessage, json)) } func (s *SessionStream) Done() <-chan struct{} { return s.done } +func (s *SessionStream) write(messageType int, data []byte) error { + s.writeSync.Lock() + defer s.writeSync.Unlock() + + return trace.Wrap(s.conn.WriteMessage(messageType, data)) +} + // Close closes the stream. func (s *SessionStream) Close() error { - if atomic.CompareAndSwapInt32(&s.closed, 0, 1) { - err := s.conn.WriteMessage(websocket.CloseMessage, []byte{}) - if err != nil { - slog.WarnContext(context.Background(), "Failed to gracefully close websocket connection", "error", err) - } - t := time.NewTimer(time.Second * 5) - defer t.Stop() - select { - case <-s.done: - case <-t.C: - s.conn.Close() - } - s.closeOnce.Do(func() { close(s.done) }) + if !s.closed.CompareAndSwap(false, true) { + return nil + } + + if err := s.write(websocket.CloseMessage, nil); err != nil { + slog.WarnContext(context.Background(), "Failed to gracefully close websocket connection", "error", err) + } + + var err error + select { + case <-s.done: + case <-time.After(5 * time.Second): + err = s.conn.Close() } + s.closeOnce.Do(func() { close(s.done) }) - return nil + return trace.Wrap(err) }