Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,10 @@ func Connect(ctx context.Context, params *ConnParams) (*Conn, error) {
func (c *Conn) Ping() error {
// This is a new command, need to reset the sequence.
c.sequence = 0
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = ComPing

if err := c.writePacket([]byte{ComPing}); err != nil {
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, "%v", err)
}
data, err := c.readEphemeralPacket()
Expand Down Expand Up @@ -542,8 +544,7 @@ func (c *Conn) writeSSLRequest(capabilities uint32, characterSet uint8, params *
flags |= CapabilityClientConnectWithDB
}

data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)

// Client capability flags.
pos = writeUint32(data, pos, flags)
Expand Down Expand Up @@ -605,8 +606,7 @@ func (c *Conn) writeHandshakeResponse41(capabilities uint32, scrambledPassword [
length++
}

data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)

// Client capability flags.
pos = writeUint32(data, pos, flags)
Expand Down Expand Up @@ -672,8 +672,7 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) {
// Returns a SQLError.
func (c *Conn) writeClearTextPassword(params *ConnParams) error {
length := len(params.Pass) + 1
data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeNullString(data, pos, params.Pass)
// Sanity check.
if pos != len(data) {
Expand All @@ -686,8 +685,7 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error {
// Returns a SQLError.
func (c *Conn) writeMysqlNativePassword(params *ConnParams, salt []byte) error {
scrambledPassword := ScramblePassword(salt, []byte(params.Pass))
data := c.startEphemeralPacket(len(scrambledPassword))
pos := 0
data, pos := c.startEphemeralPacketWithHeader(len(scrambledPassword))
pos += copy(data[pos:], scrambledPassword)
// Sanity check.
if pos != len(data) {
Expand Down
84 changes: 43 additions & 41 deletions go/mysql/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ const (
// connBufferSize is how much we buffer for reading and
// writing. It is also how much we allocate for ephemeral buffers.
connBufferSize = 16 * 1024

// packetHeaderSize is the 4 bytes of header per MySQL packet
// sent over
packetHeaderSize = 4
)

// Constants for how ephemeral buffers were used for reading / writing.
Expand Down Expand Up @@ -160,7 +164,7 @@ type Conn struct {
// Keep track of how and of the buffer we allocated for an
// ephemeral packet on the read and write sides.
// These fields are used by:
// - startEphemeralPacket / writeEphemeralPacket methods for writes.
// - startEphemeralPacketWithHeader / writeEphemeralPacket methods for writes.
// - readEphemeralPacket / recycleReadPacket methods for reads.
currentEphemeralPolicy int
// currentEphemeralBuffer for tracking allocated temporary buffer for writes and reads respectively.
Expand Down Expand Up @@ -297,7 +301,7 @@ func (c *Conn) getReader() io.Reader {
}

func (c *Conn) readHeaderFrom(r io.Reader) (int, error) {
var header [4]byte
var header [packetHeaderSize]byte
// Note io.ReadFull will return two different types of errors:
// 1. if the socket is already closed, and the go runtime knows it,
// then ReadFull will return an error (different than EOF),
Expand Down Expand Up @@ -511,47 +515,49 @@ func (c *Conn) ReadPacket() ([]byte, error) {
// writePacket writes a packet, possibly cutting it into multiple
// chunks. Note this is not very efficient, as the client probably
// has to build the []byte and that makes a memory copy.
// Try to use startEphemeralPacket/writeEphemeralPacket instead.
// Try to use startEphemeralPacketWithHeader/writeEphemeralPacket instead.
//
// This method returns a generic error, not a SQLError.
func (c *Conn) writePacket(data []byte) error {
index := 0
length := len(data)
dataLength := len(data) - packetHeaderSize

w, unget := c.getWriter()
defer unget()

var header [packetHeaderSize]byte
for {
// Packet length is capped to MaxPacketSize.
packetLength := length
if packetLength > MaxPacketSize {
packetLength = MaxPacketSize
// toBeSent is capped to MaxPacketSize.
toBeSent := dataLength
if toBeSent > MaxPacketSize {
toBeSent = MaxPacketSize
}

// save the first 4 bytes of the payload, we will overwrite them with the
// header below
copy(header[0:packetHeaderSize], data[index:index+packetHeaderSize])

// Compute and write the header.
var header [4]byte
header[0] = byte(packetLength)
header[1] = byte(packetLength >> 8)
header[2] = byte(packetLength >> 16)
header[3] = c.sequence
if n, err := w.Write(header[:]); err != nil {
return vterrors.Wrapf(err, "Write(header) failed")
} else if n != 4 {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(header) returned a short write: %v < 4", n)
}
data[index] = byte(toBeSent)
data[index+1] = byte(toBeSent >> 8)
data[index+2] = byte(toBeSent >> 16)
data[index+3] = c.sequence

// Write the body.
if n, err := w.Write(data[index : index+packetLength]); err != nil {
if n, err := w.Write(data[index : index+toBeSent+packetHeaderSize]); err != nil {
return vterrors.Wrapf(err, "Write(packet) failed")
} else if n != packetLength {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, packetLength)
} else if n != (toBeSent + packetHeaderSize) {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(packet) returned a short write: %v < %v", n, (toBeSent + packetHeaderSize))
}

// restore the first 4 bytes once the network send is done
copy(data[index:index+packetHeaderSize], header[0:packetHeaderSize])

// Update our state.
c.sequence++
length -= packetLength
if length == 0 {
if packetLength == MaxPacketSize {
dataLength -= toBeSent
if dataLength == 0 {
if toBeSent == MaxPacketSize {
// The packet we just sent had exactly
// MaxPacketSize size, we need to
// sent a zero-size packet too.
Expand All @@ -561,30 +567,30 @@ func (c *Conn) writePacket(data []byte) error {
header[3] = c.sequence
if n, err := w.Write(header[:]); err != nil {
return vterrors.Wrapf(err, "Write(empty header) failed")
} else if n != 4 {
} else if n != packetHeaderSize {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "Write(empty header) returned a short write: %v < 4", n)
}
c.sequence++
}
return nil
}
index += packetLength
index += toBeSent
}
}

func (c *Conn) startEphemeralPacket(length int) []byte {
func (c *Conn) startEphemeralPacketWithHeader(length int) ([]byte, int) {
if c.currentEphemeralPolicy != ephemeralUnused {
panic("startEphemeralPacket cannot be used while a packet is already started.")
panic("startEphemeralPacketWithHeader cannot be used while a packet is already started.")
}

c.currentEphemeralPolicy = ephemeralWrite
// get buffer from pool or it'll be allocated if length is too big
c.currentEphemeralBuffer = bufPool.Get(length)
return *c.currentEphemeralBuffer
c.currentEphemeralBuffer = bufPool.Get(length + packetHeaderSize)
return *c.currentEphemeralBuffer, packetHeaderSize
}

// writeEphemeralPacket writes the packet that was allocated by
// startEphemeralPacket.
// startEphemeralPacketWithHeader.
func (c *Conn) writeEphemeralPacket() error {
defer c.recycleWritePacket()

Expand Down Expand Up @@ -622,8 +628,8 @@ func (c *Conn) writeComQuit() error {
// This is a new command, need to reset the sequence.
c.sequence = 0

data := c.startEphemeralPacket(1)
data[0] = ComQuit
data, pos := c.startEphemeralPacketWithHeader(1)
data[pos] = ComQuit
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
Expand Down Expand Up @@ -673,8 +679,7 @@ func (c *Conn) writeOKPacket(affectedRows, lastInsertID uint64, flags uint16, wa
lenEncIntSize(lastInsertID) +
2 + // flags
2 // warnings
data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, OKPacket)
pos = writeLenEncInt(data, pos, affectedRows)
pos = writeLenEncInt(data, pos, lastInsertID)
Expand All @@ -695,8 +700,7 @@ func (c *Conn) writeOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, fla
lenEncIntSize(lastInsertID) +
2 + // flags
2 // warnings
data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, EOFPacket)
pos = writeLenEncInt(data, pos, affectedRows)
pos = writeLenEncInt(data, pos, lastInsertID)
Expand All @@ -712,8 +716,7 @@ func (c *Conn) writeOKPacketWithEOFHeader(affectedRows, lastInsertID uint64, fla
func (c *Conn) writeErrorPacket(errorCode uint16, sqlState string, format string, args ...interface{}) error {
errorMessage := fmt.Sprintf(format, args...)
length := 1 + 2 + 1 + 5 + len(errorMessage)
data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, ErrPacket)
pos = writeUint16(data, pos, errorCode)
pos = writeByte(data, pos, '#')
Expand Down Expand Up @@ -743,8 +746,7 @@ func (c *Conn) writeErrorPacketFromError(err error) error {
// doesn't flush (as it is used as part of a query result).
func (c *Conn) writeEOFPacket(flags uint16, warnings uint16) error {
length := 5
data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)
pos = writeByte(data, pos, EOFPacket)
pos = writeUint16(data, pos, warnings)
_ = writeUint16(data, pos, flags)
Expand Down
15 changes: 10 additions & 5 deletions go/mysql/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,12 @@ func useWritePacket(t *testing.T, cConn *Conn, data []byte) {
t.Fatalf("%v", x)
}
}()
if err := cConn.writePacket(data); err != nil {

dataLen := len(data)
dataWithHeader := make([]byte, packetHeaderSize+dataLen)
copy(dataWithHeader[packetHeaderSize:], data)

if err := cConn.writePacket(dataWithHeader); err != nil {
t.Fatalf("writePacket failed: %v", err)
}
}
Expand All @@ -91,8 +96,8 @@ func useWriteEphemeralPacketBuffered(t *testing.T, cConn *Conn, data []byte) {
cConn.startWriterBuffering()
defer cConn.endWriterBuffering()

buf := cConn.startEphemeralPacket(len(data))
copy(buf, data)
buf, pos := cConn.startEphemeralPacketWithHeader(len(data))
copy(buf[pos:], data)
if err := cConn.writeEphemeralPacket(); err != nil {
t.Fatalf("writeEphemeralPacket(false) failed: %v", err)
}
Expand All @@ -105,8 +110,8 @@ func useWriteEphemeralPacketDirect(t *testing.T, cConn *Conn, data []byte) {
}
}()

buf := cConn.startEphemeralPacket(len(data))
copy(buf, data)
buf, pos := cConn.startEphemeralPacketWithHeader(len(data))
copy(buf[pos:], data)
if err := cConn.writeEphemeralPacket(); err != nil {
t.Fatalf("writeEphemeralPacket(true) failed: %v", err)
}
Expand Down
45 changes: 18 additions & 27 deletions go/mysql/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ func (c *Conn) WriteComQuery(query string) error {
// This is a new command, need to reset the sequence.
c.sequence = 0

data := c.startEphemeralPacket(len(query) + 1)
data[0] = ComQuery
copy(data[1:], query)
data, pos := c.startEphemeralPacketWithHeader(len(query) + 1)
data[pos] = ComQuery
pos++
copy(data[pos:], query)
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
Expand All @@ -55,9 +56,10 @@ func (c *Conn) WriteComQuery(query string) error {
// Client -> Server.
// Returns SQLError(CRServerGone) if it can't.
func (c *Conn) writeComInitDB(db string) error {
data := c.startEphemeralPacket(len(db) + 1)
data[0] = ComInitDB
copy(data[1:], db)
data, pos := c.startEphemeralPacketWithHeader(len(db) + 1)
data[pos] = ComInitDB
pos++
copy(data[pos:], db)
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
Expand All @@ -67,9 +69,10 @@ func (c *Conn) writeComInitDB(db string) error {
// writeComSetOption changes the connection's capability of executing multi statements.
// Returns SQLError(CRServerGone) if it can't.
func (c *Conn) writeComSetOption(operation uint16) error {
data := c.startEphemeralPacket(16 + 1)
data[0] = ComSetOption
writeUint16(data, 1, operation)
data, pos := c.startEphemeralPacketWithHeader(16 + 1)
data[pos] = ComSetOption
pos++
writeUint16(data, pos, operation)
if err := c.writeEphemeralPacket(); err != nil {
return NewSQLError(CRServerGone, SSUnknownSQLState, err.Error())
}
Expand Down Expand Up @@ -861,8 +864,8 @@ func (c *Conn) parseComInitDB(data []byte) string {

func (c *Conn) sendColumnCount(count uint64) error {
length := lenEncIntSize(count)
data := c.startEphemeralPacket(length)
writeLenEncInt(data, 0, count)
data, pos := c.startEphemeralPacketWithHeader(length)
writeLenEncInt(data, pos, count)
return c.writeEphemeralPacket()
}

Expand All @@ -889,8 +892,7 @@ func (c *Conn) writeColumnDefinition(field *querypb.Field) error {
flags = int64(field.Flags)
}

data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)

pos = writeLenEncString(data, pos, "def") // Always the same.
pos = writeLenEncString(data, pos, field.Database)
Expand Down Expand Up @@ -924,8 +926,7 @@ func (c *Conn) writeRow(row []sqltypes.Value) error {
}
}

data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)
for _, val := range row {
if val.IsNull() {
pos = writeByte(data, pos, NullValue)
Expand All @@ -936,10 +937,6 @@ func (c *Conn) writeRow(row []sqltypes.Value) error {
}
}

if pos != length {
return vterrors.Errorf(vtrpc.Code_INTERNAL, "packet row: got %v bytes but expected %v", pos, length)
}

return c.writeEphemeralPacket()
}

Expand Down Expand Up @@ -1012,8 +1009,7 @@ func (c *Conn) writePrepare(fld []*querypb.Field, prepare *PrepareData) error {
prepare.ColumnNames = make([]string, columnCount)
}

data := c.startEphemeralPacket(12)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(12)

pos = writeByte(data, pos, 0x00)
pos = writeUint32(data, pos, uint32(prepare.StatementID))
Expand Down Expand Up @@ -1081,8 +1077,7 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err

length += nullBitMapLen + 1

data := c.startEphemeralPacket(length)
pos := 0
data, pos := c.startEphemeralPacketWithHeader(length)

pos = writeByte(data, pos, 0x00)

Expand All @@ -1105,10 +1100,6 @@ func (c *Conn) writeBinaryRow(fields []*querypb.Field, row []sqltypes.Value) err
}
}

if pos != length {
return fmt.Errorf("internal error packet row: got %v bytes but expected %v", pos, length)
}

return c.writeEphemeralPacket()
}

Expand Down
Loading