diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 7d97222606483..56f7dbde8d909 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -25,6 +25,7 @@ import ( "errors" "fmt" "io" + "io/fs" "os" "path/filepath" "sort" @@ -985,7 +986,9 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID if rmErr := os.Remove(tarballPath); rmErr != nil { l.log.WithError(rmErr).Warningf("Failed to remove file %v.", tarballPath) } - + if errors.Is(err, fs.ErrNotExist) { + err = trace.NotFound("a recording for session %v was not found", sessionID) + } e <- trace.Wrap(err) return c, e } @@ -1003,7 +1006,7 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID for { if ctx.Err() != nil { e <- trace.Wrap(ctx.Err()) - break + return } event, err := protoReader.Read(ctx) @@ -1013,12 +1016,16 @@ func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID } else { close(c) } - - break + return } if event.GetIndex() >= startIndex { - c <- event + select { + case c <- event: + case <-ctx.Done(): + e <- trace.Wrap(ctx.Err()) + return + } } } }() diff --git a/lib/player/player.go b/lib/player/player.go new file mode 100644 index 0000000000000..2e225a2136ac0 --- /dev/null +++ b/lib/player/player.go @@ -0,0 +1,332 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package player includes an API to play back recorded sessions. +package player + +import ( + "context" + "errors" + "math" + "sync/atomic" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/session" +) + +// Player is used to stream recorded sessions over a channel. +type Player struct { + // read only config fields + clock clockwork.Clock + log logrus.FieldLogger + sessionID session.ID + streamer Streamer + + speed atomic.Value // playback speed (1.0 for normal speed) + lastPlayed atomic.Int64 // timestamp of most recently played event + + // advanceTo is used to implement fast-forward and rewind. + // During normal operation, it is set to [normalPlayback]. + // + // When set to a positive value the player is seeking forward + // in time (and plays events as quickly as possible). + // + // When set to a negative value, the player needs to "rewind" + // by starting the stream over from the beginning and then + // seeking forward to the rewind point. + advanceTo atomic.Int64 + + emit chan events.AuditEvent + done chan struct{} + + // playPause holds a channel to be closed when + // the player transitions from paused to playing, + // or nil if the player is already playing. + // + // This approach mimics a "select-able" condition variable + // and is inspired by "Rethinking Classical Concurrency Patterns" + // by Bryan C. Mills (GopherCon 2018): https://www.youtube.com/watch?v=5zXAHh5tJqQ + playPause chan chan struct{} + + // err holds the error (if any) encountered during playback + err error +} + +const normalPlayback = math.MinInt64 + +// Streamer is the underlying streamer that provides +// access to recorded session events. +type Streamer interface { + StreamSessionEvents( + ctx context.Context, + sessionID session.ID, + startIndex int64, + ) (chan events.AuditEvent, chan error) +} + +// Config configures a session player. +type Config struct { + Clock clockwork.Clock + Log logrus.FieldLogger + SessionID session.ID + Streamer Streamer +} + +func New(cfg *Config) (*Player, error) { + if cfg.Streamer == nil { + return nil, trace.BadParameter("missing Streamer") + } + + if cfg.SessionID == "" { + return nil, trace.BadParameter("missing SessionID") + } + + clk := cfg.Clock + if clk == nil { + clk = clockwork.NewRealClock() + } + + var log logrus.FieldLogger = cfg.Log + if log == nil { + log = logrus.New().WithField(trace.Component, "player") + } + + p := &Player{ + clock: clk, + log: log, + sessionID: cfg.SessionID, + streamer: cfg.Streamer, + emit: make(chan events.AuditEvent, 64), + playPause: make(chan chan struct{}, 1), + done: make(chan struct{}), + } + + p.speed.Store(float64(defaultPlaybackSpeed)) + p.advanceTo.Store(normalPlayback) + + // start in a paused state + p.playPause <- make(chan struct{}) + + go p.stream() + + return p, nil +} + +// errClosed is an internal error that is used to signal +// that the player has been closed +var errClosed = errors.New("player closed") + +const ( + minPlaybackSpeed = 0.25 + defaultPlaybackSpeed = 1.0 + maxPlaybackSpeed = 16 +) + +// SetSpeed adjusts the playback speed of the player. +// It can be called at any time (the player can be in a playing +// or paused state). A speed of 1.0 plays back at regular speed, +// while a speed of 2.0 plays back twice as fast as originally +// recorded. Valid speeds range from 0.25 to 16.0. +func (p *Player) SetSpeed(s float64) error { + if s < minPlaybackSpeed || s > maxPlaybackSpeed { + return trace.BadParameter("speed %v is out of range", s) + } + p.speed.Store(s) + return nil +} + +func (p *Player) stream() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + eventsC, errC := p.streamer.StreamSessionEvents(ctx, p.sessionID, 0) + lastDelay := int64(0) + for { + select { + case <-p.done: + close(p.emit) + return + case err := <-errC: + p.log.Warn(err) + p.err = err + close(p.emit) + return + case evt := <-eventsC: + if evt == nil { + p.log.Debugf("reached end of playback for session %v", p.sessionID) + close(p.emit) + return + } + + if err := p.waitWhilePaused(); err != nil { + p.log.Warn(err) + close(p.emit) + return + } + + currentDelay := getDelay(evt) + if currentDelay > 0 && currentDelay > lastDelay { + switch adv := p.advanceTo.Load(); { + case adv >= currentDelay: + // no timing delay necessary, we are fast forwarding + break + case adv < 0 && adv != normalPlayback: + // any negative value other than normalPlayback means + // we rewind (by restarting the stream and seeking forward + // to the rewind point) + p.advanceTo.Store(adv * -1) + go p.stream() + return + default: + if adv != normalPlayback { + p.advanceTo.Store(normalPlayback) + + // we're catching back up to real time, so the delay + // is calculated not from the last event but from the + // time we were advanced to + lastDelay = adv + } + if err := p.applyDelay(time.Duration(currentDelay-lastDelay) * time.Millisecond); err != nil { + close(p.emit) + return + } + } + + lastDelay = currentDelay + } + + select { + case p.emit <- evt: + p.lastPlayed.Store(currentDelay) + default: + p.log.Warnf("dropped event %v, reader too slow", evt.GetID()) + } + } + } +} + +// Close shuts down the player and cancels any streams that are +// in progress. +func (p *Player) Close() error { + close(p.done) + return nil +} + +// C returns a read only channel of recorded session events. +// The player manages the timing of events and writes them to the channel +// when they should be rendered. The channel is closed when the player +// has reached the end of playback. +func (p *Player) C() <-chan events.AuditEvent { + return p.emit +} + +// Err returns the error (if any) that occurred during playback. +// It should only be called after the channel returned by [C] is +// closed. +func (p *Player) Err() error { + return p.err +} + +// Pause temporarily stops the player from emitting events. +// It is a no-op if playback is currently paused. +func (p *Player) Pause() error { + p.setPlaying(false) + return nil +} + +// Play starts emitting events. It is used to start playback +// for the first time and to resume playing after the player +// is paused. +func (p *Player) Play() error { + p.setPlaying(true) + return nil +} + +// SetPos sets playback to a specific time, expressed as a duration +// from the beginning of the session. A duration of 0 restarts playback +// from the beginning. A duration greater than the length of the session +// will cause playback to rapidly advance to the end of the recording. +func (p *Player) SetPos(d time.Duration) error { + if d.Milliseconds() < p.lastPlayed.Load() { + // if we're rewinding we store a negative value + d = -1 * d + } + p.advanceTo.Store(d.Milliseconds()) + return nil +} + +// applyDelay "sleeps" for d in a manner that +// can be canceled +func (p *Player) applyDelay(d time.Duration) error { + scaled := float64(d) / p.speed.Load().(float64) + select { + case <-p.done: + return errClosed + case <-p.clock.After(time.Duration(scaled)): + return nil + } +} + +func (p *Player) setPlaying(play bool) { + ch := <-p.playPause + alreadyPlaying := ch == nil + + if alreadyPlaying && !play { + ch = make(chan struct{}) + } else if !alreadyPlaying && play { + // signal waiters who are paused that it's time to resume playing + close(ch) + ch = nil + } + + p.playPause <- ch +} + +// waitWhilePaused blocks while the player is in a paused state. +// It returns immediately if the player is currently playing. +func (p *Player) waitWhilePaused() error { + ch := <-p.playPause + p.playPause <- ch + + if alreadyPlaying := ch == nil; !alreadyPlaying { + select { + case <-p.done: + return errClosed + case <-ch: + } + } + return nil +} + +// LastPlayed returns the time of the last played event, +// expressed as milliseconds since the start of the session. +func (p *Player) LastPlayed() int64 { + return p.lastPlayed.Load() +} + +func getDelay(e events.AuditEvent) int64 { + switch x := e.(type) { + case *events.DesktopRecording: + return x.DelayMilliseconds + case *events.SessionPrint: + return x.DelayMilliseconds + default: + return int64(0) + } +} diff --git a/lib/player/player_test.go b/lib/player/player_test.go new file mode 100644 index 0000000000000..8ac04cb632928 --- /dev/null +++ b/lib/player/player_test.go @@ -0,0 +1,279 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package player_test + +import ( + "context" + "fmt" + "strconv" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/player" + "github.com/gravitational/teleport/lib/session" +) + +func TestBasicStream(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 3}, + }) + require.NoError(t, err) + + require.NoError(t, p.Play()) + + count := 0 + for range p.C() { + count++ + } + + require.Equal(t, 3, count) + require.NoError(t, p.Err()) +} + +func TestPlayPause(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 3}, + }) + require.NoError(t, err) + + // pausing an already paused player should be a no-op + require.NoError(t, p.Pause()) + require.NoError(t, p.Pause()) + + // toggling back and forth between play and pause + // should not impact our ability to receive all + // 3 events + require.NoError(t, p.Play()) + require.NoError(t, p.Pause()) + require.NoError(t, p.Play()) + + count := 0 + for range p.C() { + count++ + } + + require.Equal(t, 3, count) +} + +func TestAppliesTiming(t *testing.T) { + for _, test := range []struct { + desc string + speed float64 + advance time.Duration + }{ + { + desc: "half speed", + speed: 0.5, + advance: 2000 * time.Millisecond, + }, + { + desc: "normal speed", + speed: 1.0, + advance: 1000 * time.Millisecond, + }, + { + desc: "double speed", + speed: 2.0, + advance: 500 * time.Millisecond, + }, + } { + t.Run(test.desc, func(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 3, delay: 1000}, + }) + require.NoError(t, err) + + require.NoError(t, p.SetSpeed(test.speed)) + require.NoError(t, p.Play()) + + clk.BlockUntil(1) // player is now waiting to emit event 0 + + // advance to next event (player will have emitted event 0 + // and will be waiting to emit event 1) + clk.Advance(test.advance) + clk.BlockUntil(1) + evt := <-p.C() + require.Equal(t, int64(0), evt.GetIndex()) + + // repeat the process (emit event 1, wait for event 2) + clk.Advance(test.advance) + clk.BlockUntil(1) + evt = <-p.C() + require.Equal(t, int64(1), evt.GetIndex()) + + // advance the player to allow event 2 to be emitted + clk.Advance(test.advance) + evt = <-p.C() + require.Equal(t, int64(2), evt.GetIndex()) + + // channel should be closed + _, ok := <-p.C() + require.False(t, ok, "player should be closed") + }) + } +} + +func TestClose(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 2, delay: 1000}, + }) + require.NoError(t, err) + + require.NoError(t, p.Play()) + + clk.BlockUntil(1) // player is now waiting to emit event 0 + + // advance to next event (player will have emitted event 0 + // and will be waiting to emit event 1) + clk.Advance(1001 * time.Millisecond) + clk.BlockUntil(1) + evt := <-p.C() + require.Equal(t, int64(0), evt.GetIndex()) + + require.NoError(t, p.Close()) + + // channel should have been closed + _, ok := <-p.C() + require.False(t, ok, "player channel should have been closed") + require.NoError(t, p.Err()) +} + +func TestSeekForward(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 10, delay: 1000}, + }) + require.NoError(t, err) + require.NoError(t, p.Play()) + + clk.BlockUntil(1) // player is now waiting to emit event 0 + + // advance playback until right before the last event + p.SetPos(9001 * time.Millisecond) + + // advance the clock to unblock the player + // (it should now spit out all but the last event in rapid succession) + clk.Advance(1001 * time.Millisecond) + + ch := make(chan struct{}) + go func() { + defer close(ch) + for evt := range p.C() { + t.Logf("got event %v (delay=%v)", evt.GetID(), evt.GetCode()) + } + }() + + clk.BlockUntil(1) + require.Equal(t, int64(9000), p.LastPlayed()) + + clk.Advance(999 * time.Millisecond) + select { + case <-ch: + case <-time.After(3 * time.Second): + require.FailNow(t, "player hasn't closed in time") + } +} + +func TestRewind(t *testing.T) { + clk := clockwork.NewFakeClock() + p, err := player.New(&player.Config{ + Clock: clk, + SessionID: "test-session", + Streamer: &simpleStreamer{count: 10, delay: 1000}, + }) + require.NoError(t, err) + require.NoError(t, p.Play()) + + // play through 7 events at regular speed + for i := 0; i < 7; i++ { + clk.BlockUntil(1) // player is now waiting to emit event + clk.Advance(1000 * time.Millisecond) // unblock event + <-p.C() // read event + } + + // now "rewind" to the point just prior to event index 3 (4000 ms into session) + clk.BlockUntil(1) + p.SetPos(3900 * time.Millisecond) + + // when we advance the clock, we expect the following behavior: + // - event index 7 (which we were blocked on) comes out right away + // - playback restarts, events 0 through 2 are emitted immediately + // - event index 3 is emitted after another 100ms + clk.Advance(1000 * time.Millisecond) + require.Equal(t, int64(7), (<-p.C()).GetIndex()) + require.Equal(t, int64(0), (<-p.C()).GetIndex(), "expected playback to retart for rewind") + require.Equal(t, int64(1), (<-p.C()).GetIndex(), "expected rapid streaming up to rewind point") + require.Equal(t, int64(2), (<-p.C()).GetIndex()) + clk.BlockUntil(1) + clk.Advance(100 * time.Millisecond) + require.Equal(t, int64(3), (<-p.C()).GetIndex()) + + p.Close() +} + +// simpleStreamer streams a fake session that contains +// count events, emitted at a particular interval +type simpleStreamer struct { + count int64 + delay int64 // milliseconds +} + +func (s *simpleStreamer) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { + errors := make(chan error, 1) + evts := make(chan apievents.AuditEvent) + + go func() { + defer close(evts) + + for i := int64(0); i < s.count; i++ { + select { + case <-ctx.Done(): + return + case evts <- &apievents.SessionPrint{ + Metadata: apievents.Metadata{ + Type: events.SessionPrintEvent, + Index: i, + ID: strconv.Itoa(int(i)), + Code: strconv.FormatInt((i+1)*s.delay, 10), + }, + Data: []byte(fmt.Sprintf("event %d\n", i)), + ChunkIndex: i, // TODO(zmb3) deprecate this + DelayMilliseconds: (i + 1) * s.delay, + }: + } + } + }() + + return evts, errors +} diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index e311c3c2061b4..fd6fa62f25541 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -37,6 +37,7 @@ import ( "time" "github.com/google/uuid" + "github.com/gorilla/websocket" "github.com/gravitational/oxy/ratelimit" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" @@ -142,6 +143,11 @@ type Handler struct { // nodeWatcher is a services.NodeWatcher used by Assist to lookup nodes from // the proxy's cache and get nodes in real time. nodeWatcher *services.NodeWatcher + + // wsIODeadline is used to set a deadline for receiving a message from + // an authenticated websocket so unauthenticated sockets dont get left + // open. + wsIODeadline time.Duration } // HandlerOption is a functional argument - an option that can be passed @@ -338,6 +344,7 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { clock: clockwork.NewRealClock(), ClusterFeatures: cfg.ClusterFeatures, healthCheckAppServer: cfg.HealthCheckAppServer, + wsIODeadline: wsIODeadline, } // Check for self-hosted vs Cloud. @@ -672,7 +679,10 @@ func (h *Handler) bindDefaultEndpoints() { h.DELETE("/webapi/sites/:site/locks/:uuid", h.WithClusterAuth(h.deleteClusterLock)) // active sessions handlers - h.GET("/webapi/sites/:site/connect", h.WithClusterAuth(h.siteNodeConnect)) // connect to an active session (via websocket) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/connect", h.WithClusterAuthWebSocket(false, h.siteNodeConnect)) // connect to an active session (via websocket) + h.GET("/webapi/sites/:site/connect/ws", h.WithClusterAuthWebSocket(true, h.siteNodeConnect)) // connect to an active session (via websocket, with auth over websocket) h.GET("/webapi/sites/:site/sessions", h.WithClusterAuth(h.clusterActiveAndPendingSessionsGet)) // get list of active and pending sessions // Audit events handlers. @@ -771,9 +781,17 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/desktopservices", h.WithClusterAuth(h.clusterDesktopServicesGet)) h.GET("/webapi/sites/:site/desktops/:desktopName", h.WithClusterAuth(h.getDesktopHandle)) // GET /webapi/sites/:site/desktops/:desktopName/connect?access_token=&username=&width=&height= - h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuth(h.desktopConnectHandle)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktops/:desktopName/connect", h.WithClusterAuthWebSocket(false, h.desktopConnectHandle)) + // GET /webapi/sites/:site/desktops/:desktopName/connect?username=&width=&height= + h.GET("/webapi/sites/:site/desktops/:desktopName/connect/ws", h.WithClusterAuthWebSocket(true, h.desktopConnectHandle)) // GET /webapi/sites/:site/desktopplayback/:sid?access_token= - h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuth(h.desktopPlaybackHandle)) + // Deprecated: The desktopplayback/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/desktopplayback/:sid", h.WithClusterAuthWebSocket(false, h.desktopPlaybackHandle)) + // GET /webapi/sites/:site/desktopplayback/:sid/ws + h.GET("/webapi/sites/:site/desktopplayback/:sid/ws", h.WithClusterAuthWebSocket(true, h.desktopPlaybackHandle)) h.GET("/webapi/sites/:site/desktops/:desktopName/active", h.WithClusterAuth(h.desktopIsActive)) // GET a Connection Diagnostics by its name @@ -820,7 +838,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/sites/:site/user-groups", h.WithClusterAuth(h.getUserGroups)) // WebSocket endpoint for the chat conversation - h.GET("/webapi/sites/:site/assistant", h.WithClusterAuth(h.assistant)) + // Deprecated: The connect/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/sites/:site/assistant", h.WithClusterAuthWebSocket(false, h.assistant)) + // WebSocket endpoint for the chat conversation, websocket auth + h.GET("/webapi/sites/:site/assistant/ws", h.WithClusterAuthWebSocket(true, h.assistant)) // Sets the title for the conversation. h.POST("/webapi/assistant/conversations/:conversation_id/title", h.WithAuth(h.setAssistantTitle)) @@ -839,7 +861,11 @@ func (h *Handler) bindDefaultEndpoints() { h.GET("/webapi/assistant/conversations/:conversation_id", h.WithAuth(h.getAssistantConversationByID)) // Allows executing an arbitrary command on multiple nodes. - h.GET("/webapi/command/:site/execute", h.WithClusterAuth(h.executeCommand)) + // Deprecated: The execute/ws variant should be used instead. + // TODO(lxea): DELETE in v16 + h.GET("/webapi/command/:site/execute", h.WithClusterAuthWebSocket(false, h.executeCommand)) + // Allows executing an arbitrary command on multiple nodes, websocket auth. + h.GET("/webapi/command/:site/execute/ws", h.WithClusterAuthWebSocket(true, h.executeCommand)) // Fetches the user's preferences h.GET("/webapi/user/preferences", h.WithAuth(h.getUserPreferences)) @@ -2698,6 +2724,7 @@ func (h *Handler) siteNodeConnect( p httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { q := r.URL.Query() params := q.Get("params") @@ -2795,6 +2822,7 @@ func (h *Handler) siteNodeConnect( PROXYSigner: h.cfg.PROXYSigner, Tracker: tracker, Clock: h.clock, + WebsocketConn: ws, } term, err := NewTerminal(ctx, terminalConfig) @@ -3566,6 +3594,9 @@ type ContextHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Pa // ClusterHandler is a authenticated handler that is called for some existing remote cluster type ClusterHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (interface{}, error) +// ClusterWebsocketHandler is a authenticated websocket handler that is called for some existing remote cluster +type ClusterWebsocketHandler func(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn) (interface{}, error) + // WithClusterAuth wraps a ClusterHandler to ensure that a request is authenticated to this proxy // (the same as WithAuth), as well as to grab the remoteSite (which can represent this local cluster // or a remote trusted cluster) as specified by the ":site" url parameter. @@ -3580,12 +3611,108 @@ func (h *Handler) WithClusterAuth(fn ClusterHandler) httprouter.Handle { }) } +func (h *Handler) writeErrToWebSocket(ws *websocket.Conn, err error) { + if err == nil { + return + } + errEnvelope := Envelope{ + Type: defaults.WebsocketError, + Payload: trace.UserMessage(err), + } + env, err := errEnvelope.Marshal() + if err != nil { + h.log.WithError(err).Error("error marshaling proto") + return + } + if err := ws.WriteMessage(websocket.BinaryMessage, env); err != nil { + h.log.WithError(err).Error("error writing proto") + return + } +} + +// authnWsUpgrader is an upgrader that allows any origin to connect to the websocket. +// This makes our lives easier in our automated tests. While ordinarily this would be +// used to enforce the same-origin policy, we don't need to worry about that for authenticated +// websockets, which also require a valid bearer token sent over the websocket after upgrade. +// Therefore even if an attacker were to connect to the websocket and trick the browser into +// sending the session cookie, they would still fail to send the bearer token needed to authenticate. +var authnWsUpgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// WithClusterAuthWebSocket wraps a ClusterWebsocketHandler to ensure that a request is authenticated +// to this proxy via websocket if websocketAuth is true, or via query parameter if false (the same as WithAuth), as +// well as to grab the remoteSite (which can represent this local cluster or a remote trusted cluster) +// as specified by the ":site" url parameter. +// +// TODO(lxea): remove the 'websocketAuth' bool once the deprecated websocket handlers are removed +func (h *Handler) WithClusterAuthWebSocket(websocketAuth bool, fn ClusterWebsocketHandler) httprouter.Handle { + return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (any, error) { + var sctx *SessionContext + var ws *websocket.Conn + var site reversetunnelclient.RemoteSite + var err error + + if websocketAuth { + sctx, ws, site, err = h.authenticateWSRequestWithCluster(w, r, p) + } else { + sctx, ws, site, err = h.authenticateWSRequestWithClusterDeprecated(w, r, p) + } + + if err != nil { + return nil, trace.Wrap(err) + } + // WS protocol requires the server send a close message + // which should be done by downstream users + defer ws.Close() + if _, err := fn(w, r, p, sctx, site, ws); err != nil { + h.writeErrToWebSocket(ws, err) + } + return nil, nil + }) +} + +// authenticateWSRequestWithCluster ensures that a request is +// authenticated to this proxy via websocket, returning the +// *SessionContext (same as AuthenticateRequest), and also grabs the +// remoteSite (which can represent this local cluster or a remote +// trusted cluster) as specified by the ":site" url parameter. +func (h *Handler) authenticateWSRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, ws, err := h.AuthenticateRequestWS(w, r) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + site, err := h.getSiteByParams(sctx, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + + return sctx, ws, site, nil +} + +// TODO(lxea): remove once the deprecated websocket handlers are removed +func (h *Handler) authenticateWSRequestWithClusterDeprecated(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, *websocket.Conn, reversetunnelclient.RemoteSite, error) { + sctx, site, err := h.authenticateRequestWithCluster(w, r, p) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, nil, trace.Wrap(err) + } + return sctx, ws, site, nil +} + // authenticateRequestWithCluster ensures that a request is authenticated // to this proxy, returning the *SessionContext (same as AuthenticateRequest), // and also grabs the remoteSite (which can represent this local cluster or a // remote trusted cluster) as specified by the ":site" url parameter. func (h *Handler) authenticateRequestWithCluster(w http.ResponseWriter, r *http.Request, p httprouter.Params) (*SessionContext, reversetunnelclient.RemoteSite, error) { sctx, err := h.AuthenticateRequest(w, r, true) + if err != nil { h.log.WithError(err).Warn("Failed to authenticate.") return nil, nil, trace.Wrap(err) @@ -3849,9 +3976,34 @@ func (h *Handler) WithLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.Handler } } -// AuthenticateRequest authenticates request using combination of a session cookie -// and bearer token -func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { +// WithHighLimiterHandlerFunc adds IP-based rate limiting to a HandlerFunc. This is similar to WithLimiterHandlerFunc +// but provides a higher rate limit. This should only be used for requests which are only CPU bound (no disk or other +// resources used). +func (h *Handler) WithHighLimiterHandlerFunc(fn httplib.HandlerFunc) httplib.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { + err := rateLimitRequest(r, h.highLimiter) + if err != nil { + return nil, trace.Wrap(err) + } + return fn(w, r, p) + } +} + +func rateLimitRequest(r *http.Request, limiter *limiter.RateLimiter) error { + remote, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return trace.Wrap(err) + } + + err = limiter.RegisterRequest(remote, nil /* customRate */) + // MaxRateError doesn't play well with errors.Is, hence the type assertion. + if _, ok := err.(*ratelimit.MaxRateError); ok { + return trace.LimitExceeded(err.Error()) + } + return trace.Wrap(err) +} + +func (h *Handler) validateCookie(w http.ResponseWriter, r *http.Request) (*SessionContext, error) { const missingCookieMsg = "missing session cookie" logger := h.log.WithField("request", fmt.Sprintf("%v %v", r.Method, r.URL.Path)) cookie, err := r.Cookie(CookieName) @@ -3866,24 +4018,97 @@ func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, ch logger.WithError(err).Warn("Failed to decode cookie.") return nil, trace.AccessDenied("failed to decode cookie") } - ctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID) + sctx, err := h.auth.getOrCreateSession(r.Context(), decodedCookie.User, decodedCookie.SID) if err != nil { logger.WithError(err).Warn("Invalid session.") ClearSession(w) return nil, trace.AccessDenied("need auth") } + + return sctx, nil +} + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token +func (h *Handler) AuthenticateRequest(w http.ResponseWriter, r *http.Request, checkBearerToken bool) (*SessionContext, error) { + logger := h.log.WithField("request", fmt.Sprintf("%v %v", r.Method, r.URL.Path)) + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, trace.Wrap(err) + } if checkBearerToken { creds, err := roundtrip.ParseAuthHeaders(r) if err != nil { logger.WithError(err).Warn("No auth headers.") return nil, trace.AccessDenied("need auth") } - if err := ctx.validateBearerToken(r.Context(), creds.Password); err != nil { - logger.WithError(err).Warn("Request failed: bad bearer token.") + if err := sctx.validateBearerToken(r.Context(), creds.Password); err != nil { return nil, trace.AccessDenied("bad bearer token") } } - return ctx, nil + return sctx, nil +} + +type wsBearerToken struct { + Token string `json:"token"` +} + +type wsStatus struct { + Type string `json:"type"` + Status string `json:"status"` + Message string `json:"message,omitempty"` +} + +// wsIODeadline is used to set a deadline for receiving a message from +// an authenticated websocket so unauthenticated sockets dont get left +// open. +const wsIODeadline = time.Second * 4 + +// AuthenticateRequest authenticates request using combination of a session cookie +// and bearer token retrieved from a websocket +func (h *Handler) AuthenticateRequestWS(w http.ResponseWriter, r *http.Request) (*SessionContext, *websocket.Conn, error) { + sctx, err := h.validateCookie(w, r) + if err != nil { + return nil, nil, trace.Wrap(err) + } + ws, err := authnWsUpgrader.Upgrade(w, r, nil) + if err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error upgrading to websocket: %v", err) + } + if err := ws.SetReadDeadline(time.Now().Add(wsIODeadline)); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + var t wsBearerToken + if err := ws.ReadJSON(&t); err != nil { + return nil, nil, trace.Wrap(err) + } + if err := sctx.validateBearerToken(r.Context(), t.Token); err != nil { + writeErr := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }) + if writeErr != nil { + log.Errorf("Error while writing invalid token error to websocket: %s", writeErr) + } + + return nil, nil, trace.Wrap(err) + } + + if err := ws.WriteJSON(wsStatus{ + Type: "create_session_response", + Status: "ok", + }); err != nil { + return nil, nil, trace.Wrap(err) + } + + // unset the deadline as downstream consumers should handle this themselves. + if err := ws.SetReadDeadline(time.Time{}); err != nil { + return nil, nil, trace.ConnectionProblem(err, "Error setting websocket read deadline: %v", err) + } + + return sctx, ws, nil } // ProxyWithRoles returns a reverse tunnel proxy verifying the permissions diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 26250266bf6dd..d8d1ac3c7a080 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -6711,9 +6711,9 @@ func TestDiagnoseKubeConnection(t *testing.T) { } foundTrace = true - require.Equal(t, expectedTrace.Status.String(), returnedTrace.Status) - require.Equal(t, expectedTrace.Details, returnedTrace.Details) - require.Contains(t, expectedTrace.Error, returnedTrace.Error) + require.Equal(t, returnedTrace.Status, expectedTrace.Status.String()) + require.Equal(t, returnedTrace.Details, expectedTrace.Details) + require.Contains(t, returnedTrace.Error, expectedTrace.Error) } require.True(t, foundTrace, expectedTrace) @@ -7303,7 +7303,7 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp u := url.URL{ Host: s.url().Host, Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), } data, err := json.Marshal(req) if err != nil { @@ -7312,7 +7312,6 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp q := u.Query() q.Set("params", string(data)) - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -7338,6 +7337,10 @@ func (s *WebSuite) makeTerminal(t *testing.T, pack *authPack, opts ...terminalOp return nil, nil, trace.Wrap(err, sb.String()) } + if err := makeAuthReqOverWS(ws, pack.session.Token); err != nil { + return nil, nil, trace.Wrap(err) + } + ty, raw, err := ws.ReadMessage() if err != nil { return nil, nil, trace.Wrap(err) @@ -8025,7 +8028,7 @@ func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, - Path: fmt.Sprintf("/v1/webapi/sites/%v/connect", currentSiteShortcut), + Path: fmt.Sprintf("/v1/webapi/sites/%v/connect/ws", currentSiteShortcut), } requestData := TerminalRequest{ @@ -8067,6 +8070,9 @@ func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session require.NoError(t, resp.Body.Close()) }) + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) + ty, raw, err := ws.ReadMessage() require.NoError(t, err) require.Equal(t, websocket.BinaryMessage, ty) @@ -8079,18 +8085,38 @@ func (r *testProxy) makeTerminal(t *testing.T, pack *authPack, sessionID session return ws, sessResp.Session } +func makeAuthReqOverWS(ws *websocket.Conn, token string) error { + authReq, err := json.Marshal(struct { + Token string `json:"token"` + }{Token: token}) + if err != nil { + return trace.Wrap(err) + } + + if err := ws.WriteMessage(websocket.TextMessage, authReq); err != nil { + return trace.Wrap(err) + } + _, authRes, err := ws.ReadMessage() + if err != nil { + return trace.Wrap(err) + } + if !strings.Contains(string(authRes), `"status":"ok"`) { + return trace.AccessDenied("unexpected response") + } + return nil +} + func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID session.ID, addr net.Addr) *websocket.Conn { u := url.URL{ Host: r.webURL.Host, Scheme: client.WSS, - Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect", currentSiteShortcut, "desktop1"), + Path: fmt.Sprintf("/webapi/sites/%s/desktops/%s/connect/ws", currentSiteShortcut, "desktop1"), } q := u.Query() q.Set("username", "marek") q.Set("width", "100") q.Set("height", "100") - q.Set(roundtrip.AccessTokenQueryParam, pack.session.Token) u.RawQuery = q.Encode() dialer := websocket.Dialer{} @@ -8105,6 +8131,10 @@ func (r *testProxy) makeDesktopSession(t *testing.T, pack *authPack, sessionID s ws, resp, err := dialer.Dial(u.String(), header) require.NoError(t, err) + + err = makeAuthReqOverWS(ws, pack.session.Token) + require.NoError(t, err) + t.Cleanup(func() { require.NoError(t, ws.Close()) require.NoError(t, resp.Body.Close()) @@ -9064,6 +9094,111 @@ func (s *fakeKubeService) ListKubernetesResources(ctx context.Context, req *kube }, nil } +func TestWebSocketAuthenticateRequest(t *testing.T) { + t.Parallel() + ctx := context.Background() + env := newWebPack(t, 1) + proxy := env.proxies[0] + proxy.handler.handler.wsIODeadline = time.Second + pack := proxy.authPack(t, "test-user@example.com", nil) + for _, tc := range []struct { + name string + serverExpectError string + expectResponse wsStatus + token string + writeTimeout func() + readTimeout func() + }{ + { + name: "valid token", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "ok", + }, + token: pack.session.Token, + }, + { + name: "invalid token", + serverExpectError: "not found", + expectResponse: wsStatus{ + Type: "create_session_response", + Status: "error", + Message: "invalid token", + }, + token: "honk", + }, + { + name: "server read timeout", + serverExpectError: "i/o timeout", + token: pack.session.Token, + readTimeout: func() { + <-time.After(wsIODeadline * 3) + }, + }, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sctx, ws, err := proxy.handler.handler.AuthenticateRequestWS(w, r) + if err != nil { + if tc.serverExpectError == "" { + t.Errorf("unexpected error: %v", err) + } + if !strings.Contains(err.Error(), tc.serverExpectError) { + t.Errorf("unexpected error: %v", err) + return + } + return + } + t.Cleanup(func() { ws.Close() }) + if err == nil && tc.serverExpectError != "" { + t.Errorf("expected error, got nil") + return + } + + clt, err := sctx.GetClient() + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + _, err = clt.GetDomainName(ctx) + if err != nil { + t.Errorf("unexpected error: %v", err) + return + } + })) + + header := http.Header{} + for _, cookie := range pack.cookies { + header.Add("Cookie", cookie.String()) + } + + u := strings.Replace(server.URL, "http:", "ws:", 1) + conn, resp, err := websocket.DefaultDialer.Dial(u, header) + require.NoError(t, err) + t.Cleanup(func() { conn.Close() }) + t.Cleanup(func() { resp.Body.Close() }) + + if tc.readTimeout != nil { + tc.readTimeout() + } + err = conn.WriteJSON(wsBearerToken{ + Token: tc.token, + }) + require.NoError(t, err) + if tc.readTimeout != nil { + return // Reading will fail as the server will have closed the connection + } + + var status wsStatus + err = conn.ReadJSON(&status) + require.NoError(t, err) + require.Equal(t, tc.expectResponse, status) + }) + } +} + // TestSimultaneousAuthenticateRequest ensures that multiple authenticated // requests do not race to create a SessionContext. This would happen when // Proxies were deployed behind a round-robin load balancer. Only the Proxy diff --git a/lib/web/assistant.go b/lib/web/assistant.go index bfb4ea09f1b5e..bb1d476d24154 100644 --- a/lib/web/assistant.go +++ b/lib/web/assistant.go @@ -330,9 +330,9 @@ func (h *Handler) generateAssistantTitle(_ http.ResponseWriter, r *http.Request, // This handler covers the main chat conversation as well as the // SSH completition (SSH command generation and output explanation). func (h *Handler) assistant(w http.ResponseWriter, r *http.Request, _ httprouter.Params, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (any, error) { - if err := runAssistant(h, w, r, sctx, site); err != nil { + if err := runAssistant(h, w, r, sctx, site, ws); err != nil { h.log.Warn(trace.DebugReport(err)) return nil, trace.Wrap(err) } @@ -386,7 +386,7 @@ func checkAssistEnabled(a auth.ClientI, ctx context.Context) error { // runAssistant upgrades the HTTP connection to a websocket and starts a chat loop. func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, - sctx *SessionContext, site reversetunnelclient.RemoteSite, + sctx *SessionContext, site reversetunnelclient.RemoteSite, ws *websocket.Conn, ) (err error) { q := r.URL.Query() conversationID := q.Get("conversation_id") @@ -426,20 +426,6 @@ func runAssistant(h *Handler, w http.ResponseWriter, r *http.Request, return trace.Wrap(err) } - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil - } - // Note: This time should be longer than OpenAI response time. keepAliveInterval := netConfig.GetKeepAliveInterval() err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) diff --git a/lib/web/command.go b/lib/web/command.go index b5c7cc6797463..cd2c602d2d95d 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -129,6 +129,7 @@ func (h *Handler) executeCommand( _ httprouter.Params, sessionCtx *SessionContext, site reversetunnelclient.RemoteSite, + rawWS *websocket.Conn, ) (any, error) { q := r.URL.Query() params := q.Get("params") @@ -177,20 +178,6 @@ func (h *Handler) executeCommand( clusterName := site.GetName() - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - rawWS, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - h.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return nil, nil - } - defer func() { rawWS.WriteMessage(websocket.CloseMessage, nil) rawWS.Close() diff --git a/lib/web/desktop.go b/lib/web/desktop.go index ca4ff6be12ec5..b3d25228d2dd3 100644 --- a/lib/web/desktop.go +++ b/lib/web/desktop.go @@ -63,6 +63,7 @@ func (h *Handler) desktopConnectHandle( p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { desktopName := p.ByName("desktopName") if desktopName == "" { @@ -72,7 +73,7 @@ func (h *Handler) desktopConnectHandle( log := sctx.cfg.Log.WithField("desktop-name", desktopName).WithField("cluster-name", site.GetName()) log.Debug("New desktop access websocket connection") - if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site); err != nil { + if err := h.createDesktopConnection(w, r, desktopName, site.GetName(), log, sctx, site, ws); err != nil { // createDesktopConnection makes a best effort attempt to send an error to the user // (via websocket) before terminating the connection. We log the error here, but // return nil because our HTTP middleware will try to write the returned error in JSON @@ -97,15 +98,8 @@ func (h *Handler) createDesktopConnection( log *logrus.Entry, sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) error { - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - } - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - return trace.Wrap(err) - } defer ws.Close() sendTDPError := func(err error) error { diff --git a/lib/web/desktop/playback.go b/lib/web/desktop/playback.go index de53c02d6c00d..c51d5913883fc 100644 --- a/lib/web/desktop/playback.go +++ b/lib/web/desktop/playback.go @@ -18,19 +18,14 @@ package desktop import ( "context" - "errors" + "encoding/json" "fmt" - "net" - "os" - "sync" - "time" - "github.com/gravitational/trace" + "github.com/gorilla/websocket" "github.com/sirupsen/logrus" - "golang.org/x/net/websocket" - apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/player" "github.com/gravitational/teleport/lib/utils" ) @@ -39,74 +34,6 @@ const ( maxPlaybackSpeed = 16 ) -// Player manages the playback of a recorded desktop session. -// It streams events from the audit log to the browser over -// a websocket connection. -type Player struct { - ws *websocket.Conn - streamer Streamer - - mu sync.Mutex - cond *sync.Cond - playState playbackState - playSpeed float32 - - log logrus.FieldLogger - sID string - - closeOnce sync.Once -} - -// Streamer is the interface that can provide with a stream of events related to -// a particular session. -type Streamer interface { - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first - // channel if one is encountered. Otherwise the event channel is closed when the stream ends. - // The event channel is not closed on error to prevent race conditions in downstream select statements. - StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) -} - -// NewPlayer creates a player that streams a desktop session -// over the provided websocket connection. -func NewPlayer(sID string, ws *websocket.Conn, streamer Streamer, log logrus.FieldLogger) *Player { - p := &Player{ - ws: ws, - streamer: streamer, - playState: playStatePlaying, - log: log, - sID: sID, - playSpeed: 1.0, - } - p.cond = sync.NewCond(&p.mu) - return p -} - -// Play kicks off goroutines for receiving actions -// and playing back the session over the websocket, -// and then waits for the stream to complete. -func (pp *Player) Play(ctx context.Context) { - defer pp.log.Debug("playbackPlayer.Play returned") - - pp.ws.PayloadType = websocket.BinaryFrame - ppCtx, cancel := context.WithCancel(ctx) - defer pp.close(cancel) - - go pp.receiveActions(cancel) - go pp.streamSessionEvents(ppCtx, cancel) - - // Wait until the ctx is canceled, either by - // one of the goroutines above or by the http handler. - <-ppCtx.Done() -} - -type playbackState string - -const ( - playStatePlaying = playbackState("playing") - playStatePaused = playbackState("paused") - playStateFinished = playbackState("finished") -) - // playbackAction identifies a command sent from the // browser to control playback type playbackAction string @@ -125,163 +52,101 @@ const ( // control playback. type actionMessage struct { Action playbackAction `json:"action"` - PlaybackSpeed float32 `json:"speed,omitempty"` -} - -// waitWhilePaused waits idly while the player's state is paused, waiting until: -// - the play state is toggled back to playing -// - the play state is set to finished (the player is closed) -func (pp *Player) waitWhilePaused() { - pp.cond.L.Lock() - defer pp.cond.L.Unlock() - - for pp.playState == playStatePaused { - pp.cond.Wait() - } -} - -// togglePlaying toggles the state of the player between playing and paused, -// and wakes up any goroutines waiting in waitWhilePaused. -func (pp *Player) togglePlaying() { - pp.cond.L.Lock() - defer pp.cond.L.Unlock() - switch pp.playState { - case playStatePlaying: - pp.playState = playStatePaused - case playStatePaused: - pp.playState = playStatePlaying - } - pp.cond.Broadcast() -} - -// close closes the websocket connection, wakes up any goroutines waiting on the playState condition, -// and cancels the playbackPlayer's context. -// -// It should be deferred by all the goroutines that use playbackPlayer, -// in order to ensure that when one goroutine closes, all the others do too. -func (pp *Player) close(cancel context.CancelFunc) { - pp.closeOnce.Do(func() { - pp.mu.Lock() - defer pp.mu.Unlock() - - err := pp.ws.Close() - if err != nil { - pp.log.WithError(err).Errorf("websocket.Close() failed") - } - - pp.playState = playStateFinished - pp.cond.Broadcast() - cancel() - }) + PlaybackSpeed float64 `json:"speed,omitempty"` } -// receiveActions handles logic for receiving playbackAction jsons -// over the websocket and modifying playbackPlayer's state accordingly. -func (pp *Player) receiveActions(cancel context.CancelFunc) { - defer pp.log.Debug("playbackPlayer.ReceiveActions returned") - defer pp.close(cancel) +// ReceivePlaybackActions handles logic for receiving playbackAction messages +// over the websocket and updating the player state accordingly. +func ReceivePlaybackActions( + log logrus.FieldLogger, + ws *websocket.Conn, + player *player.Player) { + // playback always starts in a playing state + playing := true for { var action actionMessage - if err := websocket.JSON.Receive(pp.ws, &action); err != nil { - // We expect net.ErrClosed if the websocket is closed by another - // goroutine and io.EOF if the websocket is closed by the browser - // while websocket.JSON.Receive() is hanging. + + if err := ws.ReadJSON(&action); err != nil { + // Connection close errors are expected if the user closes the tab. + // Only log unexpected errors to avoid cluttering the logs. if !utils.IsOKNetworkError(err) { - pp.log.WithError(err).Error("error reading from websocket") + log.Warnf("websocket read error: %v", err) } return } - pp.log.Debugf("received playback action: %+v", action) + switch action.Action { case actionPlayPause: - pp.togglePlaying() - case actionSpeed: - if action.PlaybackSpeed < minPlaybackSpeed { - action.PlaybackSpeed = minPlaybackSpeed - } else if action.PlaybackSpeed > maxPlaybackSpeed { - action.PlaybackSpeed = maxPlaybackSpeed + if playing { + player.Pause() + } else { + player.Play() } - - pp.mu.Lock() - pp.playSpeed = action.PlaybackSpeed - pp.mu.Unlock() + playing = !playing + case actionSpeed: + action.PlaybackSpeed = max(action.PlaybackSpeed, minPlaybackSpeed) + action.PlaybackSpeed = min(action.PlaybackSpeed, maxPlaybackSpeed) + player.SetSpeed(action.PlaybackSpeed) default: - pp.log.Errorf("received unknown action: %v", action.Action) + log.Warnf("invalid desktop playback action: %v", action.Action) return } } } -// streamSessionEvents streams the session's events as playback events over the websocket. -func (pp *Player) streamSessionEvents(ctx context.Context, cancel context.CancelFunc) { - defer pp.log.Debug("playbackPlayer.StreamSessionEvents returned") - defer pp.close(cancel) - - var lastDelay int64 - scaleDelay := func(delay int64) int64 { - pp.mu.Lock() - defer pp.mu.Unlock() - return int64(float32(delay) / pp.playSpeed) - } - eventsC, errC := pp.streamer.StreamSessionEvents(ctx, session.ID(pp.sID), 0) +// PlayRecording feeds recorded events from a player +// over a websocket. +func PlayRecording( + ctx context.Context, + log logrus.FieldLogger, + ws *websocket.Conn, + player *player.Player) { + player.Play() for { - pp.waitWhilePaused() - select { - case err := <-errC: - if err != nil && !errors.Is(err, context.Canceled) { - pp.log.WithError(err).Errorf("streaming session %v", pp.sID) - var errorText string - if os.IsNotExist(err) || trace.IsNotFound(err) { - errorText = "session not found" - } else { - errorText = "server error" - } - if _, err := pp.ws.Write([]byte(fmt.Sprintf(`{"message": "error", "errorText": "%v"}`, errorText))); err != nil { - pp.log.WithError(err).Error("failed to write \"error\" message over websocket") - } - } + case <-ctx.Done(): return - case evt := <-eventsC: - if evt == nil { - pp.log.Debug("reached end of playback") - if _, err := pp.ws.Write([]byte(`{"message":"end"}`)); err != nil { - pp.log.WithError(err).Error("failed to write \"end\" message over websocket") - } - return - } - switch e := evt.(type) { - case *apievents.DesktopRecording: - if e.DelayMilliseconds > lastDelay { - // TODO(zmb3): replace with time.After so we can cancel - time.Sleep(time.Duration(scaleDelay(e.DelayMilliseconds-lastDelay)) * time.Millisecond) - lastDelay = e.DelayMilliseconds - } - msg, err := utils.FastMarshal(e) - if err != nil { - pp.log.WithError(err).Errorf("failed to marshal DesktopRecording event into JSON: %v", e) - if _, err := pp.ws.Write([]byte(`{"message":"error","errorText":"server error"}`)); err != nil { - pp.log.WithError(err).Error("failed to write \"error\" message over websocket") + case evt, ok := <-player.C(): + if !ok { + if playerErr := player.Err(); playerErr != nil { + // Attempt to JSONify the error (escaping any quotes) + msg, err := json.Marshal(playerErr.Error()) + if err != nil { + log.Warnf("failed to marshal player error message: %v", err) + msg = []byte(`"internal server error"`) } - return - } - if _, err := pp.ws.Write(msg); err != nil { - // We expect net.ErrClosed to arise when another goroutine returns before - // this one or the browser window is closed, both of which cause the websocket to close. - if !errors.Is(err, net.ErrClosed) { - pp.log.WithError(err).Error("failed to write DesktopRecording event over websocket") + //lint:ignore QF1012 this write needs to happen in a single operation + bytes := []byte(fmt.Sprintf(`{"message":"error", "errorText":%s}`, string(msg))) + if err := ws.WriteMessage(websocket.BinaryMessage, bytes); err != nil { + log.Errorf("failed to write error message: %v", err) } return } - case *apievents.WindowsDesktopSessionStart, *apievents.WindowsDesktopSessionEnd: - // these events are part of the stream but never needed for playback - case *apievents.DesktopClipboardReceive, *apievents.DesktopClipboardSend: - // these events are not currently needed for playback, - // but may be useful in the future + if err := ws.WriteMessage(websocket.BinaryMessage, []byte(`{"message":"end"}`)); err != nil { + log.Errorf("failed to write end message: %v", err) + } + return + } - default: - pp.log.Warnf("session %v contains unexpected event type %T", pp.sID, evt) + // some events are part of the stream but not currently + // needed during playback (session start/end, clipboard use, etc) + if _, ok := evt.(*events.DesktopRecording); !ok { + continue + } + msg, err := utils.FastMarshal(evt) + if err != nil { + log.Errorf("failed to marshal desktop event: %v", err) + ws.WriteMessage(websocket.BinaryMessage, []byte(`{"message":"error","errorText":"server error"}`)) + return + } + if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil { + // Connection close errors are expected if the user closes the tab. + // Only log unexpected errors to avoid cluttering the logs. + if !utils.IsOKNetworkError(err) { + log.Warnf("websocket write error: %v", err) + } + return } } } diff --git a/lib/web/desktop/playback_test.go b/lib/web/desktop/playback_test.go index 4728b40a6a843..f4fad2a3e7354 100644 --- a/lib/web/desktop/playback_test.go +++ b/lib/web/desktop/playback_test.go @@ -23,11 +23,15 @@ import ( "testing" "time" + "github.com/gorilla/websocket" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/net/websocket" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events/eventstest" + "github.com/gravitational/teleport/lib/player" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/web/desktop" ) @@ -41,43 +45,64 @@ func TestStreamsDesktopEvents(t *testing.T) { &apievents.DesktopRecording{Message: []byte("jkl")}, } s := newServer(t, 20*time.Millisecond, events) - url := strings.Replace(s.URL, "http", "ws", 1) - cfg, err := websocket.NewConfig(url, "http://localhost") - require.NoError(t, err) // connect to the server and verify that we receive // all 4 JSON-encoded events - ws, err := websocket.DialConfig(cfg) + url := strings.Replace(s.URL, "http", "ws", 1) + + // As per https://pkg.go.dev/github.com/gorilla/websocket#Dialer.DialContext: + // "The response body may not contain the entire response and does not need to be closed by the application." + //nolint:bodyclose // false positive + ws, _, err := websocket.DefaultDialer.Dial(url, nil) require.NoError(t, err) + t.Cleanup(func() { ws.Close() }) for _, evt := range events { - b := make([]byte, 4096) - n, err := ws.Read(b) + typ, b, err := ws.ReadMessage() require.NoError(t, err) + require.Equal(t, websocket.BinaryMessage, typ) var dr apievents.DesktopRecording - err = utils.FastUnmarshal(b[:n], &dr) + err = utils.FastUnmarshal(b, &dr) require.NoError(t, err) require.Equal(t, evt.(*apievents.DesktopRecording).Message, dr.Message) } - b := make([]byte, 4096) - n, err := ws.Read(b) + typ, b, err := ws.ReadMessage() require.NoError(t, err) - require.JSONEq(t, `{"message":"end"}`, string(b[:n])) + require.Equal(t, websocket.BinaryMessage, typ) + require.JSONEq(t, `{"message":"end"}`, string(b)) } func newServer(t *testing.T, streamInterval time.Duration, events []apievents.AuditEvent) *httptest.Server { t.Helper() fs := eventstest.NewFakeStreamer(events, streamInterval) + log := utils.NewLoggerForTests() + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - websocket.Handler(func(ws *websocket.Conn) { - desktop.NewPlayer("session-id", ws, fs, utils.NewLoggerForTests()).Play(r.Context()) - }).ServeHTTP(w, r) + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer ws.Close() + + player, err := player.New(&player.Config{ + Clock: clockwork.NewRealClock(), + Log: log, + SessionID: session.ID("session-id"), + Streamer: fs, + }) + assert.NoError(t, err) + player.Play() + desktop.PlayRecording(r.Context(), log, ws, player) })) - t.Cleanup(s.Close) + t.Cleanup(s.Close) return s } diff --git a/lib/web/desktop_playback.go b/lib/web/desktop_playback.go index be2035288e580..69a3b7e999a2f 100644 --- a/lib/web/desktop_playback.go +++ b/lib/web/desktop_playback.go @@ -17,13 +17,16 @@ limitations under the License. package web import ( + "context" "net/http" + "github.com/gorilla/websocket" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" - "golang.org/x/net/websocket" + "github.com/gravitational/teleport/lib/player" "github.com/gravitational/teleport/lib/reversetunnelclient" + "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/web/desktop" ) @@ -31,22 +34,49 @@ func (h *Handler) desktopPlaybackHandle( w http.ResponseWriter, r *http.Request, p httprouter.Params, - ctx *SessionContext, + sctx *SessionContext, site reversetunnelclient.RemoteSite, + ws *websocket.Conn, ) (interface{}, error) { sID := p.ByName("sid") if sID == "" { - return nil, trace.BadParameter("missing sid in request URL") + return nil, trace.BadParameter("missing session ID in request URL") } - clt, err := ctx.GetUserClient(r.Context(), site) + clt, err := sctx.GetUserClient(r.Context(), site) if err != nil { return nil, trace.Wrap(err) } - websocket.Handler(func(ws *websocket.Conn) { - defer h.log.Debug("desktopPlaybackHandle websocket handler returned") - desktop.NewPlayer(sID, ws, clt, h.log).Play(r.Context()) - }).ServeHTTP(w, r) + player, err := player.New(&player.Config{ + Clock: h.clock, + Log: h.log, + SessionID: session.ID(sID), + Streamer: clt, + }) + if err != nil { + h.log.Errorf("couldn't create player for session %v: %v", sID, err) + ws.WriteMessage(websocket.BinaryMessage, + []byte(`{"message": "error", "errorText": "Internal server error"}`)) + return nil, nil + } + + defer player.Close() + + ctx, cancel := context.WithCancel(r.Context()) + defer cancel() + + go func() { + defer cancel() + desktop.ReceivePlaybackActions(h.log, ws, player) + }() + + go func() { + defer cancel() + defer ws.Close() + desktop.PlayRecording(ctx, h.log, ws, player) + }() + + <-ctx.Done() return nil, nil } diff --git a/lib/web/terminal.go b/lib/web/terminal.go index 56fcaad49c50f..70bb69fdffb58 100644 --- a/lib/web/terminal.go +++ b/lib/web/terminal.go @@ -137,6 +137,7 @@ func NewTerminal(ctx context.Context, cfg TerminalHandlerConfig) (*TerminalHandl participantMode: cfg.ParticipantMode, tracker: cfg.Tracker, clock: cfg.Clock, + websocketConn: cfg.WebsocketConn, }, nil } @@ -181,6 +182,8 @@ type TerminalHandlerConfig struct { Tracker types.SessionTracker // Clock used for presence checking. Clock clockwork.Clock + // WebsocketConn is the active websocket connection + WebsocketConn *websocket.Conn } func (t *TerminalHandlerConfig) CheckAndSetDefaults() error { @@ -283,12 +286,15 @@ type TerminalHandler struct { // if the user is not joining a session. tracker types.SessionTracker - // clock to use for presence checking - clock clockwork.Clock - // closedByClient indicates if the websocket connection was closed by the // user (closing the browser tab, exiting the session, etc). closedByClient atomic.Bool + + // clock used to interact with time. + clock clockwork.Clock + + // websocketConn is the active websocket connection + websocketConn *websocket.Conn } // ServeHTTP builds a connection to the remote node and then pumps back two types of @@ -300,21 +306,9 @@ func (t *TerminalHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t.ctx.AddClosers(t) defer t.ctx.RemoveCloser(t) - upgrader := websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { return true }, - } - - ws, err := upgrader.Upgrade(w, r, nil) - if err != nil { - errMsg := "Error upgrading to websocket" - t.log.WithError(err).Error(errMsg) - http.Error(w, errMsg, http.StatusInternalServerError) - return - } + ws := t.websocketConn - err = ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) + err := ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)) if err != nil { t.log.WithError(err).Error("Error setting websocket readline") return diff --git a/web/packages/teleport/src/Assist/context/AssistContext.tsx b/web/packages/teleport/src/Assist/context/AssistContext.tsx index 349d3d713f453..b1b97a0f4684c 100644 --- a/web/packages/teleport/src/Assist/context/AssistContext.tsx +++ b/web/packages/teleport/src/Assist/context/AssistContext.tsx @@ -29,7 +29,7 @@ import { AssistStateActionType, reducer } from 'teleport/Assist/context/state'; import { convertServerMessages } from 'teleport/Assist/context/utils'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import { ExecutionEnvelopeType, @@ -45,6 +45,7 @@ import { makeMfaAuthenticateChallenge, WebauthnAssertionResponse, } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import * as service from '../service'; import { @@ -79,9 +80,9 @@ let lastCommandExecutionResultId = 0; const TEN_MINUTES = 10 * 60 * 1000; export function AssistContextProvider(props: PropsWithChildren) { - const activeWebSocket = useRef(null); + const activeWebSocket = useRef(null); // TODO(ryan): this should be removed once https://github.com/gravitational/teleport.e/pull/1609 is implemented - const executeCommandWebSocket = useRef(null); + const executeCommandWebSocket = useRef(null); const refreshWebSocketTimeout = useRef(null); const { clusterId } = useStickyClusterId(); @@ -119,11 +120,10 @@ export function AssistContextProvider(props: PropsWithChildren) { } function setupWebSocket(conversationId: string, initialMessage?: string) { - activeWebSocket.current = new WebSocket( + activeWebSocket.current = new AuthenticatedWebSocket( cfg.getAssistConversationWebSocketUrl( getHostName(), clusterId, - getAccessToken(), conversationId ) ); @@ -326,7 +326,7 @@ export function AssistContextProvider(props: PropsWithChildren) { if ( !activeWebSocket.current || - activeWebSocket.current.readyState === WebSocket.CLOSED + activeWebSocket.current.readyState === AuthenticatedWebSocket.CLOSED ) { setupWebSocket(state.conversations.selectedId, data); } else { @@ -356,7 +356,8 @@ export function AssistContextProvider(props: PropsWithChildren) { function sendMfaChallenge(data: WebauthnAssertionResponse) { if ( !executeCommandWebSocket.current || - executeCommandWebSocket.current.readyState !== WebSocket.OPEN || + executeCommandWebSocket.current.readyState !== + AuthenticatedWebSocket.OPEN || !data ) { console.warn( @@ -424,12 +425,11 @@ export function AssistContextProvider(props: PropsWithChildren) { const url = cfg.getAssistExecuteCommandUrl( getHostName(), clusterId, - getAccessToken(), execParams ); const proto = new Protobuf(); - executeCommandWebSocket.current = new WebSocket(url); + executeCommandWebSocket.current = new AuthenticatedWebSocket(url); executeCommandWebSocket.current.binaryType = 'arraybuffer'; executeCommandWebSocket.current.onmessage = event => { diff --git a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx index 9af1886c7a1a9..0b6535bdc4c5a 100644 --- a/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx +++ b/web/packages/teleport/src/Console/DocumentSsh/TerminalAssist/TerminalAssistContext.tsx @@ -24,7 +24,7 @@ import React, { } from 'react'; import { Author, ServerMessage } from 'teleport/Assist/types'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import useStickyClusterId from 'teleport/useStickyClusterId'; import cfg from 'teleport/config'; import { @@ -34,6 +34,7 @@ import { SuggestedCommandMessage, UserMessage, } from 'teleport/Console/DocumentSsh/TerminalAssist/types'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; interface TerminalAssistContextValue { close: () => void; @@ -55,11 +56,10 @@ export function TerminalAssistContextProvider( const [visible, setVisible] = useState(false); - const socketRef = useRef(null); + const socketRef = useRef(null); const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-cmdgen' ); @@ -70,7 +70,7 @@ export function TerminalAssistContextProvider( const [messages, setMessages] = useState([]); useEffect(() => { - socketRef.current = new WebSocket(socketUrl); + socketRef.current = new AuthenticatedWebSocket(socketUrl); socketRef.current.onmessage = e => { const data = JSON.parse(e.data) as ServerMessage; @@ -115,11 +115,10 @@ export function TerminalAssistContextProvider( const socketUrl = cfg.getAssistActionWebSocketUrl( getHostName(), clusterId, - getAccessToken(), 'ssh-explain' ); - const ws = new WebSocket(socketUrl); + const ws = new AuthenticatedWebSocket(socketUrl); ws.onopen = () => { ws.send(encodedOutput); diff --git a/web/packages/teleport/src/Console/consoleContext.tsx b/web/packages/teleport/src/Console/consoleContext.tsx index 519a63890b61d..c39a0eb3688f8 100644 --- a/web/packages/teleport/src/Console/consoleContext.tsx +++ b/web/packages/teleport/src/Console/consoleContext.tsx @@ -22,7 +22,7 @@ import { W3CTraceContextPropagator } from '@opentelemetry/core'; import webSession from 'teleport/services/websession'; import history from 'teleport/services/history'; import cfg, { UrlResourcesParams, UrlSshParams } from 'teleport/config'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import Tty from 'teleport/lib/term/tty'; import TtyAddressResolver from 'teleport/lib/term/ttyAddressResolver'; import serviceSession, { @@ -193,7 +193,6 @@ export default class ConsoleContext { const ttyUrl = cfg.api.ttyWsAddr .replace(':fqdn', getHostName()) - .replace(':token', getAccessToken()) .replace(':clusterId', clusterId) .replace(':traceparent', carrier['traceparent']); diff --git a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx index 3e0afce12cef8..97a0b78277960 100644 --- a/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx +++ b/web/packages/teleport/src/DesktopSession/useTdpClientCanvas.tsx @@ -22,7 +22,7 @@ import { getPlatformType } from 'design/platform'; import { TdpClient, ButtonState, ScrollAxis } from 'teleport/lib/tdp'; import { ClipboardData, PngFrame } from 'teleport/lib/tdp/codec'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import cfg from 'teleport/config'; import { Sha256Digest } from 'teleport/lib/util'; @@ -58,7 +58,6 @@ export default function useTdpClientCanvas(props: Props) { .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':desktopName', desktopName) - .replace(':token', getAccessToken()) .replace(':username', username) .replace(':width', width.toString()) .replace(':height', height.toString()); diff --git a/web/packages/teleport/src/Player/DesktopPlayer.tsx b/web/packages/teleport/src/Player/DesktopPlayer.tsx index 280af21335219..3a112addb6f41 100644 --- a/web/packages/teleport/src/Player/DesktopPlayer.tsx +++ b/web/packages/teleport/src/Player/DesktopPlayer.tsx @@ -21,7 +21,7 @@ import useAttempt from 'shared/hooks/useAttemptNext'; import cfg from 'teleport/config'; import { PlayerClient, PlayerClientEvent } from 'teleport/lib/tdp'; -import { getAccessToken, getHostName } from 'teleport/services/api'; +import { getHostName } from 'teleport/services/api'; import TdpClientCanvas from 'teleport/components/TdpClientCanvas'; import { ProgressBarDesktop } from './ProgressBar'; @@ -110,7 +110,6 @@ const useDesktopPlayer = ({ .replace(':fqdn', getHostName()) .replace(':clusterId', clusterId) .replace(':sid', sid) - .replace(':token', getAccessToken()) ) ); }, [clusterId, sid]); diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index e5813ffb81aa3..b66ed5c323838 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -168,12 +168,14 @@ const cfg = { desktopServicesPath: `/v1/webapi/sites/:clusterId/desktopservices?searchAsRoles=:searchAsRoles?&limit=:limit?&startKey=:startKey?&query=:query?&search=:search?&sort=:sort?`, desktopPath: `/v1/webapi/sites/:clusterId/desktops/:desktopName`, desktopWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect?access_token=:token&username=:username&width=:width&height=:height', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktops/:desktopName/connect/ws?username=:username&width=:width&height=:height', desktopPlaybackWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid?access_token=:token', + 'wss://:fqdn/v1/webapi/sites/:clusterId/desktopplayback/:sid/ws', desktopIsActive: '/v1/webapi/sites/:clusterId/desktops/:desktopName/active', ttyWsAddr: - 'wss://:fqdn/v1/webapi/sites/:clusterId/connect?access_token=:token¶ms=:params&traceparent=:traceparent', + 'wss://:fqdn/v1/webapi/sites/:clusterId/connect/ws?params=:params&traceparent=:traceparent', + ttyPlaybackWsAddr: + 'wss://:fqdn/v1/webapi/sites/:clusterId/ttyplayback/:sid?access_token=:token', // TODO(zmb3): get token out of URL activeAndPendingSessionsPath: '/v1/webapi/sites/:clusterId/sessions', sshPlaybackPrefix: '/v1/webapi/sites/:clusterId/sessions/:sid', // prefix because this is eventually concatenated with "/stream" or "/events" kubernetesPath: @@ -259,11 +261,11 @@ const cfg = { '/v1/webapi/assistant/conversations/:conversationId/title', assistGenerateSummaryPath: '/v1/webapi/assistant/title/summary', assistConversationWebSocketPath: - 'wss://:hostname/v1/webapi/sites/:clusterId/assistant', + 'wss://:hostname/v1/webapi/sites/:clusterId/assistant/ws', assistConversationHistoryPath: '/v1/webapi/assistant/conversations/:conversationId', assistExecuteCommandWebSocketPath: - 'wss://:hostname/v1/webapi/command/:clusterId/execute', + 'wss://:hostname/v1/webapi/command/:clusterId/execute/ws', userPreferencesPath: '/v1/webapi/user/preferences', }, @@ -741,12 +743,10 @@ const cfg = { getAssistConversationWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, conversationId: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('conversation_id', conversationId); return ( @@ -760,12 +760,10 @@ const cfg = { getAssistActionWebSocketUrl( hostname: string, clusterId: string, - accessToken: string, action: string ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('action', action); return ( @@ -785,12 +783,10 @@ const cfg = { getAssistExecuteCommandUrl( hostname: string, clusterId: string, - accessToken: string, params: Record ) { const searchParams = new URLSearchParams(); - searchParams.set('access_token', accessToken); searchParams.set('params', JSON.stringify(params)); return ( diff --git a/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts new file mode 100644 index 0000000000000..4c1d0c4e5e281 --- /dev/null +++ b/web/packages/teleport/src/lib/AuthenticatedWebSocket.ts @@ -0,0 +1,279 @@ +/** + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +import { getAccessToken } from 'teleport/services/api'; +import { WebsocketStatus } from 'teleport/types'; + +/** + * `AuthenticatedWebSocket` is a drop-in replacement for + * the `WebSocket` class that handles Teleport's websocket + * authentication process. + */ +export class AuthenticatedWebSocket extends WebSocket { + private authenticated: boolean = false; + private openListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onopenInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private messageListeners: ((this: WebSocket, ev: MessageEvent) => any)[] = []; + private onmessageInternal: + | ((this: WebSocket, ev: MessageEvent) => any) + | null = null; + private oncloseListeners: ((this: WebSocket, ev: CloseEvent) => any)[] = []; + private oncloseInternal: ((this: WebSocket, ev: CloseEvent) => any) | null = + null; + private onerrorListeners: ((this: WebSocket, ev: Event) => any)[] = []; + private onerrorInternal: ((this: WebSocket, ev: Event) => any) | null = null; + private binaryTypeInternal: BinaryType = 'blob'; // Default binaryType + private onopenEvent: Event | null = null; + + constructor(url: string | URL, protocols?: string | string[]) { + super(url, protocols); + // Set the binaryType to 'arraybuffer' to handle the authentication process. + super.binaryType = 'arraybuffer'; + + // The open event listener should immediately send the authentication token + super.onopen = (onopenEvent: Event) => { + super.send(JSON.stringify({ token: getAccessToken() })); + // Don't call the user defined onopen messages yet, wait for the authentication response. + this.onopenEvent = onopenEvent; + }; + + // The message event listener should handle the authentication response, + // and if it succeeds, set the binaryType to the user-defined value and + // trigger any user-added open listeners. + super.onmessage = (ev: MessageEvent) => { + // If not yet authenticated, handle the authentication response. + if (!this.authenticated) { + // Parse the message as a WebsocketStatus. + let authResponse: WebsocketStatus; + try { + authResponse = JSON.parse(ev.data) as WebsocketStatus; + } catch (e) { + this.triggerError('Error parsing JSON from websocket message: ' + e); + return; + } + + // Validate the WebsocketStatus. + if ( + !authResponse.type || + !authResponse.status || + !(authResponse.type === 'create_session_response') || + !(authResponse.status === 'ok' || authResponse.status === 'error') + ) { + this.triggerError( + 'Invalid auth response: ' + JSON.stringify(authResponse) + ); + return; + } + + // Authentication succeeded. + if (authResponse.status === 'ok') { + this.authenticated = true; + // Set the binaryType to the value set by the user (or back to the default 'blob'). + super.binaryType = this.binaryTypeInternal; + // Now that authentication is complete, trigger any user-added open listeners + // with the original onopen event. + this.openListeners.forEach(listener => + listener.call(this, this.onopenEvent) + ); + this.onopenInternal?.call(this, this.onopenEvent); + return; + } else { + // Authentication failed, authResponse.status === 'error'. + this.triggerError( + 'auth error connecting to websocket: ' + authResponse.message + ); + return; + } + } else { + // If authenticated, pass messages to user-added listeners. + this.messageListeners.forEach(listener => { + listener.call(this, ev); + }); + this.onmessageInternal?.call(this, ev); + } + }; + + // Set the 'close' event for cleanup. + super.onclose = (ev: CloseEvent) => { + // Trigger any user-added close listeners + this.oncloseListeners.forEach(listener => listener.call(this, ev)); + this.oncloseInternal?.call(this, ev); + this.authenticated = false; + }; + + // Set the 'error' event for cleanup. + super.onerror = (ev: Event) => { + // Trigger any user-added error listeners + this.onerrorListeners.forEach(listener => listener.call(this, ev)); + this.onerrorInternal?.call(this, ev); + this.authenticated = false; + }; + } + + // Authenticated send + override send(data: string | ArrayBufferLike | Blob | ArrayBufferView): void { + if (!this.authenticated) { + // This should be unreachable, but just in case. + this.triggerError( + 'Cannot send data before authentication is complete. Data: ' + data + ); + return; + } + super.send(data); + } + + // Override addEventListener to intercept these listeners and store them in + // our appropriate arrays. They are called in the appropriate places in the + // `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + override addEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + this.openListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + } else if (type === 'message') { + this.messageListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + } else if (type === 'close') { + this.oncloseListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + } else if (type === 'error') { + this.onerrorListeners.push( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + } else { + // This should be unreachable, but just in case. + super.addEventListener(type, listener); + } + } + + // Override the onopen, onmessage, onclose, and onerror properties to store the user-defined + // listeners in the appropriate internal properties. These are called in the appropriate places + // in the `onopen`, `onmessage`, `onclose`, and `onerror` methods set in the constructor. + + override set onopen(listener: (this: WebSocket, ev: Event) => any | null) { + this.onopenInternal = listener; + } + + override get onopen(): ((this: WebSocket, ev: Event) => any) | null { + return this.onopenInternal; + } + + override set onmessage( + listener: ((this: WebSocket, ev: MessageEvent) => any) | null + ) { + this.onmessageInternal = listener; + } + + override get onmessage(): + | ((this: WebSocket, ev: MessageEvent) => any) + | null { + return this.onmessageInternal; + } + + override set onclose( + listener: ((this: WebSocket, ev: CloseEvent) => any) | null + ) { + this.oncloseInternal = listener; + } + + override get onclose(): ((this: WebSocket, ev: CloseEvent) => any) | null { + return this.oncloseInternal; + } + + override set onerror(listener: ((this: WebSocket, ev: Event) => any) | null) { + this.onerrorInternal = listener; + } + + override get onerror(): ((this: WebSocket, ev: Event) => any) | null { + return this.onerrorInternal; + } + + // Override the binaryType property to store the user-defined binaryType in the appropriate internal property. + // This is because we need to set the binaryType to 'arraybuffer' for the authentication process (see constructor), + // and only then can we set it to the user-defined value. + override set binaryType(binaryType: BinaryType) { + if (this.authenticated) { + super.binaryType = binaryType; + return; + } + + this.binaryTypeInternal = binaryType; + } + + override get binaryType(): BinaryType { + return this.binaryTypeInternal; + } + + // Override removeEventListener to support listeners removal for 'open', 'message', and 'close' events + override removeEventListener( + type: K, + listener: (this: WebSocket, ev: WebSocketEventMap[K]) => any + ): void { + if (type === 'open') { + const index = this.openListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['open']) => any + ); + if (index !== -1) { + this.openListeners.splice(index, 1); + } + } else if (type === 'message') { + const index = this.messageListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['message']) => any + ); + if (index !== -1) { + this.messageListeners.splice(index, 1); + } + } else if (type === 'close') { + const index = this.oncloseListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['close']) => any + ); + if (index !== -1) { + this.oncloseListeners.splice(index, 1); + } + } else if (type === 'error') { + const index = this.onerrorListeners.indexOf( + listener as (this: WebSocket, ev: WebSocketEventMap['error']) => any + ); + if (index !== -1) { + this.onerrorListeners.splice(index, 1); + } + } else { + // This should be unreachable, but just in case. + super.removeEventListener( + type, + listener as EventListenerOrEventListenerObject + ); + } + } + + // Method to manually trigger an error event. + private triggerError(errorMessage: string): void { + const errorEvent = new ErrorEvent('error', { + error: new Error(errorMessage), + message: errorMessage, + }); + + // Dispatch the event to trigger all listeners attached for 'error' events. + this.dispatchEvent(errorEvent); + } +} diff --git a/web/packages/teleport/src/lib/tdp/client.ts b/web/packages/teleport/src/lib/tdp/client.ts index fd59ba4b5337b..3e54d04ea53d2 100644 --- a/web/packages/teleport/src/lib/tdp/client.ts +++ b/web/packages/teleport/src/lib/tdp/client.ts @@ -15,6 +15,7 @@ import Logger from 'shared/libs/logger'; import { WebsocketCloseCode, TermEvent } from 'teleport/lib/term/enums'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import Codec, { MessageType, @@ -63,12 +64,12 @@ export enum TdpClientEvent { } // Client is the TDP client. It is responsible for connecting to a websocket serving the tdp server, -// sending client commands, and recieving and processing server messages. Its creator is responsible for +// sending client commands, and receiving and processing server messages. Its creator is responsible for // ensuring the websocket gets closed and all of its event listeners cleaned up when it is no longer in use. // For convenience, this can be done in one fell swoop by calling Client.shutdown(). export default class Client extends EventEmitterWebAuthnSender { protected codec: Codec; - protected socket: WebSocket | undefined; + protected socket: AuthenticatedWebSocket | undefined; private socketAddr: string; private sdManager: SharedDirectoryManager; @@ -83,7 +84,7 @@ export default class Client extends EventEmitterWebAuthnSender { // Connect to the websocket and register websocket event handlers. init() { - this.socket = new WebSocket(this.socketAddr); + this.socket = new AuthenticatedWebSocket(this.socketAddr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = () => { diff --git a/web/packages/teleport/src/lib/term/tty.ts b/web/packages/teleport/src/lib/term/tty.ts index 852898ef020d9..ee4b9f1dd4904 100644 --- a/web/packages/teleport/src/lib/term/tty.ts +++ b/web/packages/teleport/src/lib/term/tty.ts @@ -18,6 +18,7 @@ import Logger from 'shared/libs/logger'; import { EventEmitterWebAuthnSender } from 'teleport/lib/EventEmitterWebAuthnSender'; import { WebauthnAssertionResponse } from 'teleport/services/auth'; +import { AuthenticatedWebSocket } from 'teleport/lib/AuthenticatedWebSocket'; import { EventType, TermEvent, WebsocketCloseCode } from './enums'; import { Protobuf, MessageTypeEnum } from './protobuf'; @@ -60,7 +61,7 @@ class Tty extends EventEmitterWebAuthnSender { connect(w: number, h: number) { const connStr = this._addressResolver.getConnStr(w, h); - this.socket = new WebSocket(connStr); + this.socket = new AuthenticatedWebSocket(connStr); this.socket.binaryType = 'arraybuffer'; this.socket.onopen = this._onOpenConnection; this.socket.onmessage = this._onMessage; diff --git a/web/packages/teleport/src/types.ts b/web/packages/teleport/src/types.ts index 250287fd74b75..485af3208eebe 100644 --- a/web/packages/teleport/src/types.ts +++ b/web/packages/teleport/src/types.ts @@ -178,3 +178,11 @@ export enum RecommendationStatus { Notify = 'NOTIFY', Done = 'DONE', } + +// WebsocketStatus is used to indicate the auth status from a +// websocket connection +export type WebsocketStatus = { + type: string; + status: string; + message?: string; +};