Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 238 additions & 3 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,23 @@ type Conn struct {
// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
// It can be allocated from bufPool or heap and should be recycled in the same manner.
currentEphemeralBuffer *[]byte

// StatementID is the prepared statement ID.
StatementID uint32

// PrepareData is the map to use a prepared statement.
PrepareData map[uint32]*PrepareData
}

// PrepareData is a buffer used for store prepare statement meta data
type PrepareData struct {
StatementID uint32
PrepareStmt string
ParsedStmt *sqlparser.Statement
ParamsCount uint16
ParamsType []int32
ColumnNames []string
BindVars map[string]*querypb.BindVariable
}

// bufPool is used to allocate and free buffers in an efficient way.
Expand All @@ -180,9 +197,10 @@ func newConn(conn net.Conn) *Conn {
// size for reads.
func newServerConn(conn net.Conn, listener *Listener) *Conn {
c := &Conn{
conn: conn,
listener: listener,
closed: sync2.NewAtomicBool(false),
conn: conn,
listener: listener,
closed: sync2.NewAtomicBool(false),
PrepareData: make(map[uint32]*PrepareData),
}
if listener.connReadBufferSize > 0 {
c.bufferedReader = bufio.NewReaderSize(conn, listener.connReadBufferSize)
Expand Down Expand Up @@ -799,6 +817,223 @@ func (c *Conn) handleNextCommand(handler Handler) error {
return err
}
}
case ComPrepare:
query := c.parseComPrepare(data)
c.recycleReadPacket()

var queries []string
if c.Capabilities&CapabilityClientMultiStatements != 0 {
queries, err = sqlparser.SplitStatementToPieces(query)
if err != nil {
log.Errorf("Conn %v: Error splitting query: %v", c, err)
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Conn %v: Error writing query error: %v", c, werr)
return werr
}
}
} else {
queries = []string{query}
}

if len(queries) != 1 {
return fmt.Errorf("can not prepare multiple statements")
}

c.StatementID++
prepare := &PrepareData{
StatementID: c.StatementID,
PrepareStmt: queries[0],
}

c.PrepareData[c.StatementID] = prepare

fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
err := handler.ComPrepare(c, queries[0], func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}

if !fieldSent {
fieldSent = true
if err := c.writePrepare(qr, c.PrepareData[c.StatementID]); err != nil {
return err
}
}

return nil
})
if err != nil {
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}

delete(c.PrepareData, c.StatementID)
return nil
}
case ComStmtExecute:
queryStart := time.Now()
stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data)
c.recycleReadPacket()
if err != nil {
if stmtID != uint32(0) {
prepare := c.PrepareData[stmtID]
if prepare.BindVars != nil {
for k := range prepare.BindVars {
prepare.BindVars[k] = nil
}
}
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr)
return werr
}
return nil
}

fieldSent := false
// sendFinished is set if the response should just be an OK packet.
sendFinished := false
prepare := c.PrepareData[stmtID]
err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error {
if sendFinished {
// Failsafe: Unreachable if server is well-behaved.
return io.EOF
}

if !fieldSent {
fieldSent = true

if len(qr.Fields) == 0 {
sendFinished = true
// We should not send any more packets after this.
return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0)
}
if err := c.writeFields(qr); err != nil {
return err
}
}

return c.writeBinaryRows(qr)
})

if prepare.BindVars != nil {
for k := range prepare.BindVars {
prepare.BindVars[k] = nil
}
}

// If no field was sent, we expect an error.
if !fieldSent {
// This is just a failsafe. Should never happen.
if err == nil || err == io.EOF {
err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error"))
}
if werr := c.writeErrorPacketFromError(err); werr != nil {
// If we can't even write the error, we're done.
log.Errorf("Error writing query error to %s: %v", c, werr)
return werr
}
} else {
if err != nil {
// We can't send an error in the middle of a stream.
// All we can do is abort the send, which will cause a 2013.
log.Errorf("Error in the middle of a stream to %s: %v", c, err)
return err
}

// Send the end packet only sendFinished is false (results were streamed).
// In this case the affectedRows and lastInsertID are always 0 since it
// was a read operation.
if !sendFinished {
if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil {
log.Errorf("Error writing result to %s: %v", c, err)
return err
}
}
}

timings.Record(queryTimingKey, queryStart)
case ComStmtSendLongData:
stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data)
c.recycleReadPacket()
if !ok {
err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data)
log.Error(err.Error())
return err
}

prepare, ok := c.PrepareData[stmtID]
if !ok {
err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID)
log.Error(err.Error())
return err
}

if prepare.BindVars == nil ||
prepare.ParamsCount == uint16(0) ||
paramID >= prepare.ParamsCount {
err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt)
log.Error(err.Error())
return err
}

chunk := make([]byte, len(chunkData))
copy(chunk, chunkData)

key := fmt.Sprintf("v%d", paramID+1)
if val, ok := prepare.BindVars[key]; ok {
val.Value = append(val.Value, chunk...)
} else {
v, err := sqltypes.InterfaceToValue(chunk)
if err != nil {
log.Error("build converted parameter value failed: %v", err)
return err
}
prepare.BindVars[key] = sqltypes.ValueBindVariable(v)
}
case ComStmtClose:
stmtID, ok := c.parseComStmtClose(data)
c.recycleReadPacket()
if ok {
delete(c.PrepareData, stmtID)
}
case ComStmtReset:
stmtID, ok := c.parseComStmtReset(data)
c.recycleReadPacket()
if !ok {
log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil {
log.Error("Error writing error packet to client: %v", err)
return err
}
}

prepare, ok := c.PrepareData[stmtID]
if !ok {
log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data)
if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil {
log.Error("Error writing error packet to client: %v", err)
return err
}
}

if prepare.BindVars != nil {
for k := range prepare.BindVars {
prepare.BindVars[k] = nil
}
}

if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil {
log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err)
return err
}
default:
log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data)
c.recycleReadPacket()
Expand Down
18 changes: 18 additions & 0 deletions go/mysql/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,24 @@ const (
// ComBinlogDump is COM_BINLOG_DUMP.
ComBinlogDump = 0x12

// ComPrepare is COM_PREPARE.
ComPrepare = 0x16
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is called COM_STMT_PREPARE in the MySQL codebase: https://dev.mysql.com/doc/internals/en/com-stmt-prepare.html


// ComStmtExecute is COM_STMT_EXECUTE.
ComStmtExecute = 0x17

// ComStmtSendLongData is COM_STMT_SEND_LONG_DATA
ComStmtSendLongData = 0x18

// ComStmtClose is COM_STMT_CLOSE.
ComStmtClose = 0x19

// ComStmtReset is COM_STMT_RESET
ComStmtReset = 0x1a

//ComStmtFetch is COM_STMT_FETCH
ComStmtFetch = 0x1c

// ComSetOption is COM_SET_OPTION
ComSetOption = 0x1b

Expand Down
10 changes: 10 additions & 0 deletions go/mysql/fakesqldb/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,16 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) {
return entry.QueryResult, nil
}

// ComPrepare is part of the mysql.Handler interface.
func (db *DB) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error {
return nil
}

// ComStmtExecute is part of the mysql.Handler interface.
func (db *DB) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error {
return nil
}

//
// Methods to add expected queries and results.
//
Expand Down
Loading