diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 7a220168df4..bbb30d241d8 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -190,6 +190,7 @@ type PrepareData struct { BindVars map[string]*querypb.BindVariable } +// execResult is an enum signifying the result of executing a query type execResult byte const ( @@ -720,6 +721,23 @@ func (c *Conn) writeOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, fla return c.writeEphemeralPacket() } +func (c *Conn) writeErrorAndLog(errorCode uint16, sqlState string, format string, args ...interface{}) bool { + if err := c.writeErrorPacket(errorCode, sqlState, format, args...); err != nil { + log.Errorf("Error writing error to %s: %v", c, err) + return false + } + return true +} + +func (c *Conn) writeErrorPacketFromErrorAndLog(err error) bool { + werr := c.writeErrorPacketFromError(err) + if werr != nil { + log.Errorf("Error writing error to %s: %v", c, werr) + return false + } + return true +} + // writeErrorPacket writes an error packet. // Server -> Client. // This method returns a generic error, not a SQLError. @@ -787,361 +805,350 @@ func (c *Conn) handleNextCommand(handler Handler) bool { res := c.execQuery("use "+sqlescape.EscapeID(db), handler, false) return res != connErr case ComQuery: - c.startWriterBuffering() - defer func() { - if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) - } - }() - - queryStart := time.Now() - query := c.parseComQuery(data) - c.recycleReadPacket() - - var queries []string - if c.Capabilities&CapabilityClientMultiStatements != 0 { - queries, err = sqlparser.SplitStatementToPieces(query) - if err != nil { - log.Errorf("Conn %v: Error splitting query: %v", c, err) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return false - } - } - } else { - queries = []string{query} - } - for index, sql := range queries { - more := false - if index != len(queries)-1 { - more = true - } - res := c.execQuery(sql, handler, more) - if res != execSuccess { - return res != connErr - } - } - - timings.Record(queryTimingKey, queryStart) - + return c.handleComQuery(handler, data) case ComPing: - c.recycleReadPacket() - // Return error if listener was shut down and OK otherwise - if c.listener.isShutdown() { - if err := c.writeErrorPacket(ERServerShutdown, SSServerShutdown, "Server shutdown in progress"); err != nil { - log.Errorf("Error writing ComPing error to %s: %v", c, err) - return false - } - } else { - if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Errorf("Error writing ComPing result to %s: %v", c, err) - return false - } - } - + return c.handleComPing() case ComSetOption: - operation, ok := c.parseComSetOption(data) + return c.handleComSetOption(data) + case ComPrepare: + return c.handleComPrepare(handler, data) + case ComStmtExecute: + return c.handleComStmtExecute(handler, data) + case ComStmtSendLongData: + return c.handleComStmtSendLongData(data) + case ComStmtClose: + stmtID, ok := c.parseComStmtClose(data) c.recycleReadPacket() if ok { - switch operation { - case 0: - c.Capabilities |= CapabilityClientMultiStatements - case 1: - c.Capabilities &^= CapabilityClientMultiStatements - default: - log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Errorf("Error writing error packet to client: %v", err) - return false - } - } - if err := c.writeEndResult(false, 0, 0, 0); err != nil { - log.Errorf("Error writeEndResult error %v ", err) - return false - } - } else { - log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Errorf("Error writing error packet to client: %v", err) - return false - } + delete(c.PrepareData, stmtID) } + case ComStmtReset: + return c.handleComStmtReset(data) + case ComResetConnection: + c.handleComResetConnection(handler) + return true - case ComPrepare: - query := c.parseComPrepare(data) + default: + log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data) c.recycleReadPacket() + if !c.writeErrorAndLog(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]) { + return false + } + } - var queries []string - if c.Capabilities&CapabilityClientMultiStatements != 0 { - queries, err = sqlparser.SplitStatementToPieces(query) - if err != nil { - log.Errorf("Conn %v: Error splitting query: %v", c, err) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return false - } - } - if len(queries) != 1 { - log.Errorf("Conn %v: can not prepare multiple statements", c, err) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Conn %v: Error writing query error: %v", c, werr) - return false - } - return true - } - } else { - queries = []string{query} + return true +} + +func (c *Conn) handleComResetConnection(handler Handler) { + // Clean up and reset the connection + c.recycleReadPacket() + handler.ComResetConnection(c) + // Reset prepared statements + c.PrepareData = make(map[uint32]*PrepareData) + err := c.writeOKPacket(0, 0, 0, 0) + if err != nil { + c.writeErrorPacketFromError(err) + } +} + +func (c *Conn) handleComStmtReset(data []byte) bool { + stmtID, ok := c.parseComStmtReset(data) + c.recycleReadPacket() + if !ok { + log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data) { + return false } + } - // Popoulate PrepareData - c.StatementID++ - prepare := &PrepareData{ - StatementID: c.StatementID, - PrepareStmt: queries[0], + prepare, ok := c.PrepareData[stmtID] + if !ok { + log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data) { + return false } + } - statement, err := sqlparser.ParseStrictDDL(query) - if err != nil { - log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Conn %v: Error writing prepared statement error: %v", c, werr) - return false - } + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil } + } - paramsCount := uint16(0) - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { - switch node := node.(type) { - case sqlparser.Argument: - if strings.HasPrefix(string(node), ":v") { - paramsCount++ - } - } - return true, nil - }, statement) + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) + return false + } + return true +} - if paramsCount > 0 { - prepare.ParamsCount = paramsCount - prepare.ParamsType = make([]int32, paramsCount) - prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) - } +func (c *Conn) handleComStmtSendLongData(data []byte) bool { + stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) + c.recycleReadPacket() + if !ok { + err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) + return c.writeErrorPacketFromErrorAndLog(err) + } - bindVars := make(map[string]*querypb.BindVariable, paramsCount) - for i := uint16(0); i < paramsCount; i++ { - parameterID := fmt.Sprintf("v%d", i+1) - bindVars[parameterID] = &querypb.BindVariable{} - } + prepare, ok := c.PrepareData[stmtID] + if !ok { + err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) + return !c.writeErrorPacketFromErrorAndLog(err) + } - c.PrepareData[c.StatementID] = prepare + if prepare.BindVars == nil || + prepare.ParamsCount == uint16(0) || + paramID >= prepare.ParamsCount { + err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) + return !c.writeErrorPacketFromErrorAndLog(err) + } - fld, err := handler.ComPrepare(c, queries[0], bindVars) + chunk := make([]byte, len(chunkData)) + copy(chunk, chunkData) - if err != nil { - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return false - } - return true - } + key := fmt.Sprintf("v%d", paramID+1) + if val, ok := prepare.BindVars[key]; ok { + val.Value = append(val.Value, chunk...) + } else { + prepare.BindVars[key] = sqltypes.BytesBindVariable(chunk) + } + return true +} - if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { - log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err) - return false +func (c *Conn) handleComStmtExecute(handler Handler, data []byte) (kontinue bool) { + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + kontinue = false } + }() + queryStart := time.Now() + stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + c.recycleReadPacket() - case ComStmtExecute: - c.startWriterBuffering() + if stmtID != uint32(0) { defer func() { - if err := c.endWriterBuffering(); err != nil { - log.Errorf("conn %v: flush() failed: %v", c.ID(), err) - } + // Allocate a new bindvar map every time since VTGate.Execute() mutates it. + prepare := c.PrepareData[stmtID] + prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) }() - queryStart := time.Now() - stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) - c.recycleReadPacket() + } - if stmtID != uint32(0) { - defer func() { - // Allocate a new bindvar map every time since VTGate.Execute() mutates it. - prepare := c.PrepareData[stmtID] - prepare.BindVars = make(map[string]*querypb.BindVariable, prepare.ParamsCount) - }() - } + if err != nil { + return !c.writeErrorPacketFromErrorAndLog(err) + } - if err != nil { - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return false - } - return true + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + prepare := c.PrepareData[stmtID] + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF } - fieldSent := false - // sendFinished is set if the response should just be an OK packet. - sendFinished := false - prepare := c.PrepareData[stmtID] - err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { - if sendFinished { - // Failsafe: Unreachable if server is well-behaved. - return io.EOF - } - - if !fieldSent { - fieldSent = true - - if len(qr.Fields) == 0 { - sendFinished = true - // We should not send any more packets after this. - return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) - } - if err := c.writeFields(qr); err != nil { - return err - } - } - - return c.writeBinaryRows(qr) - }) - - // If no field was sent, we expect an error. if !fieldSent { - // This is just a failsafe. Should never happen. - if err == nil || err == io.EOF { - err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) - } - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Error writing query error to %s: %v", c, werr) - return false - } - } else { - if err != nil { - // We can't send an error in the middle of a stream. - // All we can do is abort the send, which will cause a 2013. - log.Errorf("Error in the middle of a stream to %s: %v", c, err) - return false - } + fieldSent = true - // Send the end packet only sendFinished is false (results were streamed). - // In this case the affectedRows and lastInsertID are always 0 since it - // was a read operation. - if !sendFinished { - if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { - log.Errorf("Error writing result to %s: %v", c, err) - return false - } + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) + } + if err := c.writeFields(qr); err != nil { + return err } } - timings.Record(queryTimingKey, queryStart) + return c.writeBinaryRows(qr) + }) - case ComStmtSendLongData: - stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) - c.recycleReadPacket() - if !ok { - err = fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return false - } - return true + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) } - - prepare, ok := c.PrepareData[stmtID] - if !ok { - err = fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) - return false - } - return true + if !c.writeErrorPacketFromErrorAndLog(err) { + return false + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return false } - if prepare.BindVars == nil || - prepare.ParamsCount == uint16(0) || - paramID >= prepare.ParamsCount { - err = fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) return false } - return true } + } - chunk := make([]byte, len(chunkData)) - copy(chunk, chunkData) + timings.Record(queryTimingKey, queryStart) + return true +} + +func (c *Conn) handleComPrepare(handler Handler, data []byte) bool { + query := c.parseComPrepare(data) + c.recycleReadPacket() - key := fmt.Sprintf("v%d", paramID+1) - if val, ok := prepare.BindVars[key]; ok { - val.Value = append(val.Value, chunk...) - } else { - prepare.BindVars[key] = sqltypes.BytesBindVariable(chunk) + var queries []string + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err := sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + return !c.writeErrorPacketFromErrorAndLog(err) } - case ComStmtClose: - stmtID, ok := c.parseComStmtClose(data) - c.recycleReadPacket() - if ok { - delete(c.PrepareData, stmtID) + if len(queries) != 1 { + log.Errorf("Conn %v: can not prepare multiple statements", c, err) + return !c.writeErrorPacketFromErrorAndLog(err) } - case ComStmtReset: - stmtID, ok := c.parseComStmtReset(data) - c.recycleReadPacket() - if !ok { - log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { - log.Error("Error writing error packet to client: %v", err) - return false - } + } else { + queries = []string{query} + } + + // Popoulate PrepareData + c.StatementID++ + prepare := &PrepareData{ + StatementID: c.StatementID, + PrepareStmt: queries[0], + } + + statement, err := sqlparser.ParseStrictDDL(query) + if err != nil { + log.Errorf("Conn %v: Error parsing prepared statement: %v", c, err) + if !c.writeErrorPacketFromErrorAndLog(err) { + return false } + } - prepare, ok := c.PrepareData[stmtID] - if !ok { - log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) - if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil { - log.Error("Error writing error packet to client: %v", err) - return false + paramsCount := uint16(0) + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { + switch node := node.(type) { + case sqlparser.Argument: + if strings.HasPrefix(string(node), ":v") { + paramsCount++ } } + return true, nil + }, statement) + + if paramsCount > 0 { + prepare.ParamsCount = paramsCount + prepare.ParamsType = make([]int32, paramsCount) + prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) + } - if prepare.BindVars != nil { - for k := range prepare.BindVars { - prepare.BindVars[k] = nil + bindVars := make(map[string]*querypb.BindVariable, paramsCount) + for i := uint16(0); i < paramsCount; i++ { + parameterID := fmt.Sprintf("v%d", i+1) + bindVars[parameterID] = &querypb.BindVariable{} + } + + c.PrepareData[c.StatementID] = prepare + + fld, err := handler.ComPrepare(c, queries[0], bindVars) + + if err != nil { + return !c.writeErrorPacketFromErrorAndLog(err) + } + + if err := c.writePrepare(fld, c.PrepareData[c.StatementID]); err != nil { + log.Error("Error writing prepare data to client %v: %v", c.ConnectionID, err) + return false + } + return true +} + +func (c *Conn) handleComSetOption(data []byte) bool { + operation, ok := c.parseComSetOption(data) + c.recycleReadPacket() + if ok { + switch operation { + case 0: + c.Capabilities |= CapabilityClientMultiStatements + case 1: + c.Capabilities &^= CapabilityClientMultiStatements + default: + log.Errorf("Got unhandled packet (ComSetOption default) from client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data) { + return false } } + if err := c.writeEndResult(false, 0, 0, 0); err != nil { + log.Errorf("Error writeEndResult error %v ", err) + return false + } + } else { + log.Errorf("Got unhandled packet (ComSetOption else) from client %v, returning error: %v", c.ConnectionID, data) + if !c.writeErrorAndLog(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data) { + return false + } + } + return true +} +func (c *Conn) handleComPing() bool { + c.recycleReadPacket() + // Return error if listener was shut down and OK otherwise + if c.listener.isShutdown() { + if !c.writeErrorAndLog(ERServerShutdown, SSServerShutdown, "Server shutdown in progress") { + return false + } + } else { if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { - log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) + log.Errorf("Error writing ComPing result to %s: %v", c, err) return false } + } + return true +} - case ComResetConnection: - // Clean up and reset the connection - c.recycleReadPacket() - handler.ComResetConnection(c) - // Reset prepared statements - c.PrepareData = make(map[uint32]*PrepareData) - err = c.writeOKPacket(0, 0, 0, 0) - if err != nil { - c.writeErrorPacketFromError(err) +func (c *Conn) handleComQuery(handler Handler, data []byte) (kontinue bool) { + c.startWriterBuffering() + defer func() { + if err := c.endWriterBuffering(); err != nil { + log.Errorf("conn %v: flush() failed: %v", c.ID(), err) + kontinue = false } + }() - default: - log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data) - c.recycleReadPacket() - if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "command handling not implemented yet: %v", data[0]); err != nil { - log.Errorf("Error writing error packet to %s: %s", c, err) - return false + queryStart := time.Now() + query := c.parseComQuery(data) + c.recycleReadPacket() + + var queries []string + var err error + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err = sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + return !c.writeErrorPacketFromErrorAndLog(err) + } + } else { + queries = []string{query} + } + for index, sql := range queries { + more := false + if index != len(queries)-1 { + more = true + } + res := c.execQuery(sql, handler, more) + if res != execSuccess { + return res != connErr } } + timings.Record(queryTimingKey, queryStart) return true } @@ -1188,9 +1195,7 @@ func (c *Conn) execQuery(query string, handler Handler, more bool) execResult { if err == nil || err == io.EOF { err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) } - if werr := c.writeErrorPacketFromError(err); werr != nil { - // If we can't even write the error, we're done. - log.Errorf("Error writing query error to %s: %v", c, werr) + if !c.writeErrorPacketFromErrorAndLog(err) { return connErr } return execErr