From b752b3b38999db77640931b9d7f23355320c5e19 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Thu, 23 Nov 2023 13:36:21 +0000 Subject: [PATCH 01/17] Read the bearer token over WS endpoints use the request context, not session Dont pass websocket by context lint resolve some comments Add TestWSAuthenticateRequest Close ws in handler deprecation notices, doc resolve comments resolve comments give a longer read/write deadline dont set write deadline, ws endpoints never did before and it breaks things convert frontend to use ws access token Resolove comments, move to using an explicit state fix ci reset read deadline prettier --- lib/web/apiserver.go | 166 ++++++++++- lib/web/apiserver_test.go | 279 +++++++++++++++++- lib/web/assistant.go | 20 +- lib/web/command.go | 15 +- lib/web/desktop.go | 12 +- lib/web/terminal.go | 22 +- .../src/Assist/context/AssistContext.tsx | 68 ++++- .../TerminalAssist/TerminalAssistContext.tsx | 35 ++- .../teleport/src/Console/consoleContext.tsx | 3 +- .../src/DesktopSession/useTdpClientCanvas.tsx | 3 +- .../teleport/src/Player/DesktopPlayer.tsx | 3 +- web/packages/teleport/src/config.ts | 16 +- web/packages/teleport/src/lib/tdp/client.ts | 26 +- web/packages/teleport/src/lib/term/tty.ts | 29 +- web/packages/teleport/src/types.ts | 8 + 15 files changed, 609 insertions(+), 96 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 0977e13cfc1d4..e323120a114ed 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -40,6 +40,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/oxy/ratelimit" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -720,7 +721,10 @@ func (h *Handler) bindDefaultEndpoints() { h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock)) // active sessions handlers - h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWS(false, h.siteNodeConnect)) // connect to an active session (via websocket) + h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWS(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions // Audit events handlers. @@ -828,7 +832,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet)) h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle)) // GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=&username=&width=&height= - h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWS(false, h.desktopConnectHandle)) + // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= + h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWS(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) @@ -889,7 +897,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups)) // WebSocket endpoint for the chat conversation - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWS(false, h.assistant)) + // WebSocket endpoint for the chat conversation, websocket auth + h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWS(true, h.assistant)) // Sets the title for the conversation. h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) @@ -908,7 +920,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID)) // Allows executing an arbitrary command on multiple nodes. - h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand)) + // Deprecated: The execute/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/command/:site/execute", h.WithClusterAuthWS(false, h.executeCommand)) + // Allows executing an arbitrary command on multiple nodes, websocket auth. + h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWS(true, h.executeCommand)) // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -2941,6 +2957,7 @@ func (h *Handler) siteNodeConnect( p httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { q := r.URL.Query() params := q.Get("params") @@ -3033,6 +3050,7 @@ func (h *Handler) siteNodeConnect( PROXYSigner: h.cfg.PROXYSigner, Tracker: tracker, PresenceChecker: h.cfg.PresenceChecker, + WebsocketConn: ws, } term, err := NewTerminal(ctx, terminalConfig) @@ -3731,6 +3749,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa // ClusterHandler is a authenticated handler that is called for some existing remote cluster type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) +// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster +type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error) + // WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy // (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster // or a remote trusted cluster) as specified by the ":site" url parameter. @@ -3745,12 +3766,72 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +// WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated +// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as +// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) +// as specified by the ":site" url parameter. +// +// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed +func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + if websocketAuth { + sctx, ws, site, err := h.authenticateWSRequestWithCluster(w, r, p) + if err != nil { + return nil, trace.Wrap(err) + } + + defer ws.Close() + return fn(w, r, p, sctx, site, ws) + } + + sctx, site, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + return nil, trace.Wrap(err) + } + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, + } + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + const errMsg = "Error upgrading to websocket" + h.log.WithError(err).Error(errMsg) + http.Error(w, errMsg, http.StatusInternalServerError) + return nil, nil + } + + defer ws.Close() + return fn(w, r, p, sctx, site, ws) + }) +} + +// authenticateWSRequestWithCluster ensures that a request is +// authenticated to this proxy via websocket, returning the +// *SessionContext (same as AuthenticateRequest), and also grabs the +// remoteSite (which can represent this local cluster or a remote +// trusted cluster) as specified by the ":site" url parameter. +func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, ws, err := h.AuthenticateRequestWS(w, r) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + site, err := h.getSiteByParams(sctx, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sctx, ws, site, nil +} + // authenticateRequestWithCluster ensures that a request is authenticated // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) { sctx, err := h.AuthenticateRequest(w, r, true) + if err != nil { return nil, nil, trace.Wrap(err) } @@ -4068,9 +4149,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { return trace.Wrap(err) } -// AuthenticateRequest authenticates request using combination of a session cookie -// and bearer token -func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { +func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) { const missingCookieMsg = "missing session cookie" cookie, err := r.Cookie(websession.CookieName) if err != nil || (cookie != nil && cookie.Value == "") { @@ -4085,6 +4164,17 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch clearSessionCookies((w)) return nil, trace.AccessDenied("need auth") } + + return sctx, nil +} + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token +func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, trace.Wrap(err) + } if checkBearerToken { creds, err := roundtrip.ParseAuthHeaders(r) if err != nil { @@ -4137,6 +4227,68 @@ func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader return ctx, nil } +type wsBearerToken struct { + Token string `json:"token"` +} + +type wsStatus struct { + Type string `json:"type"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +var wsIODeadline = time.Second * 4 + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token retrieved from a websocket +func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) { + ctx, err := h.validateCookie(w, r) + if err != nil { + return nil, nil, trace.Wrap(err) + } + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, + } + + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) + } + if err := ws.SetReadDeadline(deadlineForInterval(wsIODeadline)); err != nil { + log.WithError(err).Error("Error setting websocket read deadline") + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + var t wsBearerToken + if err := ws.ReadJSON(&t); err != nil { + return nil, nil, trace.Wrap(err) + } + if err := ctx.validateBearerToken(r.Context(), t.Token); err != nil { + ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }) + return nil, nil, trace.Wrap(err) + } + + if err := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "ok", + }); err != nil { + return nil, nil, trace.Wrap(err) + } + + if err := ws.SetReadDeadline(time.Time{}); err != nil { + log.WithError(err).Error("Error setting websocket read deadline") + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + return ctx, ws, nil +} + // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions // of the given user. func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 680ec4b727493..780de3e3173e5 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7349,6 +7349,102 @@ func (mock authProviderMock) GetRole(_ context.Context, _ string) (types.Role, e return nil, nil } +type terminalOpt func(t *TerminalRequest) + +func withSessionID(sid session.ID) terminalOpt { + return func(t *TerminalRequest) { t.SessionID = sid } +} + +func withServer(target string) terminalOpt { + return func(t *TerminalRequest) { t.Server = target } +} + +func withKeepaliveInterval(d time.Duration) terminalOpt { + return func(t *TerminalRequest) { t.KeepAliveInterval = d } +} + +func withParticipantMode(m types.SessionParticipantMode) terminalOpt { + return func(t *TerminalRequest) { t.ParticipantMode = m } +} + +func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOpt) (*websocket.Conn, *session.Session, error) { + req := TerminalRequest{ + Server: s.srvID, + Login: pack.login, + Term: session.TerminalParams{ + W: 100, + H: 100, + }, + } + for _, opt := range opts { + opt(&req) + } + + u := url.URL{ + Host: s.url().Host, + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), + } + data, err := json.Marshal(req) + if err != nil { + return nil, nil, err + } + + q := u.Query() + q.Set("params", string(data)) + u.RawQuery = q.Encode() + + dialer := websocket.Dialer{} + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + header := http.Header{} + header.Add("Origin", "http://localhost") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + ws, resp, err := dialer.Dial(u.String(), header) + if err != nil { + var sb strings.Builder + sb.WriteString("websocket dial") + if resp != nil { + fmt.Fprintf(&sb, "; status code %v;", resp.StatusCode) + fmt.Fprintf(&sb, "headers: %v; body: ", resp.Header) + io.Copy(&sb, resp.Body) + } + return nil, nil, trace.Wrap(err, sb.String()) + } + makeAuthReqOverWS(t, ws, pack.session.Token) + + ty, raw, err := ws.ReadMessage() + if err != nil { + return nil, nil, trace.Wrap(err) + } + require.Equal(t, websocket.BinaryMessage, ty) + var env Envelope + + err = proto.Unmarshal(raw, &env) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + var sessResp siteSessionGenerateResponse + + err = json.Unmarshal([]byte(env.Payload), &sessResp) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + err = resp.Body.Close() + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return ws, &sessResp.Session, nil +} + func waitForOutputWithDuration(r io.Reader, substr string, timeout time.Duration) error { timeoutCh := time.After(timeout) @@ -8118,18 +8214,91 @@ func (r *testProxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *Test return &TestWebClient{clt, t} } +func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) (*websocket.Conn, session.Session) { + u := url.URL{ + Host: r.webURL.Host, + Scheme: client.WSS, + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), + } + + requestData := TerminalRequest{ + Server: r.node.ID(), + Login: pack.login, + Term: session.TerminalParams{ + W: 100, + H: 100, + }, + } + + if sessionID != "" { + requestData.SessionID = sessionID + } + + data, err := json.Marshal(requestData) + require.NoError(t, err) + + q := u.Query() + q.Set("params", string(data)) + u.RawQuery = q.Encode() + + dialer := websocket.Dialer{} + dialer.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + header := http.Header{} + header.Add("Origin", "http://localhost") + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + ws, resp, err := dialer.Dial(u.String(), header) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ws.Close()) + require.NoError(t, resp.Body.Close()) + }) + + makeAuthReqOverWS(t, ws, pack.session.Token) + + ty, raw, err := ws.ReadMessage() + require.NoError(t, err) + require.Equal(t, websocket.BinaryMessage, ty) + var env Envelope + require.NoError(t, proto.Unmarshal(raw, &env)) + + var sessResp siteSessionGenerateResponse + require.NoError(t, json.Unmarshal([]byte(env.Payload), &sessResp)) + + return ws, sessResp.Session +} + +func makeAuthReqOverWS(t *testing.T, ws *websocket.Conn, token string) { + t.Helper() + authReq, err := json.Marshal(struct { + Token string `json:"token"` + }{Token: token}) + require.NoError(t, err) + + err = ws.WriteMessage(websocket.TextMessage, authReq) + require.NoError(t, err) + + _, authRes, err := ws.ReadMessage() + require.NoError(t, err) + require.Contains(t, string(authRes), `"status":"ok"`) +} + func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, - Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect", currentSiteShortcut, "desktop1"), + Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect/ws", currentSiteShortcut, "desktop1"), } q := u.Query() q.Set("username", "marek") q.Set("width", "100") q.Set("height", "100") - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -8144,6 +8313,9 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID s ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) + + makeAuthReqOverWS(t, ws, pack.session.Token) + t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) @@ -9135,6 +9307,109 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube }, nil } +func TestWSAuthenticateRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + env := newWebPack(t, 1) + proxy := env.proxies[0] + pack := proxy.authPack(t, "test-user@example.com", nil) + wsIODeadline = time.Second + + for _, tc := range []struct { + name string + serverExpectError string + expectResponse wsStatus + token string + writeTimeout func() + readTimeout func() + }{ + { + name: "valid token", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "ok", + }, + token: pack.session.Token, + }, + { + name: "invalid token", + serverExpectError: "not found", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }, + token: "honk", + }, + { + name: "server read timeout", + serverExpectError: "i/o timeout", + token: pack.session.Token, + readTimeout: func() { + <-time.After(time.Second * 3) + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sctx, ws, err := proxy.handler.handler.AuthenticateRequestWS(w, r) + if err != nil { + if !strings.Contains(err.Error(), tc.serverExpectError) { + t.Errorf("unexpected error: %v", err) + return + } + return + } + defer ws.Close() + if err == nil && tc.serverExpectError != "" { + t.Errorf("expected error, got nil") + return + } + + clt, err := sctx.GetClient() + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, err = clt.GetDomainName(ctx) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + })) + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + u := strings.Replace(server.URL, "http:", "ws:", 1) + conn, resp, err := websocket.DefaultDialer.Dial(u, header) + require.NoError(t, err) + defer conn.Close() + defer resp.Body.Close() + + if tc.readTimeout != nil { + tc.readTimeout() + } + err = conn.WriteJSON(wsBearerToken{ + Token: tc.token, + }) + require.NoError(t, err) + if tc.readTimeout != nil { + return // Reading will fail as the server will have closed the connection + } + + var status wsStatus + err = conn.ReadJSON(&status) + require.NoError(t, err) + require.Equal(t, tc.expectResponse, status) + }) + } +} + // TestSimultaneousAuthenticateRequest ensures that multiple authenticated // requests do not race to create a SessionContext. This would happen when // Proxies were deployed behind a round-robin load balancer. Only the Proxy diff --git a/lib/web/assistant.go b/lib/web/assistant.go index fca5fde9c4e19..da1caae195a17 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -332,9 +332,9 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, // This handler covers the main chat conversation as well as the // SSH completition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (any, error) { - if err := runAssistant(h, w, r, sctx, site); err != nil { + if err := runAssistant(h, w, r, sctx, site, ws); err != nil { h.log.Warn(trace.DebugReport(err)) return nil, trace.Wrap(err) } @@ -420,7 +420,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") @@ -455,20 +455,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil - } - // Note: This time should be longer than OpenAI response time. keepAliveInterval := netConfig.GetKeepAliveInterval() err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) diff --git a/lib/web/command.go b/lib/web/command.go index 9017b63ed1b34..bae2a66247b04 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -128,6 +128,7 @@ func (h *Handler) executeCommand( _ httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + rawWS *websocket.Conn, ) (any, error) { q := r.URL.Query() params := q.Get("params") @@ -171,20 +172,6 @@ func (h *Handler) executeCommand( clusterName := site.GetName() - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - rawWS, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil, nil - } - defer func() { rawWS.WriteMessage(websocket.CloseMessage, nil) rawWS.Close() diff --git a/lib/web/desktop.go b/lib/web/desktop.go index a89bcf01b17c3..abd663d897c1c 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -66,6 +66,7 @@ func (h *Handler) desktopConnectHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { desktopName := p.ByName("desktopName") if desktopName == "" { @@ -75,7 +76,7 @@ func (h *Handler) desktopConnectHandle( log := sctx.cfg.Log.WithField("desktop-name", desktopName).WithField("cluster-name", site.GetName()) log.Debug("New desktop access websocket connection") - if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site); err != nil { + if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site, ws); err != nil { // createDesktopConnection makes a best effort attempt to send an error to the user // (via websocket) before terminating the connection. We log the error here, but // return nil because our HTTP middleware will try to write the returned error in JSON @@ -94,15 +95,8 @@ func (h *Handler) createDesktopConnection( log *logrus.Entry, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) error { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return trace.Wrap(err) - } defer ws.Close() sendTDPError := func(err error) error { diff --git a/lib/web/terminal.go b/lib/web/terminal.go index f8c6ebb17882f..3c40c7a18c15c 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -142,6 +142,7 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl participantMode: cfg.ParticipantMode, tracker: cfg.Tracker, presenceChecker: cfg.PresenceChecker, + websocketConn: cfg.WebsocketConn, }, nil } @@ -191,6 +192,8 @@ type TerminalHandlerConfig struct { PresenceChecker PresenceChecker // Clock allows interaction with time. Clock clockwork.Clock + // WebsocketConn is the active websocket connection + WebsocketConn *websocket.Conn } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { @@ -317,6 +320,9 @@ type TerminalHandler struct { // clock used to interact with time. clock clockwork.Clock + + // websocketConn is the active websocket connection + websocketConn *websocket.Conn } // ServeHTTP builds a connection to the remote node and then pumps back two types of @@ -328,21 +334,9 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.ctx.AddClosers(t) defer t.ctx.RemoveCloser(t) - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - t.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return - } + ws := t.websocketConn - err = ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) + err := ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) if err != nil { t.log.WithError(err).Error("Error setting websocket readline") return diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 15dbf01f5068d..ec9dbe7d921b5 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -33,6 +33,8 @@ import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { getAccessToken, getHostName } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + import { AccessRequestClientMessage, ExecutionEnvelopeType, @@ -128,7 +130,6 @@ export function AssistContextProvider(props: PropsWithChildren) { cfg.getAssistConversationWebSocketUrl( getHostName(), clusterId, - getAccessToken(), conversationId ) ); @@ -142,11 +143,13 @@ export function AssistContextProvider(props: PropsWithChildren) { ); activeWebSocket.current.onopen = () => { - if (initialMessage) { - activeWebSocket.current.send(initialMessage); - } + executeCommandWebSocket.current.send( + JSON.stringify({ token: getAccessToken() }) + ); }; + let authenticated = false; + activeWebSocket.current.onclose = () => { dispatch({ type: AssistStateActionType.SetStreaming, @@ -154,9 +157,29 @@ export function AssistContextProvider(props: PropsWithChildren) { }); }; - activeWebSocket.current.onmessage = async event => { - const data = JSON.parse(event.data) as ServerMessage; + executeCommandWebSocket.current.onmessage = event => { + if (!authenticated) { + const authResponse = JSON.parse(event.data) as WebsocketStatus; + if (authResponse.type != 'create_session_response') { + executeCommandWebSocket.current.close(); + console.log('invalid auth response type: ' + authResponse.message); + return; + } + + if (authResponse.status == 'error') { + executeCommandWebSocket.current.close(); + console.log( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + authenticated = true; + if (initialMessage) { + activeWebSocket.current.send(initialMessage); + } + } + const data = JSON.parse(event.data) as ServerMessage; switch (data.type) { case ServerMessageType.Assist: dispatch({ @@ -448,15 +471,41 @@ export function AssistContextProvider(props: PropsWithChildren) { const url = cfg.getAssistExecuteCommandUrl( getHostName(), clusterId, - getAccessToken(), execParams ); - const proto = new Protobuf(); executeCommandWebSocket.current = new WebSocket(url); - executeCommandWebSocket.current.binaryType = 'arraybuffer'; + executeCommandWebSocket.current.onopen = () => { + executeCommandWebSocket.current.send( + JSON.stringify({ token: getAccessToken() }) + ); + }; + + let authenticated = false; + + const proto = new Protobuf(); executeCommandWebSocket.current.onmessage = event => { + if (!authenticated) { + const authResponse = JSON.parse(event.data) as WebsocketStatus; + if (authResponse.type != 'create_session_response') { + executeCommandWebSocket.current.close(); + console.log('invalid auth response type: ' + authResponse.message); + return; + } + + if (authResponse.status == 'error') { + executeCommandWebSocket.current.close(); + console.log( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + authenticated = true; + return; + } + + executeCommandWebSocket.current.binaryType = 'arraybuffer'; const uintArray = new Uint8Array(event.data); const msg = proto.decode(uintArray); @@ -535,6 +584,7 @@ export function AssistContextProvider(props: PropsWithChildren) { executeCommandWebSocket.current.onclose = () => { executeCommandWebSocket.current = null; + authenticated = false; // If the execution failed, we won't get a SESSION_END message, so we // need to mark all the results as finished here. diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 9bcd78ce43027..4b00c571aff69 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -61,7 +61,6 @@ export function TerminalAssistContextProvider( const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-cmdgen' ); @@ -74,8 +73,23 @@ export function TerminalAssistContextProvider( useEffect(() => { socketRef.current = new WebSocket(socketUrl); + socketRef.current.onopen = () => { + socketRef.current.send(JSON.stringify({ token: getAccessToken() })); + }; + socketRef.current.onmessage = e => { - const data = JSON.parse(e.data) as ServerMessage; + const resData = JSON.parse(e.data); + + if (resData.type === 'create_session_response') { + if (resData.status == 'error') { + socketRef.current.close(); + console.log('auth error connecting to websocket: ' + resData.message); + return; + } + return; + } + + const data = resData as ServerMessage; const payload = JSON.parse(data.payload) as { action: string; input: string; @@ -117,19 +131,30 @@ export function TerminalAssistContextProvider( const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-explain' ); const ws = new WebSocket(socketUrl); ws.onopen = () => { - ws.send(encodedOutput); + socketRef.current.send(JSON.stringify({ token: getAccessToken() })); }; ws.onmessage = event => { const message = event.data; - const msg = JSON.parse(message) as ServerMessage; + const resMsg = JSON.parse(message); + + if (resMsg.type === 'create_session_response') { + if (resMsg.status == 'error') { + socketRef.current.close(); + console.log('auth error connecting to websocket: ' + resMsg.message); + return; + } + ws.send(encodedOutput); + return; + } + + const msg = resMsg as ServerMessage; const explanation: ExplanationMessage = { author: Author.Teleport, diff --git a/web/packages/teleport/src/Console/consoleContext.tsx b/web/packages/teleport/src/Console/consoleContext.tsx index e696f18602e0e..76ba3013d7cac 100644 --- a/web/packages/teleport/src/Console/consoleContext.tsx +++ b/web/packages/teleport/src/Console/consoleContext.tsx @@ -24,7 +24,7 @@ import { W3CTraceContextPropagator } from '@opentelemetry/core'; import webSession from 'teleport/services/websession'; import history from 'teleport/services/history'; import cfg, { UrlResourcesParams, UrlSshParams } from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import Tty from 'teleport/lib/term/tty'; import TtyAddressResolver from 'teleport/lib/term/ttyAddressResolver'; import serviceSession, { @@ -197,7 +197,6 @@ export default class ConsoleContext { const ttyUrl = cfg.api.ttyWsAddr .replace(':fqdn', getHostName()) - .replace(':token', getAccessToken()) .replace(':clusterId', clusterId) .replace(':traceparent', carrier['traceparent']); diff --git a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx index 74153145af2b9..e5f3c986a3f21 100644 --- a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx +++ b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx @@ -30,7 +30,7 @@ import { PngFrame, SyncKeys, } from 'teleport/lib/tdp/codec'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import cfg from 'teleport/config'; import { Sha256Digest } from 'teleport/lib/util'; @@ -85,7 +85,6 @@ export default function useTdpClientCanvas(props: Props) { .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':desktopName', desktopName) - .replace(':token', getAccessToken()) .replace(':username', username); setTdpClient(new TdpClient(addr)); diff --git a/web/packages/teleport/src/Player/DesktopPlayer.tsx b/web/packages/teleport/src/Player/DesktopPlayer.tsx index c819d7e6581c8..b56e8b51ac0b9 100644 --- a/web/packages/teleport/src/Player/DesktopPlayer.tsx +++ b/web/packages/teleport/src/Player/DesktopPlayer.tsx @@ -23,7 +23,7 @@ import { Indicator, Box, Alert, Flex } from 'design'; import cfg from 'teleport/config'; import { StatusEnum, formatDisplayTime } from 'teleport/lib/player'; import { PlayerClient, TdpClient } from 'teleport/lib/tdp'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import TdpClientCanvas from 'teleport/components/TdpClientCanvas'; import ProgressBar from './ProgressBar'; @@ -158,7 +158,6 @@ const useDesktopPlayer = ({ clusterId, sid }) => { .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':sid', sid) - .replace(':token', getAccessToken()); return new PlayerClient({ url, setTime, setPlayerStatus, setStatusText }); }, [clusterId, sid, setTime, setPlayerStatus]); diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index ac3521c95bb86..99d3f7ab0a22b 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -196,12 +196,12 @@ const cfg = { desktopServicesPath: `/v1/webapi/sites/:clusterId/desktopservices?searchAsRoles=:searchAsRoles?&limit=:limit?&startKey=:startKey?&query=:query?&search=:search?&sort=:sort?`, desktopPath: `/v1/webapi/sites/:clusterId/desktops/:desktopName`, desktopWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect?access_token=:token&username=:username', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect/ws?username=:username', desktopPlaybackWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid?access_token=:token', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid/ws', desktopIsActive: '/v1/webapi/sites/:clusterId/desktops/:desktopName/active', ttyWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?access_token=:token¶ms=:params&traceparent=:traceparent', + 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?params=:params&traceparent=:traceparent', ttyPlaybackWsAddr: 'wss://:fqdn/v1/webapi/sites/:clusterId/ttyplayback/:sid?access_token=:token', // TODO(zmb3): get token out of URL activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', @@ -310,11 +310,11 @@ const cfg = { '/v1/webapi/assistant/conversations/:conversationId/title', assistGenerateSummaryPath: '/v1/webapi/assistant/title/summary', assistConversationWebSocketPath: - 'wss://:hostname/v1/webapi/sites/:clusterId/assistant', + 'wss://:hostname/v1/webapi/sites/:clusterId/assistant/ws', assistConversationHistoryPath: '/v1/webapi/assistant/conversations/:conversationId', assistExecuteCommandWebSocketPath: - 'wss://:hostname/v1/webapi/command/:clusterId/execute', + 'wss://:hostname/v1/webapi/command/:clusterId/execute/ws', userPreferencesPath: '/v1/webapi/user/preferences', userClusterPreferencesPath: '/v1/webapi/user/preferences/:clusterId', @@ -857,12 +857,10 @@ const cfg = { getAssistConversationWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, conversationId: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('conversation_id', conversationId); return ( @@ -876,12 +874,10 @@ const cfg = { getAssistActionWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, action: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('action', action); return ( @@ -901,12 +897,10 @@ const cfg = { getAssistExecuteCommandUrl( hostname: string, clusterId: string, - accessToken: string, params: Record ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('params', JSON.stringify(params)); return ( diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index f2584c95eadea..5955cc8e3a1a7 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -25,6 +25,9 @@ import init, { import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { WebsocketStatus } from 'teleport/types'; + +import { getAccessToken } from 'teleport/services/api'; import Codec, { MessageType, @@ -119,13 +122,34 @@ export default class Client extends EventEmitterWebAuthnSender { this.socket.onopen = () => { this.logger.info('websocket is open'); + this.socket.send(JSON.stringify({ token: getAccessToken() })); this.emit(TdpClientEvent.WS_OPEN); if (spec) { this.sendClientScreenSpec(spec); } }; + let authenticated = false; + this.socket.onmessage = async (ev: MessageEvent) => { + if (!authenticated) { + const authResponse = JSON.parse(ev.data) as WebsocketStatus; + if (authResponse.type != 'create_session_response') { + this.socket.close(); + console.log('invalid auth response type: ' + authResponse.message); + return; + } + + if (authResponse.status == 'error') { + this.socket.close(); + console.log( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + authenticated = true; + return; + } await this.processMessage(ev.data as ArrayBuffer); }; @@ -135,7 +159,7 @@ export default class Client extends EventEmitterWebAuthnSender { this.socket.onerror = null; this.socket.onclose = () => { this.logger.info('websocket is closed'); - + authenticated = false; // Clean up all of our socket's listeners and the socket itself. this.socket.onopen = null; this.socket.onmessage = null; diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 88a18bcc5d246..373668d5b6cde 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -19,7 +19,9 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { getAccessToken } from 'teleport/services/api'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; +import { WebsocketStatus } from 'teleport/types'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -39,6 +41,7 @@ class Tty extends EventEmitterWebAuthnSender { _addressResolver = null; _proto = new Protobuf(); _pendingUploads = {}; + _authenticated = false; constructor(addressResolver, props = {}) { super(); @@ -52,6 +55,7 @@ class Tty extends EventEmitterWebAuthnSender { this._onOpenConnection = this._onOpenConnection.bind(this); this._onCloseConnection = this._onCloseConnection.bind(this); this._onMessage = this._onMessage.bind(this); + this._onAuthMessage = this._onAuthMessage.bind(this); } disconnect(closeCode = WebsocketCloseCode.NORMAL) { @@ -65,7 +69,7 @@ class Tty extends EventEmitterWebAuthnSender { this.socket = new WebSocket(connStr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = this._onOpenConnection; - this.socket.onmessage = this._onMessage; + this.socket.onmessage = this._onAuthMessage; this.socket.onclose = this._onCloseConnection; } @@ -166,6 +170,7 @@ class Tty extends EventEmitterWebAuthnSender { _onOpenConnection() { this.emit('open'); logger.info('websocket is open'); + this.socket.send(JSON.stringify({ token: getAccessToken() })); } _onCloseConnection(e) { @@ -174,10 +179,32 @@ class Tty extends EventEmitterWebAuthnSender { this.socket.onclose = null; this.socket = null; this.emit(TermEvent.CONN_CLOSE, e); + this._authenticated = false; logger.info('websocket is closed'); } + _onAuthMessage(ev) { + const authResponse = JSON.parse(ev.data) as WebsocketStatus; + if (authResponse.type != 'create_session_response') { + this.socket.close(); + console.log('invalid auth response type: ' + authResponse.message); + return; + } + if (authResponse.status == 'error') { + this.socket.close(); + console.log( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + this._authenticated = true; + } + _onMessage(ev) { + if (!this._authenticated) { + this._onAuthMessage(ev); + return; + } try { const uintArray = new Uint8Array(ev.data); const msg = this._proto.decode(uintArray); diff --git a/web/packages/teleport/src/types.ts b/web/packages/teleport/src/types.ts index 3aabdcec6a52a..02c52361d9791 100644 --- a/web/packages/teleport/src/types.ts +++ b/web/packages/teleport/src/types.ts @@ -197,3 +197,11 @@ export enum RecommendationStatus { Notify = 'NOTIFY', Done = 'DONE', } + +// WebsocketStatus is used to indicate the auth status from a +// websocket connection +export type WebsocketStatus = { + type: string; + status: string; + message: string; +}; From b39fd47ca8356bffb610a8ee1699a09c7dc54662 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Tue, 30 Jan 2024 15:33:52 +0000 Subject: [PATCH 02/17] update connectToHost --- lib/web/apiserver_test.go | 162 ++++---------------------------------- lib/web/terminal_test.go | 8 +- 2 files changed, 21 insertions(+), 149 deletions(-) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 780de3e3173e5..e6e7dafd2c092 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7367,84 +7367,6 @@ func withParticipantMode(m types.SessionParticipantMode) terminalOpt { return func(t *TerminalRequest) { t.ParticipantMode = m } } -func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOpt) (*websocket.Conn, *session.Session, error) { - req := TerminalRequest{ - Server: s.srvID, - Login: pack.login, - Term: session.TerminalParams{ - W: 100, - H: 100, - }, - } - for _, opt := range opts { - opt(&req) - } - - u := url.URL{ - Host: s.url().Host, - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), - } - data, err := json.Marshal(req) - if err != nil { - return nil, nil, err - } - - q := u.Query() - q.Set("params", string(data)) - u.RawQuery = q.Encode() - - dialer := websocket.Dialer{} - dialer.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } - - header := http.Header{} - header.Add("Origin", "http://localhost") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } - - ws, resp, err := dialer.Dial(u.String(), header) - if err != nil { - var sb strings.Builder - sb.WriteString("websocket dial") - if resp != nil { - fmt.Fprintf(&sb, "; status code %v;", resp.StatusCode) - fmt.Fprintf(&sb, "headers: %v; body: ", resp.Header) - io.Copy(&sb, resp.Body) - } - return nil, nil, trace.Wrap(err, sb.String()) - } - makeAuthReqOverWS(t, ws, pack.session.Token) - - ty, raw, err := ws.ReadMessage() - if err != nil { - return nil, nil, trace.Wrap(err) - } - require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope - - err = proto.Unmarshal(raw, &env) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - var sessResp siteSessionGenerateResponse - - err = json.Unmarshal([]byte(env.Payload), &sessResp) - if err != nil { - return nil, nil, trace.Wrap(err) - } - - err = resp.Body.Close() - if err != nil { - return nil, nil, trace.Wrap(err) - } - - return ws, &sessResp.Session, nil -} - func waitForOutputWithDuration(r io.Reader, substr string, timeout time.Duration) error { timeoutCh := time.After(timeout) @@ -8214,78 +8136,25 @@ func (r *testProxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *Test return &TestWebClient{clt, t} } -func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session.ID) (*websocket.Conn, session.Session) { - u := url.URL{ - Host: r.webURL.Host, - Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), - } - - requestData := TerminalRequest{ - Server: r.node.ID(), - Login: pack.login, - Term: session.TerminalParams{ - W: 100, - H: 100, - }, - } - - if sessionID != "" { - requestData.SessionID = sessionID - } - - data, err := json.Marshal(requestData) - require.NoError(t, err) - - q := u.Query() - q.Set("params", string(data)) - u.RawQuery = q.Encode() - - dialer := websocket.Dialer{} - dialer.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, - } - - header := http.Header{} - header.Add("Origin", "http://localhost") - for _, cookie := range pack.cookies { - header.Add("Cookie", cookie.String()) - } - - ws, resp, err := dialer.Dial(u.String(), header) - require.NoError(t, err) - t.Cleanup(func() { - require.NoError(t, ws.Close()) - require.NoError(t, resp.Body.Close()) - }) - - makeAuthReqOverWS(t, ws, pack.session.Token) - - ty, raw, err := ws.ReadMessage() - require.NoError(t, err) - require.Equal(t, websocket.BinaryMessage, ty) - var env Envelope - require.NoError(t, proto.Unmarshal(raw, &env)) - - var sessResp siteSessionGenerateResponse - require.NoError(t, json.Unmarshal([]byte(env.Payload), &sessResp)) - - return ws, sessResp.Session -} - -func makeAuthReqOverWS(t *testing.T, ws *websocket.Conn, token string) { - t.Helper() +func makeAuthReqOverWS(ws *websocket.Conn, token string) error { authReq, err := json.Marshal(struct { Token string `json:"token"` }{Token: token}) - require.NoError(t, err) - - err = ws.WriteMessage(websocket.TextMessage, authReq) - require.NoError(t, err) + if err != nil { + return trace.Wrap(err) + } + if err := ws.WriteMessage(websocket.TextMessage, authReq); err != nil { + return trace.Wrap(err) + } _, authRes, err := ws.ReadMessage() - require.NoError(t, err) - require.Contains(t, string(authRes), `"status":"ok"`) + if err != nil { + return trace.Wrap(err) + } + if !strings.Contains(string(authRes), `"status":"ok"`) { + return trace.AccessDenied("unexpected response") + } + return nil } func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { @@ -8314,7 +8183,8 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID s ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) - makeAuthReqOverWS(t, ws, pack.session.Token) + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) t.Cleanup(func() { require.NoError(t, ws.Close()) diff --git a/lib/web/terminal_test.go b/lib/web/terminal_test.go index ce8996308ed53..2e7edf5150f6f 100644 --- a/lib/web/terminal_test.go +++ b/lib/web/terminal_test.go @@ -33,7 +33,6 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gorilla/websocket" - "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -127,12 +126,11 @@ func connectToHost(ctx context.Context, cfg connectConfig) (*terminal, error) { u := url.URL{ Host: cfg.proxy, Scheme: client.WSS, - Path: "/v1/webapi/sites/-current-/connect", + Path: "/v1/webapi/sites/-current-/connect/ws", } q := u.Query() q.Set("params", string(data)) - q.Set(roundtrip.AccessTokenQueryParam, cfg.pack.session.Token) u.RawQuery = q.Encode() header := http.Header{} @@ -162,6 +160,10 @@ func connectToHost(ctx context.Context, cfg connectConfig) (*terminal, error) { return nil, trace.Wrap(err) } + if err := makeAuthReqOverWS(ws, cfg.pack.session.Token); err != nil { + return nil, trace.Wrap(err) + } + if cfg.pingHandler != nil { ws.SetPingHandler(func(message string) error { return cfg.pingHandler(ws, message) From 4e2aca34df0f56604020fa2507e1c31b644fda93 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Tue, 30 Jan 2024 15:35:30 +0000 Subject: [PATCH 03/17] linter --- lib/web/apiserver_test.go | 18 ------------------ .../teleport/src/Player/DesktopPlayer.tsx | 2 +- 2 files changed, 1 insertion(+), 19 deletions(-) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index e6e7dafd2c092..0edb9a51c7b7a 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7349,24 +7349,6 @@ func (mock authProviderMock) GetRole(_ context.Context, _ string) (types.Role, e return nil, nil } -type terminalOpt func(t *TerminalRequest) - -func withSessionID(sid session.ID) terminalOpt { - return func(t *TerminalRequest) { t.SessionID = sid } -} - -func withServer(target string) terminalOpt { - return func(t *TerminalRequest) { t.Server = target } -} - -func withKeepaliveInterval(d time.Duration) terminalOpt { - return func(t *TerminalRequest) { t.KeepAliveInterval = d } -} - -func withParticipantMode(m types.SessionParticipantMode) terminalOpt { - return func(t *TerminalRequest) { t.ParticipantMode = m } -} - func waitForOutputWithDuration(r io.Reader, substr string, timeout time.Duration) error { timeoutCh := time.After(timeout) diff --git a/web/packages/teleport/src/Player/DesktopPlayer.tsx b/web/packages/teleport/src/Player/DesktopPlayer.tsx index b56e8b51ac0b9..8262835db060c 100644 --- a/web/packages/teleport/src/Player/DesktopPlayer.tsx +++ b/web/packages/teleport/src/Player/DesktopPlayer.tsx @@ -157,7 +157,7 @@ const useDesktopPlayer = ({ clusterId, sid }) => { const url = cfg.api.desktopPlaybackWsAddr .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) - .replace(':sid', sid) + .replace(':sid', sid); return new PlayerClient({ url, setTime, setPlayerStatus, setStatusText }); }, [clusterId, sid, setTime, setPlayerStatus]); From 4a3391e7c6e0c44549cc09c0d76519b2faa0cd8f Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Tue, 30 Jan 2024 18:33:45 +0000 Subject: [PATCH 04/17] read errors from websocket --- lib/web/apiserver.go | 18 ++++++++++++++++-- lib/web/apiserver_test.go | 7 +++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index e323120a114ed..1eb097a51d24c 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3766,6 +3766,12 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +// WSError is used to write errors that previously occurred before a +// websocket got upgraded +type WSError struct { + Error string `json:"error"` +} + // WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated // to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as // well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) @@ -3781,7 +3787,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } defer ws.Close() - return fn(w, r, p, sctx, site, ws) + _, err = fn(w, r, p, sctx, site, ws) + if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil { + h.log.WithError(err).Error("error writing json") + } + return nil, nil } sctx, site, err := h.authenticateRequestWithCluster(w, r, p) @@ -3802,7 +3812,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } defer ws.Close() - return fn(w, r, p, sctx, site, ws) + _, err = fn(w, r, p, sctx, site, ws) + if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil { + h.log.WithError(err).Error("error writing json") + } + return nil, nil }) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 0edb9a51c7b7a..10e68723bb177 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1361,8 +1361,11 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) { proxy: s.webServer.Listener.Addr().String(), sessionID: "/../../../foo", }) - require.Error(t, err) - require.Nil(t, term) + require.NoError(t, err) + var wsError WSError + err = term.ws.ReadJSON(&wsError) + require.NoError(t, err) + require.Equal(t, "/../../../foo is not a valid UUID", wsError.Error) } func TestResolveServerHostPort(t *testing.T) { From ff6b0e252501ef40c54a7283be21d37a90458982 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Wed, 31 Jan 2024 10:10:31 +0000 Subject: [PATCH 05/17] missing /ws on ttyWsAddr and fix wrong onmessage --- web/packages/teleport/src/config.ts | 2 +- web/packages/teleport/src/lib/term/tty.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index 99d3f7ab0a22b..c8d0748193fd5 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -201,7 +201,7 @@ const cfg = { 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid/ws', desktopIsActive: '/v1/webapi/sites/:clusterId/desktops/:desktopName/active', ttyWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?params=:params&traceparent=:traceparent', + 'wss://:fqdn/v1/webapi/sites/:clusterId/connect/ws?params=:params&traceparent=:traceparent', ttyPlaybackWsAddr: 'wss://:fqdn/v1/webapi/sites/:clusterId/ttyplayback/:sid?access_token=:token', // TODO(zmb3): get token out of URL activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 373668d5b6cde..d6da364a6b17b 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -69,7 +69,7 @@ class Tty extends EventEmitterWebAuthnSender { this.socket = new WebSocket(connStr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = this._onOpenConnection; - this.socket.onmessage = this._onAuthMessage; + this.socket.onmessage = this._onMessage; this.socket.onclose = this._onCloseConnection; } From 9b81949e0a962297d8524dba58750bc7ce822500 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Wed, 31 Jan 2024 11:09:32 +0000 Subject: [PATCH 06/17] fix race in test --- lib/web/apiserver.go | 26 ++++++++++++++++++++++---- lib/web/apiserver_test.go | 19 ++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 1eb097a51d24c..3213f5d999837 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3788,8 +3788,17 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl defer ws.Close() _, err = fn(w, r, p, sctx, site, ws) - if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil { - h.log.WithError(err).Error("error writing json") + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: err.Error(), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshalling proto") + return nil, nil + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") } return nil, nil } @@ -3813,8 +3822,17 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl defer ws.Close() _, err = fn(w, r, p, sctx, site, ws) - if err := ws.WriteJSON(WSError{Error: err.Error()}); err != nil { - h.log.WithError(err).Error("error writing json") + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: err.Error(), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshalling proto") + return nil, nil + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") } return nil, nil }) diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 10e68723bb177..4bcbe816b0a6b 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1355,17 +1355,25 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) { ctx, cancel := context.WithCancel(s.ctx) t.Cleanup(cancel) - term, err := connectToHost(ctx, connectConfig{ + result := make(chan error) + + _, err := connectToHost(ctx, connectConfig{ pack: s.authPack(t, "foo"), host: s.node.ID(), proxy: s.webServer.Listener.Addr().String(), sessionID: "/../../../foo", + handlers: map[string]WSHandlerFunc{ + defaults.WebsocketError: func(ctx context.Context, e Envelope) { + if e.Payload == "/../../../foo is not a valid UUID" { + result <- errors.New(e.Payload) + } + close(result) + }, + }, }) require.NoError(t, err) - var wsError WSError - err = term.ws.ReadJSON(&wsError) - require.NoError(t, err) - require.Equal(t, "/../../../foo is not a valid UUID", wsError.Error) + res := <-result + require.NotNil(t, res) } func TestResolveServerHostPort(t *testing.T) { @@ -1900,6 +1908,7 @@ func TestTerminal(t *testing.T) { host: s.node.ID(), proxy: s.webServer.Listener.Addr().String(), }) + require.NoError(t, err) t.Cleanup(func() { require.True(t, utils.IsOKNetworkError(term.Close())) }) From 73f49843caa27b54f8e2aa96a1d381d1099b2e17 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Wed, 31 Jan 2024 11:23:38 +0000 Subject: [PATCH 07/17] lint --- lib/web/apiserver.go | 4 ++-- lib/web/apiserver_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 3213f5d999837..86dcf3ea9229d 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3794,7 +3794,7 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } env, err := errEnvelope.Marshal() if err != nil { - h.log.WithError(err).Error("error marshalling proto") + h.log.WithError(err).Error("error marshaling proto") return nil, nil } if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { @@ -3828,7 +3828,7 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } env, err := errEnvelope.Marshal() if err != nil { - h.log.WithError(err).Error("error marshalling proto") + h.log.WithError(err).Error("error marshaling proto") return nil, nil } if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 4bcbe816b0a6b..5cdf12a31c45d 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1373,7 +1373,7 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) { }) require.NoError(t, err) res := <-result - require.NotNil(t, res) + require.Error(t, res) } func TestResolveServerHostPort(t *testing.T) { From 9836ad6608194a4b79c92f78368450de63f20dda Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Wed, 31 Jan 2024 12:06:09 +0000 Subject: [PATCH 08/17] skip TestTerminal as it takes 11 seconds to run --- build.assets/tooling/cmd/difftest/main.go | 3 ++ lib/web/apiserver.go | 59 +++++++++++------------ 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/build.assets/tooling/cmd/difftest/main.go b/build.assets/tooling/cmd/difftest/main.go index 725372012ba29..ed4a485534ae8 100644 --- a/build.assets/tooling/cmd/difftest/main.go +++ b/build.assets/tooling/cmd/difftest/main.go @@ -76,6 +76,9 @@ var ( // TestAdminActionMFA takes longer than 6 seconds to run. "TestAdminActionMFA", + + // TestTerminal takes 11 seconds to run. + "TestTerminal", } ) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 86dcf3ea9229d..0ea8835d6d547 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3766,12 +3766,6 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } -// WSError is used to write errors that previously occurred before a -// websocket got upgraded -type WSError struct { - Error string `json:"error"` -} - // WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated // to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as // well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) @@ -3787,18 +3781,21 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } defer ws.Close() - _, err = fn(w, r, p, sctx, site, ws) - errEnvelope := Envelope{ - Type: defaults.WebsocketError, - Payload: err.Error(), - } - env, err := errEnvelope.Marshal() - if err != nil { - h.log.WithError(err).Error("error marshaling proto") - return nil, nil - } - if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { - h.log.WithError(err).Error("error writing proto") + + if _, err := fn(w, r, p, sctx, site, ws); err != nil { + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: err.Error(), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return nil, nil + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return nil, nil + } } return nil, nil } @@ -3821,18 +3818,20 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl } defer ws.Close() - _, err = fn(w, r, p, sctx, site, ws) - errEnvelope := Envelope{ - Type: defaults.WebsocketError, - Payload: err.Error(), - } - env, err := errEnvelope.Marshal() - if err != nil { - h.log.WithError(err).Error("error marshaling proto") - return nil, nil - } - if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { - h.log.WithError(err).Error("error writing proto") + if _, err := fn(w, r, p, sctx, site, ws); err != nil { + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: err.Error(), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return nil, nil + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return nil, nil + } } return nil, nil }) From 6afb9165e12f0268198d435478533e0f89efb08b Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Wed, 31 Jan 2024 17:56:47 +0000 Subject: [PATCH 09/17] dont skip the test --- build.assets/tooling/cmd/difftest/main.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/build.assets/tooling/cmd/difftest/main.go b/build.assets/tooling/cmd/difftest/main.go index ed4a485534ae8..725372012ba29 100644 --- a/build.assets/tooling/cmd/difftest/main.go +++ b/build.assets/tooling/cmd/difftest/main.go @@ -76,9 +76,6 @@ var ( // TestAdminActionMFA takes longer than 6 seconds to run. "TestAdminActionMFA", - - // TestTerminal takes 11 seconds to run. - "TestTerminal", } ) From 6b3a9665e52b81b7c745ed069f1f23c7e74fd2a9 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Thu, 1 Feb 2024 12:57:07 +0000 Subject: [PATCH 10/17] resolve apiserver comments --- lib/web/apiserver.go | 62 +++++++++++++++++++-------------------- lib/web/apiserver_test.go | 2 +- 2 files changed, 31 insertions(+), 33 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 0ea8835d6d547..e5ca340db555a 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3766,6 +3766,22 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +func (h *Handler) writeErrToWS(ws *websocket.Conn, err error) { + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: err.Error(), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return + } +} + // WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated // to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as // well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) @@ -3783,19 +3799,7 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl defer ws.Close() if _, err := fn(w, r, p, sctx, site, ws); err != nil { - errEnvelope := Envelope{ - Type: defaults.WebsocketError, - Payload: err.Error(), - } - env, err := errEnvelope.Marshal() - if err != nil { - h.log.WithError(err).Error("error marshaling proto") - return nil, nil - } - if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { - h.log.WithError(err).Error("error writing proto") - return nil, nil - } + h.writeErrToWS(ws, err) } return nil, nil } @@ -3819,19 +3823,7 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl defer ws.Close() if _, err := fn(w, r, p, sctx, site, ws); err != nil { - errEnvelope := Envelope{ - Type: defaults.WebsocketError, - Payload: err.Error(), - } - env, err := errEnvelope.Marshal() - if err != nil { - h.log.WithError(err).Error("error marshaling proto") - return nil, nil - } - if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { - h.log.WithError(err).Error("error writing proto") - return nil, nil - } + h.writeErrToWS(ws, err) } return nil, nil }) @@ -4268,12 +4260,15 @@ type wsStatus struct { Message string `json:"message,omitempty"` } +// wsIODeadline is used to set a deadline for recieving a message from +// an authenticated websocket so unauthenticated sockets dont get left +// open. var wsIODeadline = time.Second * 4 // AuthenticateRequest authenticates request using combination of a session cookie // and bearer token retrieved from a websocket func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) { - ctx, err := h.validateCookie(w, r) + sctx, err := h.validateCookie(w, r) if err != nil { return nil, nil, trace.Wrap(err) } @@ -4287,8 +4282,7 @@ func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) if err != nil { return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) } - if err := ws.SetReadDeadline(deadlineForInterval(wsIODeadline)); err != nil { - log.WithError(err).Error("Error setting websocket read deadline") + if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil { return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) } @@ -4296,7 +4290,7 @@ func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) if err := ws.ReadJSON(&t); err != nil { return nil, nil, trace.Wrap(err) } - if err := ctx.validateBearerToken(r.Context(), t.Token); err != nil { + if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil { ws.WriteJSON(wsStatus{ Type: "create_session_response", Status: "error", @@ -4312,12 +4306,16 @@ func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) return nil, nil, trace.Wrap(err) } + // unset the deadline as downstream consumers should handle this themselves. if err := ws.SetReadDeadline(time.Time{}); err != nil { - log.WithError(err).Error("Error setting websocket read deadline") return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) } - return ctx, ws, nil + if err := parseMFAResponseFromRequest(r); err != nil { + return nil, nil, trace.Wrap(err) + } + + return sctx, ws, nil } // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5cdf12a31c45d..f5c3177a0abd9 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -9210,7 +9210,7 @@ func TestWSAuthenticateRequest(t *testing.T) { serverExpectError: "i/o timeout", token: pack.session.Token, readTimeout: func() { - <-time.After(time.Second * 3) + <-time.After(wsIODeadline * 3) }, }, } { From 1cfaecf74e2d216e0b54ae0b0f23f43d7d6e30dd Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Thu, 1 Feb 2024 16:03:55 +0000 Subject: [PATCH 11/17] Add an AuthenticatedWebSocket class --- .../src/lib/AuthenticatedWebsoscket.ts | 105 ++++++++++++++++++ web/packages/teleport/src/lib/term/tty.ts | 46 ++------ 2 files changed, 115 insertions(+), 36 deletions(-) create mode 100644 web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts diff --git a/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts b/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts new file mode 100644 index 0000000000000..30d8b17b5ff8c --- /dev/null +++ b/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts @@ -0,0 +1,105 @@ +/** + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { getAccessToken } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + +export class AuthenticatedWebSocket { + ws: WebSocket | undefined; + + onopenAfterAuth: (ev: Event) => void | undefined; + onmessageAfterAuth: (ev: MessageEvent) => void | undefined; + oncloseAfterAuth: (ev: CloseEvent) => void | undefined; + + private authenticated: boolean; + + constructor( + socketAddr: string, + onopen: (ev: Event) => void | null, + onmessage: (ev: MessageEvent) => void | null, + onerror: (ev: Event) => void | null, + onclose: (ev: CloseEvent) => void | null + ) { + this.onopen = this.onopen.bind(this); + this.onmessage = this.onmessage.bind(this); + this.onclose = this.onclose.bind(this); + + this.authenticated = false; + this.onmessageAfterAuth = onmessage; + this.onopenAfterAuth = onopen; + this.oncloseAfterAuth = onclose; + + this.ws = new WebSocket(socketAddr); + this.ws.binaryType = 'arraybuffer'; + + this.ws.onopen = this.onopen; + this.ws.onmessage = this.onmessage; + this.ws.onerror = onerror; + this.ws.onclose = this.onclose; + } + + onopen(): void { + this.ws.send(JSON.stringify({ token: getAccessToken() })); + } + + onmessage(ev: MessageEvent): void { + if (!this.authenticated) { + const authResponse = JSON.parse(ev.data) as WebsocketStatus; + if (authResponse.type != 'create_session_response') { + this.ws.close(); + console.log('invalid auth response type: ' + authResponse.message); + return; + } + + if (authResponse.status == 'error') { + this.ws.close(); + console.log( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + this.authenticated = true; + + if (this.onopenAfterAuth) { + this.onopenAfterAuth(ev); + } + + return; + } + + if (this.onmessageAfterAuth) { + this.onmessageAfterAuth(ev); + } + } + + onclose(ev: CloseEvent): void { + if (this.oncloseAfterAuth) { + this.oncloseAfterAuth(ev); + } + this.authenticated = false; + this.ws = null; + } + + send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + this.ws.send(data); + } + + close(code?: number, reason?: string): void { + this.ws.close(code, reason); + } +} diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index d6da364a6b17b..39ac5e5da6dde 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -19,9 +19,9 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; -import { getAccessToken } from 'teleport/services/api'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; -import { WebsocketStatus } from 'teleport/types'; + +import { AuthenticatedWebSocket } from '../AuthenticatedWebsoscket'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -33,7 +33,7 @@ const defaultOptions = { }; class Tty extends EventEmitterWebAuthnSender { - socket = null; + socket: AuthenticatedWebSocket = null; _buffered = true; _attachSocketBufferTimer; @@ -41,7 +41,6 @@ class Tty extends EventEmitterWebAuthnSender { _addressResolver = null; _proto = new Protobuf(); _pendingUploads = {}; - _authenticated = false; constructor(addressResolver, props = {}) { super(); @@ -55,7 +54,6 @@ class Tty extends EventEmitterWebAuthnSender { this._onOpenConnection = this._onOpenConnection.bind(this); this._onCloseConnection = this._onCloseConnection.bind(this); this._onMessage = this._onMessage.bind(this); - this._onAuthMessage = this._onAuthMessage.bind(this); } disconnect(closeCode = WebsocketCloseCode.NORMAL) { @@ -66,11 +64,13 @@ class Tty extends EventEmitterWebAuthnSender { connect(w: number, h: number) { const connStr = this._addressResolver.getConnStr(w, h); - this.socket = new WebSocket(connStr); - this.socket.binaryType = 'arraybuffer'; - this.socket.onopen = this._onOpenConnection; - this.socket.onmessage = this._onMessage; - this.socket.onclose = this._onCloseConnection; + this.socket = new AuthenticatedWebSocket( + connStr, + this._onOpenConnection, + this._onMessage, + null, + this._onCloseConnection + ); } send(data) { @@ -170,41 +170,15 @@ class Tty extends EventEmitterWebAuthnSender { _onOpenConnection() { this.emit('open'); logger.info('websocket is open'); - this.socket.send(JSON.stringify({ token: getAccessToken() })); } _onCloseConnection(e) { - this.socket.onopen = null; - this.socket.onmessage = null; - this.socket.onclose = null; this.socket = null; this.emit(TermEvent.CONN_CLOSE, e); - this._authenticated = false; logger.info('websocket is closed'); } - _onAuthMessage(ev) { - const authResponse = JSON.parse(ev.data) as WebsocketStatus; - if (authResponse.type != 'create_session_response') { - this.socket.close(); - console.log('invalid auth response type: ' + authResponse.message); - return; - } - if (authResponse.status == 'error') { - this.socket.close(); - console.log( - 'auth error connecting to websocket: ' + authResponse.message - ); - return; - } - this._authenticated = true; - } - _onMessage(ev) { - if (!this._authenticated) { - this._onAuthMessage(ev); - return; - } try { const uintArray = new Uint8Array(ev.data); const msg = this._proto.decode(uintArray); From 8fb7e751b8155647864c76ac8d7f6c201b1ac7c1 Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Thu, 1 Feb 2024 16:41:35 +0000 Subject: [PATCH 12/17] convert other clients to use AuthenticatedWebSocket --- .../src/Assist/context/AssistContext.tsx | 95 +++++-------------- .../TerminalAssist/TerminalAssistContext.tsx | 48 +++------- .../src/lib/AuthenticatedWebsoscket.ts | 22 +++-- web/packages/teleport/src/lib/tdp/client.ts | 59 ++++-------- 4 files changed, 68 insertions(+), 156 deletions(-) diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index ec9dbe7d921b5..23bcd53fcf8a7 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -65,6 +65,7 @@ import type { ServerMessage, } from 'teleport/Assist/types'; import type { AssistState } from 'teleport/Assist/context/state'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket'; interface AssistContextValue { cancelMfaChallenge: () => void; @@ -86,9 +87,9 @@ let lastCommandExecutionResultId = 0; const TEN_MINUTES = 10 * 60 * 1000; export function AssistContextProvider(props: PropsWithChildren) { - const activeWebSocket = useRef(null); + const activeWebSocket = useRef(null); // TODO(ryan): this should be removed once https://github.com/gravitational/teleport.e/pull/1609 is implemented - const executeCommandWebSocket = useRef(null); + const executeCommandWebSocket = useRef(null); const refreshWebSocketTimeout = useRef(null); const { clusterId } = useStickyClusterId(); @@ -126,13 +127,7 @@ export function AssistContextProvider(props: PropsWithChildren) { } function setupWebSocket(conversationId: string, initialMessage?: string) { - activeWebSocket.current = new WebSocket( - cfg.getAssistConversationWebSocketUrl( - getHostName(), - clusterId, - conversationId - ) - ); + window.clearTimeout(refreshWebSocketTimeout.current); @@ -142,43 +137,20 @@ export function AssistContextProvider(props: PropsWithChildren) { TEN_MINUTES * 0.8 ); - activeWebSocket.current.onopen = () => { - executeCommandWebSocket.current.send( - JSON.stringify({ token: getAccessToken() }) - ); - }; - - let authenticated = false; + const onopen = () => { + if (initialMessage) { + activeWebSocket.current.send(initialMessage); + } + } - activeWebSocket.current.onclose = () => { + const onclose = () => { dispatch({ type: AssistStateActionType.SetStreaming, streaming: false, }); }; - executeCommandWebSocket.current.onmessage = event => { - if (!authenticated) { - const authResponse = JSON.parse(event.data) as WebsocketStatus; - if (authResponse.type != 'create_session_response') { - executeCommandWebSocket.current.close(); - console.log('invalid auth response type: ' + authResponse.message); - return; - } - - if (authResponse.status == 'error') { - executeCommandWebSocket.current.close(); - console.log( - 'auth error connecting to websocket: ' + authResponse.message - ); - return; - } - authenticated = true; - if (initialMessage) { - activeWebSocket.current.send(initialMessage); - } - } - + const onmessage = event => { const data = JSON.parse(event.data) as ServerMessage; switch (data.type) { case ServerMessageType.Assist: @@ -273,6 +245,14 @@ export function AssistContextProvider(props: PropsWithChildren) { break; } }; + + activeWebSocket.current = new AuthenticatedWebSocket( + cfg.getAssistConversationWebSocketUrl( + getHostName(), + clusterId, + conversationId + ), onopen, onmessage, null, onclose + ); } async function createConversation() { @@ -474,37 +454,8 @@ export function AssistContextProvider(props: PropsWithChildren) { execParams ); - executeCommandWebSocket.current = new WebSocket(url); - - executeCommandWebSocket.current.onopen = () => { - executeCommandWebSocket.current.send( - JSON.stringify({ token: getAccessToken() }) - ); - }; - - let authenticated = false; - const proto = new Protobuf(); - executeCommandWebSocket.current.onmessage = event => { - if (!authenticated) { - const authResponse = JSON.parse(event.data) as WebsocketStatus; - if (authResponse.type != 'create_session_response') { - executeCommandWebSocket.current.close(); - console.log('invalid auth response type: ' + authResponse.message); - return; - } - - if (authResponse.status == 'error') { - executeCommandWebSocket.current.close(); - console.log( - 'auth error connecting to websocket: ' + authResponse.message - ); - return; - } - authenticated = true; - return; - } - + const onmessage = (event: MessageEvent) => { executeCommandWebSocket.current.binaryType = 'arraybuffer'; const uintArray = new Uint8Array(event.data); @@ -582,10 +533,8 @@ export function AssistContextProvider(props: PropsWithChildren) { } }; - executeCommandWebSocket.current.onclose = () => { + const onclose = () => { executeCommandWebSocket.current = null; - authenticated = false; - // If the execution failed, we won't get a SESSION_END message, so we // need to mark all the results as finished here. for (const nodeId of nodeIdToResultId.keys()) { @@ -597,6 +546,8 @@ export function AssistContextProvider(props: PropsWithChildren) { } nodeIdToResultId.clear(); }; + + executeCommandWebSocket.current = new AuthenticatedWebSocket(url, null, onmessage, null, onclose); } async function deleteConversation(conversationId: string) { diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 4b00c571aff69..9820351446396 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -36,6 +36,7 @@ import { SuggestedCommandMessage, UserMessage, } from 'teleport/Console/DocumentSsh/TerminalAssist/types'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket'; interface TerminalAssistContextValue { close: () => void; @@ -57,7 +58,7 @@ export function TerminalAssistContextProvider( const [visible, setVisible] = useState(false); - const socketRef = useRef(null); + const socketRef = useRef(null); const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, @@ -71,25 +72,8 @@ export function TerminalAssistContextProvider( const [messages, setMessages] = useState([]); useEffect(() => { - socketRef.current = new WebSocket(socketUrl); - - socketRef.current.onopen = () => { - socketRef.current.send(JSON.stringify({ token: getAccessToken() })); - }; - - socketRef.current.onmessage = e => { - const resData = JSON.parse(e.data); - - if (resData.type === 'create_session_response') { - if (resData.status == 'error') { - socketRef.current.close(); - console.log('auth error connecting to websocket: ' + resData.message); - return; - } - return; - } - - const data = resData as ServerMessage; + let onmessage = (e: MessageEvent) => { + const data = JSON.parse(e.data) as ServerMessage; const payload = JSON.parse(data.payload) as { action: string; input: string; @@ -109,6 +93,8 @@ export function TerminalAssistContextProvider( setLoading(false); setMessages(m => [message, ...m]); }; + + socketRef.current = new AuthenticatedWebSocket(socketUrl, null, onmessage); }, []); function close() { @@ -134,27 +120,14 @@ export function TerminalAssistContextProvider( 'ssh-explain' ); - const ws = new WebSocket(socketUrl); - - ws.onopen = () => { - socketRef.current.send(JSON.stringify({ token: getAccessToken() })); - }; - ws.onmessage = event => { - const message = event.data; - const resMsg = JSON.parse(message); - if (resMsg.type === 'create_session_response') { - if (resMsg.status == 'error') { - socketRef.current.close(); - console.log('auth error connecting to websocket: ' + resMsg.message); - return; - } + let onopen = () => { ws.send(encodedOutput); - return; - } + }; - const msg = resMsg as ServerMessage; + let onmessage = (event: MessageEvent) => { + const msg = JSON.parse(event.data) as ServerMessage; const explanation: ExplanationMessage = { author: Author.Teleport, @@ -167,6 +140,7 @@ export function TerminalAssistContextProvider( ws.close(); }; + const ws = new AuthenticatedWebSocket(socketUrl, onopen, onmessage); } function send(message: string) { diff --git a/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts b/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts index 30d8b17b5ff8c..8f717e7ddf3c9 100644 --- a/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts +++ b/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts @@ -30,10 +30,10 @@ export class AuthenticatedWebSocket { constructor( socketAddr: string, - onopen: (ev: Event) => void | null, - onmessage: (ev: MessageEvent) => void | null, - onerror: (ev: Event) => void | null, - onclose: (ev: CloseEvent) => void | null + onopen: (ev: Event) => void | null = null, + onmessage: (ev: MessageEvent) => void | null = null, + onerror: (ev: Event) => void | null = null, + onclose: (ev: CloseEvent) => void | null = null, ) { this.onopen = this.onopen.bind(this); this.onmessage = this.onmessage.bind(this); @@ -53,11 +53,19 @@ export class AuthenticatedWebSocket { this.ws.onclose = this.onclose; } - onopen(): void { + set binaryType(btype: BinaryType) { + this.ws.binaryType = btype + } + + get readyState(): number { + return this.ws.readyState + } + + private onopen(): void { this.ws.send(JSON.stringify({ token: getAccessToken() })); } - onmessage(ev: MessageEvent): void { + private onmessage(ev: MessageEvent): void { if (!this.authenticated) { const authResponse = JSON.parse(ev.data) as WebsocketStatus; if (authResponse.type != 'create_session_response') { @@ -87,7 +95,7 @@ export class AuthenticatedWebSocket { } } - onclose(ev: CloseEvent): void { + private onclose(ev: CloseEvent): void { if (this.oncloseAfterAuth) { this.oncloseAfterAuth(ev); } diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index 5955cc8e3a1a7..44df2ccb8f563 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -25,9 +25,8 @@ import init, { import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; -import { WebsocketStatus } from 'teleport/types'; -import { getAccessToken } from 'teleport/services/api'; +import { AuthenticatedWebSocket } from '../AuthenticatedWebsoscket'; import Codec, { MessageType, @@ -93,7 +92,7 @@ export enum LogType { // For convenience, this can be done in one fell swoop by calling Client.shutdown(). export default class Client extends EventEmitterWebAuthnSender { protected codec: Codec; - protected socket: WebSocket | undefined; + protected socket: AuthenticatedWebSocket | undefined; private socketAddr: string; private sdManager: SharedDirectoryManager; private fastPathProcessor: FastPathProcessor | undefined; @@ -117,57 +116,37 @@ export default class Client extends EventEmitterWebAuthnSender { async connect(spec?: ClientScreenSpec) { await this.initWasm(); - this.socket = new WebSocket(this.socketAddr); - this.socket.binaryType = 'arraybuffer'; - - this.socket.onopen = () => { + let onopen = () => { this.logger.info('websocket is open'); - this.socket.send(JSON.stringify({ token: getAccessToken() })); + this.emit(TdpClientEvent.WS_OPEN); if (spec) { this.sendClientScreenSpec(spec); } }; - let authenticated = false; - - this.socket.onmessage = async (ev: MessageEvent) => { - if (!authenticated) { - const authResponse = JSON.parse(ev.data) as WebsocketStatus; - if (authResponse.type != 'create_session_response') { - this.socket.close(); - console.log('invalid auth response type: ' + authResponse.message); - return; - } - - if (authResponse.status == 'error') { - this.socket.close(); - console.log( - 'auth error connecting to websocket: ' + authResponse.message - ); - return; - } - authenticated = true; - return; - } + let onmessage = async (ev: MessageEvent) => { await this.processMessage(ev.data as ArrayBuffer); }; // The socket 'error' event will only ever be emitted by the socket // prior to a socket 'close' event (https://stackoverflow.com/a/40084550/6277051). // Therefore, we can rely on our onclose handler to account for any websocket errors. - this.socket.onerror = null; - this.socket.onclose = () => { + let onerror = null; + let onclose = () => { this.logger.info('websocket is closed'); - authenticated = false; // Clean up all of our socket's listeners and the socket itself. - this.socket.onopen = null; - this.socket.onmessage = null; - this.socket.onclose = null; this.socket = null; - this.emit(TdpClientEvent.WS_CLOSE); }; + + this.socket = new AuthenticatedWebSocket( + this.socketAddr, + onopen, + onmessage, + onerror, + onclose + ); } private async initWasm() { @@ -586,9 +565,9 @@ export default class Client extends EventEmitterWebAuthnSender { protected send( data: string | ArrayBufferLike | Blob | ArrayBufferView ): void { - if (this.socket && this.socket.readyState === 1) { + if (this.socket && this.socket.ws.readyState === 1) { try { - this.socket.send(data); + this.socket.ws.send(data); } catch (e) { this.handleError(e, TdpClientEvent.CLIENT_ERROR); } @@ -711,7 +690,7 @@ export default class Client extends EventEmitterWebAuthnSender { ) { this.logger.error(err); this.emit(errType, err); - this.socket?.close(); + this.socket?.ws?.close(); } // Emits an warnType event @@ -730,7 +709,7 @@ export default class Client extends EventEmitterWebAuthnSender { // will simply do nothing. shutdown(closeCode = WebsocketCloseCode.NORMAL) { this.removeAllListeners(); - this.socket?.close(closeCode); + this.socket?.ws?.close(closeCode); } } From 7a4d2bfa902eaa55ddb549e22ea0ecf8b11de8a2 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Mon, 5 Feb 2024 10:43:35 -0800 Subject: [PATCH 13/17] Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket` (#37699) * Converts `AuthenticatedWebSocket` into drop-in replacement for `WebSocket` that automatically goes through Teleport's custom authentication process before facilitating any caller-defined communication. This also reverts previous-`WebSocket` users to their original state (sans the code for passing the bearer token in the query string), swapping in `AuthenticatedWebSocket` in place of `WebSocket`. --- .../src/Assist/context/AssistContext.tsx | 47 ++-- .../TerminalAssist/TerminalAssistContext.tsx | 22 +- .../src/lib/AuthenticatedWebSocket.ts | 261 ++++++++++++++++++ .../src/lib/AuthenticatedWebsoscket.ts | 113 -------- web/packages/teleport/src/lib/tdp/client.ts | 36 ++- web/packages/teleport/src/lib/term/tty.ts | 20 +- web/packages/teleport/src/types.ts | 2 +- 7 files changed, 323 insertions(+), 178 deletions(-) create mode 100644 web/packages/teleport/src/lib/AuthenticatedWebSocket.ts delete mode 100644 web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 23bcd53fcf8a7..fb9319285701d 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -31,9 +31,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state'; import { convertServerMessages } from 'teleport/Assist/context/utils'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; - -import { WebsocketStatus } from 'teleport/types'; +import { getHostName } from 'teleport/services/api'; import { AccessRequestClientMessage, @@ -50,6 +48,7 @@ import { makeMfaAuthenticateChallenge, WebauthnAssertionResponse, } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import * as service from '../service'; import { @@ -65,7 +64,6 @@ import type { ServerMessage, } from 'teleport/Assist/types'; import type { AssistState } from 'teleport/Assist/context/state'; -import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket'; interface AssistContextValue { cancelMfaChallenge: () => void; @@ -127,7 +125,13 @@ export function AssistContextProvider(props: PropsWithChildren) { } function setupWebSocket(conversationId: string, initialMessage?: string) { - + activeWebSocket.current = new AuthenticatedWebSocket( + cfg.getAssistConversationWebSocketUrl( + getHostName(), + clusterId, + conversationId + ) + ); window.clearTimeout(refreshWebSocketTimeout.current); @@ -137,21 +141,22 @@ export function AssistContextProvider(props: PropsWithChildren) { TEN_MINUTES * 0.8 ); - const onopen = () => { + activeWebSocket.current.onopen = () => { if (initialMessage) { activeWebSocket.current.send(initialMessage); } - } + }; - const onclose = () => { + activeWebSocket.current.onclose = () => { dispatch({ type: AssistStateActionType.SetStreaming, streaming: false, }); }; - const onmessage = event => { + activeWebSocket.current.onmessage = async event => { const data = JSON.parse(event.data) as ServerMessage; + switch (data.type) { case ServerMessageType.Assist: dispatch({ @@ -245,14 +250,6 @@ export function AssistContextProvider(props: PropsWithChildren) { break; } }; - - activeWebSocket.current = new AuthenticatedWebSocket( - cfg.getAssistConversationWebSocketUrl( - getHostName(), - clusterId, - conversationId - ), onopen, onmessage, null, onclose - ); } async function createConversation() { @@ -353,7 +350,7 @@ export function AssistContextProvider(props: PropsWithChildren) { if ( !activeWebSocket.current || - activeWebSocket.current.readyState === WebSocket.CLOSED + activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED ) { setupWebSocket(state.conversations.selectedId, data); } else { @@ -383,7 +380,8 @@ export function AssistContextProvider(props: PropsWithChildren) { function sendMfaChallenge(data: WebauthnAssertionResponse) { if ( !executeCommandWebSocket.current || - executeCommandWebSocket.current.readyState !== WebSocket.OPEN || + executeCommandWebSocket.current.readyState !== + AuthenticatedWebSocket.OPEN || !data ) { console.warn( @@ -455,8 +453,10 @@ export function AssistContextProvider(props: PropsWithChildren) { ); const proto = new Protobuf(); - const onmessage = (event: MessageEvent) => { - executeCommandWebSocket.current.binaryType = 'arraybuffer'; + executeCommandWebSocket.current = new AuthenticatedWebSocket(url); + executeCommandWebSocket.current.binaryType = 'arraybuffer'; + + executeCommandWebSocket.current.onmessage = event => { const uintArray = new Uint8Array(event.data); const msg = proto.decode(uintArray); @@ -533,8 +533,9 @@ export function AssistContextProvider(props: PropsWithChildren) { } }; - const onclose = () => { + executeCommandWebSocket.current.onclose = () => { executeCommandWebSocket.current = null; + // If the execution failed, we won't get a SESSION_END message, so we // need to mark all the results as finished here. for (const nodeId of nodeIdToResultId.keys()) { @@ -546,8 +547,6 @@ export function AssistContextProvider(props: PropsWithChildren) { } nodeIdToResultId.clear(); }; - - executeCommandWebSocket.current = new AuthenticatedWebSocket(url, null, onmessage, null, onclose); } async function deleteConversation(conversationId: string) { diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 9820351446396..ebdec9bfcf4dc 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -26,7 +26,7 @@ import React, { } from 'react'; import { Author, ServerMessage } from 'teleport/Assist/types'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { @@ -36,7 +36,7 @@ import { SuggestedCommandMessage, UserMessage, } from 'teleport/Console/DocumentSsh/TerminalAssist/types'; -import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebsoscket'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; interface TerminalAssistContextValue { close: () => void; @@ -72,7 +72,9 @@ export function TerminalAssistContextProvider( const [messages, setMessages] = useState([]); useEffect(() => { - let onmessage = (e: MessageEvent) => { + socketRef.current = new AuthenticatedWebSocket(socketUrl); + + socketRef.current.onmessage = e => { const data = JSON.parse(e.data) as ServerMessage; const payload = JSON.parse(data.payload) as { action: string; @@ -93,8 +95,6 @@ export function TerminalAssistContextProvider( setLoading(false); setMessages(m => [message, ...m]); }; - - socketRef.current = new AuthenticatedWebSocket(socketUrl, null, onmessage); }, []); function close() { @@ -120,14 +120,15 @@ export function TerminalAssistContextProvider( 'ssh-explain' ); + const ws = new AuthenticatedWebSocket(socketUrl); - - let onopen = () => { - ws.send(encodedOutput); + ws.onopen = () => { + ws.send(encodedOutput); }; - let onmessage = (event: MessageEvent) => { - const msg = JSON.parse(event.data) as ServerMessage; + ws.onmessage = event => { + const message = event.data; + const msg = JSON.parse(message) as ServerMessage; const explanation: ExplanationMessage = { author: Author.Teleport, @@ -140,7 +141,6 @@ export function TerminalAssistContextProvider( ws.close(); }; - const ws = new AuthenticatedWebSocket(socketUrl, onopen, onmessage); } function send(message: string) { diff --git a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts new file mode 100644 index 0000000000000..7986a9aae1ea1 --- /dev/null +++ b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts @@ -0,0 +1,261 @@ +import { getAccessToken } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + +/** + * `AuthenticatedWebSocket` is a drop-in replacement for + * the `WebSocket` class that handles Teleport's websocket + * authentication process. + */ +export class AuthenticatedWebSocket extends WebSocket { + private authenticated: boolean = false; + private openListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onopenInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private messageListeners: ((this: WebSocket, ev: MessageEvent) => any)[] = []; + private onmessageInternal: + | ((this: WebSocket, ev: MessageEvent) => any) + | null = null; + private oncloseListeners: ((this: WebSocket, ev: CloseEvent) => any)[] = []; + private oncloseInternal: ((this: WebSocket, ev: CloseEvent) => any) | null = + null; + private onerrorListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onerrorInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private binaryTypeInternal: BinaryType = 'blob'; // Default binaryType + private onopenEvent: Event | null = null; + + constructor(url: string | URL, protocols?: string | string[]) { + super(url, protocols); + // Set the binaryType to 'arraybuffer' to handle the authentication process. + super.binaryType = 'arraybuffer'; + + // The open event listener should immediately send the authentication token + super.onopen = (onopenEvent: Event) => { + super.send(JSON.stringify({ token: getAccessToken() })); + // Don't call the user defined onopen messages yet, wait for the authentication response. + this.onopenEvent = onopenEvent; + }; + + // The message event listener should handle the authentication response, + // and if it succeeds, set the binaryType to the user-defined value and + // trigger any user-added open listeners. + super.onmessage = (ev: MessageEvent) => { + // If not yet authenticated, handle the authentication response. + if (!this.authenticated) { + // Parse the message as a WebsocketStatus. + let authResponse: WebsocketStatus; + try { + authResponse = JSON.parse(ev.data) as WebsocketStatus; + } catch (e) { + this.triggerError('Error parsing JSON from websocket message: ' + e); + return; + } + + // Validate the WebsocketStatus. + if ( + !authResponse.type || + !authResponse.status || + !(authResponse.type === 'create_session_response') || + !(authResponse.status === 'ok' || authResponse.status === 'error') + ) { + this.triggerError( + 'Invalid auth response: ' + JSON.stringify(authResponse) + ); + return; + } + + // Authentication succeeded. + if (authResponse.status === 'ok') { + this.authenticated = true; + // Set the binaryType to the value set by the user (or back to the default 'blob'). + super.binaryType = this.binaryTypeInternal; + // Now that authentication is complete, trigger any user-added open listeners + // with the original onopen event. + this.openListeners.forEach(listener => + listener.call(this, this.onopenEvent) + ); + this.onopenInternal?.call(this, this.onopenEvent); + return; + } else { + // Authentication failed, authResponse.status === 'error'. + this.triggerError( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + } else { + // If authenticated, pass messages to user-added listeners. + this.messageListeners.forEach(listener => { + listener.call(this, ev); + }); + this.onmessageInternal?.call(this, ev); + } + }; + + // Set the 'close' event for cleanup. + super.onclose = (ev: CloseEvent) => { + // Trigger any user-added close listeners + this.oncloseListeners.forEach(listener => listener.call(this, ev)); + this.oncloseInternal?.call(this, ev); + this.authenticated = false; + }; + + // Set the 'error' event for cleanup. + super.onerror = (ev: Event) => { + // Trigger any user-added error listeners + this.onerrorListeners.forEach(listener => listener.call(this, ev)); + this.onerrorInternal?.call(this, ev); + this.authenticated = false; + }; + } + + // Authenticated send + override send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + if (!this.authenticated) { + // This should be unreachable, but just in case. + this.triggerError( + 'Cannot send data before authentication is complete. Data: ' + data + ); + return; + } + super.send(data); + } + + // Override addEventListener to intercept these listeners and store them in + // our appropriate arrays. They are called in the appropriate places in the + // `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + override addEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + this.openListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + } else if (type === 'message') { + this.messageListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + } else if (type === 'close') { + this.oncloseListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + } else if (type === 'error') { + this.onerrorListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + } else { + // This should be unreachable, but just in case. + super.addEventListener(type, listener); + } + } + + // Override the onopen, onmessage, onclose, and onerror properties to store the user-defined + // listeners in the appropriate internal properties. These are called in the appropriate places + // in the `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + + override set onopen(listener: (this: WebSocket, ev: Event) => any | null) { + this.onopenInternal = listener; + } + + override get onopen(): ((this: WebSocket, ev: Event) => any) | null { + return this.onopenInternal; + } + + override set onmessage( + listener: ((this: WebSocket, ev: MessageEvent) => any) | null + ) { + this.onmessageInternal = listener; + } + + override get onmessage(): + | ((this: WebSocket, ev: MessageEvent) => any) + | null { + return this.onmessageInternal; + } + + override set onclose( + listener: ((this: WebSocket, ev: CloseEvent) => any) | null + ) { + this.oncloseInternal = listener; + } + + override get onclose(): ((this: WebSocket, ev: CloseEvent) => any) | null { + return this.oncloseInternal; + } + + override set onerror(listener: ((this: WebSocket, ev: Event) => any) | null) { + this.onerrorInternal = listener; + } + + override get onerror(): ((this: WebSocket, ev: Event) => any) | null { + return this.onerrorInternal; + } + + // Override the binaryType property to store the user-defined binaryType in the appropriate internal property. + // This is because we need to set the binaryType to 'arraybuffer' for the authentication process (see constructor), + // and only then can we set it to the user-defined value. + override set binaryType(binaryType: BinaryType) { + if (this.authenticated) { + super.binaryType = binaryType; + return; + } + + this.binaryTypeInternal = binaryType; + } + + override get binaryType(): BinaryType { + return this.binaryTypeInternal; + } + + // Override removeEventListener to support listeners removal for 'open', 'message', and 'close' events + override removeEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + const index = this.openListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + if (index !== -1) { + this.openListeners.splice(index, 1); + } + } else if (type === 'message') { + const index = this.messageListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + if (index !== -1) { + this.messageListeners.splice(index, 1); + } + } else if (type === 'close') { + const index = this.oncloseListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + if (index !== -1) { + this.oncloseListeners.splice(index, 1); + } + } else if (type === 'error') { + const index = this.onerrorListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + if (index !== -1) { + this.onerrorListeners.splice(index, 1); + } + } else { + // This should be unreachable, but just in case. + super.removeEventListener( + type, + listener as EventListenerOrEventListenerObject + ); + } + } + + // Method to manually trigger an error event. + private triggerError(errorMessage: string): void { + const errorEvent = new ErrorEvent('error', { + error: new Error(errorMessage), + message: errorMessage, + }); + + // Dispatch the event to trigger all listeners attached for 'error' events. + this.dispatchEvent(errorEvent); + } +} diff --git a/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts b/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts deleted file mode 100644 index 8f717e7ddf3c9..0000000000000 --- a/web/packages/teleport/src/lib/AuthenticatedWebsoscket.ts +++ /dev/null @@ -1,113 +0,0 @@ -/** - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -import { getAccessToken } from 'teleport/services/api'; -import { WebsocketStatus } from 'teleport/types'; - -export class AuthenticatedWebSocket { - ws: WebSocket | undefined; - - onopenAfterAuth: (ev: Event) => void | undefined; - onmessageAfterAuth: (ev: MessageEvent) => void | undefined; - oncloseAfterAuth: (ev: CloseEvent) => void | undefined; - - private authenticated: boolean; - - constructor( - socketAddr: string, - onopen: (ev: Event) => void | null = null, - onmessage: (ev: MessageEvent) => void | null = null, - onerror: (ev: Event) => void | null = null, - onclose: (ev: CloseEvent) => void | null = null, - ) { - this.onopen = this.onopen.bind(this); - this.onmessage = this.onmessage.bind(this); - this.onclose = this.onclose.bind(this); - - this.authenticated = false; - this.onmessageAfterAuth = onmessage; - this.onopenAfterAuth = onopen; - this.oncloseAfterAuth = onclose; - - this.ws = new WebSocket(socketAddr); - this.ws.binaryType = 'arraybuffer'; - - this.ws.onopen = this.onopen; - this.ws.onmessage = this.onmessage; - this.ws.onerror = onerror; - this.ws.onclose = this.onclose; - } - - set binaryType(btype: BinaryType) { - this.ws.binaryType = btype - } - - get readyState(): number { - return this.ws.readyState - } - - private onopen(): void { - this.ws.send(JSON.stringify({ token: getAccessToken() })); - } - - private onmessage(ev: MessageEvent): void { - if (!this.authenticated) { - const authResponse = JSON.parse(ev.data) as WebsocketStatus; - if (authResponse.type != 'create_session_response') { - this.ws.close(); - console.log('invalid auth response type: ' + authResponse.message); - return; - } - - if (authResponse.status == 'error') { - this.ws.close(); - console.log( - 'auth error connecting to websocket: ' + authResponse.message - ); - return; - } - this.authenticated = true; - - if (this.onopenAfterAuth) { - this.onopenAfterAuth(ev); - } - - return; - } - - if (this.onmessageAfterAuth) { - this.onmessageAfterAuth(ev); - } - } - - private onclose(ev: CloseEvent): void { - if (this.oncloseAfterAuth) { - this.oncloseAfterAuth(ev); - } - this.authenticated = false; - this.ws = null; - } - - send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { - this.ws.send(data); - } - - close(code?: number, reason?: string): void { - this.ws.close(code, reason); - } -} diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index 44df2ccb8f563..0dcbc8077cde8 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -25,8 +25,7 @@ import init, { import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; - -import { AuthenticatedWebSocket } from '../AuthenticatedWebsoscket'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import Codec, { MessageType, @@ -116,37 +115,36 @@ export default class Client extends EventEmitterWebAuthnSender { async connect(spec?: ClientScreenSpec) { await this.initWasm(); - let onopen = () => { - this.logger.info('websocket is open'); + this.socket = new AuthenticatedWebSocket(this.socketAddr); + this.socket.binaryType = 'arraybuffer'; + this.socket.onopen = () => { + this.logger.info('websocket is open'); this.emit(TdpClientEvent.WS_OPEN); if (spec) { this.sendClientScreenSpec(spec); } }; - let onmessage = async (ev: MessageEvent) => { + this.socket.onmessage = async (ev: MessageEvent) => { await this.processMessage(ev.data as ArrayBuffer); }; // The socket 'error' event will only ever be emitted by the socket // prior to a socket 'close' event (https://stackoverflow.com/a/40084550/6277051). // Therefore, we can rely on our onclose handler to account for any websocket errors. - let onerror = null; - let onclose = () => { + this.socket.onerror = null; + this.socket.onclose = () => { this.logger.info('websocket is closed'); + // Clean up all of our socket's listeners and the socket itself. + this.socket.onopen = null; + this.socket.onmessage = null; + this.socket.onclose = null; this.socket = null; + this.emit(TdpClientEvent.WS_CLOSE); }; - - this.socket = new AuthenticatedWebSocket( - this.socketAddr, - onopen, - onmessage, - onerror, - onclose - ); } private async initWasm() { @@ -565,9 +563,9 @@ export default class Client extends EventEmitterWebAuthnSender { protected send( data: string | ArrayBufferLike | Blob | ArrayBufferView ): void { - if (this.socket && this.socket.ws.readyState === 1) { + if (this.socket && this.socket.readyState === 1) { try { - this.socket.ws.send(data); + this.socket.send(data); } catch (e) { this.handleError(e, TdpClientEvent.CLIENT_ERROR); } @@ -690,7 +688,7 @@ export default class Client extends EventEmitterWebAuthnSender { ) { this.logger.error(err); this.emit(errType, err); - this.socket?.ws?.close(); + this.socket?.close(); } // Emits an warnType event @@ -709,7 +707,7 @@ export default class Client extends EventEmitterWebAuthnSender { // will simply do nothing. shutdown(closeCode = WebsocketCloseCode.NORMAL) { this.removeAllListeners(); - this.socket?.ws?.close(closeCode); + this.socket?.close(closeCode); } } diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 39ac5e5da6dde..fe45eb930d65a 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -20,8 +20,7 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; - -import { AuthenticatedWebSocket } from '../AuthenticatedWebsoscket'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -33,7 +32,7 @@ const defaultOptions = { }; class Tty extends EventEmitterWebAuthnSender { - socket: AuthenticatedWebSocket = null; + socket = null; _buffered = true; _attachSocketBufferTimer; @@ -64,13 +63,11 @@ class Tty extends EventEmitterWebAuthnSender { connect(w: number, h: number) { const connStr = this._addressResolver.getConnStr(w, h); - this.socket = new AuthenticatedWebSocket( - connStr, - this._onOpenConnection, - this._onMessage, - null, - this._onCloseConnection - ); + this.socket = new AuthenticatedWebSocket(connStr); + this.socket.binaryType = 'arraybuffer'; + this.socket.onopen = this._onOpenConnection; + this.socket.onmessage = this._onMessage; + this.socket.onclose = this._onCloseConnection; } send(data) { @@ -173,6 +170,9 @@ class Tty extends EventEmitterWebAuthnSender { } _onCloseConnection(e) { + this.socket.onopen = null; + this.socket.onmessage = null; + this.socket.onclose = null; this.socket = null; this.emit(TermEvent.CONN_CLOSE, e); logger.info('websocket is closed'); diff --git a/web/packages/teleport/src/types.ts b/web/packages/teleport/src/types.ts index 02c52361d9791..144db28946953 100644 --- a/web/packages/teleport/src/types.ts +++ b/web/packages/teleport/src/types.ts @@ -203,5 +203,5 @@ export enum RecommendationStatus { export type WebsocketStatus = { type: string; status: string; - message: string; + message?: string; }; From 27ed40a7a6cbc284bed1c4d8751ed93b560b4e89 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Mon, 5 Feb 2024 11:40:36 -0800 Subject: [PATCH 14/17] Create a single authnWsUpgrader with a comment justifying why we turn off CORS --- lib/web/apiserver.go | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index e5ca340db555a..a81e43cb63928 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3782,6 +3782,18 @@ func (h *Handler) writeErrToWS(ws *websocket.Conn, err error) { } } +// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket. +// This makes our lives easier in our automated tests. While ordinarily this would be +// used to enforce the same-origin policy, we don't need to worry about that for authenticated +// websockets, which also require a valid bearer token sent over the websocket after upgrade. +// Therefore even if an attacker were to connect to the websocket and trick the browser into +// sending the session cookie, they would still fail to send the bearer token needed to authenticate. +var authnWsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + // WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated // to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as // well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) @@ -3808,12 +3820,7 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl if err != nil { return nil, trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - ws, err := upgrader.Upgrade(w, r, nil) + ws, err := authnWsUpgrader.Upgrade(w, r, nil) if err != nil { const errMsg = "Error upgrading to websocket" h.log.WithError(err).Error(errMsg) @@ -4272,13 +4279,7 @@ func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) if err != nil { return nil, nil, trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) + ws, err := authnWsUpgrader.Upgrade(w, r, nil) if err != nil { return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) } From 70b21d803bbc0cab6f944a3c4b495f7b31ed232d Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Mon, 5 Feb 2024 13:08:15 -0800 Subject: [PATCH 15/17] recieving to receiving --- lib/web/apiserver.go | 2 +- web/packages/teleport/src/lib/tdp/client.ts | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a81e43cb63928..a443b614a04b3 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -4267,7 +4267,7 @@ type wsStatus struct { Message string `json:"message,omitempty"` } -// wsIODeadline is used to set a deadline for recieving a message from +// wsIODeadline is used to set a deadline for receiving a message from // an authenticated websocket so unauthenticated sockets dont get left // open. var wsIODeadline = time.Second * 4 diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index 0dcbc8077cde8..64b66dde7dffd 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -86,7 +86,7 @@ export enum LogType { } // Client is the TDP client. It is responsible for connecting to a websocket serving the tdp server, -// sending client commands, and recieving and processing server messages. Its creator is responsible for +// sending client commands, and receiving and processing server messages. Its creator is responsible for // ensuring the websocket gets closed and all of its event listeners cleaned up when it is no longer in use. // For convenience, this can be done in one fell swoop by calling Client.shutdown(). export default class Client extends EventEmitterWebAuthnSender { From 883e3d2d6f9a91ec1bb0e7dbcf28c06d4669e1ea Mon Sep 17 00:00:00 2001 From: Alex McGrath Date: Wed, 7 Feb 2024 12:40:02 +0000 Subject: [PATCH 16/17] resolve comments --- lib/web/apiserver.go | 53 ++++++++++++------- lib/web/apiserver_test.go | 14 ++--- .../src/lib/AuthenticatedWebSocket.ts | 18 +++++++ 3 files changed, 60 insertions(+), 25 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a443b614a04b3..061a04e6916ec 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -152,6 +152,11 @@ type Handler struct { // tracer is used to create spans. tracer oteltrace.Tracer + + // wsIODeadline is used to set a deadline for receiving a message from + // an authenticated websocket so unauthenticated sockets dont get left + // open. + wsIODeadline time.Duration } // HandlerOption is a functional argument - an option that can be passed @@ -366,6 +371,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb), + wsIODeadline: wsIODeadline, } // Check for self-hosted vs Cloud. @@ -723,8 +729,8 @@ func (h *Handler) bindDefaultEndpoints() { // active sessions handlers // Deprecated: The connect/ws variant should be used instead. // TODO(lxea): DELETE in v16 - h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWS(false, h.siteNodeConnect)) // connect to an active session (via websocket) - h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWS(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) + h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket) + h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions // Audit events handlers. @@ -834,9 +840,9 @@ func (h *Handler) bindDefaultEndpoints() { // GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=&username=&width=&height= // Deprecated: The connect/ws variant should be used instead. // TODO(lxea): DELETE in v16 - h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWS(false, h.desktopConnectHandle)) + h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= - h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWS(true, h.desktopConnectHandle)) + h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) @@ -899,9 +905,9 @@ func (h *Handler) bindDefaultEndpoints() { // WebSocket endpoint for the chat conversation // Deprecated: The connect/ws variant should be used instead. // TODO(lxea): DELETE in v16 - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWS(false, h.assistant)) + h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant)) // WebSocket endpoint for the chat conversation, websocket auth - h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWS(true, h.assistant)) + h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant)) // Sets the title for the conversation. h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) @@ -922,9 +928,9 @@ func (h *Handler) bindDefaultEndpoints() { // Allows executing an arbitrary command on multiple nodes. // Deprecated: The execute/ws variant should be used instead. // TODO(lxea): DELETE in v16 - h.GET("/webapi/command/:site/execute", h.WithClusterAuthWS(false, h.executeCommand)) + h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand)) // Allows executing an arbitrary command on multiple nodes, websocket auth. - h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWS(true, h.executeCommand)) + h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand)) // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -3766,10 +3772,13 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } -func (h *Handler) writeErrToWS(ws *websocket.Conn, err error) { +func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) { + if err == nil { + return + } errEnvelope := Envelope{ Type: defaults.WebsocketError, - Payload: err.Error(), + Payload: trace.UserMessage(err), } env, err := errEnvelope.Marshal() if err != nil { @@ -3794,24 +3803,25 @@ var authnWsUpgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } -// WithClusterAuthWS wraps a ClusterWebsocketHandler to ensure that a request is authenticated +// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated // to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as // well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) // as specified by the ":site" url parameter. // // TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed -func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { - return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { +func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { if websocketAuth { sctx, ws, site, err := h.authenticateWSRequestWithCluster(w, r, p) if err != nil { return nil, trace.Wrap(err) } - + // WS protocol requires the server send a close message + // which should be done by downstream users defer ws.Close() if _, err := fn(w, r, p, sctx, site, ws); err != nil { - h.writeErrToWS(ws, err) + h.writeErrToWebSocket(ws, err) } return nil, nil } @@ -3827,10 +3837,11 @@ func (h *Handler) WithClusterAuthWS(websocketAuth bool, fn ClusterWebsocketHandl http.Error(w, errMsg, http.StatusInternalServerError) return nil, nil } - + // WS protocol requires the server send a close message + // which should be done by downstream users defer ws.Close() if _, err := fn(w, r, p, sctx, site, ws); err != nil { - h.writeErrToWS(ws, err) + h.writeErrToWebSocket(ws, err) } return nil, nil }) @@ -4270,7 +4281,7 @@ type wsStatus struct { // wsIODeadline is used to set a deadline for receiving a message from // an authenticated websocket so unauthenticated sockets dont get left // open. -var wsIODeadline = time.Second * 4 +const wsIODeadline = time.Second * 4 // AuthenticateRequest authenticates request using combination of a session cookie // and bearer token retrieved from a websocket @@ -4292,11 +4303,15 @@ func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) return nil, nil, trace.Wrap(err) } if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil { - ws.WriteJSON(wsStatus{ + writeErr := ws.WriteJSON(wsStatus{ Type: "create_session_response", Status: "error", Message: "invalid token", }) + if writeErr != nil { + log.Errorf("Error while writing invalid token error to websocket: %s", writeErr) + } + return nil, nil, trace.Wrap(err) } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index f5c3177a0abd9..d47abc5c0a6b3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -9171,14 +9171,13 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube }, nil } -func TestWSAuthenticateRequest(t *testing.T) { +func TestWebSocketAuthenticateRequest(t *testing.T) { t.Parallel() ctx := context.Background() env := newWebPack(t, 1) proxy := env.proxies[0] + proxy.handler.handler.wsIODeadline = time.Second pack := proxy.authPack(t, "test-user@example.com", nil) - wsIODeadline = time.Second - for _, tc := range []struct { name string serverExpectError string @@ -9220,13 +9219,16 @@ func TestWSAuthenticateRequest(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sctx, ws, err := proxy.handler.handler.AuthenticateRequestWS(w, r) if err != nil { + if tc.serverExpectError == "" { + t.Errorf("unexpected error: %v", err) + } if !strings.Contains(err.Error(), tc.serverExpectError) { t.Errorf("unexpected error: %v", err) return } return } - defer ws.Close() + t.Cleanup(func() { ws.Close() }) if err == nil && tc.serverExpectError != "" { t.Errorf("expected error, got nil") return @@ -9252,8 +9254,8 @@ func TestWSAuthenticateRequest(t *testing.T) { u := strings.Replace(server.URL, "http:", "ws:", 1) conn, resp, err := websocket.DefaultDialer.Dial(u, header) require.NoError(t, err) - defer conn.Close() - defer resp.Body.Close() + t.Cleanup(func() { conn.Close() }) + t.Cleanup(func() { resp.Body.Close() }) if tc.readTimeout != nil { tc.readTimeout() diff --git a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts index 7986a9aae1ea1..4c1d0c4e5e281 100644 --- a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts +++ b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts @@ -1,3 +1,21 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + import { getAccessToken } from 'teleport/services/api'; import { WebsocketStatus } from 'teleport/types'; From cc4e557bce3353dfaaa785798f76a7c7af1a3d99 Mon Sep 17 00:00:00 2001 From: Isaiah Becker-Mayer Date: Thu, 8 Feb 2024 21:29:03 -0800 Subject: [PATCH 17/17] Updates `desktopPlaybackHandle` to new ws paradigm This was mistakenly left out of https://github.com/gravitational/teleport/pull/37520. This commit also refactors `WithClusterAuthWebSocket` slightly for easier comprehension, and updates the vite config to facilitate the new websocket endpoints in development mode. --- lib/web/apiserver.go | 47 ++++++++++++++++++++----------------- lib/web/desktop_playback.go | 11 +-------- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 061a04e6916ec..ff06a2b7e4a4f 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -844,7 +844,11 @@ func (h *Handler) bindDefaultEndpoints() { // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= - h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) + // Deprecated: The desktopplayback/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle)) + // // GET /webapi/sites/:site/desktopplayback/:sid/ws + h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) // GET a Connection Diagnostics by its name @@ -3811,32 +3815,20 @@ var authnWsUpgrader = websocket.Upgrader{ // TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { - if websocketAuth { - sctx, ws, site, err := h.authenticateWSRequestWithCluster(w, r, p) - if err != nil { - return nil, trace.Wrap(err) - } - // WS protocol requires the server send a close message - // which should be done by downstream users - defer ws.Close() + var sctx *SessionContext + var ws *websocket.Conn + var site reversetunnelclient.RemoteSite + var err error - if _, err := fn(w, r, p, sctx, site, ws); err != nil { - h.writeErrToWebSocket(ws, err) - } - return nil, nil + if websocketAuth { + sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p) + } else { + sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p) } - sctx, site, err := h.authenticateRequestWithCluster(w, r, p) if err != nil { return nil, trace.Wrap(err) } - ws, err := authnWsUpgrader.Upgrade(w, r, nil) - if err != nil { - const errMsg = "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil, nil - } // WS protocol requires the server send a close message // which should be done by downstream users defer ws.Close() @@ -3866,6 +3858,19 @@ func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *htt return sctx, ws, site, nil } +// TODO(lxea): remove once the deprecated websocket handlers are removed +func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, site, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + return sctx, ws, site, nil +} + // authenticateRequestWithCluster ensures that a request is authenticated // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index df04c330eed7f..9c50cdcc153c7 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -38,6 +38,7 @@ func (h *Handler) desktopPlaybackHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { @@ -49,16 +50,6 @@ func (h *Handler) desktopPlaybackHandle( return nil, trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return nil, trace.Wrap(err) - } - defer ws.Close() - player, err := player.New(&player.Config{ Clock: h.clock, Log: h.log,