diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 0977e13cfc1d4..ff06a2b7e4a4f 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -40,6 +40,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/oxy/ratelimit" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -151,6 +152,11 @@ type Handler struct { // tracer is used to create spans. tracer oteltrace.Tracer + + // wsIODeadline is used to set a deadline for receiving a message from + // an authenticated websocket so unauthenticated sockets dont get left + // open. + wsIODeadline time.Duration } // HandlerOption is a functional argument - an option that can be passed @@ -365,6 +371,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, tracer: cfg.TracerProvider.Tracer(teleport.ComponentWeb), + wsIODeadline: wsIODeadline, } // Check for self-hosted vs Cloud. @@ -720,7 +727,10 @@ func (h *Handler) bindDefaultEndpoints() { h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock)) // active sessions handlers - h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket) + h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions // Audit events handlers. @@ -828,9 +838,17 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet)) h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle)) // GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=&username=&width=&height= - h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle)) + // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= + h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= - h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) + // Deprecated: The desktopplayback/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle)) + // // GET /webapi/sites/:site/desktopplayback/:sid/ws + h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) // GET a Connection Diagnostics by its name @@ -889,7 +907,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups)) // WebSocket endpoint for the chat conversation - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant)) + // WebSocket endpoint for the chat conversation, websocket auth + h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant)) // Sets the title for the conversation. h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) @@ -908,7 +930,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID)) // Allows executing an arbitrary command on multiple nodes. - h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand)) + // Deprecated: The execute/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand)) + // Allows executing an arbitrary command on multiple nodes, websocket auth. + h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand)) // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -2941,6 +2967,7 @@ func (h *Handler) siteNodeConnect( p httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { q := r.URL.Query() params := q.Get("params") @@ -3033,6 +3060,7 @@ func (h *Handler) siteNodeConnect( PROXYSigner: h.cfg.PROXYSigner, Tracker: tracker, PresenceChecker: h.cfg.PresenceChecker, + WebsocketConn: ws, } term, err := NewTerminal(ctx, terminalConfig) @@ -3731,6 +3759,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa // ClusterHandler is a authenticated handler that is called for some existing remote cluster type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) +// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster +type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error) + // WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy // (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster // or a remote trusted cluster) as specified by the ":site" url parameter. @@ -3745,12 +3776,108 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) { + if err == nil { + return + } + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: trace.UserMessage(err), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return + } +} + +// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket. +// This makes our lives easier in our automated tests. While ordinarily this would be +// used to enforce the same-origin policy, we don't need to worry about that for authenticated +// websockets, which also require a valid bearer token sent over the websocket after upgrade. +// Therefore even if an attacker were to connect to the websocket and trick the browser into +// sending the session cookie, they would still fail to send the bearer token needed to authenticate. +var authnWsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated +// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as +// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) +// as specified by the ":site" url parameter. +// +// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed +func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { + var sctx *SessionContext + var ws *websocket.Conn + var site reversetunnelclient.RemoteSite + var err error + + if websocketAuth { + sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p) + } else { + sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p) + } + + if err != nil { + return nil, trace.Wrap(err) + } + // WS protocol requires the server send a close message + // which should be done by downstream users + defer ws.Close() + if _, err := fn(w, r, p, sctx, site, ws); err != nil { + h.writeErrToWebSocket(ws, err) + } + return nil, nil + }) +} + +// authenticateWSRequestWithCluster ensures that a request is +// authenticated to this proxy via websocket, returning the +// *SessionContext (same as AuthenticateRequest), and also grabs the +// remoteSite (which can represent this local cluster or a remote +// trusted cluster) as specified by the ":site" url parameter. +func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, ws, err := h.AuthenticateRequestWS(w, r) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + site, err := h.getSiteByParams(sctx, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sctx, ws, site, nil +} + +// TODO(lxea): remove once the deprecated websocket handlers are removed +func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, site, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + return sctx, ws, site, nil +} + // authenticateRequestWithCluster ensures that a request is authenticated // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) { sctx, err := h.AuthenticateRequest(w, r, true) + if err != nil { return nil, nil, trace.Wrap(err) } @@ -4068,9 +4195,7 @@ func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { return trace.Wrap(err) } -// AuthenticateRequest authenticates request using combination of a session cookie -// and bearer token -func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { +func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) { const missingCookieMsg = "missing session cookie" cookie, err := r.Cookie(websession.CookieName) if err != nil || (cookie != nil && cookie.Value == "") { @@ -4085,6 +4210,17 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch clearSessionCookies((w)) return nil, trace.AccessDenied("need auth") } + + return sctx, nil +} + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token +func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, trace.Wrap(err) + } if checkBearerToken { creds, err := roundtrip.ParseAuthHeaders(r) if err != nil { @@ -4137,6 +4273,72 @@ func contextWithMFAResponseFromRequestHeader(ctx context.Context, requestHeader return ctx, nil } +type wsBearerToken struct { + Token string `json:"token"` +} + +type wsStatus struct { + Type string `json:"type"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// wsIODeadline is used to set a deadline for receiving a message from +// an authenticated websocket so unauthenticated sockets dont get left +// open. +const wsIODeadline = time.Second * 4 + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token retrieved from a websocket +func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + var t wsBearerToken + if err := ws.ReadJSON(&t); err != nil { + return nil, nil, trace.Wrap(err) + } + if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil { + writeErr := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }) + if writeErr != nil { + log.Errorf("Error while writing invalid token error to websocket: %s", writeErr) + } + + return nil, nil, trace.Wrap(err) + } + + if err := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "ok", + }); err != nil { + return nil, nil, trace.Wrap(err) + } + + // unset the deadline as downstream consumers should handle this themselves. + if err := ws.SetReadDeadline(time.Time{}); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + if err := parseMFAResponseFromRequest(r); err != nil { + return nil, nil, trace.Wrap(err) + } + + return sctx, ws, nil +} + // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions // of the given user. func (h *Handler) ProxyWithRoles(ctx *SessionContext) (reversetunnelclient.Tunnel, error) { diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 680ec4b727493..d47abc5c0a6b3 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -1355,14 +1355,25 @@ func TestSiteNodeConnectInvalidSessionID(t *testing.T) { ctx, cancel := context.WithCancel(s.ctx) t.Cleanup(cancel) - term, err := connectToHost(ctx, connectConfig{ + result := make(chan error) + + _, err := connectToHost(ctx, connectConfig{ pack: s.authPack(t, "foo"), host: s.node.ID(), proxy: s.webServer.Listener.Addr().String(), sessionID: "/../../../foo", + handlers: map[string]WSHandlerFunc{ + defaults.WebsocketError: func(ctx context.Context, e Envelope) { + if e.Payload == "/../../../foo is not a valid UUID" { + result <- errors.New(e.Payload) + } + close(result) + }, + }, }) - require.Error(t, err) - require.Nil(t, term) + require.NoError(t, err) + res := <-result + require.Error(t, res) } func TestResolveServerHostPort(t *testing.T) { @@ -1897,6 +1908,7 @@ func TestTerminal(t *testing.T) { host: s.node.ID(), proxy: s.webServer.Listener.Addr().String(), }) + require.NoError(t, err) t.Cleanup(func() { require.True(t, utils.IsOKNetworkError(term.Close())) }) @@ -8118,18 +8130,38 @@ func (r *testProxy) newClient(t *testing.T, opts ...roundtrip.ClientParam) *Test return &TestWebClient{clt, t} } +func makeAuthReqOverWS(ws *websocket.Conn, token string) error { + authReq, err := json.Marshal(struct { + Token string `json:"token"` + }{Token: token}) + if err != nil { + return trace.Wrap(err) + } + + if err := ws.WriteMessage(websocket.TextMessage, authReq); err != nil { + return trace.Wrap(err) + } + _, authRes, err := ws.ReadMessage() + if err != nil { + return trace.Wrap(err) + } + if !strings.Contains(string(authRes), `"status":"ok"`) { + return trace.AccessDenied("unexpected response") + } + return nil +} + func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, - Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect", currentSiteShortcut, "desktop1"), + Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect/ws", currentSiteShortcut, "desktop1"), } q := u.Query() q.Set("username", "marek") q.Set("width", "100") q.Set("height", "100") - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -8144,6 +8176,10 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID s ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) + + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) @@ -9135,6 +9171,111 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube }, nil } +func TestWebSocketAuthenticateRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + env := newWebPack(t, 1) + proxy := env.proxies[0] + proxy.handler.handler.wsIODeadline = time.Second + pack := proxy.authPack(t, "test-user@example.com", nil) + for _, tc := range []struct { + name string + serverExpectError string + expectResponse wsStatus + token string + writeTimeout func() + readTimeout func() + }{ + { + name: "valid token", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "ok", + }, + token: pack.session.Token, + }, + { + name: "invalid token", + serverExpectError: "not found", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }, + token: "honk", + }, + { + name: "server read timeout", + serverExpectError: "i/o timeout", + token: pack.session.Token, + readTimeout: func() { + <-time.After(wsIODeadline * 3) + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sctx, ws, err := proxy.handler.handler.AuthenticateRequestWS(w, r) + if err != nil { + if tc.serverExpectError == "" { + t.Errorf("unexpected error: %v", err) + } + if !strings.Contains(err.Error(), tc.serverExpectError) { + t.Errorf("unexpected error: %v", err) + return + } + return + } + t.Cleanup(func() { ws.Close() }) + if err == nil && tc.serverExpectError != "" { + t.Errorf("expected error, got nil") + return + } + + clt, err := sctx.GetClient() + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, err = clt.GetDomainName(ctx) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + })) + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + u := strings.Replace(server.URL, "http:", "ws:", 1) + conn, resp, err := websocket.DefaultDialer.Dial(u, header) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + t.Cleanup(func() { resp.Body.Close() }) + + if tc.readTimeout != nil { + tc.readTimeout() + } + err = conn.WriteJSON(wsBearerToken{ + Token: tc.token, + }) + require.NoError(t, err) + if tc.readTimeout != nil { + return // Reading will fail as the server will have closed the connection + } + + var status wsStatus + err = conn.ReadJSON(&status) + require.NoError(t, err) + require.Equal(t, tc.expectResponse, status) + }) + } +} + // TestSimultaneousAuthenticateRequest ensures that multiple authenticated // requests do not race to create a SessionContext. This would happen when // Proxies were deployed behind a round-robin load balancer. Only the Proxy diff --git a/lib/web/assistant.go b/lib/web/assistant.go index fca5fde9c4e19..da1caae195a17 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -332,9 +332,9 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, // This handler covers the main chat conversation as well as the // SSH completition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (any, error) { - if err := runAssistant(h, w, r, sctx, site); err != nil { + if err := runAssistant(h, w, r, sctx, site, ws); err != nil { h.log.Warn(trace.DebugReport(err)) return nil, trace.Wrap(err) } @@ -420,7 +420,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") @@ -455,20 +455,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil - } - // Note: This time should be longer than OpenAI response time. keepAliveInterval := netConfig.GetKeepAliveInterval() err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) diff --git a/lib/web/command.go b/lib/web/command.go index 9017b63ed1b34..bae2a66247b04 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -128,6 +128,7 @@ func (h *Handler) executeCommand( _ httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + rawWS *websocket.Conn, ) (any, error) { q := r.URL.Query() params := q.Get("params") @@ -171,20 +172,6 @@ func (h *Handler) executeCommand( clusterName := site.GetName() - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - rawWS, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil, nil - } - defer func() { rawWS.WriteMessage(websocket.CloseMessage, nil) rawWS.Close() diff --git a/lib/web/desktop.go b/lib/web/desktop.go index a89bcf01b17c3..abd663d897c1c 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -66,6 +66,7 @@ func (h *Handler) desktopConnectHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { desktopName := p.ByName("desktopName") if desktopName == "" { @@ -75,7 +76,7 @@ func (h *Handler) desktopConnectHandle( log := sctx.cfg.Log.WithField("desktop-name", desktopName).WithField("cluster-name", site.GetName()) log.Debug("New desktop access websocket connection") - if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site); err != nil { + if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site, ws); err != nil { // createDesktopConnection makes a best effort attempt to send an error to the user // (via websocket) before terminating the connection. We log the error here, but // return nil because our HTTP middleware will try to write the returned error in JSON @@ -94,15 +95,8 @@ func (h *Handler) createDesktopConnection( log *logrus.Entry, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) error { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return trace.Wrap(err) - } defer ws.Close() sendTDPError := func(err error) error { diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index df04c330eed7f..9c50cdcc153c7 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -38,6 +38,7 @@ func (h *Handler) desktopPlaybackHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { @@ -49,16 +50,6 @@ func (h *Handler) desktopPlaybackHandle( return nil, trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 4096, - WriteBufferSize: 4096, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return nil, trace.Wrap(err) - } - defer ws.Close() - player, err := player.New(&player.Config{ Clock: h.clock, Log: h.log, diff --git a/lib/web/terminal.go b/lib/web/terminal.go index f8c6ebb17882f..3c40c7a18c15c 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -142,6 +142,7 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl participantMode: cfg.ParticipantMode, tracker: cfg.Tracker, presenceChecker: cfg.PresenceChecker, + websocketConn: cfg.WebsocketConn, }, nil } @@ -191,6 +192,8 @@ type TerminalHandlerConfig struct { PresenceChecker PresenceChecker // Clock allows interaction with time. Clock clockwork.Clock + // WebsocketConn is the active websocket connection + WebsocketConn *websocket.Conn } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { @@ -317,6 +320,9 @@ type TerminalHandler struct { // clock used to interact with time. clock clockwork.Clock + + // websocketConn is the active websocket connection + websocketConn *websocket.Conn } // ServeHTTP builds a connection to the remote node and then pumps back two types of @@ -328,21 +334,9 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.ctx.AddClosers(t) defer t.ctx.RemoveCloser(t) - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - t.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return - } + ws := t.websocketConn - err = ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) + err := ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) if err != nil { t.log.WithError(err).Error("Error setting websocket readline") return diff --git a/lib/web/terminal_test.go b/lib/web/terminal_test.go index ce8996308ed53..2e7edf5150f6f 100644 --- a/lib/web/terminal_test.go +++ b/lib/web/terminal_test.go @@ -33,7 +33,6 @@ import ( "github.com/gogo/protobuf/proto" "github.com/gorilla/websocket" - "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -127,12 +126,11 @@ func connectToHost(ctx context.Context, cfg connectConfig) (*terminal, error) { u := url.URL{ Host: cfg.proxy, Scheme: client.WSS, - Path: "/v1/webapi/sites/-current-/connect", + Path: "/v1/webapi/sites/-current-/connect/ws", } q := u.Query() q.Set("params", string(data)) - q.Set(roundtrip.AccessTokenQueryParam, cfg.pack.session.Token) u.RawQuery = q.Encode() header := http.Header{} @@ -162,6 +160,10 @@ func connectToHost(ctx context.Context, cfg connectConfig) (*terminal, error) { return nil, trace.Wrap(err) } + if err := makeAuthReqOverWS(ws, cfg.pack.session.Token); err != nil { + return nil, trace.Wrap(err) + } + if cfg.pingHandler != nil { ws.SetPingHandler(func(message string) error { return cfg.pingHandler(ws, message) diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 15dbf01f5068d..fb9319285701d 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -31,7 +31,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state'; import { convertServerMessages } from 'teleport/Assist/context/utils'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import { AccessRequestClientMessage, @@ -48,6 +48,7 @@ import { makeMfaAuthenticateChallenge, WebauthnAssertionResponse, } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import * as service from '../service'; import { @@ -84,9 +85,9 @@ let lastCommandExecutionResultId = 0; const TEN_MINUTES = 10 * 60 * 1000; export function AssistContextProvider(props: PropsWithChildren) { - const activeWebSocket = useRef(null); + const activeWebSocket = useRef(null); // TODO(ryan): this should be removed once https://github.com/gravitational/teleport.e/pull/1609 is implemented - const executeCommandWebSocket = useRef(null); + const executeCommandWebSocket = useRef(null); const refreshWebSocketTimeout = useRef(null); const { clusterId } = useStickyClusterId(); @@ -124,11 +125,10 @@ export function AssistContextProvider(props: PropsWithChildren) { } function setupWebSocket(conversationId: string, initialMessage?: string) { - activeWebSocket.current = new WebSocket( + activeWebSocket.current = new AuthenticatedWebSocket( cfg.getAssistConversationWebSocketUrl( getHostName(), clusterId, - getAccessToken(), conversationId ) ); @@ -350,7 +350,7 @@ export function AssistContextProvider(props: PropsWithChildren) { if ( !activeWebSocket.current || - activeWebSocket.current.readyState === WebSocket.CLOSED + activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED ) { setupWebSocket(state.conversations.selectedId, data); } else { @@ -380,7 +380,8 @@ export function AssistContextProvider(props: PropsWithChildren) { function sendMfaChallenge(data: WebauthnAssertionResponse) { if ( !executeCommandWebSocket.current || - executeCommandWebSocket.current.readyState !== WebSocket.OPEN || + executeCommandWebSocket.current.readyState !== + AuthenticatedWebSocket.OPEN || !data ) { console.warn( @@ -448,12 +449,11 @@ export function AssistContextProvider(props: PropsWithChildren) { const url = cfg.getAssistExecuteCommandUrl( getHostName(), clusterId, - getAccessToken(), execParams ); const proto = new Protobuf(); - executeCommandWebSocket.current = new WebSocket(url); + executeCommandWebSocket.current = new AuthenticatedWebSocket(url); executeCommandWebSocket.current.binaryType = 'arraybuffer'; executeCommandWebSocket.current.onmessage = event => { diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 9bcd78ce43027..ebdec9bfcf4dc 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -26,7 +26,7 @@ import React, { } from 'react'; import { Author, ServerMessage } from 'teleport/Assist/types'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { @@ -36,6 +36,7 @@ import { SuggestedCommandMessage, UserMessage, } from 'teleport/Console/DocumentSsh/TerminalAssist/types'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; interface TerminalAssistContextValue { close: () => void; @@ -57,11 +58,10 @@ export function TerminalAssistContextProvider( const [visible, setVisible] = useState(false); - const socketRef = useRef(null); + const socketRef = useRef(null); const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-cmdgen' ); @@ -72,7 +72,7 @@ export function TerminalAssistContextProvider( const [messages, setMessages] = useState([]); useEffect(() => { - socketRef.current = new WebSocket(socketUrl); + socketRef.current = new AuthenticatedWebSocket(socketUrl); socketRef.current.onmessage = e => { const data = JSON.parse(e.data) as ServerMessage; @@ -117,11 +117,10 @@ export function TerminalAssistContextProvider( const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-explain' ); - const ws = new WebSocket(socketUrl); + const ws = new AuthenticatedWebSocket(socketUrl); ws.onopen = () => { ws.send(encodedOutput); diff --git a/web/packages/teleport/src/Console/consoleContext.tsx b/web/packages/teleport/src/Console/consoleContext.tsx index e696f18602e0e..76ba3013d7cac 100644 --- a/web/packages/teleport/src/Console/consoleContext.tsx +++ b/web/packages/teleport/src/Console/consoleContext.tsx @@ -24,7 +24,7 @@ import { W3CTraceContextPropagator } from '@opentelemetry/core'; import webSession from 'teleport/services/websession'; import history from 'teleport/services/history'; import cfg, { UrlResourcesParams, UrlSshParams } from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import Tty from 'teleport/lib/term/tty'; import TtyAddressResolver from 'teleport/lib/term/ttyAddressResolver'; import serviceSession, { @@ -197,7 +197,6 @@ export default class ConsoleContext { const ttyUrl = cfg.api.ttyWsAddr .replace(':fqdn', getHostName()) - .replace(':token', getAccessToken()) .replace(':clusterId', clusterId) .replace(':traceparent', carrier['traceparent']); diff --git a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx index 74153145af2b9..e5f3c986a3f21 100644 --- a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx +++ b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx @@ -30,7 +30,7 @@ import { PngFrame, SyncKeys, } from 'teleport/lib/tdp/codec'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import cfg from 'teleport/config'; import { Sha256Digest } from 'teleport/lib/util'; @@ -85,7 +85,6 @@ export default function useTdpClientCanvas(props: Props) { .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':desktopName', desktopName) - .replace(':token', getAccessToken()) .replace(':username', username); setTdpClient(new TdpClient(addr)); diff --git a/web/packages/teleport/src/Player/DesktopPlayer.tsx b/web/packages/teleport/src/Player/DesktopPlayer.tsx index c819d7e6581c8..8262835db060c 100644 --- a/web/packages/teleport/src/Player/DesktopPlayer.tsx +++ b/web/packages/teleport/src/Player/DesktopPlayer.tsx @@ -23,7 +23,7 @@ import { Indicator, Box, Alert, Flex } from 'design'; import cfg from 'teleport/config'; import { StatusEnum, formatDisplayTime } from 'teleport/lib/player'; import { PlayerClient, TdpClient } from 'teleport/lib/tdp'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import TdpClientCanvas from 'teleport/components/TdpClientCanvas'; import ProgressBar from './ProgressBar'; @@ -157,8 +157,7 @@ const useDesktopPlayer = ({ clusterId, sid }) => { const url = cfg.api.desktopPlaybackWsAddr .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) - .replace(':sid', sid) - .replace(':token', getAccessToken()); + .replace(':sid', sid); return new PlayerClient({ url, setTime, setPlayerStatus, setStatusText }); }, [clusterId, sid, setTime, setPlayerStatus]); diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index ac3521c95bb86..c8d0748193fd5 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -196,12 +196,12 @@ const cfg = { desktopServicesPath: `/v1/webapi/sites/:clusterId/desktopservices?searchAsRoles=:searchAsRoles?&limit=:limit?&startKey=:startKey?&query=:query?&search=:search?&sort=:sort?`, desktopPath: `/v1/webapi/sites/:clusterId/desktops/:desktopName`, desktopWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect?access_token=:token&username=:username', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect/ws?username=:username', desktopPlaybackWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid?access_token=:token', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid/ws', desktopIsActive: '/v1/webapi/sites/:clusterId/desktops/:desktopName/active', ttyWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?access_token=:token¶ms=:params&traceparent=:traceparent', + 'wss://:fqdn/v1/webapi/sites/:clusterId/connect/ws?params=:params&traceparent=:traceparent', ttyPlaybackWsAddr: 'wss://:fqdn/v1/webapi/sites/:clusterId/ttyplayback/:sid?access_token=:token', // TODO(zmb3): get token out of URL activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', @@ -310,11 +310,11 @@ const cfg = { '/v1/webapi/assistant/conversations/:conversationId/title', assistGenerateSummaryPath: '/v1/webapi/assistant/title/summary', assistConversationWebSocketPath: - 'wss://:hostname/v1/webapi/sites/:clusterId/assistant', + 'wss://:hostname/v1/webapi/sites/:clusterId/assistant/ws', assistConversationHistoryPath: '/v1/webapi/assistant/conversations/:conversationId', assistExecuteCommandWebSocketPath: - 'wss://:hostname/v1/webapi/command/:clusterId/execute', + 'wss://:hostname/v1/webapi/command/:clusterId/execute/ws', userPreferencesPath: '/v1/webapi/user/preferences', userClusterPreferencesPath: '/v1/webapi/user/preferences/:clusterId', @@ -857,12 +857,10 @@ const cfg = { getAssistConversationWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, conversationId: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('conversation_id', conversationId); return ( @@ -876,12 +874,10 @@ const cfg = { getAssistActionWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, action: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('action', action); return ( @@ -901,12 +897,10 @@ const cfg = { getAssistExecuteCommandUrl( hostname: string, clusterId: string, - accessToken: string, params: Record ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('params', JSON.stringify(params)); return ( diff --git a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts new file mode 100644 index 0000000000000..4c1d0c4e5e281 --- /dev/null +++ b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts @@ -0,0 +1,279 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { getAccessToken } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + +/** + * `AuthenticatedWebSocket` is a drop-in replacement for + * the `WebSocket` class that handles Teleport's websocket + * authentication process. + */ +export class AuthenticatedWebSocket extends WebSocket { + private authenticated: boolean = false; + private openListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onopenInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private messageListeners: ((this: WebSocket, ev: MessageEvent) => any)[] = []; + private onmessageInternal: + | ((this: WebSocket, ev: MessageEvent) => any) + | null = null; + private oncloseListeners: ((this: WebSocket, ev: CloseEvent) => any)[] = []; + private oncloseInternal: ((this: WebSocket, ev: CloseEvent) => any) | null = + null; + private onerrorListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onerrorInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private binaryTypeInternal: BinaryType = 'blob'; // Default binaryType + private onopenEvent: Event | null = null; + + constructor(url: string | URL, protocols?: string | string[]) { + super(url, protocols); + // Set the binaryType to 'arraybuffer' to handle the authentication process. + super.binaryType = 'arraybuffer'; + + // The open event listener should immediately send the authentication token + super.onopen = (onopenEvent: Event) => { + super.send(JSON.stringify({ token: getAccessToken() })); + // Don't call the user defined onopen messages yet, wait for the authentication response. + this.onopenEvent = onopenEvent; + }; + + // The message event listener should handle the authentication response, + // and if it succeeds, set the binaryType to the user-defined value and + // trigger any user-added open listeners. + super.onmessage = (ev: MessageEvent) => { + // If not yet authenticated, handle the authentication response. + if (!this.authenticated) { + // Parse the message as a WebsocketStatus. + let authResponse: WebsocketStatus; + try { + authResponse = JSON.parse(ev.data) as WebsocketStatus; + } catch (e) { + this.triggerError('Error parsing JSON from websocket message: ' + e); + return; + } + + // Validate the WebsocketStatus. + if ( + !authResponse.type || + !authResponse.status || + !(authResponse.type === 'create_session_response') || + !(authResponse.status === 'ok' || authResponse.status === 'error') + ) { + this.triggerError( + 'Invalid auth response: ' + JSON.stringify(authResponse) + ); + return; + } + + // Authentication succeeded. + if (authResponse.status === 'ok') { + this.authenticated = true; + // Set the binaryType to the value set by the user (or back to the default 'blob'). + super.binaryType = this.binaryTypeInternal; + // Now that authentication is complete, trigger any user-added open listeners + // with the original onopen event. + this.openListeners.forEach(listener => + listener.call(this, this.onopenEvent) + ); + this.onopenInternal?.call(this, this.onopenEvent); + return; + } else { + // Authentication failed, authResponse.status === 'error'. + this.triggerError( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + } else { + // If authenticated, pass messages to user-added listeners. + this.messageListeners.forEach(listener => { + listener.call(this, ev); + }); + this.onmessageInternal?.call(this, ev); + } + }; + + // Set the 'close' event for cleanup. + super.onclose = (ev: CloseEvent) => { + // Trigger any user-added close listeners + this.oncloseListeners.forEach(listener => listener.call(this, ev)); + this.oncloseInternal?.call(this, ev); + this.authenticated = false; + }; + + // Set the 'error' event for cleanup. + super.onerror = (ev: Event) => { + // Trigger any user-added error listeners + this.onerrorListeners.forEach(listener => listener.call(this, ev)); + this.onerrorInternal?.call(this, ev); + this.authenticated = false; + }; + } + + // Authenticated send + override send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + if (!this.authenticated) { + // This should be unreachable, but just in case. + this.triggerError( + 'Cannot send data before authentication is complete. Data: ' + data + ); + return; + } + super.send(data); + } + + // Override addEventListener to intercept these listeners and store them in + // our appropriate arrays. They are called in the appropriate places in the + // `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + override addEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + this.openListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + } else if (type === 'message') { + this.messageListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + } else if (type === 'close') { + this.oncloseListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + } else if (type === 'error') { + this.onerrorListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + } else { + // This should be unreachable, but just in case. + super.addEventListener(type, listener); + } + } + + // Override the onopen, onmessage, onclose, and onerror properties to store the user-defined + // listeners in the appropriate internal properties. These are called in the appropriate places + // in the `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + + override set onopen(listener: (this: WebSocket, ev: Event) => any | null) { + this.onopenInternal = listener; + } + + override get onopen(): ((this: WebSocket, ev: Event) => any) | null { + return this.onopenInternal; + } + + override set onmessage( + listener: ((this: WebSocket, ev: MessageEvent) => any) | null + ) { + this.onmessageInternal = listener; + } + + override get onmessage(): + | ((this: WebSocket, ev: MessageEvent) => any) + | null { + return this.onmessageInternal; + } + + override set onclose( + listener: ((this: WebSocket, ev: CloseEvent) => any) | null + ) { + this.oncloseInternal = listener; + } + + override get onclose(): ((this: WebSocket, ev: CloseEvent) => any) | null { + return this.oncloseInternal; + } + + override set onerror(listener: ((this: WebSocket, ev: Event) => any) | null) { + this.onerrorInternal = listener; + } + + override get onerror(): ((this: WebSocket, ev: Event) => any) | null { + return this.onerrorInternal; + } + + // Override the binaryType property to store the user-defined binaryType in the appropriate internal property. + // This is because we need to set the binaryType to 'arraybuffer' for the authentication process (see constructor), + // and only then can we set it to the user-defined value. + override set binaryType(binaryType: BinaryType) { + if (this.authenticated) { + super.binaryType = binaryType; + return; + } + + this.binaryTypeInternal = binaryType; + } + + override get binaryType(): BinaryType { + return this.binaryTypeInternal; + } + + // Override removeEventListener to support listeners removal for 'open', 'message', and 'close' events + override removeEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + const index = this.openListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + if (index !== -1) { + this.openListeners.splice(index, 1); + } + } else if (type === 'message') { + const index = this.messageListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + if (index !== -1) { + this.messageListeners.splice(index, 1); + } + } else if (type === 'close') { + const index = this.oncloseListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + if (index !== -1) { + this.oncloseListeners.splice(index, 1); + } + } else if (type === 'error') { + const index = this.onerrorListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + if (index !== -1) { + this.onerrorListeners.splice(index, 1); + } + } else { + // This should be unreachable, but just in case. + super.removeEventListener( + type, + listener as EventListenerOrEventListenerObject + ); + } + } + + // Method to manually trigger an error event. + private triggerError(errorMessage: string): void { + const errorEvent = new ErrorEvent('error', { + error: new Error(errorMessage), + message: errorMessage, + }); + + // Dispatch the event to trigger all listeners attached for 'error' events. + this.dispatchEvent(errorEvent); + } +} diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index f2584c95eadea..64b66dde7dffd 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -25,6 +25,7 @@ import init, { import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import Codec, { MessageType, @@ -85,12 +86,12 @@ export enum LogType { } // Client is the TDP client. It is responsible for connecting to a websocket serving the tdp server, -// sending client commands, and recieving and processing server messages. Its creator is responsible for +// sending client commands, and receiving and processing server messages. Its creator is responsible for // ensuring the websocket gets closed and all of its event listeners cleaned up when it is no longer in use. // For convenience, this can be done in one fell swoop by calling Client.shutdown(). export default class Client extends EventEmitterWebAuthnSender { protected codec: Codec; - protected socket: WebSocket | undefined; + protected socket: AuthenticatedWebSocket | undefined; private socketAddr: string; private sdManager: SharedDirectoryManager; private fastPathProcessor: FastPathProcessor | undefined; @@ -114,7 +115,7 @@ export default class Client extends EventEmitterWebAuthnSender { async connect(spec?: ClientScreenSpec) { await this.initWasm(); - this.socket = new WebSocket(this.socketAddr); + this.socket = new AuthenticatedWebSocket(this.socketAddr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = () => { diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 88a18bcc5d246..fe45eb930d65a 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -20,6 +20,7 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -62,7 +63,7 @@ class Tty extends EventEmitterWebAuthnSender { connect(w: number, h: number) { const connStr = this._addressResolver.getConnStr(w, h); - this.socket = new WebSocket(connStr); + this.socket = new AuthenticatedWebSocket(connStr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = this._onOpenConnection; this.socket.onmessage = this._onMessage; diff --git a/web/packages/teleport/src/types.ts b/web/packages/teleport/src/types.ts index 3aabdcec6a52a..144db28946953 100644 --- a/web/packages/teleport/src/types.ts +++ b/web/packages/teleport/src/types.ts @@ -197,3 +197,11 @@ export enum RecommendationStatus { Notify = 'NOTIFY', Done = 'DONE', } + +// WebsocketStatus is used to indicate the auth status from a +// websocket connection +export type WebsocketStatus = { + type: string; + status: string; + message?: string; +};