Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions lib/web/desktop/playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ import (
"encoding/json"
"fmt"

"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"golang.org/x/net/websocket"

"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/player"
Expand Down Expand Up @@ -68,7 +68,8 @@ func ReceivePlaybackActions(

for {
var action actionMessage
if err := websocket.JSON.Receive(ws, &action); err != nil {

if err := ws.ReadJSON(&action); err != nil {
// Connection close errors are expected if the user closes the tab.
// Only log unexpected errors to avoid cluttering the logs.
if !utils.IsOKNetworkError(err) {
Expand Down Expand Up @@ -118,12 +119,13 @@ func PlayRecording(
msg = []byte(`"internal server error"`)
}
//lint:ignore QF1012 this write needs to happen in a single operation
if _, err := ws.Write([]byte(fmt.Sprintf(`{"message":"error", "errorText":%s}`, string(msg)))); err != nil {
bytes := []byte(fmt.Sprintf(`{"message":"error", "errorText":%s}`, string(msg)))
if err := ws.WriteMessage(websocket.BinaryMessage, bytes); err != nil {
log.Errorf("failed to write error message: %v", err)
}
return
}
if _, err := ws.Write([]byte(`{"message":"end"}`)); err != nil {
if err := ws.WriteMessage(websocket.BinaryMessage, []byte(`{"message":"end"}`)); err != nil {
log.Errorf("failed to write end message: %v", err)
}
return
Expand All @@ -137,10 +139,10 @@ func PlayRecording(
msg, err := utils.FastMarshal(evt)
if err != nil {
log.Errorf("failed to marshal desktop event: %v", err)
ws.Write([]byte(`{"message":"error","errorText":"server error"}`))
ws.WriteMessage(websocket.BinaryMessage, []byte(`{"message":"error","errorText":"server error"}`))
return
}
if _, err := ws.Write(msg); err != nil {
if err := ws.WriteMessage(websocket.BinaryMessage, msg); err != nil {
// Connection close errors are expected if the user closes the tab.
// Only log unexpected errors to avoid cluttering the logs.
if !utils.IsOKNetworkError(err) {
Expand Down
57 changes: 34 additions & 23 deletions lib/web/desktop/playback_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import (
"testing"
"time"

"github.com/gorilla/websocket"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/websocket"

apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/events/eventstest"
Expand All @@ -47,31 +47,34 @@ func TestStreamsDesktopEvents(t *testing.T) {
&apievents.DesktopRecording{Message: []byte("jkl")},
}
s := newServer(t, 20*time.Millisecond, events)
url := strings.Replace(s.URL, "http", "ws", 1)
cfg, err := websocket.NewConfig(url, "http://localhost")
require.NoError(t, err)

// connect to the server and verify that we receive
// all 4 JSON-encoded events
ws, err := websocket.DialConfig(cfg)
url := strings.Replace(s.URL, "http", "ws", 1)

// As per https://pkg.go.dev/github.com/gorilla/websocket#Dialer.DialContext:
// "The response body may not contain the entire response and does not need to be closed by the application."
//nolint:bodyclose // false positive
ws, _, err := websocket.DefaultDialer.Dial(url, nil)
require.NoError(t, err)

t.Cleanup(func() { ws.Close() })

for _, evt := range events {
b := make([]byte, 4096)
n, err := ws.Read(b)
typ, b, err := ws.ReadMessage()
require.NoError(t, err)
require.Equal(t, websocket.BinaryMessage, typ)

var dr apievents.DesktopRecording
err = utils.FastUnmarshal(b[:n], &dr)
err = utils.FastUnmarshal(b, &dr)
require.NoError(t, err)
require.Equal(t, evt.(*apievents.DesktopRecording).Message, dr.Message)
}

b := make([]byte, 4096)
n, err := ws.Read(b)
typ, b, err := ws.ReadMessage()
require.NoError(t, err)
require.JSONEq(t, `{"message":"end"}`, string(b[:n]))
require.Equal(t, websocket.BinaryMessage, typ)
require.JSONEq(t, `{"message":"end"}`, string(b))
}

func newServer(t *testing.T, streamInterval time.Duration, events []apievents.AuditEvent) *httptest.Server {
Expand All @@ -81,19 +84,27 @@ func newServer(t *testing.T, streamInterval time.Duration, events []apievents.Au
log := utils.NewLoggerForTests()

s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
websocket.Handler(func(ws *websocket.Conn) {
player, err := player.New(&player.Config{
Clock: clockwork.NewRealClock(),
Log: log,
SessionID: session.ID("session-id"),
Streamer: fs,
})
assert.NoError(t, err)
player.Play()
desktop.PlayRecording(r.Context(), log, ws, player)
}).ServeHTTP(w, r)
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
}
ws, err := upgrader.Upgrade(w, r, nil)
if err != nil {
return
}
defer ws.Close()

player, err := player.New(&player.Config{
Clock: clockwork.NewRealClock(),
Log: log,
SessionID: session.ID("session-id"),
Streamer: fs,
})
assert.NoError(t, err)
player.Play()
desktop.PlayRecording(r.Context(), log, ws, player)
}))
t.Cleanup(s.Close)

t.Cleanup(s.Close)
return s
}
64 changes: 35 additions & 29 deletions lib/web/desktop_playback.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
"context"
"net/http"

"github.com/gorilla/websocket"
"github.com/gravitational/trace"
"github.com/julienschmidt/httprouter"
"golang.org/x/net/websocket"

"github.com/gravitational/teleport/lib/player"
"github.com/gravitational/teleport/lib/reversetunnelclient"
Expand All @@ -49,39 +49,45 @@ func (h *Handler) desktopPlaybackHandle(
return nil, trace.Wrap(err)
}

websocket.Handler(func(ws *websocket.Conn) {
ws.PayloadType = websocket.BinaryFrame

player, err := player.New(&player.Config{
Clock: h.clock,
Log: h.log,
SessionID: session.ID(sID),
Streamer: clt,
})
if err != nil {
h.log.Errorf("couldn't create player for session %v: %v", sID, err)
ws.Write([]byte(`{"message": "error", "errorText": "Internal server error"}`))
return
}
upgrader := websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
}
ws, err := upgrader.Upgrade(w, r, nil)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with this code, but WS after it's created with gorilla library should:

  1. Set ping/pong handler, so the connection won't terminate because of timeout
  2. Send WS 1006 message (WS connection closed) before closing the connection. I don't see it anywhere in this PR
  3. Our implementation in other places sets some custom timeouts.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I will look into whether ping/pong handler is required here. We didn't have any sort of ping happening with x/net/websocket so I suspect its fine.
  2. This is completely optional and really only makes sense when you need graceful shutdown or when one side of the connection is going to shut down before the other. We use ws.Close() (which doesn't send a 1006) already in lib/web/terminal and lib/web/assistant.
  3. I don't see any reason why that is necessary here.

Copy link
Copy Markdown
Contributor

@rosstimothy rosstimothy Jan 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gorilla sets a default ping and pong handler when connections are created

if err != nil {
return nil, trace.Wrap(err)
}
defer ws.Close()

defer player.Close()
player, err := player.New(&player.Config{
Clock: h.clock,
Log: h.log,
SessionID: session.ID(sID),
Streamer: clt,
})
if err != nil {
h.log.Errorf("couldn't create player for session %v: %v", sID, err)
ws.WriteMessage(websocket.BinaryMessage,
[]byte(`{"message": "error", "errorText": "Internal server error"}`))
return nil, nil
}

ctx, cancel := context.WithCancel(r.Context())
defer cancel()
defer player.Close()

go func() {
defer cancel()
desktop.ReceivePlaybackActions(h.log, ws, player)
}()
ctx, cancel := context.WithCancel(r.Context())
defer cancel()

go func() {
defer cancel()
defer ws.Close()
desktop.PlayRecording(ctx, h.log, ws, player)
}()
go func() {
defer cancel()
desktop.ReceivePlaybackActions(h.log, ws, player)
}()

<-ctx.Done()
}).ServeHTTP(w, r)
go func() {
defer cancel()
defer ws.Close()
desktop.PlayRecording(ctx, h.log, ws, player)
}()

<-ctx.Done()
return nil, nil
}