Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 51 additions & 41 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -1167,9 +1167,16 @@ func (f *Forwarder) join(ctx *authContext, w http.ResponseWriter, req *http.Requ
if err != nil {
return nil, trace.Wrap(err)
}

var stream *streamproto.SessionStream
// Close the stream when we exit to ensure no goroutines are leaked and
// to ensure the client gets a close message in case of an error.
defer func() {
if stream != nil {
stream.Close()
}
}()
if err := func() error {
stream, err := streamproto.NewSessionStream(ws, streamproto.ServerHandshake{MFARequired: session.PresenceEnabled})
stream, err = streamproto.NewSessionStream(ws, streamproto.ServerHandshake{MFARequired: session.PresenceEnabled})
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -1269,67 +1276,70 @@ func (f *Forwarder) remoteJoin(ctx *authContext, w http.ResponseWriter, req *htt
}
defer wsSource.Close()

err = wsProxy(wsSource, wsTarget)
if err != nil {
return nil, trace.Wrap(err)
}
wsProxy(f.log, wsSource, wsTarget)

return nil, nil
}

// wsProxy proxies a websocket connection between two clusters transparently to allow for
// remote joins.
func wsProxy(wsSource *websocket.Conn, wsTarget *websocket.Conn) error {
closeM := make(chan struct{})
errS := make(chan error)
errT := make(chan error)

go func() {
func wsProxy(log logrus.FieldLogger, wsSource *websocket.Conn, wsTarget *websocket.Conn) {
errS := make(chan error, 1)
errT := make(chan error, 1)
wg := &sync.WaitGroup{}

forwardConn := func(dst, src *websocket.Conn, errc chan<- error) {
defer dst.Close()
defer src.Close()
for {
ty, data, err := wsSource.ReadMessage()
msgType, msg, err := src.ReadMessage()
if err != nil {
wsSource.Close()
errS <- trace.Wrap(err)
return
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, err.Error())
var e *websocket.CloseError
if errors.As(err, &e) {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}

wsTarget.WriteMessage(ty, data)

if ty == websocket.CloseMessage {
closeM <- struct{}{}
return
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}
}()
}

wg.Add(2)
go func() {
for {
ty, data, err := wsTarget.ReadMessage()
if err != nil {
wsTarget.Close()
errT <- trace.Wrap(err)
return
}

wsSource.WriteMessage(ty, data)

if ty == websocket.CloseMessage {
closeM <- struct{}{}
return
}
}
defer wg.Done()
forwardConn(wsSource, wsTarget, errS)
}()
go func() {
defer wg.Done()
forwardConn(wsTarget, wsSource, errT)
}()

var err error
var from, to string
select {
case err = <-errS:
wsTarget.WriteMessage(websocket.CloseMessage, []byte{})
from = "client"
to = "upstream"
case err = <-errT:
wsSource.WriteMessage(websocket.CloseMessage, []byte{})
case <-closeM:
from = "upstream"
to = "client"
}

return trace.Wrap(err)
var websocketErr *websocket.CloseError
if errors.As(err, &websocketErr) && websocketErr.Code == websocket.CloseAbnormalClosure {
log.WithError(err).Debugf("websocket proxy: Error when copying from %s to %s", from, to)
}
wg.Wait()
}

// acquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent.
Expand Down
14 changes: 14 additions & 0 deletions lib/kube/proxy/streamproto/proto.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package streamproto

import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
Expand Down Expand Up @@ -72,6 +74,7 @@ type SessionStream struct {
closed int32
MFARequired bool
Mode types.SessionParticipantMode
isClient bool
}

// NewSessionStream creates a new session stream.
Expand All @@ -87,6 +90,7 @@ func NewSessionStream(conn *websocket.Conn, handshake any) (*SessionStream, erro

clientHandshake, isClient := handshake.(ClientHandshake)
serverHandshake, ok := handshake.(ServerHandshake)
s.isClient = isClient

if !isClient && !ok {
return nil, trace.BadParameter("Handshake must be either client or server handshake, got %T", handshake)
Expand Down Expand Up @@ -167,6 +171,16 @@ func (s *SessionStream) readTask() {
log.WithError(err).Warn("Failed to read message from websocket")
}

var closeErr *websocket.CloseError
// If it's a close error, we want to send a message to the stdout
if s.isClient && errors.As(err, &closeErr) && closeErr.Text != "" {
select {
case s.in <- []byte(fmt.Sprintf("\r\n---\r\nConnection closed: %v\r\n", closeErr.Text)):
case <-s.done:
return
}
}

return
}

Expand Down
29 changes: 24 additions & 5 deletions lib/srv/termmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,37 @@ func (g *TermManager) writeToClients(p []byte) {
g.history = truncateFront(append(g.history, p...), maxHistoryBytes)

atomic.AddUint64(&g.countWritten, uint64(len(p)))
var toDelete []struct {
key string
err error
}
for key, w := range g.writers {
_, err := w.Write(p)
if err != nil {
if err != io.EOF {
log.Warnf("Failed to write to remote terminal: %v", err)
}

// Let term manager decide how to handle broken party writers
if g.OnWriteError != nil {
g.OnWriteError(key, err)
}
toDelete = append(
toDelete, struct {
key string
err error
}{key, err})

delete(g.writers, key)
}
}

// Let term manager decide how to handle broken party writers
if g.OnWriteError != nil {
// writeToClients is called with the lock held, so we need to release it
// before calling OnWriteError to avoid a deadlock if OnWriteError
// calls DeleteWriter/DeleteReader.
g.mu.Unlock()
for _, deleteWriter := range toDelete {
g.OnWriteError(deleteWriter.key, deleteWriter.err)
}
g.mu.Lock()
}
}

func (g *TermManager) TerminateNotifier() <-chan struct{} {
Expand Down Expand Up @@ -213,7 +229,10 @@ func (g *TermManager) DeleteWriter(name string) {
}

func (g *TermManager) AddReader(name string, r io.Reader) {
// AddReader is called by goroutines so we need to hold the lock.
g.mu.Lock()
g.readerState[name] = false
g.mu.Unlock()

go func() {
for {
Expand Down