Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 51 additions & 19 deletions lib/web/desktop.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import (
"strings"
"sync"

"github.com/coreos/go-semver/semver"
"github.com/gorilla/websocket"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
Expand Down Expand Up @@ -175,7 +176,7 @@ func (h *Handler) createDesktopConnection(
clientSrcAddr: clientSrcAddr,
clientDstAddr: clientDstAddr,
}
serviceConn, err := c.connectToWindowsService(clusterName, validServiceIDs)
serviceConn, version, err := c.connectToWindowsService(clusterName, validServiceIDs)
if err != nil {
return sendTDPError(trace.Wrap(err, "cannot connect to Windows Desktop Service"))
}
Expand All @@ -200,7 +201,7 @@ func (h *Handler) createDesktopConnection(

// proxyWebsocketConn hangs here until connection is closed
handleProxyWebsocketConnErr(
proxyWebsocketConn(ws, serviceConnTLS), log)
proxyWebsocketConn(ws, serviceConnTLS, version), log)

return nil
}
Expand Down Expand Up @@ -385,53 +386,66 @@ type connector struct {
// connectToWindowsService tries to make a connection to a Windows Desktop Service
// by trying each of the services provided. It returns an error if it could not connect
// to any of the services or if it encounters an error that is not a connection problem.
func (c *connector) connectToWindowsService(clusterName string, desktopServiceIDs []string) (net.Conn, error) {
func (c *connector) connectToWindowsService(
clusterName string,
desktopServiceIDs []string) (conn net.Conn, version string, err error) {
for _, id := range desktopServiceIDs {
conn, err := c.tryConnect(clusterName, id)
conn, ver, err := c.tryConnect(clusterName, id)
if err != nil && !trace.IsConnectionProblem(err) {
return nil, trace.WrapWithMessage(err,
return nil, "", trace.WrapWithMessage(err,
"error connecting to windows_desktop_service %q", id)
}
if trace.IsConnectionProblem(err) {
c.log.Warnf("failed to connect to windows_desktop_service %q: %v", id, err)
continue
}
if err == nil {
return conn, err
return conn, ver, nil
}
}
return nil, trace.Errorf("failed to connect to any windows_desktop_service")
return nil, "", trace.Errorf("failed to connect to any windows_desktop_service")
}

func (c *connector) tryConnect(clusterName, desktopServiceID string) (net.Conn, error) {
func (c *connector) tryConnect(clusterName, desktopServiceID string) (conn net.Conn, version string, err error) {
service, err := c.clt.GetWindowsDesktopService(context.Background(), desktopServiceID)
if err != nil {
log.Errorf("Error finding service with id %s", desktopServiceID)
return nil, trace.NotFound("could not find windows desktop service %s: %v", desktopServiceID, err)
return nil, "", trace.NotFound("could not find windows desktop service %s: %v", desktopServiceID, err)
}

ver := service.GetTeleportVersion()
*c.log = *c.log.WithField("windows-service-version", ver)
*c.log = *c.log.WithField("windows-service-uuid", service.GetName())
*c.log = *c.log.WithField("windows-service-addr", service.GetAddr())
return c.site.DialTCP(reversetunnelclient.DialParams{

conn, err = c.site.DialTCP(reversetunnelclient.DialParams{
From: c.clientSrcAddr,
To: &utils.NetAddr{AddrNetwork: "tcp", Addr: service.GetAddr()},
ConnType: types.WindowsDesktopTunnel,
ServerID: service.GetName() + "." + clusterName,
ProxyIDs: service.GetProxyIDs(),
OriginalClientDstAddr: c.clientDstAddr,
})
return conn, ver, trace.Wrap(err)
}

// proxyWebsocketConn does a bidrectional copy between the websocket
// connection to the browser (ws) and the mTLS connection to Windows
// Desktop Serivce (wds)
func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error {
func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn, wdsVersion string) error {
var closeOnce sync.Once
close := func() {
ws.Close()
wds.Close()
}

v, err := semver.NewVersion(wdsVersion)
if err != nil {
return trace.BadParameter("invalid windows desktop service version %q: %v", wdsVersion, err)
}

isPre15 := v.Major < 15

errs := make(chan error, 2)

go func() {
Expand Down Expand Up @@ -495,15 +509,33 @@ func proxyWebsocketConn(ws *websocket.Conn, wds net.Conn) error {
go func() {
defer closeOnce.Do(close)

// io.Copy is fine here, as the Windows Desktop Service
// operates on a stream and doesn't care if TPD messages
// are fragmented
stream := &WebsocketIO{Conn: ws}
_, err := io.Copy(wds, stream)
if utils.IsOKNetworkError(err) {
err = nil
buf := make([]byte, 4096)
for {
_, reader, err := ws.NextReader()
switch {
case utils.IsOKNetworkError(err):
errs <- nil
return
case err != nil:
errs <- err
return
}
n, err := reader.Read(buf)
if err != nil {
errs <- err
return
}
// don't pass the sync keys message along to old agents
// (they don't support it)
if isPre15 && tdp.MessageType(buf[0]) == tdp.TypeSyncKeys {
continue
}

if _, err := wds.Write(buf[:n]); err != nil {
errs <- trace.Wrap(err, "sending TDP message to desktop agent")
return
}
}
errs <- err
}()

var retErrs []error
Expand Down