diff --git a/lib/srv/db/mysql/engine.go b/lib/srv/db/mysql/engine.go index 83af22edac3e8..b7fdc7751c392 100644 --- a/lib/srv/db/mysql/engine.go +++ b/lib/srv/db/mysql/engine.go @@ -20,6 +20,7 @@ import ( "context" "crypto/tls" "fmt" + "io" "net" "time" @@ -383,34 +384,46 @@ func (e *Engine) receiveFromServer(serverConn, clientConn net.Conn, serverErrCh "client": clientConn.RemoteAddr(), "server": serverConn.RemoteAddr(), }) - defer func() { - log.Debug("Stop receiving from server.") - close(serverErrCh) - }() + messagesCounter := common.GetMessagesFromServerMetric(sessionCtx.Database) - msgFromServer := common.GetMessagesFromServerMetric(sessionCtx.Database) + // parse and count the messages from the server in a separate goroutine, + // operating on a copy of the server message stream. the copy is arranged below. + copyReader, copyWriter := io.Pipe() + defer copyWriter.Close() - for { - packet, _, err := protocol.ReadPacket(serverConn) - if err != nil { - if utils.IsOKNetworkError(err) { - log.Debug("Server connection closed.") + go func() { + defer copyReader.Close() + + var count int64 + defer func() { + log.WithField("parsed_total", count).Debug("Stopped parsing messages from server.") + }() + + for { + _, _, err := protocol.ReadPacket(copyReader) + if err != nil { return } - log.WithError(err).Error("Failed to read server packet.") - serverErrCh <- err - return - } - msgFromServer.Inc() + count += 1 + messagesCounter.Inc() + } + }() - _, err = protocol.WritePacket(packet, clientConn) - if err != nil { - log.WithError(err).Error("Failed to write client packet.") - serverErrCh <- err - return + // the messages are ultimately copied from serverConn to clientConn, + // but a copy of that message stream is written to a synchronous pipe, + // which is read by the analysis goroutine above. + total, err := io.Copy(clientConn, io.TeeReader(serverConn, copyWriter)) + if err != nil { + if utils.IsOKNetworkError(err) { + log.Debug("Server connection closed.") + } else { + log.WithError(err).Warn("Server -> Client copy finished with unexpected error.") } } + + log.Debugf("Stopped receiving from server. Transferred %v bytes.", total) + serverErrCh <- trace.Wrap(err) } // makeAcquireSemaphoreConfig builds parameters for acquiring a semaphore