diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 2f5394866b7..1ad1e87a71a 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -155,6 +155,23 @@ type Conn struct { // currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively. // It can be allocated from bufPool or heap and should be recycled in the same manner. currentEphemeralBuffer *[]byte + + // StatementID is the prepared statement ID. + StatementID uint32 + + // PrepareData is the map to use a prepared statement. + PrepareData map[uint32]*PrepareData +} + +// PrepareData is a buffer used for store prepare statement meta data +type PrepareData struct { + StatementID uint32 + PrepareStmt string + ParsedStmt *sqlparser.Statement + ParamsCount uint16 + ParamsType []int32 + ColumnNames []string + BindVars map[string]*querypb.BindVariable } // bufPool is used to allocate and free buffers in an efficient way. @@ -180,9 +197,10 @@ func newConn(conn net.Conn) *Conn { // size for reads. func newServerConn(conn net.Conn, listener *Listener) *Conn { c := &Conn{ - conn: conn, - listener: listener, - closed: sync2.NewAtomicBool(false), + conn: conn, + listener: listener, + closed: sync2.NewAtomicBool(false), + PrepareData: make(map[uint32]*PrepareData), } if listener.connReadBufferSize > 0 { c.bufferedReader = bufio.NewReaderSize(conn, listener.connReadBufferSize) @@ -799,6 +817,223 @@ func (c *Conn) handleNextCommand(handler Handler) error { return err } } + case ComPrepare: + query := c.parseComPrepare(data) + c.recycleReadPacket() + + var queries []string + if c.Capabilities&CapabilityClientMultiStatements != 0 { + queries, err = sqlparser.SplitStatementToPieces(query) + if err != nil { + log.Errorf("Conn %v: Error splitting query: %v", c, err) + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Conn %v: Error writing query error: %v", c, werr) + return werr + } + } + } else { + queries = []string{query} + } + + if len(queries) != 1 { + return fmt.Errorf("can not prepare multiple statements") + } + + c.StatementID++ + prepare := &PrepareData{ + StatementID: c.StatementID, + PrepareStmt: queries[0], + } + + c.PrepareData[c.StatementID] = prepare + + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + err := handler.ComPrepare(c, queries[0], func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !fieldSent { + fieldSent = true + if err := c.writePrepare(qr, c.PrepareData[c.StatementID]); err != nil { + return err + } + } + + return nil + }) + if err != nil { + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return werr + } + + delete(c.PrepareData, c.StatementID) + return nil + } + case ComStmtExecute: + queryStart := time.Now() + stmtID, _, err := c.parseComStmtExecute(c.PrepareData, data) + c.recycleReadPacket() + if err != nil { + if stmtID != uint32(0) { + prepare := c.PrepareData[stmtID] + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil + } + } + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Error("Error writing query error to client %v: %v", c.ConnectionID, werr) + return werr + } + return nil + } + + fieldSent := false + // sendFinished is set if the response should just be an OK packet. + sendFinished := false + prepare := c.PrepareData[stmtID] + err = handler.ComStmtExecute(c, prepare, func(qr *sqltypes.Result) error { + if sendFinished { + // Failsafe: Unreachable if server is well-behaved. + return io.EOF + } + + if !fieldSent { + fieldSent = true + + if len(qr.Fields) == 0 { + sendFinished = true + // We should not send any more packets after this. + return c.writeOKPacket(qr.RowsAffected, qr.InsertID, c.StatusFlags, 0) + } + if err := c.writeFields(qr); err != nil { + return err + } + } + + return c.writeBinaryRows(qr) + }) + + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil + } + } + + // If no field was sent, we expect an error. + if !fieldSent { + // This is just a failsafe. Should never happen. + if err == nil || err == io.EOF { + err = NewSQLErrorFromError(errors.New("unexpected: query ended without no results and no error")) + } + if werr := c.writeErrorPacketFromError(err); werr != nil { + // If we can't even write the error, we're done. + log.Errorf("Error writing query error to %s: %v", c, werr) + return werr + } + } else { + if err != nil { + // We can't send an error in the middle of a stream. + // All we can do is abort the send, which will cause a 2013. + log.Errorf("Error in the middle of a stream to %s: %v", c, err) + return err + } + + // Send the end packet only sendFinished is false (results were streamed). + // In this case the affectedRows and lastInsertID are always 0 since it + // was a read operation. + if !sendFinished { + if err := c.writeEndResult(false, 0, 0, handler.WarningCount(c)); err != nil { + log.Errorf("Error writing result to %s: %v", c, err) + return err + } + } + } + + timings.Record(queryTimingKey, queryStart) + case ComStmtSendLongData: + stmtID, paramID, chunkData, ok := c.parseComStmtSendLongData(data) + c.recycleReadPacket() + if !ok { + err := fmt.Errorf("error parsing statement send long data from client %v, returning error: %v", c.ConnectionID, data) + log.Error(err.Error()) + return err + } + + prepare, ok := c.PrepareData[stmtID] + if !ok { + err := fmt.Errorf("got wrong statement id from client %v, statement ID(%v) is not found from record", c.ConnectionID, stmtID) + log.Error(err.Error()) + return err + } + + if prepare.BindVars == nil || + prepare.ParamsCount == uint16(0) || + paramID >= prepare.ParamsCount { + err := fmt.Errorf("invalid parameter Number from client %v, statement: %v", c.ConnectionID, prepare.PrepareStmt) + log.Error(err.Error()) + return err + } + + chunk := make([]byte, len(chunkData)) + copy(chunk, chunkData) + + key := fmt.Sprintf("v%d", paramID+1) + if val, ok := prepare.BindVars[key]; ok { + val.Value = append(val.Value, chunk...) + } else { + v, err := sqltypes.InterfaceToValue(chunk) + if err != nil { + log.Error("build converted parameter value failed: %v", err) + return err + } + prepare.BindVars[key] = sqltypes.ValueBindVariable(v) + } + case ComStmtClose: + stmtID, ok := c.parseComStmtClose(data) + c.recycleReadPacket() + if ok { + delete(c.PrepareData, stmtID) + } + case ComStmtReset: + stmtID, ok := c.parseComStmtReset(data) + c.recycleReadPacket() + if !ok { + log.Error("Got unhandled packet from client %v, returning error: %v", c.ConnectionID, data) + if err := c.writeErrorPacket(ERUnknownComError, SSUnknownComError, "error handling packet: %v", data); err != nil { + log.Error("Error writing error packet to client: %v", err) + return err + } + } + + prepare, ok := c.PrepareData[stmtID] + if !ok { + log.Error("Commands were executed in an improper order from client %v, packet: %v", c.ConnectionID, data) + if err := c.writeErrorPacket(CRCommandsOutOfSync, SSUnknownComError, "commands were executed in an improper order: %v", data); err != nil { + log.Error("Error writing error packet to client: %v", err) + return err + } + } + + if prepare.BindVars != nil { + for k := range prepare.BindVars { + prepare.BindVars[k] = nil + } + } + + if err := c.writeOKPacket(0, 0, c.StatusFlags, 0); err != nil { + log.Error("Error writing ComStmtReset OK packet to client %v: %v", c.ConnectionID, err) + return err + } default: log.Errorf("Got unhandled packet (default) from %s, returning error: %v", c, data) c.recycleReadPacket() diff --git a/go/mysql/constants.go b/go/mysql/constants.go index 4d1a530a861..b1d4491f637 100644 --- a/go/mysql/constants.go +++ b/go/mysql/constants.go @@ -153,6 +153,24 @@ const ( // ComBinlogDump is COM_BINLOG_DUMP. ComBinlogDump = 0x12 + // ComPrepare is COM_PREPARE. + ComPrepare = 0x16 + + // ComStmtExecute is COM_STMT_EXECUTE. + ComStmtExecute = 0x17 + + // ComStmtSendLongData is COM_STMT_SEND_LONG_DATA + ComStmtSendLongData = 0x18 + + // ComStmtClose is COM_STMT_CLOSE. + ComStmtClose = 0x19 + + // ComStmtReset is COM_STMT_RESET + ComStmtReset = 0x1a + + //ComStmtFetch is COM_STMT_FETCH + ComStmtFetch = 0x1c + // ComSetOption is COM_SET_OPTION ComSetOption = 0x1b diff --git a/go/mysql/fakesqldb/server.go b/go/mysql/fakesqldb/server.go index 60dbb0547ee..77da9ea8905 100644 --- a/go/mysql/fakesqldb/server.go +++ b/go/mysql/fakesqldb/server.go @@ -432,6 +432,16 @@ func (db *DB) comQueryOrdered(query string) (*sqltypes.Result, error) { return entry.QueryResult, nil } +// ComPrepare is part of the mysql.Handler interface. +func (db *DB) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + return nil +} + +// ComStmtExecute is part of the mysql.Handler interface. +func (db *DB) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + // // Methods to add expected queries and results. // diff --git a/go/mysql/query.go b/go/mysql/query.go index b92ece87b63..5cacc0ad8d1 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -17,6 +17,11 @@ limitations under the License. package mysql import ( + "fmt" + "math" + "strconv" + "strings" + "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -510,6 +515,347 @@ func (c *Conn) parseComSetOption(data []byte) (uint16, bool) { return val, ok } +func (c *Conn) parseComPrepare(data []byte) string { + return string(data[1:]) +} + +func (c *Conn) parseComStmtExecute(prepareData map[uint32]*PrepareData, data []byte) (uint32, byte, error) { + pos := 0 + payload := data[1:] + bitMap := make([]byte, 0) + + // statement ID + stmtID, pos, ok := readUint32(payload, 0) + if !ok { + return 0, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading statement ID failed") + } + prepare, ok := prepareData[stmtID] + if !ok { + return 0, 0, NewSQLError(CRCommandsOutOfSync, SSUnknownSQLState, "statement ID is not found from record") + } + + // cursor type flags + cursorType, pos, ok := readByte(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading cursor type flags failed") + } + + // iteration count + iterCount, pos, ok := readUint32(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading iteration count failed") + } + if iterCount != uint32(1) { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "iteration count is not equal to 1") + } + + if prepare.ParamsCount > 0 { + bitMap, pos, ok = readBytes(payload, pos, int((prepare.ParamsCount+7)/8)) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading NULL-bitmap failed") + } + } + + newParamsBoundFlag, pos, ok := readByte(payload, pos) + if newParamsBoundFlag == 0x01 { + var mysqlType, flags byte + for i := uint16(0); i < prepare.ParamsCount; i++ { + mysqlType, pos, ok = readByte(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter type failed") + } + + flags, pos, ok = readByte(payload, pos) + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "reading parameter flags failed") + } + + // convert MySQL type to internal type. + valType, err := sqltypes.MySQLToType(int64(mysqlType), int64(flags)) + if err != nil { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "MySQLToType(%v,%v) failed: %v", mysqlType, flags, err) + } + + prepare.ParamsType[i] = int32(valType) + } + } + + for i := 0; i < len(prepare.ParamsType); i++ { + var val sqltypes.Value + parameterID := fmt.Sprintf("v%d", i+1) + if v, ok := prepare.BindVars[parameterID]; ok { + if v != nil { + continue + } + } + + if (bitMap[i/8] & (1 << uint(i%8))) > 0 { + val, pos, ok = c.parseStmtArgs(nil, sqltypes.Null, pos) + } else { + val, pos, ok = c.parseStmtArgs(payload, querypb.Type(prepare.ParamsType[i]), pos) + } + if !ok { + return stmtID, 0, NewSQLError(CRMalformedPacket, SSUnknownSQLState, "decoding parameter value failed: %v", prepare.ParamsType[i]) + } + + prepare.BindVars[parameterID] = sqltypes.ValueBindVariable(val) + } + + return stmtID, cursorType, nil +} + +func (c *Conn) parseStmtArgs(data []byte, typ querypb.Type, pos int) (sqltypes.Value, int, bool) { + switch typ { + case sqltypes.Null: + return sqltypes.NULL, pos, true + case sqltypes.Int8: + val, pos, ok := readByte(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Uint8: + val, pos, ok := readByte(data, pos) + return sqltypes.NewUint64(uint64(val)), pos, ok + case sqltypes.Uint16: + val, pos, ok := readUint16(data, pos) + return sqltypes.NewUint64(uint64(val)), pos, ok + case sqltypes.Int16, sqltypes.Year: + val, pos, ok := readUint16(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Uint24, sqltypes.Uint32: + val, pos, ok := readUint32(data, pos) + return sqltypes.NewUint64(uint64(val)), pos, ok + case sqltypes.Int24, sqltypes.Int32: + val, pos, ok := readUint32(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Float32: + val, pos, ok := readUint32(data, pos) + return sqltypes.NewFloat64(float64(math.Float32frombits(uint32(val)))), pos, ok + case sqltypes.Uint64: + val, pos, ok := readUint64(data, pos) + return sqltypes.NewUint64(val), pos, ok + case sqltypes.Int64: + val, pos, ok := readUint64(data, pos) + return sqltypes.NewInt64(int64(val)), pos, ok + case sqltypes.Float64: + val, pos, ok := readUint64(data, pos) + return sqltypes.NewFloat64(math.Float64frombits(val)), pos, ok + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: + size, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + switch size { + case 0x00: + return sqltypes.NewVarChar(" "), pos, ok + case 0x0b: + year, pos, ok := readUint16(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + month, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + day, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + microSecond, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + val := strconv.Itoa(int(year)) + "-" + + strconv.Itoa(int(month)) + "-" + + strconv.Itoa(int(day)) + " " + + strconv.Itoa(int(hour)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + "." + + strconv.Itoa(int(microSecond)) + + return sqltypes.NewVarChar(val), pos, ok + case 0x07: + year, pos, ok := readUint16(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + month, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + day, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + val := strconv.Itoa(int(year)) + "-" + + strconv.Itoa(int(month)) + "-" + + strconv.Itoa(int(day)) + " " + + strconv.Itoa(int(hour)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + + return sqltypes.NewVarChar(val), pos, ok + case 0x04: + year, pos, ok := readUint16(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + month, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + day, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + val := strconv.Itoa(int(year)) + "-" + + strconv.Itoa(int(month)) + "-" + + strconv.Itoa(int(day)) + + return sqltypes.NewVarChar(val), pos, ok + default: + return sqltypes.NULL, 0, false + } + case sqltypes.Time: + size, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + switch size { + case 0x00: + return sqltypes.NewVarChar("00:00:00"), pos, ok + case 0x0c: + isNegative, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + days, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + hours := uint32(hour) + days*uint32(24) + + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + microSecond, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + val := "" + if isNegative == 0x01 { + val += "-" + } + val += strconv.Itoa(int(hours)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + "." + + strconv.Itoa(int(microSecond)) + + return sqltypes.NewVarChar(val), pos, ok + case 0x08: + isNegative, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + days, pos, ok := readUint32(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + hour, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + hours := uint32(hour) + days*uint32(24) + + minute, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + second, pos, ok := readByte(data, pos) + if !ok { + return sqltypes.NULL, 0, false + } + + val := "" + if isNegative == 0x01 { + val += "-" + } + val += strconv.Itoa(int(hours)) + ":" + + strconv.Itoa(int(minute)) + ":" + + strconv.Itoa(int(second)) + + return sqltypes.NewVarChar(val), pos, ok + default: + return sqltypes.NULL, 0, false + } + case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, sqltypes.VarBinary, sqltypes.Char, + sqltypes.Bit, sqltypes.Enum, sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: + val, pos, ok := readLenEncStringAsBytes(data, pos) + return sqltypes.MakeTrusted(sqltypes.VarBinary, val), pos, ok + default: + return sqltypes.NULL, pos, false + } +} + +func (c *Conn) parseComStmtSendLongData(data []byte) (uint32, uint16, []byte, bool) { + pos := 1 + statementID, pos, ok := readUint32(data, pos) + if !ok { + return 0, 0, nil, false + } + + paramID, pos, ok := readUint16(data, pos) + if !ok { + return 0, 0, nil, false + } + + return statementID, paramID, data[pos:], true +} + +func (c *Conn) parseComStmtClose(data []byte) (uint32, bool) { + val, _, ok := readUint32(data, 1) + return val, ok +} + +func (c *Conn) parseComStmtReset(data []byte) (uint32, bool) { + val, _, ok := readUint32(data, 1) + return val, ok +} + func (c *Conn) parseComInitDB(data []byte) string { return string(data[1:]) } @@ -655,3 +1001,485 @@ func (c *Conn) writeEndResult(more bool, affectedRows, lastInsertID uint64, warn return nil } + +// writePrepare writes a prepare query response to the wire. +func (c *Conn) writePrepare(result *sqltypes.Result, prepare *PrepareData) error { + paramsCount := prepare.ParamsCount + columnCount := 0 + if result != nil { + columnCount = len(result.Fields) + } + if columnCount > 0 { + prepare.ColumnNames = make([]string, columnCount) + } + + data := c.startEphemeralPacket(12) + pos := 0 + + pos = writeByte(data, pos, 0x00) + pos = writeUint32(data, pos, uint32(prepare.StatementID)) + pos = writeUint16(data, pos, uint16(columnCount)) + pos = writeUint16(data, pos, uint16(paramsCount)) + pos = writeByte(data, pos, 0x00) + pos = writeUint16(data, pos, 0x0000) + + if err := c.writeEphemeralPacket(); err != nil { + return err + } + + if paramsCount > 0 { + for i := uint16(0); i < paramsCount; i++ { + if err := c.writeColumnDefinition(&querypb.Field{ + Name: "?", + Type: sqltypes.VarBinary, + Charset: 63}); err != nil { + return err + } + } + + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err + } + } + } + + if result != nil { + // Now send each Field. + for i, field := range result.Fields { + field.Name = strings.Replace(field.Name, "'?'", "?", -1) + prepare.ColumnNames[i] = field.Name + if err := c.writeColumnDefinition(field); err != nil { + return err + } + } + + if columnCount > 0 { + // Now send an EOF packet. + if c.Capabilities&CapabilityClientDeprecateEOF == 0 { + // With CapabilityClientDeprecateEOF, we do not send this EOF. + if err := c.writeEOFPacket(c.StatusFlags, 0); err != nil { + return err + } + } + } + } + + return c.flush() +} + +func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) error { + length := 0 + nullBitMapLen := (len(fields) + 7 + 2) / 8 + for _, val := range row { + if !val.IsNull() { + l, err := val2MySQLLen(val) + if err != nil { + return fmt.Errorf("internal value %v get MySQL value length error: %v", val, err) + } + length += l + } + } + + length += nullBitMapLen + 1 + + data := c.startEphemeralPacket(length) + pos := 0 + + pos = writeByte(data, pos, 0x00) + + for i := 0; i < nullBitMapLen; i++ { + pos = writeByte(data, pos, 0x00) + } + + for i, val := range row { + if val.IsNull() { + bytePos := (i+2)/8 + 1 + bitPos := (i + 2) % 8 + data[bytePos] |= 1 << uint(bitPos) + } else { + v, err := val2MySQL(val) + if err != nil { + return fmt.Errorf("internal value %v to MySQL value error: %v", val, err) + } + pos += copy(data[pos:], v) + } + } + + if pos != length { + return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length) + } + + return c.writeEphemeralPacket() +} + +// writeBinaryRows sends the rows of a Result with binary form. +func (c *Conn) writeBinaryRows(result *sqltypes.Result) error { + for _, row := range result.Rows { + if err := c.writeBinaryRow(result.Fields, row); err != nil { + return err + } + } + return nil +} + +func val2MySQL(v sqltypes.Value) ([]byte, error) { + var out []byte + pos := 0 + switch v.Type() { + case sqltypes.Null: + // no-op + case sqltypes.Int8: + val, err := strconv.ParseInt(v.ToString(), 10, 8) + if err != nil { + return []byte{}, err + } + out = make([]byte, 1) + writeByte(out, pos, uint8(val)) + case sqltypes.Uint8: + val, err := strconv.ParseUint(v.ToString(), 10, 8) + if err != nil { + return []byte{}, err + } + out = make([]byte, 1) + writeByte(out, pos, uint8(val)) + case sqltypes.Uint16: + val, err := strconv.ParseUint(v.ToString(), 10, 16) + if err != nil { + return []byte{}, err + } + out = make([]byte, 2) + writeUint16(out, pos, uint16(val)) + case sqltypes.Int16, sqltypes.Year: + val, err := strconv.ParseInt(v.ToString(), 10, 16) + if err != nil { + return []byte{}, err + } + out = make([]byte, 2) + writeUint16(out, pos, uint16(val)) + case sqltypes.Uint24, sqltypes.Uint32: + val, err := strconv.ParseUint(v.ToString(), 10, 32) + if err != nil { + return []byte{}, err + } + out = make([]byte, 4) + writeUint32(out, pos, uint32(val)) + case sqltypes.Int24, sqltypes.Int32: + val, err := strconv.ParseInt(v.ToString(), 10, 32) + if err != nil { + return []byte{}, err + } + out = make([]byte, 4) + writeUint32(out, pos, uint32(val)) + case sqltypes.Float32: + val, err := strconv.ParseFloat(v.ToString(), 32) + if err != nil { + return []byte{}, err + } + bits := math.Float32bits(float32(val)) + out = make([]byte, 4) + writeUint32(out, pos, bits) + case sqltypes.Uint64: + val, err := strconv.ParseUint(v.ToString(), 10, 64) + if err != nil { + return []byte{}, err + } + out = make([]byte, 8) + writeUint64(out, pos, uint64(val)) + case sqltypes.Int64: + val, err := strconv.ParseInt(v.ToString(), 10, 64) + if err != nil { + return []byte{}, err + } + out = make([]byte, 8) + writeUint64(out, pos, uint64(val)) + case sqltypes.Float64: + val, err := strconv.ParseFloat(v.ToString(), 64) + if err != nil { + return []byte{}, err + } + bits := math.Float64bits(val) + out = make([]byte, 8) + writeUint64(out, pos, bits) + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: + if len(v.Raw()) > 19 { + out = make([]byte, 1+11) + out[pos] = 0x0b + pos++ + year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) + if err != nil { + return []byte{}, err + } + month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) + if err != nil { + return []byte{}, err + } + day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8) + if err != nil { + return []byte{}, err + } + hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8) + if err != nil { + return []byte{}, err + } + minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8) + if err != nil { + return []byte{}, err + } + second, err := strconv.ParseUint(string(v.Raw()[17:19]), 10, 8) + if err != nil { + return []byte{}, err + } + val := make([]byte, 6) + count := copy(val, v.Raw()[20:]) + for i := 0; i < (6 - count); i++ { + val[count+i] = 0x30 + } + microSecond, err := strconv.ParseUint(string(val), 10, 32) + if err != nil { + return []byte{}, err + } + pos = writeUint16(out, pos, uint16(year)) + pos = writeByte(out, pos, byte(month)) + pos = writeByte(out, pos, byte(day)) + pos = writeByte(out, pos, byte(hour)) + pos = writeByte(out, pos, byte(minute)) + pos = writeByte(out, pos, byte(second)) + pos = writeUint32(out, pos, uint32(microSecond)) + } else if len(v.Raw()) > 10 { + out = make([]byte, 1+7) + out[pos] = 0x07 + pos++ + year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) + if err != nil { + return []byte{}, err + } + month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) + if err != nil { + return []byte{}, err + } + day, err := strconv.ParseUint(string(v.Raw()[8:10]), 10, 8) + if err != nil { + return []byte{}, err + } + hour, err := strconv.ParseUint(string(v.Raw()[11:13]), 10, 8) + if err != nil { + return []byte{}, err + } + minute, err := strconv.ParseUint(string(v.Raw()[14:16]), 10, 8) + if err != nil { + return []byte{}, err + } + second, err := strconv.ParseUint(string(v.Raw()[17:]), 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint16(out, pos, uint16(year)) + pos = writeByte(out, pos, byte(month)) + pos = writeByte(out, pos, byte(day)) + pos = writeByte(out, pos, byte(hour)) + pos = writeByte(out, pos, byte(minute)) + pos = writeByte(out, pos, byte(second)) + } else if len(v.Raw()) > 0 { + out = make([]byte, 1+4) + out[pos] = 0x04 + pos++ + year, err := strconv.ParseUint(string(v.Raw()[0:4]), 10, 16) + if err != nil { + return []byte{}, err + } + month, err := strconv.ParseUint(string(v.Raw()[5:7]), 10, 8) + if err != nil { + return []byte{}, err + } + day, err := strconv.ParseUint(string(v.Raw()[8:]), 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint16(out, pos, uint16(year)) + pos = writeByte(out, pos, byte(month)) + pos = writeByte(out, pos, byte(day)) + } else { + out = make([]byte, 1) + out[pos] = 0x00 + } + case sqltypes.Time: + if string(v.Raw()) == "00:00:00" { + out = make([]byte, 1) + out[pos] = 0x00 + } else if strings.Contains(string(v.Raw()), ".") { + out = make([]byte, 1+12) + out[pos] = 0x0c + pos++ + + sub1 := strings.Split(string(v.Raw()), ":") + if len(sub1) != 3 { + err := fmt.Errorf("incorrect time value, ':' is not found") + return []byte{}, err + } + sub2 := strings.Split(sub1[2], ".") + if len(sub2) != 2 { + err := fmt.Errorf("incorrect time value, '.' is not found") + return []byte{}, err + } + + var total []byte + if strings.HasPrefix(sub1[0], "-") { + out[pos] = 0x01 + total = []byte(sub1[0]) + total = total[1:] + } else { + out[pos] = 0x00 + total = []byte(sub1[0]) + } + pos++ + + h, err := strconv.ParseUint(string(total), 10, 32) + if err != nil { + return []byte{}, err + } + + days := uint32(h) / 24 + hours := uint32(h) % 24 + minute := sub1[1] + second := sub2[0] + microSecond := sub2[1] + + minutes, err := strconv.ParseUint(minute, 10, 8) + if err != nil { + return []byte{}, err + } + + seconds, err := strconv.ParseUint(second, 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint32(out, pos, uint32(days)) + pos = writeByte(out, pos, byte(hours)) + pos = writeByte(out, pos, byte(minutes)) + pos = writeByte(out, pos, byte(seconds)) + + val := make([]byte, 6) + count := copy(val, microSecond) + for i := 0; i < (6 - count); i++ { + val[count+i] = 0x30 + } + microSeconds, err := strconv.ParseUint(string(val), 10, 32) + if err != nil { + return []byte{}, err + } + pos = writeUint32(out, pos, uint32(microSeconds)) + } else if len(v.Raw()) > 0 { + out = make([]byte, 1+8) + out[pos] = 0x08 + pos++ + + sub1 := strings.Split(string(v.Raw()), ":") + if len(sub1) != 3 { + err := fmt.Errorf("incorrect time value, ':' is not found") + return []byte{}, err + } + + var total []byte + if strings.HasPrefix(sub1[0], "-") { + out[pos] = 0x01 + total = []byte(sub1[0]) + total = total[1:] + } else { + out[pos] = 0x00 + total = []byte(sub1[0]) + } + pos++ + + h, err := strconv.ParseUint(string(total), 10, 32) + if err != nil { + return []byte{}, err + } + + days := uint32(h) / 24 + hours := uint32(h) % 24 + minute := sub1[1] + second := sub1[2] + + minutes, err := strconv.ParseUint(minute, 10, 8) + if err != nil { + return []byte{}, err + } + + seconds, err := strconv.ParseUint(second, 10, 8) + if err != nil { + return []byte{}, err + } + pos = writeUint32(out, pos, uint32(days)) + pos = writeByte(out, pos, byte(hours)) + pos = writeByte(out, pos, byte(minutes)) + pos = writeByte(out, pos, byte(seconds)) + } else { + err := fmt.Errorf("incorrect time value") + return []byte{}, err + } + case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, + sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum, + sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: + l := len(v.Raw()) + length := lenEncIntSize(uint64(l)) + l + out = make([]byte, length) + pos = writeLenEncInt(out, pos, uint64(l)) + copy(out[pos:], v.Raw()) + default: + out = make([]byte, len(v.Raw())) + copy(out, v.Raw()) + } + return out, nil +} + +func val2MySQLLen(v sqltypes.Value) (int, error) { + var length int + var err error + + switch v.Type() { + case sqltypes.Null: + length = 0 + case sqltypes.Int8, sqltypes.Uint8: + length = 1 + case sqltypes.Uint16, sqltypes.Int16, sqltypes.Year: + length = 2 + case sqltypes.Uint24, sqltypes.Uint32, sqltypes.Int24, sqltypes.Int32, sqltypes.Float32: + length = 4 + case sqltypes.Uint64, sqltypes.Int64, sqltypes.Float64: + length = 8 + case sqltypes.Timestamp, sqltypes.Date, sqltypes.Datetime: + if len(v.Raw()) > 19 { + length = 12 + } else if len(v.Raw()) > 10 { + length = 8 + } else if len(v.Raw()) > 0 { + length = 5 + } else { + length = 1 + } + case sqltypes.Time: + if string(v.Raw()) == "00:00:00" { + length = 1 + } else if strings.Contains(string(v.Raw()), ".") { + length = 13 + } else if len(v.Raw()) > 0 { + length = 9 + } else { + err = fmt.Errorf("incorrect time value") + } + case sqltypes.Decimal, sqltypes.Text, sqltypes.Blob, sqltypes.VarChar, + sqltypes.VarBinary, sqltypes.Char, sqltypes.Bit, sqltypes.Enum, + sqltypes.Set, sqltypes.Geometry, sqltypes.Binary, sqltypes.TypeJSON: + l := len(v.Raw()) + length = lenEncIntSize(uint64(l)) + l + default: + length = len(v.Raw()) + } + if err != nil { + return 0, err + } + return length, nil +} diff --git a/go/mysql/server.go b/go/mysql/server.go index 7538261c259..db28f12bc99 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -94,6 +94,14 @@ type Handler interface { // hang on to the byte slice. ComQuery(c *Conn, query string, callback func(*sqltypes.Result) error) error + // ComPrepare is called when a connection receives a prepared + // statement query. + ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error + + // ComStmtExecute is called when a connection receives a statement + // execute query. + ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error + // WarningCount is called at the end of each query to obtain // the value to be returned to the client in the EOF packet. // Note that this will be called either in the context of the diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 8895425a7cd..032e67c3bbe 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -171,6 +171,14 @@ func (th *testHandler) ComQuery(c *Conn, query string, callback func(*sqltypes.R return nil } +func (th *testHandler) ComPrepare(c *Conn, query string, callback func(*sqltypes.Result) error) error { + return nil +} + +func (th *testHandler) ComStmtExecute(c *Conn, prepare *PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) WarningCount(c *Conn) uint16 { return th.warnings } diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index 8ebddca2055..58d2f55c03a 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1415,3 +1415,143 @@ func buildVarCharRow(values ...string) []sqltypes.Value { } return row } + +// Prepare executes a prepare statements. +func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (result *sqltypes.Result, err error) { + logStats := NewLogStats(ctx, method, sql, bindVars) + result, err = e.prepare(ctx, safeSession, sql, bindVars, logStats) + logStats.Error = err + + // The mysql plugin runs an implicit rollback whenever a connection closes. + // To avoid spamming the log with no-op rollback records, ignore it if + // it was a no-op record (i.e. didn't issue any queries) + if !(logStats.StmtType == "ROLLBACK" && logStats.ShardQueries == 0) { + logStats.Send() + } + return result, err +} + +func (e *Executor) prepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, logStats *LogStats) (*sqltypes.Result, error) { + // Start an implicit transaction if necessary. + // TODO(sougou): deprecate legacyMode after all users are migrated out. + if !e.legacyAutocommit && !safeSession.Autocommit && !safeSession.InTransaction() { + if err := e.txConn.Begin(ctx, safeSession); err != nil { + return nil, err + } + } + + destKeyspace, destTabletType, dest, err := e.ParseDestinationTarget(safeSession.TargetString) + if err != nil { + return nil, err + } + + if safeSession.InTransaction() && destTabletType != topodatapb.TabletType_MASTER { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "transactions are supported only for master tablet types, current type: %v", destTabletType) + } + if bindVars == nil { + bindVars = make(map[string]*querypb.BindVariable) + } + + stmtType := sqlparser.Preview(sql) + logStats.StmtType = sqlparser.StmtType(stmtType) + + // Mysql warnings are scoped to the current session, but are + // cleared when a "non-diagnostic statement" is executed: + // https://dev.mysql.com/doc/refman/8.0/en/show-warnings.html + // + // To emulate this behavior, clear warnings from the session + // for all statements _except_ SHOW, so that SHOW WARNINGS + // can actually return them. + if stmtType != sqlparser.StmtShow { + safeSession.ClearWarnings() + } + + switch stmtType { + case sqlparser.StmtSelect: + return e.handlePrepare(ctx, safeSession, sql, bindVars, destKeyspace, destTabletType, logStats) + case sqlparser.StmtInsert, sqlparser.StmtReplace, sqlparser.StmtUpdate, sqlparser.StmtDelete: + safeSession := safeSession + + // In legacy mode, we ignore autocommit settings. + if e.legacyAutocommit { + return &sqltypes.Result{}, nil + } + + mustCommit := false + if safeSession.Autocommit && !safeSession.InTransaction() { + mustCommit = true + if err := e.txConn.Begin(ctx, safeSession); err != nil { + return nil, err + } + // The defer acts as a failsafe. If commit was successful, + // the rollback will be a no-op. + defer e.txConn.Rollback(ctx, safeSession) + } + + // The SetAutocommitable flag should be same as mustCommit. + // If we started a transaction because of autocommit, then mustCommit + // will be true, which means that we can autocommit. If we were already + // in a transaction, it means that the app started it, or we are being + // called recursively. If so, we cannot autocommit because whatever we + // do is likely not final. + // The control flow is such that autocommitable can only be turned on + // at the beginning, but never after. + safeSession.SetAutocommitable(mustCommit) + + if mustCommit { + commitStart := time.Now() + if err = e.txConn.Commit(ctx, safeSession); err != nil { + return nil, err + } + logStats.CommitTime = time.Since(commitStart) + } + return &sqltypes.Result{}, nil + case sqlparser.StmtDDL, sqlparser.StmtBegin, sqlparser.StmtCommit, sqlparser.StmtRollback, sqlparser.StmtSet, + sqlparser.StmtUse, sqlparser.StmtOther, sqlparser.StmtComment: + return &sqltypes.Result{}, nil + case sqlparser.StmtShow: + return e.handleShow(ctx, safeSession, sql, bindVars, dest, destKeyspace, destTabletType, logStats) + } + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized statement: %s", sql) +} + +func (e *Executor) handlePrepare(ctx context.Context, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable, destKeyspace string, destTabletType topodatapb.TabletType, logStats *LogStats) (*sqltypes.Result, error) { + // V3 mode. + query, comments := sqlparser.SplitMarginComments(sql) + vcursor := newVCursorImpl(ctx, safeSession, destKeyspace, destTabletType, comments, e, logStats) + plan, err := e.getPlan( + vcursor, + query, + comments, + bindVars, + skipQueryPlanCache(safeSession), + logStats, + ) + execStart := time.Now() + logStats.PlanTime = execStart.Sub(logStats.StartTime) + + if err != nil { + logStats.Error = err + return nil, err + } + + qr, err := plan.Instructions.GetFields(vcursor, bindVars) + logStats.ExecuteTime = time.Since(execStart) + var errCount uint64 + if err != nil { + logStats.Error = err + errCount = 1 + } else { + logStats.RowsAffected = qr.RowsAffected + } + + // Check if there was partial DML execution. If so, rollback the transaction. + if err != nil && safeSession.InTransaction() && vcursor.hasPartialDML { + _ = e.txConn.Rollback(ctx, safeSession) + err = vterrors.Errorf(vtrpcpb.Code_ABORTED, "transaction rolled back due to partial DML execution: %v", err) + } + + plan.AddStats(1, time.Since(logStats.StartTime), uint64(logStats.ShardQueries), logStats.RowsAffected, errCount) + + return qr, err +} diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index e1c0e3408f0..dd0c53432a1 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -21,6 +21,7 @@ import ( "fmt" "net" "os" + "strings" "sync/atomic" "syscall" "time" @@ -33,6 +34,7 @@ import ( "vitess.io/vitess/go/vt/callinfo" "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/servenv" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vttls" querypb "vitess.io/vitess/go/vt/proto/query" @@ -160,6 +162,154 @@ func (vh *vtgateHandler) ComQuery(c *mysql.Conn, query string, callback func(*sq return callback(result) } +// ComPrepare is the handler for command prepare. +func (vh *vtgateHandler) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } + + ctx = callinfo.MysqlCallInfo(ctx, c) + + // Fill in the ImmediateCallerID with the UserData returned by + // the AuthServer plugin for that user. If nothing was + // returned, use the User. This lets the plugin map a MySQL + // user used for authentication to a Vitess User used for + // Table ACLs and Vitess authentication in general. + im := c.UserData.Get() + ef := callerid.NewEffectiveCallerID( + c.User, /* principal: who */ + c.RemoteAddr().String(), /* component: running client process */ + "VTGate MySQL Connector" /* subcomponent: part of the client */) + ctx = callerid.NewContext(ctx, ef, im) + + session, _ := c.ClientData.(*vtgatepb.Session) + if session == nil { + session = &vtgatepb.Session{ + Options: &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_ALL, + }, + Autocommit: true, + } + if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { + session.Options.ClientFoundRows = true + } + } + + if !session.InTransaction { + atomic.AddInt32(&busyConnections, 1) + } + defer func() { + if !session.InTransaction { + atomic.AddInt32(&busyConnections, -1) + } + }() + + if c.SchemaName != "" { + session.TargetString = c.SchemaName + } + + statement, err := sqlparser.ParseStrictDDL(query) + if err != nil { + err = mysql.NewSQLErrorFromError(err) + return err + } + + paramsCount := uint16(0) + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) { + switch node := node.(type) { + case *sqlparser.SQLVal: + if strings.HasPrefix(string(node.Val), ":v") { + paramsCount++ + } + } + return true, nil + }, statement) + + prepare := c.PrepareData[c.StatementID] + prepare.ParsedStmt = &statement + + if paramsCount > 0 { + prepare.ParamsCount = paramsCount + prepare.ParamsType = make([]int32, paramsCount) + prepare.BindVars = make(map[string]*querypb.BindVariable, paramsCount) + } + + session, result, err := vh.vtg.Prepare(ctx, session, query, make(map[string]*querypb.BindVariable)) + c.ClientData = session + err = mysql.NewSQLErrorFromError(err) + if err != nil { + return err + } + return callback(result) +} + +func (vh *vtgateHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + var ctx context.Context + var cancel context.CancelFunc + if *mysqlQueryTimeout != 0 { + ctx, cancel = context.WithTimeout(context.Background(), *mysqlQueryTimeout) + defer cancel() + } else { + ctx = context.Background() + } + + ctx = callinfo.MysqlCallInfo(ctx, c) + + // Fill in the ImmediateCallerID with the UserData returned by + // the AuthServer plugin for that user. If nothing was + // returned, use the User. This lets the plugin map a MySQL + // user used for authentication to a Vitess User used for + // Table ACLs and Vitess authentication in general. + im := c.UserData.Get() + ef := callerid.NewEffectiveCallerID( + c.User, /* principal: who */ + c.RemoteAddr().String(), /* component: running client process */ + "VTGate MySQL Connector" /* subcomponent: part of the client */) + ctx = callerid.NewContext(ctx, ef, im) + + session, _ := c.ClientData.(*vtgatepb.Session) + if session == nil { + session = &vtgatepb.Session{ + Options: &querypb.ExecuteOptions{ + IncludedFields: querypb.ExecuteOptions_ALL, + }, + Autocommit: true, + } + if c.Capabilities&mysql.CapabilityClientFoundRows != 0 { + session.Options.ClientFoundRows = true + } + } + + if !session.InTransaction { + atomic.AddInt32(&busyConnections, 1) + } + defer func() { + if !session.InTransaction { + atomic.AddInt32(&busyConnections, -1) + } + }() + + if c.SchemaName != "" { + session.TargetString = c.SchemaName + } + if session.Options.Workload == querypb.ExecuteOptions_OLAP { + err := vh.vtg.StreamExecute(ctx, session, prepare.PrepareStmt, prepare.BindVars, callback) + return mysql.NewSQLErrorFromError(err) + } + _, qr, err := vh.vtg.Execute(ctx, session, prepare.PrepareStmt, prepare.BindVars) + if err != nil { + err = mysql.NewSQLErrorFromError(err) + return err + } + + return callback(qr) +} + func (vh *vtgateHandler) WarningCount(c *mysql.Conn) uint16 { session, _ := c.ClientData.(*vtgatepb.Session) if session != nil { diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index c873206e6ad..07d03fa97b2 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -43,6 +43,14 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) ComPrepare(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { + return nil +} + +func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 2b2e7b81d8c..f02f37433bb 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -832,6 +832,34 @@ func (vtg *VTGate) ResolveTransaction(ctx context.Context, dtid string) error { return formatError(vtg.txConn.Resolve(ctx, dtid)) } +// Prepare supports non-streaming prepare statement query with multi shards +func (vtg *VTGate) Prepare(ctx context.Context, session *vtgatepb.Session, sql string, bindVariables map[string]*querypb.BindVariable) (newSession *vtgatepb.Session, qr *sqltypes.Result, err error) { + // In this context, we don't care if we can't fully parse destination + destKeyspace, destTabletType, _, _ := vtg.executor.ParseDestinationTarget(session.TargetString) + statsKey := []string{"Execute", destKeyspace, topoproto.TabletTypeLString(destTabletType)} + defer vtg.timings.Record(statsKey, time.Now()) + + if bvErr := sqltypes.ValidateBindVariables(bindVariables); bvErr != nil { + err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", bvErr) + goto handleError + } + + qr, err = vtg.executor.Prepare(ctx, "Execute", NewSafeSession(session), sql, bindVariables) + if err == nil { + vtg.rowsReturned.Add(statsKey, int64(len(qr.Rows))) + return session, qr, nil + } + +handleError: + query := map[string]interface{}{ + "Sql": sql, + "BindVariables": bindVariables, + "Session": session, + } + err = recordAndAnnotateError(err, statsKey, query, vtg.logExecute) + return session, nil, err +} + // isKeyspaceRangeBasedSharded returns true if a keyspace is sharded // by range. This is true when there is a ShardingColumnType defined // in the SrvKeyspace (that is using the range-based sharding with the diff --git a/go/vt/vtqueryserver/plugin_mysql_server.go b/go/vt/vtqueryserver/plugin_mysql_server.go index 804cd207f8b..6e86a21a172 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server.go +++ b/go/vt/vtqueryserver/plugin_mysql_server.go @@ -136,6 +136,14 @@ func (mh *proxyHandler) WarningCount(c *mysql.Conn) uint16 { return 0 } +func (mh *proxyHandler) ComPrepare(c *mysql.Conn, query string, callback func(*sqltypes.Result) error) error { + return nil +} + +func (mh *proxyHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + var mysqlListener *mysql.Listener var mysqlUnixListener *mysql.Listener diff --git a/go/vt/vtqueryserver/plugin_mysql_server_test.go b/go/vt/vtqueryserver/plugin_mysql_server_test.go index ca0f6b806cd..936962075e7 100644 --- a/go/vt/vtqueryserver/plugin_mysql_server_test.go +++ b/go/vt/vtqueryserver/plugin_mysql_server_test.go @@ -43,6 +43,14 @@ func (th *testHandler) ComQuery(c *mysql.Conn, q string, callback func(*sqltypes return nil } +func (th *testHandler) ComPrepare(c *mysql.Conn, q string, callback func(*sqltypes.Result) error) error { + return nil +} + +func (th *testHandler) ComStmtExecute(c *mysql.Conn, prepare *mysql.PrepareData, callback func(*sqltypes.Result) error) error { + return nil +} + func (th *testHandler) WarningCount(c *mysql.Conn) uint16 { return 0 }