Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions lib/web/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (h *Handler) executeCommand(
CheckOrigin: func(r *http.Request) bool { return true },
}

ws, err := upgrader.Upgrade(w, r, nil)
rawWS, err := upgrader.Upgrade(w, r, nil)
if err != nil {
errMsg := "Error upgrading to websocket"
h.log.WithError(err).Error(errMsg)
Expand All @@ -174,16 +174,27 @@ func (h *Handler) executeCommand(
}

defer func() {
ws.WriteMessage(websocket.CloseMessage, nil)
ws.Close()
rawWS.WriteMessage(websocket.CloseMessage, nil)
rawWS.Close()
}()

keepAliveInterval := netConfig.GetKeepAliveInterval()
err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval))
err = rawWS.SetReadDeadline(deadlineForInterval(keepAliveInterval))
if err != nil {
h.log.WithError(err).Error("Error setting websocket readline")
return nil, trace.Wrap(err)
}
// Update the read deadline upon receiving a pong message.
rawWS.SetPongHandler(func(_ string) error {
// This is intentonally called without a lock as this callback is
// called from the same goroutine as the read loop which is already locked.
return trace.Wrap(rawWS.SetReadDeadline(deadlineForInterval(keepAliveInterval)))
})

// Wrap the raw websocket connection in a syncRWWSConn so that we can
// safely read and write to the the single websocket connection from
// multiple goroutines/execution nodes.
ws := &syncRWWSConn{WSConn: rawWS}

hosts, err := findByQuery(ctx, clt, req.Query)
if err != nil {
Expand Down Expand Up @@ -506,11 +517,6 @@ func (t *commandHandler) handler(r *http.Request) {

t.log.Debug("Creating websocket stream")

// Update the read deadline upon receiving a pong message.
t.ws.SetPongHandler(func(_ string) error {
return trace.Wrap(t.ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval)))
})

// Start sending ping frames through websocket to the client.
go startPingLoop(r.Context(), t.ws, t.keepAliveInterval, t.log, t.Close)

Expand Down
37 changes: 31 additions & 6 deletions lib/web/command_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ import (
"encoding/json"
"io"
"net"
"sync"
"time"

"github.com/gorilla/websocket"
"github.com/gravitational/trace"
)

Expand All @@ -39,13 +39,9 @@ type WSConn interface {

WriteControl(messageType int, data []byte, deadline time.Time) error
WriteMessage(messageType int, data []byte) error
NextReader() (messageType int, r io.Reader, err error)
ReadMessage() (messageType int, p []byte, err error)
SetReadLimit(limit int64)
SetReadDeadline(t time.Time) error
PingHandler() func(appData string) error
SetPingHandler(h func(appData string) error)
PongHandler() func(appData string) error
SetPongHandler(h func(appData string) error)
}

Expand Down Expand Up @@ -98,10 +94,39 @@ func newPayloadWriter(nodeID, outputName string, stream io.Writer) *payloadWrite
// by any underlying code as we want to keep the connection open until the command
// is executed on all nodes and a single failure should not close the connection.
type noopCloserWS struct {
*websocket.Conn
WSConn
}

// Close does nothing.
func (ws *noopCloserWS) Close() error {
return nil
}

// syncRWWSConn is a wrapper around websocket.Conn, which serializes
// read and write to a web socket connection. This is needed to prevent
// a race conditions and panics in gorilla/websocket.
// Details https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency
// This struct does not lock SetReadDeadline() as the SetReadDeadline()
// is called from the pong handler, which is interanlly called on ReadMessage()
// according to https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages
// This would prevent the pong handler from being called.
type syncRWWSConn struct {
// WSConn the underlying websocket connection.
WSConn
// rmtx is a mutex used to serialize reads.
rmtx sync.Mutex
// wmtx is a mutex used to serialize writes.
wmtx sync.Mutex
}

func (s *syncRWWSConn) WriteMessage(messageType int, data []byte) error {
s.wmtx.Lock()
defer s.wmtx.Unlock()
return s.WSConn.WriteMessage(messageType, data)
}

func (s *syncRWWSConn) ReadMessage() (messageType int, p []byte, err error) {
s.rmtx.Lock()
defer s.rmtx.Unlock()
return s.WSConn.ReadMessage()
}