Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 31 additions & 30 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ func (c *Conn) readEphemeralPacket(ctx context.Context) ([]byte, error) {

// readEphemeralPacketDirect attempts to read a packet from the socket directly.
// It needs to be used for the first handshake packet the server receives,
// so we do't buffer the SSL negotiation packet. As a shortcut, only
// so we don't buffer the SSL negotiation packet. As a shortcut, only
// packets smaller than MaxPacketSize can be read here.
// This function usually shouldn't be used - use readEphemeralPacket.
func (c *Conn) readEphemeralPacketDirect(ctx context.Context) ([]byte, error) {
Expand Down Expand Up @@ -1075,8 +1075,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
case 1:
c.Capabilities &^= CapabilityClientMultiStatements
default:
log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Errorf("Got unhandled packet (ComSetOption default) from client %v", c.ConnectionID)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComSetOption packet"); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
}
Expand All @@ -1086,8 +1086,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
return err
}
} else {
log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Errorf("Got unhandled packet (ComSetOption else) from client %v", c.ConnectionID)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComSetOption packet"); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
}
Expand All @@ -1098,8 +1098,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {

if c.cs != nil {
log.Error("Received ComStmtPrepare with outstanding cursor")
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", werr)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtPrepare packet"); werr != nil {
log.Errorf("Error writing error packet to client: %v", werr)
return werr
}
return nil
Expand Down Expand Up @@ -1181,7 +1181,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
log.Errorf("unable to prepare query: %s", err.Error())
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}
return nil
Expand All @@ -1194,8 +1194,8 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
// outstanding cursor, error
if c.cs != nil {
log.Error("Received ComStmtExecute with outstanding cursor")
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", werr)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtExecute packet"); werr != nil {
log.Errorf("Error writing error packet to client: %v", werr)
return werr
}
return nil
Expand All @@ -1211,7 +1211,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
if err != nil {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
log.Errorf("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}
return c.flush(ctx)
Expand Down Expand Up @@ -1239,7 +1239,7 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
stmtID, paramID, chunk, ok := c.parseComStmtSendLongData(data)
c.recycleReadPacket()
if !ok {
err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data)
err := fmt.Errorf("error parsing statement send long data from client %v", c.ConnectionID)
log.Error(err.Error())
return err
}
Expand Down Expand Up @@ -1276,18 +1276,18 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
stmtID, ok := c.parseComStmtReset(data)
c.recycleReadPacket()
if !ok {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Error("Error writing error packet to client: %v", err)
log.Errorf("Got unhandled ComStmtReset packet from client %v", c.ConnectionID)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtReset packet"); err != nil {
log.Errorf("Error writing error packet to client: %v", err)
return err
}
}

prepare, ok := c.PrepareData[stmtID]
if !ok {
log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
if werr := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", err)
log.Errorf("Commands were executed in an improper order from client %v", c.ConnectionID)
if werr := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order"); werr != nil {
log.Errorf("Error writing error packet to client: %v", err)
return werr
}
}
Expand All @@ -1301,27 +1301,28 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {
c.discardCursor()

if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
log.Errorf("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
return err
}
case ComStmtFetch:
c.startWriterBuffering()
stmtID, numRows, ok := c.parseComStmtFetch(data)
c.recycleReadPacket()
if !ok {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", werr)
log.Errorf("Unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError,
"unable to parse COM_STMT_FETCH message on connection %v", c.ConnectionID); werr != nil {
log.Errorf("Error writing error packet to client: %v", werr)
return werr
}
return c.flush(ctx)
}

// fetching from wrong statement
if c.cs == nil || stmtID != c.cs.stmtID {
log.Errorf("Requested stmtID does not match stmtID of open cursor. Client %v, returning error: %v", c.ConnectionID, data)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); werr != nil {
log.Error("Error writing error packet to client: %v", err)
log.Errorf("Requested stmtID does not match stmtID of open cursor. Client %v", c.ConnectionID)
if werr := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling ComStmtFetch packet"); werr != nil {
log.Errorf("Error writing error packet to client: %v", err)
return werr
}
return c.flush(ctx)
Expand Down Expand Up @@ -1393,15 +1394,17 @@ func (c *Conn) handleNextCommand(ctx context.Context, handler Handler) error {

case ComBinlogDumpGTID:
ok := c.handleComBinlogDumpGTID(handler, data)
c.recycleReadPacket()
if !ok {
return fmt.Errorf("error handling ComBinlogDumpGTID packet: %v", data)
return fmt.Errorf("error handling ComBinlogDumpGTID packet")
}
return nil

case ComRegisterReplica:
ok := c.handleComRegisterReplica(handler, data)
c.recycleReadPacket()
if !ok {
return fmt.Errorf("error handling ComRegisterReplica packet: %v", data)
return fmt.Errorf("error handling ComRegisterReplica packet")
}
return nil

Expand Down Expand Up @@ -1430,8 +1433,6 @@ func (c *Conn) handleComRegisterReplica(handler Handler, data []byte) (kontinue
return false
}

c.recycleReadPacket()

if err := binlogReplicaHandler.ComRegisterReplica(c, replicaHost, replicaPort, replicaUser, replicaPassword); err != nil {
c.writeErrorPacketFromError(err)
return false
Expand Down Expand Up @@ -1465,7 +1466,7 @@ func (c *Conn) handleComBinlogDumpGTID(handler Handler, data []byte) (kontinue b
log.Errorf("conn %v: parseComBinlogDumpGTID failed: %v", c.ID(), err)
return false
}
c.recycleReadPacket()

if err := binlogReplicaHandler.ComBinlogDumpGTID(c, logFile, logPos, position.GTIDSet); err != nil {
log.Error(err.Error())
c.writeErrorPacketFromError(err)
Expand Down
11 changes: 9 additions & 2 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -407,12 +407,12 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3

// Returns copies of the data, so we can recycle the buffer.
user, clientAuthMethod, clientAuthResponse, err = l.parseClientHandshakePacket(c, false, response)
c.recycleReadPacket()
if err != nil {
l.handleConnectionError(c, fmt.Sprintf(
"Cannot parse post-SSL client handshake response from %s: %v", c, err))
return
}
c.recycleReadPacket()

if con, ok := c.Conn.(*tls.Conn); ok {
connState := con.ConnectionState()
Expand Down Expand Up @@ -473,12 +473,19 @@ func (l *Listener) handle(ctx context.Context, conn net.Conn, connectionID uint3
return
}

clientAuthResponse, err = c.readEphemeralPacket(context.Background())
data, err := c.readEphemeralPacket(context.Background())
if err != nil {
l.handleConnectionError(c, fmt.Sprintf("Error reading auth switch response for %s: %v", c, err))
return
}

var ok bool
clientAuthResponse, _, ok = readBytesCopy(data, 0, len(data))
c.recycleReadPacket()
if !ok {
l.handleConnectionError(c, fmt.Sprintf("Unable to copy client auth response for %s", c))
return
}
}

userData, err := negotiatedAuthMethod.HandleAuthPluginData(c, user, serverAuthPluginData, clientAuthResponse, conn.RemoteAddr())
Expand Down