diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 8b42460ae1345..a89bcf01b17c3 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -35,6 +35,7 @@ import ( "strings" "sync" + "github.com/coreos/go-semver/semver" "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" @@ -175,7 +176,7 @@ func (h *Handler) createDesktopConnection( clientSrcAddr: clientSrcAddr, clientDstAddr: clientDstAddr, } - serviceConn, err := c.connectToWindowsService(clusterName, validServiceIDs) + serviceConn, version, err := c.connectToWindowsService(clusterName, validServiceIDs) if err != nil { return sendTDPError(trace.Wrap(err, "cannot connect to Windows Desktop Service")) } @@ -200,7 +201,7 @@ func (h *Handler) createDesktopConnection( // proxyWebsocketConn hangs here until connection is closed handleProxyWebsocketConnErr( - proxyWebsocketConn(ws, serviceConnTLS), log) + proxyWebsocketConn(ws, serviceConnTLS, version), log) return nil } @@ -385,11 +386,13 @@ type connector struct { // connectToWindowsService tries to make a connection to a Windows Desktop Service // by trying each of the services provided. It returns an error if it could not connect // to any of the services or if it encounters an error that is not a connection problem. -func (c *connector) connectToWindowsService(clusterName string, desktopServiceIDs []string) (net.Conn, error) { +func (c *connector) connectToWindowsService( + clusterName string, + desktopServiceIDs []string) (conn net.Conn, version string, err error) { for _, id := range desktopServiceIDs { - conn, err := c.tryConnect(clusterName, id) + conn, ver, err := c.tryConnect(clusterName, id) if err != nil && !trace.IsConnectionProblem(err) { - return nil, trace.WrapWithMessage(err, + return nil, "", trace.WrapWithMessage(err, "error connecting to windows_desktop_service %q", id) } if trace.IsConnectionProblem(err) { @@ -397,22 +400,25 @@ func (c *connector) connectToWindowsService(clusterName string, desktopServiceID continue } if err == nil { - return conn, err + return conn, ver, nil } } - return nil, trace.Errorf("failed to connect to any windows_desktop_service") + return nil, "", trace.Errorf("failed to connect to any windows_desktop_service") } -func (c *connector) tryConnect(clusterName, desktopServiceID string) (net.Conn, error) { +func (c *connector) tryConnect(clusterName, desktopServiceID string) (conn net.Conn, version string, err error) { service, err := c.clt.GetWindowsDesktopService(context.Background(), desktopServiceID) if err != nil { log.Errorf("Error finding service with id %s", desktopServiceID) - return nil, trace.NotFound("could not find windows desktop service %s: %v", desktopServiceID, err) + return nil, "", trace.NotFound("could not find windows desktop service %s: %v", desktopServiceID, err) } + ver := service.GetTeleportVersion() + *c.log = *c.log.WithField("windows-service-version", ver) *c.log = *c.log.WithField("windows-service-uuid", service.GetName()) *c.log = *c.log.WithField("windows-service-addr", service.GetAddr()) - return c.site.DialTCP(reversetunnelclient.DialParams{ + + conn, err = c.site.DialTCP(reversetunnelclient.DialParams{ From: c.clientSrcAddr, To: &utils.NetAddr{AddrNetwork: "tcp", Addr: service.GetAddr()}, ConnType: types.WindowsDesktopTunnel, @@ -420,18 +426,26 @@ func (c *connector) tryConnect(clusterName, desktopServiceID string) (net.Conn, ProxyIDs: service.GetProxyIDs(), OriginalClientDstAddr: c.clientDstAddr, }) + return conn, ver, trace.Wrap(err) } // proxyWebsocketConn does a bidrectional copy between the websocket // connection to the browser (ws) and the mTLS connection to Windows // Desktop Serivce (wds) -func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error { +func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn, wdsVersion string) error { var closeOnce sync.Once close := func() { ws.Close() wds.Close() } + v, err := semver.NewVersion(wdsVersion) + if err != nil { + return trace.BadParameter("invalid windows desktop service version %q: %v", wdsVersion, err) + } + + isPre15 := v.Major < 15 + errs := make(chan error, 2) go func() { @@ -495,15 +509,33 @@ func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error { go func() { defer closeOnce.Do(close) - // io.Copy is fine here, as the Windows Desktop Service - // operates on a stream and doesn't care if TPD messages - // are fragmented - stream := &WebsocketIO{Conn: ws} - _, err := io.Copy(wds, stream) - if utils.IsOKNetworkError(err) { - err = nil + buf := make([]byte, 4096) + for { + _, reader, err := ws.NextReader() + switch { + case utils.IsOKNetworkError(err): + errs <- nil + return + case err != nil: + errs <- err + return + } + n, err := reader.Read(buf) + if err != nil { + errs <- err + return + } + // don't pass the sync keys message along to old agents + // (they don't support it) + if isPre15 && tdp.MessageType(buf[0]) == tdp.TypeSyncKeys { + continue + } + + if _, err := wds.Write(buf[:n]); err != nil { + errs <- trace.Wrap(err, "sending TDP message to desktop agent") + return + } } - errs <- err }() var retErrs []error