diff --git a/go/mysql/client.go b/go/mysql/client.go index a76ad0c0300..7785760deab 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -177,8 +177,10 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) { func (c *Conn) Ping() error { // This is a new command, need to reset the sequence. c.sequence = 0 + data, pos := c.startEphemeralPacketWithHeader(1) + data[pos] = ComPing - if err := c.writePacket([]byte{ComPing}); err != nil { + if err := c.writeEphemeralPacket(); err != nil { return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err) } data, err := c.readEphemeralPacket() @@ -542,8 +544,7 @@ func (c *Conn) writeSSLRequest(capabilities uint32, characterSet uint8, params * flags |= CapabilityClientConnectWithDB } - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) // Client capability flags. pos = writeUint32(data, pos, flags) @@ -605,8 +606,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [ length++ } - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) // Client capability flags. pos = writeUint32(data, pos, flags) @@ -672,8 +672,7 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) { // Returns a SQLError. func (c *Conn) writeClearTextPassword(params *ConnParams) error { length := len(params.Pass) + 1 - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeNullString(data, pos, params.Pass) // Sanity check. if pos != len(data) { @@ -686,8 +685,7 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error { // Returns a SQLError. func (c *Conn) writeMysqlNativePassword(params *ConnParams, salt []byte) error { scrambledPassword := ScramblePassword(salt, []byte(params.Pass)) - data := c.startEphemeralPacket(len(scrambledPassword)) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword)) pos += copy(data[pos:], scrambledPassword) // Sanity check. if pos != len(data) { diff --git a/go/mysql/conn.go b/go/mysql/conn.go index 8cb4d04041f..d515ec6c7a2 100644 --- a/go/mysql/conn.go +++ b/go/mysql/conn.go @@ -45,6 +45,10 @@ const ( // connBufferSize is how much we buffer for reading and // writing. It is also how much we allocate for ephemeral buffers. connBufferSize = 16 * 1024 + + // packetHeaderSize is the 4 bytes of header per MySQL packet + // sent over + packetHeaderSize = 4 ) // Constants for how ephemeral buffers were used for reading / writing. @@ -160,7 +164,7 @@ type Conn struct { // Keep track of how and of the buffer we allocated for an // ephemeral packet on the read and write sides. // These fields are used by: - // - startEphemeralPacket / writeEphemeralPacket methods for writes. + // - startEphemeralPacketWithHeader / writeEphemeralPacket methods for writes. // - readEphemeralPacket / recycleReadPacket methods for reads. currentEphemeralPolicy int // currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively. @@ -297,7 +301,7 @@ func (c *Conn) getReader() io.Reader { } func (c *Conn) readHeaderFrom(r io.Reader) (int, error) { - var header [4]byte + var header [packetHeaderSize]byte // Note io.ReadFull will return two different types of errors: // 1. if the socket is already closed, and the go runtime knows it, // then ReadFull will return an error (different than EOF), @@ -511,47 +515,49 @@ func (c *Conn) ReadPacket() ([]byte, error) { // writePacket writes a packet, possibly cutting it into multiple // chunks. Note this is not very efficient, as the client probably // has to build the []byte and that makes a memory copy. -// Try to use startEphemeralPacket/writeEphemeralPacket instead. +// Try to use startEphemeralPacketWithHeader/writeEphemeralPacket instead. // // This method returns a generic error, not a SQLError. func (c *Conn) writePacket(data []byte) error { index := 0 - length := len(data) + dataLength := len(data) - packetHeaderSize w, unget := c.getWriter() defer unget() + var header [packetHeaderSize]byte for { - // Packet length is capped to MaxPacketSize. - packetLength := length - if packetLength > MaxPacketSize { - packetLength = MaxPacketSize + // toBeSent is capped to MaxPacketSize. + toBeSent := dataLength + if toBeSent > MaxPacketSize { + toBeSent = MaxPacketSize } + // save the first 4 bytes of the payload, we will overwrite them with the + // header below + copy(header[0:packetHeaderSize], data[index:index+packetHeaderSize]) + // Compute and write the header. - var header [4]byte - header[0] = byte(packetLength) - header[1] = byte(packetLength >> 8) - header[2] = byte(packetLength >> 16) - header[3] = c.sequence - if n, err := w.Write(header[:]); err != nil { - return vterrors.Wrapf(err, "Write(header) failed") - } else if n != 4 { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(header) returned a short write: %v < 4", n) - } + data[index] = byte(toBeSent) + data[index+1] = byte(toBeSent >> 8) + data[index+2] = byte(toBeSent >> 16) + data[index+3] = c.sequence // Write the body. - if n, err := w.Write(data[index : index+packetLength]); err != nil { + if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil { return vterrors.Wrapf(err, "Write(packet) failed") - } else if n != packetLength { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, packetLength) + } else if n != (toBeSent + packetHeaderSize) { + return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize)) } + // restore the first 4 bytes once the network send is done + copy(data[index:index+packetHeaderSize], header[0:packetHeaderSize]) + // Update our state. c.sequence++ - length -= packetLength - if length == 0 { - if packetLength == MaxPacketSize { + dataLength -= toBeSent + if dataLength == 0 { + if toBeSent == MaxPacketSize { // The packet we just sent had exactly // MaxPacketSize size, we need to // sent a zero-size packet too. @@ -561,30 +567,30 @@ func (c *Conn) writePacket(data []byte) error { header[3] = c.sequence if n, err := w.Write(header[:]); err != nil { return vterrors.Wrapf(err, "Write(empty header) failed") - } else if n != 4 { + } else if n != packetHeaderSize { return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n) } c.sequence++ } return nil } - index += packetLength + index += toBeSent } } -func (c *Conn) startEphemeralPacket(length int) []byte { +func (c *Conn) startEphemeralPacketWithHeader(length int) ([]byte, int) { if c.currentEphemeralPolicy != ephemeralUnused { - panic("startEphemeralPacket cannot be used while a packet is already started.") + panic("startEphemeralPacketWithHeader cannot be used while a packet is already started.") } c.currentEphemeralPolicy = ephemeralWrite // get buffer from pool or it'll be allocated if length is too big - c.currentEphemeralBuffer = bufPool.Get(length) - return *c.currentEphemeralBuffer + c.currentEphemeralBuffer = bufPool.Get(length + packetHeaderSize) + return *c.currentEphemeralBuffer, packetHeaderSize } // writeEphemeralPacket writes the packet that was allocated by -// startEphemeralPacket. +// startEphemeralPacketWithHeader. func (c *Conn) writeEphemeralPacket() error { defer c.recycleWritePacket() @@ -622,8 +628,8 @@ func (c *Conn) writeComQuit() error { // This is a new command, need to reset the sequence. c.sequence = 0 - data := c.startEphemeralPacket(1) - data[0] = ComQuit + data, pos := c.startEphemeralPacketWithHeader(1) + data[pos] = ComQuit if err := c.writeEphemeralPacket(); err != nil { return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) } @@ -673,8 +679,7 @@ func (c *Conn) writeOKPacket(affectedRows, lastInsertID uint64, flags uint16, wa lenEncIntSize(lastInsertID) + 2 + // flags 2 // warnings - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeByte(data, pos, OKPacket) pos = writeLenEncInt(data, pos, affectedRows) pos = writeLenEncInt(data, pos, lastInsertID) @@ -695,8 +700,7 @@ func (c *Conn) writeOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, fla lenEncIntSize(lastInsertID) + 2 + // flags 2 // warnings - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeByte(data, pos, EOFPacket) pos = writeLenEncInt(data, pos, affectedRows) pos = writeLenEncInt(data, pos, lastInsertID) @@ -712,8 +716,7 @@ func (c *Conn) writeOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, fla func (c *Conn) writeErrorPacket(errorCode uint16, sqlState string, format string, args ...interface{}) error { errorMessage := fmt.Sprintf(format, args...) length := 1 + 2 + 1 + 5 + len(errorMessage) - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeByte(data, pos, ErrPacket) pos = writeUint16(data, pos, errorCode) pos = writeByte(data, pos, '#') @@ -743,8 +746,7 @@ func (c *Conn) writeErrorPacketFromError(err error) error { // doesn't flush (as it is used as part of a query result). func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error { length := 5 - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeByte(data, pos, EOFPacket) pos = writeUint16(data, pos, warnings) _ = writeUint16(data, pos, flags) diff --git a/go/mysql/conn_test.go b/go/mysql/conn_test.go index 405b23c40d5..d8ab1a8526a 100644 --- a/go/mysql/conn_test.go +++ b/go/mysql/conn_test.go @@ -77,7 +77,12 @@ func useWritePacket(t *testing.T, cConn *Conn, data []byte) { t.Fatalf("%v", x) } }() - if err := cConn.writePacket(data); err != nil { + + dataLen := len(data) + dataWithHeader := make([]byte, packetHeaderSize+dataLen) + copy(dataWithHeader[packetHeaderSize:], data) + + if err := cConn.writePacket(dataWithHeader); err != nil { t.Fatalf("writePacket failed: %v", err) } } @@ -91,8 +96,8 @@ func useWriteEphemeralPacketBuffered(t *testing.T, cConn *Conn, data []byte) { cConn.startWriterBuffering() defer cConn.endWriterBuffering() - buf := cConn.startEphemeralPacket(len(data)) - copy(buf, data) + buf, pos := cConn.startEphemeralPacketWithHeader(len(data)) + copy(buf[pos:], data) if err := cConn.writeEphemeralPacket(); err != nil { t.Fatalf("writeEphemeralPacket(false) failed: %v", err) } @@ -105,8 +110,8 @@ func useWriteEphemeralPacketDirect(t *testing.T, cConn *Conn, data []byte) { } }() - buf := cConn.startEphemeralPacket(len(data)) - copy(buf, data) + buf, pos := cConn.startEphemeralPacketWithHeader(len(data)) + copy(buf[pos:], data) if err := cConn.writeEphemeralPacket(); err != nil { t.Fatalf("writeEphemeralPacket(true) failed: %v", err) } diff --git a/go/mysql/query.go b/go/mysql/query.go index ef37f6eb5ae..37cb8e50e9c 100644 --- a/go/mysql/query.go +++ b/go/mysql/query.go @@ -42,9 +42,10 @@ func (c *Conn) WriteComQuery(query string) error { // This is a new command, need to reset the sequence. c.sequence = 0 - data := c.startEphemeralPacket(len(query) + 1) - data[0] = ComQuery - copy(data[1:], query) + data, pos := c.startEphemeralPacketWithHeader(len(query) + 1) + data[pos] = ComQuery + pos++ + copy(data[pos:], query) if err := c.writeEphemeralPacket(); err != nil { return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) } @@ -55,9 +56,10 @@ func (c *Conn) WriteComQuery(query string) error { // Client -> Server. // Returns SQLError(CRServerGone) if it can't. func (c *Conn) writeComInitDB(db string) error { - data := c.startEphemeralPacket(len(db) + 1) - data[0] = ComInitDB - copy(data[1:], db) + data, pos := c.startEphemeralPacketWithHeader(len(db) + 1) + data[pos] = ComInitDB + pos++ + copy(data[pos:], db) if err := c.writeEphemeralPacket(); err != nil { return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) } @@ -67,9 +69,10 @@ func (c *Conn) writeComInitDB(db string) error { // writeComSetOption changes the connection's capability of executing multi statements. // Returns SQLError(CRServerGone) if it can't. func (c *Conn) writeComSetOption(operation uint16) error { - data := c.startEphemeralPacket(16 + 1) - data[0] = ComSetOption - writeUint16(data, 1, operation) + data, pos := c.startEphemeralPacketWithHeader(16 + 1) + data[pos] = ComSetOption + pos++ + writeUint16(data, pos, operation) if err := c.writeEphemeralPacket(); err != nil { return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error()) } @@ -861,8 +864,8 @@ func (c *Conn) parseComInitDB(data []byte) string { func (c *Conn) sendColumnCount(count uint64) error { length := lenEncIntSize(count) - data := c.startEphemeralPacket(length) - writeLenEncInt(data, 0, count) + data, pos := c.startEphemeralPacketWithHeader(length) + writeLenEncInt(data, pos, count) return c.writeEphemeralPacket() } @@ -889,8 +892,7 @@ func (c *Conn) writeColumnDefinition(field *querypb.Field) error { flags = int64(field.Flags) } - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeLenEncString(data, pos, "def") // Always the same. pos = writeLenEncString(data, pos, field.Database) @@ -924,8 +926,7 @@ func (c *Conn) writeRow(row []sqltypes.Value) error { } } - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) for _, val := range row { if val.IsNull() { pos = writeByte(data, pos, NullValue) @@ -936,10 +937,6 @@ func (c *Conn) writeRow(row []sqltypes.Value) error { } } - if pos != length { - return vterrors.Errorf(vtrpc.Code_INTERNAL, "packet row: got %v bytes but expected %v", pos, length) - } - return c.writeEphemeralPacket() } @@ -1012,8 +1009,7 @@ func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error { prepare.ColumnNames = make([]string, columnCount) } - data := c.startEphemeralPacket(12) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(12) pos = writeByte(data, pos, 0x00) pos = writeUint32(data, pos, uint32(prepare.StatementID)) @@ -1081,8 +1077,7 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err length += nullBitMapLen + 1 - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) pos = writeByte(data, pos, 0x00) @@ -1105,10 +1100,6 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err } } - if pos != length { - return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length) - } - return c.writeEphemeralPacket() } diff --git a/go/mysql/query_test.go b/go/mysql/query_test.go index 77be4940d27..f0529a1bc90 100644 --- a/go/mysql/query_test.go +++ b/go/mysql/query_test.go @@ -31,9 +31,9 @@ import ( // Utility function to write sql query as packets to test parseComPrepare func MockQueryPackets(t *testing.T, query string) []byte { - data := make([]byte, len(query)+1) + data := make([]byte, len(query)+1+packetHeaderSize) // Not sure if it makes a difference - pos := 0 + pos := packetHeaderSize pos = writeByte(data, pos, ComPrepare) copy(data[pos:], query) return data diff --git a/go/mysql/replication.go b/go/mysql/replication.go index dcc4c5e20c2..d10925b6ddd 100644 --- a/go/mysql/replication.go +++ b/go/mysql/replication.go @@ -28,8 +28,8 @@ func (c *Conn) WriteComBinlogDump(serverID uint32, binlogFilename string, binlog 2 + // flags 4 + // server-id len(binlogFilename) // binlog-filename - data := c.startEphemeralPacket(length) - pos := writeByte(data, 0, ComBinlogDump) + data, pos := c.startEphemeralPacketWithHeader(length) + pos = writeByte(data, pos, ComBinlogDump) pos = writeUint32(data, pos, binlogPos) pos = writeUint16(data, pos, flags) pos = writeUint32(data, pos, serverID) @@ -53,8 +53,8 @@ func (c *Conn) WriteComBinlogDumpGTID(serverID uint32, binlogFilename string, bi 8 + // binlog-pos 4 + // data-size len(gtidSet) // data - data := c.startEphemeralPacket(length) - pos := writeByte(data, 0, ComBinlogDumpGTID) + data, pos := c.startEphemeralPacketWithHeader(length) + pos = writeByte(data, pos, ComBinlogDumpGTID) pos = writeUint16(data, pos, flags) pos = writeUint32(data, pos, serverID) pos = writeUint32(data, pos, uint32(len(binlogFilename))) diff --git a/go/mysql/server.go b/go/mysql/server.go index 53e405b9f0c..8eb72a2d87e 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -520,8 +520,7 @@ func (c *Conn) writeHandshakeV10(serverVersion string, authServer AuthServer, en 13 + // auth-plugin-data lenNullString(MysqlNativePassword) // auth-plugin-name - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) // Protocol version. pos = writeByte(data, pos, protocolVersion) @@ -767,8 +766,7 @@ func (c *Conn) writeAuthSwitchRequest(pluginName string, pluginData []byte) erro len(pluginName) + 1 + // 0-terminated pluginName len(pluginData) - data := c.startEphemeralPacket(length) - pos := 0 + data, pos := c.startEphemeralPacketWithHeader(length) // Packet header. pos = writeByte(data, pos, AuthSwitchRequestPacket)