diff --git a/api/client/proxy/client.go b/api/client/proxy/client.go index 4653544760f53..8413e99293ebf 100644 --- a/api/client/proxy/client.go +++ b/api/client/proxy/client.go @@ -449,7 +449,7 @@ func (c *Client) ClusterDetails(ctx context.Context) (ClusterDetails, error) { // ProxyWindowsDesktopSession establishes a connection to the target desktop over a bidirectional stream. // The caller is required to pass a valid desktop certificate. -func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, windowsDesktopCert tls.Certificate, rootCAs *x509.CertPool) (net.Conn, error) { +func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, windowsDesktopCert tls.Certificate, rootCAs *x509.CertPool) (*tls.Conn, error) { session, err := c.transport.ProxyWindowsDesktopSession(ctx, cluster, desktopName, windowsDesktopCert, rootCAs) if err != nil { return nil, trace.Wrap(err) diff --git a/api/client/proxy/transport/transportv1/client.go b/api/client/proxy/transport/transportv1/client.go index 31cbba10ef55a..92d92fb59ebc3 100644 --- a/api/client/proxy/transport/transportv1/client.go +++ b/api/client/proxy/transport/transportv1/client.go @@ -74,7 +74,7 @@ const ( // ProxyWindowsDesktopSession establishes a connection to the target desktop over a bidirectional stream. // The caller is required to pass a valid desktop certificate. -func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (net.Conn, error) { +func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (*tls.Conn, error) { connCtx, cancel := context.WithCancel(context.WithoutCancel(ctx)) stop := context.AfterFunc(ctx, cancel) defer stop() @@ -99,7 +99,7 @@ func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, return nc, nil } -func (c *Client) dialProxyWindowsDesktopSession(ctx context.Context, cancel context.CancelFunc, stream grpc.BidiStreamingClient[transportv1pb.ProxyWindowsDesktopSessionRequest, transportv1pb.ProxyWindowsDesktopSessionResponse], cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (net.Conn, error) { +func (c *Client) dialProxyWindowsDesktopSession(ctx context.Context, cancel context.CancelFunc, stream grpc.BidiStreamingClient[transportv1pb.ProxyWindowsDesktopSessionRequest, transportv1pb.ProxyWindowsDesktopSessionResponse], cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (*tls.Conn, error) { err := stream.Send(&transportv1pb.ProxyWindowsDesktopSessionRequest{ DialTarget: &transportv1pb.TargetWindowsDesktop{ DesktopName: desktopName, diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index 5439ba2296301..4b9d160d2b3d9 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -438,19 +438,20 @@ func (c *Client) startInputStreaming(stopCh chan struct{}) error { return err } if m, ok := msg.(tdp.Ping); ok { - // Upon receiving a ping message, we make a connection - // to the host and send the same message back to the proxy. - // The proxy will then compute the round trip time. - if !disableDesktopPing { - conn, err := net.Dial("tcp", c.cfg.Addr) - if err == nil { - conn.Close() + go func() { + // Upon receiving a ping message, we make a connection + // to the host and send the same message back to the proxy. + // The proxy will then compute the round trip time. + if !disableDesktopPing { + conn, err := net.Dial("tcp", c.cfg.Addr) + if err == nil { + conn.Close() + } } - } - if err := c.cfg.Conn.WriteMessage(m); err != nil { - c.cfg.Logger.WarnContext(context.Background(), "Failed writing TDP ping message", "error", err) - return err - } + if err := c.cfg.Conn.WriteMessage(m); err != nil { + c.cfg.Logger.WarnContext(context.Background(), "Failed writing TDP ping message", "error", err) + } + }() continue } diff --git a/lib/srv/desktop/tdp/conn.go b/lib/srv/desktop/tdp/conn.go index 40239dd6f9782..021942709d8e7 100644 --- a/lib/srv/desktop/tdp/conn.go +++ b/lib/srv/desktop/tdp/conn.go @@ -20,7 +20,6 @@ package tdp import ( "bufio" - "context" "errors" "io" "net" @@ -37,6 +36,7 @@ import ( // Teleport Desktop Protocol (TDP) messages. type Conn struct { rwc io.ReadWriteCloser + writeMu sync.Mutex bufr *bufio.Reader closeOnce sync.Once @@ -107,7 +107,10 @@ func (c *Conn) WriteMessage(m Message) error { return trace.Wrap(err) } + c.writeMu.Lock() _, err = c.rwc.Write(buf) + c.writeMu.Unlock() + if c.OnSend != nil { c.OnSend(m, buf) } @@ -173,24 +176,22 @@ func IsFatalErr(err error) bool { // It accepts an optional serverInterceptor to intercept received messages. func NewConnProxy(client, server io.ReadWriteCloser, serverInterceptor ServerInterceptor) *ConnProxy { return &ConnProxy{ - client: client, - server: server, + client: NewConn(client), + server: NewConn(server), serverInterceptor: serverInterceptor, - messagesToClient: make(chan Message), } } // ConnProxy does a bidirectional copy between the connection to the client and the mTLS connection to the server. type ConnProxy struct { // client is a connection to the client (browser/Connect). - client io.ReadWriteCloser - // server io.ReadWriteCloser is a connection to the server (Windows Desktop Service). - server io.ReadWriteCloser + client *Conn + // server is a connection to the server (Windows Desktop Service). + server *Conn // serverInterceptor intercepts the incoming messages. // If the returned message is non-nil, it is forwarded to the client. // If an error is returned, the stream is canceled. serverInterceptor ServerInterceptor - messagesToClient chan Message } // ServerInterceptor intercepts messages received from the server. @@ -198,115 +199,57 @@ type ConnProxy struct { type ServerInterceptor func(serverTdpConn *Conn, message Message) (Message, error) // SendToClient sends a message to the client and blocks until the operation completes. -func (c *ConnProxy) SendToClient(ctx context.Context, message Message) error { - select { - case c.messagesToClient <- message: - return nil - case <-ctx.Done(): - return ctx.Err() - } +func (c *ConnProxy) SendToClient(message Message) error { + err := c.client.WriteMessage(message) + return trace.Wrap(err) +} + +// SendToServer sends a message to the server and blocks until the operation completes. +func (c *ConnProxy) SendToServer(message Message) error { + err := c.server.WriteMessage(message) + return trace.Wrap(err) } // Run starts proxying the connection. -func (c *ConnProxy) Run(ctx context.Context) error { - ctx, cancel := context.WithCancel(ctx) +func (c *ConnProxy) Run() error { + var errs errgroup.Group - var closeOnce sync.Once - closeAll := func() { - cancel() + closeAll := sync.OnceFunc(func() { c.client.Close() c.server.Close() - } - defer closeOnce.Do(closeAll) - - var errs errgroup.Group - - sendTDPAlert := func(err error, severity Severity) error { - msg := Alert{Message: err.Error(), Severity: severity} - b, err := msg.Encode() - if err != nil { - return trace.Wrap(err) - } - _, err = c.client.Write(b) - return trace.Wrap(err) - } - - // Run a goroutine to pick TDP messages up from a channel and send - // them to the client. - errs.Go(func() error { - defer closeOnce.Do(closeAll) - - for { - select { - case msg := <-c.messagesToClient: - encoded, err := msg.Encode() - if err != nil { - return trace.Wrap(err) - } - if _, err := c.client.Write(encoded); err != nil { - return trace.Wrap(err) - } - case <-ctx.Done(): - return ctx.Err() - } - } }) + defer closeAll() - // Run a second goroutine to read TDP messages from the Windows - // agent and write them to our send channel. + // Run a goroutine to read TDP messages from the Windows + // agent and write them to client. errs.Go(func() error { - defer closeOnce.Do(closeAll) + defer closeAll() // We avoid using io.Copy here, as we want to make sure // each TDP message is sent as a unit so that a single // 'message' event is emitted in the JS TDP client. // Internal buffer of io.Copy could split one message // into multiple downstreamConn.Send() calls. - tdpConn := NewConn(c.server) - defer tdpConn.Close() - - // we don't care about the content of the message, we just + // We don't care about the content of the message, we just // need to split the stream into individual messages and // write them to the client for { - msg, err := tdpConn.ReadMessage() - if utils.IsOKNetworkError(err) { - return trace.Wrap(err) - } - if err != nil { - isFatal := IsFatalErr(err) - severity := SeverityError - if !isFatal { - severity = SeverityWarning - } - sendErr := sendTDPAlert(err, severity) - - // If the error wasn't fatal, and we successfully - // sent it back to the client, continue. - if !isFatal && sendErr == nil { - continue - } + msg, err := c.server.ReadMessage() - // If the error was fatal, or we failed to send it back - // to the client, send it to the errCh channel and end - // the session. - if sendErr != nil { - err = sendErr - } - return trace.Wrap(err) + if err := c.handleError(err); err != nil { + return err } if c.serverInterceptor != nil { - msg, err = c.serverInterceptor(tdpConn, msg) + msg, err = c.serverInterceptor(c.server, msg) if err != nil { return trace.Wrap(err) } } if msg != nil { - select { - case c.messagesToClient <- msg: - case <-ctx.Done(): - return ctx.Err() + err := c.SendToClient(msg) + if err != nil { + return trace.Wrap(err) } } } @@ -315,9 +258,18 @@ func (c *ConnProxy) Run(ctx context.Context) error { // Run a goroutine to read TDP messages coming from the client // and pass them on to the Windows agent. errs.Go(func() error { - defer closeOnce.Do(closeAll) - _, err := io.Copy(c.server, c.client) - return trace.Wrap(err, "sending TDP message to desktop agent") + defer closeAll() + for { + msg, err := c.client.ReadMessage() + + if err := c.handleError(err); err != nil { + return err + } + + if err := c.SendToServer(msg); err != nil { + return trace.Wrap(err) + } + } }) // Wait for all goroutines to finish @@ -327,3 +279,31 @@ func (c *ConnProxy) Run(ctx context.Context) error { return nil } + +func (c *ConnProxy) handleError(err error) error { + if err == nil { + return nil + } + if utils.IsOKNetworkError(err) { + return trace.Wrap(err) + } + isFatal := IsFatalErr(err) + severity := SeverityError + if !isFatal { + severity = SeverityWarning + } + sendErr := c.SendToClient(Alert{Message: err.Error(), Severity: severity}) + + // If the error wasn't fatal, and we successfully + // sent it back to the client, continue. + if !isFatal && sendErr == nil { + return nil + } + + // If the error was fatal, or we failed to send it back + // to the client, return it and end the session. + if sendErr != nil { + err = sendErr + } + return trace.Wrap(err) +} diff --git a/lib/teleterm/services/desktop/desktop.go b/lib/teleterm/services/desktop/desktop.go index 4b71beb6c508f..2ae2e1cc065f4 100644 --- a/lib/teleterm/services/desktop/desktop.go +++ b/lib/teleterm/services/desktop/desktop.go @@ -155,7 +155,7 @@ func (s *Session) Start(ctx context.Context, stream grpc.BidiStreamingServer[api return msg, nil }) - return trace.Wrap(tdpConnProxy.Run(ctx)) + return trace.Wrap(tdpConnProxy.Run()) } // clientStream implements the [streamutils.Source] interface diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 144f8ee25362b..80d6aa96a47d9 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -24,7 +24,6 @@ import ( "crypto/tls" "errors" "log/slog" - "net" "net/http" "github.com/google/uuid" @@ -456,20 +455,15 @@ func readClientScreenSpec(ws *websocket.Conn) (*tdp.ClientScreenSpec, error) { // desktopPinger measures latency between proxy and the desktop by sending tdp.Ping messages // Windows Desktop Service and measuring the time it takes to receive message with the same UUID back. type desktopPinger struct { - wds net.Conn - ch <-chan tdp.Ping + proxy *tdp.ConnProxy + ch <-chan tdp.Ping } func (d desktopPinger) Ping(ctx context.Context) error { ping := tdp.Ping{ UUID: uuid.New(), } - buf, err := ping.Encode() - if err != nil { - return trace.Wrap(err) - } - _, err = d.wds.Write(buf) - if err != nil { + if err := d.proxy.SendToServer(ping); err != nil { return trace.Wrap(err) } for { @@ -487,7 +481,7 @@ func (d desktopPinger) Ping(ctx context.Context) error { // proxyWebsocketConn does a bidrectional copy between the websocket // connection to the browser (ws) and the mTLS connection to Windows // Desktop Serivce (wds) -func proxyWebsocketConn(ctx context.Context, ws *websocket.Conn, wds net.Conn, log *slog.Logger, version string) error { +func proxyWebsocketConn(ctx context.Context, ws *websocket.Conn, wds *tls.Conn, log *slog.Logger, version string) error { ctx, cancel := context.WithCancel(ctx) defer func() { cancel() @@ -504,25 +498,28 @@ func proxyWebsocketConn(ctx context.Context, ws *websocket.Conn, wds net.Conn, l tdpConnProxy := tdp.NewConnProxy(&WebsocketIO{Conn: ws}, wds, func(_ *tdp.Conn, msg tdp.Message) (tdp.Message, error) { if ping, ok := msg.(tdp.Ping); ok { - pings <- ping + if !latencySupported { + return nil, trace.BadParameter("received unexpected Ping message from server (this is a bug)") + } + select { + case pings <- ping: + case <-ctx.Done(): + } return nil, nil } - - if ls, ok := msg.(tdp.LatencyStats); ok { - log.DebugContext(ctx, "sending latency stats", "client", ls.ClientLatency, "server", ls.ServerLatency) - } return msg, nil }) if latencySupported { pinger := desktopPinger{ - wds: wds, - ch: pings, + proxy: tdpConnProxy, + ch: pings, } go monitorLatency(ctx, clockwork.NewRealClock(), ws, pinger, latency.ReporterFunc(func(ctx context.Context, stats latency.Statistics) error { - return trace.Wrap(tdpConnProxy.SendToClient(ctx, tdp.LatencyStats{ + log.DebugContext(ctx, "sending latency stats", "client", stats.Client, "server", stats.Server) + return trace.Wrap(tdpConnProxy.SendToClient(tdp.LatencyStats{ ClientLatency: uint32(stats.Client), ServerLatency: uint32(stats.Server), })) @@ -531,7 +528,7 @@ func proxyWebsocketConn(ctx context.Context, ws *websocket.Conn, wds net.Conn, l } - return trace.Wrap(tdpConnProxy.Run(ctx)) + return trace.Wrap(tdpConnProxy.Run()) } // handleProxyWebsocketConnErr handles the error returned by proxyWebsocketConn by