diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 189909f9a0f5f..cbf765756171f 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -37,6 +37,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/oxy/ratelimit" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -146,6 +147,11 @@ type Handler struct { // tracer is used to create spans. tracer oteltrace.Tracer + + // wsIODeadline is used to set a deadline for receiving a message from + // an authenticated websocket so unauthenticated sockets dont get left + // open. + wsIODeadline time.Duration } // HandlerOption is a functional argument - an option that can be passed @@ -348,6 +354,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb), + wsIODeadline: wsIODeadline, } // Check for self-hosted vs Cloud. @@ -695,7 +702,10 @@ func (h *Handler) bindDefaultEndpoints() { h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock)) // active sessions handlers - h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket) + h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions // Audit events handlers. @@ -800,9 +810,17 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet)) h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle)) // GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=&username=&width=&height= - h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle)) + // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= + h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= - h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) + // Deprecated: The desktopplayback/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle)) + // // GET /webapi/sites/:site/desktopplayback/:sid/ws + h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) // GET a Connection Diagnostics by its name @@ -858,7 +876,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups)) // WebSocket endpoint for the chat conversation - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant)) + // WebSocket endpoint for the chat conversation, websocket auth + h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant)) // Sets the title for the conversation. h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) @@ -877,7 +899,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID)) // Allows executing an arbitrary command on multiple nodes. - h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand)) + // Deprecated: The execute/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand)) + // Allows executing an arbitrary command on multiple nodes, websocket auth. + h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand)) // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -2942,6 +2968,7 @@ func (h *Handler) siteNodeConnect( p httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { q := r.URL.Query() params := q.Get("params") @@ -3034,6 +3061,7 @@ func (h *Handler) siteNodeConnect( PROXYSigner: h.cfg.PROXYSigner, Tracker: tracker, Clock: h.clock, + WebsocketConn: ws, } term, err := NewTerminal(ctx, terminalConfig) @@ -3752,6 +3780,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa // ClusterHandler is a authenticated handler that is called for some existing remote cluster type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) +// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster +type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error) + // WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy // (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster // or a remote trusted cluster) as specified by the ":site" url parameter. @@ -3766,12 +3797,108 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) { + if err == nil { + return + } + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: trace.UserMessage(err), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return + } +} + +// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket. +// This makes our lives easier in our automated tests. While ordinarily this would be +// used to enforce the same-origin policy, we don't need to worry about that for authenticated +// websockets, which also require a valid bearer token sent over the websocket after upgrade. +// Therefore even if an attacker were to connect to the websocket and trick the browser into +// sending the session cookie, they would still fail to send the bearer token needed to authenticate. +var authnWsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated +// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as +// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) +// as specified by the ":site" url parameter. +// +// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed +func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { + var sctx *SessionContext + var ws *websocket.Conn + var site reversetunnelclient.RemoteSite + var err error + + if websocketAuth { + sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p) + } else { + sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p) + } + + if err != nil { + return nil, trace.Wrap(err) + } + // WS protocol requires the server send a close message + // which should be done by downstream users + defer ws.Close() + if _, err := fn(w, r, p, sctx, site, ws); err != nil { + h.writeErrToWebSocket(ws, err) + } + return nil, nil + }) +} + +// authenticateWSRequestWithCluster ensures that a request is +// authenticated to this proxy via websocket, returning the +// *SessionContext (same as AuthenticateRequest), and also grabs the +// remoteSite (which can represent this local cluster or a remote +// trusted cluster) as specified by the ":site" url parameter. +func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, ws, err := h.AuthenticateRequestWS(w, r) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + site, err := h.getSiteByParams(sctx, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sctx, ws, site, nil +} + +// TODO(lxea): remove once the deprecated websocket handlers are removed +func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, site, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + return sctx, ws, site, nil +} + // authenticateRequestWithCluster ensures that a request is authenticated // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) { sctx, err := h.AuthenticateRequest(w, r, true) + if err != nil { return nil, nil, trace.Wrap(err) } @@ -4089,9 +4216,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { return trace.Wrap(err) } -// AuthenticateRequest authenticates request using combination of a session cookie -// and bearer token -func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { +func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) { const missingCookieMsg = "missing session cookie" cookie, err := r.Cookie(websession.CookieName) if err != nil || (cookie != nil && cookie.Value == "") { @@ -4101,11 +4226,22 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch if err != nil { return nil, trace.AccessDenied("failed to decode cookie") } - ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID) + sctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID) if err != nil { websession.ClearCookie(w) return nil, trace.AccessDenied("need auth") } + + return sctx, nil +} + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token +func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { + ctx, err := h.validateCookie(w, r) + if err != nil { + return nil, trace.Wrap(err) + } if checkBearerToken { creds, err := roundtrip.ParseAuthHeaders(r) if err != nil { @@ -4118,6 +4254,68 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch return ctx, nil } +type wsBearerToken struct { + Token string `json:"token"` +} + +type wsStatus struct { + Type string `json:"type"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// wsIODeadline is used to set a deadline for receiving a message from +// an authenticated websocket so unauthenticated sockets dont get left +// open. +const wsIODeadline = time.Second * 4 + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token retrieved from a websocket +func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + var t wsBearerToken + if err := ws.ReadJSON(&t); err != nil { + return nil, nil, trace.Wrap(err) + } + if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil { + writeErr := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }) + if writeErr != nil { + log.Errorf("Error while writing invalid token error to websocket: %s", writeErr) + } + + return nil, nil, trace.Wrap(err) + } + + if err := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "ok", + }); err != nil { + return nil, nil, trace.Wrap(err) + } + + // unset the deadline as downstream consumers should handle this themselves. + if err := ws.SetReadDeadline(time.Time{}); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + return sctx, ws, nil +} + // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions // of the given user. func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index c389b92e76a65..a960d9c7cdeb4 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -7435,7 +7435,7 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp u := url.URL{ Host: s.url().Host, Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), } data, err := json.Marshal(req) if err != nil { @@ -7444,7 +7444,6 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp q := u.Query() q.Set("params", string(data)) - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -7470,6 +7469,10 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp return nil, nil, trace.Wrap(err, sb.String()) } + if err := makeAuthReqOverWS(ws, pack.session.Token); err != nil { + return nil, nil, trace.Wrap(err) + } + ty, raw, err := ws.ReadMessage() if err != nil { return nil, nil, trace.Wrap(err) @@ -8307,6 +8310,9 @@ func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session require.NoError(t, resp.Body.Close()) }) + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) + ty, raw, err := ws.ReadMessage() require.NoError(t, err) require.Equal(t, websocket.BinaryMessage, ty) @@ -8319,18 +8325,38 @@ func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session return ws, sessResp.Session } +func makeAuthReqOverWS(ws *websocket.Conn, token string) error { + authReq, err := json.Marshal(struct { + Token string `json:"token"` + }{Token: token}) + if err != nil { + return trace.Wrap(err) + } + + if err := ws.WriteMessage(websocket.TextMessage, authReq); err != nil { + return trace.Wrap(err) + } + _, authRes, err := ws.ReadMessage() + if err != nil { + return trace.Wrap(err) + } + if !strings.Contains(string(authRes), `"status":"ok"`) { + return trace.AccessDenied("unexpected response") + } + return nil +} + func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, - Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect", currentSiteShortcut, "desktop1"), + Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect/ws", currentSiteShortcut, "desktop1"), } q := u.Query() q.Set("username", "marek") q.Set("width", "100") q.Set("height", "100") - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -8345,6 +8371,10 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID s ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) + + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) @@ -9269,6 +9299,111 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube }, nil } +func TestWebSocketAuthenticateRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + env := newWebPack(t, 1) + proxy := env.proxies[0] + proxy.handler.handler.wsIODeadline = time.Second + pack := proxy.authPack(t, "test-user@example.com", nil) + for _, tc := range []struct { + name string + serverExpectError string + expectResponse wsStatus + token string + writeTimeout func() + readTimeout func() + }{ + { + name: "valid token", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "ok", + }, + token: pack.session.Token, + }, + { + name: "invalid token", + serverExpectError: "not found", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }, + token: "honk", + }, + { + name: "server read timeout", + serverExpectError: "i/o timeout", + token: pack.session.Token, + readTimeout: func() { + <-time.After(wsIODeadline * 3) + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sctx, ws, err := proxy.handler.handler.AuthenticateRequestWS(w, r) + if err != nil { + if tc.serverExpectError == "" { + t.Errorf("unexpected error: %v", err) + } + if !strings.Contains(err.Error(), tc.serverExpectError) { + t.Errorf("unexpected error: %v", err) + return + } + return + } + t.Cleanup(func() { ws.Close() }) + if err == nil && tc.serverExpectError != "" { + t.Errorf("expected error, got nil") + return + } + + clt, err := sctx.GetClient() + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, err = clt.GetDomainName(ctx) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + })) + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + u := strings.Replace(server.URL, "http:", "ws:", 1) + conn, resp, err := websocket.DefaultDialer.Dial(u, header) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + t.Cleanup(func() { resp.Body.Close() }) + + if tc.readTimeout != nil { + tc.readTimeout() + } + err = conn.WriteJSON(wsBearerToken{ + Token: tc.token, + }) + require.NoError(t, err) + if tc.readTimeout != nil { + return // Reading will fail as the server will have closed the connection + } + + var status wsStatus + err = conn.ReadJSON(&status) + require.NoError(t, err) + require.Equal(t, tc.expectResponse, status) + }) + } +} + // TestSimultaneousAuthenticateRequest ensures that multiple authenticated // requests do not race to create a SessionContext. This would happen when // Proxies were deployed behind a round-robin load balancer. Only the Proxy diff --git a/lib/web/assistant.go b/lib/web/assistant.go index 843f93f9d39dc..b2906c62b45b0 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -332,9 +332,9 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, // This handler covers the main chat conversation as well as the // SSH completition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (any, error) { - if err := runAssistant(h, w, r, sctx, site); err != nil { + if err := runAssistant(h, w, r, sctx, site, ws); err != nil { h.log.Warn(trace.DebugReport(err)) return nil, trace.Wrap(err) } @@ -420,7 +420,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") @@ -455,20 +455,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil - } - // Note: This time should be longer than OpenAI response time. keepAliveInterval := netConfig.GetKeepAliveInterval() err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) diff --git a/lib/web/command.go b/lib/web/command.go index 3a6aa08197ce7..e9c6e979ad846 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -128,6 +128,7 @@ func (h *Handler) executeCommand( _ httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + rawWS *websocket.Conn, ) (any, error) { q := r.URL.Query() params := q.Get("params") @@ -171,20 +172,6 @@ func (h *Handler) executeCommand( clusterName := site.GetName() - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - rawWS, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil, nil - } - defer func() { rawWS.WriteMessage(websocket.CloseMessage, nil) rawWS.Close() diff --git a/lib/web/desktop.go b/lib/web/desktop.go index 8ac3b99ce07a0..62fbbf90ace71 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -63,6 +63,7 @@ func (h *Handler) desktopConnectHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { desktopName := p.ByName("desktopName") if desktopName == "" { @@ -72,7 +73,7 @@ func (h *Handler) desktopConnectHandle( log := sctx.cfg.Log.WithField("desktop-name", desktopName).WithField("cluster-name", site.GetName()) log.Debug("New desktop access websocket connection") - if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site); err != nil { + if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site, ws); err != nil { // createDesktopConnection makes a best effort attempt to send an error to the user // (via websocket) before terminating the connection. We log the error here, but // return nil because our HTTP middleware will try to write the returned error in JSON @@ -97,15 +98,8 @@ func (h *Handler) createDesktopConnection( log *logrus.Entry, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) error { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return trace.Wrap(err) - } defer ws.Close() sendTDPError := func(err error) error { diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index be2035288e580..be32bbb36fbd5 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -31,22 +31,20 @@ func (h *Handler) desktopPlaybackHandle( w http.ResponseWriter, r *http.Request, p httprouter.Params, - ctx *SessionContext, + sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { return nil, trace.BadParameter("missing sid in request URL") } - clt, err := ctx.GetUserClient(r.Context(), site) + clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) } - websocket.Handler(func(ws *websocket.Conn) { - defer h.log.Debug("desktopPlaybackHandle websocket handler returned") - desktop.NewPlayer(sID, ws, clt, h.log).Play(r.Context()) - }).ServeHTTP(w, r) + desktop.NewPlayer(sID, ws, clt, h.log).Play(r.Context()) return nil, nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index ffab9e8e2c12d..f8da3fd870622 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -138,6 +138,7 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl participantMode: cfg.ParticipantMode, tracker: cfg.Tracker, clock: cfg.Clock, + websocketConn: cfg.WebsocketConn, }, nil } @@ -182,6 +183,8 @@ type TerminalHandlerConfig struct { Tracker types.SessionTracker // Clock used for presence checking. Clock clockwork.Clock + // WebsocketConn is the active websocket connection + WebsocketConn *websocket.Conn } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { @@ -288,12 +291,14 @@ type TerminalHandler struct { // if the user is not joining a session. tracker types.SessionTracker - // clock to use for presence checking - clock clockwork.Clock - // closedByClient indicates if the websocket connection was closed by the // user (closing the browser tab, exiting the session, etc). closedByClient atomic.Bool + // clock used to interact with time. + clock clockwork.Clock + + // websocketConn is the active websocket connection + websocketConn *websocket.Conn } // ServeHTTP builds a connection to the remote node and then pumps back two types of @@ -305,21 +310,9 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.ctx.AddClosers(t) defer t.ctx.RemoveCloser(t) - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - t.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return - } + ws := t.websocketConn - err = ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) + err := ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) if err != nil { t.log.WithError(err).Error("Error setting websocket readline") return diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index badc1838bc3bf..50a4c0d455b63 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -29,7 +29,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state'; import { convertServerMessages } from 'teleport/Assist/context/utils'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import { AccessRequestClientMessage, @@ -46,6 +46,7 @@ import { makeMfaAuthenticateChallenge, WebauthnAssertionResponse, } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import * as service from '../service'; import { @@ -82,9 +83,9 @@ let lastCommandExecutionResultId = 0; const TEN_MINUTES = 10 * 60 * 1000; export function AssistContextProvider(props: PropsWithChildren) { - const activeWebSocket = useRef(null); + const activeWebSocket = useRef(null); // TODO(ryan): this should be removed once https://github.com/gravitational/teleport.e/pull/1609 is implemented - const executeCommandWebSocket = useRef(null); + const executeCommandWebSocket = useRef(null); const refreshWebSocketTimeout = useRef(null); const { clusterId } = useStickyClusterId(); @@ -122,11 +123,10 @@ export function AssistContextProvider(props: PropsWithChildren) { } function setupWebSocket(conversationId: string, initialMessage?: string) { - activeWebSocket.current = new WebSocket( + activeWebSocket.current = new AuthenticatedWebSocket( cfg.getAssistConversationWebSocketUrl( getHostName(), clusterId, - getAccessToken(), conversationId ) ); @@ -348,7 +348,7 @@ export function AssistContextProvider(props: PropsWithChildren) { if ( !activeWebSocket.current || - activeWebSocket.current.readyState === WebSocket.CLOSED + activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED ) { setupWebSocket(state.conversations.selectedId, data); } else { @@ -378,7 +378,8 @@ export function AssistContextProvider(props: PropsWithChildren) { function sendMfaChallenge(data: WebauthnAssertionResponse) { if ( !executeCommandWebSocket.current || - executeCommandWebSocket.current.readyState !== WebSocket.OPEN || + executeCommandWebSocket.current.readyState !== + AuthenticatedWebSocket.OPEN || !data ) { console.warn( @@ -446,12 +447,11 @@ export function AssistContextProvider(props: PropsWithChildren) { const url = cfg.getAssistExecuteCommandUrl( getHostName(), clusterId, - getAccessToken(), execParams ); const proto = new Protobuf(); - executeCommandWebSocket.current = new WebSocket(url); + executeCommandWebSocket.current = new AuthenticatedWebSocket(url); executeCommandWebSocket.current.binaryType = 'arraybuffer'; executeCommandWebSocket.current.onmessage = event => { diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 9af1886c7a1a9..0b6535bdc4c5a 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -24,7 +24,7 @@ import React, { } from 'react'; import { Author, ServerMessage } from 'teleport/Assist/types'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { @@ -34,6 +34,7 @@ import { SuggestedCommandMessage, UserMessage, } from 'teleport/Console/DocumentSsh/TerminalAssist/types'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; interface TerminalAssistContextValue { close: () => void; @@ -55,11 +56,10 @@ export function TerminalAssistContextProvider( const [visible, setVisible] = useState(false); - const socketRef = useRef(null); + const socketRef = useRef(null); const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-cmdgen' ); @@ -70,7 +70,7 @@ export function TerminalAssistContextProvider( const [messages, setMessages] = useState([]); useEffect(() => { - socketRef.current = new WebSocket(socketUrl); + socketRef.current = new AuthenticatedWebSocket(socketUrl); socketRef.current.onmessage = e => { const data = JSON.parse(e.data) as ServerMessage; @@ -115,11 +115,10 @@ export function TerminalAssistContextProvider( const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-explain' ); - const ws = new WebSocket(socketUrl); + const ws = new AuthenticatedWebSocket(socketUrl); ws.onopen = () => { ws.send(encodedOutput); diff --git a/web/packages/teleport/src/Console/consoleContext.tsx b/web/packages/teleport/src/Console/consoleContext.tsx index 0d5a937fc344b..a0bb4642b6975 100644 --- a/web/packages/teleport/src/Console/consoleContext.tsx +++ b/web/packages/teleport/src/Console/consoleContext.tsx @@ -22,7 +22,7 @@ import { W3CTraceContextPropagator } from '@opentelemetry/core'; import webSession from 'teleport/services/websession'; import history from 'teleport/services/history'; import cfg, { UrlResourcesParams, UrlSshParams } from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import Tty from 'teleport/lib/term/tty'; import TtyAddressResolver from 'teleport/lib/term/ttyAddressResolver'; import serviceSession, { @@ -194,7 +194,6 @@ export default class ConsoleContext { const ttyUrl = cfg.api.ttyWsAddr .replace(':fqdn', getHostName()) - .replace(':token', getAccessToken()) .replace(':clusterId', clusterId) .replace(':traceparent', carrier['traceparent']); diff --git a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx index 3e0afce12cef8..97a0b78277960 100644 --- a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx +++ b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx @@ -22,7 +22,7 @@ import { getPlatformType } from 'design/platform'; import { TdpClient, ButtonState, ScrollAxis } from 'teleport/lib/tdp'; import { ClipboardData, PngFrame } from 'teleport/lib/tdp/codec'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import cfg from 'teleport/config'; import { Sha256Digest } from 'teleport/lib/util'; @@ -58,7 +58,6 @@ export default function useTdpClientCanvas(props: Props) { .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':desktopName', desktopName) - .replace(':token', getAccessToken()) .replace(':username', username) .replace(':width', width.toString()) .replace(':height', height.toString()); diff --git a/web/packages/teleport/src/Player/DesktopPlayer.tsx b/web/packages/teleport/src/Player/DesktopPlayer.tsx index 280af21335219..3a112addb6f41 100644 --- a/web/packages/teleport/src/Player/DesktopPlayer.tsx +++ b/web/packages/teleport/src/Player/DesktopPlayer.tsx @@ -21,7 +21,7 @@ import useAttempt from 'shared/hooks/useAttemptNext'; import cfg from 'teleport/config'; import { PlayerClient, PlayerClientEvent } from 'teleport/lib/tdp'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import TdpClientCanvas from 'teleport/components/TdpClientCanvas'; import { ProgressBarDesktop } from './ProgressBar'; @@ -110,7 +110,6 @@ const useDesktopPlayer = ({ .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':sid', sid) - .replace(':token', getAccessToken()) ) ); }, [clusterId, sid]); diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index 2c5afd4714e6c..89b23b6ddf053 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -191,12 +191,12 @@ const cfg = { desktopServicesPath: `/v1/webapi/sites/:clusterId/desktopservices?searchAsRoles=:searchAsRoles?&limit=:limit?&startKey=:startKey?&query=:query?&search=:search?&sort=:sort?`, desktopPath: `/v1/webapi/sites/:clusterId/desktops/:desktopName`, desktopWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect?access_token=:token&username=:username&width=:width&height=:height', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect/ws?username=:username&width=:width&height=:height', desktopPlaybackWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid?access_token=:token', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid/ws', desktopIsActive: '/v1/webapi/sites/:clusterId/desktops/:desktopName/active', ttyWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?access_token=:token¶ms=:params&traceparent=:traceparent', + 'wss://:fqdn/v1/webapi/sites/:clusterId/connect/ws?params=:params&traceparent=:traceparent', activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', sshPlaybackPrefix: '/v1/webapi/sites/:clusterId/sessions/:sid', // prefix because this is eventually concatenated with "/stream" or "/events" kubernetesPath: @@ -294,11 +294,11 @@ const cfg = { '/v1/webapi/assistant/conversations/:conversationId/title', assistGenerateSummaryPath: '/v1/webapi/assistant/title/summary', assistConversationWebSocketPath: - 'wss://:hostname/v1/webapi/sites/:clusterId/assistant', + 'wss://:hostname/v1/webapi/sites/:clusterId/assistant/ws', assistConversationHistoryPath: '/v1/webapi/assistant/conversations/:conversationId', assistExecuteCommandWebSocketPath: - 'wss://:hostname/v1/webapi/command/:clusterId/execute', + 'wss://:hostname/v1/webapi/command/:clusterId/execute/ws', userPreferencesPath: '/v1/webapi/user/preferences', userClusterPreferencesPath: '/v1/webapi/user/preferences/:clusterId', @@ -836,12 +836,10 @@ const cfg = { getAssistConversationWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, conversationId: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('conversation_id', conversationId); return ( @@ -855,12 +853,10 @@ const cfg = { getAssistActionWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, action: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('action', action); return ( @@ -880,12 +876,10 @@ const cfg = { getAssistExecuteCommandUrl( hostname: string, clusterId: string, - accessToken: string, params: Record ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('params', JSON.stringify(params)); return ( diff --git a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts new file mode 100644 index 0000000000000..4c1d0c4e5e281 --- /dev/null +++ b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts @@ -0,0 +1,279 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { getAccessToken } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + +/** + * `AuthenticatedWebSocket` is a drop-in replacement for + * the `WebSocket` class that handles Teleport's websocket + * authentication process. + */ +export class AuthenticatedWebSocket extends WebSocket { + private authenticated: boolean = false; + private openListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onopenInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private messageListeners: ((this: WebSocket, ev: MessageEvent) => any)[] = []; + private onmessageInternal: + | ((this: WebSocket, ev: MessageEvent) => any) + | null = null; + private oncloseListeners: ((this: WebSocket, ev: CloseEvent) => any)[] = []; + private oncloseInternal: ((this: WebSocket, ev: CloseEvent) => any) | null = + null; + private onerrorListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onerrorInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private binaryTypeInternal: BinaryType = 'blob'; // Default binaryType + private onopenEvent: Event | null = null; + + constructor(url: string | URL, protocols?: string | string[]) { + super(url, protocols); + // Set the binaryType to 'arraybuffer' to handle the authentication process. + super.binaryType = 'arraybuffer'; + + // The open event listener should immediately send the authentication token + super.onopen = (onopenEvent: Event) => { + super.send(JSON.stringify({ token: getAccessToken() })); + // Don't call the user defined onopen messages yet, wait for the authentication response. + this.onopenEvent = onopenEvent; + }; + + // The message event listener should handle the authentication response, + // and if it succeeds, set the binaryType to the user-defined value and + // trigger any user-added open listeners. + super.onmessage = (ev: MessageEvent) => { + // If not yet authenticated, handle the authentication response. + if (!this.authenticated) { + // Parse the message as a WebsocketStatus. + let authResponse: WebsocketStatus; + try { + authResponse = JSON.parse(ev.data) as WebsocketStatus; + } catch (e) { + this.triggerError('Error parsing JSON from websocket message: ' + e); + return; + } + + // Validate the WebsocketStatus. + if ( + !authResponse.type || + !authResponse.status || + !(authResponse.type === 'create_session_response') || + !(authResponse.status === 'ok' || authResponse.status === 'error') + ) { + this.triggerError( + 'Invalid auth response: ' + JSON.stringify(authResponse) + ); + return; + } + + // Authentication succeeded. + if (authResponse.status === 'ok') { + this.authenticated = true; + // Set the binaryType to the value set by the user (or back to the default 'blob'). + super.binaryType = this.binaryTypeInternal; + // Now that authentication is complete, trigger any user-added open listeners + // with the original onopen event. + this.openListeners.forEach(listener => + listener.call(this, this.onopenEvent) + ); + this.onopenInternal?.call(this, this.onopenEvent); + return; + } else { + // Authentication failed, authResponse.status === 'error'. + this.triggerError( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + } else { + // If authenticated, pass messages to user-added listeners. + this.messageListeners.forEach(listener => { + listener.call(this, ev); + }); + this.onmessageInternal?.call(this, ev); + } + }; + + // Set the 'close' event for cleanup. + super.onclose = (ev: CloseEvent) => { + // Trigger any user-added close listeners + this.oncloseListeners.forEach(listener => listener.call(this, ev)); + this.oncloseInternal?.call(this, ev); + this.authenticated = false; + }; + + // Set the 'error' event for cleanup. + super.onerror = (ev: Event) => { + // Trigger any user-added error listeners + this.onerrorListeners.forEach(listener => listener.call(this, ev)); + this.onerrorInternal?.call(this, ev); + this.authenticated = false; + }; + } + + // Authenticated send + override send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + if (!this.authenticated) { + // This should be unreachable, but just in case. + this.triggerError( + 'Cannot send data before authentication is complete. Data: ' + data + ); + return; + } + super.send(data); + } + + // Override addEventListener to intercept these listeners and store them in + // our appropriate arrays. They are called in the appropriate places in the + // `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + override addEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + this.openListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + } else if (type === 'message') { + this.messageListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + } else if (type === 'close') { + this.oncloseListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + } else if (type === 'error') { + this.onerrorListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + } else { + // This should be unreachable, but just in case. + super.addEventListener(type, listener); + } + } + + // Override the onopen, onmessage, onclose, and onerror properties to store the user-defined + // listeners in the appropriate internal properties. These are called in the appropriate places + // in the `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + + override set onopen(listener: (this: WebSocket, ev: Event) => any | null) { + this.onopenInternal = listener; + } + + override get onopen(): ((this: WebSocket, ev: Event) => any) | null { + return this.onopenInternal; + } + + override set onmessage( + listener: ((this: WebSocket, ev: MessageEvent) => any) | null + ) { + this.onmessageInternal = listener; + } + + override get onmessage(): + | ((this: WebSocket, ev: MessageEvent) => any) + | null { + return this.onmessageInternal; + } + + override set onclose( + listener: ((this: WebSocket, ev: CloseEvent) => any) | null + ) { + this.oncloseInternal = listener; + } + + override get onclose(): ((this: WebSocket, ev: CloseEvent) => any) | null { + return this.oncloseInternal; + } + + override set onerror(listener: ((this: WebSocket, ev: Event) => any) | null) { + this.onerrorInternal = listener; + } + + override get onerror(): ((this: WebSocket, ev: Event) => any) | null { + return this.onerrorInternal; + } + + // Override the binaryType property to store the user-defined binaryType in the appropriate internal property. + // This is because we need to set the binaryType to 'arraybuffer' for the authentication process (see constructor), + // and only then can we set it to the user-defined value. + override set binaryType(binaryType: BinaryType) { + if (this.authenticated) { + super.binaryType = binaryType; + return; + } + + this.binaryTypeInternal = binaryType; + } + + override get binaryType(): BinaryType { + return this.binaryTypeInternal; + } + + // Override removeEventListener to support listeners removal for 'open', 'message', and 'close' events + override removeEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + const index = this.openListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + if (index !== -1) { + this.openListeners.splice(index, 1); + } + } else if (type === 'message') { + const index = this.messageListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + if (index !== -1) { + this.messageListeners.splice(index, 1); + } + } else if (type === 'close') { + const index = this.oncloseListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + if (index !== -1) { + this.oncloseListeners.splice(index, 1); + } + } else if (type === 'error') { + const index = this.onerrorListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + if (index !== -1) { + this.onerrorListeners.splice(index, 1); + } + } else { + // This should be unreachable, but just in case. + super.removeEventListener( + type, + listener as EventListenerOrEventListenerObject + ); + } + } + + // Method to manually trigger an error event. + private triggerError(errorMessage: string): void { + const errorEvent = new ErrorEvent('error', { + error: new Error(errorMessage), + message: errorMessage, + }); + + // Dispatch the event to trigger all listeners attached for 'error' events. + this.dispatchEvent(errorEvent); + } +} diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index fd59ba4b5337b..3e54d04ea53d2 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -15,6 +15,7 @@ import Logger from 'shared/libs/logger'; import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import Codec, { MessageType, @@ -63,12 +64,12 @@ export enum TdpClientEvent { } // Client is the TDP client. It is responsible for connecting to a websocket serving the tdp server, -// sending client commands, and recieving and processing server messages. Its creator is responsible for +// sending client commands, and receiving and processing server messages. Its creator is responsible for // ensuring the websocket gets closed and all of its event listeners cleaned up when it is no longer in use. // For convenience, this can be done in one fell swoop by calling Client.shutdown(). export default class Client extends EventEmitterWebAuthnSender { protected codec: Codec; - protected socket: WebSocket | undefined; + protected socket: AuthenticatedWebSocket | undefined; private socketAddr: string; private sdManager: SharedDirectoryManager; @@ -83,7 +84,7 @@ export default class Client extends EventEmitterWebAuthnSender { // Connect to the websocket and register websocket event handlers. init() { - this.socket = new WebSocket(this.socketAddr); + this.socket = new AuthenticatedWebSocket(this.socketAddr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = () => { diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 0e4f264f0116f..703282bdcc2de 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -18,6 +18,7 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -60,7 +61,7 @@ class Tty extends EventEmitterWebAuthnSender { connect(w: number, h: number) { const connStr = this._addressResolver.getConnStr(w, h); - this.socket = new WebSocket(connStr); + this.socket = new AuthenticatedWebSocket(connStr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = this._onOpenConnection; this.socket.onmessage = this._onMessage; diff --git a/web/packages/teleport/src/types.ts b/web/packages/teleport/src/types.ts index c5273ea800ef0..6d40f7b0fe191 100644 --- a/web/packages/teleport/src/types.ts +++ b/web/packages/teleport/src/types.ts @@ -189,3 +189,11 @@ export enum RecommendationStatus { Notify = 'NOTIFY', Done = 'DONE', } + +// WebsocketStatus is used to indicate the auth status from a +// websocket connection +export type WebsocketStatus = { + type: string; + status: string; + message?: string; +};