diff --git a/server/serverimpl.go b/server/serverimpl.go index 12048199..f303df14 100644 --- a/server/serverimpl.go +++ b/server/serverimpl.go @@ -249,7 +249,7 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co // Loop until fail to read from the WebSocket connection. for { - msgContext := context.Background() + msgContext := context.Background() // reqContext is cancelled when the ServerHTTP method is returned, for WebSockets connections this happens before this loop even starts. request := protobufs.AgentToServer{} // Block until the next message can be read. @@ -311,8 +311,8 @@ func (s *server) handleWSConnection(reqCtx context.Context, wsConn *websocket.Co } } -func decompressGzip(data []byte) ([]byte, error) { - r, err := gzip.NewReader(bytes.NewBuffer(data)) +func decompressGzip(data io.Reader) ([]byte, error) { + r, err := gzip.NewReader(data) if err != nil { return nil, err } @@ -321,15 +321,16 @@ func decompressGzip(data []byte) ([]byte, error) { } func (s *server) readReqBody(req *http.Request) ([]byte, error) { - data, err := io.ReadAll(req.Body) - if err != nil { - return nil, err - } if req.Header.Get(headerContentEncoding) == contentEncodingGzip { - data, err = decompressGzip(data) + data, err := decompressGzip(req.Body) if err != nil { return nil, err } + return data, nil + } + data, err := io.ReadAll(req.Body) + if err != nil { + return nil, err } return data, nil } @@ -391,6 +392,7 @@ func (s *server) handlePlainHTTPRequest(req *http.Request, w http.ResponseWriter } // Return the CustomCapabilities + // Note that unlike a WebSocket response, this is included in all HTTP responses. response.CustomCapabilities = &protobufs.CustomCapabilities{ Capabilities: s.settings.CustomCapabilities, } diff --git a/server/serverimpl_test.go b/server/serverimpl_test.go index 74b1f1cc..62f48e75 100644 --- a/server/serverimpl_test.go +++ b/server/serverimpl_test.go @@ -964,11 +964,8 @@ func TestServerHonoursAcceptEncoding(t *testing.T) { // Verify the received message is what was sent. assert.True(t, proto.Equal(rcvMsg.Load().(proto.Message), &sendMsg)) - // Read Server's response. - b, err = io.ReadAll(resp.Body) - require.NoError(t, err) // Decompress the gzip response - b, err = decompressGzip(b) + b, err = decompressGzip(resp.Body) require.NoError(t, err) assert.EqualValues(t, http.StatusOK, resp.StatusCode)