diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 093b474fb2b..7d0d9e4e146 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -421,7 +421,7 @@ func (c *Conn) readEphemeralPacket(ctx context.Context) ([]byte, error) { // readEphemeralPacketDirect attempts to read a packet from the socket directly. // It needs to be used for the first handshake packet the server receives, -// so we do't buffer the SSL negotiation packet. As a shortcut, only +// so we don't buffer the SSL negotiation packet. As a shortcut, only // packets smaller than MaxPacketSize can be read here. // This function usually shouldn't be used - use readEphemeralPacket. func (c *Conn) readEphemeralPacketDirect(ctx context.Context) ([]byte, error) { @@ -1075,8 +1075,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { case 1: c.Capabilities &^= CapabilityClientMultiStatements default: - log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { + log.Errorf("Got unhandled packet (ComSetOption default) from client %v", c.ConnectionID) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComSetOption packet"); err != nil { log.Errorf("Error writing error packet to client: %v", err) return err } @@ -1086,8 +1086,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { return err } } else { - log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { + log.Errorf("Got unhandled packet (ComSetOption else) from client %v", c.ConnectionID) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComSetOption packet"); err != nil { log.Errorf("Error writing error packet to client: %v", err) return err } @@ -1098,8 +1098,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { if c.cs != nil { log.Error("Received ComStmtPrepare with outstanding cursor") - if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", werr) + if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtPrepare packet"); werr != nil { + log.Errorf("Error writing error packet to client: %v", werr) return werr } return nil @@ -1181,7 +1181,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { log.Errorf("unable to prepare query: %s", err.Error()) if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr) return werr } return nil @@ -1194,8 +1194,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { // outstanding cursor, error if c.cs != nil { log.Error("Received ComStmtExecute with outstanding cursor") - if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", werr) + if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtExecute packet"); werr != nil { + log.Errorf("Error writing error packet to client: %v", werr) return werr } return nil @@ -1211,7 +1211,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { if err != nil { if werr := c.writeErrorPacketFromError(err); werr != nil { // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr) return werr } return c.flush(ctx) @@ -1239,7 +1239,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data) c.recycleReadPacket() if !ok { - err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) + err := fmt.Errorf("error parsing statement send long data from client %v", c.ConnectionID) log.Error(err.Error()) return err } @@ -1276,18 +1276,18 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { stmtID, ok := c.parseComStmtReset(data) c.recycleReadPacket() if !ok { - log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Error("Error writing error packet to client: %v", err) + log.Errorf("Got unhandled ComStmtReset packet from client %v", c.ConnectionID) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtReset packet"); err != nil { + log.Errorf("Error writing error packet to client: %v", err) return err } } prepare, ok := c.PrepareData[stmtID] if !ok { - log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) - if werr := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", err) + log.Errorf("Commands were executed in an improper order from client %v", c.ConnectionID) + if werr := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order"); werr != nil { + log.Errorf("Error writing error packet to client: %v", err) return werr } } @@ -1301,7 +1301,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { c.discardCursor() if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) + log.Errorf("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) return err } case ComStmtFetch: @@ -1309,9 +1309,10 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { stmtID, numRows, ok := c.parseComStmtFetch(data) c.recycleReadPacket() if !ok { - log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) - if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", werr) + log.Errorf("Unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID) + if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, + "unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID); werr != nil { + log.Errorf("Error writing error packet to client: %v", werr) return werr } return c.flush(ctx) @@ -1319,9 +1320,9 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { // fetching from wrong statement if c.cs == nil || stmtID != c.cs.stmtID { - log.Errorf("Requested stmtID does not match stmtID of open cursor. Client %v, returning error: %v", c.ConnectionID, data) - if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil { - log.Error("Error writing error packet to client: %v", err) + log.Errorf("Requested stmtID does not match stmtID of open cursor. Client %v", c.ConnectionID) + if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtFetch packet"); werr != nil { + log.Errorf("Error writing error packet to client: %v", err) return werr } return c.flush(ctx) @@ -1393,15 +1394,17 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error { case ComBinlogDumpGTID: ok := c.handleComBinlogDumpGTID(handler, data) + c.recycleReadPacket() if !ok { - return fmt.Errorf("error handling ComBinlogDumpGTID packet: %v", data) + return fmt.Errorf("error handling ComBinlogDumpGTID packet") } return nil case ComRegisterReplica: ok := c.handleComRegisterReplica(handler, data) + c.recycleReadPacket() if !ok { - return fmt.Errorf("error handling ComRegisterReplica packet: %v", data) + return fmt.Errorf("error handling ComRegisterReplica packet") } return nil @@ -1430,8 +1433,6 @@ func (c *Conn) handleComRegisterReplica(handler Handler, data []byte) (kontinue return false } - c.recycleReadPacket() - if err := binlogReplicaHandler.ComRegisterReplica(c, replicaHost, replicaPort, replicaUser, replicaPassword); err != nil { c.writeErrorPacketFromError(err) return false @@ -1465,7 +1466,7 @@ func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue b log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err) return false } - c.recycleReadPacket() + if err := binlogReplicaHandler.ComBinlogDumpGTID(c, logFile, logPos, position.GTIDSet); err != nil { log.Error(err.Error()) c.writeErrorPacketFromError(err) diff --git a/go/mysql/server.go b/go/mysql/server.go index 3fed025dd29..38d07b667ee 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -407,12 +407,12 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3 // Returns copies of the data, so we can recycle the buffer. user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response) + c.recycleReadPacket() if err != nil { l.handleConnectionError(c, fmt.Sprintf( "Cannot parse post-SSL client handshake response from %s: %v", c, err)) return } - c.recycleReadPacket() if con, ok := c.Conn.(*tls.Conn); ok { connState := con.ConnectionState() @@ -473,12 +473,19 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3 return } - clientAuthResponse, err = c.readEphemeralPacket(context.Background()) + data, err := c.readEphemeralPacket(context.Background()) if err != nil { l.handleConnectionError(c, fmt.Sprintf("Error reading auth switch response for %s: %v", c, err)) return } + + var ok bool + clientAuthResponse, _, ok = readBytesCopy(data, 0, len(data)) c.recycleReadPacket() + if !ok { + l.handleConnectionError(c, fmt.Sprintf("Unable to copy client auth response for %s", c)) + return + } } userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr())