diff --git a/lib/web/command.go b/lib/web/command.go index 2117bf51be768..7a92ca483513d 100644 --- a/lib/web/command.go +++ b/lib/web/command.go @@ -165,7 +165,7 @@ func (h *Handler) executeCommand( CheckOrigin: func(r *http.Request) bool { return true }, } - ws, err := upgrader.Upgrade(w, r, nil) + rawWS, err := upgrader.Upgrade(w, r, nil) if err != nil { errMsg := "Error upgrading to websocket" h.log.WithError(err).Error(errMsg) @@ -174,16 +174,27 @@ func (h *Handler) executeCommand( } defer func() { - ws.WriteMessage(websocket.CloseMessage, nil) - ws.Close() + rawWS.WriteMessage(websocket.CloseMessage, nil) + rawWS.Close() }() keepAliveInterval := netConfig.GetKeepAliveInterval() - err = ws.SetReadDeadline(deadlineForInterval(keepAliveInterval)) + err = rawWS.SetReadDeadline(deadlineForInterval(keepAliveInterval)) if err != nil { h.log.WithError(err).Error("Error setting websocket readline") return nil, trace.Wrap(err) } + // Update the read deadline upon receiving a pong message. + rawWS.SetPongHandler(func(_ string) error { + // This is intentonally called without a lock as this callback is + // called from the same goroutine as the read loop which is already locked. + return trace.Wrap(rawWS.SetReadDeadline(deadlineForInterval(keepAliveInterval))) + }) + + // Wrap the raw websocket connection in a syncRWWSConn so that we can + // safely read and write to the the single websocket connection from + // multiple goroutines/execution nodes. + ws := &syncRWWSConn{WSConn: rawWS} hosts, err := findByQuery(ctx, clt, req.Query) if err != nil { @@ -506,11 +517,6 @@ func (t *commandHandler) handler(r *http.Request) { t.log.Debug("Creating websocket stream") - // Update the read deadline upon receiving a pong message. - t.ws.SetPongHandler(func(_ string) error { - return trace.Wrap(t.ws.SetReadDeadline(deadlineForInterval(t.keepAliveInterval))) - }) - // Start sending ping frames through websocket to the client. go startPingLoop(r.Context(), t.ws, t.keepAliveInterval, t.log, t.Close) diff --git a/lib/web/command_utils.go b/lib/web/command_utils.go index cc3a0835cc91e..db8ebdbc42a5a 100644 --- a/lib/web/command_utils.go +++ b/lib/web/command_utils.go @@ -22,9 +22,9 @@ import ( "encoding/json" "io" "net" + "sync" "time" - "github.com/gorilla/websocket" "github.com/gravitational/trace" ) @@ -39,13 +39,9 @@ type WSConn interface { WriteControl(messageType int, data []byte, deadline time.Time) error WriteMessage(messageType int, data []byte) error - NextReader() (messageType int, r io.Reader, err error) ReadMessage() (messageType int, p []byte, err error) SetReadLimit(limit int64) SetReadDeadline(t time.Time) error - PingHandler() func(appData string) error - SetPingHandler(h func(appData string) error) - PongHandler() func(appData string) error SetPongHandler(h func(appData string) error) } @@ -98,10 +94,39 @@ func newPayloadWriter(nodeID, outputName string, stream io.Writer) *payloadWrite // by any underlying code as we want to keep the connection open until the command // is executed on all nodes and a single failure should not close the connection. type noopCloserWS struct { - *websocket.Conn + WSConn } // Close does nothing. func (ws *noopCloserWS) Close() error { return nil } + +// syncRWWSConn is a wrapper around websocket.Conn, which serializes +// read and write to a web socket connection. This is needed to prevent +// a race conditions and panics in gorilla/websocket. +// Details https://pkg.go.dev/github.com/gorilla/websocket#hdr-Concurrency +// This struct does not lock SetReadDeadline() as the SetReadDeadline() +// is called from the pong handler, which is interanlly called on ReadMessage() +// according to https://pkg.go.dev/github.com/gorilla/websocket#hdr-Control_Messages +// This would prevent the pong handler from being called. +type syncRWWSConn struct { + // WSConn the underlying websocket connection. + WSConn + // rmtx is a mutex used to serialize reads. + rmtx sync.Mutex + // wmtx is a mutex used to serialize writes. + wmtx sync.Mutex +} + +func (s *syncRWWSConn) WriteMessage(messageType int, data []byte) error { + s.wmtx.Lock() + defer s.wmtx.Unlock() + return s.WSConn.WriteMessage(messageType, data) +} + +func (s *syncRWWSConn) ReadMessage() (messageType int, p []byte, err error) { + s.rmtx.Lock() + defer s.rmtx.Unlock() + return s.WSConn.ReadMessage() +}