Skip to content
2 changes: 1 addition & 1 deletion api/client/proxy/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ func (c *Client) ClusterDetails(ctx context.Context) (ClusterDetails, error) {

// ProxyWindowsDesktopSession establishes a connection to the target desktop over a bidirectional stream.
// The caller is required to pass a valid desktop certificate.
func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, windowsDesktopCert tls.Certificate, rootCAs *x509.CertPool) (net.Conn, error) {
func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, windowsDesktopCert tls.Certificate, rootCAs *x509.CertPool) (*tls.Conn, error) {
session, err := c.transport.ProxyWindowsDesktopSession(ctx, cluster, desktopName, windowsDesktopCert, rootCAs)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
4 changes: 2 additions & 2 deletions api/client/proxy/transport/transportv1/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ const (

// ProxyWindowsDesktopSession establishes a connection to the target desktop over a bidirectional stream.
// The caller is required to pass a valid desktop certificate.
func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (net.Conn, error) {
func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (*tls.Conn, error) {
connCtx, cancel := context.WithCancel(context.WithoutCancel(ctx))
stop := context.AfterFunc(ctx, cancel)
defer stop()
Expand All @@ -99,7 +99,7 @@ func (c *Client) ProxyWindowsDesktopSession(ctx context.Context, cluster string,
return nc, nil
}

func (c *Client) dialProxyWindowsDesktopSession(ctx context.Context, cancel context.CancelFunc, stream grpc.BidiStreamingClient[transportv1pb.ProxyWindowsDesktopSessionRequest, transportv1pb.ProxyWindowsDesktopSessionResponse], cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (net.Conn, error) {
func (c *Client) dialProxyWindowsDesktopSession(ctx context.Context, cancel context.CancelFunc, stream grpc.BidiStreamingClient[transportv1pb.ProxyWindowsDesktopSessionRequest, transportv1pb.ProxyWindowsDesktopSessionResponse], cluster string, desktopName string, desktopCert tls.Certificate, rootCAs *x509.CertPool) (*tls.Conn, error) {
err := stream.Send(&transportv1pb.ProxyWindowsDesktopSessionRequest{
DialTarget: &transportv1pb.TargetWindowsDesktop{
DesktopName: desktopName,
Expand Down
25 changes: 13 additions & 12 deletions lib/srv/desktop/rdp/rdpclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,19 +438,20 @@ func (c *Client) startInputStreaming(stopCh chan struct{}) error {
return err
}
if m, ok := msg.(tdp.Ping); ok {
// Upon receiving a ping message, we make a connection
// to the host and send the same message back to the proxy.
// The proxy will then compute the round trip time.
if !disableDesktopPing {
conn, err := net.Dial("tcp", c.cfg.Addr)
if err == nil {
conn.Close()
go func() {
// Upon receiving a ping message, we make a connection
// to the host and send the same message back to the proxy.
// The proxy will then compute the round trip time.
if !disableDesktopPing {
conn, err := net.Dial("tcp", c.cfg.Addr)
if err == nil {
conn.Close()
}
}
}
if err := c.cfg.Conn.WriteMessage(m); err != nil {
c.cfg.Logger.WarnContext(context.Background(), "Failed writing TDP ping message", "error", err)
return err
}
if err := c.cfg.Conn.WriteMessage(m); err != nil {
c.cfg.Logger.WarnContext(context.Background(), "Failed writing TDP ping message", "error", err)
}
}()
continue
}

Expand Down
166 changes: 73 additions & 93 deletions lib/srv/desktop/tdp/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package tdp

import (
"bufio"
"context"
"errors"
"io"
"net"
Expand All @@ -37,6 +36,7 @@ import (
// Teleport Desktop Protocol (TDP) messages.
type Conn struct {
rwc io.ReadWriteCloser
writeMu sync.Mutex
bufr *bufio.Reader
closeOnce sync.Once

Expand Down Expand Up @@ -107,7 +107,10 @@ func (c *Conn) WriteMessage(m Message) error {
return trace.Wrap(err)
}

c.writeMu.Lock()
_, err = c.rwc.Write(buf)
c.writeMu.Unlock()

if c.OnSend != nil {
c.OnSend(m, buf)
}
Expand Down Expand Up @@ -173,140 +176,80 @@ func IsFatalErr(err error) bool {
// It accepts an optional serverInterceptor to intercept received messages.
func NewConnProxy(client, server io.ReadWriteCloser, serverInterceptor ServerInterceptor) *ConnProxy {
return &ConnProxy{
client: client,
server: server,
client: NewConn(client),
server: NewConn(server),
serverInterceptor: serverInterceptor,
messagesToClient: make(chan Message),
}
}

// ConnProxy does a bidirectional copy between the connection to the client and the mTLS connection to the server.
type ConnProxy struct {
// client is a connection to the client (browser/Connect).
client io.ReadWriteCloser
// server io.ReadWriteCloser is a connection to the server (Windows Desktop Service).
server io.ReadWriteCloser
client *Conn
// server is a connection to the server (Windows Desktop Service).
server *Conn
// serverInterceptor intercepts the incoming messages.
// If the returned message is non-nil, it is forwarded to the client.
// If an error is returned, the stream is canceled.
serverInterceptor ServerInterceptor
messagesToClient chan Message
}

// ServerInterceptor intercepts messages received from the server.
// If a message returned from the interceptor is nil, it's not sent to the client.
type ServerInterceptor func(serverTdpConn *Conn, message Message) (Message, error)

// SendToClient sends a message to the client and blocks until the operation completes.
func (c *ConnProxy) SendToClient(ctx context.Context, message Message) error {
select {
case c.messagesToClient <- message:
return nil
case <-ctx.Done():
return ctx.Err()
}
func (c *ConnProxy) SendToClient(message Message) error {
Comment thread
probakowski marked this conversation as resolved.
err := c.client.WriteMessage(message)
return trace.Wrap(err)
}

// SendToServer sends a message to the server and blocks until the operation completes.
func (c *ConnProxy) SendToServer(message Message) error {
err := c.server.WriteMessage(message)
return trace.Wrap(err)
}

// Run starts proxying the connection.
func (c *ConnProxy) Run(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
func (c *ConnProxy) Run() error {
var errs errgroup.Group

var closeOnce sync.Once
closeAll := func() {
cancel()
closeAll := sync.OnceFunc(func() {
c.client.Close()
c.server.Close()
}
defer closeOnce.Do(closeAll)

var errs errgroup.Group

sendTDPAlert := func(err error, severity Severity) error {
msg := Alert{Message: err.Error(), Severity: severity}
b, err := msg.Encode()
if err != nil {
return trace.Wrap(err)
}
_, err = c.client.Write(b)
return trace.Wrap(err)
}

// Run a goroutine to pick TDP messages up from a channel and send
// them to the client.
errs.Go(func() error {
defer closeOnce.Do(closeAll)

for {
select {
case msg := <-c.messagesToClient:
encoded, err := msg.Encode()
if err != nil {
return trace.Wrap(err)
}
if _, err := c.client.Write(encoded); err != nil {
return trace.Wrap(err)
}
case <-ctx.Done():
return ctx.Err()
}
}
})
defer closeAll()

// Run a second goroutine to read TDP messages from the Windows
// agent and write them to our send channel.
// Run a goroutine to read TDP messages from the Windows
// agent and write them to client.
errs.Go(func() error {
defer closeOnce.Do(closeAll)
defer closeAll()

// We avoid using io.Copy here, as we want to make sure
// each TDP message is sent as a unit so that a single
// 'message' event is emitted in the JS TDP client.
// Internal buffer of io.Copy could split one message
// into multiple downstreamConn.Send() calls.
tdpConn := NewConn(c.server)
defer tdpConn.Close()

// we don't care about the content of the message, we just
// We don't care about the content of the message, we just
// need to split the stream into individual messages and
// write them to the client
for {
msg, err := tdpConn.ReadMessage()
if utils.IsOKNetworkError(err) {
return trace.Wrap(err)
}
if err != nil {
isFatal := IsFatalErr(err)
severity := SeverityError
if !isFatal {
severity = SeverityWarning
}
sendErr := sendTDPAlert(err, severity)

// If the error wasn't fatal, and we successfully
// sent it back to the client, continue.
if !isFatal && sendErr == nil {
continue
}
msg, err := c.server.ReadMessage()

// If the error was fatal, or we failed to send it back
// to the client, send it to the errCh channel and end
// the session.
if sendErr != nil {
err = sendErr
}
return trace.Wrap(err)
if err := c.handleError(err); err != nil {
return err
}

if c.serverInterceptor != nil {
msg, err = c.serverInterceptor(tdpConn, msg)
msg, err = c.serverInterceptor(c.server, msg)
if err != nil {
return trace.Wrap(err)
}
}
if msg != nil {
select {
case c.messagesToClient <- msg:
case <-ctx.Done():
return ctx.Err()
err := c.SendToClient(msg)
if err != nil {
return trace.Wrap(err)
}
}
}
Expand All @@ -315,9 +258,18 @@ func (c *ConnProxy) Run(ctx context.Context) error {
// Run a goroutine to read TDP messages coming from the client
// and pass them on to the Windows agent.
errs.Go(func() error {
defer closeOnce.Do(closeAll)
_, err := io.Copy(c.server, c.client)
return trace.Wrap(err, "sending TDP message to desktop agent")
defer closeAll()
for {
msg, err := c.client.ReadMessage()

if err := c.handleError(err); err != nil {
return err
}

if err := c.SendToServer(msg); err != nil {
return trace.Wrap(err)
}
}
Comment on lines +262 to +272
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren’t we missing context cancellation handling in the goroutines? Something like

Suggested change
for {
msg, err := c.client.ReadMessage()
if err := c.handleError(err); err != nil {
return err
}
if err := c.SendToServer(msg); err != nil {
return trace.Wrap(err)
}
}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
msg, err := c.client.ReadMessage()
if err := c.handleError(err); err != nil {
return err
}
if err := c.SendToServer(msg); err != nil {
return trace.Wrap(err)
}
}
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't mix context cancellation with IO like that, it doesn't do anything and just adds to the confusion. If the loop should end when the context is cancelled then c.client.ReadMessage() and c.SendToServer() should unblock and return an error when the context is cancelled, and if they don't then this is just occasionally succeeding in exiting the loop when the context is done if we get lucky with the timing, and a possible deadlock otherwise.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify, if cancelling the context closes the connections then there's nothing to be added by checking for the context in this loop, and if cancelling the context does not close the connections then we are going to block on read or write and we will not be respecting context cancellation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So ctx here is connected with client side of the connection so its cancellation should mean that client will close as well. That means its goroutine will exit and in turn close server side (with closeAll), that will end the other goroutine as well. I removed the ctx.Err check.

})

// Wait for all goroutines to finish
Expand All @@ -327,3 +279,31 @@ func (c *ConnProxy) Run(ctx context.Context) error {

return nil
}

func (c *ConnProxy) handleError(err error) error {
if err == nil {
return nil
}
if utils.IsOKNetworkError(err) {
return trace.Wrap(err)
}
isFatal := IsFatalErr(err)
severity := SeverityError
if !isFatal {
severity = SeverityWarning
}
sendErr := c.SendToClient(Alert{Message: err.Error(), Severity: severity})

// If the error wasn't fatal, and we successfully
// sent it back to the client, continue.
if !isFatal && sendErr == nil {
return nil
}

// If the error was fatal, or we failed to send it back
// to the client, return it and end the session.
if sendErr != nil {
err = sendErr
}
return trace.Wrap(err)
}
2 changes: 1 addition & 1 deletion lib/teleterm/services/desktop/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func (s *Session) Start(ctx context.Context, stream grpc.BidiStreamingServer[api
return msg, nil
})

return trace.Wrap(tdpConnProxy.Run(ctx))
return trace.Wrap(tdpConnProxy.Run())
}

// clientStream implements the [streamutils.Source] interface
Expand Down
Loading
Loading