diff --git a/lib/srv/desktop/rdp/rdpclient/client.go b/lib/srv/desktop/rdp/rdpclient/client.go index c2640b1f18a75..5439ba2296301 100644 --- a/lib/srv/desktop/rdp/rdpclient/client.go +++ b/lib/srv/desktop/rdp/rdpclient/client.go @@ -74,8 +74,10 @@ import ( "encoding/binary" "fmt" "log/slog" + "net" "os" "runtime/cgo" + "strconv" "sync" "sync/atomic" "time" @@ -414,6 +416,9 @@ func (c *Client) startInputStreaming(stopCh chan struct{}) error { c.cfg.Logger.InfoContext(context.Background(), "TDP input streaming starting") defer c.cfg.Logger.InfoContext(context.Background(), "TDP input streaming finished") + // we will disable ping only if the env var is truthy + disableDesktopPing, _ := strconv.ParseBool(os.Getenv("TELEPORT_DISABLE_DESKTOP_LATENCY_DETECTOR_PING")) + var withheldResize *tdp.ClientScreenSpec for { select { @@ -432,6 +437,22 @@ func (c *Client) startInputStreaming(stopCh chan struct{}) error { c.cfg.Logger.WarnContext(context.Background(), "Failed reading TDP input message", "error", err) return err } + if m, ok := msg.(tdp.Ping); ok { + // Upon receiving a ping message, we make a connection + // to the host and send the same message back to the proxy. + // The proxy will then compute the round trip time. + if !disableDesktopPing { + conn, err := net.Dial("tcp", c.cfg.Addr) + if err == nil { + conn.Close() + } + } + if err := c.cfg.Conn.WriteMessage(m); err != nil { + c.cfg.Logger.WarnContext(context.Background(), "Failed writing TDP ping message", "error", err) + return err + } + continue + } if atomic.LoadUint32(&c.readyForInput) == 0 { switch m := msg.(type) { diff --git a/lib/srv/desktop/tdp/proto.go b/lib/srv/desktop/tdp/proto.go index 5ad8baf9d38ed..c753b48f9fcf2 100644 --- a/lib/srv/desktop/tdp/proto.go +++ b/lib/srv/desktop/tdp/proto.go @@ -33,6 +33,7 @@ import ( "image/png" "io" + "github.com/google/uuid" "github.com/gravitational/trace" authproto "github.com/gravitational/teleport/api/client/proto" @@ -82,6 +83,8 @@ const ( TypeSyncKeys = MessageType(32) TypeSharedDirectoryTruncateRequest = MessageType(33) TypeSharedDirectoryTruncateResponse = MessageType(34) + TypeLatencyStats = MessageType(35) + TypePing = MessageType(36) ) // Message is a Go representation of a desktop protocol message. @@ -182,6 +185,8 @@ func decodeMessage(firstByte byte, in byteReader) (Message, error) { return decodeSharedDirectoryTruncateRequest(in) case TypeSharedDirectoryTruncateResponse: return decodeSharedDirectoryTruncateResponse(in) + case TypePing: + return decodePing(in) default: return nil, trace.BadParameter("unsupported desktop protocol message type %d", firstByte) } @@ -1631,6 +1636,44 @@ func decodeSharedDirectoryTruncateResponse(in io.Reader) (SharedDirectoryTruncat return res, err } +// LatencyStats is used to report the latency of the connection(s) to the client. +type LatencyStats struct { + ClientLatency uint32 + ServerLatency uint32 +} + +func (l LatencyStats) Encode() ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte(byte(TypeLatencyStats)) + writeUint32(buf, l.ClientLatency) + writeUint32(buf, l.ServerLatency) + return buf.Bytes(), nil +} + +// Ping is used to measure the latency of the connection(s) between proxy and desktop (includes +// latency between proxy and Windows Desktop Service and between WDS and desktop). +type Ping struct { + + // UUID is used to correlate message send by proxy and received from the Windows Desktop Service + UUID uuid.UUID +} + +func (p Ping) Encode() ([]byte, error) { + buf := new(bytes.Buffer) + buf.WriteByte(byte(TypePing)) + buf.Write(p.UUID[:]) + return buf.Bytes(), nil +} + +func decodePing(in io.Reader) (Ping, error) { + var ping Ping + _, err := io.ReadFull(in, ping.UUID[:]) + if err != nil { + return ping, trace.Wrap(err) + } + return ping, nil +} + // encodeString encodes strings for TDP. Strings are encoded as UTF-8 with // a 32-bit length prefix (in bytes): // https://github.com/gravitational/teleport/blob/master/rfd/0037-desktop-access-protocol.md#field-types diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 2acad1ae98396..77b5e1d37c744 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -30,9 +30,12 @@ import ( "net/http" "sync" + "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/julienschmidt/httprouter" + "golang.org/x/sync/errgroup" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/constants" @@ -47,6 +50,7 @@ import ( "github.com/gravitational/teleport/lib/reversetunnelclient" "github.com/gravitational/teleport/lib/srv/desktop/tdp" "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/utils/diagnostics/latency" logutils "github.com/gravitational/teleport/lib/utils/log" ) @@ -194,7 +198,7 @@ func (h *Handler) createDesktopConnection( clientSrcAddr: clientSrcAddr, clientDstAddr: clientDstAddr, } - serviceConn, _, err := c.connectToWindowsService(ctx, clusterName, validServiceIDs) + serviceConn, version, err := c.connectToWindowsService(ctx, clusterName, validServiceIDs) if err != nil { return sendTDPError(trace.Wrap(err, "cannot connect to Windows Desktop Service")) } @@ -233,7 +237,7 @@ func (h *Handler) createDesktopConnection( // proxyWebsocketConn hangs here until connection is closed handleProxyWebsocketConnErr( ctx, - proxyWebsocketConn(ws, serviceConnTLS), + proxyWebsocketConn(ctx, ws, serviceConnTLS, log, version), log, ) @@ -535,19 +539,108 @@ func (c *connector) tryConnect(ctx context.Context, clusterName, desktopServiceI return conn, ver, trace.Wrap(err) } +// desktopPinger measures latency between proxy and the desktop by sending tdp.Ping messages +// Windows Desktop Service and measuring the time it takes to receive message with the same UUID back. +type desktopPinger struct { + wds net.Conn + ch <-chan tdp.Ping +} + +func (d desktopPinger) Ping(ctx context.Context) error { + ping := tdp.Ping{ + UUID: uuid.New(), + } + buf, err := ping.Encode() + if err != nil { + return trace.Wrap(err) + } + _, err = d.wds.Write(buf) + if err != nil { + return trace.Wrap(err) + } + for { + select { + case pong := <-d.ch: + if pong.UUID == ping.UUID { + return nil + } + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + } + } +} + // proxyWebsocketConn does a bidrectional copy between the websocket // connection to the browser (ws) and the mTLS connection to Windows // Desktop Serivce (wds) -func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error { +func proxyWebsocketConn(ctx context.Context, ws *websocket.Conn, wds net.Conn, log *slog.Logger, version string) error { + ctx, cancel := context.WithCancel(ctx) var closeOnce sync.Once close := func() { + cancel() ws.Close() wds.Close() } - errs := make(chan error, 2) + tdpMessagesToSend := make(chan tdp.Message) + + latencySupported, err := utils.MinVerWithoutPreRelease(version, "17.5.0") + if err != nil { + return trace.Wrap(err) + } + + pings := make(chan tdp.Ping) + + if latencySupported { + pinger := desktopPinger{ + wds: wds, + ch: pings, + } + + go monitorLatency(ctx, clockwork.NewRealClock(), ws, pinger, + latency.ReporterFunc(func(ctx context.Context, stats latency.Statistics) error { + tdpMessagesToSend <- tdp.LatencyStats{ + ClientLatency: uint32(stats.Client), + ServerLatency: uint32(stats.Server), + } + return nil + }), + ) + + } + + var errs errgroup.Group + + // run a goroutine to pick TDP messages up from a channel and send + // them to the browser + errs.Go(func() error { + for msg := range tdpMessagesToSend { + if ping, ok := msg.(tdp.Ping); ok { + pings <- ping + continue + } + if ls, ok := msg.(tdp.LatencyStats); ok { + log.DebugContext(ctx, "sending latency stats", "client", ls.ClientLatency, "server", ls.ServerLatency) + } + encoded, err := msg.Encode() + if err != nil { + return err + } + + err = ws.WriteMessage(websocket.BinaryMessage, encoded) + if utils.IsOKNetworkError(err) { + return err + } + if err != nil { + return err + } + } + return nil + }) - go func() { + // run a second goroutine to read TDP messages from the Windows + // agent and write them to our send channel + errs.Go(func() error { defer closeOnce.Do(close) // we avoid using io.Copy here, as we want to make sure @@ -563,8 +656,7 @@ func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error { for { msg, err := tc.ReadMessage() if utils.IsOKNetworkError(err) { - errs <- nil - return + return err } else if err != nil { isFatal := tdp.IsFatalErr(err) severity := tdp.SeverityError @@ -585,27 +677,15 @@ func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error { if sendErr != nil { err = sendErr } - errs <- err - return - } - encoded, err := msg.Encode() - if err != nil { - errs <- err - return - } - err = ws.WriteMessage(websocket.BinaryMessage, encoded) - if utils.IsOKNetworkError(err) { - errs <- nil - return - } - if err != nil { - errs <- err - return + return err } + tdpMessagesToSend <- msg } - }() + }) - go func() { + // run a goroutine to read TDP messages coming from the browser + // and pass them on to the Windows agent + errs.Go(func() error { defer closeOnce.Do(close) var buf bytes.Buffer @@ -613,30 +693,22 @@ func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error { _, reader, err := ws.NextReader() switch { case utils.IsOKNetworkError(err): - errs <- nil - return + return err case err != nil: - errs <- err - return + return err } buf.Reset() if _, err := io.Copy(&buf, reader); err != nil { - errs <- err - return + return err } if _, err := wds.Write(buf.Bytes()); err != nil { - errs <- trace.Wrap(err, "sending TDP message to desktop agent") - return + return trace.Wrap(err, "sending TDP message to desktop agent") } } - }() + }) - var retErrs []error - for i := 0; i < 2; i++ { - retErrs = append(retErrs, <-errs) - } - return trace.NewAggregate(retErrs...) + return trace.Wrap(errs.Wait()) } // handleProxyWebsocketConnErr handles the error returned by proxyWebsocketConn by diff --git a/lib/web/latency.go b/lib/web/latency.go new file mode 100644 index 0000000000000..822ab3249ce4a --- /dev/null +++ b/lib/web/latency.go @@ -0,0 +1,63 @@ +/* + * * + * * Teleport + * * Copyright (C) 2024 Gravitational, Inc. + * * + * * This program is free software: you can redistribute it and/or modify + * * it under the terms of the GNU Affero General Public License as published by + * * the Free Software Foundation, either version 3 of the License, or + * * (at your option) any later version. + * * + * * This program is distributed in the hope that it will be useful, + * * but WITHOUT ANY WARRANTY; without even the implied warranty of + * * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * * GNU Affero General Public License for more details. + * * + * * You should have received a copy of the GNU Affero General Public License + * * along with this program. If not, see . + * + */ + +package web + +import ( + "context" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport/lib/utils/diagnostics/latency" +) + +// monitorLatency implements the Web UI's latency detector. +// It runs as long as the provided context has not expired. +// +// The latency of the provided websocket is monitored automatically, +// and the latency to the target endpoint is monitored with the provided pinger. +// The results of the latency calculation are reported to the web UI +// with the provided reporter. +func monitorLatency( + ctx context.Context, + clock clockwork.Clock, + ws latency.WebSocket, + endpointPinger latency.Pinger, + reporter latency.Reporter, +) error { + wsPinger, err := latency.NewWebsocketPinger(clock, ws) + if err != nil { + return trace.Wrap(err, "creating websocket pinger") + } + + monitor, err := latency.NewMonitor(latency.MonitorConfig{ + ClientPinger: wsPinger, + ServerPinger: endpointPinger, + Reporter: reporter, + Clock: clock, + }) + if err != nil { + return trace.Wrap(err, "creating latency monitor") + } + + monitor.Run(ctx) + return nil +} diff --git a/lib/web/terminal.go b/lib/web/terminal.go index e3cd8daed17c5..ae31b5bb40ce9 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -778,36 +778,6 @@ func (t *sshBaseHandler) connectToHost(ctx context.Context, ws terminal.WSConn, } } -func monitorSessionLatency(ctx context.Context, clock clockwork.Clock, stream *terminal.WSStream, sshClient *tracessh.Client) error { - wsPinger, err := latency.NewWebsocketPinger(clock, stream) - if err != nil { - return trace.Wrap(err, "creating websocket pinger") - } - - sshPinger, err := latency.NewSSHPinger(sshClient) - if err != nil { - return trace.Wrap(err, "creating ssh pinger") - } - - monitor, err := latency.NewMonitor(latency.MonitorConfig{ - ClientPinger: wsPinger, - ServerPinger: sshPinger, - Reporter: latency.ReporterFunc(func(ctx context.Context, statistics latency.Statistics) error { - return trace.Wrap(stream.WriteLatency(terminal.SSHSessionLatencyStats{ - WebSocket: statistics.Client, - SSH: statistics.Server, - })) - }), - Clock: clock, - }) - if err != nil { - return trace.Wrap(err, "creating latency monitor") - } - - monitor.Run(ctx) - return nil -} - // streamTerminal opens an SSH connection to the remote host and streams // events back to the web client. func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.TeleportClient) { @@ -846,11 +816,24 @@ func (t *TerminalHandler) streamTerminal(ctx context.Context, tc *client.Telepor monitorCtx, monitorCancel := context.WithCancel(ctx) defer monitorCancel() - go func() { - if err := monitorSessionLatency(monitorCtx, t.clock, t.stream.WSStream, nc.Client); err != nil { - t.logger.WarnContext(monitorCtx, "failure monitoring session latency", "error", err) - } - }() + + sshPinger, err := latency.NewSSHPinger(nc.Client) + if err != nil { + t.logger.WarnContext(monitorCtx, "failure monitoring session latency", "error", err) + } else { + go monitorLatency(monitorCtx, t.clock, t.stream.WSStream, sshPinger, + latency.ReporterFunc( + func(ctx context.Context, statistics latency.Statistics) error { + return trace.Wrap( + t.stream.WSStream.WriteLatency(terminal.SSHSessionLatencyStats{ + WebSocket: statistics.Client, + SSH: statistics.Server, + }), + ) + }, + ), + ) + } sessionDataSent := make(chan struct{}) // If we are joining a session, send the session data right away, we diff --git a/rfd/0037-desktop-access-protocol.md b/rfd/0037-desktop-access-protocol.md index 49cd7807afa48..77af6b0aba68d 100644 --- a/rfd/0037-desktop-access-protocol.md +++ b/rfd/0037-desktop-access-protocol.md @@ -316,3 +316,22 @@ This message is sent from the client to the server to synchronize the state of k - `0` for \* lock inactive - `1` FOR \* LOCK ACTIVE + +#### 35 - latency stats + +This message is sent from the server to the client to indicate latency +between client and proxy and between proxy and desktop. + +``` +| message type (35) | client_latency uint32 | server_latency uint32 | +``` + +#### 36 - ping + +This message is sent between proxy and Windows desktop service to measure latency between proxy and desktop. +Proxy will send ping message with random UUID and WDS will respond with the same message +after measuring latency to desktop. + +``` +| message type (36) | uuid [16]byte | +``` \ No newline at end of file diff --git a/web/packages/shared/components/DesktopSession/DesktopSession.tsx b/web/packages/shared/components/DesktopSession/DesktopSession.tsx index 12decc4ab1691..6621242b741be 100644 --- a/web/packages/shared/components/DesktopSession/DesktopSession.tsx +++ b/web/packages/shared/components/DesktopSession/DesktopSession.tsx @@ -34,6 +34,7 @@ import { CanvasRenderer, CanvasRendererRef, } from 'shared/components/CanvasRenderer'; +import { Latency } from 'shared/components/LatencyDiagnostic'; import { Attempt, makeEmptyAttempt, @@ -218,6 +219,17 @@ export function DesktopSession({ useListener(client.onReset, canvasRendererRef.current?.clear); useListener(client.onScreenSpec, canvasRendererRef.current?.setResolution); + const [latencyStats, setLatencyStats] = useState(); + useListener( + client.onLatencyStats, + useCallback(stats => { + setLatencyStats({ + client: stats.client, + server: stats.server, + }); + }, []) + ); + const shouldConnect = aclAttempt.status === 'success' && anotherDesktopActiveAttempt.status === 'success' && @@ -358,6 +370,7 @@ export function DesktopSession({ onCtrlAltDel={handleCtrlAltDel} alerts={alerts} onRemoveAlert={onRemoveAlert} + latency={latencyStats} /> {/* The UI states below (except the loading indicator) take up space.*/} diff --git a/web/packages/shared/components/DesktopSession/TopBar.tsx b/web/packages/shared/components/DesktopSession/TopBar.tsx index 8bcabde5b61aa..b40d841f8091a 100644 --- a/web/packages/shared/components/DesktopSession/TopBar.tsx +++ b/web/packages/shared/components/DesktopSession/TopBar.tsx @@ -21,6 +21,7 @@ import { useTheme } from 'styled-components'; import { Flex, Text, TopNav } from 'design'; import { Clipboard, FolderShared } from 'design/Icon'; import { HoverTooltip } from 'design/Tooltip'; +import { LatencyDiagnostic } from 'shared/components/LatencyDiagnostic'; import type { NotificationItem } from 'shared/components/Notification'; import ActionMenu from './ActionMenu'; @@ -39,6 +40,7 @@ export default function TopBar(props: Props) { alerts, onRemoveAlert, isConnected, + latency, } = props; const theme = useTheme(); @@ -62,7 +64,8 @@ export default function TopBar(props: Props) { {isConnected && ( - + + {latency && } - + - + @@ -117,4 +120,8 @@ type Props = { alerts: NotificationItem[]; isConnected: boolean; onRemoveAlert(id: string): void; + latency: { + client: number; + server: number; + }; }; diff --git a/web/packages/shared/components/LatencyDiagnostic/LatencyDiagnostic.tsx b/web/packages/shared/components/LatencyDiagnostic/LatencyDiagnostic.tsx index bc436591e0340..2649b4f447f99 100644 --- a/web/packages/shared/components/LatencyDiagnostic/LatencyDiagnostic.tsx +++ b/web/packages/shared/components/LatencyDiagnostic/LatencyDiagnostic.tsx @@ -99,7 +99,11 @@ export function LatencyDiagnostic({ const colors = latencyColors(latency); return ( - +

Network Connection

diff --git a/web/packages/shared/components/MenuAction/MenuActionIcon.tsx b/web/packages/shared/components/MenuAction/MenuActionIcon.tsx index 679d09e05b287..6f201e169fb14 100644 --- a/web/packages/shared/components/MenuAction/MenuActionIcon.tsx +++ b/web/packages/shared/components/MenuAction/MenuActionIcon.tsx @@ -22,6 +22,7 @@ import { ButtonIcon } from 'design'; import { MoreHoriz } from 'design/Icon'; import { IconProps } from 'design/Icon/Icon'; import Menu from 'design/Menu'; +import { HoverTooltip } from 'design/Tooltip'; import { AnchorProps, MenuProps } from './types'; @@ -56,14 +57,16 @@ export default class MenuActionIcon extends React.Component< const { children, buttonIconProps, menuProps, Icon } = this.props; return ( <> - (this.anchorEl = e)} - onClick={this.onOpen} - data-testid="button" - > - - + + (this.anchorEl = e)} + onClick={this.onOpen} + data-testid="button" + > + + + ; + tooltip?: React.ReactNode; }; diff --git a/web/packages/shared/libs/tdp/client.ts b/web/packages/shared/libs/tdp/client.ts index cb961087bb7a8..70b428ad3c244 100644 --- a/web/packages/shared/libs/tdp/client.ts +++ b/web/packages/shared/libs/tdp/client.ts @@ -28,6 +28,7 @@ import Logger from 'shared/libs/logger'; import Codec, { FileType, + LatencyStats, MessageType, PointerData, Severity, @@ -72,6 +73,7 @@ export enum TdpClientEvent { TRANSPORT_CLOSE = 'transport close', RESET = 'reset', POINTER = 'pointer', + LATENCY_STATS = 'latency stats', } export enum LogType { @@ -255,6 +257,11 @@ export class TdpClient extends EventEmitter { return () => this.off(TdpClientEvent.TDP_CLIENT_SCREEN_SPEC, listener); }; + onLatencyStats = (listener: (stats: LatencyStats) => void) => { + this.on(TdpClientEvent.LATENCY_STATS, listener); + return () => this.off(TdpClientEvent.LATENCY_STATS, listener); + }; + private async initWasm() { // select the wasm log level let wasmLogLevel = LogType.OFF; @@ -352,11 +359,19 @@ export class TdpClient extends EventEmitter { case MessageType.SHARED_DIRECTORY_TRUNCATE_REQUEST: await this.handleSharedDirectoryTruncateRequest(buffer); break; + case MessageType.LATENCY_STATS: + this.handleLatencyStats(buffer); + break; default: this.logger.warn(`received unsupported message type ${messageType}`); } } + handleLatencyStats(buffer: ArrayBuffer) { + const stats = this.codec.decodeLatencyStats(buffer); + this.emit(TdpClientEvent.LATENCY_STATS, stats); + } + handleClientScreenSpec(buffer: ArrayBuffer) { this.logger.warn( `received unsupported message type ${this.codec.decodeMessageType( diff --git a/web/packages/shared/libs/tdp/codec.ts b/web/packages/shared/libs/tdp/codec.ts index c57ff4f89d7dd..1dc94efc54dbb 100644 --- a/web/packages/shared/libs/tdp/codec.ts +++ b/web/packages/shared/libs/tdp/codec.ts @@ -55,6 +55,7 @@ export enum MessageType { SYNC_KEYS = 32, SHARED_DIRECTORY_TRUNCATE_REQUEST = 33, SHARED_DIRECTORY_TRUNCATE_RESPONSE = 34, + LATENCY_STATS = 35, __LAST, // utility value } @@ -323,6 +324,12 @@ export enum FileType { Directory = 1, } +// | message type (35) | client_latency uint32 | server_latency uint32 | +export type LatencyStats = { + client: number; + server: number; +}; + function toSharedDirectoryErrCode(errCode: number): SharedDirectoryErrCode { if (!(errCode in SharedDirectoryErrCode)) { throw new Error(`attempted to convert invalid error code ${errCode}`); @@ -1117,6 +1124,20 @@ export default class Codec { }; } + decodeLatencyStats(buffer: ArrayBuffer): LatencyStats { + const dv = new DataView(buffer); + let bufOffset = BYTE_LEN; // eat message type + const browserLatency = dv.getUint32(bufOffset); + bufOffset += UINT_32_LEN; + const desktopLatency = dv.getUint32(bufOffset); + bufOffset += UINT_32_LEN; + + return { + client: browserLatency, + server: desktopLatency, + }; + } + // asBase64Url creates a data:image uri from the png data part of a PNG_FRAME tdp message. private asBase64Url(buffer: ArrayBuffer, offset: number): string { return `data:image/png;base64,${arrayBufferToBase64(buffer.slice(offset))}`;